我遇到了相同的问题,并提出了另一种解决方案。您可以使用结构化数据类型将多列数据视为单个条目。结构化数据类型将允许在数据上使用argsort/sort(而不是lexsort,尽管在此阶段lexsort似乎更快),然后使用标准的searchsorted。这是一个例子:
import numpy as np
from itertools import repeat
a = np.array([1,1,1,2,2,3,5,6,6])
b = np.array([10,20,30,5,10,100,10,30,40])
data = np.transpose([a,b])
data = data[np.lexsort(data.T[::-1])]
dt = np.dtype(zip(repeat(''), repeat(data.dtype, data.shape[1])))
data = np.ascontiguousarray(data).view(dt).squeeze(-1)
values = np.array([(2,7),(5,150)], dtype=dt)
pos = np.searchsorted(data, values)
这适用于任意数量的列,使用内置的numpy函数,列保持“逻辑”顺序(优先级降低),且速度应该很快。
我比较了这两种基于numpy的方法的时间。
#1 是来自@j0ker5的递归方法(下面的示例扩展了他的递归建议,并适用于任何数量的lexsorted行)
#2 是来自我的结构化数组
它们都采用相同的输入方式,基本上类似于searchsorted,除了a和v是根据lexsort而定的。
import numpy as np
def lexsearch1(a, v, side='left', sorter=None):
def _recurse(a, v):
if a.shape[1] == 0: return 0
if a.shape[0] == 1: return a.squeeze(0).searchsorted(v.squeeze(0), side)
bl = np.searchsorted(a[-1,:], v[-1], side='left')
br = np.searchsorted(a[-1,:], v[-1], side='right')
return bl + _recurse(a[:-1,bl:br], v[:-1])
a,v = np.asarray(a), np.asarray(v)
if v.ndim == 1: v = v[:,np.newaxis]
assert a.ndim == 2 and v.ndim == 2 and a.shape[0] == v.shape[0] and a.shape[0] > 1
if sorter is not None: a = a[:,sorter]
bl = np.searchsorted(a[-1,:], v[-1,:], side='left')
br = np.searchsorted(a[-1,:], v[-1,:], side='right')
for i in xrange(len(bl)): bl[i] += _recurse(a[:-1,bl[i]:br[i]], v[:-1,i])
return bl
def lexsearch2(a, v, side='left', sorter=None):
from itertools import repeat
a,v = np.asarray(a), np.asarray(v)
if v.ndim == 1: v = v[:,np.newaxis]
assert a.ndim == 2 and v.ndim == 2 and a.shape[0] == v.shape[0] and a.shape[0] > 1
a_dt = np.dtype(zip(repeat(''), repeat(a.dtype, a.shape[0])))
v_dt = np.dtype(zip(a_dt.names, repeat(v.dtype, a.shape[0])))
a = np.asfortranarray(a[::-1,:]).view(a_dt).squeeze(0)
v = np.asfortranarray(v[::-1,:]).view(v_dt).squeeze(0)
return a.searchsorted(v, side, sorter).ravel()
a = np.random.randint(100, size=(2,10000))
v = np.random.randint(100, size=(2,10000))
sorted_idx = np.lexsort(a)
a_sorted = a[:,sorted_idx]
而且时序结果(在iPython中):
# 2 rows
%timeit lexsearch1(a_sorted, v)
10 loops, best of 3: 33.4 ms per loop
%timeit lexsearch2(a_sorted, v)
100 loops, best of 3: 14 ms per loop
# 10 rows
%timeit lexsearch1(a_sorted, v)
10 loops, best of 3: 103 ms per loop
%timeit lexsearch2(a_sorted, v)
100 loops, best of 3: 14.7 ms per loop
整体来说,结构化数组方法更快,如果您设计它与
a
和
v
的翻转和转置版本一起使用,它甚至可以变得更快。当行数/键数增加时,速度会更快,从2行到10行几乎不会变慢。
我没有注意到使用
a_sorted
或a以及
sorter=sorted_idx
之间有任何显着的时间差异,因此我将它们省略以保持清晰明了。
我相信使用Cython可以制作出一个真正快速的方法,但这已经是使用纯Python和numpy的最快方法了。