保留仅在位置0具有一个唯一值的子数组

3

从Numpy的nd-array开始:

>>> arr
[
    [
        [10, 4, 5, 6, 7],
        [11, 1, 2, 3, 4],
        [11, 5, 6, 7, 8]
    ],
    [
        [12, 4, 5, 6, 7],
        [12, 1, 2, 3, 4],
        [12, 5, 6, 7, 8]
    ],
    [
        [15, 4, 5, 6, 7],
        [15, 1, 2, 3, 4],
        [15, 5, 6, 7, 8]
    ],
    [
        [13, 4, 5, 6, 7],
        [13, 1, 2, 3, 4],
        [14, 5, 6, 7, 8]
    ],
    [
        [10, 4, 5, 6, 7],
        [11, 1, 2, 3, 4],
        [12, 5, 6, 7, 8]
    ]
]

我想仅保留在位置0具有唯一值的3个子数组序列,以获得以下结果:
>>> new_arr
[
    [
        [12, 4, 5, 6, 7],
        [12, 1, 2, 3, 4],
        [12, 5, 6, 7, 8]
    ],
    [
        [15, 4, 5, 6, 7],
        [15, 1, 2, 3, 4],
        [15, 5, 6, 7, 8]
    ]
]

从初始数组中,arr[0]arr[3]arr[4] 被丢弃,因为它们在位置 0 上都有不止一个唯一值(分别是 [10, 11][13, 14][10, 11, 12])。

我尝试使用 numpy.unique() 进行调整,但只能得到所有子数组在位置 0 处的全局唯一值,这不是所需的。

-- 编辑

以下似乎让我更接近解决方案:

>>> np.unique(arr[0, :, 0])
array([10, 11])

但我不确定如何在不使用Python循环的情况下,使每个子数组满足一个比这更高一级的条件。

4个回答

5

我成功地实现了这个功能,而无需进行任何转置。

arr = np.array(arr)
arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]

TypeError: list indices must be integers or slices, not tuple - Michael Stachura
1
@MichaelStachura 你需要先将它转换为numpy数组。 - rchome

3

我很感兴趣地想看看这些方法的比较情况,因此我使用一个大数据集(4000000, 4, 4)对这里的答案进行了基准测试。

结果

--------------------------------------------------------------------------------------- benchmark: 4 tests ---------------------------------------------------------------------------------------
Name (time in ms)            Min                   Max                  Mean             StdDev                Median                IQR            Outliers     OPS            Rounds  Iterations
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_np_arr_T           128.3483 (1.0)        130.5462 (1.0)        129.0869 (1.0)       0.9536 (1.01)       128.5447 (1.0)       1.5660 (1.83)          2;0  7.7467 (1.0)           8           1
test_np_arr             128.5017 (1.00)       131.2399 (1.01)       129.2841 (1.00)      0.9414 (1.0)        128.9724 (1.00)      0.8553 (1.0)           1;1  7.7349 (1.00)          7           1
test_pure_py_set      2,840.2911 (22.13)    2,849.0413 (21.82)    2,844.4716 (22.04)     3.8494 (4.09)     2,846.1608 (22.14)     6.4168 (7.50)          3;0  0.3516 (0.05)          5           1
test_pure_py          3,688.4772 (28.74)    3,750.0933 (28.73)    3,717.3411 (28.80)    24.7294 (26.27)    3,707.3502 (28.84)    37.1902 (43.48)         2;0  0.2690 (0.03)          5           1
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

这些基准测试使用 pytest-benchmark,因此需要创建一个 venv 来运行它:

python3 -m venv venv
. ./venv/bin/activate
pip install numpy pytest pytest-benchmark

运行测试:

pytest test_runs.py

test_runs.py

import numpy as np

# No guarantee this will produce sub-arrays with shared first index
ARR = np.random.randint(low=0, high=10, size=(4_000_000, 4, 4)).tolist()
# ARR = [
#     [[10, 4, 5, 6, 7], [11, 1, 2, 3, 4], [11, 5, 6, 7, 8]],
#     [[12, 4, 5, 6, 7], [12, 1, 2, 3, 4], [12, 5, 6, 7, 8]],
#     [[15, 4, 5, 6, 7], [15, 1, 2, 3, 4], [15, 5, 6, 7, 8]],
#     [[13, 4, 5, 6, 7], [13, 1, 2, 3, 4], [14, 5, 6, 7, 8]],
#     [[10, 4, 5, 6, 7], [11, 1, 2, 3, 4], [12, 5, 6, 7, 8]],
# ]

def pure_py(arr):
    new_array = []
    for i, v in enumerate(arr):
        first_elems = [x[0] for x in v]
        if all(elem == first_elems[0] for elem in first_elems):
            new_array.append(arr[i])
    return new_array

def pure_py_set(arr):
    new_array = []
    for sub_arr in arr:
        if len(set(x[0] for x in sub_arr)) == 1:
            new_array.append(sub_arr)
    return new_array

def np_arr(arr):
    return arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]

def np_arr_T(arr):
    return arr[(arr[:, :, 0].T == arr[:, 0, 0]).T.all(axis=1)]

def np_not_arr(arr):
    arr = np.array(arr)
    return arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]

RES = np_not_arr(ARR).tolist()

def test_pure_py(benchmark):
    res = benchmark(pure_py, ARR)
    assert res == RES

def test_pure_py_set(benchmark):
    res = benchmark(pure_py_set, ARR)
    assert res == RES

def test_np_arr(benchmark):
    ARR_ = np.array(ARR)
    res = benchmark(np_arr, ARR_)
    assert res.tolist() == RES

def test_np_arr_T(benchmark):
    ARR_ = np.array(ARR)
    res = benchmark(np_arr_T, ARR_)
    assert res.tolist() == RES

你可能也对如何使用Python进一步加速感兴趣。通常情况下,Python的循环速度相当慢。 numpy 可以进一步提高速度。但是它无法达到 numbacythonnumexpr 的优化程度,因为它们会尽可能地优化事物。 - mathfux
不错的比较。你也检查了numpy数组初始化时间吗? - Michael Stachura
有一个 np_not_arr 函数,它将在过滤之前将列表转换为 np.array。您可以添加另一个测试函数,仅对创建和过滤进行基准测试。当我包含它时,它总是比纯 Python 方法慢,而且遍历 ndarray 非常慢。 - Alex

1

受到回答问题的尝试启发(我拒绝了这个建议,因为它应该是一个答案),这里有一些有效的方法:

编辑

>>> arr[(arr[:,:,0].T == arr[:,0,0]).T.all(axis=1)]
[
    [
        [12, 4, 5, 6, 7],
        [12, 1, 2, 3, 4],
        [12, 5, 6, 7, 8]
    ],
    [
        [15, 4, 5, 6, 7],
        [15, 1, 2, 3, 4],
        [15, 5, 6, 7, 8]
    ]
]

技巧在于将结果转置,以便:
# all 0-th positions of each subarray
arr[:,:,0].T

# the first 0-th position of each subarray 
arr[:,0,0]

# whether each 0-th position equals the first one
(arr[:,:,0].T == arr[:,0,0]).T

# keep only the sub-array where the above is true for all positions
(arr[:,:,0].T == arr[:,0,0]).T.all(axis=1)

# lastly, apply this indexing to the initial array
arr[(arr[:,:,0].T == arr[:,0,0]).T.all(axis=1)]

1
嗨@Jivan,我提出了这个编辑。我被禁言了一周,但是我惊讶地发现了一个答案,但由于我被禁言了,我无法提交答案。所以我将其编辑到答案中,希望你会拒绝它,但也注意到并使用它 :) - user17242583
1
你当时的直觉是正确的 :) - Jivan

0

好的,我已经比较了两种解决方案来解决这个问题。一种是使用numpy(@rchome编写的脚本),另一种是不用它-纯Python。

new_array = []
for i, v in enumerate(arr):
    first_elems = [x[0] for x in v]
    if all(elem == first_elems[0] for elem in first_elems):
        new_array.append(arr[i])

该代码执行时间 = (+- 0:00:00.000015)

arr = np.array(arr)
new_array = arr[np.all(arr[:, :, 0] == arr[:, :1, 0], axis=1)]

这段代码的执行时间为 (+- 0:00:00.000060)

使用numpy大约需要4倍的时间。但是我们必须记住,这个数组非常小。也许对于更大的数组,numpy会更快 :)

--编辑-- 我将数组扩大了10倍,以下是我的结果:

  • python: 0:00:00.000205
  • numpy: 0:00:00.002710

所以,对于这个任务来说,使用numpy是多余的。


1
我认为你正在计算将列表转换为numpy数组所需的时间。在我的比较中,只有当子数组少于10个时,numpy解决方案才会慢一些。你还可以通过在单行中检查唯一性来加快解决方案的速度:if len(set(x[0] for x in v)) == 1: new_array.append(v) - Alex
如果初始数组有四百万个元素,每个子数组有16个元素呢? :) - Jivan
@Jivan 是指 (m, n, k) 吗?我会发布我的测试代码。 - Alex

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接