在numpy数组中查找相同的行和列

4
我有一个nxn的布尔数组,我想检查是否有任何一行与另一行相同。如果有任何相同的行,我想检查相应的列是否也相同。
以下是一个示例:
A=np.array([[0, 1, 0, 0, 0, 1],
            [0, 0, 0, 1, 0, 1],
            [0, 1, 0, 0, 0, 1],
            [1, 0, 1, 0, 1, 1],
            [1, 1, 1, 0, 0, 0],
            [0, 1, 0, 1, 0, 1]])

我希望程序能够找到第一行和第三行相同,并检查第一列和第三列是否也相同;在这种情况下,它们是相同的。

你是否重视性能? - wim
不要太多,因为数组很小。 - cgog
4个回答

4
你可以使用np.array_equal()函数:
for i in range(len(A)):  # generate pairs
    for j in range(i + 1, len(A)): 
        if np.array_equal(A[i], A[j]):  # compare rows
            if np.array_equal(A[:,i], A[:,j]):  # compare columns
                print(i, j)
        else:
            pass

或者使用combinations()

import itertools

for pair in itertools.combinations(range(len(A)), 2):
    if np.array_equal(A[pair[0]], A[pair[1]]) and np.array_equal(A[:,pair[0]], A[:,pair[1]]):  # compare columns
        print(pair)

2

以下是应用np.unique于2D数组并返回唯一对的典型方法:

def unique_pairs(arr):
    uview = np.ascontiguousarray(arr).view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[1])))
    uvals, uidx = np.unique(uview, return_inverse=True)
    pos = np.where(np.bincount(uidx) == 2)[0]

    pairs = []
    for p in pos:
        pairs.append(np.where(uidx==p)[0])

    return np.array(pairs)

我们可以执行以下操作:
row_pairs = unique_pairs(A)
col_pairs = unique_pairs(A.T)

for pair in row_pairs:
    if np.any(np.all(pair==col_pairs, axis=1)):
        print pair

>>> [0 2]

当然,还有很多优化需要进行,但主要的是使用np.unique。与其他方法相比,该方法的效率在多大程度上取决于您如何定义“小”数组。

1

既然你说性能不是关键,这里提供一种不太符合Python风格的暴力解决方案:

>>> n = len(A)
>>> for i1, row1 in enumerate(A):
...     offset = i1 + 1  # skip rows already compared 
...     for i2, row2 in enumerate(A[offset:], start=offset):
...         if (row1 == row2).all() and (A.T[i1] == A.T[i2]).all():
...             print i1, i2
...             
0 2

这可能是O(n^2)。我使用转置数组A.T来检查列的相等性。

1
对于小数组,可以通过NumPy广播的方式来避免依赖Python循环的另一种方法。
bool_array = np.logical_not(np.logical_xor(A[:,np.newaxis,:], A[np.newaxis,:,:])) # XNOR for comparison
matches_array = np.sum(bool_array, axis=2)  # count total matches for all elements in a row
row1, row2 = np.where(matches_array == A.shape[1]) # identical row = all elements in a row match
row1, row2 = row1[row2 > row1], row2[row2 > row1]  # filter self & duplicated comparisons
column_match = np.all(A[:,row1] == A[:,row2], axis=0)  # check if the corresponding columns are identical
for r1, r2, c in zip(row1, row2, column_match):
    print("Row %d and row %d : Column identical: %s" % (r1, r2, c))

如前所述,当A变得很大时,这种方法将不起作用,因为在计算过程中需要O(n^3)的存储空间(由于bool_array)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接