Finetuning networks

With NiftyNet, it’s possible to initialize your neural net with pre-trained variables and then fine-tune it for a seperate but similar task. This functionality is provided through two config file parameters: vars_to_restore and vars_to_freeze.

Setting up your model directory

To fine-tune your model on a new dataset, first create a new model directory and specify its location through the model_dir parameter in your config file. Inside your model_dir directory create a folder named models and place the three checkpoint files which constitue your model inside. Your directory structure should look like this:

model_dir/
  models/
    model.ckpt-###.data-00000-of-00001
    model.ckpt-###.index
    model.ckpt-###.meta

You can specify which checkpoint sources the pre-trained variables by setting the starting_iter parameter in your config file. The starting_iter should be set to either the ### in your checkpoint filenames or -1, which will use the latest model in the models folder. You can freely change ### in the filenames to anything you want, however specifying 0 will cause NiftyNet to ignore existing models and create a new model with randomized variables. Just change ### to 1 if you want a fresh iteration counter.

Selecting variables to restore

Next we must decide which variables we would like to restore. The vars_to_restore parameter allows you to specify a regular expression that will match variable names in a checkpoint file. Only variable names matched by the regex will be restored by NiftyNet while the rest are initialized to random values.

You can obtain a list of all the variables in your model using the following bit of code:

import tensorflow as tf

# ckpt_path: full path to checkpoint file (e.g.: /path/to/ckpt/model.ckpt-###)
# output_file: name of output file (e.g.: /path/to/file/net_vars.txt)
def get_ckpt_vars(ckpt_path, output_file):
    file = open(output_file, 'w+')
    for var in tf.train.list_variables(ckpt_path):
        file.write(str(var) + '\n')
    file.close()

get_ckpt_vars('~/Desktop/model_outputs/models/model.ckpt-1', \
              '~/Desktop/net_vars.txt')

Once you’ve determined which variables you plan to restore, you must write a regex which will match them. If you have little experience with regex, here are a few examples to get you started:

For matching variables:

vars_to_restore=^.*(conv_1|conv_2).*$ will match all trainable variables that have conv_1 or conv_2 in their name (the | acts like a boolean OR).

For excluding variables:

vars_to_restore=^((?!DenseVNet\/(skip_conv|fin_conv)).)*$ will not match any trainable variables that contain DenseVNet/skip_conv or DenseVNet/fin_conv in their name (the \ is an escape character for /).

Once you’ve created a regex expression, it’s recommended that you use a tool like RegEx101 to double check that it works as expected. Set the test string as the list of variable names returned by get_ckpt_vars(). If the regex successfuly selects the lines you intended, then you can use it to set vars_to_restore.

Freezing model weights

The model variables matched by vars_to_restore will be restored from checkpoint and by default, remain intact during training. To change the model parameter updating behaviour, you can also specify a regular expression of vars_to_freeze. The matched variables within the collection of trainable parameters (tf.trainable_variables()) will not be updated during training. This is useful for training a subset of layers of a network while leaving the rest of the network fixed.

Common Pitfalls

Transfer learning in NiftyNet will work between any models that share the same variables as those being restored. In other words, you can completely change network layers that you plan to randomise but if you try to restore variables that aren’t the exact shape and name between models, Tensorflow and NiftyNet will throw an error. For example, you may encounter the following error in your training log:

CRITICAL:niftynet:2018-10-24 00:53:22,438: checkpoint ~/outputs/models/
model.ckpt-### not found or variables to restore do not match the current
application graph

This means that certain variables in your network are not present in your model checkpoint. Often this can occur unexpectedly when you restore variables that were frozen during the previous round of training. Since these variables weren’t being trained, optimizer specific variables used in methods like Adam were never created and therefore never saved in the checkpoint. You can overcome this by simply restoring all variables except for those used by Adam: vars_to_restore = ^((?!(Adam)).)*$. In general, if you read the error thrown by Tensorflow, you should be able to figure out which variables are causing the problem.