Source code for niftynet.contrib.sampler_pairwise.sampler_pairwise_uniform
from __future__ import absolute_import, division, print_function
import numpy as np
import tensorflow as tf
#from tensorflow.contrib.data.python.ops.dataset_ops import Dataset
from niftynet.engine.image_window import ImageWindow
from niftynet.layer.base_layer import Layer
from niftynet.layer.grid_warper import AffineGridWarperLayer
from niftynet.layer.resampler import ResamplerLayer
from niftynet.layer.linear_resize import LinearResizeLayer as Resize
#from niftynet.layer.approximated_smoothing import SmoothingLayer as Smooth
[docs]class PairwiseUniformSampler(Layer):
def __init__(self,
reader_0,
reader_1,
data_param,
batch_size=1):
Layer.__init__(self, name='pairwise_sampler_uniform')
# reader for the fixed images
self.reader_0 = reader_0
# reader for the moving images
self.reader_1 = reader_1
# TODO:
# 0) check the readers should have the same length file list
# 1) detect window shape mismatches or defaulting
# windows to the fixed image reader properties
# 2) reshape images to (supporting multi-modal data)
# [batch, x, y, channel] or [batch, x, y, z, channels]
# 3) infer spatial rank
# 4) make ``label`` optional
self.batch_size = batch_size
self.spatial_rank = 3
self.window = ImageWindow.from_data_reader_properties(
self.reader_0.input_sources,
self.reader_0.shapes,
self.reader_0.tf_dtypes,
data_param)
if self.window.has_dynamic_shapes:
tf.logging.fatal('Dynamic shapes not supported.\nPlease specify '
'all spatial dims of the input data, for the '
'spatial_window_size parameter.')
raise NotImplementedError
# TODO: check spatial dims the same across input modalities
self.image_shape = \
self.reader_0.shapes['fixed_image'][:self.spatial_rank]
self.moving_image_shape = \
self.reader_1.shapes['moving_image'][:self.spatial_rank]
self.window_size = self.window.shapes['fixed_image'][1:]
# initialise a dataset prefetching pairs of image and label volumes
n_subjects = len(self.reader_0.output_list)
rand_ints = np.random.randint(n_subjects, size=[n_subjects])
image_dataset = tf.data.Dataset.from_tensor_slices(rand_ints)
# mapping random integer id to 4 volumes moving/fixed x image/label
# tf.py_func wrapper of ``get_pairwise_inputs``
image_dataset = image_dataset.map(
lambda image_id: tuple(tf.py_func(
self.get_pairwise_inputs, [image_id],
[tf.int64, tf.float32, tf.float32, tf.int32, tf.int32])),
num_parallel_calls=4) # supported by tf 1.4?
image_dataset = image_dataset.repeat() # num_epochs can be param
image_dataset = image_dataset.shuffle(
buffer_size=self.batch_size * 20)
image_dataset = image_dataset.batch(self.batch_size)
self.iterator = image_dataset.make_initializable_iterator()
[docs] def get_pairwise_inputs(self, image_id):
# fetch fixed image
fixed_inputs = []
fixed_inputs.append(self._get_image('fixed_image', image_id)[0])
fixed_inputs.append(self._get_image('fixed_label', image_id)[0])
fixed_inputs = np.concatenate(fixed_inputs, axis=-1)
fixed_shape = np.asarray(fixed_inputs.shape).T.astype(np.int32)
# fetch moving image
moving_inputs = []
moving_inputs.append(self._get_image('moving_image', image_id)[0])
moving_inputs.append(self._get_image('moving_label', image_id)[0])
moving_inputs = np.concatenate(moving_inputs, axis=-1)
moving_shape = np.asarray(moving_inputs.shape).T.astype(np.int32)
return image_id, fixed_inputs, moving_inputs, fixed_shape, moving_shape
def _get_image(self, image_source_type, image_id):
# returns a random image from either the list of fixed images
# or the list of moving images
try:
image_source_type = image_source_type.decode()
except AttributeError:
pass
if image_source_type.startswith('fixed'):
_, data, _ = self.reader_0(idx=image_id)
else: # image_source_type.startswith('moving'):
_, data, _ = self.reader_1(idx=image_id)
image = np.asarray(data[image_source_type]).astype(np.float32)
image_shape = list(image.shape)
image = np.reshape(image, image_shape[:self.spatial_rank] + [-1])
image_shape = np.asarray(image.shape).astype(np.int32)
return image, image_shape
[docs] def layer_op(self):
"""
This function concatenate image and label volumes at the last dim
and randomly cropping the volumes (also the cropping margins)
"""
image_id, fixed_inputs, moving_inputs, fixed_shape, moving_shape = \
self.iterator.get_next()
# TODO preprocessing layer modifying
# image shapes will not be supported
# assuming the same shape across modalities, using the first
image_id.set_shape((self.batch_size,))
image_id = tf.to_float(image_id)
fixed_inputs.set_shape(
(self.batch_size,) + (None,) * self.spatial_rank + (2,))
# last dim is 1 image + 1 label
moving_inputs.set_shape(
(self.batch_size,) + self.moving_image_shape + (2,))
fixed_shape.set_shape((self.batch_size, self.spatial_rank + 1))
moving_shape.set_shape((self.batch_size, self.spatial_rank + 1))
# resizing the moving_inputs to match the target
# assumes the same shape across the batch
target_spatial_shape = \
tf.unstack(fixed_shape[0], axis=0)[:self.spatial_rank]
moving_inputs = Resize(new_size=target_spatial_shape)(moving_inputs)
combined_volume = tf.concat([fixed_inputs, moving_inputs], axis=-1)
# smoothing_layer = Smoothing(
# sigma=1, truncate=3.0, type_str='gaussian')
# combined_volume = tf.unstack(combined_volume, axis=-1)
# combined_volume[0] = tf.expand_dims(combined_volume[0], axis=-1)
# combined_volume[1] = smoothing_layer(
# tf.expand_dims(combined_volume[1]), axis=-1)
# combined_volume[2] = tf.expand_dims(combined_volume[2], axis=-1)
# combined_volume[3] = smoothing_layer(
# tf.expand_dims(combined_volume[3]), axis=-1)
# combined_volume = tf.stack(combined_volume, axis=-1)
# TODO affine data augmentation here
if self.spatial_rank == 3:
window_channels = np.prod(self.window_size[self.spatial_rank:]) * 4
# TODO if no affine augmentation:
img_spatial_shape = target_spatial_shape
win_spatial_shape = [tf.constant(dim) for dim in
self.window_size[:self.spatial_rank]]
# when img==win make sure shift => 0.0
# otherwise interpolation is out of bound
batch_shift = [
tf.random_uniform(
shape=(self.batch_size, 1),
minval=0,
maxval=tf.maximum(tf.to_float(img - win - 1), 0.01))
for (win, img) in zip(win_spatial_shape, img_spatial_shape)]
batch_shift = tf.concat(batch_shift, axis=1)
affine_constraints = ((1.0, 0.0, 0.0, None),
(0.0, 1.0, 0.0, None),
(0.0, 0.0, 1.0, None))
computed_grid = AffineGridWarperLayer(
source_shape=(None, None, None),
output_shape=self.window_size[:self.spatial_rank],
constraints=affine_constraints)(batch_shift)
computed_grid.set_shape((self.batch_size,) +
self.window_size[:self.spatial_rank] +
(self.spatial_rank,))
resampler = ResamplerLayer(
interpolation='linear', boundary='replicate')
windows = resampler(combined_volume, computed_grid)
out_shape = [self.batch_size] + \
list(self.window_size[:self.spatial_rank]) + \
[window_channels]
windows.set_shape(out_shape)
image_id = tf.reshape(image_id, (self.batch_size, 1))
start_location = tf.zeros((self.batch_size, self.spatial_rank))
locations = tf.concat([
image_id, start_location, batch_shift], axis=1)
return windows, locations
# return windows, [tf.reduce_max(computed_grid), batch_shift]
# overriding input buffers
[docs] def run_threads(self, session, *args, **argvs):
"""
To be called at the beginning of running graph variables
"""
session.run(self.iterator.initializer)
return