最有效的测试lambda表达式相等性的方法

55

给定一个方法签名:

public bool AreTheSame<T>(Expression<Func<T, object>> exp1, Expression<Func<T, object>> exp2)

如何高效地判断两个表达式是否相同?这只需要适用于简单的表达式,我指的是仅支持简单的MemberExpressions,例如c => c.ID。

一个示例调用可能是:

AreTheSame<User>(u1 => u1.ID, u2 => u2.ID); --> would return true

我认为一个基本的问题是,表达式是否像匿名类型一样,即使你在某个地方定义了一个相同的表达式,该表达式树是否会被运行时缓存,以便始终只有一个基础定义。这类似于享元模式以及C#中字符串的实现方式,还有在某种程度上匿名类,就我所知。 - jpierson
5个回答

57
更新:由于对我的解决方案的兴趣,我已经更新了代码,使其支持数组、新运算符和其他东西,并以更优雅的方式比较 ASTs。 这是 Marc 代码的改进版本,现在它作为一个 nuget 包可用: nuget package
public static class LambdaCompare
{
    public static bool Eq<TSource, TValue>(
        Expression<Func<TSource, TValue>> x,
        Expression<Func<TSource, TValue>> y)
    {
        return ExpressionsEqual(x, y, null, null);
    }

    public static bool Eq<TSource1, TSource2, TValue>(
        Expression<Func<TSource1, TSource2, TValue>> x,
        Expression<Func<TSource1, TSource2, TValue>> y)
    {
        return ExpressionsEqual(x, y, null, null);
    }

    public static Expression<Func<Expression<Func<TSource, TValue>>, bool>> Eq<TSource, TValue>(Expression<Func<TSource, TValue>> y)
    {
        return x => ExpressionsEqual(x, y, null, null);
    }

    private static bool ExpressionsEqual(Expression x, Expression y, LambdaExpression rootX, LambdaExpression rootY)
    {
        if (ReferenceEquals(x, y)) return true;
        if (x == null || y == null) return false;

        var valueX = TryCalculateConstant(x);
        var valueY = TryCalculateConstant(y);

        if (valueX.IsDefined && valueY.IsDefined)
            return ValuesEqual(valueX.Value, valueY.Value);

        if (x.NodeType != y.NodeType
            || x.Type != y.Type)
        {
            if (IsAnonymousType(x.Type) && IsAnonymousType(y.Type))
                throw new NotImplementedException("Comparison of Anonymous Types is not supported");
            return false;
        }

        if (x is LambdaExpression)
        {
            var lx = (LambdaExpression)x;
            var ly = (LambdaExpression)y;
            var paramsX = lx.Parameters;
            var paramsY = ly.Parameters;
            return CollectionsEqual(paramsX, paramsY, lx, ly) && ExpressionsEqual(lx.Body, ly.Body, lx, ly);
        }
        if (x is MemberExpression)
        {
            var mex = (MemberExpression)x;
            var mey = (MemberExpression)y;
            return Equals(mex.Member, mey.Member) && ExpressionsEqual(mex.Expression, mey.Expression, rootX, rootY);
        }
        if (x is BinaryExpression)
        {
            var bx = (BinaryExpression)x;
            var by = (BinaryExpression)y;
            return bx.Method == @by.Method && ExpressionsEqual(bx.Left, @by.Left, rootX, rootY) &&
                   ExpressionsEqual(bx.Right, @by.Right, rootX, rootY);
        }
        if (x is UnaryExpression)
        {
            var ux = (UnaryExpression)x;
            var uy = (UnaryExpression)y;
            return ux.Method == uy.Method && ExpressionsEqual(ux.Operand, uy.Operand, rootX, rootY);
        }
        if (x is ParameterExpression)
        {
            var px = (ParameterExpression)x;
            var py = (ParameterExpression)y;
            return rootX.Parameters.IndexOf(px) == rootY.Parameters.IndexOf(py);
        }
        if (x is MethodCallExpression)
        {
            var cx = (MethodCallExpression)x;
            var cy = (MethodCallExpression)y;
            return cx.Method == cy.Method
                   && ExpressionsEqual(cx.Object, cy.Object, rootX, rootY)
                   && CollectionsEqual(cx.Arguments, cy.Arguments, rootX, rootY);
        }
        if (x is MemberInitExpression)
        {
            var mix = (MemberInitExpression)x;
            var miy = (MemberInitExpression)y;
            return ExpressionsEqual(mix.NewExpression, miy.NewExpression, rootX, rootY)
                   && MemberInitsEqual(mix.Bindings, miy.Bindings, rootX, rootY);
        }
        if (x is NewArrayExpression)
        {
            var nx = (NewArrayExpression)x;
            var ny = (NewArrayExpression)y;
            return CollectionsEqual(nx.Expressions, ny.Expressions, rootX, rootY);
        }
        if (x is NewExpression)
        {
            var nx = (NewExpression)x;
            var ny = (NewExpression)y;
            return
                Equals(nx.Constructor, ny.Constructor)
                && CollectionsEqual(nx.Arguments, ny.Arguments, rootX, rootY)
                && (nx.Members == null && ny.Members == null
                    || nx.Members != null && ny.Members != null && CollectionsEqual(nx.Members, ny.Members));
        }
        if (x is ConditionalExpression)
        {
            var cx = (ConditionalExpression)x;
            var cy = (ConditionalExpression)y;
            return
                ExpressionsEqual(cx.Test, cy.Test, rootX, rootY)
                && ExpressionsEqual(cx.IfFalse, cy.IfFalse, rootX, rootY)
                && ExpressionsEqual(cx.IfTrue, cy.IfTrue, rootX, rootY);
        }

        throw new NotImplementedException(x.ToString());
    }

    private static Boolean IsAnonymousType(Type type)
    {
        Boolean hasCompilerGeneratedAttribute = type.GetCustomAttributes(typeof(CompilerGeneratedAttribute), false).Any();
        Boolean nameContainsAnonymousType = type.FullName.Contains("AnonymousType");
        Boolean isAnonymousType = hasCompilerGeneratedAttribute && nameContainsAnonymousType;

        return isAnonymousType;
    }

    private static bool MemberInitsEqual(ICollection<MemberBinding> bx, ICollection<MemberBinding> by, LambdaExpression rootX, LambdaExpression rootY)
    {
        if (bx.Count != by.Count)
        {
            return false;
        }

        if (bx.Concat(by).Any(b => b.BindingType != MemberBindingType.Assignment))
            throw new NotImplementedException("Only MemberBindingType.Assignment is supported");

        return
            bx.Cast<MemberAssignment>().OrderBy(b => b.Member.Name).Select((b, i) => new { Expr = b.Expression, b.Member, Index = i })
            .Join(
                  by.Cast<MemberAssignment>().OrderBy(b => b.Member.Name).Select((b, i) => new { Expr = b.Expression, b.Member, Index = i }),
                  o => o.Index, o => o.Index, (xe, ye) => new { XExpr = xe.Expr, XMember = xe.Member, YExpr = ye.Expr, YMember = ye.Member })
                   .All(o => Equals(o.XMember, o.YMember) && ExpressionsEqual(o.XExpr, o.YExpr, rootX, rootY));
    }

    private static bool ValuesEqual(object x, object y)
    {
        if (ReferenceEquals(x, y))
            return true;
        if (x is ICollection && y is ICollection)
            return CollectionsEqual((ICollection)x, (ICollection)y);

        return Equals(x, y);
    }

    private static ConstantValue TryCalculateConstant(Expression e)
    {
        if (e is ConstantExpression)
            return new ConstantValue(true, ((ConstantExpression)e).Value);
        if (e is MemberExpression)
        {
            var me = (MemberExpression)e;
            var parentValue = TryCalculateConstant(me.Expression);
            if (parentValue.IsDefined)
            {
                var result =
                    me.Member is FieldInfo
                        ? ((FieldInfo)me.Member).GetValue(parentValue.Value)
                        : ((PropertyInfo)me.Member).GetValue(parentValue.Value);
                return new ConstantValue(true, result);
            }
        }
        if (e is NewArrayExpression)
        {
            var ae = ((NewArrayExpression)e);
            var result = ae.Expressions.Select(TryCalculateConstant);
            if (result.All(i => i.IsDefined))
                return new ConstantValue(true, result.Select(i => i.Value).ToArray());
        }
        if (e is ConditionalExpression)
        {
            var ce = (ConditionalExpression)e;
            var evaluatedTest = TryCalculateConstant(ce.Test);
            if (evaluatedTest.IsDefined)
            {
                return TryCalculateConstant(Equals(evaluatedTest.Value, true) ? ce.IfTrue : ce.IfFalse);
            }
        }

        return default(ConstantValue);
    }

    private static bool CollectionsEqual(IEnumerable<Expression> x, IEnumerable<Expression> y, LambdaExpression rootX, LambdaExpression rootY)
    {
        return x.Count() == y.Count()
               && x.Select((e, i) => new { Expr = e, Index = i })
                   .Join(y.Select((e, i) => new { Expr = e, Index = i }),
                         o => o.Index, o => o.Index, (xe, ye) => new { X = xe.Expr, Y = ye.Expr })
                   .All(o => ExpressionsEqual(o.X, o.Y, rootX, rootY));
    }

    private static bool CollectionsEqual(ICollection x, ICollection y)
    {
        return x.Count == y.Count
               && x.Cast<object>().Select((e, i) => new { Expr = e, Index = i })
                   .Join(y.Cast<object>().Select((e, i) => new { Expr = e, Index = i }),
                         o => o.Index, o => o.Index, (xe, ye) => new { X = xe.Expr, Y = ye.Expr })
                   .All(o => Equals(o.X, o.Y));
    }

    private struct ConstantValue
    {
        public ConstantValue(bool isDefined, object value)
            : this()
        {
            IsDefined = isDefined;
            Value = value;
        }

        public bool IsDefined { get; private set; }

        public object Value { get; private set; }
    }
}

请注意,它不会比较完整的AST。相反,它会折叠常量表达式并比较它们的值而不是它们的AST。 当lambda表达式引用本地变量时,它对于模拟验证非常有用。在这种情况下,通过值比较变量。

单元测试:

[TestClass]
public class Tests
{
    [TestMethod]
    public void BasicConst()
    {
        var f1 = GetBasicExpr1();
        var f2 = GetBasicExpr2();
        Assert.IsTrue(LambdaCompare.Eq(f1, f2));
    }

    [TestMethod]
    public void PropAndMethodCall()
    {
        var f1 = GetPropAndMethodExpr1();
        var f2 = GetPropAndMethodExpr2();
        Assert.IsTrue(LambdaCompare.Eq(f1, f2));
    }

    [TestMethod]
    public void MemberInitWithConditional()
    {
        var f1 = GetMemberInitExpr1();
        var f2 = GetMemberInitExpr2();
        Assert.IsTrue(LambdaCompare.Eq(f1, f2));
    }

    [TestMethod]
    public void AnonymousType()
    {
        var f1 = GetAnonymousExpr1();
        var f2 = GetAnonymousExpr2();
        Assert.Inconclusive("Anonymous Types are not supported");
    }

    private static Expression<Func<int, string, string>> GetBasicExpr2()
    {
        var const2 = "some const value";
        var const3 = "{0}{1}{2}{3}";
        return (i, s) =>
            string.Format(const3, (i + 25).ToString(CultureInfo.InvariantCulture), i + s, const2.ToUpper(), 25);
    }

    private static Expression<Func<int, string, string>> GetBasicExpr1()
    {
        var const1 = 25;
        return (first, second) =>
            string.Format("{0}{1}{2}{3}", (first + const1).ToString(CultureInfo.InvariantCulture), first + second,
                "some const value".ToUpper(), const1);
    }

    private static Expression<Func<Uri, bool>> GetPropAndMethodExpr2()
    {
        return u => Uri.IsWellFormedUriString(u.ToString(), UriKind.Absolute);
    }

    private static Expression<Func<Uri, bool>> GetPropAndMethodExpr1()
    {
        return arg1 => Uri.IsWellFormedUriString(arg1.ToString(), UriKind.Absolute);
    }

    private static Expression<Func<Uri, UriBuilder>> GetMemberInitExpr2()
    {
        var isSecure = true;
        return u => new UriBuilder(u) { Host = string.IsNullOrEmpty(u.Host) ? "abc" : "def" , Port = isSecure ? 443 : 80 };
    }

    private static Expression<Func<Uri, UriBuilder>> GetMemberInitExpr1()
    {
        var port = 443;
        return x => new UriBuilder(x) { Port = port, Host = string.IsNullOrEmpty(x.Host) ? "abc" : "def" };
    }

    private static Expression<Func<Uri, object>> GetAnonymousExpr2()
    {
        return u => new { u.Host , Port = 443, Addr = u.AbsolutePath };
    }

    private static Expression<Func<Uri, object>> GetAnonymousExpr1()
    {
        return x => new { Port = 443, x.Host, Addr = x.AbsolutePath };
    }
}

'AST' 是什么意思? - Aage
2
为什么TryCalculateConstant会同时崩溃成员访问?常量对象的属性会改变。如果它旨在成为通用比较器,这是不好的。在我看来,如果消费者想要决定将常量的属性访问折叠到常量上,则消费者应该在进行比较之前使用访问者对表达式进行变异,并对此负责。 - jnm2
1
我最终将它移动到了Github和nuget,尽情享用吧! - neleus
1
这真是太棒了! - Pratik Bhattacharya
1
@Aage AST 意味着抽象语法树,它是一种对表达式进行建模的方式。 - Thoker
显示剩余6条评论

42

嗯......我想你需要解析树,检查每个节点类型和成员。我会写一个示例...

using System;
using System.Linq.Expressions;
class Test {
    public string Foo { get; set; }
    public string Bar { get; set; }
    static void Main()
    {
        bool test1 = FuncTest<Test>.FuncEqual(x => x.Bar, y => y.Bar),
            test2 = FuncTest<Test>.FuncEqual(x => x.Foo, y => y.Bar);
    }

}
// this only exists to make it easier to call, i.e. so that I can use FuncTest<T> with
// generic-type-inference; if you use the doubly-generic method, you need to specify
// both arguments, which is a pain...
static class FuncTest<TSource>
{
    public static bool FuncEqual<TValue>(
        Expression<Func<TSource, TValue>> x,
        Expression<Func<TSource, TValue>> y)
    {
        return FuncTest.FuncEqual<TSource, TValue>(x, y);
    }
}
static class FuncTest {
    public static bool FuncEqual<TSource, TValue>(
        Expression<Func<TSource,TValue>> x,
        Expression<Func<TSource,TValue>> y)
    {
        return ExpressionEqual(x, y);
    }
    private static bool ExpressionEqual(Expression x, Expression y)
    {
        // deal with the simple cases first...
        if (ReferenceEquals(x, y)) return true;
        if (x == null || y == null) return false;
        if (   x.NodeType != y.NodeType
            || x.Type != y.Type ) return false;

        switch (x.NodeType)
        {
            case ExpressionType.Lambda:
                return ExpressionEqual(((LambdaExpression)x).Body, ((LambdaExpression)y).Body);
            case ExpressionType.MemberAccess:
                MemberExpression mex = (MemberExpression)x, mey = (MemberExpression)y;
                return mex.Member == mey.Member; // should really test down-stream expression
            default:
                throw new NotImplementedException(x.NodeType.ToString());
        }
    }
}

5
表达方式让我感到恼火,它们很有力却又缺失了很多内容。点个赞上传代码。 - Andrew Bullock
@MarcGravell,这个方法很棒,但有一个限制。你的比较方法假设lambda的返回类型必须相同(在这种情况下是string)。如果FooBar是不同的类型,为了使其通用,你需要使用TValue1TValue2。你能更新一下你的答案吗? - KFL
不递归比较 MemberExpression.Expression 会导致一些恶意 bug。此外,lambda 的 ParameterExpression 集合可以进行递归比较。 - jnm2

4

一个标准的解决方案会很棒。同时,我创建了一个 IEqualityComparer<Expression> 版本。

这是一个相当冗长的实现,因此我 为它创建了一个 gist

它旨在成为一个全面的抽象语法树比较器。为此,它比较了每种表达式类型,包括尚未被 C# 支持的表达式,如 TrySwitchBlock。它不比较的唯一类型是 GotoLabelLoopDebugInfo,因为我对它们的了解有限。

您可以指定是否以及如何比较参数和 Lambda 的名称,以及如何处理 ConstantExpression

它通过上下文按位置跟踪参数。支持嵌套的 Lambda 和 catch 块变量参数。


2

我知道这是一个老问题,但我自己编写了表达式树相等比较器 - https://github.com/yesmarket/yesmarket.Linq.Expressions

实现中广泛使用ExpressionVisitor类来确定两个表达式树是否相等。随着遍历表达式树中的节点,将比较单个节点是否相等。


0

我认为当你使用lambda-efficient collection - 我的意思是可以单独实现每个列的IEnumerable来枚举一个或多个选定列,从而实现对Lambdas的最大效率利用 - 这样可以实现column-based collection - 让我们称其为第一步 ;)

这只是我的想法,我希望有一天可以实现。目前我还没有任何线索,但我想许多人会同意,通过单变量列表枚举并检查值相比于在对象列表中检查某些属性就像证明它本身一样。

下一第二步是从函数式编程中获取更多:使用sorted-listhash-table或任何其他search-efficient collection作为表示列的集合。

class LambdaReadyColumn<int> : HashTable<int> 

还有一个第三步,使用一些指针将列之间的项目连接起来,这样就不必将底层列保持在下面:

class LambdaReadyColumn<int> : IEnumabrable<int> 

将数据保存在更接近的位置:

class LambdaReadyColumn<LambdaReadyColumnItem<T, int>> : IEnumabrable<int> 
//with example constructor like: 
public LambdaReadyColumn<LambdaReadyColumnItem<T, int>>(Hash, LambdaReadyColumnItem, LambdaReadyColumnItem, T, int); 

其中:

  • CollectionItem - 引用相同 T 项的右列和左列项
  • Hash - 为了加快搜索速度
  • int - 表示列的类型
  • T - 第四步 整个 LambdaReadyCollection 在列集合中轻松理解,只需在返回项描述符上进行 Select(T) 即可更快地避免遍历具有许多属性的项目的少数引用。

最后,通过所有步骤,我们拥有双重数据的集合:基于行和基于集合,此外还有大量的引用数据。 当然,基于行的哈希表可以保留数据,而基于列的仅保留引用,但是从语句中构建每个集合以返回将使用大量引用。

要实现这一点,您需要使用反射、动态类型或其他高级技术,具体取决于语言。


网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接