Source code for niftynet.application.autoencoder_application

import tensorflow as tf

from niftynet.application.base_application import BaseApplication
from niftynet.engine.application_factory import ApplicationNetFactory
from niftynet.engine.application_factory import OptimiserFactory
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.application_variables import NETWORK_OUTPUT
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.engine.sampler_linear_interpolate_v2 import LinearInterpolateSampler
from niftynet.engine.sampler_resize_v2 import ResizeSampler
from niftynet.engine.windows_aggregator_identity import WindowAsImageAggregator
from niftynet.io.image_reader import ImageReader
from niftynet.layer.loss_autoencoder import LossFunction
from niftynet.utilities.util_common import look_up_operations

SUPPORTED_INPUT = set(['image', 'feature'])
SUPPORTED_INFERENCE = \
    set(['encode', 'encode-decode', 'sample', 'linear_interpolation'])


[docs]class AutoencoderApplication(BaseApplication): REQUIRED_CONFIG_SECTION = "AUTOENCODER" def __init__(self, net_param, action_param, action): BaseApplication.__init__(self) tf.logging.info('starting autoencoder application') self.action = action self.net_param = net_param self.action_param = action_param self.data_param = None self.autoencoder_param = None
[docs] def initialise_dataset_loader( self, data_param=None, task_param=None, data_partitioner=None): self.data_param = data_param self.autoencoder_param = task_param if not self.is_training: self._infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) else: self._infer_type = None try: reader_phase = self.action_param.dataset_to_infer except AttributeError: reader_phase = None file_lists = data_partitioner.get_file_lists_by( phase=reader_phase, action=self.action) # read each line of csv files into an instance of Subject if self.is_evaluation: NotImplementedError('Evaluation is not yet ' 'supported in this application.') if self.is_training: self.readers = [] for file_list in file_lists: reader = ImageReader(['image']) reader.initialise(data_param, task_param, file_list) self.readers.append(reader) if self._infer_type in ('encode', 'encode-decode'): self.readers = [ImageReader(['image'])] self.readers[0].initialise(data_param, task_param, file_lists[0]) elif self._infer_type == 'sample': self.readers = [] elif self._infer_type == 'linear_interpolation': self.readers = [ImageReader(['feature'])] self.readers[0].initialise(data_param, task_param, file_lists[0])
# if self.is_training or self._infer_type in ('encode', 'encode-decode'): # mean_var_normaliser = MeanVarNormalisationLayer(image_name='image') # self.reader.add_preprocessing_layers([mean_var_normaliser])
[docs] def initialise_sampler(self): self.sampler = [] if self.is_training: self.sampler.append([ResizeSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle=True, queue_length=self.net_param.queue_length) for reader in self.readers]) return if self._infer_type in ('encode', 'encode-decode'): self.sampler.append([ResizeSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, windows_per_image=1, shuffle=False, queue_length=self.net_param.queue_length) for reader in self.readers]) return if self._infer_type == 'linear_interpolation': self.sampler.append([LinearInterpolateSampler( reader=reader, window_sizes=self.data_param, batch_size=self.net_param.batch_size, n_interpolations=self.autoencoder_param.n_interpolations, queue_length=self.net_param.queue_length) for reader in self.readers]) return
[docs] def initialise_network(self): w_regularizer = None b_regularizer = None reg_type = self.net_param.reg_type.lower() decay = self.net_param.decay if reg_type == 'l2' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l2_regularizer(decay) b_regularizer = regularizers.l2_regularizer(decay) elif reg_type == 'l1' and decay > 0: from tensorflow.contrib.layers.python.layers import regularizers w_regularizer = regularizers.l1_regularizer(decay) b_regularizer = regularizers.l1_regularizer(decay) self.net = ApplicationNetFactory.create(self.net_param.name)( w_regularizer=w_regularizer, b_regularizer=b_regularizer)
[docs] def connect_data_and_network(self, outputs_collector=None, gradients_collector=None): def switch_sampler(for_training): with tf.name_scope('train' if for_training else 'validation'): sampler = self.get_sampler()[0][0 if for_training else -1] return sampler.pop_batch_op() if self.is_training: self.patience = self.action_param.patience self.mode = self.action_param.early_stopping_mode if self.action_param.validation_every_n > 0: data_dict = tf.cond(tf.logical_not(self.is_validation), lambda: switch_sampler(True), lambda: switch_sampler(False)) else: data_dict = switch_sampler(for_training=True) image = tf.cast(data_dict['image'], tf.float32) net_output = self.net(image, is_training=self.is_training) with tf.name_scope('Optimiser'): optimiser_class = OptimiserFactory.create( name=self.action_param.optimiser) self.optimiser = optimiser_class.get_instance( learning_rate=self.action_param.lr) loss_func = LossFunction(loss_type=self.action_param.loss_type) data_loss = loss_func(net_output) loss = data_loss if self.net_param.decay > 0.0: reg_losses = tf.get_collection( tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: reg_loss = tf.reduce_mean( [tf.reduce_mean(reg_loss) for reg_loss in reg_losses]) loss = loss + reg_loss self.total_loss = loss grads = self.optimiser.compute_gradients( loss, colocate_gradients_with_ops=True) # collecting gradients variables gradients_collector.add_to_collection([grads]) outputs_collector.add_to_collection( var=self.total_loss, name='total_loss', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection( var=self.total_loss, name='total_loss', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=data_loss, name='variational_lower_bound', average_over_devices=True, collection=CONSOLE) outputs_collector.add_to_collection( var=data_loss, name='variational_lower_bound', average_over_devices=True, summary_type='scalar', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=net_output[4], name='Originals', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=net_output[2], name='Means', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) outputs_collector.add_to_collection( var=net_output[5], name='Variances', average_over_devices=False, summary_type='image3_coronal', collection=TF_SUMMARIES) else: if self._infer_type in ('encode', 'encode-decode'): data_dict = self.get_sampler()[0][0].pop_batch_op() image = tf.cast(data_dict['image'], dtype=tf.float32) net_output = self.net(image, is_training=False) outputs_collector.add_to_collection( var=data_dict['image_location'], name='location', average_over_devices=True, collection=NETWORK_OUTPUT) if self._infer_type == 'encode-decode': outputs_collector.add_to_collection( var=net_output[2], name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) if self._infer_type == 'encode': outputs_collector.add_to_collection( var=net_output[7], name='embedded', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'sample': image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) noise_shape = net_output[-1].shape.as_list() stddev = self.autoencoder_param.noise_stddev noise = tf.random_normal(shape=noise_shape, mean=0.0, stddev=stddev, dtype=tf.float32) partially_decoded_sample = self.net.shared_decoder( noise, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection( var=decoder_output, name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=None, output_path=self.action_param.save_seg_dir) return elif self._infer_type == 'linear_interpolation': # construct the entire network image_size = (self.net_param.batch_size,) + \ self.action_param.spatial_window_size + (1,) dummy_image = tf.zeros(image_size) net_output = self.net(dummy_image, is_training=False) data_dict = self.get_sampler()[0][0].pop_batch_op() real_code = data_dict['feature'] real_code = tf.reshape(real_code, net_output[-1].get_shape()) partially_decoded_sample = self.net.shared_decoder( real_code, is_training=False) decoder_output = self.net.decoder_means( partially_decoded_sample, is_training=False) outputs_collector.add_to_collection( var=decoder_output, name='generated_image', average_over_devices=True, collection=NETWORK_OUTPUT) outputs_collector.add_to_collection( var=data_dict['feature_location'], name='location', average_over_devices=True, collection=NETWORK_OUTPUT) self.output_decoder = WindowAsImageAggregator( image_reader=self.readers[0], output_path=self.action_param.save_seg_dir) else: raise NotImplementedError
[docs] def interpret_output(self, batch_output): if self.is_training: return True else: infer_type = look_up_operations( self.autoencoder_param.inference_type, SUPPORTED_INFERENCE) if infer_type == 'encode': return self.output_decoder.decode_batch( {'window_embedded':batch_output['embedded']}, batch_output['location'][:, 0:1]) if infer_type == 'encode-decode': return self.output_decoder.decode_batch( {'window_generated_image':batch_output['generated_image']}, batch_output['location'][:, 0:1]) if infer_type == 'sample': return self.output_decoder.decode_batch( {'generated_image':batch_output['generated_image']}, None) if infer_type == 'linear_interpolation': return self.output_decoder.decode_batch( {'generated_image':batch_output['generated_image']}, batch_output['location'][:, :2])