针对函数中变量固定的情况,优化Python数学代码

3

我有一个非常长的数学公式(仅为了让您了解:它有293095个字符),实际上将成为Python函数的主体。该函数有15个输入参数,如下:

def math_func(t,X,P,n1,n2,R,r):
    x,y,z = X
    a,b,c = P
    u1,v1,w1 = n1
    u2,v2,w2 = n2
    return <long math formula>

该公式使用简单的数学运算符+ - * ** /和一个函数调用arctan。以下是公式的摘录:
r*((-16*(r**6*t*u1**6 - 6*r**6*u1**5*u2 - 15*r**6*t*u1**4*u2**2 +
 20*r**6*u1**3*u2**3 + 15*r**6*t*u1**2*u2**4 - 6*r**6*u1*u2**5 -
 r**6*t*u2**6 + 3*r**6*t*u1**4*v1**2 - 12*r**6*u1**3*u2*v1**2 -
 18*r**6*t*u1**2*u2**2*v1**2 + 12*r**6*u1*u2**3*v1**2 +
 3*r**6*t*u2**4*v1**2 + 3*r**6*t*u1**2*v1**4 - 6*r**6*u1*u2*v1**4 -
 3*r**6*t*u2**2*v1**4 + r**6*t*v1**6 - 6*r**6*u1**4*v1*v2 -
 24*r**6*t*u1**3*u2*v1*v2 + 36*r**6*u1**2*u2**2*v1*v2 +
 24*r**6*t*u1*u2**3*v1*v2 - 6*r**6*u2**4*v1*v2 -
 12*r**6*u1**2*v1**3*v2 - 24*r**6*t*u1*u2*v1**3*v2 +
 12*r**6*u2**2*v1**3*v2 - 6*r**6*v1**5*v2 - 3*r**6*t*u1**4*v2**2 + ...  

现在的重点是,实际中将批量评估此函数时会针对固定值 P、n1、n2、Rr 进行计算,从而将“自由”变量集减少为仅有四个,理论上具有较少参数的公式应该更快。

因此问题是:我如何在Python中实现此优化?

我知道可以将所有内容放入字符串中,并执行某种类型的 replacecompileeval 操作,就像以下示例:

formula = formula.replace('r','1').replace('R','2')....
code = compile(formula,'formula-name','eval')
math_func = lambda t,x,y,z: eval(code)

如果一些操作(如幂)可以用它们的值替换,那将是很好的,例如18*r**6*t*u1**2*u2**2*v1**2应该在r=u1=u2=v1=1的情况下变为18*t。我认为compile应该这样做,但无论如何我都不确定。 compile实际上执行此优化吗? 我的解决方案加速了计算,但如果我能够压缩它更多,那将是非常好的。注意:最好在标准Python内完成(稍后我可以尝试Cython)。
一般来说,我对以“Pythonic”的方式实现我的目标感兴趣,也许需要一些额外的库:有什么合理好的方法可以做到这一点吗?我的解决方案是一个好的方法吗? 编辑:(为了提供更多背景)
巨大的表达式是圆弧线积分的输出。圆弧在空间中由半径r,两个正交规范向量(如二维版本中的x和y轴)n1 =(u1,v1,w1)n2 =(u2,v2,w2)和中心P =(a,b,c)给出。其余部分是我执行积分的点X =(x,y,z)和函数R的参数。 SympyMaple需要很长时间来计算这个表达式,实际输出来自Mathematica
如果您对公式感到好奇,请看这里(伪伪代码):
G(u) = P + r*(1-u**2)/(1+u**2)*n1 + r*2*u/(1+u**2)*n2
integral of (1-|X-G(t)|^2/R^2)^3 over t

2
看看 sympy - 一个用于符号数学的 Python 包。 - Leon
  1. 鉴于该表达式的巨大规模,我认为您要么做错了什么,要么可以用更简单的形式来描述该表达式(例如,它可能是一个复杂表达式的1000个值的求和?)。如果您能提供您正在使用的完整综合代数表达式,那将非常有帮助。
  2. Python不进行任何优化,因为这些变量的类型可以是任何类型的对象,这些对象可以使用这些运算符执行任何操作,因此无法安全地执行任何优化,编译器甚至不尝试。
- Bakuriu
顺便问一下,我很感兴趣:你是怎么生成那个表达式的?请注意,Python解析器有50个括号嵌套的限制,这可能不足以处理如此庞大的表达式(但这取决于表达式的具体构造方式)。 - Bakuriu
@Leon 我已经尝试过 sympy(实际上我是 sympy 的日常用户),但在这个表达式上它有点慢。虽然它给出了一个很好的输出结果,但很高兴知道你认为这是一个不错的选择 ;) - Alvaro Fuentes
@Bakuriu 好的,但是Python不应该将2**3*x优化为8*x吗?特别是如果我将2**a*x中的a替换为3并编译(我假设,也许是错误的),它应该会进行优化。 - Alvaro Fuentes
显示剩余3条评论
3个回答

2
您可以使用Sympy:
>>> from sympy import symbols
>>> x,y,z,a,b,c,u1,v1,w1,u2,v2,w2,t,r = symbols("x,y,z,a,b,c,u1,v1,w1,u2,v2,w2,t,r")
>>> r=u1=u2=v1=1
>>> a = 18*r**6*t*u1**2*u2**2*v1**2
>>> a
18*t

然后您可以创建一个类似这样的Python函数:
>>> from sympy import lambdify
>>> f = lambdify(t, a)
>>> f(1)
18

而那个f函数确实只是18*t

>>> import dis
>>> dis.dis(f)
  1           0 LOAD_CONST               1 (18)
              3 LOAD_FAST                0 (_Dummy_18)
              6 BINARY_MULTIPLY
              7 RETURN_VALUE

如果你想将生成的代码编译成机器码,可以尝试使用JIT编译器,例如NumbaTheanoParakeet


是的,现在我正在做的是 expr = sympy.parsing.sympy_parser(formula,local_dict={'a':1,'b':0...}) 然后 sympy.lambdify(expr,(t,x,y,z)) 这将给出函数的快速版本。虽然我正在探索是否可能在Python编译中实现。无论如何,如果没有其他答案,我将把这个标记为已接受。 - Alvaro Fuentes

1
这是我解决这个问题的方法:
  1. 将函数编译为AST(抽象语法树)而不是普通的字节码函数-请参阅标准的ast模块了解详情。
  2. 遍历AST,用其固定值替换所有对固定参数的引用。有一些库,例如Macropy可能会有所帮助,但我没有具体的建议。
  3. 再次遍历AST,执行可能启用的任何优化,例如Mult(1, X) => X。您无需担心两个常量之间的操作,因为Python(自2.6以来)已经进行了优化。
  4. 将AST编译为普通函数。调用它,并希望速度增加足够多,以证明所有预优化都是合理的。

请注意,Python永远不会自动优化类似于1*X这样的东西,因为它不能知道运行时X的类型-它可以是实现任意方式的乘法操作的类的实例,因此结果不一定是X。只有您知道所有变量都是普通数字,遵守算术的常规规则,才能使此优化有效。


1

解决这类问题的“正确方法”可以是以下一种或多种:

  1. 找到更高效的公式
  2. 符号化简和缩减项
  3. 使用向量化(例如NumPy)
  4. 使用已经优化的低级库(例如在C或Fortran等语言中隐含地进行强表达式优化,而不是Python,它什么也不做)。

假设暂时无法使用1、3和4方法,并且必须在Python中完成此操作。那么简化和“提升”常见子表达式是您的主要工具。

好消息是,有很多机会。例如,表达式r**6重复了26次。您只需一次性分配r_6 = r ** 6,然后替换每次出现r**6,即可节省25次计算。

当您开始在这里寻找常见表达式时,您会发现它们无处不在。机械化这个过程会很好,对吧?一般来说,这需要完整的表达式解析器(例如来自ast模块),并且是指数级时间优化问题。但是,您的表达式有点特殊。虽然长而且各种各样,但并不特别复杂。它只有少量内部括号分组,因此我们可以采用更快、更脏的方法。
在介绍如何之前,以下是最终代码:
sa = r**6                      # 26 occurrences
sb = u1**2                     # 5 occurrences
sc = u2**2                     # 5 occurrences
sd = v1**2                     # 5 occurrences
se = u1**4                     # 4 occurrences
sf = u2**3                     # 3 occurrences
sg = u1**3                     # 3 occurrences
sh = v1**4                     # 3 occurrences
si = u2**4                     # 3 occurrences
sj = v1**3                     # 3 occurrences
sk = v2**2                     # 1 occurrence
sl = v1**6                     # 1 occurrence
sm = v1**5                     # 1 occurrence
sn = u1**6                     # 1 occurrence
so = u1**5                     # 1 occurrence
sp = u2**6                     # 1 occurrence
sq = u2**5                     # 1 occurrence
sr = 6*sa                      # 6 occurrences
ss = 3*sa                      # 5 occurrences
st = ss*t                      # 5 occurrences
su = 12*sa                     # 4 occurrences
sv = sa*t                      # 3 occurrences
sw = v1*v2                     # 5 occurrences
sx = sj*v2                     # 3 occurrences
sy = 24*sv                     # 3 occurrences
sz = 15*sv                     # 2 occurrences
sA = sr*u1                     # 2 occurrences
sB = sy*u1                     # 2 occurrences
sC = sb*sc                     # 2 occurrences
sD = st*se                     # 2 occurrences

# revised formula
sv*sn - sr*so*u2 - sz*se*sc +
20*sa*sg*sf + sz*sb*si - sA*sq -
sv*sp + sD*sd - su*sg*u2*sd -
18*sv*sC*sd + su*u1*sf*sd +
st*si*sd + st*sb*sh - sA*u2*sh -
st*sc*sh + sv*sl - sr*se*sw -
sy*sg*u2*sw + 36*sa*sC*sw +
sB*sf*sw - sr*si*sw -
su*sb*sx - sB*u2*sx +
su*sc*sx - sr*sm*v2 - sD*sk

那可以避免81次计算。这只是一个粗略的计算,甚至结果可以进一步改善。例如,子表达式sr * swsu * sd也可以预先计算。但是我们将把下一级留给另一天。
请注意,这不包括起始的r *(-16 *(。大部分简化工作可以在表达式的核心上完成,而不是在其外部项上进行。因此,我暂时剥离了它们,一旦计算出共同的核心,就可以将它们添加回来。
你怎么做到的?
f = """
r**6*t*u1**6 - 6*r**6*u1**5*u2 - 15*r**6*t*u1**4*u2**2 +
20*r**6*u1**3*u2**3 + 15*r**6*t*u1**2*u2**4 - 6*r**6*u1*u2**5 -
r**6*t*u2**6 + 3*r**6*t*u1**4*v1**2 - 12*r**6*u1**3*u2*v1**2 -
18*r**6*t*u1**2*u2**2*v1**2 + 12*r**6*u1*u2**3*v1**2 +
3*r**6*t*u2**4*v1**2 + 3*r**6*t*u1**2*v1**4 - 6*r**6*u1*u2*v1**4 -
3*r**6*t*u2**2*v1**4 + r**6*t*v1**6 - 6*r**6*u1**4*v1*v2 -
24*r**6*t*u1**3*u2*v1*v2 + 36*r**6*u1**2*u2**2*v1*v2 +
24*r**6*t*u1*u2**3*v1*v2 - 6*r**6*u2**4*v1*v2 -
12*r**6*u1**2*v1**3*v2 - 24*r**6*t*u1*u2*v1**3*v2 +
12*r**6*u2**2*v1**3*v2 - 6*r**6*v1**5*v2 - 3*r**6*t*u1**4*v2**2
""".strip()


from collections import Counter
import re

expre = re.compile('(?<!\w)\w+\*\*\d+')
multre = re.compile('(?<!\w)\w+\*\w+')

expr_saved = 0
stmts = []


secache = {}
seindex = 0
def subexpr(e):
    global seindex
    cached = secache.get(e)
    if cached:
        return cached
    base = ord('a') if seindex < 26 else ord('A') - 26
    name = 's' + chr(seindex + base)
    seindex += 1
    secache[e] = name
    return name

def hoist(e, flat, c):
    """
    Hoist the expression e into name defined by flat.
    c is the count of how many times seen in incoming
    formula.
    """
    global expr_saved

    assign = "{} = {}".format(flat, e)
    s = "{:30} # {} occurrence{}".format(assign, c, '' if c == 1 else 's')
    stmts.append(s)
    print "{} needless computations quashed with {}".format(c-1, flat)
    expr_saved += c - 1

def common_exp(form):
    """
    Replace ALL exponentiation operations with a hoisted
    sub-expression.
    """
    # find the exponentiation operations
    exponents = re.findall(expre, form)

    # find and count exponentiation operations
    expcount = Counter(re.findall(expre, form))

    # for each exponentiation, create a hoisted sub-expression
    for e, c in expcount.most_common():
        hoist(e, subexpr(e), c)

    # replace all exponentiation operations with their sub-expressions
    form = re.sub(expre, lambda x: subexpr(x.group(0)), form)
    return form


def common_mult(f):
    """
    Replace multiplication operations with a hoisted
    sub-expression if they occur > 1 time. Also, only
    replaces one sub-expression at a time (the most common)
    because it may affect further expressions
    """
    mults = re.findall(multre, f)
    for e, c in Counter(mults).most_common():
        # unlike exponents, only replace if >1 occurrence
        if c == 1:
            return f
        # occurs >1 time, so hoist
        hoist(e, subexpr(e), c)
        # replace in loop and return
        return re.sub('(?<!\w)' + re.escape(e), subexpr(e), f)
        # return f.replace(e, flat(e))
    return f

# fix all exponents
form = common_exp(f)

# fix selected multiplies
prev = form
while True:
    form = common_mult(form)
    if form == prev:
        # have converged; no more replacements possible
        break
    prev = form

print "--"
mults = re.split(r'\s*[+-]\s*', form)
smults = ['*'.join(sorted(terms.split('*'))) for terms in mults]
print smults

# print the hoisted statements and the revised expression
print '\n'.join(stmts)
print
print "# revised formula"
print form

使用正则表达式进行解析是一个风险很高的事情。这个过程容易出错,让人感到悲伤和后悔。我通过提升一些不严格需要的指数,并将随机值插入到前后公式中来防范不良结果。如果这是生产代码,我建议采用“转向C”的策略。但如果你不能...

有趣的回答。我认为你提出的所有观点都是有道理的。这段代码只是一个原型,可能会成为一个“真实世界”的实现,肯定会使用 C++,也许还会加上 Python 接口。我正在寻找一些容易优化代码的方法,以便我可以玩弄它并决定它是否值得最终投入时间和精力 ;) - Alvaro Fuentes
遗憾的是,在Python中优化表达式没有特别好或容易的方法。它甚至不会做最简单的“窥孔优化”,这在C和Fortran 40-50年前就已经成功了,更不用说公共子表达式消除、强度降低、管道重新排序等了。 - Jonathan Eunice

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接