Cython中的复数

43
什么是使用Cython处理复数的正确方法?我想使用dtype为np.complex128的numpy.ndarray编写一个纯C循环。在Cython中,相关的C类型在 Cython/Includes/numpy/__init__.pxd 中定义。
ctypedef double complex complex128_t

看起来这只是一个简单的C复数。

然而,它很容易出现奇怪的行为。特别是在使用以下定义时

cimport numpy as np
import numpy as np
np.import_array()

cdef extern from "complex.h":
    pass

cdef:
    np.complex128_t varc128 = 1j
    np.float64_t varf64 = 1.
    double complex vardc = 1j
    double vard = 1.

这条线

varc128 = varc128 * varf64

可以用Cython编译,但gcc不能编译生成的C代码(错误为“testcplx.c:663:25: error: two or more data types in declaration specifiers”,可能是由于这行代码 typedef npy_float64 _Complex __pyx_t_npy_float64_complex;)。已经有人报告了这个错误(例如这里),但我没有找到任何好的说明和/或清洁的解决方案。

如果不包含complex.h,就不会出现错误(我猜这是因为没有包括typedef)。

然而,在通过cython -a testcplx.pyx生成的html文件中,这行代码varc128 = varc128 * varf64是黄色的,这意味着它还没有被翻译成纯C。相应的C代码是:

__pyx_t_2 = __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex_from_parts(__Pyx_CREAL(__pyx_v_8testcplx_varc128), __Pyx_CIMAG(__pyx_v_8testcplx_varc128)), __pyx_t_npy_float64_complex_from_parts(__pyx_v_8testcplx_varf64, 0));
__pyx_v_8testcplx_varc128 = __pyx_t_double_complex_from_parts(__Pyx_CREAL(__pyx_t_2), __Pyx_CIMAG(__pyx_t_2));

__Pyx_CREAL__Pyx_CIMAG 是橙色的(Python调用)。

有趣的是,这行代码

vardc = vardc * vard

不会产生任何错误,并被翻译为纯C代码(只有__pyx_v_8testcplx_vardc = __Pyx_c_prod(__pyx_v_8testcplx_vardc, __pyx_t_double_complex_from_parts(__pyx_v_8testcplx_vard, 0));),但与第一个非常相似。

我可以通过使用中间变量来避免错误(并且它将被翻译为纯C代码):

vardc = varc128
vard = varf64
varc128 = vardc * vard

或者仅仅通过强制转换(但这并不相当于纯C语言):

vardc = <double complex>varc128 * <double>varf64

发生了什么?编译错误的意义是什么?有没有干净的方法可以避免它?为什么np.complex128_t和np.float64_t的乘法似乎涉及Python调用?

版本

Cython版本0.22(在提问时PyPI中最新版本)和GCC 4.9.2。

存储库

我创建了一个小型存储库,其中包含示例(hg clone https://bitbucket.org/paugier/test_cython_complex)和一个带有3个目标的小型Makefile(make cleanmake buildmake html),因此很容易测试任何东西。


一些想法(尽管我不知道答案):我认为gcc期望typedef _Complex npy_float64 __pyx_t_npy_float64_complex;(请注意,此语句中首先出现复数)。这可以通过编辑.c文件或(可能)您的Cython/Includes/numpy/__init__.pxd来交换顺序来解决。 - DavidW
(更正:编辑Cython/Includes/numpy/init.pxd文件无效)同时,修改.c文件只会得到另一个错误,但这可能是朝着正确方向迈出的一步。 - DavidW
你使用的Cython版本和GCC版本是什么?使用Cython与Numpy文档在结尾处有这样一段话:“在Cython 0.11.2中,np.complex64_tnp.complex128_t无法工作,必须改为写complex或double complex。这在0.11.3中已经修复。Cython 0.11.1及更早版本不支持复数。” - Will
不是这样的:Cython版本0.22(今天在Pypi上最新版本)和GCC 4.9.2...我将编辑问题以提供版本信息。无论如何,感谢您的评论。 - paugier
这不是答案,而是我的评论 - 在我的安装中,如果您反转乘法的顺序,即varc128 = varf64 * varc128,它似乎可以工作,但没有中间变量。您能否确认您的情况是否相同?(显然,这并不能回答更有趣的问题 - 为什么?)。我只是在尝试一些东西时注意到了这个问题,但我完全是新手cython(即今天设置环境!),虽然之前有一些C和Python的经验。 - J Richard Snape
1个回答

21
我能找到的最简单的解决此问题的方法是简单地改变乘法的顺序。
如果在testcplx.pyx中,我更改:
varc128 = varc128 * varf64

varc128 = varf64 * varc128

我从之前的失败情况转变为正确工作的情况。这种情况很有用,因为它允许直接对生成的C代码进行比较。

简短概括

乘法运算的顺序会影响翻译,这意味着在失败的版本中,乘法是通过__pyx_t_npy_float64_complex类型尝试进行的,而在工作的版本中,则是通过__pyx_t_double_complex类型进行的。这反过来引入了无效的typedef行typedef npy_float64 _Complex __pyx_t_npy_float64_complex;

我相当确定这是一个cython的bug(更新:在此处报告)。虽然这是一个非常古老的gcc bug报告,但回复明确指出(说它实际上不是gcc的bug,而是用户代码错误):

typedef R _Complex C;

This is not valid code; you can't use _Complex together with a typedef, only together with "float", "double" or "long double" in one of the forms listed in C99.

他们得出结论,double _Complex 是一个有效的类型说明符,而 ArbitraryType _Complex 不是。这份较新的报告 也有相同类型的回应——尝试在非基本类型上使用 _Complex 超出了规范,GCC 手册 指出 _Complex 只能与 floatdoublelong double 一起使用。

因此,我们可以修改 Cython 生成的 C 代码进行测试:将 typedef npy_float64 _Complex __pyx_t_npy_float64_complex; 替换为 typedef double _Complex __pyx_t_npy_float64_complex; 并验证它确实有效,并且可以使输出代码编译通过。


代码简短的远足

交换乘法顺序只是突出了编译器告诉我们的问题。在第一种情况下,有问题的行是这一行:typedef npy_float64 _Complex __pyx_t_npy_float64_complex; - 它试图将类型npy_float64关键字_Complex分配给类型__pyx_t_npy_float64_complex

float _Complexdouble _Complex是有效的类型,而npy_float64 _Complex不是。为了看到效果,您可以从该行中删除npy_float64,或者将其替换为doublefloat,代码可以编译通过。接下来的问题是为什么会产生那一行...

这似乎是由Cython源代码中的this line生成的。

为什么乘法的顺序会显着改变代码 - 以至于引入了类型__pyx_t_npy_float64_complex,并以一种失败的方式引入?

在失败的实例中,实现乘法的代码将varf64转换为__pyx_t_npy_float64_complex类型,对实部和虚部进行乘法,然后重新组装复数。在工作版本中,它通过__Pyx_c_prod函数直接使用__pyx_t_double_complex类型进行乘积。

我想这就像是Cython代码从它遇到的第一个变量中获得其乘法类型一样简单。在第一个情况下,它看到一个float64,因此生成基于该类型的(无效)C代码,而在第二个情况下,它看到了(double)complex128类型,并基于此进行翻译。这个解释有点含糊不清,如果时间允许,我希望能回到对它的分析...
关于这一点 - 在这里我们看到npy_float64typedefdouble,因此在这种情况下,修复可能包括修改这里的代码以使用double _Complex,其中typenpy_float64,但这超出了SO答案的范围,也没有提供通用解决方案。

C代码差异结果

工作版本

从行“varc128 = varf64 * varc128”创建此C代码

__pyx_v_8testcplx_varc128 = __Pyx_c_prod(__pyx_t_double_complex_from_parts(__pyx_v_8testcplx_varf64, 0), __pyx_v_8testcplx_varc128);

失败版本

从行 varc128 = varc128 * varf64 生成此C代码

__pyx_t_2 = __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex_from_parts(__Pyx_CREAL(__pyx_v_8testcplx_varc128), __Pyx_CIMAG(__pyx_v_8testcplx_varc128)), __pyx_t_npy_float64_complex_from_parts(__pyx_v_8testcplx_varf64, 0));
  __pyx_v_8testcplx_varc128 = __pyx_t_double_complex_from_parts(__Pyx_CREAL(__pyx_t_2), __Pyx_CIMAG(__pyx_t_2));

这需要额外导入的东西 - 问题所在是这行代码:typedef npy_float64 _Complex __pyx_t_npy_float64_complex;,它尝试将类型npy_float64 类型_Complex 分配给类型__pyx_t_npy_float64_complex
#if CYTHON_CCOMPLEX
  #ifdef __cplusplus
    typedef ::std::complex< npy_float64 > __pyx_t_npy_float64_complex;
  #else
    typedef npy_float64 _Complex __pyx_t_npy_float64_complex;
  #endif
#else
    typedef struct { npy_float64 real, imag; } __pyx_t_npy_float64_complex;
#endif

/*... loads of other stuff the same ... */

static CYTHON_INLINE __pyx_t_npy_float64_complex __pyx_t_npy_float64_complex_from_parts(npy_float64, npy_float64);

#if CYTHON_CCOMPLEX
    #define __Pyx_c_eq_npy_float64(a, b)   ((a)==(b))
    #define __Pyx_c_sum_npy_float64(a, b)  ((a)+(b))
    #define __Pyx_c_diff_npy_float64(a, b) ((a)-(b))
    #define __Pyx_c_prod_npy_float64(a, b) ((a)*(b))
    #define __Pyx_c_quot_npy_float64(a, b) ((a)/(b))
    #define __Pyx_c_neg_npy_float64(a)     (-(a))
  #ifdef __cplusplus
    #define __Pyx_c_is_zero_npy_float64(z) ((z)==(npy_float64)0)
    #define __Pyx_c_conj_npy_float64(z)    (::std::conj(z))
    #if 1
        #define __Pyx_c_abs_npy_float64(z)     (::std::abs(z))
        #define __Pyx_c_pow_npy_float64(a, b)  (::std::pow(a, b))
    #endif
  #else
    #define __Pyx_c_is_zero_npy_float64(z) ((z)==0)
    #define __Pyx_c_conj_npy_float64(z)    (conj_npy_float64(z))
    #if 1
        #define __Pyx_c_abs_npy_float64(z)     (cabs_npy_float64(z))
        #define __Pyx_c_pow_npy_float64(a, b)  (cpow_npy_float64(a, b))
    #endif
 #endif
#else
    static CYTHON_INLINE int __Pyx_c_eq_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_sum_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_diff_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_prod_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_quot_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_neg_npy_float64(__pyx_t_npy_float64_complex);
    static CYTHON_INLINE int __Pyx_c_is_zero_npy_float64(__pyx_t_npy_float64_complex);
    static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_conj_npy_float64(__pyx_t_npy_float64_complex);
    #if 1
        static CYTHON_INLINE npy_float64 __Pyx_c_abs_npy_float64(__pyx_t_npy_float64_complex);
        static CYTHON_INLINE __pyx_t_npy_float64_complex __Pyx_c_pow_npy_float64(__pyx_t_npy_float64_complex, __pyx_t_npy_float64_complex);
    #endif
#endif

2
@taleinat 你激励了我创建一个工单 - 引用了一个旧的工单,该工单对同一问题有不同的角度。我还在答案中编辑了一个链接。http://trac.cython.org/ticket/850#ticket - J Richard Snape

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接