Source code for niftynet.engine.windows_aggregator_grid

# -*- coding: utf-8 -*-
"""
windows aggregator decode sampling grid coordinates and image id from
batch data, forms image level output and write to hard drive.
"""
from __future__ import absolute_import, division, print_function

import os
from collections import OrderedDict

import numpy as np
import pandas as pd

# pylint: disable=too-many-nested-blocks
# pylint: disable=too-many-branches
import niftynet.io.misc_io as misc_io
from niftynet.engine.windows_aggregator_base import ImageWindowsAggregator
from niftynet.layer.discrete_label_normalisation import \
    DiscreteLabelNormalisationLayer
from niftynet.layer.pad import PadLayer


[docs]class GridSamplesAggregator(ImageWindowsAggregator): """ This class keeps record of the currently cached image, initialised as all zeros, and the values are replaced by image window data decoded from batch. """ def __init__(self, image_reader, name='image', output_path=os.path.join('.', 'output'), window_border=(), interp_order=0, postfix='niftynet_out', fill_constant=0.0): ImageWindowsAggregator.__init__( self, image_reader=image_reader, output_path=output_path) self.name = name self.image_out = None self.csv_out = None self.window_border = window_border self.output_interp_order = interp_order self.postfix = postfix self.fill_constant = fill_constant
[docs] def decode_batch(self, window, location): """ Function used to save multiple outputs listed in the window dictionary. For the fields that have the keyword 'window' in the dictionary key, it will be saved as image. The rest will be saved as csv. CSV files will contain at saving a first line of 0 (to be changed into the header by the user), the first column being the index of the window, followed by the list of output and the location array for each considered window :param window: dictionary of output :param location: location of the input :return: """ n_samples = location.shape[0] location_cropped = {} for key in window: if 'window' in key: # all outputs to be created as images should # contained the keyword "window" window[key], location_cropped[key] = self.crop_batch( window[key], location, self.window_border) for batch_id in range(n_samples): image_id = location[batch_id, 0] if image_id != self.image_id: # image name changed: # save current result and create an empty result file self._save_current_image() self._save_current_csv() if self._is_stopping_signal(location[batch_id]): return False self.image_out, self.csv_out = {}, {} for key in window: if 'window' in key: # to be saved as image self.image_out[key] = self._initialise_empty_image( image_id=image_id, n_channels=window[key].shape[-1], dtype=window[key].dtype) else: # to be saved as csv file n_elements = np.int64( np.asarray(window[key]).size / n_samples) table_header = [ '{}_{}'.format(key, idx) for idx in range(n_elements) ] if n_elements > 1 else ['{}'.format(key)] table_header += [ 'coord_{}'.format(idx) for idx in range(location.shape[-1]) ] self.csv_out[key] = self._initialise_empty_csv( key_names=table_header) for key in window: if 'window' in key: x_start, y_start, z_start, x_end, y_end, z_end = \ location_cropped[key][batch_id, 1:] self.image_out[key][ x_start:x_end, y_start:y_end, z_start:z_end, ...] = \ window[key][batch_id, ...] else: window[key] = np.asarray(window[key]).reshape( [n_samples, -1]) window_save = window[key][batch_id:batch_id + 1, :] window_loc = location[batch_id:batch_id + 1, :] csv_row = np.concatenate([window_save, window_loc], 1) csv_row = csv_row.ravel() key_names = self.csv_out[key].columns self.csv_out[key] = self.csv_out[key].append( OrderedDict(zip(key_names, csv_row)), ignore_index=True) return True
def _initialise_empty_image(self, image_id, n_channels, dtype=np.float): """ Initialise an empty image in which to populate the output :param image_id: image_id to be used in the reader :param n_channels: numbers of channels of the saved output (for multimodal output) :param dtype: datatype used for the saving :return: the initialised empty image """ self.image_id = image_id spatial_shape = self.input_image[self.name].shape[:3] output_image_shape = spatial_shape + (n_channels, ) empty_image = np.zeros(output_image_shape, dtype=dtype) for layer in self.reader.preprocessors: if isinstance(layer, PadLayer): empty_image, _ = layer(empty_image) if self.fill_constant != 0.0: empty_image[:] = self.fill_constant return empty_image def _initialise_empty_csv(self, key_names): """ Initialise a csv output file with a first line of zeros :param n_channel: number of saved fields :return: empty first line of the array to be saved as csv """ return pd.DataFrame(columns=key_names) def _save_current_image(self): """ For all the outputs to be saved as images, go through the dictionary and save the resulting output after reversing the initial preprocessing :return: """ if self.input_image is None: return for layer in reversed(self.reader.preprocessors): if isinstance(layer, PadLayer): for i in self.image_out: self.image_out[i], _ = layer.inverse_op(self.image_out[i]) if isinstance(layer, DiscreteLabelNormalisationLayer): for i in self.image_out: self.image_out[i], _ = layer.inverse_op(self.image_out[i]) subject_name = self.reader.get_subject_id(self.image_id) for i in self.image_out: filename = "{}_{}_{}.nii.gz".format(i, subject_name, self.postfix) source_image_obj = self.input_image[self.name] misc_io.save_data_array(self.output_path, filename, self.image_out[i], source_image_obj, self.output_interp_order) self.log_inferred(subject_name, filename) return def _save_current_csv(self): """ For all output to be saved as csv, loop through the dictionary of output and create the csv :return: """ if self.input_image is None: return subject_name = self.reader.get_subject_id(self.image_id) for i in self.csv_out: filename = "{}_{}_{}.csv".format(i, subject_name, self.postfix) misc_io.save_csv_array(self.output_path, filename, self.csv_out[i]) self.log_inferred(subject_name, filename) return