XGBoost树的叶节点值在多类分类问题中如何解释

3
我一直在使用XGBoost Python库解决多类分类问题,使用multi:softmax目标函数。通常情况下,我不确定当我使用xgb.plot_tree()或bst.dump_model()将模型转储到txt文件中时所输出的几个决策树的叶子值如何解释。我的问题有6类,标记为0-5,并且我已经设置了我的模型执行两次boosting迭代(至少现在是这样,因为我试图更好地理解XGBoost的工作原理)。通过在线搜索(尤其是https://github.com/dmlc/xgboost/issues/1746),我注意到booster [x]上的树表示boosting的第int(x / (num_classes)) + 1个迭代,在决策树中显示x%(num_classes)类。例如,在我的txt文件中,booster [7]显示了第2次boosting期间的决策树,用于1类。此外,我发现在每棵树内使用softmax函数时,所有叶节点的softmax值加起来为1。
除此之外,我对所有这些树的叶节点值如何决定XGBoost选择哪个类别感到困惑。我的问题是:
  1. 这些树在增强迭代中如何影响输出?例如,booster[0]booster[6](它们表示我的类别0的第一次和第二次增强迭代)如何影响最终输出或类别0的最终概率?
  2. 所有树的叶节点值如何决定XGBoost选择哪个类别的数学原理是什么?
如果通过演示回答有帮助的话,我已经提供了转储的txt文件,以及使用multi:softprobmulti:softmax作为目标的样本输入和输出。

dump.raw.txt:

booster[0]:
0:[f0<0.5] yes=1,no=2,missing=1
    1:[f8<19.5299988] yes=3,no=4,missing=3
        3:leaf=0.244897947
        4:leaf=-0.042857144
    2:leaf=-0.0595400333
booster[1]:
0:[f2<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0594852231
    2:[f8<0.389999986] yes=3,no=4,missing=3
        3:leaf=0.272727251
        4:[f9<0.607749999] yes=5,no=6,missing=5
            5:[f9<0.290250003] yes=7,no=8,missing=7
                7:[f8<6.75] yes=11,no=12,missing=11
                    11:leaf=0.0157894716
                    12:leaf=-0.0348837189
                8:leaf=0.11249999
            6:[f8<12.6100006] yes=9,no=10,missing=9
                9:leaf=-0.0483870953
                10:[f8<15.1700001] yes=13,no=14,missing=13
                    13:leaf=0.0157894716
                    14:leaf=-0.0348837189
booster[2]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0595029891
    2:[f8<0.439999998] yes=3,no=4,missing=3
        3:[f5<0.5] yes=5,no=6,missing=5
            5:leaf=-0.042857144
            6:leaf=0.226027399
        4:[f9<-0.606250048] yes=7,no=8,missing=7
            7:leaf=0.0157894716
            8:leaf=-0.0545454584
booster[3]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0595029891
    2:[f5<0.5] yes=3,no=4,missing=3
        3:[f8<19.6599998] yes=5,no=6,missing=5
            5:leaf=0.260869563
            6:leaf=-0.0452054814
        4:leaf=-0.0524475537
booster[4]:
0:[f9<-0.477999985] yes=1,no=2,missing=1
    1:[f9<-0.622750044] yes=3,no=4,missing=3
        3:leaf=-0.0557312258
        4:[f10<0] yes=7,no=8,missing=7
            7:[f5<0.5] yes=11,no=12,missing=11
                11:leaf=0.0069767423
                12:leaf=0.0631578937
            8:leaf=-0.0483870953
    2:[f8<0.400000006] yes=5,no=6,missing=5
        5:leaf=-0.0563139915
        6:[f10<0] yes=9,no=10,missing=9
            9:[f8<19.5200005] yes=13,no=14,missing=13
                13:[f2<0.5] yes=17,no=18,missing=17
                    17:[f9<1.14275002] yes=23,no=24,missing=23
                        23:[f8<15.2000008] yes=27,no=28,missing=27
                            27:leaf=-0.0483870953
                            28:leaf=0.0157894716
                        24:leaf=0.0631578937
                    18:leaf=0.226829246
                14:leaf=0.293398529
            10:[f9<0.492500007] yes=15,no=16,missing=15
                15:[f8<17.2700005] yes=19,no=20,missing=19
                    19:leaf=0.152054787
                    20:leaf=-0.0570247956
                16:[f8<13.4099998] yes=21,no=22,missing=21
                    21:[f2<0.5] yes=25,no=26,missing=25
                        25:leaf=-0.0348837189
                        26:leaf=0.132558137
                    22:leaf=0.275871307
booster[5]:
0:[f9<-0.181999996] yes=1,no=2,missing=1
    1:[f10<0] yes=3,no=4,missing=3
        3:[f9<-0.49150002] yes=7,no=8,missing=7
            7:[f4<0.5] yes=13,no=14,missing=13
                13:leaf=0.0157894716
                14:leaf=0.226829246
            8:leaf=-0.0529411733
        4:[f8<12.9099998] yes=9,no=10,missing=9
            9:leaf=-0.0396226421
            10:leaf=0.285522789
    2:[f9<0.490750015] yes=5,no=6,missing=5
        5:[f10<0] yes=11,no=12,missing=11
            11:leaf=-0.0577405877
            12:[f8<17.2800007] yes=15,no=16,missing=15
                15:leaf=-0.0521739125
                16:[f2<0.5] yes=17,no=18,missing=17
                    17:leaf=0.274038434
                    18:leaf=0.0631578937
        6:leaf=-0.0589545034
booster[6]:
0:[f0<0.5] yes=1,no=2,missing=1
    1:[f8<19.5299988] yes=3,no=4,missing=3
        3:leaf=0.200149015
        4:leaf=-0.0419149213
    2:leaf=-0.0587796457
booster[7]:
0:[f2<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0587093942
    2:[f8<0.389999986] yes=3,no=4,missing=3
        3:leaf=0.212223038
        4:[f9<0.607749999] yes=5,no=6,missing=5
            5:[f9<0.290250003] yes=7,no=8,missing=7
                7:[f8<6.75] yes=11,no=12,missing=11
                    11:leaf=0.0150387408
                    12:leaf=-0.0345491134
                8:leaf=0.102861121
            6:[f10<0] yes=9,no=10,missing=9
                9:leaf=-0.047783535
                10:[f9<0.93175] yes=13,no=14,missing=13
                    13:leaf=0.0160113405
                    14:leaf=-0.0342122875
booster[8]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0587323084
    2:[f8<0.439999998] yes=3,no=4,missing=3
        3:[f5<0.5] yes=5,no=6,missing=5
            5:leaf=-0.0419248194
            6:leaf=0.187167063
        4:[f9<-0.606250048] yes=7,no=8,missing=7
            7:leaf=0.0154749081
            8:leaf=-0.0537380874
booster[9]:
0:[f3<0.5] yes=1,no=2,missing=1
    1:leaf=-0.0587323084
    2:[f5<0.5] yes=3,no=4,missing=3
        3:[f8<19.6599998] yes=5,no=6,missing=5
            5:leaf=0.207475975
            6:leaf=-0.0443004556
        4:leaf=-0.0517353415
booster[10]:
0:[f9<-0.477999985] yes=1,no=2,missing=1
    1:[f9<-0.622750044] yes=3,no=4,missing=3
        3:leaf=-0.0549092069
        4:[f10<0] yes=7,no=8,missing=7
            7:[f8<19.9899998] yes=11,no=12,missing=11
                11:leaf=0.0621421933
                12:leaf=0.00554796588
            8:leaf=-0.0474151336
    2:[f8<0.400000006] yes=5,no=6,missing=5
        5:leaf=-0.0555005781
        6:[f0<0.5] yes=9,no=10,missing=9
            9:leaf=-0.0508832447
            10:[f10<0] yes=13,no=14,missing=13
                13:[f3<0.5] yes=15,no=16,missing=15
                    15:leaf=0.220791802
                    16:[f9<0.988499999] yes=19,no=20,missing=19
                        19:leaf=-0.0421211571
                        20:leaf=0.059088923
                14:[f9<0.492500007] yes=17,no=18,missing=17
                    17:[f8<17.2700005] yes=21,no=22,missing=21
                        21:leaf=0.162014976
                        22:leaf=-0.0559271388
                    18:[f3<0.5] yes=23,no=24,missing=23
                        23:leaf=0.217694834
                        24:leaf=0.0335121229
booster[11]:
0:[f9<-0.181999996] yes=1,no=2,missing=1
    1:[f8<19.3400002] yes=3,no=4,missing=3
        3:leaf=-0.0464246981
        4:[f10<0] yes=7,no=8,missing=7
            7:[f9<-0.49150002] yes=11,no=12,missing=11
                11:leaf=0.178972095
                12:leaf=-0.0509003103
            8:leaf=0.218449697
    2:[f9<0.490750015] yes=5,no=6,missing=5
        5:[f10<0] yes=9,no=10,missing=9
            9:leaf=-0.0568957441
            10:[f8<17.2800007] yes=13,no=14,missing=13
                13:leaf=-0.0513576232
                14:[f2<0.5] yes=15,no=16,missing=15
                    15:leaf=0.212948546
                    16:leaf=0.0586818419
        6:leaf=-0.0581783429

样本输入,带有期望标签:[0, 1, 0, 0, 1, 0, 1, 20, 16.8799, 0.587, 0.5],标签:0
multi:softmax 输出:[0]
multi:softprob 输出(如果有帮助):[[0.24506968 0.13953298 0.13952732 0.13952732 0.19666144 0.13968122]]
我知道这是一个复杂的问题,希望我解释得清楚。非常感谢您提前的帮助!
1个回答

4
  1. 每个类别的树都会基于其之前的迭代进行构建(这就是所谓的boosting!)。在你的示例中,booster[0]booster[6]都有助于为类别0提供softmax概率的分子。

更一般地说,booster[i]booster[i+6]都有助于为类别i提供softmax概率的分子。如果你将迭代次数增加到2以上,你就会有booster[i]booster[i+6]、...booster[i+6n]都对第i类做出贡献,其中n-1为迭代次数。

  1. 我们可以使用你的示例来证明这一点:

给定你的输入和转储的txt文件,我们可以找到每个booster的叶节点值:

Booster 0: 0.24489
Booster 1: -0.0594
Booster 2: -0.0595
Booster 3: -0.0595
Booster 4: 0.27587
Booster 5: -0.0589
Booster 6: 0.2
Booster 7: -0.0587
Booster 8: -0.0587
Booster 9: -0.0587
Booster 10: -0.0508
Booster 11: -0.0582

现在我们只需要将其插入到softmax公式中,即可得出在softprob下每个五类的概率。
Z_0 = e^{0.24489+0.2} = 1.5603
Z_1 = e^{-0.0594-0.0587} = 0.8886
Z_2 = e^{-0.0595-0.0587} = 0.8885
Z_3 = e^{-0.0595-0.0587} = 0.8885
Z_4 = e^{0.2758-0.0508} = 1.2523
Z_5 = e^{-0.0589-0.0582} = 0.8895

将这些相加得到softmax概率的分母:6.3677。

因此,我们可以计算每个类别的软性概率。

P(output=0) = 1.5603/6.3677 = 0.2450
P(output=1) = 0.8886/6.3677 = 0.1395
P(output=2) = 0.8885/6.3677 = 0.1395
P(output=3) = 0.8885/6.3677 = 0.1395
P(output=4) = 1.2523/6.3677 = 0.1967
P(output=5) = 0.8895/6.3677 = 0.1397

选择具有最高概率的类别(类别0)将产生您预测的softmax输出。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接