Source code for niftynet.layer.spatial_gradient
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.layer.base_layer import Layer
from niftynet.layer.layer_util import infer_spatial_rank
[docs]class SpatialGradientLayer(Layer):
"""
Computing image spatial gradients.
"""
def __init__(self,
spatial_axis=0,
do_cropping=True,
name='spatial_gradient'):
Layer.__init__(self, name=name)
self.spatial_axis = int(spatial_axis)
self.do_cropping = do_cropping
[docs] def layer_op(self, input_tensor):
"""
Computing spatial gradient of ``input_tensor`` along
``self.spatial_axis``.
output is equivalent to convolve along ``spatial_axis`` with a
kernel: ``[-1, 0, 1]``
This layer assumes the first and the last dimension of the input
tensor represent batch and feature channels.
Therefore ``spatial_axis=1`` is computing gradient along the
third dimension of input tensor, i.e., ``input_tensor[:, :, y, ...]``
Given the input with shape ``[B, X, Y, Z, C]``, and ``spatial_axis=1``
the output shape is::
[B, X-2, Y-2, Z-2, C] if do_scropping is True
[B, X, Y-2, Z, C] otherwise
Setting do_cropping to True makes the output tensor has the same
dimensionality for different ``spatial_axis``.
:param input_tensor: a batch of images with a shape of
``[Batch, x[, y, z, ... ], Channel]``
:return: spatial gradients of ``input_tensor``
"""
spatial_rank = infer_spatial_rank(input_tensor)
spatial_size = input_tensor.get_shape().as_list()[1:-1]
if self.do_cropping:
# remove two elements in all spatial dims
spatial_size = [size_x - 2 for size_x in spatial_size]
spatial_begins = [1] * spatial_rank
else:
# remove two elements along the gradient dim only
spatial_size[self.spatial_axis] = spatial_size[self.spatial_axis] -2
spatial_begins = [0] * spatial_rank
spatial_begins[self.spatial_axis] = 2
begins_0 = [0] + spatial_begins + [0]
spatial_begins[self.spatial_axis] = 0
begins_1 = [0] + spatial_begins + [0]
sizes_0 = [-1] + spatial_size + [-1]
sizes_1 = [-1] + spatial_size + [-1]
image_gradients = \
tf.slice(input_tensor, begins_0, sizes_0) - \
tf.slice(input_tensor, begins_1, sizes_1)
return image_gradients