niftynet.engine.application_variables module

Managing tf.Tensor variables for initialising/evaluating networks

class GradientsCollector(n_devices=1)[source]

Bases: object

This collector has a list of all gradients, collected when constructing the tf graph. The gradient from multiple GPUs will be averaged, and the averaged op is added to graph by application driver

add_to_collection(gradients)[source]

Add gradient generated by optimiser.compute_gradients to a dictionary. This will be retrieved during training. The gradient can be a list of model updates, application can choose to implement set_iteration_update() interface to specify how the gradients are used at each iteration

Parameters:gradients – generated by optimiser.compute_gradients(loss)
Returns:
gradients

this function returns averaged gradient over devices used by application driver :return: averaged gradients over devices

class OutputsCollector(n_devices=1)[source]

Bases: object

Collect all tf.Tensor object, to be evaluated by tf.Session.run() These objects are grouped into:

NETWORK_OUTPUT: to be decoded by an aggregator
CONSOLE: to be printed on command line
TF_SUMMARIES: to be added to tensorboard visualisation
add_to_collection(var, name, average_over_devices=False, collection='niftynetconsole', summary_type='scalar')[source]

add tf.Tensors to be evaluated to dictionary The dictionaries will be retrieved and evaluated by application driver in the train/infer loops

Parameters:
  • var – tf.Tensor to be evaluated by tf.Session()
  • name – name of the variable (for displaying purposes)
  • average_over_devices
  • collection – in choices of [CONSOLE, TF_SUMMARIES, NETWORK_OUTPUT]
  • summary_type – if adding to TF_SUMMARIES, there are a few possible ways to visualise the Tensor value see SUPPORTED_SUMMARY
Returns:

variables(collection='niftynetconsole')[source]

get tf.Tensors to be evaluated by tf.Session().run()

Parameters:collection – in choices of [CONSOLE, TF_SUMMARIES, NETWORK_OUTPUT]
Returns:a variable dictionary
finalise_output_op()[source]

This function checks the dictionary, if the variable needs to be averaged over devices, then a reduce_mean node is added to the graph. This function should be called in ApplicationDriver.create_graph function

global_vars_init_or_restore(var_list=None)[source]

For any scope added to RESTORABLE collection: variable will be restored from a checkpoint if it exists in the specified checkpoint and no scope ancestor can restore it.