NumPy - np.searchsorted 用于二维数组

3

np.searchsorted 只适用于 1D 数组。

我有一个按字典顺序排列的 2D 数组,意味着第0行已排序,然后对于0行相同值的元素,1行对应的元素也已排序,对于1行相同值的元素,2行的值也已排序。换句话说,由列组成的元组是已排序的。

我还有另一个需要插入到第一个 2D 数组中正确位置的元组列的 2D 数组。对于 1D 情况,通常使用 np.searchsorted 来找到正确的位置。

但是对于 2D 数组,是否有 np.searchsorted 的替代方法?类似于 np.lexsort 是 1D np.argsort 的 2D 替代品。

如果没有这样的功能,那么是否可以使用现有的numpy函数以高效的方式实现此功能?
我对任何dtype的数组的高效解决方案感兴趣,包括np.object_。
处理任何dtype情况的一种天真的方法是将两个数组的每一列转换为1D数组(或元组),然后将这些列存储为另一个dtype为np.object_的1D数组。也许这并不是那么天真,特别是如果列相当高的话,可能会更快。

1
请访问以下链接以了解有关使用NumPy的矢量化searchsorted函数的更多信息:https://dev59.com/xZzha4cB1Zd3GeqPFngu 或 https://dev59.com/crTma4cB1Zd3GeqP4lOs? - Divakar
@Divakar 有趣的解决方案,但只适用于数字。是否可以解决包括 np.object_ 在内的任何 dtype 的一般情况? - Arty
1
是的,我认为这些仅限于整数。也许考虑使用np.unique(... return_inverse=True)转换为整数。 - Divakar
@Divakar用np.unique的想法似乎是一个不错的解决方案,我会尝试编写它。 - Arty
你能以某种方式哈希值吗?实际上,如果你正在寻找对象数组,为什么不将行转换为元组并将其视为1D数组呢? - Mad Physicist
显示剩余4条评论
3个回答

2
我已经创建了几个更高级的策略。
还实现了一种简单的策略,使用像 另一个我的回答 中的元组。
所有解决方案的时间都被测量。
大多数策略都使用 np.searchsorted 作为底层引擎。为了实现这些高级策略,使用了一个特殊的包装类 _CmpIx,以提供自定义比较函数 (__lt__) 用于调用 np.searchsorted
  1. py.tuples策略将所有列转换为元组并将它们存储为np.object_ dtype的numpy 1D数组,然后进行常规搜索排序。
  2. py.zip使用python的zip来懒惰地执行相同的任务。
  3. np.lexsort策略只是使用np.lexsort来按字典顺序比较两列。
  4. np.nonzero使用np.flatnonzero(a != b)表达式。
  5. cmp_numba_CmpIx包装器中使用预先编译的numba代码,用于快速按字典顺序懒惰地比较两个提供的元素。
  6. np.searchsorted使用标准的numpy函数,但仅针对1D情况进行测量。
  7. 对于numba策略,整个搜索算法都是使用Numba引擎从头实现的,该算法基于二分搜索。这个算法有_py_nm两种变体,_nm更快,因为它使用了Numba编译器,而_py是相同的算法但未编译。还有_sorted变体,它对要插入的数组进行了额外的优化,如果已经排序,则更快。
  8. view1d - 这个方法是由@MadPhysicist在这个答案中提出的。在代码中注释掉它们,因为它们对于所有关键字长度>1的大多数测试返回不正确的答案,可能是由于原始查看数组的某些问题。

在线尝试!

class SearchSorted2D:
    class _CmpIx:
        def __init__(self, t, p, i):
            self.p, self.i = p, i
            self.leg = self.leg_cache()[t]
            self.lt = lambda o: self.leg(self, o, False) if self.i != o.i else False
            self.le = lambda o: self.leg(self, o, True) if self.i != o.i else True
        @classmethod
        def leg_cache(cls):
            if not hasattr(cls, 'leg_cache_data'):
                cls.leg_cache_data = {
                    'py.zip': cls._leg_py_zip, 'np.lexsort': cls._leg_np_lexsort,
                    'np.nonzero': cls._leg_np_nonzero, 'cmp_numba': cls._leg_numba_create(),
                }
            return cls.leg_cache_data
        def __eq__(self, o): return not self.lt(o) and self.le(o)
        def __ne__(self, o): return self.lt(o) or not self.le(o)
        def __lt__(self, o): return self.lt(o)
        def __le__(self, o): return self.le(o)
        def __gt__(self, o): return not self.le(o)
        def __ge__(self, o): return not self.lt(o)
        @staticmethod
        def _leg_np_lexsort(self, o, eq):
            import numpy as np
            ia, ib = (self.i, o.i) if eq else (o.i, self.i)
            return (np.lexsort(self.p.ab[::-1, ia : (ib + (-1, 1)[ib >= ia], None)[ib == 0] : ib - ia])[0] == 0) == eq
        @staticmethod
        def _leg_py_zip(self, o, eq):
            for l, r in zip(self.p.ab[:, self.i], self.p.ab[:, o.i]):
                if l < r:
                    return True
                if l > r:
                    return False
            return eq
        @staticmethod
        def _leg_np_nonzero(self, o, eq):
            import numpy as np
            a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
            ix = np.flatnonzero(a != b)
            return a[ix[0]] < b[ix[0]] if ix.size != 0 else eq
        @staticmethod
        def _leg_numba_create():
            import numpy as np

            try:
                from numba.pycc import CC
                cc = CC('ss_numba_mod')
                @cc.export('ss_numba_i8', 'b1(i8[:],i8[:],b1)')
                def ss_numba(a, b, eq):
                    for i in range(a.size):
                        if a[i] < b[i]:
                            return True
                        elif b[i] < a[i]:
                            return False
                    return eq
                cc.compile()
                success = True
            except:    
                success = False
                
            if success:
                try:
                    import ss_numba_mod
                except:
                    success = False
            
            def odo(self, o, eq):
                a, b = self.p.ab[:, self.i], self.p.ab[:, o.i]
                assert a.ndim == 1 and a.shape == b.shape, (a.shape, b.shape)
                return ss_numba_mod.ss_numba_i8(a, b, eq)
                
            return odo if success else None

    def __init__(self, type_):
        import numpy as np
        self.type_ = type_
        self.ci = np.array([], dtype = np.object_)
    def __call__(self, a, b, *pargs, **nargs):
        import numpy as np
        self.ab = np.concatenate((a, b), axis = 1)
        self._grow(self.ab.shape[1])
        ix = np.searchsorted(self.ci[:a.shape[1]], self.ci[a.shape[1] : a.shape[1] + b.shape[1]], *pargs, **nargs)
        return ix
    def _grow(self, to):
        import numpy as np
        if self.ci.size >= to:
            return
        import math
        to = 1 << math.ceil(math.log(to) / math.log(2))
        self.ci = np.concatenate((self.ci, [self._CmpIx(self.type_, self, i) for i in range(self.ci.size, to)]))

class SearchSorted2DNumba:
    @classmethod
    def do(cls, a, v, side = 'left', *, vsorted = False, numba_ = True):
        import numpy as np

        if not hasattr(cls, '_ido_numba'):
            def _ido_regular(a, b, vsorted, lrt):
                nk, na, nb = a.shape[0], a.shape[1], b.shape[1]
                res = np.zeros((2, nb), dtype = np.int64)
                max_depth = 0
                if nb == 0:
                    return res, max_depth
                #lb, le, rb, re = 0, 0, 0, 0
                lrb, lre = 0, 0
                
                if vsorted:
                    brngs = np.zeros((nb, 6), dtype = np.int64)
                    brngs[0, :4] = (-1, 0, nb >> 1, nb)
                    i, j, size = 0, 1, 1
                    while i < j:
                        for k in range(i, j):
                            cbrng = brngs[k]
                            bp, bb, bm, be = cbrng[:4]
                            if bb < bm:
                                brngs[size, :4] = (k, bb, (bb + bm) >> 1, bm)
                                size += 1
                            bmp1 = bm + 1
                            if bmp1 < be:
                                brngs[size, :4] = (k, bmp1, (bmp1 + be) >> 1, be)
                                size += 1
                        i, j = j, size
                    assert size == nb
                    brngs[:, 4:] = -1

                for ibc in range(nb):
                    if not vsorted:
                        ib, lrb, lre = ibc, 0, na
                    else:
                        ibpi, ib = int(brngs[ibc, 0]), int(brngs[ibc, 2])
                        if ibpi == -1:
                            lrb, lre = 0, na
                        else:
                            ibp = int(brngs[ibpi, 2])
                            if ib < ibp:
                                lrb, lre = int(brngs[ibpi, 4]), int(res[1, ibp])
                            else:
                                lrb, lre = int(res[0, ibp]), int(brngs[ibpi, 5])
                        brngs[ibc, 4 : 6] = (lrb, lre)
                        assert lrb != -1 and lre != -1
                        
                    for ik in range(nk):
                        if lrb >= lre:
                            if ik > max_depth:
                                max_depth = ik
                            break

                        bv = b[ik, ib]
                        
                        # Binary searches
                        
                        if nk != 1 or lrt == 2:
                            cb, ce = lrb, lre
                            while cb < ce:
                                cm = (cb + ce) >> 1
                                av = a[ik, cm]
                                if av < bv:
                                    cb = cm + 1
                                elif bv < av:
                                    ce = cm
                                else:
                                    break
                            lrb, lre = cb, ce
                                
                        if nk != 1 or lrt >= 1:
                            cb, ce = lrb, lre
                            while cb < ce:
                                cm = (cb + ce) >> 1
                                if not (bv < a[ik, cm]):
                                    cb = cm + 1
                                else:
                                    ce = cm
                            #rb, re = cb, ce
                            lre = ce
                                
                        if nk != 1 or lrt == 0 or lrt == 2:
                            cb, ce = lrb, lre
                            while cb < ce:
                                cm = (cb + ce) >> 1
                                if a[ik, cm] < bv:
                                    cb = cm + 1
                                else:
                                    ce = cm
                            #lb, le = cb, ce
                            lrb = cb
                            
                        #lrb, lre = lb, re
                            
                    res[:, ib] = (lrb, lre)
                    
                return res, max_depth

            cls._ido_regular = _ido_regular
            
            import numba
            cls._ido_numba = numba.jit(nopython = True, nogil = True, cache = True)(cls._ido_regular)
            
        assert side in ['left', 'right', 'left_right'], side
        a, v = np.array(a), np.array(v)
        assert a.ndim == 2 and v.ndim == 2 and a.shape[0] == v.shape[0], (a.shape, v.shape)
        res, max_depth = (cls._ido_numba if numba_ else cls._ido_regular)(
            a, v, vsorted, {'left': 0, 'right': 1, 'left_right': 2}[side],
        )
        return res[0] if side == 'left' else res[1] if side == 'right' else res

def Test():
    import time
    import numpy as np
    np.random.seed(0)
    
    def round_float_fixed_str(x, n = 0):
        if type(x) is int:
            return str(x)
        s = str(round(float(x), n))
        if n > 0:
            s += '0' * (n - (len(s) - 1 - s.rfind('.')))
        return s

    def to_tuples(x):
        r = np.empty([x.shape[1]], dtype = np.object_)
        r[:] = [tuple(e) for e in x.T]
        return r
    
    searchsorted2d = {
        'py.zip': SearchSorted2D('py.zip'),
        'np.nonzero': SearchSorted2D('np.nonzero'),
        'np.lexsort': SearchSorted2D('np.lexsort'),
        'cmp_numba': SearchSorted2D('cmp_numba'),
    }
    
    for iklen, klen in enumerate([1, 1, 2, 5, 10, 20, 50, 100, 200]):
        times = {}
        for side in ['left', 'right']:
            a = np.zeros((klen, 0), dtype = np.int64)
            tac = to_tuples(a)

            for itest in range((15, 100)[iklen == 0]):
                b = np.random.randint(0, (3, 100000)[iklen == 0], (klen, np.random.randint(1, (1000, 2000)[iklen == 0])), dtype = np.int64)
                b = b[:, np.lexsort(b[::-1])]
                
                if iklen == 0:
                    assert klen == 1, klen
                    ts = time.time()
                    ix1 = np.searchsorted(a[0], b[0], side = side)
                    te = time.time()
                    times['np.searchsorted'] = times.get('np.searchsorted', 0.) + te - ts
                    
                for cached in [False, True]:
                    ts = time.time()
                    tb = to_tuples(b)
                    ta = tac if cached else to_tuples(a)
                    ix1 = np.searchsorted(ta, tb, side = side)
                    if not cached:
                        ix0 = ix1
                    tac = np.insert(tac, ix0, tb) if cached else tac
                    te = time.time()
                    timesk = f'py.tuples{("", "_cached")[cached]}'
                    times[timesk] = times.get(timesk, 0.) + te - ts

                for type_ in searchsorted2d.keys():
                    if iklen == 0 and type_ in ['np.nonzero', 'np.lexsort']:
                        continue
                    ss = searchsorted2d[type_]
                    try:
                        ts = time.time()
                        ix1 = ss(a, b, side = side)
                        te = time.time()
                        times[type_] = times.get(type_, 0.) + te - ts
                        assert np.array_equal(ix0, ix1)
                    except Exception:
                        times[type_ + '!failed'] = 0.

                for numba_ in [False, True]:
                    for vsorted in [False, True]:
                        if numba_:
                            # Heat-up/pre-compile numba
                            SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
                        
                        ts = time.time()
                        ix1 = SearchSorted2DNumba.do(a, b, side = side, vsorted = vsorted, numba_ = numba_)
                        te = time.time()
                        timesk = f'numba{("_py", "_nm")[numba_]}{("", "_sorted")[vsorted]}'
                        times[timesk] = times.get(timesk, 0.) + te - ts
                        assert np.array_equal(ix0, ix1)


                # View-1D methods suggested by @MadPhysicist
                if False: # Commented out as working just some-times
                    aT, bT = np.copy(a.T), np.copy(b.T)
                    assert aT.ndim == 2 and bT.ndim == 2 and aT.shape[1] == klen and bT.shape[1] == klen, (aT.shape, bT.shape, klen)
                    
                    for ty in ['if', 'cf']:
                        try:
                            dt = np.dtype({'if': [('', b.dtype)] * klen, 'cf': [('row', b.dtype, klen)]}[ty])
                            ts = time.time()
                            va = np.ndarray(aT.shape[:1], dtype = dt, buffer = aT)
                            vb = np.ndarray(bT.shape[:1], dtype = dt, buffer = bT)
                            ix1 = np.searchsorted(va, vb, side = side)
                            te = time.time()
                            assert np.array_equal(ix0, ix1), (ix0.shape, ix1.shape, ix0[:20], ix1[:20])
                            times[f'view1d_{ty}'] = times.get(f'view1d_{ty}', 0.) + te - ts
                        except Exception:
                            raise
                
                a = np.insert(a, ix0, b, axis = 1)
            
        stimes = ([f'key_len: {str(klen).rjust(3)}'] +
            [f'{k}: {round_float_fixed_str(v, 4).rjust(7)}' for k, v in times.items()])
        nlines = 4
        print('-' * 50 + '\n' + ('', '!LARGE!:\n')[iklen == 0], end = '')
        for i in range(nlines):
            print(',  '.join(stimes[len(stimes) * i // nlines : len(stimes) * (i + 1) // nlines]), flush = True)
            
Test()

输出:

--------------------------------------------------
!LARGE!:
key_len:   1,  np.searchsorted:  0.0250
py.tuples_cached:  3.3113,  py.tuples: 30.5263,  py.zip: 40.9785
cmp_numba: 25.7826,  numba_py:  3.6673
numba_py_sorted:  6.8926,  numba_nm:  0.0466,  numba_nm_sorted:  0.0505
--------------------------------------------------
key_len:   1,  py.tuples_cached:  0.1371
py.tuples:  0.4698,  py.zip:  1.2005,  np.nonzero:  4.7827
np.lexsort:  4.4672,  cmp_numba:  1.0644,  numba_py:  0.2748
numba_py_sorted:  0.5699,  numba_nm:  0.0005,  numba_nm_sorted:  0.0020
--------------------------------------------------
key_len:   2,  py.tuples_cached:  0.1131
py.tuples:  0.3643,  py.zip:  1.0670,  np.nonzero:  4.5199
np.lexsort:  3.4595,  cmp_numba:  0.8582,  numba_py:  0.4958
numba_py_sorted:  0.6454,  numba_nm:  0.0025,  numba_nm_sorted:  0.0025
--------------------------------------------------
key_len:   5,  py.tuples_cached:  0.1876
py.tuples:  0.4493,  py.zip:  1.6342,  np.nonzero:  5.5168
np.lexsort:  4.6086,  cmp_numba:  1.0939,  numba_py:  1.0607
numba_py_sorted:  0.9737,  numba_nm:  0.0050,  numba_nm_sorted:  0.0065
--------------------------------------------------
key_len:  10,  py.tuples_cached:  0.6017
py.tuples:  1.2275,  py.zip:  3.5276,  np.nonzero: 13.5460
np.lexsort: 12.4183,  cmp_numba:  2.5404,  numba_py:  2.8334
numba_py_sorted:  2.3991,  numba_nm:  0.0165,  numba_nm_sorted:  0.0155
--------------------------------------------------
key_len:  20,  py.tuples_cached:  0.8316
py.tuples:  1.3759,  py.zip:  3.4238,  np.nonzero: 13.7834
np.lexsort: 16.2164,  cmp_numba:  2.4483,  numba_py:  2.6405
numba_py_sorted:  2.2226,  numba_nm:  0.0170,  numba_nm_sorted:  0.0160
--------------------------------------------------
key_len:  50,  py.tuples_cached:  1.0443
py.tuples:  1.4085,  py.zip:  2.2475,  np.nonzero:  9.1673
np.lexsort: 19.5266,  cmp_numba:  1.6181,  numba_py:  1.7731
numba_py_sorted:  1.4637,  numba_nm:  0.0415,  numba_nm_sorted:  0.0405
--------------------------------------------------
key_len: 100,  py.tuples_cached:  2.0136
py.tuples:  2.5380,  py.zip:  2.2279,  np.nonzero:  9.2929
np.lexsort: 33.9505,  cmp_numba:  1.5722,  numba_py:  1.7158
numba_py_sorted:  1.4208,  numba_nm:  0.0871,  numba_nm_sorted:  0.0851
--------------------------------------------------
key_len: 200,  py.tuples_cached:  3.5945
py.tuples:  4.1847,  py.zip:  2.3553,  np.nonzero: 11.3781
np.lexsort: 66.0104,  cmp_numba:  1.8153,  numba_py:  1.9449
numba_py_sorted:  1.6463,  numba_nm:  0.1661,  numba_nm_sorted:  0.1651

根据时间,numba_nm实现是最快的,它比下一个最快的实现(py.zippy.tuples_cached)快了15-100x倍。对于一维情况,它的速度与标准的np.searchsorted相当(1.85x较慢)。此外,使用插入数组排序信息的_sorted版本并没有改善情况。 cmp_numba方法是机器代码编译的,平均比同样使用纯Python算法的py.zip1.5x。由于平均最大相等键深度约为15-18个元素,因此在这里numba没有获得很大的加速。如果深度达到数百个,则numba代码可能会有巨大的加速。
对于键长<= 100的情况,py.tuples_cached策略比py.zip快。
此外,看起来np.lexsort实际上非常缓慢,可能是因为它没有针对仅具有两列的情况进行优化,或者它花费时间在预处理方面,例如将行拆分成列表,或者它执行的词典序比较不是惰性的,最后一种情况可能是真正的原因,因为随着键长度增加,lexsort会变慢。 np.nonzero策略也是不惰性的,因此工作速度也很慢,并且随着键长度的增长而减慢(但速度不如np.lexsort慢)。
以上时间可能不太准确,因为我的CPU在过热时随机降低核心频率2-2.3倍,而且它经常过热,因为它是笔记本电脑内部功能强大的CPU。

你的计时器包含了很多不必要的东西,比如在计时字典中进行关键字查找。 - Mad Physicist
只有在微观测量时才需要关注时间中包含的微小细节,例如当您测量代码 a + b 的速度时,需要运行100万个周期。但是,当您在一秒钟内仅有20个周期时,像字典查找或 if 条件这样的微小细节对于计时几乎没有任何意义。 - Arty
@MadPhysicist 其他类似 to_tuples(a) 或者 np.insert(tac, ix0, tb) 的内容应该被包含在时间测量中,因为它们是搜索排序的支持代码,应该始终包含在客户端代码中,因此不能不进行测量。 - Arty
@MadPhysicist 是的,让代码整洁是非常好的事情,我有时也喜欢这样做,特别是为了制作最终版以供公共使用。但是,对于速度的影响更大的是(有时达到10个数量级),通过找到一个好的算法。 - Arty
@MadPhysicist 刚刚决定使用 Numba 模块和二分查找方法,从头开始实现 我的答案 中的整个 searchsorted 2D 算法。它的性能比下一个最快的算法(使用 ziptuples)快了 15x-100x 倍,并且在特定的 1D 情况下,接近标准 numpy 的 np.searchsorted 的速度(慢了 1.85x)。 - Arty
显示剩余3条评论

2
两件事可以帮助你:(1)你可以对结构化数组进行排序和搜索,(2)如果你有有限的集合可以映射到整数,那么你可以利用它。

作为一维查看

假设你有一个字符串数组要插入:

data = np.array([['a', '1'], ['a', 'z'], ['b', 'a']], dtype=object)

由于数组永远不会是不规则的,因此您可以构造一个大小为行的dtype:

dt = np.dtype([('', data.dtype)] * data.shape[1])

使用我无耻地插入的答案这里,现在您可以将原始的二维数组视为一维数组:

view = np.ndarray(data.shape[:1], dtype=dt, buffer=data)

现在可以非常简单地进行搜索:

key = np.array([('a', 'a')], dtype=dt)
index = np.searchsorted(view, key)

你甚至可以通过使用适当的最小值找到不完整元素的插入索引。对于字符串,这将是'' Faster Comparison

You may get better mileage out of the comparison if you don't have to check each field of the dtype. You can make a similar dtype with a single homogeneous field:

dt2 = np.dtype([('row', data.dtype, data.shape[1])])

Constructing the view is the same as before:

view = np.ndarray(data.shape[:1], dtype=dt2, buffer=data)

The key is done a little differently this time (another plug here):

key = np.array([(['a', 'a'],)], dtype=dt2)
这个方法使用自定义dtype对对象数组进行行排序所施加的排序顺序不正确。我在这里留下一个参考,以防链接的问题有解决方法。此外,它仍然非常适用于整数排序。
整数映射
如果您要搜索有限数量的对象,将它们映射到整数会更容易:
idata = np.empty(data.shape, dtype=int)
keys = [None] * data.shape[1]     # Map index to key per column
indices = [None] * data.shape[1]  # Map key to index per column
for i in range(data.shape[1]):
    keys[i], idata[:, i] = np.unique(data[:, i], return_inverse=True)
    indices[i] = {k: i for i, k in enumerate(keys[i])}  # Assumes hashable objects

idt = np.dtype([('row', idata.dtype, idata.shape[1])])
view = idata.view(idt).ravel()

只有在data中实际包含每列中所有可能的键时,此方法才有效。否则,您将需要通过其他方式获取正向和反向映射。一旦确定了这一点,设置键就变得更加简单,只需要使用indices

key = np.array([index[k] for index, k in zip(indices, ['a', 'a'])])

进一步改进

如果您拥有的类别数量不超过八个,并且每个类别都有256个或更少的元素,则可以通过将所有内容放入单个np.uint64元素中来构建更好的哈希。

k = math.ceil(math.log(data.shape[1], 2))  # math.log provides base directly
assert 0 < k <= 64
idata = np.empty((data.shape[:1], k), dtype=np.uint8)
...
idata = idata.view(f'>u{k}').ravel()

同样,键也是这样制作的:

key = np.array([index[k] for index, k in zip(indices, ['a', 'a'])]).view(f'>u{k}')

时间

我已经对这里展示的方法进行了计时(没有包括其他答案),使用随机洗牌的字符串。关键的计时参数为:

  • M:行数:10 ** {2, 3, 4, 5}
  • N:列数:2 ** {3, 4, 5, 6}
  • K:要插入的元素数:1, 10, M // 10
  • 方法:individual_fieldscombined_fieldint_mappingint_packing。下面显示了函数。

对于最后两种方法,我假设您将预先将数据转换为映射的数据类型,但不是搜索键。因此,我传递已转换的数据,但计时键的转换。

import numpy as np
from math import ceil, log

def individual_fields(data, keys):
    dt = [('', data.dtype)] * data.shape[1]
    dview = np.ndarray(data.shape[:1], dtype=dt, buffer=data)
    kview = np.ndarray(keys.shape[:1], dtype=dt, buffer=keys)
    return np.searchsorted(dview, kview)

def combined_fields(data, keys):
    dt = [('row', data.dtype, data.shape[1])]
    dview = np.ndarray(data.shape[:1], dtype=dt, buffer=data)
    kview = np.ndarray(keys.shape[:1], dtype=dt, buffer=keys)
    return np.searchsorted(dview, kview)

def int_mapping(idata, keys, indices):
    idt = np.dtype([('row', idata.dtype, idata.shape[1])])
    dview = idata.view(idt).ravel()
    kview = np.empty(keys.shape[0], dtype=idt)
    for i, (index, key) in enumerate(zip(indices, keys.T)):
        kview['row'][:, i] = [index[k] for k in key]
    return np.searchsorted(dview, kview)

def int_packing(idata, keys, indices):
    idt = f'>u{idata.shape[1]}'
    dview = idata.view(idt).ravel()
    kview = np.empty(keys.shape, dtype=np.uint8)
    for i, (index, key) in enumerate(zip(indices, keys.T)):
        kview[:, i] = [index[k] for k in key]
    kview = kview.view(idt).ravel()
    return np.searchsorted(dview, kview)

计时代码:

from math import ceil, log
from string import ascii_lowercase
from timeit import Timer

def time(m, n, k, fn, *args):
    t = Timer(lambda: fn(*args))
    s = t.autorange()[0]
    print(f'M={m}; N={n}; K={k} {fn.__name__}: {min(t.repeat(5, s)) / s}')

selection = np.array(list(ascii_lowercase), dtype=object)
for lM in range(2, 6):
    M = 10**lM
    for lN in range(3, 6):
        N = 2**lN
        data = np.random.choice(selection, size=(M, N))
        np.ndarray(data.shape[0], dtype=[('', data.dtype)] * data.shape[1], buffer=data).sort()
        idata = np.array([[ord(a) - ord('a') for a in row] for row in data], dtype=np.uint8)
        ikeys = [selection] * data.shape[1]
        indices = [{k: i for i, k in enumerate(selection)}] * data.shape[1]
        for K in (1, 10, M // 10):
            key = np.random.choice(selection, size=(K, N))
            time(M, N, K, individual_fields, data, key)
            time(M, N, K, combined_fields, data, key)
            time(M, N, K, int_mapping, idata, key, indices)
            if N <= 8:
                time(M, N, K, int_packing, idata, key, indices)

结果:

M=100(单位=微秒)

   |                           K                           |
   +---------------------------+---------------------------+
N  |             1             |            10             |
   +------+------+------+------+------+------+------+------+
   |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |
---+------+------+------+------+------+------+------+------+
 8 | 25.9 | 18.6 | 52.6 | 48.2 | 35.8 | 22.7 | 76.3 | 68.2 | 
16 | 40.1 | 19.0 | 87.6 |  --  | 51.1 | 22.8 | 130. |  --  |
32 | 68.3 | 18.7 | 157. |  --  | 79.1 | 22.4 | 236. |  --  |
64 | 125. | 18.7 | 290. |  --  | 135. | 22.4 | 447. |  --  |
---+------+------+------+------+------+------+------+------+

M=1000(单位为us)
   |                                         K                                         |
   +---------------------------+---------------------------+---------------------------+
N  |             1             |            10             |            100            |
   +------+------+------+------+------+------+------+------+------+------+------+------+
   |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |
---+------+------+------+------+------+------+------+------+------+------+------+------+
 8 | 26.9 | 19.1 | 55.0 | 55.0 | 44.8 | 25.1 | 79.2 | 75.0 | 218. | 74.4 | 305. | 250. |
16 | 41.0 | 19.2 | 90.5 |  --  | 59.3 | 24.6 | 134. |  --  | 244. | 79.0 | 524. |  --  | 
32 | 68.5 | 19.0 | 159. |  --  | 87.4 | 24.7 | 241. |  --  | 271. | 80.5 | 984. |  --  |
64 | 128. | 19.7 | 312. |  --  | 168. | 26.0 | 549. |  --  | 396. | 7.78 | 2.0k |  --  |
---+------+------+------+------+------+------+------+------+------+------+------+------+

M=10K (单位=微秒)

   |                                         K                                         |
   +---------------------------+---------------------------+---------------------------+
N  |             1             |            10             |           1000            |
   +------+------+------+------+------+------+------+------+------+------+------+------+
   |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |  IF  |  CF  |  IM  |  IP  |
---+------+------+------+------+------+------+------+------+------+------+------+------+
 8 | 28.8 | 19.5 | 54.5 | 107. | 57.0 | 27.2 | 90.5 | 128. | 3.2k | 762. | 2.7k | 2.1k |
16 | 42.5 | 19.6 | 90.4 |  --  | 73.0 | 27.2 | 140. |  --  | 3.3k | 752. | 4.6k |  --  |
32 | 73.0 | 19.7 | 164. |  --  | 104. | 26.7 | 246. |  --  | 3.4k | 803. | 8.6k |  --  |
64 | 135. | 19.8 | 302. |  --  | 162. | 26.1 | 466. |  --  | 3.7k | 791. | 17.k |  --  |
---+------+------+------+------+------+------+------+------+------+------+------+------+

individual_fields (IF) 是最快的工作方法。它的复杂性与列数成比例增长。不幸的是,combined_fields (CF) 对于对象数组不起作用。否则,它不仅是最快的方法,而且在增加列时不会增加复杂度。

我认为所有其他可能更快的技术都不行,因为将 Python 对象映射到键很慢(例如,实际查找打包的 int 数组比结构化数组快得多得多)。

参考资料

以下是我必须提出的其他问题,以使此代码正常工作:


@Arty. 它根据文档进行词典排序,除非你传递一个自定义的排序器。 - Mad Physicist
@Arty。我的唯一困惑是当dtype有一个单独的数组字段时,它似乎对指针进行排序而不是它们的内容,这是出乎意料的。 - Mad Physicist
你有没有将你的view1d方法的结果与任何参考简单方法进行比较,例如tuple method here?我刚刚将你的IF/CF方法放入我的测量代码中在这个答案中,看起来你的方法MOST-TIMES返回的输出索引与np.searchsorted不同,当键长度>1时,与我的许多方法都返回相同的结果。可能是因为numpy在比较这种自定义类型时不是按字典顺序比较,或者存在一些原始视图的内部问题。 - Arty
@Arty。 combined_fields 对于对象不起作用,而 int_pacing 仅适用于 N<=8。但是,我已经检查了其他方法返回相同且合理的结果。 - Mad Physicist
@Arty。复制很可能会将底层数据实际上转置为内存中的C顺序。 - Mad Physicist
显示剩余11条评论

0

发布我在问题中提到的第一个天真的解决方案,它只是将2D数组转换为包含原始列的Python元组的1D dtype = np.object_数组,然后使用1D np.searchsorted,该解决方案适用于任何dtype。实际上,这个解决方案并不那么天真,它相当快,正如我在当前问题的另一个答案中所测量的那样,特别是对于键长度低于100的情况下速度很快。

在线试用!

import numpy as np
np.random.seed(0)

def to_obj(x):
    res = np.empty((x.shape[0],), dtype = np.object_)
    res[:] = [tuple(np.squeeze(e, 0)) for e in np.split(x, x.shape[0], axis = 0)]
    return res

a = np.random.randint(0, 3, (10, 23))
b = np.random.randint(0, 3, (10, 15))

a, b = [x[:, np.lexsort(x[::-1])] for x in (a, b)]

print(np.concatenate((np.arange(a.shape[1])[None, :], a)), '\n\n', b, '\n')

a, b = [to_obj(x.T) for x in (a, b)]

print(np.searchsorted(a, b))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接