Source code for niftynet.contrib.regression_weighted_sampler.isample_regression

import os

import tensorflow as tf

from niftynet.application.regression_application import \
    RegressionApplication, SUPPORTED_INPUT
from niftynet.engine.sampler_uniform import UniformSampler
from niftynet.engine.sampler_weighted import WeightedSampler
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.io.image_reader import ImageReader
from niftynet.layer.histogram_normalisation import \
    HistogramNormalisationLayer
from niftynet.layer.mean_variance_normalisation import \
    MeanVarNormalisationLayer
from niftynet.layer.pad import PadLayer


[docs]class ISampleRegression(RegressionApplication): #def initialise_weighted_sampler(self): # if len(self.readers) == 2: # training_sampler = WeightedSampler( # reader=self.readers[0], # data_param=self.data_param, # batch_size=self.net_param.batch_size, # windows_per_image=self.action_param.sample_per_volume, # queue_length=self.net_param.queue_length) # validation_sampler = UniformSampler( # reader=self.readers[1], # data_param=self.data_param, # batch_size=self.net_param.batch_size, # windows_per_image=self.action_param.sample_per_volume, # queue_length=self.net_param.queue_length) # self.sampler = [[training_sampler, validation_sampler]] # else: # RegressionApplication.initialise_weighted_sampler()
[docs] def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): RegressionApplication.initialise_dataset_loader( self, data_param, task_param, data_partitioner) if self.is_training: return if not task_param.error_map: return file_lists = self.get_file_lists(data_partitioner) # modifying the original readers in regression application # as we need ground truth labels to generate error maps self.readers=[] for file_list in file_lists: reader = ImageReader(['image', 'output']) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') histogram_normaliser = None if self.net_param.histogram_ref_file: histogram_normaliser = HistogramNormalisationLayer( image_name='image', modalities=vars(task_param).get('image'), model_filename=self.net_param.histogram_ref_file, norm_type=self.net_param.norm_type, cutoff=self.net_param.cutoff, name='hist_norm_layer') preprocessors = [] if self.net_param.normalisation: preprocessors.append(histogram_normaliser) if self.net_param.whitening: preprocessors.append(mean_var_normaliser) if self.net_param.volume_padding_size: preprocessors.append(PadLayer( image_name=SUPPORTED_INPUT, border=self.net_param.volume_padding_size)) self.readers[0].add_preprocessing_layers(preprocessors)
[docs] def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): if self.is_training: # using the original training pipeline RegressionApplication.connect_data_and_network( self, outputs_collector, gradients_collector) else: init_aggregator = \ self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2] init_aggregator() # modifying the original pipeline so that # the error maps are computed instead of the regression output with tf.name_scope('validation'): data_dict = self.get_sampler()[0][-1].pop_batch_op() image = tf.cast(data_dict['image'], tf.float32) net_out = self.net(image, is_training=self.is_training) if self.regression_param.error_map: # writing error maps to folder without prefix error_map_folder = os.path.join( os.path.dirname(self.action_param.save_seg_dir), 'error_maps') self.output_decoder.output_path = error_map_folder self.output_decoder.prefix = '' # computes absolute error target = tf.cast(data_dict['output'], tf.float32) net_out = tf.squared_difference(target, net_out) # window output and locations for aggregating volume results outputs_collector.add_to_collection( var=net_out, name='window', average_over_devices=False, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=False, collection=NETWORK_OUTPUT)