如何通过TensorFlowInferenceInterface.java来填充布尔占位符?

4
我正在尝试通过Java Tensorflow API启动用Keras Tensorflow训练的图。除了标准输入图像占位符外,这个图包含一个'keras_learning_phase'占位符,需要用一个布尔值来填充。然而,在TensorFlowInferenceInterface中没有提供给boolean值的方法,你只能使用floatdoubleintbyte值进行填充。
很明显,当我尝试通过以下代码向该张量传递int时:
inferenceInterface.fillNodeInt("keras_learning_phase",  
                               new int[]{1}, new int[]{0});

我得到了

tensorflow_inference_jni.cc:207 推理过程中发生错误:内部错误: 与节点 _recv_keras_learning_phase_0 相关的类型为 int32 的输出与已声明的 bool 类型的输出不匹配, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=4742451733276497694, tensor_name="keras_learning_phase", tensor_type=DT_BOOL, _device="/job:localhost/replica:0/task:0/cpu:0"

有没有办法规避这个问题?
也许可以将图中的 Placeholder 节点显式转换为 Constant ,这样做可以解决问题吗?
或者最初就避免在图中创建此 Placeholder 节点?


请在您遇到问题的地方发布您的代码。 - Srihari
我也遇到了同样的问题,下面的答案解决了输入流问题,但在Android上,我又遇到了另一个错误:没有注册OpKernel来支持Op 'Switch',有一份文档提到了这个问题,https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/#fixing-missing-kernel-errors-on-mobile,但我不知道该怎么做,我应该重新构建自定义的TensorFlow库吗?你是否也遇到了同样的问题,你是如何解决的?谢谢! - Piasy
@Piasy,我没有完全相同的问题,但有几次遇到了类似的情况。对我来说,更新tensorflow库到当前版本并重新构建.so库有所帮助。 - Dmitry Tochilkin
2个回答

5

TensorFlowInferenceInterface类本质上是对完整的TensorFlow Java API的方便包装,该API支持布尔值。

您可以向TensorFlowInferenceInterface添加一个方法来完成您想要的操作。类似于fillNodeInt,您可以添加以下内容(请注意,TensorFlow中的布尔值表示为一个字节):

public void fillNodeBool(String inputName, int[] dims, bool[] src) {
  byte[] b = new byte[src.length];
  for (int i = 0; i < src.length; ++i) {
    b[i] = src[i] ? 1 : 0;
  }
  addFeed(inputName, Tensor.create(DatType.BOOL, mkDims(dims), ByteBuffer.wrap(b)));
}

希望这有所帮助。如果它起作用了,我鼓励您向TensorFlow代码库做出贡献。

1
谢谢,@ash。我已经将这个方法添加到推理接口的代码中,重新构建了_libandroid_tensorflow_inference_java.jar_,并且它可以工作了。稍后我会将其提交到Tensorflow代码库中。 - Dmitry Tochilkin

0

这是在answerash的基础上的补充,因为Tensorflow API已经有了一些变化。对我来说使用这个方法起作用了:

public void feed(String inputName, boolean[] src, long... dims) {
  byte[] b = new byte[src.length];
  for (int i = 0; i < src.length; i++) {
    b[i] = src[i] ? (byte) 1 : (byte) 0;
  }
  addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b)));
}

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接