注意力层抛出 TypeError:Keras中的Permute层不支持掩码处理。

19

我一直在遵循这个帖子,以实现在我的LSTM模型上应用注意力层

注意力层的代码:

INPUT_DIM = 2
TIME_STEPS = 20
SINGLE_ATTENTION_VECTOR = False
APPLY_ATTENTION_BEFORE_LSTM = False

def attention_3d_block(inputs):
    input_dim = int(inputs.shape[2])
    a = Permute((2, 1))(inputs)
    a = Reshape((input_dim, TIME_STEPS))(a)
    a = Dense(TIME_STEPS, activation='softmax')(a)
    if SINGLE_ATTENTION_VECTOR:
        a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction')(a)
        a = RepeatVector(input_dim)(a)
    a_probs = Permute((2, 1), name='attention_vec')(a)
    output_attention_mul = merge(
        [inputs, a_probs],
        name='attention_mul',
        mode='mul'
    )
    return output_attention_mul

我收到的错误:

File "main_copy.py", line 244, in model = create_model(X_vocab_len, X_max_len, y_vocab_len, y_max_len, HIDDEN_DIM, LAYER_NUM) File "main_copy.py", line 189, in create_model attention_mul = attention_3d_block(temp) File "main_copy.py", line 124, in attention_3d_block a = Permute((2, 1))(inputs) File "/root/.virtualenvs/keras_tf/lib/python3.5/site-packages/keras/engine/topology.py", line 597, in call output_mask = self.compute_mask(inputs, previous_mask) File "/root/.virtualenvs/keras_tf/lib/python3.5/site-packages/keras/engine/topology.py", line 744, in compute_mask str(mask)) TypeError: Layer permute_1 does not support masking, but was passed an input_mask: Tensor("merge_2/All:0", shape=(?, 15), dtype=bool)

我查看了此讨论串,其中说:

这是在Keras源代码中进行的小更改(将Lambda层中的supports_masking类变量设置为True而不是False)。否则没有其他方法可以解决。虽然遮罩不是必需的。

我在哪里可以将supports_masking变量设置为True?此外,还有其他解决方案吗?

2个回答

0

我建议:不要使用掩码。

这个实现方式有点奇怪,试图将一个 Dense 层应用于一个可变维度(TIME_STEPS)。

这将需要在层中有可变数量的权重,这是不可能的。(使用掩码,你会告诉模型忽略每个不同样本的一些权重)。

我建议在输入中加入一个标记/单词,表示“这是句子/电影/序列的结尾”,并用此标记填充剩余长度。然后关闭或删除您在模型中使用的任何掩码(无论是在声明嵌入层时的参数还是实际的掩码层)。


试图更改keras原生代码可能会导致不稳定的行为和错误的结果(如果没有错误)。

这些层不支持掩码的原因有一些类似于关于Dense层的上述解释。如果你改变了它,谁知道会出现什么问题?除非你真的非常确定它可能产生的所有后果,否则不要去修改源代码。


如果你想使用遮罩,但是发现有一些复杂的解决方案(我没有测试过),比如这个:MaskEatingLambda层:


0

我是该软件包的作者之一。

您应该使用最新版本。 以前的版本存在一些概念上的问题。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接