niftynet.engine.application_variables module¶
Managing tf.Tensor variables for initialising/evaluating networks
-
class
niftynet.engine.application_variables.
GradientsCollector
(n_devices=1)¶ Bases:
object
This collector has a list of all gradients, collected when constructing the tf graph. The gradient from multiple GPUs will be averaged, and the averaged op is added to graph by application driver
-
add_to_collection
(gradients)¶ Add gradient generated by optimiser.compute_gradients to a dictionary. This will be retrieved during training in training_op. The gradient can be a list of model updates, application can choose to implement training_ops() to specify how the gradients are used at each iteration
Parameters: gradients – generated by optimiser.compute_gradients(loss) Returns:
-
gradients
¶ this function returns avareged gradient over devices used by application driver :return: averaged gradients over devices
-
-
class
niftynet.engine.application_variables.
OutputsCollector
(n_devices=1)¶ Bases:
object
Collect all tf.Tensor object, to be evaluated by tf.Session.run() These objects are grouped into
NETORK_OUTPUT: to be decoded by an aggregator CONSOLE: to be printed on command line TF_SUMMARIES: to be added to tensorboard visualisation-
add_to_collection
(var, name, average_over_devices=False, collection='niftynetconsole', summary_type='scalar')¶ add tf.Tensors to be evaulated to dictionary The dictionaries will be retrieved and evaluated by application driver in the train/infer loops
:param var tf.Tensor to be evaluated by tf.Session() :param name name of the variable (for displaying purposes) :param average_over_devices :param collection: in choices of
[CONSOLE, TF_SUMMARIES, NETORK_OUTPUT]- :param summary_type if adding to TF_SUMMARIES, there are
- a few possible ways to visualise the Tensor value see SUPPORTED_SUMMARY
Returns:
-
finalise_output_op
()¶ This function checks the dictionary, if the variable needs to be averaged over devices, then a reduce_mean node is added to the graph. This function should be called in ApplicationDriver.create_graph function
-
variables
(collection='niftynetconsole')¶ get tf.Tensors to be evaulated by tf.Session().run() :param collection: in choices of
[CONSOLE, TF_SUMMARIES, NETORK_OUTPUT]Returns: a variable dictionary
-
-
niftynet.engine.application_variables.
global_vars_init_or_restore
(var_list=None)¶ For any scope added to RESTORABLE collection: variable will be restored from a checkpoint if it exists in the specified checkpoint and no scope ancestor can restore it.