Caffe:如何获取Python层的阶段?

10

我在caffe中创建了一个名为"myLayer""Python"层,并将其用于网络train_val.prototxt中。插入该层的方式如下:

layer {
  name: "my_py_layer"
  type: "Python"
  bottom: "in"
  top: "out"
  python_param {
    module: "my_module_name"
    layer: "myLayer"
  }
  include { phase: TRAIN } # THIS IS THE TRICKY PART!
}

现在,我的层仅参与网络的TRAINing阶段,
我该如何在我的层的setup函数中知道这一点呢?


class myLayer(caffe.Layer):
  def setup(self, bottom, top):
     # I want to know here what is the phase?!!
  ...

PS,
我也在“Caffe用户”Google群组上发布了这个问题。如果那里有任何消息,我会更新的。

2个回答

7
正如galloguille所指出的,caffe现在将phase暴露给了python层类。这个新特性使得此答案有点多余。但是知道在caffe python层中使用param_str传递其他参数是很有用的。
以下是原始答案:
据我所知,没有简单的方法来获取阶段信息。不过,可以将任意参数从网络prototxt传递到python。这可以通过python_paramparam_str参数来完成。
以下是实现方式:
layer {
  type: "Python"
  ...
  python_param {
    ...
    param_str: '{"phase":"TRAIN","numeric_arg":5}' # passing params as a STRING

在Python中,在层的setup函数中获取param_str:
import caffe, json
class myLayer(caffe.Layer):
  def setup(self, bottom, top):
    param = json.loads( self.param_str ) # use JSON to convert string to dict
    self.phase = param['phase']
    self.other_param = int( param['numeric_arg'] ) # I might want to use this as well...

6
这是一个非常好的解决方案,但如果您只想将"phase"作为参数传递,现在您可以将其作为图层的属性来访问。这个特性只是在6天前被合并 https://github.com/BVLC/caffe/pull/3995
具体的提交:https://github.com/BVLC/caffe/commit/de8ac32a02f3e324b0495f1729bff2446d402c2c 有了这个新功能,您只需要使用属性self.phase。例如,您可以执行以下操作:
class PhaseLayer(caffe.Layer):
"""A layer for checking attribute `phase`"""

def setup(self, bottom, top):
    pass

def reshape(self, bootom, top):
    top[0].reshape()

def forward(self, bottom, top):
    top[0].data[()] = self.phase

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接