如何使用Flux.jl绘制一个函数及其梯度/导数

4

我想使用Flux.jlPlots.jl绘制一个函数及其梯度。

using Flux.Tracker
using Plots

f(x::Float64) = 3x^2 + 2x + 1
df(x::Float64) = Tracker.gradient(f, x)[1]
d2f(x::Float64) = Tracker.gradient(df, x)[1]

plot([f], -2, 2)
plot!([df], -2, 2)

I get:

ERROR: LoadError: MethodError: no method matching Float64(::Flux.Tracker.TrackedReal{Float64})
Closest candidates are:
  Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:194
  Float64(::T<:Number) where T<:Number at boot.jl:741
  Float64(::Int8) at float.jl:60

所以我想的是将 Flux.Tracker.TrackedReal{Float64} 转换为 Float64。我该怎么做呢?
1个回答

4
您可以使用以下内容(适用于Flux 0.8.3):
f(x::Float64) = 3x^2 + 2x + 1
df(x::Float64) = Tracker.data(Tracker.gradient(f, x, nest=true)[1])
d2f(x::Float64) = Tracker.data(Tracker.gradient(df, x, nest=true)[1])

非常感谢您的回答@BogumilKaminski。为了尝试您的答案,我应该更新到Flux 0.8.3,但我尝试过并没有成功(尽管重新安装和更新了软件包),因此在这里发布了另一个问题:https://dev59.com/ouk5XIcBkEYKwwoY49on - ecjb

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接