Source code for niftynet.layer.loss_classification
# -*- coding: utf-8 -*-
"""
Loss functions for multi-class classification
"""
from __future__ import absolute_import, print_function, division
import numpy as np
import tensorflow as tf
from niftynet.engine.application_factory import LossClassificationFactory
from niftynet.layer.base_layer import Layer
[docs]class LossFunction(Layer):
def __init__(self,
n_class,
loss_type='CrossEntropy',
loss_func_params=None,
name='loss_function'):
super(LossFunction, self).__init__(name=name)
self._num_classes = n_class
if loss_func_params is not None:
self._loss_func_params = loss_func_params
else:
self._loss_func_params = {}
self._data_loss_func = None
self.make_callable_loss_func(loss_type)
[docs] def make_callable_loss_func(self, type_str):
self._data_loss_func = LossClassificationFactory.create(type_str)
[docs] def layer_op(self,
prediction,
ground_truth=None,
var_scope=None, ):
"""
Compute loss from `prediction` and `ground truth`,
if `prediction `is list of tensors, each element of the list
will be compared against `ground_truth`.
:param prediction: input will be reshaped into (N, num_classes)
:param ground_truth: input will be reshaped into (N,)
:param var_scope:
:return:
"""
with tf.device('/cpu:0'):
if ground_truth is not None:
ground_truth = tf.reshape(ground_truth, [-1])
if not isinstance(prediction, (list, tuple)):
prediction = [prediction]
# prediction should be a list for holistic networks
if self._num_classes > 0:
# reshape the prediction to [n_voxels , num_classes]
prediction = [tf.reshape(pred, [-1, self._num_classes])
for pred in prediction]
data_loss = []
for pred in prediction:
if self._loss_func_params:
data_loss.append(self._data_loss_func(
pred, ground_truth,
**self._loss_func_params))
else:
data_loss.append(self._data_loss_func(
pred, ground_truth))
return tf.reduce_mean(data_loss)
[docs]def cross_entropy(prediction,
ground_truth):
"""
Function to calculate the cross entropy loss
:param prediction: the logits (before softmax)
:param ground_truth: the classification ground truth
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=prediction, labels=ground_truth)
return loss