Source code for niftynet.layer.downsample
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.layer import layer_util
from niftynet.layer.base_layer import Layer
from niftynet.utilities.util_common import look_up_operations
SUPPORTED_OP = set(['AVG', 'MAX', 'CONSTANT'])
SUPPORTED_PADDING = set(['SAME', 'VALID'])
[docs]class DownSampleLayer(Layer):
def __init__(self,
func,
kernel_size=3,
stride=2,
padding='SAME',
name='pooling'):
self.func = func.upper()
self.layer_name = '{}_{}'.format(self.func.lower(), name)
super(DownSampleLayer, self).__init__(name=self.layer_name)
self.padding = padding.upper()
look_up_operations(self.padding, SUPPORTED_PADDING)
self.kernel_size = kernel_size
self.stride = stride
[docs] def layer_op(self, input_tensor):
spatial_rank = layer_util.infer_spatial_rank(input_tensor)
look_up_operations(self.func, SUPPORTED_OP)
kernel_size_all_dims = layer_util.expand_spatial_params(
self.kernel_size, spatial_rank)
stride_all_dims = layer_util.expand_spatial_params(
self.stride, spatial_rank)
if self.func == 'CONSTANT':
full_kernel_size = kernel_size_all_dims + (1, 1)
np_kernel = layer_util.trivial_kernel(full_kernel_size)
kernel = tf.constant(np_kernel, dtype=tf.float32)
output_tensor = [tf.expand_dims(x, -1)
for x in tf.unstack(input_tensor, axis=-1)]
output_tensor = [
tf.nn.convolution(
input=inputs,
filter=kernel,
strides=stride_all_dims,
padding=self.padding,
name='conv')
for inputs in output_tensor]
output_tensor = tf.concat(output_tensor, axis=-1)
else:
output_tensor = tf.nn.pool(
input=input_tensor,
window_shape=kernel_size_all_dims,
pooling_type=self.func,
padding=self.padding,
dilation_rate=[1] * spatial_rank,
strides=stride_all_dims,
name=self.layer_name)
return output_tensor