如何加速Keras Attention计算?

15

我已经为 AttentiveLSTMCellAttentiveLSTM(RNN) 写了一个自定义的 Keras 层,符合 keras 对 RNN 的 方法。这种注意力机制由 Bahdanau 描述,在编码器/解码器模型中,由所有编码器和解码器当前隐藏状态的输出创建“上下文”向量。然后,我在每个时间步骤将上下文向量附加到输入中。

该模型用于创建对话代理,但在架构上与 NMT 模型非常相似(类似的任务)。

然而,通过添加这个注意力机制,我使网络训练速度减慢了 5 倍,我真的很想知道如何以更有效的方式编写代码部分来提高训练效率。

计算的核心部分在此处完成:

h_tm1 = states[0]  # previous memory state
c_tm1 = states[1]  # previous carry state

# attention mechanism

# repeat the hidden state to the length of the sequence
_stm = K.repeat(h_tm1, self.annotation_timesteps)

# multiplty the weight matrix with the repeated (current) hidden state
_Wxstm = K.dot(_stm, self.kernel_w)

# calculate the attention probabilities
# self._uh is of shape (batch, timestep, self.units)
et = K.dot(activations.tanh(_Wxstm + self._uh), K.expand_dims(self.kernel_v))

at = K.exp(et)
at_sum = K.sum(at, axis=1)
at_sum_repeated = K.repeat(at_sum, self.annotation_timesteps)
at /= at_sum_repeated  # vector of size (batchsize, timesteps, 1)

# calculate the context vector
context = K.squeeze(K.batch_dot(at, self.annotations, axes=1), axis=1)

# append the context vector to the inputs
inputs = K.concatenate([inputs, context])

AttentiveLSTMCellcall 方法中(一个时间步长)。

完整的代码可以在这里找到。如果需要提供一些数据和与模型交互的方式,那我可以做到。

有什么想法吗?当然,如果有什么聪明的方法,我会在GPU上进行训练。


2
你能发布一些使用 tensorflow.python.client.timeline.Timeline 在一些训练轮次上的输出吗?如果没有好的分析器数据,只是猜测原因就像在黑暗中射击。最好收集直接证据。 - ely
是的,我可以稍后开始处理 @ely。 - modesitt
你有为你的代码做性能分析吗?猜测哪些地方需要优化往往是徒劳的。我喜欢使用Python的line-profiler [kernprof](https://github.com/rkern/line_profiler),而且在Keras中,你可以利用TensorFlow提供的工具,如[TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard)。 - BoltzmannBrain
2个回答

1
你修改了LSTM类,这对于CPU计算非常好,但是你提到你正在GPU上训练。
我建议研究cudnn-recurrent实现或进一步研究所使用的tf部分。也许你可以在那里扩展代码。

1
我建议使用relu而不是tanh来训练模型,因为这个操作的计算速度更快。这将节省你计算时间,大约为你的训练样本数量*每个样本的平均序列长度*epoch数。
此外,我建议考虑追加上下文向量的性能改进,但请注意,这会降低其他参数的迭代周期。如果没有给你带来太多的改进,那么尝试其他方法可能值得一试。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接