Source code for niftynet.application.base_application

# -*- coding: utf-8 -*-
"""
Interface of NiftyNet application
"""
import os
from argparse import Namespace

import tensorflow as tf
from six import with_metaclass

from niftynet.engine.signal import TRAIN, INFER, EVAL

APP_INSTANCE = None  # global so it can be reset


# pylint: disable=global-statement
[docs]class SingletonApplication(type): """ Maintaining a global application instance. """ def __call__(cls, *args, **kwargs): global APP_INSTANCE if APP_INSTANCE is None: APP_INSTANCE = \ super(SingletonApplication, cls).__call__(*args, **kwargs) # else: # raise RuntimeError('application instance already started.') return APP_INSTANCE
[docs] @classmethod def clear(mcs): """ Remove the instance. :return: """ global APP_INSTANCE APP_INSTANCE = None
[docs]class BaseApplication(with_metaclass(SingletonApplication, object)): """ BaseApplication represents an interface. Each application ``type_str`` should support to use the standard training and inference driver. """ # defines name of the customised configuration file section # the section collects all application specific user parameters REQUIRED_CONFIG_SECTION = None SUPPORTED_PHASES = {TRAIN, INFER, EVAL} _action = TRAIN # TF placeholders for switching network on the fly is_validation = None # input of the network readers = None sampler = None # the network net = None # training the network optimiser = None gradient_op = None # interpret network output output_decoder = None outputs_collector = None gradients_collector = None # performance total_loss = None patience = None performance_history = [] mode = None
[docs] def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): """ this function initialise self.readers :param data_param: input modality specifications :param task_param: contains task keywords for grouping data_param :param data_partitioner: specifies train/valid/infer splitting if needed :return: """ raise NotImplementedError
[docs] def initialise_sampler(self): """ Samplers take ``self.reader`` as input and generates sequences of ImageWindow that will be fed to the networks This function sets ``self.sampler``. """ raise NotImplementedError
[docs] def initialise_network(self): """ This function create an instance of network and sets ``self.net`` :return: None """ raise NotImplementedError
[docs] def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): """ Adding sampler output tensor and network tensors to the graph. :param outputs_collector: :param gradients_collector: :return: """ raise NotImplementedError
[docs] def interpret_output(self, batch_output): """ Implement output interpretations, e.g., save to hard drive cache output windows. :param batch_output: outputs by running the tf graph :return: True indicates the driver should continue the loop False indicates the drive should stop """ raise NotImplementedError
[docs] def add_inferred_output_like(self, data_param, task_param, name): """ This function adds entries to parameter objects to enable the evaluation action to automatically read in the output of a previous inference run if inference is not explicitly specified. This can be used in an application if there is a data section entry in the configuration file that matches the inference output. In supervised learning, the reference data section would often match the inference output and could be used here. Otherwise, a template data section could be used. :param data_param: :param task_param: :param name: name of input parameter to copy parameters from :return: modified data_param and task_param """ print(task_param) # Add the data parameter if 'inferred' not in data_param: data_name = vars(task_param)[name][0] inferred_param = Namespace(**vars(data_param[data_name])) inferred_param.csv_file = os.path.join( self.action_param.save_seg_dir, 'inferred.csv') data_param['inferred'] = inferred_param # Add the task parameter if 'inferred' not in task_param or not task_param.inferred: task_param.inferred = ('inferred',) return data_param, task_param
[docs] def set_iteration_update(self, iteration_message): """ At each iteration ``application_driver`` calls:: output = tf.session.run(variables_to_eval, feed_dict=data_dict) to evaluate TF graph elements, where ``variables_to_eval`` and ``data_dict`` are retrieved from ``iteration_message.ops_to_run`` and ``iteration_message.data_feed_dict`` (In addition to the variables collected by self.output_collector). The output of `tf.session.run(...)` will be stored at ``iteration_message.current_iter_output``, and can be accessed from ``engine.handler_network_output.OutputInterpreter``. override this function for more complex operations (such as learning rate decay) according to ``iteration_message.current_iter``. """ if iteration_message.is_training: iteration_message.data_feed_dict[self.is_validation] = False elif iteration_message.is_validation: iteration_message.data_feed_dict[self.is_validation] = True
[docs] def get_sampler(self): """ Get samplers of the application :return: ``niftynet.engine.sampler_*`` instances """ return self.sampler
[docs] def add_validation_flag(self): """ Add a TF placeholder for switching between train/valid graphs, this function sets ``self.is_validation``. ``self.is_validation`` can be used by applications. :return: """ self.is_validation = \ tf.placeholder_with_default(False, [], 'is_validation')
@property def action(self): """ A string indicating the action in train/inference/evaluation :return: """ return self._action @action.setter def action(self, value): """ The action should have at least two characters matching the one of the phase string TRAIN, INFER, EVAL :param value: :return: """ try: self._action = value.lower() assert len(self._action) >= 2 except (AttributeError, AssertionError): tf.logging.fatal('Error setting application action: %s', value) @property def is_training(self): """ :return: boolean value indicating if the phase is training """ return TRAIN.startswith(self.action) @property def is_inference(self): """ :return: boolean value indicating if the phase is inference """ return INFER.startswith(self.action) @property def is_evaluation(self): """ :return: boolean value indicating if the action is evaluation """ return EVAL.startswith(self.action)
[docs] @staticmethod def stop(): """ remove application instance if there's any. :return: """ SingletonApplication.clear()