1) 使用tf.scan
可以使用tf.scan函数实现RNN的底层实现。例如,对于SimpleRNN,实现类似于:
Wx = tf.get_variable(name='Wx', shape=[embedding_size, rnn_size])
Wh = tf.get_variable(name='Wh', shape=[rnn_size, rnn_size])
bias_rnn = tf.get_variable(name='brnn', initializer=tf.zeros([rnn_size]))
def rnn_step(prev_hidden_state, x):
return tf.tanh(tf.matmul(x, Wx) + tf.matmul(prev_hidden_state, Wh) + bias_rnn)
hidden_states = tf.scan(fn=rnn_step,
elems=tf.transpose(embed, perm=[1, 0, 2]),
initializer=tf.zeros([batch_size, rnn_size]))
outputs = tf.transpose(hidden_states, perm=[1, 0, 2])
last_rnn_output = outputs[:, -1, :]
完整的示例请点击此处查看。
2) 使用AutoGraph
tf.scan
是一个for循环,您也可以使用Auto-graph API实现它:
from tensorflow.python import autograph as ag
@ag.convert()
def f(x):
for ch in chars:
cell_output, (state, output) = cell.call(ch, (state, output))
hidden_outputs.append(cell_output)
hidden_outputs = autograph.stack(hidden_outputs)
完整的带有自动化API的示例请参见
此处。
3)使用Numpy实现
如果您仍需要深入了解RNN的实现,请参见此教程,该教程使用numpy实现了RNN。
4)Keras中的自定义RNN单元
请参见此处。