TensorFlow 1.2 如何使用 Seq2Seq 在推理时设置时间序列预测

3

我正在尝试使用玩具模型学习TensorFlow库中的tf.contrib.seq2seq部分。目前,我的图如下所示:

tf.reset_default_graph()

# Placeholders
enc_inp = tf.placeholder(tf.float32, [None, n_steps, n_input])
expect = tf.placeholder(tf.float32, [None, n_steps, n_output])
expect_length = tf.placeholder(tf.int32, [None])
keep_prob = tf.placeholder(tf.float32, [])

# Encoder
cells = [tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(n_hidden), output_keep_prob=keep_prob) for i in range(layers_stacked_count)]
cell = tf.contrib.rnn.MultiRNNCell(cells)
encoded_outputs, encoded_states = tf.nn.dynamic_rnn(cell, enc_inp, dtype=tf.float32)

# Decoder
de_cells = [tf.contrib.rnn.DropoutWrapper(tf.contrib.rnn.BasicLSTMCell(n_hidden), output_keep_prob=keep_prob) for i in range(layers_stacked_count)]
de_cell = tf.contrib.rnn.MultiRNNCell(de_cells)

training_helper = tf.contrib.seq2seq.TrainingHelper(expect, expect_length)

decoder = tf.contrib.seq2seq.BasicDecoder(cell=de_cell, helper=training_helper, initial_state=encoded_states)
final_outputs, final_state, final_sequence_lengths = tf.contrib.seq2seq.dynamic_decode(decoder)

decoder_logits = final_outputs.rnn_output

h = tf.contrib.layers.fully_connected(decoder_logits, n_output)

diff = tf.squared_difference(h, expect)
batch_loss = tf.reduce_sum(diff, axis=1)
loss = tf.reduce_mean(batch_loss)

optimiser = tf.train.AdamOptimizer(1e-3)
training_op = optimiser.minimize(loss)

图表的训练非常顺利,执行也很好。然而,在推断时我不确定该怎么做,因为这个图表总是需要 "expect" 变量(即我试图预测的值)。
据我了解,TrainingHelper函数使用的是实际值作为输入,所以在推断时我需要另一个辅助函数。
我见过的大多数seq2seq模型实现似乎已经过时(tf.contrib.legacy_seq2seq)。一些最新的模型通常使用GreddyEmbeddingHelper,但我不确定它是否适用于连续时间序列预测。
我发现的另一个可能的解决方案是使用CustomHelper函数。然而,没有太多资料可以供我学习,我只能自己摸索。
如果我想为时间序列预测实现seq2seq模型,在推断时应该怎么做?
非常感谢您的帮助和建议!提前致谢!

1
缺少的部分是:“我找到的另一个可能的解决方案是使用CustomHelper函数。然而,没有太多的资料可以让我学习,我一直在摸索。”需要示例。 - George Pligoropoulos
你已经尝试使用CustomHelper解决这个问题了吗? - MrfksIV
1个回答

1

您是正确的,需要使用另一个辅助函数进行推理,但是在测试和推理之间需要共享权重。

您可以使用tf.variable_scope()实现此操作。

with tf.variable_scope("decode"):
    training_helper = ...

with tf.variable_scope("decode", reuse = True):
    inference_helper = ...

为了更全面的示例,请查看以下两个示例之一:


谢谢你的帮助,这就是我一直在寻找的! - ppyht2
我现在也面临着同样的问题。你有使用InferenceHelper来解决这个问题吗? - MrfksIV
@MrfksIV,你找到使用inferenceHelper的方法了吗? - shivam13juna

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接