Python Numpy向量化嵌套的for循环以进行组合计算

8

给定一个nxn的实数正数数组A,我正在尝试找到二维数组的所有三行元素最小值的最大值的最小值。使用for循环,代码如下:

import numpy as np

n = 100
np.random.seed(2)
A = np.random.rand(n,n)
global_best = np.inf

for i in range(n-2):
    for j in range(i+1, n-1):
        for k in range(j+1, n):
            # find the maximum of the element-wise minimum of the three vectors
            local_best = np.amax(np.array([A[i,:], A[j,:], A[k,:]]).min(0))
            # if local_best is lower than global_best, update global_best
            if (local_best < global_best):
                global_best = local_best
                save_rows = [i, j, k]

print global_best, save_rows

如果 n = 100,那么输出应该是这样的:

Out[]: 0.492652949593 [6, 41, 58]

我有一个感觉,使用Numpy向量化可以更快地完成这项任务,如果能得到任何帮助将不胜感激。谢谢。


你在寻找类似这样的东西吗?https://dev59.com/sarka4cB1Zd3GeqPgJrb#49821744 它并不完全相同,但获取3个组合而不是2个可能会解决问题。 - OriolAbril
n的大小是多少? - Brad Solomon
@BradSolomon 变量。我们就说始终<1000吧。 - ToneDaBass
1
如果您能提供一个小的输入和输出示例,那将会很有帮助。 - John Zwinck
1
@JohnZwinck 更新了一个可独立运行的脚本。您可以改变 n 的值来测试更大规模的问题。 - ToneDaBass
4个回答

6

这个方案在 n=100的情况下快了5倍:

coms = np.fromiter(itertools.combinations(np.arange(n), 3), 'i,i,i').view(('i', 3))
best = A[coms].min(1).max(1)
at = best.argmin()
global_best = best[at]
save_rows = coms[at]

第一行有点绕,但它将`itertools.combinations`的结果转换为一个包含所有可能的`[i,j,k]`索引组合的NumPy数组。然后,只需使用所有可能的索引组合来对`A`进行索引,然后沿着适当的轴进行缩减即可。这种解决方案消耗了更多的内存,因为它构建了所有可能组合`A[coms]`的具体数组。对于较小的`n`,比如小于250,可以节省时间,但是对于大的`n`,内存流量将非常高,并且可能比原始代码慢。

非常有趣。事实上,这正是我最初想要的思路,但随着规模的增大,内存问题确实成为了一个真正的问题。我尝试了n=430,结果代码崩溃了。这是一个很好的例子,说明随着问题规模的增加,内存最终会支配算法效率。 - ToneDaBass
在阅读这个答案之前,我不知道np.fromiter的存在,现在我已经成为了它的忠实粉丝。 - OriolAbril

5

按块处理可以结合矢量化计算的速度,同时避免遇到内存错误。下面是将嵌套循环转换为按块矢量化的示例。

从与问题相同的变量开始,定义一个块长度,以便在块内矢量化计算,并仅循环处理块,而不是组合。

chunk = 2000 # define chunk length, if to small, the code won't take advantage 
             # of vectorization, if it is too large, excessive memory usage will 
             # slow down execution, or Memory Error will be risen 
combinations = itertools.combinations(range(n),3) # generate iterator containing 
                                        # all possible combinations of 3 columns
N = n*(n-1)*(n-2)//6 # number of combinations (length of combinations cannot be 
                     # retrieved because it is an iterator)
# generate a list containing how many elements of combinations will be retrieved 
# per iteration
n_chunks, remainder = divmod(N,chunk)
counts_list = [chunk for _ in range(n_chunks)]
if remainder:
    counts_list.append(remainder)

# Iterate one chunk at a time, using vectorized code to treat the chunk
for counts in counts_list:
    # retrieve combinations in current chunk
    current_comb = np.fromiter(combinations,dtype='i,i,i',count=counts)\
                     .view(('i',3)) 
    # maximum of element-wise minimum in current chunk
    chunk_best = np.minimum(np.minimum(A[current_comb[:,0],:],A[current_comb[:,1],:]),
                            A[current_comb[:,2],:]).max(axis=1) 
    ravel_save_row = chunk_best.argmin() # minimum of maximums in current chunk
    # check if current chunk contains global minimum
    if chunk_best[ravel_save_row] < global_best: 
        global_best = chunk_best[ravel_save_row]
        save_rows = current_comb[ravel_save_row]
print(global_best,save_rows)

我使用嵌套循环进行了性能比较,得出以下结果 (chunk_length = 1000):
  • n=100
    • 嵌套循环: 1.13 秒 ± 16.6 毫秒
    • 分块处理:108 毫秒 ± 565 微秒
  • n=150
    • 嵌套循环: 4.16 秒 ± 39.3 毫秒
    • 分块处理:523 毫秒 ± 4.75 毫秒
  • n=500
    • 嵌套循环: 3 分钟 18 秒 ± 3.21 秒
    • 分块处理:1 分钟 12 秒 ± 1.6 秒

注意

通过对代码进行剖析,我发现 np.min 是最耗时间的部分,因为其调用了 np.maximum.reduce 。我直接将其转换为 np.maximum ,这样可以稍微提高一下性能。


1
这是一个很好的方法,可以在避免内存问题的同时获得快速计算的好处,我从这个帖子中学到了很多! - ToneDaBass

2
你可以使用Python标准库中的itertools来帮助你消除所有嵌套循环。
from itertools import combinations
import numpy as np

n = 100
np.random.seed(2)
A = np.random.rand(n,n)
global_best = 1000000000000000.0

for i, j, k in combinations(range(n), 3):
    local_best = np.amax(np.array([A[i,:], A[j,:], A[k,:]]).min(0))
    if local_best < global_best:
        global_best = local_best
        save_rows = [i, j, k]

print global_best, save_rows

1
是的,这确实使代码看起来更好。但它的运行速度几乎与三个嵌套循环相同。我在想是否有任何方法可以显着提高运行速度。 - ToneDaBass

2
不要试图向量化不易向量化的循环。相反,使用像Numba这样的jit编译器或使用Cython。如果结果代码更易读,则向量化解决方案是好的,但就性能而言,编译解决方案通常更快,或者在最坏的情况下与向量化解决方案一样快(除了BLAS例程)。
单线程示例
import numba as nb
import numpy as np

#Min and max library calls may be costly for only 3 values
@nb.njit()
def max_min_3(A,B,C):
  max_of_min=-np.inf
  for i in range(A.shape[0]):
    loc_min=A[i]
    if (B[i]<loc_min):
      loc_min=B[i]
    if (C[i]<loc_min):
      loc_min=C[i]

    if (max_of_min<loc_min):
      max_of_min=loc_min

  return max_of_min

@nb.njit()
def your_func(A):
  n=A.shape[0]
  save_rows=np.zeros(3,dtype=np.uint64)
  global_best=np.inf
  for i in range(n):
      for j in range(i+1, n):
          for k in range(j+1, n):
              # find the maximum of the element-wise minimum of the three vectors
              local_best = max_min_3(A[i,:], A[j,:], A[k,:])
              # if local_best is lower than global_best, update global_best
              if (local_best < global_best):
                  global_best = local_best
                  save_rows[0] = i
                  save_rows[1] = j
                  save_rows[2] = k

  return global_best, save_rows

单线程版本的性能

n=100
your_version: 1.56s
compiled_version: 0.0168s (92x speedup)

n=150
your_version: 5.41s
compiled_version: 0.08122s (66x speedup)

n=500
your_version: 283s
compiled_version: 8.86s (31x speedup)

第一次调用大约需要0.3-1秒的恒定开销。为了测量计算时间本身的性能,只需调用一次,然后测量性能。

通过少量代码更改,此任务也可以并行化。

多线程示例

@nb.njit(parallel=True)
def your_func(A):
  n=A.shape[0]
  all_global_best=np.inf
  rows=np.empty((3),dtype=np.uint64)

  save_rows=np.empty((n,3),dtype=np.uint64)
  global_best_Temp=np.empty((n),dtype=A.dtype)
  global_best_Temp[:]=np.inf

  for i in range(n):
      for j in nb.prange(i+1, n):
          row_1=0
          row_2=0
          row_3=0
          global_best=np.inf
          for k in range(j+1, n):
              # find the maximum of the element-wise minimum of the three vectors

              local_best = max_min_3(A[i,:], A[j,:], A[k,:])
              # if local_best is lower than global_best, update global_best
              if (local_best < global_best):
                  global_best = local_best
                  row_1 = i
                  row_2 = j
                  row_3 = k

          save_rows[j,0]=row_1
          save_rows[j,1]=row_2
          save_rows[j,2]=row_3
          global_best_Temp[j]=global_best

      ind=np.argmin(global_best_Temp)
      if (global_best_Temp[ind]<all_global_best):
          rows[0] = save_rows[ind,0]
          rows[1] = save_rows[ind,1]
          rows[2] = save_rows[ind,2]
          all_global_best=global_best_Temp[ind]

  return all_global_best, rows

多线程版本的性能

n=100
your_version: 1.56s
compiled_version: 0.0078s (200x speedup)

n=150
your_version: 5.41s
compiled_version: 0.0282s (191x speedup)

n=500
your_version: 283s
compiled_version: 2.95s (96x speedup)

编辑

在较新的Numba版本中(通过Anaconda Python分发安装),我需要手动安装tbb才能获得有效的并行化。


numba的后续版本似乎会在您的并行方法中出现错误,您知道如何修复吗? - ToneDaBass
@ToneDaBass 我在更新的 Numba 版本中遇到了相同的问题。安装 TBB 线程后端应该能够解决这个问题。 - max9111
问题似乎是由代码中的 save_rows[j,0]=row_1 部分引起的,导致出现以下错误 Cannot resolve setitem: array(uint64, 2d, C)[(int64, Literal[int](0))] = array(int64, 1d, C) - ToneDaBass
1
@ToneDaBass 我也稍微编辑了一下代码(插入row_1=0,row_2=0,row_3=0),似乎没有这个编辑,Numba无法识别线程本地变量row_1、row_2和row_3。 - max9111
我之前忘了评论并感谢您发布的修改,它们完美地解决了问题。谢谢! - ToneDaBass
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接