Source code for niftynet.layer.deconvolution

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf

from niftynet.layer import layer_util
from niftynet.layer.activation import ActiLayer
from niftynet.layer.base_layer import TrainableLayer
from import BNLayer
from import GNLayer
from niftynet.utilities.util_common import look_up_operations

    '2D': tf.nn.conv2d_transpose,
    '3D': tf.nn.conv3d_transpose}

[docs]def default_w_initializer(): def _initializer(shape, dtype, partition_info): stddev = np.sqrt(2.0 / ([:-2]) * shape[-1])) from tensorflow.python.ops import random_ops return random_ops.truncated_normal(shape, 0.0, stddev, dtype=tf.float32) # return tf.truncated_normal_initializer( # mean=0.0, stddev=stddev, dtype=tf.float32) return _initializer
[docs]def default_b_initializer(): return tf.constant_initializer(0.0)
[docs]def infer_output_dims(input_dims, strides, kernel_sizes, padding): """ infer output dims from list, the dim can be different in different directions. Note: dilation is not considered here. """ assert len(input_dims) == len(strides) assert len(input_dims) == len(kernel_sizes) output_dims = [] for (i, dim) in enumerate(input_dims): if dim is None: output_dims.append(None) continue if padding == 'VALID': output_dims.append( dim * strides[i] + max(kernel_sizes[i] - strides[i], 0)) else: output_dims.append(dim * strides[i]) return output_dims
[docs]class DeconvLayer(TrainableLayer): """ This class defines a simple deconvolution with an optional bias term. Please consider ``DeconvolutionalLayer`` if batch_norm and activation are also used. """ def __init__(self, n_output_chns, kernel_size=3, stride=1, padding='SAME', with_bias=False, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, name='deconv'): super(DeconvLayer, self).__init__(name=name) self.padding = look_up_operations(padding.upper(), SUPPORTED_PADDING) self.n_output_chns = int(n_output_chns) self.kernel_size = kernel_size self.stride = stride self.with_bias = with_bias self.initializers = { 'w': w_initializer if w_initializer else default_w_initializer(), 'b': b_initializer if b_initializer else default_b_initializer()} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor): input_shape = input_tensor.shape.as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply kernel_size_all_dim = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) w_full_size = kernel_size_all_dim + (self.n_output_chns, n_input_chns) stride_all_dim = layer_util.expand_spatial_params( self.stride, spatial_rank) full_stride = (1,) + stride_all_dim + (1,) deconv_kernel = tf.get_variable( 'w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) if spatial_rank == 2: op_ = SUPPORTED_OP['2D'] elif spatial_rank == 3: op_ = SUPPORTED_OP['3D'] else: raise ValueError( "Only 2D and 3D spatial deconvolutions are supported") spatial_shape = [] for (i, dim) in enumerate(input_shape[:-1]): if i == 0: continue if dim is None: spatial_shape.append(tf.shape(input_tensor)[i]) else: spatial_shape.append(dim) output_dims = infer_output_dims(spatial_shape, stride_all_dim, kernel_size_all_dim, self.padding) if input_tensor.shape.is_fully_defined(): full_output_size = \ [input_shape[0]] + output_dims + [self.n_output_chns] else: batch_size = tf.shape(input_tensor)[0] full_output_size = tf.stack( [batch_size] + output_dims + [self.n_output_chns]) output_tensor = op_(value=input_tensor, filter=deconv_kernel, output_shape=full_output_size, strides=full_stride, padding=self.padding, name='deconv') if not self.with_bias: return output_tensor # adding the bias term bias_full_size = (self.n_output_chns,) bias_term = tf.get_variable( 'b', shape=bias_full_size, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
[docs]class DeconvolutionalLayer(TrainableLayer): """ This class defines a composite layer with optional components:: deconvolution -> batch_norm -> activation -> dropout The b_initializer and b_regularizer are applied to the DeconvLayer The w_initializer and w_regularizer are applied to the DeconvLayer, the batch normalisation layer, and the activation layer (for 'prelu') """ def __init__(self, n_output_chns, kernel_size=3, stride=1, padding='SAME', with_bias=False, with_bn=True, group_size=-1, acti_func=None, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, moving_decay=0.9, eps=1e-5, name="deconv"): self.acti_func = acti_func self.with_bn = with_bn self.group_size = group_size self.layer_name = '{}'.format(name) if self.with_bn and self.group_size > 0: raise ValueError('only choose either batchnorm or groupnorm') if self.with_bn: self.layer_name += '_bn' if self.group_size > 0: self.layer_name += '_gn' if self.acti_func is not None: self.layer_name += '_{}'.format(self.acti_func) super(DeconvolutionalLayer, self).__init__(name=self.layer_name) # for DeconvLayer self.n_output_chns = n_output_chns self.kernel_size = kernel_size self.stride = stride self.padding = padding self.with_bias = with_bias # for BNLayer self.moving_decay = moving_decay self.eps = eps self.initializers = { 'w': w_initializer if w_initializer else default_w_initializer(), 'b': b_initializer if b_initializer else default_b_initializer()} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor, is_training=None, keep_prob=None): # init sub-layers deconv_layer = DeconvLayer(n_output_chns=self.n_output_chns, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, with_bias=self.with_bias, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], b_initializer=self.initializers['b'], b_regularizer=self.regularizers['b'], name='deconv_') output_tensor = deconv_layer(input_tensor) if self.with_bn: if is_training is None: raise ValueError('is_training argument should be ' 'True or False unless with_bn is False') bn_layer = BNLayer( regularizer=self.regularizers['w'], moving_decay=self.moving_decay, eps=self.eps, name='bn_') output_tensor = bn_layer(output_tensor, is_training) if self.group_size > 0: gn_layer = GNLayer( regularizer=self.regularizers['w'], group_size=self.group_size, eps=self.eps, name='gn_') output_tensor = gn_layer(output_tensor) if self.acti_func is not None: acti_layer = ActiLayer( func=self.acti_func, regularizer=self.regularizers['w'], name='acti_') output_tensor = acti_layer(output_tensor) if keep_prob is not None: dropout_layer = ActiLayer(func='dropout', name='dropout_') output_tensor = dropout_layer(output_tensor, keep_prob=keep_prob) return output_tensor