Source code for niftynet.engine.application_driver

# -*- coding: utf-8 -*-
"""
This module defines a general procedure for running applications.

Example usage::
    app_driver = ApplicationDriver()
    app_driver.initialise_application(system_param, input_data_param)
    app_driver.run_application()

``system_param`` and ``input_data_param`` should be generated using:
``niftynet.utilities.user_parameters_parser.run()``
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import time
import itertools

import tensorflow as tf
from blinker import signal

from niftynet.engine.application_factory import ApplicationFactory
from niftynet.engine.application_iteration import IterationMessage
from niftynet.engine.application_variables import \
    CONSOLE, NETWORK_OUTPUT, TF_SUMMARIES
from niftynet.engine.application_variables import \
    GradientsCollector, OutputsCollector, global_vars_init_or_restore
from niftynet.io.image_sets_partitioner import ImageSetsPartitioner
from niftynet.io.image_sets_partitioner import TRAIN, VALID, INFER
from niftynet.io.misc_io import get_latest_subfolder, touch_folder
from niftynet.layer.bn import BN_COLLECTION
from niftynet.utilities.util_common import set_cuda_device, traverse_nested

FILE_PREFIX = 'model.ckpt'

# pylint: disable=too-many-instance-attributes
[docs]class ApplicationDriver(object): """ This class initialises an application by building a TF graph, and maintaining a session and coordinator. It controls the starting/stopping of an application. Applications should be implemented by inheriting ``niftynet.application.base_application`` to be compatible with this driver. """ # pylint: disable=too-many-instance-attributes pre_train_iter = signal('pre_train_iter') post_train_iter = signal('post_train_iter') pre_validation_iter = signal('pre_validation_iter') post_validation_iter = signal('post_validation_iter') pre_infer_iter = signal('pre_infer_iter') post_infer_iter = signal('post_infer_iter') post_training = signal('post_training') def __init__(self): self.app = None self.graph = tf.Graph() self.saver = None self.is_training = True self.num_threads = 0 self.num_gpus = 0 self.model_dir = None self.summary_dir = None self.session_prefix = None self.max_checkpoints = 2 self.save_every_n = 0 self.tensorboard_every_n = -1 self.validation_every_n = -1 self.validation_max_iter = 1 self.initial_iter = 0 self.final_iter = 0 self._coord = tf.train.Coordinator() self._init_op = None self._data_partitioner = None self.outputs_collector = None self.gradients_collector = None self.console = None self.tensorboard = None self.model_saver = None
[docs] def initialise_application(self, workflow_param, data_param): """ This function receives all parameters from user config file, create an instance of application. :param workflow_param: a dictionary of user parameters, keys correspond to sections in the config file :param data_param: a dictionary of input image parameters, keys correspond to data properties to be used by image_reader :return: """ try: system_param = workflow_param.get('SYSTEM', None) net_param = workflow_param.get('NETWORK', None) train_param = workflow_param.get('TRAINING', None) infer_param = workflow_param.get('INFERENCE', None) app_param = workflow_param.get('CUSTOM', None) except AttributeError: tf.logging.fatal('parameters should be dictionaries') raise assert os.path.exists(system_param.model_dir), \ 'Model folder not exists {}'.format(system_param.model_dir) self.is_training = (system_param.action == "train") # hardware-related parameters self.num_threads = max(system_param.num_threads, 1) \ if self.is_training else 1 self.num_gpus = system_param.num_gpus \ if self.is_training else min(system_param.num_gpus, 1) set_cuda_device(system_param.cuda_devices) # set output TF model folders self.model_dir = touch_folder( os.path.join(system_param.model_dir, 'models')) self.session_prefix = os.path.join(self.model_dir, FILE_PREFIX) # set training params. if self.is_training: assert train_param, 'training parameters not specified' summary_root = os.path.join(system_param.model_dir, 'logs') self.summary_dir = get_latest_subfolder( summary_root, create_new=train_param.starting_iter == 0) self.initial_iter = train_param.starting_iter self.final_iter = max(train_param.max_iter, self.initial_iter) self.save_every_n = train_param.save_every_n self.tensorboard_every_n = train_param.tensorboard_every_n self.max_checkpoints = \ max(train_param.max_checkpoints, self.max_checkpoints) self.gradients_collector = GradientsCollector( n_devices=max(self.num_gpus, 1)) self.validation_every_n = train_param.validation_every_n if self.validation_every_n > 0: self.validation_max_iter = max(self.validation_max_iter, train_param.validation_max_iter) action_param = train_param else: # set inference params. assert infer_param, 'inference parameters not specified' self.initial_iter = infer_param.inference_iter action_param = infer_param self.outputs_collector = OutputsCollector( n_devices=max(self.num_gpus, 1)) # create an application instance assert app_param, 'application specific param. not specified' app_module = ApplicationDriver._create_app(app_param.name) self.app = app_module(net_param, action_param, system_param.action) # initialise data input data_partitioner = ImageSetsPartitioner() # clear the cached file lists data_partitioner.reset() do_new_partition = \ self.is_training and self.initial_iter == 0 and \ (not os.path.isfile(system_param.dataset_split_file)) and \ (train_param.exclude_fraction_for_validation > 0 or train_param.exclude_fraction_for_inference > 0) data_fractions = None if do_new_partition: assert train_param.exclude_fraction_for_validation > 0 or \ self.validation_every_n <= 0, \ 'validation_every_n is set to {}, ' \ 'but train/validation splitting not available,\nplease ' \ 'check "exclude_fraction_for_validation" in the config ' \ 'file (current config value: {}).'.format( self.validation_every_n, train_param.exclude_fraction_for_validation) data_fractions = (train_param.exclude_fraction_for_validation, train_param.exclude_fraction_for_inference) if data_param: data_partitioner.initialise( data_param=data_param, new_partition=do_new_partition, ratios=data_fractions, data_split_file=system_param.dataset_split_file) if data_param and self.is_training and self.validation_every_n > 0: assert data_partitioner.has_validation, \ 'validation_every_n is set to {}, ' \ 'but train/validation splitting not available.\nPlease ' \ 'check dataset partition list {} ' \ '(remove file to generate a new dataset partition). ' \ 'Or set validation_every_n to -1.'.format( self.validation_every_n, system_param.dataset_split_file) # initialise readers self.app.initialise_dataset_loader( data_param, app_param, data_partitioner) self._data_partitioner = data_partitioner # pylint: disable=not-context-manager with self.graph.as_default(), tf.name_scope('Sampler'): self.app.initialise_sampler()
def _run_sampler_threads(self, session=None): """ Get samplers from application and try to run sampler threads. Note: Overriding app.get_sampler() method by returning None to bypass this step. :param session: TF session used for fill tf.placeholders with sampled data :return: """ if session is None: return if self._coord is None: return if self.num_threads <= 0: return try: samplers = self.app.get_sampler() for sampler in traverse_nested(samplers): if sampler is None: continue sampler.run_threads(session, self._coord, self.num_threads) tf.logging.info('Filling queues (this can take a few minutes)') except (NameError, TypeError, AttributeError, IndexError): tf.logging.fatal( "samplers not running, pop_batch_op operations " "are blocked.") raise
[docs] def run_application(self): """ Initialise a TF graph, connect data sampler and network within the graph context, run training loops or inference loops. The training loop terminates when ``self.final_iter`` reached. The inference loop terminates when there is no more image sample to be processed from image reader. :return: """ config = ApplicationDriver._tf_config() with tf.Session(config=config, graph=self.graph) as session: # start samplers' threads self._run_sampler_threads(session=session) self.graph = self._create_graph(self.graph) # check app variables initialised and ready for starts self.app.check_initialisations() # initialise network trainable parameters self._rand_init_or_restore_vars(session) start_time = time.time() loop_status = {} try: # iteratively run the graph if self.is_training: self.model_saver = ModelSaver(session, self.saver, self.save_every_n, self.session_prefix) loop_status['current_iter'] = self.initial_iter self._training_loop(session, loop_status) else: loop_status['all_saved_flag'] = False self._inference_loop(session, loop_status) except KeyboardInterrupt: tf.logging.warning('User cancelled application') except tf.errors.OutOfRangeError: if loop_status.get('all_saved_flag', None) is not None: # reached the end of inference Dataset loop_status['all_saved_flag'] = True except RuntimeError: import sys import traceback exc_type, exc_value, exc_traceback = sys.exc_info() traceback.print_exception( exc_type, exc_value, exc_traceback, file=sys.stdout) finally: tf.logging.info('Cleaning up...') if self.is_training: # saving model at the last iteration iter_msg = IterationMessage() iter_msg.current_iter = loop_status.get('current_iter', -1) self.post_training.send(iter_msg) elif not loop_status.get('all_saved_flag', None): tf.logging.warning('stopped early, incomplete loops') tf.logging.info('stopping sampling threads') self.app.stop() tf.logging.info( "%s stopped (time in second %.2f).", type(self.app).__name__, (time.time() - start_time))
# pylint: disable=not-context-manager def _create_graph(self, graph=tf.Graph()): """ TensorFlow graph is only created within this function. """ assert isinstance(graph, tf.Graph) main_device = self._device_string(0, is_worker=False) # start constructing the graph, handling training and inference cases with graph.as_default(), tf.device(main_device): # initialise network, these are connected in # the context of multiple gpus self.app.initialise_network() self.app.add_validation_flag() # for data parallelism -- # defining and collecting variables from multiple devices bn_ops = None for gpu_id in range(0, max(self.num_gpus, 1)): worker_device = self._device_string(gpu_id, is_worker=True) scope_string = 'worker_{}'.format(gpu_id) with tf.name_scope(scope_string) as scope: with tf.device(worker_device): # setup network for each of the multiple devices self.app.connect_data_and_network( self.outputs_collector, self.gradients_collector) if self.is_training: # batch norm statistics from the last device bn_ops = tf.get_collection(BN_COLLECTION, scope) # assemble all training operations if self.is_training and self.gradients_collector: updates_op = [] # batch normalisation moving averages operation if bn_ops: updates_op.extend(bn_ops) # combine them with model parameter updating operation with tf.name_scope('ApplyGradients'): with graph.control_dependencies(updates_op): self.app.set_network_gradient_op( self.gradients_collector.gradients) # initialisation operation with tf.name_scope('Initialization'): self._init_op = global_vars_init_or_restore() with tf.name_scope('MergedOutputs'): self.outputs_collector.finalise_output_op() # saving operation self.saver = tf.train.Saver(max_to_keep=self.max_checkpoints, save_relative_paths=True) # no more operation definitions after this point tf.Graph.finalize(graph) return graph def _rand_init_or_restore_vars(self, sess): """ Randomly initialising all trainable variables defined in session, or loading checkpoint files as variable initialisations. """ tf.logging.info('starting from iter %d', self.initial_iter) if self.is_training and self.initial_iter == 0: sess.run(self._init_op) tf.logging.info('Parameters from random initialisations ...') return # check model's folder assert os.path.exists(self.model_dir), \ "Model folder not found {}, please check" \ "config parameter: model_dir".format(self.model_dir) # check model's file ckpt_state = tf.train.get_checkpoint_state(self.model_dir) if ckpt_state is None: tf.logging.warning( "%s/checkpoint not found, please check " "config parameter: model_dir", self.model_dir) if self.initial_iter > 0: checkpoint = '{}-{}'.format(self.session_prefix, self.initial_iter) else: try: checkpoint = ckpt_state.model_checkpoint_path assert checkpoint, 'checkpoint path not found ' \ 'in {}/checkpoints'.format(self.model_dir) self.initial_iter = int(checkpoint.rsplit('-')[-1]) tf.logging.info('set initial_iter to %d based ' 'on checkpoints', self.initial_iter) except (ValueError, AttributeError): tf.logging.fatal( 'failed to get iteration number ' 'from checkpoint path, please set ' 'inference_iter or starting_iter to a positive integer') raise # restore session tf.logging.info('Accessing %s ...', checkpoint) try: self.saver.restore(sess, checkpoint) except tf.errors.NotFoundError: tf.logging.fatal( 'checkpoint %s not found or variables to restore do not ' 'match the current application graph', checkpoint) raise
[docs] def interleaved_iteration_generator(self): """ This generator yields a sequence of training and validation iterations """ train_iters = iter_generator(range(self.initial_iter + 1, self.final_iter + 1), TRAIN) for train_iter_msg in train_iters: self.app.set_iteration_update(train_iter_msg) yield train_iter_msg if train_iter_msg.current_iter > 0 and\ self.validation_every_n > 0 and \ (train_iter_msg.current_iter % self.validation_every_n == 0): val_iters = [train_iter_msg.current_iter] val_iters = val_iters * self.validation_max_iter valid_iters = iter_generator(val_iters, VALID) for valid_iter_msg in valid_iters: self.app.set_iteration_update(valid_iter_msg) yield valid_iter_msg
def _loop(self, iteration_generator, sess, loop_status): for iter_msg in iteration_generator: if self._coord.should_stop(): break if iter_msg.should_stop: break loop_status['current_iter'] = iter_msg.current_iter iter_msg.pre_iter.send(iter_msg) iter_msg.ops_to_run[NETWORK_OUTPUT] = \ self.outputs_collector.variables(NETWORK_OUTPUT) graph_output = sess.run(iter_msg.ops_to_run, feed_dict=iter_msg.data_feed_dict) iter_msg.current_iter_output = graph_output iter_msg.status = self.app.interpret_output( iter_msg.current_iter_output[NETWORK_OUTPUT]) iter_msg.post_iter.send(iter_msg) if iter_msg.should_stop: break def _training_loop(self, sess, loop_status): """ The training loop iterates through training (and validation) iterations Each iteration is represented as an ``IterationMessage`` object, whose ops_to_run are populated with the ops to run in each iteration (by the training loop or by objects watching for iteration events), fed into into `session.run()` and then passed to the app (and observers) for interpretation. """ # Add observers for tensorboard, and console output (move to io?) self.tensorboard = TensorBoardLogger(self.outputs_collector, self.summary_dir, sess.graph, self.tensorboard_every_n) self.console = ConsoleLogger(self.outputs_collector) # Core training loop handling def add_gradient(iter_msg): """ Event handler to add the backpropagation update. iter_msg is an IterationMessage object """ iter_msg.ops_to_run['gradients'] = self.app.gradient_op self.pre_train_iter.connect(add_gradient) self._loop(self.interleaved_iteration_generator(), sess, loop_status) def _inference_loop(self, sess, loop_status): """ Runs all variables returned by outputs_collector, this loop stops when the return value of application.interpret_output is False. """ loop_status['all_saved_flag'] = False self.console = ConsoleLogger(self.outputs_collector) def is_complete(iter_msg): """ Event handler to trigger the completion message. iter_msg is an IterationMessage object """ if not iter_msg.status: tf.logging.info('processed all batches.') loop_status['all_saved_flag'] = True iter_msg.should_stop = True self.post_infer_iter.connect(is_complete) self._loop(iter_generator(itertools.count(), INFER), sess, loop_status) def _device_string(self, device_id=0, is_worker=True): """ assigning CPU/GPU based on user specifications """ # pylint: disable=no-name-in-module from tensorflow.python.client import device_lib devices = device_lib.list_local_devices() n_local_gpus = sum([x.device_type == 'GPU' for x in devices]) if self.num_gpus <= 0: # user specified no gpu at all return '/cpu:{}'.format(device_id) if self.is_training: # in training: use gpu only for workers whenever n_local_gpus device = 'gpu' if (is_worker and n_local_gpus > 0) else 'cpu' if device == 'gpu' and device_id >= n_local_gpus: tf.logging.warning( 'trying to use gpu id %s, but only has %s GPU(s), ' 'please set num_gpus to %s at most', device_id, n_local_gpus, n_local_gpus) #raise ValueError return '/{}:{}'.format(device, device_id) # in inference: use gpu for everything whenever n_local_gpus return '/gpu:0' if n_local_gpus > 0 else '/cpu:0' @staticmethod def _create_app(app_type_string): """ Import the application module """ app_class = ApplicationFactory.create(app_type_string) app_class.clear() return app_class @staticmethod def _tf_config(): """ tensorflow system configurations """ config = tf.ConfigProto() config.log_device_placement = False config.allow_soft_placement = True return config
[docs]def iter_generator(count_generator, phase): """ Generate a numbered sequence of IterationMessage objects with phase-appropriate signals. count_generator is an iterable object yielding iteration numbers phase is one of TRAIN, VALID or INFER """ signals = {TRAIN: (ApplicationDriver.pre_train_iter, ApplicationDriver.post_train_iter), VALID: (ApplicationDriver.pre_validation_iter, ApplicationDriver.post_validation_iter), INFER: (ApplicationDriver.pre_infer_iter, ApplicationDriver.post_infer_iter)} for iter_i in count_generator: iter_msg = IterationMessage() iter_msg.current_iter, iter_msg.phase = iter_i, phase iter_msg.pre_iter = signals[phase][0] iter_msg.post_iter = signals[phase][1] yield iter_msg
[docs]class ConsoleLogger(object): """ This class handles iteration events to print output to the console """ def __init__(self, outputs_collector): self.outputs_collector = outputs_collector ApplicationDriver.pre_train_iter.connect(self.on_pre_iter) ApplicationDriver.pre_validation_iter.connect(self.on_pre_iter) ApplicationDriver.post_train_iter.connect(self.on_post_iter) ApplicationDriver.post_validation_iter.connect(self.on_post_iter) ApplicationDriver.post_infer_iter.connect(self.on_post_iter)
[docs] def on_pre_iter(self, iter_msg): """ Event handler to add all the console output ops to the iteration message. iter_msg is an IterationMessage object """ iter_msg.ops_to_run[CONSOLE] = \ self.outputs_collector.variables(CONSOLE)
# pylint: disable=no-self-use
[docs] def on_post_iter(self, iter_msg): """ Event handler to log the console outputs from the iteration message. iter_msg is an IterationMessage object """ tf.logging.info(iter_msg.to_console_string())
[docs]class ModelSaver(object): """This class handles iteration events to save the model at regular intervals and at the end of training.""" def __init__(self, sess, saver, save_every_n, session_prefix): self.sess = sess self.session_prefix = session_prefix self.save_every_n = save_every_n self.saver = saver ApplicationDriver.post_training.connect(self.save_model) ApplicationDriver.post_train_iter.connect(self.save_model_interval)
[docs] def save_model_interval(self, iter_msg): """ Save the model if the iteration at at the specified interval iter_msg is an IterationMessage object """ if iter_msg.current_iter > 0 and \ self.save_every_n > 0 and \ (iter_msg.current_iter % self.save_every_n == 0): self.save_model(iter_msg)
[docs] def save_model(self, iter_msg): """ Save the model regardless of iteration iter_msg is an IterationMessage object """ self.saver.save(sess=self.sess, save_path=self.session_prefix, global_step=iter_msg.current_iter) tf.logging.info('iter %d saved: %s', iter_msg.current_iter, self.session_prefix)
[docs]class TensorBoardLogger(object): """ This class handles iteration events to log summaries to the TensorBoard log.""" def __init__(self, outputs_collector, summary_dir, graph, tensorboard_every_n): self.summary_dir = summary_dir self.writer_train = tf.summary.FileWriter( os.path.join(self.summary_dir, TRAIN), graph) self.writer_valid = tf.summary.FileWriter( os.path.join(self.summary_dir, VALID), graph) self.outputs_collector = outputs_collector self.tensorboard_every_n = tensorboard_every_n ApplicationDriver.pre_train_iter.connect(self.on_pre_iter) ApplicationDriver.pre_validation_iter.connect(self.on_pre_iter) ApplicationDriver.post_train_iter.connect(self.on_post_train_iter) ApplicationDriver.post_validation_iter.connect( self.on_post_validation_iter)
[docs] def filter(self, iter_msg): """ Decide whether to save a TensorBoard log entry for a given iteration. iter_msg is an IterationMessage object """ return self.tensorboard_every_n > 0 and \ (iter_msg.current_iter % self.tensorboard_every_n == 0)
[docs] def on_pre_iter(self, iter_msg): """ Event handler to add all the TensorBoard summaries to the iteration message. iter_msg is an IterationMessage object """ if self.filter(iter_msg): iter_msg.ops_to_run[TF_SUMMARIES] = \ self.outputs_collector.variables(TF_SUMMARIES)
[docs] def on_post_train_iter(self, iter_msg): """ Event handler to write the training TensorBoard log entry to disk. iter_msg is an IterationMessage object """ if self.filter(iter_msg): iter_msg.to_tf_summary(self.writer_train)
[docs] def on_post_validation_iter(self, iter_msg): """ Event handler to write the valiation TensorBoard log entry to disk. iter_msg is an IterationMessage object """ if self.filter(iter_msg): iter_msg.to_tf_summary(self.writer_valid)