Source code for niftynet.layer.histogram_normalisation

# -*- coding: utf-8 -*-
"""
This class computes histogram based normalisation. A `training`
process is first used to find an averaged histogram mapping
from all training volumes.  This layer maintains the mapping array,
and the layer_op maps the intensity of new volumes to a normalised version.
The histogram is computed from foreground if a definition is provided for
foreground (by `binary_masking_func` or a `mask` matrix)
"""
from __future__ import absolute_import, print_function, division

import os

import numpy as np
import tensorflow as tf

import niftynet.utilities.histogram_standardisation as hs
from niftynet.layer.base_layer import DataDependentLayer
from niftynet.layer.binary_masking import BinaryMaskingLayer


[docs]class HistogramNormalisationLayer(DataDependentLayer):
[docs] def __init__(self, image_name, modalities, model_filename=None, binary_masking_func=None, norm_type='percentile', cutoff=(0.05, 0.95), name='hist_norm'): """ :param image_name: :param modalities: :param model_filename: :param binary_masking_func: set to None for global mapping :param norm_type: :param cutoff: :param name: """ super(HistogramNormalisationLayer, self).__init__(name=name) if model_filename is None: model_filename = os.path.join('.', 'histogram_ref_file.txt') self.model_file = os.path.abspath(model_filename) assert not os.path.isdir(self.model_file), \ "model_filename is a directory, " \ "please change histogram_ref_file to a filename." if binary_masking_func: assert isinstance(binary_masking_func, BinaryMaskingLayer) self.binary_masking_func = binary_masking_func else: self.binary_masking_func = None self.norm_type = norm_type self.cutoff = cutoff # mapping is a complete cache of the model file, the total number of # modalities are listed in self.modalities tuple self.image_name = image_name self.modalities = modalities self.mapping = hs.read_mapping_file(self.model_file)
[docs] def layer_op(self, image, mask=None): assert self.is_ready(), \ "histogram normalisation layer needs to be trained first." if isinstance(image, dict): image_5d = np.asarray(image[self.image_name], dtype=np.float32) else: image_5d = np.asarray(image, dtype=np.float32) if isinstance(mask, dict): image_mask = mask.get(self.image_name, None) elif mask is not None: image_mask = mask elif self.binary_masking_func is not None: image_mask = self.binary_masking_func(image_5d) else: # no access to mask, default to all image image_mask = np.ones_like(image_5d, dtype=np.bool) normalised = self._normalise_5d(image_5d, image_mask) if isinstance(image, dict): image[self.image_name] = normalised if isinstance(mask, dict): mask[self.image_name] = image_mask else: mask = {self.image_name: image_mask} return image, mask else: return normalised, image_mask
def __check_modalities_to_train(self): modalities_to_train = [mod for mod in self.modalities if mod not in self.mapping] return set(modalities_to_train)
[docs] def is_ready(self): mod_to_train = self.__check_modalities_to_train() return False if mod_to_train else True
[docs] def train(self, image_list): # check modalities to train, using the first subject in subject list # to find input modality list if self.is_ready(): tf.logging.info( "normalisation histogram reference models ready" " for {}:{}".format(self.image_name, self.modalities)) return mod_to_train = self.__check_modalities_to_train() tf.logging.info( "training normalisation histogram references " "for {}:{}, using {} subjects".format( self.image_name, mod_to_train, len(image_list))) trained_mapping = hs.create_mapping_from_multimod_arrayfiles( image_list, self.image_name, self.modalities, mod_to_train, self.cutoff, self.binary_masking_func) # merging trained_mapping dict and self.mapping dict self.mapping.update(trained_mapping) all_maps = hs.read_mapping_file(self.model_file) all_maps.update(self.mapping) hs.write_all_mod_mapping(self.model_file, all_maps)
def _normalise_5d(self, data_array, mask_array): assert self.modalities assert data_array.ndim == 5 assert data_array.shape[4] <= len(self.modalities) if not self.mapping: tf.logging.fatal( "calling normaliser with empty mapping," "probably {} is not loaded".format(self.model_file)) raise RuntimeError mask_array = np.asarray(mask_array, dtype=np.bool) for mod_id, mod_name in enumerate(self.modalities): if not np.any(data_array[..., mod_id]): continue # missing modality data_array[..., mod_id] = self.__normalise( data_array[..., mod_id], mask_array[..., mod_id], self.mapping[mod_name]) return data_array def __normalise(self, img_data, mask, mapping): return hs.transform_by_mapping( img_data, mask, mapping, self.cutoff, self.norm_type)