TensorFlow 随机森林回归

5
我希望能够实现一个简单的随机森林回归来预测值。输入是一些具有多个特征的样本,标签是一个值。然而,我找不到关于随机森林回归问题的简单示例。因此,我看到了tensorflow的文档并发现:
一个可以训练和评估随机森林的估计器。 示例:
  python
  params = tf.contrib.tensor_forest.python.tensor_forest.ForestHParams(
      num_classes=2, num_features=40, num_trees=10, max_nodes=1000)
  # Estimator using the default graph builder.
  estimator = TensorForestEstimator(params, model_dir=model_dir)
  # Or estimator using TrainingLossForest as the graph builder.
  estimator = TensorForestEstimator(
      params, graph_builder_class=tensor_forest.TrainingLossForest,
      model_dir=model_dir)
  # Input builders
  def input_fn_train: # returns x, y
    ...
  def input_fn_eval: # returns x, y
    ...
  estimator.fit(input_fn=input_fn_train)
  estimator.evaluate(input_fn=input_fn_eval)
  # Predict returns an iterable of dicts.
  results = list(estimator.predict(x=x))
  prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
  prediction0 = results[0][eval_metrics.INFERENCE_PRED_NAME]

然而,当我按照示例操作时,在这行代码prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]中遇到了错误,该错误显示为:
Example conversion:
est = Estimator(...) -> est = SKCompat(Estimator(...))
Traceback (most recent call last):
  File "RF_2.py", line 312, in <module>
    main()
  File "RF_2.py", line 298, in main
    train_eval(x_train, y_train, x_validation, y_validation, x_test, y_test, num_tree)
  File "RF_2.py", line 221, in train_eval
    prob0 = results[0][eval_metrics.INFERENCE_PROB_NAME]
KeyError: 'probabilities'

我认为错误出在INFERENCE_PROB_NAME上,我看了文档,但仍不知道该用什么词来替换INFERENCE_PROB_NAME

我尝试使用get_metric('accuracy')来替换INFERENCE_PROB_NAME,但它返回了错误:KeyError: <function _accuracy at 0x11a06eaa0>

我还尝试使用get_prediction_key('accuracy')来替换INFERENCE_PROB_NAME,但它返回了错误:KeyError: 'classes'

如果你知道可能的答案,请告诉我。谢谢。

2个回答

1

num_classes=0 在 TensorFlow 1.3.0 中是错误的。

根据 Mehdi Rezaie 的链接,num_classes 是回归问题输出维度的数量。

你必须使用 num_classes=1 或更大的值来设置 num_classes。 否则,你将会得到类似于 ValueError: Invalid logits_dimension 0. 的错误。


1

我认为你在无意中进行分类问题,因为你给了错误的num_classes=2并且没有改变regression=False的默认值。请参见参数部分这里。只是一个快速测试,将num_classes=0regression=True设置,并重新运行你的代码。


num_classes=0对于v1.3不适用。 - Majo_Jose

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接