我有一台配备了多个GPU的服务器,希望在Java应用程序内进行模型推断时充分利用它们。 默认情况下,TensorFlow会占用所有可用的GPU,但只使用第一个GPU。
我能想到三个选项来解决这个问题:
在进程级别上限制设备可见性,即使用
CUDA_VISIBLE_DEVICES
环境变量。这将要求我运行几个Java应用程序实例并在它们之间分配流量。 不是那么诱人的想法。
在单个应用程序中启动多个会话,并尝试通过
ConfigProto
将一个设备分配给每个会话:public class DistributedPredictor { private Predictor[] nested; private int[] counters; // ... public DistributedPredictor(String modelPath, int numDevices, int numThreadsPerDevice) { nested = new Predictor[numDevices]; counters = new int[numDevices]; for (int i = 0; i < nested.length; i++) { nested[i] = new Predictor(modelPath, i, numDevices, numThreadsPerDevice); } } public Prediction predict(Data data) { int i = acquirePredictorIndex(); Prediction result = nested[i].predict(data); releasePredictorIndex(i); return result; } private synchronized int acquirePredictorIndex() { int i = argmin(counters); counters[i] += 1; return i; } private synchronized void releasePredictorIndex(int i) { counters[i] -= 1; } } public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { GPUOptions gpuOptions = GPUOptions.newBuilder() .setVisibleDeviceList("" + deviceIdx) .setAllowGrowth(true) .build(); ConfigProto config = ConfigProto.newBuilder() .setGpuOptions(gpuOptions) .setInterOpParallelismThreads(numDevices * numThreadsPerDevice) .build(); byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); Graph graph = new Graph(); graph.importGraphDef(graphDef); this.session = new Session(graph, config.toByteArray()); } public Prediction predict(Data data) { // ... } }
这种方法乍一看似乎很有效,但有时会忽略
setVisibleDeviceList
选项并全部使用第一个设备,导致内存不足崩溃。使用
tf.device()
规范在Python中以多塔的方式构建模型。在Java端,为不同的Predictor
在共享会话中分配不同的塔。对我来说感觉笨重且用法不正确。
更新:正如@ash提出的那样,还有另一种选择:
通过修改其定义(
graphDef
),为现有图中的每个操作分配适当的设备。要完成此操作,可以根据第2种方法中的代码进行调整:
public class Predictor { private Session session; public Predictor(String modelPath, int deviceIdx, int numDevices, int numThreadsPerDevice) { byte[] graphDef = Files.readAllBytes(Paths.get(modelPath)); graphDef = setGraphDefDevice(graphDef, deviceIdx) Graph graph = new Graph(); graph.importGraphDef(graphDef); ConfigProto config = ConfigProto.newBuilder() .setAllowSoftPlacement(true) .build(); this.session = new Session(graph, config.toByteArray()); } private static byte[] setGraphDefDevice(byte[] graphDef, int deviceIdx) throws InvalidProtocolBufferException { String deviceString = String.format("/gpu:%d", deviceIdx); GraphDef.Builder builder = GraphDef.parseFrom(graphDef).toBuilder(); for (int i = 0; i < builder.getNodeCount(); i++) { builder.getNodeBuilder(i).setDevice(deviceString); } return builder.build().toByteArray(); } public Prediction predict(Data data) { // ... } }
就像其他提到的方法一样,这种方法并不能使我从手动分配数据给不同设备中获得解放。但至少它工作稳定,实现起来相对容易。总的来说,这看起来像是一种(几乎)正常的技术。
在tensorflow java API中有一种优雅的方法来做这样基本的事情吗?欢迎任何想法。