Source code for niftynet.network.holistic_net

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import tensorflow as tf

from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.dilatedcontext import DilatedTensor
from niftynet.layer.downsample import DownSampleLayer
from niftynet.layer.upsample import UpSampleLayer
from niftynet.network.base_net import BaseNet
from niftynet.network.highres3dnet import HighResBlock


[docs]class HolisticNet(BaseNet): """ Implementation of HolisticNet detailed in Fidon, L. et. al. (2017) Generalised Wasserstein Dice Score for Imbalanced Multi-class Segmentation using Holistic Convolutional Networks. MICCAI 2017 (BrainLes) """ def __init__(self, num_classes, w_initializer=None, w_regularizer=None, b_initializer=None, b_regularizer=None, acti_func='elu', name='HolisticNet'): super(HolisticNet, self).__init__( num_classes=num_classes, acti_func=acti_func, name=name, w_initializer=w_initializer, w_regularizer=w_regularizer, b_initializer=b_initializer, b_regularizer=b_regularizer) self.num_res_blocks = [3, 3, 3, 3] self.num_features = [70] * 4 self.num_fea_score_layers = [[70, 140]] * 4 # self.loss = LossFunction(num_classes, loss_type='Dice', decay=0.0)
[docs] def layer_op(self, input_tensor, is_training=True, layer_id=-1, **unused_kwargs): layer_instances = [] scores_instances = [] first_conv_layer = ConvolutionalLayer( n_output_chns=self.num_features[0], with_bn=True, kernel_size=3, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], acti_func=self.acti_func, name='conv_1_1') flow = first_conv_layer(input_tensor, is_training) layer_instances.append((first_conv_layer, flow)) # SCALE 1 with DilatedTensor(flow, dilation_factor=1) as dilated: for j in range(self.num_res_blocks[0]): res_block = HighResBlock( self.num_features[0], acti_func=self.acti_func, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], name='%s_%d' % ('res_1', j)) dilated.tensor = res_block(dilated.tensor, is_training) layer_instances.append((res_block, dilated.tensor)) flow = dilated.tensor score_layer_scale1 = ScoreLayer( num_features=self.num_fea_score_layers[0], num_classes=self.num_classes) score_1 = score_layer_scale1(flow, is_training) scores_instances.append(score_1) # if is_training: # loss_s1 = WGDL(score_1, labels) # tf.add_to_collection('multiscale_loss', loss_s1/num_scales) # # SCALE 2 with DilatedTensor(flow, dilation_factor=2) as dilated: for j in range(self.num_res_blocks[1]): res_block = HighResBlock( self.num_features[1], acti_func=self.acti_func, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], name='%s_%d' % ('res_2', j)) dilated.tensor = res_block(dilated.tensor, is_training) layer_instances.append((res_block, dilated.tensor)) flow = dilated.tensor score_layer_scale2 = ScoreLayer( num_features=self.num_fea_score_layers[1], num_classes=self.num_classes) score_2 = score_layer_scale2(flow, is_training) # score_2 = self.score_layer(flow, self.num_fea_score_layers[1]) up_score_2 = score_2 scores_instances.append(up_score_2) # if is_training: # loss_s2 = self.WGDL(score_2, labels) # # loss_s2 = self.new_dice_loss(score_2, labels) # tf.add_to_collection('multiscale_loss', loss_s2/num_scales) # SCALE 3 ## dowsampling factor = 2 downsample_scale3 = DownSampleLayer( func='AVG', kernel_size=2, stride=2) flow = downsample_scale3(flow) layer_instances.append((downsample_scale3, flow)) with DilatedTensor(flow, dilation_factor=1) as dilated: for j in range(self.num_res_blocks[2]): res_block = HighResBlock( self.num_features[2], acti_func=self.acti_func, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], name='%s_%d' % ('res_3', j)) dilated.tensor = res_block(dilated.tensor, is_training) layer_instances.append((res_block, dilated.tensor)) flow = dilated.tensor score_layer_scale3 = ScoreLayer( num_features=self.num_fea_score_layers[2], num_classes=self.num_classes) score_3 = score_layer_scale3(flow, is_training) upsample_indep_scale3 = UpSampleLayer( func='CHANNELWISE_DECONV', kernel_size=2, stride=2, w_initializer=tf.constant_initializer(1.0, dtype=tf.float32)) up_score_3 = upsample_indep_scale3(score_3) scores_instances.append(up_score_3) # up_score_3 = self.feature_indep_upsample_conv(score_3, factor=2) # if is_training: # loss_s3 = self.WGDL(up_score_3, labels) # # loss_s3 = self.new_dice_loss(up_score_3, labels) # tf.add_to_collection('multiscale_loss', loss_s3/num_scales) # SCALE 4 with DilatedTensor(flow, dilation_factor=2) as dilated: for j in range(self.num_res_blocks[3]): res_block = HighResBlock( self.num_features[3], acti_func=self.acti_func, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], name='%s_%d' % ('res_4', j)) dilated.tensor = res_block(dilated.tensor, is_training) layer_instances.append((res_block, dilated.tensor)) flow = dilated.tensor score_layer_scale4 = ScoreLayer( num_features=self.num_fea_score_layers[3], num_classes=self.num_classes) score_4 = score_layer_scale4( flow, self.num_fea_score_layers[3], is_training) upsample_indep_scale4 = UpSampleLayer( func='CHANNELWISE_DECONV', kernel_size=1, stride=2, w_initializer=tf.constant_initializer(1.0, dtype=tf.float32)) up_score_4 = upsample_indep_scale4(score_4) scores_instances.append(up_score_4) # if is_training: # loss_s4 = self.WGDL(up_score_4, labels) # # loss_s4 = self.new_dice_loss(up_score_4, labels) # tf.add_to_collection('multiscale_loss', loss_s4/num_scales) # FUSED SCALES merge_layer = MergeLayer('WEIGHTED_AVERAGE') soft_scores = [] for s in scores_instances: soft_scores.append(tf.nn.softmax(s)) fused_score = merge_layer(soft_scores) scores_instances.append(fused_score) if is_training: return scores_instances return fused_score
[docs]class ScoreLayer(TrainableLayer): def __init__(self, num_features=None, w_initializer=None, w_regularizer=None, num_classes=1, acti_func='elu', name='ScoreLayer'): super(ScoreLayer, self).__init__(name=name) self.num_classes = num_classes self.acti_func = acti_func self.num_features = num_features self.n_layers = len(self.num_features) self.initializers = {'w': w_initializer} self.regularizers = {'w': w_regularizer}
[docs] def layer_op(self, input_tensor, is_training, layer_id=-1): rank = input_tensor.shape.ndims perm = [i for i in range(rank)] perm[-2], perm[-1] = perm[-1], perm[-2] output_tensor = input_tensor n_layers = self.n_layers # All layers except the last one consists in: # BN + Conv_3x3x3 + Activation # layer_instances = [] for layer in range(n_layers - 1): layer_to_add = ConvolutionalLayer( n_output_chns=self.num_features[layer + 1], with_bn=True, kernel_size=3, w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], acti_func=self.acti_func, name='conv_fc_%d' % layer) output_tensor = layer_to_add(output_tensor, is_training) # layer_instances.append((layer_to_add, output_tensor)) last_layer = ConvolutionalLayer(n_output_chns=self.num_classes, kernel_size=1) output_tensor = last_layer(output_tensor, is_training) # layer_instances.append((last_layer, output_tensor)) return output_tensor
SUPPORTED_OPS = set(['AVERAGE', 'WEIGHTED_AVERAGE', 'MAXOUT'])
[docs]class MergeLayer(TrainableLayer): def __init__(self, func, w_initializer=None, w_regularizer=None, acti_func='elu', name='MergeLayer'): super(MergeLayer, self).__init__(name=name) self.func = func self.acti_func = acti_func self.initializers = {'w': w_initializer} self.regularizers = {'w': w_regularizer}
[docs] def layer_op(self, roots): if self.func == 'MAXOUT': return tf.reduce_max(tf.stack(roots, axis=-1), axis=-1) elif self.func == 'AVERAGE': return tf.reduce_mean(tf.stack(roots, axis=-1), axis=-1) elif self.func == 'WEIGHTED_AVERAGE': input_tensor = tf.stack(roots, axis=-1) rank = input_tensor.shape.ndims perm = [i for i in range(rank)] perm[-2], perm[-1] = perm[-1], perm[-2] output_tensor = input_tensor output_tensor = tf.transpose(output_tensor, perm=perm) output_tensor = tf.unstack(output_tensor, axis=-1) roots_merged = [] for f in range(len(output_tensor)): conv_layer = ConvLayer( n_output_chns=1, kernel_size=1, stride=1) roots_merged_f = conv_layer(output_tensor[f]) roots_merged.append(roots_merged_f) return tf.concat(roots_merged, axis=-1)