Source code for niftynet.contrib.layer.rand_elastic_deform

# -*- coding: utf-8 -*-
#Data augmentation using elastic deformations as used by:
#Milletari,F., Navab, N., & Ahmadi, S. A. (2016) V-net:
#Fully convolutional neural networks for volumetric medical
#image segmentation


from __future__ import absolute_import, print_function
from niftynet.utilities.util_import import check_module
check_module('SimpleITK')

import warnings

import numpy as np
import SimpleITK as sitk

from niftynet.layer.base_layer import RandomisedLayer

warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)


[docs]class RandomElasticDeformationLayer(RandomisedLayer): """ generate randomised elastic deformations along each dim for data augmentation """ def __init__(self, num_controlpoints=4, std_deformation_sigma=15, name='random_elastic_deformation'): super(RandomElasticDeformationLayer, self).__init__(name=name) self.num_controlpoints = max(num_controlpoints, 2) self.std_deformation_sigma = max(std_deformation_sigma, 1) self.bspline_transformation = None
[docs] def randomise(self, image_dict, spatial_rank=3): images = image_dict.values() equal_shapes = np.all([images[0].shape == image.shape for image in images]) if spatial_rank == 3 and equal_shapes: self._randomise_bspline_transformation_3d(images[0].shape) else: # currently not supported spatial rank for elastic deformation print("randomising elastic deformation FAILED") pass
def _randomise_bspline_transformation_3d(self, shape): # generate transformation itkimg = sitk.GetImageFromArray(np.zeros(shape[:3])) transfromDomainMeshSize = [self.num_controlpoints] * itkimg.GetDimension() self.bspline_transformation = sitk.BSplineTransformInitializer(itkimg, transfromDomainMeshSize) params = self.bspline_transformation.GetParameters() paramsNp = np.asarray(params, dtype=float) paramsNp = paramsNp + np.random.randn(paramsNp.shape[0]) * self.std_deformation_sigma # paramsNp[0:int(len(params) / 3)] = 0 # remove z deformations! The resolution in z is too bad params = tuple(paramsNp) self.bspline_transformation.SetParameters(params) def _apply_bspline_transformation_3d(self, image, interp_order=3): if (np.random.rand(1)[0] > 0.5): # do not apply deformations always, just sometimes sitkImage = sitk.GetImageFromArray(image) resampler = sitk.ResampleImageFilter() resampler.SetReferenceImage(sitkImage) if interp_order == 3: resampler.SetInterpolator(sitk.sitkBSpline) elif interp_order == 2: resampler.SetInterpolator(sitk.sitkLinear) elif interp_order == 1 or interp_order == 0: resampler.SetInterpolator(sitk.sitkNearestNeighbor) else: raise RuntimeError("not supported interpolation_order") resampler.SetDefaultPixelValue(0) resampler.SetTransform(self.bspline_transformation) outimgsitk = resampler.Execute(sitkImage) outimg = sitk.GetArrayFromImage(outimgsitk) return outimg return image
[docs] def layer_op(self, inputs, interp_orders, *args, **kwargs): if inputs is None: return inputs if isinstance(inputs, dict) and isinstance(interp_orders, dict): for (field, image) in inputs.items(): assert image.shape[-1] == len(interp_orders[field]), \ "interpolation orders should be" \ "specified for each inputs modality" for mod_i, interp_order in enumerate(interp_orders[field]): if image.ndim == 4: inputs[field][..., mod_i] = \ self._apply_bspline_transformation_3d( image[..., mod_i], interp_order) elif image.ndim == 5: for t in range(image.shape[-2]): inputs[field][..., t, mod_i] = \ self._apply_bspline_transformation_3d( image[..., t, mod_i], interp_order) else: raise NotImplementedError("unknown input format") else: raise NotImplementedError("unknown input format") return inputs