Source code for niftynet.engine.image_window_buffer
# -*- coding: utf-8 -*-
"""
This module define queues that stores training/evaluation images (and labels)
"""
from __future__ import absolute_import, print_function, division
import threading
import numpy as np
import tensorflow as tf
from niftynet.io.misc_io import squeeze_spatial_temporal_dim
# pylint: disable=protected-access
[docs]class InputBatchQueueRunner(object):
"""
This class defines a light wrapper around queue objects
for input windows, and the coordinates describes the original location
of the window.
After initialisation, ``run_threads()`` can be called with
``tf.session`` and ``tf.coordinator`` to start generating samples
with multiple threads.
The sampling threads can be stopped by:
``close_all()`` called externally -- all threads quit immediately.
"""
# pylint: disable=too-many-instance-attributes
def __init__(self, capacity, shuffle=True):
# define queue properties
self.capacity = capacity
self.shuffle = shuffle
# dequeue size
self._batch_size = 1
# create queue and the associated operations
self.placeholders_dict = None
self.output_tensor = None
self._queue = None
self._enqueue_op = None
self._dequeue_func = None
self._query_queue_size_op = None
self._close_queue_op = None
# keep track of session and threads created by this class instance
self._session = None
self._coordinator = None
self._threads = []
self._window = None
def _create_queue_and_ops(self, window, enqueue_size=1, dequeue_size=1):
"""
Create a shuffled queue or FIFO queue, and create queue
operations. This should be called before ``tf.Graph.finalize``.
"""
self._window = window
try:
is_dynamic_window = window.has_dynamic_shapes
except AttributeError:
tf.logging.fatal(
"unrecognised window format, expecting a"
"niftynet.engine.image_window.ImageWindow instance")
raise
if is_dynamic_window and enqueue_size > 1:
tf.logging.warning(
"using dynamic window size, buffer input size is set to 1")
if is_dynamic_window and dequeue_size > 1:
tf.logging.warning(
"using dynamic window size, network batch size is set to 1")
_enqueue_size = 1 if is_dynamic_window else enqueue_size
# batch_size is 1 if is_dynamic_window because
# RandomShuffleQueue's DequeueMany and DequeueUpTo require
# the components to have specified shapes
self._batch_size = 1 if is_dynamic_window else dequeue_size
self.capacity = int(max(
self.capacity, round(self._batch_size * 2.5)))
assert self._batch_size <= self.capacity, \
"batch size {} is larger than the buffer size {}, " \
"please increase the queue capacity " \
"or decrease the batch size".format(
self._batch_size, self.capacity)
tf.logging.info('buffering with %s windows', self.capacity)
try:
self.placeholders_dict = window.placeholders_dict(_enqueue_size)
except AttributeError:
tf.logging.fatal(
"unrecognised window format, expecting a"
"niftynet.engine.image_window.ImageWindow instance")
raise
names = list(self.placeholders_dict)
placeholders = list(self.placeholders_dict.values())
input_dtypes = [holder.dtype for holder in placeholders]
input_shapes = [holder.shape[1:] for holder in placeholders] \
if not is_dynamic_window else None
# create a queue
# pylint: disable=redefined-variable-type
if self.shuffle:
self._queue = tf.RandomShuffleQueue(
capacity=self.capacity,
min_after_dequeue=self.capacity // 2,
dtypes=input_dtypes,
shapes=input_shapes,
names=names,
name="shuffled_queue")
assert (self.capacity - self.capacity // 2) >= self._batch_size, \
"batch size larger than the largest possible dequeue size" \
"of the current queue capacity"
else:
self._queue = tf.FIFOQueue(
capacity=self.capacity,
dtypes=input_dtypes,
shapes=input_shapes,
names=names,
name="FIFO_queue")
# create queue operations
if is_dynamic_window:
self._enqueue_op = self._queue.enqueue(self.placeholders_dict)
self._dequeue_func = self._queue.dequeue
else:
self._enqueue_op = self._queue.enqueue_many(self.placeholders_dict)
self._dequeue_func = self._queue.dequeue_many
self._query_queue_size_op = self._queue.size()
self._close_queue_op = self._queue.close(cancel_pending_enqueues=True)
def __call__(self):
tf.logging.fatal(
'input queue should be used with a'
'niftynet.layer.base_layer.Layer instance,'
'where a layer_op is implemented as providing'
'enqueue data')
raise NotImplementedError
def _push(self, thread_id):
tf.logging.info('New thread: %d', thread_id)
# pylint: disable=broad-except
try:
output_dict = None
for output_dict in self():
if self._session._closed or self._coordinator.should_stop():
break
self._session.run(self._enqueue_op, feed_dict=output_dict)
if output_dict is None:
tf.logging.fatal('no output from the sampler')
raise ValueError
# push a set of stopping patches
for _ in range(self.capacity + self._batch_size):
if self._session._closed or self._coordinator.should_stop():
break
for name in list(output_dict):
output_dict[name] = np.ones_like(output_dict[name]) * -1
self._session.run(self._enqueue_op, feed_dict=output_dict)
except NotImplementedError:
self.close_all()
raise
except tf.errors.CancelledError:
pass
except Exception:
import sys
import traceback
exc_type, exc_value, exc_traceback = sys.exc_info()
traceback.print_exception(
exc_type, exc_value, exc_traceback, file=sys.stdout)
self.close_all()
raise
finally:
pass
def _current_queue_size(self):
# for debugging purpose
if self._session._closed:
return 0
return self._session.run(self._query_queue_size_op)
[docs] def pop_batch_op(self):
"""
This function is used when connecting a sampler output
to a network. e.g.::
data_dict = self.get_sampler()[0].pop_batch_op(device_id)
net_output = net_model(data_dict, is_training)
.. caution::
Note it squeezes the output tensor of 6 dims
``[batch, x, y, z, time, modality]``
by removing all dims along which length is one.
:return: a tensorflow graph op
"""
assert all([thread.isAlive() for thread in self._threads]), \
"input sampling threads are not running"
if self._window.has_dynamic_shapes:
data_output = self._dequeue_func()
else:
data_output = self._dequeue_func(self._batch_size)
for (name, shape) in self._window.shapes.items():
# set first dim as the batch size
data_output[name].set_shape([self._batch_size] + list(shape[1:]))
for name in data_output:
data_output[name] = squeeze_spatial_temporal_dim(data_output[name])
# keep a copy of the sampler's output tensors
self.output_tensor = data_output
return data_output
[docs] def run_threads(self, session, coord, num_threads=1):
"""
This function should be called by application.driver,
where a session and coordinator is maintained, it
starts sampling threads to fill the queue.
Note that the threads will be blocked if there's no
dequeue_op running, or number of samples is less
than the dequeue batch size.
:param session: a tensorflow session
:param coord: a tensorflow coordinator
:param num_threads: integer specifies the number of threads
:return:
"""
num_threads = max(int(num_threads), 1)
if num_threads > 1 and isinstance(self._queue, tf.FIFOQueue):
tf.logging.warning('Only one thread for FIFO Queues')
num_threads = 1
tf.logging.info('Starting preprocessing threads...')
self._session = session
self._coordinator = coord
for idx in range(num_threads):
self._threads.append(
threading.Thread(target=self._push, args=[idx]))
self._threads[idx].daemon = True
self._threads[idx].start()
[docs] def close_all(self):
"""
This function stops all threads immediately and close the queue.
Further enqueue/dequeue operation raises errors
"""
if not self._threads:
tf.logging.warning("the queue threads is not currently running")
try:
self._coordinator.request_stop()
self._coordinator.join(threads=self._threads,
stop_grace_period_secs=0)
except (RuntimeError, AttributeError):
pass
finally:
if (self._session is not None) and (not self._session._closed):
self._session.run(self._close_queue_op)