我的情况:
- 定义一个RNN模型结构并使用具有固定批处理大小和序列长度的输入进行训练。
- 冻结模型(即将所有可训练变量转换为常数),从而产生一个包含使用模型进行测试所需的所有内容的
GraphDef
(通过tf.graph_util.convert_variables_to_constants
)。 - 通过
tf.import_graph_def
导入GraphDef
,并使用参数替换输入。新输入需要具有任意批处理大小和序列长度。
问题:在将使用与训练时使用的原始大小不同的批处理大小或序列长度的输入传递到测试时间图时,以上所有操作都有效。此时,我会收到以下错误:
InvalidArgumentError (see above for traceback): ConcatOp : Dimensions of inputs should match: shape[0] = [1,5] vs. shape[1] = [2,7]
[[Node: import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat = ConcatV2[N=2, T=DT_FLOAT, Tidx=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](import/rnn/while/TensorArrayReadV3, import/rnn/while/Identity_2, import/rnn/while/basic_rnn_cell/basic_rnn_cell_1/concat/axis)]]
为了说明和重现问题,请考虑以下最简示例:
- v1:创建具有任意批次大小和序列长度的图形。这很好,但不幸的是,在训练时必须使用固定的批次大小和序列长度,并且在测试时必须使用任意批次大小和序列长度,因此我不能使用这种简单的方法。 - v2a:我们模拟使用固定批次大小(2)和序列长度(3)创建训练时间图形并冻结图形。 - v2ba:我们证明在未更改的情况下加载冻结模型仍会产生相同的结果。 - v2bb:我们证明在使用仍然使用固定批次大小和序列长度的替换输入加载冻结模型的情况下仍会产生相同的结果。 - v2bc:我们证明在使用任意批次大小和序列长度的替换输入加载冻结模型的情况下,只要输入按照原始批次大小和序列长度进行形状化,仍将产生相同的结果。它适用于data,但对于data2则失败,唯一的区别在于前者的批次大小为2,而后者的批次大小为1。
是否可以通过`input_map`参数将RNN图形更改为
tf.import_graph_def
,以便输入不再具有固定的批次大小和序列长度?以下代码适用于TensorFlow 1.1 RC2,可能也适用于TensorFlow 1.0。
import numpy
import tensorflow as tf
from tensorflow import graph_util as tf_graph_util
from tensorflow.contrib import rnn as tfc_rnn
def v1(data):
with tf.Graph().as_default():
tf.set_random_seed(1)
x = tf.placeholder(tf.float32, shape=(None, None, 5))
_, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
print session.run(s, feed_dict={x: data})
def v2a():
with tf.Graph().as_default():
tf.set_random_seed(1)
x = tf.placeholder(tf.float32, shape=(2, 3, 5), name="x")
_, s = tf.nn.dynamic_rnn(tfc_rnn.BasicRNNCell(7), x, dtype=tf.float32)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
return tf_graph_util.convert_variables_to_constants(
session, session.graph_def, [s.op.name]), s.name
def v2ba((graph_def, s_name), data):
with tf.Graph().as_default():
x, s = tf.import_graph_def(graph_def,
return_elements=["x:0", s_name])
with tf.Session() as session:
print '2ba', session.run(s, feed_dict={x: data})
def v2bb((graph_def, s_name), data):
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, shape=(2, 3, 5))
[s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
return_elements=[s_name])
with tf.Session() as session:
print '2bb', session.run(s, feed_dict={x: data})
def v2bc((graph_def, s_name), data):
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, shape=(None, None, 5))
[s] = tf.import_graph_def(graph_def, input_map={"x:0": x},
return_elements=[s_name])
with tf.Session() as session:
print '2bc', session.run(s, feed_dict={x: data})
def main():
data1 = numpy.random.random_sample((2, 3, 5))
data2 = numpy.random.random_sample((1, 3, 5))
v1(data1)
model = v2a()
v2ba(model, data1)
v2bb(model, data1)
v2bc(model, data1)
v2bc(model, data2)
if __name__ == "__main__":
main()