Source code for niftynet.layer.channel_sparse_convolution

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import math

import tensorflow as tf
from tensorflow.python.training import moving_averages

import numpy as np

import niftynet.layer.convolution
import niftynet.layer.deconvolution
import niftynet.layer.bn
from niftynet.utilities.util_common import look_up_operations
from niftynet.layer import layer_util
from niftynet.layer.activation import ActiLayer
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.deconvolution import infer_output_dims

SUPPORTED_OP = {'2D': tf.nn.conv2d_transpose,
                '3D': tf.nn.conv3d_transpose}
[docs]class ChannelSparseDeconvLayer(niftynet.layer.deconvolution.DeconvLayer): """ Channel sparse convolutions perform convolulations over a subset of image channels and generate a subset of output channels. This enables spatial dropout without wasted computations """ def __init__(self,*args,**kwargs): super(ChannelSparseDeconvLayer,self).__init__(*args,**kwargs)
[docs] def layer_op(self,input_tensor,input_mask=None,output_mask=None): """ Parameters: input_tensor: image to convolve with kernel input_mask: 1-Tensor with a binary mask of input channels to use If this is None, all channels are used. output_mask: 1-Tensor with a binary mask of output channels to generate If this is None, all channels are used and the number of output channels is set at graph-creation time. """ input_shape = input_tensor.get_shape().as_list() if input_mask is None: _input_mask=tf.ones([input_shape[-1]])>0 else: _input_mask=input_mask if output_mask is None: n_sparse_output_chns = self.n_output_chns _output_mask=tf.ones([self.n_output_chns])>0 else: n_sparse_output_chns = tf.reduce_sum(tf.cast(output_mask, tf.float32)) _output_mask=output_mask n_full_input_chns = _input_mask.get_shape().as_list()[0] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply w_full_size = np.vstack(( [self.kernel_size] * spatial_rank, self.n_output_chns, n_full_input_chns)).flatten() full_stride = np.vstack(( 1, [self.stride] * spatial_rank, 1)).flatten() deconv_kernel = tf.get_variable( 'w', shape=w_full_size.tolist(), initializer=self.initializers['w'], regularizer=self.regularizers['w']) sparse_kernel = tf.transpose(tf.boolean_mask( tf.transpose(tf.boolean_mask( tf.transpose(deconv_kernel,[3,4,2,1,0]),_output_mask),[1,0,2,3,4]),_input_mask),[4,3,2,1,0]) if spatial_rank == 2: op_ = SUPPORTED_OP['2D'] elif spatial_rank == 3: op_ = SUPPORTED_OP['3D'] else: raise ValueError( "Only 2D and 3D spatial deconvolutions are supported") output_dim = infer_output_dims(input_shape[1], self.stride, self.kernel_size, self.padding) sparse_output_size = tf.stack([input_shape[0], [output_dim] * spatial_rank, n_sparse_output_chns],0) output_tensor = op_(value=input_tensor, filter=deconv_kernel, output_shape=sparse_output_size, strides=full_stride.tolist(), padding=self.padding, name='deconv') if output_mask is None: # If all output channels are used, we can specify # the number of output channels which is useful for later layers old_shape=output_tensor.get_shape().as_list() old_shape[-1]=self.n_output_chns output_tensor.set_shape(old_shape) if not self.with_bias: return output_tensor # adding the bias term bias_full_size = (self.n_output_chns,) bias_term = tf.get_variable( 'b', shape=bias_full_size, initializer=self.initializers['b'], regularizer=self.regularizers['b']) sparse_bias = tf.boolean_mask(bias_term,_output_mask) output_tensor = tf.nn.bias_add(output_tensor, sparse_bias, name='add_bias') return output_tensor
[docs]class ChannelSparseConvLayer(niftynet.layer.convolution.ConvLayer): """ Channel sparse convolutions perform convolulations over a subset of image channels and generate a subset of output channels. This enables spatial dropout without wasted computations """ def __init__(self,*args,**kwargs): super(ChannelSparseConvLayer,self).__init__(*args,**kwargs)
[docs] def layer_op(self,input_tensor,input_mask,output_mask): """ Parameters: input_tensor: image to convolve with kernel input_mask: 1-Tensor with a binary mask of input channels to use If this is None, all channels are used. output_mask: 1-Tensor with a binary mask of output channels to generate If this is None, all channels are used and the number of output channels is set at graph-creation time. """ sparse_input_shape = input_tensor.get_shape().as_list() if input_mask is None: _input_mask=tf.ones([sparse_input_shape[-1]])>0 else: _input_mask=input_mask if output_mask is None: _output_mask=tf.ones([self.n_output_chns])>0 else: _output_mask=output_mask n_full_input_chns = _input_mask.get_shape().as_list()[0] spatial_rank = layer_util.infer_spatial_rank(input_tensor) # initialize conv kernels/strides and then apply w_full_size = layer_util.expand_spatial_params( self.kernel_size, spatial_rank) # expand kernel size to include number of features w_full_size = w_full_size + (n_full_input_chns, self.n_output_chns) full_stride = layer_util.expand_spatial_params( self.stride, spatial_rank) full_dilation = layer_util.expand_spatial_params( self.dilation, spatial_rank) conv_kernel = tf.get_variable( 'w', shape=w_full_size, initializer=self.initializers['w'], regularizer=self.regularizers['w']) sparse_kernel = tf.transpose(tf.boolean_mask( tf.transpose(tf.boolean_mask( tf.transpose(conv_kernel,[4,3,2,1,0]), _output_mask),[1,0,2,3,4]),_input_mask),[4,3,2,0,1]) output_tensor = tf.nn.convolution(input=input_tensor, filter=sparse_kernel, strides=full_stride, dilation_rate=full_dilation, padding=self.padding, name='conv') if output_mask is None: # If all output channels are used, we can specify # the number of output channels which is useful for later layers old_shape=output_tensor.get_shape().as_list() old_shape[-1]=self.n_output_chns output_tensor.set_shape(old_shape) if not self.with_bias: return output_tensor # adding the bias term bias_term = tf.get_variable( 'b', shape=self.n_output_chns, initializer=self.initializers['b'], regularizer=self.regularizers['b']) sparse_bias = tf.boolean_mask(bias_term,output_mask) output_tensor = tf.nn.bias_add(output_tensor, sparse_bias, name='add_bias') return output_tensor
[docs]class ChannelSparseBNLayer(niftynet.layer.bn.BNLayer): """ Channel sparse convolutions perform convolulations over a subset of image channels and generate a subset of output channels. This enables spatial dropout without wasted computations """ def __init__(self,n_dense_channels, *args,**kwargs): self.n_dense_channels = n_dense_channels super(ChannelSparseBNLayer,self).__init__(*args,**kwargs)
[docs] def layer_op(self, inputs, is_training, mask, use_local_stats=False): """ Parameters: inputs: image to normalize. This typically represents a sparse subset of channels from a sparse convolution. is_training: boolean that is True during training. When True, the layer uses batch statistics for normalization and records a moving average of means and variances. When False, the layer uses previously computed moving averages for normalization mask: 1-Tensor with a binary mask identifying the sparse channels represented in inputs """ if mask is None: mask=tf.ones([self.n_dense_channels])>0 else: mask=mask input_shape = inputs.get_shape() mask_shape = mask.get_shape() # operates on all dims except the last dim params_shape = mask_shape[-1:] assert params_shape[0]==self.n_dense_channels, \ 'Mask size {} must match n_dense_channels {}.'.format(params_shape[0],self.n_dense_channels) axes = list(range(input_shape.ndims - 1)) # create trainable variables and moving average variables beta = tf.get_variable( 'beta', shape=params_shape, initializer=self.initializers['beta'], regularizer=self.regularizers['beta'], dtype=tf.float32, trainable=True) gamma = tf.get_variable( 'gamma', shape=params_shape, initializer=self.initializers['gamma'], regularizer=self.regularizers['gamma'], dtype=tf.float32, trainable=True) collections = [tf.GraphKeys.GLOBAL_VARIABLES] moving_mean = tf.get_variable( 'moving_mean', shape=params_shape, initializer=self.initializers['moving_mean'], dtype=tf.float32, trainable=False, collections=collections) moving_variance = tf.get_variable( 'moving_variance', shape=params_shape, initializer=self.initializers['moving_variance'], dtype=tf.float32, trainable=False, collections=collections) # mean and var mean, variance = tf.nn.moments(inputs, axes) # only update masked moving averages mean_update=tf.dynamic_stitch([tf.to_int32(tf.where(mask)[:,0]),tf.to_int32(tf.where(~mask)[:,0])],[mean,tf.boolean_mask(moving_mean,~mask)]) variance_update=tf.dynamic_stitch([tf.to_int32(tf.where(mask)[:,0]),tf.to_int32(tf.where(~mask)[:,0])],[variance,tf.boolean_mask(moving_variance,~mask)]) update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean_update, self.moving_decay).op update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance_update, self.moving_decay).op tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_mean) tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_variance) # call the normalisation function if is_training or use_local_stats: # with tf.control_dependencies( # [update_moving_mean, update_moving_variance]): outputs = tf.nn.batch_normalization( inputs, mean, variance, tf.boolean_mask(beta,mask), tf.boolean_mask(gamma,mask), self.eps, name='batch_norm') else: outputs = tf.nn.batch_normalization( inputs, tf.boolean_mask(moving_mean,mask), tf.boolean_mask(moving_variance,mask), tf.boolean_mask(beta,mask), tf.boolean_mask(gamma,mask), self.eps, name='batch_norm') outputs.set_shape(inputs.get_shape()) return outputs
[docs]class ChannelSparseConvolutionalLayer(TrainableLayer): """ This class defines a composite layer with optional components:: channel sparse convolution -> batchwise-spatial dropout -> batch_norm -> activation The b_initializer and b_regularizer are applied to the ChannelSparseConvLayer The w_initializer and w_regularizer are applied to the ChannelSparseConvLayer, the batch normalisation layer, and the activation layer (for 'prelu') """ def __init__(self, n_output_chns, kernel_size=3, stride=1, dilation=1, padding='SAME', with_bias=False, with_bn=True, acti_func=None, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, moving_decay=0.9, eps=1e-5, name="conv"): self.acti_func = acti_func self.with_bn = with_bn self.layer_name = '{}'.format(name) if self.with_bn: self.layer_name += '_bn' if self.acti_func is not None: self.layer_name += '_{}'.format(self.acti_func) super(ChannelSparseConvolutionalLayer, self).__init__(name=self.layer_name) # for ConvLayer self.n_output_chns = n_output_chns self.kernel_size = kernel_size self.stride = stride self.dilation = dilation self.padding = padding self.with_bias = with_bias # for BNLayer self.moving_decay = moving_decay self.eps = eps self.initializers = { 'w': w_initializer if w_initializer else niftynet.layer.convolution.default_w_initializer(), 'b': b_initializer if b_initializer else niftynet.layer.convolution.default_b_initializer()} self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor, input_mask=None, is_training=None, keep_prob=None): conv_layer = ChannelSparseConvLayer(n_output_chns=self.n_output_chns, kernel_size=self.kernel_size, stride=self.stride, dilation=self.dilation, padding=self.padding, with_bias=self.with_bias, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], b_initializer=self.initializers['b'], b_regularizer=self.regularizers['b'], name='conv_') if keep_prob is not None: output_mask = tf.to_float(tf.random_shuffle(tf.range(self.n_output_chns))) \ < keep_prob*self.n_output_chns n_output_ch = math.ceil(keep_prob*self.n_output_chns) else: output_mask = tf.ones([self.n_output_chns])>0 n_output_ch = self.n_output_chns output_tensor = conv_layer(input_tensor,input_mask,output_mask) if self.with_bn: if is_training is None: raise ValueError('is_training argument should be ' 'True or False unless with_bn is False') bn_layer = ChannelSparseBNLayer(self.n_output_chns, regularizer=self.regularizers['w'], moving_decay=self.moving_decay, eps=self.eps, name='bn_') output_tensor = bn_layer(output_tensor, is_training, output_mask) if self.acti_func is not None: acti_layer = ActiLayer( func=self.acti_func, regularizer=self.regularizers['w'], name='acti_') output_tensor = acti_layer(output_tensor) output_tensor.set_shape(output_tensor.get_shape().as_list()[:-1]+\ [n_output_ch]) return output_tensor, output_mask