Cython将双精度复数返回为浮点数复数,会导致表达式不是纯C。

3
我在使用Cython时遇到了使用complex64_t的问题。下面是我的一个简单的Cython例子。
cimport numpy as cnp

cdef extern from "complex.h":
    double complex cexp(double complex)

cpdef example():
    cdef float b = 2.0
    cdef cnp.complex64_t temp1
    cdef cnp.complex128_t temp2

    temp1 = cexp(1j * b)
    temp2 = cexp(1j * b)

当我使用以下setup.py将文件进行cythonize处理时:
from distutils.core import setup
from Cython.Build import cythonize
from distutils.extension import Extension
import numpy as np


ext_modules = [
    Extension(
        "bug_example",
        ["bug_example.pyx"],
        include_dirs=[np.get_include()],
    )
]


setup(
    name='bug_example',
    ext_modules=cythonize(ext_modules, annotate=True,
                          compiler_directives={'boundscheck': False})
)

所有内容都能够编译通过,但是在包含以下代码的行上出现了黄色警告(不是纯C语言)

temp1 = cexp(1j * b)

但不在此处。
temp2 = cexp(1j * b)

似乎存在将双倍复杂值返回为浮点复杂值的问题。我已尝试将其显式转换为浮点复杂值,例如:

temp1 = <float complex>(cexp(1j * b))

但这并没有什么区别。

有人能帮我修复代码,让temp1这一行不再有黄色,并且是纯C语言。这将允许我在Cython中使用openmp。


1
它是否实际上使用了任何Python API调用(或以其他方式阻塞openmp)?注释着色是提示;如果您单击“+”号并查看C代码,它在做什么?我认为它应该调用一些宏或内联函数,在C文件中进一步定义,只需执行类似于“(x)+(y)*(_CYTHON_COMPLEX_64_I_CONST_HERE)”的操作(如果它不使用Python(加上宏调用以获取Cython complex128的实部和虚部),您应该能够轻松验证。 - abarnert
尽管可能会有点凌乱,因为我记得 Cython 为 C++98、C11、C99 和 C89 的每个宏都编写了替代版本,并通过检查预处理器标志在它们之间进行切换。 - abarnert
另外,您没有指定语言。如果您建立C89,则会使用structs来处理complex64和complex128,这并不像使用Python之类的东西那样慢,但仍然不如内置C复杂类型(在涉及无穷大的情况下可能不正确)。 - abarnert
你使用的是哪个编译器?如果是gcc,那么我怀疑这行代码是否会阻止openmp的使用,因为宏__Pyx_CREAL会扩展到__real__扩展名。 - ead
感谢大家的帮助。 - rtclark
1个回答

1
黄色是由于 __Pyx_CREAL__Pyx_CIMAG 导致的,这不应该是一个问题,但谁知道呢...
为了避免这种情况,你需要避免从 doublefloat 的转换以及相反的转换。
例如:
cimport numpy as cnp

#take the float version (cexpf) instead of double-version (cexp)
cdef extern from "complex.h":
     float complex cexpf(float complex)

#1j maps to double complex, so create a float version
cdef float complex float_1j = 1j

cpdef example():
    cdef float b_float = 2.0                              #use float not double
    cdef cnp.complex64_t temp1 = cexpf(float_1j*b_float)  #everything is float

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接