niftynet.layer.loss_classification module

Loss functions for multi-class classification

class LossFunction(n_class, loss_type='CrossEntropy', loss_func_params=None, name='loss_function')[source]

Bases: niftynet.layer.base_layer.Layer

make_callable_loss_func(type_str)[source]
layer_op(prediction, ground_truth=None, var_scope=None)[source]

Compute loss from prediction and ground truth,

if prediction `is list of tensors, each element of the list will be compared against `ground_truth.

Parameters:
  • prediction – input will be reshaped into (N, num_classes)
  • ground_truth – input will be reshaped into (N,)
  • var_scope
Returns:

cross_entropy(prediction, ground_truth)[source]

Function to calculate the cross entropy loss :param prediction: the logits (before softmax) :param ground_truth: the classification ground truth :return: the loss