# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from tensorflow.python.training import moving_averages
from niftynet.layer.base_layer import TrainableLayer
BN_COLLECTION = tf.GraphKeys.UPDATE_OPS
[docs]class BNLayer(TrainableLayer):
"""
Batch normalisation layer, with trainable mean value 'beta' and
std 'gamma'. 'beta' is initialised to 0.0 and 'gamma' is initialised
to 1.0. This class assumes 'beta' and 'gamma' share the same type_str of
regulariser.
"""
def __init__(self,
regularizer=None,
moving_decay=0.9,
eps=1e-5,
name='batch_norm'):
super(BNLayer, self).__init__(name=name)
self.eps = eps
self.moving_decay = moving_decay
self.initializers = {'beta': tf.constant_initializer(0.0),
'gamma': tf.constant_initializer(1.0),
'moving_mean': tf.constant_initializer(0.0),
'moving_variance': tf.constant_initializer(1.0)}
self.regularizers = {'beta': regularizer, 'gamma': regularizer}
[docs] def layer_op(self, inputs, is_training, use_local_stats=False):
input_shape = inputs.shape
# operates on all dims except the last dim
params_shape = input_shape[-1:]
axes = list(range(input_shape.ndims - 1))
# create trainable variables and moving average variables
beta = tf.get_variable(
'beta',
shape=params_shape,
initializer=self.initializers['beta'],
regularizer=self.regularizers['beta'],
dtype=tf.float32, trainable=True)
gamma = tf.get_variable(
'gamma',
shape=params_shape,
initializer=self.initializers['gamma'],
regularizer=self.regularizers['gamma'],
dtype=tf.float32, trainable=True)
collections = [tf.GraphKeys.GLOBAL_VARIABLES]
moving_mean = tf.get_variable(
'moving_mean',
shape=params_shape,
initializer=self.initializers['moving_mean'],
dtype=tf.float32, trainable=False, collections=collections)
moving_variance = tf.get_variable(
'moving_variance',
shape=params_shape,
initializer=self.initializers['moving_variance'],
dtype=tf.float32, trainable=False, collections=collections)
# mean and var
mean, variance = tf.nn.moments(inputs, axes)
update_moving_mean = moving_averages.assign_moving_average(
moving_mean, mean, self.moving_decay).op
update_moving_variance = moving_averages.assign_moving_average(
moving_variance, variance, self.moving_decay).op
tf.add_to_collection(BN_COLLECTION, update_moving_mean)
tf.add_to_collection(BN_COLLECTION, update_moving_variance)
# call the normalisation function
if is_training or use_local_stats:
# with tf.control_dependencies(
# [update_moving_mean, update_moving_variance]):
outputs = tf.nn.batch_normalization(
inputs, mean, variance,
beta, gamma, self.eps, name='batch_norm')
else:
outputs = tf.nn.batch_normalization(
inputs, moving_mean, moving_variance,
beta, gamma, self.eps, name='batch_norm')
outputs.set_shape(inputs.get_shape())
return outputs
# # Regularizers are not currently supported for fused batch norm.
# return tf.contrib.layers.batch_norm(
# inputs,
# decay=self.moving_decay,
# center=True,
# scale=True,
# epsilon=self.eps,
# activation_fn=None,
# param_initializers=self.initializers,
# param_regularizers=self.regularizers,
# updates_collections=tf.GraphKeys.UPDATE_OPS,
# is_training=is_training,
# reuse=None,
# variables_collections=[tf.GraphKeys.MOVING_AVERAGE_VARIABLES,
# tf.GraphKeys.GLOBAL_VARIABLES],
# outputs_collections=None,
# trainable=True,
# batch_weights=None,
# fused=False,
# data_format='NHWC',
# zero_debias_moving_mean=False,
# scope=None)
[docs]class InstanceNormLayer(TrainableLayer):
"""
Instance normalisation layer, wrapper of `tf.contrib.layers.instance_norm`.
"""
def __init__(self, eps=1e-6, gamma_initializer=None, name='instance_norm'):
TrainableLayer.__init__(self, name=name)
self.eps = eps
self.gamma_initializer = gamma_initializer
[docs] def layer_op(self, inputs):
if self.gamma_initializer is None:
self.gamma_initializer = tf.constant_initializer(1.0)
return tf.contrib.layers.instance_norm(
inputs,
center=True,
scale=True,
epsilon=self.eps,
param_initializers={'gamma': self.gamma_initializer},
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
data_format='NHWC',
scope=None)