Source code for niftynet.layer.deconvolution

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf

from niftynet.layer import layer_util
from niftynet.layer.activation import ActiLayer
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.bn import BNLayer
from niftynet.utilities.util_common import look_up_operations

SUPPORTED_OP = {'2D': tf.nn.conv2d_transpose,
                '3D': tf.nn.conv3d_transpose}
SUPPORTED_PADDING = set(['SAME', 'VALID'])


[docs]def default_w_initializer(): def _initializer(shape, dtype, partition_info): stddev = np.sqrt(2.0 / (np.prod(shape[:-2]) * shape[-1])) from tensorflow.python.ops import random_ops return random_ops.truncated_normal(shape, 0.0, stddev, dtype=tf.float32) # return tf.truncated_normal_initializer( # mean=0.0, stddev=stddev, dtype=tf.float32) return _initializer
[docs]def default_b_initializer(): return tf.constant_initializer(0.0)
[docs]def infer_output_dims(input_dims, strides, kernel_sizes, padding): """ infer output dims from list, the dim can be different in different directions. Note: dilation is not considered here. """ assert len(input_dims) == len(strides) assert len(input_dims) == len(kernel_sizes) if padding == 'VALID': output_dims = [ dim * strides[i] + max(kernel_sizes[i] - strides[i], 0) for (i, dim) in enumerate(input_dims)] else: output_dims = [dim * strides[i] for (i, dim) in enumerate(input_dims)] return output_dims
[docs]class DeconvLayer(TrainableLayer): """ This class defines a simple deconvolution with an optional bias term. Please consider ``DeconvolutionalLayer`` if batch_norm and activation are also used. """ def __init__(self, n_output_chns, kernel_size=3, stride=1, padding='SAME', with_bias=False, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, name='deconv'): super(DeconvLayer, self).__init__(name=name) self.padding = look_up_operations(padding.upper(), SUPPORTED_PADDING) self.n_output_chns = int(n_output_chns) self.kernel_size = kernel_size self.stride = stride self.with_bias = with_bias self.initializers = { 'w': w_initializer if w_initializer else default_w_initializer(), 'b': b_initializer if b_initializer else default_b_initializer()} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor): input_shape = input_tensor.get_shape().as_list() n_input_chns = input_shape[-1] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply kernel_size_all_dim = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) w_full_size = kernel_size_all_dim + (self.n_output_chns, n_input_chns) stride_all_dim = layer_util.expand_spatial_params( self.stride, spatial_rank) full_stride = (1,) + stride_all_dim + (1,) deconv_kernel = tf.get_variable( 'w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) if spatial_rank == 2: op_ = SUPPORTED_OP['2D'] elif spatial_rank == 3: op_ = SUPPORTED_OP['3D'] else: raise ValueError( "Only 2D and 3D spatial deconvolutions are supported") output_dims = infer_output_dims(input_shape[1:-1], stride_all_dim, kernel_size_all_dim, self.padding) full_output_size = [input_shape[0]] + output_dims + [self.n_output_chns] output_tensor = op_(value=input_tensor, filter=deconv_kernel, output_shape=full_output_size, strides=full_stride, padding=self.padding, name='deconv') if not self.with_bias: return output_tensor # adding the bias term bias_full_size = (self.n_output_chns,) bias_term = tf.get_variable( 'b', shape=bias_full_size, initializer=self.initializers['b'], regularizer=self.regularizers['b']) output_tensor = tf.nn.bias_add(output_tensor, bias_term, name='add_bias') return output_tensor
[docs]class DeconvolutionalLayer(TrainableLayer): """ This class defines a composite layer with optional components:: deconvolution -> batch_norm -> activation -> dropout The b_initializer and b_regularizer are applied to the DeconvLayer The w_initializer and w_regularizer are applied to the DeconvLayer, the batch normalisation layer, and the activation layer (for 'prelu') """ def __init__(self, n_output_chns, kernel_size=3, stride=1, padding='SAME', with_bias=False, with_bn=True, acti_func=None, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, moving_decay=0.9, eps=1e-5, name="deconv"): self.acti_func = acti_func self.with_bn = with_bn self.layer_name = '{}'.format(name) if self.with_bn: self.layer_name += '_bn' if self.acti_func is not None: self.layer_name += '_{}'.format(self.acti_func) super(DeconvolutionalLayer, self).__init__(name=self.layer_name) # for DeconvLayer self.n_output_chns = n_output_chns self.kernel_size = kernel_size self.stride = stride self.padding = padding self.with_bias = with_bias # for BNLayer self.moving_decay = moving_decay self.eps = eps self.initializers = { 'w': w_initializer if w_initializer else default_w_initializer(), 'b': b_initializer if b_initializer else default_b_initializer()} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor, is_training=None, keep_prob=None): # init sub-layers deconv_layer = DeconvLayer(n_output_chns=self.n_output_chns, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, with_bias=self.with_bias, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], b_initializer=self.initializers['b'], b_regularizer=self.regularizers['b'], name='deconv_') output_tensor = deconv_layer(input_tensor) if self.with_bn: if is_training is None: raise ValueError('is_training argument should be ' 'True or False unless with_bn is False') bn_layer = BNLayer( regularizer=self.regularizers['w'], moving_decay=self.moving_decay, eps=self.eps, name='bn_') output_tensor = bn_layer(output_tensor, is_training) if self.acti_func is not None: acti_layer = ActiLayer( func=self.acti_func, regularizer=self.regularizers['w'], name='acti_') output_tensor = acti_layer(output_tensor) if keep_prob is not None: dropout_layer = ActiLayer(func='dropout', name='dropout_') output_tensor = dropout_layer(output_tensor, keep_prob=keep_prob) return output_tensor