从NumPy 2D数组的子数组中提取相交数组的索引

3
我有两个二维numpy正方形数组A和B。B是从A中提取出来的数组,其中某些列和行(具有相同的索引)已被剥离。它们都是对称的。例如,A和B可能是:
A = np.array([[1,2,3,4,5],
              [2,7,8,9,10],
              [3,8,13,14,15],
              [4,9,14,19,20],
              [5,10,15,20,25]])
B = np.array([[1,3,5],
              [3,13,15],
              [5,15,25]])

使缺失的索引为[1,3],交叉的索引为[0,2,4]。

有没有一种“智能”方法来提取A中与B中存在的行/列相应的索引,涉及高级索引等?我能想到的是:

        import numpy as np
        index = np.array([],dtype=int)
        n,m = len(A),len(B)
        for j in range(n):
            k = 0
            while set(np.intersect1d(B[j],A[k])) != set(B[j]) and k<m:
                k+=1
            np.append(index,k)

我知道当处理大型数组时,这种方法速度较慢且资源消耗较大。

谢谢!

编辑: 我找到了一种更聪明的方法。我从两个数组中提取对角线,并使用简单的相等检查在其上执行上述循环:

        index = []
        a = np.diag(A)
        b = np.diag(B)
        for j in range(len(b)):
            k = 0
            while a[j+k] != b[j] and k<n:
                k+=1
            index.append(k+j)

尽管它仍未使用高级索引,并且仍然迭代可能很长的列表,但这个部分解决方案看起来更清晰,我暂时会坚持使用它。

这应该类似于Matlab函数ismember。请查看此SO - Bort
@unutbu 是的,A 中可能会有重复的值,答案应该是唯一的。两个矩阵都应足够大,以便所有列或行向量都不同。 - skcidereves
A可以从500x500到2000x2000的范围内,很少会更大。B略小一些(不应剥夺许多行/列)。 - skcidereves
关于编辑: 我认为当对角线包含重复值时,新方法可能无法产生正确的结果。以A为单位矩阵为极端例子。 - unutbu
你说得对,但是只有在重复值沿对角线连续出现时才会出现错误,在实际操作中这种情况几乎从未发生过。不过我还是会将问题留给开放的状态以防万一。 - skcidereves
显示剩余3条评论
1个回答

2
考虑所有值都是不同的简单情况:
A = np.arange(25).reshape(5,5)
ans = [1,3,4]
B = A[np.ix_(ans, ans)]

In [287]: A
Out[287]: 
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])

In [288]: B
Out[288]: 
array([[ 6,  8,  9],
       [16, 18, 19],
       [21, 23, 24]])

如果我们将B的第一行与A的每一行进行比较,最终会得到[6, 8, 9][5, 6, 7, 8, 9]之间的比较,从中可以得出索引的候选解[1, 3, 4]
我们可以通过将B的第一行与A的每一行配对来生成所有可能的候选解集合。
如果只有一个候选解,则我们已经完成了,因为我们知道B是A的子矩阵,因此一定存在解决方案。
如果有多个候选解,则我们可以用B的第二行做同样的事情,并取候选解的交集。毕竟,一个解决方案必须是B的每一行的解决方案。
因此,我们可以循环遍历B的行,并在发现只有一个候选解时进行短路处理。再次强调,我们假设B始终是A的子矩阵。
下面的find_idx函数实现了上述思想:
import itertools as IT
import numpy as np

def find_idx_1d(rowA, rowB):
    result = []
    if np.in1d(rowB, rowA).all():
        result = [tuple(sorted(idx)) 
                  for idx in IT.product(*[np.where(rowA==b)[0] for b in rowB])]
    return result

def find_idx(A, B):
    candidates = set([idx for row in A for idx in find_idx_1d(row, B[0])])
    for Bi in B[1:]:
        if len(candidates) == 1:
            # stop when there is a unique candidate
            return candidates.pop()
        new = [idx for row in A for idx in find_idx_1d(row, Bi)]  
        candidates = candidates.intersection(new)
    if candidates:
        return candidates.pop()
    raise ValueError('no solution found')

正确性:您提出的两种解决方案在存在重复值时可能无法始终返回正确结果。例如,

def is_solution(A, B, idx):
    return np.allclose(A[np.ix_(idx, idx)], B)

def find_idx_orig(A, B):
    index = []
    for j in range(len(B)):
        k = 0
        while k<len(A) and set(np.intersect1d(B[j],A[k])) != set(B[j]):
            k+=1
        index.append(k)
    return index

def find_idx_diag(A, B):
    index = []
    a = np.diag(A)
    b = np.diag(B)
    for j in range(len(b)):
        k = 0
        while a[j+k] != b[j] and k<len(A):
            k+=1
        index.append(k+j)
    return index

def counterexample():
    """
    Show find_idx_diag, find_idx_orig may not return the correct result
    """
    A = np.array([[1,2,0],
                  [2,1,0],
                  [0,0,1]])
    ans = [0,1]
    B = A[np.ix_(ans, ans)]
    assert not is_solution(A, B, find_idx_orig(A, B))
    assert is_solution(A, B, find_idx(A, B))

    A = np.array([[1,2,0],
                  [2,1,0],
                  [0,0,1]])
    ans = [1,2]
    B = A[np.ix_(ans, ans)]

    assert not is_solution(A, B, find_idx_diag(A, B))
    assert is_solution(A, B, find_idx(A, B))

counterexample()

基准测试:忽略正确性问题,出于好奇,让我们比较这些函数的速度。

def make_AB(n, m):
    A = symmetrize(np.random.random((n, n)))
    ans = np.sort(np.random.choice(n, m, replace=False))
    B = A[np.ix_(ans, ans)]
    return A, B

def symmetrize(a):
    "https://dev59.com/xHE85IYBdhLWcg3w3Xhp#2573982 (EOL)"
    return a + a.T - np.diag(a.diagonal())

if __name__ == '__main__':
    counterexample()
    A, B = make_AB(500, 450)
    assert is_solution(A, B, find_idx(A, B))

In [283]: %timeit find_idx(A, B)
10 loops, best of 3: 74 ms per loop

In [284]: %timeit find_idx_orig(A, B)
1 loops, best of 3: 14.5 s per loop

In [285]: %timeit find_idx_diag(A, B)
100 loops, best of 3: 2.93 ms per loop

因此,find_idxfind_idx_orig快得多,但不如find_idx_diag快。


感谢您提供详细的答案!您的解决方案仅比我的慢一点,但始终是正确的,所以我会切换到它。 - skcidereves

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接