Julia中用于最小二乘问题的坐标下降算法不收敛

3
作为自己编写弹性网解算器的热身,我正在尝试使用坐标下降实现足够快的普通最小二乘法版本。
我相信我已经正确实现了坐标下降算法,但是当我使用“快速”版本(见下文)时,该算法变得非常不稳定,当特征数与样本数相比较适中时,常常输出回归系数会溢出64位浮点数。
线性回归和OLS
如果b = A * x,其中A是一个矩阵,x是未知回归系数的向量,y是输出,我想找到最小化||b-Ax||^2的x。
如果A[j]是A的第j列而A[-j]是没有第j列的A,并且A的列被规范化以使||A[j]||^2= 1对于所有j都成立,则坐标更新为:
x[j]  <--  A[j]^T * (b - A[-j] * x[-j])

我正在跟随这些笔记(第9-10页)进行学习,但推导很简单。

指出可以用更快的方法来计算A[j]^T(b - A[-j] * x[-j]),而不是一直重新计算。

快速坐标下降:

x[j]  <--  A[j]^T*r + x[j]

总剩余量r = b - Ax在坐标循环之外计算。这些更新规则的等价性是由于观察到Ax = A[j]*x[j] + A[-j]*x[-j]并重新排列项得出的。

我的问题是,虽然第二种方法确实更快,但当特征数与样本数相比较大时,对我来说非常不稳定。我想知道为什么会这样。我应该注意到,随着特征数接近样本数,第一种更稳定的方法仍然与更标准的方法不同。

Julia代码

以下是两个更新规则的一些Julia代码:

function OLS_builtin(A,b)
    x = A\b
    return(x)
end

function OLS_coord_descent(A,b)    
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        for j = 1:P 
            x[j] = dot(A[:,j], b - A[:,1:P .!= j]*x[1:P .!= j])
        end    
    end
    return(x)
end

function OLS_coord_descent_fast(A,b) 
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        r = b - A*x
        for j = 1:P
            x[j] += dot(A[:,j],r)
        end    
    end
    return(x)
end

问题示例

我使用以下命令生成数据:

n = 100
p = 50
σ = 0.1
β_nz = float([i*(-1)^i for i in 1:10])

β = append!(β_nz,zeros(Float64,p-length(β_nz)))
X = randn(n,p); X .-= mean(X,1); X ./= sqrt(sum(abs2(X),1))
y = X*β + σ*randn(n); y .-= mean(y);

这里我使用p=50,OLS_coord_descent(X,y)OLS_builtin(X,y)之间达成了良好的一致性,而OLS_coord_descent_fast(X,y)返回的回归系数呈指数级增长。
当p小于约20时,OLS_coord_descent_fast(X,y)与其他两种方法相符。
猜想
因为在p<<n的范围内结果是相符的,我认为该算法在形式上是正确的,但在数值上是不稳定的。是否有人对我这个猜想有什么看法?如果是这样,如何纠正不稳定性同时保留(大部分)快速版本算法的性能提升?

这是 http://stats.stackexchange.com/questions/251920/coordinate-descent-in-ordinary-least-squares-not-converging 的转载。请不要这样交叉发布。请注意,StackOverflow 是用于编程问题的,而这似乎更像是算法问题。我认为这实际上最适合 Computational Science SO,但认为应该迁移而不是第三次交叉发布。 - Chris Rackauckas
是的,这似乎是算法、统计学以及可能与Julia语言特定问题的奇怪交叉点,所以我不确定应该放在哪里。如果您想将其迁移到更合适的地方,请随意操作。 - Rory
1个回答

5
简短回答:您忘记在每次 x[j] 更新后更新 r。下面是修正后的函数,与 OLS_coord_descent 表现相同:
function OLS_coord_descent_fast(A,b) 
    N,P = size(A)
    x = zeros(P)
    for cycle in 1:1000
        r = b - A*x
        for j = 1:P
            x[j] += dot(A[:,j],r)
            r -= A[:,j]*dot(A[:,j],r)   # Add this line
        end    
    end
    return(x)
end

另外,请注意dot(A[:,j],r)是循环中两个更新的常见子表达式,因此它们可以折叠成一个计算。整个内部循环仍然是O(N)的计算。 - Dan Getz
1
是的!我认为r = b - A * x这个公式也可以被删除,因为在你的修复中,r已经被更新了,只需要在循环外将其初始化为b即可。 - Rory

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接