趋势线(回归,曲线拟合)Java库

16

我正在开发一个应用程序,计算与Excel相同的趋势线,但针对更大的数据集。

enter image description here

但我找不到任何Java库来计算这些回归。 对于线性模型,我使用Apache Commons math,对于其他模型,Michael Thomas Flanagan有一个很好的数值库,但自一月以来它不再可用:

http://www.ee.ucl.ac.uk/~mflanaga/java/

您是否知道其他库或代码存储库可以在Java中计算这些回归。最好,


为什么不自己动手写呢?至少数学编程起来还算容易,对吧?换句话说,“你尝试过什么?” - Shark
3个回答

42

由于它们都基于线性拟合,因此OLSMultipleLinearRegression是您需要的所有线性、多项式、指数、对数和幂趋势线。

您的问题给了我一个借口下载并使用commons math回归工具,并且我编写了一些趋势线工具:

一个接口:

public interface TrendLine {
    public void setValues(double[] y, double[] x); // y ~ f(x)
    public double predict(double x); // get a predicted y for a given x
}

一个用于回归趋势线的抽象类:

public abstract class OLSTrendLine implements TrendLine {

    RealMatrix coef = null; // will hold prediction coefs once we get values

    protected abstract double[] xVector(double x); // create vector of values from x
    protected abstract boolean logY(); // set true to predict log of y (note: y must be positive)

    @Override
    public void setValues(double[] y, double[] x) {
        if (x.length != y.length) {
            throw new IllegalArgumentException(String.format("The numbers of y and x values must be equal (%d != %d)",y.length,x.length));
        }
        double[][] xData = new double[x.length][]; 
        for (int i = 0; i < x.length; i++) {
            // the implementation determines how to produce a vector of predictors from a single x
            xData[i] = xVector(x[i]);
        }
        if(logY()) { // in some models we are predicting ln y, so we replace each y with ln y
            y = Arrays.copyOf(y, y.length); // user might not be finished with the array we were given
            for (int i = 0; i < x.length; i++) {
                y[i] = Math.log(y[i]);
            }
        }
        OLSMultipleLinearRegression ols = new OLSMultipleLinearRegression();
        ols.setNoIntercept(true); // let the implementation include a constant in xVector if desired
        ols.newSampleData(y, xData); // provide the data to the model
        coef = MatrixUtils.createColumnRealMatrix(ols.estimateRegressionParameters()); // get our coefs
    }

    @Override
    public double predict(double x) {
        double yhat = coef.preMultiply(xVector(x))[0]; // apply coefs to xVector
        if (logY()) yhat = (Math.exp(yhat)); // if we predicted ln y, we still need to get y
        return yhat;
    }
}

多项式或线性模型的实现:

(对于线性模型,只需在调用构造函数时将度数设置为1即可。)

public class PolyTrendLine extends OLSTrendLine {
    final int degree;
    public PolyTrendLine(int degree) {
        if (degree < 0) throw new IllegalArgumentException("The degree of the polynomial must not be negative");
        this.degree = degree;
    }
    protected double[] xVector(double x) { // {1, x, x*x, x*x*x, ...}
        double[] poly = new double[degree+1];
        double xi=1;
        for(int i=0; i<=degree; i++) {
            poly[i]=xi;
            xi*=x;
        }
        return poly;
    }
    @Override
    protected boolean logY() {return false;}
}

指数和幂模型更容易:

注意:我们现在预测的是对数 y,这很重要。这两种模型都只适用于正 y。

public class ExpTrendLine extends OLSTrendLine {
    @Override
    protected double[] xVector(double x) {
        return new double[]{1,x};
    }

    @Override
    protected boolean logY() {return true;}
}

并且

public class PowerTrendLine extends OLSTrendLine {
    @Override
    protected double[] xVector(double x) {
        return new double[]{1,Math.log(x)};
    }

    @Override
    protected boolean logY() {return true;}

}

同时还有一个对数模型:

(该模型将x取对数,但预测的不是ln y,而是y)

public class LogTrendLine extends OLSTrendLine {
    @Override
    protected double[] xVector(double x) {
        return new double[]{1,Math.log(x)};
    }

    @Override
    protected boolean logY() {return false;}
}

你可以这样使用:

public static void main(String[] args) {
    TrendLine t = new PolyTrendLine(2);
    Random rand = new Random();
    double[] x = new double[1000*1000];
    double[] err = new double[x.length];
    double[] y = new double[x.length];
    for (int i=0; i<x.length; i++) { x[i] = 1000*rand.nextDouble(); }
    for (int i=0; i<x.length; i++) { err[i] = 100*rand.nextGaussian(); } 
    for (int i=0; i<x.length; i++) { y[i] = x[i]*x[i]+err[i]; } // quadratic model
    t.setValues(y,x);
    System.out.println(t.predict(12)); // when x=12, y should be... , eg 143.61380202745192
}

既然你只需要趋势线,当我完成ols模型后就将其排除了,但你可能需要保留一些关于拟合度等数据。

对于使用移动平均,移动中位数等实现,看起来你可以继续使用commons math。尝试使用DescriptiveStatistics并指定一个窗口。你可能需要进行一些平滑处理,使用另一个答案中建议的插值方法。


我已经尝试了你的代码,它很好,但我想知道如何在获取Y数组后绘制趋势线。另外,它只适用于y=x2+const方程吗?还是我可以将其更改为我的方程y=4x2+3x+6.4? - Nitish Patel
是的,使用PolyTrendLine(2)创建的TrendLine将会估计方程y=b0+b1*x+b2*x^2的系数。仔细观察xVector,你会发现它使用 x^0, x^1, ... x^k 作为 k 次回归的估算器。 - maybeWeCouldStealAVan
通过简单地调用main()函数就可以在图表上画线吗?我已经在我的onCreate()函数中尝试了一段代码,但它没有在我的UI上画出任何东西。请查看我的问题http://stackoverflow.com/questions/22808204/how-to-draw-trade-line-on-scatter-chart-in-android。 - Nitish Patel
也许我们可以偷一辆货车。如何获取r的绝对值或平方?| r | 或 r ^ 2。 - Dickey Singh
@maybeWeCouldStealAVan 感谢您提供这个好答案。Apache Commons支持多元回归吗?我正在处理的问题需要有2个因变量,所以通过这个库是否可能实现? - Dania
显示剩余2条评论

4
除了maybeWeCouldStealAVa提到的内容;
commons-math3库也可以在maven repository中找到。
当前版本为3.2,依赖标签为:
    <dependency>
        <groupId>org.apache.commons</groupId>
        <artifactId>commons-math3</artifactId>
        <version>3.2</version>
    </dependency>

4

您可以使用org.apache.commons.math3.analysis.interpolation中提供的不同类型的插值器,包括线性插值器、Loess插值器和Neville插值器等。


FYI。更新链接。Math4。https://commons.apache.org/proper/commons-math/apidocs/org/apache/commons/math4/analysis/interpolation/package-tree.html - Wpigott

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接