Source code for niftynet.layer.crop
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.layer import layer_util
from niftynet.layer.base_layer import Layer
[docs]class CropLayer(Layer):
"""
This class defines a cropping operation:
Removing ``2*border`` pixels from each spatial dim of the input,
and return the spatially centred elements extracted from the input.
"""
def __init__(self, border, name='crop'):
super(CropLayer, self).__init__(name=name)
self.border = border
[docs] def layer_op(self, inputs):
spatial_rank = layer_util.infer_spatial_rank(inputs)
offsets = [0] + [int(self.border)] * spatial_rank + [0]
# inferring the shape of the output by subtracting the border dimension
out_shape = [
int(d) - 2 * int(self.border)
for d in list(inputs.shape)[1:-1]]
out_shape = [-1] + out_shape + [-1]
output_tensor = tf.slice(inputs, offsets, out_shape)
return output_tensor