Source code for niftynet.layer.rand_spatial_scaling
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import warnings
import numpy as np
import scipy.ndimage
from niftynet.layer.base_layer import RandomisedLayer
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)
[docs]class RandomSpatialScalingLayer(RandomisedLayer):
"""
generate randomised scaling along each dim for data augmentation
"""
def __init__(self,
min_percentage=-10.0,
max_percentage=10.0,
name='random_spatial_scaling'):
super(RandomSpatialScalingLayer, self).__init__(name=name)
assert min_percentage < max_percentage
self._min_percentage = max(min_percentage, -99.9)
self._max_percentage = max_percentage
self._rand_zoom = None
[docs] def randomise(self, spatial_rank=3):
spatial_rank = int(np.floor(spatial_rank))
rand_zoom = np.random.uniform(low=self._min_percentage,
high=self._max_percentage,
size=(spatial_rank,))
self._rand_zoom = (rand_zoom + 100.0) / 100.0
def _apply_transformation(self, image, interp_order=3):
assert self._rand_zoom is not None
full_zoom = np.array(self._rand_zoom)
while len(full_zoom) < image.ndim:
full_zoom = np.hstack((full_zoom, [1.0]))
if image.ndim == 4:
output = []
for mod in range(image.shape[-1]):
scaled = scipy.ndimage.zoom(image[..., mod],
full_zoom[:3],
order=interp_order)
output.append(scaled[..., np.newaxis])
return np.concatenate(output, axis=-1)
if image.ndim == 3:
scaled = scipy.ndimage.zoom(image,
full_zoom[:3],
order=interp_order)
return scaled[..., np.newaxis]
raise NotImplementedError('not implemented random scaling')
[docs] def layer_op(self, inputs, interp_orders, *args, **kwargs):
if inputs is None:
return inputs
if isinstance(inputs, dict) and isinstance(interp_orders, dict):
for (field, image) in inputs.items():
assert image.shape[-1] == len(interp_orders[field]), \
"interpolation orders should be" \
"specified for each inputs modality"
transformed_data = []
for mod_i, interp_order in enumerate(interp_orders[field]):
scaled_data = self._apply_transformation(
image[..., mod_i], interp_order)
transformed_data.append(scaled_data[..., np.newaxis])
inputs[field] = np.concatenate(transformed_data, axis=-1)
else:
raise NotImplementedError("unknown input format")
return inputs