如何使用TensorFlow中的seq2seq预测简单序列?

9

我最近开始使用tensorflow,所以我仍然在努力掌握基础知识。

我想创建简单的seq2seq预测模型。

  • 输入是0到1之间数字的列表。
  • 输出是列表中的第一个数字和其余数字乘以第一个数字。

我成功地评估了模型性能并优化了权重。我一直在苦苦寻找如何使用训练好的模型进行预测的方法。

 model_outputs, states = seq2seq.basic_rnn_seq2seq(encoder_inputs,
                                                  decoder_inputs,
                                                  rnn_cell.BasicLSTMCell(data_point_dim, state_is_tuple=True))

为了生成模型输出,我需要模型的输入和输出值,这对于评估很有用,但在预测中,我只有输入值。我猜我需要处理状态,但我不确定如何将它们转换为浮点数序列。
完整代码可在此处找到: https://gist.github.com/anonymous/be405097927758acca158666854600a2
3个回答

4

在训练过程中,您需要在每个解码器时间步提供期望输出作为解码器输入。而在测试过程中,您没有期望的输出值,所以您只能对输出进行抽样,将其作为下一个时间步的输入。

简单来说,在每个时间步将上一个时间步的解码器输出馈送回解码器即可。

编辑:一些TF代码

basic_rnn_seq2seq函数返回rnn_decoder(decoder_inputs, enc_states[-1], cell)

我们来看一下rnn_decoder

def rnn_decoder(decoder_inputs, initial_state, cell, loop_function=None, scope=None): ....

loop_function:如果不是None,则将应用于第i个输出结果以生成i+1个输入结果,除了第一个元素(“GO”符号)外,将忽略decoder_inputs。这可以用于解码,也可以用于训练以模拟http://arxiv.org/pdf/1506.03099v2.pdf

在解码过程中,您需要将此loop_function设置为True

我建议查看Tensorflow seq2seq库中的translate.py文件,以了解如何处理此问题。


如果我理解正确,按照那个解决方案,我会忽略状态变量,只使用session.run(output, feed_dict=feed)中的输出来获取结果?在这个过程中,难道不应该有一种使用状态的方法吗?我一直在使用scikit-learn,并希望有一种创建类似于model.predict方法的方式。 - Daniel Lyam Montross
添加了进一步的澄清,但是是的,你会忽略隐藏状态,因为这些隐藏状态就是被用来创建RNN最终输出的。至于代码更改,我建议看一下Tensorflow提供的seq2seq wmt示例。我认为你确实需要改变输出馈送。 - user4383691
我也在努力让类似的东西工作。是否有人可以提供一个简单的工作示例,以执行此答案中所描述的操作? - b..

0

用户user4383691之前的回答不完整。 我也遇到了同样的问题,在深入研究rnn_decoder后,发现:模型将loop_fn应用于第i个输出,因此True没有意义,因为它不是一个函数。 你应该编写一个函数,该函数可以接收第i个输出并返回第i+1个输出。我仍在制作这样的函数,并将在完成后立即更新。


0
让我们来看看源代码

prev = None    for i, inp in enumerate(decoder_inputs):
     if loop_function is not None and prev is not None:
       with variable_scope.variable_scope("loop_function", reuse=True):
         inp = loop_function(prev, i)
     if i > 0:
       variable_scope.get_variable_scope().reuse_variables()
     output, state = cell(inp, state)
     outputs.append(output)
     if loop_function is not None:
       prev = output

循环枚举解码器输入,无论您是使用提供的解码器输入进行训练还是在没有输入的情况下进行测试。这是因为在测试时,解码器输入将被循环函数的输出所替代(在上面片段的第四行)。

通常,您可以像这里一样用end_ids填充dec_inputs。

  while len(dec_inputs) < self._hps.dec_timesteps:
    dec_inputs.append(end_id)
  while len(targets) < self._hps.dec_timesteps:
    targets.append(end_id)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接