TensorFlow dynamic_rnn 状态

3
我的问题与TensorFlow方法tf.nn.dynamic_rnn有关。它返回每个时间步的输出和最终状态。
我想知道返回的最终状态是在最大序列长度处的单元格状态还是由sequence_length参数单独确定的。
为了更好地理解一个例子:我有3个长度为[10,20,30]的序列,并返回最终状态[3,512](如果单元格的隐藏状态长度为512)。
这三个返回的隐藏状态是三个序列在时间步长30时的单元格状态,还是我得到的是时间步长[10,20,30]的状态?
1个回答

7

tf.nn.dynamic_rnn返回两个张量:outputsstates

outputs保存了批处理中所有序列的所有单元格的输出。因此,如果特定序列较短且填充为零,则最后一个单元格的outputs将为零。

states保存了最后一个单元格状态,或者等效地说,每个序列的最后一个非零输出(如果您正在使用BasicRNNCell)。

以下是示例:

import numpy as np
import tensorflow as tf

n_steps = 2
n_inputs = 3
n_neurons = 5

X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])
seq_length = tf.placeholder(tf.int32, [None])

basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)

X_batch = np.array([
  # t = 0      t = 1
  [[0, 1, 2], [9, 8, 7]], # instance 0
  [[3, 4, 5], [0, 0, 0]], # instance 1
])
seq_length_batch = np.array([2, 1])

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  outputs_val, states_val = sess.run([outputs, states], 
                                     feed_dict={X: X_batch, seq_length: seq_length_batch})

  print('outputs:')
  print(outputs_val)
  print('\nstates:')
  print(states_val)

这将打印出类似于:

outputs:
[[[-0.85381496 -0.19517037  0.36011398 -0.18617202  0.39162001]
  [-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]]

 [[-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]
  [ 0.          0.          0.          0.          0.        ]]]  # because len=1

states:
[[-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]
 [-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]]

请注意,states 保存的向量与 output 中相同,并且它们是每个批次实例的最后一个非零输出。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接