Source code for niftynet.layer.upsample

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf

from niftynet.layer import layer_util
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.deconvolution import DeconvLayer
from niftynet.utilities.util_common import look_up_operations

SUPPORTED_OP = set(['REPLICATE', 'CHANNELWISE_DECONV'])


[docs]class UpSampleLayer(TrainableLayer): """ This class defines channel-wise upsampling operations. Different from ``DeconvLayer``, the elements are not mixed in the channel dim. ``REPLICATE`` mode replicates each spatial_dim into ``spatial_dim*kernel_size`` `CHANNELWISE_DECONV`` mode makes a projection using a kernel. e.g., With 2D input (without loss of generality), given input ``[N, X, Y, C]``, the output is ``[N, X*kernel_size, Y*kernel_size, C]``. """ def __init__(self, func, kernel_size=3, stride=2, w_initializer=None, w_regularizer=None, with_bias=False, b_initializer=None, b_regularizer=None, name='upsample'): self.func = look_up_operations(func.upper(), SUPPORTED_OP) self.layer_name = '{}_{}'.format(self.func.lower(), name) super(UpSampleLayer, self).__init__(name=self.layer_name) self.kernel_size = kernel_size self.stride = stride self.with_bias = with_bias self.initializers = {'w': w_initializer, 'b': b_initializer} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor): spatial_rank = layer_util.infer_spatial_rank(input_tensor) output_tensor = input_tensor if self.func == 'REPLICATE': if self.kernel_size != self.stride: raise ValueError( "`kernel_size` != `stride` currently not" "supported in `REPLICATE` mode. Please" "consider using `CHANNELWISE_DECONV` operation.") # simply replicate input values to # local regions of (kernel_size ** spatial_rank) element kernel_size_all_dims = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) pixel_num = np.prod(kernel_size_all_dims) repmat = np.hstack((pixel_num, [1] * spatial_rank, 1)).flatten() output_tensor = tf.tile(input=input_tensor, multiples=repmat) output_tensor = tf.batch_to_space_nd( input=output_tensor, block_shape=kernel_size_all_dims, crops=[[0, 0]] * spatial_rank) elif self.func == 'CHANNELWISE_DECONV': output_tensor = [tf.expand_dims(x, -1) for x in tf.unstack(input_tensor, axis=-1)] output_tensor = [DeconvLayer(n_output_chns=1, kernel_size=self.kernel_size, stride=self.stride, padding='SAME', with_bias=self.with_bias, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], b_initializer=self.initializers['b'], b_regularizer=self.regularizers['b'], name='deconv_{}'.format(i))(x) for (i, x) in enumerate(output_tensor)] output_tensor = tf.concat(output_tensor, axis=-1) return output_tensor