Source code for niftynet.engine.windows_aggregator_identity

# -*- coding: utf-8 -*-
"""
windows aggregator saves each item in a batch output as an image.
"""
from __future__ import absolute_import, division, print_function

import os
from collections import OrderedDict

import numpy as np
import pandas as pd

import niftynet.io.misc_io as misc_io
from niftynet.engine.windows_aggregator_base import ImageWindowsAggregator

# pylint: disable=too-many-branches


[docs]class WindowAsImageAggregator(ImageWindowsAggregator): """ This class saves each item in a batch output to images, the output filenames can be defined in three ways: 1. location is None (input image from a random distribution): a uuid is generated as output filename. 2. the length of the location array is 2: (indicates output image is from an interpolation of two input images): - ``location[batch_id, 0]`` is used as a ``base_name``, - ``location[batch_id, 0]`` is used as a ``relative_id`` output file name is ``"{}_{}"%(base_name, relative_id)``. 3. the length of the location array is greater than 2: (indicates output image is from single input image) ``location[batch_id, 0]`` is used as the file name """ def __init__(self, image_reader=None, output_path=os.path.join('.', 'output'), postfix='_niftynet_generated'): ImageWindowsAggregator.__init__( self, image_reader=image_reader, output_path=output_path) self.output_path = os.path.abspath(output_path) self.inferred_csv = os.path.join(self.output_path, 'inferred.csv') self.output_id = {'base_name': None, 'relative_id': 0} self.postfix = postfix if os.path.exists(self.inferred_csv): os.remove(self.inferred_csv) def _decode_subject_name(self, location=None): if self.reader: image_id = int(location) return self.reader.get_subject_id(image_id) import uuid return str(uuid.uuid4())
[docs] def decode_batch(self, window, location=None): if location is not None: n_samples = location.shape[0] else: n_samples = window[sorted(window)[0]].shape[0] for batch_id in range(n_samples): location_b = location[batch_id] if (location is not None) else None if self._is_stopping_signal(location_b): return False filename = self._decode_subject_name(location_b[0]) \ if (location_b is not None) else self._decode_subject_name() # if base file name changed, reset relative name index if filename != self.output_id['base_name']: self.output_id['base_name'] = filename self.output_id['relative_id'] = 0 # when location has two component, the name should # be constructed as a composite of two input filenames if (location_b is not None) and (len(location_b) == 2): output_name = '{}_{}'.format( self.output_id['base_name'], self._decode_subject_name(location_b[1])) else: output_name = self.output_id['base_name'] for key in window: output_name_k = '{}_{}'.format(output_name, key) if 'window' in key: self._save_current_image(self.output_id['relative_id'], output_name_k, window[key][batch_id, ...]) else: window[key] = np.asarray(window[key]).reshape( [n_samples, -1]) n_elements = window[key].shape[-1] table_header = [ '{}_{}'.format(key, idx) for idx in range(n_elements) ] if n_elements > 1 else ['{}'.format(key)] csv_table = pd.DataFrame(columns=table_header) csv_table = csv_table.append( OrderedDict(zip(table_header, window[key].ravel())), ignore_index=True) self._save_current_csv(self.output_id['relative_id'], output_name_k, csv_table) self.output_id['relative_id'] += 1 return True
def _save_current_image(self, idx, filename, image): if image is None: return if idx == 0: uniq_name = "{}{}.nii.gz".format(filename, self.postfix) else: uniq_name = "{}_{}{}.nii.gz".format(idx, filename, self.postfix) misc_io.save_data_array(self.output_path, uniq_name, image, None) with open(self.inferred_csv, 'a') as csv_file: filename = os.path.join(self.output_path, filename) csv_file.write('{},{}\n'.format(idx, filename)) return def _save_current_csv(self, idx, filename, csv_data): """ Save all csv output present in the dictionary of csv_output. :return: """ if csv_data is None: return if idx == 0: uniq_name = "{}{}.csv".format(filename, self.postfix) else: uniq_name = "{}_{}{}.csv".format(idx, filename, self.postfix) misc_io.save_csv_array(self.output_path, uniq_name, csv_data) with open(self.inferred_csv, 'a') as csv_file: filename = os.path.join(self.output_path, filename) csv_file.write('{},{}\n'.format(idx, filename)) return