正确舍入计算两个浮点数之和的平方根并处理溢出

3
有没有一种好的方法来计算正确舍入的结果?
sqrt(a+b)

如何计算浮点数 ab(精度相同),其中 0<=a<+inf0<=b<+inf

特别地,对于输入值,使得计算 a+b 会溢出的情况怎么处理?

这里“正确舍入”指的是与计算 sqrt 的结果相同,即返回最接近在无限精度下计算的“真实”结果的可表示值。

(注意:一种明显的方法是在更大的浮点大小中进行计算并避免溢出。然而,这通常不适用(例如,如果没有支持更大的浮点格式)。)

我已经尝试使用Herbie,但它完全放弃了。它似乎没有采样足够多的 a+b 溢出的点来检测问题,并且也似乎不能很好地处理相关采样。这很不幸,因为它通常是一个很棒的工具。


到目前为止,我一直在做以下操作(伪代码)

if a + b would overflow:
    2*sqrt(a/4 + b/4) # Cannot overflow for finite inputs, as f::MAX/4 + f::MAX/4 <= f::MAX
else:
    ... # handle non-overflow case. Also interesting; not quite the topic of this question.

...这似乎在实践中大多数情况下都能工作,但是a)它没有明确的原则,b)在实践中偶尔会返回一个由于溢出避免而误差为 epsilon 的结果(例如,真实结果为x + 0.2(x.next_larger()-x),但这里返回了x.next_larger()而不是 x)。

关于 f32 中“离 epsilon 偏差”的快速示例:

>>> import decimal
>>> decimal.getcontext().prec = 256
>>> from decimal import Decimal as D
>>> from numpy import float32 as f32
>>> a = D(f32("6.0847234e31").astype(float))
>>> b = D(f32("3.4028235e38").astype(float))
>>> res_act = (a+b).sqrt()
>>> res_calc = D(f32("1.8446744e19").astype(float)) # 2*sqrt(a/4 + b/4) in f32 precision
>>> res_best = D(f32("1.8446746e19").astype(float)) # obtained by brute-force
>>> abs(res_calc - res_act) > abs(res_best - res_act)
True # oops

(你需要相信我所计算的结果是f32精度,因为Python通常使用f64精度。这也是为什么要使用f32的原因。)

3
你是否希望处理当sqrt(a)接近舍入上限时的情况,而添加一个非常小的b会使“精确”的sqrt(a+b)舍入到更大的下一个值?你是否有一个可用的精确融合乘加操作? - undefined
如果a+b会溢出,那么将每个数除以浮点格式的基数的平方(对于二进制来说是4),然后取平方根,并乘以基数。这样做不会引入新的舍入误差,尽管仍然存在一个问题,即由于存在两个舍入误差,sqrt(a+b)可能与理想计算结果不同。 - undefined
3
需要正确舍入的结果吗?在中间计算中使用双精度技术(例如Dekker算法)可以将最大误差控制在0.5个ULP(例如0.5000001 ULP)左右。然而,如果需要确切舍入结果,则需要使用三精度技术,据我所知。 - undefined
@chtz - 是的,是的。 - undefined
@EricPostpischil - 对的。所以你的意思是问题不在于2*sqrt(a/4 + b/4)这部分,而是在于sqrt(a+b)本身。 - undefined
3个回答

4
溢出可以通过适当的二次幂缩放来轻松避免,使具有大量幅度的参数缩放到接近1。难点在于产生正确舍入的结果。我甚至不能完全确信在下一个更大的IEEE-754二进制浮点类型中执行中间计算会保证这一点,由于可能存在双重舍入问题。
在没有更广泛的浮点类型的情况下,必须退回到将多个本地精度数字链接在一起以执行具有更高中间精度的操作的方法。由Dekker提出的常见方案称为配对精度。它使用浮点数对,其中更重要的部分通常称为“头部”,不太重要的部分称为“尾部”。这两部分被归一化,使得尾部的大小最多为头部大小的半个ulp。
在此方案中,有效有效位数为2 * p + 1,其中p是基础浮点类型中的有效位数。 "额外"位由尾部的符号位表示。重要的是要注意,指数范围与底层基本类型相比不变,因此我们需要积极地朝向1进行缩放,以避免在中间计算中遇到次正常操作数。 配对精度计算无法保证正确舍入的结果。使用三元组可能有效,但需要更多的工作量,我无法投入答案。
然而,配对精度可以提供忠实舍入并几乎总是正确舍入的结果。当FMA(合并乘加)可用时,可以相当有效地构建基于牛顿-拉弗森的配对精度平方根,产生约2 * p-1个好位。这是我在下面的示例ISO-C99代码中使用的,它将float映射到IEEE-754 binary32作为本机浮点类型。应该以最高的IEEE-754标准编译配对精度代码,以防止浮点操作顺序的意外偏差。在我的情况下,我使用MSVC 2019的/fp:strict命令行开关。
通过数十亿个随机测试向量,我的测试程序报告了最大误差为0.500000179 ulp。
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <math.h>

/* compute square root of sum of two positive floating-point numbers */
float sqrt_sum_pos (float a, float b)
{
    float mn, mx, res, scale_in, scale_out;
    float r, s, t, u, v, w, x;

    /* sort arguments according to magnitude */
    mx = a < b ? b : a;
    mn = a < b ? a : b; 

    /* select scale factor: scale argument larger in magnitude towards unity */
    scale_in  = (mx > 1.0f) ? 0x1.0p-64f : 0x1.0p+64f;
    scale_out = (mx > 1.0f) ? 0x1.0p+32f : 0x1.0p-32f;

    /* scale input arguments */
    mn = mn * scale_in;
    mx = mx * scale_in;

    /* represent sum as a normalized pair s:t of 'float' */
    s = mx + mn;        // most significant bits
    t = (mx - s) + mn;  // least significant bits

    /* compute square root of s:t. Based on Alan Karp and Peter Markstein,
       "High Precision Division and Square Root", ACM TOMS, vol. 23, no. 4, 
       December 1997, pp. 561-589 
    */
    r = sqrtf (1.0f / s);
    if (s == 0.0f) r = 0.0f;
    x = r * s;
    s = fmaf (x, -x, s);
    r = 0.5f * r;
    u = s + t;
    v = (s - u) + t;
    s = r * u;
    t = fmaf (r, u, -s);
    t = fmaf (r, v, t);
    r = x + s;
    s = (x - r) + s;
    s = s + t;
    t = r + s;
    s = (r - t) + s;
    
    /* Component sum of t:s represents square root with maximum error very close to 0.5 ulp */
    w = s + t;

    /* compensate scaling of source operands */
    res = w * scale_out;

    /* handle special cases: NaN, Inf */
    t = a + b;
    if (isinf (mx)) res = mx;
    if (isnan (t)) res = t;

    return res;
}

// George Marsaglia's KISS PRNG, period 2**123. Newsgroup sci.math, 21 Jan 1999
// Bug fix: Greg Rose, "KISS: A Bit Too Simple" http://eprint.iacr.org/2011/007
static uint32_t kiss_z=362436069, kiss_w=521288629;
static uint32_t kiss_jsr=123456789, kiss_jcong=380116160;
#define znew (kiss_z=36969*(kiss_z&65535)+(kiss_z>>16))
#define wnew (kiss_w=18000*(kiss_w&65535)+(kiss_w>>16))
#define MWC  ((znew<<16)+wnew )
#define SHR3 (kiss_jsr^=(kiss_jsr<<13),kiss_jsr^=(kiss_jsr>>17), \
              kiss_jsr^=(kiss_jsr<<5))
#define CONG (kiss_jcong=69069*kiss_jcong+1234567)
#define KISS ((MWC^CONG)+SHR3)

uint32_t float_as_uint32 (float a)
{
    uint32_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

uint64_t double_as_uint64 (double a)
{
    uint64_t r;
    memcpy (&r, &a, sizeof r);
    return r;
}

float uint32_as_float (uint32_t a)
{
    float r;
    memcpy (&r, &a, sizeof r);
    return r;
}

double floatUlpErr (float res, double ref)
{
    uint64_t i, j, err, refi;
    int expoRef;
    
    /* ulp error cannot be computed if either operand is NaN, infinity, zero */
    if (isnan (res) || isnan (ref) || isinf (res) || isinf (ref) ||
        (res == 0.0f) || (ref == 0.0f)) {
        return 0.0;
    }
    /* Convert the float result to an "extended float". This is like a float
       with 56 instead of 24 effective mantissa bits
    */
    i = ((uint64_t) float_as_uint32 (res)) << 32;
    /* Convert the double reference to an "extended float". If the reference is
       >= 2^129, we need to clamp to the maximum "extended float". If reference
       is < 2^-126, we need to denormalize because of float's limited exponent
       range.
    */
    refi = double_as_uint64 (ref);
    expoRef = (int)(((refi >> 52) & 0x7ff) - 1023);
    if (expoRef >= 129) {
        j = 0x7fffffffffffffffULL;
    } else if (expoRef < -126) {
        j = ((refi << 11) | 0x8000000000000000ULL) >> 8;
        j = j >> (-(expoRef + 126));
    } else {
        j = ((refi << 11) & 0x7fffffffffffffffULL) >> 8;
        j = j | ((uint64_t)(expoRef + 127) << 55);
    }
    j = j | (refi & 0x8000000000000000ULL);
    err = (i < j) ? (j - i) : (i - j);
    return err / 4294967296.0;
}

int main (void)
{
    float arga, argb, res, reff;
    uint32_t argai, argbi, resi, refi, diff;
    double ref, ulp, maxulp = 0;
    unsigned long long int count = 0;
    
    do {
        /* random positive inputs */
        argai = KISS & 0x7fffffff;
        argbi = KISS & 0x7fffffff;

        /* increase occurence of zero, infinity */
        if ((argai & 0xffff) == 0x5555) argai = 0x00000000;
        if ((argbi & 0xffff) == 0x3333) argbi = 0x00000000;
        if ((argai & 0xffff) == 0xaaaa) argai = 0x7f800000;
        if ((argbi & 0xffff) == 0xcccc) argbi = 0x7f800000;

        arga = uint32_as_float (argai);
        argb = uint32_as_float (argbi);
        res = sqrt_sum_pos (arga, argb);
        ref = sqrt ((double)arga + (double)argb);
        reff = (float)ref;
        ulp = floatUlpErr (res, ref);
        resi = float_as_uint32 (res);
        refi = float_as_uint32 (reff);
        diff = (refi > resi) ? (refi - resi) : (resi - refi);
        if (diff > 1) {
            /* if both source operands were NaNs, result could be either NaN,
               quietened if necessary
            */
            if (!(isnan (arga) && isnan (argb) && 
                  ((resi == (argai | 0x00400000)) || 
                   (resi == (argbi | 0x00400000))))) {
                printf ("\rerror: refi=%08x  resi=%08x  a=% 15.8e %08x  b=% 15.8e %08x\n", 
                        refi, resi, arga, argai, argb, argbi);
                return EXIT_FAILURE;
            }
        }
        if (ulp > maxulp) {
            printf ("\rulp = %.9f @ a=%14.8e (%15.6a)  b=%14.8e (%15.6a) a+b=%22.13a  res=%15.6a  ref=%22.13a\n", 
                    ulp, arga, arga, argb, argb, (double)arga + argb, res, ref);
            maxulp = ulp;
        }
        count++;
        if (!(count & 0xffffff)) printf ("\r%llu", count);
    } while (1);
    printf ("\ntest passed\n");
    return EXIT_SUCCESS;
}

哇,好棒的回答。特别感谢提供的参考和评论。我需要花点时间来理解这个。那个0.000000179...真是可惜。 - undefined

1

现在有一种替代方法,@EricPostpischil和@njuffa指出了实际问题(即双重舍入),

(注意:以下讨论的是“表现良好”的数字。它不考虑精度边界或子规格,但可以扩展到这样做。)

首先,请注意sqrt(x)a+b都保证返回最接近结果的可表示值。问题在于双重舍入。也就是说,当我们想要计算round(sqrt(a+b))时,实际上我们正在计算round(sqrt(round(a+b)))。请注意缺少内部舍入。

那么,内部舍入会对结果产生多大影响?嗯,内部舍入会将加法结果的ULP增加或减少0.5。因此,假设有一个p位尾数,我们大约有sqrt((a+b)*(1 ±2**-p))

这可以简化为sqrt(a+b)*sqrt(1 ±2**-p)...但是sqrt(1 ±2**-p)(1 ±2**-p)更接近1!(它很接近,但不完全等于(1 ±2**-(p+1)),因为这是一个有限差分。您可以从围绕1的泰勒级数中看到这一点(d/dx = 1/2)。)第二次舍入将使结果再受到±0.5ULP的影响。
这意味着我们保证不会偏离“真实”结果超过1 ULP。因此,如果我们能“仅仅”找出如何选择,那么在{sqrt(a+b)-1ULP,sqrt(a+b),sqrt(a+b)+1ULP}之间选择的修复策略是可行的。
因此,让我们看看是否可以提出一种在有限精度下有效的基于比较的方法。(注意:以下内容在无限精度下,除非另有说明)
resy = float(sqrt(a+b))
resx = resy.prev_nearest()
resz = resy.next_nearest()

请注意resx < resy < resz
假设我们的浮点数有p位精度,则变为:
res = sqrt(a+b) // in infinite precision
resy = float(res)
resx = resy * (1 - 2**(1-p))
resz = resy * (1 + 2**(1-p))

让我们暂时比较一下 resxresy

distx = abs(resx - res)
disty = abs(resy - res)

checkxy: distx < disty
checkxy: abs(resx - res) < abs(resy - res)
checkxy: (resx - res)**2 < (resy - res)**2
checkxy: resx**2 - 2*resx*res - res**2 < resy**2 - 2*resy*res - res**2
checkxy: resx**2 - resy**2 < 2*resx*res - 2*resy*res
checkxy: resx**2 - resy**2 < 2*res*(resx - resy)
// Assuming resx < resy
checkxy: resx+resy > 2*res
checkxy: resx+resy > 2*sqrt(a+b)
// Assuming resx+resy >= 0
checkxy: (resx+resy)**2 > 4*(a+b)
checkxy: (resy*(2 - 2**(1-p)))**2 > 4*(a+b)
checkxy: (resy**2)*((2 - 2**(1-p)))**2 > 4*(a+b)
checkxy: (resy**2)*(4 - 2*2**(1-p) + 2**(2-2p)) > 4*(a+b)
checkxy: (resy**2)*(4 - 4*2**(0-p) + 4*2**(0-2p)) > 4*(a+b)
checkxy: (resy**2)*(1 - 2**-p + 2**-2p) > a+b

...这是我们可以在有限精度下实际进行的检查(尽管仍需要更高的精度,这很麻烦)。

同样地,对于checkyz,我们得到

checkxy: disty < distz
checkyz: (resy**2)*(1 + 2**-p + 2**-2p) < a+b

从这两个检查中,您可以选择正确的结果。...然后只需要检查/处理我上面忽略的边缘情况就可以了。
现在,在实践中,我认为与一开始就以更高的精度进行sqrt相比,这并不值得,除非有人能想出更好的选择方法。但这仍然是一个有趣的替代方案。

0

这里有一个极端的例子。让我们假设 u = 2^-p,其中 p 是浮点精度。

我们有 (1+u)^2 = (1+2u) + u^2

如果我们取 a = 1+2u,我们有 float(a)=a,a 可以用浮点数表示(它是 1 后面的下一个浮点数),并且 b= u^2float(b)=b,b 也可以用浮点数表示(作为 2^(-2p) 的幂)。

精确的 sqrt(a+b)(1+u),应该四舍五入为 float(1+u)=1,由于精确相等,它会向最接近的偶数有效数字下舍入...

float(a+b)=afloat(sqrt(a))=1,所以没问题。

但是让我们改变 b 的最后一位:b=(1+2*u)*u^2float(b)=b,b 只是按两倍精度缩小的。

现在我们有了精确的 sqrt(a+b) > 1+u,因此它应该四舍五入为 float(sqrt(a+b)) = 1+2u

我们看到,到2^(-3p+1)位(三倍于浮点数精度)的一点变化就可以改变正确的舍入!
这意味着您不应该依赖双精度来执行正确舍入的操作。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接