Source code for niftynet.contrib.niftyreg_image_resampling.niftyreg_image_resampling

from __future__ import print_function, division

import tensorflow as tf
from tensorflow.python.framework import ops
from niftynet.layer.base_layer import Layer
from niftynet.layer.layer_util import infer_spatial_rank

from niftynet.contrib.niftyreg_image_resampling.niftyreg_module_loader import get_niftyreg_module


# NiftyNet boundary types to NiftyReg code mapping
__BOUNDARY_CODES__ = {
    'ZERO': 0,
    'NAN': 1,
    'REPLICATE': 2,
    'SYMMETRIC': 3
}


# Exposure of supported boundary types for compat. w/ ResamplerLayer
SUPPORTED_BOUNDARY = {k for k in __BOUNDARY_CODES__}


# NiftyNet interpolation types to NiftyReg code mapping
__INTERPOLATION_CODES__ = {'NEAREST': 0,
                           'LINEAR': 1,
                           'BSPLINE': 3}


# Exposure of supported interpolation types for compat. w/ ResamplerLayer
SUPPORTED_INTERPOLATION = {k for k in __INTERPOLATION_CODES__}


# NiftyReg expects displacement components to be
# indexed w/ slowest index
def _transpose(data):
    nof_dims = len(data.shape) - 1
    perm = [0] + list(range(nof_dims, 0, -1))
    perm += list(range(nof_dims + 1, len(data.shape)))
    assert len(perm) == len(data.shape)

    return tf.transpose(data, perm)


@ops.RegisterGradient("NiftyregImageResampling")
def _niftyreg_resampling_grad(op, grad):
    grad_op = get_niftyreg_module().niftyreg_image_resampling_gradient(
        op.inputs[0],
        op.inputs[1],
        interpolation=op.get_attr('interpolation'),
        boundary=op.get_attr('boundary'))

    chained_grad = None

    nof_modalities = op.inputs[0].shape.as_list()[1]
    if not nof_modalities is None and nof_modalities != 1:
        nof_dims = op.inputs[1].shape.as_list()[1]

        assert grad_op.shape.as_list()[1] == nof_modalities*nof_dims

        chained_grads = []
        for m in range(nof_modalities):
            mod_grad = grad_op[:,(nof_dims*m):((m+1)*nof_dims),...]
            out_mod_grad = tf.expand_dims(grad[:,m,...], axis=1)

            out_grad = tf.tile(out_mod_grad, [1] + [nof_dims]
                               + [1]*(len(grad_op.shape) - 2))

            chained_grads.append(tf.multiply(mod_grad, out_grad))

        chained_grad = tf.reduce_sum(tf.stack(chained_grads, axis=0), axis=0)

    else:
        grad_rep = tf.tile(grad, [1] + [grad_op.shape[1]]
                           + [1]*(len(grad_op.shape) - 2))
        chained_grad = tf.multiply(grad_rep, grad_op)

    image_grad_op \
        = get_niftyreg_module().niftyreg_image_resampling_image_gradient(
            op.inputs[0],
            op.inputs[1],
            grad,
            interpolation=op.get_attr('interpolation'),
            boundary=op.get_attr('boundary'))

    return [image_grad_op, chained_grad]


[docs]class NiftyregImageResamplingLayer(Layer): def __init__(self, interpolation, boundary='ZERO', **kwargs): super(NiftyregImageResamplingLayer, self).__init__(**kwargs) self._interpolation = __INTERPOLATION_CODES__[interpolation.upper()] self._boundary = boundary.upper()
[docs] def layer_op(self, inputs, deformation, **kwargs): nof_dims = infer_spatial_rank(inputs) nof_output_dims = infer_spatial_rank(deformation) batch_size = inputs.shape.as_list()[0] if deformation.shape.as_list()[0] != batch_size: deformation = tf.tile(deformation, [batch_size] + [1]*(nof_output_dims + 1)) output_spatial_dims = deformation.shape.as_list()[1:-1] input_dims = [d if d else -1 for d in inputs.shape.as_list()] if len(output_spatial_dims) != nof_dims: resample_def = deformation while len(resample_def.shape) < len(inputs.shape): resample_def = tf.expand_dims(resample_def, axis=len(resample_def.shape) - 2) else: resample_def = deformation assert infer_spatial_rank(resample_def) == nof_dims resampled = get_niftyreg_module().niftyreg_image_resampling( _transpose(inputs), _transpose(resample_def), interpolation=self._interpolation, boundary=__BOUNDARY_CODES__[self._boundary]) return tf.reshape( _transpose(resampled), [batch_size] + output_spatial_dims + [input_dims[-1]])