使用jCUDA进行复杂矩阵运算

4

使用jCuda操作复数的最佳方式是什么?我应该使用cuComplex格式还是其他解决方案(例如一个数组,实部和虚部一个接一个)?我会非常感激有这种计算类型的Java代码示例。

由于我的目的是使用GPU解决具有复数的大型线性方程组,我不想只依赖于jCuda。有哪些替代方法可以使用GPU进行此类计算?


使用cuDoubleComplex来表示您的复数。此外,jCUDA附带了jCUBLAS,它将允许您表示线性方程组。 - user3248346
在jCublas库文档中,除了像cublasSetMatrix这样的方法之外,开发人员还写道:“提供指向包含复数的浮点数组的指针要高效得多,其中数组中每对连续的数字描述一个复数的实部和虚部。”那么我应该如何操作指针呢? - user3173452
Java没有指针,所以你不能这样做。那些开发人员可能在使用另一种语言,可能是C或C++。在Java中,我看到他们创建了一个Pointer类来模仿C或C++指针。 - user3248346
不,那并非Java的自带功能。这是用户定义的。您需要创建该类并模拟指针行为。 - user3248346
好的,你说得有道理。然而,我仍然不明白如何使用分别带有实部和虚部的数组进行复杂计算。 - user3173452
显示剩余2条评论
1个回答

3
首先,关于Java在GPU上的计算问题,我在这里写了一些相关内容(链接)
您的应用场景似乎非常特定。您可能需要更详细地描述您的实际意图,因为这将决定所有设计决策。到目前为止,我只能给出一些基本提示。决定哪种方案最合适取决于您。
在Java世界和GPU世界之间搭建桥梁的主要困难之一是根本不同的内存处理方式。

C/C++中的内存布局

CUDA中的cuComplex结构体定义如下:
typedef float2 cuFloatComplex
typedef cuFloatComplex cuComplex;

在IT技术方面,float2基本上类似于一个

struct float2 {
    float x; 
    float y; 
};

(包括一些额外的对齐等指定符)

现在,当您在C/C++程序中分配cuComplex值的数组时,只需编写以下内容:

cuComplex *c = new cuComplex[100];

在这种情况下,可以保证所有这些cuComplex值的内存都是一个单一的、连续的内存块。这个内存块只由复数的所有xy值依次组成:
      _____________________________
c -> | x0 | y0 | x1 | y1 | x2 | y2 |... 
     |____|____|____|____|____|____|

这个连续的内存块可以轻松地复制到设备上:只需取指针,并调用如下命令:

cudaMemcpy(device, c, sizeof(cuComplex)*n, cudaMemcpyHostToDevice);

Java中的内存布局

假设您创建了一个在结构上等同于cuComplex结构的Java类,并分配了这些类的数组:

 class cuComplex {
     public float x;
     public float y;
 }

 cuComplex c[] = new cuComplex[100];

那么你就没有一个连续的内存块,其中包含float值。相反,你有一个cuComplex对象引用的数组,而相应的xy值则分散在各处:

      ____________________
c -> |  c0  |  c1  |  c2  |... 
     |______|______|______|
        |       |      |
        v       v      v
      [x,y]   [x,y]  [x,y]

关键点是:您不能将(Java)cuComplex对象的数组复制到设备上!
这有几个影响。在评论中,您已经提到了cublasSetVector方法,该方法以cuComplex数组作为参数,并且我试图强调这不是最有效的解决方案,而只是为了方便。实际上,此方法通过内部创建一个新的ByteBuffer,以便具有连续的内存块,使用从cuComplex []数组中获取的值填充此ByteBuffer,然后将此ByteBuffer复制到设备上。
当然,这会带来一个开销,在性能关键应用程序中,您很可能希望避免这种情况。
有几种解决此问题的选项。幸运的是,对于复数,解决方案相对容易:
不要使用cuComplex结构表示复数数组
相反,应将复数数组表示为单个连续内存块,其中复数的实部和虚部交错,分别是单个floatdouble值。这将允许不同后端之间的最大互操作性(省略某些详细信息,如对齐要求)。
不幸的是,这可能会引起一些不便和问题,而且对于此并没有通用的解决方案。
如果试图将其泛化,不仅涉及到复数,而是"结构"总体,则可以应用某种"模式":可以为这些结构创建接口,并创建一个该集合,其中包括实现此接口的类的实例的列表,这些类都由连续的内存块支持。这对于某些情况可能是适当的。但对于复数来说,每个复数的Java对象的内存开销可能过大。
另一种极端只处理裸float []double []数组也可能不是最佳解决方案。例如:如果您有一个表示复数的float值数组,那么如何将其中一个复数乘以另一个复数?
一个"中间"解决方案可以是创建一个允许访问复数的实部和虚部的接口。在实现中,这些复数存储在单个数组中,如上所述。
我在这里勾画了这样的实现。
注:
这只是一个示例,旨在展示基本思想并展示如何与类似JCublas的东西一起工作。对于您来说,根据实际目标,可能需要采用不同的策略:除了JCuda之外,还应该有哪些其他后端?在Java端处理复数应该有多“方便”?用于处理Java端复数的结构(类/接口)应该是什么样子的?
简而言之,在开始实现之前,您应该非常清楚地知道您的应用程序/库应该能够做什么。
import static jcuda.jcublas.JCublas2.*;
import static jcuda.jcublas.cublasOperation.CUBLAS_OP_N;
import static jcuda.runtime.JCuda.*;

import java.util.Random;

import jcuda.*;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.cudaMemcpyKind;

// An interface describing an array of complex numbers, residing
// on the host, with methods for accessing the real and imaginary
// parts of the complex numbers, as well as methods for copying
// the underlying data from and to the device
interface cuComplexHostArray
{
    int size();

    float getReal(int i);
    float getImag(int i);

    void setReal(int i, float real);
    void setImag(int i, float imag);

    void set(int i, cuComplex c);
    void set(int i, float real, float imag);

    cuComplex get(int i, cuComplex c);

    void copyToDevice(Pointer devicePointer);
    void copyFromDevice(Pointer devicePointer);
}

// A default implementation of a cuComplexHostArray, backed
// by a single float[] array
class DefaultCuComplexHostArray implements cuComplexHostArray
{
    private final int size;
    private final float data[];

    DefaultCuComplexHostArray(int size)
    {
        this.size = size;
        this.data = new float[size * 2];
    }

    @Override
    public int size()
    {
        return size;
    }

    @Override
    public float getReal(int i)
    {
        return data[i+i];
    }

    @Override
    public float getImag(int i)
    {
        return data[i+i+1];
    }

    @Override
    public void setReal(int i, float real)
    {
        data[i+i] = real;
    }

    @Override
    public void setImag(int i, float imag)
    {
        data[i+i+1] = imag;
    }

    @Override
    public void set(int i, cuComplex c)
    {
        data[i+i+0] = c.x;
        data[i+i+1] = c.y;
    }

    @Override
    public void set(int i, float real, float imag)
    {
        data[i+i+0] = real;
        data[i+i+1] = imag;
    }

    @Override
    public cuComplex get(int i, cuComplex c)
    {
        float real = getReal(i);
        float imag = getImag(i);
        if (c != null)
        {
            c.x = real;
            c.y = imag;
            return c;
        }
        return cuComplex.cuCmplx(real, imag);
    }

    @Override
    public void copyToDevice(Pointer devicePointer)
    {
        cudaMemcpy(devicePointer, Pointer.to(data),
            size * Sizeof.FLOAT * 2,
            cudaMemcpyKind.cudaMemcpyHostToDevice);
    }

    @Override
    public void copyFromDevice(Pointer devicePointer)
    {
        cudaMemcpy(Pointer.to(data), devicePointer,
            size * Sizeof.FLOAT * 2,
            cudaMemcpyKind.cudaMemcpyDeviceToHost);
    }
}

// An example that performs a "gemm" with complex numbers, once
// in Java and once in JCublas2, and verifies the result
public class JCublas2ComplexSample
{
    public static void main(String args[])
    {
        testCgemm(500);
    }

    public static void testCgemm(int n)
    {
        cuComplex alpha = cuComplex.cuCmplx(0.3f, 0.2f);
        cuComplex beta  = cuComplex.cuCmplx(0.1f, 0.7f);
        int nn = n * n;

        System.out.println("Creating input data...");
        Random random = new Random(0);
        cuComplex[] rhA = createRandomComplexRawArray(nn, random);
        cuComplex[] rhB = createRandomComplexRawArray(nn, random);
        cuComplex[] rhC = createRandomComplexRawArray(nn, random);

        random = new Random(0);
        cuComplexHostArray hA = createRandomComplexHostArray(nn, random);
        cuComplexHostArray hB = createRandomComplexHostArray(nn, random);
        cuComplexHostArray hC = createRandomComplexHostArray(nn, random);

        System.out.println("Performing Cgemm with Java...");
        cgemmJava(n, alpha, rhA, rhB, beta, rhC);

        System.out.println("Performing Cgemm with JCublas...");
        cgemmJCublas(n, alpha, hA, hB, beta, hC);

        boolean passed = isCorrectResult(hC, rhC);
        System.out.println("testCgemm "+(passed?"PASSED":"FAILED"));
    }

    private static void cgemmJCublas(
        int n,
        cuComplex alpha,
        cuComplexHostArray A,
        cuComplexHostArray B,
        cuComplex beta,
        cuComplexHostArray C)
    {
        int nn = n * n;

        // Create a CUBLAS handle
        cublasHandle handle = new cublasHandle();
        cublasCreate(handle);

        // Allocate memory on the device
        Pointer dA = new Pointer();
        Pointer dB = new Pointer();
        Pointer dC = new Pointer();
        cudaMalloc(dA, nn * Sizeof.FLOAT * 2);
        cudaMalloc(dB, nn * Sizeof.FLOAT * 2);
        cudaMalloc(dC, nn * Sizeof.FLOAT * 2);

        // Copy the memory from the host to the device
        A.copyToDevice(dA);
        B.copyToDevice(dB);
        C.copyToDevice(dC);

        // Execute cgemm
        Pointer pAlpha = Pointer.to(new float[]{alpha.x, alpha.y});
        Pointer pBeta = Pointer.to(new float[]{beta.x, beta.y});
        cublasCgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, n, n, n,
            pAlpha, dA, n, dB, n, pBeta, dC, n);

        // Copy the result from the device to the host
        C.copyFromDevice(dC);

        // Clean up
        cudaFree(dA);
        cudaFree(dB);
        cudaFree(dC);
        cublasDestroy(handle);
    }

    private static void cgemmJava(
        int n,
        cuComplex alpha,
        cuComplex A[],
        cuComplex B[],
        cuComplex beta,
        cuComplex C[])
    {
        for (int i = 0; i < n; ++i)
        {
            for (int j = 0; j < n; ++j)
            {
                cuComplex prod = cuComplex.cuCmplx(0, 0);
                for (int k = 0; k < n; ++k)
                {
                    cuComplex ab =
                        cuComplex.cuCmul(A[k * n + i], B[j * n + k]);
                    prod = cuComplex.cuCadd(prod, ab);
                }
                cuComplex ap = cuComplex.cuCmul(alpha, prod);
                cuComplex bc = cuComplex.cuCmul(beta, C[j * n + i]);
                C[j * n + i] = cuComplex.cuCadd(ap, bc);
            }
        }
    }

    private static cuComplex[] createRandomComplexRawArray(
        int n, Random random)
    {
        cuComplex c[] = new cuComplex[n];
        for (int i = 0; i < n; i++)
        {
            float real = random.nextFloat();
            float imag = random.nextFloat();
            c[i] = cuComplex.cuCmplx(real, imag);
        }
        return c;
    }

    private static cuComplexHostArray createRandomComplexHostArray(
        int n, Random random)
    {
        cuComplexHostArray c = new DefaultCuComplexHostArray(n);
        for (int i = 0; i < n; i++)
        {
            float real = random.nextFloat();
            float imag = random.nextFloat();
            c.setReal(i, real);
            c.setImag(i, imag);
        }
        return c;
    }

    private static boolean isCorrectResult(
        cuComplexHostArray result, cuComplex reference[])
    {
        float errorNormX = 0;
        float errorNormY = 0;
        float refNormX = 0;
        float refNormY = 0;
        for (int i = 0; i < result.size(); i++)
        {
            float diffX = reference[i].x - result.getReal(i);
            float diffY = reference[i].y - result.getImag(i);
            errorNormX += diffX * diffX;
            errorNormY += diffY * diffY;
            refNormX += reference[i].x * result.getReal(i);
            refNormY += reference[i].y * result.getImag(i);
        }
        errorNormX = (float) Math.sqrt(errorNormX);
        errorNormY = (float) Math.sqrt(errorNormY);
        refNormX = (float) Math.sqrt(refNormX);
        refNormY = (float) Math.sqrt(refNormY);
        if (Math.abs(refNormX) < 1e-6)
        {
            return false;
        }
        if (Math.abs(refNormY) < 1e-6)
        {
            return false;
        }
        return
            (errorNormX / refNormX < 1e-6f) &&
            (errorNormY / refNormY < 1e-6f);
    }
}

(顺便说一句:我可能会采用这个答案的部分并将它们扩展成样本和/或“如何......”页面,以供JCuda使用。提供这样的信息的任务已经在我的“待办事项”清单上有一段时间了。)


这是一个使用cgemm的超级有用的演示。将其扩展到zgemm并不困难。感谢!来自IntelliJ IDEA的小代码反馈 - cuComplex A[] 是C风格的声明,建议写成 cuComplex[] A - ATutorMe

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接