TensorFlow:将GRUCell的权重从compat.v1转换为TensorFlow 2

4

我正在尝试将一个在tensorflow 1中保存的模型转换为tensorflow 2。如tensorflow文档中所示,我正在将代码迁移到tensorflow 2。然而,我想简单地将我的model_weights.ckpt更新到tensorflow 2。一些权重(LinearEmbdedding)与tensorflow 2语法具有类似的形状,但我正在努力将我的GRUCell权重从旧版compat.v1.nn.rnn_cell.GRUCell转换为keras.layers.GRUCell

如何将GRUCell权重从compat.v1.nn.rnn_cell.GRUCell转换为keras.layers.GRUCell

GRUCell有四个权重:

  • gru_cell/gates/kernel:0的形状为(S + H, 2 x H),
  • gru_cell/gates/bias:0的形状为(2 x H, ),
  • gru_cell/candidate/kernel:0的形状为(S + H, H),
  • gru_cell/candidate/bias:0的形状为(H, )

我想要的权重与tensorflow 2 API(或PyTorch API)具有类似的形状,即一个带有以下权重的GRUCell

  • gru_cell/kernel:0的形状为(S, 3 x H)
  • gru_cell/recurrent_kernel:0的形状为(H, 3 x H)
  • gru_cell/bias:0的形状为(2, 3 x H)

你可以复制以下结果以说明:

1. 具有tensorflow 1 API的GRUCell

import tensorflow as tf

SEQ_LENGTH = 4
HIDDEN_SIZE = 512
BATCH_SIZE = 1
inputs = tf.random.normal([BATCH_SIZE, SEQ_LENGTH])

# GRU cell
gru = tf.compat.v1.nn.rnn_cell.GRUCell(HIDDEN_SIZE)
# Hidden state
state = gru.zero_state(BATCH_SIZE, tf.float32)
# Forward
output, state = gru(inputs, state)

for weight in gru.weights:
    print(weight.name, weight.shape)

输出:

gru_cell/gates/kernel:0 (516, 1024)
gru_cell/gates/bias:0 (1024,)
gru_cell/candidate/kernel:0 (516, 512)
gru_cell/candidate/bias:0 (512,)

2. 使用tensorflow 2 API的GRUCell

import tensorflow as tf

SEQ_LENGTH = 4
HIDDEN_SIZE = 512
BATCH_SIZE = 1
inputs = tf.random.normal([BATCH_SIZE , SEQ_LENGTH])

# GRU cell
gru = tf.keras.layers.GRUCell(HIDDEN_SIZE)
# Hidden state
state = tf.zeros((BATCH_SIZE, HIDDEN_SIZE), dtype=tf.float32)
# Forward
output, state = gru(inputs, state)

# Display the weigths
for weight in gru.weights:
    print(weight.name, weight.shape)

输出:

gru_cell/kernel:0 (4, 1536)
gru_cell/recurrent_kernel:0 (512, 1536)
gru_cell/bias:0 (2, 1536)

注意


这里提到了一些检查点兼容性的解决方法,如果您还没有尝试过,请看一下并使用这些解决方法,让我们看看结果。谢谢。 - user11530462
这并没有解决问题。检查点转换器适用于DNNLinearCombined估算器。它不支持复杂模型(例如GAN和其他模型)。然而,如果我能够手动将GRUCell权重转换为keras语法,那么它将解决我的问题,因为我将能够加载检查点/模型。 - polop
1个回答

1

为了造福社区,即使它是在Github上展示的,我们也提供解决方案。

简而言之,compat.v1.nn.rnn_cell.GRUCellkeras.layers.GRUCell之间的权重彼此不兼容。我们没有一个函数来转换它们,如果你真的想这么做,你需要手动处理。

数学上来说,如果你有v1权重的numpy值,公式如下:

B = batch_size

H = state_size

  1. all_kernel = np.concatenate([gru_cell/gates/kernel, gru_cell/candidate/kernel], axis=1) # 形状为 (B+H, 3 * H)
  2. kernel = all_kernel[:B] # 形状为(B, 3 * H)
  3. recurrent_kernel = all_kernel[B:] # 形状为 (H, 3 * H)
  4. bias = np.concatenate([gru_cell/gates/bias, gru_cell/candidate/bias], axis=0) # 形状为 (B, 3 * H)
  5. zero_bias = np.zeros([B, 3 * H])
  6. bias = np.concatenate([bias, zero_bias], axis=0)

偏置的形状是否有错别字?没有B,应该是[3 * H, ] - Vimos

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接