自定义Caffe Windows CPP中的卷积层

5

我有这个网络'RGB2GRAY.prototxt'

name: "RGB2GRAY"
layer {
  name: "data"
  type: "Input"
  top: "data"
  input_param { shape: { dim: 1 dim: 3 dim: 512 dim: 512 } }
}

layer {
    name: "conv1"
    bottom: "data"
    top: "conv1"
    type: "Convolution"
    convolution_param {
        num_output: 1
        kernel_size: 1
        pad: 0
        stride: 1
        bias_term: false
        weight_filler {
        type: "constant"
        value: 1
        }
    }
}

我正在尝试使用这个公式将RGB转换为灰度图像的网络:

x = 0.299r + 0.587g + 0.114b.

基本上,我可以使用自定义权重(0.299、0.587、0.114)对大小为1的内核进行卷积。但我不知道如何修改卷积层。我已经设置了权重和偏差,但无法修改过滤器值。 我尝试了下面的方法,但它无法更新卷积过滤器。
shared_ptr<Net<float> > net_;
net_.reset(new Net<float>("path of model file", TEST));

const shared_ptr<Blob<float> >& conv_blob = net_->blob_by_name("conv1");
float* conv_weight = conv_blob->mutable_cpu_data();
conv_weight[0] =  0.299;
conv_weight[1] =  0.587;
conv_weight[2] =  0.114;

net_->Forward();

//for dumping the output
const shared_ptr<Blob<float> >& probs = net_->blob_by_name("conv1");
const float* probs_out = probs->cpu_data();

cv::Mat matout(height, width, CV_32F);

for (size_t i = 0; i < height; i++)
{
    for (size_t j = 0; j < width; j++)
    {
        matout.at<float>(i, j) = probs_out[i* width + j];
    }

}
matout.convertTo(matout, CV_8UC1);
cv::imwrite("gray.bmp", matout);

在Python中,我发现自定义卷积滤波器更容易,但我需要C++的解决方案。


1
我正在寻找相同的解决方案。 - Ankit Dixit
我不明白你的解决方案有什么问题?当没有权重(“冷启动”)时,weight_fillrer 才会被使用。如果你现在保存网络,它不会保留正确的权重吗? - Shai
@shai... 是的....权重没有更新!! - AnkitSahu
1个回答

2

只需在您的C++代码中进行小改动:

// access the convolution layer by its name
const shared_ptr<Layer<float> >& conv_layer = net_->layer_by_name("conv1");
// access the layer's blob that stores weights
shared_ptr<Blob<float> >& weight = conv_layer->blobs()[0];
float* conv_weight = weight->mutable_cpu_data();
conv_weight[0] =  0.299;
conv_weight[1] =  0.587;
conv_weight[2] =  0.114;

实际上,"conv1"指的是代码中卷积层输出的blob,而不是包含权重的blob。函数Net<Dtype>::blob_by_name(const string& blob_name)的作用是返回存储网络层之间中间结果的blob

1
@AnkitSahu 请再试一次最新的代码,我无法运行它,需要你尝试几次。:) - Dale
1
@AnkitSahu 你能打印出我回答中weight blob的形状吗?同时确保你已经将输入图像馈送到net_中。 - Dale
@ Dale...向你致敬 :) (y) - AnkitSahu

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接