Source code for niftynet.layer.elementwise

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import numpy as np
import tensorflow as tf

from niftynet.layer import layer_util
from niftynet.layer.base_layer import TrainableLayer
from niftynet.layer.convolution import ConvLayer
from niftynet.utilities.util_common import look_up_operations

SUPPORTED_OP = set(['SUM', 'CONCAT'])


[docs]class ElementwiseLayer(TrainableLayer): """ This class takes care of the elementwise sum in a residual connection It matches the channel dims from two branch flows, by either padding or projection if necessary. """ def __init__(self, func, initializer=None, regularizer=None, name='residual'): self.func = look_up_operations(func.upper(), SUPPORTED_OP) self.layer_name = '{}_{}'.format(name, self.func.lower()) super(ElementwiseLayer, self).__init__(name=self.layer_name) self.initializers = {'w': initializer} self.regularizers = {'w': regularizer}
[docs] def layer_op(self, param_flow, bypass_flow): n_param_flow = param_flow.shape[-1] n_bypass_flow = bypass_flow.shape[-1] spatial_rank = layer_util.infer_spatial_rank(param_flow) output_tensor = param_flow if self.func == 'SUM': if n_param_flow > n_bypass_flow: # pad the channel dim pad_1 = np.int((n_param_flow - n_bypass_flow) // 2) pad_2 = np.int(n_param_flow - n_bypass_flow - pad_1) padding_dims = np.vstack(([[0, 0]], [[0, 0]] * spatial_rank, [[pad_1, pad_2]])) bypass_flow = tf.pad(tensor=bypass_flow, paddings=padding_dims.tolist(), mode='CONSTANT') elif n_param_flow < n_bypass_flow: # make a projection projector = ConvLayer(n_output_chns=n_param_flow, kernel_size=1, stride=1, padding='SAME', w_initializer=self.initializers['w'], w_regularizer=self.regularizers['w'], name='proj') bypass_flow = projector(bypass_flow) # element-wise sum of both paths output_tensor = param_flow + bypass_flow elif self.func == 'CONCAT': output_tensor = tf.concat([param_flow, bypass_flow], axis=-1) return output_tensor