Julia: Zygote.@adjoint与Enzyme.autodiff有什么不同?

3

给定如下函数f!

function f!(s::Vector, a::Vector, b::Vector)
  
  s .= a .+ b
  return nothing

end # f!

我该如何基于 Zygote 定义伴随,使用的方法是 Enzyme.autodiff(f!, Const, Duplicated(s, dz_ds). Duplicated(a, zero(a)), Duplicated(b, zero(b)))
Zygote.@adjoint f!(s, a, b) = f!(s, a, b), # What would come here ?

我非常推荐在Julia Slack的#autodiff频道上询问这种专业问题。编写这些库的所有人都经常访问它,比Stack Overflow更多,我想。 - phipsgabler
感谢@phipsgabler,我将分享我找到的解决方案。但下次我会在Julia Slack频道上提问。 - luciano-drozda
1个回答

4
可以想出一种方法,这里分享一下。
对于给定的函数fooZygote.pullback(foo, args...)返回foo(args...)和反向传递(允许梯度计算)。
我的目标是告诉Zygote在反向传递中使用Enzyme
可以通过Zygote.@adjoint来实现这一点(详情见此处)。
对于数组值函数,Enzyme需要一个返回nothing和其结果在args中的可变版本(详情见此处)。
问题帖子中的函数f!是两个数组和的Enzyme兼容版本。
由于f!返回nothing,所以当我们调用反向传递时,Zygote会简单地返回nothing
解决方案是将f!放在一个包装器中(比如f),使其返回数组s
并为f定义Zygote.@adjoint,而不是f!
因此,
function f(a::Vector, b::Vector)

  s = zero(a)
  f!(s, a, b)
  return s

end

function enzyme_back(dzds, a, b)

  s    = zero(a)
  dzda = zero(dzds)
  dzdb = zero(dzds)
  Enzyme.autodiff(
    f!,
    Const,
    Duplicated(s, dzds),
    Duplicated(a, dzda),
    Duplicated(b, dzdb)
  )
  return (dzda, dzdb)

end

并且

Zygote.@adjoint f(a, b) = f(a, b), dzds -> enzyme_back(dzds, a, b)

通知 Zygote 在反向传播中使用 Enzyme


最后,您可以检查在哪里调用 Zygote.gradient

g1(a::Vector, b::Vector) = sum(abs2, a + b)

或者
g2(a::Vector, b::Vector) = sum(abs2, f(a, b))

产生相同的结果。

1
唯一的问题是Enzyme将再次执行前向传递。理想情况下,我们应该使用split ABI,但目前Enzyme.jl还没有提供公共API。 - vchuravy

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接