我能否使用Cython在Python中覆盖C++虚函数?

14

我有一个带有虚函数的C++类:

//C++
class A
{

    public:
        A() {};
        virtual int override_me(int a) {return 2*a;};
        int calculate(int a) { return this->override_me(a) ;}

};

我的目标是使用Cython将这个类暴露给Python,在Python中继承这个类并调用正确的重写方法:

#python:
class B(PyA):
   def override_me(self, a):
       return 5*a
b = B()
b.calculate(1)  # should return 5 instead of 2
有没有办法做到这一点?现在我在考虑,如果我们也可以在Cython中覆盖虚拟方法(在pyx文件中),那就太好了,但是让用户能够在纯Python中实现这一点更为重要。 编辑:如果有帮助的话,一个解决方案可以使用此处提供的伪代码:http://docs.cython.org/src/userguide/pyrex_differences.html#cpdef-functions 但是会有两个问题:
  • 我不知道如何在Cython中编写这个伪代码
  • 可能有更好的方法

当然可以。它会返回2。您需要pyx源码吗(这是明显错误的,但我还没有找到解决方法)? - ascobol
不,我认为我不能帮忙。我认为boost.python支持这个。 - Sven Marnach
确实,我几年前使用了boost.python完成了它。现在我想尝试boost.python的替代方案(编译时间太长,生成的模块太大等问题)。如果Cython能够解决这个问题,我认为其余的工作将会很顺利。 - ascobol
2
我认为这不是直接支持的,但在邮件列表中提到了一个解决方法。 - Fred Foo
另一个解决方法是使用策略模式或类似的东西,而不是方法重载。 - katzenversteher
2个回答

11

解决方法有些复杂,但是可以实现。这里有一个完全可用的示例:https://bitbucket.org/chadrik/cy-cxxfwk/overview

以下是该技术的概述:

创建一个特定的class A子类,其目的是与cython扩展程序进行交互:

// created by cython when providing 'public api' keywords:
#include "mycymodule_api.h"

class CyABase : public A
{
public:
  PyObject *m_obj;

  CyABase(PyObject *obj);
  virtual ~CyABase();
  virtual int override_me(int a);
};
构造函数需要一个Python对象,这个Python对象是我们Cython扩展的实例:
CyABase::CyABase(PyObject *obj) :
  m_obj(obj)
{
  // provided by "mycymodule_api.h"
  if (import_mycymodule()) {
  } else {
    Py_XINCREF(this->m_obj);
  }
}

CyABase::~CyABase()
{
  Py_XDECREF(this->m_obj);
}

在Cython中创建此子类的扩展,以标准方式实现所有非虚拟方法。

cdef class A:
    cdef CyABase* thisptr
    def __init__(self):
        self.thisptr = new CyABase(
            <cpy_ref.PyObject*>self)

    #------- non-virutal methods --------
    def calculate(self):
        return self.thisptr.calculate()

将虚拟和纯虚拟方法创建为 public api 函数,该函数接受扩展实例、方法参数和错误指针作为参数:

cdef public api int cy_call_override_me(object self, int a, int *error):
    try:
        func = self.override_me
    except AttributeError:
        error[0] = 1
        # not sure what to do about return value here...
    else:
        error[0] = 0
        return func(a)

在你的C++中间层中使用这些函数:

int
CyABase::override_me(int a)
{
  if (this->m_obj) {
    int error;
    // call a virtual overload, if it exists
    int result = cy_call_override_me(this->m_obj, a, &error);
    if (error)
      // call parent method
      result = A::override_me(i);
    return result;
  }
  // throw error?
  return 0;
}

我很快就将我的代码适应到你的示例中,所以可能会有错误。请查看存储库中的完整示例,它应该能回答你大部分的问题。随意fork它并添加自己的实验,它远未完成!


这是一个很好的开始,非常感谢。但是 override_me() 方法能否被 Python 脚本调用?如果在 C++ 中此方法不是纯虚函数,则应该可以从 Python 部分调用它。 - ascobol
我已经编写了一份指南 https://monadical.com/posts/virtual-classes-in-cython.html - jdcaballerov

9

太好了!

虽然不完整但足够使用。我已经能够为自己的目的做到这一点。结合上面链接的来源和这篇文章,这并不容易,因为我是Cython的初学者,但我确认这是我在www上找到的唯一方法。

非常感谢你们。

很抱歉我没有太多时间进入文本细节,但这里是我的文件(可能有助于获得另一个将所有这些内容组合在一起的视角)

setup.py :

from distutils.core import setup
from distutils.extension import Extension
from Cython.Distutils import build_ext

setup(
    cmdclass = {'build_ext': build_ext},
    ext_modules = [
    Extension("elps", 
              sources=["elps.pyx", "src/ITestClass.cpp"],
              libraries=["elp"],
              language="c++",
              )
    ]
)

TestClass :

#ifndef TESTCLASS_H_
#define TESTCLASS_H_


namespace elps {

class TestClass {

public:
    TestClass(){};
    virtual ~TestClass(){};

    int getA() { return this->a; };
    virtual int override_me() { return 2; };
    int calculate(int a) { return a * this->override_me(); }

private:
    int a;

};

} /* namespace elps */
#endif /* TESTCLASS_H_ */

ITestClass.h:

#ifndef ITESTCLASS_H_
#define ITESTCLASS_H_

// Created by Cython when providing 'public api' keywords
#include "../elps_api.h"

#include "../../inc/TestClass.h"

namespace elps {

class ITestClass : public TestClass {
public:
    PyObject *m_obj;

    ITestClass(PyObject *obj);
    virtual ~ITestClass();
    virtual int override_me();
};

} /* namespace elps */
#endif /* ITESTCLASS_H_ */

ITestClass.cpp :

#include "ITestClass.h"

namespace elps {

ITestClass::ITestClass(PyObject *obj): m_obj(obj) {
    // Provided by "elps_api.h"
    if (import_elps()) {
    } else {
        Py_XINCREF(this->m_obj);
    }
}

ITestClass::~ITestClass() {
    Py_XDECREF(this->m_obj);
}

int ITestClass::override_me()
{
    if (this->m_obj) {
        int error;
        // Call a virtual overload, if it exists
        int result = cy_call_func(this->m_obj, (char*)"override_me", &error);
        if (error)
            // Call parent method
            result = TestClass::override_me();
        return result;
    }
    // Throw error ?
    return 0;
}

} /* namespace elps */

编辑2:关于纯虚方法的说明(似乎是一个相当经常出现的问题)。如上所示,以这种特定方式,“TestClass :: override_me()”不能是纯虚的,因为如果在Python的扩展类中没有重写该方法,则必须能够调用它(即:不落入“错误”/“未找到覆盖”部分的“ITestClass :: override_me()”体中)。

扩展:elps.pyx:

cimport cpython.ref as cpy_ref

cdef extern from "src/ITestClass.h" namespace "elps" :
    cdef cppclass ITestClass:
        ITestClass(cpy_ref.PyObject *obj)
        int getA()
        int override_me()
        int calculate(int a)

cdef class PyTestClass:
    cdef ITestClass* thisptr

    def __cinit__(self):
       ##print "in TestClass: allocating thisptr"
       self.thisptr = new ITestClass(<cpy_ref.PyObject*>self)
    def __dealloc__(self):
       if self.thisptr:
           ##print "in TestClass: deallocating thisptr"
           del self.thisptr

    def getA(self):
       return self.thisptr.getA()

#    def override_me(self):
#        return self.thisptr.override_me()

    cpdef int calculate(self, int a):
        return self.thisptr.calculate(a) ;


cdef public api int cy_call_func(object self, char* method, int *error):
    try:
        func = getattr(self, method);
    except AttributeError:
        error[0] = 1
    else:
        error[0] = 0
        return func()

最后,Python 调用:
from elps import PyTestClass as TC;

a = TC(); 
print a.calculate(1);

class B(TC):
#   pass
    def override_me(self):
        return 5

b = B()
print b.calculate(1)

这应该能让前面链接的工作更加直接地与我们正在讨论的问题相关。另一方面,上述代码可以通过使用“hasattr”而不是try/catch块进行优化:
cdef public api int cy_call_func_int_fast(object self, char* method, bint *error):
    if (hasattr(self, method)):
        error[0] = 0
        return getattr(self, method)();
    else:
        error[0] = 1

当然,上述代码只有在我们不覆盖“override_me”方法的情况下才会有所区别。

1
对于那些想尝试这个示例的人,需要注意:TestClass()和~TestClass()的实现是缺失的。这将导致类似于“ImportError: ./elps.so: undefined symbol: _ZTIN4elps9TestClassE”的错误。只需添加一个空的内联实现即可。 - ascobol
您的解决方案中是否有一种方法可以将虚拟方法(如override_me())暴露给Python侧? - ascobol
只要你更改名称,你应该能够使用以下代码:def call_override_me(self): return self.thisptr.override_me() - Gauthier Boaglio
我更喜欢保持相同的名称。但是现在我有一个想法:cy_call_func_int_fast可以检查override_me是否已被覆盖。它需要比较类中的方法override_me和实例化对象中的方法,例如if PyTestClass.override_me != self.__class__.override_me。也许这样可以行得通... - ascobol
从我的测试来看,它似乎可以工作。在cy_call_func中,不要寻找属性"override_me",而是执行检查if self.__class__.override_me == PyA.override_me。如果为真,则将错误设置为1(我们必须使用C++版本)。如果为假,则使用Python版本,因为它已经被重载。这样,您就可以取消注释您的pyx文件中的def override_me(self) - ascobol
显示剩余2条评论

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接