朱莉娅语言:函数类型和性能

9

在Julia中,是否有一种方式可以将以下模式进行通用化?

function compute_sum(xs::Vector{Float64})
    res = 0
    for i in 1:length(xs)
        res += sqrt(xs[i])
    end
    res
end

这会计算每个向量元素的平方根,然后将所有结果相加。与使用数组推导或 map 的“朴素”版本相比,它要快得多,也不需要额外分配内存:

xs = rand(1000)

julia> @time compute_sum(xs)
  0.000004 seconds
676.8372556762225

julia> @time sum([sqrt(x) for x in xs])
  0.000013 seconds (3 allocations: 7.969 KiB)
676.837255676223

julia> @time sum(map(sqrt, xs))
  0.000013 seconds (3 allocations: 7.969 KiB)
676.837255676223

很不幸,“显而易见”的通用版本在性能方面很糟糕:

function compute_sum2(xs::Vector{Float64}, fn::Function)
    res = 0
    for i in 1:length(xs)
        res += fn(xs[i])
    end
    res
end

julia> @time compute_sum2(xs, x -> sqrt(x))
  0.013537 seconds (19.34 k allocations: 1.011 MiB)
676.8372556762225
4个回答

8

原因在于每次调用compute_sum2时,x -> sqrt(x) 都会被定义为一个新的匿名函数,这就导致了每次都需要重新编译。

如果你事先像这样定义它:

julia> f = x -> sqrt(x)

那么你就拥有了:

julia> @time compute_sum2(xs, f) # here you pay compilation cost
  0.010053 seconds (19.46 k allocations: 1.064 MiB)
665.2469135020949

julia> @time compute_sum2(xs, f) # here you have already compiled everything
  0.000003 seconds (1 allocation: 16 bytes)
665.2469135020949

请注意,一种自然的方法是定义一个名为这样的函数:
julia> g(x) = sqrt(x)
g (generic function with 1 method)

julia> @time compute_sum2(xs, g)
  0.000002 seconds
665.2469135020949

您可以看到,当您编写例如以下代码时:x -> sqrt(x)每次遇到它都会定义一个新的匿名函数。
julia> typeof(x -> sqrt(x))
var"#3#4"

julia> typeof(x -> sqrt(x))
var"#5#6"

julia> typeof(x -> sqrt(x))
var"#7#8"

请注意,如果在函数体中定义匿名函数,则情况将不同:
julia> h() = typeof(x -> sqrt(x))
h (generic function with 2 methods)

julia> h()
var"#11#12"

julia> h()
var"#11#12"

julia> h()
var"#11#12"

你会发现这次匿名函数每次都是相同的。


1
谢谢!我以为是因为 fn 类型不明确,从而导致类型不稳定,但实际上是编译时间的问题。 - cno

7

除了Bogumil的出色回答,我想补充一点,一个非常方便的泛化方法是使用普通的函数式编程函数,例如mapreducefold等。

在这种情况下,你正在进行一个map转换(即sqrt)和一个reduce(即+),所以你也可以用mapreduce(sqrt, +, xs)来实现结果。这基本上没有任何开销,并且在性能上与手动循环相当。

如果你有一个非常复杂的转换系列,你可以获得最佳的性能并仍然使用一个函数,使用Transducers.jl包。


知道这个真是太好了!我其实一开始尝试使用map,但是性能很差让我感到沮丧。我不知道还有mapreduce这个选项。 - cno

3

Bogumił已经回答了有关函数类型的部分。 我想指出,如果适当地进行基准测试,您的实现已经以最高效的方式运行,但可以用等效的内置函数替换:

julia> @btime compute_sum($xs)
  2.149 μs (0 allocations: 0 bytes)
661.6571623823567

julia> @btime sum(sqrt, $xs)
  2.149 μs (0 allocations: 0 bytes)
661.6571623823567

julia> @btime compute_sum2($xs, sqrt)
  2.149 μs (0 allocations: 0 bytes)
661.6571623823567

julia> @btime mapreduce(sqrt, +, $xs)
  2.149 μs (0 allocations: 0 bytes)
661.6571623823567

如果可能的话,最好使用等效于eta的非lambda函数: f 而不是 x -> f(x)。特别是对于内置函数,因为它们有时会进行分派。


1
其他回答已经很全面了,但我想指出你可以省略sum([sqrt(x) for x in xs])中的方括号来获得最快的版本:
julia> using BenchmarkTools

julia> @btime compute_sum($xs)
  1.779 μs (0 allocations: 0 bytes)
679.0943275393031

julia> @btime sum([sqrt(x) for x in $xs])
  1.626 μs (1 allocation: 7.94 KiB)
679.0943275393028

julia> @btime sum(map(sqrt, $xs))
  1.628 μs (1 allocation: 7.94 KiB)
679.0943275393028

julia> @btime sum(sqrt(x) for x in $xs)
  1.337 μs (0 allocations: 0 bytes)
679.0943275393031

请注意,在我的计算机上,使用Julia主程序计算compute_sum是最慢的方法,而不是最快的方法来对这些数字求和。

1
我认为生成器不应该比sum(sqrt, xs)更快,但它们应该是相同的 - 对吗? - Bogumił Kamiński
1
当然,为了完整起见,可以包括生成器版本:您可以在不需要中间分配的情况下获得漂亮的理解风格。 - StefanKarpinski

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接