# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, division
import tensorflow as tf
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.network.base_net import BaseNet
from niftynet.network.highres3dnet import HighRes3DNet, HighResBlock
from niftynet.utilities.util_common import look_up_operations
[docs]class ScaleNet(BaseNet):
"""
implementation of ScaleNet:
Fidon et al., "Scalable multimodal convolutional
networks for brain tumour segmentation", MICCAI '17
### Diagram
INPUT --> [BACKEND] ----> [MERGING] ----> [FRONTEND] ---> OUTPUT
[BACKEND] and [MERGING] are provided by the ScaleBlock below
[FRONTEND]: it can be any NiftyNet network (default: HighRes3dnet)
### Constraints:
- Input image size should be divisible by 8
- more than one modality should be used
"""
[docs] def __init__(self,
num_classes,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
acti_func='prelu',
name='ScaleNet'):
"""
:param num_classes: int, number of channels of output
:param w_initializer: weight initialisation for network
:param w_regularizer: weight regularisation for network
:param b_initializer: bias initialisation for network
:param b_regularizer: bias regularisation for network
:param acti_func: activation function to use
:param name: layer name
"""
super(ScaleNet, self).__init__(
num_classes=num_classes,
w_initializer=w_initializer,
w_regularizer=w_regularizer,
b_initializer=b_initializer,
b_regularizer=b_regularizer,
acti_func=acti_func,
name=name)
self.n_features = 16
[docs] def layer_op(self, images, is_training=True, layer_id=-1, **unused_kwargs):
"""
:param images: tensor, concatenation of multiple input modalities
:param is_training: boolean, True if network is in training mode
:param layer_id: not in use
:param unused_kwargs:
:return: predicted tensor
"""
n_modality = images.shape.as_list()[-1]
rank = images.shape.ndims
assert n_modality > 1
roots = tf.split(images, n_modality, axis=rank - 1)
for (idx, root) in enumerate(roots):
conv_layer = ConvolutionalLayer(
n_output_chns=self.n_features,
kernel_size=3,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='conv_{}'.format(idx))
roots[idx] = conv_layer(root, is_training)
roots = tf.stack(roots, axis=-1)
back_end = ScaleBlock('AVERAGE', n_layers=1)
output_tensor = back_end(roots, is_training)
front_end = HighRes3DNet(self.num_classes)
output_tensor = front_end(output_tensor, is_training)
return output_tensor
SUPPORTED_OP = set(['MAX', 'AVERAGE'])
[docs]class ScaleBlock(TrainableLayer):
"""
Implementation of the ScaleBlock described in
Fidon et al., "Scalable multimodal convolutional
networks for brain tumour segmentation", MICCAI '17
See Fig 2(a) for diagram details - SN BackEnd
"""
[docs] def __init__(self,
func,
n_layers=1,
w_initializer=None,
w_regularizer=None,
acti_func='relu',
name='scaleblock'):
"""
:param func: merging function (SUPPORTED_OP: MAX, AVERAGE)
:param n_layers: int, number of layers
:param w_initializer: weight initialisation for network
:param w_regularizer: weight regularisation for network
:param acti_func: activation function to use
:param name: layer name
"""
self.func = look_up_operations(func.upper(), SUPPORTED_OP)
super(ScaleBlock, self).__init__(name=name)
self.n_layers = n_layers
self.acti_func = acti_func
self.initializers = {'w': w_initializer}
self.regularizers = {'w': w_regularizer}
[docs] def layer_op(self, input_tensor, is_training):
"""
:param input_tensor: tensor, input to the network
:param is_training: boolean, True if network is in training mode
:return: merged tensor after backend layers
"""
n_modality = input_tensor.shape.as_list()[-1]
n_chns = input_tensor.shape.as_list()[-2]
rank = input_tensor.shape.ndims
perm = [i for i in range(rank)]
perm[-2], perm[-1] = perm[-1], perm[-2]
output_tensor = input_tensor
for layer in range(self.n_layers):
# modalities => feature channels
output_tensor = tf.transpose(output_tensor, perm=perm)
output_tensor = tf.unstack(output_tensor, axis=-1)
for (idx, tensor) in enumerate(output_tensor):
block_name = 'M_F_{}_{}'.format(layer, idx)
highresblock_op = HighResBlock(
n_output_chns=n_modality,
kernels=(3, 1),
with_res=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name=block_name)
output_tensor[idx] = highresblock_op(tensor, is_training)
print(highresblock_op)
output_tensor = tf.stack(output_tensor, axis=-1)
# feature channels => modalities
output_tensor = tf.transpose(output_tensor, perm=perm)
output_tensor = tf.unstack(output_tensor, axis=-1)
for (idx, tensor) in enumerate(output_tensor):
block_name = 'F_M_{}_{}'.format(layer, idx)
highresblock_op = HighResBlock(
n_output_chns=n_chns,
kernels=(3, 1),
with_res=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name=block_name)
output_tensor[idx] = highresblock_op(tensor, is_training)
print(highresblock_op)
output_tensor = tf.stack(output_tensor, axis=-1)
if self.func == 'MAX':
output_tensor = tf.reduce_max(output_tensor, axis=-1)
elif self.func == 'AVERAGE':
output_tensor = tf.reduce_mean(output_tensor, axis=-1)
return output_tensor