Source code for niftynet.engine.windows_aggregator_classifier

# -*- coding: utf-8 -*-
"""
windows aggregator resize each item
in a batch output and save as an image
"""
from __future__ import absolute_import, print_function, division

import os

import numpy as np

import niftynet.io.misc_io as misc_io
from niftynet.engine.windows_aggregator_base import ImageWindowsAggregator
from niftynet.layer.discrete_label_normalisation import \
    DiscreteLabelNormalisationLayer


[docs]class ClassifierSamplesAggregator(ImageWindowsAggregator): """ This class decodes each item in a batch by saving classification labels to a new image volume. """ def __init__(self, image_reader, name='image', output_path=os.path.join('.', 'output'), postfix='_niftynet_out'): ImageWindowsAggregator.__init__( self, image_reader=image_reader, output_path=output_path) self.name = name self.output_interp_order = 0 self.postfix = postfix self.csv_path = os.path.join(self.output_path, self.postfix+'.csv') if os.path.exists(self.csv_path): os.remove(self.csv_path)
[docs] def decode_batch(self, window, location): """ window holds the classifier labels location is a holdover from segmentation and may be removed in a later refactoring, but currently hold info about the stopping signal from the sampler """ n_samples = window.shape[0] for batch_id in range(n_samples): if self._is_stopping_signal(location[batch_id]): return False self.image_id = location[batch_id, 0] self._save_current_image(window[batch_id, ...]) return True
def _save_current_image(self, image_out): if self.input_image is None: return window_shape = [1, 1, 1, 1, image_out.shape[-1]] image_out = np.reshape(image_out, window_shape) for layer in reversed(self.reader.preprocessors): if isinstance(layer, DiscreteLabelNormalisationLayer): image_out, _ = layer.inverse_op(image_out) subject_name = self.reader.get_subject_id(self.image_id) filename = "{}{}.nii.gz".format(subject_name, self.postfix) source_image_obj = self.input_image[self.name] misc_io.save_data_array(self.output_path, filename, image_out, source_image_obj, self.output_interp_order) with open(self.csv_path, 'a') as csv_file: data_str = ','.join([str(i) for i in image_out[0, 0, 0, 0, :]]) csv_file.write(subject_name+','+data_str+'\n') self.log_inferred(subject_name, filename) return