Source code for niftynet.layer.loss_gan
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.engine.application_factory import LossGANFactory
from niftynet.layer.base_layer import Layer
[docs]class LossFunction(Layer):
def __init__(self,
loss_type='CrossEntropy',
loss_func_params=None,
name='loss_function'):
super(LossFunction, self).__init__(name=name)
if loss_func_params is not None:
self._loss_func_params = loss_func_params
else:
self._loss_func_params = {}
self._data_loss_func = None
self.make_callable_loss_func(loss_type)
[docs] def make_callable_loss_func(self, type_str):
self._data_loss_func = LossGANFactory.create(type_str)
[docs] def layer_op(self, pred_real, pred_fake, var_scope=None):
with tf.device('/cpu:0'):
g_loss = self._data_loss_func['g'](
pred_fake,
**self._loss_func_params)
d_fake = self._data_loss_func['d_fake'](
pred_fake,
**self._loss_func_params)
d_real = self._data_loss_func['d_real'](
pred_real,
**self._loss_func_params)
return g_loss, (d_fake + d_real)
[docs]def cross_entropy_function(is_real, softness=.1):
def cross_entropy_op(pred, **kwargs):
if is_real:
target = (1. - softness) * tf.ones_like(pred)
else:
target = softness * tf.ones_like(pred)
entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits=pred,
labels=target)
return tf.reduce_mean(entropy)
return cross_entropy_op
cross_entropy = {'g': cross_entropy_function(True, 0),
'd_fake': cross_entropy_function(False, 0),
'd_real': cross_entropy_function(True, .1)}