为所有模板参数的组合实例化函数,在运行时选择实例化。

3

非常抱歉如果这个问题之前已经被问过,但我没有找到完全相同的问题。

我有一个类似于下面代码的CUDA内核模板:

template<int firstTextureIndex, int secondTextureIndex, int thirdTextureIndex> __global__ void myKernel

三种纹理索引模板类型将在运行时范围从0-7,并且在运行时不会被知道。我需要实例化这个内核的所有512种组合,然后根据纹理索引的运行时值调用正确的模板。
我从来没有编写过任何预处理宏,也试图避免使用它。另一篇帖子(这里)展示了如何通过递归地为一个模板变量实例化许多类模板。
template<int i>
class loop {
    loop<i-1> x;
}

template<>
class loop<1> {
}

loop<10> l;

我正在努力将这个扩展到3个变量和一个函数(而不是一个类),以适应我的情况。即使我找出了如何以这种方式实例化所有这些内容,如何在运行时调用512种可能性中的1种,而不使用嵌套的switch语句呢?为了说明问题,我试图避免的嵌套switch语句会像这样:

switch(firstTextureIndex)
{
    case 0:
        switch(secondTextureIndex)
        {
            case 1:
                switch(thirdTextureIndex)
                {
                    case 2:
                        myKernel<0, 1, 2><<<grid, block>>>(param1, param2, param3);
                        break;
                }
             break;
        }
    break;
}

如果我找出如何为所有的0-7实例化,那么我可以这样调用它吗:
myKernel<i, j, k><<<grid, block>>>(param1, param2); 

如果我把i、j和k定义为只包含0-7的枚举类型,编译器就能知道所有可能的值,并且由于我实例化了它们,所以这是可以的吗?
请注意,这个三重模板之所以要传递纹理索引,是有充分理由的,但为了简洁起见,我省略了说明。非常感谢任何关于实例化和/或调用此内核的帮助。
编辑:Jarod42提供了一个有效的解决方案,完全符合我的要求。不幸的是,我现在意识到c++标准在这里很重要。我使用的是c++98/03与最新稳定版本的boost库相结合,因此最好使用这些方法来解决问题。我有可能会使用c++11,但由于我们编译器的限制,无法使用c++14。

1
使用 firstTextureIndex * 64 + secondTextureIndex * 8 + thirdTextureIndex,你可以限制为一个变量。 - Jarod42
好主意!这对我来说实际上是可行的,而且能够大大简化问题。在运行时,我仍然需要从512个可能性中选择一个,并实例化所有这些,但至少可以将其压缩为一个变量。 - user1777820
请考虑删除CUDA相关的引用 - 因为您的问题和答案实际上与CUDA无关。只需像@havogt的示例中那样使用一个虚拟函数即可。 - einpoklum
2个回答

4
您可以这样做:

您可以采取类似的操作:

template <std::size_t I>
void do_job()
{
    myKernel<I / 64, (I / 8) % 8, I % 8>{}();
}

template <std::size_t ... Is>
void callMyKernel(std::index_sequence<Is...>, std::size_t i, std::size_t j, std::size_t k)
{
    std::function<void()> fs[] = {&do_job<Is>...};

    fs[i * 64 + j * 8 + k]();
}

void callMyKernel(std::size_t i, std::size_t j, std::size_t k)
{
    callMyKernel(std::make_index_sequence<512>{}, i, j, k);
}

Demo


谢谢!这是一个非常好的解决方案,我从你的演示中学到了很多。不幸的是,我只能使用c++99,但至少我可以使用最新稳定版本的boost库。也许我可以说服同事升级到c++11,但我们的操作系统编译器不支持c++14。我将继续自己寻找一个最多使用c++11的解决方案,但你显然对多态性有更全面的理解。有没有关于使用boost的c++99(最好)或c++11的解决方案的想法?我会编辑我的问题,包括这些相关信息。 - user1777820
index_sequence的相关内容可以在C++11中完成(但不能在C++03中完成)。 - Jarod42
1
我希望我能够将你们两个都标记为正确。我选择@havogt的解决方案,因为我可以在不改变我们使用的C++标准的情况下使用它。不过,你对原始问题的回答是完全正确的,这是我的疏忽没有提及标准。 - user1777820

2
以下代码是基于C++98/03boost.MPL的实现。肯定有改进的余地(例如隐藏全局指针数组,检查非法组合等)。
思路是递归运行整数列表的所有组合,从而为每个组合填充函数指针数组。
我以前使用过类似但更复杂的代码来在运行时选择最佳的内核参数组合(自动调优),例如launch_bounds和其他选项:culgt/runtimechooser
以下是简化版本:
#include <iostream>

#include <boost/mpl/vector.hpp>
#include <boost/mpl/vector_c.hpp>
#include <boost/mpl/for_each.hpp>
#include <boost/mpl/push_back.hpp>
#include <boost/mpl/at.hpp>

namespace mpl = boost::mpl;

template<int index1, int index2, int index3> void execKernel()
{
    std::cout << "Kernel called with " << index1 << "/" << index2 << "/" << index3 << std::endl;
}

typedef void (*FPTR)();
FPTR ptr[512];

struct NIL
{
public:
    static const int value = 0;
};

template<typename Seq, typename T1, typename T2 = NIL> class MakeSequenceImpl
{
public:
    template<typename T> void operator()(T)
    {
        typedef MakeSequenceImpl<typename mpl::push_back<Seq,T>::type,T2> RunSeq;
        mpl::for_each<T1>( RunSeq() );
    }
};

template<typename Seq> class MakeSequenceImpl<Seq, NIL, NIL>
{
public:
    template<typename T> void operator()(T)
    {
        typedef typename mpl::push_back<Seq,T>::type FinalSeq;

        int index = mpl::at<FinalSeq,mpl::int_<0> >::type::value * 64
                + mpl::at<FinalSeq,mpl::int_<1> >::type::value * 8
                + mpl::at<FinalSeq,mpl::int_<2> >::type::value;

        ptr[index] = execKernel<mpl::at<FinalSeq,mpl::int_<0> >::type::value, mpl::at<FinalSeq,mpl::int_<1> >::type::value, mpl::at<FinalSeq,mpl::int_<2> >::type::value>;
    }
};


template<typename T0, typename T1, typename T2> class MakeSequence
{
public:
    typedef mpl::vector_c<int> Seq;

    MakeSequence()
    {
        typedef MakeSequenceImpl<Seq, T1, T2> RunSeq;
        mpl::for_each<T0>( RunSeq() );
    }
};


void callWrapper( int i, int j, int k )
{
    ptr[i*64+j*8+k]();
}

typedef mpl::vector_c< int, 0, 1, 2, 3, 4, 5, 6, 7 > list1;
typedef mpl::vector_c< int, 0, 1, 2, 3, 4, 5, 6, 7 > list2;
typedef mpl::vector_c< int, 0, 1, 2, 3, 4, 5, 6, 7 > list3;

int main()
{
    MakeSequence<list1,list2,list3> frontend;

    int i,j,k;

    std::cin >> i;
    std::cin >> j;
    std::cin >> k;

    callWrapper(i,j,k);
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接