# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
from collections import namedtuple
import tensorflow as tf
from niftynet.io.misc_io import image3_axial
from niftynet.layer import layer_util
from niftynet.layer.affine_augmentation import AffineAugmentationLayer
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.bn import BNLayer
from niftynet.layer.channel_sparse_convolution \
import ChannelSparseConvolutionalLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.linear_resize import LinearResizeLayer
from niftynet.layer.downsample import DownSampleLayer
from niftynet.network.base_net import BaseNet
# Create a structure with all the fields of a DenseVNet network
DenseVNetDesc = namedtuple(
'DenseVNetParts',
['initial_bn', 'initial_conv', 'dense_vblocks', 'seg_layer']
)
[docs]class DenseVNet(BaseNet):
"""
implementation of Dense-V-Net:
Gibson et al.
Automatic multi-organ segmentation on abdominal CT with dense V-networks
### Diagram
DFS = Dense Feature Stack Block
- Initial image is first downsampled to a given size.
- Each DFS+SD outputs a skip link + a downsampled output.
- All outputs are upscaled to the initial downsampled size.
- If initial prior is given add it to the output prediction.
Input
|
--[ DFS ]-----------------------[ Conv ]------------[ Conv ]------[+]-->
| | | |
-----[ DFS ]---------------[ Conv ]------ | |
| | |
-----[ DFS ]-------[ Conv ]--------- |
[ Prior ]---
The layer DenseFeatureStackBlockWithSkipAndDownsample layer implements
[DFS + Conv + Downsampling] in a single module, and outputs 2 elements:
- Skip layer: [ DFS + Conv]
- Downsampled output: [ DFS + Down]
"""
__hyper_params__ = dict(
prior_size=12,
n_dense_channels=(4, 8, 16),
n_seg_channels=(12, 24, 24),
n_input_channels=(24, 24, 24),
dilation_rates=([1] * 5, [1] * 10, [1] * 10),
final_kernel=3,
augmentation_scale=0.1
)
__net_params__ = dict(
use_bdo=False,
use_prior=False,
use_dense_connections=True,
use_coords=False
)
def __init__(self,
num_classes,
hyperparameters={},
architecture_parameters={},
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
acti_func='relu',
name='DenseVNet'):
super(DenseVNet, self).__init__(
num_classes=num_classes,
w_initializer=w_initializer,
w_regularizer=w_regularizer,
b_initializer=b_initializer,
b_regularizer=b_regularizer,
acti_func=acti_func,
name=name)
# Override default Hyperparameters
self.hyperparameters = dict(self.__hyper_params__)
self.hyperparameters.update(hyperparameters)
# Check for dilation rates
if any([d != 1 for ds in self.hyperparameters['dilation_rates']
for d in ds]):
raise NotImplementedError(
'Dilated convolutions are not yet implemented')
# Override default architectural parameters
self.architecture_parameters = dict(self.__net_params__)
self.architecture_parameters.update(architecture_parameters)
# Check available modes
if self.architecture_parameters['use_dense_connections'] is False:
raise NotImplementedError(
'Non-dense connections are not yet implemented')
if self.architecture_parameters['use_coords'] is True:
raise NotImplementedError(
'Image coordinate augmentation is not yet implemented')
[docs] def create_network(self):
hyper = self.hyperparameters
# Initial Convolution
net_initial_conv = ConvolutionalLayer(
hyper['n_input_channels'][0],
kernel_size=5, stride=2
)
# Dense Block Params
downsample_channels = list(hyper['n_input_channels'][1:]) + [None]
num_blocks = len(hyper["n_dense_channels"])
use_bdo = self.architecture_parameters['use_bdo']
# Create DenseBlocks
net_dense_vblocks = []
for idx in range(num_blocks):
dense_ch = hyper["n_dense_channels"][idx] # Num dense channels
seg_ch = hyper["n_seg_channels"][idx] # Num segmentation ch
down_ch = downsample_channels[idx] # Num of downsampling ch
dil_rate = hyper["dilation_rates"][idx] # Dilation rate
# Dense feature block
dblock = DenseFeatureStackBlockWithSkipAndDownsample(
dense_ch, 3, dil_rate, seg_ch, down_ch, use_bdo,
acti_func='relu'
)
net_dense_vblocks.append(dblock)
# Segmentation
net_seg_layer = ConvolutionalLayer(
self.num_classes, kernel_size=hyper['final_kernel'],
with_bn=False, with_bias=True
)
return DenseVNetDesc(initial_bn=BNLayer(),
initial_conv=net_initial_conv,
dense_vblocks=net_dense_vblocks,
seg_layer=net_seg_layer)
[docs] def layer_op(self,
input_tensor,
is_training=True,
layer_id=-1,
keep_prob=0.5,
**unused_kwargs):
hyper = self.hyperparameters
# Initialize DenseVNet network layers
net = self.create_network()
#
# Parameter handling
#
# Shape and dimension variable shortcuts
channel_dim = len(input_tensor.shape) - 1
input_size = input_tensor.shape.as_list()
spatial_size = input_size[1:-1]
n_spatial_dims = input_tensor.shape.ndims - 2
# Validate input dimension with dilation rates
modulo = 2 ** (len(hyper['dilation_rates']))
assert layer_util.check_spatial_dims(input_tensor,
lambda x: x % modulo == 0)
#
# Augmentation + Downsampling + Initial Layers
#
# On the fly data augmentation
augment_layer = None
if is_training and hyper['augmentation_scale'] > 0:
augmentation_class = AffineAugmentationLayer
augment_layer = augmentation_class(
hyper['augmentation_scale'], 'LINEAR', 'ZERO')
input_tensor = augment_layer(input_tensor)
# Variable storing all intermediate results -- VLinks
all_segmentation_features = []
# Downsample input to the network
ave_downsample_layer = DownSampleLayer(
func='AVG', kernel_size=3, stride=2)
down_tensor = ave_downsample_layer(input_tensor)
downsampled_img = net.initial_bn(down_tensor, is_training=is_training)
# Add initial downsampled image VLink
all_segmentation_features.append(downsampled_img)
# All results should match the downsampled input's shape
output_shape = downsampled_img.shape.as_list()[1:-1]
init_features = net.initial_conv(input_tensor, is_training=is_training)
#
# Dense VNet Main Block
#
# `down` will handle the input of each Dense VNet block
# Initialize it by stacking downsampled image and initial conv features
down = tf.concat([downsampled_img, init_features], channel_dim)
# Process Dense VNet Blocks
for dblock in net.dense_vblocks:
# Get skip layer and activation output
skip, down = dblock(down,
is_training=is_training,
keep_prob=keep_prob)
# Resize skip layer to original shape and add VLink
skip = LinearResizeLayer(output_shape)(skip)
all_segmentation_features.append(skip)
# Concatenate all intermediate skip layers
inter_results = tf.concat(all_segmentation_features, channel_dim)
# Initial segmentation output
seg_output = net.seg_layer(inter_results, is_training=is_training)
#
# Dense VNet End - Now postprocess outputs
#
# Refine segmentation with prior if any
if self.architecture_parameters['use_prior']:
xyz_prior = SpatialPriorBlock([12] * n_spatial_dims, output_shape)
seg_output += xyz_prior
# Invert augmentation if any
if is_training and hyper['augmentation_scale'] > 0 \
and augment_layer is not None:
inverse_aug = augment_layer.inverse()
seg_output = inverse_aug(seg_output)
# Resize output to original size
seg_output = LinearResizeLayer(spatial_size)(seg_output)
# Segmentation results
seg_argmax = tf.to_float(tf.expand_dims(tf.argmax(seg_output, -1), -1))
seg_summary = seg_argmax * (255. / self.num_classes - 1)
# Image Summary
norm_axes = list(range(1, n_spatial_dims + 1))
mean, var = tf.nn.moments(input_tensor, axes=norm_axes, keep_dims=True)
timg = tf.to_float(input_tensor - mean) / (tf.sqrt(var) * 2.)
timg = (timg + 1.) * 127.
single_channel = tf.reduce_mean(timg, -1, True)
img_summary = tf.minimum(255., tf.maximum(0., single_channel))
if n_spatial_dims == 2:
tf.summary.image(
tf.get_default_graph().unique_name('imgseg'),
tf.concat([img_summary, seg_summary], 1),
5, [tf.GraphKeys.SUMMARIES])
elif n_spatial_dims == 3:
# Show summaries
image3_axial(
tf.get_default_graph().unique_name('imgseg'),
tf.concat([img_summary, seg_summary], 1),
5, [tf.GraphKeys.SUMMARIES])
else:
raise NotImplementedError(
'Image Summary only supports 2D and 3D images')
return seg_output
[docs]class SpatialPriorBlock(TrainableLayer):
def __init__(self,
prior_shape,
output_shape,
name='spatial_prior_block'):
super(SpatialPriorBlock, self).__init__(name=name)
self.prior_shape = prior_shape
self.output_shape = output_shape
[docs] def layer_op(self):
# The internal representation is probabilities so
# that resampling makes sense
prior = tf.get_variable('prior',
shape=self.prior_shape,
initializer=tf.constant_initializer(1))
return tf.log(LinearResizeLayer(self.output_shape)(prior))
DenseFSBlockDesc = namedtuple('DenseFSDesc', ['conv_layers'])
[docs]class DenseFeatureStackBlock(TrainableLayer):
"""
Dense Feature Stack Block
- Stack is initialized with the input from above layers.
- Iteratively the output of convolution layers is added to the stack.
- Each sequential convolution is performed over all the previous stacked
channels.
Diagram example:
stack = [Input]
stack = [stack, conv(stack)]
stack = [stack, conv(stack)]
stack = [stack, conv(stack)]
...
Output = [stack, conv(stack)]
"""
def __init__(self,
n_dense_channels,
kernel_size,
dilation_rates,
use_bdo,
name='dense_feature_stack_block',
**kwargs):
super(DenseFeatureStackBlock, self).__init__(name=name)
self.n_dense_channels = n_dense_channels
self.kernel_size = kernel_size
self.dilation_rates = dilation_rates
self.use_bdo = use_bdo
self.kwargs = kwargs
[docs] def create_block(self):
net_conv_layers = []
for _ in self.dilation_rates:
if self.use_bdo:
conv = ChannelSparseConvolutionalLayer(
self.n_dense_channels, kernel_size=self.kernel_size,
**self.kwargs
)
else:
conv = ConvolutionalLayer(self.n_dense_channels,
kernel_size=self.kernel_size,
**self.kwargs)
net_conv_layers.append(conv)
return DenseFSBlockDesc(conv_layers=net_conv_layers)
[docs] def layer_op(self, input_tensor, is_training=True, keep_prob=None):
# Initialize FeatureStackBlocks
block = self.create_block()
stack = [input_tensor]
channel_dim = len(input_tensor.shape) - 1
n_channels = input_tensor.shape.as_list()[-1]
input_mask = tf.ones([n_channels]) > 0
# Stack all convolution outputs
for idx, conv in enumerate(block.conv_layers):
if idx == len(self.dilation_rates) - 1:
keep_prob = None # no dropout on last layer of the stack
if self.use_bdo:
conv, new_input_mask = conv(tf.concat(stack, channel_dim),
input_mask=input_mask,
is_training=is_training,
keep_prob=keep_prob)
input_mask = tf.concat([input_mask, new_input_mask], 0)
else:
conv = conv(tf.concat(stack, channel_dim),
is_training=is_training,
keep_prob=keep_prob)
stack.append(conv)
if self.use_bdo: # unmask the conv channels
# modify the returning stack by:
# 1. Removing the input of the DFS from the stack
# 2. Unmasking the stack by filling in zeros
# see: https://github.com/NifTK/NiftyNet/pull/101
conv_channels = tf.concat(stack[1:], axis=-1)
# insert a channel with zeros to be placed
# where channels were not calculated
zero_channel = tf.zeros(conv_channels.shape[:-1])
zero_channel = tf.expand_dims(zero_channel, axis=-1)
conv_channels = tf.concat([zero_channel, conv_channels], axis=-1)
# indices to keep
int_mask = tf.cast(input_mask[n_channels:], tf.int32)
indices = tf.cumsum(int_mask) * int_mask
# rearrange stack with zeros where channels were not calculated
conv_channels = tf.gather(conv_channels, indices, axis=-1)
stack = [conv_channels]
return stack
DenseSDBlockDesc = namedtuple('DenseSDBlock', ['dense_fstack', 'conv', 'down'])
[docs]class DenseFeatureStackBlockWithSkipAndDownsample(TrainableLayer):
"""
Dense Feature Stack with Skip Layer and Downsampling
- Downsampling is done through strided convolution.
---[ DenseFeatureStackBlock ]----------[ Conv ]------- Skip layer
|
-------------------- Downsampled Output
See DenseFeatureStackBlock for more info.
"""
def __init__(self,
n_dense_channels,
kernel_size,
dilation_rates,
n_seg_channels,
n_downsample_channels,
use_bdo,
name='dense_feature_stack_block',
**kwargs):
super(DenseFeatureStackBlockWithSkipAndDownsample,
self).__init__(name=name)
self.n_dense_channels = n_dense_channels
self.kernel_size = kernel_size
self.dilation_rates = dilation_rates
self.n_seg_channels = n_seg_channels
self.n_downsample_channels = n_downsample_channels
self.use_bdo = use_bdo
self.kwargs = kwargs
[docs] def create_block(self):
net_dense_fstack = DenseFeatureStackBlock(
self.n_dense_channels, self.kernel_size, self.dilation_rates,
self.use_bdo, **self.kwargs
)
net_conv = ConvolutionalLayer(
self.n_seg_channels, kernel_size=self.kernel_size, **self.kwargs
)
net_down = None
if self.n_downsample_channels is not None:
net_down = ConvolutionalLayer(self.n_downsample_channels,
kernel_size=self.kernel_size,
stride=2, **self.kwargs)
return DenseSDBlockDesc(dense_fstack=net_dense_fstack,
conv=net_conv, down=net_down)
[docs] def layer_op(self, input_tensor, is_training=True, keep_prob=None):
# Current block model
block = self.create_block()
# Dense Feature Stack
stack = block.dense_fstack(input_tensor, is_training=is_training,
keep_prob=keep_prob)
all_features = tf.concat(stack, len(input_tensor.shape) - 1)
# Output Convolution
seg = block.conv(all_features, is_training=is_training,
keep_prob=keep_prob)
# Downsample if needed
down = None
if block.down is not None:
down = block.down(all_features, is_training=is_training,
keep_prob=keep_prob)
return seg, down