Source code for niftynet.engine.sampler_uniform
# -*- coding: utf-8 -*-
"""
Generating uniformly distributed image window from input image
This can also be considered as a "random cropping" layer of the
input image.
"""
from __future__ import absolute_import, division, print_function
import numpy as np
import tensorflow as tf
from niftynet.engine.image_window import ImageWindow, N_SPATIAL
from niftynet.engine.image_window_buffer import InputBatchQueueRunner
from niftynet.layer.base_layer import Layer
# pylint: disable=too-many-arguments
[docs]class UniformSampler(Layer, InputBatchQueueRunner):
"""
This class generates samples by uniformly sampling each input volume
currently the coordinates are randomised for spatial dims only,
i.e., the first three dims of image.
This layer can be considered as a "random cropping" layer of the
input image.
"""
def __init__(self,
reader,
data_param,
batch_size,
windows_per_image,
queue_length=10):
self.reader = reader
Layer.__init__(self, name='input_buffer')
InputBatchQueueRunner.__init__(
self,
capacity=queue_length,
shuffle=True)
tf.logging.info('reading size of preprocessed images')
self.window = ImageWindow.from_data_reader_properties(
self.reader.input_sources,
self.reader.shapes,
self.reader.tf_dtypes,
data_param)
tf.logging.info('initialised window instance')
self._create_queue_and_ops(self.window,
enqueue_size=windows_per_image,
dequeue_size=batch_size)
tf.logging.info("initialised sampler output %s "
" [-1 for dynamic size]", self.window.shapes)
self.spatial_coordinates_generator = rand_spatial_coordinates
# pylint: disable=too-many-locals
[docs] def layer_op(self):
"""
This function generates sampling windows to the input buffer
image data are from ``self.reader()``
It first completes window shapes based on image data,
then finds random coordinates based on the window shapes
finally extract window with the coordinates and output
a dictionary (required by input buffer).
:return: output data dictionary ``{placeholders: data_array}``
"""
while True:
image_id, data, _ = self.reader(idx=None, shuffle=True)
if not data:
break
image_shapes = dict((name, data[name].shape)
for name in self.window.names)
static_window_shapes = self.window.match_image_shapes(image_shapes)
# find random coordinates based on window and image shapes
coordinates = self.spatial_coordinates_generator(
image_id,
data,
image_shapes,
static_window_shapes,
self.window.n_samples)
# initialise output dict, placeholders as dictionary keys
# this dictionary will be used in
# enqueue operation in the form of: `feed_dict=output_dict`
output_dict = {}
# fill output dict with data
for name in list(data):
coordinates_key = self.window.coordinates_placeholder(name)
image_data_key = self.window.image_data_placeholder(name)
# fill the coordinates
location_array = coordinates[name]
output_dict[coordinates_key] = location_array
# fill output window array
image_array = []
for window_id in range(self.window.n_samples):
x_start, y_start, z_start, x_end, y_end, z_end = \
location_array[window_id, 1:]
try:
image_window = data[name][
x_start:x_end, y_start:y_end, z_start:z_end, ...]
image_array.append(image_window[np.newaxis, ...])
except ValueError:
tf.logging.fatal(
"dimensionality miss match in input volumes, "
"please specify spatial_window_size with a "
"3D tuple and make sure each element is "
"smaller than the image length in each dim.")
raise
if len(image_array) > 1:
output_dict[image_data_key] = \
np.concatenate(image_array, axis=0)
else:
output_dict[image_data_key] = image_array[0]
# the output image shape should be
# [enqueue_batch_size, x, y, z, time, modality]
# where enqueue_batch_size = windows_per_image
yield output_dict
[docs]def rand_spatial_coordinates(subject_id,
data,
img_sizes,
win_sizes,
n_samples=1):
"""
``win_sizes`` could be different (for example in segmentation network
input image window size is ``32x32x10``,
training label window is ``16x16x10`` -- the network reduces x-y plane
spatial resolution.)
This function handles this situation by first find the largest
window across these window definitions, and generate the coordinates.
These coordinates are then adjusted for each of the
smaller window sizes (the output windows are concentric).
"""
assert data is not None, "No input from image reader. Please check" \
"the configuration file."
n_samples = max(n_samples, 1)
uniq_spatial_size = set([img_size[:N_SPATIAL]
for img_size in list(img_sizes.values())])
if len(uniq_spatial_size) > 1:
tf.logging.fatal("Don't know how to generate sampling "
"locations: Spatial dimensions of the "
"grouped input sources are not "
"consistent. %s", uniq_spatial_size)
raise NotImplementedError
uniq_spatial_size = uniq_spatial_size.pop()
# find spatial window location based on the largest spatial window
spatial_win_sizes = [win_size[:N_SPATIAL]
for win_size in win_sizes.values()]
spatial_win_sizes = np.asarray(spatial_win_sizes, dtype=np.int32)
max_spatial_win = np.max(spatial_win_sizes, axis=0)
max_coords = np.zeros((n_samples, N_SPATIAL), dtype=np.int32)
for i in range(0, N_SPATIAL):
assert uniq_spatial_size[i] >= max_spatial_win[i], \
"window size {} is larger than image size {}".format(
max_spatial_win[i], uniq_spatial_size[i])
max_coords[:, i] = np.random.randint(
0, max(uniq_spatial_size[i] - max_spatial_win[i], 1), n_samples)
# adjust max spatial coordinates based on each spatial window size
all_coordinates = {}
for mod in list(win_sizes):
win_size = win_sizes[mod][:N_SPATIAL]
half_win_diff = np.floor((max_spatial_win - win_size) / 2.0)
# shift starting coords of the window
# so that smaller windows are centred within the large windows
spatial_coords = np.zeros((n_samples, N_SPATIAL * 2), dtype=np.int32)
spatial_coords[:, :N_SPATIAL] = \
max_coords[:, :N_SPATIAL] + half_win_diff[:N_SPATIAL]
spatial_coords[:, N_SPATIAL:] = \
spatial_coords[:, :N_SPATIAL] + win_size[:N_SPATIAL]
# include the subject id
subject_id = np.ones((n_samples,), dtype=np.int32) * subject_id
spatial_coords = np.append(subject_id[:, None], spatial_coords, axis=1)
all_coordinates[mod] = spatial_coords
return all_coordinates