生成一个二维NumPy数组的索引

4
我想生成一个二维的numpy数组,其中的元素是由它们的位置计算得出的。就像以下代码一样:
import numpy as np

def calculate_element(i, j, other_parameters):
    # do something
    return value_at_i_j

def main():
    arr = np.zeros((M, N))  # (M, N) is the shape of the array
    for i in range(M):
        for j in range(N):
            arr[i][j] = calculate_element(i, j, ...)

由于Python中的循环效率不高,因此此代码运行非常缓慢。在这种情况下有没有更快的方法?

顺便说一下,现在我使用一个变通方法通过计算两个2D的“索引矩阵”来完成。就像这样:

def main():
    index_matrix_i = np.array([range(M)] * N).T
    index_matrix_j = np.array([range(N)] * M)

    '''
    index_matrix_i is like
    [[0,0,0,...],
     [1,1,1,...],
     [2,2,2,...],
     ...
    ]

    index_matrix_j is like
    [[0,1,2,...],
     [0,1,2,...],
     [0,1,2,...],
     ...
    ]
    '''

    arr = calculate_element(index_matrix_i, index_matrix_j, ...)

编辑1:在应用“索引矩阵”技巧后,代码变得更快了,因此我想问的主要问题是是否有一种方法使用这个技巧,因为它需要更多的内存。简而言之,我希望有一个既时间高效空间高效的解决方案。
编辑2:以下是我测试过的一些例子。
# a simple 2D Gaussian
def calculate_element(i, j, i_mid, j_mid, i_sig, j_sig):
    gaus_i = np.exp(-((i - i_mid)**2) / (2 * i_sig**2))
    gaus_j = np.exp(-((j - j_mid)**2) / (2 * j_sig**2))
    return gaus_i * gaus_j

# size of M, N
M, N = 1200, 4000

# use for loops to go through every element
# this code takes ~10 seconds
def main_1():
    arr = np.zeros((M, N))  # (M, N) is the shape of the array
    for i in range(M):
        for j in range(N):
            arr[i][j] = calculate_element(i, j, 600, 2000, 300, 500)
    # print(arr)
    plt.figure(figsize=(8, 5))
    plt.imshow(arr, aspect='auto', origin='lower')
    plt.show()

# use index matrices
# this code takes <1 second
def main_2():
    index_matrix_i = np.array([range(M)] * N).T
    index_matrix_j = np.array([range(N)] * M)
    arr = calculate_element(index_matrix_i, index_matrix_j, 600, 2000, 300, 500)

    # print(arr)
    plt.figure(figsize=(8, 5))
    plt.imshow(arr, aspect='auto', origin='lower')
    plt.show()

calculate_element 的主体是什么?无论你如何分配,你仍然会调用这个函数 M*N 次。如果你可以缓存这个函数的中间结果,你就可以加速循环。 - Jeremy
使用“索引矩阵”技巧后,速度显著提升,因此我认为循环花费的时间太长,而不是函数本身。 - Alex Chen
1
请添加一个 [mre]。这看起来可以轻松地进行向量化(空间开销为 M * N * np.intp),使用 numba 可以轻松避免。但是,如果没有 calculate_element 的主体和 N,M 的大小估计,那只是一个猜测。 - Michael Szczesny
快速编译的 numpy 方法可以处理或生成整个数组。因此,即使在临时缓冲区中,它们也可能占用大量内存。从形状和 dtype 很容易估算内存使用情况。迭代可以避免那些大型临时数组,但这不会改变最终结果的大小。但是你会失去速度 - 除非你在编译工具(如 cythonnumba)中实现循环。 - hpaulj
3个回答

4
你可以使用np.indices()生成所需的输出:
例如,
np.indices((3, 4))

输出:

[[[0 0 0 0]
  [1 1 1 1]
  [2 2 2 2]]

 [[0 1 2 3]
  [0 1 2 3]
  [0 1 2 3]]]

哇!那是一种高效生成索引矩阵的方法!但是有没有办法减少内存使用? - Alex Chen
我找到了另一种方法来做这件事: index_matrix_i,index_matrix_j = np.meshgrid(range(M), range(N), indexing ='ij')。 速度与@BrokenBenchmark的方法相当,并且比原始索引矩阵生成方法快得多。 - Alex Chen
1
meshgrid 也可以使用。顺便提一下,如果您还没有这样做的话,请记得给有用的答案点赞并接受您认为最有帮助的答案。 - BrokenBenchmark

2

在我的双核机器上,矢量化的比简单的即时编译要快。

import numpy as np
import matplotlib.pyplot as plt

M, N = 1200, 4000
i = np.arange(M)
j = np.arange(N)
i_mid, j_mid, i_sig, j_sig = 600, 2000, 300, 500

arr = np.exp(-(i - i_mid)**2 / (2 * i_sig**2))[:,None] * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
# %timeit 100 loops, best of 5: 8.82 ms per loop

plt.figure(figsize=(8, 5))
plt.imshow(arr, aspect='auto', origin='lower')
plt.show()

numpy result

使用Jitted并行化的numba
import numba as nb  # tested with numba 0.55.1

@nb.njit(parallel=True)
def calculate_element_nb(i, j, i_mid, j_mid, i_sig, j_sig):
    res = np.empty((i,j), np.float32)
    for i in nb.prange(res.shape[0]):
        for j in range(res.shape[1]):
            res[i,j] = np.exp(-(i - i_mid)**2 / (2 * i_sig**2)) * np.exp(-(j - j_mid)**2 / (2 * j_sig**2))
    return res

M, N = 1200, 4000

calculate_element_nb(M, N, 600, 2000, 300, 500)
# %timeit 10 loops, best of 5: 80.4 ms per loop

plt.figure(figsize=(8, 5))
plt.imshow(calculate_element_nb(M, N, 600, 2000, 300, 500), aspect='auto', origin='lower')
plt.show()

numba result


1

您可以使用单个循环来填充多维列表,完成所有元素后,将其转换为 np.array,如下所示:

import numpy as np

m, n = 5, 5
arr = []
for i in range(0, m*n, n):
    arr.append(list(range(i, i+n)))
print(np.array(arr))

输出:

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接