Source code for niftynet.layer.grid_warper

# -*- coding: utf-8 -*-
# Copyright 2018 The Sonnet Authors. All Rights Reserved.
# Modifications copyright 2018 The NiftyNet Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""
Grid warper layer and utilities
adapted from
https://github.com/deepmind/sonnet/blob/v1.13/sonnet/python/modules/spatial_transformer.py
https://github.com/niftk/NiftyNet/blob/v0.2.0.post1/niftynet/layer/spatial_transformer.py
"""
from __future__ import absolute_import, division, print_function

from itertools import chain

import numpy as np
import tensorflow as tf

from niftynet.layer.base_layer import Layer, LayerFromCallable, Invertible


[docs]class GridWarperLayer(Layer): """ Grid warper interface class. An object implementing the `GridWarper` interface generates a reference grid of feature points at construction time, and warps it via a parametric transformation model, specified at run time by an input parameter Tensor. Grid warpers must then implement a `create_features` function used to generate the reference grid to be warped in the forward pass (according to a determined warping model). """
[docs] def __init__(self, source_shape, output_shape, coeff_shape, name, **kwargs): """ Constructs a GridWarper module and initializes the source grid params. `source_shape` and `output_shape` defines the size of the source and output signal domains. For example, for an image of size `width=W` and `height=H`, `{source,output}_shape=[H, W]`; for a volume of size `width=W`, `height=H` and `depth=D`, `{source,output}_shape=[H, W, D]`. Args: source_shape: Iterable of integers determining the size of the source signal domain. output_shape: Iterable of integers determining the size of the destination resampled signal domain. coeff_shape: Shape of coefficients parameterizing the grid warp. For example, a 2D affine transformation will be defined by the [6] parameters populating the corresponding 2x3 affine matrix. name: Name of Module. **kwargs: Extra kwargs to be forwarded to the `create_features` function, instantiating the source grid parameters. Raises: Error: If `len(output_shape) > len(source_shape)`. TypeError: If `output_shape` and `source_shape` are not both iterable. """ super(GridWarperLayer, self).__init__(name=name) self._source_shape = tuple(source_shape) self._output_shape = tuple(output_shape) if len(self._output_shape) > len(self._source_shape): tf.logging.fatal( 'Output domain dimensionality (%s) must be equal or ' 'smaller than source domain dimensionality (%s)', len(self._output_shape), len(self._source_shape)) raise ValueError self._coeff_shape = coeff_shape self._psi = self._create_features(**kwargs)
def _create_features(self, **kwargs): """ Precomputes features (e.g. sampling patterns, unconstrained feature matrices). """ tf.logging.fatal('_create_features() should be implemented') raise NotImplementedError
[docs] def layer_op(self, *args, **kwargs): tf.logging.fatal('layer_op() should be implemented to warp self._psi') raise NotImplementedError
@property def coeff_shape(self): """Returns number of coefficients of warping function.""" return self._coeff_shape @property def psi(self): """Returns a list of features used to compute the grid warp.""" return self._psi @property def source_shape(self): """Returns a tuple containing the shape of the source signal.""" return self._source_shape @property def output_shape(self): """Returns a tuple containing the shape of the output grid.""" return self._output_shape
[docs]class AffineGridWarperLayer(GridWarperLayer, Invertible): """ Affine Grid Warper class. The affine grid warper generates a reference grid of n-dimensional points and warps it via an affine transformation model determined by an input parameter Tensor. Some of the transformation parameters can be fixed at construction time via an `AffineWarpConstraints` object. """
[docs] def __init__(self, source_shape, output_shape, constraints=None, name='affine_grid_warper'): """Constructs an AffineGridWarper. `source_shape` and `output_shape` are used to define shape of source and output signal domains, as opposed to the shape of the respective Tensors. For example, for an image of size `width=W` and `height=H`, `{source,output}_shape=[H, W]`; for a volume of size `width=W`, `height=H` and `depth=D`, `{source,output}_shape=[H, W, D]`. Args: source_shape: Iterable of integers determining shape of source signal domain. output_shape: Iterable of integers determining shape of destination resampled signal domain. constraints: Either a double list of shape `[N, N+1]` defining constraints on the entries of a matrix defining an affine transformation in N dimensions, or an `AffineWarpConstraints` object. If the double list is passed, a numeric value bakes in a constraint on the corresponding entry in the transformation matrix, whereas `None` implies that the corresponding entry will be specified at run time. name: Name of module. Raises: Error: If constraints fully define the affine transformation; or if input grid shape and constraints have different dimensionality. TypeError: If output_shape and source_shape are not both iterable. """ self._source_shape = tuple(source_shape) self._output_shape = tuple(output_shape) num_dim = len(source_shape) if isinstance(constraints, AffineWarpConstraints): self._constraints = constraints elif constraints is None: self._constraints = AffineWarpConstraints.no_constraints(num_dim) else: self._constraints = AffineWarpConstraints(constraints=constraints) if self._constraints.num_free_params == 0: tf.logging.fatal('Transformation is fully constrained.') raise ValueError if self._constraints.num_dim != num_dim: tf.logging.fatal('Incompatible set of constraints provided: ' 'input grid shape and constraints have different ' 'dimensionality.') raise ValueError GridWarperLayer.__init__( self, source_shape=source_shape, output_shape=output_shape, coeff_shape=[6], name=name, constraints=self._constraints)
def _create_features(self, constraints): """ Creates all the matrices needed to compute the output warped grids. """ affine_warp_constraints = constraints if not isinstance(affine_warp_constraints, AffineWarpConstraints): affine_warp_constraints = AffineWarpConstraints(constraints) psi = _create_affine_features(output_shape=self._output_shape, source_shape=self._source_shape, relative=True) psi = np.asarray(psi) scales = [(x - 1.0) * .5 for x in self._source_shape] offsets = scales # Transforming a point x's i-th coordinate via an affine transformation # is performed via the following dot product: # # x_i' = s_i * (T_i * x) + t_i (1) # # where Ti is the i-th row of an affine matrix, and the scalars # s_i and t_i define a decentering and global scaling into # the source space. # # In the AffineGridWarper some of the entries of Ti are provided via the # input, some others are instead fixed, according to the constraints # assigned in the constructor. # In create_features the internal dot product (1) is accordingly # broken down into two parts: # # x_i' = Ti[uncon_i] * x[uncon_i, :] + offset(con_var) (2) # # i.e. the sum of the dot product of the free parameters (coming # from the input) indexed by uncond_i and an offset obtained by # precomputing the fixed part of (1) according to the constraints. # This step is implemented by analyzing row by row # the constraints matrix and saving into a list # the x[uncon_i] and offset(con_var) data matrices # for each output dimension. # # constraint -- None, indicates dynamic element of in the affine mat. spatial_rank = len(self._source_shape) features = [] # computes dynamic elements in the affine for i in range(spatial_rank): is_fixed = affine_warp_constraints[i] x_i = np.array( [x for x, fixed_var in zip(psi, is_fixed) if fixed_var is None]) features.append(x_i * scales[i] if len(x_i) else None) # computes fixed elements in the affine for i in range(spatial_rank): all_elements = np.asarray(affine_warp_constraints[i]) dynamic_elements = all_elements == np.array(None) if np.all(dynamic_elements): x_i = None else: all_elements[dynamic_elements] = 0.0 x_i = np.dot(all_elements, psi) * scales[i] features.append(x_i) # appending global offsets to the list features = features + offsets return features @property def constraints(self): return self._constraints
[docs] def layer_op(self, inputs): """Assembles the module network and adds it to the graph. The internal computation graph is assembled according to the set of constraints provided at construction time. inputs shape: batch_size x num_free_params Args: inputs: Tensor containing a batch of transformation parameters. Returns: A batch of warped grids. Raises: Error: If the input tensor size is not consistent with the constraints passed at construction time. """ inputs = tf.to_float(inputs) batch_size, number_of_params = list(inputs.shape) input_dtype = inputs.dtype.as_numpy_dtype if number_of_params != self._constraints.num_free_params: tf.logging.fatal( 'Input size is not consistent with constraint ' 'definition: (N, %s) parameters expected ' '(where N is the batch size; > 1), but %s provided.', self._constraints.num_free_params, inputs.shape) raise ValueError spatial_rank = len(self._source_shape) warped_grid = [] var_index_offset = 0 for i in range(spatial_rank): if self._psi[i] is not None: # The i-th output dimension is not fully specified # by the constraints, the graph is setup to perform # matrix multiplication in batch mode. grid_coord = self._psi[i].astype(input_dtype) num_active_vars = self._psi[i].shape[0] var_start = var_index_offset var_index_offset += num_active_vars warped_coord = tf.matmul( inputs[:, var_start:var_index_offset], grid_coord) offset = self._psi[spatial_rank + i] if offset is not None: offset = offset.astype(input_dtype) # Some entries in the i-th row # of the affine matrix were constrained # and the corresponding matrix # multiplications have been precomputed. tiling_params = tf.concat([ [batch_size], tf.ones_like(offset.shape)], 0) offset = np.expand_dims(offset, 0) warped_coord += tf.tile(offset, tiling_params) else: # The i-th output dimension is fully specified # by the constraints, and the corresponding matrix # multiplications have been precomputed. warped_coord = \ self._psi[spatial_rank + i].astype(input_dtype) tiling_params = tf.concat([ [batch_size], tf.ones_like(warped_coord.shape)], 0) warped_coord = np.expand_dims(warped_coord, 0) warped_coord = tf.tile(warped_coord, tiling_params) # update global offset warped_coord = warped_coord + self._psi[i + 2 * spatial_rank] # Need to help TF figuring out shape inference # since tiling information # is held in Tensors which are not known until run time. warped_coord.set_shape([batch_size, np.prod(self._output_shape)]) warped_grid.append(warped_coord) # Reshape all the warped coordinates tensors to # match the specified output # shape and concatenate into a single matrix. warped_grid = [tf.reshape(grid, (batch_size,) + self._output_shape) for grid in warped_grid] return tf.stack(warped_grid, -1)
[docs] def inverse_op(self, name=None): """ Returns a layer to compute inverse affine transforms. The function first assembles a network that given the constraints of the current AffineGridWarper and a set of input parameters, retrieves the coefficients of the corresponding inverse affine transform, then feeds its output into a new AffineGridWarper setup to correctly warp the `output` space into the `source` space. Args: name: Name of module implementing the inverse grid transformation. Returns: A `sonnet` module performing the inverse affine transform of a reference grid of points via an AffineGridWarper module. Raises: tf.errors.UnimplementedError: If the function is called on a non 2D instance of AffineGridWarper. """ if self._coeff_shape != [6]: tf.logging.fatal('AffineGridWarper currently supports' 'inversion only for the 2D case.') raise NotImplementedError def _affine_grid_warper_inverse(inputs): """Assembles network to compute inverse affine transformation. Each `inputs` row potentially contains [a, b, tx, c, d, ty] corresponding to an affine matrix: A = [a, b, tx], [c, d, ty] We want to generate a tensor containing the coefficients of the corresponding inverse affine transformation in a constraints-aware fashion. Calling M: M = [a, b] [c, d] the affine matrix for the inverse transform is: A_in = [M^(-1), M^-1 * [-tx, -tx]^T] where M^(-1) = (ad - bc)^(-1) * [ d, -b] [-c, a] Args: inputs: Tensor containing a batch of transformation parameters. Returns: A tensorflow graph performing the inverse affine transformation parametrized by the input coefficients. """ batch_size = tf.expand_dims(tf.shape(inputs)[0], 0) constant_shape = tf.concat( [batch_size, tf.convert_to_tensor((1,))], 0) index = iter(range(6)) def get_variable(constraint): if constraint is None: i = next(index) return inputs[:, i:i + 1] else: return tf.fill(constant_shape, tf.constant(constraint, dtype=inputs.dtype)) constraints = chain.from_iterable(self.constraints) a, b, tx, c, d, ty = (get_variable(constr) for constr in constraints) det = a * d - b * c a_inv = d / det b_inv = -b / det c_inv = -c / det d_inv = a / det m_inv = tf.reshape( tf.concat([a_inv, b_inv, c_inv, d_inv], 1), [-1, 2, 2]) txy = tf.expand_dims(tf.concat([tx, ty], 1), 2) txy_inv = tf.reshape(tf.matmul(m_inv, txy), [-1, 2]) tx_inv = txy_inv[:, 0:1] ty_inv = txy_inv[:, 1:2] inverse_gw_inputs = tf.concat( [a_inv, b_inv, -tx_inv, c_inv, d_inv, -ty_inv], 1) agw = AffineGridWarperLayer(self.output_shape, self.source_shape) return agw(inverse_gw_inputs) # pylint: disable=not-callable if name is None: name = self.name + '_inverse' return LayerFromCallable(_affine_grid_warper_inverse, name=name)
[docs]class AffineWarpConstraints(object): """Affine warp constraints class. `AffineWarpConstraints` allow for very succinct definitions of constraints on the values of entries in affine transform matrices. """
[docs] def __init__(self, constraints=((None,) * 3,) * 2): """Creates a constraint definition for an affine transformation. Args: constraints: A doubly-nested iterable of shape `[N, N+1]` defining constraints on the entries of a matrix that represents an affine transformation in `N` dimensions. A numeric value bakes in a constraint on the corresponding entry in the transformation matrix, whereas `None` implies that the corresponding entry will be specified at run time. Raises: TypeError: If `constraints` is not a nested iterable. ValueError: If the double iterable `constraints` has inconsistent dimensions. """ try: self._constraints = tuple(tuple(x) for x in constraints) except TypeError: tf.logging.fatal('constraints must be a nested iterable.') raise TypeError # Number of rows self._num_dim = len(self._constraints) expected_num_cols = self._num_dim + 1 if any(len(x) != expected_num_cols for x in self._constraints): tf.logging.fatal( 'The input list must define a Nx(N+1) matrix of constraints.') raise ValueError
def _calc_num_free_params(self): """Computes number of non constrained parameters.""" return sum(row.count(None) for row in self._constraints) @property def num_free_params(self): return self._calc_num_free_params() @property def constraints(self): return self._constraints @property def num_dim(self): return self._num_dim def __getitem__(self, i): """ Returns the list of constraints for the i-th row of the affine matrix. """ return self._constraints[i] def _combine(self, x, y): """ Combines two constraints, raising an error if they are not compatible. """ if x is None or y is None: return x or y if x != y: tf.logging.fatal('Incompatible set of constraints provided.') raise ValueError return x def __and__(self, rhs): """Combines two sets of constraints into a coherent single set.""" return self.combine_with(rhs)
[docs] def combine_with(self, additional_constraints): """Combines two sets of constraints into a coherent single set.""" x = additional_constraints if not isinstance(additional_constraints, AffineWarpConstraints): x = AffineWarpConstraints(additional_constraints) new_constraints = [] for left, right in zip(self._constraints, x.constraints): new_constraints.append( [self._combine(x, y) for x, y in zip(left, right)]) return AffineWarpConstraints(new_constraints)
# Collection of utilities to initialize an AffineGridWarper in 2D and 3D.
[docs] @classmethod def no_constraints(cls, num_dim=2): """ Empty set of constraints for a num_dim affine transform. """ return cls(((None,) * (num_dim + 1),) * num_dim)
[docs] @classmethod def translation_2d(cls, x=None, y=None): """ Assign constraints on translation components of affine transform in 2d. """ return cls([[None, None, x], [None, None, y]])
[docs] @classmethod def translation_3d(cls, x=None, y=None, z=None): """ Assign constraints on translation components of affine transform in 3d. """ return cls([[None, None, None, x], [None, None, None, y], [None, None, None, z]])
[docs] @classmethod def scale_2d(cls, x=None, y=None): """ Assigns constraints on scaling components of affine transform in 2d. """ return cls([[x, None, None], [None, y, None]])
[docs] @classmethod def scale_3d(cls, x=None, y=None, z=None): """ Assigns constraints on scaling components of affine transform in 3d. """ return cls([[x, None, None, None], [None, y, None, None], [None, None, z, None]])
[docs] @classmethod def shear_2d(cls, x=None, y=None): """ Assigns constraints on shear components of affine transform in 2d. """ return cls([[None, x, None], [y, None, None]])
[docs] @classmethod def no_shear_2d(cls): return cls.shear_2d(x=0, y=0)
[docs] @classmethod def no_shear_3d(cls): """ Assigns constraints on shear components of affine transform in 3d. """ return cls([[None, 0, 0, None], [0, None, 0, None], [0, 0, None, None]])
def _create_affine_features(output_shape, source_shape, relative=False): """ Generates n-dimensional homogeneous coordinates for a given grid definition. `source_shape` and `output_shape` are used to define the size of the source and output signal domains. For example, for an image of size `width=W` and `height=H`, `{source,output}_shape=[H, W]`; for a volume of size `width=W`, `height=H` and `depth=D`, `{source,output}_shape=[H, W, D]`. Note returning in Matrix indexing 'ij' Args: output_shape: Iterable of integers determining the shape of the grid to be warped. source_shape: Iterable of integers determining the domain of the signal to be resampled. Returns: List of flattened numpy arrays of coordinates When the dimensionality of `output_shape` is smaller that that of `source_shape` the last rows before [1, ..., 1] will be filled with -1. """ dim_gap = len(source_shape) - len(output_shape) embedded_output_shape = list(output_shape) + [1] * dim_gap if not relative: ranges = [np.arange(dim, dtype=np.float32) for dim in embedded_output_shape] else: ranges = [np.linspace(-1., 1., x, dtype=np.float32) for x in embedded_output_shape] ranges.append(np.array([1.0])) return [x.ravel() for x in np.meshgrid(*ranges, indexing='ij')]