我已经创建了几个更高级的策略。
还实现了一种简单的策略,使用像
另一个我的回答 中的元组。
所有解决方案的时间都被测量。
大多数策略都使用
np.searchsorted
作为底层引擎。为了实现这些高级策略,使用了一个特殊的包装类
_CmpIx
,以提供自定义比较函数 (
__lt__
) 用于调用
np.searchsorted
。
py.tuples
策略将所有列转换为元组并将它们存储为np.object_
dtype的numpy 1D数组,然后进行常规搜索排序。
py.zip
使用python的zip来懒惰地执行相同的任务。
np.lexsort
策略只是使用np.lexsort
来按字典顺序比较两列。
np.nonzero
使用np.flatnonzero(a != b)
表达式。
cmp_numba
在_CmpIx
包装器中使用预先编译的numba代码,用于快速按字典顺序懒惰地比较两个提供的元素。
np.searchsorted
使用标准的numpy函数,但仅针对1D情况进行测量。
- 对于
numba
策略,整个搜索算法都是使用Numba引擎从头实现的,该算法基于二分搜索。这个算法有_py
和_nm
两种变体,_nm
更快,因为它使用了Numba编译器,而_py
是相同的算法但未编译。还有_sorted
变体,它对要插入的数组进行了额外的优化,如果已经排序,则更快。
view1d
- 这个方法是由@MadPhysicist在这个答案中提出的。在代码中注释掉它们,因为它们对于所有关键字长度>1的大多数测试返回不正确的答案,可能是由于原始查看数组的某些问题。
在线尝试!
class SearchSorted2D:
class _CmpIx:
def __init__(self, t, p, i):
self.p, self.i = p, i
self.leg = self.leg_cache()[t]
self.lt = lambda o: self.leg(self, o, False) if self.i != o.i else False
self.le = lambda o: self.leg(self, o, True) if self.i != o.i else True
@classmethod
def leg_cache(cls):
if not hasattr(cls, 'leg_cache_data'):
cls.leg_cache_data = {
'py.zip': cls._leg_py_zip, 'np.lexsort': cls._leg_np_lexsort,
'np.nonzero': cls._leg_np_nonzero, 'cmp_numba': cls._leg_numba_create(),
}
return cls.leg_cache_data
def __eq__(self, o): return not self.lt(o) and self.le(o)
def __ne__(self, o): return self.lt(o) or not self.le(o)
def __lt__(self, o): return self.lt(o)
def __le__(self, o): return self.le(o)
def __gt__(self, o): return not self.le(o)
def __ge__(self, o): return not self.lt(o)
@staticmethod
def _leg_np_lexsort(self, o, eq):
import numpy as np
ia, ib = (self.i, o.i) if eq else (o.i, self.i)
return (np.lexsort(self.p.ab[::-1, ia : (ib + (-1, 1)[ib >= ia], None)[ib == 0] : ib - ia])[0] == 0) == eq
@staticmethod
def _leg_py_zip(self, o, eq):
for l, r in zip(self.p.ab[:, self.i], self.p.ab[:, o.i]):
if l < r:
return True
if l > r:
return False
return eq
@staticmethod
def _leg_np_nonzero(self, o, eq):
import numpy as np
a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
ix = np.flatnonzero(a != b)
return a[ix[0]] < b[ix[0]] if ix.size != 0 else eq
@staticmethod
def _leg_numba_create():
import numpy as np
try:
from numba.pycc import CC
cc = CC('ss_numba_mod')
@cc.export('ss_numba_i8', 'b1(i8[:],i8[:],b1)')
def ss_numba(a, b, eq):
for i in range(a.size):
if a[i] < b[i]:
return True
elif b[i] < a[i]:
return False
return eq
cc.compile()
success = True
except:
success = False
if success:
try:
import ss_numba_mod
except:
success = False
def odo(self, o, eq):
a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
assert a.ndim == 1 and a.shape == b.shape, (a.shape, b.shape)
return ss_numba_mod.ss_numba_i8(a, b, eq)
return odo if success else None
def __init__(self, type_):
import numpy as np
self.type_ = type_
self.ci = np.array([], dtype = np.object_)
def __call__(self, a, b, *pargs, **nargs):
import numpy as np
self.ab = np.concatenate((a, b), axis = 1)
self._grow(self.ab.shape[1])
ix = np.searchsorted(self.ci[:a.shape[1]], self.ci[a.shape[1] : a.shape[1] + b.shape[1]], *pargs, **nargs)
return ix
def _grow(self, to):
import numpy as np
if self.ci.size >= to:
return
import math
to = 1 << math.ceil(math.log(to) / math.log(2))
self.ci = np.concatenate((self.ci, [self._CmpIx(self.type_, self, i) for i in range(self.ci.size, to)]))
class SearchSorted2DNumba:
@classmethod
def do(cls, a, v, side = 'left', *, vsorted = False, numba_ = True):
import numpy as np
if not hasattr(cls, '_ido_numba'):
def _ido_regular(a, b, vsorted, lrt):
nk, na, nb = a.shape[0], a.shape[1], b.shape[1]
res = np.zeros((2, nb), dtype = np.int64)
max_depth = 0
if nb == 0:
return res, max_depth
lrb, lre = 0, 0
if vsorted:
brngs = np.zeros((nb, 6), dtype = np.int64)
brngs[0, :4] = (-1, 0, nb >> 1, nb)
i, j, size = 0, 1, 1
while i < j:
for k in range(i, j):
cbrng = brngs[k]
bp, bb, bm, be = cbrng[:4]
if bb < bm:
brngs[size, :4] = (k, bb, (bb + bm) >> 1, bm)
size += 1
bmp1 = bm + 1
if bmp1 < be:
brngs[size, :4] = (k, bmp1, (bmp1 + be) >> 1, be)
size += 1
i, j = j, size
assert size == nb
brngs[:, 4:] = -1
for ibc in range(nb):
if not vsorted:
ib, lrb, lre = ibc, 0, na
else:
ibpi, ib = int(brngs[ibc, 0]), int(brngs[ibc, 2])
if ibpi == -1:
lrb, lre = 0, na
else:
ibp = int(brngs[ibpi, 2])
if ib < ibp:
lrb, lre = int(brngs[ibpi, 4]), int(res[1, ibp])
else:
lrb, lre = int(res[0, ibp]), int(brngs[ibpi, 5])
brngs[ibc, 4 : 6] = (lrb, lre)
assert lrb != -1 and lre != -1
for ik in range(nk):
if lrb >= lre:
if ik > max_depth:
max_depth = ik
break
bv = b[ik, ib]
if nk != 1 or lrt == 2:
cb, ce = lrb, lre
while cb < ce:
cm = (cb + ce) >> 1
av = a[ik, cm]
if av < bv:
cb = cm + 1
elif bv < av:
ce = cm
else:
break
lrb, lre = cb, ce
if nk != 1 or lrt >= 1:
cb, ce = lrb, lre
while cb < ce:
cm = (cb + ce) >> 1
if not (bv < a[ik, cm]):
cb = cm + 1
else:
ce = cm
lre = ce
if nk != 1 or lrt == 0 or lrt == 2:
cb, ce = lrb, lre
while cb < ce:
cm = (cb + ce) >> 1
if a[ik, cm] < bv:
cb = cm + 1
else:
ce = cm
lrb = cb
res[:, ib] = (lrb, lre)
return res, max_depth
cls._ido_regular = _ido_regular
import numba
cls._ido_numba = numba.jit(nopython = True, nogil = True, cache = True)(cls._ido_regular)
assert side in ['left', 'right', 'left_right'], side
a, v = np.array(a), np.array(v)
assert a.ndim == 2 and v.ndim == 2 and a.shape[0] == v.shape[0], (a.shape, v.shape)
res, max_depth = (cls._ido_numba if numba_ else cls._ido_regular)(
a, v, vsorted, {'left': 0, 'right': 1, 'left_right': 2}[side],
)
return res[0] if side == 'left' else res[1] if side == 'right' else res
def Test():
import time
import numpy as np
np.random.seed(0)
def round_float_fixed_str(x, n = 0):
if type(x) is int:
return str(x)
s = str(round(float(x), n))
if n > 0:
s += '0' * (n - (len(s) - 1 - s.rfind('.')))
return s
def to_tuples(x):
r = np.empty([x.shape[1]], dtype = np.object_)
r[:] = [tuple(e) for e in x.T]
return r
searchsorted2d = {
'py.zip': SearchSorted2D('py.zip'),
'np.nonzero': SearchSorted2D('np.nonzero'),
'np.lexsort': SearchSorted2D('np.lexsort'),
'cmp_numba': SearchSorted2D('cmp_numba'),
}
for iklen, klen in enumerate([1, 1, 2, 5, 10, 20, 50, 100, 200]):
times = {}
for side in ['left', 'right']:
a = np.zeros((klen, 0), dtype = np.int64)
tac = to_tuples(a)
for itest in range((15, 100)[iklen == 0]):
b = np.random.randint(0, (3, 100000)[iklen == 0], (klen, np.random.randint(1, (1000, 2000)[iklen == 0])), dtype = np.int64)
b = b[:, np.lexsort(b[::-1])]
if iklen == 0:
assert klen == 1, klen
ts = time.time()
ix1 = np.searchsorted(a[0], b[0], side = side)
te = time.time()
times['np.searchsorted'] = times.get('np.searchsorted', 0.) + te - ts
for cached in [False, True]:
ts = time.time()
tb = to_tuples(b)
ta = tac if cached else to_tuples(a)
ix1 = np.searchsorted(ta, tb, side = side)
if not cached:
ix0 = ix1
tac = np.insert(tac, ix0, tb) if cached else tac
te = time.time()
timesk = f'py.tuples{("", "_cached")[cached]}'
times[timesk] = times.get(timesk, 0.) + te - ts
for type_ in searchsorted2d.keys():
if iklen == 0 and type_ in ['np.nonzero', 'np.lexsort']:
continue
ss = searchsorted2d[type_]
try:
ts = time.time()
ix1 = ss(a, b, side = side)
te = time.time()
times[type_] = times.get(type_, 0.) + te - ts
assert np.array_equal(ix0, ix1)
except Exception:
times[type_ + '!failed'] = 0.
for numba_ in [False, True]:
for vsorted in [False, True]:
if numba_:
SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
ts = time.time()
ix1 = SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
te = time.time()
timesk = f'numba{("_py", "_nm")[numba_]}{("", "_sorted")[vsorted]}'
times[timesk] = times.get(timesk, 0.) + te - ts
assert np.array_equal(ix0, ix1)
if False:
aT, bT = np.copy(a.T), np.copy(b.T)
assert aT.ndim == 2 and bT.ndim == 2 and aT.shape[1] == klen and bT.shape[1] == klen, (aT.shape, bT.shape, klen)
for ty in ['if', 'cf']:
try:
dt = np.dtype({'if': [('', b.dtype)] * klen, 'cf': [('row', b.dtype, klen)]}[ty])
ts = time.time()
va = np.ndarray(aT.shape[:1], dtype = dt, buffer = aT)
vb = np.ndarray(bT.shape[:1], dtype = dt, buffer = bT)
ix1 = np.searchsorted(va, vb, side = side)
te = time.time()
assert np.array_equal(ix0, ix1), (ix0.shape, ix1.shape, ix0[:20], ix1[:20])
times[f'view1d_{ty}'] = times.get(f'view1d_{ty}', 0.) + te - ts
except Exception:
raise
a = np.insert(a, ix0, b, axis = 1)
stimes = ([f'key_len: {str(klen).rjust(3)}'] +
[f'{k}: {round_float_fixed_str(v, 4).rjust(7)}' for k, v in times.items()])
nlines = 4
print('-' * 50 + '\n' + ('', '!LARGE!:\n')[iklen == 0], end = '')
for i in range(nlines):
print(', '.join(stimes[len(stimes) * i // nlines : len(stimes) * (i + 1) // nlines]), flush = True)
Test()
输出:
--------------------------------------------------
!LARGE!:
key_len: 1, np.searchsorted: 0.0250
py.tuples_cached: 3.3113, py.tuples: 30.5263, py.zip: 40.9785
cmp_numba: 25.7826, numba_py: 3.6673
numba_py_sorted: 6.8926, numba_nm: 0.0466, numba_nm_sorted: 0.0505
--------------------------------------------------
key_len: 1, py.tuples_cached: 0.1371
py.tuples: 0.4698, py.zip: 1.2005, np.nonzero: 4.7827
np.lexsort: 4.4672, cmp_numba: 1.0644, numba_py: 0.2748
numba_py_sorted: 0.5699, numba_nm: 0.0005, numba_nm_sorted: 0.0020
--------------------------------------------------
key_len: 2, py.tuples_cached: 0.1131
py.tuples: 0.3643, py.zip: 1.0670, np.nonzero: 4.5199
np.lexsort: 3.4595, cmp_numba: 0.8582, numba_py: 0.4958
numba_py_sorted: 0.6454, numba_nm: 0.0025, numba_nm_sorted: 0.0025
--------------------------------------------------
key_len: 5, py.tuples_cached: 0.1876
py.tuples: 0.4493, py.zip: 1.6342, np.nonzero: 5.5168
np.lexsort: 4.6086, cmp_numba: 1.0939, numba_py: 1.0607
numba_py_sorted: 0.9737, numba_nm: 0.0050, numba_nm_sorted: 0.0065
--------------------------------------------------
key_len: 10, py.tuples_cached: 0.6017
py.tuples: 1.2275, py.zip: 3.5276, np.nonzero: 13.5460
np.lexsort: 12.4183, cmp_numba: 2.5404, numba_py: 2.8334
numba_py_sorted: 2.3991, numba_nm: 0.0165, numba_nm_sorted: 0.0155
--------------------------------------------------
key_len: 20, py.tuples_cached: 0.8316
py.tuples: 1.3759, py.zip: 3.4238, np.nonzero: 13.7834
np.lexsort: 16.2164, cmp_numba: 2.4483, numba_py: 2.6405
numba_py_sorted: 2.2226, numba_nm: 0.0170, numba_nm_sorted: 0.0160
--------------------------------------------------
key_len: 50, py.tuples_cached: 1.0443
py.tuples: 1.4085, py.zip: 2.2475, np.nonzero: 9.1673
np.lexsort: 19.5266, cmp_numba: 1.6181, numba_py: 1.7731
numba_py_sorted: 1.4637, numba_nm: 0.0415, numba_nm_sorted: 0.0405
--------------------------------------------------
key_len: 100, py.tuples_cached: 2.0136
py.tuples: 2.5380, py.zip: 2.2279, np.nonzero: 9.2929
np.lexsort: 33.9505, cmp_numba: 1.5722, numba_py: 1.7158
numba_py_sorted: 1.4208, numba_nm: 0.0871, numba_nm_sorted: 0.0851
--------------------------------------------------
key_len: 200, py.tuples_cached: 3.5945
py.tuples: 4.1847, py.zip: 2.3553, np.nonzero: 11.3781
np.lexsort: 66.0104, cmp_numba: 1.8153, numba_py: 1.9449
numba_py_sorted: 1.6463, numba_nm: 0.1661, numba_nm_sorted: 0.1651
根据时间,
numba_nm
实现是最快的,它比下一个最快的实现(
py.zip
或
py.tuples_cached
)快了
15-100x
倍。对于一维情况,它的速度与标准的
np.searchsorted
相当(
1.85x
较慢)。此外,使用插入数组排序信息的
_sorted
版本并没有改善情况。
cmp_numba
方法是机器代码编译的,平均比同样使用纯Python算法的
py.zip
快
1.5x
。由于平均最大相等键深度约为
15-18
个元素,因此在这里numba没有获得很大的加速。如果深度达到数百个,则numba代码可能会有巨大的加速。
对于键长
<= 100
的情况,
py.tuples_cached
策略比
py.zip
快。
此外,看起来
np.lexsort
实际上非常缓慢,可能是因为它没有针对仅具有两列的情况进行优化,或者它花费时间在预处理方面,例如将行拆分成列表,或者它执行的词典序比较不是惰性的,最后一种情况可能是真正的原因,因为随着键长度增加,lexsort会变慢。
np.nonzero
策略也是不惰性的,因此工作速度也很慢,并且随着键长度的增长而减慢(但速度不如
np.lexsort
慢)。
以上时间可能不太准确,因为我的CPU在过热时随机降低核心频率2-2.3倍,而且它经常过热,因为它是笔记本电脑内部功能强大的CPU。
np.object_
在内的任何dtype
的一般情况? - Artynp.unique(... return_inverse=True)
转换为整数。 - Divakarnp.unique
的想法似乎是一个不错的解决方案,我会尝试编写它。 - Arty