如何使用libsvm计算多类预测的概率?

4

我正在使用libsvm,文档让我相信有一种方法可以输出输出分类准确性的置信概率。这是真的吗?如果是,有人能提供一个清晰的代码示例吗?

目前,我正在以下方式使用Java库

    SvmModel model = Svm.svm_train(problem, parameters);
    SvmNode x[] = getAnArrayOfSvmNodesForProblem();
    double predictedValue = Svm.svm_predict(model, x);
2个回答

8
鉴于您的代码片段,我假设您想使用与 libSVM 打包的 Java API,而不是由 jlibsvm 提供的更冗长的 API。
要启用带概率估计的预测,请使用 svm_parameter 字段,将probability设置为1来训练模型。然后,只需更改您的代码,使其调用 svm 方法 svm_predict_probability 而不是 svm_predict
修改您的代码片段如下:
parameters.probability = 1;
svm_model model = svm.svm_train(problem, parameters);

svm_node x[] = problem.x[0]; // let's try the first data pt in problem
double[] prob_estimates = new double[NUM_LABEL_CLASSES]; 
svm.svm_predict_probability(model, x, prob_estimates);

值得注意的是,使用多类概率估计进行训练可能会改变分类器所做出的预测。有关更多信息,请参阅问题使用LibSVM计算最接近均值/标准差对的最佳匹配


@dmcer 请问哪个软件包的学习曲线较低(Java API与libSVM捆绑在一起的软件包还是jlibsvm)?我对SVMs完全是新手。 - GobiasKoffi
@rohanbk - 可能是jlibsvm,因为它看起来和感觉像一个典型的Java API。 - dmcer
@dmcer,您有使用WEKA进行支持向量机的经验吗? - GobiasKoffi
@rohanbk - 没有太多。但是,这将是一个相当不错的选择,因为它可以让您轻松地在您的数据上进行其他分类器的基准测试。 - dmcer
@dmcer 谢谢你的建议。我想我会采用jlibsvm方法。你知道哪里有好的资源或例子可以让我学习jlibsvm吗?似乎缺乏典型案例。 - GobiasKoffi
1
@rohanbk - 我所知道的唯一好的示例代码是"legacyexec"命令行工具:http://dev.davidsoergel.com/trac/jlibsvm/browser/trunk/src/main/java/edu/berkeley/compbio/jlibsvm/legacyexec - dmcer

1

被接受的答案非常有效。在训练期间,请确保将probability = 1设置。

如果您正在尝试在置信度未达到阈值时放弃预测,则可以使用以下代码示例:

double confidenceScores[] = new double[model.nr_class];
svm.svm_predict_probability(model, svmVector, confidenceScores);

/*System.out.println("text="+ text);
for (int i = 0; i < model.nr_class; i++) {
    System.out.println("i=" + i + ", labelNum:" + model.label[i] + ", name=" + classLoadMap.get(model.label[i]) + ", score="+confidenceScores[i]);
}*/

//finding max confidence; 
int maxConfidenceIndex = 0;
double maxConfidence = confidenceScores[maxConfidenceIndex];
for (int i = 1; i < confidenceScores.length; i++) {
    if(confidenceScores[i] > maxConfidence){
        maxConfidenceIndex = i;
        maxConfidence = confidenceScores[i];
    }
}

double threshold = 0.3; // set this based data & no. of classes
int labelNum = model.label[maxConfidenceIndex];
// reverse map number to name
String targetClassLabel = classLoadMap.get(labelNum); 
LOG.info("classNumber:{}, className:{}; confidence:{}; for text:{}",
        labelNum, targetClassLabel, (maxConfidence), text);
if (maxConfidence < threshold ) {
    LOG.info("Not enough confidence; threshold={}", threshold);
    targetClassLabel = null;
}
return targetClassLabel;

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接