Source code for niftynet.layer.rand_flip
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import warnings
import numpy as np
from niftynet.layer.base_layer import RandomisedLayer
warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)
[docs]class RandomFlipLayer(RandomisedLayer):
"""
Add a random flipping layer as pre-processing.
"""
[docs] def __init__(self,
flip_axes,
flip_probability=0.5,
name='random_flip'):
"""
:param flip_axes: a list of indices over which to flip
:param flip_probability: the probability of performing the flip
(default = 0.5)
:param name:
"""
super(RandomFlipLayer, self).__init__(name=name)
self._flip_axes = flip_axes
self._flip_probability = flip_probability
self._rand_flip = None
[docs] def randomise(self, spatial_rank=3):
spatial_rank = int(np.floor(spatial_rank))
self._rand_flip = np.random.random(
size=spatial_rank) < self._flip_probability
def _apply_transformation(self, image):
assert self._rand_flip is not None, "Flip is unset -- Error!"
for axis_number, do_flip in enumerate(self._rand_flip):
if axis_number in self._flip_axes and do_flip:
image = np.flip(image, axis=axis_number)
return image
[docs] def layer_op(self, inputs, interp_orders=None, *args, **kwargs):
if inputs is None:
return inputs
if isinstance(inputs, dict) and isinstance(interp_orders, dict):
for (field, image_data) in inputs.items():
assert (all([i < 0 for i in interp_orders[field]]) or
all([i >= 0 for i in interp_orders[field]])), \
'Cannot combine interpolatable and non-interpolatable data'
if interp_orders[field][0]<0:
continue
inputs[field] = self._apply_transformation(image_data)
else:
inputs = self._apply_transformation(inputs)
return inputs