如何将表达式树转换为部分 SQL 查询?

51
当EF或LINQ to SQL运行一个查询时,它会执行以下操作:
  1. 从代码中构建一个表达式树(expression tree),
  2. 将表达式树转换为SQL查询语句,
  3. 执行查询,从数据库获得原始结果并将其转换为应用程序使用的结果。
查看堆栈跟踪,我无法确定第二部分发生在哪里。
一般来说,是否可以使用EF或(最好是)LINQ to SQL中现有的部分来将Expression对象转换为部分SQL查询(使用Transact-SQL语法),还是我必须重新发明轮子? 更新:一个评论要求提供我要做什么的示例。
实际上,下面Ryan Wright的答案完美地阐述了我想要达到的结果,只不过我的问题具体是关于如何使用.NET Framework中EF和LINQ to SQL实际使用的现有机制来完成,而不是不得不重新发明轮子并编写成千上万行未经太多测试的代码来完成类似的事情。
这里还有一个示例。请注意,没有ORM生成的代码。
private class Product
{
    [DatabaseMapping("ProductId")]
    public int Id { get; set; }

    [DatabaseMapping("Price")]
    public int PriceInCents { get; set; }
}

private string Convert(Expression expression)
{
    // Some magic calls to .NET Framework code happen here.
    // [...]
}

private void TestConvert()
{
    Expression<Func<Product, int, int, bool>> inPriceRange =
        (Product product, int from, int to) =>
            product.PriceInCents >= from && product.PriceInCents <= to;

    string actualQueryPart = this.Convert(inPriceRange);

    Assert.AreEqual("[Price] between @from and @to", actualQueryPart);
}
在预期的查询中,Price这个名称是从哪里来的?

可以通过反射查询Product类的Price属性的自定义DatabaseMapping属性来获取该名称。

预期查询中的名称@from@to来自何处?

这些名称是表达式参数的实际名称。

预期查询中的between … and语句来自何处?

这是二元表达式的可能结果。也许EF或LINQ to SQL不会使用between … and语句, 而是使用[Price] >= @from and [Price] <= @to语句,两种方式逻辑上都是一致的(我没有提到性能)。

为什么预期查询中没有where关键字?

因为在Expression中没有任何指示必须有where关键字的内容。也许实际表达式只是将与二元运算符组合后用于构建更大查询的表达式之一,并在前面添加where


你能否提供一个示例,说明你正在尝试将什么转换成什么? - Orion Adrian
我已经打开了一个相关问题的讨论,欢迎加入:https://github.com/aspnet/AspNetCore/issues/13465 - Mehdi Dehghani
9个回答

58

是的,这是可能的,您可以使用访问者模式解析LINQ表达式树。您需要通过子类化ExpressionVisitor构造查询转换器,如下所示。通过在正确的点上插入钩子,您可以使用转换器从LINQ表达式构造SQL字符串。请注意,下面的代码仅处理基本的where/orderby/skip/take子句,但您可以根据需要填充它。希望它作为一个良好的第一步。

public class MyQueryTranslator : ExpressionVisitor
{
    private StringBuilder sb;
    private string _orderBy = string.Empty;
    private int? _skip = null;
    private int? _take = null;
    private string _whereClause = string.Empty;

    public int? Skip
    {
        get
        {
            return _skip;
        }
    }

    public int? Take
    {
        get
        {
            return _take;
        }
    }

    public string OrderBy
    {
        get
        {
            return _orderBy;
        }
    }

    public string WhereClause
    {
        get
        {
            return _whereClause;
        }
    }

    public MyQueryTranslator()
    {
    }

    public string Translate(Expression expression)
    {
        this.sb = new StringBuilder();
        this.Visit(expression);
        _whereClause = this.sb.ToString();
        return _whereClause;
    }

    private static Expression StripQuotes(Expression e)
    {
        while (e.NodeType == ExpressionType.Quote)
        {
            e = ((UnaryExpression)e).Operand;
        }
        return e;
    }

    protected override Expression VisitMethodCall(MethodCallExpression m)
    {
        if (m.Method.DeclaringType == typeof(Queryable) && m.Method.Name == "Where")
        {
            this.Visit(m.Arguments[0]);
            LambdaExpression lambda = (LambdaExpression)StripQuotes(m.Arguments[1]);
            this.Visit(lambda.Body);
            return m;
        }
        else if (m.Method.Name == "Take")
        {
            if (this.ParseTakeExpression(m))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }
        else if (m.Method.Name == "Skip")
        {
            if (this.ParseSkipExpression(m))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }
        else if (m.Method.Name == "OrderBy")
        {
            if (this.ParseOrderByExpression(m, "ASC"))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }
        else if (m.Method.Name == "OrderByDescending")
        {
            if (this.ParseOrderByExpression(m, "DESC"))
            {
                Expression nextExpression = m.Arguments[0];
                return this.Visit(nextExpression);
            }
        }

        throw new NotSupportedException(string.Format("The method '{0}' is not supported", m.Method.Name));
    }

    protected override Expression VisitUnary(UnaryExpression u)
    {
        switch (u.NodeType)
        {
            case ExpressionType.Not:
                sb.Append(" NOT ");
                this.Visit(u.Operand);
                break;
            case ExpressionType.Convert:
                this.Visit(u.Operand);
                break;
            default:
                throw new NotSupportedException(string.Format("The unary operator '{0}' is not supported", u.NodeType));
        }
        return u;
    }


    /// <summary>
    /// 
    /// </summary>
    /// <param name="b"></param>
    /// <returns></returns>
    protected override Expression VisitBinary(BinaryExpression b)
    {
        sb.Append("(");
        this.Visit(b.Left);

        switch (b.NodeType)
        {
            case ExpressionType.And:
                sb.Append(" AND ");
                break;

            case ExpressionType.AndAlso:
                sb.Append(" AND ");
                break;

            case ExpressionType.Or:
                sb.Append(" OR ");
                break;

            case ExpressionType.OrElse:
                sb.Append(" OR ");
                break;

            case ExpressionType.Equal:
                if (IsNullConstant(b.Right))
                {
                    sb.Append(" IS ");
                }
                else
                {
                    sb.Append(" = ");
                }
                break;

            case ExpressionType.NotEqual:
                if (IsNullConstant(b.Right))
                {
                    sb.Append(" IS NOT ");
                }
                else
                {
                    sb.Append(" <> ");
                }
                break;

            case ExpressionType.LessThan:
                sb.Append(" < ");
                break;

            case ExpressionType.LessThanOrEqual:
                sb.Append(" <= ");
                break;

            case ExpressionType.GreaterThan:
                sb.Append(" > ");
                break;

            case ExpressionType.GreaterThanOrEqual:
                sb.Append(" >= ");
                break;

            default:
                throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", b.NodeType));

        }

        this.Visit(b.Right);
        sb.Append(")");
        return b;
    }

    protected override Expression VisitConstant(ConstantExpression c)
    {
        IQueryable q = c.Value as IQueryable;

        if (q == null && c.Value == null)
        {
            sb.Append("NULL");
        }
        else if (q == null)
        {
            switch (Type.GetTypeCode(c.Value.GetType()))
            {
                case TypeCode.Boolean:
                    sb.Append(((bool)c.Value) ? 1 : 0);
                    break;

                case TypeCode.String:
                    sb.Append("'");
                    sb.Append(c.Value);
                    sb.Append("'");
                    break;

                case TypeCode.DateTime:
                    sb.Append("'");
                    sb.Append(c.Value);
                    sb.Append("'");
                    break;

                case TypeCode.Object:
                    throw new NotSupportedException(string.Format("The constant for '{0}' is not supported", c.Value));

                default:
                    sb.Append(c.Value);
                    break;
            }
        }

        return c;
    }

    protected override Expression VisitMember(MemberExpression m)
    {
        if (m.Expression != null && m.Expression.NodeType == ExpressionType.Parameter)
        {
            sb.Append(m.Member.Name);
            return m;
        }

        throw new NotSupportedException(string.Format("The member '{0}' is not supported", m.Member.Name));
    }

    protected bool IsNullConstant(Expression exp)
    {
        return (exp.NodeType == ExpressionType.Constant && ((ConstantExpression)exp).Value == null);
    }

    private bool ParseOrderByExpression(MethodCallExpression expression, string order)
    {
        UnaryExpression unary = (UnaryExpression)expression.Arguments[1];
        LambdaExpression lambdaExpression = (LambdaExpression)unary.Operand;

        lambdaExpression = (LambdaExpression)Evaluator.PartialEval(lambdaExpression);

        MemberExpression body = lambdaExpression.Body as MemberExpression;
        if (body != null)
        {
            if (string.IsNullOrEmpty(_orderBy))
            {
                _orderBy = string.Format("{0} {1}", body.Member.Name, order);
            }
            else
            {
                _orderBy = string.Format("{0}, {1} {2}", _orderBy, body.Member.Name, order);
            }

            return true;
        }

        return false;
    }

    private bool ParseTakeExpression(MethodCallExpression expression)
    {
        ConstantExpression sizeExpression = (ConstantExpression)expression.Arguments[1];

        int size;
        if (int.TryParse(sizeExpression.Value.ToString(), out size))
        {
            _take = size;
            return true;
        }

        return false;
    }

    private bool ParseSkipExpression(MethodCallExpression expression)
    {
        ConstantExpression sizeExpression = (ConstantExpression)expression.Arguments[1];

        int size;
        if (int.TryParse(sizeExpression.Value.ToString(), out size))
        {
            _skip = size;
            return true;
        }

        return false;
    }
}

然后通过调用以下命令来访问表达式:

var translator = new MyQueryTranslator();
string whereClause = translator.Translate(expression);

1
虽然起步不错,但好像缺少字符串操作或比较字符串的能力。例如,包含(Contains),以……开始(StartsWith)等。再次说明,起步不错。 - Orion Adrian
15
这个“Evaluator”类位于哪里? - programad
4
需要一个评估器来解析表达式中的局部变量引用。我使用了这里提供的实现。同样值得一试的是,可以在Translate方法中部分地评估表达式(即在ParseOrderByExpression方法中删除对Evaluator的调用,如下所示): public string Translate(Expression expression) { expression = Evaluator.PartialEval(expression); - Peter
1
任何要使用它来生成where子句并将其附加到SQL的人,请注意在生成的子句中单引号未被转义。为此,请在VisitConstant方法中将sb.Append(c.Value);更改为sb.Append(c.Value.ToString.Replace('"","''")); - insomniac
1
对于那些询问“Evaluator”的人,@Peter提供的链接已经失效了,这里是你可以找到它的地方:https://github.com/mattwar/iqtoolkit/issues/18,请查看第三部分。 - Stacked
显示剩余5条评论

30
简短的回答似乎是,你不能使用 EF 或 LINQ to SQL 的一部分作为翻译的捷径。你需要至少一个 ObjectContext 子类来访问 internal protected QueryProvider 属性,这意味着需要创建上下文的所有开销,包括元数据等等。
假设你能接受这个,要获取部分 SQL 查询,例如仅 WHERE 子句,你基本上需要查询提供程序,并调用 IQueryProvider.CreateQuery(),就像 LINQ 在实现 Queryable.Where 时所做的那样。要获取更完整的查询,可以使用 ObjectQuery.ToTraceString()
至于这发生在哪里,LINQ provider basics 通常说明了这一点。
IQueryProvider返回一个引用到IQueryable,该引用使用LINQ框架传递的构建表达式树进行进一步调用。一般来说,每个查询块都会转换为一堆方法调用。对于每个方法调用,都涉及一些表达式。在创建我们的提供程序时,在方法IQueryProvider.CreateQuery中,我们遍历表达式并填充一个过滤器对象,该对象在IQueryProvider.Execute方法中用于针对数据存储运行查询。查询可以通过两种方式执行,一种是在Query类(继承自IQueryable)中实现GetEnumerator方法(在IEnumerable接口中定义);另一种是直接由LINQ运行时执行。在调试器下检查EF时,它是前者。如果您不想完全重新发明轮子,EF和LINQ to SQL都不是选项,也许这系列文章可以帮助您:如何:将LINQ转换为SQL如何:将LINQ转换为SQL - 第二部分如何:将LINQ转换为SQL - 第三部分。以下是创建查询提供程序的一些来源,这可能需要您更多地投入工作来实现您想要的功能:

1
我简直不敢相信,居然不存在把表达式转换为(字符串)SQL语句的完整解决方案?在GitHub、NuGet gallery和Google上搜索了好几遍,但都没有任何结果...如果有人知道一个维护良好的方案,请告诉我! - Patrick
我同意,像这样的解决方案只需获取表达式树并生成原始SQL翻译将非常适合规范模式。我希望我们能找到一些东西或者组建一个小团队开始做些什么。 - George Taskos

7

这并不是完整的,但如果你稍后来到这里,我有一些关于IT技术的想法供你借鉴:

    private string CreateWhereClause(Expression<Func<T, bool>> predicate)
    {
        StringBuilder p = new StringBuilder(predicate.Body.ToString());
        var pName = predicate.Parameters.First();
        p.Replace(pName.Name + ".", "");
        p.Replace("==", "=");
        p.Replace("AndAlso", "and");
        p.Replace("OrElse", "or");
        p.Replace("\"", "\'");
        return p.ToString();
    }

    private string AddWhereToSelectCommand(Expression<Func<T, bool>> predicate, int maxCount = 0)
    {           
        string command = string.Format("{0} where {1}", CreateSelectCommand(maxCount), CreateWhereClause(predicate));
        return command;
    }

    private string CreateSelectCommand(int maxCount = 0)
    {
        string selectMax = maxCount > 0 ? "TOP " + maxCount.ToString() + " * " : "*";
        string command = string.Format("Select {0} from {1}", selectMax, _tableName);
        return command;
    }

1
我觉得你的回答在这里应该得到更多的关注。 但是你可能会发现像(f)=> f.SomeList.Where((g)=> g.Epicness > 30)这样的方法调用存在问题。 还要注意的是,(f)=> f.Name != Environtment.MachineName将输出类似于“f.Name!= 'Environment.MachineName'”这样的内容,这可能会产生反作用。 - Felype

6
在Linq2SQL中,您可以使用以下内容:
var cmd = DataContext.GetCommand(expression);
var sqlQuery = cmd.CommandText;

1
您的示例不准确。GetCommand 的参数不是 Expression,而是 IQueryable,构建 IQueryable 需要有一个 IQueryProvider。因此问题仍然存在。 - Arseni Mourzenko
你想从一个没有“QueryProvider”的表达式中创建一个SQL查询吗? - Magnus
更准确地说,我想从一个表达式中创建一个部分的SQL查询,而不必将数据库中的表添加到EF/Linq2SQL中。如果我理解正确,Linq2SQL使用的查询提供程序取决于这些表。我错了吗? - Arseni Mourzenko

6
在搜寻数小时后,我没有找到任何在.NET Core上工作、有用或者免费的将表达式树转换为SQL语句的实现。然后我发现了这个。感谢Ryan Wright。 我拿了他的代码并稍作修改以适应我的需求。现在我将其还给社区。
当前版本可以执行以下操作:
批量更新
            int rowCount = context
                .Users
                .Where(x => x.Status == UserStatus.Banned)
                .Update(x => new
                {
                    DisplayName = "Bad Guy"
                });


这将生成以下SQL。
DECLARE @p0 NVarChar
DECLARE @p1 Int
SET @p0 = 'Bad Guy'
SET @p1 = 3
UPDATE [Users]
SET [DisplayName] = @p0
WHERE ( [Status] = @p1 )

批量删除

            int rowCount = context
                .Users
                .Where(x => x.UniqueName.EndsWith("012"))
                .Delete();

生成的 SQL
DECLARE @p0 NVarChar
SET @p0 = '%012'
DELETE
FROM [Users]
WHERE [UniqueName] LIKE @p0

输出SQL语句

            string sql = context
                .Users
                .Where(x => x.Status == UserStatus.LockedOut)
                .OrderBy(x => x.UniqueName)
                .ThenByDescending(x => x.LastLogin)
                .Select(x => new
                {
                    x.UniqueName,
                    x.Email
                })
                .ToSqlString();

这将生成SQL语句

DECLARE @p0 Int
SET @p0 = 4
SELECT [UniqueName], [Email]
FROM [Users]
WHERE ( [Status] = @p0 )
ORDER BY [LastLogin] DESC, [UniqueName] ASC

另一个示例

            string sql = context
                .Users
                .Where(x => x.Status == UserStatus.LockedOut)
                .OrderBy(x => x.UniqueName)
                .ThenByDescending(x => x.LastLogin)
                .Select(x => new
                {
                    x.UniqueName,
                    x.Email,
                    x.LastLogin
                })
                .Take(4)
                .Skip(3)
                .Distinct()
                .ToSqlString();

这个 SQL

DECLARE @p0 Int
SET @p0 = 4
SELECT DISTINCT [UniqueName], [Email], [LastLogin]
FROM [Users]
WHERE ( [Status] = @p0 )
ORDER BY [LastLogin] DESC, [UniqueName] ASC OFFSET 3 ROWS FETCH NEXT 4 ROWS ONLY

下面是一个关于本地变量的例子:

            string name ="venom";

            string sql = context
                .Users
                .Where(x => x.LastLogin == DateTime.UtcNow && x.UniqueName.Contains(name))
                .Select(x => x.Email)
                .ToSqlString();

生成的 SQL。
DECLARE @p0 DateTime
DECLARE @p1 NVarChar
SET @p0 = '20.06.2020 19:23:46'
SET @p1 = '%venom%'
SELECT [Email]
FROM [Users]
WHERE ( ( [LastLogin] = @p0 ) AND [UniqueName] LIKE @p1 )

可以直接使用 SimpleExpressionToSQL 类本身。

var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable);
simpleExpressionToSQL.ExecuteNonQuery(IsolationLevel.Snapshot);

代码

这里使用的评估器来自这里

SimpleExpressionToSQL

    public class SimpleExpressionToSQL : ExpressionVisitor
    {
        /*
         * Original By Ryan Wright: https://dev59.com/Vmsz5IYBdhLWcg3wrJ6-
         */

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _groupBy = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _orderBy = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<SqlParameter> _parameters = new List<SqlParameter>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _select = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _update = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly List<string> _where = new List<string>();

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private int? _skip;

        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private int? _take;

        public SimpleExpressionToSQL(IQueryable queryable)
        {
            if (queryable is null)
            {
                throw new ArgumentNullException(nameof(queryable));
            }

            Expression expression = queryable.Expression;
            Visit(expression);
            Type entityType = (GetEntityType(expression) as IQueryable).ElementType;
            TableName = queryable.GetTableName(entityType);
            DbContext = queryable.GetDbContext();
        }

        public string CommandText => BuildSqlStatement().Join(Environment.NewLine);

        public DbContext DbContext { get; private set; }

        public string From => $"FROM [{TableName}]";

        public string GroupBy => _groupBy.Count == 0 ? null : "GROUP BY " + _groupBy.Join(", ");
        public bool IsDelete { get; private set; } = false;
        public bool IsDistinct { get; private set; }
        public string OrderBy => BuildOrderByStatement().Join(" ");
        public SqlParameter[] Parameters => _parameters.ToArray();
        public string Select => BuildSelectStatement().Join(" ");
        public int? Skip => _skip;
        public string TableName { get; private set; }
        public int? Take => _take;
        public string Update => "SET " + _update.Join(", ");

        public string Where => _where.Count == 0 ? null : "WHERE " + _where.Join(" ");

        public static implicit operator string(SimpleExpressionToSQL simpleExpression) => simpleExpression.ToString();

        public int ExecuteNonQuery(IsolationLevel isolationLevel = IsolationLevel.RepeatableRead)
        {
            DbConnection connection = DbContext.Database.GetDbConnection();
            using (DbCommand command = connection.CreateCommand())
            {
                command.CommandText = CommandText;
                command.CommandType = CommandType.Text;
                command.Parameters.AddRange(Parameters);

#if DEBUG
                Debug.WriteLine(ToString());
#endif

                if (command.Connection.State != ConnectionState.Open)
                    command.Connection.Open();

                using (DbTransaction transaction = connection.BeginTransaction(isolationLevel))
                {
                    command.Transaction = transaction;
                    int result = command.ExecuteNonQuery();
                    transaction.Commit();

                    return result;
                }
            }
        }

        public async Task<int> ExecuteNonQueryAsync(IsolationLevel isolationLevel = IsolationLevel.RepeatableRead)
        {
            DbConnection connection = DbContext.Database.GetDbConnection();
            using (DbCommand command = connection.CreateCommand())
            {
                command.CommandText = CommandText;
                command.CommandType = CommandType.Text;
                command.Parameters.AddRange(Parameters);

#if DEBUG
                Debug.WriteLine(ToString());
#endif

                if (command.Connection.State != ConnectionState.Open)
                    await command.Connection.OpenAsync();

                using (DbTransaction transaction = connection.BeginTransaction(isolationLevel))
                {
                    command.Transaction = transaction;
                    int result = await command.ExecuteNonQueryAsync();
                    transaction.Commit();

                    return result;
                }
            }
        }

        public override string ToString() =>
            BuildDeclaration()
                .Union(BuildSqlStatement())
                .Join(Environment.NewLine);

        protected override Expression VisitBinary(BinaryExpression binaryExpression)
        {
            _where.Add("(");
            Visit(binaryExpression.Left);

            switch (binaryExpression.NodeType)
            {
                case ExpressionType.And:
                    _where.Add("AND");
                    break;

                case ExpressionType.AndAlso:
                    _where.Add("AND");
                    break;

                case ExpressionType.Or:
                case ExpressionType.OrElse:
                    _where.Add("OR");
                    break;

                case ExpressionType.Equal:
                    if (IsNullConstant(binaryExpression.Right))
                    {
                        _where.Add("IS");
                    }
                    else
                    {
                        _where.Add("=");
                    }
                    break;

                case ExpressionType.NotEqual:
                    if (IsNullConstant(binaryExpression.Right))
                    {
                        _where.Add("IS NOT");
                    }
                    else
                    {
                        _where.Add("<>");
                    }
                    break;

                case ExpressionType.LessThan:
                    _where.Add("<");
                    break;

                case ExpressionType.LessThanOrEqual:
                    _where.Add("<=");
                    break;

                case ExpressionType.GreaterThan:
                    _where.Add(">");
                    break;

                case ExpressionType.GreaterThanOrEqual:
                    _where.Add(">=");
                    break;

                default:
                    throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", binaryExpression.NodeType));
            }

            Visit(binaryExpression.Right);
            _where.Add(")");
            return binaryExpression;
        }

        protected override Expression VisitConstant(ConstantExpression constantExpression)
        {
            switch (constantExpression.Value)
            {
                case null when constantExpression.Value == null:
                    _where.Add("NULL");
                    break;

                default:

                    if (constantExpression.Type.CanConvertToSqlDbType())
                    {
                        _where.Add(CreateParameter(constantExpression.Value).ParameterName);
                    }

                    break;
            }

            return constantExpression;
        }

        protected override Expression VisitMember(MemberExpression memberExpression)
        {
            Expression VisitMemberLocal(Expression expression)
            {
                switch (expression.NodeType)
                {
                    case ExpressionType.Parameter:
                        _where.Add($"[{memberExpression.Member.Name}]");
                        return memberExpression;

                    case ExpressionType.Constant:
                        _where.Add(CreateParameter(GetValue(memberExpression)).ParameterName);

                        return memberExpression;

                    case ExpressionType.MemberAccess:
                        _where.Add(CreateParameter(GetValue(memberExpression)).ParameterName);

                        return memberExpression;
                }

                throw new NotSupportedException(string.Format("The member '{0}' is not supported", memberExpression.Member.Name));
            }

            if (memberExpression.Expression == null)
            {
                return VisitMemberLocal(memberExpression);
            }

            return VisitMemberLocal(memberExpression.Expression);
        }

        protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
        {
            switch (methodCallExpression.Method.Name)
            {
                case nameof(Queryable.Where) when methodCallExpression.Method.DeclaringType == typeof(Queryable):

                    Visit(methodCallExpression.Arguments[0]);
                    var lambda = (LambdaExpression)StripQuotes(methodCallExpression.Arguments[1]);
                    Visit(lambda.Body);

                    return methodCallExpression;

                case nameof(Queryable.Select):
                    return ParseExpression(methodCallExpression, _select);

                case nameof(Queryable.GroupBy):
                    return ParseExpression(methodCallExpression, _groupBy);

                case nameof(Queryable.Take):
                    return ParseExpression(methodCallExpression, ref _take);

                case nameof(Queryable.Skip):
                    return ParseExpression(methodCallExpression, ref _skip);

                case nameof(Queryable.OrderBy):
                case nameof(Queryable.ThenBy):
                    return ParseExpression(methodCallExpression, _orderBy, "ASC");

                case nameof(Queryable.OrderByDescending):
                case nameof(Queryable.ThenByDescending):
                    return ParseExpression(methodCallExpression, _orderBy, "DESC");

                case nameof(Queryable.Distinct):
                    IsDistinct = true;
                    return Visit(methodCallExpression.Arguments[0]);

                case nameof(string.StartsWith):
                    _where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
                    _where.Add("LIKE");
                    _where.Add(CreateParameter(GetValue(methodCallExpression.Arguments[0]).ToString() + "%").ParameterName);
                    return methodCallExpression.Arguments[0];

                case nameof(string.EndsWith):
                    _where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
                    _where.Add("LIKE");
                    _where.Add(CreateParameter("%" + GetValue(methodCallExpression.Arguments[0]).ToString()).ParameterName);
                    return methodCallExpression.Arguments[0];

                case nameof(string.Contains):
                    _where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
                    _where.Add("LIKE");
                    _where.Add(CreateParameter("%" + GetValue(methodCallExpression.Arguments[0]).ToString() + "%").ParameterName);
                    return methodCallExpression.Arguments[0];

                case nameof(Extensions.ToSqlString):
                    return Visit(methodCallExpression.Arguments[0]);

                case nameof(Extensions.Delete):
                case nameof(Extensions.DeleteAsync):
                    IsDelete = true;
                    return Visit(methodCallExpression.Arguments[0]);

                case nameof(Extensions.Update):
                    return ParseExpression(methodCallExpression, _update);

                default:
                    if (methodCallExpression.Object != null)
                    {
                        _where.Add(CreateParameter(GetValue(methodCallExpression)).ParameterName);
                        return methodCallExpression;
                    }
                    break;
            }

            throw new NotSupportedException($"The method '{methodCallExpression.Method.Name}' is not supported");
        }

        protected override Expression VisitUnary(UnaryExpression unaryExpression)
        {
            switch (unaryExpression.NodeType)
            {
                case ExpressionType.Not:
                    _where.Add("NOT");
                    Visit(unaryExpression.Operand);
                    break;

                case ExpressionType.Convert:
                    Visit(unaryExpression.Operand);
                    break;

                default:
                    throw new NotSupportedException($"The unary operator '{unaryExpression.NodeType}' is not supported");
            }
            return unaryExpression;
        }

        private static Expression StripQuotes(Expression expression)
        {
            while (expression.NodeType == ExpressionType.Quote)
            {
                expression = ((UnaryExpression)expression).Operand;
            }
            return expression;
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildDeclaration()
        {
            if (Parameters.Length == 0)                        /**/    yield break;
            foreach (SqlParameter parameter in Parameters)     /**/    yield return $"DECLARE {parameter.ParameterName} {parameter.SqlDbType}";

            foreach (SqlParameter parameter in Parameters)     /**/
                if (parameter.SqlDbType.RequiresQuotes())      /**/    yield return $"SET {parameter.ParameterName} = '{parameter.SqlValue?.ToString().Replace("'", "''") ?? "NULL"}'";
                else                                           /**/    yield return $"SET {parameter.ParameterName} = {parameter.SqlValue}";
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildOrderByStatement()
        {
            if (Skip.HasValue && _orderBy.Count == 0)                       /**/   yield return "ORDER BY (SELECT NULL)";
            else if (_orderBy.Count == 0)                                   /**/   yield break;
            else if (_groupBy.Count > 0 && _orderBy[0].StartsWith("[Key]")) /**/   yield return "ORDER BY " + _groupBy.Join(", ");
            else                                                            /**/   yield return "ORDER BY " + _orderBy.Join(", ");

            if (Skip.HasValue && Take.HasValue)                             /**/   yield return $"OFFSET {Skip} ROWS FETCH NEXT {Take} ROWS ONLY";
            else if (Skip.HasValue && !Take.HasValue)                       /**/   yield return $"OFFSET {Skip} ROWS";
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildSelectStatement()
        {
            yield return "SELECT";

            if (IsDistinct)                                 /**/    yield return "DISTINCT";

            if (Take.HasValue && !Skip.HasValue)            /**/    yield return $"TOP ({Take.Value})";

            if (_select.Count == 0 && _groupBy.Count > 0)   /**/    yield return _groupBy.Select(x => $"MAX({x})").Join(", ");
            else if (_select.Count == 0)                    /**/    yield return "*";
            else                                            /**/    yield return _select.Join(", ");
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
        private IEnumerable<string> BuildSqlStatement()
        {
            if (IsDelete)                   /**/   yield return "DELETE";
            else if (_update.Count > 0)     /**/   yield return $"UPDATE [{TableName}]";
            else                            /**/   yield return Select;

            if (_update.Count == 0)         /**/   yield return From;
            else if (_update.Count > 0)     /**/   yield return Update;

            if (Where != null)              /**/   yield return Where;
            if (GroupBy != null)            /**/   yield return GroupBy;
            if (OrderBy != null)            /**/   yield return OrderBy;
        }

        private SqlParameter CreateParameter(object value)
        {
            string parameterName = $"@p{_parameters.Count}";

            var parameter = new SqlParameter()
            {
                ParameterName = parameterName,
                Value = value
            };

            _parameters.Add(parameter);

            return parameter;
        }

        private object GetEntityType(Expression expression)
        {
            while (true)
            {
                switch (expression)
                {
                    case ConstantExpression constantExpression:
                        return constantExpression.Value;

                    case MethodCallExpression methodCallExpression:
                        expression = methodCallExpression.Arguments[0];
                        continue;

                    default:
                        return null;
                }
            }
        }

        private IEnumerable<string> GetNewExpressionString(NewExpression newExpression, string appendString = null)
        {
            for (int i = 0; i < newExpression.Members.Count; i++)
            {
                if (newExpression.Arguments[i].NodeType == ExpressionType.MemberAccess)
                {
                    yield return
                        appendString == null ?
                        $"[{newExpression.Members[i].Name}]" :
                        $"[{newExpression.Members[i].Name}] {appendString}";
                }
                else
                {
                    yield return
                        appendString == null ?
                        $"[{newExpression.Members[i].Name}] = {CreateParameter(GetValue(newExpression.Arguments[i])).ParameterName}" :
                        $"[{newExpression.Members[i].Name}] = {CreateParameter(GetValue(newExpression.Arguments[i])).ParameterName}";
                }
            }
        }

        private object GetValue(Expression expression)
        {
            object GetMemberValue(MemberInfo memberInfo, object container = null)
            {
                switch (memberInfo)
                {
                    case FieldInfo fieldInfo:
                        return fieldInfo.GetValue(container);

                    case PropertyInfo propertyInfo:
                        return propertyInfo.GetValue(container);

                    default: return null;
                }
            }

            switch (expression)
            {
                case ConstantExpression constantExpression:
                    return constantExpression.Value;

                case MemberExpression memberExpression when memberExpression.Expression is ConstantExpression constantExpression:
                    return GetMemberValue(memberExpression.Member, constantExpression.Value);

                case MemberExpression memberExpression when memberExpression.Expression is null: // static
                    return GetMemberValue(memberExpression.Member);

                case MethodCallExpression methodCallExpression:
                    return Expression.Lambda(methodCallExpression).Compile().DynamicInvoke();

                case null:
                    return null;
            }

            throw new NotSupportedException();
        }

        private bool IsNullConstant(Expression expression) => expression.NodeType == ExpressionType.Constant && ((ConstantExpression)expression).Value == null;

        private IEnumerable<string> ParseExpression(Expression parent, Expression body, string appendString = null)
        {
            switch (body)
            {
                case MemberExpression memberExpression:
                    return appendString == null ?
                        new string[] { $"[{memberExpression.Member.Name}]" } :
                        new string[] { $"[{memberExpression.Member.Name}] {appendString}" };

                case NewExpression newExpression:
                    return GetNewExpressionString(newExpression, appendString);

                case ParameterExpression parameterExpression when parent is LambdaExpression lambdaExpression && lambdaExpression.ReturnType == parameterExpression.Type:
                    return new string[0];

                case ConstantExpression constantExpression:
                    return constantExpression
                        .Type
                        .GetProperties(BindingFlags.Public | BindingFlags.Instance)
                        .Select(x => $"[{x.Name}] = {CreateParameter(x.GetValue(constantExpression.Value)).ParameterName}");
            }

            throw new NotSupportedException();
        }

        private Expression ParseExpression(MethodCallExpression expression, List<string> commandList, string appendString = null)
        {
            var unary = (UnaryExpression)expression.Arguments[1];
            var lambdaExpression = (LambdaExpression)unary.Operand;

            lambdaExpression = (LambdaExpression)Evaluator.PartialEval(lambdaExpression);

            commandList.AddRange(ParseExpression(lambdaExpression, lambdaExpression.Body, appendString));

            return Visit(expression.Arguments[0]);
        }

        private Expression ParseExpression(MethodCallExpression expression, ref int? size)
        {
            var sizeExpression = (ConstantExpression)expression.Arguments[1];

            if (int.TryParse(sizeExpression.Value.ToString(), out int value))
            {
                size = value;
                return Visit(expression.Arguments[0]);
            }

            throw new NotSupportedException();
        }
    }

我会在评论区贴出扩展程序。

请在生产环境中小心使用

欢迎将其制作成 Nuget 包 :)


兄弟,你有这个的完整工作示例吗?我在将它们全部组合在一起时遇到了麻烦。什么是上下文,为什么它有用户?我如何在具有启用布尔状态的简单类上使用它? - Secretary Of Education
@SecretaryOfEducation 上下文只是一个DBContext实例。 - Legacy Code

4
您基本上需要重新发明轮子。 QueryProvider是将表达式树转换为其存储的原生语法的工具。它还要处理特殊情况,例如string.Contains(),string.StartsWith()以及处理它的所有专业函数。它还会处理ORM的各个层次中的元数据查找(在数据库优先或模型优先实体框架中使用*.edml)。已经有一些示例和框架可用于构建SQL命令。但是您要寻找的似乎只是一个部分解决方案。
此外,了解正确确定什么是合法的需要表/视图元数据。查询提供程序非常复杂,并且除了将简单的表达式树转换为SQL之外,它们还为您完成了大量工作。
针对您的问题“第二部分发生在哪里”,第二部分发生在IQueryable枚举时。IQueryables也是IEnumerables,在调用GetEnumerator时,它会调用查询提供程序来处理表达式树,查询提供程序将使用其元数据生成一个SQL命令。虽然这不是确切的过程,但应该可以帮助您理解这个概念。

3

1
对于EF来说,这是否需要先建立一个ObjectContextDbContext(这意味着几乎要设置所有内容),而不仅仅是EF的一部分 - Kit

1

不确定这是否完全符合您的需求,但看起来可能接近:

string[] companies = { "Consolidated Messenger", "Alpine Ski House", "Southridge Video", "City Power & Light",
                   "Coho Winery", "Wide World Importers", "Graphic Design Institute", "Adventure Works",
                   "Humongous Insurance", "Woodgrove Bank", "Margie's Travel", "Northwind Traders",
                   "Blue Yonder Airlines", "Trey Research", "The Phone Company",
                   "Wingtip Toys", "Lucerne Publishing", "Fourth Coffee" };

// The IQueryable data to query.
IQueryable<String> queryableData = companies.AsQueryable<string>();

// Compose the expression tree that represents the parameter to the predicate.
ParameterExpression pe = Expression.Parameter(typeof(string), "company");

// ***** Where(company => (company.ToLower() == "coho winery" || company.Length > 16)) *****
// Create an expression tree that represents the expression 'company.ToLower() == "coho winery"'.
Expression left = Expression.Call(pe, typeof(string).GetMethod("ToLower", System.Type.EmptyTypes));
Expression right = Expression.Constant("coho winery");
Expression e1 = Expression.Equal(left, right);

// Create an expression tree that represents the expression 'company.Length > 16'.
left = Expression.Property(pe, typeof(string).GetProperty("Length"));
right = Expression.Constant(16, typeof(int));
Expression e2 = Expression.GreaterThan(left, right);

// Combine the expression trees to create an expression tree that represents the
// expression '(company.ToLower() == "coho winery" || company.Length > 16)'.
Expression predicateBody = Expression.OrElse(e1, e2);

// Create an expression tree that represents the expression
// 'queryableData.Where(company => (company.ToLower() == "coho winery" || company.Length > 16))'
MethodCallExpression whereCallExpression = Expression.Call(
    typeof(Queryable),
    "Where",
    new Type[] { queryableData.ElementType },
    queryableData.Expression,
    Expression.Lambda<Func<string, bool>>(predicateBody, new ParameterExpression[] { pe }));
// ***** End Where *****

// ***** OrderBy(company => company) *****
// Create an expression tree that represents the expression
// 'whereCallExpression.OrderBy(company => company)'
MethodCallExpression orderByCallExpression = Expression.Call(
    typeof(Queryable),
    "OrderBy",
    new Type[] { queryableData.ElementType, queryableData.ElementType },
    whereCallExpression,
    Expression.Lambda<Func<string, string>>(pe, new ParameterExpression[] { pe }));
// ***** End OrderBy *****

// Create an executable query from the expression tree.
IQueryable<string> results = queryableData.Provider.CreateQuery<string>(orderByCallExpression);

// Enumerate the results.
foreach (string company in results)
    Console.WriteLine(company);

1

Extensions for the SimpleExpressionToSQL class

    public static class Extensions
    {
        private static readonly MethodInfo _deleteMethod;
        private static readonly MethodInfo _deleteMethodAsync;
        private static readonly MethodInfo _toSqlStringMethod;
        private static readonly MethodInfo _updateMethod;
        private static readonly MethodInfo _updateMethodAsync;

        static Extensions()
        {
            Type extensionType = typeof(Extensions);

            _deleteMethod = extensionType.GetMethod(nameof(Extensions.Delete), BindingFlags.Static | BindingFlags.Public);
            _updateMethod = extensionType.GetMethod(nameof(Extensions.Update), BindingFlags.Static | BindingFlags.Public);

            _deleteMethodAsync = extensionType.GetMethod(nameof(Extensions.DeleteAsync), BindingFlags.Static | BindingFlags.Public);
            _updateMethodAsync = extensionType.GetMethod(nameof(Extensions.Update), BindingFlags.Static | BindingFlags.Public);

            _toSqlStringMethod = extensionType.GetMethod(nameof(Extensions.ToSqlString), BindingFlags.Static | BindingFlags.Public);
        }

        public static bool CanConvertToSqlDbType(this Type type) => type.ToSqlDbTypeInternal().HasValue;

        public static int Delete<T>(this IQueryable<T> queryable)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_deleteMethod));
            return simpleExpressionToSQL.ExecuteNonQuery();
        }

        public static async Task<int> DeleteAsync<T>(this IQueryable<T> queryable)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_deleteMethodAsync));
            return await simpleExpressionToSQL.ExecuteNonQueryAsync();
        }

        public static string GetTableName<TEntity>(this DbSet<TEntity> dbSet) where TEntity : class
        {
            DbContext context = dbSet.GetService<ICurrentDbContext>().Context;
            IModel model = context.Model;
            IEntityType entityTypeOfFooBar = model
                .GetEntityTypes()
                .First(t => t.ClrType == typeof(TEntity));

            IAnnotation tableNameAnnotation = entityTypeOfFooBar.GetAnnotation("Relational:TableName");

            return tableNameAnnotation.Value.ToString();
        }

        public static string GetTableName(this IQueryable query, Type entity)
        {
            QueryCompiler compiler = query.Provider.GetValueOfField<QueryCompiler>("_queryCompiler");
            IModel model = compiler.GetValueOfField<IModel>("_model");
            IEntityType entityTypeOfFooBar = model
                .GetEntityTypes()
                .First(t => t.ClrType == entity);

            IAnnotation tableNameAnnotation = entityTypeOfFooBar.GetAnnotation("Relational:TableName");

            return tableNameAnnotation.Value.ToString();
        }

        public static SqlDbType ToSqlDbType(this Type type) =>
            type.ToSqlDbTypeInternal() ?? throw new InvalidCastException($"Unable to cast from '{type}' to '{typeof(DbType)}'.");

        public static string ToSqlString<T>(this IQueryable<T> queryable) => new SimpleExpressionToSQL(queryable.AppendCall(_toSqlStringMethod));

        public static int Update<TSource, TResult>(this IQueryable<TSource> queryable, Expression<Func<TSource, TResult>> selector)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_updateMethod, selector));
            return simpleExpressionToSQL.ExecuteNonQuery();
        }

        public static async Task<int> UpdateAsync<TSource, TResult>(this IQueryable<TSource> queryable, Expression<Func<TSource, TResult>> selector)
        {
            var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable.AppendCall(_updateMethodAsync, selector));
            return await simpleExpressionToSQL.ExecuteNonQueryAsync();
        }

        internal static DbContext GetDbContext(this IQueryable query)
        {
            QueryCompiler compiler = query.Provider.GetValueOfField<QueryCompiler>("_queryCompiler");
            RelationalQueryContextFactory queryContextFactory = compiler.GetValueOfField<RelationalQueryContextFactory>("_queryContextFactory");
            QueryContextDependencies dependencies = queryContextFactory.GetValueOfField<QueryContextDependencies>("_dependencies");

            return dependencies.CurrentContext.Context;
        }

        internal static string Join(this IEnumerable<string> values, string separator) => string.Join(separator, values);

        internal static bool RequiresQuotes(this SqlDbType sqlDbType)
        {
            switch (sqlDbType)
            {
                case SqlDbType.Char:
                case SqlDbType.Date:
                case SqlDbType.DateTime:
                case SqlDbType.DateTime2:
                case SqlDbType.DateTimeOffset:
                case SqlDbType.NChar:
                case SqlDbType.NText:
                case SqlDbType.Time:
                case SqlDbType.SmallDateTime:
                case SqlDbType.Text:
                case SqlDbType.UniqueIdentifier:
                case SqlDbType.Timestamp:
                case SqlDbType.VarChar:
                case SqlDbType.Xml:
                case SqlDbType.Variant:
                case SqlDbType.NVarChar:
                    return true;

                default:
                    return false;
            }
        }

        internal static unsafe string ToCamelCase(this string value)
        {
            if (value == null || value.Length == 0)
            {
                return value;
            }

            string result = string.Copy(value);

            fixed (char* chr = result)
            {
                char valueChar = *chr;
                *chr = char.ToLowerInvariant(valueChar);
            }

            return result;
        }

        private static IQueryable<TResult> AppendCall<TSource, TResult>(this IQueryable<TSource> queryable, MethodInfo methodInfo, Expression<Func<TSource, TResult>> selector)
        {
            MethodInfo methodInfoGeneric = methodInfo.MakeGenericMethod(typeof(TSource), typeof(TResult));
            MethodCallExpression methodCallExpression = Expression.Call(methodInfoGeneric, queryable.Expression, selector);

            return new EntityQueryable<TResult>(queryable.Provider as IAsyncQueryProvider, methodCallExpression);
        }

        private static IQueryable<T> AppendCall<T>(this IQueryable<T> queryable, MethodInfo methodInfo)
        {
            MethodInfo methodInfoGeneric = methodInfo.MakeGenericMethod(typeof(T));
            MethodCallExpression methodCallExpression = Expression.Call(methodInfoGeneric, queryable.Expression);

            return new EntityQueryable<T>(queryable.Provider as IAsyncQueryProvider, methodCallExpression);
        }

        private static T GetValueOfField<T>(this object obj, string name)
        {
            FieldInfo field = obj
                .GetType()
                .GetField(name, BindingFlags.NonPublic | BindingFlags.Instance);

            return (T)field.GetValue(obj);
        }

        [SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read than with Allman braces")]
        private static SqlDbType? ToSqlDbTypeInternal(this Type type)
        {
            if (Nullable.GetUnderlyingType(type) is Type nullableType)
                return nullableType.ToSqlDbTypeInternal();

            if (type.IsEnum)
                return Enum.GetUnderlyingType(type).ToSqlDbTypeInternal();

            if (type == typeof(long))            /**/                return SqlDbType.BigInt;
            if (type == typeof(byte[]))          /**/                return SqlDbType.VarBinary;
            if (type == typeof(bool))            /**/                return SqlDbType.Bit;
            if (type == typeof(string))          /**/                return SqlDbType.NVarChar;
            if (type == typeof(DateTime))        /**/                return SqlDbType.DateTime2;
            if (type == typeof(decimal))         /**/                return SqlDbType.Decimal;
            if (type == typeof(double))          /**/                return SqlDbType.Float;
            if (type == typeof(int))             /**/                return SqlDbType.Int;
            if (type == typeof(float))           /**/                return SqlDbType.Real;
            if (type == typeof(Guid))            /**/                return SqlDbType.UniqueIdentifier;
            if (type == typeof(short))           /**/                return SqlDbType.SmallInt;
            if (type == typeof(object))          /**/                return SqlDbType.Variant;
            if (type == typeof(DateTimeOffset))  /**/                return SqlDbType.DateTimeOffset;
            if (type == typeof(TimeSpan))        /**/                return SqlDbType.Time;
            if (type == typeof(byte))            /**/                return SqlDbType.TinyInt;

            return null;
        }
    }

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接