提高numpy映射操作的性能

7

我有一个大小为(4,X,Y)的numpy数组,其中第一维代表一个(R,G,B,A)四元组。

我的目标是将每个X*Y的RGBA四元组转置为X*Y的浮点值,并给出一个匹配它们的字典。

我的当前代码如下:

codeTable = {
    (255, 255, 255, 127): 5.5,
    (128, 128, 128, 255): 6.5,
    (0  , 0  , 0  , 0  ): 7.5,
}

for i in range(0, rows):
    for j in range(0, cols):
        new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)

这段代码中,data是一个大小为(4, rows, cols)的numpy数组,而new_data的大小为(rows, cols)

这段代码工作良好,但运行时间较长。我应该如何优化这段代码?

以下是完整示例:

import numpy

codeTable = {
    (253, 254, 255, 127): 5.5,
    (128, 129, 130, 255): 6.5,
    (0  , 0  , 0  , 0  ): 7.5,
}

# test data
rows = 2
cols = 2
data = numpy.array([
    [[253, 0], [128,   0], [128,  0]],
    [[254, 0], [129, 144], [129,  0]],
    [[255, 0], [130, 243], [130,  5]],
    [[127, 0], [255, 120], [255,  5]],
])

new_data = numpy.zeros((rows,cols), numpy.float32)

for i in range(0, rows):
    for j in range(0, cols):
        new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)

# expected result for `new_data`:
# array([[  5.50000000e+00,   7.50000000e+00],
#        [  6.50000000e+00,  -9.99900000e+03],
#        [  6.50000000e+00,  -9.99900000e+03], dtype=float32)

有多少行(rows)和列(cols)? - Will
也许这会有所帮助:http://stackoverflow.com/questions/36480358/whats-a-fast-non-loop-way-to-apply-a-dict-to-a-ndarray-meaning-use-elements - hpaulj
你能否创建一个可以运行的代码片段?当前的代码无法工作,因此很难帮助你。 - John Karasinski
第二列是否总是为-9999? - John Karasinski
@karasinski 不,那只是一个例子。实际上,输入数据是任意的PNG图像。 - lesenk
显示剩余2条评论
2个回答

1

numpy_indexed 包含一个向量化的 nd-array 变体,能够解决列表索引的问题,可以高效且简洁地解决您的问题(免责声明:我是该包的作者):

import numpy_indexed as npi
map_keys = np.array(list(codeTable.keys()))
map_values = np.array(list(codeTable.values()))
indices = npi.indices(map_keys, data.reshape(4, -1).T, missing='mask')
remapped = np.where(indices.mask, -9999, map_values[indices.data]).reshape(data.shape[1:])

你的解决方案似乎运行得非常完美,谢谢!后面我会讨论关于性能方面的改进。 - lesenk

1
这里有一种方法可以返回您期望的结果,但由于数据量很小,很难知道这是否对您来说更快。然而,由于我避免了双重循环,我想您会看到相当不错的加速效果。
import numpy
import pandas as pd


codeTable = {
    (253, 254, 255, 127): 5.5,
    (128, 129, 130, 255): 6.5,
    (0  , 0  , 0  , 0  ): 7.5,
}

# test data
rows = 3
cols = 2
data = numpy.array([
    [[253, 0], [128,   0], [128,  0]],
    [[254, 0], [129, 144], [129,  0]],
    [[255, 0], [130, 243], [130,  5]],
    [[127, 0], [255, 120], [255,  5]],
])

new_data = numpy.zeros((rows,cols), numpy.float32)

for i in range(0, rows):
    for j in range(0, cols):
        new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999)

def create_output(data):
    # Reshape your two data sources to be a bit more sane
    reshaped_data = data.reshape((4, -1))
    df = pd.DataFrame(reshaped_data).T

    reshaped_codeTable = []
    for key in codeTable.keys():
        reshaped = list(key) + [codeTable[key]]
        reshaped_codeTable.append(reshaped)
    ct = pd.DataFrame(reshaped_codeTable)

    # Merge on the data, replace missing merges with -9999
    result = df.merge(ct, how='left')
    newest_data = result[4].fillna(-9999)

    # Reshape
    output = newest_data.reshape(rows, cols)
    return output

output = create_output(data)
print(output)
# array([[  5.50000000e+00,   7.50000000e+00],
#        [  6.50000000e+00,  -9.99900000e+03],
#        [  6.50000000e+00,  -9.99900000e+03])

print(numpy.array_equal(new_data, output))
# True

你的解决方案似乎只适用于正方形输入数据,并且在cols != rows时无法工作。但是感谢你的想法,我会进行调查。无论如何,速度比我的朴素双重循环解决方案要令人满意得多。 - lesenk
已修复!现在将按请求的行数和列数进行操作。 - John Karasinski
好的,你的代码对于其他数据形状不起作用。我已经在我的初始消息中更新了一个更复杂的例子。你的代码返回了正确的结果,但是在输出数组中的位置错误。 - lesenk
我根据你提供的新测试数据更新了我的答案,但仍然得到了正确的结果。你能否提供一个导致我的方法失败的示例数据? - John Karasinski
我第一条消息中的示例在我的电脑上失败了,使用的是Python 2.7(我已将解包*key更改为key[0],key[1],key[2],key[3])。我现在无法使用Python3,因为数据是使用目前仅在Python2中可用的库加载的。 - lesenk
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接