Java中的KDTree实现

55
我正在寻找Java中的KDTree实现。
我已经进行了谷歌搜索,结果似乎相当杂乱。虽然有许多结果,但它们大多只是一次性的小实现,我希望能找到一些更具“生产价值”的东西。像Apache集合或.NET的优秀C5集合库之类的东西。这样我就可以看到公共错误跟踪器,并检查上次SVN提交发生的时间。在理想情况下,我会找到一个漂亮而设计良好的空间数据结构API,而KDTree只是该库中的一个类。
对于这个项目,我只会在2或3个维度中工作,我主要关心一个良好的最近邻居实现。

12
看起来现在轮到你写点东西并且分享出去了。 - Peter Wone
2
你的第一个链接已经失效,而第二个链接会带你到 http://code.openhub.net/ ...请更新或删除这些链接。 - grepit
11个回答

25
在书籍《算法精解》中,有一个Java的kd树实现以及几个变体。所有代码都在oreilly.com上,该书本身也会引导您了解算法,以便自己构建一个。

2
具体来说:http://examples.oreilly.com/9780596516246/Releases/ADK-1.0.zip位于:ADK-1.0\ADK\Deployment\JavaCode\src\algs\model\kdtree - John Kurlak
3
也可以在Github上找到,链接如下:https://github.com/heineman/algorithms-nutshell-2ed/tree/master/JavaCode/src/algs/model/kdtree。 - Jim Andreas

17

这个库的好处(与“算法概要”实现之类的库相比)在于API使用本地双精度数组作为键和范围,而不是自定义对象。 - Oliver Coleman
1
java-ml中的KDTree实现正是Levy教授的,但已经过时了许多。 - Dmitry Avtonomov
2
太遗憾了,它没有发布到Maven仓库。 - Display Name
1
实际上,将net.sf设置为groupId,javaml设置为artifactId对我起了作用。请参见https://dev59.com/_YHba4cB1Zd3GeqPXOKF。 - besil

12
我在这里成功使用了Levy教授的实现。我了解您正在寻找更加适合生产环境认证的实现,所以它可能并不适用。
不过需要注意的是,对于任何路过的人,我已经在我的拼图项目中使用它一段时间了,并且没有出现任何问题。虽然不能保证百分百安全,但总比没有好 :)

我也在这方面取得了很多成功。+1 (注意:它是LGPL许可证)。 - Tom
太棒了!!正是我所需要的,易于集成,而且显然可以直接使用。尚未测试性能,但我非常高兴! - rupps
我在过去的两年里一直在使用 Processor Levy 的实现,它存在一些奇怪的 bug,但是我也在处理超过 1000 万个数据点。其中有一个 bug 是在附近方法中返回了错误的索引。 - Wisienkas

5

3

2

这是一个完整的KD-Tree实现,我使用了一些库来存储点和矩形。这些库是免费提供的。你可以通过创建自己的类来存储点和矩形来使用这些类。欢迎分享您的反馈意见。

import java.util.ArrayList;
import java.util.List;
import edu.princeton.cs.algs4.In;
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;
public class KdTree {
    private static class Node {
        public Point2D point; // the point
        public RectHV rect; // the axis-aligned rectangle corresponding to this
        public Node lb; // the left/bottom subtree
        public Node rt; // the right/top subtree
        public int size;
        public double x = 0;
        public double y = 0;
        public Node(Point2D p, RectHV rect, Node lb, Node rt) {
            super();
            this.point = p;
            this.rect = rect;
            this.lb = lb;
            this.rt = rt;
            x = p.x();
            y = p.y();
        }

    }
    private Node root = null;;

    public KdTree() {
    }

    public boolean isEmpty() {
        return root == null;
    }

    public int size() {
        return rechnenSize(root);
    }

    private int rechnenSize(Node node) {
        if (node == null) {
            return 0;
        } else {
            return node.size;
        }
    }

    public void insert(Point2D p) {
        if (p == null) {
            throw new NullPointerException();
        }
        if (isEmpty()) {
            root = insertInternal(p, root, 0);
            root.rect = new RectHV(0, 0, 1, 1);
        } else {
            root = insertInternal(p, root, 1);
        }
    }

    // at odd level we will compare x coordinate, and at even level we will
    // compare y coordinate
    private Node insertInternal(Point2D pointToInsert, Node node, int level) {
        if (node == null) {
            Node newNode = new Node(pointToInsert, null, null, null);
            newNode.size = 1;
            return newNode;
        }
        if (level % 2 == 0) {//Horizontal partition line
            if (pointToInsert.y() < node.y) {//Traverse in bottom area of partition
                node.lb = insertInternal(pointToInsert, node.lb, level + 1);
                if(node.lb.rect == null){
                    node.lb.rect = new RectHV(node.rect.xmin(), node.rect.ymin(),
                            node.rect.xmax(), node.y);
                }
            } else {//Traverse in top area of partition
                if (!node.point.equals(pointToInsert)) {
                    node.rt = insertInternal(pointToInsert, node.rt, level + 1);
                    if(node.rt.rect == null){
                        node.rt.rect = new RectHV(node.rect.xmin(), node.y,
                                node.rect.xmax(), node.rect.ymax());
                    }
                }
            }

        } else if (level % 2 != 0) {//Vertical partition line
            if (pointToInsert.x() < node.x) {//Traverse in left area of partition
                node.lb = insertInternal(pointToInsert, node.lb, level + 1);
                if(node.lb.rect == null){
                    node.lb.rect = new RectHV(node.rect.xmin(), node.rect.ymin(),
                            node.x, node.rect.ymax());
                }
            } else {//Traverse in right area of partition
                if (!node.point.equals(pointToInsert)) {
                    node.rt = insertInternal(pointToInsert, node.rt, level + 1);
                    if(node.rt.rect == null){
                        node.rt.rect = new RectHV(node.x, node.rect.ymin(),
                                node.rect.xmax(), node.rect.ymax());
                    }
                }
            }
        }
        node.size = 1 + rechnenSize(node.lb) + rechnenSize(node.rt);
        return node;
    }

    public boolean contains(Point2D p) {
        return containsInternal(p, root, 1);
    }

    private boolean containsInternal(Point2D pointToSearch, Node node, int level) {
        if (node == null) {
            return false;
        }
        if (level % 2 == 0) {//Horizontal partition line
            if (pointToSearch.y() < node.y) {
                return containsInternal(pointToSearch, node.lb, level + 1);
            } else {
                if (node.point.equals(pointToSearch)) {
                    return true;
                }
                return containsInternal(pointToSearch, node.rt, level + 1);
            }
        } else {//Vertical partition line
            if (pointToSearch.x() < node.x) {
                return containsInternal(pointToSearch, node.lb, level + 1);
            } else {
                if (node.point.equals(pointToSearch)) {
                    return true;
                }
                return containsInternal(pointToSearch, node.rt, level + 1);
            }
        }

    }

    public void draw() {
        StdDraw.clear();
        drawInternal(root, 1);
    }

    private void drawInternal(Node node, int level) {
        if (node == null) {
            return;
        }
        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.02);
        node.point.draw();
        double sx = node.rect.xmin();
        double ex = node.rect.xmax();
        double sy = node.rect.ymin();
        double ey = node.rect.ymax();
        StdDraw.setPenRadius(0.01);
        if (level % 2 == 0) {
            StdDraw.setPenColor(StdDraw.BLUE);
            sy = ey = node.y;
        } else {
            StdDraw.setPenColor(StdDraw.RED);
            sx = ex = node.x;
        }
        StdDraw.line(sx, sy, ex, ey);
        drawInternal(node.lb, level + 1);
        drawInternal(node.rt, level + 1);
    }

    /**
     * Find the points which lies in the rectangle as parameter
     * @param rect
     * @return
     */
    public Iterable<Point2D> range(RectHV rect) {
        List<Point2D> resultList = new ArrayList<Point2D>();
        rangeInternal(root, rect, resultList);
        return resultList;
    }

    private void rangeInternal(Node node, RectHV rect, List<Point2D> resultList) {
        if (node == null) {
            return;
        }
        if (node.rect.intersects(rect)) {
            if (rect.contains(node.point)) {
                resultList.add(node.point);
            }
            rangeInternal(node.lb, rect, resultList);
            rangeInternal(node.rt, rect, resultList);
        }

    }

    public Point2D nearest(Point2D p) {
        if(root == null){
            return null;
        }
        Champion champion = new Champion(root.point,Double.MAX_VALUE);
        return nearestInternal(p, root, champion, 1).champion;
    }

    private Champion nearestInternal(Point2D targetPoint, Node node,
            Champion champion, int level) {
        if (node == null) {
            return champion;
        }
        double dist = targetPoint.distanceSquaredTo(node.point);
        int newLevel = level + 1;
        if (dist < champion.championDist) {
            champion.champion = node.point;
            champion.championDist = dist;
        }
        boolean goLeftOrBottom = false;
        //We will decide which part to be visited first, based upon in which part point lies.
        //If point is towards left or bottom part, we traverse in that area first, and later on decide
        //if we need to search in other part too.
        if(level % 2 == 0){
            if(targetPoint.y() < node.y){
                goLeftOrBottom = true;
            }
        } else {
            if(targetPoint.x() < node.x){
                goLeftOrBottom = true;
            }
        }
        if(goLeftOrBottom){
            nearestInternal(targetPoint, node.lb, champion, newLevel);
            Point2D orientationPoint = createOrientationPoint(node.x,node.y,targetPoint,level);
            double orientationDist = orientationPoint.distanceSquaredTo(targetPoint);
            //We will search on the other part only, if the point is very near to partitioned line
            //and champion point found so far is far away from the partitioned line.
            if(orientationDist < champion.championDist){
                nearestInternal(targetPoint, node.rt, champion, newLevel);
            }
        } else {
            nearestInternal(targetPoint, node.rt, champion, newLevel);
            Point2D orientationPoint = createOrientationPoint(node.x,node.y,targetPoint,level);
            //We will search on the other part only, if the point is very near to partitioned line
            //and champion point found so far is far away from the partitioned line.
            double orientationDist = orientationPoint.distanceSquaredTo(targetPoint);
            if(orientationDist < champion.championDist){
                nearestInternal(targetPoint, node.lb, champion, newLevel);
            }

        }
        return champion;
    }
    /**
     * Returns the point from a partitioned line, which can be directly used to calculate
     * distance between partitioned line and the target point for which neighbours are to be searched.
     * @param linePointX
     * @param linePointY
     * @param targetPoint
     * @param level
     * @return
     */
    private Point2D createOrientationPoint(double linePointX, double linePointY, Point2D targetPoint, int level){
        if(level % 2 == 0){
            return new Point2D(targetPoint.x(),linePointY);
        } else {
            return new Point2D(linePointX,targetPoint.y());
        }
    }

    private static class Champion{
        public Point2D champion;
        public double championDist;
        public Champion(Point2D c, double d){
            champion = c;
            championDist = d;
        }
    }

    public static void main(String[] args) {
        String filename = "/home/raman/Downloads/kdtree/circle100.txt";
        In in = new In(filename);
        KdTree kdTree = new KdTree();
        while (!in.isEmpty()) {
            double x = in.readDouble();
            double y = in.readDouble();
            Point2D p = new Point2D(x, y);
            kdTree.insert(p);
        }
        // kdTree.print();
        System.out.println(kdTree.size());
        kdTree.draw();
        System.out.println(kdTree.nearest(new Point2D(0.4, 0.5)));
        System.out.println(new Point2D(0.7, 0.4).distanceSquaredTo(new Point2D(0.9,0.5)));
        System.out.println(new Point2D(0.7, 0.4).distanceSquaredTo(new Point2D(0.9,0.4)));

    }
}

1
你是正确的,没有太多用于Java的kd实现网站! 无论如何,kd树基本上是一棵二叉搜索树,其中每次都会计算该维度的中位数。这里有一个简单的KDNode,在最近邻方法或完整实现方面,请查看此github项目。这是我能为您找到的最好的项目。希望这可以帮助您。
private class KDNode {
    KDNode left;
    KDNode right;
    E val;
    int depth;
    private KDNode(E e, int depth){
    this.left = null;
    this.right = null;
    this.val = e;
    this.depth = depth;
}

2
链接已损坏。 - PlsWork

1

还有JTS拓扑套件

KdTree实现仅提供范围搜索(无最近邻)。

如果最近邻是你所需的,请看STRtree


0
也许对某些人来说会很有兴趣。请查看我在Java中实现的nearest()函数(以及KD Tree类)用于2D树的实现:
import edu.princeton.cs.algs4.Point2D;
import edu.princeton.cs.algs4.RectHV;
import edu.princeton.cs.algs4.StdDraw;

import java.util.ArrayList;
import java.util.List;

public class KdTree {
    private Node root;
    private int size;

    private static class Node {
        private Point2D p;      // the point
        private RectHV rect;    // the axis-aligned rectangle corresponding to this node
        private Node lb;        // the left/bottom subtree
        private Node rt;        // the right/top subtree
        public Node(Point2D p, RectHV rect) {
            this.p = p;
            this.rect = rect;
        }
    }

    public KdTree() {
    }

    public boolean isEmpty() {
        return size == 0;
    }

    public int size() {
        return size;
    }

    public boolean contains(Point2D p) {
        if (p == null) throw new IllegalArgumentException("argument to contains() is null");
        return contains(root, p, 1);
    }

    private boolean contains(Node node, Point2D p, int level) {
        if (node == null) return false; // a base case for recursive call

        if (node.p.equals(p)) return true;

        if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
            if (p.y() < node.p.y())
                return contains(node.lb, p, level + 1);
            else
                return contains(node.rt, p, level + 1);
        }
        else { // search by x coordinate (node with vertical partition line)
            if (p.x() < node.p.x())
                return contains(node.lb, p, level + 1);
            else
                return contains(node.rt, p, level + 1);
        }
    }

    public void insert(Point2D p) {
        if (p == null) throw new IllegalArgumentException("calls insert() with a null point");
        root = insert(root, p, 1);
    }

    private Node insert(Node x, Point2D p, int level) {
        if (x == null) {
            size++;
            return new Node(p, new RectHV(0, 0, 1, 1));
        }

        if (x.p.equals(p)) return x; // if we try to insert existed point just return its node

        if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
            if (p.y() < x.p.y()) {
                x.lb = insert(x.lb, p, level + 1);
                if (x.lb.rect.equals(root.rect))
                    x.lb.rect = new RectHV(x.rect.xmin(), x.rect.ymin(), x.rect.xmax(), x.p.y());
            }
            else {
                x.rt = insert(x.rt, p, level + 1);
                if (x.rt.rect.equals(root.rect))
                    x.rt.rect = new RectHV(x.rect.xmin(), x.p.y(), x.rect.xmax(), x.rect.ymax());
            }
        }
        else { // search by x coordinate (node with vertical partition line)
            if (p.x() < x.p.x()) {
                x.lb = insert(x.lb, p, level + 1);
                if (x.lb.rect.equals(root.rect))
                    x.lb.rect = new RectHV(x.rect.xmin(), x.rect.ymin(), x.p.x(), x.rect.ymax());
            }
            else {
                x.rt = insert(x.rt, p, level + 1);
                if (x.rt.rect.equals(root.rect))
                    x.rt.rect = new RectHV(x.p.x(), x.rect.ymin(), x.rect.xmax(), x.rect.ymax());
            }
        }
        return x;
    }

    public void draw() {
        draw(root, 1);
    }

    private void draw(Node node, int level) {
        if (node == null) return;

        StdDraw.setPenColor(StdDraw.BLACK);
        StdDraw.setPenRadius(0.01);
        node.p.draw();
        StdDraw.setPenRadius();

        if (level % 2 == 0) {
            StdDraw.setPenColor(StdDraw.BLUE);
            StdDraw.line(node.rect.xmin(), node.p.y(), node.rect.xmax(), node.p.y());
        }
        else {
            StdDraw.setPenColor(StdDraw.RED);
            StdDraw.line(node.p.x(), node.rect.ymin(), node.p.x(), node.rect.ymax());
        }

        draw(node.lb, level + 1);
        draw(node.rt, level + 1);
    }

    public Iterable<Point2D> range(RectHV rect) {
        if (rect == null) throw new IllegalArgumentException("calls range() with a null rect");
        List<Point2D> points = new ArrayList<>(); // create an Iterable object with all points we found
        range(root, rect, points); // call helper method with rects intersects comparing
        
        return points; // return an Iterable object (It could be any type - Queue, LinkedList etc)
    }

    private void range(Node node, RectHV rect, List<Point2D> points) {
        if (node == null || !node.rect.intersects(rect)) return; // a base case for recursive call


        if (rect.contains(node.p))
                points.add(node.p);
        range(node.lb, rect, points);
        range(node.rt, rect, points);

    }    

    public Point2D nearest(Point2D query) {
         if (isEmpty()) return null;
        if (query == null) throw new IllegalArgumentException("calls nearest() with a null point");
        // set the start distance from root to query point
        double best = root.p.distanceSquaredTo(query);
        // StdDraw.setPenColor(StdDraw.BLACK); // just for debugging
        // StdDraw.setPenRadius(0.01);
        // query.draw();
        return nearest(root, query, root.p, best, 1); // call a helper method
    }

    private Point2D nearest(Node node, Point2D query, Point2D champ, double best, int level) {
        // a base case for the recursive call
        if (node == null || best < node.rect.distanceSquaredTo(query)) return champ;
        // we'll need to set an actual best distance when we recur
        best = champ.distanceSquaredTo(query);
        // check whether a distance from query point to the traversed node less than
        // distance from current champion to query point
        double temp = node.p.distanceSquaredTo(query);
        if (temp < best) {
            best = temp;
            champ = node.p;
        }

        if (level % 2 == 0) { // search by y coordinate (node with horizontal partition line)
            // we compare y coordinate and decide go up or down
            if (node.p.y() < query.y()) { // if true go up
                champ = nearest(node.rt, query, champ, best, level + 1);
                // important case - when we traverse node and go back up through the tree
                // we need to decide whether we need to go down(left) in this node or not
                // we just check our bottom (left) node on null && compare distance
                // from query point to the nearest point of the node's rectangle and
                // the distance from current champ point to thr query point
                if (node.lb != null && node.lb.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query)) {
                    champ = nearest(node.lb, query, champ, best, level + 1);
                }

            }
            else { // if false go down
                champ = nearest(node.lb, query, champ, best, level + 1);
                if (node.rt != null && node.rt.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
                    // when we traverse node and go back up through the tree
                    // we need to decide whether we need to go up(right) in this node or not
                    // we just check our top (right) node on null && compare distance
                    // from query point to the nearest point of the node's rectangle and
                    // the distance from current champ point to thr query point
                    champ = nearest(node.rt, query, champ, best, level + 1);

            }

        }
        else {
            // search by x coordinate (node with vertical partition line)
            if (node.p.x() < query.x()) { // if true go right
                champ = nearest(node.rt, query, champ, best, level + 1);
                // the same check as mentioned above when we search by y coordinate
                if (node.lb != null && node.lb.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
                    champ = nearest(node.lb, query, champ, best, level + 1);
            }
            else { // if false go left
                champ = nearest(node.lb, query, champ, best, level + 1);
                  if (node.rt != null && node.rt.rect.distanceSquaredTo(query) < champ.distanceSquaredTo(query))
                     champ = nearest(node.rt, query, champ, best, level + 1);
            }
        }
        return champ;
    }



    public static void main(String[] args) {
        // unit tests
        KdTree kd = new KdTree();
        Point2D p1 = new Point2D(0.7, 0.2);
        Point2D p2 = new Point2D(0.5, 0.4);
        Point2D p3 = new Point2D(0.2, 0.3);
        Point2D p4 = new Point2D(0.4, 0.7);
        Point2D p5 = new Point2D(0.9, 0.6);
        // Point2D query = new Point2D(0.676, 0.736);
        Point2D query1 = new Point2D(0.972, 0.887);
        // RectHV test = new RectHV(0, 0, 0.7, 0.4);
        // Point2D query = new Point2D(0.331, 0.762);

        // Point2D p6 = new Point2D(0.4, 0.4);
        // Point2D p7 = new Point2D(0.1, 0.6);
        // RectHV rect = new RectHV(0.05, 0.1, 0.15, 0.6);

        kd.insert(p1);
        kd.insert(p2);
        kd.insert(p3);
        kd.insert(p4);
        kd.insert(p5);
        System.out.println(kd.nearest(query1));
        // System.out.println("Dist query to 0.4,0.7= " + query.distanceSquaredTo(p4));
        // System.out.println("Dist query to RectHV 0.2,0,3= " + test.distanceSquaredTo(p4));
        // kd.insert(p6);
        // kd.insert(p7);
        // System.out.println(kd.size);
        // System.out.println(kd.contains(p3));
        // // System.out.println(kd.range(rect));

        kd.draw();
        

    }
}

0

非常感谢 theosem!

基于他发布的库(http://java-ml.sourceforge.net/),我制作了这个代码示例:

package kdtreeexample; //place your package name here
import net.sf.javaml.core.kdtree.KDTree; //import library
public class KDTreeExample {

public static void main(String[] args) {
    KDTree kdTree = new KDTree(2); //2 dimensions (x, y)
//        point insertion:
    kdTree.insert(new double[]{4, 3}, 0); //insert points (x=4,y=3), index = 0
    kdTree.insert(new double[]{1, 10}, 1); //insert points (x=1,y=10), index = 1
    kdTree.insert(new double[]{10, 10}, 2); //insert points (x=10,y=10), index = 2
    kdTree.insert(new double[]{5, 1}, 3); //insert points (x=5,y=1), index = 3
//        nearest index to point in coordinates x, y:
    int x = 0; //x coordinate for target point
    int y = 11; //y coordinate for target point
    int nearestIndex = (int) kdTree.nearest(new double[]{x, y}); //doing calculation here
    // result:
    System.out.println("Nearest point value index to point(" + x + ", " + y + ") = " + nearestIndex);
    System.out.println(kdTree.toString()); //check the data
    }
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接