Source code for niftynet.utilities.restore_initializer

# -*- coding: utf-8 -*-

# Copyright 2018 The Sonnet Authors. All Rights Reserved.
# Modifications copyright 2018 The NiftyNet Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

"""A checkpoint-restoring Tensorflow initializer."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import io_ops


# Dependency imports


class _Restore(init_ops.Initializer):
    """Initializer that restores tensors from a checkpoint."""

    def __init__(self, filename, var_name, scope=None):
        """Construct a new restoring initializer.
        Will read from the checkpoint from the SSTables file `filename` using
        the RestoreV2 Tensorflow op.
        The actual variable read from the checkpoint will be
        `scope_name` + '/' + `var_name` (or just `var_name` if `scope_name` is
        empty), where `scope_name` is given by one of
        (1) The current scope's name at the point where the initializer gets called,
            if the `scope` argument to this constructor is None,
        (2) If `scope` is callable, the result of applying it to the current scope's
            name,
        (3) Otherwise, the `scope` argument to this constructor itself.
        Args:
          filename: Name of an SSTables entry where the checkpoint is hosted.
          var_name: Name of the variable to restore.
          scope: The variable scope's name of the variable to restore, see above.
        """
        self._filename = filename
        self._var_name = var_name
        self._scope = scope

    def _partition_spec(self, shape, partition_info):
        """Build magic (and sparsely documented) shapes_and_slices spec string."""
        if partition_info is None:
            return ''  # Empty string indicates a non-partitioned tensor.
        ssi = tf.Variable.SaveSliceInfo(
            full_name=self._var_name,
            full_shape=partition_info.full_shape,
            var_offset=partition_info.var_offset,
            var_shape=shape)
        return ssi.spec

    def __call__(self, shape, dtype=None, partition_info=None):
        # Creating different RestoreV2 ops when a single one could
        # output several tensors seems inefficient, but that's actually
        # what tf.Saver.restore_op (via tf.BaseSaverBuilder) does too.
        if self._scope is None:
            scope_name = tf.get_variable_scope().name
        elif callable(self._scope):
            scope_name = self._scope(tf.get_variable_scope().name)
        else:
            scope_name = self._scope
        tensor_name = self._var_name
        if scope_name:
            tensor_name = '{}/{}'.format(scope_name, tensor_name)
        tensor = io_ops.restore_v2(
            self._filename,
            [tensor_name],
            [self._partition_spec(shape, partition_info)],
            [dtype])[0]
        tensor.set_shape(shape)
        return tensor


# pylint: disable=invalid-name
restore_initializer = _Restore
# pylint: enable=invalid-name