Source code for niftynet.engine.handler_performance
# -*- coding: utf-8 -*-
"""
This module tracks model validation performance over training
"""
import tensorflow as tf
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.signal import ITER_FINISHED
[docs]class PerformanceLogger(object):
"""
This class handles iteration events to store the current performance as
an attribute of the sender (i.e. application).
"""
def __init__(self, **_unused):
ITER_FINISHED.connect(self.update_performance_history)
[docs] def update_performance_history(self, _sender, **msg):
"""
Printing iteration message with ``tf.logging`` interface.
:param _sender:
:param msg: an iteration message instance
:return:
"""
iter_msg = msg['iter_msg']
if iter_msg.is_validation:
try:
console_content = iter_msg.current_iter_output.get(CONSOLE, '')
current_loss = console_content['total_loss']
if len(_sender.performance_history) < _sender.patience:
_sender.performance_history.append(current_loss)
else:
_sender.performance_history = \
_sender.performance_history[1:] + [current_loss]
except (AttributeError, KeyError):
tf.logging.warning("does not contain any performance field "
"called total loss.")