如何使用二维掩码过滤三维数组

3

我有一个(m,n,3) 的数组 data,我想用一个 (m,n) 的掩码来过滤其值,以获得一个 (x,3)output 数组。

下面的代码可以工作,但我该如何用更高效的替代循环呢?

import numpy as np

data = np.array([
    [[11, 12, 13], [14, 15, 16], [17, 18, 19]],
    [[21, 22, 13], [24, 25, 26], [27, 28, 29]],
    [[31, 32, 33], [34, 35, 36], [37, 38, 39]],
])
mask = np.array([
    [False, False, True],
    [False, True, False],
    [True, True, False],
])

output = []
for i in range(len(mask)):
    for j in range(len(mask[i])):
        if mask[i][j] == True:
            output.append(data[i][j])
output = np.array(output)

期望输出结果为:
np.array([[17, 18, 19], [24, 25, 26], [31, 32, 33], [34, 35, 36]])

2
data[mask]?我有什么遗漏吗? - Sayandip Dutta
1
@SayandipDutta 是的,就是这样。我现在感觉有点傻。谢谢! - Florian Ludewig
1个回答

2
import numpy as np

data = np.array([
    [[11, 12, 13], [14, 15, 16], [17, 18, 19]],
    [[21, 22, 13], [24, 25, 26], [27, 28, 29]],
    [[31, 32, 33], [34, 35, 36], [37, 38, 39]],
])
mask = np.array([
    [False, False, True],
    [False, True, False],
    [True, True, False],
])

output = data[mask]

您的答案可以通过提供更多支持信息来改善。请编辑以添加进一步细节,例如引用或文档,以便他人确认您的答案是正确的。您可以在帮助中心找到有关如何撰写良好答案的更多信息。 - Community

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接