函数层次结构的组成

11
有没有一种规范的方式来表达一个由函数树组成的函数?
以下是“函数树组合”的具体示例。以一个带有标记函数的根节点的树为例,如下所示:

enter image description here

每个节点的函数都是其子节点的函数组合。与树相关联的函数本身就是这些组合的结果。
F = a0(b0(c0(e0, e1, e2)), b1(d0(f0), d1(g0, g1)))

更明确地说,F 是一个由 6 个参数组成的函数,这些参数通过叶子节点上的函数进行求值:
F(x0, ... , x5) == a0(b0(c0(e0(x0), e1(x1), e2(x2))),
                      b1(d0(f0(x3)), d1(g0(x4), g1(x5))))

通用问题
给定一个根树T和一个与T的节点对应的函数列表L,是否有一种“规范”的方式来编写一个函数F,该函数接受参数T和L,并返回按照树T结构组合的L中的函数的结果?
因此,组合的“布线”——树T与其内部“组件”——列表L分开。一个“规范”的解决方案应该包括特别适用于这个问题的T和L的表示。
我怀疑这个问题在函数式编程语言中有一个微不足道的解决方案,但理想情况下,我想在像Python这样的动态类型命令式语言中找到一个解决方案,例如:
def treecomp(tree, list_of_funcs):
    ...
    return function

F = treecomp(T, L)

补充:

与此同时,我想出了自己的解决方案(如下所示)。

虽然我对它的经济性和概念上的简单性感到满意,但我仍然对其他本质上不同的方法感兴趣,特别是那些利用另一种语言的优势,在Python中缺乏或支持不佳的方法。

直觉

通过适当的数据结构——不会本质上重现所需的输出!——函数式编程习惯应该能够实现非常简短的解决方案。


听起来像是一张奇怪的调用图。有趣! :) - erip
@EugeneHa 我正在使用函数字典而不是列表,这样可以吗? - JeD
@JeD:使用字典也是相当合理的,但列表在某种程度上更方便。在示例中,LT都可以表示为词典列表:L = [a0, b0, b1, c0, ... , g1]T = [(0,), (0,0), (0,1), (0,0,0), ... , (0,1,1,1)]。(类似于书中的章节编号。)然后,带有函数标签的树是列表TL = list(zip(T, L))。因此,在编写treecomp时,一种方法是展开TL中的(隐式)树结构并递归地组合函数。但我不确定如果我选择将TL作为列表是否已经出现了问题!尽管这看起来很自然。 - egnha
@JeD:我并没有被搞糊涂;参见我的解决方案,虽然仍然有些笨拙。:( - egnha
@JeD:我的解决方案中的笨拙已经被纠正。 - egnha
5个回答

2

听起来很有趣。我尝试了一下,这是结果。

class Node(object):
    def __init__(self, parents, fn):
        self.parents = parents
        self.fn = fn

    def get_number_of_args(self):
        if not self.parents:
            return 1
        if not hasattr(self, '_cached_no_args'):
            self._cached_no_args = sum(
                parent.get_number_of_args() for parent in self.parents
            )
        return self._cached_no_args

    def compose(self):
        if not self.parents:
            def composition(*args):
                return self.fn(*args)
            return composition

        fns = []
        fns_args = []
        for parent in self.parents:
            fns.append(parent.compose())
            fns_args.append(parent.get_number_of_args())
        number_of_args = sum(fns_args)
        length = len(self.parents)

        def composition(*args):
            if len(args) != number_of_args:
                raise TypeError
            sub_args = []
            last_no_args = 0
            reached_no_args = 0
            for i in range(length):
                fn = fns[i]
                no_args = fns_args[i]
                reached_no_args += no_args
                args_cut = args[last_no_args: reached_no_args]
                sub_call = fn(*args_cut)
                sub_args.append(sub_call)
                last_no_args = no_args
            return self.fn(*sub_args)

        return composition

您没有说明如何实现树结构,因此我将节点和函数组合成一个结构(您始终可以自己进行映射)。现在看用法:

>>> def fn(x):
...     return x
>>> def fn2(x):
...    return 1
>>> def add(x, y):
...    return x + y
>>> n1 = Node([], fn)
>>> n2 = Node([], fn2)
>>> n3 = Node([n1, n2], add)
>>> fn = n3.compose()
>>> print(fn(5, 7))
6

正如预期的那样。请随意测试它(我实际上还没有在更深层次的树上尝试过),如果您发现任何问题,请告诉我。


感谢您的输入。事实证明,不需要进行任何节点计数的簿记。 - egnha

2
我们定义一个名为 treecomp 的函数,它根据一棵根树 T 的结构,将一组函数列表 L 合成。该函数接受 LT 作为分开的参数。请保留 HTML 标签。
F = treecomp(T, L)

与迄今提出的其他解决方案不同,它不会被不必要的簿记(如跟踪叶子或参数数量,这些最好由装饰器处理)所复杂。 treecomp的简单构造 treecomp的一个简单实现如下:它仅生成树组合的符号(字符串)表达式。然后,将其插入并评估生成的表达式就是一件简单的事情。
这个天真的想法可以使用相当基本的数据结构来实现:用于树和函数的列表,以及用于函数标记树的简单类。(命名元组也可以做到。但是,通过使用具有特殊比较方法的类,我们可以编写更加语义化自然的代码。)
数据结构
作为“平面”列表,根树的最经济编码方式是作为“节点地址”的列表。在对@JeD的评论中,我暗示可以通过“绘制”树来完成此操作:
T = [(0,),
         (0, 0),
             (0, 0, 0),
                 (0, 0, 0, 0), (0, 0, 0, 1), (0, 0, 0, 2),
         (0, 1),
             (0, 1, 0),
                 (0, 1, 0, 0),
             (0, 1, 1),
                 (0, 1, 1, 0), (0, 1, 1, 1)]

在这里,(0,) 是对应于 a0 的节点,(0, 0) 是对应于 b0 的节点,(0, 1) 是对应于 b1 的节点,依此类推,就像书中章节的编号一样。最长(或“最高”)的元组是叶子节点。
函数列表 L 可以按照 T 中节点的顺序给出一个匹配列表。
L = [a0, b0, c0, e0, e1, e2, b1, d0, f0, d1, g0, g1]

由于树T的节点是由L中的函数标记的,因此为此需要一个数据结构。我们定义了一个类来记录节点的地址和标记它的函数的字面名称;其方法实现相对于树的偏序比较(其中根是最小元素):

class SymbNode:
    '''Class that records a node's address and symbol.'''

    def __init__(self, addr, symb):
        self.addr = addr
        self.symb = symb

    def __len__(self): # how "high" a node is above the root
        return len(self.addr)

    def _compare(self, other, segment):
        return self.addr == other.addr[:segment]

    def __le__(self, other):
        return self._compare(other, segment=len(self))

    def begets(self, other):
        return self._compare(other, segment=-1)

实现

treecomp的简单两步机制如下所示。通过规范化SymbNodes列表的顺序,我们可以通过简单地“剥离”树的每一层来构建符号表达式。

from functools import partial
from operator import attrgetter

def treecomp(tree, funcs):
    '''Returns the composition of a tree of functions.'''
    symbtree = makesymbtree(tree, funcs)
    symbexp = makesymbexp(symbtree)
    return partial(evalsymbexp, symbexp=symbexp)

FUNC_CALL = '{func}({{}})'

def makesymbtree(tree, funcs):
    '''Returns the symbolic expression of a tree composition.'''
    symbols = [FUNC_CALL.format(func=func.__name__) for func in funcs]
    symbtree = sorted((SymbNode(*x) for x in zip(tree, symbols)),
                      key=attrgetter('addr'))
    symbtree.sort(key=len)
    return symbtree

def makesymbexp(symbtree):
    root = symbtree[0]
    if len(symbtree) == 1: # symbtree is a leaf node
        return root.symb
    symbargs = [makesymbexp(subsymbtree(symbtree, root=node))
                for node in symbtree if root.begets(node)]
    return root.symb.format(','.join(symbargs))

def subsymbtree(symbtree, root):
    subsymbtree = [node for node in symbtree if root <= node]
    return subsymbtree

ARGS = 'args[{idx}]'

def evalsymbexp(symbexp, *args):
    '''Returns the evaluation of a symbolic expression on arguments.'''
    argnames = [ARGS.format(idx=str(n)) for n, _ in enumerate(args)]
    return eval(symbexp.format(*argnames))

验证

由于 treecomp 的分隔性,我们只需要验证函数 makesymbexp 生成正确的符号表达式,以及函数 evalsymbexp 正确评估符号表达式。

函数evalsymbexp(基本上只有一行)应该采用字符串模板并插入参数名称'args [0]''args [1]'等,然后评估结果。它显然做到了这一点。

至于 makesymbexp,在没有正式证明(我们避免这样做)的情况下,我们可以通过检查其在一些测试数据上的输出来获得其正确性。例如,考虑以下函数:

def D(x): return 2*x
def M(x): return -x
def S(*xs): return sum(xs)

a0 = S
b0, b1 = D, S
c0, d0, d1 = S, D, S
e0, e1, e2, f0, g0, g1 = D, M, D, M, D, M

使用上述的TL,我们可以检查是否得到正确的符号表达式:
makesymbexp(makesymbtree(T, L))

确实会返回字符串

'S(D(S(D({}),M({}),D({}))),S(D(M({})),S(D({}),M({}))))'

为了检查treecompevalsymbexp的委托,作为一个部分函数,我验证了以下值:

F = treecomp(T, L)
F(x0, x1, x2, x3, x4, x5)

同意价值观

a0(b0(c0(e0(x0), e1(x1), e2(x2))), b1(d0(f0(x3)), d1(g0(x4), g1(x5))))

从-100到100之间的整数中,随机抽取1000个样本x0,…,x5


1
这是一个我编写的简单示例:

Here's a simple example that I've cooked up:

from collections import deque

class Node(object):
    def __init__(self, children, func):
        self.children = children
        if children:
            self.leaf_count = sum(c.leaf_count for c in children)
        else:
            self.leaf_count = 1  # It is a leaf node.

        self.func = func

    def __call__(self, *args):
        if not self.children:
            assert len(args) == 1, 'leaf can only accept 1 argument.'
            return self.func(*args)  # Base case.

        d_args = deque(args)
        func_results = []
        for child in self.children:
            f_args = [d_args.popleft() for _ in xrange(child.leaf_count)]
            func_results.append(child(*f_args))
        assert not d_args, 'Called with the wrong number of arguments'
        return self.func(*func_results)

基本上,“技巧”就是要跟踪每个节点有多少叶子节点,因为叶子节点的数量是它期望接受的参数数量。
  • 如果一个节点是叶子节点,则只需使用单个输入参数调用其委托函数。
  • 如果一个节点不是叶子节点,则根据子树中叶子节点的数量调用每个子节点并提供相应的参数。

一些实现注意事项:

我使用了collections.deque来获取正确数量的参数以传递给子节点。这是为了效率,因为deque让我们在O(1)时间内获取这些参数。否则,我们将得到类似以下内容的东西:

for child in self.children:
    f_args = args[:child.leaf_count]
    args = args[child.leaf_count:]
    func_results.append(child(*args))

但是每个遍历都需要O(N)的时间。对于小树,这可能无关紧要。对于大树,可能会很重要:-)。

我还使用了叶子计数的静态成员,这意味着您需要从叶子到根来构建树。当然,您可以根据问题约束使用不同的策略。例如,您可以构建树,然后在开始评估函数之前在单个遍历中填写leaf_count,或者您可以将leaf_count转换为函数(@property),每次调用它都会计算叶子(对于大树来说可能会变得很昂贵)。

现在进行一些测试……我能想到的最简单的情况是,如果叶节点都与身份函数相关联,那么非叶节点就是将输入值相加的函数。在这种情况下,结果应该总是输入值的总和:

def my_sum(*args):
    return sum(args)

def identity(value):
    return value

e0, e1, e2, f0, g0, g1 = [Node([], identity) for _ in xrange(6)]
c0 = Node([e0, e1, e2], my_sum)
d0 = Node([f0], my_sum)
d1 = Node([g0, g1], my_sum)
b0 = Node([c0], my_sum)
b1 = Node([d0, d1], my_sum)
a0 = Node([b0, b1], my_sum)

arg_tests = [
    (1, 1, 1, 1, 1, 1),
    (1, 2, 3, 4, 5, 6)
]
for args in arg_tests:
    assert a0(*args) == sum(args)
print('Pass!')

你的实现非常清晰易懂。事实证明,如果将树放在规范化形式中,则“技巧”是不必要的。(如果用户提供了错误数量的参数,请让执行失败——宁愿请求原谅也不要征得许可。;) 看起来,你设置的树标签方式基本上相当于手写组合。 - egnha

1

如果你想将函数和树解耦,可以这样做:

#root=RootNode, funcs=Map from Node to function, values=list of inputs
#nodes need isLeaf and children field
def Func(root,funcs,values):
    #check if leaf
    if root.isLeaf:
        #removes value from list
        val=values.pop(0)
        #returns function of root
        #(can be identity if you just want to input values)
        return funcs[root](val)

    #else we do a recursive iteration:
    else:
        nextVals=[]
        #for each child
        for child in root.children:
            #call this function->DFS for roots, removes values that are used
            nextVals.append(Func(child,funcs,values))
        #unpack list and call function
        return funcs[root](*nextVals)

Here an example:

class Node:
    children=[]
    isLeaf=False

    def __init__(self,isLeaf):
        self.isLeaf=isLeaf

    def add(self,n):
        self.children.append(n)




def Func(root,funcs,values):
    #check if leaf
    if root.isLeaf:
        #removes value from list
        val=values.pop(0)
        #returns function of root
        #(can be identity if you just want to input values)
        return funcs[root](val)

    #else we do a recursive iteration:
    else:
        nextVals=[]
        #for each child
        for child in root.children:
            #call this function->DFS for roots, removes values that are used
            nextVals.append(Func(child,funcs,values))
        #unpack list and call function
        return funcs[root](*nextVals)


def sum3(a,b,c):
    return a+b+c


import math

funcMap={}
funcMap[root]=sum3

root=Node(False)
layer1=[Node(True) for i in range(3)]
for i in range(3):
    root.add(layer1[i])
    funcMap[layer1[i]]=math.sin




print Func(root,funcMap,[1,2,3])
print math.sin(1)+math.sin(2)+math.sin(3)

这将返回相同的值(使用Python 2.7)


0

这是一个很好的面向对象编程(OOP)的候选。例如,您可以使用以下三个类:

  1. 节点(Node)
  2. 叶子(Leaf)
  3. 树(Tree)

对于处理树形结构,递归方法通常更容易。

或者,您也可以通过嵌套元组来直接构建递归结构。例如:

n1 = ( 'L', 'e0' )
n2 = ( 'L', 'e1' )
n3 = ( 'L', 'e2' )
n4 = ( 'N', 'c0', n1, n2, n3 )
n5 = ( 'N', 'b0', n4 )

这不是你完整的树,但它可以很容易地扩展。只需使用print(n5)查看结果。

这不是唯一的方法,可能会有变化。对于每个元组,第一个项目是指定它是叶子“L”还是节点“N”的字母--这将使递归函数更容易。第二个项目是名称(从您的绘图中获取)。对于节点,其他项目是子节点。

(注意:我曾经使用“元组内嵌元组”来实现Huffmann编码算法--它也适用于树结构)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接