LINQ性能:Count vs Where和Count

37
public class Group
{
   public string Name { get; set; }
}  

测试:

List<Group> _groups = new List<Group>();

for (int i = 0; i < 10000; i++)
{
    var group = new Group();

    group.Name = i + "asdasdasd";
    _groups.Add(group);
}

Stopwatch _stopwatch2 = new Stopwatch();

_stopwatch2.Start();
foreach (var group in _groups)
{
    var count = _groups.Count(x => x.Name == group.Name);
}
_stopwatch2.Stop();

Console.WriteLine(_stopwatch2.ElapsedMilliseconds);
Stopwatch _stopwatch = new Stopwatch();

_stopwatch.Start();
foreach (var group in _groups)
{
    var count = _groups.Where(x => x.Name == group.Name).Count();
}
_stopwatch.Stop();

Console.WriteLine(_stopwatch.ElapsedMilliseconds);

结果:第一个方法返回2863,第二个方法返回2185。

有人能解释一下为什么第一个方法比第二个方法慢吗?第二个方法应该返回枚举器并在其上调用Count,而第一个方法只是调用Count。第一个方法应该更快一点。

编辑:我删除了计数器列表以防止使用GC,并改变了顺序以检查排序是否有意义。结果几乎相同。

编辑2:这个性能问题不仅与Count有关。它与First(),FirstOrDefault(),Any()等有关。其中+方法总是比Method更快。


13
你的测量不准确——你没有进行热身。 - Display Name
你是如何定义 counters 和 counters2 的?我看到声明它们时使用了初始大小(例如:List<int> counter = new List<int>(500);),这会影响性能。你可能以不同的方式声明了它们吗?其次,在两个循环之间执行 GC.Collect 是否有用 - 可能是 GC 正在启动并扭曲了你的结果(这真的是一个非常长的尝试..) - PhillipH
counters 是什么类型?我猜是 List<bool>,但这并没有太多意义。另外,我认为 Skip(1).Any() 会更快。 - Oliver
我编辑了帖子,没有计数器列表。 - MistyK
1
@Zbigniew 确认一下:这是在发布模式下运行的,对吧?尝试交换这两个测试。这会改变结果吗? - usr
显示剩余9条评论
6个回答

23

关键是在Where()的实现中,如果可以,它将IEnumerable转换为List<T>。请注意,在构造WhereListIterator时进行了转换(这是通过反射获取的.Net源代码):

public static IEnumerable<TSource> Where<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate) {
    if (source is List<TSource>) return new WhereListIterator<TSource>((List<TSource>)source, predicate);
    return new WhereEnumerableIterator<TSource>(source, predicate);
}

我已通过复制 (并在可能的情况下简化) .Net 实现来验证这一点。

关键是,我实现了两个版本的 Count() - 一个叫做 TestCount(),其中我使用 IEnumerable<T>,另一个叫做 TestListCount(),我在计算项目数量之前将可枚举对象强制转换为 List<T>

这给出了与我们看到的 Where() 操作符相同的加速效果,后者 (如上所示) 在可以的情况下也会强制转换为 List<T>

(应该尝试在没有调试器附加的发布构建中运行此操作。)

这证明了使用 foreach 遍历 List<T> 要比以同样的顺序表示的 IEnumerable<T> 更快。

首先,这是完整的测试代码:

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;

namespace Demo
{
    public class Group
    {
        public string Name
        {
            get;
            set;
        }
    }

    internal static class Program
    {
        static void Main()
        {
            int dummy = 0;
            List<Group> groups = new List<Group>();

            for (int i = 0; i < 10000; i++)
            {
                var group = new Group();

                group.Name = i + "asdasdasd";
                groups.Add(group);
            }

            Stopwatch stopwatch = new Stopwatch();

            for (int outer = 0; outer < 4; ++outer)
            {
                stopwatch.Restart();

                foreach (var group in groups)
                    dummy += TestWhere(groups, x => x.Name == group.Name).Count();

                Console.WriteLine("Using TestWhere(): " + stopwatch.ElapsedMilliseconds);

                stopwatch.Restart();

                foreach (var group in groups)
                    dummy += TestCount(groups, x => x.Name == group.Name);

                Console.WriteLine("Using TestCount(): " + stopwatch.ElapsedMilliseconds);

                stopwatch.Restart();

                foreach (var group in groups)
                    dummy += TestListCount(groups, x => x.Name == group.Name);

                Console.WriteLine("Using TestListCount(): " + stopwatch.ElapsedMilliseconds);
            }

            Console.WriteLine("Total = " + dummy);
        }

        public static int TestCount<TSource>(IEnumerable<TSource> source, Func<TSource, bool> predicate)
        {
            int count = 0;

            foreach (TSource element in source)
            {
                if (predicate(element)) 
                    count++;
            }

            return count;
        }

        public static int TestListCount<TSource>(IEnumerable<TSource> source, Func<TSource, bool> predicate)
        {
            return testListCount((List<TSource>) source, predicate);
        }

        private static int testListCount<TSource>(List<TSource> source, Func<TSource, bool> predicate)
        {
            int count = 0;

            foreach (TSource element in source)
            {
                if (predicate(element))
                    count++;
            }

            return count;
        }

        public static IEnumerable<TSource> TestWhere<TSource>(IEnumerable<TSource> source, Func<TSource, bool> predicate)
        {
            return new WhereListIterator<TSource>((List<TSource>)source, predicate);
        }
    }

    class WhereListIterator<TSource>: Iterator<TSource>
    {
        readonly Func<TSource, bool> predicate;
        List<TSource>.Enumerator enumerator;

        public WhereListIterator(List<TSource> source, Func<TSource, bool> predicate)
        {
            this.predicate = predicate;
            this.enumerator = source.GetEnumerator();
        }

        public override bool MoveNext()
        {
            while (enumerator.MoveNext())
            {
                TSource item = enumerator.Current;
                if (predicate(item))
                {
                    current = item;
                    return true;
                }
            }
            Dispose();

            return false;
        }
    }

    abstract class Iterator<TSource>: IEnumerable<TSource>, IEnumerator<TSource>
    {
        internal TSource current;

        public TSource Current
        {
            get
            {
                return current;
            }
        }

        public virtual void Dispose()
        {
            current = default(TSource);
        }

        public IEnumerator<TSource> GetEnumerator()
        {
            return this;
        }

        public abstract bool MoveNext();

        object IEnumerator.Current
        {
            get
            {
                return Current;
            }
        }

        IEnumerator IEnumerable.GetEnumerator()
        {
            return GetEnumerator();
        }

        void IEnumerator.Reset()
        {
            throw new NotImplementedException();
        }
    }
}

现在这里是两个关键方法的IL生成代码:TestCount()testListCount()。请记住,它们之间唯一的区别是 TestCount() 使用了 IEnumerable<T>,而 testListCount() 则是使用同样的可枚举对象,但将其强制转换为其底层的List<T>类型:

TestCount():

.method public hidebysig static int32 TestCount<TSource>(class [mscorlib]System.Collections.Generic.IEnumerable`1<!!TSource> source, class [mscorlib]System.Func`2<!!TSource, bool> predicate) cil managed
{
    .maxstack 8
    .locals init (
        [0] int32 count,
        [1] !!TSource element,
        [2] class [mscorlib]System.Collections.Generic.IEnumerator`1<!!TSource> CS$5$0000)
    L_0000: ldc.i4.0 
    L_0001: stloc.0 
    L_0002: ldarg.0 
    L_0003: callvirt instance class [mscorlib]System.Collections.Generic.IEnumerator`1<!0> [mscorlib]System.Collections.Generic.IEnumerable`1<!!TSource>::GetEnumerator()
    L_0008: stloc.2 
    L_0009: br L_0025
    L_000e: ldloc.2 
    L_000f: callvirt instance !0 [mscorlib]System.Collections.Generic.IEnumerator`1<!!TSource>::get_Current()
    L_0014: stloc.1 
    L_0015: ldarg.1 
    L_0016: ldloc.1 
    L_0017: callvirt instance !1 [mscorlib]System.Func`2<!!TSource, bool>::Invoke(!0)
    L_001c: brfalse L_0025
    L_0021: ldloc.0 
    L_0022: ldc.i4.1 
    L_0023: add.ovf 
    L_0024: stloc.0 
    L_0025: ldloc.2 
    L_0026: callvirt instance bool [mscorlib]System.Collections.IEnumerator::MoveNext()
    L_002b: brtrue.s L_000e
    L_002d: leave L_003f
    L_0032: ldloc.2 
    L_0033: brfalse L_003e
    L_0038: ldloc.2 
    L_0039: callvirt instance void [mscorlib]System.IDisposable::Dispose()
    L_003e: endfinally 
    L_003f: ldloc.0 
    L_0040: ret 
    .try L_0009 to L_0032 finally handler L_0032 to L_003f
}


testListCount():

.method private hidebysig static int32 testListCount<TSource>(class [mscorlib]System.Collections.Generic.List`1<!!TSource> source, class [mscorlib]System.Func`2<!!TSource, bool> predicate) cil managed
{
    .maxstack 8
    .locals init (
        [0] int32 count,
        [1] !!TSource element,
        [2] valuetype [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource> CS$5$0000)
    L_0000: ldc.i4.0 
    L_0001: stloc.0 
    L_0002: ldarg.0 
    L_0003: callvirt instance valuetype [mscorlib]System.Collections.Generic.List`1/Enumerator<!0> [mscorlib]System.Collections.Generic.List`1<!!TSource>::GetEnumerator()
    L_0008: stloc.2 
    L_0009: br L_0026
    L_000e: ldloca.s CS$5$0000
    L_0010: call instance !0 [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::get_Current()
    L_0015: stloc.1 
    L_0016: ldarg.1 
    L_0017: ldloc.1 
    L_0018: callvirt instance !1 [mscorlib]System.Func`2<!!TSource, bool>::Invoke(!0)
    L_001d: brfalse L_0026
    L_0022: ldloc.0 
    L_0023: ldc.i4.1 
    L_0024: add.ovf 
    L_0025: stloc.0 
    L_0026: ldloca.s CS$5$0000
    L_0028: call instance bool [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::MoveNext()
    L_002d: brtrue.s L_000e
    L_002f: leave L_0042
    L_0034: ldloca.s CS$5$0000
    L_0036: constrained [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>
    L_003c: callvirt instance void [mscorlib]System.IDisposable::Dispose()
    L_0041: endfinally 
    L_0042: ldloc.0 
    L_0043: ret 
    .try L_0009 to L_0034 finally handler L_0034 to L_0042
}

我认为这里最重要的是调用IEnumerator::GetCurrent()IEnumerator::MoveNext()方法。

第一种情况是:

callvirt instance !0 [mscorlib]System.Collections.Generic.IEnumerator`1<!!TSource>::get_Current()
callvirt instance bool [mscorlib]System.Collections.IEnumerator::MoveNext()

在第二种情况下,它是:

call instance !0 [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::get_Current()
call instance bool [mscorlib]System.Collections.Generic.List`1/Enumerator<!!TSource>::MoveNext()

值得注意的是,在第二种情况下进行的是非虚函数调用,如果它在循环中(当然是在循环中),这比虚函数调用快得多。


1
太棒了!肯定是callcallvirt的原因。Where迭代器避免了间接的虚拟表调用,而是直接调用迭代器方法。感谢您的调查。 - Pavel Gatilov

5

在我看来,差别在于Linq扩展的编码方式。我怀疑Where使用了List<>类中的优化来加速操作,但Count只是遍历一个IEnumerable<>

如果您使用相同的过程,但使用IEnumerable,那么两种方法都很接近,Where略慢一些。

List<Group> _groups = new List<Group>();

for (int i = 0; i < 10000; i++)
{
    var group = new Group();

    group.Name = i + "asdasdasd";
    _groups.Add(group);
}

IEnumerable<Group> _groupsEnumerable = from g in _groups select g;

Stopwatch _stopwatch2 = new Stopwatch();

_stopwatch2.Start();
foreach (var group in _groups)
{
    var count = _groupsEnumerable.Count(x => x.Name == group.Name);
}
_stopwatch2.Stop();

Console.WriteLine(_stopwatch2.ElapsedMilliseconds);
Stopwatch _stopwatch = new Stopwatch();

_stopwatch.Start();
foreach (var group in _groups)
{
    var count = _groupsEnumerable.Where(x => x.Name == group.Name).Count();
}
_stopwatch.Stop();

Console.WriteLine(_stopwatch.ElapsedMilliseconds);

扩展方法在哪里。请注意if (source is List<TSource>)的情况:

public static IEnumerable<TSource> Where<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
{
    if (source == null)
    {
        throw Error.ArgumentNull("source");
    }
    if (predicate == null)
    {
        throw Error.ArgumentNull("predicate");
    }
    if (source is Enumerable.Iterator<TSource>)
    {
        return ((Enumerable.Iterator<TSource>)source).Where(predicate);
    }
    if (source is TSource[])
    {
        return new Enumerable.WhereArrayIterator<TSource>((TSource[])source, predicate);
    }
    if (source is List<TSource>)
    {
        return new Enumerable.WhereListIterator<TSource>((List<TSource>)source, predicate);
    }
    return new Enumerable.WhereEnumerableIterator<TSource>(source, predicate);
}

Count方法。只需遍历IEnumerable:

public static int Count<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate)
{
    if (source == null)
    {
        throw Error.ArgumentNull("source");
    }
    if (predicate == null)
    {
        throw Error.ArgumentNull("predicate");
    }
    int num = 0;
    checked
    {
        foreach (TSource current in source)
        {
            if (predicate(current))
            {
                num++;
            }
        }
        return num;
    }
}

好发现!根据反编译器的显示,Where 对于 List<T> 有一个特殊情况。Where 占用了 99% 的运行时间。在这个测试中,Count 只与一个项目相关。将谓词 x => true 使得 Count 版本再次变快。 - usr
1
这并没有解释使用了哪些“优化”。据我所知,在两种情况下它仍然必须迭代列表。您能解释一下是哪种优化使得差异? - Pavel Gatilov
@Yuval Itzchakov。你说得对,那个编辑并不一定是针对Pavel Gatilov的。毫无疑问,这个列表必须是差异。 - Wyatt Earp
调用List是硬编码的。没有虚拟调用。可内联。否则为什么要添加特殊情况?它是出于性能原因而添加的。 - usr
我希望微软可以定义一个继承自IEnumerator<T>的'IEnhancedEnumerator<T>'类型并使像List<T>这样的类型的枚举器实现它。 这样的功能将使得可以有效地操作由Skip()Concat()等方法包装的集合,或者被ReadOnlyCollection<T>类型封装的集合。 事实上,几乎没有一种方法可以在阻止除按项顺序迭代之外的所有访问手段的情况下包装或封装集合。 - supercat
显示剩余3条评论

2

接着Matthew Watson的回答:

List<T>遍历时生成call指令而不是像IEnumerable<T>那样使用callvirt,原因是C#中的foreach语句是鸭子类型。

C#语言规范8.8.4节指出编译器“确定类型X是否具有适当的GetEnumerator方法”,这优先于可枚举接口。因此,在这里,foreach语句使用返回List<T>.EnumeratorList<T>.GetEnumerator重载版本,而不是返回IEnumerable<T>或仅返回IEnumerable的版本。

编译器还检查GetEnumerator返回的类型是否具有一个不带参数的Current属性和MoveNext方法。对于List<T>.Enumerator,这些方法没有标记为virtual,所以编译器可以编译直接调用。相比之下,在IEnumerator<T>中它们被标记为virtual,所以编译器必须生成一个callvirt指令。通过虚函数表调用的额外开销解释了性能差异。


1

我的猜测:

.Where()使用特殊的"WhereListIterator"来迭代元素,而Count()则不使用,正如Wyatt Earp所指出的那样。有趣的是,迭代器被标记为"ngenable":

 [TargetedPatchingOptOut("Performance critical to inline this type of method across NGen image boundaries")]
 public WhereListIterator(List<TSource> source, Func<TSource, bool> predicate)
 {
   this.source = source;
   this.predicate = predicate;
 }

这可能意味着“迭代器”部分作为“非托管代码”运行,而Count()作为托管代码运行。我不知道这是否有意义/如何证明它,但这是我的想法。
此外,如果您仔细重写Count()以处理List,
您可以使它相同/甚至更快:
public static class TestExt{
   public static int CountFaster<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate) {
       if (source == null) throw new Exception();
       if (predicate == null) throw new Exception();

       if(source is List<TSource>)
       {
                int finalCount=0;
                var list = (List<TSource>)source;
                var count = list.Count;
                for(var j = 0; j < count; j++){
                    if(predicate(list[j])) 
                        finalCount++;
                }
                return finalCount;
       }


       return source.Count(predicate);
   }

}

在我的测试中,使用CountFaster()后,被称为LATER的人赢了(因为冷启动)。

我认为这没有意义,最终所有的代码都是作为非托管代码运行。 - svick
我制作了一个没有属性的测试应用程序,但它使用了.Net实现,更复杂的实现仍然更快! - Matthew Watson
好吧,从某种意义上说这很重要,因为如果两个代码都是最优的,那么ngen的代码在冷启动时会获胜。附:@MatthewWatson;我已经添加了实现,在我的机器上,这两个实现现在是相同的。 - Erti-Chris Eelmaa

0
根据@Matthew Watson的帖子,我检查了一些行为。在我的例子中,“Where”总是返回空集合,因此甚至没有在接口IEnumerable上调用Count(这比枚举List元素慢得多)。我不是添加所有具有不同名称的组,而是添加所有具有相同名称的项目。然后,Count比Count + Method更快。这是因为在Count方法中,我们在接口IEnumerable上枚举所有项。在Method + Count方法中,如果所有项都相同,“Where”返回整个集合(转换为IEnumerable接口),并调用Count(),因此Where调用是多余的,或者可以说-它会减慢事情的速度。
总之,在这个例子中特定的情况导致我得出结论,即Method + Where始终更快,但这并不正确。如果“Where”返回的集合与原始集合相差不大,则“Method + Where方法”将更慢。

-3

Sarge Borsch在评论中给出了正确的答案,但没有进一步的解释。

问题在于字节码必须在第一次运行时由JIT编译器编译为x86。因此,您的测量包括了您想要测试的内容和编译时间。由于第二个测试使用的大多数东西在第一个测试期间已经被编译(列表枚举器、Name属性getter等),因此第一个测试受到编译的影响更大。

解决方案是进行“热身”:您先运行代码一次,不进行测量,通常只进行一次迭代,以便将其编译。然后您开始计时并真正运行它,需要进行足够长时间的迭代(例如一秒钟)。


这不可能是对的 - 我测试过,通过将所有初始化移动到内部循环外部,并多次循环整个测试,结果相同。 - Matthew Watson
JIT时间并不能解释700毫秒的差异。好的基准测试运行时间较长,以确保没有一次性效应的影响。这是一个很好的基准测试。 - usr
我尝试了另一种排序方式,但无论如何,“Where + Method” 总是更快。 - MistyK
我也测试过了,就像@Zbigniew一样,得出了相同的结果。 - Dirk

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接