Source code for niftynet.engine.sampler_weighted

# -*- coding: utf-8 -*-
"""
Generating image window by weighted sampling map from input image
This can also be considered as a "weighted random cropping" layer of the
input image.
"""
from __future__ import absolute_import, division, print_function

import numpy as np
import tensorflow as tf

from niftynet.engine.image_window import N_SPATIAL
from niftynet.engine.sampler_uniform import UniformSampler


[docs]class WeightedSampler(UniformSampler): """ This class generators samples from a user provided frequency map for each input volume The sampling likelihood of each voxel (and window around) is proportional to its frequency This is implemented in a closed form using cumulative histograms for efficiency purposes i.e., the first three dims of image. This layer can be considered as a "weighted random cropping" layer of the input image. """ def __init__(self, reader, data_param, batch_size, windows_per_image, queue_length=10): UniformSampler.__init__(self, reader=reader, data_param=data_param, batch_size=batch_size, windows_per_image=windows_per_image, queue_length=queue_length) tf.logging.info('Initialised weighted sampler window instance') self.spatial_coordinates_generator = weighted_spatial_coordinates
[docs]def weighted_spatial_coordinates(subject_id, data, img_sizes, win_sizes, n_samples=1): """ This is the function that actually does the cumulative histogram and sampling. also, note that win_sizes could be different (for example in segmentation network input image window size is 32x32x10, training label window is 16x16x10 -- the network reduces x-y plane spatial resolution). This function handles this situation by first find the largest window across these window definitions, and generate the coordinates. These coordinates are then adjusted for each of the smaller window sizes (the output windows are concentric). """ # requiring a data['sampler'] as the frequency map. # the shape should be [x, y, z, 1, 1] if data is None or data.get('sampler', None) is None: tf.logging.fatal("input weight map not found. please check " "the configuration file") raise RuntimeError n_samples = max(n_samples, 1) uniq_spatial_size = set([img_size[:N_SPATIAL] for img_size in list(img_sizes.values())]) if len(uniq_spatial_size) > 1: tf.logging.fatal("Don't know how to generate sampling " "locations: Spatial dimensions of the " "grouped input sources are not " "consistent. %s", uniq_spatial_size) raise NotImplementedError uniq_spatial_size = uniq_spatial_size.pop() # find spatial window location based on the largest spatial window spatial_win_sizes = [win_size[:N_SPATIAL] for win_size in win_sizes.values()] spatial_win_sizes = np.asarray(spatial_win_sizes, dtype=np.int32) max_spatial_win = np.max(spatial_win_sizes, axis=0) # testing window size for i in range(0, N_SPATIAL): assert uniq_spatial_size[i] >= max_spatial_win[i], \ "window size {} is larger than image size {}".format( max_spatial_win[i], uniq_spatial_size[i]) # get cropped version of the input weight map where the centre of # the window might be. If the centre of the window was outside of # this crop area, the patch would be outside of the field of view half_win = np.floor(max_spatial_win / 2).astype(int) try: cropped_map = data['sampler'][ half_win[0]:-half_win[0] if max_spatial_win[0] > 1 else 1, half_win[1]:-half_win[1] if max_spatial_win[1] > 1 else 1, half_win[2]:-half_win[2] if max_spatial_win[2] > 1 else 1, 0, 0] assert np.all(cropped_map.shape) > 0 except (IndexError, KeyError): tf.logging.fatal("incompatible map: %s", data['sampler'].shape) raise except AssertionError: tf.logging.fatal( "incompatible window size for weighted sampler. " "Please use smaller (fully-specified) spatial window sizes") raise # Get the cumulative sum of the normalised sorted intensities # i.e. first sort the sampling frequencies, normalise them # to sum to one, and then accumulate them in order flatten_map = cropped_map.flatten() sorted_data = np.cumsum(np.divide(np.sort(flatten_map), flatten_map.sum())) # get the sorting indexes to that we can invert the sorting later on. sorted_indexes = np.argsort(flatten_map) middle_coords = np.zeros((n_samples, N_SPATIAL), dtype=np.int32) for sample in range(0, n_samples): # get n_sample from the cumulative histogram, spaced by 1/n_samples, # plus a random perturbation to give us a stochastic sampler sample_ratio = 1 - (np.random.random() + sample) / (n_samples + 1) # find the index where the cumulative it above the sample threshold # import pdb; pdb.set_trace() try: sample_index = np.argmax(sorted_data >= sample_ratio) except ValueError: tf.logging.fatal("unable to choose sampling window based on " "the current frequency map.") raise # invert the sample index to the pre-sorted index inverted_sample_index = sorted_indexes[sample_index] # get the x,y,z coordinates on the cropped_map # (note: we need to re-shift it later due to the crop) middle_coords[sample, :N_SPATIAL] = np.unravel_index( inverted_sample_index, cropped_map.shape)[:N_SPATIAL] # adjust max spatial coordinates based on each mod spatial window size all_coordinates = {} for mod in list(win_sizes): win_size = win_sizes[mod][:N_SPATIAL] half_win_diff = np.floor((max_spatial_win - win_size) / 2.0) # shift starting coordinates of the window # Note that we did not shift the centre coordinates # above to the corner of the window # because the shift is the same as the cropping amount # Also, we need to add half_win_diff/2 so that smaller windows # are centred within the large windows spatial_coords = np.zeros((n_samples, N_SPATIAL * 2), dtype=np.int32) spatial_coords[:, :N_SPATIAL] = \ middle_coords[:, :N_SPATIAL] + half_win_diff[:N_SPATIAL] # the opposite corner of the window is # just adding the mod specific window size spatial_coords[:, N_SPATIAL:] = \ spatial_coords[:, :N_SPATIAL] + win_size[:N_SPATIAL] # include the subject id subject_id = np.ones((n_samples,), dtype=np.int32) * subject_id spatial_coords = np.append(subject_id[:, None], spatial_coords, axis=1) all_coordinates[mod] = spatial_coords return all_coordinates