np.newaxis在Numba nopython中的应用

7

有没有办法在Numba nopython 模式下使用 np.newaxis?以便在不退回到Python的情况下应用广播函数?

例如:

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

谢谢

3个回答

9

在我的情况下 (numba: 0.35, numpy: 1.14.0),expand_dims 很好地发挥了作用:

import numpy as np
from numba import jit

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - np.expand_dims(b, -1) * np.expand_dims(c, 0)
    return d

当然,我们可以使用广播省略第二个expand_dims

7

您可以使用reshape来完成这个操作,看起来目前不支持[:, None]的索引。请注意,这可能不比在python中进行向量化处理更快,因为已经进行了向量化处理。

@jit(nopython=True)
def toto():
    a = np.random.randn(20, 10)
    b = np.random.randn(20) 
    c = np.random.randn(10)
    d = a - b.reshape((-1, 1)) * c.reshape((1,-1))
    return d

1
我已经尝试过了,但是出现了“reshape()仅支持连续的数组”的错误。当然,“toto()”只是一个示例,而非我的实际函数。 - EntrustName
你可以使用 b.copy().reshape((-1,1))。如果你的数组不是连续的,我相信这个操作仍然会进行复制,但不能百分之百确定。 - chrisb

1

这可以通过最新版本的Numba(0.27)和numpy stride_tricks实现。你需要小心处理,而且有点丑陋。阅读as_strideddocstring确保你理解发生了什么,因为这不是“安全”的,它不检查形状或步幅。

import numpy as np
import numba as nb

a = np.random.randn(20, 10)
b = np.random.randn(20) 
c = np.random.randn(10)

def toto(a, b, c):

    d = a - b[:, np.newaxis] * c[np.newaxis, :]
    return d

@nb.jit(nopython=True)
def toto2(a, b, c):
    _b = np.lib.stride_tricks.as_strided(b, shape=(b.shape[0], 1), strides=(b.strides[0], 0))
    _c = np.lib.stride_tricks.as_strided(c, shape=(1, c.shape[0]), strides=(0, c.strides[0]))
    d = a - _b * _c

    return d

x = toto(a,b,c)
y = toto2(a,b,c)
print np.allclose(x, y) # True

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接