# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import math
import tensorflow as tf
from tensorflow.python.training import moving_averages
import numpy as np
import niftynet.layer.convolution
import niftynet.layer.deconvolution
import niftynet.layer.bn
from niftynet.utilities.util_common import look_up_operations
from niftynet.layer import layer_util
from niftynet.layer.activation import ActiLayer
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.deconvolution import infer_output_dims
SUPPORTED_OP = {'2D': tf.nn.conv2d_transpose,
'3D': tf.nn.conv3d_transpose}
[docs]class ChannelSparseDeconvLayer(niftynet.layer.deconvolution.DeconvLayer):
"""
Channel sparse convolutions perform convolulations over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations
"""
def __init__(self,*args,**kwargs):
super(ChannelSparseDeconvLayer,self).__init__(*args,**kwargs)
[docs] def layer_op(self,input_tensor,input_mask=None,output_mask=None):
"""
Parameters:
input_tensor: image to convolve with kernel
input_mask: 1-Tensor with a binary mask of input channels to use
If this is None, all channels are used.
output_mask: 1-Tensor with a binary mask of output channels to generate
If this is None, all channels are used and the number of output
channels is set at graph-creation time.
"""
input_shape = input_tensor.get_shape().as_list()
if input_mask is None:
_input_mask=tf.ones([input_shape[-1]])>0
else:
_input_mask=input_mask
if output_mask is None:
n_sparse_output_chns = self.n_output_chns
_output_mask=tf.ones([self.n_output_chns])>0
else:
n_sparse_output_chns = tf.reduce_sum(tf.cast(output_mask, tf.float32))
_output_mask=output_mask
n_full_input_chns = _input_mask.get_shape().as_list()[0]
spatial_rank = layer_util.infer_spatial_rank(input_tensor)
# initialize conv kernels/strides and then apply
w_full_size = np.vstack((
[self.kernel_size] * spatial_rank,
self.n_output_chns, n_full_input_chns)).flatten()
full_stride = np.vstack((
1, [self.stride] * spatial_rank, 1)).flatten()
deconv_kernel = tf.get_variable(
'w', shape=w_full_size.tolist(),
initializer=self.initializers['w'],
regularizer=self.regularizers['w'])
sparse_kernel = tf.transpose(tf.boolean_mask(
tf.transpose(tf.boolean_mask(
tf.transpose(deconv_kernel,[3,4,2,1,0]),_output_mask),[1,0,2,3,4]),_input_mask),[4,3,2,1,0])
if spatial_rank == 2:
op_ = SUPPORTED_OP['2D']
elif spatial_rank == 3:
op_ = SUPPORTED_OP['3D']
else:
raise ValueError(
"Only 2D and 3D spatial deconvolutions are supported")
output_dim = infer_output_dims(input_shape[1],
self.stride,
self.kernel_size,
self.padding)
sparse_output_size = tf.stack([input_shape[0],
[output_dim] * spatial_rank,
n_sparse_output_chns],0)
output_tensor = op_(value=input_tensor,
filter=deconv_kernel,
output_shape=sparse_output_size,
strides=full_stride.tolist(),
padding=self.padding,
name='deconv')
if output_mask is None:
# If all output channels are used, we can specify
# the number of output channels which is useful for later layers
old_shape=output_tensor.get_shape().as_list()
old_shape[-1]=self.n_output_chns
output_tensor.set_shape(old_shape)
if not self.with_bias:
return output_tensor
# adding the bias term
bias_full_size = (self.n_output_chns,)
bias_term = tf.get_variable(
'b', shape=bias_full_size,
initializer=self.initializers['b'],
regularizer=self.regularizers['b'])
sparse_bias = tf.boolean_mask(bias_term,_output_mask)
output_tensor = tf.nn.bias_add(output_tensor,
sparse_bias,
name='add_bias')
return output_tensor
[docs]class ChannelSparseConvLayer(niftynet.layer.convolution.ConvLayer):
""" Channel sparse convolutions perform convolulations over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations
"""
def __init__(self,*args,**kwargs):
super(ChannelSparseConvLayer,self).__init__(*args,**kwargs)
[docs] def layer_op(self,input_tensor,input_mask,output_mask):
"""
Parameters:
input_tensor: image to convolve with kernel
input_mask: 1-Tensor with a binary mask of input channels to use
If this is None, all channels are used.
output_mask: 1-Tensor with a binary mask of output channels to generate
If this is None, all channels are used and the number of output
channels is set at graph-creation time.
"""
sparse_input_shape = input_tensor.get_shape().as_list()
if input_mask is None:
_input_mask=tf.ones([sparse_input_shape[-1]])>0
else:
_input_mask=input_mask
if output_mask is None:
_output_mask=tf.ones([self.n_output_chns])>0
else:
_output_mask=output_mask
n_full_input_chns = _input_mask.get_shape().as_list()[0]
spatial_rank = layer_util.infer_spatial_rank(input_tensor)
# initialize conv kernels/strides and then apply
w_full_size = layer_util.expand_spatial_params(
self.kernel_size, spatial_rank)
# expand kernel size to include number of features
w_full_size = w_full_size + (n_full_input_chns, self.n_output_chns)
full_stride = layer_util.expand_spatial_params(
self.stride, spatial_rank)
full_dilation = layer_util.expand_spatial_params(
self.dilation, spatial_rank)
conv_kernel = tf.get_variable(
'w', shape=w_full_size,
initializer=self.initializers['w'],
regularizer=self.regularizers['w'])
sparse_kernel = tf.transpose(tf.boolean_mask(
tf.transpose(tf.boolean_mask(
tf.transpose(conv_kernel,[4,3,2,1,0]),
_output_mask),[1,0,2,3,4]),_input_mask),[4,3,2,0,1])
output_tensor = tf.nn.convolution(input=input_tensor,
filter=sparse_kernel,
strides=full_stride,
dilation_rate=full_dilation,
padding=self.padding,
name='conv')
if output_mask is None:
# If all output channels are used, we can specify
# the number of output channels which is useful for later layers
old_shape=output_tensor.get_shape().as_list()
old_shape[-1]=self.n_output_chns
output_tensor.set_shape(old_shape)
if not self.with_bias:
return output_tensor
# adding the bias term
bias_term = tf.get_variable(
'b', shape=self.n_output_chns,
initializer=self.initializers['b'],
regularizer=self.regularizers['b'])
sparse_bias = tf.boolean_mask(bias_term,output_mask)
output_tensor = tf.nn.bias_add(output_tensor, sparse_bias,
name='add_bias')
return output_tensor
[docs]class ChannelSparseBNLayer(niftynet.layer.bn.BNLayer):
""" Channel sparse convolutions perform convolulations over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations
"""
def __init__(self,n_dense_channels, *args,**kwargs):
self.n_dense_channels = n_dense_channels
super(ChannelSparseBNLayer,self).__init__(*args,**kwargs)
[docs] def layer_op(self, inputs, is_training, mask, use_local_stats=False):
"""
Parameters:
inputs: image to normalize. This typically represents a sparse subset of channels
from a sparse convolution.
is_training: boolean that is True during training. When True, the layer uses batch
statistics for normalization and records a moving average of means and variances. When False, the layer uses previously computed moving averages for normalization
mask: 1-Tensor with a binary mask identifying the sparse channels represented in inputs
"""
if mask is None:
mask=tf.ones([self.n_dense_channels])>0
else:
mask=mask
input_shape = inputs.get_shape()
mask_shape = mask.get_shape()
# operates on all dims except the last dim
params_shape = mask_shape[-1:]
assert params_shape[0]==self.n_dense_channels, \
'Mask size {} must match n_dense_channels {}.'.format(params_shape[0],self.n_dense_channels)
axes = list(range(input_shape.ndims - 1))
# create trainable variables and moving average variables
beta = tf.get_variable(
'beta',
shape=params_shape,
initializer=self.initializers['beta'],
regularizer=self.regularizers['beta'],
dtype=tf.float32, trainable=True)
gamma = tf.get_variable(
'gamma',
shape=params_shape,
initializer=self.initializers['gamma'],
regularizer=self.regularizers['gamma'],
dtype=tf.float32, trainable=True)
collections = [tf.GraphKeys.GLOBAL_VARIABLES]
moving_mean = tf.get_variable(
'moving_mean',
shape=params_shape,
initializer=self.initializers['moving_mean'],
dtype=tf.float32, trainable=False, collections=collections)
moving_variance = tf.get_variable(
'moving_variance',
shape=params_shape,
initializer=self.initializers['moving_variance'],
dtype=tf.float32, trainable=False, collections=collections)
# mean and var
mean, variance = tf.nn.moments(inputs, axes)
# only update masked moving averages
mean_update=tf.dynamic_stitch([tf.to_int32(tf.where(mask)[:,0]),tf.to_int32(tf.where(~mask)[:,0])],[mean,tf.boolean_mask(moving_mean,~mask)])
variance_update=tf.dynamic_stitch([tf.to_int32(tf.where(mask)[:,0]),tf.to_int32(tf.where(~mask)[:,0])],[variance,tf.boolean_mask(moving_variance,~mask)])
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean_update, self.moving_decay).op
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance_update, self.moving_decay).op
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_mean)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_variance)
# call the normalisation function
if is_training or use_local_stats:
# with tf.control_dependencies(
# [update_moving_mean, update_moving_variance]):
outputs = tf.nn.batch_normalization(
inputs, mean, variance,
tf.boolean_mask(beta,mask), tf.boolean_mask(gamma,mask), self.eps, name='batch_norm')
else:
outputs = tf.nn.batch_normalization(
inputs, tf.boolean_mask(moving_mean,mask), tf.boolean_mask(moving_variance,mask),
tf.boolean_mask(beta,mask), tf.boolean_mask(gamma,mask), self.eps, name='batch_norm')
outputs.set_shape(inputs.get_shape())
return outputs
[docs]class ChannelSparseConvolutionalLayer(TrainableLayer):
"""
This class defines a composite layer with optional components::
channel sparse convolution -> batchwise-spatial dropout -> batch_norm -> activation
The b_initializer and b_regularizer are applied to the ChannelSparseConvLayer
The w_initializer and w_regularizer are applied to the ChannelSparseConvLayer,
the batch normalisation layer, and the activation layer (for 'prelu')
"""
def __init__(self,
n_output_chns,
kernel_size=3,
stride=1,
dilation=1,
padding='SAME',
with_bias=False,
with_bn=True,
acti_func=None,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
moving_decay=0.9,
eps=1e-5,
name="conv"):
self.acti_func = acti_func
self.with_bn = with_bn
self.layer_name = '{}'.format(name)
if self.with_bn:
self.layer_name += '_bn'
if self.acti_func is not None:
self.layer_name += '_{}'.format(self.acti_func)
super(ChannelSparseConvolutionalLayer, self).__init__(name=self.layer_name)
# for ConvLayer
self.n_output_chns = n_output_chns
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.with_bias = with_bias
# for BNLayer
self.moving_decay = moving_decay
self.eps = eps
self.initializers = {
'w': w_initializer if w_initializer else niftynet.layer.convolution.default_w_initializer(),
'b': b_initializer if b_initializer else niftynet.layer.convolution.default_b_initializer()}
self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self, input_tensor, input_mask=None, is_training=None, keep_prob=None):
conv_layer = ChannelSparseConvLayer(n_output_chns=self.n_output_chns,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=self.padding,
with_bias=self.with_bias,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
b_initializer=self.initializers['b'],
b_regularizer=self.regularizers['b'],
name='conv_')
if keep_prob is not None:
output_mask = tf.to_float(tf.random_shuffle(tf.range(self.n_output_chns))) \
< keep_prob*self.n_output_chns
n_output_ch = math.ceil(keep_prob*self.n_output_chns)
else:
output_mask = tf.ones([self.n_output_chns])>0
n_output_ch = self.n_output_chns
output_tensor = conv_layer(input_tensor,input_mask,output_mask)
if self.with_bn:
if is_training is None:
raise ValueError('is_training argument should be '
'True or False unless with_bn is False')
bn_layer = ChannelSparseBNLayer(self.n_output_chns,
regularizer=self.regularizers['w'],
moving_decay=self.moving_decay,
eps=self.eps,
name='bn_')
output_tensor = bn_layer(output_tensor, is_training, output_mask)
if self.acti_func is not None:
acti_layer = ActiLayer(
func=self.acti_func,
regularizer=self.regularizers['w'],
name='acti_')
output_tensor = acti_layer(output_tensor)
output_tensor.set_shape(output_tensor.get_shape().as_list()[:-1]+\
[n_output_ch])
return output_tensor, output_mask