在多核设备上运行TensorFlow

9

我有一个基本的Android TensorFlowInference示例,在单线程中运行良好。

public class InferenceExample {

    private static final String MODEL_FILE = "file:///android_asset/model.pb";
    private static final String INPUT_NODE = "intput_node0";
    private static final String OUTPUT_NODE = "output_node0"; 
    private static final int[] INPUT_SIZE = {1, 8000, 1};
    public static final int CHUNK_SIZE = 8000;
    public static final int STRIDE = 4;
    private static final int NUM_OUTPUT_STATES = 5;

    private static TensorFlowInferenceInterface inferenceInterface;

    public InferenceExample(final Context context) {
        inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);
    }

    public float[] run(float[] data) {

        float[] res = new float[CHUNK_SIZE / STRIDE * NUM_OUTPUT_STATES];

        inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]);
        inferenceInterface.run(new String[]{OUTPUT_NODE});
        inferenceInterface.fetch(OUTPUT_NODE, res);

        return res;
    }
}

以下是示例代码,当在 ThreadPool 中运行时,该示例会崩溃并抛出各种异常,包括 java.lang.ArrayIndexOutOfBoundsExceptionjava.lang.NullPointerException,因此我猜测它不是线程安全的。

InferenceExample inference = new InferenceExample(context);

ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);    
Collection<Future<?>> futures = new LinkedList<Future<?>>();

for (int i = 1; i <= 100; i++) {
    Future<?> result = executor.submit(new Runnable() {
        public void run() {
           inference.call(randomData);
        }
    });
    futures.add(result);
}

for (Future<?> future:futures) {
    try { future.get(); }
    catch(ExecutionException | InterruptedException e) {
        Log.e("TF", e.getMessage());
    }
}

使用 TensorFlowInferenceInterface 可以利用多核 Android 设备吗?

2个回答

1
为了使 InferenceExample 线程安全,我将 TensorFlowInferenceInterfacestatic 改为非静态,并将 run 方法改为 synchronized
private TensorFlowInferenceInterface inferenceInterface;

public InferenceExample(final Context context) {
    inferenceInterface = new TensorFlowInferenceInterface(assets, model);
}

public synchronized float[] run(float[] data) { ... }

然后我对一个InterferenceExample实例列表进行循环,跨越numThreads

for (int i = 1; i <= 100; i++) {
    final int id = i % numThreads;
    Future<?> result = executor.submit(new Runnable() {
        public void run() {
            list.get(id).run(data);
        }
    });
    futures.add(result);
}

这确实提高了性能,但在8核设备上,numThreads为2时,在Android Studio监视器中仅显示约50%的CPU使用率。

{btsdaf} - ash
{btsdaf} - Chris Seymour
哦,抱歉,我看错了,并没有注意到feed()fetch()的调用是在你同步的run()内部。所以,我在上面的评论中是错误的。然而,你的方法会限制并行性,因为它实际上将TensorFlow会话的使用串行化 - 只有一个线程可以同时执行模型。 - ash
{btsdaf} - Chris Seymour
{btsdaf} - ash

0

TensorFlowInferenceInterface类不是线程安全的(因为它在feedrunfetch等调用之间保留状态)。

然而,它是建立在TensorFlow Java API之上的,Session类的对象是线程安全的。

因此,您可能希望直接使用底层的Java API,TensorFlowInferenceInterface的构造函数创建一个Session并使用从AssetManagercode)加载的Graph进行设置。

希望这有所帮助。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接