Source code for niftynet.engine.sampler_weighted_v2

# -*- coding: utf-8 -*-
"""
Generating image window by weighted sampling map from input image
This can also be considered as a "weighted random cropping" layer of the
input image.
"""
from __future__ import absolute_import, division, print_function

import numpy as np
import tensorflow as tf

from niftynet.engine.sampler_uniform_v2 import UniformSampler
from niftynet.engine.image_window import N_SPATIAL


[docs]class WeightedSampler(UniformSampler): """ This class generators samples from a user provided frequency map for each input volume The sampling likelihood of each voxel (and window around) is proportional to its frequency This is implemented in a closed form using cumulative histograms for efficiency purposes i.e., the first three dims of image. This layer can be considered as a "weighted random cropping" layer of the input image. """ def __init__(self, reader, window_sizes, batch_size=1, windows_per_image=1, queue_length=10, name='weighted_sampler'): UniformSampler.__init__(self, reader=reader, window_sizes=window_sizes, batch_size=batch_size, windows_per_image=windows_per_image, queue_length=queue_length, name=name) tf.logging.info('Initialised weighted sampler window instance') self.window_centers_sampler = weighted_spatial_coordinates
[docs]def weighted_spatial_coordinates( n_samples, img_spatial_size, win_spatial_size, sampler_map): """ Weighted sampling from a map. This function uses a cumulative histogram for fast sampling. see also `sampler_uniform.rand_spatial_coordinates` :param n_samples: number of random coordinates to generate :param img_spatial_size: input image size :param win_spatial_size: input window size :param sampler_map: sampling prior map, it's spatial shape should be consistent with `img_spatial_size` :return: (n_samples, N_SPATIAL) coordinates representing sampling window centres relative to img_spatial_size """ assert sampler_map is not None, \ 'sampling prior map is not specified, ' \ 'please check `sampler=` option in the config.' # Get the cumulative sum of the normalised sorted intensities # i.e. first sort the sampling frequencies, normalise them # to sum to one, and then accumulate them in order assert np.all(img_spatial_size[:N_SPATIAL] == sampler_map.shape[:N_SPATIAL]), \ 'image and sampling map shapes do not match' win_spatial_size = np.asarray(win_spatial_size, dtype=np.int32) cropped_map = crop_sampling_map(sampler_map, win_spatial_size) flatten_map = cropped_map.flatten() flatten_map_min = np.min(flatten_map) if flatten_map_min < 0: flatten_map -= flatten_map_min normaliser = flatten_map.sum() # get the sorting indexes to that we can invert the sorting later on. sorted_indexes = np.argsort(flatten_map) sorted_data = np.cumsum( np.true_divide(flatten_map[sorted_indexes], normaliser)) middle_coords = np.zeros((n_samples, N_SPATIAL), dtype=np.int32) for sample in range(n_samples): # get n_sample from the cumulative histogram, spaced by 1/n_samples, # plus a random perturbation to give us a stochastic sampler sample_ratio = 1 - (np.random.random() + sample) / n_samples # find the index where the cumulative it above the sample threshold try: if normaliser == 0: # constant map? reducing to a uniform sampling sample_index = np.random.randint(len(sorted_data)) else: sample_index = np.argmax(sorted_data >= sample_ratio) except ValueError: tf.logging.fatal("unable to choose sampling window based on " "the current frequency map.") raise # invert the sample index to the pre-sorted index inverted_sample_index = sorted_indexes[sample_index] # get the x,y,z coordinates on the cropped_map middle_coords[sample, :N_SPATIAL] = np.unravel_index( inverted_sample_index, cropped_map.shape)[:N_SPATIAL] # re-shift coords due to the crop half_win = np.floor(win_spatial_size / 2).astype(np.int32) middle_coords[:, :N_SPATIAL] = \ middle_coords[:, :N_SPATIAL] + half_win[:N_SPATIAL] return middle_coords
[docs]def crop_sampling_map(input_map, win_spatial_size): """ Utility function for generating a cropped version of the input sampling prior map (the input weight map where the centre of the window might be). If the centre of the window was outside of this crop area, the patch would be outside of the field of view :param input_map: the input weight map where the centre of the window might be :param win_spatial_size: size of the borders to be cropped :return: cropped sampling map """ # prepare cropping indices _start, _end = [], [] for win_size, img_size in \ zip(win_spatial_size[:N_SPATIAL], input_map.shape[:N_SPATIAL]): # cropping floor of the half window d_start = int(win_size / 2.0) # using ceil of half window d_end = img_size - win_size + int(win_size / 2.0 + 0.6) _start.append(d_start) _end.append(d_end + 1 if d_start == d_end else d_end) try: assert len(_start) == 3 cropped_map = input_map[ _start[0]:_end[0], _start[1]:_end[1], _start[2]:_end[2], 0, 0] assert np.all(cropped_map.shape) > 0 except (IndexError, KeyError, TypeError, AssertionError): tf.logging.fatal( "incompatible map: %s and window size: %s\n" "try smaller (fully-specified) spatial window sizes?", input_map.shape, win_spatial_size) raise return cropped_map