# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
from collections import namedtuple
import abc
import tensorflow as tf
from niftynet.layer import layer_util
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.channel_sparse_convolution \
import ChannelSparseConvolutionalLayer
from niftynet.layer.bn import BNLayer
from niftynet.layer.spatial_transformer import ResamplerLayer
from niftynet.layer.grid_warper import AffineGridWarperLayer
from niftynet.network.base_net import BaseNet
from niftynet.io.misc_io import image3_axial
# Create a structure with all the fields of a DenseVNet network
DenseVNetDesc = namedtuple(
'DenseVNetParts',
['initial_bn', 'initial_conv', 'dense_vblocks', 'seg_layer']
)
[docs]class DenseVNet(BaseNet):
"""
implementation of Dense-V-Net:
Gibson et al.
Automatic multi-organ segmentation on abdominal CT with dense V-networks
### Diagram
DFS = Dense Feature Stack Block
- Initial image is first downsampled to a given size.
- Each DFS+SD outputs a skip link + a downsampled output.
- All outputs are upscaled to the initial downsampled size.
- If initial prior is given add it to the output prediction.
Input
|
--[ DFS ]-----------------------[ Conv ]------------[ Conv ]------[+]-->
| | | |
-----[ DFS ]---------------[ Conv ]------ | |
| | |
-----[ DFS ]-------[ Conv ]--------- |
[ Prior ]---
The layer DenseFeatureStackBlockWithSkipAndDownsample layer implements
[DFS + Conv + Downsampling] in a single module, and outputs 2 elements:
- Skip layer: [ DFS + Conv]
- Downsampled output: [ DFS + Down]
"""
__hyper_params__ = dict(
prior_size=12,
p_channels_selected=0.5,
n_dense_channels=(4, 8, 16),
n_seg_channels=(12, 24, 24),
n_input_channels=(24, 24, 24),
dilation_rates=([1] * 5, [1] * 10, [1] * 10),
final_kernel=3,
augmentation_scale=0.1
)
__net_params__ = dict(
use_bdo=False,
use_prior=False,
use_dense_connections=True,
use_coords=False
)
def __init__(self,
num_classes,
hyperparameters={},
architecture_parameters={},
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
acti_func='relu',
name='DenseVNet'):
super(DenseVNet, self).__init__(
num_classes=num_classes,
w_initializer=w_initializer,
w_regularizer=w_regularizer,
b_initializer=b_initializer,
b_regularizer=b_regularizer,
acti_func=acti_func,
name=name)
# Override default Hyperparameters
self.hyperparameters = dict(self.__hyper_params__)
self.hyperparameters.update(hyperparameters)
# Check for dilation rates
if any([d != 1 for ds in self.hyperparameters['dilation_rates']
for d in ds]):
raise NotImplementedError(
'Dilated convolutions are not yet implemented')
# Override default architectural parameters
self.architecture_parameters = dict(self.__net_params__)
self.architecture_parameters.update(architecture_parameters)
# Check available modes
if self.architecture_parameters['use_dense_connections'] is False:
raise NotImplementedError(
'Non-dense connections are not yet implemented')
if self.architecture_parameters['use_coords'] is True:
raise NotImplementedError(
'Image coordinate augmentation is not yet implemented')
[docs] def create_network(self):
hyper = self.hyperparameters
# Initial Convolution
net_initial_conv = ConvolutionalLayer(
hyper['n_input_channels'][0],
kernel_size=5, stride=2
)
# Dense Block Params
downsample_channels = list(hyper['n_input_channels'][1:]) + [None]
num_blocks = len(hyper["n_dense_channels"])
use_bdo = self.architecture_parameters['use_bdo']
# Create DenseBlocks
net_dense_vblocks = []
for idx in range(num_blocks):
dense_ch = hyper["n_dense_channels"][idx] # Num dense channels
seg_ch = hyper["n_seg_channels"][idx] # Num segmentation ch
down_ch = downsample_channels[idx] # Num of downsampling ch
dil_rate = hyper["dilation_rates"][idx] # Dilation rate
# Dense feature block
dblock = DenseFeatureStackBlockWithSkipAndDownsample(
dense_ch, 3, dil_rate, seg_ch, down_ch, use_bdo,
acti_func='relu'
)
net_dense_vblocks.append(dblock)
# Segmentation
net_seg_layer = ConvolutionalLayer(
self.num_classes, kernel_size=hyper['final_kernel'],
with_bn=False, with_bias=True
)
return DenseVNetDesc(initial_bn=BNLayer(),
initial_conv=net_initial_conv,
dense_vblocks=net_dense_vblocks,
seg_layer=net_seg_layer)
[docs] def layer_op(self, input_tensor, is_training, layer_id=-1):
hyper = self.hyperparameters
# Initialize DenseVNet network layers
net = self.create_network()
#
# Parameter handling
#
# Shape and dimension variable shortcuts
channel_dim = len(input_tensor.shape) - 1
input_size = input_tensor.shape.as_list()
spatial_size = input_size[1:-1]
n_spatial_dims = input_tensor.shape.ndims - 2
# Quick access to hyperparams
pkeep = hyper['p_channels_selected']
# Validate input dimension with dilation rates
modulo = 2 ** (len(hyper['dilation_rates']))
assert layer_util.check_spatial_dims(input_tensor,
lambda x: x % modulo == 0)
#
# Augmentation + Downsampling + Initial Layers
#
# On the fly data augmentation
if is_training and hyper['augmentation_scale'] > 0:
if n_spatial_dims == 2:
augmentation_class = Affine2DAugmentationLayer
elif n_spatial_dims == 3:
augmentation_class = Affine3DAugmentationLayer
else:
raise NotImplementedError(
'Affine augmentation only supports 2D and 3D images')
augment_layer = augmentation_class(hyper['augmentation_scale'],
'LINEAR', 'ZERO')
input_tensor = augment_layer(input_tensor)
# Variable storing all intermediate results -- VLinks
all_segmentation_features = []
# Downsample input to the network
down_tensor = self.downsample_input(input_tensor, n_spatial_dims)
downsampled_img = net.initial_bn(down_tensor, is_training=is_training)
# Add initial downsampled image VLink
all_segmentation_features.append(downsampled_img)
# All results should match the downsampled input's shape
output_shape = downsampled_img.shape.as_list()[1:-1]
init_features = net.initial_conv(input_tensor, is_training=is_training)
#
# Dense VNet Main Block
#
# `down` will handle the input of each Dense VNet block
# Initialize it by stacking downsampled image and initial conv features
down = tf.concat([downsampled_img, init_features], channel_dim)
# Process Dense VNet Blocks
for dblock in net.dense_vblocks:
# Get skip layer and activation output
skip, down = dblock(down, is_training=is_training, keep_prob=pkeep)
# Resize skip layer to original shape and add VLink
skip = image_resize(skip, output_shape)
all_segmentation_features.append(skip)
# Concatenate all intermediate skip layers
inter_results = tf.concat(all_segmentation_features, channel_dim)
# Initial segmentation output
seg_output = net.seg_layer(inter_results, is_training=is_training)
#
# Dense VNet End - Now postprocess outputs
#
# Refine segmentation with prior if any
if self.architecture_parameters['use_prior']:
xyz_prior = SpatialPriorBlock([12] * n_spatial_dims, output_shape)
seg_output += xyz_prior
# Invert augmentation if any
if is_training and hyper['augmentation_scale'] > 0:
inverse_aug = augment_layer.inverse()
seg_output = inverse_aug(seg_output)
# Resize output to original size
seg_output = image_resize(seg_output, spatial_size)
# Segmentation results
seg_argmax = tf.to_float(tf.expand_dims(tf.argmax(seg_output, -1), -1))
seg_summary = seg_argmax * (255. / self.num_classes - 1)
# Image Summary
norm_axes = list(range(1, n_spatial_dims+1))
mean, var = tf.nn.moments(input_tensor, axes=norm_axes, keep_dims=True)
timg = tf.to_float(input_tensor - mean) / (tf.sqrt(var) * 2.)
timg = (timg + 1.) * 127.
single_channel = tf.reduce_mean(timg, axis=-1, keep_dims=True)
img_summary = tf.minimum(255., tf.maximum(0., single_channel))
if n_spatial_dims == 2:
tf.summary.image(
tf.get_default_graph().unique_name('imgseg'),
tf.concat([img_summary, seg_summary], 1),
5, [tf.GraphKeys.SUMMARIES])
elif n_spatial_dims == 3:
# Show summaries
image3_axial(
tf.get_default_graph().unique_name('imgseg'),
tf.concat([img_summary, seg_summary], 1),
5, [tf.GraphKeys.SUMMARIES])
else:
raise NotImplementedError(
'Image Summary only supports 2D and 3D images')
return seg_output
[docs]def image_resize(image, output_size):
input_size = image.shape.as_list()
spatial_rank = len(input_size) - 2
if all([o == i for o, i in zip(output_size, input_size[1:-1])]):
return image
if spatial_rank == 2:
return tf.image.resize_images(image, output_size)
first_reshape = tf.reshape(image, input_size[0:3] + [-1])
first_resize = tf.image.resize_images(first_reshape, output_size[0:2])
second_shape = input_size[:1] + [output_size[0] * output_size[1],
input_size[3], -1]
second_reshape = tf.reshape(first_resize, second_shape)
second_resize = tf.image.resize_images(second_reshape,
[second_shape[1], output_size[2]])
final_shape = input_size[:1] + output_size + input_size[-1:]
return tf.reshape(second_resize, final_shape)
[docs]class SpatialPriorBlock(TrainableLayer):
def __init__(self,
prior_shape,
output_shape,
name='spatial_prior_block'):
super(SpatialPriorBlock, self).__init__(name=name)
self.prior_shape = prior_shape
self.output_shape = output_shape
[docs] def layer_op(self):
# The internal representation is probabilities so
# that resampling makes sense
prior = tf.get_variable('prior',
shape=self.prior_shape,
initializer=tf.constant_initializer(1))
return tf.log(image_resize(prior, self.output_shape))
DenseFSBlockDesc = namedtuple('DenseFSDesc', ['conv_layers'])
[docs]class DenseFeatureStackBlock(TrainableLayer):
"""
Dense Feature Stack Block
- Stack is initialized with the input from above layers.
- Iteratively the output of convolution layers is added to the stack.
- Each sequential convolution is performed over all the previous stacked
channels.
Diagram example:
stack = [Input]
stack = [stack, conv(stack)]
stack = [stack, conv(stack)]
stack = [stack, conv(stack)]
...
Output = [stack, conv(stack)]
"""
def __init__(self,
n_dense_channels,
kernel_size,
dilation_rates,
use_bdo,
name='dense_feature_stack_block',
**kwargs):
super(DenseFeatureStackBlock, self).__init__(name=name)
self.n_dense_channels = n_dense_channels
self.kernel_size = kernel_size
self.dilation_rates = dilation_rates
self.use_bdo = use_bdo
self.kwargs = kwargs
[docs] def create_block(self):
net_conv_layers = []
for _ in self.dilation_rates:
if self.use_bdo:
conv = ChannelSparseConvolutionalLayer(
self.n_dense_channels, kernel_size=self.kernel_size,
**self.kwargs
)
else:
conv = ConvolutionalLayer(self.n_dense_channels,
kernel_size=self.kernel_size,
**self.kwargs)
net_conv_layers.append(conv)
return DenseFSBlockDesc(conv_layers=net_conv_layers)
[docs] def layer_op(self, input_tensor, is_training=None, keep_prob=None):
# Initialize FeatureStackBlocks
block = self.create_block()
stack = [input_tensor]
channel_dim = len(input_tensor.shape) - 1
input_mask = tf.ones([input_tensor.shape.as_list()[-1]]) > 0
# Stack all convolution outputs
for idx, conv in enumerate(block.conv_layers):
if idx == len(self.dilation_rates) - 1:
keep_prob = None # no dropout on last layer of the stack
if self.use_bdo:
conv, new_input_mask = conv(tf.concat(stack, channel_dim),
input_mask=input_mask,
is_training=is_training,
keep_prob=keep_prob)
input_mask = tf.concat([input_mask, new_input_mask], 0)
else:
conv = conv(tf.concat(stack, channel_dim),
is_training=is_training,
keep_prob=keep_prob)
stack.append(conv)
return stack
DenseSDBlockDesc = namedtuple('DenseSDBlock', ['dense_fstack', 'conv', 'down'])
[docs]class DenseFeatureStackBlockWithSkipAndDownsample(TrainableLayer):
"""
Dense Feature Stack with Skip Layer and Downsampling
- Downsampling is done through strided convolution.
---[ DenseFeatureStackBlock ]----------[ Conv ]------- Skip layer
|
-------------------- Downsampled Output
See DenseFeatureStackBlock for more info.
"""
def __init__(self,
n_dense_channels,
kernel_size,
dilation_rates,
n_seg_channels,
n_downsample_channels,
use_bdo,
name='dense_feature_stack_block',
**kwargs):
super(DenseFeatureStackBlockWithSkipAndDownsample,
self).__init__(name=name)
self.n_dense_channels = n_dense_channels
self.kernel_size = kernel_size
self.dilation_rates = dilation_rates
self.n_seg_channels = n_seg_channels
self.n_downsample_channels = n_downsample_channels
self.use_bdo = use_bdo
self.kwargs = kwargs
[docs] def create_block(self):
net_dense_fstack = DenseFeatureStackBlock(
self.n_dense_channels, self.kernel_size, self.dilation_rates,
self.use_bdo, **self.kwargs
)
net_conv = ConvolutionalLayer(
self.n_seg_channels, kernel_size=self.kernel_size, **self.kwargs
)
net_down = None
if self.n_downsample_channels is not None:
net_down = ConvolutionalLayer(self.n_downsample_channels,
kernel_size=self.kernel_size,
stride=2, **self.kwargs)
return DenseSDBlockDesc(dense_fstack=net_dense_fstack,
conv=net_conv, down=net_down)
[docs] def layer_op(self, input_tensor, is_training=None, keep_prob=None):
# Current block model
block = self.create_block()
# Dense Feature Stack
stack = block.dense_fstack(input_tensor, is_training=is_training,
keep_prob=keep_prob)
all_features = tf.concat(stack, len(input_tensor.shape) - 1)
# Output Convolution
seg = block.conv(all_features, is_training=is_training,
keep_prob=keep_prob)
# Downsample if needed
down = None
if block.down is not None:
down = block.down(all_features, is_training=is_training,
keep_prob=keep_prob)
return seg, down
[docs]class AffineAugmentationLayer(TrainableLayer):
""" This layer applies a small random (per-iteration) affine
transformation to an image. The distribution of transformations
generally results in scaling the image up, with minimal sampling
outside the original image."""
__metaclass__ = abc.ABCMeta
[docs] def __init__(self, scale, interpolation,
boundary, transform_func=None,
name='AffineAugmentation'):
""""
scale denotes how extreme the perturbation is, with 1. meaning
no perturbation and 0.5 giving larger perturbations.
interpolation denotes the image value interpolation used by
the resampling
boundary denotes the boundary handling used by the resampling
transform_func should be a function returning a relative
transformation (mapping <-1..1,-1..1,-1..1> to <-1..1,-1..1,-1..1>
or <-1..1,-1..1> to <-1..1,-1..1>)"""
super(AffineAugmentationLayer, self).__init__(name=name)
self.scale = scale
if transform_func is None:
self.transform_func = self.random_transform
else:
self.transform_func = transform_func
self._transform = None
self.interpolation = interpolation
self.boundary = boundary
[docs] def layer_op(self, input_tensor):
size = input_tensor.shape.as_list()
grid_warper = AffineGridWarperLayer(size[1:-1],
size[1:-1])
resampler = ResamplerLayer(interpolation=self.interpolation,
boundary=self.boundary)
relative_transform = self.transform_func(size[0])
to_relative = tf.tile(self.get_tfm_to_relative(size), [size[0], 1, 1])
from_relative = tf.matrix_inverse(to_relative)
voxel_transform = tf.matmul(from_relative,
tf.matmul(relative_transform, to_relative))
dims = self.spatial_dims
warp_parameters = tf.reshape(voxel_transform[:, 0:dims, 0:dims + 1],
[size[0], dims * (dims + 1)])
grid = grid_warper(warp_parameters)
return resampler(input_tensor, grid)
[docs] def inverse(self, interpolation=None, boundary=None):
if interpolation is None:
interpolation = self.interpolation
if boundary is None:
boundary = self.boundary
return self.__class__(self.scale,
interpolation,
boundary,
self.inverse_transform)
@abc.abstractproperty
def spatial_dims(self):
raise NotImplementedError
[docs] @abc.abstractmethod
def get_corners(self):
raise NotImplementedError
[docs] @abc.abstractmethod
def get_tfm_to_relative(self):
raise NotImplementedError
[docs]class Affine2DAugmentationLayer(AffineAugmentationLayer):
""" Specialization of AffineAugmentationLayer for 2D coordinates """
spatial_dims = 2
[docs] def get_corners(self):
return [[[-1., -1.],
[-1., 1.],
[1., -1.],
[1., 1.]]]
[docs] def get_tfm_to_relative(self, size):
return [[[2./(size[1]-1), 0., -1.],
[0., 2. / (size[2] - 1), -1.],
[0., 0., 1.]]]
[docs]class Affine3DAugmentationLayer(AffineAugmentationLayer):
""" Specialization of AffineAugmentationLayer for 3D coordinates """
spatial_dims = 3
[docs] def get_corners(self):
return [[[-1., -1., -1.],
[-1., -1., 1.],
[-1., 1., -1.],
[-1., 1., 1.],
[1., -1., -1.],
[1., -1., 1.],
[1., 1., -1.],
[1., 1., 1.]]]
[docs] def get_tfm_to_relative(self, size):
return [[[2./(size[1]-1), 0., 0., -1.],
[0., 2. / (size[2] - 1), 0., -1.],
[0., 0., 2. / (size[3] - 1), -1.],
[0., 0., 0., 1.]]]