Source code for niftynet.layer.crf

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import tensorflow as tf

from niftynet.layer.base_layer import TrainableLayer
import numpy

"""
Re-implementation of [1] for volumetric image processing.


[1] Zheng, Shuai, et al. "Conditional random names as recurrent neural networks."
CVPR 2015.
"""

[docs]def permutohedral_prepare(position_vectors): batch_size = int(position_vectors.get_shape()[0]) nCh=int(position_vectors.get_shape()[-1]) nVoxels=int(position_vectors.get_shape().num_elements())//batch_size//nCh # reshaping batches and voxels into one dimension means we can use 1D gather and hashing easily position_vectors=tf.reshape(position_vectors,[-1,nCh]) ## Generate position vectors in lattice space x=position_vectors/(numpy.sqrt(2./3.)*(nCh+1)) # Embed in lattice space using black magic from the permutohedral paper alpha=lambda i:numpy.sqrt(float(i)/(float(i)+1.)) Ex=[None]*(nCh+1) Ex[nCh]=-alpha(nCh)*x[:,nCh-1] for dit in range(nCh-1,0,-1): Ex[dit]=-alpha(dit)*x[:,dit-1]+x[:,dit]/alpha(dit+1)+Ex[dit+1] Ex[0]=x[:,0]/alpha(1)+Ex[1] Ex=tf.stack(Ex,1) ## Compute coordinates # Get closest remainder-0 point v=tf.to_int32(tf.round(Ex*(1./float(nCh+1)))) rem0=v*(nCh+1) sumV=tf.reduce_sum(v,1,keep_dims=True) # Find the simplex we are in and store it in rank (where rank describes what position coorinate i has #in the sorted order of the features values) di=Ex-tf.to_float(rem0) _,index=tf.nn.top_k(di,nCh+1,sorted=True) _,rank=tf.nn.top_k(-index,nCh+1,sorted=True) # This can be done more efficiently if necessary following the permutohedral paper # if the point doesn't lie on the plane (sum != 0) bring it back rank=tf.to_int32(rank)+sumV addMinusSub=tf.to_int32(rank<0)*(nCh+1)-tf.to_int32(rank>=nCh+1)*(nCh+1) rank=rank+addMinusSub rem0=rem0+addMinusSub # Compute the barycentric coordinates (p.10 in [Adams etal 2010]) v2=(Ex-tf.to_float(rem0))*(1./float(nCh+1)) # the barycentric coordinates are v_sorted-v_sorted[...,[-1,1:-1]]+[1,0,0,...] # CRF2RNN uses the calculated ranks to get v2 sorted in O(nCh) time # We cheat here by using the easy to implement but slower method of sorting again in O(nCh log nCh) # we might get this even more efficient if we correct the original sorted data above v_sortedDesc,_=tf.nn.top_k(v2,nCh+1,sorted=True) v_sorted=tf.reverse(v_sortedDesc,[1]) barycentric=v_sorted-tf.concat([v_sorted[:,-1:]-1.,v_sorted[:,0:nCh]],1) # Compute all vertices and their offset canonical = [[i]*(nCh+1-i)+[i-nCh-1]*i for i in range(nCh+1)] # WARNING: This hash function does not guarantee uniqueness of different position_vectors hashVector = tf.constant(numpy.power(int(numpy.floor(numpy.power(tf.int64.max,1./(nCh+2)))),[range(1,nCh+1)]),dtype=tf.int64) hash=lambda key: tf.reduce_sum(tf.to_int64(key)*hashVector,1) hashtable=tf.contrib.lookup.MutableDenseHashTable(tf.int64,tf.int64,default_value=tf.constant([-1]*nCh,dtype=tf.int64),empty_key=-1,initial_num_buckets=8,checkpoint=False) indextable=tf.contrib.lookup.MutableDenseHashTable(tf.int64,tf.int64,default_value=0,empty_key=-1,initial_num_buckets=8,checkpoint=False) numSimplexCorners=nCh+1 keys=[None]*numSimplexCorners i64keys=[None]*numSimplexCorners insertOps=[] for scit in range(numSimplexCorners): keys[scit] = tf.gather(canonical[scit],rank[:,:-1])+rem0[:,:-1] i64keys[scit]=hash(keys[scit]) insertOps.append(hashtable.insert(i64keys[scit],tf.to_int64(keys[scit]))) with tf.control_dependencies(insertOps): fusedI64Keys,fusedKeys = hashtable.export() fusedKeys=tf.boolean_mask(fusedKeys,tf.not_equal(fusedI64Keys,-1)) fusedI64Keys=tf.boolean_mask(fusedI64Keys,tf.not_equal(fusedI64Keys,-1)) insertIndices = indextable.insert(fusedI64Keys,tf.expand_dims(tf.transpose(tf.range(1,tf.to_int64(tf.size(fusedI64Keys)+1),dtype=tf.int64)),1)) blurNeighbours1=[None]*(nCh+1) blurNeighbours2=[None]*(nCh+1) indices=[None]*(nCh+1) with tf.control_dependencies([insertIndices]): for dit in range(nCh+1): offset=tf.constant([nCh if i==dit else -1 for i in range(nCh)],dtype=tf.int64) blurNeighbours1[dit]=indextable.lookup(hash(fusedKeys+offset)) blurNeighbours2[dit]=indextable.lookup(hash(fusedKeys-offset)) batch_index=tf.reshape(tf.meshgrid(tf.range(batch_size),tf.zeros([nVoxels],dtype=tf.int32))[0],[-1,1]) indices[dit] = tf.stack([tf.to_int32(indextable.lookup(i64keys[dit])),batch_index[:,0]],1) # where in the splat variable each simplex vertex is return barycentric,blurNeighbours1,blurNeighbours2,indices
[docs]def permutohedral_compute(data_vectors,barycentric,blurNeighbours1,blurNeighbours2,indices,name,reverse): batch_size=tf.shape(data_vectors)[0] numSimplexCorners=int(barycentric.get_shape()[-1]) nCh=numSimplexCorners-1 nChData=tf.shape(data_vectors)[-1] data_vectors = tf.reshape(data_vectors,[-1,nChData]) data_vectors = tf.concat([data_vectors,tf.ones_like(data_vectors[:,0:1])],1) # Convert to homogenous coordinates ## Splatting initialSplat=tf.zeros([tf.shape(blurNeighbours1[0])[0]+1,batch_size,nChData+1]) with tf.variable_scope(name): # WARNING: we use local variables so the graph must initialize local variables with tf.local_variables_initializer() splat=tf.contrib.framework.local_variable(tf.ones([0,0]),validate_shape=False,name='splatbuffer') with tf.control_dependencies([splat.initialized_value()]): resetSplat=tf.assign(splat,initialSplat,validate_shape=False,name='assign') # This is needed to force tensorflow to update the cache with tf.control_dependencies([resetSplat]): uncachedSplat=splat.read_value() for scit in range(numSimplexCorners): data = data_vectors*barycentric[:,scit:scit+1] with tf.control_dependencies([uncachedSplat]): splat=tf.scatter_nd_add(splat,indices[scit],data) ## Blur with tf.control_dependencies([splat]): blurred=[splat] order = range(nCh,-1,-1) if reverse else range(nCh+1) for dit in order: with tf.control_dependencies([blurred[-1]]): b1=0.5*tf.gather(blurred[-1],blurNeighbours1[dit]) b2=blurred[-1][1:,:,:] b3=0.5*tf.gather(blurred[-1],blurNeighbours2[dit]) blurred.append(tf.concat([blurred[-1][0:1,:,:], b2+b1+b3],0)) # Alpha is a magic scaling constant from CRFAsRNN code alpha = 1. / (1.+numpy.power(2., -nCh)) normalized=blurred[-1][:,:,:-1]/blurred[-1][:,:,-1:] ## Slice sliced = tf.gather_nd(normalized,indices[0])*barycentric[:,0:1]*alpha for scit in range(1,numSimplexCorners): sliced = sliced+tf.gather_nd(normalized,indices[scit])*barycentric[:,scit:scit+1]*alpha return sliced
# Differentiation can be done using permutohedral lattice with gaussion filter order reversed # To get this to work with automatic differentiation we use a hack attributed to Sergey Ioffe # mentioned here: http://stackoverflow.com/questions/36456436/how-can-i-define-only-the-gradient-for-a-tensorflow-subgraph/36480182 # Define custom py_func which takes also a grad op as argument: from https://gist.github.com/harpone/3453185b41d8d985356cbe5e57d67342
[docs]def py_func(func, inp, Tout, stateful=True, name=None, grad=None): # Need to generate a unique name to avoid duplicates: rnd_name = 'PyFuncGrad' + str(numpy.random.randint(0, 1E+8)) tf.RegisterGradient(rnd_name)(grad) # see _MySquareGrad for grad example g = tf.get_default_graph() with g.gradient_override_map({"PyFunc": rnd_name}): return tf.py_func(func, inp, Tout, stateful=stateful, name=name)
[docs]def gradientStub(data_vectors,barycentric,blurNeighbours1,blurNeighbours2,indices,name): # This is a stub operator whose purpose is to allow us to overwrite the gradient. # The forward pass gives zeros and the backward pass gives the correct gradients for the permutohedral_compute function return py_func(lambda data_vectors,barycentric,blurNeighbours1,blurNeighbours2,indices: data_vectors*0, [data_vectors,barycentric,blurNeighbours1,blurNeighbours2,indices], [tf.float32], name=name, grad=lambda op,grad: [permutohedral_compute(grad,op.inputs[1],op.inputs[2],op.inputs[3],op.inputs[4],name,reverse=True)]+[tf.zeros_like(i) for i in op.inputs[1:]])
[docs]def permutohedral_gen(permutohedral, data_vectors,name): barycentric,blurNeighbours1,blurNeighbours2,indices=permutohedral return gradientStub(data_vectors,barycentric,blurNeighbours1,blurNeighbours2,indices,name)+ tf.stop_gradient(tf.reshape(permutohedral_compute(data_vectors,barycentric,blurNeighbours1,blurNeighbours2,indices,name,reverse=False),data_vectors.get_shape()))
[docs]def ftheta(U,H1,permutohedrals,mu,kernel_weights, aspect_ratio,name): nCh=U.get_shape().as_list()[-1] batch_size=int(U.get_shape()[0]) # Message Passing data=tf.reshape(tf.nn.softmax(H1),[batch_size,-1,nCh]) Q1=[None]*len(permutohedrals) with tf.device('/cpu:0'): for idx,permutohedral in enumerate(permutohedrals): Q1[idx] = tf.reshape(permutohedral_gen(permutohedral,data,name+str(idx)),U.get_shape()) # Weighting Filter Outputs Q2=tf.add_n([Q1*w for Q1,w in zip(Q1,kernel_weights)]) # Compatibility Transform Q3=tf.nn.conv3d(Q2,mu,strides=[1,1,1,1,1],padding='SAME') # Adding Unary Potentials Q4=U-Q3 # Normalizing return Q4 # output logits, not the softmax
[docs]class CRFAsRNNLayer(TrainableLayer): """ This class defines a layer implementing CRFAsRNN described in [1] using a bilateral and a spatial kernel as in [2]. Essentially, this layer smooths its input based on a distance in a feature space comprising spatial and feature dimensions. [1] Zheng, Shuai, et al. "Conditional random names as recurrent neural networks." CVPR 2015. [2] https://arxiv.org/pdf/1210.5644.pdf """
[docs] def __init__(self,alpha=5.,beta=5.,gamma=5.,T=5,aspect_ratio=(1.,1.,1.), name="crf_as_rnn"): """ Parameters: alpha: bandwidth for spatial coordinates in bilateral kernel. Higher values cause more spatial blurring beta: bandwidth for feature coordinates in bilateral kernel Higher values cause more feature blurring gamma: bandwidth for spatial coordinates in spatial kernel Higher values cause more spatial blurring T: number of stacked layers in the RNN aspect_ratio: spacing of adjacent voxels (allows isotropic spatial smoothing when voxels are not isotropic """ super(CRFAsRNNLayer, self).__init__(name=name) self._T=T self._aspect_ratio=aspect_ratio self._alpha=alpha self._beta=beta self._gamma=gamma self._name=name
[docs] def layer_op(self, I,U): """ Parameters: I: feature maps defining the non-spatial dimensions within which smoothing is performed For example, to smooth U within regions of similar intensity this would be the image intensity U: activation maps to smooth """ batch_size=int(U.get_shape()[0]) H1=[U] # Build permutohedral structures for smoothing coords=tf.tile(tf.expand_dims(tf.stack(tf.meshgrid(*[numpy.array(range(int(i)),dtype=numpy.float32)*a for i,a in zip(U.get_shape()[1:4],self._aspect_ratio)]),3),0),[batch_size,1,1,1,1]) bilateralCoords =tf.reshape(tf.concat([coords/self._alpha,I/self._beta],4),[batch_size,-1,int(I.get_shape()[-1])+3]) spatialCoords=tf.reshape(coords/self._gamma,[batch_size,-1,3]) kernel_coords=[bilateralCoords,spatialCoords] permutohedrals = [permutohedral_prepare(coords) for coords in kernel_coords] nCh=U.get_shape()[-1] mu = tf.get_variable('Compatibility',initializer=tf.constant(numpy.reshape(numpy.eye(nCh),[1,1,1,nCh,nCh]),dtype=tf.float32)) kernel_weights = [tf.get_variable("FilterWeights"+str(idx), shape=[1,1,1,1,nCh], initializer=tf.zeros_initializer()) for idx,k in enumerate(permutohedrals)] for t in range(self._T): H1.append(ftheta(U,H1[-1],permutohedrals,mu,kernel_weights, aspect_ratio=self._aspect_ratio,name=self._name+str(t))) return H1[-1]