我正在使用dynamic_rnn来处理MNIST数据:
# LSTM Cell
lstm = rnn_cell.LSTMCell(num_units=200,
forget_bias=1.0,
initializer=tf.random_normal)
# Initial state
istate = lstm.zero_state(batch_size, "float")
# Get lstm cell output
output, states = rnn.dynamic_rnn(lstm, X, initial_state=istate)
# Output at last time point T
output_at_T = output[:, 27, :]
完整代码:http://pastebin.com/bhf9MgMe
LSTM的输入是(batch_size, sequence_length, input_size)
因此,output_at_T
的维度为 (batch_size, sequence_length, num_units)
,其中 num_units=200
。
我需要获取沿着 sequence_length
维度的最后一个输出。在上面的代码中,这是硬编码为 27
。然而,在我的应用程序中,我不知道 sequence_length
的先验知识,因为它可以从批次到批次改变。
我尝试了:
output_at_T = output[:, -1, :]
但是它说负索引尚未实现,我尝试使用占位符变量以及常量(可以理想地将sequence_length
馈送到特定批次中);但是两者都没有起作用。
有没有办法在tensorflow中实现这样的功能?