StatsModels.jl:在Julia中使用指数公式 `formula(y ~ (x1 + x2 + x3 + x4 + x5)^2`

5

我一直在尝试在Julia中复制这个矩阵(R)。

data = read.csv("https://gist.githubusercontent.com/TJhon/158daa0c2dd06010d01a72dae2af8314/raw/61df065c98ec90b9ea3b8598d1996fb5371a64aa/rnd.csv")

head(model.matrix(y ~ (x1 + x2 + x3 + x4 + x5)^2, data), 3)
#>   (Intercept)         x1         x2         x3        x4 x5       x1:x2
#> 1           1 -0.3007225 -1.3710894  0.3423409  1.322547  2  0.41231744
#> 2           1  0.4674170  0.8728939  0.9534157 -1.007083  1  0.40800548
#> 3           1  0.2085316 -0.3657995 -0.3043694 -1.036938  4 -0.07628076
#>         x1:x3      x1:x4      x1:x5      x2:x3      x2:x4      x2:x5      x3:x4
#> 1 -0.10294961 -0.3977198 -0.6014450 -0.4693799 -1.8133307 -2.7421787  0.4527620
#> 2  0.44564276 -0.4707279  0.4674170  0.8322308 -0.8790769  0.8728939 -0.9601690
#> 3 -0.06347064 -0.2162343  0.8341265  0.1113382  0.3793113 -1.4631979  0.3156121
#>        x3:x5     x4:x5
#> 1  0.6846817  2.645095
#> 2  0.9534157 -1.007083
#> 3 -1.2174775 -4.147751

本文创建于2022年10月18日,使用reprex程序包(v2.0.1)

我尝试了

using CSV, DataFrames, StatsModels, StatsBase

data = CSV.read(download("https://gist.githubusercontent.com/TJhon/158daa0c2dd06010d01a72dae2af8314/raw/61df065c98ec90b9ea3b8598d1996fb5371a64aa/rnd.csv"), DataFrame) 

ModelMatrix(ModelFrame(@formula(y ~ (x1 + x2 + x3 + x4 + x5) * (x1 + x2 + x3 + x4 + x5)), data)).m

9×31 Matrix{Float64}:
 1.0  -0.300723  -1.37109    0.342341   1.32255   2.0  0.090434    0.412317   -0.10295    -0.397720.452762    1.74913     2.64509   -0.601445  -2.74218    0.684682   2.64509    4.0
 1.0   0.467417   0.872894   0.953416  -1.00708   1.0  0.218479    0.408005    0.445643   -0.470728      -0.960169    1.01422    -1.00708    0.467417   0.872894   0.953416  -1.00708    1.0

 1.0   0.395908  -1.15159   -0.204683  -0.207952  2.0  0.156743   -0.455924   -0.0810355  -0.0823297      0.0425641   0.0432439  -0.415903   0.791816  -2.30318   -0.409365  -0.415903   4.0

ModelMatrix(ModelFrame(@formula(y ~ (x1 + x2 + x3 + x4 + x5) & (x1 + x2 + x3 + x4 + x5)), data)).m

9×26 Matrix{Float64}:
 1.0  0.090434    0.412317   -0.10295    -0.39772    -0.601445   0.412317   1.87989   -0.46938   -1.813330.452762    1.74913     2.64509   -0.601445  -2.74218    0.684682   2.64509    4.0   
 1.0  0.218479    0.408005    0.445643   -0.470728    0.467417   0.408005   0.761944   0.832231  -0.879077      -0.960169    1.01422    -1.00708    0.467417   0.872894   0.953416  

 1.0  0.156743   -0.455924   -0.0810355  -0.0823297   0.791816  -0.455924   1.32616    0.235711   0.239475       0.0425641   0.0432439  -0.415903   0.791816  -2.30318   -0.409365  -0.415903   4.0   

ModelMatrix(ModelFrame(@formula(y ~ (x1 + x2 + x3 + x4 + x5)^2), data)).m

9×2 Matrix{Float64}:
 1.0   3.97235
 1.0   5.22874
 :     :
 1.0   0.691696

我希望相同的数组和向量变量名称能够被转换为数据框以供日后使用。
1个回答

3
xs = term.((Symbol("x$i") for i=1:5))
ff = vcat(term(1), xs, [xs[a] & xs[b] for a in 1:5 for b in a+1:5])
ModelMatrix(ModelFrame(FormulaTerm(Term(:y),Tuple(ff)), data)).m

这个方法运行正常,但是比 R 版本更难看。也许有更好的方法。

另外:

varnames = vcat("(intercept)", ["x$i" for i=1:5], ["x$(a):x$(b)" for a in 1:5 for b in a+1:5])

更新

上述解决方案(虽然有效)并不是特别好或朱利安风格,因此这里有一个重新编写的解决方案,尝试更加通用:

using CSV, DataFrames, StatsModels, StatsBase

URL = join([
  "https:","","gist.githubusercontent.com",
  "TJhon","158daa0c2dd06010d01a72dae2af8314",
  "raw","61df065c98ec90b9ea3b8598d1996fb5371a64aa","rnd.csv"],
  '/')

data = CSV.read(download(URL), DataFrame) 

using Combinatorics

subsets(X; from=0, upto=length(X)) =
  Iterators.flatten(combinations(X,i) for i=max(0,from):min(upto,length(X)))

xs = term.((Symbol("x$i") for i=1:5))

term_vec = collect(subsets(xs; from=1, upto=2))
rhs = vcat(ConstantTerm(1), map(x->reduce(&, x), term_vec))
rhs_names = vcat("(intercept)", [join(string.(x),'*') for x in term_vec])

ModelMatrix(ModelFrame(FormulaTerm(Term(:y),Tuple(rhs)), data)).m

看起来很长,但初始部分仅为复制粘贴方便而设计,并且通过在term_vec定义中将upto=2替换为upto=3,它还具有允许三项交互的好处。

特别是,subsets迭代器在许多情况下非常有用,并且将其添加到迭代器中是个好主意(如同意,请评论)。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接