为什么numpy.var的空间复杂度是O(N)?

3

我有一个大约13GB的数组。 我调用 numpy.var 来计算方差。但是,它会再次分配大约13GB的内存来执行此操作。为什么它需要O(N)的空间?或者我使用 numpy.var 的方式有误?

import numpy as np
# data = ...
print('Variance: ', np.var(data))

也许这取决于你的数组维度。 - Devesh Kumar Singh
该数组为1.2M乘以2.8K的float32 - Serge Rogatch
我建议查看文档 https://docs.scipy.org/doc/numpy/reference/generated/numpy.var.html 并使用较小的数组调整参数以检查发生了什么,以及方差是否符合您的预期。 - Devesh Kumar Singh
3
var 函数使用“平均值与各元素差的平方和”来计算方差。因此,在求和之前需要临时计算(数据-数据均值)**2numpy 经常将多个整个数组的计算组合在一起。它通过使用编译的代码来执行常见任务,而不是节省内存来提高速度。 - hpaulj
1
当NumPy创建巨大的临时数组时,不要感到惊讶。NumPy的设计使得巨大的临时数组成为几乎任何事情的最高效解决方案的一部分,即使显式循环可能更节省空间。 - user2357112
2个回答

3

为了计算方差,NumPy将创建一个中间数组来计算abs(data - data.mean()) ** 2。您可以使用循环编写自己的方差函数,并使用Numba使其更快:

import numpy as np
import numba as nb

@nb.njit(parallel=True)
def var_nb(a, ddof=0):
    n = len(a)
    s = a.sum()
    m = s / (n - ddof)
    v = 0
    for i in nb.prange(n):
        v += abs(a[i] - m) ** 2
    return v / (n - ddof)

np.random.seed(100)
a = np.random.rand(100_000)
print(np.var(a))
# 0.08349747560941487
print(var_nb(a))
# 0.08349747560941487

%timeit np.var(a)
# 143 µs ± 414 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit var_nb(a)
# 40.2 µs ± 530 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

1

这是没有并行化的更快的方式:

import numpy as np


def var(a: np.ndarray, axis: int = 0):
return np.sum(abs(a - (a.sum(axis=axis) / len(a))) ** 2, axis=axis) / len(a)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接