使用numpy向量化for循环以计算胶带重叠部分。

7
我正在使用Python创建一个应用程序,用于计算胶带重叠(模拟旋转鼓上的产品应用)。
我有一个可以正常工作的程序,但速度非常慢。我正在寻求优化“for”循环以填充numpy数组的解决方案。有人能帮我矢量化下面的代码吗?
import numpy as np
import matplotlib.pyplot as plt

# Some parameters
width = 264
bbddiam = 940
accuracy = 4 #2 points per pixel

drum = np.zeros(accuracy**2 * width * bbddiam).reshape((bbddiam * accuracy , width * accuracy))

# The "slow" function
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
    """Masks a half of the array"""
    to_return = np.zeros(drum.shape)
    for index, v in np.ndenumerate(to_return):
        if upper == True:
            if index[0] * coef + intercept > index[1]:
                to_return[index] = 1
        else:
            if index[0] * coef + intercept <= index[1]:
                to_return[index] = 1
    return to_return


def get_band(drum, coef, intercept, bandwidth):
    """Calculate a ribbon path on the drum"""
    to_return = np.zeros(drum.shape)
    t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
    t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
    to_return = t1 + t2
    return np.where(to_return == 2, 1, 0)

single_band = get_band(drum, 1 / 10, 130, bandwidth=15)

# Visualize the result !
plt.imshow(single_band)
plt.show()

Numba 对我的代码产生了神奇的作用,将运行时间从 5.8 秒降低到 86 毫秒(特别感谢 @Maarten-vd-Sande)。
from numba import jit
@jit(nopython=True, parallel=True)
def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
    ...

一个更好的解决方案仍然欢迎使用numpy;-)

5
你可以尝试使用Numba,保留原有的代码即可 :) - Maarten-vd-Sande
确实,Numba 热爱循环!我会尝试一下。 - Laurent R
3
@Maarten-vd-Sande,使用numba后结果令人印象深刻-->执行时间从5.8秒降至86毫秒!!谢谢。 - Laurent R
1
没有时间提供完整的答案,但基于重复的计算,我会查看 https://docs.scipy.org/doc/numpy/reference/routines.math.html(包括numpy.sum和numpy.prod)来快速执行矩阵运算,并使用逻辑函数进行比较(https://docs.scipy.org/doc/numpy/reference/routines.logic.html)。 - Jeff
1
不确定纯numpy是否能击败numba,但是使用适当的numpy可以相当简单地完成。 - Mad Physicist
1个回答

8

这里完全不需要进行任何循环。您实际上有两个不同的line_mask函数,都不需要显式循环,但您可能会从将其重写为一对for循环中的ifelse而不是在for循环中的ifelse中获得显着的加速,因为后者会被多次评估。

真正numpythonic的做法是将代码适当地向量化,以在没有任何循环的情况下操作整个数组。以下是一个向量化的line_mask版本:

def line_mask(drum, coef, intercept, upper=True, accuracy=accuracy):
    """Masks a half of the array"""
    r = np.arange(drum.shape[0]).reshape(-1, 1)
    c = np.arange(drum.shape[1]).reshape(1, -1)
    comp = c.__lt__ if upper else c.__ge__
    return comp(r * coef + intercept)

rc 的形状设置为 (m, 1)(n, 1),以便结果是 (m, n),这被称为广播,是numpy向量化的基础。

更新后的 line_mask 的结果是一个布尔掩码(顾名思义),而不是浮点数数组。这使它更小,并希望完全绕过浮点运算。现在,您可以重写 get_band 以使用屏蔽而不是添加:

def get_band(drum, coef, intercept, bandwidth):
    """Calculate a ribbon path on the drum"""
    t1 = line_mask(drum, coef, intercept + bandwidth / 2, upper=True)
    t2 = line_mask(drum, coef, intercept - bandwidth / 2, upper=False)
    return t1 & t2

由于这些函数保留了所有接口,因此程序的其余部分应保持不变。

如果您愿意,可以用三行代码(尽管有点难以辨认)重写大部分程序:

coeff = 1/10
intercept = 130
bandwidth = 15

r, c = np.ogrid[:drum.shape[0], :drum.shape[1]]
check = r * coeff + intercept
single_band = ((check + bandwidth / 2 > c) & (check - bandwidth /  2 <= c))

62.6毫秒,非常感谢! - Laurent R
1
@LaurentR。整洁。我敢打赌,我刚刚添加的三行代码会更快,但可能不会快多少。 - Mad Physicist
2
@LaurentR。在我看来,看到numpy打败numba总是很有趣的。Proper numba就像一只不可阻挡的野兽。 - Mad Physicist
之前不知道ndarrays上有__lt____ge__运算符 - 谢谢! - Jeff
@Jeff。是的。在Python中使用的任何运算符,如<>=in等,都是通过一个魔法类方法实现的,你可以将其绑定到该实例上。对于x < yx >= yy in x,分别调用type(x).__lt__(x, y)type(x).__ge__(x, y)type(x).__contains__(x, y)。这对于任何Python对象都是正确的,包括数组。 - Mad Physicist

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接