Source code for niftynet.network.interventional_dense_net

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf
from tensorflow.contrib.layers.python.layers import regularizers

from niftynet.engine.application_initializer import GlorotUniform
from niftynet.layer.convolution import ConvolutionalLayer as Conv
from niftynet.layer.downsample_res_block import DownBlock as DownRes
from niftynet.layer.grid_warper import _create_affine_features
from niftynet.layer.layer_util import infer_spatial_rank, check_spatial_dims
from niftynet.layer.linear_resize import LinearResizeLayer as Resize
from niftynet.layer.spatial_gradient import SpatialGradientLayer as ImgGrad
from niftynet.layer.upsample_res_block import UpBlock as UpRes
from niftynet.network.base_net import BaseNet


[docs]class INetDense(BaseNet):
[docs] def __init__(self, decay=0.0, smoothing=0, disp_w_initializer=None, disp_b_initializer=None, acti_func='relu', multi_scale_fusion=True, name='inet-dense'): """ The network estimates dense displacement fields from a pair of moving and fixed images: Hu et al., Label-driven weakly-supervised learning for multimodal deformable image registration, arXiv:1711.01666 https://arxiv.org/abs/1711.01666 Hu et al., Weakly-Supervised Convolutional Neural Networks for Multimodal Image Registration, Medical Image Analysis (2018) https://doi.org/10.1016/j.media.2018.07.002 see also: https://github.com/YipengHu/label-reg :param decay: :param smoothing: :param disp_w_initializer: initialisation of the displacement fields :param disp_b_initializer: initialisation of the dis :param acti_func: :param multi_scale_fusion: True/False indicating whether to use multiscale feature fusion. :param name: """ BaseNet.__init__(self, name=name) # self.fea = [40, 80, 160, 320, 640] # self.fea = [32, 64, 128, 256, 512] self.fea = [30, 60, 120, 240, 480] # self.fea = [16, 32, 64, 128, 256] self.k_conv = 3 self.multi_scale_fusion = multi_scale_fusion self.down_res_param = { 'w_initializer': GlorotUniform.get_instance(''), 'w_regularizer': regularizers.l2_regularizer(decay), 'acti_func': acti_func} self.up_res_param = { 'acti_func': acti_func, 'w_initializer': GlorotUniform.get_instance(''), 'w_regularizer': regularizers.l2_regularizer(decay), 'is_residual_upsampling': True, 'type_string': 'bn_acti_conv'} # displacement initialiser & regulariser if disp_w_initializer is None: disp_b_initializer = tf.constant_initializer(0.0) #disp_w_initializer = tf.random_normal_initializer(0, 1e-4) if disp_b_initializer is None: disp_b_initializer = tf.constant_initializer(0.0) #disp_w_initializer = tf.random_normal_initializer(0, 0.0) self.disp_param = { 'w_initializer': disp_w_initializer, 'w_regularizer': regularizers.l2_regularizer(decay), 'b_initializer': disp_b_initializer, 'b_regularizer': None} if smoothing > 0: self.smoothing_func = _smoothing_func(smoothing) else: self.smoothing_func = None
[docs] def layer_op(self, fixed_image, moving_image, base_grid=None, is_training=True, **unused_kwargs): """ :param fixed_image: :param moving_image: :param base_grid: :param is_training: :return: estimated dense displacement fields """ spatial_rank = infer_spatial_rank(fixed_image) spatial_shape = fixed_image.get_shape().as_list()[1:-1] check_spatial_dims(fixed_image, lambda x: x % 16 == 0) #  resize the moving image to match the fixed moving_image = Resize(spatial_shape)(moving_image) img = tf.concat([moving_image, fixed_image], axis=-1) down_res_0, conv_0_0, _ = \ DownRes(self.fea[0], kernel_size=7, **self.down_res_param)(img, is_training) down_res_1, conv_0_1, _ = \ DownRes(self.fea[1], **self.down_res_param)(down_res_0, is_training) down_res_2, conv_0_2, _ = \ DownRes(self.fea[2], **self.down_res_param)(down_res_1, is_training) down_res_3, conv_0_3, _ = \ DownRes(self.fea[3], **self.down_res_param)(down_res_2, is_training) conv_4 = Conv(n_output_chns=self.fea[4], kernel_size=self.k_conv, **self.down_res_param)(down_res_3, is_training) up_res_0 = UpRes(self.fea[3], **self.up_res_param)( conv_4, conv_0_3, is_training) up_res_1 = UpRes(self.fea[2], **self.up_res_param)( up_res_0, conv_0_2, is_training) up_res_2 = UpRes(self.fea[1], **self.up_res_param)( up_res_1, conv_0_1, is_training) up_res_3 = UpRes(self.fea[0], **self.up_res_param)( up_res_2, conv_0_0, is_training) if self.multi_scale_fusion: output_list = [up_res_3, up_res_2, up_res_1, up_res_0, conv_4] else: output_list = [up_res_3] # converting all output layers to displacement fields  dense_fields = [] for scale_out in output_list: field = Conv(n_output_chns=spatial_rank, kernel_size=self.k_conv, with_bias=True, with_bn=False, acti_func=None, **self.disp_param)(scale_out) resized_field = Resize(new_size=spatial_shape)(field) dense_fields.append(resized_field) if base_grid is None: # adding a reference grid if it doesn't exist in_spatial_size = [None] * spatial_rank base_grid = _create_affine_features(output_shape=spatial_shape, source_shape=in_spatial_size) base_grid = np.asarray(base_grid[:-1]) base_grid = np.reshape( base_grid.T, [-1] + spatial_shape + [spatial_rank]) base_grid = tf.constant(base_grid, dtype=resized_field.dtype) if self.multi_scale_fusion and len(dense_fields) > 1: dense_field = tf.reduce_sum(dense_fields, axis=0) else: dense_field = dense_fields[0] # TODO filtering if self.smoothing_func is not None: dense_field = self.smoothing_func(dense_field, spatial_rank) tf.add_to_collection('bending_energy', _computing_bending_energy(dense_field)) tf.add_to_collection('gradient_norm', _computing_gradient_norm(dense_field)) dense_field = dense_field + base_grid return dense_field
def _get_smoothing_kernel(sigma, spatial_rank): # sigma defined in voxel not in freeform deformation grid if sigma <= 0: raise NotImplementedError tail = int(sigma * 2) if spatial_rank == 2: x, y = np.mgrid[-tail:tail + 1, -tail:tail + 1] g = np.exp(-0.5 * (x * x + y * y) / sigma * sigma) elif spatial_rank == 3: x, y, z = np.mgrid[-tail:tail + 1, -tail:tail + 1, -tail:tail + 1] g = np.exp(-0.5 * (x * x + y * y + z * z) / sigma * sigma) else: raise NotImplementedError return g / g.sum() def _smoothing_func(sigma): def smoothing(dense_field, spatial_rank): kernel = _get_smoothing_kernel(sigma, spatial_rank) kernel = tf.constant(kernel, dtype=dense_field.dtype) kernel = tf.expand_dims(kernel, axis=-1) kernel = tf.expand_dims(kernel, axis=-1) smoothed = [ tf.nn.convolution(tf.expand_dims(coord, axis=-1), kernel, 'SAME') for coord in tf.unstack(dense_field, axis=-1)] return tf.concat(smoothed, axis=-1) return smoothing def _computing_bending_energy(displacement): spatial_rank = infer_spatial_rank(displacement) if spatial_rank == 2: return _computing_bending_energy_2d(displacement) if spatial_rank == 3: return _computing_bending_energy_3d(displacement) raise NotImplementedError( "Not implmented: bending energy for {}-d input".format(spatial_rank)) def _computing_bending_energy_2d(displacement): dTdx = ImgGrad(spatial_axis=0)(displacement) dTdy = ImgGrad(spatial_axis=1)(displacement) dTdxx = ImgGrad(spatial_axis=0)(dTdx) dTdyy = ImgGrad(spatial_axis=1)(dTdy) dTdxy = ImgGrad(spatial_axis=1)(dTdx) energy = tf.reduce_mean([dTdxx * dTdxx, dTdyy * dTdyy, 2 * dTdxy * dTdxy]) return energy def _computing_bending_energy_3d(displacement): dTdx = ImgGrad(spatial_axis=0)(displacement) dTdy = ImgGrad(spatial_axis=1)(displacement) dTdz = ImgGrad(spatial_axis=2)(displacement) dTdxx = ImgGrad(spatial_axis=0)(dTdx) dTdyy = ImgGrad(spatial_axis=1)(dTdy) dTdzz = ImgGrad(spatial_axis=2)(dTdz) dTdxy = ImgGrad(spatial_axis=1)(dTdx) dTdyz = ImgGrad(spatial_axis=2)(dTdy) dTdxz = ImgGrad(spatial_axis=2)(dTdx) energy = tf.reduce_mean( [dTdxx * dTdxx, dTdyy * dTdyy, dTdzz * dTdzz, 2 * dTdxy * dTdxy, 2 * dTdxz * dTdxz, 2 * dTdyz * dTdyz]) return energy def _computing_gradient_norm(displacement, flag_L1=False): norms = [] for spatial_ind in range(infer_spatial_rank(displacement)): dTdt = ImgGrad(spatial_axis=spatial_ind)(displacement) if flag_L1: norms.append(tf.abs(dTdt)) else: norms.append(dTdt * dTdt) return tf.reduce_mean(norms)