使用NumPy加速嵌套的for循环

3

我正在尝试解决一个动态规划问题,我想到了一个简单的基于循环的算法,它根据一系列if语句填充2D数组,类似于这样:

s = # some string of size n
opt = numpy.zeros(shape=(n, n))

for j in range(0, n):
    for i in range(j, -1, -1):
        if j - i == 0:
            opt[i, j] = 1
        elif j - i == 1:
            opt[i, j] = 2 if s[i] == s[j] else 1
        elif s[i] == s[j] and opt[i + 1, j - 1] == (j - 1) - (i + 1) + 1:
            opt[i, j] = 2 + opt[i + 1, j - 1]
        else:
            opt[i, j] = max(opt[i + 1, j], opt[i, j - 1], opt[i + 1, j - 1])

很遗憾,对于大的N值,这段代码非常缓慢。我发现使用内置函数如numpy.wherenumpy.fill来填充数组的值要比for循环好得多,但我很难找到任何例子来解释这些函数(或其他优化的numpy方法)如何与一系列的if语句一起使用,就像我的算法所做的那样。使用内置的numpy库重写上面的代码以使其更好地优化Python的适当方式是什么?

3个回答

1

你的 if 语句和赋值语句的左操作数都包含了对正在循环中被修改的数组的引用。这意味着没有通用的方法可以将你的循环转换为数组操作。因此,你只能使用某种形式的 for 循环。

如果你有以下更简单的循环:

for j in range(0, n):
    for i in range(j, -1, -1):
        if j - i == 0:
            opt[i, j] = 1
        elif j - i == 1:
            opt[i, j] = 2
        elif s[i] == s[j]:
            opt[i, j] = 3
        else:
            opt[i, j] = 4

您可以构建布尔数组(使用一些广播),表示您的三个条件:
import numpy as np

# get arrays i and j that represent the row and column indices
i,j = np.ogrid[:n, :n]
# construct an array with the characters from s
sarr = np.fromiter(s, dtype='U1').reshape(1, -1)

cond1 = i==j             # result will be a bool arr with True wherever row index equals column index
cond2 = j==i+1           # result will be a bool arr with True wherever col index equals (row index + 1)
cond3 = sarr==sarr.T     # result will be a bool arr with True wherever s[i]==s[j]

你可以使用 numpy.select 构建所需的 opt:
opt = np.select([cond1, cond2, cond3], [1, 2, 3], default=4)

对于 n=5s='abbca',这将产生以下结果:

array([[1, 2, 4, 4, 3],
       [4, 1, 2, 4, 4],
       [4, 3, 1, 2, 4],
       [4, 4, 4, 1, 2],
       [3, 4, 4, 4, 1]])

1
这里有一个向量化的解决方案。
它创建对输出数组的对角线视图,使我们能够在对角线方向上进行累积。
逐步说明:
- 在对角线视图中评估 s[i] == s[j]。 - 仅保留那些通过从右上到左下方向的一系列True与主对角线或第一子对角线相连的元素。 - 用2替换所有True,除了主对角线,主对角线用1代替;在从左下到右上方向上进行累计求和。 - 最后,在从下到上和从左到右的方向上进行累积最大值。
由于不是完全明显,这个方法做的事情与循环代码相同。我已经使用函数stresstest在许多示例中进行了测试,看起来是正确的。而且对于中等大小的字符串(1-100个字符),速度大约快7倍。
import numpy as np

def loopy(s):
    n = len(s)
    opt = np.zeros(shape=(n, n), dtype=int)
    for j in range(0, n):
        for i in range(j, -1, -1):
            if j - i == 0:
                opt[i, j] = 1
            elif j - i == 1:
                opt[i, j] = 2 if s[i] == s[j] else 1
            elif s[i] == s[j] and opt[i + 1, j - 1] == (j - 1) - (i + 1) + 1:
                opt[i, j] = 2 + opt[i + 1, j - 1]
            else:
                opt[i, j] = max(opt[i + 1, j], opt[i, j - 1], opt[i + 1, j - 1])
    return opt

def vect(s):
    n = len(s)
    h = (n+1) // 2
    s = np.array([s, s]).view('U1').ravel()
    opt = np.zeros((n+2*h-1, n+2*h-1), int)
    y, x = opt.strides
    hh = np.lib.stride_tricks.as_strided(opt[h-1:, h-1:], (2, h, n), (x, x-y, x+y))
    p, o, c = np.ogrid[:2, :h, :n]
    hh[...] = 2 * np.logical_and.accumulate(s[c+o+p] == s[c-o], axis=1)
    np.einsum('ii->i', opt)[...] = 1
    hh[...] = hh.cumsum(axis=1)
    opt = np.maximum.accumulate(opt[-h-1:None if h == 1 else h-2:-1, h-1:-h], axis=0)[::-1]
    return np.maximum.accumulate(opt, axis=1)

def stresstest(n=100):
    from string import ascii_lowercase
    import random
    from timeit import timeit
    Tv, Tl = 0, 0
    for i in range(n):
        s = ''.join(random.choices(ascii_lowercase[:random.randint(2, 26)], k=random.randint(1, 100)))
        print(s, end=' ')
        assert np.all(vect(s) == loopy(s))
        Tv += timeit(lambda: vect(s), number=10)
        Tl += timeit(lambda: loopy(s), number=10)
    print()
    print(f"total time loopy {Tl}, vect {Tv}")

演示:

>>> stresstest(20)
caccbbdbcfbfdcacebbecffacabeddcfdededeeafaebeaeedaaedaabebfacbdd fckjhrmupcqmihlohjog dffffgalbdbhkjigladhgdjaaagelddehahbbhejkibdgjhlkbcihiejdgidljfalfhlaglcgcih eacdebdcfcdcccaacfccefbccbced agglljlhfj mvwlkedblhvwbsmvtbjpqhgbaolnceqpgkhfivtbkwgbvujskkoklgforocj jljiqlidcdolcpmbfdqbdpjjjhbklcqmnmkfckkch ohsxiviwanuafkjocpexjmdiwlcmtcbagksodasdriieikvxphksedajwrbpee mcwdxsoghnuvxglhxcxxrezcdkahpijgujqqrqaideyhepfmrgxndhyifg omhppjaenjprnd roubpjfjbiafulerejpdniniuljqpouimsfukudndgtjggtbcjbchhfcdhrgf krutrwnttvqdemuwqwidvntpvptjqmekjctvbbetrvehsgxqfsjhoivdvwonvjd adiccabdbifigeigdfaieecceciaghadiaigibehdaichfibeaggcgdciahfegefigghgebhddciaei llobdegpmebejvotsr rtnsevatjvuowmquaulfmgiwsophuvlablslbwrpnhtekmpphsenarhrptgbjvlseeqstewjgfhopqwgmcbcihljeguv gcjlfihmfjbkdmimjknamfbahiccbhnceiahbnhghnlleimmieglgbfjbnmemdgddndhinncegnmgmfmgahhhjkg nhbnfhp cyjcygpaaeotcpwfhnumcfveq snyefmeuyjhcglyluezrx hcjhejhdaejchedbce 
total time loopy 0.2523909523151815, vect 0.03500175685621798

1
我不认为np.where和np.fill可以解决您的问题。 np.where用于返回满足特定条件的numpy数组元素,但在您的情况下,条件不是在numpy数组的值上,而是在变量i和j的值上。
对于您的特定问题,我建议使用Cython针对较大的N值优化您的代码。 Cython基本上是Python和C之间的接口。 Cython的美妙之处在于它允许您保留Python语法,但使用C结构进行优化。它允许您以类似于C的方式定义变量类型以加速计算。例如,使用Cython将i和j定义为整数将相当大地加快速度,因为在每个循环迭代中都会检查i和j的类型。
此外,Cython将允许您使用C定义经典的快速二维数组。然后,您可以使用指针快速访问此2D数组的元素,而不是使用numpy数组。在您的情况下,opt将是该2D数组。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接