高效浮点数比较(Cortex-A8)

6
有一个大约有10万个浮点变量的数组,并且有一个阈值(也是浮点数)。
问题是,我必须将数组中的每个变量与阈值进行比较,但NEON标志传输需要很长时间(根据分析器约20个周期)。
有没有有效的方法来比较这些值? 注意:由于舍入误差无关紧要,我尝试了以下内容:
float arr[10000];
float threshold; 
....

int a = arr[20]; // e.g.
int t = threshold;
if (t > a) {....}

但在这种情况下,我得到了以下处理器命令序列:
vldr.32        s0, [r0]
vcvt.s32.f32   s0, s0
vmov           r0, s0    <--- takes 20 cycles as `vmrs APSR_nzcv, fpscr` in case of 
cmp            r0, r1         floating point comparison

由于转换发生在NEON上,因此无论我按照所描述的方式比较整数还是浮点数都没有问题。


3
您的代码与问题陈述不一致 - 数据是浮点数,但您显示阈值为整数 - 您还将每个浮点数据值强制转换为整数 - 为什么?如果您的数据是浮点数,那么阈值应该是浮点数,并且您应该进行浮点比较(即不进行整数和浮点数之间的转换)。此外,您打算如何处理大于(或小于)阈值的值(这将决定NEON是否合适)? - Paul R
2
许多人因为不知道如何避免和正确地编写SIMD程序而认为NEON比ARM慢,从而放弃了它。根据你的需求,要么一开始就不适合使用SIMD,要么你不知道如何处理NEON中的if-else语句。 - Jake 'Alquimista' LEE
2
if-else在NEON上的实现:1.使用VCnn指令创建掩码值。2.在两个不同的寄存器中计算if和else情况下的结果。3.使用步骤1中的掩码值,通过VBnn指令将两个结果合并。 - Jake 'Alquimista' LEE
2
由于在NEON上必须计算满足条件的两种情况,所以如果计算非常复杂,它可能比在ARM上慢。如果有多个相关的if-else语句,请忘记NEON。 - Jake 'Alquimista' LEE
2
NEON->ARM寄存器传输总是需要11~14个周期,应该在循环内尽可能避免。另外,即使没有舍入误差,你也可以非常快地比较浮点数,为什么还要进行类型转换呢?我将在另一个答案中向您展示如何实现。 - Jake 'Alquimista' LEE
显示剩余2条评论
4个回答

5
如果浮点数是32位IEEE-754格式,整数也是32位,并且如果没有+无穷大、-无穷大和NaN值,我们可以用一个小技巧将浮点数比作整数进行比较:
#include <stdio.h>
#include <limits.h>
#include <assert.h>

#define C_ASSERT(expr) extern char CAssertExtern[(expr)?1:-1]
C_ASSERT(sizeof(int) == sizeof(float));
C_ASSERT(sizeof(int) * CHAR_BIT == 32);

int isGreater(float* f1, float* f2)
{
  int i1, i2, t1, t2;

  i1 = *(int*)f1;
  i2 = *(int*)f2;

  t1 = i1 >> 31;
  i1 = (i1 ^ t1) + (t1 & 0x80000001);

  t2 = i2 >> 31;
  i2 = (i2 ^ t2) + (t2 & 0x80000001);

  return i1 > i2;
}

int main(void)
{
  float arr[9] = { -3, -2, -1.5, -1, 0, 1, 1.5, 2, 3 };
  float thr;
  int i;

  // Make sure floats are 32-bit IEE754 and
  // reinterpreted as integers as we want/expect
  {
    static const float testf = 8873283.0f;
    unsigned testi = *(unsigned*)&testf;
    assert(testi == 0x4B076543);
  }

  thr = -1.5;
  for (i = 0; i < 9; i++)
  {
    printf("%f %s %f\n", arr[i], "<=\0> " + 3*isGreater(&arr[i], &thr), thr);
  }

  thr = 1.5;
  for (i = 0; i < 9; i++)
  {
    printf("%f %s %f\n", arr[i], "<=\0> " + 3*isGreater(&arr[i], &thr), thr);
  }

  return 0;
}

输出:

-3.000000 <= -1.500000
-2.000000 <= -1.500000
-1.500000 <= -1.500000
-1.000000 >  -1.500000
0.000000 >  -1.500000
1.000000 >  -1.500000
1.500000 >  -1.500000
2.000000 >  -1.500000
3.000000 >  -1.500000
-3.000000 <= 1.500000
-2.000000 <= 1.500000
-1.500000 <= 1.500000
-1.000000 <= 1.500000
0.000000 <= 1.500000
1.000000 <= 1.500000
1.500000 <= 1.500000
2.000000 >  1.500000
3.000000 >  1.500000

当然,如果你的阈值不变,预先计算在isGreater()中使用比较运算符的最终整数值是有意义的。
如果你担心在C/C++中出现未定义行为,你可以用汇编重写上述代码。

看起来是个好主意。我仍然在处理vmov.32的问题,但总体来说这是一个好主意。谢谢。 - Alex
@vasile:你在说什么错误?什么是复杂的?如果是,你怎么让它变得更简单? - Alexey Frunze
我是在参考@Paul-R的答案。 - Sam

2

如果你的数据是浮点型,那么你应该使用浮点型进行比较,例如:

float arr[10000];
float threshold;
....

float a = arr[20]; // e.g.
if (threshold > a) {....}

否则,您将需要进行昂贵的浮点数和整数之间的转换。

如果我比较两个浮点数,它会导致昂贵的标志寄存器传输。这就是为什么我尝试比较两个整数的原因。 - Alex
当阈值测试为真/假时,您随后执行哪些操作? - Paul R
vcmpe.f32 s17, s16 vmrs APSR_nzcv, fpscr 如果我理解你的问题了。 - Alex
Alex:我的意思是:在你的阈值测试后面的 { ... } 中会发生什么?这将决定是否适合使用 NEON 进行测试。 - Paul R
测试之后有很多代码,所以我将其标记为“...”。 - Alex
1
如果代码量很大,那么测试成本应该是可以忽略不计的(除非绝大多数数据点不需要处理?)。 - Paul R

2
你的例子展示了编译器生成的代码有多糟糕:
它使用NEON加载一个值,只为将其转换为int,然后执行NEON->ARM转移,导致浪费11~14个周期。
最好的解决方案是完全用手写汇编来编写函数。
然而,有一个简单的技巧可以在不进行类型转换和截断的情况下快速比较浮点数:
阈值正数(与int比较一样快):
void example(float * pSrc, float threshold, unsigned int count)
{
  typedef union {
    int ival,
    unsigned int uval,
    float fval
  } unitype;

  unitype v, t;
  if (count==0) return;
  t.fval = threshold;
  do {
    v.fval = *pSrc++;
    if (v.ival < t.ival) {
      // your code here
    }
    else {
      // your code here (optional)
    }
  } while (--count);
}

阈值负数(比整数比较多1个周期):

void example(float * pSrc, float threshold, unsigned int count)
{
  typedef union {
    int ival,
    unsigned int uval,
    float fval
  } unitype;

  unitype v, t, temp;
  if (count==0) return;
  t.fval = threshold;
  t.uval &= 0x7fffffff;
  do {
    v.fval = *pSrc++;
    temp.uval = v.uval ^ 0x80000000;
    if (temp.ival >= t.ival) {
      // your code here
    }
    else {
      // your code here (optional)
    }
  } while (--count);
}

我认为这个比上面接受的那个要快得多。再次,我有点晚了。



0

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接