最近我在尝试使用TensorFlow (TF),遇到一个问题:假设我想计算函数
的值和梯度。其中x的索引不同,但均指向同一向量 ,而J是随机常数(在物理学中这是自旋玻璃模型)。相对于 的梯度简单地为
因此,f
对N^3个项求和,而gradf
对N^2项求和N次。我通过生成所有求和的项作为秩为3的张量,并在所有条目上进行汇总约减来实现f
。然后,为了进行微分,我应用了以下操作:
tf.gradients(f, xk)[0]
其中f是损失函数,xk是一个变量。以下是一个MWE,假设所有的J都为1。
import numpy as np
import tensorflow as tf
#first I define the variable
n=10 #size of x
x1 = tf.Variable(tf.zeros([n], dtype='float64'))
x2 = tf.placeholder(tf.float64, shape=[n])
#here I define the cost function
f_tensor = tf.mul(tf.mul(tf.reshape(x1, [n]),
tf.reshape(x2, [n,1])),
tf.reshape(x2, [n,1,1]))
f = tf.reduce_sum(f_tensor)
session = tf.Session()
init = tf.initialize_all_variables()
session.run(init)
#run on test array
xtest = np.ones(n)
res = session.run([f, tf.gradients(f, x1)[0]],
feed_dict={x1 : xtest,
x2 : xtest})
assert res[0] == 1000
assert all(res[1] == np.array([100 for _ in xrange(n)]))
我需要独立多次调用run
方法,并希望将变量赋值的次数减少到只有一次,因为x1和x2指向同一个向量。
对于n=200
的相关示例进行了一些分析(在GeForce GTX 650上),结果显示:
- cuMemcpyDtoHAsync 占用63%的时间
- cuMemcpyHtoDAsync 占用18%,以及
- cuEventRecord 占用18%。
(这个MWE的结果类似)
因此,在GPU上执行计算时,赋值操作是最昂贵的操作。显然,随着n
的增加,开销会变得更糟,从而部分抵消使用GPU的好处。
你有什么建议可以减少传输X的开销吗?
同时,如果您有任何其他减少开销的建议,那将不胜感激。
编辑
为了展示问题,我将按照的建议替换所有x2实例为x1,则MWE如下所示
#first I define the variable
n=10 #size of x
x1 = tf.Variable(tf.zeros([n], dtype='float64'))
#here I define the cost function
f_tensor = tf.mul(tf.mul(tf.reshape(x1, [n]),
tf.reshape(x1, [n,1])),
tf.reshape(x1, [n,1,1]))
f = tf.reduce_sum(f_tensor)
session = tf.Session()
init = tf.initialize_all_variables()
session.run(init)
#run on test array
xtest = np.ones(n)
session.run(x1.assign(xtest))
res = session.run([f, tf.gradients(f, x1)[0]])
assert res[0] == 1000
for g in res[1]:
assert g == 100
第二个断言会失败,因为梯度的每个条目都应该是100而不是300。原因是xi,xj,xk都引用同一个向量,但它们在符号上是不同的:如果将所有x替换为相同的变量,则会得到x^3的导数(即3*x^2),这就是第二个MWE的结果。
P.S. 为了清晰起见,我还明确地分配了x1。
x1
和x2
提供相同的向量,是否需要定义两个单独的张量?例如,如果删除x2
的定义并将所有引用x2
的地方替换为x1
,我认为您的程序将具有相同的语义。 - mrryx1.assign(…)
),那么程序中为什么会有一个tf.Variable
是不清楚的。) - mrryGradientDescentOptimize
来最小化损失函数,那么这样做是行不通的,因为为了使其起作用,我必须以相同的方式更新x1和x2,但只对x1进行导数运算。我们回到了最初的问题。 - stefano