如何高效地计算一个运行标准差

97
我有一个包含数字列表的数组,例如:
[0] (0.01, 0.01, 0.02, 0.04, 0.03)
[1] (0.00, 0.02, 0.02, 0.03, 0.02)
[2] (0.01, 0.02, 0.02, 0.03, 0.02)
     ...
[n] (0.01, 0.00, 0.01, 0.05, 0.03)

我想要高效地计算列表中每个索引的平均值和标准差,跨越所有数组元素。
为了计算平均值,我一直在循环遍历数组,并将给定索引处的值相加。最后,我将我的“平均值列表”中的每个值除以 n(我正在处理的是总体,而不是总体的样本)。
为了计算标准差,现在我已经计算出了平均值,所以我再次进行循环遍历。
我希望能够避免两次遍历数组,一次用于计算平均值,一次用于计算标准差(在我有了平均值之后)。
是否有一种高效的方法可以同时计算这两个值,只需一次遍历数组?任何解释性语言(例如 Perl 或 Python)或伪代码都可以。

7
不同的语言,但是相同的算法:https://dev59.com/4nNA5IYBdhLWcg3wmfK5 - dmckee --- ex-moderator kitten
1
此外,在http://rosettacode.org/wiki/Standard_Deviation上还有几个示例。 - glenn jackman
1
维基百科有一个Python实现 http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm - Hamish Grubijan
1
也许可以尝试使用现有的numpy实现Welford算法:https://github.com/a-mitani/welford - Alex Reynolds
1
@AlexReynolds 我在考虑将它翻译成PyTorch,因此需要复制粘贴部分...但也许可以用于那个或者只是插入和播放...谢谢,稍后会看一下。 - Charlie Parker
显示剩余3条评论
17个回答

138

答案是使用Welford算法,在“naive methods”之后非常清晰地定义在:

相比其他方法中提到的两遍遍历或在线简单平方和收集器,它更具数值稳定性。当您有许多接近数值时,稳定性才真正重要,因为这些接近数值会导致浮点数文献中所称的“灾难性取消”。

您也可能想要了解在方差计算(平方偏差)中通过样本数量(N)和N-1进行除法的区别。通过N-1进行除法可以得到来自样本的无偏估计方差,而通过N进行除法则平均低估方差(因为它未考虑样本均值与真实均值之间的方差)。


5
感谢您关注从 Welford 算法中删除值的问题,为此给您点赞(+1)。我会在翻译过程中尽力保留原文意思,同时使翻译更加通俗易懂。 - Svisstack
6
回答不错,加一分提醒读者如何区分总体标准差和样本标准差。 - Assad Ebrahim
但它也更慢,所以这是数值稳定性和性能之间的权衡(换句话说,这取决于情况)。 "...在循环内部进行除法运算,因此可能效率不高。" - Peter Mortensen
这不应该是被接受的答案。正如@PeterMortensen所说,在循环内添加除法运算符对现代硬件来说太慢而不实用。当像GPT这样的模型在NVIDIA GPU上运行时,默认情况下使用Jonathan的答案(请参见FasterTransformer中的LayerNorm kernel)。在极少数需要更高准确性的情况下,会使用双通道算法。 - Gavin Uberti
这不应该是被接受的答案。正如@PeterMortensen所说,在循环内添加除法运算符对于现代硬件来说速度太慢,不实用。当像GPT这样的模型在NVIDIA GPU上运行时,默认情况下使用Jonathan的答案(请参见FasterTransformer中的LayerNorm kernel)。在极少数需要更高精度的情况下,会使用两次传递的算法。 - undefined

82

基本答案是在计算过程中累加x的总和(称为'sum_x1')和x2的总和(称为'sum_x2')。标准偏差的值就是:

stdev = sqrt((sum_x2 / n) - (mean * mean)) 

在哪里

mean = sum_x / n

这是样本标准差; 如果使用“n”而不是“n-1”作为除数,则可以获得总体标准差。

如果处理大样本,可能需要担心两个大数之间的差异对数值稳定性的影响。有关更多信息,请参阅其他答案中的外部引用(维基百科等)。


2
我决定采用Welford算法,因为它在相同的计算开销下表现更加可靠。 - Alex Reynolds
4
这是答案的简化版本,根据输入可能会产生非真实结果(例如,当sum_x2 < sum_x1 * sum_x1时)。为确保获得有效的真实结果,请使用 sd = sqrt(((n * sum_x2) - (sum_x1 * sum_x1)) / (n * (n - 1))) - Dan Tao
4
@Dan指出了一个有效的问题-上面的公式在x>1时会失效,因为你最终会取一个负数的平方根。Knuth的方法是:sqrt((sum_x2 / n) - (mean * mean)),其中mean = (sum_x / n)。 - G__
@flies:自从我在一年前留下这条评论以来,答案已经发生了变化,而Greg则在两个月前离开了。使用的公式曾经是sqrt((sum_x2 - sum_x1 * sum_x1) / (n - 1)),但除非我弄错了,否则实际上是不正确的。 - Dan Tao
1
@UriLoya — 你还没有说你是如何计算这些值的。然而,如果你在C语言中使用int来存储平方和,你会遇到溢出问题,尤其是对于你列出的这些值。 - Jonathan Leffler
显示剩余3条评论

60

这里是John D. Cook在他的优秀文章“准确计算运行方差”中实现的Welford算法的纯Python字面翻译:

文件 running_stats.py

import math

class RunningStats:

    def __init__(self):
        self.n = 0
        self.old_m = 0
        self.new_m = 0
        self.old_s = 0
        self.new_s = 0

    def clear(self):
        self.n = 0

    def push(self, x):
        self.n += 1

        if self.n == 1:
            self.old_m = self.new_m = x
            self.old_s = 0
        else:
            self.new_m = self.old_m + (x - self.old_m) / self.n
            self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m)

            self.old_m = self.new_m
            self.old_s = self.new_s

    def mean(self):
        return self.new_m if self.n else 0.0

    def variance(self):
        return self.new_s / (self.n - 1) if self.n > 1 else 0.0

    def standard_deviation(self):
        return math.sqrt(self.variance())

使用方法:

rs = RunningStats()
rs.push(17.0)
rs.push(19.0)
rs.push(24.0)

mean = rs.mean()
variance = rs.variance()
stdev = rs.standard_deviation()

print(f'Mean: {mean}, Variance: {variance}, Std. Dev.: {stdev}')

15
作为唯一正确并展示了算法且提及了Knuth的答案,这应该是被接受的回答。 - Johan Lundberg
对于最近编辑此答案的贡献者,我不得不拒绝你们的编辑,因为我认为它是不正确的。该编辑删除了 push 方法中的 n == 1 特殊情况,但我认为这种情况在使用 clear() 方法后产生正确的结果是必需的,我怀疑你们忽视了这一点。 - Marc Liyanage

26

也许不是你所问的,但是...如果你使用一个NumPy数组,它会高效地完成工作:

from numpy import array

nums = array(((0.01, 0.01, 0.02, 0.04, 0.03),
              (0.00, 0.02, 0.02, 0.03, 0.02),
              (0.01, 0.02, 0.02, 0.03, 0.02),
              (0.01, 0.00, 0.01, 0.05, 0.03)))

print nums.std(axis=1)
# [ 0.0116619   0.00979796  0.00632456  0.01788854]

print nums.mean(axis=1)
# [ 0.022  0.018  0.02   0.02 ]

19

Python runstats Module适用于这种情况。从PyPI安装runstats

pip install runstats

Runstats摘要可以在一次数据通行中生成均值、方差、标准差、偏度和峰度。我们可以利用这个来创建您的“运行”版本。

from runstats import Statistics

stats = [Statistics() for num in range(len(data[0]))]

for row in data:

    for index, val in enumerate(row):
        stats[index].push(val)

    for index, stat in enumerate(stats):
        print 'Index', index, 'mean:', stat.mean()
        print 'Index', index, 'standard deviation:', stat.stddev()

统计摘要基于Knuth和Welford方法,在一次遍历中计算标准偏差,如《计算机程序设计艺术》第2卷第232页第3版所述。这样做的好处是获得数值稳定和准确的结果。

免责声明:我是Python runstats模块的作者。


1
不错的模块。如果有一个“统计”方法有一个“.pop”方法,那么滚动统计也可以被计算出来,这将会很有趣。 - Gustavo Bezerra
@GustavoBezerra “runstats” 不维护内部值列表,所以我不确定是否有可能。但欢迎提交拉取请求。 - GrantJ

9

Statistics::Descriptive 是一个非常不错的 Perl 模块,可以用于这些类型的计算:

#!/usr/bin/perl

use strict; use warnings;

use Statistics::Descriptive qw( :all );

my $data = [
    [ 0.01, 0.01, 0.02, 0.04, 0.03 ],
    [ 0.00, 0.02, 0.02, 0.03, 0.02 ],
    [ 0.01, 0.02, 0.02, 0.03, 0.02 ],
    [ 0.01, 0.00, 0.01, 0.05, 0.03 ],
];

my $stat = Statistics::Descriptive::Full->new;
# You also have the option of using sparse data structures

for my $ref ( @$data ) {
    $stat->add_data( @$ref );
    printf "Running mean: %f\n", $stat->mean;
    printf "Running stdev: %f\n", $stat->standard_deviation;
}
__END__

输出:

Running mean: 0.022000
Running stdev: 0.013038
Running mean: 0.020000
Running stdev: 0.011547
Running mean: 0.020000
Running stdev: 0.010000
Running mean: 0.020000
Running stdev: 0.012566

__END__”是什么?它是否必要? - Peter Mortensen
我正在使用它来标记脚本的结尾。请参考“你应该知道的Perl标记”。 - Sinan Ünür

8

请查看PDL(发音为“piddle!”)。

这是Perl数据语言,专为高精度数学和科学计算而设计。

以下是使用您的数字的示例....

use strict;
use warnings;
use PDL;

my $figs = pdl [
    [0.01, 0.01, 0.02, 0.04, 0.03],
    [0.00, 0.02, 0.02, 0.03, 0.02],
    [0.01, 0.02, 0.02, 0.03, 0.02],
    [0.01, 0.00, 0.01, 0.05, 0.03],
];

my ( $mean, $prms, $median, $min, $max, $adev, $rms ) = statsover( $figs );

say "Mean scores:     ", $mean;
say "Std dev? (adev): ", $adev;
say "Std dev? (prms): ", $prms;
say "Std dev? (rms):  ", $rms;

这将产生:

Mean scores:     [0.022 0.018 0.02 0.02]
Std dev? (adev): [0.0104 0.0072 0.004 0.016]
Std dev? (prms): [0.013038405 0.010954451 0.0070710678 0.02]
Std dev? (rms):  [0.011661904 0.009797959 0.0063245553 0.017888544]

请查看 PDL::Primitive 了解更多关于 statsover 函数的信息。这似乎表明 ADEV 是“标准差”。
然而,它可能是 PRMS(如 Sinan 的 Statistics::Descriptive 示例所示)或 RMS(如 ars's NumPy example 所示)。我猜其中一个必须是正确的 ;-)
要获取更多 PDL 信息,请查看:

2
这不是一个正在运行的计算。 - Jake

3

我喜欢用这种方式来表达更新:

def running_update(x, N, mu, var):
    '''
        @arg x: the current data sample
        @arg N : the number of previous samples
        @arg mu: the mean of the previous samples
        @arg var : the variance over the previous samples
        @retval (N+1, mu', var') -- updated mean, variance and count
    '''
    N = N + 1
    rho = 1.0/N
    d = x - mu
    mu += rho*d
    var += rho*((1-rho)*d**2 - var)
    return (N, mu, var)

这样一来,单次函数看起来会像这样:
def one_pass(data):
    N = 0
    mu = 0.0
    var = 0.0
    for x in data:
        N = N + 1
        rho = 1.0/N
        d = x - mu
        mu += rho*d
        var += rho*((1-rho)*d**2 - var)
        # could yield here if you want partial results
   return (N, mu, var)

请注意,这里计算的是样本方差(1/N),而不是无偏估计的总体方差(使用1/(N-1)的标准化因子)。与其他答案不同的是,跟踪运行方差的变量var不会按比例随着样本数量增加而增长。它始终只是迄今为止所见样本的方差(在获取方差时没有最终的“除以n”步骤)。
在课堂上,它看起来像这样:
class RunningMeanVar(object):
    def __init__(self):
        self.N = 0
        self.mu = 0.0
        self.var = 0.0
    def push(self, x):
        self.N = self.N + 1
        rho = 1.0/N
        d = x-self.mu
        self.mu += rho*d
        self.var += + rho*((1-rho)*d**2-self.var)
    # reset, accessors etc. can be setup as you see fit

这也适用于加权样本:
def running_update(w, x, N, mu, var):
    '''
        @arg w: the weight of the current sample
        @arg x: the current data sample
        @arg mu: the mean of the previous N sample
        @arg var : the variance over the previous N samples
        @arg N : the number of previous samples
        @retval (N+w, mu', var') -- updated mean, variance and count
    '''
    N = N + w
    rho = w/N
    d = x - mu
    mu += rho*d
    var += rho*((1-rho)*d**2 - var)
    return (N, mu, var)

但是除法可能是一项非常昂贵的操作?也许可以在回答中解决这个问题?(但是不要包括“编辑:”、“更新:”或类似的内容——答案应该看起来像是今天写的。) - Peter Mortensen
@PeterMortensen 如果你的应用程序中,每个集合样本都需要消除一次除法运算才能使你的算法成功,那么你可能需要考虑采用不同的方法。 - Dave

3

除非你的数组有成千上万个元素,否则不用担心循环两次。 代码简单且易于测试。

我的偏好是使用NumPy 数组数学扩展将你的数组转换为 NumPy 二维数组并直接获得标准差:

>>> x = [ [ 1, 2, 4, 3, 4, 5 ], [ 3, 4, 5, 6, 7, 8 ] ] * 10
>>> import numpy
>>> a = numpy.array(x)
>>> a.std(axis=0)
array([ 1. ,  1. ,  0.5,  1.5,  1.5,  1.5])
>>> a.mean(axis=0)
array([ 2. ,  3. ,  4.5,  4.5,  5.5,  6.5])

如果这不是一个选择,你需要一个纯Python解决方案,请继续阅读...
如果您的数组是:
x = [
      [ 1, 2, 4, 3, 4, 5 ],
      [ 3, 4, 5, 6, 7, 8 ],
      ....
]

那么标准差就是:

d = len(x[0])
n = len(x)
sum_x = [ sum(v[i] for v in x) for i in range(d) ]
sum_x2 = [ sum(v[i]**2 for v in x) for i in range(d) ]
std_dev = [ sqrt((sx2 - sx**2)/N)  for sx, sx2 in zip(sum_x, sum_x2) ]

如果你决定只循环一次数组,那么可以合并运行的总和。
sum_x  = [ 0 ] * d
sum_x2 = [ 0 ] * d
for v in x:
   for i, t in enumerate(v):
   sum_x[i] += t
   sum_x2[i] += t**2

这个解决方案并不像上面的列表推导式那么优雅。

1
我确实需要处理大量数字,这正是激发我需求高效解决方案的动力。谢谢! - Alex Reynolds
1
它不是关于数据集有多大,而是关于我需要每秒对500个元素进行3500个不同标准差计算的频率。 - PirateApp

1

回答 Charlie Parker 的 2021 年问题:

我想要一个可以直接复制粘贴到我的 numpy 代码中的答案。我的输入是大小为 [N, 1] 的矩阵,其中 N 是数据点的数量,我已经计算出了运行平均值,并假设我们已经计算出了运行标准差/方差,如何更新新的数据批次。

这里有两种实现函数的方法,它们都需要原始平均值、原始方差和原始大小以及新样本,返回组合原始样本和新样本的总平均值和总方差(要获取标准差,只需使用 **(1/2) 取方差的平方根)。第一种方法使用 NumPy,第二种方法使用 Welford。您可以选择最适合您情况的方法。

def mean_and_variance_update_numpy(previous_mean, previous_var, previous_size, sample_to_append):
    if type(sample_to_append) is np.matrix:
        sample_to_append = sample_to_append.A1
    else:
        sample_to_append = sample_to_append.flatten()
    sample_to_append_mean = np.mean(sample_to_append)
    sample_to_append_size = len(sample_to_append)
    total_size = previous_size+sample_to_append_size
    total_mean = (previous_mean*previous_size+sample_to_append_mean*sample_to_append_size)/total_size
    total_var = (((previous_var+(total_mean-previous_mean)**2)*previous_size)+((np.var(sample_to_append)+(sample_to_append_mean-tm)**2)*sample_to_append_size))/total_size
    return (total_mean, total_var)

def mean_and_variance_update_welford(previous_mean, previous_var, previous_size, sample_to_append):
    if type(sample_to_append) is np.matrix:
        sample_to_append = sample_to_append.A1
    else:
        sample_to_append = sample_to_append.flatten()
    pos = previous_size
    mean = previous_mean
    v = previous_var*previous_size
    for value in sample_to_append:
        pos += 1
        mean_next = mean + (value - mean) / pos
        v = v + (value - mean)*(value - mean_next)
        mean = mean_next
    return (mean, v/pos)

让我们检查一下它是否有效:

import numpy as np

def mean_and_variance_udpate_numpy:
    ...
def mean_and_variance_udpate_welford:
    ...

# Making the samples and results deterministic
np.random.seed(0)

# Our initial sample has 100 samples, we want to append 10
n0, n1 = 100, 10

# Using np.matrix only, because it was in the question. 'np.array' is more common
s0 = np.matrix(1e3+np.random.random_sample(n0)*1e-3).T
s1 = np.matrix(1e3+np.random.random_sample(n1)*1e-3).T

# Precalculating our mean and var for initial sample:
s0mean, s0var = np.mean(s0), np.var(s0)

# Calculating mean and variance for s0+s1 using our NumPy updater
mean_and_variance_update_numpy(s0mean, s0var, len(s0), s1)
# (1000.0004826329636, 8.24577589696613e-08)

# Calculating mean and variance for s0+s1 using our Welford updater
mean_and_variance_update_welford(s0mean, s0var, len(s0), s1)
# (1000.0004826329634, 8.245775896913623e-08)

# Similar results, now checking with NumPy's calculation over the concatenation of s0 and s1
s0s1 = np.concatenate([s0,s1])
(np.mean(s0s1), np.var(s0s1))
# (1000.0004826329638, 8.245775896917313e-08)

这里的三个结果更接近:

# np(s0s1)        (1000.0004826329638, 8.245775896917313e-08)
# np(s0)updnp(s1) (1000.0004826329636, 8.245775896966130e-08)
# np(s0)updwf(s1) (1000.0004826329634, 8.245775896913623e-08)

可以发现结果非常相似。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接