Source code for niftynet.engine.handler_model

# -*- coding: utf-8 -*-
"""
This module implements a model checkpoint loader and writer.
"""
import os

import tensorflow as tf

from niftynet.engine.application_variables import global_vars_init_or_restore
from niftynet.engine.signal import \
    ITER_FINISHED, SESS_FINISHED, SESS_STARTED
from niftynet.io.misc_io import touch_folder

FILE_PREFIX = 'model.ckpt'


[docs]def make_model_name(model_dir): """ Make the model checkpoint folder. the checkpoint file will be located at `model_dir/models/` folder, the filename will start with FILE_PREFIX. :param model_dir: niftynet model folder :return: a partial name of a checkpoint file `model_dir/model/FILE_PREFIX` """ _model_dir = touch_folder(os.path.join(model_dir, 'models')) return os.path.join(_model_dir, FILE_PREFIX)
[docs]class ModelRestorer(object): """ This class handles restoring the model at the beginning of a session. """ def __init__(self, model_dir, initial_iter=0, is_training_action=True, vars_to_restore=None, **_unused): self.initial_iter = initial_iter self.vars_to_restore = vars_to_restore self.file_name_prefix = make_model_name(model_dir) # randomly initialise or restoring model if is_training_action and initial_iter == 0: SESS_STARTED.connect(self.rand_init_model) else: SESS_STARTED.connect(self.restore_model)
[docs] def rand_init_model(self, _sender, **_unused): """ Randomly initialising all trainable variables defined in the default session. :param _sender: :param _unused: :return: """ with tf.name_scope('Initialisation'): init_op = global_vars_init_or_restore() tf.get_default_session().run(init_op) tf.logging.info('Parameters from random initialisations ...')
[docs] def restore_model(self, _sender, **_unused): """ Loading checkpoint files as variable initialisations. :param _sender: :param _unused: :return: """ checkpoint = '{}-{}'.format(self.file_name_prefix, self.initial_iter) to_restore = None # tf.train.Saver's default value, restoring all if self.vars_to_restore: # partially restore (updating `to_restore` list) tf.logging.info("Finding variables to restore...") import re # Determine which vars to # restore using regex matching var_regex = re.compile(self.vars_to_restore) to_restore, to_randomise = [], [] for restorable in tf.global_variables(): if var_regex.search(restorable.name): to_restore.append(restorable) else: to_randomise.append(restorable) if not to_restore: tf.logging.fatal( 'vars_to_restore specified: %s, but nothing matched.', self.vars_to_restore) assert to_restore, 'Nothing to restore (--vars_to_restore)' var_names = [ # getting first three item to print var_restore.name for var_restore in to_restore[:3]] tf.logging.info( 'Restoring %s out of %s variables from %s: \n%s, ...', len(to_restore), len(tf.global_variables()), checkpoint, ',\n'.join(var_names)) # Initialize vars to randomize init_op = tf.variables_initializer(to_randomise) tf.get_default_session().run(init_op) try: saver = tf.train.Saver( var_list=to_restore, save_relative_paths=True) saver.restore(tf.get_default_session(), checkpoint) except tf.errors.NotFoundError: tf.logging.fatal( 'checkpoint %s not found or variables to restore do not ' 'match the current application graph', checkpoint) dir_name = os.path.dirname(checkpoint) if dir_name and not os.path.exists(dir_name): tf.logging.fatal( "Model folder not found %s, please check" "config parameter: model_dir", dir_name) raise
[docs]class ModelSaver(object): """ This class handles iteration events to save the model as checkpoint files. """ def __init__(self, model_dir, save_every_n=0, max_checkpoints=1, is_training_action=True, **_unused): self.save_every_n = save_every_n self.max_checkpoints = max_checkpoints self.file_name_prefix = make_model_name(model_dir) self.saver = None # initialise the saver after the graph finalised SESS_STARTED.connect(self.init_saver) # save the training model at a positive frequency if self.save_every_n > 0: ITER_FINISHED.connect(self.save_model_interval) # always save the final training model before exiting if is_training_action: SESS_FINISHED.connect(self.save_model)
[docs] def init_saver(self, _sender, **_unused): """ Initialise a model saver. :param _sender: :param _unused: :return: """ self.saver = tf.train.Saver( max_to_keep=self.max_checkpoints, save_relative_paths=True)
[docs] def save_model(self, _sender, **msg): """ Saving the model at the current iteration. :param _sender: :param msg: an iteration message instance :return: """ iter_i = msg['iter_msg'].current_iter if iter_i >= 0: self._save_at(iter_i)
[docs] def save_model_interval(self, _sender, **msg): """ Saving the model according to the frequency of ``save_every_n``. :param _sender: :param msg: an iteration message instance :return: """ if not msg['iter_msg'].is_training: return iter_i = msg['iter_msg'].current_iter if iter_i > 0 and iter_i % self.save_every_n == 0: self._save_at(iter_i)
def _save_at(self, iter_i): """ Saving the model at iter i and print a console log. : param iter_i: integer of the current iteration : return: """ if not self.saver: return self.saver.save(sess=tf.get_default_session(), save_path=self.file_name_prefix, global_step=iter_i) tf.logging.info('iter %d saved: %s', iter_i, self.file_name_prefix)