快速、无分支的无符号整数绝对差

20

我有一个程序,它大部分时间都在计算RGB值之间的欧几里得距离(三元组无符号8位Word8)。我需要一种快速、无分支的无符号整数绝对差函数,使得

unsigned_difference :: Word8 -> Word8 -> Word8
unsigned_difference a b = max a b - min a b

特别地,

unsigned_difference a b == unsigned_difference b a

我想到了以下内容,使用了GHC 7.8中的新原语:

-- (a < b) * (b - a) + (a > b) * (a - b)
unsigned_difference (I# a) (I# b) =
    I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))]

ghc -O2 -S编译后的结果是什么?

.Lc42U:
    movq 7(%rbx),%rax
    movq $ghczmprim_GHCziTypes_Izh_con_info,-8(%r12)
    movq 8(%rbp),%rbx
    movq %rbx,%rcx
    subq %rax,%rcx
    cmpq %rax,%rbx
    setg %dl
    movzbl %dl,%edx
    imulq %rcx,%rdx
    movq %rax,%rcx
    subq %rbx,%rcx
    cmpq %rax,%rbx
    setl %al
    movzbl %al,%eax
    imulq %rcx,%rax
    addq %rdx,%rax
    movq %rax,(%r12)
    leaq -7(%r12),%rbx
    addq $16,%rbp
    jmp *(%rbp)

使用ghc -O2 -fllvm -optlo -O3 -S编译会生成以下汇编代码:

.LBB6_1:
    movq    7(%rbx), %rsi
    movq    $ghczmprim_GHCziTypes_Izh_con_info, 8(%rax)
    movq    8(%rbp), %rcx
    movq    %rsi, %rdx
    subq    %rcx, %rdx
    xorl    %edi, %edi
    subq    %rsi, %rcx
    cmovleq %rdi, %rcx
    cmovgeq %rdi, %rdx
    addq    %rcx, %rdx
    movq    %rdx, 16(%rax)
    movq    16(%rbp), %rax
    addq    $16, %rbp
    leaq    -7(%r12), %rbx
    jmpq    *%rax  # TAILCALL
LLVM可以用(更高效?)的条件移动指令替换比较。不幸的是,使用-fllvm编译对我的程序运行时间几乎没有影响。
然而,这个函数存在两个问题。
  • 我想比较Word8,但比较primops需要使用Int。这会导致不必要的分配,因为我被迫存储64位的Int而不是Word8

我进行了分析并确认fromIntegral :: Word8 -> Int的使用占程序总分配量的42.4%。

  • 我的版本使用2个比较、2个乘法和2个减法。我想知道是否有一种更有效率的方法,使用位运算或SIMD指令,并利用我正在比较的Word8

之前我标记了问题C/C ++以吸引那些更倾向于位操作的人的注意。我的问题使用Haskell,但我将接受任何语言实现的正确方法的答案。

结论:

我决定使用

w8_sad :: Word8 -> Word8 -> Int16
w8_sad a b = xor (diff + mask) mask
    where diff = fromIntegral a - fromIntegral b
          mask = unsafeShiftR diff 15

我决定使用标量版本,因为它比原来的unsigned_difference函数更快,而且实现简单。尽管Haskell中的SIMD内在函数速度更快,但这些函数还没有成熟。


@cdk:你可以看一下例如(http://bytes.com/topic/c/answers/212935-bitwise-absolute-value)这样的网站(它基本上使用符号位来计算掩码),但很可能对于整数基础的“primops”并没有帮助。 :-( - Cheers and hth. - Alf
4
@cdk 我怀疑最好的答案是上一级。请解释代码为什么需要RGB之间的欧几里得距离,以及该值如何使用。也许欧几里得距离的平方足以满足要求。 (a - b)*(a - b) - chux - Reinstate Monica
2
@cdk:我猜chux的意思是加法和乘法在模2^8下形成一个环,所以(a-b)*(a-b) = (b-a)*(b-a) - Niklas B.
1
为什么计算欧几里得距离需要 abs 函数?你只有在处理最大规范之类的问题时才需要使用 abs - nwellnhof
3
x86 SSE2具有psadbw指令,可为8个Word8 SAD操作提供总和。 因此,如果您将2个输入字节零扩展到XMM寄存器中,则psadbw会执行所需的操作。 它专门为视频编解码器执行多个像素块运动搜索而设计,但对于您的用例,您可以使用SSE4.1的pmovzxbq将2个字节加载到2个qword中以并行检查2个像素组件。 同时使用 pmuludq 来平方这2个结果。 我不知道如何让Haskell编译器发出它,我完全不懂Haskell。 - Peter Cordes
显示剩余13条评论
3个回答

8

嗯,我尝试进行一些基准测试。 我使用Criterion进行基准测试,因为它可以进行适当的显著性测试。 我还在这里使用QuickCheck来确保所有方法返回相同的结果。

我使用GHC 7.6.3进行编译(所以不幸的是我无法包含您的primops函数),并使用-O3

ghc -O3 AbsDiff.hs -o AbsDiff && ./AbsDiff

主要我们可以看到一个天真的实现和一点摆弄之间的区别:

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 a b = max a b - min a b

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 a b = unsafeCoerce $ xor (v + mask) mask
  where v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        mask = unsafeShiftR v 63

输出:

benchmarking absdiff_Word8/1
mean: 249.8591 us, lb 248.1229 us, ub 252.4321 us, ci 0.950
....

benchmarking absdiff_Word8/2
mean: 202.5095 us, lb 200.8041 us, ub 206.7602 us, ci 0.950
...

我使用了“位操作技巧”中的绝对整数值技巧。不幸的是,我们需要进行强制转换,我认为在Word8领域内很难解决这个问题,但是使用本机整数类型似乎是合理的(当然没有必要创建堆对象)。
虽然看起来差别不大,但我的测试设置也不完美:我正在对大量随机值的列表上映射函数,以排除分支预测使分支版本看起来比实际更有效的情况。这会导致thunk在内存中积累,这可能会极大地影响时间。当我们减去维护列表的常量开销时,我们可能会看到超过20%的加速。
生成的汇编代码实际上非常好(这是函数的内联版本)。
.Lc4BB:
    leaq 7(%rbx),%rax
    movq 8(%rbp),%rbx
    subq (%rax),%rbx
    movq %rbx,%rax
    sarq $63,%rax
    movq $base_GHCziInt_I64zh_con_info,-8(%r12)
    addq %rax,%rbx
    xorq %rax,%rbx
    movq %rbx,0(%r12)
    leaq -7(%r12),%rbx
    movq $s4z0_info,8(%rbp)

1 减法,1 加法,1 右移,1 异或,并且没有分支,如预期所述。使用LLVM后端并没有显著改善运行时间。

如果您想尝试更多内容,希望这对您有用。

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where

import Data.Word
import Data.Int
import Data.Bits
import Control.Arrow ((***))
import Control.DeepSeq (force)
import Control.Exception (evaluate)
import Control.Monad
import System.Random
import Unsafe.Coerce

import Test.QuickCheck hiding ((.&.))
import Criterion.Main

absdiff1_w8 :: Word8 -> Word8 -> Word8
absdiff1_w8 !a !b = max a b - min a b

absdiff1_int16 :: Int16 -> Int16 -> Int16
absdiff1_int16 a b = max a b - min a b

absdiff1_int :: Int -> Int -> Int
absdiff1_int a b = max a b - min a b

absdiff2_int16 :: Int16 -> Int16 -> Int16
absdiff2_int16 a b = xor (v + mask) mask
  where v = a - b
        mask = unsafeShiftR v 15

absdiff2_w8 :: Word8 -> Word8 -> Word8
absdiff2_w8 !a !b = unsafeCoerce $ xor (v + mask) mask
  where !v = (unsafeCoerce a::Int64) - (unsafeCoerce b::Int64)
        !mask = unsafeShiftR v 63

absdiff3_w8 :: Word8 -> Word8 -> Word8
absdiff3_w8 a b = if a > b then a - b else b - a

{-absdiff4_int :: Int -> Int -> Int-}
{-absdiff4_int (I# a) (I# b) =-}
    {-I# ((a <# b) *# (b -# a) +# (a ># b) *# (a -# b))-}

e2e :: (Enum a, Enum b) => a -> b
e2e = toEnum . fromEnum

prop_same1 x y = absdiff1_w8 x y == absdiff2_w8 x y
prop_same2 (x::Word8) (y::Word8) = absdiff1_int16 x' y' == absdiff2_int16 x' y'
    where x' = e2e x
          y' = e2e y

check = quickCheck prop_same1
     >> quickCheck prop_same2

instance (Random x, Random y) => Random (x, y) where
  random gen1 =
    let (x, gen2) = random gen1
        (y, gen3) = random gen2
    in ((x,y),gen3)

main =
    do check
       !pairs_w8 <- fmap force $ replicateM 10000 (randomIO :: IO (Word8,Word8))
       let !pairs_int16 = force $ map (e2e *** e2e) pairs_w8
       defaultMain
         [ bgroup "absdiff_Word8" [ bench "1" $ nf (map (uncurry absdiff1_w8)) pairs_w8
                                  , bench "2" $ nf (map (uncurry absdiff2_w8)) pairs_w8
                                  , bench "3" $ nf (map (uncurry absdiff3_w8)) pairs_w8
                                  ]
         , bgroup "absdiff_Int16" [ bench "1" $ nf (map (uncurry absdiff1_int16)) pairs_int16
                                  , bench "2" $ nf (map (uncurry absdiff2_int16)) pairs_int16
                                  ]
         {-, bgroup "absdiff_Int"   [ bench "1" $ whnf (absdiff1_int 13) 14-}
                                  {-, bench "2" $ whnf (absdiff3_int 13) 14-}
                                  {-]-}
         ]

4
如果你的目标系统支持SSE指令集,那么使用它可以获得良好的性能提升。我已经测试过其他发布的方法并发现这是最快的方法。
大量数值比较的示例结果如下:
diff0: 188.020679 ms // branching
diff1: 118.934970 ms // max min
diff2: 97.087710 ms  // branchless mul add
diff3: 54.495269 ms  // branchless signed
diff4: 31.159628 ms  // sse
diff5: 30.855885 ms  // sse v2

以下是我的完整测试代码。我使用了SSE2指令,这些指令现在在x86ish CPU中广泛使用,通过SSE内置函数,应该相当可移植(MSVC、GCC、Clang、Intel编译器等)。

注意事项:

  • 实际上,这个计算方法是先计算最大值,然后计算最小值,最后减去两者之差,但每次执行16个值的计算,每个指令都会执行。
  • diff5中展开循环貌似没有太大效果,但可能可以进行调整。
  • 对于最后15个或更少的值,当前的回退方法是在一个循环中使用带符号的技巧方法,但可能可以通过展开和/或SSE进一步加速。
  • 这些函数本身非常简单,因此它们应该很容易移植到任何具有SSE内置函数或asm的东西上。
  • 我使用了Windows特定的计时函数,因为std::chrono::high_resolution_clock在MSVC实现中精度较低,对于C/C++测试代码的混合使用感到抱歉。
  • 在测量性能之后,将根据参考分支实现测试结果,因此它们应该是正确的。

如果对代码或这种方法有任何问题/建议,请留下评论。

#include <cstdlib>
#include <cstdint>
#include <cstdio>
#include <cmath>
#include <random>
#include <algorithm>

#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <Windows.h>

#include <emmintrin.h> // sse2

// branching
void diff0(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = a[i] > b[i] ? a[i] - b[i] : b[i] - a[i];
    }
}

// max min
void diff1(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = std::max(a[i], b[i]) - std::min(a[i], b[i]);
    }
}

// branchless mul add
void diff2(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        res[i] = (a[i] > b[i]) * (a[i] - b[i]) + (a[i] < b[i]) * (b[i] - a[i]);
    }
}

// branchless signed
void diff3(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    for (std::size_t i = 0; i < n; i++) {
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

// sse
void diff4(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    auto pA = reinterpret_cast<const __m128i*>(a);
    auto pB = reinterpret_cast<const __m128i*>(b);
    auto pRes = reinterpret_cast<__m128i*>(res);
    std::size_t i = 0;
    for (std::size_t j = n / 16; j--; i++) {
        __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
    }
    for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

// sse v2
void diff5(const std::uint8_t* a, const std::uint8_t* b, std::uint8_t* res,
    std::size_t n)
{
    auto pA = reinterpret_cast<const __m128i*>(a);
    auto pB = reinterpret_cast<const __m128i*>(b);
    auto pRes = reinterpret_cast<__m128i*>(res);
    std::size_t i = 0;
    const std::size_t UNROLL = 2;
    for (std::size_t j = n / (16 * UNROLL); j--; i += UNROLL) {
        __m128i max0 = _mm_max_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
        __m128i min0 = _mm_min_epu8(_mm_load_si128(pA + i + 0), _mm_load_si128(pB + i + 0));
        __m128i max1 = _mm_max_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
        __m128i min1 = _mm_min_epu8(_mm_load_si128(pA + i + 1), _mm_load_si128(pB + i + 1));
        _mm_store_si128(pRes + i + 0, _mm_sub_epi8(max0, min0));
        _mm_store_si128(pRes + i + 1, _mm_sub_epi8(max1, min1));
    }
    for (std::size_t j = n % (16 * UNROLL) / 16; j--; i++) {
        __m128i max = _mm_max_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        __m128i min = _mm_min_epu8(_mm_load_si128(pA + i), _mm_load_si128(pB + i));
        _mm_store_si128(pRes + i, _mm_sub_epi8(max, min));
    }
    for (i *= 16; i < n; i++) { // fallback for the remaining <16 values
        std::int16_t  diff = a[i] - b[i];
        std::uint16_t mask = diff >> 15;
        res[i] = (diff + mask) ^ mask;
    }
}

int main() {
    const std::size_t ALIGN = 16; // sse requires 16 bit align
    const std::size_t N = 10 * 1024 * 1024 * 3;

    auto a = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));
    auto b = static_cast<uint8_t*>(_mm_malloc(N, ALIGN));

    { // fill with random values
        std::mt19937 engine(std::random_device{}());
        std::uniform_int<std::uint8_t> distribution(0, 255);
        for (std::size_t i = 0; i < N; i++) {
            a[i] = distribution(engine);
            b[i] = distribution(engine);
        }
    }

    auto res0 = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff0 results
    auto resX = static_cast<uint8_t*>(_mm_malloc(N, ALIGN)); // diff1+ results

    LARGE_INTEGER f, t0, t1;
    QueryPerformanceFrequency(&f);

    QueryPerformanceCounter(&t0);
    diff0(a, b, res0, N);
    QueryPerformanceCounter(&t1);
    printf("diff0: %.6f ms\n",
        static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);

#define TEST(diffX)\
    QueryPerformanceCounter(&t0);\
    diffX(a, b, resX, N);\
    QueryPerformanceCounter(&t1);\
    printf("%s: %.6f ms\n", #diffX,\
        static_cast<double>(t1.QuadPart - t0.QuadPart) / f.QuadPart * 1000);\
    for (std::size_t i = 0; i < N; i++) {\
        if (resX[i] != res0[i]) {\
            printf("error: %s(%03u, %03u) == %03u != %03u\n", #diffX,\
                a[i], b[i], resX[i], res0[i]);\
            break;\
        }\
    }

    TEST(diff1);
    TEST(diff2);
    TEST(diff3);
    TEST(diff4);
    TEST(diff5);

    _mm_free(a);
    _mm_free(b);
    _mm_free(res0);
    _mm_free(resX);

    getc(stdin);
    return 0;
}

3

编辑:更正我的答案,我在优化方面配置有误。

我在C语言中设置了一个快速测试程序,发现

a - b + (a < b) * ((b - a) << 1);

在我的环境中略微更好。我的方法的优点是消除了比较。你的版本会将a-b==0隐式处理为一个单独的情况,但这是不必要的。

我的测试结果如下:

  • 你的实现:371毫秒
  • 此实现:324毫秒
  • 加速比:14%

我尝试了一种非分支绝对值的方法,结果更好。需要注意的是,编译器是否认为输入或输出是带符号的与本题无关。它可以循环大型无符号值,但既然只需要处理小值(如问题所述),那么应该足够了。

s32 diff = a - b;
u32 mask = diff >> 31;
return (diff + mask) ^ mask;
  • 您的实现:371毫秒
  • 此实现:241毫秒
  • 加速:53%

你说得对,我复制的例子假设了>> 32的某种行为,并根据我的CPU简化了它,这不是正确的方法。我想出了一个解决方案并更新了我的答案。 - VoidStar

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接