好的 - 看起来实际上有一个更简单的解决方案。正如@Shao Tang和@Rahul提到的那样,最好的方法是通过访问最终单元状态来完成这个问题。原因在于:
- 如果您查看GRUCell源代码(如下所示),您将看到单元维护的“状态”实际上就是隐藏层权重本身。因此,当
tf.nn.dynamic_rnn
返回最终状态时,它实际上返回的是您感兴趣的最终隐藏层权重。为了证明这一点,我刚刚调整了您的设置并得到了结果:
GRUCell Call (rnn_cell_impl.py):
def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
if self._gate_linear is None:
bias_ones = self._bias_initializer
if self._bias_initializer is None:
bias_ones = init_ops.constant_initializer(1.0, dtype=inputs.dtype)
with vs.variable_scope("gates"):
self._gate_linear = _Linear(
[inputs, state],
2 * self._num_units,
True,
bias_initializer=bias_ones,
kernel_initializer=self._kernel_initializer)
value = math_ops.sigmoid(self._gate_linear([inputs, state]))
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
r_state = r * state
if self._candidate_linear is None:
with vs.variable_scope("candidate"):
self._candidate_linear = _Linear(
[inputs, r_state],
self._num_units,
True,
bias_initializer=self._bias_initializer,
kernel_initializer=self._kernel_initializer)
c = self._activation(self._candidate_linear([inputs, r_state]))
new_h = u * state + (1 - u) * c
return new_h, new_h
解决方案:
import numpy as np
import tensorflow as tf
batch = 2
dim = 3
hidden = 4
lengths = tf.placeholder(dtype=tf.int32, shape=[batch])
inputs = tf.placeholder(dtype=tf.float32, shape=[batch, None, dim])
cell = tf.nn.rnn_cell.GRUCell(hidden)
cell_state = cell.zero_state(batch, tf.float32)
output, state = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)
inputs_ = np.asarray([[[0, 0, 0], [1, 1, 1], [2, 2, 2], [3, 3, 3]],
[[6, 6, 6], [7, 7, 7], [8, 8, 8], [9, 9, 9]]],
dtype=np.int32)
lengths_ = np.asarray([3, 1], dtype=np.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
output_, state_ = sess.run([output, state], {inputs: inputs_, lengths: lengths_})
print (output_)
print (state_)
输出:
[[[ 0. 0. 0. 0. ]
[-0.24305521 -0.15512943 0.06614969 0.16873555]
[-0.62767833 -0.30741733 0.14819752 0.44313088]
[ 0. 0. 0. 0. ]]
[[-0.99152333 -0.1006391 0.28767768 0.76360202]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]]
[[-0.62767833 -0.30741733 0.14819752 0.44313088]
[-0.99152333 -0.1006391 0.28767768 0.76360202]]
对于使用LSTMCell(另一种流行的选项)的其他读者,事情有些不同。LSTMCell以不同的方式维护状态 - 单元格状态可以是元组或实际单元格状态和隐藏状态的串联版本。因此,要访问最终的隐藏权重,您可以在单元格初始化期间将
is_state_tuple
设置为
True
,最终状态将是一个元组:(最终单元格状态,最终隐藏权重)。所以,在这种情况下,
_, (_, h) = tf.nn.dynamic_rnn(cell, inputs, lengths, initial_state=cell_state)
将给你最终的权重。
参考文献:
Tensorflow LSTM中的c_state和m_state
https://github.com/tensorflow/tensorflow/blob/438604fc885208ee05f9eef2d0f2c630e1360a83/tensorflow/python/ops/rnn_cell_impl.py#L308
https://github.com/tensorflow/tensorflow/blob/438604fc885208ee05f9eef2d0f2c630e1360a83/tensorflow/python/ops/rnn_cell_impl.py#L415
state
(根据TF命名规则)是表示单元内部状态的向量元组,通过时间步传递,而问题关注的是单元的最终输出。顺便说一下,状态大小并不一定与输出相同。 - petrux