跨数组索引获取最小值

4

我有一个 n×3 的索引数组(将三角形与点相关联),以及一个与这些三角形相关联的浮点数值列表。现在,我想为每个索引 ("点") 获取最小值,即检查包含该索引(例如0)的所有行,并从相应的行中获取 vals 的最小值:

import numpy

a = numpy.array([
    [0, 1, 2],
    [2, 3, 0],
    [1, 4, 2],
    [2, 5, 3],
])
vals = numpy.array([0.1, 0.5, 0.3, 0.6])

out = [
    numpy.min(vals[numpy.any(a == i, axis=1)])
    for i in range(6)
]
# out = numpy.array([0.1, 0.1, 0.1, 0.5, 0.3, 0.6])

这种解决方法效率低下,因为它需要对每个i进行完整的数组比较。

这个问题与numpy的ufuncs非常相似,但是numpy.min.at不存在。

有什么提示吗?


你能解释一下如何得到输出中的0.1、0.1、0.1吗?有可重现的代码吗? - LazyCoder
@SANTOSHKUMARDESAI 已完成。 - Nico Schlömer
有趣的优化问题。我认为你的解决方案已经达到了最好的水平。 - LazyCoder
如果有一些ID丢失了,比如a [0,1]a [2,0]都是0,那怎么办呢?所以,在a中我们没有1。这种情况可能发生吗? - Divakar
4个回答

2

方法 #1

一种基于数组分配的方法,用来设置一个填充了 NaNs2D 数组,使用那些 a 值作为列索引(因此假设它们是整数),然后将 vals 映射到其中,并查找 nan 跳过的最小值作为最终输出 -

最初的回答:

首先,我们可以使用numpy库中的nanmin函数来计算跳过nan值的最小值。接下来,我们可以创建一个填充有NaN值的2D数组,然后将vals映射到其中。最后,我们可以使用a值作为列索引,查找每行中跳过nan值的最小值。

nr,nc = len(a),a.max()+1
m = np.full((nr,nc),np.nan)
m[np.arange(nr)[:,None],a] = vals[:,None]
out = np.nanmin(m,axis=0)

方法二

这种方法也基于数组赋值,但使用掩码和np.minimum.reduceat来处理NaNs -

nr,nc = len(a),a.max()+1
m = np.zeros((nc,nr),dtype=bool)
m[a.T,np.arange(nr)] = 1
c = m.sum(1)
shift_idx = np.r_[0,c[:-1].cumsum()]
out = np.minimum.reduceat(np.broadcast_to(vals,m.shape)[m],shift_idx)

第三种方法

另一种基于argsort的方法(假设在a中,您拥有从0a.max()的所有整数)-

原始答案:Approach #3

sidx = a.ravel().argsort()
c = np.bincount(a.ravel())
out = np.minimum.reduceat(vals[sidx//a.shape[1]],np.r_[0,c[:-1].cumsum()])

第四种方法

为了提高内存效率和性能,同时也为了完成“set”-

最初的回答

from numba import njit

@njit
def numba1(a, vals, out):
    m,n = a.shape
    for j in range(m):
        for i in range(n):
            e = a[j,i]
            if vals[j] < out[e]:
                out[e] = vals[j]
    return out

def func1(a, vals, outlen=None): # feed in output length as outlen if known
    if outlen is not None:
        N = outlen
    else:
        N = a.max()+1
    out = np.full(N,np.inf)
    return numba1(a, vals, out)

1
你可以在for循环超过6个时,使用pd.GroupByitertools.groupby进行转换。
例如:
r = n.ravel()
pd.Series(np.arange(len(r))//3).groupby(r).apply(lambda s: vals[s].min())

对于较长的循环,这个解决方案会更快,而对于较小的循环(< 50),可能会慢一些。


0

显然,numpy.minimum.at 存在:

import numpy

a = numpy.array([
    [0, 1, 2],
    [2, 3, 0],
    [1, 4, 2],
    [2, 5, 3],
])
vals = numpy.array([0.1, 0.5, 0.3, 0.6])


out = numpy.full(6, numpy.inf)
numpy.minimum.at(out, a.reshape(-1), numpy.repeat(vals, 3))

0
这是一个基于这个问答的例子:
如果你有pythran,编译
文件<stb_pthr.py>
import numpy as np

#pythran export sort_to_bins(int[:], int)

def sort_to_bins(idx, mx):
    if mx==-1:
        mx = idx.max() + 1
    cnts = np.zeros(mx + 2, int)
    for i in range(idx.size):
        cnts[idx[i]+2] += 1
    for i in range(2, cnts.size):
        cnts[i] += cnts[i-1]
    res = np.empty_like(idx)
    for i in range(idx.size):
        res[cnts[idx[i]+1]] = i
        cnts[idx[i]+1] += 1
    return res, cnts[:-1]

否则脚本将退回到基于稀疏矩阵的方法,该方法仅略慢:
import numpy as np
try:
    from stb_pthr import sort_to_bins
    HAVE_PYTHRAN = True
except:
    HAVE_PYTHRAN = False

from scipy.sparse import csr_matrix

def sort_to_bins_sparse(idx, mx):
    if mx==-1:
        mx = idx.max() + 1
    aux = csr_matrix((np.ones_like(idx),idx,np.arange(idx.size+1)),
                     (idx.size,mx)).tocsc()
    return aux.indices, aux.indptr

if not HAVE_PYTHRAN:
    sort_to_bins = sort_to_bins_sparse

def f_op():
    mx = a.max() + 1
    return np.fromiter((np.min(vals[np.any(a == i, axis=1)])
                        for i in range(mx)),vals.dtype,mx)

def f_pp():
    idx, bb = sort_to_bins(a.reshape(-1),-1)
    res = np.minimum.reduceat(vals[idx//3], bb[:-1])
    res[bb[:-1]==bb[1:]] = np.inf
    return res

def f_div_3():
    sidx = a.ravel().argsort()
    c = np.bincount(a.ravel())
    bb = np.r_[0,c.cumsum()]
    res = np.minimum.reduceat(vals[sidx//a.shape[1]],bb[:-1])
    res[bb[:-1]==bb[1:]] = np.inf
    return res

a = np.array([
    [0, 1, 2],
    [2, 3, 0],
    [1, 4, 2],
    [2, 5, 3],
])
vals = np.array([0.1, 0.5, 0.3, 0.6])

assert np.all(f_op()==f_pp())

from timeit import timeit

a = np.random.randint(0,1000,(10000,3))
vals = np.random.random(10000)
assert len(np.unique(a))==1000

assert np.all(f_op()==f_pp())
print("1000/1000 labels, 10000 rows")
print("op ", timeit(f_op, number=10)*100, 'ms')
print("pp ", timeit(f_pp, number=100)*10, 'ms')
print("div", timeit(f_div_3, number=100)*10, 'ms')

a = 1 + 2 * np.random.randint(0,5000,(1000000,3))
vals = np.random.random(1000000)
nl = len(np.unique(a))

assert np.all(f_div_3()==f_pp())
print(f"{nl}/{a.max()+1} labels, 1000000 rows")
print("pp ", timeit(f_pp, number=10)*100, 'ms')
print("div", timeit(f_div_3, number=10)*100, 'ms')

a = 1 + 2 * np.random.randint(0,100000,(1000000,3))
vals = np.random.random(1000000)
nl = len(np.unique(a))

assert np.all(f_div_3()==f_pp())
print(f"{nl}/{a.max()+1} labels, 1000000 rows")
print("pp ", timeit(f_pp, number=10)*100, 'ms')
print("div", timeit(f_div_3, number=10)*100, 'ms')

样例运行(计时包括 @Divakar 方法3 以供参考):

1000/1000 labels, 10000 rows
op  145.1122640981339 ms
pp  0.7944229000713676 ms
div 2.2905819199513644 ms
5000/10000 labels, 1000000 rows
pp  113.86540920939296 ms
div 417.2476712032221 ms
100000/200000 labels, 1000000 rows
pp  158.23634970001876 ms
div 486.13436080049723 ms

更新:@Divakar的最新方法(第4种方法)很难被超越,本质上是一种C实现。除了即时编译不是一个选择而是一个要求之外,没有任何问题(未经即时编译的代码运行起来并不好玩)。如果接受这一点,当然也可以用pythran实现:

pythran -O3 labeled_min.py

文件<labeled_min.py>

import numpy as np

#pythran export labeled_min(int[:,:], float[:])

def labeled_min(A, vals):
    mn = np.empty(A.max()+1)
    mn[:] = np.inf
    M,N = A.shape
    for i in range(M):
        v = vals[i]
        for j in range(N):
            c = A[i,j]
            if v < mn[c]:
                mn[c] = v
    return mn

两者都可以带来巨大的加速:

from labeled_min import labeled_min

func1() # do not measure jitting time    
print("nmb ", timeit(func1, number=100)*10, 'ms')
print("pthr", timeit(lambda:labeled_min(a,vals), number=100)*10, 'ms')

示例运行:

nmb  8.41792532010004 ms
pthr 8.104007659712806 ms

pythran稍微快了几个百分点,但这只是因为我将vals的查找移出了内部循环;如果没有这个操作,它们几乎相等。

作为比较,在同一问题上,之前最好的带和不带非Python助手的情况:

pp           114.04887529788539 ms
pp (py only) 147.0821460010484 ms

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接