Source code for niftynet.network.interventional_affine_net

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers.python.layers import regularizers

from niftynet.engine.application_initializer import GlorotUniform
from niftynet.layer.convolution import ConvolutionalLayer as Conv
from niftynet.layer.downsample_res_block import DownBlock as DownRes
from niftynet.layer.fully_connected import FullyConnectedLayer as FC
from niftynet.layer.grid_warper import AffineGridWarperLayer as Grid
from niftynet.layer.layer_util import infer_spatial_rank
from niftynet.layer.linear_resize import LinearResizeLayer as Resize
from niftynet.network.base_net import BaseNet


[docs]class INetAffine(BaseNet): """ ### Description This network estimates affine transformations from a pair of moving and fixed image: Hu et al., Label-driven weakly-supervised learning for multimodal deformable image registration, arXiv:1711.01666 https://arxiv.org/abs/1711.01666 Hu et al., Weakly-Supervised Convolutional Neural Networks for Multimodal Image Registration, Medical Image Analysis (2018) https://doi.org/10.1016/j.media.2018.07.002 see also: https://github.com/YipengHu/label-reg ### Building blocks [DOWN CONV] - Convolutional layer + Residual Unit + Max pooling [CONV] - Convolutional layer [FC] - Fully connected layer, outputs the affine matrix [WARPER] - Grid resampling with the obtained affine matrix ### Diagram INPUT PAIR --> [DOWN CONV]x4 --> [CONV] --> [FC] --> [WARPER] --> DISPLACEMENT FIELD ### Constraints - input spatial rank should be either 2 or 3 (2D or 3D images only) """
[docs] def __init__(self, decay=1e-6, affine_w_initializer=None, affine_b_initializer=None, acti_func='relu', name='inet-affine'): """ :param decay: float, regularisation decay :param affine_w_initializer: weight initialisation for affine registration network :param affine_b_initializer: bias initialisation for affine registration network :param acti_func: activation function to use :param name: layer name """ BaseNet.__init__(self, name=name) self.fea = [4, 8, 16, 32, 64] self.k_conv = 3 self.affine_w_initializer = affine_w_initializer self.affine_b_initializer = affine_b_initializer self.res_param = { 'w_initializer': GlorotUniform.get_instance(''), 'w_regularizer': regularizers.l2_regularizer(decay), 'acti_func': acti_func} self.affine_param = { 'w_regularizer': regularizers.l2_regularizer(decay), 'b_regularizer': None}
[docs] def layer_op(self, fixed_image, moving_image, is_training=True, **unused_kwargs): """ :param fixed_image: tensor, fixed image for registration (defines reference space) :param moving_image: tensor, moving image to be registered to fixed :param is_training: boolean, True if network is in training mode :return: displacement fields transformed by estimating affine """ spatial_rank = infer_spatial_rank(moving_image) spatial_shape = fixed_image.get_shape().as_list()[1:-1] # resize the moving image to match the fixed moving_image = Resize(spatial_shape)(moving_image) img = tf.concat([moving_image, fixed_image], axis=-1) res_1 = DownRes(self.fea[0], kernel_size=7, **self.res_param)(img, is_training)[0] res_2 = DownRes(self.fea[1], **self.res_param)(res_1, is_training)[0] res_3 = DownRes(self.fea[2], **self.res_param)(res_2, is_training)[0] res_4 = DownRes(self.fea[3], **self.res_param)(res_3, is_training)[0] conv_5 = Conv(n_output_chns=self.fea[4], kernel_size=self.k_conv, with_bias=False, feature_normalization='batch', **self.res_param)(res_4, is_training) if spatial_rank == 2: affine_size = 6 elif spatial_rank == 3: affine_size = 12 else: tf.logging.fatal('Not supported spatial rank') raise NotImplementedError if self.affine_w_initializer is None: self.affine_w_initializer = init_affine_w() if self.affine_b_initializer is None: self.affine_b_initializer = init_affine_b(spatial_rank) affine = FC(n_output_chns=affine_size, feature_normalization=None, w_initializer=self.affine_w_initializer, b_initializer=self.affine_b_initializer, **self.affine_param)(conv_5) grid_global = Grid(source_shape=spatial_shape, output_shape=spatial_shape)(affine) return grid_global
[docs]def init_affine_w(std=1e-8): """ :param std: float, standard deviation of normal distribution for weight initialisation :return: random weight initialisation from normal distribution with zero mean """ return tf.random_normal_initializer(0, std)
[docs]def init_affine_b(spatial_rank, initial_bias=0.0): """ :param spatial_rank: int, rank of inputs (either 2D or 3D) :param initial_bias: float, initial bias :return: bias initialisation for the affine matrix """ if spatial_rank == 2: identity = np.array([[1., 0., 0.], [0., 1., 0.]]).flatten() elif spatial_rank == 3: identity = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]).flatten() else: tf.logging.fatal('Not supported spatial rank') raise NotImplementedError identity = identity.reshape([1, -1]) identity = np.tile(identity, [1, 1]) return tf.constant_initializer(identity + initial_bias)