# -*- coding: utf-8 -*-
"""
Message stores status info of the current iteration.
"""
import itertools
import time
from niftynet.engine.application_variables import CONSOLE, TF_SUMMARIES
from niftynet.engine.signal import TRAIN, VALID, INFER
from niftynet.utilities.util_common import look_up_operations
CONSOLE_FORMAT = "{} iter {}, {} ({:3f}s)"
SUPPORTED_PHASES = {TRAIN, VALID, INFER}
[docs]class IterationMessage(object):
"""
This class consists of network variables and operations at each iteration.
It is generated by the application engine but can be modified by the
application as well.
"""
_current_iter = 0
_current_iter_tic = 0
_current_iter_toc = 0
_current_iter_output = None
_data_feed_dict = None
_ops_to_run = None
_phase = TRAIN
_should_stop = None
@property
def current_iter(self):
"""
Current iteration index
can be used to create complex schedule for the
iterative training/validation/inference procedure.
:return: integer of iteration
"""
return self._current_iter
@current_iter.setter
def current_iter(self, value):
self._current_iter = int(value)
self._current_iter_tic = time.time()
self._current_iter_output = None
@property
def ops_to_run(self):
"""
operations (tf graph elements) to be fed into
``session.run(...)``. This is currently mainly used
for passing network gradient updates ops to ``session.run``.
To modify the operations, assigns ``self.ops_to_run``
:return: a copy of the operation dictionary
"""
if self._ops_to_run is None:
self._ops_to_run = {}
assert isinstance(self._ops_to_run, dict), \
'ops to run should be a dictionary'
return self._ops_to_run
@ops_to_run.setter
def ops_to_run(self, value):
self._ops_to_run = value
@property
def data_feed_dict(self):
"""
A dictionary that maps graph elements to values
to be fed into ``session.run(...)`` as feed_dict parameter
:return: dictionary of operations
"""
if self._data_feed_dict is None:
self._data_feed_dict = {}
return self._data_feed_dict
@data_feed_dict.setter
def data_feed_dict(self, value):
assert isinstance(value, dict), \
'data_feed_dict should a dictionary of placeholders:values'
self._data_feed_dict = value
@property
def current_iter_output(self):
"""
This property stores graph output received
by running ``session.run()``.
:return:
"""
return self._current_iter_output
@current_iter_output.setter
def current_iter_output(self, value):
self._current_iter_output = value
self._current_iter_toc = time.time()
@property
def should_stop(self):
"""
Engine check this property after each iteration
This could be modified in by application
``application.set_iteration_update()``
to create training schedules such as early stopping.
:return: None or a handler that requested to stop the loop
"""
return self._should_stop
@should_stop.setter
def should_stop(self, value):
self._should_stop = value
@property
def phase(self):
"""
A string indicating the phase in train/validation/inference
:return:
"""
return self._phase
@phase.setter
def phase(self, value):
self._phase = look_up_operations(value, SUPPORTED_PHASES)
@property
def is_training(self):
"""
:return: boolean value indicating if the phase is training
"""
return self.phase == TRAIN
@property
def is_validation(self):
"""
:return: boolean value indicating if the phase is validation
"""
return self.phase == VALID
@property
def is_inference(self):
"""
:return: boolean value indicating if the phase is inference
"""
return self.phase == INFER
@property
def iter_duration(self):
"""
measuring time used
from setting self.current_iter to setting self.current_iter_output
:return: time duration of an iteration
"""
current_toc = max(self._current_iter_toc, self._current_iter_tic)
return current_toc - self._current_iter_tic
[docs] def to_console_string(self):
"""
converting current_iter_output to string, for console displaying
:return: summary string
"""
summary_indentation = " " if self.is_validation else ""
summary_format = summary_indentation + CONSOLE_FORMAT
try:
console_content = self.current_iter_output.get(CONSOLE, '')
except AttributeError:
console_content = "print to console -- set current_iter_output " \
"to a dictionary of {CONSOLE: 'content'}."
result_str = _console_vars_to_str(console_content)
summary = summary_format.format(
self.phase, self.current_iter, result_str, self.iter_duration)
return summary
[docs] def to_tf_summary(self, writer=None):
"""
converting current_iter_output to tf summary and write to ``writer``
:param writer: writer instance for summary output
:return:
"""
if writer is None:
return
try:
summary = self.current_iter_output.get(TF_SUMMARIES, {})
except AttributeError:
summary = None
if not summary:
return
writer.add_summary(summary, self.current_iter)
[docs]class IterationMessageGenerator(object):
"""
Classes provides an iteration message generator function.
The generator should yield IterationMessage instances.
"""
def __init__(self,
initial_iter=0,
final_iter=0,
validation_every_n=0,
validation_max_iter=0,
is_training_action=True,
**_unused):
self.initial_iter = max(initial_iter, -1)
self.final_iter = max(final_iter, self.initial_iter)
self.validation_every_n = validation_every_n
self.validation_max_iter = validation_max_iter
self.is_training_action = is_training_action
def __call__(self):
if not self.is_training_action:
return _infer_iter_generator()
return _train_iter_generator(
initial_iter=self.initial_iter,
final_iter=self.final_iter,
validation_every_n=self.validation_every_n,
validation_max_iter=self.validation_max_iter)
def _infer_iter_generator():
"""
This generator yields infinite number of infer iterations.
:return: iteration message instances
"""
infer_iterations = _iter_msg_generator(itertools.count(), INFER)
for infer_iter_msg in infer_iterations:
yield infer_iter_msg
def _train_iter_generator(initial_iter=0,
final_iter=0,
validation_every_n=0,
validation_max_iter=0):
"""
This generator yields a sequence of interleaved training and validation
iterations.
:param initial_iter: starting iteration of the training sequence
:param final_iter: ending iteration of the training sequence
:param validation_every_n: validation at every n training
:param validation_max_iter: number of validation iterations
:return: iteration message instances
"""
train_iterations = _iter_msg_generator(
range(initial_iter + 1, final_iter + 1), TRAIN)
for train_iter_msg in train_iterations:
yield train_iter_msg
current_iter = train_iter_msg.current_iter
if current_iter > 0 and validation_every_n > 0 and \
current_iter % validation_every_n == 0:
# generating validation iterations without changing the current
# iteration number.
valid_iterations = _iter_msg_generator(
[current_iter] * validation_max_iter, VALID)
for valid_iter_msg in valid_iterations:
yield valid_iter_msg
def _iter_msg_generator(count_generator, phase):
"""
Generate a numbered sequence of IterationMessage objects
with phase-appropriate signals.
count_generator is an iterable object yielding iteration numbers
phase is one of TRAIN, VALID or INFER
"""
for iter_i in count_generator:
iter_msg = IterationMessage()
iter_msg.current_iter, iter_msg.phase = iter_i, phase
yield iter_msg
def _console_vars_to_str(console_dict):
"""
Printing values of variable evaluations to command line output.
"""
if not console_dict:
return ''
if isinstance(console_dict, dict):
console_str = ', '.join('{}={}'.format(key, val)
for (key, val) in console_dict.items())
else:
console_str = '{}'.format(console_dict)
return console_str