Source code for niftynet.engine.handler_tensorboard

# -*- coding: utf-8 -*-
"""
This module implements a TensorBoard log writer.
"""
import os

import tensorflow as tf

from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.signal import \
    TRAIN, VALID, ITER_STARTED, ITER_FINISHED, GRAPH_CREATED
from niftynet.io.misc_io import get_latest_subfolder


[docs]class TensorBoardLogger(object): """ This class handles iteration events to log summaries to the TensorBoard log. """ def __init__(self, model_dir=None, initial_iter=0, tensorboard_every_n=0, **_unused): self.tensorboard_every_n = tensorboard_every_n # creating new summary subfolder if it's not finetuning self.summary_dir = get_latest_subfolder( os.path.join(model_dir, 'logs'), create_new=initial_iter == 0) self.writer_train = None self.writer_valid = None GRAPH_CREATED.connect(self.init_writer) ITER_STARTED.connect(self.read_tensorboard_op) ITER_FINISHED.connect(self.write_tensorboard)
[docs] def init_writer(self, _sender, **_unused_msg): """ Initialise summary writers. :param _sender: :param msg: :return: """ # initialise summary writer if not self.summary_dir or self.tensorboard_every_n <= 0: return self.writer_train = tf.summary.FileWriter( os.path.join(self.summary_dir, TRAIN), tf.get_default_graph()) self.writer_valid = tf.summary.FileWriter( os.path.join(self.summary_dir, VALID), tf.get_default_graph())
[docs] def read_tensorboard_op(self, sender, **msg): """ Get TensorBoard summary_op from application at the beginning of each iteration. :param sender: a niftynet.application instance :param msg: should contain an IterationMessage instance """ _iter_msg = msg['iter_msg'] if _iter_msg.is_inference: return if not self._is_writing(_iter_msg.current_iter): return tf_summary_ops = sender.outputs_collector.variables(TF_SUMMARIES) _iter_msg.ops_to_run[TF_SUMMARIES] = tf_summary_ops
[docs] def write_tensorboard(self, _sender, **msg): """ Write to tensorboard when received the iteration finished signal. :param _sender: :param msg: """ _iter_msg = msg['iter_msg'] if not self._is_writing(_iter_msg.current_iter): return if _iter_msg.is_training: _iter_msg.to_tf_summary(self.writer_train) elif _iter_msg.is_validation: _iter_msg.to_tf_summary(self.writer_valid)
def _is_writing(self, c_iter): """ Decide whether to save a TensorBoard log entry for a given iteration. :param c_iter: Integer of the current iteration number :return: boolean True if is writing at the current iteration """ if self.writer_valid is None or self.writer_train is None: return False if not self.summary_dir: return False return c_iter % self.tensorboard_every_n == 0