Java中递归算法的优化

10

背景

我有一组有序的数据点,存储在一个 TreeSet<DataPoint> 中。每个数据点都有一个 position 和一个 Event 对象的 SetHashSet<Event>)。

有 4 种可能的 Event 对象 ABCD。每个 DataPoint 都有其中的 2 个,例如 AC,除了集合中的第一个和最后一个 DataPoint 对象,它们只有大小为 1 的 T

我的算法是找到一个新的 DataPoint Q 在位置 x 上具有 Event q 在这个集合中的概率。

我通过计算此数据集的值 S,然后将 Q 添加到集合中并再次计算 S 来实现这一点。然后,我将第二个 S 除以第一个,以分离新的 DataPoint Q 的概率。

算法

计算 S 的公式如下:

http://mathbin.net/equations/105225_0.png

其中

http://mathbin.net/equations/105225_1.png

http://mathbin.net/equations/105225_2.png

对于 http://mathbin.net/equations/105225_3.png

以及

http://mathbin.net/equations/105225_4.png

http://mathbin.net/equations/105225_5.png是一个昂贵的概率函数,它只取决于其参数而不受其他影响(以及http://mathbin.net/equations/105225_6.png),http://mathbin.net/equations/105225_7.png是集合中的最后一个DataPoint(右侧节点),http://mathbin.net/equations/105225_8.png是第一个DataPoint(左侧节点),http://mathbin.net/equations/105225_9.png是不是节点的最右侧DataPointhttp://mathbin.net/equations/105225_10.png是一个DataPointhttp://mathbin.net/equations/105225_12.png是此DataPoint的事件Set

因此,具有事件qQ的概率为:

http://mathbin.net/equations/105225_11.png

实现

我用Java实现了这个算法:

public class ProbabilityCalculator {
    private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
        // do some stuff
    }
    
    private Double f(DataPoint right, Event rightEvent, NavigableSet<DataPoint> points) {
        DataPoint left = points.lower(right);
        
        Double result = 0.0;
        
        if(left.isLefthandNode()) {
            result = 0.25 * p(right, rightEvent, left, null);
        } else if(left.isQ()) {
            result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
        } else { // if M_k
            for(Event leftEvent : left.getEvents())
                result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
        }
        
        return result;
    }
    
    public Double S(NavigableSet<DataPoint> points) {
        return f(points.last(), points.last().getRightNodeEvent(), points)
    }
}

因此,要找到在x处概率为Qq

Double S1 = S(points);
points.add(Q);
Double S2 = S(points);
Double probability = S2/S1;

问题

目前的实现严格遵循数学算法。然而,在实践中,这并不是一个特别好的想法,因为对于每个 DataPointf 会调用自身两次。所以对于http://mathbin.net/equations/105225_9.pngf 被调用两次,然后对于前面每个调用,n-1 又会再次调用两次 f,依此类推。这导致复杂度为 O(2^n),考虑到每个 Set 中可能有超过 1000 个 DataPoints,这相当糟糕。由于 p() 除了它的参数之外与其他一切都无关,因此我包括了一个缓存函数,如果已经为这些参数计算了 p(),则它只返回先前的结果,但这并不能解决固有的复杂性问题。在重复计算方面,我是否漏掉了什么,或者这种算法的复杂性是不可避免的?


1
为什么不把 f 也缓存起来呢?只需将参数 points 从函数参数移动到类成员即可。 - Dialecticus
@Dialecticus 我认为如果我将points的子集存储到right左侧,这将使缓存在将Q添加到点之后仍然可以使用,一旦Q已经在过程中被传递。 - bountiful
是的,所以主函数将运行以下操作:清除P缓存,清除F缓存,获取S1,添加Q,清除F缓存,获取S2。 - Dialecticus
我认为不需要清除整个F缓存,只需清除Q右侧的部分即可。 - bountiful
3个回答

2

您还需要对前两个参数的f进行备忘录。第三个参数始终会被传递,所以您不需要担心它。这将将您的代码时间复杂度从O(2^n)降低到O(n),使其更加高效。


0

更新:

正如下面的评论所述,无法使用顺序来帮助优化,必须使用另一种方法。由于大多数P值将被多次计算(并且如注释所述,这是昂贵的),因此一种优化方法是将它们缓存起来。我不确定最佳键是什么,但您可以想象更改代码类似于:

....
private Map<String, Double> previousResultMap = new ....


private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
   String key = // calculate unique key from inputs
   Double previousResult = previousResultMap.get(key);
   if (previousResult != null) {
      return previousResult;
   } 

   // do some stuff
   previousResultMap.put(key, result);
   return result;
}

这种方法应该能够有效地减少很多冗余计算 - 但是,由于您比我更了解数据,您需要确定设置密钥的最佳方式(甚至是否使用字符串作为最佳表示形式)。


如果我理解正确的话,我认为这不会起作用。因为S不仅取决于点的数量。如果我将Q放在x和y之间的位置x上,那么没有Q的S将调用p(x, xEvent, y, yEvent),然而有Q的S将调用p(x, xEvent, q, qEvent),然后调用p(q, qEvent, y, yEvent)。但是我可以同时调用两个S,并且只有当其中一个到达Q时才转移。 - bountiful
你应该看看我下面的答案! - bountiful

0

感谢您的所有建议。我通过为已计算的PF值创建新的嵌套类来实现我的解决方案,然后使用HashMap来存储结果。在进行计算之前,查询HashMap是否存在结果;如果存在,则直接返回结果,如果不存在,则计算结果并将其添加到HashMap中。

最终的产品看起来有点像这样:

public class ProbabilityCalculator {

    private NavigableSet<DataPoint> points;

    private ProbabilityCalculator(NavigableSet<DataPoint> points) {
        this.points = points;
    }

    private static class P {
        public final DataPoint left;
        public final Event leftEvent;
        public final DataPoint right;
        public final Event rightEvent;

        public P(DataPoint left, Event leftEvent, DataPoint right, Event rightEvent) {
            this.left = left;
            this.leftEvent = leftEvent;
            this.right = right;
            this.rightEvent = rightEvent;
        }

        public boolean equals(Object o) {
            if(!(o instanceof P)) return false;
            P p = (P) o;

            if(!(this.leftEvent == null ? p.leftEvent == null : this.leftEvent.equals(p.leftEvent)))
                return false;
            if(!(this.rightEvent == null ? p.rightEvent == null : this.rightEvent.equals(p.rightEvent)))
                return false;

            return this.left.equals(p.left) && this.right.equals(p.right);
        }

        public int hashCode() {
            int result = 93;

            result = 31 * result + this.left.hashCode();
            result = 31 * result + this.right.hashCode();
            result = this.leftEvent != null ? 31 * result + this.leftEvent.hashCode() : 31 * result;
            result = this.rightEvent != null ? 31 * result + this.rightEvent.hashCode() : 31 * result;

            return result;
        }
    }

    private Map<P, Double> usedPs = new HashMap<P, Double>();

    private static class F {
        public final DataPoint left;
        public final Event leftEvent;
        public final NavigableSet<DataPoint> dataPointsToLeft;

        public F(DataPoint dataPoint, Event dataPointEvent, NavigableSet<DataPoint> dataPointsToLeft) {
            this.dataPoint = dataPoint;
            this.dataPointEvent = dataPointEvent;
            this.dataPointsToLeft = dataPointsToLeft;
        }

        public boolean equals(Object o) {
            if(!(o instanceof F)) return false;
            F f = (F) o;
            return this.dataPoint.equals(f.dataPoint) && this.dataPointEvent.equals(f.dataPointEvent) && this.dataPointsToLeft.equals(f.dataPointsToLeft);
        }

        public int hashCode() {
            int result = 7;

            result = 31 * result + this.dataPoint.hashCode();
            result = 31 * result + this.dataPointEvent.hashCode();
            result = 31 * result + this.dataPointsToLeft.hashCode();

            return result;
        }

    }

    private Map<F, Double> usedFs = new HashMap<F, Double>();

    private Double p(DataPoint right, Event rightEvent, DataPoint left, Event leftEvent) {
        P newP = new P(right, rightEvent, left, leftEvent);

        if(this.usedPs.containsKey(newP)) return usedPs.get(newP);


        // do some stuff

        usedPs.put(newP, result);
        return result;

    }

    private Double f(DataPoint right, Event rightEvent) {

        NavigableSet<DataPoint> dataPointsToLeft = dataPoints.headSet(right, false);

        F newF = new F(right, rightEvent, dataPointsToLeft);

        if(usedFs.containsKey(newF)) return usedFs.get(newF);

        DataPoint left = points.lower(right);

        Double result = 0.0;

        if(left.isLefthandNode()) {
            result = 0.25 * p(right, rightEvent, left, null);
        } else if(left.isQ()) {
            result = p(right, rightEvent, left, left.getQEvent()) * f(left, left.getQEvent(), points);
        } else { // if M_k
            for(Event leftEvent : left.getEvents())
                result += p(right, rightEvent, left, leftEvent) * f(left, leftEvent, points);
        }

        usedFs.put(newF, result)

        return result;
    }

    public Double S() {
        return f(points.last(), points.last().getRightNodeEvent(), points)
    }

    public static probabilityOfQ(DataPoint q, NavigableSet<DataPoint> points) {
        ProbabilityCalculator pc = new ProbabilityCalculator(points);

        Double S1 = S();

        points.add(q);

        Double S2 = S();

        return S2/S1;

    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接