在使用Numba时如何指定“字符串”数据类型?

8
Numba无法识别该字符串。我该如何纠正以下代码呢?谢谢!
@nb.jit(nb.float64(nb.float64[:], nb.char[:]), nopython=True, cache=True)
def func(x, y='cont'):
    """
    :param x: is np.array, x.shape=(n,)
    :param y: is a string, 
    :return: a np.array of same shape as x
    """
    return result
1个回答

6
以下内容适用于Numba 0.44版本:
import numpy as np
import numba as nb

from numba import types

@nb.jit(nb.float64[:](nb.float64[:], types.unicode_type), nopython=True, cache=True)
def func(x, y='cont'):
    """
    :param x: is np.array, x.shape=(n,)
    :param y: is a string, 
    :return: a np.array of same shape as x
    """
    print(y)
    return x

然而,如果你尝试运行没有指定y值的func函数,会出现错误,因为在定义中指定了第二个参数是必需的。我尝试着找出如何处理可选参数(看了一下types.Omitted),但是无法完全弄清楚。可能需要将签名不予指定,让numba进行正确的类型推断:

@nb.jit(nopython=True, cache=True)
def func2(x, y='cont'):
    """
    :param x: is np.array, x.shape=(n,)
    :param y: is a string, 
    :return: a np.array of same shape as x
    """
    print(y)
    return x

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接