Source code for niftynet.contrib.ultrasound_simulator_gan.ultrasound_simulator_gan

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import tensorflow as tf
import numpy as np

from niftynet.layer.convolution import ConvolutionalLayer
from niftynet.layer.deconvolution import DeconvolutionalLayer
from niftynet.layer.gan_blocks import GANImageBlock, BaseGenerator, BaseDiscriminator


[docs]class SimulatorGAN(GANImageBlock): def __init__(self, load_from_checkpoint=True, name='simulator_GAN'): generator = ImageGenerator(load_from_checkpoint=load_from_checkpoint, name='generator') discriminator = ImageDiscriminator(load_from_checkpoint=load_from_checkpoint, name='discriminator') super(SimulatorGAN, self).__init__(generator, discriminator, clip=None, name=name)
[docs]class ImageGenerator(BaseGenerator): def __init__(self, load_from_checkpoint=True, name='generator'): self.load_from_checkpoint = load_from_checkpoint super(ImageGenerator, self).__init__(name=name) self.initializers = {'w': tf.contrib.layers.variance_scaling_initializer(), 'b': tf.constant_initializer(0)} self.noise_channels_per_layer = 0 self.generator_shortcuts = [True, True, True, True, False]
[docs] def layer_op(self, random_source, image_size, conditioning, is_training): spatial_rank = len(image_size) - 1 add_noise = self.noise_channels_per_layer conditioning_channels = conditioning.shape.as_list()[ -1] + add_noise if not conditioning is None else add_noise batch_size = random_source.shape.as_list()[0] noise_size = random_source.shape.as_list()[1] w_init = tf.random_normal_initializer(0, 0.02) b_init = tf.constant_initializer(0.001) ch = [512] sz = [[160, 120]] keep_prob_ph = 1 # not passed in as a placeholder for i in range(4): ch.append(round((ch[-1] + conditioning_channels * self.generator_shortcuts[i]) / 2)) sz = [[round(i / 2) for i in sz[0]]] + sz if spatial_rank == 3: def resize_func(x, sz): sz_x = x.shape.as_list() r1 = tf.image.resize_images(tf.reshape(x, sz_x[:3] + [-1]), sz[0:2]) r2 = tf.image.resize_images(tf.reshape(r1, [sz_x[0], sz[0] * sz[1], sz_x[3], -1]), [sz[0] * sz[1], sz[2]]) return tf.reshape(r2, [sz_x[0]] + sz + [sz_x[-1]]) elif spatial_rank == 2: resize_func = tf.image.resize_bilinear def concat_cond(x, i): if add_noise: noise = [tf.random_normal(x.shape.as_list()[0:-1] + [add_noise], 0, .1)] else: noise = [] if not conditioning is None and self.generator_shortcuts[i]: with tf.name_scope('concat_conditioning'): return tf.concat([x, resize_func(conditioning, x.shape.as_list()[1:-1])] + noise, axis=3) else: return x def conv(ch, x): with tf.name_scope('conv'): conv_layer = ConvolutionalLayer(ch, 3, feature_normalization=None, w_initializer=w_init) c = conv_layer(x, is_training=is_training) return tf.nn.relu(tf.contrib.layers.batch_norm(c)) def up(ch, x, hack=False): with tf.name_scope('up'): deconv = DeconvolutionalLayer(ch, 3, feature_normalization=None, stride=2, w_initializer=w_init)(x, is_training=is_training) if hack: deconv = deconv[:, :, 1:, :] # hack to match Yipeng's image size return tf.nn.relu(tf.contrib.layers.batch_norm(deconv)) def up_block(ch, x, i, hack=False): with tf.name_scope('up_block'): cond = concat_cond(up(ch, x, hack), i) return conv(cond.shape.as_list()[-1], cond) with tf.name_scope('noise_to_image'): g_no_0 = np.prod(sz[0]) * ch[0] w1p = tf.get_variable("G_W1p", shape=[noise_size, g_no_0], initializer=w_init) b1p = tf.get_variable('G_b1p', shape=[g_no_0], initializer=b_init) g_h1p = tf.nn.dropout(tf.nn.relu(tf.matmul(random_source, w1p) + b1p), keep_prob_ph) g_h1p = tf.reshape(g_h1p, [batch_size] + sz[0] + [ch[0]]) g_h1p = concat_cond(g_h1p, 0) g_h1 = conv(ch[0] + conditioning_channels, g_h1p) g_h2 = up_block(ch[1], g_h1, 1, hack=True) g_h3 = up_block(ch[2], g_h2, 2) g_h4 = up_block(ch[3], g_h3, 3) g_h5 = up_block(ch[4], g_h4, 4) with tf.name_scope('final_image'): if add_noise: noise = tf.random_normal(g_h5.shape.as_list()[0:-1] + [add_noise], 0, .1) g_h5 = tf.concat([g_h5, noise],axis=3) x_sample = ConvolutionalLayer(1, 3, feature_normalization=None, with_bias=True, w_initializer=w_init, b_initializer=b_init)(g_h5, is_training=is_training) x_sample = tf.nn.dropout(tf.nn.tanh(x_sample), keep_prob_ph) if self.load_from_checkpoint: checkpoint_name = '/home/egibson/deeplearning/TensorFlow/NiftyNet/NiftyNet/' + \ 'contrib/ultrasound_simulator_gan/ultrasound_simulator_gan' restores = [ ['G_Wo', 'generator/conv_5/conv_/w'], ['G_W5t', 'generator/deconv_3/deconv_/w'], ['G_W5', 'generator/conv_4/conv_/w'], ['G_W4t', 'generator/deconv_2/deconv_/w'], ['G_W4', 'generator/conv_3/conv_/w'], ['G_W3t', 'generator/deconv_1/deconv_/w'], ['G_W3', 'generator/conv_2/conv_/w'], ['G_W2t', 'generator/deconv/deconv_/w'], ['G_W2', 'generator/conv_1/conv_/w'], ['G_W1p', 'generator/G_W1p'], ['G_W1', 'generator/conv/conv_/w'], ['G_b1p', 'generator/G_b1p']] [tf.add_to_collection('NiftyNetObjectsToRestore', (r[1], checkpoint_name, r[0])) for r in restores] return x_sample
[docs]class ImageDiscriminator(BaseDiscriminator): def __init__(self, load_from_checkpoint=True, name='discriminator'): self.load_from_checkpoint=load_from_checkpoint super(ImageDiscriminator, self).__init__(name=name) self.initializers = {'w': tf.contrib.layers.variance_scaling_initializer(), 'b': tf.constant_initializer(0)}
[docs] def layer_op(self, image, conditioning, is_training): conditioning = tf.image.resize_images(conditioning, [160, 120]) batch_size = image.shape.as_list()[0] w_init = tf.random_normal_initializer(0, 0.02) b_init = tf.constant_initializer(0.001) def leaky_relu(x, alpha=0.2): with tf.name_scope('leaky_relu'): return 0.5 * (1 + alpha) * x + 0.5 * (1 - alpha) * abs(x) ch = [32, 64, 128, 256, 512, 1024] def down(ch, x): with tf.name_scope('downsample'): c = ConvolutionalLayer(ch, 3, stride=2, feature_normalization=None, w_initializer=w_init)(x, is_training=is_training) c=tf.contrib.layers.batch_norm(c) c = leaky_relu(c) return c def convr(ch, x): c= ConvolutionalLayer(ch, 3, feature_normalization=None, w_initializer=w_init)(x, is_training=is_training) return leaky_relu(tf.contrib.layers.batch_norm(c)) def conv(ch, x, s): c = (ConvolutionalLayer(ch, 3, feature_normalization=None, w_initializer=w_init)(x, is_training=is_training)) return leaky_relu(tf.contrib.layers.batch_norm(c) + s) def down_block(ch, x): with tf.name_scope('down_resnet'): s = down(ch, x) r = convr(ch, s) return conv(ch, r, s) if not conditioning is None: image = tf.concat([image, conditioning], axis=-1) with tf.name_scope('feature'): d_h1s = ConvolutionalLayer(ch[0], 5, with_bias=True, feature_normalization=None, w_initializer=w_init, b_initializer=b_init)(image, is_training=is_training) d_h1s = leaky_relu(d_h1s) d_h1r = convr(ch[0], d_h1s) d_h1 = conv(ch[0], d_h1r, d_h1s) d_h2 = down_block(ch[1], d_h1) d_h3 = down_block(ch[2], d_h2) d_h4 = down_block(ch[3], d_h3) d_h5 = down_block(ch[4], d_h4) d_h6 = down_block(ch[5], d_h5) with tf.name_scope('fc'): d_hf = tf.reshape(d_h6, [batch_size, -1]) d_nf_o = np.prod(d_hf.shape.as_list()[1:]) D_Wo = tf.get_variable("D_Wo", shape=[d_nf_o, 1], initializer=w_init) D_bo = tf.get_variable('D_bo', shape=[1], initializer=b_init) d_logit = tf.matmul(d_hf, D_Wo) + D_bo if self.load_from_checkpoint: checkpoint_name = '/home/egibson/deeplearning/TensorFlow/NiftyNet/NiftyNet/' + \ 'contrib/ultrasound_simulator_gan/ultrasound_simulator_gan' restores = [ ['D_b1s', 'discriminator/conv/conv_/b'], ['D_W1s', 'discriminator/conv/conv_/w'], ['D_W1r1', 'discriminator/conv_1/conv_/w'], ['D_W1r2', 'discriminator/conv_2/conv_/w'], ['D_W2s', 'discriminator/conv_3/conv_/w'], ['D_W2r1', 'discriminator/conv_4/conv_/w'], ['D_W2r2', 'discriminator/conv_5/conv_/w'], ['D_W3s', 'discriminator/conv_6/conv_/w'], ['D_W3r1', 'discriminator/conv_7/conv_/w'], ['D_W3r2', 'discriminator/conv_8/conv_/w'], ['D_W4s', 'discriminator/conv_9/conv_/w'], ['D_W4r1', 'discriminator/conv_10/conv_/w'], ['D_W4r2', 'discriminator/conv_11/conv_/w'], ['D_W5s', 'discriminator/conv_12/conv_/w'], ['D_W5r1', 'discriminator/conv_13/conv_/w'], ['D_W5r2', 'discriminator/conv_14/conv_/w'], ['D_W6s', 'discriminator/conv_15/conv_/w'], ['D_W6r1', 'discriminator/conv_16/conv_/w'], ['D_W6r2', 'discriminator/conv_17/conv_/w'], ['D_Wo', 'discriminator/D_Wo'], ['D_bo', 'discriminator/D_bo']] [tf.add_to_collection('NiftyNetObjectsToRestore', (r[1], checkpoint_name, r[0])) for r in restores] return d_logit