如何封装Entity Framework以在执行前拦截LINQ表达式?

26

我想在执行之前重写LINQ表达式的某些部分。但是我在将我的重写器注入到正确位置时遇到了问题(实际上是完全无法注入)。

查看Entity Framework源代码(使用reflector),最终涉及到IQueryProvider.Execute,在EF中由ObjectContext通过提供internal IQueryProvider Provider { get; }属性与表达式耦合。

因此,我创建了一个包装类(实现IQueryProvider),在调用Execute时对表达式进行重写,然后将其传递给原始提供程序。

问题是,Provider背后的字段为private ObjectQueryProvider _queryProvider;。 这个ObjectQueryProvider是一个internal sealed class,这意味着不可能创建一个子类以提供额外的重写功能。

因此,由于非常紧密地耦合了ObjectContext,这种方法让我陷入了死胡同。

如何解决这个问题? 我是在错误的方向上寻找吗? 也许有一种方法可以在ObjectQueryProvider周围注入自己?

更新:虽然提供的解决方案都适用于使用存储库模式“包装”ObjectContext的情况,但最好的解决方案是允许直接使用从ObjectContext生成的子类。 从而保持与Dynamic Data脚手架的兼容性。

3个回答

15

根据Arthur的回答,我创建了一个可用的包装器。

提供的代码片段提供了一种使用自己的QueryProvider和IQueryable根节点来包装每个LINQ查询的方法。这意味着您必须控制初始查询的起点(因为在使用任何模式时,大多数时间都是如此)。

这种方法的问题在于它不是透明的,更理想的情况是在构造函数级别向实体容器注入某些内容。

我创建了一个可编译的实现,让它能够与Entity Framework一起工作,并添加了对ObjectQuery.Include方法的支持。表达式访问者类可以从MSDN复制。

public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression expression = null;
    private QueryTranslatorProvider<T> provider = null;

    public QueryTranslator(IQueryable source)
    {
        expression = Expression.Constant(this);
        provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        expression = e;
        provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)provider.ExecuteEnumerable(this.expression)).GetEnumerator();
    }

    System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return provider.ExecuteEnumerable(this.expression).GetEnumerator();
    }

    public QueryTranslator<T> Include(String path)
    {
        ObjectQuery<T> possibleObjectQuery = provider.source as ObjectQuery<T>;
        if (possibleObjectQuery != null)
        {
            return new QueryTranslator<T>(possibleObjectQuery.Include(path));
        }
        else
        {
            throw new InvalidOperationException("The Include should only happen at the beginning of a LINQ expression");
        }
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return expression; }
    }

    public IQueryProvider Provider
    {
        get { return provider; }
    }
}

public class QueryTranslatorProvider<T> : ExpressionVisitor, IQueryProvider
{
    internal IQueryable source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        this.source = source;
    }

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        return new QueryTranslator<TElement>(source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        Type elementType = expression.Type.GetGenericArguments().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.Execute(translated);
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);
        return source.Provider.CreateQuery(translated);
    }

    #region Visitors
    protected override Expression VisitConstant(ConstantExpression c)
    {
        // fix up the Expression tree to work with EF again
        if (c.Type == typeof(QueryTranslator<T>))
        {
            return source.Expression;
        }
        else
        {
            return base.VisitConstant(c);
        }
    }
    #endregion
}

您的存储库中的示例用法:

示例用法在您的存储库中:

public IQueryable<User> List()
{
    return new QueryTranslator<User>(entities.Users).Include("Department");
}

你现在需要的东西都有了吗?我应该再提供一些辅助方法还是查找我的代码中的某些内容? - Arthur
我搞定了,但奇怪的是你忘记了把它改回 EF 查询的部分。 - Davy Landman

10

我有你需要的源代码,但不知道如何附加文件。

以下是一些片段(片段!我已经改编了这段代码,所以可能无法编译):

IQueryable:

public class QueryTranslator<T> : IOrderedQueryable<T>
{
    private Expression _expression = null;
    private QueryTranslatorProvider<T> _provider = null;

    public QueryTranslator(IQueryable source)
    {
        _expression = Expression.Constant(this);
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public QueryTranslator(IQueryable source, Expression e)
    {
        if (e == null) throw new ArgumentNullException("e");
        _expression = e;
        _provider = new QueryTranslatorProvider<T>(source);
    }

    public IEnumerator<T> GetEnumerator()
    {
        return ((IEnumerable<T>)_provider.ExecuteEnumerable(this._expression)).GetEnumerator();
    }

    IEnumerator System.Collections.IEnumerable.GetEnumerator()
    {
        return _provider.ExecuteEnumerable(this._expression).GetEnumerator();
    }

    public Type ElementType
    {
        get { return typeof(T); }
    }

    public Expression Expression
    {
        get { return _expression; }
    }

    public IQueryProvider Provider
    {
        get { return _provider; }
    }
}

IQueryProvider:

public class QueryTranslatorProvider<T> : ExpressionTreeTranslator, IQueryProvider
{
    IQueryable _source;

    public QueryTranslatorProvider(IQueryable source)
    {
        if (source == null) throw new ArgumentNullException("source");
        _source = source;
    }

    #region IQueryProvider Members

    public IQueryable<TElement> CreateQuery<TElement>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        return new QueryTranslator<TElement>(_source, expression) as IQueryable<TElement>;
    }

    public IQueryable CreateQuery(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Type elementType = expression.Type.FindElementTypes().First();
        IQueryable result = (IQueryable)Activator.CreateInstance(typeof(QueryTranslator<>).MakeGenericType(elementType),
            new object[] { _source, expression });
        return result;
    }

    public TResult Execute<TResult>(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");
        object result = (this as IQueryProvider).Execute(expression);
        return (TResult)result;
    }

    public object Execute(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.Execute(translated);            
    }

    internal IEnumerable ExecuteEnumerable(Expression expression)
    {
        if (expression == null) throw new ArgumentNullException("expression");

        Expression translated = this.Visit(expression);

        return _source.Provider.CreateQuery(translated);
    }

    #endregion        

    #region Visits
    protected override MethodCallExpression VisitMethodCall(MethodCallExpression m)
    {
        return m;
    }

    protected override Expression VisitUnary(UnaryExpression u)
    {
         return Expression.MakeUnary(u.NodeType, base.Visit(u.Operand), u.Type.ToImplementationType(), u.Method);
    }
    #endregion
}

使用方法(警告:代码已调整!可能无法编译):

private Dictionary<Type, object> _table = new Dictionary<Type, object>();
public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateQuery<T>("[" + typeof(T).Name + "]"));
    }

    return (IQueryable<T>)_table[type];
}

表达式访问器/翻译器:

http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx

http://msdn.microsoft.com/en-us/library/bb882521.aspx

编辑:添加了FindElementTypes()。希望现在所有方法都存在。

    /// <summary>
    /// Finds all implemented IEnumerables of the given Type
    /// </summary>
    public static IQueryable<Type> FindIEnumerables(this Type seqType)
    {
        if (seqType == null || seqType == typeof(object) || seqType == typeof(string))
            return new Type[] { }.AsQueryable();

        if (seqType.IsArray || seqType == typeof(IEnumerable))
            return new Type[] { typeof(IEnumerable) }.AsQueryable();

        if (seqType.IsGenericType && seqType.GetGenericArguments().Length == 1 && seqType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
        {
            return new Type[] { seqType, typeof(IEnumerable) }.AsQueryable();
        }

        var result = new List<Type>();

        foreach (var iface in (seqType.GetInterfaces() ?? new Type[] { }))
        {
            result.AddRange(FindIEnumerables(iface));
        }

        return FindIEnumerables(seqType.BaseType).Union(result);
    }

    /// <summary>
    /// Finds all element types provided by a specified sequence type.
    /// "Element types" are T for IEnumerable&lt;T&gt; and object for IEnumerable.
    /// </summary>
    public static IQueryable<Type> FindElementTypes(this Type seqType)
    {
        return seqType.FindIEnumerables().Select(t => t.IsGenericType ? t.GetGenericArguments().Single() : typeof(object));
    }

如果我没记错的话,我需要将ObjectContext生成的子类中的Sets从base.CreateQuery调用更改为使用这个包装器?这并不是一个好的解决方案,因为重新生成会破坏我的更改,难道我对你的用法示例有误解吗? - Davy Landman
你好,能提供一下“ExpressionTreeTranslator”吗?我猜它是Expression Tree访问者模式的一个实现? - Davy Landman
@第一条评论:没错,你可以包装CreateQuery调用。我有自己的生成器,所以没有问题。我还有一个通用的GetQuery方法,它创建正确的EF查询并进行包装。我会发布那个方法。 @第二条:您可以在此处找到QueryTranslator:http://msdn.microsoft.com/en-us/library/bb882521.aspx或者在此处找到http://blogs.msdn.com/mattwar/archive/2007/07/31/linq-building-an-iqueryable-provider-part-ii.aspx - Arthur
好的,所以那些类只是改名了,我就是这么想的。但是你能提供一下叫做FindElementTypes()的扩展方法吗?我在谷歌上也找不到。 - Davy Landman
很抱歉,但我无法让它真正工作,EF提供程序不喜欢使用QueryTranslater包装的查询。--System.NotSupportedException:无法创建类型为“QueryTranslator`1”的常量值。在此上下文中,仅支持基元类型(例如Int32、String和Guid)。 - Davy Landman
这行代码public override IQueryable<T> GetObjectQuery<T>()在哪里?你提供的代码好像每个人都知道它在哪里一样。令人惊讶的是,它竟然能获得11票(不管我的-1)。 - Hopeless

4

我想补充一下Arthur的例子。

正如Arthur所警告的那样,他的GetObjectQuery()方法中确实存在一个bug。

它使用typeof(T).Name作为EntitySet名称来创建基本查询。

EntitySet名称与类型名称非常不同。

如果您正在使用EF 4,则应执行以下操作:

public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateObjectSet<T>();
    }

    return (IQueryable<T>)_table[type];
}

只要您没有多个实体集合类型(MEST),这种方法就有效,但这种情况非常罕见。
如果您使用的是3.5版本,则可以使用我在Tip 13中提供的代码来获取EntitySet名称,并像这样进行馈送:
public override IQueryable<T> GetObjectQuery<T>()
{
    if (!_table.ContainsKey(type))
    {
        _table[type] = new QueryTranslator<T>(
            _ctx.CreateQuery<T>("[" + GetEntitySetName<T>() + "]"));

    } 
    return (IQueryable<T>)_table[type];
}

希望这有所帮助

Alex

Entity Framework 提示


谢谢修复错误 ;) 但是也许更改链接为非 URL 缩短服务,以保持 StackOverflow 数据库的清洁? - Davy Landman
当然,如果我的博客具备良好的引荐统计功能就好了。 - Alex James
哦,我能理解那个。我不知道在blogs.msdn.com上是否允许使用Google Analytics,但在我看来,它们可以很好地概述您的访问者。 - Davy Landman

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接