# -*- coding: utf-8 -*-
"""
This module defines a general procedure for running applications.
Example usage::
app_driver = ApplicationDriver()
app_driver.initialise_application(system_param, input_data_param)
app_driver.run_application()
``system_param`` and ``input_data_param`` should be generated using:
``niftynet.utilities.user_parameters_parser.run()``
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import tensorflow as tf
from niftynet.engine.application_factory import \
ApplicationFactory, EventHandlerFactory, IteratorFactory
from niftynet.engine.application_iteration import IterationMessage
from niftynet.engine.application_variables import \
GradientsCollector, OutputsCollector
from niftynet.engine.signal import TRAIN, \
ITER_STARTED, ITER_FINISHED, GRAPH_CREATED, SESS_FINISHED, SESS_STARTED
from niftynet.io.image_sets_partitioner import ImageSetsPartitioner
from niftynet.io.misc_io import infer_latest_model_file
from niftynet.utilities.user_parameters_default import \
DEFAULT_EVENT_HANDLERS, DEFAULT_ITERATION_GENERATOR
from niftynet.utilities.util_common import \
set_cuda_device, tf_config, device_string
from niftynet.utilities.util_common import traverse_nested
# pylint: disable=too-many-instance-attributes
[docs]class ApplicationDriver(object):
"""
This class initialises an application by building a TF graph,
and maintaining a session. It controls the
starting/stopping of an application. Applications should be
implemented by inheriting ``niftynet.application.base_application``
to be compatible with this driver.
"""
def __init__(self):
self.app = None
self.is_training_action = True
self.num_threads = 0
self.num_gpus = 0
self.model_dir = None
self.max_checkpoints = 2
self.save_every_n = 0
self.tensorboard_every_n = -1
self.vars_to_restore = ''
self.initial_iter = 0
self.final_iter = 0
self.validation_every_n = -1
self.validation_max_iter = 1
self.data_partitioner = ImageSetsPartitioner()
self._event_handlers = None
self._generator = None
[docs] def initialise_application(self, workflow_param, data_param=None):
"""
This function receives all parameters from user config file,
create an instance of application.
:param workflow_param: a dictionary of user parameters,
keys correspond to sections in the config file
:param data_param: a dictionary of input image parameters,
keys correspond to data properties to be used by image_reader
:return:
"""
try:
system_param = workflow_param.get('SYSTEM', None)
net_param = workflow_param.get('NETWORK', None)
train_param = workflow_param.get('TRAINING', None)
infer_param = workflow_param.get('INFERENCE', None)
app_param = workflow_param.get('CUSTOM', None)
except AttributeError:
tf.logging.fatal('parameters should be dictionaries')
raise
assert os.path.exists(system_param.model_dir), \
'Model folder not exists {}'.format(system_param.model_dir)
self.model_dir = system_param.model_dir
self.is_training_action = TRAIN.startswith(system_param.action.lower())
# hardware-related parameters
self.num_threads = max(system_param.num_threads, 1) \
if self.is_training_action else 1
self.num_gpus = system_param.num_gpus \
if self.is_training_action else min(system_param.num_gpus, 1)
set_cuda_device(system_param.cuda_devices)
# set training params.
if self.is_training_action:
assert train_param, 'training parameters not specified'
self.initial_iter = train_param.starting_iter
self.final_iter = max(train_param.max_iter, self.initial_iter)
self.save_every_n = train_param.save_every_n
self.tensorboard_every_n = train_param.tensorboard_every_n
self.max_checkpoints = max(self.max_checkpoints,
train_param.max_checkpoints)
self.validation_every_n = train_param.validation_every_n
self.vars_to_restore = train_param.vars_to_restore \
if hasattr(train_param, 'vars_to_restore') else ''
if self.validation_every_n > 0:
self.validation_max_iter = max(self.validation_max_iter,
train_param.validation_max_iter)
action_param = train_param
else: # set inference params.
assert infer_param, 'inference parameters not specified'
self.initial_iter = infer_param.inference_iter
action_param = infer_param
# infer the initial iteration from model files
if self.initial_iter < 0:
self.initial_iter = infer_latest_model_file(
os.path.join(self.model_dir, 'models'))
# create an application instance
assert app_param, 'application specific param. not specified'
app_module = ApplicationFactory.create(app_param.name)
self.app = app_module(net_param, action_param, system_param.action)
# clear the cached file lists
self.data_partitioner.reset()
if data_param:
do_new_partition = \
self.is_training_action and \
(not os.path.isfile(system_param.dataset_split_file)) and \
(train_param.exclude_fraction_for_validation > 0 or
train_param.exclude_fraction_for_inference > 0)
data_fractions = (train_param.exclude_fraction_for_validation,
train_param.exclude_fraction_for_inference) \
if do_new_partition else None
self.data_partitioner.initialise(
data_param=data_param,
new_partition=do_new_partition,
ratios=data_fractions,
data_split_file=system_param.dataset_split_file)
assert self.data_partitioner.has_validation or \
self.validation_every_n <= 0, \
'validation_every_n is set to {}, ' \
'but train/validation splitting not available.\nPlease ' \
'check dataset partition list {} ' \
'(remove file to generate a new dataset partition), ' \
'check "exclude_fraction_for_validation" ' \
'(current config value: {}).\nAlternatively, ' \
'set "validation_every_n" to -1.'.format(
self.validation_every_n,
system_param.dataset_split_file,
train_param.exclude_fraction_for_validation)
# initialise readers
self.app.initialise_dataset_loader(
data_param, app_param, self.data_partitioner)
# make the list of initialised event handler instances.
self.load_event_handlers(
system_param.event_handler or DEFAULT_EVENT_HANDLERS)
self._generator = IteratorFactory.create(
system_param.iteration_generator or DEFAULT_ITERATION_GENERATOR)
[docs] def run(self, application, graph=None):
"""
Initialise a TF graph, connect data sampler and network within
the graph context, run training loops or inference loops.
:param application: a niftynet application
:param graph: default base graph to run the application
:return:
"""
if graph is None:
graph = ApplicationDriver.create_graph(
application=application,
num_gpus=self.num_gpus,
num_threads=self.num_threads,
is_training_action=self.is_training_action)
start_time = time.time()
loop_status = {'current_iter': self.initial_iter, 'normal_exit': False}
with tf.Session(config=tf_config(), graph=graph):
try:
# broadcasting event of session started
SESS_STARTED.send(application, iter_msg=None)
# create a iteration message generator and
# iteratively run the graph (the main engine loop)
iteration_messages = self._generator(**vars(self))()
ApplicationDriver.loop(
application=application,
iteration_messages=iteration_messages,
loop_status=loop_status)
except KeyboardInterrupt:
tf.logging.warning('User cancelled application')
except (tf.errors.OutOfRangeError, EOFError):
if not loop_status.get('normal_exit', False):
# reached the end of inference Dataset
loop_status['normal_exit'] = True
except RuntimeError:
import sys
import traceback
exc_type, exc_value, exc_traceback = sys.exc_info()
traceback.print_exception(
exc_type, exc_value, exc_traceback, file=sys.stdout)
finally:
tf.logging.info('cleaning up...')
# broadcasting session finished event
iter_msg = IterationMessage()
iter_msg.current_iter = loop_status.get('current_iter', -1)
SESS_FINISHED.send(application, iter_msg=iter_msg)
application.stop()
if not loop_status.get('normal_exit', False):
# loop didn't finish normally
tf.logging.warning('stopped early, incomplete iterations.')
tf.logging.info(
"%s stopped (time in second %.2f).",
type(application).__name__, (time.time() - start_time))
# pylint: disable=not-context-manager
[docs] @staticmethod
def create_graph(
application, num_gpus=1, num_threads=1, is_training_action=False):
"""
Create a TF graph based on self.app properties
and engine parameters.
:return:
"""
graph = tf.Graph()
main_device = device_string(num_gpus, 0, False, is_training_action)
outputs_collector = OutputsCollector(n_devices=max(num_gpus, 1))
gradients_collector = GradientsCollector(n_devices=max(num_gpus, 1))
# start constructing the graph, handling training and inference cases
with graph.as_default(), tf.device(main_device):
# initialise sampler
with tf.name_scope('Sampler'):
application.initialise_sampler()
for sampler in traverse_nested(application.get_sampler()):
sampler.set_num_threads(num_threads)
# initialise network, these are connected in
# the context of multiple gpus
application.initialise_network()
application.add_validation_flag()
# for data parallelism --
# defining and collecting variables from multiple devices
for gpu_id in range(0, max(num_gpus, 1)):
worker_device = device_string(
num_gpus, gpu_id, True, is_training_action)
scope_string = 'worker_{}'.format(gpu_id)
with tf.name_scope(scope_string), tf.device(worker_device):
# setup network for each of the multiple devices
application.connect_data_and_network(
outputs_collector, gradients_collector)
with tf.name_scope('MergeOutputs'):
outputs_collector.finalise_output_op()
application.outputs_collector = outputs_collector
application.gradients_collector = gradients_collector
GRAPH_CREATED.send(application, iter_msg=None)
return graph
[docs] def load_event_handlers(self, names):
"""
Import event handler modules and create a list of handler instances.
The event handler instances will be stored with this engine.
:param names: strings of event handlers
:return:
"""
if not names:
return
if self._event_handlers:
# disconnect all handlers (assuming always weak connection)
for handler in list(self._event_handlers):
del self._event_handlers[handler]
self._event_handlers = {}
for name in set(names):
the_event_class = EventHandlerFactory.create(name)
# initialise all registered event handler classes
engine_config_dict = vars(self)
key = '{}'.format(the_event_class)
self._event_handlers[key] = the_event_class(**engine_config_dict)
[docs] @staticmethod
def loop(application,
iteration_messages=(),
loop_status=None):
"""
Running ``loop_step`` with ``IterationMessage`` instances
generated by ``iteration_generator``.
This loop stops when any of the condition satisfied:
1. no more element from the ``iteration_generator``;
2. ``application.interpret_output`` returns False;
3. any exception raised.
Broadcasting SESS_* signals at the beginning and end of this method.
This function should be used in a context of
``tf.Session`` or ``session.as_default()``.
:param application: a niftynet.application instance, application
will provides ``tensors`` to be fetched by ``tf.session.run()``.
:param iteration_messages:
a generator of ``engine.IterationMessage`` instances
:param loop_status: optional dictionary used to capture the loop status,
useful when the loop exited in an unexpected manner.
:return:
"""
loop_status = loop_status or {}
for iter_msg in iteration_messages:
loop_status['current_iter'] = iter_msg.current_iter
# run an iteration
ApplicationDriver.loop_step(application, iter_msg)
# Checking stopping conditions
if iter_msg.should_stop:
tf.logging.info('stopping -- event handler: %s.',
iter_msg.should_stop)
break
# loop finished without any exception
loop_status['normal_exit'] = True
[docs] @staticmethod
def loop_step(application, iteration_message):
"""
Calling ``tf.session.run`` with parameters encapsulated in
iteration message as an iteration.
Broadcasting ITER_* events before and afterward.
:param application:
:param iteration_message: an ``engine.IterationMessage`` instances
:return:
"""
# broadcasting event of starting an iteration
ITER_STARTED.send(application, iter_msg=iteration_message)
# ``iter_msg.ops_to_run`` are populated with the ops to run in
# each iteration, fed into ``session.run()`` and then
# passed to the application (and observers) for interpretation.
sess = tf.get_default_session()
assert sess, 'method should be called within a TF session context.'
iteration_message.current_iter_output = sess.run(
iteration_message.ops_to_run,
feed_dict=iteration_message.data_feed_dict)
# broadcasting event of finishing an iteration
ITER_FINISHED.send(application, iter_msg=iteration_message)