为什么使用PYMC3进行线性回归训练效果如此之差?

3

我是PYMC3的新手。也许这是一个幼稚的问题,但我搜索了很多资料,没有找到任何关于这个问题的解释。基本上,我想在PYMC3中进行线性回归,但是训练速度非常慢,而且模型在训练集上的表现也非常差。以下是我的代码:

X_Tr = np.array([ 13.99802212,  13.8512075 ,  13.9531636 ,  13.97432944,
    13.89211468,  13.91357953,  13.95987483,  13.86476587,
    13.9501789 ,  13.92698143,  13.9653932 ,  14.06663115,
    13.91697969,  13.99629862,  14.01392784,  13.96495713,
    13.98697998,  13.97516973,  14.01048397,  14.05918188,
    14.08342002,  13.89350606,  13.81768849,  13.94942447,
    13.90465027,  13.93969029,  14.18771189,  14.08631113,
    14.03718829,  14.01836206,  14.06758363,  14.05243539,
    13.96287123,  13.93011351,  14.01616973,  14.01923812,
    13.97424024,  13.9587175 ,  13.85669845,  13.97778302,
    14.04192138,  13.93775494,  13.86693585,  13.79985956,
    13.82679677,  14.06474544,  13.90821822,  13.71648423,
    13.78899668,  13.76857337,  13.87201756,  13.86152949,
    13.80447525,  13.99609891,  14.0210165 ,  13.986906  ,
    13.97479211,  14.04562055,  14.03293095,  14.15178043,
    14.32413197,  14.2330354 ,  13.99247751,  13.92962912,
    13.95394525,  13.87888254,  13.82743111,  14.10724699,
    14.23638905,  14.15731881,  14.13239278,  14.13386722,
    13.91442452,  14.01056255,  14.19378649,  14.22233852,
    14.30405399,  14.25880108,  14.23985258,  14.21184303,
    14.4443183 ,  14.55710331,  14.42102092,  14.29047616,
    14.43712609,  14.58666212])
y_tr = np.array([ 13.704,  13.763,  13.654,  13.677,  13.66 ,  13.735,  13.845,
    13.747,  13.747,  13.606,  13.819,  13.867,  13.817,  13.68 ,
    13.823,  13.779,  13.814,  13.936,  13.956,  13.912,  13.982,
    13.979,  13.919,  13.944,  14.094,  13.983,  13.887,  13.902,
    13.899,  13.881,  13.784,  13.909,  13.99 ,  14.06 ,  13.834,
    13.778,  13.703,  13.965,  14.02 ,  13.992,  13.927,  14.009,
    13.988,  14.022,  13.754,  13.837,  13.91 ,  13.907,  13.867,
    14.014,  13.952,  13.796,  13.92 ,  14.051,  13.773,  13.837,
    13.745,  14.034,  13.923,  14.041,  14.077,  14.125,  13.989,
    14.174,  13.967,  13.952,  14.024,  14.171,  14.175,  14.091,
    14.267,  14.22 ,  14.071,  14.112,  14.174,  14.289,  14.146,
    14.356,  14.5  ,  14.265,  14.259,  14.406,  14.463,  14.473,
    14.413,  14.507])
sns.regplot(x=X_tr, y=y_tr.flatten());

enter image description here

在此我使用PYMC3来训练模型:

shA_X = shared(X_tr)
with pm.Model() as linear_model:    
    alpha = pm.Normal("alpha", mu=14,sd=100)
    betas = pm.Normal("betas", mu=0, sd=100, shape=1)
    sigma = pm.HalfCauchy('sigma', beta=10, testval=1.)
    mu = alpha + betas * shA_X
    forecast = pm.Normal("forecast", mu=mu, sd=sigma, observed=y_tr)
    step = pm.NUTS()
    trace=pm.sample(3000, tune=1000)

然后检查性能:

ppc_w = pm.sample_ppc(trace, 1000, linear_model,
                    progressbar=False)
plt.plot(ppc_w['forecast'].mean(axis=0),'r')
plt.plot(y_tr, color='k')`

在此输入图片描述

为什么训练集上的预测结果如此糟糕?欢迎任何建议和想法。


看起来你正在展示 plt.plot(ppc_w['forecast'].mean(axis=1), 'r') 的图表。你发布的代码实际上生成了一个合理的图表。 - colcarroll
1个回答

3
这个模型表现不错 - 我认为混淆的原因在于如何处理PyMC3对象(感谢您提供易于使用的示例!)。一般来说,PyMC3将用于量化模型中的不确定性。
例如,trace['betas'].mean()约为0.83(这取决于您的随机种子),而最小二乘估计值,例如sklearn给出的值为0.826。同样,trace['alpha'].mean()值为2.34,而“真实”值为2.38。
您还可以使用跟踪来绘制许多不同的最佳拟合线的可行抽样:
for draw in trace[::100]:
    pred = draw['betas'] * X_tr + draw['alpha']
    plt.plot(X_tr, pred, '--', alpha=0.2, color='grey')


plt.plot(X_tr, y_tr, 'o');

请注意,这些来自于您的数据的“最佳拟合”分布。您还使用sigma来建模噪声,并且您也可以将此值绘制出来:

enter image description here

for draw in trace[::100]:
    pred = draw['betas'] * X_tr + draw['alpha']

    plt.plot(X_tr, pred, '-', alpha=0.2, color='grey')
    plt.plot(X_tr, pred + draw['sigma'], '-', alpha=0.05, color='red')
    plt.plot(X_tr, pred - draw['sigma'], '-', alpha=0.05, color='red');


plt.plot(X_tr, y_tr, 'o');

enter image description here

使用sample_ppc从后验分布中绘制观测值,因此ppc_w['forecast']的每一行都是数据在“下一次”生成的合理方式。您可以按照以下方式使用该对象:

ppc_w = pm.sample_ppc(trace, 1000, linear_model,
                      progressbar=False)
for draw in ppc_w['forecast'][::5]:
    sns.regplot(X_tr, draw, scatter_kws={'alpha': 0.005, 'color': 'r'}, fit_reg=False)
sns.regplot(X_tr, y_tr, color='k', fit_reg=False);

enter image description here


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接