Source code for niftynet.application.base_application
# -*- coding: utf-8 -*-
"""
Interface of NiftyNet application
"""
import tensorflow as tf
from six import with_metaclass
from niftynet.layer.base_layer import TrainableLayer
from niftynet.utilities import util_common
[docs]class SingletonApplication(type):
_instances = None
def __call__(cls, *args, **kwargs):
if cls._instances is None:
cls._instances = \
super(SingletonApplication, cls).__call__(*args, **kwargs)
# else:
# raise RuntimeError('application instance already started.')
return cls._instances
[docs]class BaseApplication(with_metaclass(SingletonApplication, object)):
"""
BaseApplication represents an interface.
Each application ``type_str`` should support to use
the standard training and inference driver.
"""
# defines name of the customised configuration file section
# the section collects all application specific user parameters
REQUIRED_CONFIG_SECTION = None
# boolean flag
is_training = True
# TF placeholders for switching network on the fly
is_validation = None
# input of the network
readers = None
sampler = None
# the network
net = None
# training the network
optimiser = None
gradient_op = None
# interpret network output
output_decoder = None
[docs] def check_initialisations(self):
if self.readers is None:
raise NotImplementedError('reader should be initialised')
if self.sampler is None:
raise NotImplementedError('sampler should be initialised')
if self.net is None:
raise NotImplementedError('net should be initialised')
if not isinstance(self.net, TrainableLayer):
raise ValueError('self.net should be an instance'
' of niftynet.layer.TrainableLayer')
if self.optimiser is None and self.is_training:
raise NotImplementedError('optimiser should be initialised')
if self.gradient_op is None and self.is_training:
raise NotImplementedError('gradient_op should be initialised')
if self.output_decoder is None and not self.is_training:
raise NotImplementedError('output decoder should be initialised')
[docs] def initialise_dataset_loader(
self, data_param=None, task_param=None, data_partitioner=None):
"""
this function initialise self.readers
:param data_param: input modality specifications
:param task_param: contains task keywords for grouping data_param
:param data_partitioner:
specifies train/valid/infer splitting if needed
:return:
"""
raise NotImplementedError
[docs] def initialise_sampler(self):
"""
Samplers take ``self.reader`` as input and generates
sequences of ImageWindow that will be fed to the networks
This function sets self.samplers.
"""
raise NotImplementedError
[docs] def initialise_network(self):
"""
This function create an instance of network and sets ``self.net``
:return: None
"""
raise NotImplementedError
[docs] def connect_data_and_network(self,
outputs_collector=None,
gradients_collector=None):
"""
Adding sampler output tensor and network tensors to the graph.
:param outputs_collector:
:param gradients_collector:
:return:
"""
raise NotImplementedError
[docs] def interpret_output(self, batch_output):
"""
Implement output interpretations, e.g., save to hard drive
cache output windows.
:param batch_output: outputs by running the tf graph
:return: True indicates the driver should continue the loop
False indicates the drive should stop
"""
raise NotImplementedError
[docs] def set_network_gradient_op(self, gradients):
"""
create gradient op by optimiser.apply_gradients
this function sets ``self.gradient_op``.
Override this function for more complex optimisations such as
using different optimisers for sub-networks.
:param gradients: processed gradients from the gradient_collector
:return:
"""
grad_list_depth = util_common.list_depth_count(gradients)
if grad_list_depth == 3:
# nested depth 3 means: gradients list is nested in terms of:
# list of networks -> list of network variables
self.gradient_op = [self.optimiser.apply_gradients(grad)
for grad in gradients]
elif grad_list_depth == 2:
# nested depth 2 means:
# gradients list is a list of variables
self.gradient_op = self.optimiser.apply_gradients(gradients)
else:
raise NotImplementedError(
'This app supports updating a network, or a list of networks.')
[docs] def stop(self):
for sampler_set in self.get_sampler():
for sampler in sampler_set:
if sampler:
sampler.close_all()
[docs] def set_iteration_update(self, iteration_message):
"""
At each iteration ``application_driver`` calls:
``output = tf.session.run(variables_to_eval, feed_dict=data_dict)``
to evaluate TF graph elements, where
``variables_to_eval`` and ``data_dict`` are retrieved from
``application_iteration.IterationMessage.ops_to_run`` and
``application_iteration.IterationMessage.data_feed_dict``.
in addition to the variables collected by output_collector;
implemented in ``application_driver.run_vars``)
This function (is called before ``tf.session.run`` by the
driver) provides an interface for accessing ``variables_to_eval`` and
``data_dict`` at each iteration.
Override this function for more complex operations according to
``application_iteration.IterationMessage.current_iter``.
"""
if iteration_message.is_training:
iteration_message.data_feed_dict[self.is_validation] = False
elif iteration_message.is_validation:
iteration_message.data_feed_dict[self.is_validation] = True
[docs] def get_sampler(self):
return self.sampler
[docs] def add_validation_flag(self):
"""
add a TF placeholder for switching between train/valid graphs,
this function sets ``self.is_validation``. ``self.is_validation``
can be used by applications.
:return:
"""
self.is_validation = \
tf.placeholder_with_default(False, [], 'is_validation')