我最终得到了以下可用代码。
struct PackedTriangles
{
__m256 e1[3];
__m256 e2[3];
__m256 v0[3];
__m256 inactiveMask;
};
struct PackedIntersectionResult
{
float t = Math::infinity<float>();
int idx;
};
struct PackedRay
{
__m256 m_origin[3];
__m256 m_direction[3];
__m256 m_length;
bool intersect(const PackedTriangles& packedTris, PackedIntersectionResult& result) const;
};
#define or8f _mm256_or_ps
#define mul _mm256_mul_ps
#define fmsub _mm256_fmsub_ps
#define fmadd _mm256_fmadd_ps
#define cmp _mm256_cmp_ps
#define div _mm256_div_ps
void avx_multi_cross(__m256 result[3], const __m256 a[3], const __m256 b[3])
{
result[0] = fmsub(a[1], b[2], mul(b[1], a[2]));
result[1] = fmsub(a[2], b[0], mul(b[2], a[0]));
result[2] = fmsub(a[0], b[1], mul(b[0], a[1]));
}
__m256 avx_multi_dot(const __m256 a[3], const __m256 b[3])
{
return fmadd(a[2], b[2], fmadd(a[1], b[1], mul(a[0], b[0])));
}
void avx_multi_sub(__m256 result[3], const __m256 a[3], const __m256 b[3])
{
result[0] = _mm256_sub_ps(a[0], b[0]);
result[1] = _mm256_sub_ps(a[1], b[1]);
result[2] = _mm256_sub_ps(a[2], b[2]);
}
const __m256 oneM256 = _mm256_set1_ps(1.0f);
const __m256 minusOneM256 = _mm256_set1_ps(-1.0f);
const __m256 positiveEpsilonM256 = _mm256_set1_ps(1e-6f);
const __m256 negativeEpsilonM256 = _mm256_set1_ps(-1e-6f);
const __m256 zeroM256 = _mm256_set1_ps(0.0f);
bool PackedRay::intersect(const PackedTriangles& packedTris, PackedIntersectionResult& result) const
{
__m256 q[3];
avx_multi_cross(q, m_direction, packedTris.e2);
__m256 a = avx_multi_dot(packedTris.e1, q);
__m256 f = div(oneM256, a);
__m256 s[3];
avx_multi_sub(s, m_origin, packedTris.v0);
__m256 u = mul(f, avx_multi_dot(s, q));
__m256 r[3];
avx_multi_cross(r, s, packedTris.e1);
__m256 v = mul(f, avx_multi_dot(m_direction, r));
__m256 t = mul(f, avx_multi_dot(packedTris.e2, r));
__m256 failed = _mm256_and_ps(
cmp(a, negativeEpsilonM256, _CMP_GT_OQ),
cmp(a, positiveEpsilonM256, _CMP_LT_OQ)
);
failed = or8f(failed, cmp(u, zeroM256, _CMP_LT_OQ));
failed = or8f(failed, cmp(v, zeroM256, _CMP_LT_OQ));
failed = or8f(failed, cmp(_mm256_add_ps(u, v), oneM256, _CMP_GT_OQ));
failed = or8f(failed, cmp(t, zeroM256, _CMP_LT_OQ));
failed = or8f(failed, cmp(t, m_length, _CMP_GT_OQ));
failed = or8f(failed, packedTris.inactiveMask);
__m256 tResults = _mm256_blendv_ps(t, minusOneM256, failed);
int mask = _mm256_movemask_ps(tResults);
if (mask != 0xFF)
{
result.idx = -1;
float* ptr = (float*)&tResults;
for (int i = 0; i < 8; ++i)
{
if (ptr[i] >= 0.0f && ptr[i] < result.t)
{
result.t = ptr[i];
result.idx = i;
}
}
return result.idx != -1;
}
return false;
}
结果
这些结果真的很棒。对于一个包含100,000个三角形的场景,我获得了84%的加速!! 对于非常小的场景(只有20个三角形),性能下降了13%。但这没关系,因为这种情况并不常见。