在Cuda中使用表达式模板构建lambda表达式

3
我正在开发一个需要将任意函数传递给CUDA核心的CUDA应用程序。由于为每种可能情况声明函数指针并将它们传递到内核中会很麻烦(超过50个不同的函数),而且它们都是像sin(x)/y这样的基本函数组合,因此我希望 CUDA 核心具有一些最小的 Lambda-表达式功能。由于设备代码尚不支持 C++11 特性(就我所知),而且我在网上找不到任何相关信息,所以我决定自己教授表达式模板,并实现一些简单的 lambda-表达式规则以传递到内核中。
我想出了以下代码,这是一种可以在 NVCC 上编译并运行良好的最小实现。然而,通过这种方式,我只能实现具有1个变量的函数。是否有办法扩展我的代码来处理类似于 sin(_x) + _y 的函数组合?
提前感谢!
#include<math.h>

#ifdef __CUDACC__
#define HOST_DEVICE __host__ __device__
#else
#define HOST_DEVICE
#endif

struct Id {};

template <typename Op, typename Left, typename Right>
struct BinaryOp
{
    Left left;
    Right right;
    HOST_DEVICE BinaryOp(Left t1, Right t2) : left(t1), right(t2) {}

    HOST_DEVICE double operator() (double x) {
        return Op::apply(left(x), right(x));
    }
};

template <typename Op, typename Arg>
struct UnaryOp
{
    Arg arg;
    HOST_DEVICE UnaryOp(Arg t1) : arg(t1) {}

    HOST_DEVICE double operator() (double x) {
        return Op::apply(arg(x));
    }
};

template <>
struct UnaryOp<Id, double>
{
    HOST_DEVICE UnaryOp() {}
    HOST_DEVICE double operator() (double x) {
        return x;
    }
};

struct Sin
{
    HOST_DEVICE static double apply(double x) {
        return sin(x);
    }
};

struct Plus
{
    HOST_DEVICE static double apply(double a, double b) {
        return a + b;
    }
};

template <typename Left, typename Right>
BinaryOp<Plus, Left, Right> operator+ (Left lhs, Right rhs) {
    return BinaryOp<Plus, Left, Right>(lhs, rhs);
}

template <typename Arg>
UnaryOp<Sin, Arg> _sin(Arg arg) {
    return UnaryOp<Sin, Arg>(arg);
}

template <class T>
__global__ void test(T func, double x) {
    printf("%e\n", func(x));
}

int main () 
{
    UnaryOp<Id, double> _x;
    double x = 1.0;
    test<<<1, 1>>>(_sin(_x) + _x, x);
    cudaDeviceSynchronize();  // Needed or the host will return before kernel is finished
    return 0;
}

请查看表达式模板 - Constructor
@Constructor 谢谢,不过我已经详细阅读了那篇文章,并编写了我自己的实现表达式模板的代码。但是我认为该页面没有足够的信息来满足我想要做的事情:为多个变量构建 lambda 表达式。 - Alex Chen
你能解释一下为什么你所做的比函数指针更简单吗?我真的很想知道。我已经盯着你的代码两天了,但我仍然看不出优势。 - benathon
@portforwardpodcast 是的,原则上我可以为我需要使用的所有基本函数的可能组合定义一个设备函数,并根据需要将指针传递给内核。然而,每当我决定需要使用例如 f1(x, y) * f2(x, y) 时,我都需要定义一个新的设备函数,结果代码将被一堆我可能使用的小函数所淹没。另一个问题是,如果作为函数指针传递,Cuda 将不会内联设备函数,这会导致性能下降。 - Alex Chen
1个回答

1

所以在提问后,我花了一些时间并且想出了一个解决方案。虽然不太美观但对我自己有效。这是修改后的代码,支持最多3个自由变量。更多变量可以硬编码,但目前我的项目不需要。

#include<math.h>

#ifdef __CUDACC__
#define HOST_DEVICE __host__ __device__
#else
#define HOST_DEVICE
#endif

struct Id {};

template <typename Op, typename Left, typename Right>
struct BinaryOp
{
    Left left;
    Right right;
    HOST_DEVICE BinaryOp(Left t1, Right t2) : left(t1), right(t2) {}

    HOST_DEVICE double operator() (double x1, double x2 = 0.0, double x3 = 0.0) {
        return Op::apply(left(x1, x2, x3), right(x1, x2, x3));
    }
};

template <typename Op, typename Arg>
struct UnaryOp
{
    Arg arg;
    HOST_DEVICE UnaryOp(Arg t1) : arg(t1) {}

    HOST_DEVICE double operator() (double x1, double x2 = 0.0, double x3 = 0.0) {
        return Op::apply(arg(x1, x2, x3));
    }
};

template <int argnum>
struct Var
{
    HOST_DEVICE Var() {}
    HOST_DEVICE double operator() (double x1, double x2 = 0.0, double x3 = 0.0) {
        if (1 == argnum) return x1;
        else if (2 == argnum) return x2;
        else return x3;
    }
};

struct Sin
{
    HOST_DEVICE static double apply(double x) {
        return sin(x);
    }
};

struct Plus
{
    HOST_DEVICE static double apply(double a, double b) {
        return a + b;
    }
};

template <typename Left, typename Right>
BinaryOp<Plus, Left, Right> operator+ (Left lhs, Right rhs) {
    return BinaryOp<Plus, Left, Right>(lhs, rhs);
}

template <typename Arg>
UnaryOp<Sin, Arg> _sin(Arg arg) {
    return UnaryOp<Sin, Arg>(arg);
}

template <class T>
__global__ void test(T func, double x, double y, double z = 0.0) {
    printf("%e\n", func(x, y));
}

Var<1> _x;
Var<2> _y;

int main () 
{
    test<<<1, 1>>>(_sin(_x) + _y, 1.0, 2.0);
    cudaDeviceSynchronize();  // Needed or the host will return before kernel is finished
    return 0;
}

这显然是一个丑陋的hack。lambda表达式只适用于double(或可转换为double的类型)。然而,我目前想不到解决办法。希望NVCC能尽快支持c++11特性,这样我就不需要这种hack了。
如果有人能向我展示更好的解决方案,无论是库还是更好的组合方式,都将不胜感激。谢谢任何帮助!

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接