脚手架和tf.train.MonitoredTrainingSession

4

我想知道如何在使用tf.train.MonitoredTrainingSession时使用Scaffold,并且将图的权重初始化为从Numpy数组导入的特定值。我找不到任何类似用法的明确示例。谢谢

1个回答

2

有几种方法可以实现这个目标。

保存图表检查点的方法

  • 构建图表。
  • 初始化所有变量。
  • 运行会话以为每个变量分配值。
  • 保存检查点以在训练时加载。
  • 在训练时使用检查点。

使用模型初始化和恢复

您可以在此处查看更多详细信息:Tensorflow模型恢复。基本上,您可以创建tf.train.Scaffold并将init_fn与您的init函数分配。

我只测试过第一种方法,可以分享一些代码:

  with tf.Graph().as_default():

    # build the graph as it is in training
    some code...

    sess = tf.Session()
    with sess.as_default():

        # Add an op to initialize the variables.
        init_op = tf.global_variables_initializer()
        sess.run(init_op)

        #Update your graph with starting variables
        data_dict = np.load('your_pass/model.npy', encoding='latin1').item()
        #
        var = tf.get_variable(param_name)
        sess.run(var.assign(data_dict))
        print('assignment done!')

    saver = tf.train.Saver()

    # Save the variables to disk.
    save_path = saver.save(sess, FLAGS.train_dir)
    print("Model saved in file: %s" % save_path)    

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接