如何在SymPy中展开矩阵表达式?

12
在SymPy中,我正在尝试执行矩阵乘法并在此后进行展开。然而,SymPy似乎不支持矩阵表达式的展开。例如,以下是矩阵的四阶Runge-Kutta(RK4):
from sympy import init_session
init_session()
from sympy import *

A = MatrixSymbol('A', 3, 3)
x = MatrixSymbol('x', 3, 1)
dt = symbols('dt')

k1 = A*x
k2 = A*(x + S(1)/2*k1*dt)
k3 = A*(x + S(1)/2*k2*dt)
k4 = A*(x + k3*dt)
final = dt*S(1)/6*(k1 + 2*k2 + 2*k3 + k4)
final.expand()

产生结果的

Traceback (most recent call last)
<ipython-input-38-b3ff67883c61> in <module>()
     12 final = dt*1/6*(k1+2*k2+2*k3+k4)
     13 
---> 14 final.expand()

AttributeError: 'MatMul' object has no attribute 'expand'

我希望这个表达式可以像标量变量一样扩展:

A,x,dt = symbols('A x dt')
k1 = A*x
k2 = A*(x+k1*dt*S(1)/2)
k3 = A*(x+k2*dt*S(1)/2)
k4 = A*(x+k3*dt)
final = x+dt*(S(1)/6)*(k1+k2+k3+k4)
collect(expand((final)),x)

结果为:

x*(A**4*dt**4/24 + A**3*dt**3/8 + A**2*dt**2/3 + 2*A*dt/3 + 1)

是否可能对矩阵表达式进行类似的改变?

nicoguaro的答案消除了错误,但将整个表达式扩展为一个矩阵。正如标量示例所示,这不是我要找的。


1
应该可以工作。我打开了https://github.com/sympy/sympy/issues/10360。 - asmeurer
2个回答

3

Matrix(final)会创建一个显式矩阵,其中包含您的个人方程式。将它们留在矩阵中可能很方便,这样可以对所有条目执行相同的操作。要实现这一点,请将您想要进行的操作定义为applyfunc的函数参数:

>>> ex = Matrix(final)
>>> ex = ex.applyfunc(expand)
>>> ex = ex.applyfunc(lambda i: collect(i, dt))
...

为了节约打印成本,我使用带有紧凑符号条目的矩阵来计算,并在简化后的矩阵上运行来进行如下操作:
>>> A = Matrix(3, 3, var('a:3:3'))
>>> x = Matrix(3, 1, var('x:3'))
>>> dt = symbols('dt')
>>> k1 = A*x
>>> k2 = A*(x + S(1)/2*k1*dt)
>>> k3 = A*(x + S(1)/2*k2*dt)
>>> k4 = A*(x + k3*dt)
>>> final = dt*S(1)/6*(k1 + 2*k2 + 2*k3 + k4)
>>> eqi = Matrix(final)
>>> print(cse(eqi.applyfunc(simplify)))
([(x3, 4*x0), (x4, a01*x1), (x5, a02*x2), (x6, dt*(a00*x0 + x4 + x5) +
2*x0), (x7, a00*x6), (x8, a11*x1), (x9, a12*x2), (x10, dt*(a10*x0 + x8
+ x9) + 2*x1), (x11, a01*x10), (x12, a21*x1), (x13, a22*x2), (x14,
dt*(a20*x0 + x12 + x13) + 2*x2), (x15, a02*x14), (x16, dt*(x11 + x15 +
x7) + x3), (x17, a00*x16), (x18, 4*x1), (x19, a10*x6), (x20, a11*x10),
(x21, a12*x14), (x22, dt*(x19 + x20 + x21) + x18), (x23, a01*x22),
(x24, 4*x2), (x25, a20*x6), (x26, a21*x10), (x27, a22*x14), (x28,
dt*(x25 + x26 + x27) + x24), (x29, a02*x28), (x30, dt*(x17 + x23 +
x29) + x3), (x31, a10*x16), (x32, a11*x22), (x33, a12*x28), (x34,
dt*(x31 + x32 + x33) + x18), (x35, a20*x16), (x36, a21*x22), (x37,
a22*x28), (x38, dt*(x35 + x36 + x37) + x24), (x39, dt/24)], [Matrix([
[   x39*(a00*x3 + a00*x30 + a01*x34 + a02*x38 + 4*x11 + 4*x15 + 2*x17
+ 2*x23 + 2*x29 + 4*x4 + 4*x5 + 4*x7)], [  x39*(a10*x3 + a10*x30 +
a11*x34 + a12*x38 + 4*x19 + 4*x20 + 4*x21 + 2*x31 + 2*x32 + 2*x33 +
4*x8 + 4*x9)], [x39*(a20*x3 + a20*x30 + a21*x34 + a22*x38 + 4*x12 +
4*x13 + 4*x25 + 4*x26 + 4*x27 + 2*x35 + 2*x36 + 2*x37)]])])

还有对非交换表达式的有限支持,可以帮助解决这种情况:

>>> A, x = symbols("A x", commutative=False)
>>> dt = symbols('dt')
>>> k1 = A*x
>>> k2 = A*(x + S(1)/2*k1*dt)
>>> k3 = A*(x + S(1)/2*k2*dt)
>>> k4 = A*(x + k3*dt)
>>> final = dt*S(1)/6*(k1 + 2*k2 + 2*k3 + k4)
>>> final.expand()
dt**4*A**4*x/24 + dt**3*A**3*x/6 + dt**2*A**2*x/2 + dt*A*x
>>> factor(_)
dt*A*(dt**3*A**3/24 + dt**2*A**2/6 + dt*A/2 + 1)*x

但并非所有简化程序都支持nc(这是已知的问题):

>>> collect(final,x)
Traceback (most recent call last):
...
AttributeError: Can not collect noncommutative symbol

这不是期望的结果。期望的结果是 (A**4*dt**4/24 + A**3*dt**3/6 + A**2*dt**2/2 + A*dt + I)*x - Lutz Lehmann
我明白 - 你想要一个MatExpr来允许标量的扩展和收集。 - smichr
或者,如果可能的话,将 Ax 声明为非交换标量。 - Lutz Lehmann

0

我认为你可以扩展矩阵表达式。但是你拥有的不是一个矩阵,而是两个符号矩阵(Matsymbols)的乘积。如果你将你的表达式转换成一个矩阵,你就可以得到你想要的扩展。请参见下面的额外行。

from sympy import init_session
init_session()
from sympy import *

A = MatrixSymbol('A', 3, 3)
x = MatrixSymbol('x', 3, 1)
dt = symbols('dt')

k1 = A*x
k2 = A*(x + S(1)/2*k1*dt)
k3 = A*(x + S(1)/2*k2*dt)
k4 = A*(x + k3*dt)
final = dt*S(1)/6*(k1 + k2 + k3 + k4)
Matrix(final).expand()

谢谢你的回答!虽然这不会出现错误,但这并不是我要找的,我希望表达式能像标量变量一样简化、扩展和收集(我已经在问题中添加了标量变量)。 - D.Thomas

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接