在Numba中,array(float64, 1d, C)和array(float64, 1d, A)有什么区别?

6

我有两个简单且相似的函数。其中一个可以使用numba编译,而另一个则不行。我不明白它们之间的区别。以下是这两个函数:

第一个函数:

@nb.njit(float64[:](float64[:], float64[:]))
def arrAdd(a,b):
    assert a.shape == b.shape
    return a + b

编译成功。当我调用它时,

arrAdd(np.array([1,2.0,21]),np.array([2,3.0,1]))

它会返回:

array([ 3.,  5., 22.])

第二个:
c = np.array([1,2.0,21])
@nb.njit
def arrAdd1(arr):
    return arrAdd(arr,c)

然而,当我调用这个函数时:

arrAdd1([2,3.0,1])

它将会显示:

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of type(CPUDispatcher(<function arrAdd at 0x00000212A9FDC670>)) with parameters (array(float64, 1d, C), readonly array(float64, 1d, C))
Known signatures:
 * (array(float64, 1d, A), array(float64, 1d, A)) -> array(float64, 1d, A)
During: resolving callee type: type(CPUDispatcher(<function arrAdd at 0x00000212A9FDC670>))
During: typing of call at <ipython-input-57-c77e552c5560> (4)


File "<ipython-input-57-c77e552c5560>", line 4:
def arrAdd1(arr):
    return arrAdd(arr, c)
    ^

那么 array(float64, 1d, C) 和 array(float64, 1d, A) 有什么区别呢?
1个回答

3

A (任意), C (按行连续) 和 F (按列连续) 是 数组布局 的三种类型。但这不是你的例子中存在的问题。

问题 1

在这一行中

arrAdd1([2,3.0,1])

你传递的是列表而不是数组。
以下简化版本可以正常工作:
@nb.njit                            # No types
def arrAdd(a, b):
    assert a.shape == b.shape
    return a + b

a = arrAdd(np.array([1, 2.0, 21]), np.array([2, 3.0, 1]))
print(a)

c = np.array([1, 2.0, 21])
@nb.njit                            # No types
def arrAdd1(arr):
    return arrAdd(arr, c)

a = arrAdd1(np.array([2,3.0,1]))    # Pass an array
print(a)

并且生成

[ 3.  5. 22.]
[ 3.  5. 22.]

问题 2

在您的例子中,arrAdd1() 被定义为闭包,因此 c 在函数中成为一个常量。

如果你真的想使用显式参数类型,你需要指定 addArr() 至少在第二个参数接收一个常量数组。

只要函数不修改它们的输入,你可以声明所有的输入参数为只读,就像这个例子一样,产生相同的结果:

vector = nb.types.Array(dtype=f8, ndim=1, layout="A")
readonly_vector = nb.types.Array(dtype=f8, ndim=1, layout="A", readonly=True)

@nb.njit(vector(readonly_vector, readonly_vector))
def arrAdd(a, b):
    assert a.shape == b.shape
    return a + b

a = arrAdd(np.array([1, 2.0, 21]), np.array([2, 3.0, 1]))
print(a)

c = np.array([1, 2.0, 21])
@nb.njit(vector(readonly_vector))
def arrAdd1(arr):
    return arrAdd(arr, c)

a = arrAdd1(np.array([2,3.0,1]))
print(a)

您可以将布局(A、C、F)更改为最适合您的布局。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接