在搜寻数小时后,我没有找到任何在.NET Core上工作、有用或者免费的将表达式树转换为SQL语句的实现。然后我发现了这个。感谢
Ryan Wright。
我拿了他的代码并稍作修改以适应我的需求。现在我将其还给社区。
当前版本可以执行以下操作:
批量更新
int rowCount = context
.Users
.Where(x => x.Status == UserStatus.Banned)
.Update(x => new
{
DisplayName = "Bad Guy"
});
这将生成以下SQL。
DECLARE @p0 NVarChar
DECLARE @p1 Int
SET @p0 = 'Bad Guy'
SET @p1 = 3
UPDATE [Users]
SET [DisplayName] = @p0
WHERE ( [Status] = @p1 )
批量删除
int rowCount = context
.Users
.Where(x => x.UniqueName.EndsWith("012"))
.Delete();
生成的 SQL
DECLARE @p0 NVarChar
SET @p0 = '%012'
DELETE
FROM [Users]
WHERE [UniqueName] LIKE @p0
输出SQL语句
string sql = context
.Users
.Where(x => x.Status == UserStatus.LockedOut)
.OrderBy(x => x.UniqueName)
.ThenByDescending(x => x.LastLogin)
.Select(x => new
{
x.UniqueName,
x.Email
})
.ToSqlString();
这将生成SQL语句
DECLARE @p0 Int
SET @p0 = 4
SELECT [UniqueName], [Email]
FROM [Users]
WHERE ( [Status] = @p0 )
ORDER BY [LastLogin] DESC, [UniqueName] ASC
另一个示例
string sql = context
.Users
.Where(x => x.Status == UserStatus.LockedOut)
.OrderBy(x => x.UniqueName)
.ThenByDescending(x => x.LastLogin)
.Select(x => new
{
x.UniqueName,
x.Email,
x.LastLogin
})
.Take(4)
.Skip(3)
.Distinct()
.ToSqlString();
这个 SQL
DECLARE @p0 Int
SET @p0 = 4
SELECT DISTINCT [UniqueName], [Email], [LastLogin]
FROM [Users]
WHERE ( [Status] = @p0 )
ORDER BY [LastLogin] DESC, [UniqueName] ASC OFFSET 3 ROWS FETCH NEXT 4 ROWS ONLY
下面是一个关于本地变量的例子:
string name ="venom";
string sql = context
.Users
.Where(x => x.LastLogin == DateTime.UtcNow && x.UniqueName.Contains(name))
.Select(x => x.Email)
.ToSqlString();
生成的 SQL。
DECLARE @p0 DateTime
DECLARE @p1 NVarChar
SET @p0 = '20.06.2020 19:23:46'
SET @p1 = '%venom%'
SELECT [Email]
FROM [Users]
WHERE ( ( [LastLogin] = @p0 ) AND [UniqueName] LIKE @p1 )
可以直接使用 SimpleExpressionToSQL 类本身。
var simpleExpressionToSQL = new SimpleExpressionToSQL(queryable);
simpleExpressionToSQL.ExecuteNonQuery(IsolationLevel.Snapshot);
代码
这里使用的评估器来自这里
SimpleExpressionToSQL
public class SimpleExpressionToSQL : ExpressionVisitor
{
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly List<string> _groupBy = new List<string>();
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly List<string> _orderBy = new List<string>();
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly List<SqlParameter> _parameters = new List<SqlParameter>();
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly List<string> _select = new List<string>();
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly List<string> _update = new List<string>();
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private readonly List<string> _where = new List<string>();
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private int? _skip;
[DebuggerBrowsable(DebuggerBrowsableState.Never)]
private int? _take;
public SimpleExpressionToSQL(IQueryable queryable)
{
if (queryable is null)
{
throw new ArgumentNullException(nameof(queryable));
}
Expression expression = queryable.Expression;
Visit(expression);
Type entityType = (GetEntityType(expression) as IQueryable).ElementType;
TableName = queryable.GetTableName(entityType);
DbContext = queryable.GetDbContext();
}
public string CommandText => BuildSqlStatement().Join(Environment.NewLine);
public DbContext DbContext { get; private set; }
public string From => $"FROM [{TableName}]";
public string GroupBy => _groupBy.Count == 0 ? null : "GROUP BY " + _groupBy.Join(", ");
public bool IsDelete { get; private set; } = false;
public bool IsDistinct { get; private set; }
public string OrderBy => BuildOrderByStatement().Join(" ");
public SqlParameter[] Parameters => _parameters.ToArray();
public string Select => BuildSelectStatement().Join(" ");
public int? Skip => _skip;
public string TableName { get; private set; }
public int? Take => _take;
public string Update => "SET " + _update.Join(", ");
public string Where => _where.Count == 0 ? null : "WHERE " + _where.Join(" ");
public static implicit operator string(SimpleExpressionToSQL simpleExpression) => simpleExpression.ToString();
public int ExecuteNonQuery(IsolationLevel isolationLevel = IsolationLevel.RepeatableRead)
{
DbConnection connection = DbContext.Database.GetDbConnection();
using (DbCommand command = connection.CreateCommand())
{
command.CommandText = CommandText;
command.CommandType = CommandType.Text;
command.Parameters.AddRange(Parameters);
#if DEBUG
Debug.WriteLine(ToString());
#endif
if (command.Connection.State != ConnectionState.Open)
command.Connection.Open();
using (DbTransaction transaction = connection.BeginTransaction(isolationLevel))
{
command.Transaction = transaction;
int result = command.ExecuteNonQuery();
transaction.Commit();
return result;
}
}
}
public async Task<int> ExecuteNonQueryAsync(IsolationLevel isolationLevel = IsolationLevel.RepeatableRead)
{
DbConnection connection = DbContext.Database.GetDbConnection();
using (DbCommand command = connection.CreateCommand())
{
command.CommandText = CommandText;
command.CommandType = CommandType.Text;
command.Parameters.AddRange(Parameters);
#if DEBUG
Debug.WriteLine(ToString());
#endif
if (command.Connection.State != ConnectionState.Open)
await command.Connection.OpenAsync();
using (DbTransaction transaction = connection.BeginTransaction(isolationLevel))
{
command.Transaction = transaction;
int result = await command.ExecuteNonQueryAsync();
transaction.Commit();
return result;
}
}
}
public override string ToString() =>
BuildDeclaration()
.Union(BuildSqlStatement())
.Join(Environment.NewLine);
protected override Expression VisitBinary(BinaryExpression binaryExpression)
{
_where.Add("(");
Visit(binaryExpression.Left);
switch (binaryExpression.NodeType)
{
case ExpressionType.And:
_where.Add("AND");
break;
case ExpressionType.AndAlso:
_where.Add("AND");
break;
case ExpressionType.Or:
case ExpressionType.OrElse:
_where.Add("OR");
break;
case ExpressionType.Equal:
if (IsNullConstant(binaryExpression.Right))
{
_where.Add("IS");
}
else
{
_where.Add("=");
}
break;
case ExpressionType.NotEqual:
if (IsNullConstant(binaryExpression.Right))
{
_where.Add("IS NOT");
}
else
{
_where.Add("<>");
}
break;
case ExpressionType.LessThan:
_where.Add("<");
break;
case ExpressionType.LessThanOrEqual:
_where.Add("<=");
break;
case ExpressionType.GreaterThan:
_where.Add(">");
break;
case ExpressionType.GreaterThanOrEqual:
_where.Add(">=");
break;
default:
throw new NotSupportedException(string.Format("The binary operator '{0}' is not supported", binaryExpression.NodeType));
}
Visit(binaryExpression.Right);
_where.Add(")");
return binaryExpression;
}
protected override Expression VisitConstant(ConstantExpression constantExpression)
{
switch (constantExpression.Value)
{
case null when constantExpression.Value == null:
_where.Add("NULL");
break;
default:
if (constantExpression.Type.CanConvertToSqlDbType())
{
_where.Add(CreateParameter(constantExpression.Value).ParameterName);
}
break;
}
return constantExpression;
}
protected override Expression VisitMember(MemberExpression memberExpression)
{
Expression VisitMemberLocal(Expression expression)
{
switch (expression.NodeType)
{
case ExpressionType.Parameter:
_where.Add($"[{memberExpression.Member.Name}]");
return memberExpression;
case ExpressionType.Constant:
_where.Add(CreateParameter(GetValue(memberExpression)).ParameterName);
return memberExpression;
case ExpressionType.MemberAccess:
_where.Add(CreateParameter(GetValue(memberExpression)).ParameterName);
return memberExpression;
}
throw new NotSupportedException(string.Format("The member '{0}' is not supported", memberExpression.Member.Name));
}
if (memberExpression.Expression == null)
{
return VisitMemberLocal(memberExpression);
}
return VisitMemberLocal(memberExpression.Expression);
}
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
{
switch (methodCallExpression.Method.Name)
{
case nameof(Queryable.Where) when methodCallExpression.Method.DeclaringType == typeof(Queryable):
Visit(methodCallExpression.Arguments[0]);
var lambda = (LambdaExpression)StripQuotes(methodCallExpression.Arguments[1]);
Visit(lambda.Body);
return methodCallExpression;
case nameof(Queryable.Select):
return ParseExpression(methodCallExpression, _select);
case nameof(Queryable.GroupBy):
return ParseExpression(methodCallExpression, _groupBy);
case nameof(Queryable.Take):
return ParseExpression(methodCallExpression, ref _take);
case nameof(Queryable.Skip):
return ParseExpression(methodCallExpression, ref _skip);
case nameof(Queryable.OrderBy):
case nameof(Queryable.ThenBy):
return ParseExpression(methodCallExpression, _orderBy, "ASC");
case nameof(Queryable.OrderByDescending):
case nameof(Queryable.ThenByDescending):
return ParseExpression(methodCallExpression, _orderBy, "DESC");
case nameof(Queryable.Distinct):
IsDistinct = true;
return Visit(methodCallExpression.Arguments[0]);
case nameof(string.StartsWith):
_where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
_where.Add("LIKE");
_where.Add(CreateParameter(GetValue(methodCallExpression.Arguments[0]).ToString() + "%").ParameterName);
return methodCallExpression.Arguments[0];
case nameof(string.EndsWith):
_where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
_where.Add("LIKE");
_where.Add(CreateParameter("%" + GetValue(methodCallExpression.Arguments[0]).ToString()).ParameterName);
return methodCallExpression.Arguments[0];
case nameof(string.Contains):
_where.AddRange(ParseExpression(methodCallExpression, methodCallExpression.Object));
_where.Add("LIKE");
_where.Add(CreateParameter("%" + GetValue(methodCallExpression.Arguments[0]).ToString() + "%").ParameterName);
return methodCallExpression.Arguments[0];
case nameof(Extensions.ToSqlString):
return Visit(methodCallExpression.Arguments[0]);
case nameof(Extensions.Delete):
case nameof(Extensions.DeleteAsync):
IsDelete = true;
return Visit(methodCallExpression.Arguments[0]);
case nameof(Extensions.Update):
return ParseExpression(methodCallExpression, _update);
default:
if (methodCallExpression.Object != null)
{
_where.Add(CreateParameter(GetValue(methodCallExpression)).ParameterName);
return methodCallExpression;
}
break;
}
throw new NotSupportedException($"The method '{methodCallExpression.Method.Name}' is not supported");
}
protected override Expression VisitUnary(UnaryExpression unaryExpression)
{
switch (unaryExpression.NodeType)
{
case ExpressionType.Not:
_where.Add("NOT");
Visit(unaryExpression.Operand);
break;
case ExpressionType.Convert:
Visit(unaryExpression.Operand);
break;
default:
throw new NotSupportedException($"The unary operator '{unaryExpression.NodeType}' is not supported");
}
return unaryExpression;
}
private static Expression StripQuotes(Expression expression)
{
while (expression.NodeType == ExpressionType.Quote)
{
expression = ((UnaryExpression)expression).Operand;
}
return expression;
}
[SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
private IEnumerable<string> BuildDeclaration()
{
if (Parameters.Length == 0) yield break;
foreach (SqlParameter parameter in Parameters) yield return $"DECLARE {parameter.ParameterName} {parameter.SqlDbType}";
foreach (SqlParameter parameter in Parameters)
if (parameter.SqlDbType.RequiresQuotes()) yield return $"SET {parameter.ParameterName} = '{parameter.SqlValue?.ToString().Replace("'", "''") ?? "NULL"}'";
else yield return $"SET {parameter.ParameterName} = {parameter.SqlValue}";
}
[SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
private IEnumerable<string> BuildOrderByStatement()
{
if (Skip.HasValue && _orderBy.Count == 0) yield return "ORDER BY (SELECT NULL)";
else if (_orderBy.Count == 0) yield break;
else if (_groupBy.Count > 0 && _orderBy[0].StartsWith("[Key]")) yield return "ORDER BY " + _groupBy.Join(", ");
else yield return "ORDER BY " + _orderBy.Join(", ");
if (Skip.HasValue && Take.HasValue) yield return $"OFFSET {Skip} ROWS FETCH NEXT {Take} ROWS ONLY";
else if (Skip.HasValue && !Take.HasValue) yield return $"OFFSET {Skip} ROWS";
}
[SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
private IEnumerable<string> BuildSelectStatement()
{
yield return "SELECT";
if (IsDistinct) yield return "DISTINCT";
if (Take.HasValue && !Skip.HasValue) yield return $"TOP ({Take.Value})";
if (_select.Count == 0 && _groupBy.Count > 0) yield return _groupBy.Select(x => $"MAX({x})").Join(", ");
else if (_select.Count == 0) yield return "*";
else yield return _select.Join(", ");
}
[SuppressMessage("Style", "IDE0011:Add braces", Justification = "Easier to read")]
private IEnumerable<string> BuildSqlStatement()
{
if (IsDelete) yield return "DELETE";
else if (_update.Count > 0) yield return $"UPDATE [{TableName}]";
else yield return Select;
if (_update.Count == 0) yield return From;
else if (_update.Count > 0) yield return Update;
if (Where != null) yield return Where;
if (GroupBy != null) yield return GroupBy;
if (OrderBy != null) yield return OrderBy;
}
private SqlParameter CreateParameter(object value)
{
string parameterName = $"@p{_parameters.Count}";
var parameter = new SqlParameter()
{
ParameterName = parameterName,
Value = value
};
_parameters.Add(parameter);
return parameter;
}
private object GetEntityType(Expression expression)
{
while (true)
{
switch (expression)
{
case ConstantExpression constantExpression:
return constantExpression.Value;
case MethodCallExpression methodCallExpression:
expression = methodCallExpression.Arguments[0];
continue;
default:
return null;
}
}
}
private IEnumerable<string> GetNewExpressionString(NewExpression newExpression, string appendString = null)
{
for (int i = 0; i < newExpression.Members.Count; i++)
{
if (newExpression.Arguments[i].NodeType == ExpressionType.MemberAccess)
{
yield return
appendString == null ?
$"[{newExpression.Members[i].Name}]" :
$"[{newExpression.Members[i].Name}] {appendString}";
}
else
{
yield return
appendString == null ?
$"[{newExpression.Members[i].Name}] = {CreateParameter(GetValue(newExpression.Arguments[i])).ParameterName}" :
$"[{newExpression.Members[i].Name}] = {CreateParameter(GetValue(newExpression.Arguments[i])).ParameterName}";
}
}
}
private object GetValue(Expression expression)
{
object GetMemberValue(MemberInfo memberInfo, object container = null)
{
switch (memberInfo)
{
case FieldInfo fieldInfo:
return fieldInfo.GetValue(container);
case PropertyInfo propertyInfo:
return propertyInfo.GetValue(container);
default: return null;
}
}
switch (expression)
{
case ConstantExpression constantExpression:
return constantExpression.Value;
case MemberExpression memberExpression when memberExpression.Expression is ConstantExpression constantExpression:
return GetMemberValue(memberExpression.Member, constantExpression.Value);
case MemberExpression memberExpression when memberExpression.Expression is null:
return GetMemberValue(memberExpression.Member);
case MethodCallExpression methodCallExpression:
return Expression.Lambda(methodCallExpression).Compile().DynamicInvoke();
case null:
return null;
}
throw new NotSupportedException();
}
private bool IsNullConstant(Expression expression) => expression.NodeType == ExpressionType.Constant && ((ConstantExpression)expression).Value == null;
private IEnumerable<string> ParseExpression(Expression parent, Expression body, string appendString = null)
{
switch (body)
{
case MemberExpression memberExpression:
return appendString == null ?
new string[] { $"[{memberExpression.Member.Name}]" } :
new string[] { $"[{memberExpression.Member.Name}] {appendString}" };
case NewExpression newExpression:
return GetNewExpressionString(newExpression, appendString);
case ParameterExpression parameterExpression when parent is LambdaExpression lambdaExpression && lambdaExpression.ReturnType == parameterExpression.Type:
return new string[0];
case ConstantExpression constantExpression:
return constantExpression
.Type
.GetProperties(BindingFlags.Public | BindingFlags.Instance)
.Select(x => $"[{x.Name}] = {CreateParameter(x.GetValue(constantExpression.Value)).ParameterName}");
}
throw new NotSupportedException();
}
private Expression ParseExpression(MethodCallExpression expression, List<string> commandList, string appendString = null)
{
var unary = (UnaryExpression)expression.Arguments[1];
var lambdaExpression = (LambdaExpression)unary.Operand;
lambdaExpression = (LambdaExpression)Evaluator.PartialEval(lambdaExpression);
commandList.AddRange(ParseExpression(lambdaExpression, lambdaExpression.Body, appendString));
return Visit(expression.Arguments[0]);
}
private Expression ParseExpression(MethodCallExpression expression, ref int? size)
{
var sizeExpression = (ConstantExpression)expression.Arguments[1];
if (int.TryParse(sizeExpression.Value.ToString(), out int value))
{
size = value;
return Visit(expression.Arguments[0]);
}
throw new NotSupportedException();
}
}
我会在评论区贴出扩展程序。
请在生产环境中小心使用
欢迎将其制作成 Nuget 包 :)