Source code for niftynet.layer.base_layer

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

from abc import ABCMeta, abstractmethod

import numpy as np
import tensorflow as tf
from six import with_metaclass

from niftynet.engine.application_variables import RESTORABLE

[docs]class Invertible(with_metaclass(ABCMeta, object)): """ interface of Invertible data """
[docs] @abstractmethod def inverse_op(self, *args, **kwargs): raise NotImplementedError
[docs]class Layer(object): def __init__(self, name='untitled_op'): = name self._op = tf.make_template(name, self.layer_op, create_scope_now_=True)
[docs] def layer_op(self, *args, **kwargs): msg = 'method \'layer_op\' in \'{}\''.format(type(self).__name__) tf.logging.fatal(msg) raise NotImplementedError
def __call__(self, *args, **kwargs): return self._op(*args, **kwargs) def __str__(self): return self.to_string()
[docs] def layer_scope(self): return self._op.variable_scope
[docs] def to_string(self): layer_scope_name = self.layer_scope().name out_str = "\033[42m[Layer]\033[0m {}".format(layer_scope_name) if not self._op._variables_created: out_str += ' \033[46m(input undecided)\033[0m' return out_str return out_str
[docs]class TrainableLayer(Layer): """ Extends the Layer object to have trainable parameters, adding initializers and regularizers. """ def __init__(self, name='trainable_op'): super(TrainableLayer, self).__init__(name=name) self._initializers = None self._regularizers = None
[docs] def trainable_variables(self): return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, self.layer_scope().name)
[docs] def restore_from_checkpoint(self, checkpoint_name, scope=None): if scope is None: scope = self.layer_scope().name tf.add_to_collection(RESTORABLE, (self.layer_scope().name, checkpoint_name, scope))
[docs] def regularizer_loss(self): return tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, self.layer_scope().name)
[docs] def num_trainable_params(self): n = tf.Dimension(0) for x in self.trainable_variables(): n += return int(n)
[docs] def to_string(self): out_str = Layer.to_string(self) # try to add trainable variable info to the string layer_variables = self.trainable_variables() if not layer_variables: return out_str # including name of parameters out_str += ' \033[92m[Trainable]\033[0m ' out_str += ', '.join( [':')[0][len(self.layer_scope().name) + 1:] for v in layer_variables]) # including number of parameters out_str += ' ({})'.format(self.num_trainable_params()) return out_str
@property def initializers(self): return self._initializers @property def regularizers(self): return self._regularizers @initializers.setter def initializers(self, value): assert isinstance(value, dict) self._initializers = value @regularizers.setter def regularizers(self, value): assert isinstance(value, dict) self._regularizers = value
[docs]class DataDependentLayer(Layer): """ Some layers require a one-pass training through the training set to determine their internal models, this abstract provides interfaces for training these internal models and querying the status. """ def __init__(self, name='data_dependent_op'): super(DataDependentLayer, self).__init__(name=name)
[docs] def is_ready(self): raise NotImplementedError
[docs] def train(self, *args, **kwargs): raise NotImplementedError
[docs]class RandomisedLayer(Layer): """ The layers require a randomisation process, to randomly change some of the layer's states on the fly. """ def __init__(self, name='randomised_op'): super(RandomisedLayer, self).__init__(name=name)
[docs] def randomise(self, *args, **kwargs): raise NotImplementedError
[docs]class LayerFromCallable(Layer): """ Module wrapping a function provided by the user. Analogous to snt.Module """ def __init__(self, layer_op, name='from_callable_op'): super(LayerFromCallable, self).__init__(name=name) if not callable(layer_op): tf.logging.fatal("layer_op must be callable.") raise TypeError self._layer_op = layer_op
[docs] def layer_op(self, *args, **kwargs): return self._layer_op(*args, **kwargs)