MATLAB中高效的树实现

18

MATLAB中的Tree类

我正在MATLAB中实现一种树形数据结构。添加新的子节点到树中,分配和更新与节点相关的数据值是我期望执行的典型操作。每个节点都有相同类型的data与之关联。对于我来说,删除节点并不是必需的。到目前为止,我已经决定采用从handle类继承的类实现,以便能够将引用传递给将修改树的函数。

编辑:12月2日

首先,感谢所有评论和答案中的建议。它们已经帮助我改进了我的树类。

有人建议尝试在R2015b中引入的digraph。我还没有探索过这个功能,但是由于它不能像从handle类继承的类那样作为引用参数工作,因此我有点怀疑它在我的应用程序中将如何工作。此时,对于如何使用自定义data来处理节点和边仍不清楚。

编辑:(12月3日)有关主要应用程序MCTS的更多信息

最初,我认为主要应用程序的细节只是边缘兴趣,但是自从阅读了评论和@FirefoxMetzger的答案以来,我意识到它具有重要的影响。

我正在实现一种Monte Carlo树搜索算法。搜索树以迭代方式进行探索和扩展。Wikipedia提供了该过程的良好图形概述:Monte Carlo tree search

在我的应用程序中,我执行大量的搜索迭代。在每个搜索迭代中,我从根开始遍历当前树,直到叶节点,然后通过添加新节点扩展树,并重复此过程。由于该方法基于随机抽样,在每次迭代开始时,我不知道我将在每次迭代中完成哪个叶节点。相反,这是由当前树中节点的data和随机样本的结果共同确定的。我访问的任何节点都会更新其data

示例:我在节点n,它有几个子节点。我需要访问每个孩子的数据,并绘制一个随机样本来确定下一步搜索哪个孩子。重复此过程,直到到达叶子节点。实际上,我通过在根上调用search函数来实现这一点,该函数将决定下一个要展开的子节点,递归地调用search该节点,等等,最后一旦到达叶子节点就返回一个值。返回递归函数时使用此值以更新搜索迭代期间访问的节点的data

树可能非常不平衡,某些分支是非常长的节点链,而其他分支在根级别之后很快终止并且不再展开。

当前实现

以下是我的当前实现示例,其中包括添加节点、查询深度或树中节点数量等成员函数的示例。

classdef stree < handle
    %   A class for a tree object that acts like a reference
    %   parameter.
    %   The tree can be traversed in both directions by using the parent
    %   and children information.
    %   New nodes can be added to the tree. The object will automatically
    %   keep track of the number of nodes in the tree and increment the
    %   storage space as necessary.

    properties (SetAccess = private)
        % Hold the data at each node
        Node = { [] };
        % Index of the parent node. The root of the tree as a parent index
        % equal to 0.
        Parent = 0;
        num_nodes = 0;
        size_increment = 1;
        maxSize = 1;
    end

    methods
        function [obj, root_ID] = stree(data, init_siz)
            % New object with only root content, with specified initial
            % size
            obj.Node = repmat({ data },init_siz,1);
            obj.Parent = zeros(init_siz,1);
            root_ID = 1;
            obj.num_nodes = 1;
            obj.size_increment = init_siz;
            obj.maxSize = numel(obj.Parent);
        end

        function ID = addnode(obj, parent, data)
            % Add child node to specified parent
            if obj.num_nodes < obj.maxSize
                % still have room for data
                idx = obj.num_nodes + 1;
                obj.Node{idx} = data;
                obj.Parent(idx) = parent;
                obj.num_nodes = idx;
            else
                % all preallocated elements are in use, reserve more memory
                obj.Node = [
                    obj.Node
                    repmat({data},obj.size_increment,1)
                    ];

                obj.Parent = [
                    obj.Parent
                    parent
                    zeros(obj.size_increment-1,1)];
                obj.num_nodes = obj.num_nodes + 1;

                obj.maxSize = numel(obj.Parent);

            end
            ID = obj.num_nodes;
        end

        function content = get(obj, ID)
            %% GET  Return the contents of the given node IDs.
            content = [obj.Node{ID}];
        end

        function obj = set(obj, ID, content)
            %% SET  Set the content of given node ID and return the modifed tree.
            obj.Node{ID} = content;
        end

        function IDs = getchildren(obj, ID)
            % GETCHILDREN  Return the list of ID of the children of the given node ID.
            % The list is returned as a line vector.
            IDs = find( obj.Parent(1:obj.num_nodes) == ID );
            IDs = IDs';
        end
        function n = nnodes(obj)
            % NNODES  Return the number of nodes in the tree.
            % Equal to root + those whose parent is not root.
            n = 1 + sum(obj.Parent(1:obj.num_nodes) ~= 0);
            assert( obj.num_nodes == n);
        end

        function flag = isleaf(obj, ID)
            % ISLEAF  Return true if given ID matches a leaf node.
            % A leaf node is a node that has no children.
            flag = ~any( obj.Parent(1:obj.num_nodes) == ID );
        end

        function depth = depth(obj,ID)
            % DEPTH return depth of tree under ID. If ID is not given, use
            % root.
            if nargin == 1
                ID = 0;
            end
            if obj.isleaf(ID)
                depth = 0;
            else
                children = obj.getchildren(ID);
                NC = numel(children);
                d = 0; % Depth from here on out
                for k = 1:NC
                    d = max(d, obj.depth(children(k)));
                end
                depth = 1 + d;
            end
        end
    end
end

然而,有时候性能较慢,树的操作占据了大部分计算时间。有哪些具体的方法可以使实现更加高效?如果存在性能提升,甚至可以将实现更改为其他类型而不是“handle”继承类型。

当前实现的分析结果

由于向树中添加新节点是最典型的操作之一(以及更新节点的data),我对此进行了一些分析。 我使用以下基准代码在Nd=6, Ns=10下运行了分析器。

function T = benchmark(Nd, Ns)
% Tree benchmark. Nd: tree depth, Ns: number of nodes per layer
% Initialize tree
T = stree(rand, 10000);
add_layers(1, Nd);
    function add_layers(node_id, num_layers)
        if num_layers == 0
            return;
        end
        child_id = zeros(Ns,1);
        for s = 1:Ns
            % add child to current node
            child_id(s) = T.addnode(node_id, rand);

            % recursively increase depth under child_id(s)
            add_layers(child_id(s), num_layers-1);
        end
    end
end

分析器的结果: Profiler results

R2015b 性能


发现 R2015b 改进了 MATLAB 的面向对象编程功能的性能。我重做了上述基准测试,确实观察到性能提高:

R2015b profiler result

所以这已经是一个好消息,当然如果有更进一步的改进就更好了 ;)

以不同的方式预留内存

在评论中还建议使用

obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];

为了比当前使用repmat的方法更好地保留内存,这样可以稍微提高性能。需要注意的是,我的基准代码是针对虚拟数据的,实际上由于实际数据更加复杂,因此这可能会有所帮助。以下是分析器的结果:

zeeMonkeez memory reserve style

关于进一步提高性能的问题

  1. 也许有一种更有效的方法来维护树的内存?不幸的是,我通常不知道树中会有多少个节点。
  2. 向树中添加新节点和修改现有节点的数据是我在树上执行的最典型操作。到目前为止,它们实际上占用了主应用程序大部分处理时间。任何这些功能的改进都将受到欢迎。

作为最后的说明,我希望保持实现纯MATLAB。但是,如MEX或使用一些集成的Java功能等选项也是可以接受的。


1
运行profiler可以在性能方面对你的代码进行很多揭示。运行一次,看看代码在哪里异常缓慢,它会给你一个指引,告诉你从哪里开始改进。 - Adriaan
2
除非您使用Matlab 2015b或更新版本,否则Matlab OOP会增加显着的开销,这可能会导致问题。不使用“handle”也可能无济于事。 - Daniel
1
@Adriaan 謝謝建議。我加入了一些分析器數據。 - mikkola
1
此外,根据您的数据类型,使用repmat来分配节点数据可能会增加很多开销。为什么不使用obj.Node = [obj.Node; data; cell(obj.size_increment - 1,1)];进行初始化呢? - zeeMonkeez
1
@mikkola digraph 将节点和边都实现为 MATLAB 表格,因此将额外数据存储到节点中只需要添加列即可。如果需要以面向对象的方式处理每个节点本身的数据,则可以在列中存储对象句柄。对于树本身的面向对象遍历,最清晰的方法可能是子类化 digraph 本身。 - Will
显示剩余10条评论
3个回答

9
TL:DR:每次插入都会深度复制存储的所有数据,并将parent和Node单元格初始化得比你预期需要的更大。
你的数据确实具有树形结构,但你没有在实现中利用它。相反,实现的代码是一个计算密集型的查找表(实际上是两个表),它存储树的数据和关系数据。
我这样说的原因如下:
- 要插入,您调用stree.addnote(parent, data),它将在树对象stree的字段Node = {}和Parent = []中存储所有数据。 - 您似乎已经知道要访问树中的哪个元素,因为没有提供搜索代码(如果使用stree.getchild(ID)进行搜索,则有一些坏消息)。 - 一旦处理了节点,就使用find()跟踪它,这是一个列表搜索。
这绝不意味着实现对于数据是笨拙的,甚至可能是最好的,具体取决于您正在做什么。但是这解释了您的内存分配问题,并提供了解决方法的提示。
保持数据作为查找表
存储数据的一种方法是保持底层查找表。只有在您知道要修改的第一个元素的ID而无需搜索时,才会这样做。这种情况可使您的结构更有效率,分两步实现。
首先,初始化您的数组比您预期需要存储数据的大小要大。如果查找表的容量超过了,则会初始化一个新表,该表比旧数据大X个字段,并进行旧数据的深度复制。如果您需要扩展容量一两次(在所有插入期间),那可能不是问题,但在您的情况下,每次插入都会进行深层复制!
其次,我将更改内部结构并合并两个表Node和Parent。原因是您代码中的反向传播需要O(depth_from_root * n)时间,其中n是表中节点的数量。这是因为find()会为每个父项迭代整个表。
取而代之的是,您可以实现类似以下内容的东西:
table = cell(n,1) % n bigger then expected value
end_pointer = 1 % simple pointer to the first free value

function insert(data,parent_ID)
    if end_pointer < numel(table)
        content.data = data;
        content.parent = parent_ID;
        table{end_pointer} = content;
        end_pointer = end_pointer + 1;
    else
        % need more space, make sure its enough this time
        table = [table cell(end_pointer,1)];
        insert(data,parent_ID);
    end
end

function content = get_value(ID)
    content = table(ID);
end

这将立即使您访问父节点的ID,无需先进行find()操作,每一步节省n次迭代,所以承受力变为O(depth)。如果您不知道初始节点,则必须先查找它,这将耗费O(n)。
请注意,此结构不需要is_leaf()depth()nnodes()get_children()。如果您仍需要这些功能,请告诉我更多关于您数据处理的信息,因为这会极大地影响合适的结构。

树形结构

如果您从不知道第一个节点的ID,并且因此总是需要搜索它,则此结构是有意义的。

优点是对任意节点的搜索都可以使用O(depth)完成,因此搜索是O(depth)而不是O(n),回溯是O(depth^2)而不是O(depth+n)。请注意,深度可以是完美平衡树的log(n),这可能取决于您的数据,也可以是退化树的n,它只是一个链接列表。

但是,为了提出适当的建议,我需要更多的见解,因为每种树形结构都有其自己的特点。从目前我所看到的情况来看,我建议使用未平衡的树形结构,它通过节点所需的父节点的简单顺序进行“排序”。这可能会进一步优化,具体取决于

  • 是否可以在数据上定义完全顺序
  • 如何处理重复值(相同的数据出现两次)
  • 您的数据规模是多少(数千、数百万等)
  • 查找/搜索是否总是与回溯配对
  • 您的数据中“父-子”链的长度有多长(或者使用此简单顺序时树形结构的平衡程度和深度有多大)
  • 是否始终只有一个父节点,或者相同的元素插入了两次,但具有不同的父节点

我很乐意为上面的树形结构提供示例代码,只需留下评论即可。

编辑: 在您的情况下,未平衡的树形结构(与执行MCTS并行构建)似乎是最佳选择。下面的代码假定数据分为statescore,并且进一步假定state是唯一的。如果不是,则仍然可以工作,但是有可能进行优化以提高MCTS性能。

classdef node < handle
    % A node for a tree in a MCTS
    properties
        state = {}; %some state of the search space that identifies the node
        score = 0;
        childs = cell(50,1);
        num_childs = 0;
    end
    methods
        function obj = node(state)
            % for a new node simulate a score using MC
            obj.score = simulate_from(state); % TODO implement simulation state -> finish
            obj.state = state;
        end
        function value = update(obj)
            % update the this node using MC recursively
            if obj.num_childs == numel(obj.childs)
                % there are to many childs, we have to expand the table
                obj.childs = [obj.childs cell(obj.num_childs,1)];
            end
            if obj.do_exploration() || obj.num_childs == 0
                % explore a potential state
                state_to_explore = obj.explore();

                %check if state has already been visited
                terminate = false;
                idx = 1;
                while idx <= obj.num_childs && ~terminate
                    if obj.childs{idx}.state_equals(state_to_explore)
                        terminate = true;
                    end
                    idx = idx + 1;
                end

                %preform the according action based on search
                if idx > obj.num_childs
                    % state has never been visited
                    % this action terminates the update recursion 
                    % and creates a new leaf
                    obj.num_childs = obj.num_childs + 1;
                    obj.childs{obj.num_childs} = node(state_to_explore);
                    value = obj.childs{obj.num_childs}.calculate_value();
                    obj.update_score(value);
                else
                    % state has been visited at least once
                    value = obj.childs{idx}.update();
                    obj.update_score(value);
                end
            else
                % exploit what we know already
                best_idx = 1;
                for idx = 1:obj.num_childs
                    if obj.childs{idx}.score > obj.childs{best_idx}.score
                        best_idx = idx;
                    end
                end
                value = obj.childs{best_idx}.update();
                obj.update_score(value);
            end
            value = obj.calculate_value();
        end
        function state = explore(obj)
            %select a next state to explore, that may or may not be visited
            %TODO
        end
        function bool = do_exploration(obj)
            % decide if this node should be explored or exploited
            %TODO
        end
        function bool = state_equals(obj, test_state)
            % returns true if the nodes state is equal to test_state
            %TODO
        end
        function update_score(obj, value)
            % updates the score based on some value
            %TODO
        end
        function calculate_value(obj)
            % returns the value of this node to update previous nodes
            %TODO
        end
    end
end

关于代码的一些注释:

  • 根据设置,obj.calculate_value() 可能不需要。例如,如果它是可以通过仅评估子项得分来计算的某个值
  • 如果一个state可以有多个父节点,则重用注释对象并将其覆盖在结构中是有意义的
  • 由于每个node都知道它的所有子项,因此可以使用node作为根节点轻松生成子树
  • 搜索树(无需任何更新)是简单的递归贪心搜索
  • 根据您搜索的分支因子,可能值得访问每个可能的子项一次(在节点初始化时),然后进行randsample(obj.childs,1)以进行探索,因为这避免了子项数组的复制/重新分配
  • parent属性在树被递归更新时编码,将value传递给节点完成更新后的父节点
  • 只有当单个节点具有超过50个子项时,我才重新分配内存,并且仅对该单个节点进行重新分配

这样运行会更快,因为它只关心所选择的树的部分,而不触及其他任何部分。


谢谢您的回复!我的应用程序是一种蒙特卡罗树搜索。我最初没有意识到它会对树的设计产生多大影响,对此感到抱歉!我更新了问题,包括了更多细节。希望这能帮助您更好地聚焦于回答问题。 - mikkola
另一个后续问题:针对特定应用程序和我遍历树的方式,您是否认为查找表实现实际上更合适?我以前没有考虑过这一点。树搜索方面,以及我阅读的有关MCTS的可视化和其他文档,都让我立即使用树数据结构来实现它。 - mikkola
我猜你的“数据”包括一些复杂的状态和该状态的分数值,我的理解是更新不会影响状态,而是改变分数? - user2457516
继续使用查找表的想法,另一种选择可能是为每个节点保存此历史记录,从而消除了历史记录的需要。 - mikkola
这主要取决于你的底层空间是否是马尔可夫的。如果你能够满足马尔可夫性质,那么通常最好使用它,这样就不需要整个历史记录来描述事物。 - user2457516
显示剩余3条评论

6

我知道这可能听起来很蠢...但是不妨将空闲节点的数量保持在总节点数之前? 这将需要与一个常量(为零)进行比较,这是单个属性访问。

另一个改进方法是将 .maxSize 移动到 .num_nodes 附近,并将它们两者都放在.Node 单元格之前。这样,由于.Node 属性的增长,它们在内存中的位置相对于对象的开头不会改变(这里的神秘因素是我猜测MATLAB中对象的内部实现)。

编辑后 当我将 .Node 移动到属性列表的末尾时进行分析,大部分执行时间被扩展 .Node 属性所消耗,正如您所提到的比较所期望的那样(5.45秒,相比之下,您提到的比较仅需1.25秒)。


很有趣!与常数比较似乎是个好主意,我确实观察到了轻微的改进。不过对于 .Node 属性我就不太确定了——我尝试切换属性位置并没有观察到改进。我的理解是单元数组元素不一定占用内存中连续的块,因此对性能的影响难以预测。 - mikkola
@mikkola 问题在于,单元数组的元素不在连续的内存区域中,但它们的地址必须是连续的。从概念上讲,单元数组就像指针数组,而该数组本身在连续的内存区域中增长。顺便问一下,您每次运行基准测试时是否清除类?这通常会使新实例化的对象的性能变差(JIT内容,加上需要重新创建元类)。 - user2271770
感谢对地址指针的澄清!我重复了基准测试,每次都添加了 clear allclear classes。尽管我同意您的建议很有道理,但我仍然只看到了非常微小的性能提升。 - mikkola
@mikkola,很抱歉我的建议没有带来显著的结果。当实现不透明时,我们只能尝试几个可能有意义的理论(这就是我所谓的巫术编程)。最终,如果一个人不断地攻击系统,它可能会找到自己的甜点,或者如果不能,就等待下一个版本的发布,以改善性能。 - user2271770

4
您可以尝试分配与实际填充元素数量成比例的元素数量:这是c++标准库std::vector的标准实现方式。
obj.Node = [obj.Node; data; cell(q * obj.num_nodes,1)];

我记不清楚了,但在MSCC中,q为1,而在GCC中为0.75。


这是一个使用Java的解决方案。我不是很喜欢它,但它能够完成任务。我按照你从维基百科上提取的例子实现了它。

import javax.swing.tree.DefaultMutableTreeNode

% Let's create our example tree
top = DefaultMutableTreeNode([11,21])
n1 = DefaultMutableTreeNode([7,10])
top.add(n1)
n2 = DefaultMutableTreeNode([2,4])
n1.add(n2)
n2 = DefaultMutableTreeNode([5,6])
n1.add(n2)
n3 = DefaultMutableTreeNode([2,3])
n2.add(n3)
n3 = DefaultMutableTreeNode([3,3])
n2.add(n3)
n1 = DefaultMutableTreeNode([4,8])
top.add(n1)
n2 = DefaultMutableTreeNode([1,2])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n2 = DefaultMutableTreeNode([2,3])
n1.add(n2)
n1 = DefaultMutableTreeNode([0,3])
top.add(n1)

% Element to look for, your implementation will be recursive
searching = [0 1 1];
idx = 1;
node(idx) = top;
for item = searching,
    % Java transposes the matrices, remember to transpose back when you are reading
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

% We made a new test...
newdata = [0, 1]
newnode = DefaultMutableTreeNode(newdata)
% ...so we expand our tree at the last node we searched
node(idx).add(newnode)

% The change has to be propagated (this is where your recursion returns)
for it=length(node):-1:1,
    itnode=node(it);
    val = itnode.getUserObject()'
    newitemdata = val + newdata
    itnode.setUserObject(newitemdata)
end

% Let's see if the new values are correct
searching = [0 1 1 0];
idx = 1;
node(idx) = top;
for item = searching,
    node(idx).getUserObject()'
    node(idx+1) = node(idx).getChildAt(item);
    idx = idx + 1;
end
node(idx).getUserObject()'

这将是我的建议。每当向量用完空间时,它会预留2倍的空间。计算机喜欢2的幂 :) - Nick
感谢您的回答。澄清一下:searching = [0 1 1]; 的意思是我们想通过以下策略找到一个节点:根节点的第一个子节点,该节点的第二个子节点,最后到达该节点的第二个子节点 - 对吗?快速测试似乎表明 UserObject 可以是任何 Matlab 支持的类型,希望这也是正确的。 - mikkola
searching 数组的解释是正确的(Java 中的数组是基于 0 的)。在 Java 数组中,您可以存储的对象有限制(例如句柄),您必须尝试使用有效的对象。 - NicolaSysnet

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接