快速将Java数组转换为NumPy数组(Py4J)

3

有一些很好的例子,说明如何将NumPy数组转换为Java数组,但是没有反之的例子——如何将数据从Java对象转换回NumPy数组。我有一个像这样的Python脚本:

    from py4j.java_gateway import JavaGateway
    gateway = JavaGateway()            # connect to the JVM
    my_java = gateway.jvm.JavaClass();  # my Java object
    ....
    int_array=my_java.doSomething(int_array); # do something

    my_numpy=np.zeros((size_y,size_x));
    for jj in range(size_y):
        for ii in range(size_x):
            my_numpy[jj,ii]=int_array[jj][ii];

my_numpy 是一个 Numpy 数组,int_array 是一个 Java 数组,其中包含整数 - 一种 int[][] 类型的数组。在 Python 脚本中初始化如下:

    int_class=gateway.jvm.int       # make int class
    double_class=gateway.jvm.double # make double class

    int_array = gateway.new_array(int_class,size_y,size_x)
    double_array = gateway.new_array(double_class,size_y,size_x)

尽管它能够正常工作,但它不是最快的方式,而且速度相对比较慢——对于一个1000x1000的数组,转换需要超过5分钟。

有没有什么方法可以在合理的时间内完成这个任务?

如果我尝试:

    test=np.array(int_array)

我理解为:

    ValueError: invalid __array_struct__
2个回答

4
我曾经遇到类似的问题,并找到了一种解决方案,对于我测试的情况,速度提升了近220倍:将一个大小为1628x120的short整数数组从Java传输到Numpy时,运行时间从11秒减少到0.05秒。感谢这个相关的StackOverflow问题,我开始研究py4j字节数组,结果发现py4j可以高效地将Java字节数组转换为Python字节对象以及反向转换(通过值传递而非引用)。虽然这是一种比较迂回的方法,但并不太难。

因此,如果你想要传输一个大小为x 的整数数组intArray(为了例子简单起见,我假设它们都存储在你的对象实例变量中),你可以先编写一个Java函数将它转换成byte[]形式,如下所示:
public byte[] getByteArray() {
    // Set up a ByteBuffer called intBuffer
    ByteBuffer intBuffer = ByteBuffer.allocate(4*iMax*jMax); // 4 bytes in an int
    intBuffer.order(ByteOrder.LITTLE_ENDIAN); // Java's default is big-endian

    // Copy ints from intArray into intBuffer as bytes
    for (int i = 0; i < iMax; i++) {
        for (int j = 0; j < jMax; j++){
            intBuffer.putInt(intArray[i][j]);
        }
    }

    // Convert the ByteBuffer to a byte array and return it
    byte[] byteArray = intBuffer.array();
    return byteArray;
}

然后,您可以编写Python 3代码来接收字节数组并将其转换为正确形状的numpy数组:

byteArray = gateway.entry_point.getByteArray()
intArray = np.frombuffer(byteArray, dtype=np.int32)
intArray = intArray.reshape((iMax, jMax))

1
你的回答确实帮了我,非常感谢。但是需要做一个小修正 intArray = intArray.reshape((iMax, jMax)),reshape 应该有两个括号。 - Venkataramana Madugula
你说得对,@VenkataramanaMadugula,感谢您的纠正!我已经相应地编辑了我的答案。 - Erlend Magnus Viggen

2

我曾经遇到过类似的问题,只是试图绘制我从Java端通过py4j获取的光谱向量(Java数组)。 在这里,通过list()函数将Java数组转换为Python列表。这可能会提供一些线索,如何使用它来填充NumPy数组...

vectors = space.getVectorsAsArray(); # Java array (MxN)
wvl = space.getAverageWavelengths(); # Java array (N)

wavelengths = list(wvl)

import matplotlib.pyplot as mp
mp.hold
for i, dataset in enumerate(vectors):
    mp.plot(wavelengths, list(dataset))

我无法确定这是否比您使用的嵌套for循环更快,但它也可以完成任务:

import numpy
from numpy  import array
x = array(wavelengths)
v = array(list(vectors))

mp.plot(x, numpy.rot90(v))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接