Entity Framework - 仓储检查

3
我正在使用C#/ASP.NET和Entity Framework制作一个完整的代码仓库,但目前我担心我忘记了像处理ObjectContex这样的事情。在下面的代码中,您将看到我的完整代码仓库(至少是为了让大家理解我的问题所需的内容),希望有人能够耐心查看并告诉我是否犯了一些错误。
对于我来说,这个项目非常重要,但我对仓库/EF模型还很陌生。 Global.asax
public class Global : System.Web.HttpApplication
{
    private WebObjectContextStorage _storage;

    public override void Init()
    {
        base.Init();
        _storage = new WebObjectContextStorage(this);
    }

    protected void Application_Start(object sender, EventArgs e)
    {

    }

    protected void Session_Start(object sender, EventArgs e)
    {

    }

    protected void Application_BeginRequest(object sender, EventArgs e)
    {
        ObjectContextInitializer.Instance().InitializeObjectContextOnce(() =>
        {
            ObjectContextManager.InitStorage(_storage);
        });
    }

    protected void Application_EndRequest(object sender, EventArgs e)
    {

    }

    protected void Application_AuthenticateRequest(object sender, EventArgs e)
    {

    }

    protected void Application_Error(object sender, EventArgs e)
    {

    }

    protected void Session_End(object sender, EventArgs e)
    {

    }

    protected void Application_End(object sender, EventArgs e)
    {

    }
}

ObjectContextManager

public static class ObjectContextManager
{
    public static void InitStorage(IObjectContextStorage storage)
    {
        if (storage == null) 
        {
            throw new ArgumentNullException("storage");
        }
        if ((Storage != null) && (Storage != storage))
        {
            throw new ApplicationException("A storage mechanism has already been configured for this application");
        }            
        Storage = storage;
    }

    /// <summary>
    /// The default connection string name used if only one database is being communicated with.
    /// </summary>
    public static readonly string DefaultConnectionStringName = "TraceConnection";        

    /// <summary>
    /// Used to get the current object context session if you're communicating with a single database.
    /// When communicating with multiple databases, invoke <see cref="CurrentFor()" /> instead.
    /// </summary>
    public static ObjectContext Current
    {
        get
        {
            return CurrentFor(DefaultConnectionStringName);
        }
    }

    /// <summary>
    /// Used to get the current ObjectContext associated with a key; i.e., the key 
    /// associated with an object context for a specific database.
    /// 
    /// If you're only communicating with one database, you should call <see cref="Current" /> instead,
    /// although you're certainly welcome to call this if you have the key available.
    /// </summary>
    public static ObjectContext CurrentFor(string key)
    {
        if (string.IsNullOrEmpty(key))
        {
            throw new ArgumentNullException("key");
        }

        if (Storage == null)
        {
            throw new ApplicationException("An IObjectContextStorage has not been initialized");
        }

        ObjectContext context = null;
        lock (_syncLock)
        {
            context = Storage.GetObjectContextForKey(key);

            if (context == null)
            {
                context = ObjectContextFactory.GetTraceContext(key);
                Storage.SetObjectContextForKey(key, context);
            }
        }

        return context;
    }

    /// <summary>
    /// This method is used by application-specific object context storage implementations
    /// and unit tests. Its job is to walk thru existing cached object context(s) and Close() each one.
    /// </summary>
    public static void CloseAllObjectContexts()
    {
        foreach (ObjectContext ctx in Storage.GetAllObjectContexts())
        {
            if (ctx.Connection.State == System.Data.ConnectionState.Open)
                ctx.Connection.Close();
        }
    }      

    /// <summary>
    /// An application-specific implementation of IObjectContextStorage must be setup either thru
    /// <see cref="InitStorage" /> or one of the <see cref="Init" /> overloads. 
    /// </summary>
    private static IObjectContextStorage Storage { get; set; }

    private static object _syncLock = new object();
}

ObjectContextInitializer

public class ObjectContextInitializer
{
    private static readonly object syncLock = new object();
    private static ObjectContextInitializer instance;

    protected ObjectContextInitializer() { }

    private bool isInitialized = false;

    public static ObjectContextInitializer Instance()
    {
        if (instance == null)
        {
            lock (syncLock)
            {
                if (instance == null)
                {
                    instance = new ObjectContextInitializer();
                }
            }
        }

        return instance;
    }

    /// <summary>
    /// This is the method which should be given the call to intialize the ObjectContext; e.g.,
    /// ObjectContextInitializer.Instance().InitializeObjectContextOnce(() => InitializeObjectContext());
    /// where InitializeObjectContext() is a method which calls ObjectContextManager.Init()
    /// </summary>
    /// <param name="initMethod"></param>
    public void InitializeObjectContextOnce(Action initMethod)
    {
        lock (syncLock)
        {
            if (!isInitialized)
            {
                initMethod();
                isInitialized = true;
            }
        }
    }

}

ObjectContextFactory

public static class ObjectContextFactory
{
    /// <summary>
    /// Gets the TraceContext
    /// </summary>
    /// <param name="connectionString">Connection string to use for database queries</param>
    /// <returns>The TraceContext</returns>
    public static TraceContext GetTraceContext(string configName)
    {
        string connectionString = ConfigurationManager.ConnectionStrings[configName].ConnectionString;
        return new TraceContext(connectionString);
    }
}

WebObjectContextStorage(Web对象上下文存储)
public class WebObjectContextStorage : IObjectContextStorage
{   
    public WebObjectContextStorage(HttpApplication app)
    { 
        app.EndRequest += (sender, args) =>
                              {
                                  ObjectContextManager.CloseAllObjectContexts();
                                  HttpContext.Current.Items.Remove(HttpContextObjectContextStorageKey);
                              };
    }        

    public ObjectContext GetObjectContextForKey(string key)
    {
        ObjectContextStorage storage = GetObjectContextStorage();
        return storage.GetObjectContextForKey(key);
    }

    public void SetObjectContextForKey(string factoryKey, ObjectContext session)
    {
        ObjectContextStorage storage = GetObjectContextStorage();
        storage.SetObjectContextForKey(factoryKey, session);
    }

    public IEnumerable<ObjectContext> GetAllObjectContexts()
    {
        ObjectContextStorage storage = GetObjectContextStorage();
        return storage.GetAllObjectContexts();
    }

    private ObjectContextStorage GetObjectContextStorage()
    {
        HttpContext context = HttpContext.Current;
        ObjectContextStorage storage = context.Items[HttpContextObjectContextStorageKey] as ObjectContextStorage;
        if (storage == null)
        {
            storage = new ObjectContextStorage();
            context.Items[HttpContextObjectContextStorageKey] = storage;
        }
        return storage;
    }       

    private static readonly string HttpContextObjectContextStorageKey = "HttpContextObjectContextStorageKey";       
}

ObjectContextStorage

public class ObjectContextStorage : IObjectContextStorage
{
    private Dictionary<string, ObjectContext> storage = new Dictionary<string, ObjectContext>();

    /// <summary>
    /// Initializes a new instance of the <see cref="SimpleObjectContextStorage"/> class.
    /// </summary>
    public ObjectContextStorage() { }

    /// <summary>
    /// Returns the object context associated with the specified key or
    /// null if the specified key is not found.
    /// </summary>
    /// <param name="key">The key.</param>
    /// <returns></returns>
    public ObjectContext GetObjectContextForKey(string key)
    {
        ObjectContext context;
        if (!this.storage.TryGetValue(key, out context))
            return null;
        return context;
    }


    /// <summary>
    /// Stores the object context into a dictionary using the specified key.
    /// If an object context already exists by the specified key, 
    /// it gets overwritten by the new object context passed in.
    /// </summary>
    /// <param name="key">The key.</param>
    /// <param name="objectContext">The object context.</param>
    public void SetObjectContextForKey(string key, ObjectContext objectContext)
    {           
        this.storage.Add(key, objectContext);           
    }

    /// <summary>
    /// Returns all the values of the internal dictionary of object contexts.
    /// </summary>
    /// <returns></returns>
    public IEnumerable<ObjectContext> GetAllObjectContexts()
    {
        return this.storage.Values;
    }
}

通用仓储库

public class GenericRepository : IRepository
{
    private readonly string _connectionStringName;
    private ObjectContext _objectContext;
    private readonly PluralizationService _pluralizer = PluralizationService.CreateService(CultureInfo.GetCultureInfo("en"));
    private bool _usePlurazation;

    /// <summary>
    /// Initializes a new instance of the <see cref="GenericRepository&lt;TEntity&gt;"/> class.
    /// </summary>
    public GenericRepository()
        : this(string.Empty, false)
    {
    }

    /// <summary>
    /// Initializes a new instance of the <see cref="GenericRepository&lt;TEntity&gt;"/> class.
    /// </summary>
    /// <param name="connectionStringName">Name of the connection string.</param>
    public GenericRepository(string connectionStringName, bool usePlurazation)
    {
        this._connectionStringName = connectionStringName;
        this._usePlurazation = usePlurazation;
    }

    /// <summary>
    /// Initializes a new instance of the <see cref="GenericRepository"/> class.
    /// </summary>
    /// <param name="objectContext">The object context.</param>
    public GenericRepository(ObjectContext objectContext, bool usePlurazation)
    {
        if (objectContext == null)
            throw new ArgumentNullException("objectContext");
        this._objectContext = objectContext;
        this._usePlurazation = usePlurazation;
    }

    public TEntity GetByKey<TEntity>(object keyValue) where TEntity : class
    {
        EntityKey key = GetEntityKey<TEntity>(keyValue);

        object originalItem;
        if (ObjectContext.TryGetObjectByKey(key, out originalItem))
        {
            return (TEntity)originalItem;
        }
        return default(TEntity);
    }

    public IQueryable<TEntity> GetQuery<TEntity>() where TEntity : class
    {
        var entityName = GetEntityName<TEntity>();
        return ObjectContext.CreateQuery<TEntity>(entityName).OfType<TEntity>();
    }

    public IQueryable<TEntity> GetQuery<TEntity>(Expression<Func<TEntity, bool>> predicate) where TEntity : class
    {
        return GetQuery<TEntity>().Where(predicate);
    }

    public IQueryable<TEntity> GetQuery<TEntity>(ISpecification<TEntity> specification) where TEntity : class
    {
        return specification.SatisfyingEntitiesFrom(GetQuery<TEntity>());
    }

    public IEnumerable<TEntity> Get<TEntity>(Expression<Func<TEntity, string>> orderBy, int pageIndex, int pageSize, SortOrder sortOrder = SortOrder.Ascending) where TEntity : class
    {
        if (sortOrder == SortOrder.Ascending)
        {
            return GetQuery<TEntity>().OrderBy(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
        }
        return GetQuery<TEntity>().OrderByDescending(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
    }

    public IEnumerable<TEntity> Get<TEntity>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, string>> orderBy, int pageIndex, int pageSize, SortOrder sortOrder = SortOrder.Ascending) where TEntity : class
    {
        if (sortOrder == SortOrder.Ascending)
        {
            return GetQuery<TEntity>().Where(predicate).OrderBy(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
        }
        return GetQuery<TEntity>().Where(predicate).OrderByDescending(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
    }

    public IEnumerable<TEntity> Get<TEntity>(ISpecification<TEntity> specification, Expression<Func<TEntity, string>> orderBy, int pageIndex, int pageSize, SortOrder sortOrder = SortOrder.Ascending) where TEntity : class
    {
        if (sortOrder == SortOrder.Ascending)
        {
            return specification.SatisfyingEntitiesFrom(GetQuery<TEntity>()).OrderBy(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
        }
        return specification.SatisfyingEntitiesFrom(GetQuery<TEntity>()).OrderByDescending(orderBy).Skip(pageIndex).Take(pageSize).AsEnumerable();
    }

    public TEntity Single<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().SingleOrDefault<TEntity>(criteria);
    }

    public TEntity Single<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntityFrom(GetQuery<TEntity>());
    }

    public TEntity First<TEntity>(Expression<Func<TEntity, bool>> predicate) where TEntity : class
    {
        return GetQuery<TEntity>().FirstOrDefault(predicate);
    }

    public TEntity First<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntitiesFrom(GetQuery<TEntity>()).FirstOrDefault();
    }

    public void Add<TEntity>(TEntity entity) where TEntity : class
    {
        if (entity == null)
        {
            throw new ArgumentNullException("entity");
        }
        ObjectContext.AddObject(GetEntityName<TEntity>(), entity);
    }

    public void Attach<TEntity>(TEntity entity) where TEntity : class
    {
        if (entity == null)
        {
            throw new ArgumentNullException("entity");
        }

        ObjectContext.AttachTo(GetEntityName<TEntity>(), entity);
    }

    public void Delete<TEntity>(TEntity entity) where TEntity : class
    {
        if (entity == null)
        {
            throw new ArgumentNullException("entity");
        }
        ObjectContext.DeleteObject(entity);
    }

    public void Delete<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        IEnumerable<TEntity> records = Find<TEntity>(criteria);

        foreach (TEntity record in records)
        {
            Delete<TEntity>(record);
        }
    }

    public void Delete<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        IEnumerable<TEntity> records = Find<TEntity>(criteria);
        foreach (TEntity record in records)
        {
            Delete<TEntity>(record);
        }
    }

    public IEnumerable<TEntity> GetAll<TEntity>() where TEntity : class
    {
        return GetQuery<TEntity>().AsEnumerable();
    }

    public void Update<TEntity>(TEntity entity) where TEntity : class
    {
        var fqen = GetEntityName<TEntity>();

        object originalItem;
        EntityKey key = ObjectContext.CreateEntityKey(fqen, entity);
        if (ObjectContext.TryGetObjectByKey(key, out originalItem))
        {
            ObjectContext.ApplyCurrentValues(key.EntitySetName, entity);
        }
    }

    public IEnumerable<TEntity> Find<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().Where(criteria);
    }

    public TEntity FindOne<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().Where(criteria).FirstOrDefault();
    }

    public TEntity FindOne<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntityFrom(GetQuery<TEntity>());
    }

    public IEnumerable<TEntity> Find<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntitiesFrom(GetQuery<TEntity>());
    }

    public int Count<TEntity>() where TEntity : class
    {
        return GetQuery<TEntity>().Count();
    }

    public int Count<TEntity>(Expression<Func<TEntity, bool>> criteria) where TEntity : class
    {
        return GetQuery<TEntity>().Count(criteria);
    }

    public int Count<TEntity>(ISpecification<TEntity> criteria) where TEntity : class
    {
        return criteria.SatisfyingEntitiesFrom(GetQuery<TEntity>()).Count();
    }

    public IUnitOfWork UnitOfWork
    {
        get
        {
            if (unitOfWork == null)
            {
                unitOfWork = new UnitOfWork(this.ObjectContext);
            }
            return unitOfWork;
        }
    }

    private ObjectContext ObjectContext
    {
        get
        {
            if (this._objectContext == null)
            {
                if (string.IsNullOrEmpty(this._connectionStringName))
                {
                    this._objectContext = ObjectContextManager.Current;
                }
                else
                {
                    this._objectContext = ObjectContextManager.CurrentFor(this._connectionStringName);
                }
            }
            return this._objectContext;
        }
    }

    private EntityKey GetEntityKey<TEntity>(object keyValue) where TEntity : class
    {
        var entitySetName = GetEntityName<TEntity>();
        var objectSet = ObjectContext.CreateObjectSet<TEntity>();
        var keyPropertyName = objectSet.EntitySet.ElementType.KeyMembers[0].ToString();
        var entityKey = new EntityKey(entitySetName, new[] { new EntityKeyMember(keyPropertyName, keyValue) });
        return entityKey;
    }

    private string GetEntityName<TEntity>() where TEntity : class
    {
        // WARNING! : Exceptions for inheritance


        if (_usePlurazation)
        {
             return string.Format("{0}.{1}", ObjectContext.DefaultContainerName, _pluralizer.Pluralize(typeof(TEntity).Name));

        }
        else
        {
             return string.Format("{0}.{1}", ObjectContext.DefaultContainerName, typeof(TEntity).Name);

        }
    }

    private IUnitOfWork unitOfWork;
}

我知道阅读代码需要一些时间,但如果有人看一下并给出改进意见或者指出我没有处理好对象的情况,那将对我大有帮助。
另外我有一个小问题:“我想在这个存储库之上放置一个业务层,保持像global.asax这样的东西不变,但需要静态类(对吗?),比如一个BookProvider,它可以提供关于我的书籍实体的所有数据?”
提前感谢!

使用自定义仓储库相比仅使用ObjectContext的优点是什么?public void Attach<TEntity>(TEntity entity) where TEntity : class { if (entity == null) { throw new ArgumentNullException("entity"); } ObjectContext.AttachTo(GetEntityName(), entity); }这样做确实有所添加吗? - Oxymoron
首先,我可以在不编写新代码的情况下调用模型中的每个实体。例如:repository.GetAll(Book); 或 repository.GetAll(Shirt); 此外,正如您所看到的通用存储库中使用了规范模式,这使我能够轻松地链接规范。 - Julian
你考虑过使用EF 4.1的DbContext吗?它简化了许多ObjectContext的API。请参阅http://blogs.msdn.com/b/adonet/archive/2011/01/27/using-dbcontext-in-ef-feature-ctp5-part-1-introduction-and-model.aspx。 - Philippe
@Philippe 谢谢,我听说过它,但还没有看到好的文档。我会去研究一下! - Julian
1
@Julian 官方文档在这里:http://msdn.microsoft.com/zh-cn/library/gg696172(v=vs.103).aspx - Philippe
1个回答

5
我能提供的唯一具体意见是关于处理上下文的方式:
foreach (ObjectContext ctx in Storage.GetAllObjectContexts())
{
    if (ctx.Connection.State == System.Data.ConnectionState.Open)
        ctx.Connection.Close();
}

ObjectContext 实现了 IDisposable 接口,所以在我看来标准的方式应该是:

foreach (ObjectContext ctx in Storage.GetAllObjectContexts())
    ctx.Dispose();

据我所知,ObjectContext.Dispose()只是关闭连接,所以它做的与你所做的一样。但我认为这是一个内部实现细节,可能在EF版本之间发生变化。
您的通用存储库是众多存储库中的一种。查看方法时,我想到了一些要点:
- 由于您在public IQueryable<TEntity> GetQuery<TEntity>(...)中公开了IQueryable,那么为什么需要大多数其他方法,如SingleFirstCount等呢?(为什么不使用Any等?)您可以从IQueryable获取所有这些内容。 - 您的Update方法仅适用于标量属性。但这是通用存储库的常见问题。没有易于更新实体的通用方式或根本没有解决方案。 - 您使用存储库模式的目标是什么?如果您考虑使用内存数据存储进行单元测试,则不能公开IQueryable,因为LINQ to Entities和LINQ to Objects不同。要测试您的IQueryables是否有效,您需要集成测试和实际数据库,该数据库应在生产中使用。但是,如果您不公开IQueryable,则您的存储库需要许多特定于业务的方法,这些方法将结果作为POCO、POCO集合或选择/投影属性的DTO返回,并隐藏内部查询规范,以便您可以使用内存数据模拟这些方法来测试您的业务逻辑。但是,这就是通用存储库不再足够的地方。(例如,如何在涉及一个以上实体/ObjectSet的存储库中编写LINQ Join,该存储库只具有一个实体类型作为通用参数?)
如果您询问十个人他们的存储库的架构,您将得到十个不同的答案。没有人真正知道您的存储库真正价值多少,因为它取决于您将使用此存储库构建的应用程序。我相信没有人能告诉您您的存储库真正值多少。当您开始编写应用程序时,您将在实践中看到它。对于某些应用程序,它可能过度架构化(这是我认为最危险的,因为管理和控制无意义的架构是昂贵的,而且浪费时间,您会失去编写实际应用程序内容的时间)。对于其他需求,您可能需要扩展存储库。例如:
  • 如何处理实体的导航属性上的显式加载或查询(在EF 4.1中使用CreateSourceQueryDbEntityEntry.Collection/Reference)?如果您的应用程序从不需要显式加载:很好。 如果需要,则需要扩展您的Repo。

  • 如何控制贪婪加载?有时可能只需要父实体。有时您希望Include孩子1集合,有时是孩子2引用。

  • 如何手动设置实体的状态?也许您永远不必这样做。但在下一个应用程序中,这可能非常有用。

  • 如何手动从上下文中分离实体?

  • 如何控制实体的加载行为(是否由上下文跟踪实体)?

  • 如何手动控制延迟加载行为和更改跟踪代理的创建?

  • 如何手动创建实体代理?在使用延迟加载或更改跟踪代理时可能会需要它。

  • 如何将实体加载到上下文中而不构建结果集合?也许是另一个存储库方法,也许不是。谁知道您的应用程序逻辑将需要什么。

等等,等等...


哇,谢谢你!我明天会仔细阅读,今天回家晚了。但看起来很不错!太棒了。 - Julian
更新:@Slauma 我仍在学习你所说的一切。你确实让我对一些问题有了新的认识(如使用多种类型的加载等)。 - Julian
很遗憾只有一个人回答了我的问题,我并不完全满意,因为我希望更多的人对我的代码发表“意见”。感谢Slauma的长篇回复,它对我帮助很大!(授予50分) - Julian
@Julian:一条批评的话:我也犹豫是否要回答你的问题。为什么?因为你的问题里有太多“未经审核”的代码和代码片段,这些代码并没有很好地为你的问题做准备(看看global.asax里所有的空方法体,就很烦人,需要滚动下来跳过)。我相信如果你解释一下仓库的结构和目的以及所有相关的类,然后只提供各个类中的小例子代码片段,你会得到更好的回应。在我看来,这样更加友好读者。 - Slauma
谢谢你在那方面给我的帮助。我稍后会发布一个新问题并告诉你。(不过需要几天时间) - Julian

网页内容由stack overflow 提供, 点击上面的
可以查看英文原文,
原文链接