Source code for niftynet.layer.crop

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf

from niftynet.layer import layer_util
from niftynet.layer.base_layer import Layer


[docs]class CropLayer(Layer): """ This class defines a cropping operation: Removing `2*border` pixels from each spatial dim of the input, and return the spatially centred elements extracted from the input. This function is implemented with a convolution in the `valid` mode with a trivial kernel """ def __init__(self, border, name='crop'): super(CropLayer, self).__init__(name=name) self.border = border
[docs] def layer_op(self, inputs): spatial_rank = layer_util.infer_spatial_rank(inputs) kernel_shape = np.hstack(( [self.border * 2 + 1] * spatial_rank, 1, 1)).flatten() # initializer a kernel with all 0s, and set the central element to 1 np_kernel = layer_util.trivial_kernel(kernel_shape) crop_kernel = tf.constant(np_kernel, dtype=inputs.dtype) # split channel dim output_tensor = [tf.expand_dims(x, -1) for x in tf.unstack(inputs, axis=-1)] output_tensor = [tf.nn.convolution(input=inputs, filter=crop_kernel, strides=[1] * spatial_rank, padding='VALID', name='conv') for inputs in output_tensor] output_tensor = tf.concat(output_tensor, axis=-1) return output_tensor