检查两个numpy数组是否相同

12
假设我有一堆数组,包括x和y,并且我想检查它们是否相等。通常,我可以只使用np.all(x == y)(除了一些我现在忽略的愚蠢的边角情况)。
然而,这会评估整个(x == y)数组,这通常是不必要的。我的数组非常大,我有很多数组,两个数组相等的概率很小,因此很可能在all函数返回False之前,我只需要评估(x == y)的一小部分,所以这对我来说不是最优解决方案。
我尝试使用内置的all函数,结合itertools.izip:all(val1==val2 for val1,val2 in itertools.izip(x, y))
然而,仅当两个数组相等时,似乎比起使用np.all,这种方法速度要慢得多。我认为这是因为内置的all功能更加通用。而np.all无法用于生成器。
有没有一种更快速的方法来实现我的需求?
我知道这个问题类似于之前提出的问题(例如Comparing two numpy arrays for equality, element-wise),但它们并没有涉及到提前结束的情况。

这个函数怎么样:https://docs.scipy.org/doc/numpy-1.12.0/reference/generated/numpy.array_equal.html - Thomas Kühn
@Thomas:该函数内部只是调用了np.all,所以它有点无用。 (我确实希望专门为此目的设计一个函数进行短路,但遗憾的是没有。) - acdr
1
嗯,真遗憾。我猜想,一个numpy内部函数可能是你唯一的机会,因为任何在numpy之外的循环几乎都会更慢。你考虑过直接联系开发人员吗? - Thomas Kühn
8个回答

13

在numpy本身实现此功能之前,您可以编写自己的函数并使用numba进行jit编译:

直到此功能在numpy中本地实现之前,您可以编写自己的函数,并使用numba进行jit编译:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def arrays_equal(a, b):
    if a.shape != b.shape:
        return False
    for ai, bi in zip(a.flat, b.flat):
        if ai != bi:
            return False
    return True


a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)


%timeit np.all(a==b)  # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a)  # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b)  # 100000 loops, best of 3: 691 ns per loop

最坏情况下的性能(数组相等)与np.all等效,在提前停止的情况下,编译的函数有可能比np.all快很多。

最坏情况下的性能(数组相等)相当于np.all,在提前停止的情况下,编译的函数有可能远远超过np.all的性能。


@acdr 当我使用你的数组时,np.all 花费了 1.8 毫秒,而 arrays_equal 只花费了 183 微秒。如果我将 arr1 与自身进行比较,则两者都需要约 1.8 毫秒。也许这种差异是由我们的系统引起的?我使用的是 Python 3.5.2、numpy 1.12.1 和 numba 0.27.0。 - MB-F
可能吧。总的来说,我运行的是比你更老的东西:Python 2.7.10.2,numpy 1.9.1,numba 0.20.0。 - acdr
Np.all 没有分支指令。在数组相同的情况下,你会期望一个没有分支的函数比一个有分支的函数更快。这可能就是差异的来源。你应该看看你的使用情况并决定哪种情况更有可能发生。这仍然是 Python,而不是汇编语言,所以微小的优化并不总是会产生你想要的效果。 - Daniel
@dangom 你的观点可以通过编写上述array_compare的无分支版本轻松验证。我这样做了,发现你是对的 - 在最坏情况下,无分支版本略微更快。然而,np.all的性能与分支版本相当。(这使你关于np.all没有分支指令的说法产生了怀疑。)无论如何,这里的重点不是优化最坏情况,而是可能的情况,这种情况据称可以从短路终止中受益。我怀疑numba 0.20在特定情况下产生的代码比numba 0.27不太优秀。 - MB-F
1
很遗憾,这个函数并不支持。array_equal 函数会在内部调用 np.all(a==b) - MB-F
显示剩余2条评论

1

嗯,我知道这个回答并不好,但似乎没有简单的方法解决这个问题。Numpy的创建者应该修复它。我建议:

def compare(a, b):
    if len(a) > 0 and not np.array_equal(a[0], b[0]):
        return False
    if len(a) > 15 and not np.array_equal(a[:15], b[:15]):
        return False
    if len(a) > 200 and not np.array_equal(a[:200], b[:200]):
        return False
    return np.array_equal(a, b)

:)


1
因为没有人说过不能使用numpy完成这个任务,而且问题仍然存在,我认为 - Śmigło
这个答案已被接受并使用了numpy。 - MechMK1
2
它使用numba。如果您对某人坦率地告知没有更好的方法来做某事感到不满意,您可以标记它,但是我的答案至少包含创造性的解决方案。 - Śmigło

1
向数组比较中添加短路逻辑显然正在numpy页面上的github上讨论,并且可能会在未来的numpy版本中提供。

1
将原始问题分解为三个部分:“(1)我的数组非常大,(2)我有很多这样的数组,(3)两个数组相等的概率很小。”
到目前为止,所有的解决方案都集中在第一部分——优化每个相等检查的性能,有些方案可以将性能提高10倍。第二和第三点被忽略了。比较每对数组的复杂度为O(n^2),对于许多矩阵来说,这可能会变得非常巨大,而重复的概率却非常小。
使用以下通用算法,检查可以变得更快:
- 每个数组的快速哈希O(n) - 仅对具有相同哈希值的数组进行相等性检查
一个好的哈希几乎是唯一的,因此键的数量可以很容易地成为n的一个非常大的分数。平均而言,具有相同哈希值的数组数量非常小,有些情况下几乎为1。重复的数组将具有相同的哈希值,而具有相同的哈希值并不能保证它们是重复的。从这个意义上讲,该算法将捕获所有的重复项。仅比较具有相同哈希值的图像显着减少了比较次数,这几乎变成了O(n)。
对于我的问题,我需要在大约一百万个整数数组内检查重复项,每个数组都有1万个元素。仅优化数组相等检查(使用@MB-F的解决方案)的估计运行时间为5天。通过先进行哈希处理,运行时间缩短到几分钟。(我使用数组和作为哈希值,这适用于我的数组特征)
一些伪Python代码

def fast_hash(arr) -> int:
    pass

def arrays_equal(arr1, arr2) -> bool:
    pass

def make_hash_dict(array_stack, hush_fn=np.sum):

    hash_dict = defaultdict(list)
    hashes = np.squeeze(np.apply_over_axes(hush_fn, array_stack, range(1, array_stack.ndim)))
    for idx, hash_val in enumerate(hashes):
        hash_dict[hash_val].append(idx)

    return hash_dict

def get_duplicate_sets(hash_dict, array_stack):

    duplicate_sets = []
    for hash_key, ind_list in hash_dict.items():
        if len(ind_list) == 1:
            continue

        all_duplicates = []
        for idx1 in range(len(ind_list)):
            v1 = ind_list[idx1]
            if v1 in all_duplicates:
                continue

            arr1 = array_stack[v1]
            curr_duplicates = []
            for idx2 in range(idx1+1, len(ind_list)):
                v2 = ind_list[idx2]
                arr2 = array_stack[v2]
                if arrays_equal(arr1, arr2):
                    if len(curr_duplicates) == 0:
                        curr_duplicates.append(v1)
                    curr_duplicates.append(v2)
            
            if len(curr_duplicates) > 0:
                all_duplicates.extend(curr_duplicates)
                duplicate_sets.append(curr_duplicates)

    return duplicate_sets



变量duplicate_sets是一个列表,其中每个内部列表包含所有相同重复项的索引。

1

好的,我还没有检查它是否会破坏电路,所以这不是一个真正的答案:

assert_array_equal

根据文档:

如果两个 array_like 对象不相等,则引发 AssertionError。

如果不在性能敏感的代码路径上,请尝试使用 Try Except

或者跟随底层源代码,也许它很高效。


1
谢谢您的建议。不幸的是,底层代码似乎只是x == y的包装器,并添加了一些额外的步骤来处理一些边缘情况(如NaNInf)。 - acdr

1
也许了解底层数据结构的人可以优化这个程序,或者解释它是否可靠/安全/良好的实践方法,但它似乎能够正常工作。
np.all(a==b)
Out[]: True

memoryview(a.data)==memoryview(b.data)
Out[]: True

%timeit np.all(a==b)
The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.2 µs per loop

%timeit memoryview(a.data)==memoryview(b.data)
The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 1.85 µs per loop

如果我理解正确,ndarray.data会创建一个指向数据缓存区的指针,而memoryview则创建了一种本地python类型,可以从缓冲区中短路出来。
我想是这样的。
编辑:进一步测试显示它可能并不像展示的那样对时间有所改善。之前是a=b=np.eye(5)
a=np.random.randint(0,10,(100,100))

b=a.copy()

%timeit np.all(a==b)
The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 17.7 µs per loop

%timeit memoryview(a.data)==memoryview(b.data)
10000 loops, best of 3: 30.1 µs per loop

np.all(a==b)
Out[]: True

memoryview(a.data)==memoryview(b.data)
Out[]: True

这不仅仅是测试两个数组是否实际上是同一个对象的不同名称,而是具有相同值的两个不同对象吗? - acdr
据我所知,没有。使用上述的.copy()进行测试,然后按相同的方式依次操作上述两个随机数组。 - Daniel F
对我来说不起作用,使用当前Anaconda版本的numpy。也许它只是不喜欢NaNs。 - matanster
@matanster 不确定你尝试了什么,但在标准用法中 NaN != NaN。 - Daniel F

0
正如Thomas Kühn在您的帖子中所评论的那样,array_equal是一个应该解决问题的函数。它在Numpy's API reference中有描述。

0
你可以迭代数组的所有元素并检查它们是否相等。 如果这些数组很可能不相等,那么比使用.all函数更快地返回。 代码示例:
import numpy as np

a = np.array([1, 2, 3])
b = np.array([1, 3, 4])

areEqual = True

for x in range(0, a.size-1):
        if a[x] != b[x]:
                areEqual = False
                break
        else:
               print "a[x] is equal to b[x]\n"

if areEqual:
        print "The tables are equal\n"
else:
        print "The tables are not equal\n"

这实际上就是 all(val1==val2 for val1,val2 in itertools.izip(x, y)) 的作用:它遍历 xy,返回 val1val2 的一对,检查它们是否相同,并将结果传递给 all,一旦找到不相等的一对,它就会返回 False - acdr
哦,我明白了,我以为它会遍历数组的所有元素。 - Anoroah
幸运的是,内置的 all 做了电路断开,不像 np.all。 :) - acdr

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接