Source code for niftynet.layer.reshape

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import tensorflow as tf

from niftynet.layer.base_layer import Layer


[docs]class ReshapeLayer(Layer): """ This class defines a simple reshape layer, principally for passing feature maps to fully connected layers. """ def __init__(self, output_size, name='reshaper'): super(ReshapeLayer, self).__init__(name=name) self.output_size = output_size self.name = name
[docs] def layer_op(self, input_tensor): output_tensor = tf.reshape(tensor=input_tensor, shape=self.output_size, name=self.name) return output_tensor