NumPy: 同时获取最大值和最小值的函数

163
numpy.amax()函数可以找到数组中的最大值,numpy.amin()函数可以找到最小值。如果我想要同时找到最大值和最小值,就必须调用这两个函数,这需要对(非常大的)数组进行两次传递,似乎很慢。

在numpy API中有一种函数可以在只经过一次数据时找到最大值和最小值吗?


1
有多大才算非常大?如果我有时间,我会运行一些测试,比较Fortran实现和amax以及amin - mgilson
1
我承认“非常大”是主观的。在我的情况下,我指的是几个GB的数组。 - Stuart Berg
1
当然,这并不是一个新颖的想法。例如,boost minmax 库(C++)提供了我正在寻找的算法的实现。 - Stuart Berg
顺便提一下,相关问题:https://dev59.com/Q2ox5IYBdhLWcg3wSCVW(虽然不完全相同,因为这个问题明确要求单次处理计算两个值) - Ken Arnold
9
虽然并非对所提出问题的答案,但这可能会引起该帖子中其他人的兴趣。在此问题中向 NumPy 提出添加“minmax”到库中的建议(https://github.com/numpy/numpy/issues/9836)。 - jakirkham
显示剩余3条评论
13个回答

73
在numpy API中没有一次通过数据找到最大值和最小值的函数。(如果有这样的功能,它的性能将比在大数组上连续调用numpy.amin()和numpy.amax()要明显得多。)。截至目前编写本文。

2
如果有人和我一样想知道使用仅两个索引的原地部分排序,例如:x.partition([0, x.size-1]) 是否比两次单独调用 x.min()x.max() 更快,那么答案是否定的。两次单独调用始终更快。我测试了以下 x 的大小:[1e2、1e3、1e4、1e5、1e6、1e7、1e8] 并测量了单独和部分排序的相对时间:[35.53%、5.10%、53.33%、36.89%、37.65%、36.03%、42.89%]。因此,如果您的数据向量具有 1 亿个样本,则与部分排序相比,最小值和最大值的单独调用仅使用 42.89% 的时间。 - Paloha

36
你可以使用Numba,它是一个基于LLVM的NumPy-aware动态Python编译器。生成的实现非常简单明了。
import numpy
import numba


@numba.jit
def minmax(x):
    maximum = x[0]
    minimum = x[0]
    for i in x[1:]:
        if i > maximum:
            maximum = i
        elif i < minimum:
            minimum = i
    return (minimum, maximum)


numpy.random.seed(1)
x = numpy.random.rand(1000000)
print(minmax(x) == (x.min(), x.max()))

它应该比NumPy的min()和max()实现更快。而且,所有这一切都不需要编写任何一行C/Fortran代码。

进行自己的性能测试,因为它始终取决于您的架构、数据和软件包版本...


2
它应该比Numpy的min()和max()实现更快。我认为这不对。Numpy不是本地Python - 它是C。x = numpy.random.rand(10000000) t = time() for i in range(1000): minmax(x) print('numba ', time() - t) t = time() for i in range(1000): x.min() x.max() print('numpy ', time() - t)结果为:('numba ', 10.299750089645386) ('numpy ', 9.898081064224243) - Authman Apatira
2
@AuthmanApatira:是的,基准测试总是这样,这就是为什么我说它“应该”(更快),并且“进行自己的性能测试,因为它始终取决于您的架构、数据等因素”。在我的情况下,我尝试了三台计算机,并得到了相同的结果(Numba比Numpy更快),但在您的计算机上,结果可能会有所不同...您是否尝试在基准测试之前执行numba函数以确保它已经JIT编译?此外,如果您使用ipython,为了简单起见,我建议您使用%timeit whatever_code()来测量时间执行。 - Peque
4
无论如何,我在这个答案中试图展示的是,有时候使用Python代码(本例中使用Numba进行JIT编译)可以和最快的C编译库一样快(至少我们在讨论同一数量级),考虑到我们只编写了纯Python代码,这是令人印象深刻的,你同意吗?^^ - Peque
我同意 =) 另外,感谢您在先前的评论中提供有关Jupyter和在计时代码之外编译函数的提示。 - Authman Apatira
4
刚刚看到这个,虽然在实际情况下没有什么影响,但是“elif”允许您的最小值大于最大值。例如,对于长度为1的数组,最大值将是该值,而最小值为+无穷大。对于一次性事件来说并不是什么大问题,但是将此代码放入生产环境中不是一个好的选择。 - Mike Williamson
1
@MikeWilliamson 你说得完全正确,感谢指出!我已经更新了我的答案,使用更好的“最小值”/“最大值”初始值。^^ - Peque

36

我认为两次遍历数组不是问题。 考虑以下伪代码:

minval = array[0]
maxval = array[0]
for i in array:
    if i < minval:
       minval = i
    if i > maxval:
       maxval = i

虽然这里只有一个循环,但仍然有两个检查(而非每个循环都只有一个检查)。你只能节约一个循环的开销。如果你所说的数组确实很大,那么与实际循环负载相比,这种开销就不算什么了。(请注意,这都是在C语言中实现的,因此这些循环本身几乎没有额外消耗)。


编辑 对不起,对于那四位投票支持我的人我感到非常抱歉。你们确实可以优化它。

以下是一些Fortran代码,可以通过f2py编译成Python模块(也许Cython大师可以来比较一下与优化后的C版本的区别...):

subroutine minmax1(a,n,amin,amax)
  implicit none
  !f2py intent(hidden) :: n
  !f2py intent(out) :: amin,amax
  !f2py intent(in) :: a
  integer n
  real a(n),amin,amax
  integer i

  amin = a(1)
  amax = a(1)
  do i=2, n
     if(a(i) > amax)then
        amax = a(i)
     elseif(a(i) < amin) then
        amin = a(i)
     endif
  enddo
end subroutine minmax1

subroutine minmax2(a,n,amin,amax)
  implicit none
  !f2py intent(hidden) :: n
  !f2py intent(out) :: amin,amax
  !f2py intent(in) :: a
  integer n
  real a(n),amin,amax
  amin = minval(a)
  amax = maxval(a)
end subroutine minmax2

通过以下方式进行编译:

f2py -m untitled -c fortran_code.f90

现在,我们可以进行测试:

import timeit

size = 100000
repeat = 10000

print timeit.timeit(
    'np.min(a); np.max(a)',
    setup='import numpy as np; a = np.arange(%d, dtype=np.float32)' % size,
    number=repeat), " # numpy min/max"

print timeit.timeit(
    'untitled.minmax1(a)',
    setup='import numpy as np; import untitled; a = np.arange(%d, dtype=np.float32)' % size,
    number=repeat), '# minmax1'

print timeit.timeit(
    'untitled.minmax2(a)',
    setup='import numpy as np; import untitled; a = np.arange(%d, dtype=np.float32)' % size,
    number=repeat), '# minmax2'

对我来说,这个结果有些惊人:

8.61869883537 # numpy min/max
1.60417699814 # minmax1
2.30169081688 # minmax2

我必须说,我并不完全理解它。 仅仅比较np.minminmax1以及minmax2还是处于劣势,所以不仅是一个内存问题...

- 将大小增加10 ** a倍,并将重复次数降低10 ** a倍(保持问题大小不变)确实会改变性能,但这并不是一种看似一致的方式,这表明在Python中存在某些内存性能和函数调用开销之间的相互作用。 即使是比较简单的Fortran中的min实现,也比numpy的快大约两倍...


25
单次遍历的优势在于节省内存。特别是如果您的数组足够大,需要被换出,这个优势会非常明显。 - Danica
5
这并不完全正确,实际上速度差不多只有一半,因为在这种类型的阵列中,存储器的速度通常是限制因素,所以速度可能会减半... - seberg
10
有时候不需要进行两次检查。如果“i < minval”为真,则“i > maxval”总是为假,因此当第二个“if”被替换为“elif”时,平均每次迭代只需要进行1.5次检查。 - Fred Foo
2
小提示:我怀疑Cython不是获得最优化的Python可调用C模块的方法。 Cython的目标是成为一种类型注释的Python,然后将其机器翻译为C,而f2py只是将手动编写的Fortran包装起来,使其可以由Python调用。一个更“公平”的测试可能是手动编写C,然后使用f2py(!)将其封装为Python。如果允许使用C ++,则Shed Skin可能是平衡编码易用性和性能的最佳选择。 - John Y
5
从NumPy 1.8版本开始,在amd64平台上,min和max已经向量化。在我的core2duo处理器上,NumPy同样表现得像这段Fortran代码一样好。但是,如果数组大小超过了更大的CPU缓存的大小,进行单次遍历会更有优势。 - jtaylor
显示剩余21条评论

33

如果你觉得有用的话,可以使用一个查找(最大值-最小值)的函数叫做numpy.ptp

>>> import numpy
>>> x = numpy.array([1,2,3,4,5,6])
>>> x.ptp()
5

但我认为没有一种方法可以在一次遍历中找到最小值和最大值。

编辑:ptp在底层调用了min和max


3
这很烦人,因为据推测,ptp的实现方式必须追踪最大值和最小值! - Andy Hayden
1
或者它只是调用max和min函数,不确定。 - jterrace
3
@hayden 原来 PTP 只是调用了最大值和最小值函数。 - jterrace
2
那是掩码数组代码;主要的ndarray代码是用C编写的。但事实证明,C代码也会对数组进行两次迭代:https://github.com/numpy/numpy/blob/29e6071d2376f9e96cd111adc7b274e07ac88e8d/numpy/core/src/multiarray/calculation.c#L318。 - Ken Arnold

28

只是为了获取一些关于数字的想法,考虑以下方法:

import numpy as np


def extrema_np(arr):
    return np.max(arr), np.min(arr)
import numba as nb


@nb.jit(nopython=True)
def extrema_loop_nb(arr):
    n = arr.size
    max_val = min_val = arr[0]
    for i in range(1, n):
        item = arr[i]
        if item > max_val:
            max_val = item
        elif item < min_val:
            min_val = item
    return max_val, min_val
import numba as nb


@nb.jit(nopython=True)
def extrema_while_nb(arr):
    n = arr.size
    odd = n % 2
    if not odd:
        n -= 1
    max_val = min_val = arr[0]
    i = 1
    while i < n:
        x = arr[i]
        y = arr[i + 1]
        if x > y:
            x, y = y, x
        min_val = min(x, min_val)
        max_val = max(y, max_val)
        i += 2
    if not odd:
        x = arr[n]
        min_val = min(x, min_val)
        max_val = max(x, max_val)
    return max_val, min_val
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


import numpy as np


cdef void _extrema_loop_cy(
        long[:] arr,
        size_t n,
        long[:] result):
    cdef size_t i
    cdef long item, max_val, min_val
    max_val = arr[0]
    min_val = arr[0]
    for i in range(1, n):
        item = arr[i]
        if item > max_val:
            max_val = item
        elif item < min_val:
            min_val = item
    result[0] = max_val
    result[1] = min_val


def extrema_loop_cy(arr):
    result = np.zeros(2, dtype=arr.dtype)
    _extrema_loop_cy(arr, arr.size, result)
    return result[0], result[1]
%%cython -c-O3 -c-march=native -a
#cython: language_level=3, boundscheck=False, wraparound=False, initializedcheck=False, cdivision=True, infer_types=True


import numpy as np


cdef void _extrema_while_cy(
        long[:] arr,
        size_t n,
        long[:] result):
    cdef size_t i, odd
    cdef long x, y, max_val, min_val
    max_val = arr[0]
    min_val = arr[0]
    odd = n % 2
    if not odd:
        n -= 1
    max_val = min_val = arr[0]
    i = 1
    while i < n:
        x = arr[i]
        y = arr[i + 1]
        if x > y:
            x, y = y, x
        min_val = min(x, min_val)
        max_val = max(y, max_val)
        i += 2
    if not odd:
        x = arr[n]
        min_val = min(x, min_val)
        max_val = max(x, max_val)
    result[0] = max_val
    result[1] = min_val


def extrema_while_cy(arr):
    result = np.zeros(2, dtype=arr.dtype)
    _extrema_while_cy(arr, arr.size, result)
    return result[0], result[1]

(extrema_loop_*()方式与此处提出的方法相似,而extrema_while_*()方式基于此处的代码)

以下时间:

bm

表明extrema_while_*()是最快的,其中extrema_while_nb()是最快的。在任何情况下,extrema_loop_nb()extrema_loop_cy()解决方案都优于仅使用np.max()np.min()的NumPy-only方法。

最后,请注意,这些方法都不如np.min()/np.max()灵活(在n维支持、axis参数等方面)。

(完整代码可在此处找到)


5
如果使用@njit(fastmath=True)修饰符来编写"extrema_while_nb",似乎可以获得额外的10%速度提升。请注意,翻译过程中不会添加解释或者其他内容。 - argenisleon

16

没有人提到numpy.percentile,所以我来提一下。如果你请求[0, 100]百分位数,它会给你一个包含两个元素的数组,最小值(0th百分位数)和最大值(100th百分位数)。

然而,它并不能满足原帖的目的:它不比单独使用min和max函数更快。这可能是由于某些机制导致无法计算极端以外的百分位数(这是一个更难的问题,应该需要更长的时间来解决)。

In [1]: import numpy

In [2]: a = numpy.random.normal(0, 1, 1000000)

In [3]: %%timeit
   ...: lo, hi = numpy.amin(a), numpy.amax(a)
   ...: 
100 loops, best of 3: 4.08 ms per loop

In [4]: %%timeit
   ...: lo, hi = numpy.percentile(a, [0, 100])
   ...: 
100 loops, best of 3: 17.2 ms per loop

In [5]: numpy.__version__
Out[5]: '1.14.4'
未来的Numpy版本可能会针对只请求 [0, 100] 的情况进行特殊处理以跳过正常的百分位数计算。尽管可以通过一次调用要求Numpy返回最小值和最大值(与被接受的答案中所说的相反),但库的标准实现没有利用这种情况使其更有价值。

13

通常情况下,您可以通过同时处理两个元素并仅将较小的元素与临时最小值进行比较,将较大的元素与临时最大值进行比较,从而减少minmax算法的比较次数。平均而言,相较于朴素方法,只需要三分之四左右的比较次数。

这可以在C或Fortran(或任何其他低级语言)中实现,并且在性能方面几乎是无与伦比的。我使用来说明原理并获得非常快速、独立于数据类型的实现:

import numba as nb
import numpy as np

@nb.njit
def minmax(array):
    # Ravel the array and return early if it's empty
    array = array.ravel()
    length = array.size
    if not length:
        return

    # We want to process two elements at once so we need
    # an even sized array, but we preprocess the first and
    # start with the second element, so we want it "odd"
    odd = length % 2
    if not odd:
        length -= 1

    # Initialize min and max with the first item
    minimum = maximum = array[0]

    i = 1
    while i < length:
        # Get the next two items and swap them if necessary
        x = array[i]
        y = array[i+1]
        if x > y:
            x, y = y, x
        # Compare the min with the smaller one and the max
        # with the bigger one
        minimum = min(x, minimum)
        maximum = max(y, maximum)
        i += 2

    # If we had an even sized array we need to compare the
    # one remaining item too.
    if not odd:
        x = array[length]
        minimum = min(x, minimum)
        maximum = max(x, maximum)

    return minimum, maximum

这绝对比Peque提出的朴素方法更快:

arr = np.random.random(3000000)
assert minmax(arr) == minmax_peque(arr)  # warmup and making sure they are identical 
%timeit minmax(arr)            # 100 loops, best of 3: 2.1 ms per loop
%timeit minmax_peque(arr)      # 100 loops, best of 3: 2.75 ms per loop

正如预期的那样,新的minmax实现所需时间仅为原始实现时间的3/4 (2.1 / 2.75 = 0.7636363636363637)


1
在我的电脑上,你的解决方案不比Peque的快。Numba 0.33。 - John Zwinck
@johnzwinck 你有运行我回答中的基准测试吗?如果是不同的,请分享一下。但是也有可能:我注意到新版本中存在一些退化。 - MSeifert
我运行了你的基准测试。你的解决方案和@Peque的时间几乎相同(约2.8毫秒)。 - John Zwinck
@JohnZwinck 这很奇怪,我刚刚再次测试了一下,在我的电脑上它明显更快。也许这与依赖于硬件的numba和LLVM有关。 - MSeifert
我现在在另一台机器上尝试了一下(一台强大的工作站),你的用时是2.4毫秒,Peque的用时是2.6毫秒。所以,这是一个小胜利。 - John Zwinck
好的,我在我的旧机器上得到了8毫秒(我的)与11毫秒(peque)。这绝对很奇怪。:( - MSeifert

9

这是一个旧的帖子,但如果有人再次查看它...

在同时寻找最小值和最大值时,可以减少比较次数。如果你在比较浮点数(我猜你是),这可能会节省一些时间,但不会影响计算复杂度。

代替(Python代码):

_max = ar[0]
_min=  ar[0]
for ii in xrange(len(ar)):
    if _max > ar[ii]: _max = ar[ii]
    if _min < ar[ii]: _min = ar[ii]

您可以先比较数组中相邻的两个值,然后只将较小的与当前最小值进行比较,将较大的与当前最大值进行比较:

## for an even-sized array
_max = ar[0]
_min = ar[0]
for ii in xrange(0, len(ar), 2)):  ## iterate over every other value in the array
    f1 = ar[ii]
    f2 = ar[ii+1]
    if (f1 < f2):
        if f1 < _min: _min = f1
        if f2 > _max: _max = f2
    else:
        if f2 < _min: _min = f2
        if f1 > _max: _max = f1

这里的代码是用Python编写的,显然为了速度,你可以使用C、Fortran或Cython,但这样做每次迭代需要进行3次比较,对于len(ar)/2次迭代,给出3/2 * len(ar)次比较。相反,采用“显而易见”的比较方法,每次迭代需要进行2次比较,从而导致2*len(ar)次比较。节省了25%的比较时间。
也许有一天会有人觉得这个有用。

6
你有进行基准测试吗?在现代的x86硬件中,你可以使用min和max的机器指令,就像第一种变体中所使用的那样,这些指令避免了分支的需要,而你的代码引入了控制依赖项,这可能不太适配硬件。 - jtaylor
我实际上还没有。如果有机会的话,我会去做。我认为很明显,纯Python代码在与任何明智的编译实现相比都会输掉,但我想知道在Cython中是否可以看到加速。 - Bennet
15
在numpy中有一个minmax实现,底层被np.bincount使用,详情见这里。它不使用你所指出的技巧,因为它比朴素方法慢了多达两倍。在PR中有一个链接指向两种方法的全面基准测试。 - Jaime

6

乍一看,numpy.histogram 看起来可以解决问题:

count, (amin, amax) = numpy.histogram(a, bins=1)

...但是,如果您查看该函数的源代码,它只是独立调用a.min()a.max(),因此无法避免本问题中提到的性能问题。 :-(

同样地,scipy.ndimage.measurements.extrema看起来也是一个可能的选择,但是它也只是独立调用a.min()a.max()


3
np.histogram 并不总是适用于此,因为返回的 (amin, amax) 值是箱子中最小值和最大值。例如,如果我有 a = np.zeros(10),那么 np.histogram(a, bins=1) 返回 (array([10]), array([-0.5, 0.5]))。在这种情况下,用户正在寻找 (amin, amax) = (0, 0)。 - eclark

3

无论如何,这对我来说都是值得努力的,所以我会在此提出最困难且最不优雅的解决方案,供有兴趣的人参考。我的解决方案是在C++中实现多线程的一趟最小/最大算法,并使用它创建一个Python扩展模块。这种努力需要花费一些时间来学习如何使用Python和NumPy C/C++ API,在这里我将展示代码,并为那些希望走上这条道路的人提供一些简要的说明和参考。

多线程的最小/最大

这里没有什么特别有趣的地方。数组被分成大小为length/workers的块。每个块的最小/最大值在future中计算,然后在全局范围内扫描。

    // mt_np.cc
    //
    // multi-threaded min/max algorithm

    #include <algorithm>
    #include <future>
    #include <vector>

    namespace mt_np {

    /*
     * Get {min,max} in interval [begin,end)
     */
    template <typename T> std::pair<T, T> min_max(T *begin, T *end) {
      T min{*begin};
      T max{*begin};
      while (++begin < end) {
        if (*begin < min) {
          min = *begin;
          continue;
        } else if (*begin > max) {
          max = *begin;
        }
      }
      return {min, max};
    }

    /*
     * get {min,max} in interval [begin,end) using #workers for concurrency
     */
    template <typename T>
    std::pair<T, T> min_max_mt(T *begin, T *end, int workers) {
      const long int chunk_size = std::max((end - begin) / workers, 1l);
      std::vector<std::future<std::pair<T, T>>> min_maxes;
      // fire up the workers
      while (begin < end) {
        T *next = std::min(end, begin + chunk_size);
        min_maxes.push_back(std::async(min_max<T>, begin, next));
        begin = next;
      }
      // retrieve the results
      auto min_max_it = min_maxes.begin();
      auto v{min_max_it->get()};
      T min{v.first};
      T max{v.second};
      while (++min_max_it != min_maxes.end()) {
        v = min_max_it->get();
        min = std::min(min, v.first);
        max = std::max(max, v.second);
      }
      return {min, max};
    }
    }; // namespace mt_np

Python扩展模块

在这里,事情开始变得棘手...在Python中使用C ++代码的一种方法是实现一个扩展模块。可以使用 distutils.core 标准模块来构建和安装此模块。这包含在Python文档中: https://docs.python.org/3/extending/extending.html。注意:肯定有其他方法可以获得类似的结果,引用 https://docs.python.org/3/extending/index.html#extending-index:

此指南仅介绍作为CPython版本的基本创建扩展的工具。第三方工具,如 Cython、cffi、SWIG 和 Numba,提供了创建 C 和 C++ 扩展的更简单和更复杂的方法。

基本上,这条路线可能更加学术而不是实际应用。话虽如此,接下来我会遵循教程并创建一个模块文件。这本质上是distutils用于知道如何处理您的代码并将其创建为Python模块的样板。在进行任何操作之前,最好创建一个Python虚拟环境,以便不污染系统软件包(请参阅 https://docs.python.org/3/library/venv.html#module-venv)。

这是模块文件:

// mt_np_forpy.cc
//
// C++ module implementation for multi-threaded min/max for np

#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION

#include <python3.6/numpy/arrayobject.h>

#include "mt_np.h"

#include <cstdint>
#include <iostream>

using namespace std;

/*
 * check:
 *  shape
 *  stride
 *  data_type
 *  byteorder
 *  alignment
 */
static bool check_array(PyArrayObject *arr) {
  if (PyArray_NDIM(arr) != 1) {
    PyErr_SetString(PyExc_RuntimeError, "Wrong shape, require (1,n)");
    return false;
  }
  if (PyArray_STRIDES(arr)[0] != 8) {
    PyErr_SetString(PyExc_RuntimeError, "Expected stride of 8");
    return false;
  }
  PyArray_Descr *descr = PyArray_DESCR(arr);
  if (descr->type != NPY_LONGLTR && descr->type != NPY_DOUBLELTR) {
    PyErr_SetString(PyExc_RuntimeError, "Wrong type, require l or d");
    return false;
  }
  if (descr->byteorder != '=') {
    PyErr_SetString(PyExc_RuntimeError, "Expected native byteorder");
    return false;
  }
  if (descr->alignment != 8) {
    cerr << "alignment: " << descr->alignment << endl;
    PyErr_SetString(PyExc_RuntimeError, "Require proper alignement");
    return false;
  }
  return true;
}

template <typename T>
static PyObject *mt_np_minmax_dispatch(PyArrayObject *arr) {
  npy_intp size = PyArray_SHAPE(arr)[0];
  T *begin = (T *)PyArray_DATA(arr);
  auto minmax =
      mt_np::min_max_mt(begin, begin + size, thread::hardware_concurrency());
  return Py_BuildValue("(L,L)", minmax.first, minmax.second);
}

static PyObject *mt_np_minmax(PyObject *self, PyObject *args) {
  PyArrayObject *arr;
  if (!PyArg_ParseTuple(args, "O", &arr))
    return NULL;
  if (!check_array(arr))
    return NULL;
  switch (PyArray_DESCR(arr)->type) {
  case NPY_LONGLTR: {
    return mt_np_minmax_dispatch<int64_t>(arr);
  } break;
  case NPY_DOUBLELTR: {
    return mt_np_minmax_dispatch<double>(arr);
  } break;
  default: {
    PyErr_SetString(PyExc_RuntimeError, "Unknown error");
    return NULL;
  }
  }
}

static PyObject *get_concurrency(PyObject *self, PyObject *args) {
  return Py_BuildValue("I", thread::hardware_concurrency());
}

static PyMethodDef mt_np_Methods[] = {
    {"mt_np_minmax", mt_np_minmax, METH_VARARGS, "multi-threaded np min/max"},
    {"get_concurrency", get_concurrency, METH_VARARGS,
     "retrieve thread::hardware_concurrency()"},
    {NULL, NULL, 0, NULL} /* sentinel */
};

static struct PyModuleDef mt_np_module = {PyModuleDef_HEAD_INIT, "mt_np", NULL,
                                          -1, mt_np_Methods};

PyMODINIT_FUNC PyInit_mt_np() { return PyModule_Create(&mt_np_module); }

在这个文件中,有大量使用Python和NumPy API的内容,更多信息请参考:https://docs.python.org/3/c-api/arg.html#c.PyArg_ParseTuple,以及NumPy:https://docs.scipy.org/doc/numpy/reference/c-api.array.html

安装模块

接下来要做的是利用distutils安装模块。这需要一个setup文件:
# setup.py

from distutils.core import setup,Extension

module = Extension('mt_np', sources = ['mt_np_module.cc'])

setup (name = 'mt_np', 
       version = '1.0', 
       description = 'multi-threaded min/max for np arrays',
       ext_modules = [module])

最终安装模块,请在虚拟环境中执行 python3 setup.py install

测试该模块

最后,我们可以测试一下 C++ 实现是否比使用 NumPy 的朴素方法更快。为了这样做,下面是一个简单的测试脚本:

# timing.py
# compare numpy min/max vs multi-threaded min/max

import numpy as np
import mt_np
import timeit

def normal_min_max(X):
  return (np.min(X),np.max(X))

print(mt_np.get_concurrency())

for ssize in np.logspace(3,8,6):
  size = int(ssize)
  print('********************')
  print('sample size:', size)
  print('********************')
  samples = np.random.normal(0,50,(2,size))
  for sample in samples:
    print('np:', timeit.timeit('normal_min_max(sample)',
                 globals=globals(),number=10))
    print('mt:', timeit.timeit('mt_np.mt_np_minmax(sample)',
                 globals=globals(),number=10))

这是我通过做这一切得到的结果:

8  
********************  
sample size: 1000  
********************  
np: 0.00012079699808964506  
mt: 0.002468645994667895  
np: 0.00011947099847020581  
mt: 0.0020772050047526136  
********************  
sample size: 10000  
********************  
np: 0.00024697799381101504  
mt: 0.002037393998762127  
np: 0.0002713389985729009  
mt: 0.0020942929986631498  
********************  
sample size: 100000  
********************  
np: 0.0007130410012905486  
mt: 0.0019842900001094677  
np: 0.0007540129954577424  
mt: 0.0029724110063398257  
********************  
sample size: 1000000  
********************  
np: 0.0094779249993735  
mt: 0.007134920000680722  
np: 0.009129883001151029  
mt: 0.012836456997320056  
********************  
sample size: 10000000  
********************  
np: 0.09471094200125663  
mt: 0.0453535050037317  
np: 0.09436299200024223  
mt: 0.04188535599678289  
********************  
sample size: 100000000  
********************  
np: 0.9537652180006262  
mt: 0.3957935369980987  
np: 0.9624398809974082  
mt: 0.4019058070043684  

这些结果比此前在线程中指示的速度提升约3.5倍的结果要逊色得多,而且没有结合多线程。我获得的结果还算合理,我预计,在数组足够大之前,线程的开销会占主导地位,此时性能提升将开始接近std::thread ::hardware_concurrency x增加。

结论

针对一些NumPy代码,特别是与多线程相关的代码,显然存在应用程序特定优化的空间。是否值得付出努力并不清楚,但肯定是一个很好的锻炼(或其他东西)。我认为,也许学习像Cython这样的“第三方工具”可能是更好的时间利用方式,但谁知道呢。


1
我开始学习你的代码,了解一些C++,但仍未使用std::future和std::async。在你的“min_max_mt”模板函数中,它是如何知道在触发和检索结果之间每个工作线程都已完成的呢?(只是为了理解而问,并不是说有什么问题) - ChrCury78
该行代码 v = min_max_it->get();get 方法会阻塞直到结果准备好并返回它。由于循环遍历每个 future,因此在所有 future 完成之前它不会结束。future.get() - Nathan Chappell

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接