# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.dilatedcontext import DilatedTensor
from niftynet.layer.downsample import DownSampleLayer
from niftynet.layer.upsample import UpSampleLayer
from niftynet.network.base_net import BaseNet
from niftynet.network.highres3dnet import HighResBlock
[docs]class HolisticNet(BaseNet):
"""
Implementation of HolisticNet detailed in
Fidon, L. et. al. (2017) Generalised Wasserstein Dice Score for Imbalanced
Multi-class Segmentation using Holistic Convolutional Networks.
MICCAI 2017 (BrainLes)
"""
def __init__(self,
num_classes,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
acti_func='elu',
name='HolisticNet'):
super(HolisticNet, self).__init__(
num_classes=num_classes,
acti_func=acti_func,
name=name,
w_initializer=w_initializer,
w_regularizer=w_regularizer,
b_initializer=b_initializer,
b_regularizer=b_regularizer)
self.num_res_blocks = [3, 3, 3, 3]
self.num_features = [70] * 4
self.num_fea_score_layers = [[70, 140]] * 4
# self.loss = LossFunction(num_classes, loss_type='Dice', decay=0.0)
[docs] def layer_op(self,
input_tensor,
is_training=True,
layer_id=-1,
**unused_kwargs):
layer_instances = []
scores_instances = []
first_conv_layer = ConvolutionalLayer(
n_output_chns=self.num_features[0],
with_bn=True,
kernel_size=3,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='conv_1_1')
flow = first_conv_layer(input_tensor, is_training)
layer_instances.append((first_conv_layer, flow))
# SCALE 1
with DilatedTensor(flow, dilation_factor=1) as dilated:
for j in range(self.num_res_blocks[0]):
res_block = HighResBlock(
self.num_features[0],
acti_func=self.acti_func,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
name='%s_%d' % ('res_1', j))
dilated.tensor = res_block(dilated.tensor, is_training)
layer_instances.append((res_block, dilated.tensor))
flow = dilated.tensor
score_layer_scale1 = ScoreLayer(
num_features=self.num_fea_score_layers[0],
num_classes=self.num_classes)
score_1 = score_layer_scale1(flow, is_training)
scores_instances.append(score_1)
# if is_training:
# loss_s1 = WGDL(score_1, labels)
# tf.add_to_collection('multiscale_loss', loss_s1/num_scales)
# # SCALE 2
with DilatedTensor(flow, dilation_factor=2) as dilated:
for j in range(self.num_res_blocks[1]):
res_block = HighResBlock(
self.num_features[1],
acti_func=self.acti_func,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
name='%s_%d' % ('res_2', j))
dilated.tensor = res_block(dilated.tensor, is_training)
layer_instances.append((res_block, dilated.tensor))
flow = dilated.tensor
score_layer_scale2 = ScoreLayer(
num_features=self.num_fea_score_layers[1],
num_classes=self.num_classes)
score_2 = score_layer_scale2(flow, is_training)
# score_2 = self.score_layer(flow, self.num_fea_score_layers[1])
up_score_2 = score_2
scores_instances.append(up_score_2)
# if is_training:
# loss_s2 = self.WGDL(score_2, labels)
# # loss_s2 = self.new_dice_loss(score_2, labels)
# tf.add_to_collection('multiscale_loss', loss_s2/num_scales)
# SCALE 3
## dowsampling factor = 2
downsample_scale3 = DownSampleLayer(
func='AVG', kernel_size=2, stride=2)
flow = downsample_scale3(flow)
layer_instances.append((downsample_scale3, flow))
with DilatedTensor(flow, dilation_factor=1) as dilated:
for j in range(self.num_res_blocks[2]):
res_block = HighResBlock(
self.num_features[2],
acti_func=self.acti_func,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
name='%s_%d' % ('res_3', j))
dilated.tensor = res_block(dilated.tensor, is_training)
layer_instances.append((res_block, dilated.tensor))
flow = dilated.tensor
score_layer_scale3 = ScoreLayer(
num_features=self.num_fea_score_layers[2],
num_classes=self.num_classes)
score_3 = score_layer_scale3(flow, is_training)
upsample_indep_scale3 = UpSampleLayer(
func='CHANNELWISE_DECONV',
kernel_size=2,
stride=2,
w_initializer=tf.constant_initializer(1.0, dtype=tf.float32))
up_score_3 = upsample_indep_scale3(score_3)
scores_instances.append(up_score_3)
# up_score_3 = self.feature_indep_upsample_conv(score_3, factor=2)
# if is_training:
# loss_s3 = self.WGDL(up_score_3, labels)
# # loss_s3 = self.new_dice_loss(up_score_3, labels)
# tf.add_to_collection('multiscale_loss', loss_s3/num_scales)
# SCALE 4
with DilatedTensor(flow, dilation_factor=2) as dilated:
for j in range(self.num_res_blocks[3]):
res_block = HighResBlock(
self.num_features[3],
acti_func=self.acti_func,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
name='%s_%d' % ('res_4', j))
dilated.tensor = res_block(dilated.tensor, is_training)
layer_instances.append((res_block, dilated.tensor))
flow = dilated.tensor
score_layer_scale4 = ScoreLayer(
num_features=self.num_fea_score_layers[3],
num_classes=self.num_classes)
score_4 = score_layer_scale4(
flow,
self.num_fea_score_layers[3],
is_training)
upsample_indep_scale4 = UpSampleLayer(
func='CHANNELWISE_DECONV',
kernel_size=1,
stride=2,
w_initializer=tf.constant_initializer(1.0, dtype=tf.float32))
up_score_4 = upsample_indep_scale4(score_4)
scores_instances.append(up_score_4)
# if is_training:
# loss_s4 = self.WGDL(up_score_4, labels)
# # loss_s4 = self.new_dice_loss(up_score_4, labels)
# tf.add_to_collection('multiscale_loss', loss_s4/num_scales)
# FUSED SCALES
merge_layer = MergeLayer('WEIGHTED_AVERAGE')
soft_scores = []
for s in scores_instances:
soft_scores.append(tf.nn.softmax(s))
fused_score = merge_layer(soft_scores)
scores_instances.append(fused_score)
if is_training:
return scores_instances
return fused_score
[docs]class ScoreLayer(TrainableLayer):
def __init__(self,
num_features=None,
w_initializer=None,
w_regularizer=None,
num_classes=1,
acti_func='elu',
name='ScoreLayer'):
super(ScoreLayer, self).__init__(name=name)
self.num_classes = num_classes
self.acti_func = acti_func
self.num_features = num_features
self.n_layers = len(self.num_features)
self.initializers = {'w': w_initializer}
self.regularizers = {'w': w_regularizer}
[docs] def layer_op(self, input_tensor, is_training, layer_id=-1):
rank = input_tensor.shape.ndims
perm = [i for i in range(rank)]
perm[-2], perm[-1] = perm[-1], perm[-2]
output_tensor = input_tensor
n_layers = self.n_layers
# All layers except the last one consists in:
# BN + Conv_3x3x3 + Activation
# layer_instances = []
for layer in range(n_layers - 1):
layer_to_add = ConvolutionalLayer(
n_output_chns=self.num_features[layer + 1],
with_bn=True,
kernel_size=3,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='conv_fc_%d' % layer)
output_tensor = layer_to_add(output_tensor, is_training)
# layer_instances.append((layer_to_add, output_tensor))
last_layer = ConvolutionalLayer(n_output_chns=self.num_classes,
kernel_size=1)
output_tensor = last_layer(output_tensor, is_training)
# layer_instances.append((last_layer, output_tensor))
return output_tensor
SUPPORTED_OPS = set(['AVERAGE', 'WEIGHTED_AVERAGE', 'MAXOUT'])
[docs]class MergeLayer(TrainableLayer):
def __init__(self,
func,
w_initializer=None,
w_regularizer=None,
acti_func='elu',
name='MergeLayer'):
super(MergeLayer, self).__init__(name=name)
self.func = func
self.acti_func = acti_func
self.initializers = {'w': w_initializer}
self.regularizers = {'w': w_regularizer}
[docs] def layer_op(self, roots):
if self.func == 'MAXOUT':
return tf.reduce_max(tf.stack(roots, axis=-1), axis=-1)
elif self.func == 'AVERAGE':
return tf.reduce_mean(tf.stack(roots, axis=-1), axis=-1)
elif self.func == 'WEIGHTED_AVERAGE':
input_tensor = tf.stack(roots, axis=-1)
rank = input_tensor.shape.ndims
perm = [i for i in range(rank)]
perm[-2], perm[-1] = perm[-1], perm[-2]
output_tensor = input_tensor
output_tensor = tf.transpose(output_tensor, perm=perm)
output_tensor = tf.unstack(output_tensor, axis=-1)
roots_merged = []
for f in range(len(output_tensor)):
conv_layer = ConvLayer(
n_output_chns=1, kernel_size=1, stride=1)
roots_merged_f = conv_layer(output_tensor[f])
roots_merged.append(roots_merged_f)
return tf.concat(roots_merged, axis=-1)