在numpy中获取结果数组的数据类型

6

我希望预先分配一个数组操作输出的内存空间,并需要知道要将其设置为什么数据类型。下面是我想要实现的函数,但是它的样子很丑。

import numpy as np

def array_operation(arr1, arr2):
    out_shape = arr1.shape
    # Get the dtype of the output, these lines are the ones I want to replace.
    index1 = ([0],) * arr1.ndim
    index2 = ([0],) * arr2.ndim
    tmp_arr = arr1[index1] * arr2[index2]
    out_dtype = tmp_arr.dtype
    # All so I can do the following.
    out_arr = np.empty(out_shape, out_dtype)

上述代码看起来相当不简洁。numpy是否有一个能够完成此任务的函数?

你只对乘法后的dtype感兴趣吗? - unutbu
@unutbu,你能解释一下为什么你会这样问吗? - Mike Graham
@Mike Graham:我们可以枚举所有基本数据类型,计算它们相乘后的结果类型,并将结果存储在字典中。但是对于更复杂的数据类型,这种方法就不适用了... - unutbu
@ubutbu,你能举一个例子吗?在另一个操作中,乘法和另一个操作会给出不同的结果,这取决于原始数据类型。 - Mike Graham
@Mike Graham:我更多地考虑了任意(numpy)函数调用。 - unutbu
2个回答

8
你需要的是 numpy.result_type
(顺便提一下,你知道吗?你可以把所有多维数组看作是一维数组来访问。你不需要访问 x[0, 0, 0, 0, 0] -- 你可以访问 x.flat[0]。)

你使用的numpy版本是什么?1.3.0似乎没有result_type函数。 - kiyo
提醒一下,x[0] 会对除第一个轴之外的所有轴进行“full slice”。这可能会导致很多额外的乘法操作。 - kiyo

1

对于使用 numpy 版本小于 1.6 的用户,可以使用以下代码:

def result_type(arr1, arr2):
    x1 = arr1.flat[0]
    x2 = arr2.flat[0]
    return (x1 * x2).dtype

def array_operation(arr1, arr2):
    return np.empty(arr1.shape, result_type(arr1, arr2))

这与您发布的代码并没有太大区别,尽管我认为arr1.flat[0]index1 = ([0],) * arr1.ndim; arr1[index1]稍微好一些。

对于numpy版本>= 1.6,请使用Mike Graham的答案,np.result_type


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接