为什么tensorflow的map_fn比python的for循环慢?

4
我希望对一个形状为 [batch, n_pts, 3] 的张量中的每个批次的点应用独立的旋转,我实现了两种不同的方法。第一种方法是将张量转换为numpy数组并使用基本的python for循环。第二种方法使用tensorflow的tf.map_fn()函数来消除for循环。然而,当我运行这段代码时,tensorflow的map_fn()比第一种方法慢100倍左右。
我的问题是:我在这里是否正确地使用了tf.map_fn()函数?什么情况下您会预期使用tf.map_fn()的性能提高而不是标准的numpy/python呢?
如果我使用得正确,那么我想知道为什么tensorflow的tf.map_fn()如此缓慢。
重现实验的代码如下:
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K


def rotate_tf(pc):

    theta = tf.random.uniform((1,)) * 2.0 * np.pi
    cosval = tf.math.cos(theta)[0]
    sinval = tf.math.sin(theta)[0]

    R = tf.Variable([
        [cosval, -sinval, 0.0],
        [sinval, cosval, 0.0],
        [0.0, 0.0, 1.0]
    ])

    def dot(p):
        return K.dot(R, tf.expand_dims(p, axis=-1))

    return tf.squeeze(tf.map_fn(dot, pc))


def rotate_np(pc):

    theta = np.random.uniform() * 2.0 * np.pi
    cosval = np.cos(theta)
    sinval = np.sin(theta)

    R = np.array([
        [cosval, -sinval, 0.0],
        [sinval, cosval, 0.0],
        [0.0, 0.0, 1.0]
    ])

    for idx, p in enumerate(pc):
        pc[idx] = np.dot(R, p)

    return pc


pts = tf.random.uniform((8, 100, 3))
n = 10

# Start tensorflow map_fn() method
start = time.time()

for i in range(n):
    pts = tf.map_fn(rotate_tf, pts)

print('processed tf in: {:.4f} s'.format(time.time()-start))

# Start numpy method
start = time.time()

for i in range(n):

    pts = pts.numpy()
    for i, p in enumerate(pts):
        pts[i] = rotate_np(p)
    pts = tf.Variable(pts)

print('processed np in: {:.4f} s'.format(time.time()-start))

这的输出结果为:
processed tf in: 3.8427 s
processed np in: 0.0314 s
1个回答

0
你使用的计算机资源有哪些? 我猜测TF正在操作更复杂的对象(张量)。
但是你没有充分利用TF!TF有一个图模式。这个功能将加快你的计算速度。
当使用小数组时,numpy的性能优于TF(使用下面的代码)。对于更大的数据,我发现情况正好相反。我做了一些小的修改:
  • 在rotate_tf中使用@tf.function装饰器
  • 移除rotate_tf中的tf.Variable
  • 在for循环中移除numpy/tf的转换,以获取numpy for循环的执行时间
  • 改变n和pts的批量大小
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras import backend as K


@tf.function(jit_compile=True, reduce_retracing=True)
def rotate_tf(pc):

    theta = tf.random.uniform((1,)) * 2.0 * np.pi
    cosval = tf.math.cos(theta)[0]
    sinval = tf.math.sin(theta)[0]

    R = tf.convert_to_tensor([
        [cosval, -sinval, 0.0],
        [sinval, cosval, 0.0],
        [0.0, 0.0, 1.0]
    ])

    def dot(p):
        return K.dot(R, tf.expand_dims(p, axis=-1))

    return tf.squeeze(tf.map_fn(dot, pc))


def rotate_np(pc):

    theta = np.random.uniform() * 2.0 * np.pi
    cosval = np.cos(theta)
    sinval = np.sin(theta)

    R = np.array([
        [cosval, -sinval, 0.0],
        [sinval, cosval, 0.0],
        [0.0, 0.0, 1.0]
    ])

    for idx, p in enumerate(pc):
        pc[idx] = np.dot(R, p)

    return pc


pts = tf.random.uniform((8, 100, 3))
n = 100

# Start tensorflow map_fn() method
start = time.time()

for i in range(n):
    pts = tf.map_fn(rotate_tf, pts)

print('processed tf in: {:.4f} s'.format(time.time()-start))

# Start numpy method
start = time.time()

pts = pts.numpy()
for i in range(n):

    for i, p in enumerate(pts):
        pts[i] = rotate_np(p)
    
print('processed np in: {:.4f} s'.format(time.time()-start))

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接