Source code for niftynet.engine.image_window

# -*- coding: utf-8 -*-
"""
This module provides an interface for data elements passed
from sampler to network.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

N_SPATIAL = 3
LOCATION_FORMAT = "{}_location"
BUFFER_POSITION_DTYPE = tf.int32


# TF_NP_DTYPES = {tf.int32: np.int32, tf.float32: np.float32}


[docs]class ImageWindow(object): """ Each window is associated with a tuple of coordinates. These data properties are used to create TF placeholders when constructing a TF graph. Samplers read the data specifications and fill the placeholder with data. """ def __init__(self, names, shapes, dtypes): self.names = names self.shapes = shapes self.dtypes = dtypes self.n_samples = 1 self.has_dynamic_shapes = self._check_dynamic_shapes() self._placeholders_dict = None
[docs] @classmethod def from_data_reader_properties(cls, source_names, image_shapes, image_dtypes, data_param): """ Create a window instance with input data properties each property is grouped into dict, with pairs of image_name: data_value. Some input images is a concatenated data array from multiple data sources. example of input:: source_names={ 'image': (u'modality1', u'modality2'), 'label': (u'modality3',)}, image_shapes={ 'image': (192, 160, 192, 1, 2), 'label': (192, 160, 192, 1, 1)}, image_dtypes={ 'image': tf.float32, 'label': tf.float32}, data_param={ 'modality1': ParserNamespace(spatial_window_size=(10, 10, 2)), 'modality2': ParserNamespace(spatial_window_size=(10, 10, 2)), 'modality3': ParserNamespace(spatial_window_size=(5, 5, 1))} see ``niftynet.io.ImageReader`` for more details. :param source_names: input image names :param image_shapes: tuple of image window shapes :param image_dtypes: tuple of image window data types :param data_param: dict of each input source specifications :return: an ImageWindow instance """ try: input_names = tuple(source_names) except TypeError: tf.logging.fatal('image names should be a dictionary of strings') raise try: # complete window shapes based on user input and input_image sizes spatial_shapes = { name: _read_window_sizes(modalities, data_param) for (name, modalities) in source_names.items()} shapes = { name: _complete_partial_window_sizes( spatial_shapes[name], image_shapes[name]) for name in input_names} except KeyError: tf.logging.fatal('data_param wrong format %s', data_param) raise # create ImageWindow instance return cls(names=input_names, shapes=shapes, dtypes=image_dtypes)
[docs] def set_spatial_shape(self, spatial_window): """ Overrides all spatial window defined in input modalities sections this is useful when do inference with a spatial window which is different from the training specifications. :param spatial_window: tuple of integers specifying new shape :return: """ try: spatial_window = [int(win_size) for win_size in spatial_window] except ValueError: tf.logging.fatal("spatial window should be an array of int") raise self.shapes = { name: _complete_partial_window_sizes( spatial_window, self.shapes[name]) for name in self.names} # update based on the latest spatial shapes self.has_dynamic_shapes = self._check_dynamic_shapes() if self._placeholders_dict is not None: self._update_placeholders_dict(n_samples=self.n_samples)
def _update_placeholders_dict(self, n_samples=1): # batch size=1 if the shapes are dynamic self.n_samples = 1 if self.has_dynamic_shapes else n_samples names = list(self.names) placeholders = [] try: placeholders = [ tf.placeholder( dtype=self.dtypes[name], shape=[self.n_samples] + list(self.shapes[name]), name=name) for name in names] except TypeError: tf.logging.fatal( 'shape should be defined as dict of iterable %s', self.shapes) raise # extending names with names of coordinates names.extend([LOCATION_FORMAT.format(name) for name in names]) # extending placeholders with names of coordinates location_shape = [self.n_samples, 1 + N_SPATIAL * 2] placeholders.extend( [tf.placeholder(dtype=BUFFER_POSITION_DTYPE, shape=location_shape, name=name) for name in self.names]) self._placeholders_dict = dict(zip(names, placeholders))
[docs] def placeholders_dict(self, n_samples=1): """ This function create a dictionary with items of ``{name: placeholders}`` name should match the queue input names placeholders corresponds to the image window data for each of these items an additional ``{location_name: placeholders}`` is created to hold the spatial location of the image window :param n_samples: specifies the number of image windows :return: a dictionary with window data and locations placeholders """ if self._placeholders_dict is not None: return self._placeholders_dict self._update_placeholders_dict(n_samples) return self._placeholders_dict
[docs] def coordinates_placeholder(self, name): """ get coordinates placeholder, location name is formed using ``LOCATION_FORMAT`` :param name: input name string :return: coordinates placeholder """ try: return self._placeholders_dict[LOCATION_FORMAT.format(name)] except TypeError: tf.logging.fatal('call placeholders_dict to initialise first') raise
[docs] def image_data_placeholder(self, name): """ get the image data placeholder by name :param name: input name string :return: image placeholder """ try: return self._placeholders_dict[name] except TypeError: tf.logging.fatal('call placeholders_dict to initialise first') raise
def _check_dynamic_shapes(self): """ Check whether the shape of the window is fully specified :return: True indicates it's dynamic, False indicates the window size is fully specified. """ for shape in list(self.shapes.values()): try: for dim_length in shape: if not dim_length: return True except TypeError: return False return False
[docs] def match_image_shapes(self, image_shapes): """ if the window has dynamic shapes, this function infers the fully specified shape from the image_shapes. :param image_shapes: :return: dict of fully specified window shapes """ if self.has_dynamic_shapes: static_window_shapes = self.shapes.copy() # fill the None element in dynamic shapes using image_sizes for name in self.names: static_window_shapes[name] = tuple( win_size if win_size else image_shape for (win_size, image_shape) in zip(list(self.shapes[name]), image_shapes[name])) else: static_window_shapes = self.shapes return static_window_shapes
def _read_window_sizes(input_mod_list, input_data_param): """ Read window_size from config dict group them based on output names, this function ensures that in the multimodality case the spatial window sizes are the same across modalities. :param input_mod_list: list of input source names :param input_data_param: input source properties obtained by parameters parser :return: spatial window size """ try: window_sizes = [input_data_param[input_name].spatial_window_size for input_name in input_mod_list] except (AttributeError, TypeError, KeyError): tf.logging.fatal('unknown input_data_param format %s %s', input_mod_list, input_data_param) raise if not all(window_sizes): window_sizes = [win_size for win_size in window_sizes if win_size] uniq_window_set = set(window_sizes) if len(uniq_window_set) > 1: # pylint: disable=logging-format-interpolation tf.logging.fatal( "trying to combine input sources " "with different window sizes: %s", window_sizes) raise NotImplementedError window_shape = None if uniq_window_set: window_shape = uniq_window_set.pop() try: return tuple(int(win_size) for win_size in window_shape) except (TypeError, ValueError): pass try: # try to make it a tuple return int(window_shape), except (TypeError, ValueError): tf.logging.fatal('unknown spatial_window_size param %s, %s', input_mod_list, input_data_param) raise def _complete_partial_window_sizes(win_size, img_size): """ Window size can be partially specified in the config. This function complete the window size by making it the same ndim as img_size, and set the not added dim to size None. None values in window will be realised when each image is loaded. :param win_size: a tuple of (partial) window size :param img_size: a tuple of image size :return: a window size with the same ndim as image size, with None components to be inferred at runtime """ img_ndims = len(img_size) # crop win_size list if it's longer than img_size win_size = list(win_size[:img_ndims]) while len(win_size) < N_SPATIAL: win_size.append(-1) # complete win_size list if it's shorter than img_size while len(win_size) < img_ndims: win_size.append(img_size[len(win_size)]) # replace zero with full length in the n-th dim of image win_size = [win if win > 0 else None for win in win_size] return tuple(win_size)