Source code for niftynet.layer.layer_util
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import numpy as np
[docs]def check_spatial_dims(input_tensor, criteria):
"""
valid each of the spatial dims against `criteria`
criteria can be a lambda function
e.g. lambda x : x > 10 checks whether each dim is greater than 10
"""
input_shape = input_tensor.shape
if not input_shape.is_fully_defined():
# skip checking if the input has dynamic shapes
return True
input_shape.with_rank_at_least(3)
spatial_dims = input_shape[1:-1].as_list()
all_dims_satisfied = np.all([criteria(x) for x in spatial_dims])
if not all_dims_satisfied:
import inspect
raise ValueError("input tensor's spatial dimensionality not"
" not compatible, please tune "
"the input window sizes:\n{}".format(
inspect.getsource(criteria)))
return all_dims_satisfied
[docs]def infer_spatial_rank(input_tensor):
"""
e.g. given an input tensor [Batch, X, Y, Z, Feature] the spatial rank is 3
"""
input_shape = input_tensor.shape
input_shape.with_rank_at_least(3)
#dims = input_tensor.get_shape().ndims - 2
#assert dims > 0, "input tensor should have at least one spatial dim, " \
# "in addition to batch and channel dims"
return int(input_shape.ndims - 2)
[docs]def trivial_kernel(kernel_shape):
"""
This function generates a trivial kernel with all 0s except for the
element in its spatial center
e.g. trivial_kernel((3, 3, 1, 1,)) returns a kernel of::
[[[[0]], [[0]], [[0]]],
[[[0]], [[1]], [[0]]],
[[[0]], [[0]], [[0]]]]
kernel_shape[-1] and kernel_shape[-2] should be 1, so that it operates
on the spatial dims only. However, there is no exact spatial centre
if np.any((kernel_shape % 2) == 0). This is fine in many cases
as np.sum(trivial_kernel(kernel_shape)) == 1
"""
assert kernel_shape[-1] == 1
assert kernel_shape[-2] == 1
# assert np.all((kernel_shape % 2) == 1)
kernel = np.zeros(kernel_shape)
flattened = kernel.reshape(-1)
flattened[np.prod(kernel_shape) // 2] = 1
return flattened.reshape(kernel_shape)
[docs]def expand_spatial_params(input_param, spatial_rank, param_type=int):
"""
Expand input parameter
e.g., ``kernel_size=3`` is converted to ``kernel_size=[3, 3, 3]``
for 3D images (when ``spatial_rank == 3``).
"""
spatial_rank = int(spatial_rank)
try:
if param_type == int:
input_param = int(input_param)
else:
input_param = float(input_param)
return (input_param,) * spatial_rank
except (ValueError, TypeError):
pass
try:
if param_type == int:
input_param = \
np.asarray(input_param).flatten().astype(np.int).tolist()
else:
input_param = \
np.asarray(input_param).flatten().astype(np.float).tolist()
except (ValueError, TypeError):
# skip type casting if it's a TF tensor
pass
assert len(input_param) >= spatial_rank, \
'param length should be at least have the length of spatial rank'
return tuple(input_param[:spatial_rank])
# class RequireKeywords(object):
# def __init__(self, *list_of_keys):
# self.keys = list_of_keys
#
# def __call__(self, f):
# def wrapped(*args, **kwargs):
# for key in self.keys:
# if key not in kwargs:
# raise ValueError("{}: specify keywords: '{}'".format(
# args[0].layer_scope().name, self.keys))
# return f(*args, **kwargs)
# return wrapped
[docs]def check_divisible_channels(input_tensor, n_channel_splits):
"""
Check if the number of channels (last dim) of the input tensor
is divisible by ``n_channel_splits``. If True, returns
``n_input_channels / n_channel_splits``, raises AssertionError otherwise
:param input_tensor:
:param n_channel_splits:
:return: n_input_channels / n_channel_splits
"""
n_input_channels = int(input_tensor.shape.as_list()[-1])
n_channel_splits = int(n_channel_splits)
assert n_channel_splits > 0 and n_input_channels % n_channel_splits == 0, \
"Number of feature channels should be divisible by " \
"n_channel_splits {}, so that given an input with n_input_channels, " \
"the output tensor will have " \
"n_input_channels / n_channel_splits.".format(n_channel_splits)
return n_input_channels / n_channel_splits