Source code for niftynet.contrib.learning_rate_schedule.decay_lr_application
import tensorflow as tf
from niftynet.application.segmentation_application import \
SegmentationApplication
from niftynet.engine.application_factory import OptimiserFactory
from niftynet.engine.application_variables import CONSOLE
from niftynet.engine.application_variables import TF_SUMMARIES
from niftynet.layer.loss_segmentation import LossFunction
SUPPORTED_INPUT = set(['image', 'label', 'weight'])
[docs]class DecayLearningRateApplication(SegmentationApplication):
REQUIRED_CONFIG_SECTION = "SEGMENTATION"
def __init__(self, net_param, action_param, is_training):
SegmentationApplication.__init__(
self, net_param, action_param, is_training)
tf.logging.info('starting decay learning segmentation application')
self.learning_rate = None
self.current_lr = action_param.lr
if self.action_param.validation_every_n > 0:
raise NotImplementedError("validation process is not implemented "
"in this demo.")
[docs] def connect_data_and_network(self,
outputs_collector=None,
gradients_collector=None):
data_dict = self.get_sampler()[0][0].pop_batch_op()
image = tf.cast(data_dict['image'], tf.float32)
net_out = self.net(image, self.is_training)
if self.is_training:
with tf.name_scope('Optimiser'):
self.learning_rate = tf.placeholder(tf.float32, shape=[])
optimiser_class = OptimiserFactory.create(
name=self.action_param.optimiser)
self.optimiser = optimiser_class.get_instance(
learning_rate=self.learning_rate)
loss_func = LossFunction(
n_class=self.segmentation_param.num_classes,
loss_type=self.action_param.loss_type)
data_loss = loss_func(
prediction=net_out,
ground_truth=data_dict.get('label', None),
weight_map=data_dict.get('weight', None))
loss = data_loss
reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
if self.net_param.decay > 0.0 and reg_losses:
reg_loss = tf.reduce_mean(
[tf.reduce_mean(reg_loss) for reg_loss in reg_losses])
loss = data_loss + reg_loss
grads = self.optimiser.compute_gradients(loss)
# collecting gradients variables
gradients_collector.add_to_collection([grads])
# collecting output variables
outputs_collector.add_to_collection(
var=data_loss, name='dice_loss',
average_over_devices=False, collection=CONSOLE)
outputs_collector.add_to_collection(
var=self.learning_rate, name='lr',
average_over_devices=False, collection=CONSOLE)
outputs_collector.add_to_collection(
var=data_loss, name='dice_loss',
average_over_devices=True, summary_type='scalar',
collection=TF_SUMMARIES)
else:
# converting logits into final output for
# classification probabilities or argmax classification labels
SegmentationApplication.connect_data_and_network(
self, outputs_collector, gradients_collector)
[docs] def set_iteration_update(self, iteration_message):
"""
This function will be called by the application engine at each
iteration.
"""
current_iter = iteration_message.current_iter
if iteration_message.is_training:
if current_iter > 0 and current_iter % 3 == 0:
self.current_lr = self.current_lr / 2.0
iteration_message.data_feed_dict[self.is_validation] = False
elif iteration_message.is_validation:
iteration_message.data_feed_dict[self.is_validation] = True
iteration_message.data_feed_dict[self.learning_rate] = self.current_lr