niftynet.engine.application_variables module

Managing tf.Tensor variables for initialising/evaluating networks

class niftynet.engine.application_variables.GradientsCollector(n_devices=1)

Bases: object

This collector has a list of all gradients, collected when constructing the tf graph. The gradient from multiple GPUs will be averaged, and the averaged op is added to graph by application driver

add_to_collection(gradients)

Add gradient generated by optimiser.compute_gradients to a dictionary. This will be retrieved during training in training_op. The gradient can be a list of model updates, application can choose to implement training_ops() to specify how the gradients are used at each iteration

Parameters:gradients – generated by optimiser.compute_gradients(loss)
Returns:
gradients

this function returns avareged gradient over devices used by application driver :return: averaged gradients over devices

class niftynet.engine.application_variables.OutputsCollector(n_devices=1)

Bases: object

Collect all tf.Tensor object, to be evaluated by tf.Session.run() These objects are grouped into

NETORK_OUTPUT: to be decoded by an aggregator CONSOLE: to be printed on command line TF_SUMMARIES: to be added to tensorboard visualisation
add_to_collection(var, name, average_over_devices=False, collection='niftynetconsole', summary_type='scalar')

add tf.Tensors to be evaulated to dictionary The dictionaries will be retrieved and evaluated by application driver in the train/infer loops

:param var tf.Tensor to be evaluated by tf.Session() :param name name of the variable (for displaying purposes) :param average_over_devices :param collection: in choices of

[CONSOLE, TF_SUMMARIES, NETORK_OUTPUT]
:param summary_type if adding to TF_SUMMARIES, there are
a few possible ways to visualise the Tensor value see SUPPORTED_SUMMARY
Returns:
finalise_output_op()

This function checks the dictionary, if the variable needs to be averaged over devices, then a reduce_mean node is added to the graph. This function should be called in ApplicationDriver.create_graph function

variables(collection='niftynetconsole')

get tf.Tensors to be evaulated by tf.Session().run() :param collection: in choices of

[CONSOLE, TF_SUMMARIES, NETORK_OUTPUT]
Returns:a variable dictionary
niftynet.engine.application_variables.global_vars_init_or_restore(var_list=None)

For any scope added to RESTORABLE collection: variable will be restored from a checkpoint if it exists in the specified checkpoint and no scope ancestor can restore it.