如何让Cython函数接受浮点数或双精度数组输入?

3
假设我有以下(MCVE...)的Cython函数:
cimport cython

from scipy.linalg.cython_blas cimport dnrm2


cpdef double func(int n, double[:] x):
   cdef int inc = 1
   return dnrm2(&n, &x[0], &inc)

那么,我无法在一个np.float32数组x上调用它。

我该如何让func接受double[:]float[:],并交替调用dnrm2snrm2?目前我唯一的解决方案是创建两个函数,这会产生大量重复的代码。

1个回答

4

你可以使用融合类型。请注意,下面的代码在我的系统上无法编译,因为 ddotsdot 显然需要 5 个参数:

# cython: infer_types=True
cimport cython

from scipy.linalg.cython_blas cimport ddot, sdot

ctypedef fused anyfloat:
   double
   float

cpdef anyfloat func(int n, anyfloat[:] x):
   cdef int inc = 1
   if anyfloat is double:
      return ddot(&n, &x[0], &inc)
   else:
      return sdot(&n, &x[0], &inc)

是的,我混淆了,我的意思是dnrm2(不管怎样,这只是一个虚拟例子) “anyfloat是double”是否会造成昂贵的开销? - P. Camilleri
@P.Camilleri 我不是很确定,但我猜应该不行。 - Paul Panzer
1
if语句在编译时进行评估,因此完全没有开销(您可以检查生成的C代码以确认)。当从Python调用函数并且必须选择哪个版本时,开销发生(尽管它通常可以直接使用来自Cython的正确版本)。 - DavidW

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接