在Android上运行TensorFlow图时遇到输出维度问题

4
我已经将我的tensorflow图输出到Android,并尝试运行它。我从CSV中输入了一些数据,似乎工作正常,但最终节点的输出是批次x时间x特征维度,而我所能看到的仅有单个数组的输出函数。
我收到的错误信息是:
08-28 10:01:44.162 10602-10602/com.example.rob.android_kds E/TensorFlowInferenceInterface: Failed to run TensorFlow inference with inputs:[the_input], outputs:[output_node0]
08-28 10:01:44.162 10602-10602/com.example.rob.android_kds E/TensorFlowInferenceInterface: Inference exception: java.lang.IllegalArgumentException: Input shape axis 0 must equal 3, got shape [1]
                                                                                               [[Node: fc1/unstack = Unpack[T=DT_INT32, axis=0, num=3, _device="/job:localhost/replica:0/task:0/cpu:0"](fc1/Shape)]]
08-28 10:01:44.162 10602-10602/com.example.rob.android_kds I/System.out: readOutput
08-28 10:01:44.172 10602-10602/com.example.rob.android_kds E/AndroidRuntime: FATAL EXCEPTION: main
                                                                            Process: com.example.rob.android_kds, PID: 10602
                                                                            java.lang.IndexOutOfBoundsException: Invalid index 0, size is 0
                                                                                at java.util.ArrayList.throwIndexOutOfBoundsException(ArrayList.java:255)
                                                                                at java.util.ArrayList.get(ArrayList.java:308)
                                                                                at org.tensorflow.contrib.android.TensorFlowInferenceInterface.getTensor(TensorFlowInferenceInterface.java:486)
                                                                                at org.tensorflow.contrib.android.TensorFlowInferenceInterface.readNodeIntoFloatBuffer(TensorFlowInferenceInterface.java:332)
                                                                                at org.tensorflow.contrib.android.TensorFlowInferenceInterface.readNodeFloat(TensorFlowInferenceInterface.java:287)
                                                                                at com.example.rob.android_kds.MainActivity$1.onClick(MainActivity.java:171)
                                                                                at android.view.View.performClick(View.java:5697)
                                                                                at android.view.View$PerformClick.run(View.java:22526)
                                                                                at android.os.Handler.handleCallback(Handler.java:739)
                                                                                at android.os.Handler.dispatchMessage(Handler.java:95)
                                                                                at android.os.Looper.loop(Looper.java:158)
                                                                                at android.app.ActivityThread.main(ActivityThread.java:7225)
                                                                                at java.lang.reflect.Method.invoke(Native Method)
                                                                                at com.android.internal.os.ZygoteInit$MethodAndArgsCaller.run(ZygoteInit.java:1230)
                                                                                at com.android.internal.os.ZygoteInit.main(ZygoteInit.java:1120)

这是我的代码片段:

               // Copy the input data into TensorFlow.
               System.out.println("inputNode");
               Trace.beginSection("fillNodeFloat");
               //input is 3x234x26 and array is a unravelled arr = 18252
               tensorflow.fillNodeFloat(
                       "the_input", new int[]{3 * 234 * 26}, arr);
               Trace.endSection();

               // Run the inference call.
               System.out.println("runInference");
               Trace.beginSection("runInference");
               String outputNode = "output_node0";
               String[] outputNodes = {outputNode};
               tensorflow.runInference(outputNodes);
               Trace.endSection();

               // Copy the output Tensor back into the output array.
               System.out.println("readOutput");
               Trace.beginSection("readNodeFloat");
               //output should be batchxtimex29 (3 x 234 x 29) = 20358 flattened array
               float[] output=new float[20358];

               tensorflow.readNodeFloat(outputNode, output); // ERROR HERE
               Trace.endSection();

需要帮助(完整代码在此处https://github.com/mlrobsmt/kds2Droid),谢谢。

2个回答

0

根据您的错误信息,它说“输入形状轴0必须等于3,得到形状[1]”,以及您代码中定义输入数组的这一行“float[] arr=new float[18252];”,我不确定但希望您异常的原因是您的输入没有适当的形状。实际上,我认为您的输入应该是一个三维数组而不是一个向量。


我同意Pooyan的观点,然而从TensorFlow输入选项来看,似乎没有其他方法可以将3D数组作为输入,除非将其展开并使用tensorflow.fillNodeFloat("the_input", new int[]{3 * 234 * 26}, arr)。 - robmsmt

0

我认为问题出在输出维度上。在TensorFlow Android中,它只接受1D数组作为输出。因此,您需要确保在您的TensorFlow模型中仅输出1D数组。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接