Theano(Python):逐元素梯度

4

我试图对元素进行梯度操作,例如:

输出-f(x):5×1向量,

关于输入X:5×1向量

我可以这样做:

 import theano
 import theano.tensor as T

 X = T.vector('X')   

 f = X*3    

 [rfrx, []] = theano.scan(lambda j, f,X : T.grad(f[j], X), sequences=T.arange(X.shape[0]), non_sequences=[f,X])

 fcn_rfrx = theano.function([X], rfrx)

 fcn_rfrx(np.ones(5,).astype(float32))

结果如下:

array([[ 3.,  0.,  0.,  0.,  0.],
       [ 0.,  3.,  0.,  0.,  0.],
       [ 0.,  0.,  3.,  0.,  0.],
       [ 0.,  0.,  0.,  3.,  0.],
       [ 0.,  0.,  0.,  0.,  3.]], dtype=float32)

但是因为它不够高效,我希望得到一个5乘1的向量作为结果。

通过这样做...

 [rfrx, []] = theano.scan(lambda j, f,X : T.grad(f[j], X[j]), sequences=T.arange(X.shape[0]), non_sequences=[f,X])

这个(代码/方案)无法正常工作。

有没有办法实现这个功能呢?(很抱歉格式不好,我是新手在学习)


(我添加了一个更明确的例子):

已知输入向量:x[1]、x[2]、…、x[n]

和输出向量:y[1]、y[2]、…、y[n]

其中 y[i] = f(x[i])。

我只想要 df(x[i])/dx[i] 的结果,而不需要 (i<>j) 时的 df(x[i])/dx[j]。

这样可以提高计算效率(n 表示数据量大于 10000)。


你可以直接对输出进行求和:f.sum(),并计算相对于X的梯度吗? - Alleo
@Alleo,我增加了一個範例以免混淆。這是否意味著 Theano 會自動避免計算 (i<>j) 的 df(x[i])/dx[j]? - eunsuk jung
鉴于您的额外公式(y[i] = f(x[i])),在进行求和并计算梯度时不会出现问题。Theano应该减少不必要的计算。 - Alleo
好的,我已经检查过它在我的情况下以可接受的速度运行。 - eunsuk jung
1个回答

3
你正在寻找 theano.tensor.jacobian
import theano
import theano.tensor as T

x = T.fvector()
p = T.as_tensor_variable([(x ** i).sum() for i in range(5)])

j = T.jacobian(p, x)

f = theano.function([x], [p, j])

现在评估产量
In [31]: f([1., 2., 3.])
Out[31]: 
[array([  3.,   6.,  14.,  36.,  98.], dtype=float32),
 array([[   0.,    0.,    0.],
        [   1.,    1.,    1.],
        [   2.,    4.,    6.],
        [   3.,   12.,   27.],
        [   4.,   32.,  108.]], dtype=float32)]

如果您只对某些偏导数感兴趣,也可以仅获取它们。为了能够看到这样做的效率提升有多大,需要仔细查看Theano优化规则(基准测试是第一次测试)。可能已经索引到梯度,使Theano明确知道它不需要计算其余部分。

x = T.fscalar()
y = T.fvector()
z = T.concatenate([x.reshape((1,)), y.reshape((-1,))])

e = (z ** 2).sum()
g = T.grad(e, wrt=x)

ff = theano.function([x, y], [e, g])

我想我的问题有点混淆,我添加了一个例子。我的意思是,我不想针对输入和输出具有相同向量大小的不必要元素执行梯度。 - eunsuk jung
那么这是否意味着,当我们使用像'f.sum()'这样的方法时,Theano会自动高效地不尝试计算(i<>j)的情况下的'df(x[i])/dx[j]'? - eunsuk jung
我更新了我的回答。这是否符合你的问题方向? - eickenberg
是的,我已经检查过了,使用这种方式('f.sum()')对于我的情况来说速度足够快,没有更深入的研究。 - eunsuk jung

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接