# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.application.base_application import BaseApplication
from niftynet.engine.application_factory import \
ApplicationNetFactory, InitializerFactory, OptimiserFactory
from niftynet.io.image_reader import ImageReader
from niftynet.engine.sampler_grid_v2 import GridSampler
from niftynet.engine.sampler_resize_v2 import ResizeSampler
from niftynet.engine.sampler_uniform_v2 import UniformSampler
from niftynet.engine.sampler_weighted_v2 import WeightedSampler
from niftynet.engine.sampler_balanced_v2 import BalancedSampler
from niftynet.engine.windows_aggregator_grid import GridSamplesAggregator
from niftynet.engine.windows_aggregator_resize import ResizeSamplesAggregator
from niftynet.engine.windows_aggregator_identity import WindowAsImageAggregator
# from niftynet.engine.windows_aggregator_classifier import ClassifierSamplesAggregator
from niftynet.layer.loss_segmentation import LossFunction
from niftynet.layer.post_processing import PostProcessingLayer
from niftynet.engine.application_variables import CONSOLE, TF_SUMMARIES, NETWORK_OUTPUT
SUPPORTED_INPUT = set(['image', 'label', 'weight', 'sampler', 'inferred'])
[docs]class MultiOutputApplication(BaseApplication):
REQUIRED_CONFIG_SECTION = "SEGMENTATION"
def __init__(self, net_param, action_param, action):
BaseApplication.__init__(self)
tf.logging.info('starting multioutput test')
self.action = action
self.net_param = net_param
self.action_param = action_param
self.data_param = None
self.multioutput_param = None
self.SUPPORTED_SAMPLING = {
'uniform': (self.initialise_uniform_sampler,
self.initialise_grid_sampler,
self.initialise_grid_aggregator),
'weighted': (self.initialise_weighted_sampler,
self.initialise_grid_sampler,
self.initialise_grid_aggregator),
'resize': (self.initialise_resize_sampler,
self.initialise_resize_sampler,
self.initialise_resize_aggregator),
'classifier': (self.initialise_resize_sampler,
self.initialise_resize_sampler,
self.initialise_classifier_aggregator),
'identity': (self.initialise_uniform_sampler,
self.initialise_resize_sampler,
self.initialise_identity_aggregator)
}
[docs] def initialise_dataset_loader(
self, data_param=None, task_param=None, data_partitioner=None):
self.data_param = data_param
self.multioutput_param = task_param
# initialise input image readers
if self.is_training:
reader_names = ('image', 'label', 'weight', 'sampler')
elif self.is_inference:
# in the inference process use `image` input only
reader_names = ('image',)
elif self.is_evaluation:
reader_names = ('image', 'label', 'inferred')
else:
tf.logging.fatal(
'Action `%s` not supported. Expected one of %s',
self.action, self.SUPPORTED_PHASES)
raise ValueError
try:
reader_phase = self.action_param.dataset_to_infer
except AttributeError:
reader_phase = None
file_lists = data_partitioner.get_file_lists_by(
phase=reader_phase, action=self.action)
self.readers = [
ImageReader(reader_names).initialise(
data_param, task_param, file_list) for file_list in file_lists]
[docs] def initialise_weighted_sampler(self):
self.sampler = [[WeightedSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
windows_per_image=self.action_param.sample_per_volume,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
[docs] def initialise_resize_sampler(self):
self.sampler = [[ResizeSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
shuffle=self.is_training,
smaller_final_batch_mode=self.net_param.smaller_final_batch_mode,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
[docs] def initialise_grid_sampler(self):
self.sampler = [[GridSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
spatial_window_size=self.action_param.spatial_window_size,
window_border=self.action_param.border,
smaller_final_batch_mode=self.net_param.smaller_final_batch_mode,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
[docs] def initialise_balanced_sampler(self):
self.sampler = [[BalancedSampler(
reader=reader,
window_sizes=self.data_param,
batch_size=self.net_param.batch_size,
windows_per_image=self.action_param.sample_per_volume,
queue_length=self.net_param.queue_length) for reader in
self.readers]]
[docs] def initialise_grid_aggregator(self):
self.output_decoder = GridSamplesAggregator(
image_reader=self.readers[0],
output_path=self.action_param.save_seg_dir,
window_border=self.action_param.border,
interp_order=self.action_param.output_interp_order,
postfix=self.action_param.output_postfix,
fill_constant=self.action_param.fill_constant)
[docs] def initialise_resize_aggregator(self):
self.output_decoder = ResizeSamplesAggregator(
image_reader=self.readers[0],
output_path=self.action_param.save_seg_dir,
window_border=self.action_param.border,
interp_order=self.action_param.output_interp_order,
postfix=self.action_param.output_postfix)
[docs] def initialise_identity_aggregator(self):
self.output_decoder = WindowAsImageAggregator(
image_reader=self.readers[0],
output_path=self.action_param.save_seg_dir,
postfix=self.action_param.output_postfix)
[docs] def initialise_classifier_aggregator(self):
pass
# self.output_decoder = ClassifierSamplesAggregator(
# image_reader=self.readers[0],
# output_path=self.action_param.save_seg_dir,
# postfix=self.action_param.output_postfix)
[docs] def initialise_sampler(self):
if self.is_training:
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][0]()
elif self.is_inference:
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][1]()
[docs] def initialise_aggregator(self):
self.SUPPORTED_SAMPLING[self.net_param.window_sampling][2]()
[docs] def initialise_network(self):
w_regularizer = None
b_regularizer = None
reg_type = self.net_param.reg_type.lower()
decay = self.net_param.decay
if reg_type == 'l2' and decay > 0:
from tensorflow.contrib.layers.python.layers import regularizers
w_regularizer = regularizers.l2_regularizer(decay)
b_regularizer = regularizers.l2_regularizer(decay)
elif reg_type == 'l1' and decay > 0:
from tensorflow.contrib.layers.python.layers import regularizers
w_regularizer = regularizers.l1_regularizer(decay)
b_regularizer = regularizers.l1_regularizer(decay)
self.net = ApplicationNetFactory.create('toynet')(
num_classes=self.multioutput_param.num_classes,
w_initializer=InitializerFactory.get_initializer(
name=self.net_param.weight_initializer),
b_initializer=InitializerFactory.get_initializer(
name=self.net_param.bias_initializer),
w_regularizer=w_regularizer,
b_regularizer=b_regularizer,
acti_func=self.net_param.activation_function)
[docs] def connect_data_and_network(self,
outputs_collector=None,
gradients_collector=None):
def switch_sampler(for_training):
with tf.name_scope('train' if for_training else 'validation'):
sampler = self.get_sampler()[0][0 if for_training else -1]
return sampler.pop_batch_op()
if self.is_training:
# extract data
if self.action_param.validation_every_n > 0:
data_dict = tf.cond(tf.logical_not(self.is_validation),
lambda: switch_sampler(for_training=True),
lambda: switch_sampler(for_training=False))
else:
data_dict = switch_sampler(for_training=True)
image = tf.cast(data_dict['image'], tf.float32)
net_args = {'is_training': self.is_training,
'keep_prob': self.net_param.keep_prob}
net_out = self.net(image, **net_args)
with tf.name_scope('Optimiser'):
optimiser_class = OptimiserFactory.create(
name=self.action_param.optimiser)
self.optimiser = optimiser_class.get_instance(
learning_rate=self.action_param.lr)
loss_func = LossFunction(
n_class=self.multioutput_param.num_classes,
loss_type=self.action_param.loss_type,
softmax=self.multioutput_param.softmax)
data_loss = loss_func(
prediction=net_out,
ground_truth=data_dict.get('label', None),
weight_map=data_dict.get('weight', None))
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if self.net_param.decay > 0.0 and reg_losses:
reg_loss = tf.reduce_mean(
[tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
loss = data_loss + reg_loss
else:
loss = data_loss
# set the optimiser and the gradient
to_optimise = tf.trainable_variables()
vars_to_freeze = \
self.action_param.vars_to_freeze or \
self.action_param.vars_to_restore
if vars_to_freeze:
import re
var_regex = re.compile(vars_to_freeze)
# Only optimise vars that are not frozen
to_optimise = \
[v for v in to_optimise if not var_regex.search(v.name)]
tf.logging.info(
"Optimizing %d out of %d trainable variables, "
"the other variables fixed (--vars_to_freeze %s)",
len(to_optimise),
len(tf.trainable_variables()),
vars_to_freeze)
grads = self.optimiser.compute_gradients(
loss, var_list=to_optimise, colocate_gradients_with_ops=True)
# collecting gradients variables
gradients_collector.add_to_collection([grads])
# collecting output variables
outputs_collector.add_to_collection(
var=data_loss, name='loss',
average_over_devices=False, collection=CONSOLE)
outputs_collector.add_to_collection(
var=data_loss, name='loss',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
elif self.is_inference:
data_dict = switch_sampler(for_training=False)
image = tf.cast(data_dict['image'], tf.float32)
net_args = {'is_training': self.is_training,
'keep_prob': self.net_param.keep_prob}
net_out = self.net(image, **net_args)
num_classes = self.multioutput_param.num_classes
argmax_layer = PostProcessingLayer(
'ARGMAX', num_classes=num_classes)
softmax_layer = PostProcessingLayer(
'SOFTMAX', num_classes=num_classes)
arg_max_out = argmax_layer(net_out)
soft_max_out = softmax_layer(net_out)
# sum_prob_out = tf.reshape(tf.reduce_sum(soft_max_out),[1,1])
# min_prob_out = tf.reshape(tf.reduce_min(soft_max_out),[1,1])
sum_prob_out = tf.reduce_sum(soft_max_out)
min_prob_out = tf.reduce_min(soft_max_out)
outputs_collector.add_to_collection(
var=arg_max_out, name='window_argmax',
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=soft_max_out, name='window_softmax',
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=sum_prob_out, name='csv_sum',
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=min_prob_out, name='csv_min',
average_over_devices=False, collection=NETWORK_OUTPUT)
outputs_collector.add_to_collection(
var=data_dict['image_location'], name='location',
average_over_devices=False, collection=NETWORK_OUTPUT)
self.initialise_aggregator()
[docs] def interpret_output(self, batch_output):
if self.is_inference:
return self.output_decoder.decode_batch(
{'window_argmax': batch_output['window_argmax'],
'window_softmax': batch_output['window_softmax'],
'csv_sum': batch_output['csv_sum'],
'csv_min': batch_output['csv_min']},
batch_output['location'])
return True