# -*- coding: utf-8 -*-
"""
Windows aggregator resize each item
in a batch output and save as an image.
"""
from __future__ import absolute_import, division, print_function
import os
from collections import OrderedDict
import numpy as np
import pandas as pd
import niftynet.io.misc_io as misc_io
from niftynet.engine.sampler_resize_v2 import zoom_3d
from niftynet.engine.windows_aggregator_base import ImageWindowsAggregator
from niftynet.layer.discrete_label_normalisation import \
DiscreteLabelNormalisationLayer
from niftynet.layer.pad import PadLayer
[docs]class ResizeSamplesAggregator(ImageWindowsAggregator):
"""
This class decodes each item in a batch by resizing each image
window and save as a new image volume. Multiple output image can be
proposed and csv output can be performed as well
"""
def __init__(self,
image_reader,
name='image',
output_path=os.path.join('.', 'output'),
window_border=(),
interp_order=0,
postfix='niftynet_out'):
ImageWindowsAggregator.__init__(
self, image_reader=image_reader, output_path=output_path)
self.name = name
self.image_out = None
self.csv_out = None
self.window_border = window_border
self.output_interp_order = interp_order
self.postfix = postfix
self.current_out = {}
[docs] def decode_batch(self, window, location):
"""
Resizing each output image window in the batch as an image volume
location specifies the original input image (so that the
interpolation order, original shape information retained in the
generated outputs).For the fields that have the keyword 'window' in the
dictionary key, it will be saved as image. The rest will be saved as
csv. CSV files will contain at saving a first line of 0 (to be
changed into the header by the user), the first column being the
index of the window, followed by the list of output.
"""
n_samples = location.shape[0]
for batch_id in range(n_samples):
if self._is_stopping_signal(location[batch_id]):
return False
self.image_id = location[batch_id, 0]
self.image_out, self.csv_out = {}, {}
for key in window:
if 'window' in key:
# saving image output
while window[key].ndim < 5:
window[key] = window[key][..., np.newaxis, :]
self.image_out[key] = window[key][batch_id, ...]
else:
# saving csv output
window[key] = np.asarray(window[key]).reshape(
[n_samples, -1])
n_elements = window[key].shape[-1]
table_header = [
'{}_{}'.format(key, idx) for idx in range(n_elements)
] if n_elements > 1 else ['{}'.format(key)]
self.csv_out[key] = self._initialise_empty_csv(
key_names=table_header)
csv_row = window[key][batch_id:batch_id + 1, :].ravel()
self.csv_out[key] = self.csv_out[key].append(
OrderedDict(zip(table_header, csv_row)),
ignore_index=True)
self._save_current_image()
self._save_current_csv()
return True
def _initialise_image_shape(self, image_id, n_channels):
"""
Return the shape of the empty image to be saved
:param image_id: index to find the appropriate input image from the
reader
:param n_channels: number of channels of the image
:return: shape of the empty image
"""
self.image_id = image_id
spatial_shape = self.input_image[self.name].shape[:3]
output_image_shape = spatial_shape + (1, n_channels)
empty_image = np.zeros(output_image_shape, dtype=np.bool)
for layer in self.reader.preprocessors:
if isinstance(layer, PadLayer):
empty_image, _ = layer(empty_image)
return empty_image.shape
def _save_current_image(self):
"""
Loop through the dictionary of images output and resize and reverse
the preprocessing prior to saving
:return:
"""
if self.input_image is None:
return
self.current_out = {}
for i in self.image_out:
resize_to_shape = self._initialise_image_shape(
image_id=self.image_id, n_channels=self.image_out[i].shape[-1])
window_shape = resize_to_shape
current_out = self.image_out[i]
while current_out.ndim < 5:
current_out = current_out[..., np.newaxis, :]
if self.window_border and any([b > 0 for b in self.window_border]):
np_border = self.window_border
while len(np_border) < 5:
np_border = np_border + (0, )
np_border = [(b, ) for b in np_border]
current_out = np.pad(current_out, np_border, mode='edge')
image_shape = current_out.shape
zoom_ratio = \
[float(p) / float(d) for p, d in zip(window_shape, image_shape)]
image_shape = list(image_shape[:3]) + [1, image_shape[-1]]
current_out = np.reshape(current_out, image_shape)
current_out = zoom_3d(
image=current_out,
ratio=zoom_ratio,
interp_order=self.output_interp_order)
self.current_out[i] = current_out
for layer in reversed(self.reader.preprocessors):
if isinstance(layer, PadLayer):
for i in self.image_out:
self.current_out[i], _ = layer.inverse_op(
self.current_out[i])
if isinstance(layer, DiscreteLabelNormalisationLayer):
for i in self.image_out:
self.image_out[i], _ = layer.inverse_op(self.image_out[i])
subject_name = self.reader.get_subject_id(self.image_id)
for i in self.image_out:
filename = "{}_{}_{}.nii.gz".format(i, subject_name, self.postfix)
source_image_obj = self.input_image[self.name]
misc_io.save_data_array(self.output_path, filename,
self.current_out[i], source_image_obj,
self.output_interp_order)
self.log_inferred(subject_name, filename)
return
def _save_current_csv(self):
"""
Save all csv output present in the dictionary of csv_output.
:return:
"""
if self.input_image is None:
return
subject_name = self.reader.get_subject_id(self.image_id)
for i in self.csv_out:
filename = "{}_{}_{}.csv".format(i, subject_name, self.postfix)
misc_io.save_csv_array(self.output_path, filename, self.csv_out[i])
self.log_inferred(subject_name, filename)
return
def _initialise_empty_csv(self, key_names):
"""
Initialise the array to be saved as csv as a line of zeros according
to the number of elements to be saved
:param n_channel:
:return:
"""
return pd.DataFrame(columns=key_names)