Source code for niftynet.layer.upsample_res_block
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
from niftynet.layer.additive_upsample import ResidualUpsampleLayer as ResUp
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.deconvolution import DeconvolutionalLayer as Deconv
from niftynet.layer.elementwise import ElementwiseLayer
from niftynet.layer.residual_unit import ResidualUnit as ResUnit
[docs]class UpBlock(TrainableLayer):
def __init__(self,
n_output_chns=4,
kernel_size=3,
upsample_stride=2,
acti_func='relu',
w_initializer=None,
w_regularizer=None,
is_residual_upsampling=True,
type_string='bn_acti_conv',
name='res-upsample'):
super(TrainableLayer, self).__init__(name=name)
self.type_string = type_string
self.n_output_chns = n_output_chns
self.kernel_size = kernel_size
self.acti_func = acti_func
self.upsample_stride = upsample_stride
self.conv_param = {'w_initializer': w_initializer,
'w_regularizer': w_regularizer}
self.is_residual_upsampling = is_residual_upsampling
[docs] def layer_op(self, inputs, forwarding=None, is_training=True):
"""
Consists of::
(inputs)--upsampling-+-o--conv_1--conv_2--+--(conv_res)--
| | |
(forwarding)---------o o------------------o
where upsampling method could be ``DeconvolutionalLayer``
or ``ResidualUpsampleLayer``
"""
if self.is_residual_upsampling:
n_input_channels = inputs.get_shape().as_list()[-1]
n_splits = float(n_input_channels) / float(self.n_output_chns)
upsampled = ResUp(kernel_size=self.kernel_size,
stride=self.upsample_stride,
n_splits=n_splits,
acti_func=self.acti_func,
**self.conv_param)(inputs, is_training)
else:
upsampled = Deconv(n_output_chns=self.n_output_chns,
kernel_size=self.kernel_size,
stride=self.upsample_stride,
acti_func=self.acti_func,
with_bias=False, with_bn=True,
**self.conv_param)(inputs, is_training)
if forwarding is None:
conv_0 = upsampled
else:
conv_0 = ElementwiseLayer('SUM')(upsampled, forwarding)
conv_res = ResUnit(n_output_chns=self.n_output_chns,
kernel_size=self.kernel_size,
acti_func=self.acti_func,
type_string=self.type_string,
**self.conv_param)(conv_0, is_training)
return conv_res