Source code for niftynet.layer.resampler

# -*- coding: utf-8 -*-
"""
Resampler layer initially implemented in
https://cmiclab.cs.ucl.ac.uk/CMIC/NiftyNet/blob/v0.2.0.post1/niftynet/layer/spatial_transformer.py
"""
from __future__ import absolute_import, division, print_function

import tensorflow as tf

from niftynet.layer.base_layer import Layer
from niftynet.layer.layer_util import infer_spatial_rank
from niftynet.utilities.util_common import look_up_operations

COORDINATES_TYPE = tf.int32
EPS = 1e-6


[docs]class ResamplerLayer(Layer): """ resample inputs according to sample_coords """ def __init__(self, interpolation="LINEAR", boundary="ZERO", name="resampler"): super(ResamplerLayer, self).__init__(name=name) self.boundary = boundary.upper() self.boundary_func = look_up_operations( self.boundary, SUPPORTED_BOUNDARY) self.interpolation = look_up_operations( interpolation.upper(), SUPPORTED_INTERPOLATION) if self.boundary == 'ZERO' and self.interpolation == 'BSPLINE': tf.logging.fatal('Zero padding is not supported for BSPLINE mode') raise NotImplementedError if self.boundary == 'ZERO' and self.interpolation == 'IDW': tf.logging.fatal('Zero padding is not supported for IDW mode') raise NotImplementedError
[docs] def layer_op(self, inputs, sample_coords): if inputs.dtype not in SUPPORTED_INPUT_DTYPE: # tf.logging.warning('input datatype should be in %s ', # SUPPORTED_INPUT_DTYPE) # raise TypeError inputs = tf.to_float(inputs) if sample_coords.dtype not in SUPPORTED_INPUT_DTYPE: sample_coords = tf.to_float(sample_coords) if self.interpolation == 'LINEAR': return self._resample_linear(inputs, sample_coords) if self.interpolation == 'NEAREST': return self._resample_nearest(inputs, sample_coords) if self.interpolation == 'BSPLINE': return self._resample_bspline(inputs, sample_coords) if self.interpolation == 'IDW': return self._resample_inv_dst_weighting(inputs, sample_coords) tf.logging.fatal('interpolation method not implemented') raise NotImplementedError
def _resample_nearest(self, inputs, sample_coords): in_size = inputs.get_shape().as_list() in_spatial_size = in_size[1:-1] # This is forward only as no gradient for tf.round spatial_coords = self.boundary_func( tf.round(sample_coords), in_spatial_size) spatial_coords = tf.cast(spatial_coords, COORDINATES_TYPE) output = tf.stack([ tf.gather_nd(img, coords) for (img, coords) in zip(tf.unstack(inputs), tf.unstack(spatial_coords))]) if self.boundary == 'ZERO': scale = 1. / (tf.constant(in_spatial_size, dtype=tf.float32) - 1) mask = tf.logical_and( tf.reduce_all(sample_coords > 0, axis=-1, keep_dims=True), tf.reduce_all(scale * sample_coords < 1, axis=-1, keep_dims=True)) return output * tf.to_float(mask) return output def _resample_linear(self, inputs, sample_coords): in_size = inputs.get_shape().as_list() in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) batch_size = in_size[0] out_spatial_rank = infer_spatial_rank(sample_coords) out_spatial_size = sample_coords.get_shape().as_list()[1:-1] if in_spatial_rank == 2 and self.boundary == 'ZERO': inputs = tf.transpose(inputs, [0, 2, 1, 3]) return tf.contrib.resampler.resampler(inputs, sample_coords) xy = tf.unstack(sample_coords, axis=-1) base_coords = [tf.floor(coords) for coords in xy] floor_coords = [ tf.cast(self.boundary_func(x, in_spatial_size[idx]), COORDINATES_TYPE) for (idx, x) in enumerate(base_coords)] ceil_coords = [ tf.cast(self.boundary_func(x + 1.0, in_spatial_size[idx]), COORDINATES_TYPE) for (idx, x) in enumerate(base_coords)] if self.boundary == 'ZERO': weight_0 = [tf.expand_dims(x - tf.cast(i, tf.float32), -1) for (x, i) in zip(xy, floor_coords)] weight_1 = [tf.expand_dims(tf.cast(i, tf.float32) - x, -1) for (x, i) in zip(xy, ceil_coords)] else: weight_0 = [tf.expand_dims(x - i, -1) for (x, i) in zip(xy, base_coords)] weight_1 = [1.0 - w for w in weight_0] batch_ids = tf.reshape( tf.range(batch_size), [batch_size] + [1] * out_spatial_rank) batch_ids = tf.tile(batch_ids, [1] + out_spatial_size) sc = (floor_coords, ceil_coords) def get_knot(binary_code): coord = [sc[code][ind] for ind, code in enumerate(binary_code)] coord = tf.stack([batch_ids] + coord, -1) return tf.gather_nd(inputs, coord) def _pyramid_combination(two_samples, w_0, w_1): if len(w_0) == 1: return two_samples[0] * w_1[0] + two_samples[1] * w_0[0] f_0 = _pyramid_combination(two_samples[::2], w_0[:-1], w_1[:-1]) f_1 = _pyramid_combination(two_samples[1::2], w_0[:-1], w_1[:-1]) return f_0 * w_1[-1] + f_1 * w_0[-1] binary_neighbour_ids = [ [int(c) for c in format(i, '0%ib' % in_spatial_rank)] for i in range(2 ** in_spatial_rank)] samples = [get_knot(bc) for bc in binary_neighbour_ids] return _pyramid_combination(samples, weight_0, weight_1) def _resample_bspline(self, inputs, sample_coords): in_size = inputs.get_shape().as_list() batch_size = in_size[0] in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) out_spatial_rank = infer_spatial_rank(sample_coords) if in_spatial_rank == 2: raise NotImplementedError( 'bspline interpolation not implemented for 2d yet') floor_coords = tf.floor(sample_coords) # Compute voxels to use for interpolation grid = tf.meshgrid([-1., 0., 1., 2.], [-1., 0., 1., 2.], [-1., 0., 1., 2.], indexing='ij') offset_shape = [1, -1] + [1] * out_spatial_rank + [in_spatial_rank] offsets = tf.reshape(tf.stack(grid, 3), offset_shape) spatial_coords = offsets + tf.expand_dims(floor_coords, 1) spatial_coords = self.boundary_func(spatial_coords, in_spatial_size) spatial_coords = tf.cast(spatial_coords, COORDINATES_TYPE) knot_size = spatial_coords.get_shape().as_list() # Compute weights for each voxel def build_coef(u, d): coeff_list = [tf.pow(1 - u, 3), 3 * tf.pow(u, 3) - 6 * tf.pow(u, 2) + 4, -3 * tf.pow(u, 3) + 3 * tf.pow(u, 2) + 3 * u + 1, tf.pow(u, 3)] return tf.concat(coeff_list, d) / 6 weight = tf.reshape(sample_coords - floor_coords, [batch_size, -1, 3]) coef_shape = [batch_size, 1, 1, 1, -1] Bu = build_coef(tf.reshape(weight[:, :, 0], coef_shape), 1) Bv = build_coef(tf.reshape(weight[:, :, 1], coef_shape), 2) Bw = build_coef(tf.reshape(weight[:, :, 2], coef_shape), 3) all_weights = tf.reshape(Bu * Bv * Bw, [batch_size] + knot_size[1:-1] + [1]) # Gather voxel values and compute weighted sum batch_coords = tf.reshape( tf.range(batch_size), [batch_size] + [1] * (len(knot_size) - 1)) batch_coords = tf.tile(batch_coords, [1] + knot_size[1:-1] + [1]) raw_samples = tf.gather_nd( inputs, tf.concat([batch_coords, spatial_coords], -1)) return tf.reduce_sum(all_weights * raw_samples, reduction_indices=1) def _resample_inv_dst_weighting(self, inputs, sample_coords): in_size = inputs.get_shape().as_list() in_spatial_size = in_size[1:-1] in_spatial_rank = infer_spatial_rank(inputs) out_rank = len(sample_coords.get_shape().as_list()) self.N = 2 ** in_spatial_rank binary_neighbour_ids = [ [int(c) for c in format(i, '0%ib' % in_spatial_rank)] for i in range(self.N)] weight_id = [[[c, i] for i, c in enumerate(bc)] for bc in binary_neighbour_ids] sample_coords = tf.transpose( sample_coords, [out_rank - 1, 0] + list(range(1, out_rank - 1))) # broadcasting input spatial size for boundary functions b_size = tf.reshape(in_spatial_size, [len(in_spatial_size)] + [1] * (out_rank - 1)) # find floor and ceil coordinates all_coords_f = tf.stack([ self.boundary_func(tf.floor(sample_coords), b_size), self.boundary_func(tf.ceil(sample_coords), b_size)]) # find N weights associated to each output point diff = tf.stack( [tf.squared_difference(sample_coords - EPS, all_coords_f[0]), tf.squared_difference(sample_coords + EPS, all_coords_f[1])]) # gather_nd for both matrices, the same as: # point_weights = tf.gather_nd(diff, weight_id) # knots_id = tf.gather_nd(all_coords_f, weight_id) n_val = tf.gather_nd(tf.stack([diff, all_coords_f], axis=-1), weight_id) n_val = tf.unstack(n_val, axis=-1) point_weights, knots_id = n_val[0], n_val[1] # inverse distance weighting # sum_i (w_i*p_i/(sum_j w_j)) w_i = 1/((p-p_i)^2) # point_weights shape: # `[N, input_rank, b, sp_dim_0, ..., sp_dim_K]` # where: # `N` is 2**source data spatial rank # `b` is batch size, # `sp_dim_0` is the output spatial output 0, # # `point_weights` represents (p - p_i)^2 # with i= 0...2**source_rank neighbours # (to do: these operations could be refactored as a resampling kernel) point_weights = tf.reduce_sum(point_weights, axis=1) # skip this as power = 2.0: # self.power = 1.0 # point_weights = tf.pow(point_weights, self.power / 2.0) point_weights = tf.reciprocal(point_weights) point_weights = point_weights / tf.reduce_sum(point_weights, axis=0) # find N neighbours associated to each output point knots_id = tf.transpose(tf.cast(knots_id, COORDINATES_TYPE), [0] + list(range(2, out_rank + 1)) + [1]) # get values of N neighbours samples = [ tf.gather_nd(img, knots) for (img, knots) in zip(tf.unstack(inputs, axis=0), tf.unstack(knots_id, axis=1))] samples = tf.stack(samples, axis=1) # weighted average over N neighbours return tf.reduce_sum( samples * tf.expand_dims(point_weights, axis=-1), axis=0)
def _boundary_replicate(sample_coords, input_size): sample_coords, input_size = _param_type_and_shape(sample_coords, input_size) return tf.maximum(tf.minimum(sample_coords, input_size - 1), 0) def _boundary_circular(sample_coords, input_size): sample_coords, input_size = _param_type_and_shape(sample_coords, input_size) return tf.mod(tf.mod(sample_coords, input_size) + input_size, input_size) def _boundary_symmetric(sample_coords, input_size): sample_coords, input_size = _param_type_and_shape(sample_coords, input_size) circular_size = input_size + input_size - 2 return (input_size - 1) - tf.abs( (input_size - 1) - _boundary_circular(sample_coords, circular_size)) def _param_type_and_shape(sample_coords, input_size): # sample_coords = tf.cast(sample_coords, COORDINATES_TYPE) try: input_size = tf.constant(input_size, dtype=tf.float32) except TypeError: pass try: input_size = tf.to_float(input_size) except TypeError: pass return sample_coords, input_size @tf.RegisterGradient('FloorMod') def _floormod_grad(op, grad): return [None, None] SUPPORTED_INTERPOLATION = set(['BSPLINE', 'LINEAR', 'NEAREST', 'IDW']) SUPPORTED_BOUNDARY = { 'ZERO': _boundary_replicate, 'REPLICATE': _boundary_replicate, 'CIRCULAR': _boundary_circular, 'SYMMETRIC': _boundary_symmetric} SUPPORTED_INPUT_DTYPE = set([tf.float32])