Source code for niftynet.contrib.regression_weighted_sampler.isample_regression
import os
import tensorflow as tf
from niftynet.application.regression_application import \
RegressionApplication, SUPPORTED_INPUT
from niftynet.engine.sampler_uniform_v2 import UniformSampler
from niftynet.engine.sampler_weighted_v2 import WeightedSampler
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.io.image_reader import ImageReader
from niftynet.layer.histogram_normalisation import \
HistogramNormalisationLayer
from niftynet.layer.mean_variance_normalisation import \
MeanVarNormalisationLayer
from niftynet.layer.pad import PadLayer
[docs]class ISampleRegression(RegressionApplication):
#def initialise_weighted_sampler(self):
# if len(self.readers) == 2:
# training_sampler = WeightedSampler(
# reader=self.readers[0],
# window_sizes=self.data_param,
# batch_size=self.net_param.batch_size,
# windows_per_image=self.action_param.sample_per_volume,
# queue_length=self.net_param.queue_length)
# validation_sampler = UniformSampler(
# reader=self.readers[1],
# window_sizes=self.data_param,
# batch_size=self.net_param.batch_size,
# windows_per_image=self.action_param.sample_per_volume,
# queue_length=self.net_param.queue_length)
# self.sampler = [[training_sampler, validation_sampler]]
# else:
# RegressionApplication.initialise_weighted_sampler()
[docs] def initialise_dataset_loader(
self, data_param=None, task_param=None, data_partitioner=None):
RegressionApplication.initialise_dataset_loader(
self, data_param, task_param, data_partitioner)
if self.is_training:
return
if not task_param.error_map:
# use the regression application implementation
return
try:
reader_phase = self.action_param.dataset_to_infer
except AttributeError:
reader_phase = None
file_lists = data_partitioner.get_file_lists_by(
phase=reader_phase, action=self.action)
# modifying the original readers in regression application
# as we need ground truth labels to generate error maps
self.readers = [
ImageReader(['image', 'output']).initialise(
data_param, task_param, file_list) for file_list in file_lists]
mean_var_normaliser = MeanVarNormalisationLayer(image_name='image')
histogram_normaliser = None
if self.net_param.histogram_ref_file:
histogram_normaliser = HistogramNormalisationLayer(
image_name='image',
modalities=vars(task_param).get('image'),
model_filename=self.net_param.histogram_ref_file,
norm_type=self.net_param.norm_type,
cutoff=self.net_param.cutoff,
name='hist_norm_layer')
preprocessors = []
if self.net_param.normalisation:
preprocessors.append(histogram_normaliser)
if self.net_param.whitening:
preprocessors.append(mean_var_normaliser)
if self.net_param.volume_padding_size:
preprocessors.append(PadLayer(
image_name=SUPPORTED_INPUT,
border=self.net_param.volume_padding_size))
self.readers[0].add_preprocessing_layers(preprocessors)
[docs] def connect_data_and_network(self,
outputs_collector=None,
gradients_collector=None):
if self.is_training:
# using the original training pipeline
RegressionApplication.connect_data_and_network(
self, outputs_collector, gradients_collector)
else:
init_aggregator = \
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]
init_aggregator()
# modifying the original pipeline so that
# the error maps are computed instead of the regression output
with tf.name_scope('validation'):
data_dict = self.get_sampler()[0][-1].pop_batch_op()
image = tf.cast(data_dict['image'], tf.float32)
net_out = self.net(image, is_training=self.is_training)
if self.regression_param.error_map:
# writing error maps to folder without prefix
error_map_folder = os.path.join(
os.path.dirname(self.action_param.save_seg_dir),
'error_maps')
self.output_decoder.output_path = error_map_folder
self.output_decoder.prefix = ''
# computes absolute error
target = tf.cast(data_dict['output'], tf.float32)
net_out = tf.squared_difference(target, net_out)
# window output and locations for aggregating volume results
outputs_collector.add_to_collection(
var=net_out, name='window',
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=data_dict['image_location'], name='location',
average_over_devices=False, collection=NETWORK_OUTPUT)