在 NumPy 数组的某个轴上计算唯一元素的数量

7
我有一个三维数组,如下所示:
A = np.array([[[1, 1],
                [1, 0]],
[[1, 2], [1, 0]],
[[1, 0], [0, 0]]])
现在我想要获得一个数组,如果一个位置上只有一个非零值(或为零),则该位置上的值为非零值。如果该位置上只有零或多个非零值,则其值为零。对于上面的例子,我想要的是:
[[1, 0],
 [1, 0]]
因为:
  • A[:,0,0]中只有1
  • A[:,0,1]中有012,因此有多个非零值
  • A[:,1,0]中有01,因此保留1
  • A[:,1,1]中只有0
我可以使用np.count_nonzero(A, axis=0)找出有多少个非零元素,但即使有多个12,我也想保留它们。我查看了np.unique,但它似乎不支持我想要做的事情。
理想情况下,我希望有一个类似于np.count_unique(A, axis=0)的函数,它将返回一个原始形状的数组,例如[[1, 3], [2, 1]],这样我就可以检查是否发生了3个或更多值,并忽略该位置。

如果我理解正确的话,np.count_nonzero(A, axis=0) 可以解决你的第一个问题。接下来的部分是从列表中保留唯一的1和2,对吗?为什么不直接使用 np.greater 来实现呢?然后将结果列表附加上去就可以了。 - Adeel Ahmad
2个回答

4
你可以使用np.diff来完成第二个任务,这样可以保持在numpy级别上。
def diffcount(A):
    B=A.copy()
    B.sort(axis=0)
    C=np.diff(B,axis=0)>0
    D=C.sum(axis=0)+1
    return D

# [[1 3]
#  [2 1]]

在大数组上,它似乎会更快一些:

In [62]: A=np.random.randint(0,100,(100,100,100))

In [63]: %timeit diffcount(A)
46.8 ms ± 769 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

In [64]: timeit [[len(np.unique(A[:, i, j])) for j in range(A.shape[2])]\
for i in range(A.shape[1])]
149 ms ± 700 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

最后,计算唯一值比排序更简单,可以节省 ln(A.shape[0]) 的时间。

一个节省时间的方法是使用 set 机制:

In [81]: %timeit np.apply_along_axis(lambda a:len(set(a)),axis=0,A) 
183 ms ± 1.17 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

很遗憾,这并不更快。

另一种方法是手动执行:
def countunique(A,Amax):
    res=np.empty(A.shape[1:],A.dtype)
    c=np.empty(Amax+1,A.dtype)
    for i in range(A.shape[1]):
        for j in range(A.shape[2]):
            T=A[:,i,j]
            for k in range(c.size): c[k]=0 
            for x in T:
                c[x]=1
            res[i,j]= c.sum()
    return res 

在Python层面上:

In [70]: %timeit countunique(A,100)
429 ms ± 18.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

对于纯 Python 方法来说,这并不算太糟糕。然后,只需使用 Numba 将此代码转换为低级代码:

import numba    
countunique2=numba.jit(countunique)  

In [71]: %timeit countunique2(A,100)
3.63 ms ± 70.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

这将会很难大幅改善。


(B[1:] != B[:-1]).sum(0)+1 可能更快,特别适用于具有大型第一轴长度的数组。 - Divakar
{btsdaf} - B. M.
{btsdaf} - Divakar
{btsdaf} - Eric Duminil
{btsdaf} - Raketenolli
显示剩余3条评论

2

一种方法是使用A作为第一个轴索引,设置一个布尔数组,其长度沿着其他两个轴相同,然后仅计算它的第一个轴上的非零数。有两种变体可用 - 一种保持其为3D,另一种是为了获得一些性能优势而重新调整为2D,因为索引到2D会更快。因此,这两种实现方式分别为 -

def nunique_axis0_maskcount_app1(A):
    m,n = A.shape[1:]
    mask = np.zeros((A.max()+1,m,n),dtype=bool)
    mask[A,np.arange(m)[:,None],np.arange(n)] = 1
    return mask.sum(0)

def nunique_axis0_maskcount_app2(A):
    m,n = A.shape[1:]
    A.shape = (-1,m*n)
    maxn = A.max()+1
    N = A.shape[1]
    mask = np.zeros((maxn,N),dtype=bool)
    mask[A,np.arange(N)] = 1
    A.shape = (-1,m,n)
    return mask.sum(0).reshape(m,n)

运行时测试 -

In [154]: A = np.random.randint(0,100,(100,100,100))

# @B. M.'s soln
In [155]: %timeit f(A)
10 loops, best of 3: 28.3 ms per loop

# @B. M.'s soln using slicing : (B[1:] != B[:-1]).sum(0)+1
In [156]: %timeit f2(A)
10 loops, best of 3: 26.2 ms per loop

In [157]: %timeit nunique_axis0_maskcount_app1(A)
100 loops, best of 3: 12 ms per loop

In [158]: %timeit nunique_axis0_maskcount_app2(A)
100 loops, best of 3: 9.14 ms per loop

Numba方法

使用与nunique_axis0_maskcount_app2相同的策略,直接通过numba在C级别获取计数,我们将得到-

from numba import njit

@njit
def nunique_loopy_func(mask, N, A, p, count):
    for j in range(N):
        mask[:] = True
        mask[A[0,j]] = False
        c = 1
        for i in range(1,p):
            if mask[A[i,j]]:
                c += 1
            mask[A[i,j]] = False
        count[j] = c
    return count

def nunique_axis0_numba(A):
    p,m,n = A.shape
    A.shape = (-1,m*n)
    maxn = A.max()+1
    N = A.shape[1]
    mask = np.empty(maxn,dtype=bool)
    count = np.empty(N,dtype=int)
    out = nunique_loopy_func(mask, N, A, p, count).reshape(m,n)
    A.shape = (-1,m,n)
    return out

运行时测试 -

In [328]: np.random.seed(0)

In [329]: A = np.random.randint(0,100,(100,100,100))

In [330]: %timeit nunique_axis0_maskcount_app2(A)
100 loops, best of 3: 11.1 ms per loop

# @B.M.'s numba soln
In [331]: %timeit countunique2(A,A.max()+1)
100 loops, best of 3: 3.43 ms per loop

# Numba soln posted in this post
In [332]: %timeit nunique_axis0_numba(A)
100 loops, best of 3: 2.76 ms per loop

nunique_axis0_numbacountunique2 在我的电脑上花费的时间完全相同。 - B. M.
{btsdaf} - Divakar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接