Source code for niftynet.engine.windows_aggregator_identity

# -*- coding: utf-8 -*-
windows aggregator saves each item in a batch output as an image.
from __future__ import absolute_import, print_function, division

import os

import as misc_io
from niftynet.engine.windows_aggregator_base import ImageWindowsAggregator

[docs]class WindowAsImageAggregator(ImageWindowsAggregator): """ This class saves each item in a batch output to images, the output filenames can be defined in three ways: 1. location is None (input image from a random distribution): a uuid is generated as output filename. 2. the length of the location array is 2: (indicates output image is from an interpolation of two input images): - ``location[batch_id, 0]`` is used as a ``base_name``, - ``location[batch_id, 0]`` is used as a ``relative_id`` output file name is ``"{}_{}"%(base_name, relative_id)``. 3. the length of the location array is greater than 2: (indicates output image is from single input image) ``location[batch_id, 0]`` is used as the file name """ def __init__(self, image_reader=None, output_path=os.path.join('.', 'output'), postfix='_niftynet_generated'): ImageWindowsAggregator.__init__( self, image_reader=image_reader, output_path=output_path) self.output_path = os.path.abspath(output_path) self.inferred_csv = os.path.join(self.output_path, 'inferred.csv') self.output_id = {'base_name': None, 'relative_id': 0} self.postfix = postfix if os.path.exists(self.inferred_csv): os.remove(self.inferred_csv) def _decode_subject_name(self, location=None): if self.reader: image_id = int(location) return self.reader.get_subject_id(image_id) import uuid return str(uuid.uuid4())
[docs] def decode_batch(self, window, location=None): if location is not None: n_samples = location.shape[0] for batch_id in range(n_samples): if self._is_stopping_signal(location[batch_id]): return False filename = self._decode_subject_name(location[batch_id, 0]) # if base file name changed, reset relative name index if filename != self.output_id['base_name']: self.output_id['base_name'] = filename self.output_id['relative_id'] = 0 # when location has two component, the name should # be constructed as a composite of two input filenames if len(location[batch_id]) == 2: output_name = '{}_{}'.format( self.output_id['base_name'], self._decode_subject_name(location[batch_id, 1])) else: output_name = self.output_id['base_name'] self._save_current_image(self.output_id['relative_id'], output_name, window[batch_id, ...]) self.output_id['relative_id'] += 1 return True n_samples = window.shape[0] for batch_id in range(n_samples): filename = self._decode_subject_name() self._save_current_image( batch_id, filename, window[batch_id, ...]) return False
def _save_current_image(self, idx, filename, image): if image is None: return uniq_name = "{}_{}{}.nii.gz".format(idx, filename, self.postfix) misc_io.save_data_array(self.output_path, uniq_name, image, None) with open(self.inferred_csv, 'a') as csv_file: filename = os.path.join(self.output_path, filename) csv_file.write('{},{}\n'.format(idx, uniq_name)) return