Source code for niftynet.layer.spatial_transformer

# -*- coding: utf-8 -*-
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
# Modifications copyright 2018 The NiftyNet Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

""""Implementation of Spatial Transformer networks core components."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf

from niftynet.layer.grid_warper import GridWarperLayer
from niftynet.layer.resampler import ResamplerLayer

SUPPORTED_INTERPOLATION = set(['BSPLINE', 'LINEAR', 'NEAREST'])
SUPPORTED_BOUNDARY = set(['ZERO', 'REPLICATE', 'CIRCULAR', 'SYMMETRIC'])


[docs]class BSplineFieldImageGridWarperLayer(GridWarperLayer): """ The fast BSpline Grid Warper defines a grid based on sampling coordinate values from a spatially varying displacement field (passed as a tensor input) along a regular cartesian grid pattern aligned with the field. Specifically, this class defines a grid based on BSpline smoothing, as described by Rueckert et al. To ensure that it can be done efficiently, several assumptions are made: 1) The grid is a cartesian grid aligned with the field. 2) Knots occur every M,N,O grid points (in X,Y,Z) This allows the smoothing to be represented as a 4x4x4 convolutional kernel with MxNxO channels """
[docs] def __init__(self, source_shape, output_shape, knot_spacing, name='interpolated_spline_grid_warper_layer'): """Constructs an BSplineFieldImageGridWarperLayer. Args: source_shape: Iterable of integers determining the size of the source signal domain. output_shape: Iterable of integers determining the size of the destination resampled signal domain. knot_spacing: List of intervals (in voxels) in each dimension where displacements are defined in the field. interpolation: type_str of interpolation as used by tf.image.resize_images name: Name of Module.""" coeff_shape=[4+(n-1)//k for n,k in zip(output_shape,knot_spacing)] self._knot_spacing=knot_spacing super(BSplineFieldImageGridWarperLayer, self).__init__(source_shape=source_shape, output_shape=output_shape, coeff_shape=coeff_shape, name=name)
def _create_features(self): """ Creates the convolutional kernel""" build_coefficient = lambda u,d: np.reshape(np.stack([(np.power(1-u,3))/6, (3*np.power(u,3) - 6*np.power(u,2) + 4)/6, (-3*np.power(u,3) + 3*np.power(u,2) + 3*u + 1)/6, np.power(u,3)/6],0),np.roll([4,1,1,len(u),1,1],d)) coeffs = [build_coefficient(np.arange(k)/k,d) for d,k in enumerate(self._knot_spacing)] kernels = tf.constant(np.reshape(np.prod(coeffs),[4,4,4,1,-1]),dtype=tf.float32) return kernels
[docs] def layer_op(self,field): batch_size=int(field.shape.as_list()[0]) spatial_rank = int(field.shape.as_list()[-1]) resampled_list=[tf.nn.conv3d(field[:, :, :, :, d:d + 1], self._psi, strides=[1]*5, padding='VALID') for d in [0, 1, 2]] resampled=tf.stack(resampled_list,5) permuted_shape=[batch_size]+[f-3 for f in self._coeff_shape]+self._knot_spacing+[spatial_rank] permuted=tf.transpose(tf.reshape(resampled,permuted_shape),[0,1,4,2,5,3,6,7]) valid_size=[(f-3)*k for f,k in zip(self._coeff_shape,self._knot_spacing)] reshaped=tf.reshape(permuted,[batch_size]+valid_size+[spatial_rank]) cropped = reshaped[:,:self._output_shape[0],:self._output_shape[1],:self._output_shape[2],:] return cropped
[docs]class RescaledFieldImageGridWarperLayer(GridWarperLayer): """ The rescaled field grid warper defines a grid based on sampling coordinate values from a spatially varying displacement field (passed as a tensor input) along a regular cartesian grid pattern aligned with the field. Specifically, this class defines a grid by resampling the field (using tf.rescale_images with align_corners=False) to the output_shape. """
[docs] def __init__(self, source_shape, output_shape, coeff_shape, interpolation=tf.image.ResizeMethod.BICUBIC, name='rescaling_interpolated_spline_grid_warper_layer'): """ Constructs an RescaledFieldImageGridWarperLayer. Args: source_shape: Iterable of integers determining the size of the source signal domain. output_shape: Iterable of integers determining the size of the destination resampled signal domain. coeff_shape: Shape of displacement field. interpolation: type_str of interpolation as used by tf.image.resize_images name: Name of Module. """ self._interpolation=interpolation if self._interpolation=='LINEAR': self._interpolation=tf.image.ResizeMethod.BILINEAR elif self._interpolation=='CUBIC': self._interpolation=tf.image.ResizeMethod.BICUBIC super(RescaledFieldImageGridWarperLayer, self).__init__(source_shape=source_shape, output_shape=output_shape, coeff_shape=coeff_shape, name=name)
[docs] def layer_op(self,field): input_shape = tf.shape(field) input_dtype = field.dtype.as_numpy_dtype batch_size = int(field.shape[0]) reshaped_field=tf.reshape(field, [batch_size, self._coeff_shape[0], self._coeff_shape[1], -1]) coords_intermediate = tf.image.resize_images(reshaped_field,self._output_shape[0:2], self._interpolation,align_corners=False) sz_xy_z1=[batch_size,self._output_shape[0]*self._output_shape[1],self._coeff_shape[2],-1] tmp=tf.reshape(coords_intermediate,sz_xy_z1) final_sz=[batch_size]+list(self._output_shape)+[-1] sz_xy_z2=[self._output_shape[0]*self._output_shape[1],self._output_shape[2]] coords=tf.reshape(tf.image.resize_images(tmp,sz_xy_z2,self._interpolation,align_corners=False),final_sz) return coords
[docs]class ResampledFieldGridWarperLayer(GridWarperLayer): """ The resampled field grid warper defines a grid based on sampling coordinate values from a spatially varying displacement field (passed as a tensor input) along an affine grid pattern in the field. This enables grids representing small patches of a larger transform, as well as the composition of multiple transforms before sampling. """
[docs] def __init__(self, source_shape, output_shape, coeff_shape, field_transform=None, resampler=None, name='resampling_interpolated_spline_grid_warper'): """Constructs an ResampledFieldingGridWarperLayer. Args: source_shape: Iterable of integers determining the size of the source signal domain. output_shape: Iterable of integers determining the size of the destination resampled signal domain. coeff_shape: Shape of displacement field. interpolation: type_str of interpolation as used by tf.image.resize_images name: Name of Module. field_transform: an object defining the spatial relationship between the output_grid and the field. batch_size x4x4 tensor: per-image transform matrix from output coords to field coords None (default): corners of output map to corners of field with an allowance for interpolation (1 for bspline, 0 for linear) resampler: a ResamplerLayer used to interpolate the deformation field name: Name of module. Raises: TypeError: If output_shape and source_shape are not both iterable. """ if resampler==None: self._resampler=ResamplerLayer(interpolation='LINEAR',boundary='REPLICATE') self._interpolation = 'LINEAR' else: self._resampler=resampler self._interpolation = self._resampler.interpolation self._field_transform = field_transform super(ResampledFieldGridWarperLayer, self).__init__(source_shape=source_shape, output_shape=output_shape, coeff_shape=coeff_shape, name=name)
def _create_features(self): """Creates the coordinates for resampling. If field_transform is None, these are constant and are created in field space; otherwise, the final coordinates will be transformed by an input tensor representing a transform from output coordinates to field coordinates, so they are created are created in output coordinate space """ embedded_output_shape = list(self._output_shape)+[1]*(len(self._source_shape) - len(self._output_shape)) embedded_coeff_shape = list(self._coeff_shape)+[1]*(len(self._source_shape) - len(self._output_shape)) if self._field_transform==None and self._interpolation == 'BSPLINE': range_func= lambda f,x: tf.linspace(1.,f-2.,x) elif self._field_transform==None and self._interpolation != 'BSPLINE': range_func= lambda f,x: tf.linspace(0.,f-1.,x) else: range_func= lambda f,x: np.arange(x,dtype=np.float32) embedded_output_shape+=[1] # make homogeneous embedded_coeff_shape+=[1] ranges = [range_func(f,x) for f,x in zip(embedded_coeff_shape,embedded_output_shape)] coords= tf.stack([tf.reshape(x,[1,-1]) for x in tf.meshgrid(*ranges, indexing='ij')],2) return coords
[docs] def layer_op(self, field): """Assembles the module network and adds it to the graph. The internal computation graph is assembled according to the set of constraints provided at construction time. Args: field: Tensor containing a batch of transformation parameters. Returns: A batch of warped grids. Raises: Error: If the input tensor size is not consistent with the constraints passed at construction time. """ input_shape = tf.shape(field) input_dtype = field.dtype.as_numpy_dtype batch_size = int(field.shape[0]) # transform grid into field coordinate space if necessary if self._field_transform==None: coords=self._psi else: coords = tf.matmul(self._psi,self._field_transform[:,:,1:3]) # resample coords = tf.reshape(tf.tile(coords,[batch_size,1,1]),[-1]+list(self._output_shape)+[len(self._source_shape)]) resampled_coords = self._resampler(field, coords) return resampled_coords