# -*- coding: utf-8 -*-
"""
Loss functions for multi-class segmentation
"""
from __future__ import absolute_import, print_function, division
import numpy as np
import tensorflow as tf
from niftynet.engine.application_factory import LossSegmentationFactory
from niftynet.layer.base_layer import Layer
M_tree = np.array([[0., 1., 1., 1., 1.],
[1., 0., 0.6, 0.2, 0.5],
[1., 0.6, 0., 0.6, 0.7],
[1., 0.2, 0.6, 0., 0.5],
[1., 0.5, 0.7, 0.5, 0.]], dtype=np.float64)
[docs]class LossFunction(Layer):
def __init__(self,
n_class,
loss_type='Dice',
loss_func_params=None,
name='loss_function'):
super(LossFunction, self).__init__(name=name)
self._num_classes = n_class
if loss_func_params is not None:
self._loss_func_params = loss_func_params
else:
self._loss_func_params = {}
self._data_loss_func = None
self.make_callable_loss_func(loss_type)
[docs] def make_callable_loss_func(self, type_str):
self._data_loss_func = LossSegmentationFactory.create(type_str)
[docs] def layer_op(self,
prediction,
ground_truth=None,
weight_map=None,
var_scope=None, ):
"""
Compute loss from `prediction` and `ground truth`,
the computed loss map are weighted by `weight_map`.
if `prediction `is list of tensors, each element of the list
will be compared against `ground_truth` and the weighted by
`weight_map`.
:param prediction: input will be reshaped into (N, num_classes)
:param ground_truth: input will be reshaped into (N,)
:param weight_map: input will be reshaped into (N,)
:param var_scope:
:return:
"""
with tf.device('/cpu:0'):
if ground_truth is not None:
ground_truth = tf.reshape(ground_truth, [-1])
if weight_map is not None:
weight_map = tf.reshape(weight_map, [-1])
if not isinstance(prediction, (list, tuple)):
prediction = [prediction]
# prediction should be a list for holistic networks
if self._num_classes > 0:
# reshape the prediction to [n_voxels , num_classes]
prediction = [tf.reshape(pred, [-1, self._num_classes])
for pred in prediction]
data_loss = []
for pred in prediction:
if self._loss_func_params:
data_loss.append(self._data_loss_func(
pred, ground_truth, weight_map,
**self._loss_func_params))
else:
data_loss.append(self._data_loss_func(
pred, ground_truth, weight_map))
return tf.reduce_mean(data_loss)
[docs]def generalised_dice_loss(prediction,
ground_truth,
weight_map=None,
type_weight='Square'):
"""
Function to calculate the Generalised Dice Loss defined in
Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
loss function for highly unbalanced segmentations. DLMIA 2017
:param prediction: the logits (before softmax)
:param ground_truth: the segmentation ground truth
:param weight_map:
:param type_weight: type of weighting allowed between labels (choice
between Square (square of inverse of volume),
Simple (inverse of volume) and Uniform (no weighting))
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
n_voxels = ground_truth.get_shape()[0].value
n_classes = prediction.get_shape()[1].value
prediction = tf.nn.softmax(prediction)
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])
if weight_map is not None:
weight_map_nclasses = tf.reshape(
tf.tile(weight_map, [n_classes]), prediction.get_shape())
ref_vol = tf.sparse_reduce_sum(
weight_map_nclasses * one_hot, reduction_axes=[0])
intersect = tf.sparse_reduce_sum(
weight_map_nclasses * one_hot * prediction, reduction_axes=[0])
seg_vol = tf.reduce_sum(
tf.multiply(weight_map_nclasses, prediction), 0)
else:
ref_vol = tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
intersect = tf.sparse_reduce_sum(one_hot * prediction,
reduction_axes=[0])
seg_vol = tf.reduce_sum(prediction, 0)
if type_weight == 'Square':
weights = tf.reciprocal(tf.square(ref_vol))
elif type_weight == 'Simple':
weights = tf.reciprocal(ref_vol)
elif type_weight == 'Uniform':
weights = tf.ones_like(ref_vol)
else:
raise ValueError("The variable type_weight \"{}\""
"is not defined.".format(type_weight))
new_weights = tf.where(tf.is_inf(weights), tf.zeros_like(weights), weights)
weights = tf.where(tf.is_inf(weights), tf.ones_like(weights) *
tf.reduce_max(new_weights), weights)
generalised_dice_numerator = \
2 * tf.reduce_sum(tf.multiply(weights, intersect))
generalised_dice_denominator = \
tf.reduce_sum(tf.multiply(weights, seg_vol + ref_vol))
generalised_dice_score = \
generalised_dice_numerator / generalised_dice_denominator
return 1 - generalised_dice_score
[docs]def sensitivity_specificity_loss(prediction,
ground_truth,
weight_map=None,
r=0.05):
"""
Function to calculate a multiple-ground_truth version of
the sensitivity-specificity loss defined in "Deep Convolutional
Encoder Networks for Multiple Sclerosis Lesion Segmentation",
Brosch et al, MICCAI 2015,
https://link.springer.com/chapter/10.1007/978-3-319-24574-4_1
error is the sum of r(specificity part) and (1-r)(sensitivity part)
:param prediction: the logits (before softmax).
:param ground_truth: segmentation ground_truth.
:param r: the 'sensitivity ratio'
(authors suggest values from 0.01-0.10 will have similar effects)
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
n_voxels = ground_truth.get_shape()[0].value
n_classes = prediction.get_shape()[1].value
prediction = tf.nn.softmax(prediction)
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])
one_hot = tf.sparse_tensor_to_dense(one_hot)
# value of unity everywhere except for the previous 'hot' locations
one_cold = 1 - one_hot
# chosen region may contain no voxels of a given label. Prevents nans.
epsilon_denominator = 1e-5
squared_error = tf.square(one_hot - prediction)
specificity_part = tf.reduce_sum(
squared_error * one_hot, 0) / \
(tf.reduce_sum(one_hot, 0) + epsilon_denominator)
sensitivity_part = \
(tf.reduce_sum(tf.multiply(squared_error, one_cold), 0) /
(tf.reduce_sum(one_cold, 0) + epsilon_denominator))
return tf.reduce_sum(r * specificity_part + (1 - r) * sensitivity_part)
[docs]def l2_reg_loss(scope):
if not tf.get_collection('reg_var', scope):
return 0.0
return tf.add_n([tf.nn.l2_loss(reg_var) for reg_var in
tf.get_collection('reg_var', scope)])
[docs]def cross_entropy(prediction, ground_truth, weight_map=None):
"""
Function to calculate the cross-entropy loss function
:param prediction: the logits (before softmax)
:param ground_truth: the segmentation ground truth
:param weight_map:
:return: the cross-entropy loss
"""
entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=prediction, labels=ground_truth)
if weight_map is not None:
weight_map = tf.cast(tf.size(entropy), dtype=tf.float32) / \
tf.reduce_sum(weight_map) * weight_map
entropy = tf.multiply(entropy, weight_map)
return tf.reduce_mean(entropy)
[docs]def wasserstein_disagreement_map(prediction, ground_truth, M):
"""
Function to calculate the pixel-wise Wasserstein distance between the
flattened pred_proba and the flattened labels (ground_truth) with respect
to the distance matrix on the label space M.
:param prediction: the logits after softmax
:param ground_truth: segmentation ground_truth
:param M: distance matrix on the label space
:return: the pixelwise distance map (wass_dis_map)
"""
# pixel-wise Wassertein distance (W) between flat_pred_proba and flat_labels
# wrt the distance matrix on the label space M
n_classes = prediction.get_shape()[1].value
unstack_labels = tf.unstack(ground_truth, axis=-1)
unstack_labels = tf.cast(unstack_labels, dtype=tf.float64)
unstack_pred = tf.unstack(prediction, axis=-1)
unstack_pred = tf.cast(unstack_pred, dtype=tf.float64)
# print("shape of M", M.shape, "unstacked labels", unstack_labels,
# "unstacked pred" ,unstack_pred)
# W is a weighting sum of all pairwise correlations (pred_ci x labels_cj)
pairwise_correlations = []
for i in range(n_classes):
for j in range(n_classes):
pairwise_correlations.append(
M[i, j] * tf.multiply(unstack_pred[i], unstack_labels[j]))
wass_dis_map = tf.add_n(pairwise_correlations)
return wass_dis_map
[docs]def generalised_wasserstein_dice_loss(prediction,
ground_truth,
weight_map=None):
"""
Function to calculate the Generalised Wasserstein Dice Loss defined in
Fidon, L. et. al. (2017) Generalised Wasserstein Dice Score
for Imbalanced Multi-class Segmentation using Holistic
Convolutional Networks.MICCAI 2017 (BrainLes)
:param prediction: the logits (before softmax)
:param ground_truth: the segmentation ground_truth
:param weight_map:
:return: the loss
"""
# apply softmax to pred scores
ground_truth = tf.cast(ground_truth, dtype=tf.int64)
pred_proba = tf.nn.softmax(tf.cast(prediction, dtype=tf.float64))
n_classes = prediction.get_shape()[1].value
n_voxels = prediction.get_shape()[0].value
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])
one_hot = tf.sparse_tensor_to_dense(one_hot)
# M = tf.cast(M, dtype=tf.float64)
# compute disagreement map (delta)
M = M_tree
# print("M shape is ", M.shape, pred_proba, one_hot)
delta = wasserstein_disagreement_map(pred_proba, one_hot, M)
# compute generalisation of all error for multi-class seg
all_error = tf.reduce_sum(delta)
# compute generalisation of true positives for multi-class seg
one_hot = tf.cast(one_hot, dtype=tf.float64)
true_pos = tf.reduce_sum(
tf.multiply(tf.constant(M[0, :n_classes], dtype=tf.float64), one_hot),
axis=1)
true_pos = tf.reduce_sum(tf.multiply(true_pos, 1. - delta), axis=0)
WGDL = 1. - (2. * true_pos) / (2. * true_pos + all_error)
return tf.cast(WGDL, dtype=tf.float32)
[docs]def dice_nosquare(prediction, ground_truth, weight_map=None):
"""
Function to calculate the classical dice loss
:param prediction: the logits (before softmax)
:param ground_truth: the segmentation ground_truth
:param weight_map:
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
n_voxels = ground_truth.get_shape()[0].value
n_classes = prediction.get_shape()[1].value
prediction = tf.nn.softmax(prediction)
# construct sparse matrix for ground_truth to save space
ids = tf.constant(np.arange(n_voxels), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(indices=ids,
values=tf.ones([n_voxels], dtype=tf.float32),
dense_shape=[n_voxels, n_classes])
# dice
if weight_map is not None:
weight_map_nclasses = tf.reshape(
tf.tile(weight_map, [n_classes]), prediction.get_shape())
dice_numerator = 2.0 * tf.sparse_reduce_sum(
weight_map_nclasses * one_hot * prediction, reduction_axes=[0])
dice_denominator = \
tf.reduce_sum(prediction * weight_map_nclasses,
reduction_indices=[0]) + \
tf.sparse_reduce_sum(weight_map_nclasses * one_hot,
reduction_axes=[0])
else:
dice_numerator = 2.0 * tf.sparse_reduce_sum(one_hot * prediction,
reduction_axes=[0])
dice_denominator = tf.reduce_sum(prediction, reduction_indices=[0]) + \
tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
epsilon_denominator = 0.00001
dice_score = dice_numerator / (dice_denominator + epsilon_denominator)
# dice_score.set_shape([n_classes])
# minimising (1 - dice_coefficients)
return 1.0 - tf.reduce_mean(dice_score)
[docs]def dice(prediction, ground_truth, weight_map=None):
"""
Function to calculate the dice loss with the definition given in
Milletari, F., Navab, N., & Ahmadi, S. A. (2016)
V-net: Fully convolutional neural
networks for volumetric medical image segmentation. 3DV 2016
using a square in the denominator
:param prediction: the logits (before softmax)
:param ground_truth: the segmentation ground_truth
:param weight_map:
:return: the loss
"""
ground_truth = tf.to_int64(ground_truth)
prediction = tf.cast(prediction, tf.float32)
prediction = tf.nn.softmax(prediction)
ids = tf.range(tf.to_int64(tf.shape(ground_truth)[0]), dtype=tf.int64)
ids = tf.stack([ids, ground_truth], axis=1)
one_hot = tf.SparseTensor(
indices=ids,
values=tf.ones_like(ground_truth, dtype=tf.float32),
dense_shape=tf.to_int64(tf.shape(prediction)))
if weight_map is not None:
n_classes = prediction.get_shape()[1].value
weight_map_nclasses = tf.reshape(
tf.tile(weight_map, [n_classes]), prediction.get_shape())
dice_numerator = 2.0 * tf.sparse_reduce_sum(
weight_map_nclasses * one_hot * prediction, reduction_axes=[0])
dice_denominator = \
tf.reduce_sum(weight_map_nclasses * tf.square(prediction),
reduction_indices=[0]) + \
tf.sparse_reduce_sum(one_hot * weight_map_nclasses,
reduction_axes=[0])
else:
dice_numerator = 2.0 * tf.sparse_reduce_sum(
one_hot * prediction, reduction_axes=[0])
dice_denominator = \
tf.reduce_sum(tf.square(prediction), reduction_indices=[0]) + \
tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
epsilon_denominator = 0.00001
dice_score = dice_numerator / (dice_denominator + epsilon_denominator)
# dice_score.set_shape([n_classes])
# minimising (1 - dice_coefficients)
return 1.0 - tf.reduce_mean(dice_score)