为大量输入值快速评估数学表达式(函数)

17
以下问题:

及其相应的答案让我思考如何高效地解析一个单一的数学表达式(通常类似于此答案https://dev59.com/K3RB5IYBdhLWcg3wiHv7#594294)由一个(更或少可信的)用户给出的20k到30k个输入值来自数据库。我实现了一个快速而简单的基准测试,以便比较不同的解决方案。

# Runs with Python 3(.4)
import pprint
import time

# This is what I have
userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)
print_results = False

# Some database, represented by an array of dicts (simplified for this example)

database_xy = []
for a in range(1, demo_len, 1):
    database_xy.append({
        'x':float(a),
        'y_eval':0,
        'y_sympya':0,
        'y_sympyb':0,
        'y_sympyc':0,
        'y_aevala':0,
        'y_aevalb':0,
        'y_aevalc':0,
        'y_numexpr': 0,
        'y_simpleeval':0
        })

# 解决方案 #1:eval [是的,完全不安全]

time_start = time.time()
func = eval("lambda x: " + userinput_function)
for item in database_xy:
    item['y_eval'] = func(item['x'])
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('1 eval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #2a: sympy - evalf (http://www.sympy.org)

import sympy
time_start = time.time()
x = sympy.symbols('x')
sympy_function = sympy.sympify(userinput_function)
for item in database_xy:
    item['y_sympya'] = float(sympy_function.evalf(subs={x:item['x']}))
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('2a sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #2b: Sympy - lambdify (http://www.sympy.org)

from sympy.utilities.lambdify import lambdify
import sympy
import numpy
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numpy') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
    xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
    item['y_sympyb'] = yy[index]
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('2b sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #2c: 使用sympy - lambdify和numexpr [以及numpy] (http://www.sympy.org)

from sympy.utilities.lambdify import lambdify
import sympy
import numpy
import numexpr
time_start = time.time()
sympy_functionb = sympy.sympify(userinput_function)
func = lambdify(x, sympy_functionb, 'numexpr') # returns a numpy-ready function
xx = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
    xx[index] = item['x']
yy = func(xx)
for index, item in enumerate(database_xy):
    item['y_sympyc'] = yy[index]
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('2c sympy: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #3a:asteval [基于ast] - 带有字符串魔法(http://newville.github.io/asteval/index.html

from asteval import Interpreter
aevala = Interpreter()
time_start = time.time()
aevala('def func(x):\n\treturn ' + userinput_function)
for item in database_xy:
    item['y_aevala'] = aevala('func(' + str(item['x']) + ')')
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('3a aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #3b (M Newville):asteval [基于ast] - 解析和运行(http://newville.github.io/asteval/index.html)

from asteval import Interpreter
aevalb = Interpreter()
time_start = time.time()
exprb = aevalb.parse(userinput_function)
for item in database_xy:
    aevalb.symtable['x'] = item['x']
    item['y_aevalb'] = aevalb.run(exprb)
time_end = time.time()
print('3b aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #3c (M Newville): asteval [基于ast] - 使用numpy解析和运行 (http://newville.github.io/asteval/index.html)

from asteval import Interpreter
import numpy
aevalc = Interpreter()
time_start = time.time()
exprc = aevalc.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aevalc.symtable['x'] = x
y = aevalc.run(exprc)
for index, item in enumerate(database_xy):
    item['y_aevalc'] = y[index]
time_end = time.time()
print('3c aeval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #4: simpleeval [基于ast] (https://github.com/danthedeckie/simpleeval)

from simpleeval import simple_eval
time_start = time.time()
for item in database_xy:
    item['y_simpleeval'] = simple_eval(userinput_function, names={'x': item['x']})
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('4 simpleeval: ' + str(round(time_end - time_start, 4)) + ' seconds')

# 解决方案 #5 numexpr [和 numpy] (https://github.com/pydata/numexpr)

import numpy
import numexpr
time_start = time.time()
x = numpy.zeros(len(database_xy))
for index, item in enumerate(database_xy):
    x[index] = item['x']
y = numexpr.evaluate(userinput_function)
for index, item in enumerate(database_xy):
    item['y_numexpr'] = y[index]
time_end = time.time()
if print_results:
    pprint.pprint(database_xy)
print('5 numexpr: ' + str(round(time_end - time_start, 4)) + ' seconds')

在我的旧测试机上(Python 3.4,Linux 3.11 x86_64,双核,1.8GHz),我得到了以下结果:
1 eval: 0.0185 seconds
2a sympy: 10.671 seconds
2b sympy: 0.0315 seconds
2c sympy: 0.0348 seconds
3a aeval: 2.8368 seconds
3b aeval: 0.5827 seconds
3c aeval: 0.0246 seconds
4 simpleeval: 1.2363 seconds
5 numexpr: 0.0312 seconds

突出的是eval的不可思议的速度,但我不想在实际生活中使用它。第二好的解决方案似乎是依赖于numpynumexpr,虽然这不是硬性要求,但我想避免这种依赖。下一个最好的选择是围绕ast构建的simpleeval。另一个基于ast的解决方案aeval的问题在于,我必须先将每个单独的浮点输入值转换为字符串,而我找不到解决方法。最初我最喜欢sympy,因为它提供了最灵活和显然最安全的解决方案,但它最终以惊人的距离落后于倒数第二个解决方案。 更新1:有一种更快的方法可以使用sympy。请参见2b解决方案。它几乎和numexpr一样好,尽管我不确定sympy是否真正在内部使用它。 更新2:现在sympy的实现使用sympify而不是simplify(由其首席开发人员asmeurer推荐-感谢)。除非明确要求使用numexpr(参见解决方案2c),否则不会使用它。我还添加了两个基于asteval的显着更快的解决方案(感谢M Newville)。

我有哪些选项可以进一步加快任何相对更安全的解决方案?例如,是否可以直接使用ast进行其他安全(或相对安全)的方法?


使用 ast 解析用户表达式,检查它是否只包含白名单中允许的操作(最好要非常严格),然后再使用 eval/compile?(但这并不能防止拒绝服务攻击。) - Ry-
lambdify 不会使用 numexpr,除非你设置 modules='numexpr' - asmeurer
@asmeurer 谢谢,我相应地添加了一个新的比较方案。 - s-m-e
sympify() 函数几乎和 eval() 一样不安全。你能为它们添加一个注释吗? - user
1
自从这个问题被发布以来,我已经基于pyparsing开发了plusminus包,提供了一个安全、可嵌入、可扩展的算术解析器/求值器,支持包括许多Unicode数学运算符在内的表达式,|a-b|表示abs(a-b),以及对应于您预编译表达式的延迟求值表达式。plusminus有一个在线的、面向互联网开放的演示,您可以尝试一下。 - PaulMcG
显示剩余8条评论
5个回答

3
我过去使用过C++ ExprTK 库,并且取得了巨大的成功。这里 是一个速度测试,与其他C++解析器(例如Muparser、MathExpr、ATMSP等)相比,ExprTK排名第一。
有一个Python封装程序,称为cexprtk,我已经使用过,并发现其非常快速。您可以编译数学表达式一次,然后根据需要多次评估序列化表达式。以下是使用cexprtkuserinput_function的简单示例代码:
import cexprtk
import time

userinput_function = '5*(1-(x*0.1))' # String - numbers should be handled as floats
demo_len = 20000 # Parameter for benchmark (20k to 30k in real life)

time_start = time.time()
x = 1

st = cexprtk.Symbol_Table({"x":x}, add_constants = True) # Setup the symbol table
Expr = cexprtk.Expression(userinput_function, st) # Apply the symbol table to the userinput_function

for x in range(0,demo_len,1):
    st.variables['x'] = x # Update the symbol table with the new x value
    Expr() # evaluate expression
time_end = time.time()

print('1 cexprtk: ' + str(round(time_end - time_start, 4)) + ' seconds')

在我的机器上(Linux,双核,2.5GHz),对于演示长度为20000,这需要完成0.0202秒。
对于演示长度为2,000,000,cexprtk需要1.23秒完成。

你是否将结果与原始的C++函数进行比较?例如,“double foo(double x) { return 5*(1-(x*0.1)); }”? - macomphy

2

由于您询问了asteval,这里有一种使用它并获得更快结果的方法:

aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
for item in database_xy:
    aeval.symtable['x'] = item['x']
    item['y_aeval'] = aeval.run(expr)
time_end = time.time()

也就是说,您可以先解析(“预编译”)用户输入的函数,然后将每个新值插入符号表并使用Interpreter.run()来评估该值的已编译表达式。按照您的规模,我认为这将使您接近0.5秒。

如果您愿意使用numpy,则可以采用混合方案:

aeval = Interpreter()
time_start = time.time()
expr = aeval.parse(userinput_function)
x = numpy.array([item['x'] for item in database_xy])
aeval.symtable['x'] = x
y = aeval.run(expr)
time_end = time.time()

应该会快得多,并且在运行时间上与使用numexpr相当。

非常感谢。我将你的代码(加上次数)添加到解决方案列表中,参见3b和3c。你的第二段代码实际上执行得几乎和空白的eval语句一样好。 - s-m-e

2

CPython(以及pypy)在执行函数时使用非常简单的堆栈语言,使用ast模块编写字节码相当容易。

import sys
PY3 = sys.version_info.major > 2
import ast
from ast import parse
import types
from dis import opmap

ops = {
    ast.Mult: opmap['BINARY_MULTIPLY'],
    ast.Add: opmap['BINARY_ADD'],
    ast.Sub: opmap['BINARY_SUBTRACT'],
    ast.Div: opmap['BINARY_TRUE_DIVIDE'],
    ast.Pow: opmap['BINARY_POWER'],
}
LOAD_CONST = opmap['LOAD_CONST']
RETURN_VALUE = opmap['RETURN_VALUE']
LOAD_FAST = opmap['LOAD_FAST']
def process(consts, bytecode, p, stackSize=0):
    if isinstance(p, ast.Expr):
        return process(consts, bytecode, p.value, stackSize)
    if isinstance(p, ast.BinOp):
        szl = process(consts, bytecode, p.left, stackSize)
        szr = process(consts, bytecode, p.right, stackSize)
        if type(p.op) in ops:
            bytecode.append(ops[type(p.op)])
        else:
            print(p.op)
            raise Exception("unspported opcode")
        return max(szl, szr) + stackSize + 1
    if isinstance(p, ast.Num):
        if p.n not in consts:
            consts.append(p.n)
        idx = consts.index(p.n)
        bytecode.append(LOAD_CONST)
        bytecode.append(idx % 256)
        bytecode.append(idx // 256)
        return stackSize + 1
    if isinstance(p, ast.Name):
        bytecode.append(LOAD_FAST)
        bytecode.append(0)
        bytecode.append(0)
        return stackSize + 1
    raise Exception("unsupported token")

def makefunction(inp):
    def f(x):
        pass

    if PY3:
        oldcode = f.__code__
        kwonly = oldcode.co_kwonlyargcount
    else:
        oldcode = f.func_code
    stack_size = 0
    consts = [None]
    bytecode = []
    p = ast.parse(inp).body[0]
    stack_size = process(consts, bytecode, p, stack_size)
    bytecode.append(RETURN_VALUE)
    bytecode = bytes(bytearray(bytecode))
    consts = tuple(consts)
    if PY3:
        code = types.CodeType(oldcode.co_argcount, oldcode.co_kwonlyargcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, b'')
        f.__code__ = code
    else:
        code = types.CodeType(oldcode.co_argcount, oldcode.co_nlocals, stack_size, oldcode.co_flags, bytecode, consts, oldcode.co_names, oldcode.co_varnames, oldcode.co_filename, 'f', oldcode.co_firstlineno, '')
        f.func_code = code
    return f

这种方法的优点是生成的函数与eval基本相同,并且几乎与compile+eval一样有效(compile步骤比eval略慢,eval会预计算任何可以预计算的内容(例如1+1+x编译为2+x)。

相比之下,eval在0.0125秒内完成了20k测试,而makefunction则需要0.014秒。将迭代次数增加到2,000,000,eval需要1.23秒完成,而makefunction需要1.32秒完成。

有趣的是,pypy认识到evalmakefunction生成的函数基本相同,因此第一个的JIT热身会加速第二个。


1
我不是Python程序员,所以无法提供Python代码。但我认为我可以提供一个简单的方案,最小化您的依赖关系并且仍然运行得非常快。
关键在于构建一个类似于eval的东西,但又不是eval。因此,您想要做的就是将用户方程式“编译”成可以快速评估的内容。OP已经展示了许多解决方案。
这里是另一种基于将方程式作为逆波兰表达式进行评估的方法。
为了讨论,假设您可以将方程式转换为RPN(逆波兰表示法)。这意味着操作数在运算符之前,例如,对于用户公式:
        sqrt(x**2 + y**2)

您左到右阅读,可以获得逆波兰表达式的等价形式:
          x 2 ** y 2 ** + sqrt

实际上,我们可以将“操作数”(例如变量和常量)视为需要零个操作数的运算符。现在,逆波兰式中的每个元素都是一个运算符。
如果我们将每个运算符元素视为一个令牌(假设下面写成“RPNelement”的唯一小整数),并将它们存储在数组“RPN”中,我们可以使用一个下推栈相当快地评估这样的公式:
       stack = {};  // make the stack empty
       do i=1,len(RPN),1
          case RPN[i]:
              "0":  push(stack,0);
              "1": push(stack,1);
              "+":  push(stack,pop(stack)+pop(stack));break;
               "-": push(stack,pop(stack)-pop(stack));break;
               "**": push(stack,power(pop(stack),pop(stack)));break;
               "x": push(stack,x);break;
               "y": push(stack,y);break;
               "K1": push(stack,K1);break;
                ... // as many K1s as you have typical constants in a formula
           endcase
       enddo
       answer=pop(stack);

您可以内联推送和弹出操作以加快速度。如果提供的逆波兰式格式正确,此代码是完全安全的。
现在,如何获得逆波兰式?答案:构建一个小型递归下降解析器,其操作将RPN运算符附加到RPN数组中。有关典型方程的详细信息,请参见我的SO答案,了解如何轻松构建递归下降解析器
您将不得不组织将解析中遇到的常量放入K1、K2等中,如果它们不是特殊的、常见的值(如我展示的“0”和“1”),则可以添加更多。
这个解决方案最多只需几百行,对其他包没有任何依赖。
(Python专家:请随意编辑代码,使其更符合Python风格)。

1
如果您将字符串传递给sympy.simplify(这不是推荐的用法;推荐明确使用sympify),那么它将使用sympy.sympify将其转换为SymPy表达式,其中使用了内部的eval

我根据您的建议修改了我的问题中的sympy解决方案。 - s-m-e

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接