Source code for niftynet.contrib.checkpoint_tools.rename_checkpoint_to_partial

# -*- coding: utf-8 -*-
import sys,glob,csv

import tensorflow as tf

[docs]def rename_checkpoint_to_partial(source,target,transform): vars=tf.contrib.framework.list_variables(source) var_names = [v.split('/') for v,s in vars] transform_pairs = [] for s_name,t_name in transform: print(vars,s_name,t_name) if s_name[-1]=='/' and t_name[-1]=='/': # scope s_names=s_name.split('/')[:-1] t_names=t_name.split('/')[:-1] transform_pairs += [('/'.join(v),'/'.join(t_names+v[len(s_names):])) for v in var_names if v[:len(s_names)]==s_names] elif s_name[-1]!='/' and t_name[-1]!='/': # variable if s_name in [v for v,s in vars]: transform_pairs.append((s_name,t_name)) else: raise ValueError('Cannot rename a variable to a scope or vice versa: %s->%s' %(s_name,t_name)) print(transform_pairs) g = tf.Graph() with g.as_default(): with tf.Session() as sess: for s_name,t_name in transform_pairs: var = tf.contrib.framework.load_variable(source, s_name) var = tf.Variable(var, name=t_name) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) saver.save(sess, target)
usage = \ """%s source_checkpoint destination_checkpoint rename_file rename_file has the format: source_scope1,renamed_scope1 source_scope2/variable1,renamed_scope2/renamed_variable1 which will rename source_scope1/* to renamed_scope1/* and source_scope2/variable1 to renamed_scope2/renamed_variable1 """ %sys.argv[0]
[docs]def main(argv): if len(argv)<3: print(usage) return 2 if not glob.glob(argv[0]+'.index'): print('Checkpoint %s does not exist' % argv[0]) return 2 if not glob.glob(argv[2]): print('Transform file %s does not exist' % argv[2]) return 2 with open(argv[2],'rb') as csvfile: r=csv.reader(csvfile) rows=[row for row in r] if any(len(row)!=2 for row in rows): print('Error %s: each line must have a source and target variable name' %(argv[2])) return 2 rename_checkpoint_to_partial(argv[0],argv[1],argv[2]) return 0
if __name__ == '__main__': sys.exit(main(sys.argv[1:]))