Source code for niftynet.contrib.segmentation_bf_aug.segmentation_application_bfaug

import tensorflow as tf

from niftynet.application.segmentation_application import \
    SegmentationApplication, SUPPORTED_INPUT
from niftynet.io.image_reader import ImageReader
from niftynet.layer.binary_masking import BinaryMaskingLayer
from niftynet.layer.discrete_label_normalisation import \
    DiscreteLabelNormalisationLayer
from niftynet.layer.histogram_normalisation import \
    HistogramNormalisationLayer
from niftynet.layer.mean_variance_normalisation import \
    MeanVarNormalisationLayer
from niftynet.layer.pad import PadLayer
from niftynet.layer.rand_bias_field import RandomBiasFieldLayer
from niftynet.layer.rand_flip import RandomFlipLayer
from niftynet.layer.rand_rotation import RandomRotationLayer
from niftynet.layer.rand_spatial_scaling import RandomSpatialScalingLayer


[docs]class SegmentationApplicationBFAug(SegmentationApplication): REQUIRED_CONFIG_SECTION = "SEGMENTATION" def __init__(self, net_param, action_param, is_training): SegmentationApplication.__init__( self, net_param, action_param, is_training) tf.logging.info('starting segmentation application')
[docs] def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.segmentation_param = task_param # initialise input image readers if self.is_training: reader_names = ('image', 'label', 'weight', 'sampler') elif self.is_inference: # in the inference process use `image` input only reader_names = ('image',) elif self.is_evaluation: reader_names = ('image', 'label', 'inferred') else: tf.logging.fatal( 'Action `%s` not supported. Expected one of %s', self.action, self.SUPPORTED_PHASES) raise ValueError try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by( phase=reader_phase, action=self.action) self.readers = [ ImageReader(reader_names).initialise( data_param, task_param, file_list) for file_list in file_lists] foreground_masking_layer = None if self.net_param.normalise_foreground_only: foreground_masking_layer = BinaryMaskingLayer( type_str=self.net_param.foreground_type, multimod_fusion=self.net_param.multimod_foreground_type, threshold=0.0) mean_var_normaliser = MeanVarNormalisationLayer( image_name='image', binary_masking_func=foreground_masking_layer) histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, binary_masking_func=foreground_masking_layer, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') label_normaliser = None if self.net_param.histogram_ref_file: label_normaliser = DiscreteLabelNormalisationLayer( image_name='label', modalities=vars(task_param).get('label'), model_filename=self.net_param.histogram_ref_file) normalisation_layers = [] if self.net_param.normalisation: normalisation_layers.append(histogram_normaliser) if self.net_param.whitening: normalisation_layers.append(mean_var_normaliser) if task_param.label_normalisation and \ (self.is_training or not task_param.output_prob): normalisation_layers.append(label_normaliser) augmentation_layers = [] if self.is_training: if self.action_param.random_flipping_axes != -1: augmentation_layers.append(RandomFlipLayer( flip_axes=self.action_param.random_flipping_axes)) if self.action_param.scaling_percentage: augmentation_layers.append(RandomSpatialScalingLayer( min_percentage=self.action_param.scaling_percentage[0], max_percentage=self.action_param.scaling_percentage[1])) if self.action_param.rotation_angle or \ self.action_param.rotation_angle_x or \ self.action_param.rotation_angle_y or \ self.action_param.rotation_angle_z: rotation_layer = RandomRotationLayer() if self.action_param.rotation_angle: rotation_layer.init_uniform_angle( self.action_param.rotation_angle) else: rotation_layer.init_non_uniform_angle( self.action_param.rotation_angle_x, self.action_param.rotation_angle_y, self.action_param.rotation_angle_z) augmentation_layers.append(rotation_layer) if self.action_param.bias_field_range: bias_field_layer = RandomBiasFieldLayer() bias_field_layer.init_order(self.action_param.bf_order) bias_field_layer.init_uniform_coeff( self.action_param.bias_field_range) augmentation_layers.append(bias_field_layer) volume_padding_layer = [PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size, mode=self.net_param.volume_padding_mode, pad_to=self.net_param.volume_padding_to_size) ] self.readers[0].add_preprocessing_layers( volume_padding_layer + normalisation_layers + augmentation_layers) for reader in self.readers[1:]: reader.add_preprocessing_layers( volume_padding_layer + normalisation_layers)