Source code for niftynet.contrib.segmentation_bf_aug.segmentation_application_bfaug
import tensorflow as tf
from niftynet.application.segmentation_application import \
SegmentationApplication, SUPPORTED_INPUT
from niftynet.io.image_reader import ImageReader
from niftynet.layer.binary_masking import BinaryMaskingLayer
from niftynet.layer.discrete_label_normalisation import \
DiscreteLabelNormalisationLayer
from niftynet.layer.histogram_normalisation import \
HistogramNormalisationLayer
from niftynet.layer.mean_variance_normalisation import \
MeanVarNormalisationLayer
from niftynet.layer.pad import PadLayer
from niftynet.layer.rand_bias_field import RandomBiasFieldLayer
from niftynet.layer.rand_flip import RandomFlipLayer
from niftynet.layer.rand_rotation import RandomRotationLayer
from niftynet.layer.rand_spatial_scaling import RandomSpatialScalingLayer
[docs]class SegmentationApplicationBFAug(SegmentationApplication):
REQUIRED_CONFIG_SECTION = "SEGMENTATION"
def __init__(self, net_param, action_param, is_training):
SegmentationApplication.__init__(
self, net_param, action_param, is_training)
tf.logging.info('starting segmentation application')
[docs] def initialise_dataset_loader(
self, data_param=None, task_param=None, data_partitioner=None):
self.data_param = data_param
self.segmentation_param = task_param
# initialise input image readers
if self.is_training:
reader_names = ('image', 'label', 'weight', 'sampler')
elif self.is_inference:
# in the inference process use `image` input only
reader_names = ('image',)
elif self.is_evaluation:
reader_names = ('image', 'label', 'inferred')
else:
tf.logging.fatal(
'Action `%s` not supported. Expected one of %s',
self.action, self.SUPPORTED_PHASES)
raise ValueError
try:
reader_phase = self.action_param.dataset_to_infer
except AttributeError:
reader_phase = None
file_lists = data_partitioner.get_file_lists_by(
phase=reader_phase, action=self.action)
self.readers = [
ImageReader(reader_names).initialise(
data_param, task_param, file_list) for file_list in file_lists]
foreground_masking_layer = None
if self.net_param.normalise_foreground_only:
foreground_masking_layer = BinaryMaskingLayer(
type_str=self.net_param.foreground_type,
multimod_fusion=self.net_param.multimod_foreground_type,
threshold=0.0)
mean_var_normaliser = MeanVarNormalisationLayer(
image_name='image', binary_masking_func=foreground_masking_layer)
histogram_normaliser = None
if self.net_param.histogram_ref_file:
histogram_normaliser = HistogramNormalisationLayer(
image_name='image',
modalities=vars(task_param).get('image'),
model_filename=self.net_param.histogram_ref_file,
binary_masking_func=foreground_masking_layer,
norm_type=self.net_param.norm_type,
cutoff=self.net_param.cutoff,
name='hist_norm_layer')
label_normaliser = None
if self.net_param.histogram_ref_file:
label_normaliser = DiscreteLabelNormalisationLayer(
image_name='label',
modalities=vars(task_param).get('label'),
model_filename=self.net_param.histogram_ref_file)
normalisation_layers = []
if self.net_param.normalisation:
normalisation_layers.append(histogram_normaliser)
if self.net_param.whitening:
normalisation_layers.append(mean_var_normaliser)
if task_param.label_normalisation and \
(self.is_training or not task_param.output_prob):
normalisation_layers.append(label_normaliser)
augmentation_layers = []
if self.is_training:
if self.action_param.random_flipping_axes != -1:
augmentation_layers.append(RandomFlipLayer(
flip_axes=self.action_param.random_flipping_axes))
if self.action_param.scaling_percentage:
augmentation_layers.append(RandomSpatialScalingLayer(
min_percentage=self.action_param.scaling_percentage[0],
max_percentage=self.action_param.scaling_percentage[1]))
if self.action_param.rotation_angle or \
self.action_param.rotation_angle_x or \
self.action_param.rotation_angle_y or \
self.action_param.rotation_angle_z:
rotation_layer = RandomRotationLayer()
if self.action_param.rotation_angle:
rotation_layer.init_uniform_angle(
self.action_param.rotation_angle)
else:
rotation_layer.init_non_uniform_angle(
self.action_param.rotation_angle_x,
self.action_param.rotation_angle_y,
self.action_param.rotation_angle_z)
augmentation_layers.append(rotation_layer)
if self.action_param.bias_field_range:
bias_field_layer = RandomBiasFieldLayer()
bias_field_layer.init_order(self.action_param.bf_order)
bias_field_layer.init_uniform_coeff(
self.action_param.bias_field_range)
augmentation_layers.append(bias_field_layer)
volume_padding_layer = [PadLayer(
image_name=SUPPORTED_INPUT,
border=self.net_param.volume_padding_size,
mode=self.net_param.volume_padding_mode,
pad_to=self.net_param.volume_padding_to_size)
]
self.readers[0].add_preprocessing_layers(
volume_padding_layer + normalisation_layers + augmentation_layers)
for reader in self.readers[1:]:
reader.add_preprocessing_layers(
volume_padding_layer + normalisation_layers)