如何高效地传递函数?

16

动机

看看以下图片。

enter image description here

给出了红、蓝和绿曲线。我想在每个 x 轴上找到占优势的曲线。如图所示,这是作为黑色图表显示的。由于红、绿和蓝曲线的属性(在一段时间后增加和恒定),这归结于在最右边找到占优势的曲线,然后向左移动,找到所有交点并更新占优曲线。

需要解决此问题T次。问题的最后一个变化是:下一次迭代的蓝、绿和红曲线是通过前一次迭代的占优解加上某些可变参数进行构造的。例如,在上面的图中:解决方案是黑色函数。使用此函数生成新的蓝、绿和红曲线。然后问题开始寻找这些新曲线的占优解。

问题简述
在每个迭代中,我从固定的最右侧开始,并评估所有三个函数,以查看哪个是占优解。这些评估随着迭代的进行而变得越来越慢。 我的感觉是我没有最优地传递旧的占优函数来构造新的蓝、绿和红曲线。原因:在早期版本中,我遇到了最大递归深度错误。代码的其他部分需要当前占优函数的值 (它可能是绿色、红色或蓝色曲线),这也要随着迭代而变得越来越慢。

对于5次迭代,仅在最右侧的一个点上评估函数将会增长:

结果是通过以下方式产生的:

test = A(5, 120000, 100000) 

然后运行

test.find_all_intersections()

>>> test.find_all_intersections()
iteration 4
to compute function values it took
0.0102479457855
iteration 3
to compute function values it took
0.0134601593018
iteration 2
to compute function values it took
0.0294270515442
iteration 1
to compute function values it took
0.109843969345
iteration 0
to compute function values it took
0.823768854141

我想知道为什么这是这样的,以及是否能更有效地编程。
详细的代码解释:
1. 方法:为了生成上述绿色、红色和蓝色曲线的新批次,我们需要旧的主导曲线。u 是在第一次迭代中使用的初始化。

2. 方法_function_template:该函数使用不同的参数生成绿色、蓝色和红色曲线的版本。它返回一个单输入函数。

3. 方法eval:这是每次生成蓝色、绿色和红色版本的核心函数。每次迭代都会有三个可变参数:vfunction 是上一步的主导函数,m 和 s 是影响结果曲线形状的两个参数(flaots)。其他参数在每次迭代中都是相同的。在代码中,对于每次迭代,都有 m 和 s 的示例值。对于更极客的人来说,这是为了近似求解积分,其中 m 和 s 是基础正态分布的期望均值和标准差。通过 Gauss-Hermite 节点/权重来进行近似计算。

4. 方法find_all_intersections:这是找到每次迭代中主导函数的核心方法。它通过将蓝色、绿色和红色曲线拼接成段来构建一个主导函数。这是通过函数 piecewise 实现的。

以下是完整的代码。
import numpy as np
import pandas as pd
from scipy.optimize import brentq
import multiprocessing as mp
import pathos as pt
import timeit
import math
class A(object):
    def u(self, w):
        _w = np.asarray(w).copy()
        _w[_w >= 120000] = 120000
        _p = np.maximum(0, 100000 - _w)
        return _w - 1000*_p**2

    def __init__(self, T, upper_bound, lower_bound):
        self.T = T
        self.upper_bound = upper_bound
        self.lower_bound = lower_bound

    def _function_template(self, *args):
        def _f(x):
            return self.evalv(x, *args)
        return _f

    def evalv(self, w, c, vfunction, g, m, s, gauss_weights, gauss_nodes):
        _A = np.tile(1 + m + math.sqrt(2) * s * gauss_nodes, (np.size(w), 1))
        _W = (_A.T * w).T
        _W = gauss_weights * vfunction(np.ravel(_W)).reshape(np.size(w),
                                                             len(gauss_nodes))
        evalue = g*1/math.sqrt(math.pi)*np.sum(_W, axis=1)
        return c + evalue

    def find_all_intersections(self):

        # the hermite gauss weights and nodes for integration
        # and additional paramters used for function generation

        gauss = np.polynomial.hermite.hermgauss(10)
        gauss_nodes = gauss[0]
        gauss_weights = gauss[1]
        r = np.asarray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                        1., 1., 1., 1., 1., 1., 1., 1., 1.])
        m = [[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],
             [0.038212567720998125, 0.08509661835487026, 0.15484578903763624]]

        s = [[0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.01945441966324046, 0.04690600929081242, 0.200125178687699],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],
             [0.019491796104351332, 0.04699612658674578, 0.20050966545654142]]

        self.solution = []

        n_cpu = mp.cpu_count()
        pool = pt.multiprocessing.ProcessPool(n_cpu)

        # this function is used for multiprocessing
        def call_f(f, x):
            return f(x)

        # this function takes differences for getting cross points
        def _diff(f_dom, f_other):
            def h(x):
                return f_dom(x) - f_other(x)
            return h

        # finds the root of two function
        def find_roots(F, u_bound, l_bound):
                try:
                    sol = brentq(F, a=l_bound,
                                 b=u_bound)
                    if np.absolute(sol - u_bound) > 1:
                        return sol
                    else:
                        return l_bound
                except ValueError:
                    return l_bound

        # piecewise function
        def piecewise(l_comp, l_f):
            def f(x):
                _ind_f = np.digitize(x, l_comp) - 1
                if np.isscalar(x):
                    return l_f[_ind_f](x)
                else:
                    return np.asarray([l_f[_ind_f[i]](x[i])
                                       for i in range(0, len(x))]).ravel()
            return f

        _u = self.u

        for t in range(self.T-1, -1, -1):
            print('iteration' + ' ' + str(t))

            l_bound, u_bound = 0.5*self.lower_bound, self.upper_bound
            l_ordered_functions = []
            l_roots = []
            l_solution = []

            # build all function variations

            l_functions = [self._function_template(0, _u, r[t], m[t][i], s[t][i],
                                                   gauss_weights, gauss_nodes)
                           for i in range(0, len(m[t]))]

            # get the best solution for the upper bound on the very
            # right hand side of wealth interval

            array_functions = np.asarray(l_functions)
            start_time = timeit.default_timer()
            functions_values = pool.map(call_f, array_functions.tolist(),
                                        len(m[t]) * [u_bound])
            elapsed = timeit.default_timer() - start_time
            print('to compute function values it took')
            print(elapsed)

            ind = np.argmax(functions_values)
            cross_points = len(m[t]) * [u_bound]
            l_roots.insert(0, u_bound)
            max_m = m[t][ind]
            l_solution.insert(0, max_m)

            # move from the upper bound twoards the lower bound
            # and find the dominating solution by exploring all cross
            # points.

            test = True

            while test:
                l_ordered_functions.insert(0, array_functions[ind])
                current_max = l_ordered_functions[0]

                l_c_max = len(m[t]) * [current_max]
                l_u_cross = len(m[t]) * [cross_points[ind]]

                # Find new cross points on the smaller interval

                diff = pool.map(_diff, l_c_max, array_functions.tolist())
                cross_points = pool.map(find_roots, diff,
                                        l_u_cross, len(m[t]) * [l_bound])

                # update the solution, cross points and current
                # dominating function.

                ind = np.argmax(cross_points)
                l_roots.insert(0, cross_points[ind])
                max_m = m[t][ind]
                l_solution.insert(0, max_m)

                if cross_points[ind] <= l_bound:
                    test = False

            l_ordered_functions.insert(0, l_functions[0])
            l_roots.insert(0, 0)
            l_roots[-1] = np.inf

            l_comp = l_roots[:]
            l_f = l_ordered_functions[:]

            # build piecewise function which is used for next
            # iteration.

            _u = piecewise(l_comp, l_f)
            _sol = pd.DataFrame(data=l_solution,
                                index=np.asarray(l_roots)[0:-1])
            self.solution.insert(0, _sol)
        return self.solution

3
我认为你的问题对于SO来说太大了。虽然我可能会花费数小时测试和撰写答案,但在第一次阅读时,我很少花费超过30秒钟的时间。 - hpaulj
1
如果你正在开发新的代码,你的目标应该是Python 3,而不是2.7。 - tripleee
1
看起来这只是一个递归问题 - 您的 eval / vfunction 在每次迭代中都会增加复杂度,因为它需要重新评估所有基础和前置函数。 - Kirk Broadhurst
@math,您可以通过复制值或输出来“硬拷贝”,但不能通过引用函数来进行。请查看我的回答。 - Kirk Broadhurst
1
根据快速查看,如果您对先前的函数进行恒定调用并且这些函数相同,即为同一子组件。那么它非常类似于动态规划方法,例如背包问题。您可以使用参数输入作为索引来实现前一个函数结果的具体化。然后,每个函数执行结果查找(如果结果已计算)和计算(如果结果未计算)。 - ZhijieWang
显示剩余6条评论
2个回答

4

让我们先更改代码以输出当前迭代:

_u = self.u
for t in range(0, self.T):
    print(t)
    lparams = np.random.randint(self.a, self.b, 6).reshape(3, 2).tolist()
    functions = [self._function_template(_u, *lparams[i])
                 for i in range(0, 3)]
    # evaluate functions
    pairs = list(itertools.combinations(functions, 2))
    fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]
    ind = np.sort(np.unique(np.random.randint(self.a, self.b, 10)))
    _u = _temp(ind, np.asarray(functions)[ind % 3])

检查导致此行为的代码行,

fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]

需要关注的函数是Fdiff。后者很简单,前者:

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError:
        pass

嗯,捕获异常。让我们看看如果我们这样做会发生什么:

def F(f, a, b):
    brentq(f, a=a, b=b)

立即,在第一个函数和第一次迭代中,会抛出一个错误:

ValueError: f(a)和f(b)必须具有不同的符号

查看文档得知,这是根查找函数brentq的前提条件。让我们再次更改定义以监视每次迭代时的此条件。

def F(f, a, b):
    try:
        brentq(f, a=a, b=b)
    except ValueError as e:
        print(e)

输出结果为:
i
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs
f(a) and f(b) must have different signs

对于变量 i 范围从0到57,这意味着函数 F 第一次进行任何实际工作是在 i=58 时。并且它会持续为更高的 i 值执行此操作。

结论:对于这些较高的值,需要更长的时间,因为:

  1. 对于较低的值,根永远不会被计算
  2. 随着 i>58 ,计算次数呈线性增长

我试图提供一个简单的玩具例子。不幸的是,对于这个玩具例子来说,这真的是个问题。然而,对于真正的问题来说,情况并非如此(在我看来)。我正在更新问题以反映真正的问题。请注意,我已经在线上有一个扩展版本,但被要求缩小它。对于造成的不便,我深感抱歉。 - math
2
修订版本3中已经有一个扩展版本在线了。 - greybeard

3
您的代码过于复杂,难以解释您的问题 - 努力编写更简单的代码。有时候,您需要编写代码来演示问题。
根据您的描述,我猜测您的问题(虽然我运行了代码并进行了验证)。以下是您的问题:
方法eval:这是每次生成蓝色、绿色和红色版本的核心函数。它在每次迭代中都会使用三个不同的参数:vfunction是上一步中占主导地位的函数,m和s是两个参数(flaots),影响结果曲线的形状。
在每次迭代中,您的vfunction参数变得更加复杂。您正在传递一个由先前迭代建立起来的嵌套函数,这会导致递归执行。每次迭代都会增加递归调用的深度。
如何避免这种情况?没有简单或内置的方法。最简单的答案是 - 假设这些函数的输入是一致的 - 存储功能结果(即数字)而不是函数本身。只要您有有限数量的已知输入,就可以这样做。
如果底层函数的输入不一致,则没有捷径。您需要反复评估这些底层函数。我看到您正在对底层函数进行分段拼接 - 您可以测试是否超出了这样做的成本,是否超出了仅取每个底层函数的最大值的成本。
我运行的测试(10次迭代)花费了几秒钟。我认为这不是一个问题。

@老程序员,我觉得你的说法很令人困惑。上下文是“如果底层函数的输入不一致,那么就没有捷径”。我不是在谈论一致的函数 - 我是在谈论这些函数的一致输入。 - Kirk Broadhurst
我在下一句话结束时忘记了那个上下文,或者可能没有意识到它是“显而易见”的。 - greybeard

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接