niftynet.engine.application_variables module

Managing tf.Tensor variables for initialising/evaluating networks

class GradientsCollector(n_devices=1)[source]

Bases: object

This collector has a list of all gradients, collected when constructing the tf graph. The gradient from multiple GPUs will be averaged, and the averaged op is added to graph by application driver


Add gradient generated by optimiser.compute_gradients to a dictionary. This will be retrieved during training. The gradient can be a list of model updates, application can choose to implement set_iteration_update() interface to specify how the gradients are used at each iteration

Parameters:gradients – generated by optimiser.compute_gradients(loss)

this function returns averaged gradient over devices used by application driver :return: averaged gradients over devices

class OutputsCollector(n_devices=1)[source]

Bases: object

Collect all tf.Tensor object, to be evaluated by These objects are grouped into:

NETWORK_OUTPUT: to be decoded by an aggregator
CONSOLE: to be printed on command line
TF_SUMMARIES: to be added to tensorboard visualisation
add_to_collection(var, name, average_over_devices=False, collection='niftynetconsole', summary_type='scalar')[source]

add tf.Tensors to be evaluated to dictionary The dictionaries will be retrieved and evaluated by application driver in the train/infer loops

  • var – tf.Tensor to be evaluated by tf.Session()
  • name – name of the variable (for displaying purposes)
  • average_over_devices
  • collection – in choices of [CONSOLE, TF_SUMMARIES, NETWORK_OUTPUT]
  • summary_type – if adding to TF_SUMMARIES, there are a few possible ways to visualise the Tensor value see SUPPORTED_SUMMARY


get tf.Tensors to be evaluated by tf.Session().run()

Parameters:collection – in choices of [CONSOLE, TF_SUMMARIES, NETWORK_OUTPUT]
Returns:a variable dictionary

This function checks the dictionary, if the variable needs to be averaged over devices, then a reduce_mean node is added to the graph. This function should be called in ApplicationDriver.create_graph function


For any scope added to RESTORABLE collection: variable will be restored from a checkpoint if it exists in the specified checkpoint and no scope ancestor can restore it.