Source code for niftynet.layer.post_processing

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import tensorflow as tf

from niftynet.layer.base_layer import Layer
from niftynet.utilities.util_common import look_up_operations

SUPPORTED_OPS = set(["SOFTMAX", "ARGMAX", "IDENTITY"])


[docs]class PostProcessingLayer(Layer): """ This layer operation converts the raw network outputs into final inference results. """ def __init__(self, func='', num_classes=0, name='post_processing'): super(PostProcessingLayer, self).__init__(name=name) self.func = look_up_operations(func.upper(), SUPPORTED_OPS) self.num_classes = num_classes
[docs] def num_output_channels(self): assert self._op._variables_created if self.func == "SOFTMAX": return self.num_classes else: return 1
[docs] def layer_op(self, inputs): if self.func == "SOFTMAX": output_tensor = tf.cast(tf.nn.softmax(inputs), tf.float32) elif self.func == "ARGMAX": output_tensor = tf.cast(tf.argmax(inputs, -1), tf.int32) output_tensor = tf.expand_dims(output_tensor, axis=-1) elif self.func == "IDENTITY": output_tensor = tf.cast(inputs, tf.float32) return output_tensor