Source code for niftynet.layer.gn

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import tensorflow as tf

from niftynet.layer.base_layer import TrainableLayer


[docs]class GNLayer(TrainableLayer): """ Group normalisation layer, with trainable mean value 'beta' and std 'gamma'. 'beta' is initialised to 0.0 and 'gamma' is initialised to 1.0. This class assumes 'beta' and 'gamma' share the same type_str of regulariser. Reimplementation of Wu and He, Group Normalization, arXiv:1803.08494 (2018) """ def __init__(self, group_size=32, regularizer=None, eps=1e-5, name='group_norm'): super(GNLayer, self).__init__(name=name) self.group_size = group_size self.eps = eps self.initializers = { 'beta': tf.constant_initializer(0.0), 'gamma': tf.constant_initializer(1.0)} self.regularizers = {'beta': regularizer, 'gamma': regularizer}
[docs] def layer_op(self, inputs): input_shape = inputs.shape group_size = max(min(self.group_size, input_shape[-1]), 1) assert input_shape[-1] % group_size == 0, \ 'number of input channels should be divisible by group size.' grouped_shape = tf.stack( list(input_shape[:-1]) + [group_size, input_shape[-1] // group_size]) grouped_inputs = tf.reshape(inputs, grouped_shape) # operates on all dims except the batch and grouped dim axes = list(range(1, input_shape.ndims - 1)) + [input_shape.ndims] # create the shape of trainable variables param_shape = [1] * (input_shape.ndims - 1) + [input_shape[-1]] # create trainable variables beta = tf.get_variable( 'beta', shape=param_shape, initializer=self.initializers['beta'], regularizer=self.regularizers['beta'], dtype=tf.float32, trainable=True) gamma = tf.get_variable( 'gamma', shape=param_shape, initializer=self.initializers['gamma'], regularizer=self.regularizers['gamma'], dtype=tf.float32, trainable=True) # mean and var mean, variance = tf.nn.moments(grouped_inputs, axes, keep_dims=True) outputs = (grouped_inputs - mean) / tf.sqrt(variance + self.eps) outputs = tf.reshape(outputs, input_shape) * gamma + beta return outputs