Source code for niftynet.layer.rand_elastic_deform

# -*- coding: utf-8 -*-
Data augmentation using elastic deformations as used by:
Milletari,F., Navab, N., & Ahmadi, S. A. (2016) V-net:
Fully convolutional neural networks for volumetric medical
image segmentation

from __future__ import absolute_import, print_function

import warnings

import numpy as np

from niftynet.layer.base_layer import RandomisedLayer
from niftynet.utilities.util_import import require_module

sitk = require_module('SimpleITK')

warnings.simplefilter("ignore", UserWarning)
warnings.simplefilter("ignore", RuntimeWarning)

[docs]class RandomElasticDeformationLayer(RandomisedLayer): """ generate randomised elastic deformations along each dim for data augmentation """
[docs] def __init__(self, num_controlpoints=4, std_deformation_sigma=15, proportion_to_augment=0.5, spatial_rank=3): """ This layer elastically deforms the inputs, for data-augmentation purposes. :param num_controlpoints: :param std_deformation_sigma: :param proportion_to_augment: what fraction of the images to do augmentation on :param name: name for tensorflow graph (may be computationally expensive). """ super(RandomElasticDeformationLayer, self).__init__( name='random_elastic_deformation') self._bspline_transformation = None self.num_controlpoints = max(num_controlpoints, 2) self.std_deformation_sigma = max(std_deformation_sigma, 1) self.proportion_to_augment = proportion_to_augment if not sitk: self.proportion_to_augment = -1 self.spatial_rank = spatial_rank
[docs] def randomise(self, image_dict): images = list(image_dict.values()) equal_shapes = np.all( [images[0].shape[:self.spatial_rank] == image.shape[:self.spatial_rank] for image in images]) if equal_shapes and self.proportion_to_augment >= 0: self._randomise_bspline_transformation(images[0].shape) else: # currently not supported spatial rank for elastic deformation # should support classification in the future print("randomising elastic deformation FAILED") pass
def _randomise_bspline_transformation(self, shape): # generate transformation if len(shape) == 5: # for niftynet reader outputs squeezed_shape = [dim for dim in shape[:3] if dim > 1] else: squeezed_shape = shape[:self.spatial_rank] itkimg = sitk.GetImageFromArray(np.zeros(squeezed_shape)) trans_from_domain_mesh_size = \ [self.num_controlpoints] * itkimg.GetDimension() self._bspline_transformation = sitk.BSplineTransformInitializer( itkimg, trans_from_domain_mesh_size) params = self._bspline_transformation.GetParameters() params_numpy = np.asarray(params, dtype=float) params_numpy = params_numpy + np.random.randn( params_numpy.shape[0]) * self.std_deformation_sigma # remove z deformations! The resolution in z is too bad # params_numpy[0:int(len(params) / 3)] = 0 params = tuple(params_numpy) self._bspline_transformation.SetParameters(params) def _apply_bspline_transformation(self, image, interp_order=3): """ Apply randomised transformation to 2D or 3D image :param image: 2D or 3D array :param interp_order: order of interpolation :return: the transformed image """ resampler = sitk.ResampleImageFilter() if interp_order > 1: resampler.SetInterpolator(sitk.sitkBSpline) elif interp_order == 1: resampler.SetInterpolator(sitk.sitkLinear) elif interp_order == 0: resampler.SetInterpolator(sitk.sitkNearestNeighbor) else: return image squeezed_image = np.squeeze(image) while squeezed_image.ndim < self.spatial_rank: # pad to the required number of dimensions squeezed_image = squeezed_image[..., None] sitk_image = sitk.GetImageFromArray(squeezed_image) resampler.SetReferenceImage(sitk_image) resampler.SetDefaultPixelValue(0) resampler.SetTransform(self._bspline_transformation) out_img_sitk = resampler.Execute(sitk_image) out_img = sitk.GetArrayFromImage(out_img_sitk) return out_img.reshape(image.shape)
[docs] def layer_op(self, inputs, interp_orders, *args, **kwargs): if inputs is None: return inputs # only do augmentation with a probability `proportion_to_augment` do_augmentation = np.random.rand() < self.proportion_to_augment if not do_augmentation: return inputs if isinstance(inputs, dict) and isinstance(interp_orders, dict): for (field, image) in inputs.items(): assert image.shape[-1] == len(interp_orders[field]), \ "interpolation orders should be" \ "specified for each inputs modality" for mod_i, interp_order in enumerate(interp_orders[field]): if image.ndim in (3, 4): # for 2/3d images inputs[field][..., mod_i] = \ self._apply_bspline_transformation( image[..., mod_i], interp_order) elif image.ndim == 5: for t in range(image.shape[-2]): inputs[field][..., t, mod_i] = \ self._apply_bspline_transformation( image[..., t, mod_i], interp_order) else: raise NotImplementedError("unknown input format") else: raise NotImplementedError("unknown input format") return inputs