# -*- coding: utf-8 -*-
"""
Resampler layer initially implemented in
https://github.com/niftk/NiftyNet/blob/v0.2.0.post1/niftynet/layer/spatial_transformer.py
"""
from __future__ import absolute_import, division, print_function
import tensorflow as tf
from niftynet.layer.base_layer import Layer
from niftynet.layer.layer_util import infer_spatial_rank
from niftynet.utilities.util_common import look_up_operations
COORDINATES_TYPE = tf.int32
EPS = 1e-6
[docs]class ResamplerLayer(Layer):
"""
resample inputs according to ``sample_coords``
"""
def __init__(self,
interpolation="LINEAR",
boundary="REPLICATE",
name="resampler"):
super(ResamplerLayer, self).__init__(name=name)
self.boundary = boundary.upper()
self.boundary_func = look_up_operations(
self.boundary, SUPPORTED_BOUNDARY)
self.interpolation = look_up_operations(
interpolation.upper(), SUPPORTED_INTERPOLATION)
if self.boundary == 'ZERO' and self.interpolation == 'BSPLINE':
tf.logging.fatal('Zero padding is not supported for BSPLINE mode')
raise NotImplementedError
if self.boundary == 'ZERO' and self.interpolation == 'IDW':
tf.logging.warning('Zero padding is not supported for IDW mode')
# raise NotImplementedError
[docs] def layer_op(self, inputs, sample_coords):
"""
This layer resamples 2D or 3D data given the coordinates.
In terms of 3D inputs,
when the shape of ``inputs`` is ``[batch, x, y, z, num_channels]``,
the shape of ``sample_coords`` can be
``[1, d0, d1, ..., 3]`` or ``[batch, d0, d1, ..., 3]``.
The output shape would be ``[batch, d0, d1, ..., num_channels]``.
Similarly, in 2D,
when the shape of ``inputs`` is ``[batch, x, y, num_channels]``,
the shape of ``sample_coords`` can be
``[1, d0, d1, ..., 2]`` or ``[batch, d0, d1, ..., 2]``.
The output shape would be ``[batch, d0, d1, ... num_channels]``
(If the shape of ``inputs`` is not fully specified, ``sample_coords``
must be checked before using this function, to make sure the
coordinates are pointing to locations within the inputs.)
(Resampling 2D inputs is implemented by calling
``tf.contrib.resampler.resampler``. The interpretaion of coordinates is
different in between this function and
``tf.contrib.resampler.resampler``:
using ``self.layer_op(inputs, sample_coords)`` for 2D data
is equivalent to (apart from the batch size broadcasting feature)::
tf.contrib.resampler.resampler(
tf.transpose(inputs, [0, 2, 1, 3]), sample_coords)
(No gradient is computed for ``NEAREST`` method, and
some of the padding modes.)
"""
# check the input dims
try:
batch_inputs = int(inputs.shape[0])
batch_sample_coords = int(sample_coords.shape[0])
except (TypeError, ValueError):
tf.logging.fatal('Unknown input shape, at least batch size '
'needs to be specified.')
raise
if batch_inputs != batch_sample_coords and batch_sample_coords > 1:
tf.logging.fatal(
'\nOnly the following two cases are currently supported:\n'
' - batch size of inputs == batch size of sample_coords\n'
' - batch size of sample_coords == 1\n'
'In the second case, sample_coords will be applied to each of '
'the batch component of the inputs.')
raise ValueError
# input_spatial_rank = infer_spatial_rank(inputs)
# if input_spatial_rank != 2 and input_spatial_rank != 3:
# tf.logging.fatal('Only 2D or 3D inputs are supported.')
# raise ValueError
try:
coords_n_dim = int(sample_coords.shape[-1])
except (TypeError, ValueError):
tf.logging.fatal(
'The last dim of the coordinates must have 2 or 3 elements.')
raise
if infer_spatial_rank(inputs) != coords_n_dim:
tf.logging.fatal(
'sample_coords.shape[-1] must be the same as the spatial rank '
'of the inputs.')
raise ValueError
# currently converting everything to floats
if inputs.dtype not in SUPPORTED_INPUT_DTYPE:
inputs = tf.to_float(inputs)
if sample_coords.dtype not in SUPPORTED_INPUT_DTYPE:
sample_coords = tf.to_float(sample_coords)
if self.interpolation == 'LINEAR':
return self._resample_linear(inputs, sample_coords)
if self.interpolation == 'NEAREST':
return self._resample_nearest(inputs, sample_coords)
if self.interpolation == 'BSPLINE':
return self._resample_bspline(inputs, sample_coords)
if self.interpolation == 'IDW':
return self._resample_inv_dst_weighting(inputs, sample_coords)
tf.logging.fatal('interpolation method not implmented')
raise NotImplementedError
def _resample_nearest(self, inputs, sample_coords):
# This is forward only as no gradient for tf.round
# read input shape
in_size = inputs.shape
partial_shape = not in_size.is_fully_defined()
in_spatial_size = None
try:
batch_size = int(in_size[0])
n_coords = int(sample_coords.shape[0])
except (TypeError, AssertionError, ValueError):
tf.logging.fatal('Unknown input shape, at least batch size '
'and rank of the inputs are required.')
raise
# quantise coordinates
spatial_coords = tf.round(sample_coords)
if not partial_shape:
in_spatial_size = in_size.as_list()[1:-1]
spatial_coords = self.boundary_func(
spatial_coords, in_spatial_size)
spatial_coords = tf.cast(spatial_coords, COORDINATES_TYPE)
if batch_size == n_coords:
batch_inputs = tf.unstack(inputs)
batch_coords = tf.unstack(spatial_coords)
gathered_image = [tf.gather_nd(img, coord) for (img, coord) in
zip(batch_inputs, batch_coords)]
elif n_coords == 1 and batch_size > 1:
gathered_image = [tf.gather_nd(img, spatial_coords[0])
for img in tf.unstack(inputs)]
else:
raise NotImplementedError
output = tf.stack(gathered_image, axis=0)
if self.boundary == 'ZERO' and in_spatial_size:
scale = 1. / (tf.constant(in_spatial_size, dtype=tf.float32) - 1.)
mask = tf.logical_and(
tf.reduce_all(sample_coords > 0, -1, True),
tf.reduce_all(scale * sample_coords < 1, -1, True))
return output * tf.to_float(mask)
return output
def _resample_linear(self, inputs, sample_coords):
"""
Bilinear or trilinear resampling.
:param inputs:
:param sample_coords:
:return:
"""
# read input shape
in_size = inputs.shape
partial_shape = not in_size.is_fully_defined()
try:
batch_size = int(in_size[0])
n_coords = int(sample_coords.shape[0])
in_spatial_rank = infer_spatial_rank(inputs)
in_spatial_size = \
None if partial_shape else in_size.as_list()[1:-1]
except (TypeError, AssertionError, ValueError):
tf.logging.fatal('Unknown input shape, at least batch size '
'and rank of the inputs are required.')
raise
# read output shape
out_spatial_rank = infer_spatial_rank(sample_coords)
out_spatial_size = sample_coords.shape.as_list()[1:-1]
if in_spatial_rank == 2 and self.boundary == 'ZERO':
# calling TF's resampler
inputs = tf.transpose(inputs, [0, 2, 1, 3])
if batch_size == n_coords:
return tf.contrib.resampler.resampler(inputs, sample_coords)
outputs = [
tf.contrib.resampler.resampler(
tf.expand_dims(img), sample_coords)
for img in tf.unstack(inputs)]
return tf.concat(outputs, axis=0)
xy = tf.unstack(sample_coords, axis=-1)
base_coords = [tf.floor(coords) for coords in xy]
if partial_shape:
# if input shape is not defined, unable to compute
# boundary elements
floor_coords = [coord for coord in base_coords]
ceil_coords = [coord + 1.0 for coord in base_coords]
else:
floor_coords = [self.boundary_func(x, in_spatial_size[idx])
for (idx, x) in enumerate(base_coords)]
ceil_coords = [self.boundary_func(x + 1.0, in_spatial_size[idx])
for (idx, x) in enumerate(base_coords)]
if self.boundary == 'ZERO':
weight_0 = [tf.expand_dims(x - i, -1)
for (x, i) in zip(xy, floor_coords)]
weight_1 = [tf.expand_dims(i - x, -1)
for (x, i) in zip(xy, ceil_coords)]
else:
weight_0 = [tf.expand_dims(x - i, -1)
for (x, i) in zip(xy, base_coords)]
weight_1 = [1.0 - w for w in weight_0]
sc = (tf.cast(floor_coords, COORDINATES_TYPE),
tf.cast(ceil_coords, COORDINATES_TYPE))
if n_coords == 1 and batch_size > 1:
# fetch neighbours with the same coordinates across the input batch
inputs = tf.unstack(inputs)
def _get_knot(bc):
coord = [sc[c][i] for i, c in enumerate(bc)]
coord = tf.stack(coord, axis=-1)
batch_samples = [tf.gather_nd(img, coord) for img in inputs]
batch_samples = tf.concat(batch_samples, axis=0)
return batch_samples
elif n_coords == batch_size:
batch_ids = tf.reshape(
tf.range(batch_size), [batch_size] + [1] * out_spatial_rank)
batch_ids = tf.tile(batch_ids, [1] + out_spatial_size)
def _get_knot(bc):
coord = [batch_ids] + [sc[c][i] for i, c in enumerate(bc)]
coord = tf.stack(coord, axis=-1)
return tf.gather_nd(inputs, coord)
else:
raise NotImplementedError
def _pyramid_combination(samples, w_0, w_1):
# the case where n_coords = 1 and batch_size > 1 is handled by
# shape broadcasting
if len(w_0) == 1:
return samples[0] * w_1[0] + samples[1] * w_0[0]
f_0 = _pyramid_combination(samples[::2], w_0[:-1], w_1[:-1])
f_1 = _pyramid_combination(samples[1::2], w_0[:-1], w_1[:-1])
return f_0 * w_1[-1] + f_1 * w_0[-1]
binary_neighbour_ids = _binary_neighbour_ids(in_spatial_rank)
samples = [_get_knot(bc) for bc in binary_neighbour_ids]
return _pyramid_combination(samples, weight_0, weight_1)
def _resample_bspline(self, inputs, sample_coords):
assert inputs.shape.is_fully_defined(), \
"input shape should be fully defined for bspline interpolation"
in_size = inputs.shape.as_list()
batch_size = in_size[0]
in_spatial_size = in_size[1:-1]
in_spatial_rank = infer_spatial_rank(inputs)
out_spatial_rank = infer_spatial_rank(sample_coords)
if in_spatial_rank == 2:
raise NotImplementedError(
'bspline interpolation not implemented for 2d yet')
assert batch_size == int(sample_coords.get_shape()[0])
floor_coords = tf.floor(sample_coords)
# Compute voxels to use for interpolation
grid = tf.meshgrid([-1., 0., 1., 2.],
[-1., 0., 1., 2.],
[-1., 0., 1., 2.],
indexing='ij')
offset_shape = [1, -1] + [1] * out_spatial_rank + [in_spatial_rank]
offsets = tf.reshape(tf.stack(grid, 3), offset_shape)
spatial_coords = offsets + tf.expand_dims(floor_coords, 1)
spatial_coords = self.boundary_func(spatial_coords, in_spatial_size)
spatial_coords = tf.cast(spatial_coords, COORDINATES_TYPE)
knot_size = spatial_coords.shape.as_list()
# Compute weights for each voxel
def build_coef(u, d):
coeff_list = [tf.pow(1 - u, 3),
3 * tf.pow(u, 3) - 6 * tf.pow(u, 2) + 4,
-3 * tf.pow(u, 3) + 3 * tf.pow(u, 2) + 3 * u + 1,
tf.pow(u, 3)]
return tf.concat(coeff_list, d) / 6
weight = tf.reshape(sample_coords - floor_coords, [batch_size, -1, 3])
coef_shape = [batch_size, 1, 1, 1, -1]
Bu = build_coef(tf.reshape(weight[:, :, 0], coef_shape), 1)
Bv = build_coef(tf.reshape(weight[:, :, 1], coef_shape), 2)
Bw = build_coef(tf.reshape(weight[:, :, 2], coef_shape), 3)
all_weights = tf.reshape(Bu * Bv * Bw,
[batch_size] + knot_size[1:-1] + [1])
# Gather voxel values and compute weighted sum
batch_coords = tf.reshape(
tf.range(batch_size), [batch_size] + [1] * (len(knot_size) - 1))
batch_coords = tf.tile(batch_coords, [1] + knot_size[1:-1] + [1])
raw_samples = tf.gather_nd(
inputs, tf.concat([batch_coords, spatial_coords], -1))
return tf.reduce_sum(all_weights * raw_samples, reduction_indices=1)
def _resample_inv_dst_weighting(self, inputs, sample_coords):
# inverse distance weighting using 2^(sptial_rank) neighbours
in_size = inputs.shape
partial_shape = not in_size.is_fully_defined()
try:
batch_size = int(in_size[0])
n_coords = int(sample_coords.shape[0])
in_spatial_rank = infer_spatial_rank(inputs)
in_spatial_size = \
None if partial_shape else in_size.as_list()[1:-1]
except (TypeError, AssertionError, ValueError):
tf.logging.fatal('Unknown input shape, at least batch size '
'and rank of the inputs are required.')
raise
out_rank = len(sample_coords.get_shape())
binary_neighbour_ids = _binary_neighbour_ids(in_spatial_rank)
weight_id = [[[c, i] for i, c in enumerate(bc)]
for bc in binary_neighbour_ids]
sample_coords_shape = [out_rank - 1, 0] + list(range(1, out_rank - 1))
sample_coords = tf.transpose(sample_coords, sample_coords_shape)
if partial_shape or in_spatial_size is None:
all_coords_f = tf.stack(
[tf.floor(sample_coords), tf.ceil(sample_coords)])
else:
# broadcasting input spatial size for boundary functions
expanded_spatial_size = \
[len(in_spatial_size)] + [1] * (out_rank - 1)
b_size = tf.reshape(in_spatial_size, expanded_spatial_size)
# find floor and ceil coordinates
all_coords_f = tf.stack([
self.boundary_func(tf.floor(sample_coords), b_size),
self.boundary_func(tf.ceil(sample_coords), b_size)])
# find N weights associated to each output point
diff = tf.stack(
[tf.squared_difference(sample_coords - EPS, all_coords_f[0]),
tf.squared_difference(sample_coords + EPS, all_coords_f[1])])
# gather_nd for both matrices, the same as:
# point_weights = tf.gather_nd(diff, weight_id)
# knots_id = tf.gather_nd(all_coords_f, weight_id)
n_val = tf.gather_nd(
tf.stack([diff, all_coords_f], axis=-1), weight_id)
n_val = tf.unstack(n_val, axis=-1)
point_weights, knots_id = n_val[0], n_val[1]
# inverse distance weighting
# sum_i (w_i*p_i/(sum_j w_j)) where w_i = 1/((p-p_i)^2)
# point_weights has the shape:100
# `[N, input_rank, b, sp_dim_0, ..., sp_dim_K]`
# where:
# `N` is 2**source data spatial rank
# `b` is batch size,
# `sp_dim_0` is the output spatial output 0,
#
# `point_weights` represents (p - p_i)^2
# with i= 0...2**source_rank neighbours
# (to do: these operations could be refactored as a resampling kernel)
point_weights = tf.reduce_sum(point_weights, axis=1)
# skip this as power = 2.0:
# self.power = 2.0
# point_weights = tf.pow(point_weights, self.power / 2.0)
point_weights = tf.reciprocal(point_weights)
point_weights = point_weights / tf.reduce_sum(point_weights, axis=0)
# find N neighbours associated to each output point
# knots_shape = tf.concat([[0], tf.range(2, out_rank + 1), [1]], 0)
knots_shape = [0] + list(range(2, out_rank + 1)) + [1]
knots_id = tf.cast(knots_id, COORDINATES_TYPE)
knots_id = tf.transpose(knots_id, knots_shape)
# get values of N neighbours
batch_inputs = tf.unstack(inputs, axis=0)
batch_knots = tf.unstack(knots_id, axis=1)
if batch_size == n_coords:
samples = [tf.gather_nd(img, knot)
for (img, knot) in zip(batch_inputs, batch_knots)]
elif n_coords == 1 and batch_size > 1:
samples = [tf.gather_nd(img, batch_knots[0])
for img in batch_inputs]
else:
raise NotImplementedError
samples = tf.stack(samples, axis=1)
# weighted average over N neighbours
return tf.reduce_sum(
samples * tf.expand_dims(point_weights, axis=-1), axis=0)
def _boundary_replicate(sample_coords, input_size):
sample_coords, input_size = _param_type_and_shape(sample_coords, input_size)
return tf.maximum(tf.minimum(sample_coords, input_size - 1), 0)
def _boundary_circular(sample_coords, input_size):
sample_coords, input_size = _param_type_and_shape(sample_coords, input_size)
return tf.mod(tf.mod(sample_coords, input_size) + input_size, input_size)
def _boundary_symmetric(sample_coords, input_size):
sample_coords, input_size = _param_type_and_shape(sample_coords, input_size)
circular_size = input_size + input_size - 2
return (input_size - 1) - tf.abs(
(input_size - 1) - _boundary_circular(sample_coords, circular_size))
def _param_type_and_shape(sample_coords, input_size):
# sample_coords = tf.cast(sample_coords, COORDINATES_TYPE)
try:
# if input_size is a numpy array
input_size = tf.constant(input_size, dtype=sample_coords.dtype)
except (TypeError, AttributeError):
pass
try:
# if input_size is a TF Tensor
input_size = tf.cast(input_size, dtype=sample_coords.dtype)
except (TypeError, AttributeError):
pass
return sample_coords, input_size
def _binary_neighbour_ids(spatial_rank):
"""
returns combinatorial binary indices::
1-D: [[0], [1]]
2-D: [[0, 0], [0, 1], [1, 0], [1, 1]]
3-D: [[0, 0, 0], [0, 0, 1], [0, 1, 0],
[0, 1, 1], [1, 0, 0], [1, 0, 1],
[1, 1, 0], [1, 1, 1]]
"""
return [[int(c) for c in format(i, '0%ib' % spatial_rank)]
for i in range(2 ** spatial_rank)]
try: # Some tf versions have this defined already
@tf.RegisterGradient('FloorMod')
def _floormod_grad(op, grad):
return [None, None]
except:
pass
SUPPORTED_INTERPOLATION = {'BSPLINE', 'LINEAR', 'NEAREST', 'IDW'}
SUPPORTED_BOUNDARY = {
'ZERO': _boundary_replicate,
'REPLICATE': _boundary_replicate,
'CIRCULAR': _boundary_circular,
'SYMMETRIC': _boundary_symmetric}
SUPPORTED_INPUT_DTYPE = {tf.float32}