实际上,您可以像以前一样在model_fn函数中实现多GPU。您可以在
这里找到完整的代码。当使用estimator进行训练时,它支持多线程队列阅读器和多GPU,可实现非常高速的训练。
代码片段:(
获取完整代码)
def model_fn(features, labels, mode, params):
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=params['num_classes'],
weight_decay=0.00004,
is_training=(mode == tf.estimator.ModeKeys.TRAIN))
if mode == tf.estimator.ModeKeys.PREDICT:
logits, end_points = network_fn(features)
return tf.estimator.EstimatorSpec(mode=mode, predictions={"output": logits})
global_step = tf.train.get_global_step()
learning_rate = get_learning_rate("exponential", FLAGS.base_lr,
global_step, decay_steps=10000)
optimizer = get_optimizer(FLAGS.optimizer, learning_rate)
batch_size = tf.shape(features)[0]
split_size = batch_size // len(params['gpus_list'])
splits = [split_size, ] * (len(params['gpus_list']) - 1)
splits.append(batch_size - split_size * (len(params['gpus_list']) - 1))
features_split = tf.split(features, splits, axis=0)
labels_split = tf.split(labels, splits, axis=0)
tower_grads = []
eval_logits = []
with tf.variable_scope(tf.get_variable_scope()):
for i in xrange(len(params['gpus_list'])):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % ("classification", i)) as scope:
logits, end_points = network_fn(features_split[i])
tf.losses.softmax_cross_entropy(labels_split[i], logits)
update_ops = tf.get_collection(
tf.GraphKeys.UPDATE_OPS, scope)
updates_op = tf.group(*update_ops)
with tf.control_dependencies([updates_op]):
losses = tf.get_collection(tf.GraphKeys.LOSSES, scope)
total_loss = tf.add_n(losses, name='total_loss')
tf.get_variable_scope().reuse_variables()
grads = optimizer.compute_gradients(total_loss)
tower_grads.append(grads)
eval_logits.append(logits)
grads = average_gradients(tower_grads)
apply_gradient_op = optimizer.apply_gradients(
grads, global_step=global_step)
variable_averages = tf.train.ExponentialMovingAverage(0.9999, global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
train_op = tf.group(apply_gradient_op, variables_averages_op)
_predictions = tf.argmax(tf.concat(eval_logits, 0), 1)
_labels = tf.argmax(labels, 1)
eval_metric_ops = {
"acc": slim.metrics.streaming_accuracy(_predictions, _labels)}
return tf.estimator.EstimatorSpec(
mode=mode,
loss=total_loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops)