Source code for niftynet.layer.rand_spatial_scaling

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import warnings

import numpy as np
import scipy.ndimage

from niftynet.layer.base_layer import RandomisedLayer

warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)


[docs]class RandomSpatialScalingLayer(RandomisedLayer): """ generate randomised scaling along each dim for data augmentation """ def __init__(self, min_percentage=-10.0, max_percentage=10.0, name='random_spatial_scaling'): super(RandomSpatialScalingLayer, self).__init__(name=name) assert min_percentage < max_percentage self._min_percentage = max(min_percentage, -99.9) self._max_percentage = max_percentage self._rand_zoom = None
[docs] def randomise(self, spatial_rank=3): spatial_rank = int(np.floor(spatial_rank)) rand_zoom = np.random.uniform(low=self._min_percentage, high=self._max_percentage, size=(spatial_rank,)) self._rand_zoom = (rand_zoom + 100.0) / 100.0
def _apply_transformation(self, image, interp_order=3): if interp_order < 0: return image assert self._rand_zoom is not None full_zoom = np.array(self._rand_zoom) while len(full_zoom) < image.ndim: full_zoom = np.hstack((full_zoom, [1.0])) if image.ndim == 4: output = [] for mod in range(image.shape[-1]): scaled = scipy.ndimage.zoom(image[..., mod], full_zoom[:3], order=interp_order) output.append(scaled[..., np.newaxis]) return np.concatenate(output, axis=-1) if image.ndim == 3: scaled = scipy.ndimage.zoom(image, full_zoom[:3], order=interp_order) return scaled[..., np.newaxis] raise NotImplementedError('not implemented random scaling')
[docs] def layer_op(self, inputs, interp_orders, *args, **kwargs): if inputs is None: return inputs if isinstance(inputs, dict) and isinstance(interp_orders, dict): for (field, image) in inputs.items(): transformed_data = [] interp_order = interp_orders[field][0] for mod_i in range(image.shape[-1]): scaled_data = self._apply_transformation( image[..., mod_i], interp_order) transformed_data.append(scaled_data[..., np.newaxis]) inputs[field] = np.concatenate(transformed_data, axis=-1) # shapes = [] # for (field, image) in inputs.items(): # shapes.append(image.shape) # assert(len(shapes) == 2 and shapes[0][0:4] == shapes[1][0:4]), shapes else: raise NotImplementedError("unknown input format") return inputs