Source code for niftynet.layer.gan_blocks

# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function, division

import tensorflow as tf

from niftynet.layer.base_layer import TrainableLayer


[docs]class GANImageBlock(TrainableLayer): def __init__(self, generator, discriminator, clip=None, name='GAN_image_block'): self._generator = generator self._discriminator = discriminator self.clip = clip super(GANImageBlock, self).__init__(name=name)
[docs] def layer_op(self, random_source, training_image, conditioning, is_training): shape_to_generate = training_image.shape.as_list()[1:] fake_image = self._generator(random_source, shape_to_generate, conditioning, is_training) fake_logits = self._discriminator(fake_image, conditioning, is_training) if self.clip: with tf.name_scope('clip_real_images'): training_image = tf.maximum( -self.clip, tf.minimum(self.clip, training_image)) real_logits = self._discriminator(training_image, conditioning, is_training) return fake_image, real_logits, fake_logits
[docs]class BaseGenerator(TrainableLayer): def __init__(self, name='generator', *args, **kwargs): super(BaseGenerator, self).__init__(name=name)
[docs]class BaseDiscriminator(TrainableLayer): def __init__(self, name='discriminator', *args, **kwargs): super(BaseDiscriminator, self).__init__(name=name)