在一个Numpy数组中寻找模式

8
我正在尝试在一个名为“values”的numpy数组中找到模式。我想返回模式的起始索引位置。我知道我可以迭代每个元素并检查该元素和下一个元素是否匹配模式,但对于大型数据集来说,这样做效率非常低,我正在寻找更好的替代方法。
我已经使用np.where来搜索单个值的工作解决方案,但无法用它来查找模式或两个数字。
示例:
import numpy as np
values = np.array([0,1,2,1,2,4,5,6,1,2,1])
searchval = [1,2]
print  np.where(values == searchval)[0]

输出:

[]

预期输出:

[1, 3, 8]

也许滚动哈希可以帮助解决问题。(https://en.wikipedia.org/wiki/Rolling_hash) - betontalpfa
5个回答

7

你可以使用np.where(假设这是查找元素的最佳方法),然后仅检查满足第一个条件的模式。

import numpy as np
values = np.array([0,1,2,1,2,4,5,6,1,2,1])
searchval = [1,2]
N = len(searchval)
possibles = np.where(values == searchval[0])[0]

solns = []
for p in possibles:
    check = values[p:p+N]
    if np.all(check == searchval):
        solns.append(p)

print(solns)

如果输入是随机的(有一些重复值),那么这个解决方案将会很快。 - betontalpfa

7
这里有一个使用where的直截了当的方法。首先,从逻辑表达式开始查找匹配项:
In [670]: values = np.array([0,1,2,1,2,4,5,6,1,2,1])
     ...: searchval = [1,2]
     ...: 
In [671]: (values[:-1]==searchval[0]) & (values[1:]==searchval[1])
Out[671]: array([False,  True, False,  True, False, False, False, False,  True, False], dtype=bool)
In [672]: np.where(_)
Out[672]: (array([1, 3, 8], dtype=int32),)

这可以推广为一个循环,可以操作多个searchval。正确获取切片范围需要一些调整。另一个答案中建议的roll可能更容易,但我怀疑速度会慢一些。
只要searchval相对于values很小,这种通用方法就应该是有效的。有一个np.in1d可以进行这种匹配,但需要使用一个or测试。所以它不适用。但如果searchval列表足够小,它也使用这种迭代方法。

泛化切片

In [716]: values
Out[716]: array([0, 1, 2, 1, 2, 4, 5, 6, 1, 2, 1])
In [717]: searchvals=[1,2,1]
In [718]: idx = [np.s_[i:m-n+1+i] for i in range(n)]
In [719]: idx
Out[719]: [slice(0, 9, None), slice(1, 10, None), slice(2, 11, None)]
In [720]: [values[idx[i]] == searchvals[i] for i in range(n)]
Out[720]: 
[array([False,  True, False,  True, False, False, False, False,  True], dtype=bool),
 array([False,  True, False,  True, False, False, False, False,  True], dtype=bool),
 array([False,  True, False, False, False, False,  True, False,  True], dtype=bool)]
In [721]: np.all(_, axis=0)
Out[721]: array([False,  True, False, False, False, False, False, False,  True], dtype=bool)
In [722]: np.where(_)
Out[722]: (array([1, 8], dtype=int32),)

我使用中间的np.s_来查看切片并确保它们看起来合理。

as_strided

一个高级技巧是使用as_strided构造“滚动”数组,并在其上执行二维==测试。 as_strided很棒,但也很棘手。 要正确使用它,您必须了解步幅,并正确设置形状。

In [740]: m,n = len(values), len(searchvals)
In [741]: values.shape
Out[741]: (11,)
In [742]: values.strides
Out[742]: (4,)
In [743]: 
In [743]: M = as_strided(values, shape=(n,m-n+1),strides=(4,4))
In [744]: M
Out[744]: 
array([[0, 1, 2, 1, 2, 4, 5, 6, 1],
       [1, 2, 1, 2, 4, 5, 6, 1, 2],
       [2, 1, 2, 4, 5, 6, 1, 2, 1]])
In [745]: M == np.array(searchvals)[:,None]
Out[745]: 
array([[False,  True, False,  True, False, False, False, False,  True],
       [False,  True, False,  True, False, False, False, False,  True],
       [False,  True, False, False, False, False,  True, False,  True]], dtype=bool)
In [746]: np.where(np.all(_,axis=0))
Out[746]: (array([1, 8], dtype=int32),)

切片是个不错的想法,对于适当大小的模式应该相当有效。 - Divakar
1
你能在矩阵上使用 as_strided 吗? - Oren

3
我认为这个可以完成任务:
np.where((values == 1) & (np.roll(values,-1) == 2))[0]

3
由于我们有一系列优秀的解决方案,我对运行时间很好奇,并发现在相当长的随机数组(例如,1百万条目)上,hpaulij的解决方案比简单的滚动快大约2倍。 Ed Smith的解决方案慢了约100倍,而betontalpfa的解决方案也慢了100倍。 条目数和命中次数会使数字发生相当大的变化,但不会改变整体排名。 - jnsod

2
紧凑简单的解决方案将是“合法”的as_strided解决方案变体。其他人已经提到了np.roll。但是这里有一个通用解决方案,只需要一个循环(132微秒)。最初的回答。
seq = np.array([0,1,2,1,2,4,5,6,1,2,1])
patt = np.array([1,2])

Seq = np.vstack([np.roll(seq, shift) for shift in -np.arange(len(patt))]).T
where(all(Seq == patt, axis=1))[0]

另一种适用于小整数序列的选项是将其转换为字符串。 它比原始回答更快,快了近6倍(20微秒)。 仅适用于小正整数! 原始回答:Original Answer
import re

def to_string(arr):
    return ''.join(map(chr, arr))

array([m.start() for m in re.finditer(to_string(patt), to_string(seq))])

2
如果输入是随机的,Ed Smith的解决方案更快。但是如果您有一些可用值集,则此哈希解决方案可以帮助:
"""
Can be replaced with any revertable hash
"""
def my_hash(rem, h, add):
    return rem^h^add

"""
Imput
"""
values = np.array([0,1,2,1,2,4,5,6,1,2,1])
searchval = [1,2]


"""
Prepare
"""
sh = 0
vh = 0
ls = len(searchval)
lv = len(values)

for i in range(0, len(searchval)):
    vh = my_hash(0, vh, values[i])
    sh = my_hash(0, sh, searchval[i])

"""
Find matches
"""
for i in range(0, lv-ls):
    if sh == vh:
        eq = True
        for j in range(0, ls):
            if values[i+j] != searchval[j]:
                eq = False
                break
        if eq:
            print i
    vh = my_hash(values[i], vh, values[i+ls])

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接