从Sympy生成优化的Octave代码

3
我有一些包含sin(q)、cos(q)以及它们的和/积的大型矩阵需要导出。Sympy可以计算并将其导出到Octave,这非常棒!但是,由于这些都是大型矩阵,我需要某种形式的cse甚至更好的专门优化。
我找到了一个关于带有cse的C代码的很好的教程。所以我试图自己进行移植,但我在打印机类中遇到了一些细节问题。我认为这是无限递归导致的RecursionError: maximum recursion depth exceeded
我的问题是:是否有示例展示了sympy-octave代码生成和优化如何结合使用?或者有没有人能够帮助我运行附加的mwe?
import sympy as sp
t = sp.symbols('t')

from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):

    def _print_ImmutableDenseMatrix(self, expr):
        sub_exprs, simplified = sp.cse(expr)
        lines = []
        for var, sub_expr in sub_exprs:
            lines.append( self._print(Assignment(var, sub_expr)))
        M = sp.MatrixSymbol('M', *expr.shape)
        return '\n'.join(lines) + '\n' + self._print(Assignment(M, expr))

tmp = sp.sin(t)+sp.sin(t)**2
tmp = sp.ImmutableDenseMatrix((1,1,tmp))
se, ex = sp.cse(tmp)
print((ex,se))
print('\n')
#tmp = sp.Matrix([2*sp.sin(t),sp.sin(t)])
p = matlabMatrixPrinter()
print(p.doprint(tmp))

编辑:我现在明白了,return语句中的第二个赋值也运行了函数_print_ImmutableDenseMatrix,因此这最终成为了递归。我不知道为什么在教程中C代码没有问题,但在这里会运行递归。似乎只是简化的表达式本身不能调用self._print函数的问题。也许有人知道这些打印机是如何打印矩阵和这个单一赋值的吗?!


我已经开始阅读这个笔记本了,所以让我们试一试吧。 - Yakov Dan
1个回答

0

经过许多实验,我感觉自己对codePrinter的意图工作流程只有一点点理解。然而,我编写了一个子类,完全按照我的意图执行(请小心,这可能只适用于矩阵!)。

也许这对某人有用!对我来说,这绝对证明了sympy作为一个可行的工具,否则成千上万次的sin评估将是不可行的代码。

我仍然非常希望得到别人的评论和想法,那些知道如何正确实现这些功能的人!

import sympy as sp
t = sp.symbols('t')
from sympy.printing.octave import OctaveCodePrinter
from sympy.printing.octave import Assignment
class matlabMatrixPrinter(OctaveCodePrinter):
    def print2(self,expr_list,names=None):
        sub_exprs, simplified = sp.cse(expr_list)
        lines = []
        for var, sub_expr in sub_exprs:
            lines.append(self._print(Assignment(var, sub_expr)))
        lines.append('')
        for k,expr in enumerate(simplified):
            if names:
                M = sp.MatrixSymbol(names[k],*expr.shape)
            else:
                M = sp.MatrixSymbol('M{k}'.format(k=k), *expr.shape)
            lines.append(self._print(Assignment(M,expr)))
        result = ''
        return '\n'.join(lines)

tmp = sp.Matrix([sp.sin(t)+sp.sin(t)**2 ])
tmp2 = sp.Matrix([sp.sin(t),sp.cos(t),2*sp.sin(t),sp.cos(t)**2])

p = matlabMatrixPrinter()
#print(p.print2([tmp,tmp2]))
print(p.print2([tmp,tmp2],['scalar_matrix','matrix']));

这将会得到预期的输出:

x0 = sin(t);
x1 = cos(t);
scalar_matrix = x0.^2 + x0;
matrix = [x0; x1; 2*x0; x1.^2];

如上所述:使用需谨慎 :)

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接