Source code for niftynet.layer.deconvolution

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf

from niftynet.layer import layer_util
from niftynet.layer.activation import ActiLayer
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.bn import BNLayer, InstanceNormLayer
from niftynet.layer.gn import GNLayer
from niftynet.utilities.util_common import look_up_operations

SUPPORTED_OP = {
    '2D': tf.nn.conv2d_transpose,
    '3D': tf.nn.conv3d_transpose}
SUPPORTED_PADDING = set(['SAME', 'VALID'])


[docs]def default_w_initializer(): def _initializer(shape, dtype, partition_info): stddev = np.sqrt(2.0 / (np.prod(shape[:-2]) * shape[-1])) from tensorflow.python.ops import random_ops return random_ops.truncated_normal(shape, 0.0, stddev, dtype=tf.float32) # return tf.truncated_normal_initializer( # mean=0.0, stddev=stddev, dtype=tf.float32) return _initializer
[docs]def default_b_initializer(): return tf.constant_initializer(0.0)
[docs]def infer_output_dims(input_dims, strides, kernel_sizes, padding): """ infer output dims from list, the dim can be different in different directions. Note: dilation is not considered here. """ assert len(input_dims) == len(strides) assert len(input_dims) == len(kernel_sizes) output_dims = [] for (i, dim) in enumerate(input_dims): if dim is None: output_dims.append(None) continue if padding == 'VALID': output_dims.append( dim * strides[i] + max(kernel_sizes[i] - strides[i], 0)) else: output_dims.append(dim * strides[i]) return output_dims
[docs]class DeconvLayer(TrainableLayer): """ This class defines a simple deconvolution with an optional bias term. Please consider ``DeconvolutionalLayer`` if batch_norm and activation are also used. """ def __init__(self, n_output_chns, kernel_size=3, stride=1, padding='SAME', with_bias=False, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, name='deconv'): super(DeconvLayer, self).__init__(name=name) self.padding = look_up_operations(padding.upper(), SUPPORTED_PADDING) self.n_output_chns = int(n_output_chns) self.kernel_size = kernel_size self.stride = stride self.with_bias = with_bias self.initializers = { 'w': w_initializer if w_initializer else default_w_initializer(), 'b': b_initializer if b_initializer else default_b_initializer()} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply kernel_size_all_dim = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) w_full_size = kernel_size_all_dim + (self.n_output_chns, n_input_chns) stride_all_dim = layer_util.expand_spatial_params( self.stride, spatial_rank) full_stride = (1,) + stride_all_dim + (1,) deconv_kernel = tf.get_variable( 'w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) if spatial_rank == 2: op_ = SUPPORTED_OP['2D'] elif spatial_rank == 3: op_ = SUPPORTED_OP['3D'] else: raise ValueError( "Only 2D and 3D spatial deconvolutions are supported") spatial_shape = [] for (i, dim) in enumerate(input_shape[:-1]): if i == 0: continue if dim is None: spatial_shape.append(tf.shape(input_tensor)[i]) else: spatial_shape.append(dim) output_dims = infer_output_dims(spatial_shape, stride_all_dim, kernel_size_all_dim, self.padding) if input_tensor.shape.is_fully_defined(): full_output_size = \ [input_shape[0]] + output_dims + [self.n_output_chns] else: batch_size = tf.shape(input_tensor)[0] full_output_size = tf.stack( [batch_size] + output_dims + [self.n_output_chns]) output_tensor = op_(value=input_tensor, filter=deconv_kernel, output_shape=full_output_size, strides=full_stride, padding=self.padding, name='deconv') if not self.with_bias: return output_tensor # adding the bias term bias_full_size = (self.n_output_chns,) bias_term = tf.get_variable( 'b', shape=bias_full_size, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
[docs]class DeconvolutionalLayer(TrainableLayer): """ This class defines a composite layer with optional components:: deconvolution -> batch_norm -> activation -> dropout The b_initializer and b_regularizer are applied to the DeconvLayer The w_initializer and w_regularizer are applied to the DeconvLayer, the batch normalisation layer, and the activation layer (for 'prelu') """ def __init__(self, n_output_chns, kernel_size=3, stride=1, padding='SAME', with_bias=False, feature_normalization='batch', group_size=-1, acti_func=None, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, moving_decay=0.9, eps=1e-5, name="deconv"): self.acti_func = acti_func self.feature_normalization = feature_normalization self.group_size = group_size self.layer_name = '{}'.format(name) if self.feature_normalization != 'group' and group_size > 0: raise ValueError('You cannot have a group_size > 0 if not using group norm') elif self.feature_normalization == 'group' and group_size <= 0: raise ValueError('You cannot have a group_size <= 0 if using group norm') if self.feature_normalization is not None: # appending, for example, '_bn' to the name self.layer_name += '_' + self.feature_normalization[0] + 'n' if self.acti_func is not None: self.layer_name += '_{}'.format(self.acti_func) super(DeconvolutionalLayer, self).__init__(name=self.layer_name) # for DeconvLayer self.n_output_chns = n_output_chns self.kernel_size = kernel_size self.stride = stride self.padding = padding self.with_bias = with_bias # for BNLayer self.moving_decay = moving_decay self.eps = eps self.initializers = { 'w': w_initializer if w_initializer else default_w_initializer(), 'b': b_initializer if b_initializer else default_b_initializer()} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor, is_training=None, keep_prob=None): # init sub-layers deconv_layer = DeconvLayer(n_output_chns=self.n_output_chns, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, with_bias=self.with_bias, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], b_initializer=self.initializers['b'], b_regularizer=self.regularizers['b'], name='deconv_') output_tensor = deconv_layer(input_tensor) if self.feature_normalization == 'batch': if is_training is None: raise ValueError('is_training argument should be ' 'True or False unless feature_normalization is False') bn_layer = BNLayer( regularizer=self.regularizers['w'], moving_decay=self.moving_decay, eps=self.eps, name='bn_') output_tensor = bn_layer(output_tensor, is_training) elif self.feature_normalization == 'instance': in_layer = InstanceNormLayer(eps=self.eps, name='in_') output_tensor = in_layer(output_tensor) elif self.feature_normalization == 'group': gn_layer = GNLayer( regularizer=self.regularizers['w'], group_size=self.group_size, eps=self.eps, name='gn_') output_tensor = gn_layer(output_tensor) if self.acti_func is not None: acti_layer = ActiLayer( func=self.acti_func, regularizer=self.regularizers['w'], name='acti_') output_tensor = acti_layer(output_tensor) if keep_prob is not None: dropout_layer = ActiLayer(func='dropout', name='dropout_') output_tensor = dropout_layer(output_tensor, keep_prob=keep_prob) return output_tensor