在计算c属于[-1,1]时,稳定地计算sqrt((b²*c²) / (1-c²))的方法。

58

对于某些在[-1, 1]范围内的实数bc,我需要计算:

sqrt( (b²*c²) / (1-c²) ) = (|b|*|c|) / sqrt((1-c)*(1+c))

c接近1或-1时,分母会出现灾难性取消。平方根也可能没有帮助。

我想知道是否有巧妙的技巧可应用于此处,以避免在c=1和c=-1附近出现困难的情况?


19
您可以尝试使用 https://herbie.uwplse.org/,尽管之前的使用体验并不稳定。但至少它可以给您一些建议。 - alias
4
对于分母,计算 sqrt(1-c) * sqrt(1+c) 应该是相当稳定的数值方法。如果 c 接近于 1,那么根据 Sterbenz 引理,1 - c 是可以精确表示的,而且在 1 + c 中也不会产生灾难性的取消现象。同样地,如果 c 接近于 -11 + c 将被精确表示,而 1 - c 也是安全的。 - Mark Dickinson
1
也许当 |c| 接近于 1.0 时,可以使用 sqrt(0.5/(1-c)) - chux - Reinstate Monica
4
@jeff: “c”是从哪里来的?如果“c”本身被计算为类似于“1 + tiny”的某个值,那么重新将表达式表示为关于“tiny”的计算可能是解决问题的方法。 - Mark Dickinson
2
@MarkDickinson c 保持原本状态,因此用 tiny 表达不适用。 - jeff
显示剩余5条评论
2个回答

49
这个稳定性问题中最有趣的部分是分母,sqrt(1 - c*c)。对于这个问题,你只需要将其展开为sqrt(1 - c) * sqrt(1 + c)。我认为这并不真正算作是一个“巧妙的技巧”,但这就是需要的全部。
对于典型的二进制浮点格式(例如IEEE 754 binary64,但其他常见格式应该同样适用,可能会出现令人不愉快的double-double格式等不良情况),如果c接近于1,那么通过Sterbenz' Lemma1-c将被精确计算,而1+c没有任何稳定性问题。同样地,如果c接近于-1,那么1+c将被精确计算,1-c将被准确计算。平方根和乘法操作不会引入重大的新误差。
以下是一个数值演示,使用Python在具有IEEE 754 binary64浮点数和正确舍入的sqrt操作的机器上运行。
让我们取一个接近(但小于)1c
>>> c = float.fromhex('0x1.ffffffff24190p-1')
>>> c
0.9999999999

我们需要小心一点:请注意所显示的十进制值0.999999999是对c的精确值的一个近似值。精确值如十六进制字符串或分数形式562949953365017/562949953421312所示,我们关心的是获得这个精确值的好结果。
将表达式sqrt(1-c*c)的精确值舍入到小数点后100位,结果为:
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813

我使用Python的decimal模块进行了计算,并使用Pari/GP双重检查结果。以下是Python计算:

>>> from decimal import Decimal, getcontext
>>> getcontext().prec = 1000
>>> good = (1 - Decimal(c) * Decimal(c)).sqrt().quantize(Decimal("1e-100"))
>>> print(good)
0.0000141421362084401590649378320134409069878639187055610216016949959890888003204161068184484972504813

如果我们进行朴素计算,会得到以下结果:

>>> from math import sqrt
>>> naive = sqrt(1 - c*c)
>>> naive
1.4142136208793713e-05

我们可以轻松计算出ulps误差的近似数量(抱歉需要进行大量类型转换 - `float`和`Decimal`实例不能直接混合在算术运算中):
>>> from math import ulp
>>> float((Decimal(naive) - good) / Decimal(ulp(float(good))))
208701.28298527992

因此,天真的结果偏差了数十万个ulp-粗略地说,我们失去了大约5个小数位的准确性。

现在让我们尝试使用扩展版本:

>>> better = sqrt(1 - c) * sqrt(1 + c)
>>> better
1.4142136208440158e-05
>>> float((Decimal(better) - good) / Decimal(ulp(float(good))))
-0.7170147200803595

所以,我们的精度比1个ulp的误差更高。虽然不是完全正确舍入,但已经是下一个最好的选择了。
通过进一步的工作,应该可以在域-1 < c < 1上陈述并证明表达式sqrt(1 - c) * sqrt(1 + c)中ulp误差的绝对上界,假设采用IEEE 754二进制浮点数、四舍五入至偶数舍入模式和正确舍入操作。我没有这样做,但如果那个上界超过10 ulps,我会非常惊讶的。

8
sqrt(1-c)*sqrt(1+c)sqrt((1-c)*(1+c)) 更好吗? - chtz
4
@chtz: 啊,说得好。sqrt((1 - c) * (1 + c)) 应该更好,因为 sqrt 是一种收缩运算,往往会减小相对误差。我稍后会编辑一下。 - Mark Dickinson
如果我没记错的话,这个操作实际上非常稳定。使用双精度时,我得到了相对误差为1e-3,而在这些情况下c始终接近于1。我很快怀疑是sqrt(1-c²)的问题,但可能还有其他错误源!非常感谢您的答案,非常清晰!之前不知道Sterbenz引理。 - jeff

31
Mark Dickinson为一般情况提供了一个很好的答案,我会补充一些更专业的方法。现在许多计算环境都提供了称为融合乘加(FMA)的操作,专门设计用于这种情况。在计算fma(a, b, c)时,完整的乘积a * b(未截断和未舍入)进入与c相加的运算中,然后在最后应用单个舍入。
目前出货的GPU和CPU,包括基于ARM64、x86-64和Power架构的处理器,通常都包括快速硬件实现的FMA,并以C和C++系列等许多编程语言作为标准数学函数fma()公开。一些--通常是旧的--软件环境使用FMA的软件仿真,其中一些仿真已经发现存在错误。此外,这种仿真往往非常慢。
如果有FMA可用,则可以稳定地计算表达式,而不会出现过早的溢出和下溢风险,如fabs(b*c)/sqrt(fma(c,-c,1.0)),其中fabs()是浮点操作数的绝对值操作,sqrt()计算平方根。一些环境还提供了倒数平方根运算,通常称为rsqrt(),在这种情况下,可能的替代方法是使用fabs(b*c)*rsqrt(fma(c,-c,1.0))。使用rsqrt()避免了相对较昂贵的除法,因此通常更快。然而,rsqrt()的许多实现没有像sqrt()那样正确舍入,因此精度可能会稍差。
通过下面的代码进行快速实验似乎表明,基于FMA的表达式的最大误差约为3 ulps,只要b是一个正常的浮点数。我强调这并不证明任何错误界限。自动化Herbie工具尝试找到给定浮点表达式的数值优势重写,建议使用fabs(b*c)*sqrt(1.0/fma(c,-c,1.0))。然而,这似乎是一个虚假的结果,因为我既不能想到任何特殊的优点,也找不到实验结果。
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>

#define USE_ORIGINAL  (0)
#define USE_HERBIE    (1)

/* function under test */
float func (float b, float c)
{
#if USE_HERBIE
     return fabsf (b * c) * sqrtf (1.0f / fmaf (c, -c, 1.0f));
#else USE_HERBIE
     return fabsf (b * c) / sqrtf (fmaf (c, -c, 1.0f));
#endif // USE_HERBIE
}

/* reference */
double funcd (double b, double c)
{
#if USE_ORIGINAL
    double b2 = b * b;
    double c2 = c * c;
    return sqrt ((b2 * c2) / (1.0 - c2));
#else
    return fabs (b * c) / sqrt (fma (c, -c, 1.0));
#endif
}

uint32_t float_as_uint32 (float a)
{
    uint32_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

float uint32_as_float (uint32_t a)
{
    float r;
    memcpy (&r, &a, sizeof r);
    return r;
}

uint64_t double_as_uint64 (double a)
{
    uint64_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

double floatUlpErr (float res, double ref)
{
    uint64_t i, j, err, refi;
    int expoRef;
    
    /* ulp error cannot be computed if either operand is NaN, infinity, zero */
    if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
        (res == 0.0f) || (ref == 0.0f)) {
        return 0.0;
    }
    /* Convert the float result to an "extended float". This is like a float
       with 56 instead of 24 effective mantissa bits.
    */
    i = ((uint64_t)float_as_uint32(res)) << 32;
    /* Convert the double reference to an "extended float". If the reference is
       >= 2^129, we need to clamp to the maximum "extended float". If reference
       is < 2^-126, we need to denormalize because of the float types's limited
       exponent range.
    */
    refi = double_as_uint64(ref);
    expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
    if (expoRef >= 129) {
        j = 0x7fffffffffffffffULL;
    } else if (expoRef < -126) {
        j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
        j = j >> (-(expoRef + 126));
    } else {
        j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
        j = j | ((uint64_t)(expoRef + 127) << 55);
    }
    j = j | (refi & 0x8000000000000000ULL);
    err = (i < j) ? (j - i) : (i - j);
    return err / 4294967296.0;
}

// Fixes via: Greg Rose, KISS: A Bit Too Simple. http://eprint.iacr.org/2011/007
static unsigned int z=362436069,w=521288629,jsr=362436069,jcong=123456789;
#define znew (z=36969*(z&0xffff)+(z>>16))
#define wnew (w=18000*(w&0xffff)+(w>>16))
#define MWC  ((znew<<16)+wnew)
#define SHR3 (jsr^=(jsr<<13),jsr^=(jsr>>17),jsr^=(jsr<<5)) /* 2^32-1 */
#define CONG (jcong=69069*jcong+13579)                     /* 2^32 */
#define KISS ((MWC^CONG)+SHR3)

#define N  (20)

int main (void)
{
    float b, c, errloc_b, errloc_c, res;
    double ref, err, maxerr = 0;
    
    c = -1.0f;
    while (c <= 1.0f) {
        /* try N random values of `b` per every value of `c` */
        for (int i = 0; i < N; i++) {
            /* allow only normals */
            do {
                b = uint32_as_float (KISS);
            } while (!isnormal (b));
            res = func (b, c);
            ref = funcd ((double)b, (double)c);
            err = floatUlpErr (res, ref);
            if (err > maxerr) {
                maxerr = err;
                errloc_b = b;
                errloc_c = c;
            }
        }
        c = nextafterf (c, INFINITY);
    }
#if USE_HERBIE
    printf ("HERBIE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#else // USE_HERBIE
    printf ("SIMPLE max ulp err = %.5f @ (b=% 15.8e c=% 15.8e)\n", maxerr, errloc_b, errloc_c);
#endif // USE_HERBIE
    
    return EXIT_SUCCESS;
}

然而,这似乎是一个虚假的结果,因为我既无法想到任何特定的优点,也无法在实验中找到任何优点。Herbie可能只是不支持rsqrt - orlp
@orlp 我的评估是在考虑到Herbie不知道rsqrt的情况下进行的。如果您知道为什么像Herbie建议的fabs(b*c)*sqrt(1.0/fma(c,-c,1.0));fabs(b*c)/sqrt(fma(c,-c,1.0))更优秀,我很乐意跟进并相应地更新我的答案。 - njuffa
谢谢你这个富有洞见的回答!我之前听说过fma,但从未在实践中使用过。从未意识到它在这种情况下是适用的。 - jeff
5
我认为难以言过其实,fma对于准确计算的普适性非常高。最明显的优点是在典型情况下只舍入一次而不是两次,但在我看来影响更大的是,在浮点数精度失准的普遍原因中,一个是相对误差的影响,另一个是将其转化为绝对误差的作用。fma处理了这类问题的大类。 - Oscar Smith

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接