# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
from niftynet.layer import layer_util
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.deconvolution import DeconvolutionalLayer
from niftynet.layer.downsample import DownSampleLayer
from niftynet.layer.elementwise import ElementwiseLayer
from niftynet.layer.crop import CropLayer
from niftynet.utilities.util_common import look_up_operations
[docs]class UNet3D(TrainableLayer):
"""
reimplementation of 3D U-net
Çiçek et al., "3D U-Net: Learning dense Volumetric segmentation from
sparse annotation", MICCAI '16
"""
def __init__(self,
num_classes,
w_initializer=None,
w_regularizer=None,
b_initializer=None,
b_regularizer=None,
acti_func='prelu',
name='UNet'):
super(UNet3D, self).__init__(name=name)
self.n_features = [32, 64, 128, 256, 512]
self.acti_func = acti_func
self.num_classes = num_classes
self.initializers = {'w': w_initializer, 'b': b_initializer}
self.regularizers = {'w': w_regularizer, 'b': b_regularizer}
print('using {}'.format(name))
[docs] def layer_op(self, images, is_training=True, layer_id=-1, **unused_kwargs):
# image_size should be divisible by 8
assert layer_util.check_spatial_dims(images, lambda x: x % 8 == 0)
assert layer_util.check_spatial_dims(images, lambda x: x >= 89)
block_layer = UNetBlock('DOWNSAMPLE',
(self.n_features[0], self.n_features[1]),
(3, 3), with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L1')
pool_1, conv_1 = block_layer(images, is_training)
print(block_layer)
block_layer = UNetBlock('DOWNSAMPLE',
(self.n_features[1], self.n_features[2]),
(3, 3), with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L2')
pool_2, conv_2 = block_layer(pool_1, is_training)
print(block_layer)
block_layer = UNetBlock('DOWNSAMPLE',
(self.n_features[2], self.n_features[3]),
(3, 3), with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L3')
pool_3, conv_3 = block_layer(pool_2, is_training)
print(block_layer)
block_layer = UNetBlock('UPSAMPLE',
(self.n_features[3], self.n_features[4]),
(3, 3), with_downsample_branch=False,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='L4')
up_3, _ = block_layer(pool_3, is_training)
print(block_layer)
block_layer = UNetBlock('UPSAMPLE',
(self.n_features[3], self.n_features[3]),
(3, 3), with_downsample_branch=False,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='R3')
concat_3 = ElementwiseLayer('CONCAT')(conv_3, up_3)
up_2, _ = block_layer(concat_3, is_training)
print(block_layer)
block_layer = UNetBlock('UPSAMPLE',
(self.n_features[2], self.n_features[2]),
(3, 3), with_downsample_branch=False,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='R2')
concat_2 = ElementwiseLayer('CONCAT')(conv_2, up_2)
up_1, _ = block_layer(concat_2, is_training)
print(block_layer)
block_layer = UNetBlock('NONE',
(self.n_features[1],
self.n_features[1],
self.num_classes),
(3, 3, 1),
with_downsample_branch=True,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='R1_FC')
concat_1 = ElementwiseLayer('CONCAT')(conv_1, up_1)
# for the last layer, upsampling path is not used
_, output_tensor = block_layer(concat_1, is_training)
crop_layer = CropLayer(border=44, name='crop-88')
output_tensor = crop_layer(output_tensor)
print(block_layer)
return output_tensor
SUPPORTED_OP = set(['DOWNSAMPLE', 'UPSAMPLE', 'NONE'])
[docs]class UNetBlock(TrainableLayer):
def __init__(self,
func,
n_chns,
kernels,
w_initializer=None,
w_regularizer=None,
with_downsample_branch=False,
acti_func='relu',
name='UNet_block'):
super(UNetBlock, self).__init__(name=name)
self.func = look_up_operations(func.upper(), SUPPORTED_OP)
self.kernels = kernels
self.n_chns = n_chns
self.with_downsample_branch = with_downsample_branch
self.acti_func = acti_func
self.initializers = {'w': w_initializer}
self.regularizers = {'w': w_regularizer}
[docs] def layer_op(self, input_tensor, is_training):
output_tensor = input_tensor
for (kernel_size, n_features) in zip(self.kernels, self.n_chns):
conv_op = ConvolutionalLayer(n_output_chns=n_features,
kernel_size=kernel_size,
w_initializer=self.initializers['w'],
w_regularizer=self.regularizers['w'],
acti_func=self.acti_func,
name='{}'.format(n_features))
output_tensor = conv_op(output_tensor, is_training)
if self.with_downsample_branch:
branch_output = output_tensor
else:
branch_output = None
if self.func == 'DOWNSAMPLE':
downsample_op = DownSampleLayer('MAX',
kernel_size=2,
stride=2,
name='down_2x2')
output_tensor = downsample_op(output_tensor)
elif self.func == 'UPSAMPLE':
upsample_op = DeconvolutionalLayer(n_output_chns=self.n_chns[-1],
kernel_size=2,
stride=2,
name='up_2x2')
output_tensor = upsample_op(output_tensor, is_training)
elif self.func == 'NONE':
pass # do nothing
return output_tensor, branch_output