# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import math
import numpy as np
import tensorflow as tf
from tensorflow.python.training import moving_averages
import niftynet.layer.bn
import niftynet.layer.convolution
import niftynet.layer.deconvolution
from niftynet.layer import layer_util
from niftynet.layer.activation import ActiLayer
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.deconvolution import infer_output_dims
SUPPORTED_OP = {'2D': tf.nn.conv2d_transpose,
'3D': tf.nn.conv3d_transpose}
[docs]class ChannelSparseDeconvLayer(niftynet.layer.deconvolution.DeconvLayer):
"""
Channel sparse convolutions perform convolutions over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations
"""
def __init__(self, *args, **kwargs):
super(ChannelSparseDeconvLayer, self).__init__(*args, **kwargs)
[docs] def layer_op(self, input_tensor, input_mask=None, output_mask=None):
"""
:param input_tensor: image to convolve with kernel
:param input_mask: 1-Tensor with a binary mask of input channels to use
If this is None, all channels are used.
:param output_mask: 1-Tensor with a binary mask of output channels to
generate. If this is None, all channels are used and the number
of output channels is set at graph-creation time.
:return:
"""
input_shape = input_tensor.shape.as_list()
if input_mask is None:
_input_mask = tf.ones([input_shape[-1]]) > 0
else:
_input_mask = input_mask
if output_mask is None:
n_sparse_output_chns = self.n_output_chns
_output_mask = tf.ones([self.n_output_chns]) > 0
else:
n_sparse_output_chns = tf.reduce_sum(
tf.cast(output_mask, tf.float32))
_output_mask = output_mask
n_full_input_chns = _input_mask.shape.as_list()[0]
spatial_rank = layer_util.infer_spatial_rank(input_tensor)
# initialize conv kernels/strides and then apply
w_full_size = np.vstack((
[self.kernel_size] * spatial_rank,
self.n_output_chns, n_full_input_chns)).flatten()
full_stride = np.vstack((
1, [self.stride] * spatial_rank, 1)).flatten()
deconv_kernel = tf.get_variable(
'w', shape=w_full_size.tolist(),
initializer=self.initializers['w'],
regularizer=self.regularizers['w'])
if spatial_rank == 2:
op_ = SUPPORTED_OP['2D']
elif spatial_rank == 3:
op_ = SUPPORTED_OP['3D']
else:
raise ValueError(
"Only 2D and 3D spatial deconvolutions are supported")
output_dim = infer_output_dims(input_shape[1],
self.stride,
self.kernel_size,
self.padding)
sparse_output_size = \
[input_shape[0], [output_dim] * spatial_rank, n_sparse_output_chns]
sparse_output_size = tf.stack(sparse_output_size, 0)
output_tensor = op_(value=input_tensor,
filter=deconv_kernel,
output_shape=sparse_output_size,
strides=full_stride.tolist(),
padding=self.padding,
name='deconv')
if output_mask is None:
# If all output channels are used, we can specify
# the number of output channels which is useful for later layers
old_shape = output_tensor.shape.as_list()
old_shape[-1] = self.n_output_chns
output_tensor.set_shape(old_shape)
if not self.with_bias:
return output_tensor
# adding the bias term
bias_full_size = (self.n_output_chns,)
bias_term = tf.get_variable(
'b', shape=bias_full_size,
initializer=self.initializers['b'],
regularizer=self.regularizers['b'])
sparse_bias = tf.boolean_mask(bias_term, _output_mask)
output_tensor = tf.nn.bias_add(
output_tensor, sparse_bias, name='add_bias')
return output_tensor
[docs]class ChannelSparseConvLayer(niftynet.layer.convolution.ConvLayer):
"""
Channel sparse convolutions perform convolutions over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations.
"""
def __init__(self, *args, **kwargs):
super(ChannelSparseConvLayer, self).__init__(*args, **kwargs)
[docs] def layer_op(self, input_tensor, input_mask, output_mask):
"""
:param input_tensor: image to convolve with kernel
:param input_mask: 1-Tensor with a binary mask of input channels to use
If this is None, all channels are used.
:param output_mask: 1-Tensor with a binary mask of output channels to
generate. If this is None, all channels are used and
the number of output channels is set at graph-creation time.
:return:
"""
sparse_input_shape = input_tensor.shape.as_list()
if input_mask is None:
_input_mask = tf.ones([sparse_input_shape[-1]]) > 0
else:
_input_mask = input_mask
if output_mask is None:
_output_mask = tf.ones([self.n_output_chns]) > 0
else:
_output_mask = output_mask
n_full_input_chns = _input_mask.shape.as_list()[0]
spatial_rank = layer_util.infer_spatial_rank(input_tensor)
# initialize conv kernels/strides and then apply
w_full_size = layer_util.expand_spatial_params(
self.kernel_size, spatial_rank)
# expand kernel size to include number of features
w_full_size = w_full_size + (n_full_input_chns, self.n_output_chns)
full_stride = layer_util.expand_spatial_params(
self.stride, spatial_rank)
full_dilation = layer_util.expand_spatial_params(
self.dilation, spatial_rank)
conv_kernel = tf.get_variable(
'w', shape=w_full_size,
initializer=self.initializers['w'],
regularizer=self.regularizers['w'])
if spatial_rank == 2:
transpositions = [[3, 2, 1, 0], [1, 0, 2, 3], [3, 2, 0, 1]]
elif spatial_rank == 3:
transpositions = [[4, 3, 2, 1, 0], [1, 0, 2, 3, 4], [4, 3, 2, 0, 1]]
else:
raise NotImplementedError("spatial rank not supported")
sparse_kernel = tf.transpose(conv_kernel, transpositions[0])
sparse_kernel = tf.boolean_mask(sparse_kernel, _output_mask)
sparse_kernel = tf.transpose(sparse_kernel, transpositions[1])
sparse_kernel = tf.boolean_mask(sparse_kernel, _input_mask)
sparse_kernel = tf.transpose(sparse_kernel, transpositions[2])
output_tensor = tf.nn.convolution(input=input_tensor,
filter=sparse_kernel,
strides=full_stride,
dilation_rate=full_dilation,
padding=self.padding,
name='conv')
if output_mask is None:
# If all output channels are used, we can specify
# the number of output channels which is useful for later layers
old_shape = output_tensor.shape.as_list()
old_shape[-1] = self.n_output_chns
output_tensor.set_shape(old_shape)
if not self.with_bias:
return output_tensor
# adding the bias term
bias_term = tf.get_variable(
'b', shape=self.n_output_chns,
initializer=self.initializers['b'],
regularizer=self.regularizers['b'])
sparse_bias = tf.boolean_mask(bias_term, output_mask)
output_tensor = tf.nn.bias_add(
output_tensor, sparse_bias, name='add_bias')
return output_tensor
[docs]class ChannelSparseBNLayer(niftynet.layer.bn.BNLayer):
"""
Channel sparse convolutions perform convolutions over
a subset of image channels and generate a subset of output
channels. This enables spatial dropout without wasted computations
"""
def __init__(self, n_dense_channels, *args, **kwargs):
self.n_dense_channels = n_dense_channels
super(ChannelSparseBNLayer, self).__init__(*args, **kwargs)
[docs] def layer_op(self, inputs, is_training, mask, use_local_stats=False):
"""
:param inputs: image to normalize. This typically represents a sparse
subset of channels from a sparse convolution.
:param is_training: boolean that is True during training.
When True, the layer uses batch statistics for normalization and
records a moving average of means and variances.
When False, the layer uses previously computed moving averages
for normalization.
:param mask: 1-Tensor with a binary mask identifying the sparse
channels represented in inputs
:param use_local_stats:
:return:
"""
if mask is None:
mask = tf.ones([self.n_dense_channels]) > 0
else:
mask = mask
input_shape = inputs.shape
mask_shape = mask.shape
# operates on all dims except the last dim
params_shape = mask_shape[-1:]
assert params_shape[0] == self.n_dense_channels, \
'Mask size {} must match n_dense_channels {}.'.format(
params_shape[0], self.n_dense_channels)
axes = list(range(input_shape.ndims - 1))
# create trainable variables and moving average variables
beta = tf.get_variable(
'beta',
shape=params_shape,
initializer=self.initializers['beta'],
regularizer=self.regularizers['beta'],
dtype=tf.float32, trainable=True)
gamma = tf.get_variable(
'gamma',
shape=params_shape,
initializer=self.initializers['gamma'],
regularizer=self.regularizers['gamma'],
dtype=tf.float32, trainable=True)
beta = tf.boolean_mask(beta, mask)
gamma = tf.boolean_mask(gamma, mask)
collections = [tf.GraphKeys.GLOBAL_VARIABLES]
moving_mean = tf.get_variable(
'moving_mean',
shape=params_shape,
initializer=self.initializers['moving_mean'],
dtype=tf.float32, trainable=False, collections=collections)
moving_variance = tf.get_variable(
'moving_variance',
shape=params_shape,
initializer=self.initializers['moving_variance'],
dtype=tf.float32, trainable=False, collections=collections)
# mean and var
mean, variance = tf.nn.moments(inputs, axes)
# only update masked moving averages
mean_update = tf.dynamic_stitch(
[tf.to_int32(tf.where(mask)[:, 0]),
tf.to_int32(tf.where(~mask)[:, 0])],
[mean,
tf.boolean_mask(moving_mean, ~mask)])
variance_update = tf.dynamic_stitch(
[tf.to_int32(tf.where(mask)[:, 0]),
tf.to_int32(tf.where(~mask)[:, 0])],
[variance,
tf.boolean_mask(moving_variance, ~mask)])
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean_update, self.moving_decay).op
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance_update, self.moving_decay).op
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_mean)
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, update_moving_variance)
# call the normalisation function
if is_training or use_local_stats:
outputs = tf.nn.batch_normalization(
inputs, mean, variance,
beta, gamma, self.eps, name='batch_norm')
else:
outputs = tf.nn.batch_normalization(
inputs,
tf.boolean_mask(moving_mean, mask),
tf.boolean_mask(moving_variance, mask),
beta, gamma, self.eps, name='batch_norm')
outputs.set_shape(inputs.get_shape())
return outputs
[docs]class ChannelSparseConvolutionalLayer(TrainableLayer):
"""
This class defines a composite layer with optional components::
channel sparse convolution ->
batchwise-spatial dropout ->
batch_norm ->
activation
The b_initializer and b_regularizer are applied to
the ChannelSparseConvLayer, the w_initializer and w_regularizer
are applied to the ChannelSparseConvLayer, the batch normalisation
layer, and the activation layer (for 'prelu')
"""
def __init__(self,
n_output_chns,
kernel_size=3,
stride=1,
dilation=1,
padding='SAME',
with_bias=False,
with_bn=True,
acti_func=None,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
moving_decay=0.9,
eps=1e-5,
name="conv"):
self.acti_func = acti_func
self.with_bn = with_bn
self.layer_name = '{}'.format(name)
if self.with_bn:
self.layer_name += '_bn'
if self.acti_func is not None:
self.layer_name += '_{}'.format(self.acti_func)
super(ChannelSparseConvolutionalLayer, self).__init__(
name=self.layer_name)
# for ConvLayer
self.n_output_chns = n_output_chns
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
self.padding = padding
self.with_bias = with_bias
# for BNLayer
self.moving_decay = moving_decay
self.eps = eps
self.initializers = {
'w': w_initializer if w_initializer else
niftynet.layer.convolution.default_w_initializer(),
'b': b_initializer if b_initializer else
niftynet.layer.convolution.default_b_initializer()}
self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
[docs] def layer_op(self,
input_tensor,
input_mask=None,
is_training=None,
keep_prob=None):
conv_layer = ChannelSparseConvLayer(
n_output_chns=self.n_output_chns,
kernel_size=self.kernel_size,
stride=self.stride,
dilation=self.dilation,
padding=self.padding,
with_bias=self.with_bias,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
b_initializer=self.initializers['b'],
b_regularizer=self.regularizers['b'],
name='conv_')
if keep_prob is not None:
output_mask = \
tf.to_float(tf.random_shuffle(tf.range(self.n_output_chns))) \
< keep_prob * self.n_output_chns
n_output_ch = math.ceil(keep_prob * self.n_output_chns)
else:
output_mask = tf.ones([self.n_output_chns]) > 0
n_output_ch = self.n_output_chns
output_tensor = conv_layer(input_tensor, input_mask, output_mask)
output_tensor.set_shape(
output_tensor.shape.as_list()[:-1] + [n_output_ch])
if self.with_bn:
if is_training is None:
raise ValueError('is_training argument should be '
'True or False unless with_bn is False')
bn_layer = ChannelSparseBNLayer(
self.n_output_chns,
regularizer=self.regularizers['w'],
moving_decay=self.moving_decay,
eps=self.eps,
name='bn_')
output_tensor = bn_layer(output_tensor, is_training, output_mask)
if self.acti_func is not None:
acti_layer = ActiLayer(
func=self.acti_func,
regularizer=self.regularizers['w'],
name='acti_')
output_tensor = acti_layer(output_tensor)
return output_tensor, output_mask