Python/NumPy子序列的第一个出现

39
在Python或NumPy中,查找子数组的第一次出现的最佳方法是什么?
例如,我有以下数组:
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
什么是在a中查找b出现位置的最快方法(以运行时间为基础)?我知道对于字符串来说这非常容易,但是对于列表或numpy ndarray呢?
非常感谢!
[编辑] 我更喜欢使用numpy解决方案,因为根据我的经验,numpy向量化比Python列表推导要快得多。同时,大数组非常巨大,所以我不想将其转换为字符串;那样会(太)长。

你可以将列表转换为字符串以进行比较吗?x=''.join(str(x) for x in a) 然后使用结果字符串的find方法?还是它们必须保持为列表? - danem
10个回答

26
我假设您正在寻找一个与numpy相关的解决方案,而不是一个简单的列表推导或for循环。一种简单的方法是使用滚动窗口技术来搜索适当大小的窗口。
这种方法简单、正确,并且比任何纯Python解决方案快得多。它应该足够满足许多用例。但是,出于许多原因,它不是可能的最有效方法。对于在期望情况下渐近最优的方法,请参见norok2的答案中基于 numba 滚动哈希实现。
以下是rolling_window函数:
>>> def rolling_window(a, size):
...     shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
...     strides = a.strides + (a. strides[-1],)
...     return numpy.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
... 

那么你可以做类似这样的事情

>>> a = numpy.arange(10)
>>> numpy.random.shuffle(a)
>>> a
array([7, 3, 6, 8, 4, 0, 9, 2, 1, 5])
>>> rolling_window(a, 3) == [8, 4, 0]
array([[False, False, False],
       [False, False, False],
       [False, False, False],
       [ True,  True,  True],
       [False, False, False],
       [False, False, False],
       [False, False, False],
       [False, False, False]], dtype=bool)

为了使其真正有用,您需要使用all沿轴1进行缩小:
>>> numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
array([False, False, False,  True, False, False, False, False], dtype=bool)

那么您可以像使用布尔数组一样使用它。获取索引的简单方法:

>>> bool_indices = numpy.all(rolling_window(a, 3) == [8, 4, 0], axis=1)
>>> numpy.mgrid[0:len(bool_indices)][bool_indices]
array([3])

对于列表,您可以改编其中一个滚动窗口迭代器以使用类似的方法。

对于非常大的数组和子数组,您可以像这样节省内存:

>>> windows = rolling_window(a, 3)
>>> sub = [8, 4, 0]
>>> hits = numpy.ones((len(a) - len(sub) + 1,), dtype=bool)
>>> for i, x in enumerate(sub):
...     hits &= numpy.in1d(windows[:,i], [x])
... 
>>> hits
array([False, False, False,  True, False, False, False, False], dtype=bool)
>>> hits.nonzero()
(array([3]),)

另一方面,这可能会稍微慢一些。


1
这种方法的问题在于,虽然rolling_window的返回值不需要任何新的内存,并且重用原始数组的内存,但是当执行==操作时,您会实例化一个新的布尔数组,其大小是原始数组的完整大小的size倍。如果数组足够大,这可能会极大地影响性能。 - Jaime
1
你说得对,这不是渐进最优解。然而,它在简单性和效率之间取得了很好的平衡——它直截了当、正确,并且比任何纯Python方法快得多。对于那些需要可证明最优解的人,norok2非常详细的答案有几个候选方案,包括一个基于numba的滚动哈希方法,在期望情况下是渐进最优的。 - senderle

25
以下代码应该有效:
[x for x in xrange(len(a)) if a[x:x+len(b)] == b]

返回模式开始的索引。


2
这可能不是最快的解决方案,但最简单的答案值得加1。这可能适合许多用户的需求,特别是如果numpy不可用。 - David
4
在Python 3中,使用range取代xrange - Samoth
2
为了提高性能,你可以用 len(a) - len(b) + 1 替换 len(a) - norok2

23

(编辑,增加了更深入的讨论,更好的代码和更多的基准测试)


总结

为了获得原始速度和效率,可以使用一种经过Cython或Numba加速的经典算法版本(分别适用于Python序列或NumPy数组作为输入)。

建议采用以下方法:

  • find_kmp_cy() 用于Python序列(listtuple等)
  • find_kmp_nb() 用于 NumPy 数组

其他高效的方法是 find_rk_cy()find_rk_nb(),它们的内存效率更高,但不能保证以线性时间运行。

如果没有 Cython/Numba,则对于大多数用例,find_kmp()find_rk() 都是很好的通用解决方案,虽然在平均情况下和对于 Python 序列来说,某些形式的简单方法,特别是 find_pivot(),可能更快。对于 NumPy 数组,find_conv()(来自@Jaime answer)优于任何未加速的简单方法。

(完整代码在下方,以及这里那里。)


理论

这是计算机科学中的一个经典问题,称为字符串搜索或字符串匹配问题。 基于两个嵌套循环的朴素方法,平均计算复杂度为O(n + m),但最坏情况是O(n m)。 多年来,已经开发出了许多替代方法,保证了更好的最坏情况性能。

在经典算法中,最适合于通用序列(因为它们不依赖于字母表)的算法是:

这个最后的算法依靠计算一个滚动哈希来提高效率,因此可能需要一些关于输入的额外知识以获得最佳性能。最终,它最适合于同质数据,例如数值数组。Python中数值数组的一个著名示例当然是NumPy数组。

备注

  • 由于简单易懂,朴素算法可用不同的实现进行各种程度的运行时速度提升。
  • 其他算法在通过语言技巧进行优化方面的灵活性较小。
  • 在Python中,显式循环可能是速度瓶颈,可以使用几种技巧将循环执行器外部化。
  • Cython在加速通用Python代码的显式循环方面效果特别好。
  • Numba在加速NumPy数组上的显式循环方面效果特别好。
  • 这是生成器的极佳用例,因此所有代码都将使用它们而不是常规函数。

Python序列(listtuple等)

基于朴素算法

  • find_loop()find_loop_cy()find_loop_nb() 是纯Python、Cython和具有Numba JITing的显式循环实现。请注意,在Numba版本中,我们使用Python对象输入,因此需要forceobj=True
def find_loop(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i

%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_loop_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    for i in range(n - m + 1):
        found = True
        for j in range(m):
            if seq[i + j] != subseq[j]:
                found = False
                break
        if found:
            yield i

find_loop_nb = nb.jit(find_loop, forceobj=True)
find_loop_nb.__name__ = 'find_loop_nb'
  • find_all() 用列表生成式中的 all() 替代内部循环。
def find_all(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if all(seq[i + j] == subseq[j] for j in range(m)):
            yield i
  • find_slice()会在切片[]之后使用直接比较==来替换内部循环。
def find_slice(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i:i + m] == subseq:
            yield i
  • find_mix()find_mix2() 在切片后使用直接比较 == 替换了内部循环,但包含一两个额外的短路,在第一个(和最后一个)字符上可能更快,因为使用 int 进行切片比使用 slice() 更快。
def find_mix(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i:i + m] == subseq:
            yield i

def find_mix2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if seq[i] == subseq[0] and seq[i + m - 1] == subseq[m - 1] \
                and seq[i:i + m] == subseq:
            yield i
  • find_pivot()find_pivot2()用子序列的第一个项替换了外部循环,并使用切片进行内部循环,最终在最后一项(第一个匹配由构造保证)上使用附加的短路。多个.index()调用包装在index_all()生成器中(这可能会单独有用)。
def index_all(seq, item, start=0, stop=-1):
    try:
        n = len(seq)
        if n > 0:
            start %= n
            stop %= n
            i = start
            while True:
                i = seq.index(item, i)
                if i <= stop:
                    yield i
                    i += 1
                else:
                    return
        else:
            return
    except ValueError:
        pass


def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i:i + m] == subseq:
            yield i

def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    for i in index_all(seq, subseq[0], 0, n - m):
        if seq[i + m - 1] == subseq[m - 1] and seq[i:i + m] == subseq:
            yield i

基于Knuth-Morris-Pratt(KMP)算法

  • find_kmp()是该算法的纯Python实现。由于没有简单循环或可以使用slice()切片的位置,因此除了使用Cython(Numba需要再次使用forceobj=True,这将导致代码缓慢)之外,没有太多可以进行优化。
def find_kmp(seq, subseq):
    n = len(seq)
    m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    j = 1
    k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    i = j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1
  • find_kmp_cy() 是一个Cython实现的算法,使用C int数据类型来表示索引,因此代码速度更快。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_kmp_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    # : compute offsets
    offsets = [0] * m
    cdef Py_ssize_t j = 1
    cdef Py_ssize_t k = 0
    while j < m: 
        if subseq[j] == subseq[k]: 
            k += 1
            offsets[j] = k
            j += 1
        else: 
            if k != 0: 
                k = offsets[k - 1] 
            else: 
                offsets[j] = 0
                j += 1
    # : find matches
    cdef Py_ssize_t i = 0
    j = 0
    while i < n: 
        if seq[i] == subseq[j]: 
            i += 1
            j += 1
        if j == m:
            yield i - j
            j = offsets[j - 1] 
        elif i < n and seq[i] != subseq[j]: 
            if j != 0: 
                j = offsets[j - 1] 
            else: 
                i += 1

基于 Rabin-Karp (RK)算法

  • find_rk() 是纯 Python 实现,它依赖于 Python 的 hash() 来计算(和比较)哈希值。这种哈希值通过简单的 sum() 进行滚动计算。接着,从前一个哈希值中减去刚刚访问的项 seq[i - 1]hash() 结果,并加上新考虑的项 seq[i + m - 1]hash() 结果来计算滚动值。
def find_rk(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])   # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i
  • find_rk_cy() 是Cython实现的算法,其中索引使用适当的C数据类型,这将导致代码运行速度更快。请注意,hash() 将"根据主机机器的位宽截断返回值"。
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


def find_rk_cy(seq, subseq):
    cdef Py_ssize_t n = len(seq)
    cdef Py_ssize_t m = len(subseq)
    if seq[:m] == subseq:
        yield 0
    cdef Py_ssize_t hash_subseq = sum(hash(x) for x in subseq)  # compute hash
    cdef Py_ssize_t curr_hash = sum(hash(x) for x in seq[:m])  # compute hash
    cdef Py_ssize_t old_item, new_item
    for i in range(1, n - m + 1):
        old_item = hash(seq[i - 1])
        new_item = hash(seq[i + m - 1])
        curr_hash += new_item - old_item  # update hash
        if hash_subseq == curr_hash and seq[i:i + m] == subseq:
            yield i

性能测试

上述功能基于以下两个输入进行评估:

  • 随机输入
def gen_input(n, k=2):
    return tuple(random.randint(0, k - 1) for _ in range(n))

(几乎)是天真算法中最坏的输入
def gen_input_worst(n, k=-2):
    result = [0] * n
    result[k] = 1
    return tuple(result)

subseq 的大小是固定的(32)。 由于有很多替代方案,因此进行了两个单独的分组,并省略了一些解决方案,这些解决方案具有非常小的变化和几乎相同的计时(即 find_mix2()find_pivot2())。 对于每个组,测试了两个输入。 对于每个基准测试,提供了完整的图表和最快方法的缩放。

在随机序列上使用 Naïve 算法

bm_full_naive_random bm_zoom_naive_random

在最坏情况下使用 Naïve 算法

bm_full_naive_worst bm_zoom_naive_worst

在随机序列上使用其他算法

bm_full_other_random bm_zoom_other_random

在最坏情况下使用其他算法

bm_full_other_worst bm_zoom_other_worst

(完整代码可在此处查看。)


NumPy 数组

基于 Naïve 算法

  • find_loop()find_loop_cy()find_loop_nb() 是纯 Python、Cython 和 Numba JITing 的显式循环实现。前两者的代码与上面相同,因此被省略了。现在,find_loop_nb() 可以享受快速的 JIT 编译。内部循环已经写成一个单独的函数,因为这样可以重复使用它用于 find_rk_nb()(在 Numba 函数内调用 Numba 函数不会产生 Python 典型的函数调用惩罚)。
@nb.jit
def _is_equal_nb(seq, subseq, m, i):
    for j in range(m):
        if seq[i + j] != subseq[j]:
            return False
    return True


@nb.jit
def find_loop_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    for i in range(n - m + 1):
        if _is_equal_nb(seq, subseq, m, i):
            yield i
  • find_all()与上面相同,而find_slice()find_mix()find_mix2()几乎与上面相同,唯一的区别是seq[i:i + m] == subseq现在成为了np.all()的参数:np.all(seq[i:i + m] == subseq)

  • find_pivot()find_pivot2()与上面的思路相同,只是现在使用np.where()代替index_all()并且需要将数组相等性封装在np.all()函数中。

def find_pivot(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif np.all(seq[i:i + m] == subseq):
            yield i


def find_pivot2(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if m > n:
        return
    max_i = n - m
    for i in np.where(seq == subseq[0])[0]:
        if i > max_i:
            return
        elif seq[i + m - 1] == subseq[m - 1] \
                and np.all(seq[i:i + m] == subseq):
            yield i
  • find_rolling() 通过滑动窗口表达循环,并使用 np.all() 进行匹配检查。这种方法将所有循环向量化,代价是创建大量临时对象,同时仍然基本应用了朴素算法。(该方法来自@senderle answer)。
def rolling_window(arr, size):
    shape = arr.shape[:-1] + (arr.shape[-1] - size + 1, size)
    strides = arr.strides + (arr.strides[-1],)
    return np.lib.stride_tricks.as_strided(arr, shape=shape, strides=strides)


def find_rolling(seq, subseq):
    bool_indices = np.all(rolling_window(seq, len(subseq)) == subseq, axis=1)
    yield from np.mgrid[0:len(bool_indices)][bool_indices]
  • find_rolling2()是以上算法的稍微更加内存高效的变种,其中向量化仅部分实现,并保留一个显式循环(沿着预期的最短维度——subseq长度进行循环)。 (该方法也来自@senderle答案)。
def find_rolling2(seq, subseq):
    windows = rolling_window(seq, len(subseq))
    hits = np.ones((len(seq) - len(subseq) + 1,), dtype=bool)
    for i, x in enumerate(subseq):
        hits &= np.in1d(windows[:, i], [x])
    yield from hits.nonzero()[0]

基于Knuth-Morris-Pratt(KMP)算法

  • find_kmp()与上述相同,而find_kmp_nb()是该算法的简单即时编译。
find_kmp_nb = nb.jit(find_kmp)
find_kmp_nb.__name__ = 'find_kmp_nb'

基于Rabin-Karp(RK)算法

  • find_rk()与上述实现相同,只是seq[i:i + m] == subseq再次被包含在np.all()调用中。

  • find_rk_nb()是上述方法的Numba加速版本。使用之前定义的_is_equal_nb()来明确确定匹配,对于哈希,则使用一个Numba加速的sum_hash_nb()函数,其定义非常简单。

@nb.jit
def sum_hash_nb(arr):
    result = 0
    for x in arr:
        result += hash(x)
    return result


@nb.jit
def find_rk_nb(seq, subseq):
    n = len(seq)
    m = len(subseq)
    if _is_equal_nb(seq, subseq, m, 0):
        yield 0
    hash_subseq = sum_hash_nb(subseq)  # compute hash
    curr_hash = sum_hash_nb(seq[:m])  # compute hash
    for i in range(1, n - m + 1):
        curr_hash += hash(seq[i + m - 1]) - hash(seq[i - 1])  # update hash
        if hash_subseq == curr_hash and _is_equal_nb(seq, subseq, m, i):
            yield i
  • find_conv() 使用伪 Rabin-Karp 方法,其中初始候选项使用 np.dot() 乘积进行哈希,并在 seqsubseq 的卷积中使用 np.where() 进行定位。该方法是伪的,因为虽然它仍然使用哈希来识别可能的候选项,但它可能不被视为滚动哈希(它取决于 np.correlate() 的实际实现)。此外,它需要创建一个输入大小的临时数组。(该方法来自@Jaime answer)。
def find_conv(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq, subseq, mode='valid') == target)[0]
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    yield from candidates[mask]

基准测试

和之前一样,上述函数将在两个输入上进行评估:

  • 随机输入
def gen_input(n, k=2):
    return np.random.randint(0, k, n)
  • (几乎)是朴素算法的最坏输入
def gen_input_worst(n, k=-2):
    result = np.zeros(n, dtype=int)
    result[k] = 1
    return result
subseq 的大小固定为 32。 以下图表与以前相同,为方便起见进行了总结。
由于有很多替代方案,因此进行了两个单独的分组,并省略了一些解决方案,这些方案具有非常小的差异和几乎相同的时间(即find_mix2()find_pivot2())。 对于每个组,都会测试两个输入。 对于每个基准测试,提供完整的图和最快方法的缩放版。

随机数上的朴素算法

bm_full_naive_random bm_zoom_naive_random

最劣情况下的朴素算法

bm_full_naive_worst bm_zoom_naive_worst

随机数上的其他算法

bm_full_other_random bm_zoom_other_random

最劣情况下的其他算法

bm_full_other_worst bm_zoom_other_worst

(完整代码可以在此处找到。)

感谢您运行所有这些测试!我从我的答案中链接到了这个。我仍然喜欢我的方法,因为它相当快速,易于理解,并且不会添加任何依赖项。但是对于需要真正最优解的人来说,这是很棒的。我不同意KMP方法是最好的。对于绝大多数实际用例,RK更快,并且真正需要KMP提供的最坏情况保证的人非常少。 - senderle

20

一种基于卷积的方法,应该比基于stride_tricks的方法更节省内存:

一种基于卷积的方法,应该比基于 stride_tricks 的方法更节省内存:

def find_subsequence(seq, subseq):
    target = np.dot(subseq, subseq)
    candidates = np.where(np.correlate(seq,
                                       subseq, mode='valid') == target)[0]
    # some of the candidates entries may be false positives, double check
    check = candidates[:, np.newaxis] + np.arange(len(subseq))
    mask = np.all((np.take(seq, check) == subseq), axis=-1)
    return candidates[mask]

对于非常大的数组,可能无法使用 stride_tricks 方法,但是这种方法仍然有效:

haystack = np.random.randint(1000, size=(1e6))
needle = np.random.randint(1000, size=(100,))
# Hide 10 needles in the haystack
place = np.random.randint(1e6 - 100 + 1, size=10)
for idx in place:
    haystack[idx:idx+100] = needle

In [3]: find_subsequence(haystack, needle)
Out[3]: 
array([253824, 321497, 414169, 456777, 635055, 879149, 884282, 954848,
       961100, 973481], dtype=int64)

In [4]: np.all(np.sort(place) == find_subsequence(haystack, needle))
Out[4]: True

In [5]: %timeit find_subsequence(haystack, needle)
10 loops, best of 3: 79.2 ms per loop

虽然我非常喜欢这种方法,但是我应该指出,通常通过L2范数找到候选项并不比从needle中找到特定符号更好。但是,通过使用与needle相同长度的随机模式计算点积进行小修改后,这种方法将变得非常棒。 - Alleo
@Alleo,我没有看到使用实际子序列本身的问题在哪里。我看到的问题是,如果子序列中有重复或零,则可能会有更多的碰撞,但随机序列可能会有相同的问题。 - norok2

15

你可以调用tostring()方法将数组转换为字符串,然后使用快速字符串搜索。当你有许多子数组要检查时,这种方法可能会更快。

import numpy as np

a = np.array([1,2,3,4,5,6])
b = np.array([2,3,4])
print a.tostring().index(b.tostring())//a.itemsize

1
这个解决方案非常快速和优雅,非常感谢!稍微相关的是,我有一个项目,使用 SWIG 包装器从 C++ 抓取大约 1e8 元素的 np 数组,数组创建非常缓慢。将它们作为字符串处理可以提高实时性能。 - fr_andres
方法是不正确的。请查看np.array([0, 1]).tostring().index(np.array([256]).tostring()). - Björn Lindqvist

2

再试一次,但我相信有更符合Python风格和更高效的方法来做这个...

def array_match(a, b):
    for i in range(0, len(a)-len(b)+1):
        if a[i:i+len(b)] == b:
            return i
    return None
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]

print(array_match(a,b))
1

set(a) & set(b) == set(b)

两个问题:这也会匹配 [1, 3, 2, 4, 5, 6](集合没有顺序;数组有),并且它不报告匹配的位置(应该是索引1)。 - cdhowie
抱歉,我回答得太快了 :-/ - Stéphane
你可以通过将 first_occurence=i 替换为 return i,并将 return first_occurence 替换为 return None 来简化你的代码。 - Nayuki

1
这是一个相当直截了当的选项:

def first_subarray(full_array, sub_array):
    n = len(full_array)
    k = len(sub_array)
    matches = np.argwhere([np.all(full_array[start_ix:start_ix+k] == sub_array) 
                   for start_ix in range(0, n-k+1)])
    return matches[0]

然后使用原始的a、b向量,我们得到:
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]
first_subarray(a, b)
Out[44]: 
array([1], dtype=int64)

你可能会添加一些逻辑来处理没有匹配项的情况... - Hezi Resheff

1
a = [1, 2, 3, 4, 5, 6]
b = [2, 3, 4]

np.concatenate((np.all(np.array([a[i:len(a)-(len(b)-1-i)] for i in range(len(b))]).T == b, axis = 1), np.full((len(b)-1), False)))

array([False,  True, False, False, False, False])

UPD:对于相对较小的子数组,这个算法比使用滚动窗口更快,并且输出是与a相同大小的数组。
同样的想法:
np.all(np.lib.stride_tricks.sliding_window_view(a, len(b)) == b, axis = 1)

1
你的回答可以通过提供更多的支持性信息来改进。请编辑以添加进一步的细节,例如引用或文档,以便他人可以确认你的回答是否正确。你可以在帮助中心找到有关如何撰写好回答的更多信息。 - undefined
感谢您对Stack Overflow社区做出贡献的兴趣。这个问题已经有很多答案了,其中一个答案已经得到社区的广泛验证。您确定您的方法之前没有被提到过吗?如果是这样的话,能否解释一下您的方法与之前的方法有何不同,您的方法在什么情况下可能更好,以及为什么您认为之前的答案不够满意。您能否编辑您的答案并提供解释? - undefined

0

三种提议解决方案的快速比较(随机创建向量的100次迭代的平均时间):

import time
import collections
import numpy as np


def function_1(seq, sub):
    # direct comparison
    seq = list(seq)
    sub = list(sub)
    return [i for i in range(len(seq) - len(sub)) if seq[i:i+len(sub)] == sub]

def function_2(seq, sub):
    # Jamie's solution
    target = np.dot(sub, sub)
    candidates = np.where(np.correlate(seq, sub, mode='valid') == target)[0]
    check = candidates[:, np.newaxis] + np.arange(len(sub))
    mask = np.all((np.take(seq, check) == sub), axis=-1)
    return candidates[mask]

def function_3(seq, sub):
    # HYRY solution
    return seq.tostring().index(sub.tostring())//seq.itemsize


# --- assessment time performance
N = 100

seq = np.random.choice([0, 1, 2, 3, 4, 5, 6], 3000)
sub = np.array([1, 2, 3])

tim = collections.OrderedDict()
tim.update({function_1: 0.})
tim.update({function_2: 0.})
tim.update({function_3: 0.})

for function in tim.keys():
    for _ in range(N):
        seq = np.random.choice([0, 1, 2, 3, 4], 3000)
        sub = np.array([1, 2, 3])
        start = time.time()
        function(seq, sub)
        end = time.time()
        tim[function] += end - start

timer_dict = collections.OrderedDict()
for key, val in tim.items():
    timer_dict.update({key.__name__: val / N})

print(timer_dict)

在我的旧电脑上,这将导致:

OrderedDict([
('function_1', 0.0008518099784851074), 
('function_2', 8.157730102539063e-05), 
('function_3', 6.124973297119141e-06)
])

-1

首先,将列表转换为字符串。

a = ''.join(str(i) for i in a)
b = ''.join(str(i) for i in b)

将其转换为字符串后,您可以使用以下字符串函数轻松找到子字符串的索引。

a.index(b)

干杯!!


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接