实际上,为嵌入层设置
mask_zero=True
不会返回零向量。相反,嵌入层的行为不会改变,它将返回索引为零的嵌入向量。您可以通过检查嵌入层权重(例如,在您提到的示例中,它将是
m.layers[0].get_weights()
)来确认这一点。相反,它会影响后续层的行为,如RNN层。
如果您检查嵌入层的源代码,您会看到一个名为
compute_mask
的方法:
def compute_mask(self, inputs, mask=None):
if not self.mask_zero:
return None
output_mask = K.not_equal(inputs, 0)
return output_mask
该输出掩码将被传递作为mask
参数,用于支持掩码的以下层。这已经在基础层Layer
的__call__
方法中实现:
previous_mask = _collect_previous_mask(inputs)
user_kwargs = copy.copy(kwargs)
if not is_all_none(previous_mask):
if has_arg(self.call, 'mask'):
if 'mask' not in kwargs:
kwargs['mask'] = previous_mask
这使得以下层忽略(即在计算中不考虑)这些输入步骤。下面是一个最简示例:
data_in = np.array([
[1, 0, 2, 0]
])
x = Input(shape=(4,))
e = Embedding(5, 5, mask_zero=True)(x)
rnn = LSTM(3, return_sequences=True)(e)
m = Model(inputs=x, outputs=rnn)
m.predict(data_in)
array([[[-0.00084503, -0.00413611, 0.00049972],
[-0.00084503, -0.00413611, 0.00049972],
[-0.00144554, -0.00115775, -0.00293898],
[-0.00144554, -0.00115775, -0.00293898]]], dtype=float32)
正如您所看到的,LSTM层在第二和第四个时间步的输出与第一和第三个时间步的输出相同。这意味着这些时间步已被掩盖。
更新:考虑到掩码,损失函数也将被考虑在内,因为内部使用weighted_masked_objective
来支持掩码。
def weighted_masked_objective(fn):
"""Adds support for masking and sample-weighting to an objective function.
It transforms an objective function `fn(y_true, y_pred)`
into a sample-weighted, cost-masked objective function
`fn(y_true, y_pred, weights, mask)`.
# Arguments
fn: The objective function to wrap,
with signature `fn(y_true, y_pred)`.
# Returns
A function with signature `fn(y_true, y_pred, weights, mask)`.
"""
在编译模型时:
weighted_losses = [weighted_masked_objective(fn) for fn in loss_functions]
您可以使用以下示例进行验证:
data_in = np.array([[1, 2, 0, 0]])
data_out = np.arange(12).reshape(1,4,3)
x = Input(shape=(4,))
e = Embedding(5, 5, mask_zero=True)(x)
d = Dense(3)(e)
m = Model(inputs=x, outputs=d)
m.compile(loss='mse', optimizer='adam')
preds = m.predict(data_in)
loss = m.evaluate(data_in, data_out, verbose=0)
print(preds)
print('Computed Loss:', loss)
[[[ 0.009682 0.02505393 -0.00632722]
[ 0.01756451 0.05928303 0.0153951 ]
[-0.00146054 -0.02064196 -0.04356086]
[-0.00146054 -0.02064196 -0.04356086]]]
Computed Loss: 9.041069030761719
print(np.square(preds[0,0:2] - data_out[0,0:2]).mean())
9.041070036475277