为什么NumPy向量化比for循环慢?

3
以下代码有两个函数,它们的功能相同:检查连接两点之间的线是否与圆相交。

from line_profiler import LineProfiler
from math import sqrt
import numpy as np


class Point:
    x: float
    y: float

    def __init__(self, x: float, y: float):
        self.x = x
        self.y = y

    def __repr__(self):
        return f"Point(x={self.x}, y={self.y})"


class Circle:
    ctr: Point
    r: float

    def __init__(self, ctr: Point, r: float):
        self.ctr = ctr
        self.r = r

    def __repr__(self):
        return f"Circle(r={self.r}, ctr={self.ctr})"


def loop(p1: Point, p2: Point, circles: list[Circle]):
    m = (p1.y - p2.y) / (p1.x - p2.x)
    n = p1.y - m * p1.x

    max_x = max(p1.x, p2.x)
    min_x = min(p1.x, p2.x)

    for circle in circles:
        if sqrt((circle.ctr.x - p1.x) ** 2 + (circle.ctr.y - p1.y) ** 2) < circle.r \
                or sqrt((circle.ctr.x - p2.x) ** 2 + (circle.ctr.y - p2.y) ** 2) < circle.r:
            return False

        a = m ** 2 + 1
        b = 2 * (m * n - m * circle.ctr.y - circle.ctr.x)
        c = circle.ctr.x ** 2 + circle.ctr.y ** 2 + n ** 2 - circle.r ** 2 - 2 * n * circle.ctr.y

        # compute the intersection points
        discriminant = b ** 2 - 4 * a * c
        if discriminant <= 0:
            # no real roots, the line does not intersect the circle
            continue

        # two real roots, the line intersects the circle at two points
        x1 = (-b + sqrt(discriminant)) / (2 * a)
        x2 = (-b - sqrt(discriminant)) / (2 * a)

        # check if both points in range
        first = min_x <= x1 <= max_x
        second = min_x <= x2 <= max_x
        if first and second:
            return False

    return True


def vectorized(p1: Point, p2: Point, circles):
    m = (p1.y - p2.y) / (p1.x - p2.x)
    n = p1.y - m * p1.x

    max_x = max(p1.x, p2.x)
    min_x = min(p1.x, p2.x)

    circle_ctr_x = circles['x']
    circle_ctr_y = circles['y']
    circle_radius = circles['r']

    # Pt 1 inside circle
    if np.any(np.sqrt((circle_ctr_x - p1.x) ** 2 + (circle_ctr_y - p1.y) ** 2) < circle_radius):
        return False
    # Pt 2 inside circle
    if np.any(np.sqrt((circle_ctr_x - p2.x) ** 2 + (circle_ctr_y - p2.y) ** 2) < circle_radius):
        return False
    # Line intersects with circle in range
    a = m ** 2 + 1
    b = 2 * (m * n - m * circle_ctr_y - circle_ctr_x)
    c = circle_ctr_x ** 2 + circle_ctr_y ** 2 + n ** 2 - circle_radius ** 2 - 2 * n * circle_ctr_y

    # compute the intersection points
    discriminant = b**2 - 4*a*c
    discriminant_bigger_than_zero = discriminant > 0
    discriminant = discriminant[discriminant_bigger_than_zero]

    if discriminant.size == 0:
        return True

    b = b[discriminant_bigger_than_zero]

    # two real roots, the line intersects the circle at two points
    x1 = (-b + np.sqrt(discriminant)) / (2 * a)
    x2 = (-b - np.sqrt(discriminant)) / (2 * a)

    # check if both points in range
    in_range = (min_x <= x1) & (x1 <= max_x) & (min_x <= x2) & (x2 <= max_x)
    return not np.any(in_range)


a = Point(x=-2.47496075130008, y=1.3609840363748935)
b = Point(x=3.4637947060471084, y=-3.7779123453298817)
c = [Circle(r=1.2587063082677084, ctr=Point(x=3.618533781361757, y=2.179925931180058)), Circle(r=0.7625751871124099, ctr=Point(x=-0.3173290200183132, y=4.256206636932641)), Circle(r=0.4926043225930364, ctr=Point(x=-4.626312261120341, y=-1.5754603504419196)), Circle(r=0.6026364956540792, ctr=Point(x=3.775240278691819, y=1.7381168262343072)), Circle(r=1.2804597877349562, ctr=Point(x=4.403273380178893, y=-1.6890127555343681)), Circle(r=1.1562415624767421, ctr=Point(x=-1.0675000352105801, y=-0.23952113329203994)), Circle(r=1.112718432321835, ctr=Point(x=2.500137075066017, y=-2.77748519509295)), Circle(r=0.979889574640609, ctr=Point(x=4.494971251199753, y=-1.0530995423779388)), Circle(r=0.7817624050358268, ctr=Point(x=3.2419454348696544, y=4.3303373486692465)), Circle(r=1.0271176198616367, ctr=Point(x=-0.9740272820753071, y=-4.282195116754338)), Circle(r=1.1585218836700681, ctr=Point(x=-0.42096876790888915, y=2.135161027254492)), Circle(r=1.0242603387003988, ctr=Point(x=2.2617850544260767, y=-4.59942951839469)), Circle(r=1.5704233297828027, ctr=Point(x=-1.1182365440831088, y=4.2411408333943506)), Circle(r=0.37137272043983655, ctr=Point(x=3.280499587987774, y=-4.87871834733383)), Circle(r=1.1829610109115543, ctr=Point(x=-0.27755604766113606, y=-3.68429580935016)), Circle(r=1.0993567600839198, ctr=Point(x=0.23602306761027925, y=0.47530122196024704)), Circle(r=1.3865045367147553, ctr=Point(x=-2.537565761732492, y=4.719766182202855)), Circle(r=0.9492796511909753, ctr=Point(x=-3.7047245796551973, y=-2.501817905967274)), Circle(r=0.9866916911482386, ctr=Point(x=1.3021813533479742, y=4.754952371169189)), Circle(r=0.9053004331885084, ctr=Point(x=-3.4912157984801784, y=-0.5269727600532836)), Circle(r=1.3058987272565075, ctr=Point(x=-1.6983878085276427, y=-2.2910189455221053)), Circle(r=0.5342716756987732, ctr=Point(x=4.948676886704507, y=-1.2467089784975183)), Circle(r=1.0603926633240575, ctr=Point(x=-4.390462974765324, y=0.785568745976325)), Circle(r=0.3448422804513971, ctr=Point(x=-1.6459756952994697, y=2.7608629057950362)), Circle(r=0.8521457455807724, ctr=Point(x=-4.503217369041699, y=3.93796926957188)), Circle(r=0.602438849989669, ctr=Point(x=-2.0703406576157493, y=0.6142570312870999)), Circle(r=0.6453692950682722, ctr=Point(x=-0.14802220452893144, y=4.08189682338989)), Circle(r=0.6983361689325062, ctr=Point(x=0.09362196694661651, y=-1.0953438275586391)), Circle(r=1.880331563921456, ctr=Point(x=0.23481661751521776, y=-4.09217120864087)), Circle(r=0.5766225363413416, ctr=Point(x=3.149434524126505, y=-4.639582956406762)), Circle(r=0.6177559628867022, ctr=Point(x=-1.6758918144661683, y=-0.7954935787503492)), Circle(r=0.7347952666955615, ctr=Point(x=-3.1907522890427575, y=0.7048509241855683)), Circle(r=1.2795003337464894, ctr=Point(x=-1.777244415863577, y=2.936422879898364)), Circle(r=0.9181024765780231, ctr=Point(x=4.212544425778317, y=-1.953546993038261)), Circle(r=1.7681384709020282, ctr=Point(x=-1.3702722387909405, y=-1.7013020424154368)), Circle(r=0.5420789771729688, ctr=Point(x=4.063803796292818, y=-3.7159871611415065)), Circle(r=1.3863651881788939, ctr=Point(x=0.7685002210812408, y=-3.994230705171357)), Circle(r=0.5739750223225826, ctr=Point(x=0.08779554290638258, y=4.879912451441914)), Circle(r=1.2019825386919343, ctr=Point(x=-4.206623233886995, y=-1.1617382464768689))]

circle_dt = np.dtype('float,float,float')
circle_dt.names = ['x', 'y', 'r']
np_c = np.array([(x.ctr.x, x.ctr.y, x.r) for x in c], dtype=circle_dt)


lp1 = LineProfiler()
loop_wrapper = lp1(loop)
loop_wrapper(a, b, c)
lp1.print_stats()

lp2 = LineProfiler()
vectorized_wrapper = lp2(vectorized)
vectorized_wrapper(a, b, np_c)
lp2.print_stats()

有一种实现是常规的for循环实现,另一种是使用NumPy进行向量化实现。 从我的小小向量化知识来看,我本以为向量化函数会产生更好的结果,但正如您下面所看到的那样,并非总是如此:

Total time: 4.36e-05 s
Function: loop at line 31

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    31                                           def loop(p1: Point, p2: Point, circles: list[Circle]):
    32         1          9.0      9.0      2.1      m = (p1.y - p2.y) / (p1.x - p2.x)
    33         1          5.0      5.0      1.1      n = p1.y - m * p1.x
    34                                           
    35         1         19.0     19.0      4.4      max_x = max(p1.x, p2.x)
    36         1          5.0      5.0      1.1      min_x = min(p1.x, p2.x)
    37                                           
    38         6         30.0      5.0      6.9      for circle in circles:
    39         6         73.0     12.2     16.7          if sqrt((circle.ctr.x - p1.x) ** 2 + (circle.ctr.y - p1.y) ** 2) < circle.r \
    40         6         62.0     10.3     14.2                  or sqrt((circle.ctr.x - p2.x) ** 2 + (circle.ctr.y - p2.y) ** 2) < circle.r:
    41                                                       return False
    42                                           
    43         6         29.0      4.8      6.7          a = m ** 2 + 1
    44         6         32.0      5.3      7.3          b = 2 * (m * n - m * circle.ctr.y - circle.ctr.x)
    45         6         82.0     13.7     18.8          c = circle.ctr.x ** 2 + circle.ctr.y ** 2 + n ** 2 - circle.r ** 2 - 2 * n * circle.ctr.y
    46                                           
    47                                                   # compute the intersection points
    48         6         33.0      5.5      7.6          discriminant = b ** 2 - 4 * a * c
    49         5         11.0      2.2      2.5          if discriminant <= 0:
    50                                                       # no real roots, the line does not intersect the circle
    51         5         22.0      4.4      5.0              continue
    52                                           
    53                                                   # two real roots, the line intersects the circle at two points
    54         1          7.0      7.0      1.6          x1 = (-b + sqrt(discriminant)) / (2 * a)
    55         1          4.0      4.0      0.9          x2 = (-b - sqrt(discriminant)) / (2 * a)
    56                                           
    57                                                   # check if one point in range
    58         1          5.0      5.0      1.1          first = min_x < x1 < max_x
    59         1          3.0      3.0      0.7          second = min_x < x2 < max_x
    60         1          2.0      2.0      0.5          if first and second:
    61         1          3.0      3.0      0.7              return False
    62                                           
    63                                                   return True

Total time: 0.0001534 s
Function: vectorized at line 66

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    66                                           def vectorized(p1: Point, p2: Point, circles):
    67         1         10.0     10.0      0.7      m = (p1.y - p2.y) / (p1.x - p2.x)
    68         1          5.0      5.0      0.3      n = p1.y - m * p1.x
    69                                           
    70         1          7.0      7.0      0.5      max_x = max(p1.x, p2.x)
    71         1          4.0      4.0      0.3      min_x = min(p1.x, p2.x)
    72                                           
    73         1         10.0     10.0      0.7      circle_ctr_x = circles['x']
    74         1          3.0      3.0      0.2      circle_ctr_y = circles['y']
    75         1          3.0      3.0      0.2      circle_radius = circles['r']
    76                                           
    77                                               # Pt 1 inside circle
    78         1        652.0    652.0     42.5      if np.any(np.sqrt((circle_ctr_x - p1.x) ** 2 + (circle_ctr_y - p1.y) ** 2) < circle_radius):
    79                                                   return False
    80                                               # Pt 2 inside circle
    81         1        161.0    161.0     10.5      if np.any(np.sqrt((circle_ctr_x - p2.x) ** 2 + (circle_ctr_y - p2.y) ** 2) < circle_radius):
    82                                                   return False
    83                                               # Line intersects with circle in range
    84         1         13.0     13.0      0.8      a = m ** 2 + 1
    85         1        120.0    120.0      7.8      b = 2 * (m * n - m * circle_ctr_y - circle_ctr_x)
    86         1         77.0     77.0      5.0      c = circle_ctr_x ** 2 + circle_ctr_y ** 2 + n ** 2 - circle_radius ** 2 - 2 * n * circle_ctr_y
    87                                           
    88                                               # compute the intersection points
    89         1         25.0     25.0      1.6      discriminant = b**2 - 4*a*c
    90         1         46.0     46.0      3.0      discriminant_bigger_than_zero = discriminant > 0
    91         1         56.0     56.0      3.7      discriminant = discriminant[discriminant_bigger_than_zero]
    92                                           
    93         1          6.0      6.0      0.4      if discriminant.size == 0:
    94                                                   return True
    95                                           
    96         1         12.0     12.0      0.8      b = b[discriminant_bigger_than_zero]
    97                                           
    98                                               # two real roots, the line intersects the circle at two points
    99         1         77.0     77.0      5.0      x1 = (-b + np.sqrt(discriminant)) / (2 * a)
   100         1         28.0     28.0      1.8      x2 = (-b - np.sqrt(discriminant)) / (2 * a)
   101                                           
   102                                               # check if both points in range
   103         1         96.0     96.0      6.3      in_range = (min_x <= x1) & (x1 <= max_x) & (min_x <= x2) & (x2 <= max_x)
   104         1        123.0    123.0      8.0      return not np.any(in_range)


出于某些原因,非矢量化函数运行得更快。
我的简单猜测是因为向量化函数每次都在整个数组上运行,而非矢量化函数在找到圆交点时就停止了。
所以我的问题是:
  1. 是否有numpy函数不会迭代整个数组,但在结果为false时停止?
  2. 为什么向量化函数需要更长时间来运行?
  3. 欢迎任何一般性的优化建议。

1
在这个特定的例子中,迭代在第39个圆上停止,而第6个圆结束了迭代。前五个圆跳过了循环的一部分,因为discriminant <= 0True。要确定第六个圆是否结束迭代,需要完成该圆的所有计算。同样,在向量化变体中,必须在每个步骤中处理所有圆,虽然您尝试缩短时间,但在这里没有任何帮助,因为至少一个圆始终存在,并且您从未减少工作集。 - Dan Mašek
1
你可能也在为中间变量的分配付出很多开销,算法的某些部分可以进行优化。对于良好的测试来说,样本大小39太小了——将其增加到100000(只需反复重复相同的几个步骤)。通过创建仅最后一个成功或中间一个成功的情景,使测试更加公平。 - Dan Mašek
你的程序中还存在一个bug。如果我从输入中移除第6和第7个圆,loop函数将返回True,但是vectorized函数会返回False。这两个函数并不执行相同的操作。我刚刚注意到,在判别式检验之后你确实减少了工作集,但为时已晚。 - Dan Mašek
好的,错误在loop中的return True缩进上。将其移至循环外部。可能是笔误。 - Dan Mašek
此外,这应该在少量的圆圈上运行(在十个左右),因此圆圈数量大约是那个大小。我相信随着样本大小的增长,情况会有所不同。 - Idan
显示剩余3条评论
1个回答

3
有没有一种numpy函数可以不遍历整个数组,但在结果为false时停止?
没有。这是Numpy用户长期要求的功能,但它肯定不会被添加到Numpy中。对于简单的情况,比如返回布尔数组的第一个索引,Numpy可以实现这个功能,但问题在于布尔数组需要先完全创建出来。为了支持一般情况,Numpy应该合并多个操作并进行某种形式的惰性计算。这基本上意味着为高效实现重新编写Numpy(这是一项巨大的工作)。
如果你需要这样做,有两个主要解决方案:
- 对块进行操作,以便在计算达到len(chunk)之前停止计算额外的项目; - 使用Numba或Cython(带视图)编写自己的快速编译实现。
向���化函数为什么运行时间更长?
输入非常小,而且Numpy没有针对小数组进行优化。事实上,每次调用Numpy函数通常需要0.4-4微秒(例如我的i5-9600KF处理器)。这是因为Numpy有许多检查要执行,需要分配新数组,构建通用内部迭代器等等。因此,像np.any(np.sqrt((circle_ctr_x - p1.x) ** 2 + (circle_ctr_y - p1.y) ** 2) < circle_radius)这样的代码行执行了8个Numpy调用并创建了7个临时数组,在我的机器上大约需要8微秒。第二个类似的行也需要��样的时间。它们一起比非向量化版本还慢。
正如问题和评论中指出的那样,非向量化函数可以提前停止,这也有助于非向量化版本比其他版本更快。
有什么一般性的优化建议?
关于你的代码,使用Numba(带简单循环和Numpy数组)对性能肯定是一个好主意。注意,由于编译时间的原因,首次调用可能会更慢(你可以在加载时提供签名来做到这一点,或者只是使用包括Cython的AOT编译器)。
请注意,结构体数组通常不高效,因为它们阻止了SIMD指令的有效使用。由于数据类型是动态创建的,并且Numpy代码已经预先编译(因此不能实现针对此特定数据类型的函数),因此它们在Numpy中也肯定计算不高效。 Numpy必须在数组的每个项目上使用通用动态操作,这比基本数据类型慢得多。请考虑使用数组结构体。有关更多信息,请阅读这篇文章和更普遍的这篇文章

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接