# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
from niftynet.layer import layer_util
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.deconvolution import DeconvolutionalLayer
from niftynet.layer.downsample import DownSampleLayer
from niftynet.layer.elementwise import ElementwiseLayer
from niftynet.layer.crop import CropLayer
from niftynet.utilities.util_common import look_up_operations
[docs]class UNet3D(TrainableLayer):
"""
### Description
reimplementation of 3D U-net
Çiçek et al., "3D U-Net: Learning dense Volumetric segmentation from
sparse annotation", MICCAI '16
### Building blocks
[dBLOCK] - Downsampling UNet Block
[uBLOCK] - Upsampling UNet Block
[nBLOCK] - UNet Block with no final operation
[CROP] - Cropping layer
### Diagram
INPUT --> [dBLOCK] - - - - - - - - - - - - - - - - [nBLOCK] --> [CROP] --> OUTPUT
| |
[dBLOCK] - - - - - - - - - - - - [uBLOCK]
| |
[dBLOCK] - - - - - - - [uBLOCK]
| |
--------[uBLOCk] ------
### Constraints
- Image size - 4 should be divisible by 8
- Label size should be more than 88
- border is 44
"""
[docs] def __init__(self,
num_classes,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
acti_func='prelu',
name='UNet'):
"""
:param num_classes: int, number of final output channels
:param w_initializer: weight initialisation for network
:param w_regularizer: weight regularisation for network
:param b_initializer: bias initialisation for network
:param b_regularizer: bias regularisation for network
:param acti_func: activation function to use
:param name: layer name
"""
super(UNet3D, self).__init__(name=name)
self.n_features = [32, 64, 128, 256, 512]
self.acti_func = acti_func
self.num_classes = num_classes
self.initializers = {'w': w_initializer, 'b': b_initializer}
self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
print('using {}'.format(name))
[docs] def layer_op(self, images, is_training=True, layer_id=-1, **unused_kwargs):
"""
:param images: tensor, input to the network
:param is_training: boolean, True if network is in training mode
:param layer_id: int, not in use
:param unused_kwargs: other arguments, not in use
:return: tensor, output of the network
"""
# image_size should be divisible by 8
assert layer_util.check_spatial_dims(images, lambda x: x % 8 == 0)
assert layer_util.check_spatial_dims(images, lambda x: x >= 89)
block_layer = UNetBlock('DOWNSAMPLE',
(self.n_features[0], self.n_features[1]),
(3, 3), with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L1')
pool_1, conv_1 = block_layer(images, is_training)
print(block_layer)
block_layer = UNetBlock('DOWNSAMPLE',
(self.n_features[1], self.n_features[2]),
(3, 3), with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L2')
pool_2, conv_2 = block_layer(pool_1, is_training)
print(block_layer)
block_layer = UNetBlock('DOWNSAMPLE',
(self.n_features[2], self.n_features[3]),
(3, 3), with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L3')
pool_3, conv_3 = block_layer(pool_2, is_training)
print(block_layer)
block_layer = UNetBlock('UPSAMPLE',
(self.n_features[3], self.n_features[4]),
(3, 3), with_downsample_branch=False,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L4')
up_3, _ = block_layer(pool_3, is_training)
print(block_layer)
block_layer = UNetBlock('UPSAMPLE',
(self.n_features[3], self.n_features[3]),
(3, 3), with_downsample_branch=False,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='R3')
concat_3 = ElementwiseLayer('CONCAT')(conv_3, up_3)
up_2, _ = block_layer(concat_3, is_training)
print(block_layer)
block_layer = UNetBlock('UPSAMPLE',
(self.n_features[2], self.n_features[2]),
(3, 3), with_downsample_branch=False,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='R2')
concat_2 = ElementwiseLayer('CONCAT')(conv_2, up_2)
up_1, _ = block_layer(concat_2, is_training)
print(block_layer)
block_layer = UNetBlock('NONE',
(self.n_features[1],
self.n_features[1],
self.num_classes),
(3, 3, 1),
with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='R1_FC')
concat_1 = ElementwiseLayer('CONCAT')(conv_1, up_1)
# for the last layer, upsampling path is not used
_, output_tensor = block_layer(concat_1, is_training)
crop_layer = CropLayer(border=44, name='crop-88')
output_tensor = crop_layer(output_tensor)
print(block_layer)
return output_tensor
SUPPORTED_OP = set(['DOWNSAMPLE', 'UPSAMPLE', 'NONE'])
[docs]class UNetBlock(TrainableLayer):
[docs] def __init__(self,
func,
n_chns,
kernels,
w_initializer=None,
w_regularizer=None,
with_downsample_branch=False,
acti_func='relu',
name='UNet_block'):
"""
:param func: string, type of operation to perform after convolution (Downsampling, Upsampling, None)
:param n_chns: array, number of output channels for each convolutional layer of the block
:param kernels: array, kernel sizes for each convolutional layer of the block
:param w_initializer: weight initialisation of convolutional layers
:param w_regularizer: weight regularisation of convolutional layers
:param with_downsample_branch: boolean, returns also the tensor before func is applied
:param acti_func: activation function to use
:param name: layer name
"""
super(UNetBlock, self).__init__(name=name)
self.func = look_up_operations(func.upper(), SUPPORTED_OP)
self.kernels = kernels
self.n_chns = n_chns
self.with_downsample_branch = with_downsample_branch
self.acti_func = acti_func
self.initializers = {'w': w_initializer}
self.regularizers = {'w': w_regularizer}
[docs] def layer_op(self, input_tensor, is_training):
"""
:param input_tensor: tensor, input to the UNet block
:param is_training: boolean, True if network is in training mode
:return: output tensor of the UNet block and branch before downsampling (if required)
"""
output_tensor = input_tensor
for (kernel_size, n_features) in zip(self.kernels, self.n_chns):
conv_op = ConvolutionalLayer(n_output_chns=n_features,
kernel_size=kernel_size,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='{}'.format(n_features))
output_tensor = conv_op(output_tensor, is_training)
if self.with_downsample_branch:
branch_output = output_tensor
else:
branch_output = None
if self.func == 'DOWNSAMPLE':
downsample_op = DownSampleLayer('MAX',
kernel_size=2,
stride=2,
name='down_2x2')
output_tensor = downsample_op(output_tensor)
elif self.func == 'UPSAMPLE':
upsample_op = DeconvolutionalLayer(n_output_chns=self.n_chns[-1],
kernel_size=2,
stride=2,
name='up_2x2')
output_tensor = upsample_op(output_tensor, is_training)
elif self.func == 'NONE':
pass # do nothing
return output_tensor, branch_output