tf.nn.dynamic_rnn
接收一个批次(使用小批量意义)的无关序列。
cell
是您想要使用的实际单元格(LSTM,GRU等)
inputs
的形状为batch_size x max_time x input_size
,其中max_time是最长序列中的步骤数(但所有序列长度可以相同)
sequence_length
是大小为batch_size
的向量,其中每个元素给出批处理中每个序列的长度(如果所有序列的大小都相同,则将其保留为默认值。该参数定义了单元格展开大小。
隐藏状态处理
处理隐藏状态的常规方法是在dynamic_rnn
之前定义一个初始状态张量,例如:
hidden_state_in = cell.zero_state(batch_size, tf.float32)
output, hidden_state_out = tf.nn.dynamic_rnn(cell,
inputs,
initial_state=hidden_state_in,
...)
在上面的代码片段中,
hidden_state_in
和
hidden_state_out
的形状相同:
[batch_size, ...]
(
实际形状取决于所使用的单元类型,但重要的是第一个维度是批处理大小)。
这样,
dynamic_rnn
为每个序列都有一个初始隐藏状态。
它会在每个时间步骤上自动传递输入参数inputs
中每个序列的隐藏状态,并将hidden_state_out
包含每个批次序列的最终输出状态。不会在同一批次序列之间传递隐藏状态,而只会在同一序列的时间步之间传递。
何时需要手动反馈隐藏状态?
通常,在训练时,每个批次都是无关的,因此在执行
session.run(output)
时不需要手动反馈隐藏状态。
然而,如果你在测试时需要每个时间步长的输出(即必须在每个时间步长上执行
session.run()
),则需要使用类似以下代码评估并反馈输出的隐藏状态:
output, hidden_state = sess.run([output, hidden_state_out],
feed_dict={hidden_state_in:hidden_state})
否则,TensorFlow将在每个时间步骤上仅使用默认的
cell.zero_state(batch_size, tf.float32)
,这相当于在每个时间步骤重新初始化隐藏状态。
batch_size
的值或隐式推断它?有任何指导吗? - Kotshidden_state_in
的占位符的形状? - Kots