Source code for niftynet.layer.dilatedcontext
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function
import tensorflow as tf
from niftynet.layer import layer_util
[docs]class DilatedTensor(object):
"""
This context manager makes a wrapper of input_tensor
When created, the input_tensor is dilated,
the input_tensor resumes to original space when exiting the context.
"""
def __init__(self, input_tensor, dilation_factor):
assert (layer_util.check_spatial_dims(
input_tensor, lambda x: x % dilation_factor == 0))
self._tensor = input_tensor
self.dilation_factor = dilation_factor
# parameters to transform input tensor
self.spatial_rank = layer_util.infer_spatial_rank(self._tensor)
self.zero_paddings = [[0, 0]] * self.spatial_rank
self.block_shape = [dilation_factor] * self.spatial_rank
def __enter__(self):
if self.dilation_factor > 1:
self._tensor = tf.space_to_batch_nd(self._tensor,
self.block_shape,
self.zero_paddings,
name='dilated')
return self
def __exit__(self, *args):
if self.dilation_factor > 1:
self._tensor = tf.batch_to_space_nd(self._tensor,
self.block_shape,
self.zero_paddings,
name='de-dilate')
@property
def tensor(self):
return self._tensor
@tensor.setter
def tensor(self, value):
self._tensor = value