# -*- coding: utf-8 -*-
"""
Loading modules from a string representing the class name
or a short name that matches the dictionary item defined
in this module
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import importlib
import os
import tensorflow as tf
from niftynet.utilities.util_common import \
damerau_levenshtein_distance as edit_distance
# pylint: disable=too-few-public-methods
SUPPORTED_APP = {
'net_regress':
'niftynet.application.regression_application.RegressionApplication',
'net_segment':
'niftynet.application.segmentation_application.SegmentationApplication',
'net_autoencoder':
'niftynet.application.autoencoder_application.AutoencoderApplication',
'net_gan':
'niftynet.application.gan_application.GANApplication',
'net_classify':
'niftynet.application.classification_application.'
'ClassificationApplication',
}
SUPPORTED_NETWORK = {
# GAN
'simulator_gan':
'niftynet.network.simulator_gan.SimulatorGAN',
'simple_gan':
'niftynet.network.simple_gan.SimpleGAN',
# Segmentation
"highres3dnet":
'niftynet.network.highres3dnet.HighRes3DNet',
"highres3dnet_small":
'niftynet.network.highres3dnet_small.HighRes3DNetSmall',
"highres3dnet_large":
'niftynet.network.highres3dnet_large.HighRes3DNetLarge',
"toynet":
'niftynet.network.toynet.ToyNet',
"unet":
'niftynet.network.unet.UNet3D',
"nonewnet":
'niftynet.network.no_new_net.UNet3D',
"vnet":
'niftynet.network.vnet.VNet',
"dense_vnet":
'niftynet.network.dense_vnet.DenseVNet',
"deepmedic":
'niftynet.network.deepmedic.DeepMedic',
"scalenet":
'niftynet.network.scalenet.ScaleNet',
"holisticnet":
'niftynet.network.holistic_net.HolisticNet',
"unet_2d":
'niftynet.network.unet_2d.UNet2D',
# classification
"resnet": 'niftynet.network.resnet.ResNet',
"se_resnet": 'niftynet.network.se_resnet.SE_ResNet',
# autoencoder
"vae": 'niftynet.network.vae.VAE'
}
SUPPORTED_LOSS_GAN = {
'CrossEntropy': 'niftynet.layer.loss_gan.cross_entropy',
}
SUPPORTED_LOSS_SEGMENTATION = {
"CrossEntropy":
'niftynet.layer.loss_segmentation.cross_entropy',
"CrossEntropy_Dense":
'niftynet.layer.loss_segmentation.cross_entropy_dense',
"Dice":
'niftynet.layer.loss_segmentation.dice',
"Dice_NS":
'niftynet.layer.loss_segmentation.dice_nosquare',
"Dice_Dense":
'niftynet.layer.loss_segmentation.dice_dense',
"Dice_Dense_NS":
'niftynet.layer.loss_segmentation.dice_dense_nosquare',
"Tversky":
'niftynet.layer.loss_segmentation.tversky',
"GDSC":
'niftynet.layer.loss_segmentation.generalised_dice_loss',
"DicePlusXEnt":
'niftynet.layer.loss_segmentation.dice_plus_xent_loss',
"WGDL":
'niftynet.layer.loss_segmentation.generalised_wasserstein_dice_loss',
"SensSpec":
'niftynet.layer.loss_segmentation.sensitivity_specificity_loss',
"VolEnforcement":
'niftynet.layer.loss_segmentation.volume_enforcement',
# "L1Loss":
# 'niftynet.layer.loss_segmentation.l1_loss',
# "L2Loss":
# 'niftynet.layer.loss_segmentation.l2_loss',
# "Huber":
# 'niftynet.layer.loss_segmentation.huber_loss'
}
SUPPORTED_LOSS_REGRESSION = {
"L1Loss":
'niftynet.layer.loss_regression.l1_loss',
"L2Loss":
'niftynet.layer.loss_regression.l2_loss',
"RMSE":
'niftynet.layer.loss_regression.rmse_loss',
"MAE":
'niftynet.layer.loss_regression.mae_loss',
"Huber":
'niftynet.layer.loss_regression.huber_loss',
"SmoothL1":
'niftynet.layer.loss_regression.smooth_l1_loss',
"Cosine":
'niftynet.layer.loss_regression.cosine_loss'
}
SUPPORTED_LOSS_CLASSIFICATION = {
"CrossEntropy":
'niftynet.layer.loss_classification.cross_entropy',
}
SUPPORTED_LOSS_CLASSIFICATION_MULTI = {
"ConfusionMatrix":
'niftynet.layer.loss_classification_multi.loss_confusion_matrix',
"Variability":
'niftynet.layer.loss_classification_multi.loss_variability',
"Consistency":
'niftynet.layer.loss_classification_multi.rmse_consistency'
}
SUPPORTED_LOSS_AUTOENCODER = {
"VariationalLowerBound":
'niftynet.layer.loss_autoencoder.variational_lower_bound',
}
SUPPORTED_OPTIMIZERS = {
'adam': 'niftynet.engine.application_optimiser.Adam',
'gradientdescent': 'niftynet.engine.application_optimiser.GradientDescent',
'momentum': 'niftynet.engine.application_optimiser.Momentum',
'nesterov': 'niftynet.engine.application_optimiser.NesterovMomentum',
'adagrad': 'niftynet.engine.application_optimiser.Adagrad',
'rmsprop': 'niftynet.engine.application_optimiser.RMSProp',
}
SUPPORTED_INITIALIZATIONS = {
'constant': 'niftynet.engine.application_initializer.Constant',
'zeros': 'niftynet.engine.application_initializer.Zeros',
'ones': 'niftynet.engine.application_initializer.Ones',
'uniform_scaling':
'niftynet.engine.application_initializer.UniformUnitScaling',
'orthogonal': 'niftynet.engine.application_initializer.Orthogonal',
'variance_scaling':
'niftynet.engine.application_initializer.VarianceScaling',
'glorot_normal':
'niftynet.engine.application_initializer.GlorotNormal',
'glorot_uniform':
'niftynet.engine.application_initializer.GlorotUniform',
'he_normal': 'niftynet.engine.application_initializer.HeNormal',
'he_uniform': 'niftynet.engine.application_initializer.HeUniform'
}
SUPPORTED_EVALUATIONS = {
'dice': 'niftynet.evaluation.segmentation_evaluations.dice',
'jaccard': 'niftynet.evaluation.segmentation_evaluations.jaccard',
'Dice': 'niftynet.evaluation.segmentation_evaluations.dice',
'Jaccard': 'niftynet.evaluation.segmentation_evaluations.jaccard',
'n_pos_ref': 'niftynet.evaluation.segmentation_evaluations.n_pos_ref',
'n_neg_ref': 'niftynet.evaluation.segmentation_evaluations.n_neg_ref',
'n_pos_seg': 'niftynet.evaluation.segmentation_evaluations.n_pos_seg',
'n_neg_seg': 'niftynet.evaluation.segmentation_evaluations.n_neg_seg',
'fp': 'niftynet.evaluation.segmentation_evaluations.fp',
'fn': 'niftynet.evaluation.segmentation_evaluations.fn',
'tp': 'niftynet.evaluation.segmentation_evaluations.tp',
'tn': 'niftynet.evaluation.segmentation_evaluations.tn',
'n_intersection': 'niftynet.evaluation.segmentation_evaluations'
'.n_intersection',
'n_union': 'niftynet.evaluation.segmentation_evaluations.n_union',
'specificity': 'niftynet.evaluation.segmentation_evaluations.specificity',
'sensitivity': 'niftynet.evaluation.segmentation_evaluations.sensitivity',
'accuracy': 'niftynet.evaluation.segmentation_evaluations.accuracy',
'false_positive_rate': 'niftynet.evaluation.segmentation_evaluations'
'.false_positive_rate',
'positive_predictive_values': 'niftynet.evaluation.segmentation_evaluations'
'.positive_predictive_values',
'negative_predictive_values': 'niftynet.evaluation.segmentation_evaluations'
'.negative_predictive_values',
'intersection_over_union': 'niftynet.evaluation.segmentation_evaluations'
'.intersection_over_union',
'informedness': 'niftynet.evaluation.segmentation_evaluations.informedness',
'markedness': 'niftynet.evaluation.segmentation_evaluations.markedness',
'vol_diff': 'niftynet.evaluation.segmentation_evaluations.vol_diff',
'average_distance': 'niftynet.evaluation.segmentation_evaluations'
'.average_distance',
'hausdorff_distance': 'niftynet.evaluation.segmentation_evaluations'
'.hausdorff_distance',
'hausdorff95_distance': 'niftynet.evaluation.segmentation_evaluations'
'.hausdorff95_distance',
'com_ref': 'niftynet.contrib.evaluation.segmentation_evaluations.com_ref',
'mse': 'niftynet.evaluation.regression_evaluations.mse',
'rmse': 'niftynet.evaluation.regression_evaluations.rmse',
'mae': 'niftynet.evaluation.regression_evaluations.mae',
# 'r2': 'niftynet.contrib.evaluation.regression_evaluations.r2',
'classification_accuracy': 'niftynet.evaluation.classification_evaluations'
'.accuracy',
'roc_auc': 'niftynet.contrib.evaluation.classification_evaluations.roc_auc',
'roc': 'niftynet.contrib.evaluation.classification_evaluations.roc',
}
SUPPORTED_EVENT_HANDLERS = {
'model_restorer':
'niftynet.engine.handler_model.ModelRestorer',
'model_saver':
'niftynet.engine.handler_model.ModelSaver',
'sampler_threading':
'niftynet.engine.handler_sampler.SamplerThreading',
'apply_gradients':
'niftynet.engine.handler_gradient.ApplyGradients',
'output_interpreter':
'niftynet.engine.handler_network_output.OutputInterpreter',
'console_logger':
'niftynet.engine.handler_console.ConsoleLogger',
'tensorboard_logger':
'niftynet.engine.handler_tensorboard.TensorBoardLogger',
'performance_logger':
'niftynet.engine.handler_performance.PerformanceLogger',
'early_stopper':
'niftynet.engine.handler_early_stopping.EarlyStopper',
}
SUPPORTED_ITERATION_GENERATORS = {
'iteration_generator':
'niftynet.engine.application_iteration.IterationMessageGenerator'
}
[docs]def select_module(module_name, type_str, lookup_table=None):
"""
This function first tries to find the absolute module name
by matching the static dictionary items, if not found, it
tries to import the module by splitting the input ``module_name``
as module name and class name to be imported.
:param module_name: string that matches the keys defined in lookup_table
or an absolute class name: module.name.ClassName
:param type_str: type of the module (used for better error display)
:param lookup_table: defines a set of shorthands for absolute class name
"""
lookup_table = lookup_table or {}
module_name = '{}'.format(module_name)
is_external = True
if module_name in lookup_table:
module_name = lookup_table[module_name]
is_external = False
module_str, class_name = None, None
try:
module_str, class_name = module_name.rsplit('.', 1)
the_module = importlib.import_module(module_str)
the_class = getattr(the_module, class_name)
if is_external:
# print location of external module
tf.logging.info('Import [%s] from %s.',
class_name, os.path.abspath(the_module.__file__))
return the_class
except (AttributeError, ValueError, ImportError) as not_imported:
tf.logging.fatal(repr(not_imported))
if '.' not in module_name:
err = 'Could not import {}: ' \
'Incorrect module name "{}"; ' \
'expected "module.object".'.format(type_str, module_name)
else:
err = '{}: Could not import object' \
'"{}" from "{}"'.format(type_str, class_name, module_str)
tf.logging.fatal(err)
if not lookup_table:
# no further guess
raise ValueError(err)
dists = dict(
(k, edit_distance(k, module_name)) for k in list(lookup_table))
closest = min(dists, key=dists.get)
if dists[closest] <= 3:
err = 'Could not import {2}: By "{0}", ' \
'did you mean "{1}"?\n "{0}" is ' \
'not a valid option. '.format(module_name, closest, type_str)
tf.logging.fatal(err)
raise ValueError(err)
[docs]class ModuleFactory(object):
"""
General interface for importing a class by its name.
"""
SUPPORTED = None
type_str = 'object'
[docs] @classmethod
def create(cls, name):
"""
import a class by name
"""
return select_module(name, cls.type_str, cls.SUPPORTED)
[docs]class ApplicationNetFactory(ModuleFactory):
"""
Import a network from ``niftynet.network`` or from user specified string
"""
SUPPORTED = SUPPORTED_NETWORK
type_str = 'network'
[docs]class ApplicationFactory(ModuleFactory):
"""
Import an application from ``niftynet.application`` or
from user specified string
"""
SUPPORTED = SUPPORTED_APP
type_str = 'application'
[docs]class LossGANFactory(ModuleFactory):
"""
Import a GAN loss function from ``niftynet.layer`` or
from user specified string
"""
SUPPORTED = SUPPORTED_LOSS_GAN
type_str = 'GAN loss'
[docs]class LossSegmentationFactory(ModuleFactory):
"""
Import a segmentation loss function from ``niftynet.layer`` or
from user specified string
"""
SUPPORTED = SUPPORTED_LOSS_SEGMENTATION
type_str = 'segmentation loss'
[docs]class LossRegressionFactory(ModuleFactory):
"""
Import a regression loss function from ``niftynet.layer`` or
from user specified string
"""
SUPPORTED = SUPPORTED_LOSS_REGRESSION
type_str = 'regression loss'
[docs]class LossClassificationFactory(ModuleFactory):
"""
Import a classification loss function from niftynet.layer or
from user specified string
"""
SUPPORTED = SUPPORTED_LOSS_CLASSIFICATION
type_str = 'classification loss'
[docs]class LossClassificationMultiFactory(ModuleFactory):
"""
Import a classification loss function from niftynet.layer or
from user specified string
"""
SUPPORTED = SUPPORTED_LOSS_CLASSIFICATION_MULTI
type_str = 'classification multi loss'
[docs]class LossAutoencoderFactory(ModuleFactory):
"""
Import an autoencoder loss function from ``niftynet.layer`` or
from user specified string
"""
SUPPORTED = SUPPORTED_LOSS_AUTOENCODER
type_str = 'autoencoder loss'
[docs]class OptimiserFactory(ModuleFactory):
"""
Import an optimiser from ``niftynet.engine.application_optimiser`` or
from user specified string
"""
SUPPORTED = SUPPORTED_OPTIMIZERS
type_str = 'optimizer'
[docs]class InitializerFactory(ModuleFactory):
"""
Import an initializer from ``niftynet.engine.application_initializer`` or
from user specified string
"""
SUPPORTED = SUPPORTED_INITIALIZATIONS
type_str = 'initializer'
[docs] @staticmethod
def get_initializer(name, args=None):
"""
wrapper for getting the initializer.
:param name:
:param args: optional parameters for the initializer
:return:
"""
init_class = InitializerFactory.create(name)
if args is None:
args = {}
return init_class.get_instance(args)
[docs]class EvaluationFactory(ModuleFactory):
"""
Import an optimiser from niftynet.engine.application_optimiser or
from user specified string
"""
SUPPORTED = SUPPORTED_EVALUATIONS
type_str = 'evaluation'
[docs]class EventHandlerFactory(ModuleFactory):
"""
Import an event handler such as niftynet.engine.handler_console
"""
SUPPORTED = SUPPORTED_EVENT_HANDLERS
type_str = 'event handler'
[docs]class IteratorFactory(ModuleFactory):
"""
Import an iterative message generator for the main engine loop
"""
SUPPORTED = SUPPORTED_ITERATION_GENERATORS
type_str = 'engine iterator'