Caffe中的多标签回归

11

我正在按照kaggle人脸关键点比赛的要求,从输入图像中提取30个面部关键点(x,y)。

我该如何设置caffe来运行回归并产生30维输出??。

Input: 96x96 image
Output: 30 - (30 dimensions).

我该如何正确设置Caffe?我正在使用欧几里得损失(平方和)来得到回归输出。这是一个使用Caffe的简单逻辑回归模型,但它并没有起作用。似乎准确度层不能处理多标签输出。

I0120 17:51:27.039113  4113 net.cpp:394] accuracy <- label_fkp_1_split_1
I0120 17:51:27.039135  4113 net.cpp:356] accuracy -> accuracy
I0120 17:51:27.039158  4113 net.cpp:96] Setting up accuracy
F0120 17:51:27.039201  4113 accuracy_layer.cpp:26] Check failed: bottom[1]->channels() == 1 (30 vs. 1) 
*** Check failure stack trace: ***
    @     0x7f7c2711bdaa  (unknown)
    @     0x7f7c2711bce4  (unknown)
    @     0x7f7c2711b6e6  (unknown)

这里是图层文件:

name: "LogReg"
layers {
  name: "fkp"
  top: "data"
  top: "label"
  type: HDF5_DATA
  hdf5_data_param {
   source: "train.txt"
   batch_size: 100
  }
    include: { phase: TRAIN }

}

layers {
  name: "fkp"
  type: HDF5_DATA
  top: "data"
  top: "label"
  hdf5_data_param {
    source: "test.txt"
    batch_size: 100
  }

  include: { phase: TEST }
}

layers {
  name: "ip"
  type: INNER_PRODUCT
  bottom: "data"
  top: "ip"
  inner_product_param {
    num_output: 30
  }
}
layers {
  name: "loss"
  type: EUCLIDEAN_LOSS
  bottom: "ip"
  bottom: "label"
  top: "loss"
}

layers {
  name: "accuracy"
  type: ACCURACY
  bottom: "ip"
  bottom: "label"
  top: "accuracy"
  include: { phase: TEST }
}

请将可正常运行的模型定义文件(.prototxt)作为更新贴子或答案发布。 - mrgloom
1
准确度层在回归设置中不起作用,它仅适用于分类问题。 - curio17
1个回答

4

我找到了它 :)

我将SOFTLAYER替换为EUCLIDEAN_LOSS函数,并改变了输出数量。这样做是生效的。

layers {
  name: "loss"
  type: EUCLIDEAN_LOSS
  bottom: "ip1"
  bottom: "label"
  top: "loss"
}

HINGE_LOSS 是另一个可选项。

你的输出数量有何变化? - nayef
我将输入数据重塑为(total,1,96,96)的形状,并将输出标签重塑为(total,30)。 - pbu
你能否详细解释一下为什么你避免批处理模式并只拿了一个例子,以及为什么你把标签改成了30? - thetna
@pbu,你能解释一下你是怎么处理准确率层的吗?另外,在问题中,你最初的.prototxt文件已经将类型设置为EUCLIDEAN_LOSS。请在此处发布最终的.prototxt文件。这将非常有帮助。谢谢。 - iamprem
http://corpocrat.com/2015/02/24/facial-keypoints-extraction-using-deep-learning-with-caffe/ - pbu
当训练完成后,您如何处理输出?我可以看到您的标签是(总数,30),但这意味着您的标签是1D而不是2D(x、y=坐标值)@pbu - user4911648

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接