如何从Java中调用scikit-learn分类器?

37

我有一个使用Python中的scikit-learn训练出的分类器。如何在Java程序中使用这个分类器?我可以使用Jython吗?是否有办法在Python中保存分类器并在Java中加载它?还有其他的使用方法吗?

6个回答

54

由于scikit-learn严重依赖于numpy和scipy,这两个库有很多编译的C和Fortran扩展,因此无法在jython中使用。

在Java环境中使用scikit-learn最简单的方法是:

  • 将分类器作为HTTP / Json服务公开,例如使用诸如 flask 瓶子cornice等微型框架,并使用HTTP客户端库从Java调用它。

  • 编写一个命令行包装应用程序,该程序使用某种格式(例如CSV或JSON(或某些较低级别的二进制表示))从stdin读取数据并在stdout上输出预测,并使用Java从 python程序调用,例如使用Apache Commons Exec

  • 使Python程序输出在拟合时学习的原始数值参数(通常作为浮点值数组),然后在Java中重新实现预测函数(对于预测线性模型来说,这通常很容易,其中预测通常只是阈值化的点积)。

如果您还需要在Java中重新实现特征提取,则最后一种方法将需要更多的工作。

最后,您可以使用Java库(例如Weka或Mahout)来实现所需算法,而不是尝试从Java中使用scikit-learn。


3
我的一位同事刚刚建议使用Jepp...这对我们有用吗?请为我翻译成中文。 - Thomas Johnson
可能是我不知道Jepp。它看起来确实适合这个任务。 - ogrisel
对于 Web 应用程序,我个人更喜欢 HTTP 暴露方法。@user939259 可以使用分类器池来处理各种应用程序,并更轻松地进行扩展(根据需求调整池的大小)。我只会考虑在桌面应用程序中使用 Jepp。尽管我很喜欢 Python,但除非 scikit-lear 的性能显著优于 Weka 或 Mahout,否则我会选择单一语言解决方案。使用多种语言/框架应被视为技术债务。 - rbanffy
我同意有关多语言技术债务的观点:在一个团队中,所有开发人员都了解Java和Python,并不得不从一种技术文化转换到另一种,这增加了项目管理中的无用复杂性。 - ogrisel
也许这是技术债务 - 但是类比一下,机器学习中你总是破产申报,因为你会尝试各种方法,发现它们不起作用,然后进行调整/丢弃。因此,在这种情况下,也许债务并不是那么重要的问题。 - Thomas Johnson

24

为此目的有JPMML项目。

首先,您可以使用sklearn2pmml库将scikit-learn模型序列化为PMML(内部为XML),直接从python中进行转换或先在python中转储它并使用jpmml-sklearn在Java中进行转换或使用此库提供的命令行。接下来,您可以加载pmml文件,反序列化并使用jpmml-evaluator在您的Java代码中执行已加载的模型。

这种方法适用于不是所有scikit-learn模型,但与许多其他模型兼容。

正如一些评论者正确指出的那样,重要的是注意JPMML项目使用GNU AGPL许可证。AGPL是一个强制版权许可证,可能会限制您使用该项目的能力。例如,如果您开发了一个公共可访问的服务并希望保持源代码闭源。


2
你如何确保特征转换部分在Python训练中和使用PMML在Java服务中的一致性? - Andrea Bergonzo
2
我尝试过这个方法,它确实可以将sklearn transformers和xgboost模型转换为Java。然而,由于AGPL许可证的原因,我们没有在生产环境中选择使用它。(虽然也有商业许可证,但协商许可证不符合我们项目的时间表。) - leon
1
我尝试了这个方法,通过Java程序保留了所有的特征提取、清洗和转换逻辑。在Java端(jpmml-evaluator)上运行良好。对于容器化的Spring Boot应用程序来说是一个不错的选择,可以大大降低DevOps的复杂性,因为Python训练的频率和时间表无法与Java程序的持续集成同步。 - Indrajit Kanjilal
@leon的评论非常重要,特别是对于那些将SO答案中的解决方案复制/粘贴作为其软件开发生命周期的重要部分的人。如果您在产品中使用jpmml-evaluator,则您的用户可能会强制您披露产品的所有源代码。这就是微软警告人们将所有开源软件等同于GPL(不是LGLP)和类似许可证下授权的库时所说的大坏狼。始终阅读您的许可证! - Christopher Schultz

6
您可以使用一个porter,我已经测试了sklearn-porter(https://github.com/nok/sklearn-porter),它对于Java非常有效。
我的代码如下:
import pandas as pd
from sklearn import tree
from sklearn_porter import Porter

train_dataset = pd.read_csv('./result2.csv').as_matrix()

X_train = train_dataset[:90, :8]
Y_train = train_dataset[:90, 8:]

X_test = train_dataset[90:, :8]
Y_test = train_dataset[90:, 8:]

print X_train.shape
print Y_train.shape


clf = tree.DecisionTreeClassifier()
clf = clf.fit(X_train, Y_train)

porter = Porter(clf, language='java')
output = porter.export(embed_data=True)
print(output)

在我的情况下,我正在使用DecisionTreeClassifier,并且

print(output)

的输出是控制台中以下代码的文本形式:

class DecisionTreeClassifier {

  private static int findMax(int[] nums) {
    int index = 0;
    for (int i = 0; i < nums.length; i++) {
        index = nums[i] > nums[index] ? i : index;
    }
    return index;
  }


  public static int predict(double[] features) {
    int[] classes = new int[2];

    if (features[5] <= 51.5) {
        if (features[6] <= 21.0) {

            // HUGE amount of ifs..........

        }
    }

    return findMax(classes);
  }

  public static void main(String[] args) {
    if (args.length == 8) {

        // Features:
        double[] features = new double[args.length];
        for (int i = 0, l = args.length; i < l; i++) {
            features[i] = Double.parseDouble(args[i]);
        }

        // Prediction:
        int prediction = DecisionTreeClassifier.predict(features);
        System.out.println(prediction);

    }
  }
}

谢谢提供信息。您能分享一下如何在Java中执行使用sklearn porter打包的sklearn模型,并用于预测的想法吗?- @gustavoresque - Sourav Saha

3
以下是JPMML解决方案的一些代码:
--PYTHON部分--
# helper function to determine the string columns which have to be one-hot-encoded in order to apply an estimator.
def determine_categorical_columns(df):
    categorical_columns = []
    x = 0
    for col in df.dtypes:
        if col == 'object':
            val = df[df.columns[x]].iloc[0]
            if not isinstance(val,Decimal):
                categorical_columns.append(df.columns[x])
        x += 1
    return categorical_columns

categorical_columns = determine_categorical_columns(df)
other_columns = list(set(df.columns).difference(categorical_columns))


#construction of transformators for our example
labelBinarizers = [(d, LabelBinarizer()) for d in categorical_columns]
nones = [(d, None) for d in other_columns]
transformators = labelBinarizers+nones

mapper = DataFrameMapper(transformators,df_out=True)
gbc = GradientBoostingClassifier()

#construction of the pipeline
lm = PMMLPipeline([
    ("mapper", mapper),
    ("estimator", gbc)
])

--JAVA PART --

//Initialisation.
String pmmlFile = "ScikitLearnNew.pmml";
PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(new FileInputStream(pmmlFile));
ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory.newInstance();
MiningModelEvaluator evaluator = (MiningModelEvaluator) modelEvaluatorFactory.newModelEvaluator(pmml);

//Determine which features are required as input
HashMap<String, Field>() inputFieldMap = new HashMap<String, Field>();
for (int i = 0; i < evaluator.getInputFields().size();i++) {
  InputField curInputField = evaluator.getInputFields().get(i);
  String fieldName = curInputField.getName().getValue();
  inputFieldMap.put(fieldName.toLowerCase(),curInputField.getField());
}


//prediction

HashMap<String,String> argsMap = new HashMap<String,String>();
//... fill argsMap with input

Map<FieldName, ?> res;
// here we keep only features that are required by the model
Map<FieldName,String> args = new HashMap<FieldName, String>();
Iterator<String> iter = argsMap.keySet().iterator();
while (iter.hasNext()) {
  String key = iter.next();
  Field f = inputFieldMap.get(key);
  if (f != null) {
    FieldName name =f.getName();
    String value = argsMap.get(key);
    args.put(name, value);
  }
}
//the model is applied to input, a probability distribution is obtained
res = evaluator.evaluate(args);
SegmentResult segmentResult = (SegmentResult) res;
Object targetValue = segmentResult.getTargetValue();
ProbabilityDistribution probabilityDistribution = (ProbabilityDistribution) targetValue;

1

1
我发现自己处于类似的情况中。 我建议开发一个分类器微服务。你可以编写一个运行在Python上的分类器微服务,然后通过一些RESTful API公开对该服务的调用,从而产生JSON/XML数据交换格式。我认为这是一种更清晰的方法。

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接