EF之结构进一步优化

针对之前的使用,做了进一步优化

1.将DAL对象缓存起来

2.仓储类不依赖固定构造的DbContext,执行操作的时候,从线程中动态读取DbContext,这一步也是为了方便将DAL对象缓存起来,解决缓存对象的DbContext的释放问题,没有依赖固定构造的DbContext就不存在释放问题了。(如果依赖固定构造的DbContext,假如webapi情景,解决方案是在ActionFilter中调用API之前声明线程标识,读取缓存的时候根据该线程标识来决定是否替换以及释放DbContext,不难做到)

3.预留IDbContextFactory,方便动态创建自定义的DbContext

4.BIZ层声明静态方法,方便调用

5.更快捷的事务调用方式

6.IRepository中添加GetDbContext();供DAL对象重写实现,从而可以支持多库,默认的情况下由继承自IDbContextFactory来实现DbContext

所以接下来是一个最新的版本,同时修复了一些bug,譬如解决的事务嵌套的问题、ExecuteSqlCommand执行被上下文更新覆盖的问题、扩展了查询直接获取匿名对象,扩展了连接查询(不依赖linq)、扩展了查询单一或多个字段、修复了排序的bug问题、修复了附加实体出现主键重复的问题、将仓储类非泛型基类与泛型基类分离从而方便非泛型方法的直接调用。

目前接下来代码是一个最新的版本,不排除依旧有bug

 /// <summary>
    /// 数据仓储基类
    /// </summary>
    public class BaseRepository : IRepository, IDisposable
    {

        /// <summary>
        /// 数据库上下文字段
        /// </summary>
        protected MyDbContext _dbContext;

        /// <summary>
        /// 数据库上下文属性
        /// </summary>
        public MyDbContext DbContext
        {
            get
            {
                return _dbContext;
            }
            set
            {
                _dbContext = value;
            }
        }

        private bool disposed;

        /// <summary>
        /// 默认构造函数
        /// </summary>
        public BaseRepository()
        {
        }

        /// <summary>
        /// 构造函数
        /// </summary>
        /// <param name="db"></param>
        public BaseRepository(DbSource db)
        {
            //dbContext = DbContextFactory.GetCurrentDbContext(db);
        }

        /// <summary>
        /// 构造函数
        /// </summary>
        /// <param name="_dbContext"></param>
        public BaseRepository(MyDbContext _dbContext)
        {
            //this.dbContext = _dbContext;

        }

        #region/// <summary>
        /// 获取查询数量
        /// </summary>
        /// <param name="sqlText"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public int GetCount(string sqlText, params DbParameter[] parms)
        {
            //return dbContext.Database.SqlQuery(typeof(int), sql, paras).Cast<int>().First();
            return GetScalar<int>(sqlText, parms);
        }

        /// <summary>
        /// 返回字段
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="sqlText"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public T GetScalar<T>(string sqlText, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            string connectionString = dbContext.Database.Connection.ConnectionString;
            using (var conn = new SqlConnection(connectionString))
            {
                SqlCommand command = new SqlCommand(sqlText);
                command.Connection = conn;
                //command.Parameters.Clear();
                command.Parameters.AddRange(parms);
                conn.Open();
                object obj = null;
                obj = command.ExecuteScalar();
                command.Parameters.Clear();
                command.Dispose();
                conn.Close();

                if (obj == null
                    || obj == System.DBNull.Value)
                    return default(T);

                if (typeof(T) == typeof(int))
                    obj = Convert.ToInt32(obj);

                return (T)obj;
            }
        }

        /// <summary>
        ///  执行不带参数的sql语句,返回一个对象
        /// </summary>
        /// <typeparam name="TView"></typeparam>
        /// <param name="sqlText"></param>
        /// <returns></returns>
        public TView GetSingle<TView>(string sqlText)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TView>(sqlText).Cast<TView>().First();
            }
            catch
            {
                return default(TView);
            }
        }


        /// <summary>
        ///  执行带参数的sql语句,返回一个对象
        /// </summary>
        /// <typeparam name="TView"></typeparam>
        /// <param name="sqlText"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public TView GetSingle<TView>(string sqlText, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TView>(sqlText, parms).Cast<TView>().First();
            }
            catch
            {
                return default(TView);
            }
        }

        /// <summary>
        /// 执行不带参数的sql语句,返回list
        /// </summary>
        /// <typeparam name="TView"></typeparam>
        /// <param name="sqlText"></param>
        /// <returns></returns>
        public List<TView> GetList<TView>(string sqlText)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TView>(sqlText).ToList();
            }
            catch
            {
                return null;
            }
        }


        /// <summary>
        /// 执行带参数的sql语句,返回List
        /// </summary>
        /// <typeparam name="TView"></typeparam>
        /// <param name="sqlText"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public List<TView> GetList<TView>(string sqlText, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TView>(sqlText, parms).ToList();
            }
            catch (Exception e)
            {
                return null;
            }
        }

        /// <summary>
        /// 多表连查分页查询
        /// </summary>
        /// <typeparam name="TView"></typeparam>
        /// <param name="sqlText"></param>
        /// <param name="orderText"></param>
        /// <param name="page"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public PageSource<TView> GetPaged<TView>(string sqlText, string orderText, PageFilter page, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            PageSource<TView> pageSource = new PageSource<TView>();
            int pageIndexPara = page.PageIndex;
            // 计算分页大小,和分页数
            string sqlTextCount = String.Format("SELECT COUNT(1) AS CT  FROM ({0}) t", sqlText);
            string cst = dbContext.Database.Connection.ConnectionString;
            pageSource.TotalCount = GetScalar<int>(sqlTextCount, parms);
            int pageCount = 0;
            if (page.PageSize <= 0)
            {
                page.PageSize = 20;
            }
            if (pageSource.TotalCount % page.PageSize == 0)
            {
                pageCount = pageSource.TotalCount / page.PageSize;
            }
            else
            {
                pageCount = pageSource.TotalCount / page.PageSize + 1;
            }

            // 得到当前页面索引
            if (page.PageIndex < 1)
                page.PageIndex = 1;
            int currentPageIndex = page.PageIndex;
            if (currentPageIndex > pageCount)
            {
                currentPageIndex = pageCount;
                page.PageIndex = currentPageIndex;
            }
            pageSource.PageCount = pageCount;
            pageSource.PageIndex = page.PageIndex;
            pageSource.PageSize = page.PageSize;
            // 得到用于分页的SQL语句
            int startIndex = (currentPageIndex - 1) * page.PageSize;
            int endIndex = currentPageIndex * page.PageSize;

            if (pageIndexPara <= 0 || pageIndexPara > pageSource.PageCount)
            {
                pageSource.PageIndex = pageIndexPara;
                pageSource.DataSource = null;
                return pageSource;
            }

            string rowNumber = String.Format(" (ROW_NUMBER() OVER(ORDER BY {0})) AS RowNumber, ", orderText);
            sqlText = sqlText.Trim().Insert(6, rowNumber);
            string sqlTextRecord = String.Format("SELECT * FROM ({0}) TT1 WHERE RowNumber>{1} and RowNumber<={2}",
                sqlText,
                startIndex,
                endIndex
                );

            pageSource.DataSource = GetList<TView>(sqlTextRecord, parms);

            return pageSource;

        }

        #endregion

        /// <summary>
        /// 执行SQL命令
        /// </summary>
        /// <param name="sqlText"></param>
        /// <returns></returns>
        public int ExecuteSqlCommand(string sqlText)
        {
            MyDbContext dbContext = this.GetDbContext();
            if (dbContext.IsTransaction)
            {
                if (dbContext.Database.CurrentTransaction == null)
                {
                    dbContext.Database.BeginTransaction();
                }
            }
            
            int effect= dbContext.Database.ExecuteSqlCommand(sqlText);
            return effect;
        }

        /// <summary>
        /// 执行带参数SQL命令
        /// </summary>
        /// <param name="sqlText"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public int ExecuteSqlCommand(string sqlText, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            if (dbContext.IsTransaction)
            {
                if (dbContext.Database.CurrentTransaction == null)
                {
                    dbContext.Database.BeginTransaction();
                }
            }

            return dbContext.Database.ExecuteSqlCommand(sqlText, parms);
        }


        ///// <summary>
        ///// 通过Out参数返回要获取的值
        ///// </summary>
        ///// <param name="storedProcName"></param>
        ///// <param name="Parameters"></param>
        ///// <returns></returns>
        //public object[] ExecuteProc(string storedProcName, params SqlParameter[] Parameters)
        //{
        //    using (var conn = new SqlConnection(dbContext.Database.Connection.ConnectionString))
        //    {
        //        List<SqlParameter> outParms = Parameters.Where(p => p.Direction == System.Data.ParameterDirection.Output).ToList();
        //        SqlCommand command = new SqlCommand(storedProcName);
        //        command.Connection = conn;
        //        command.CommandType = CommandType.StoredProcedure;

        //        command.Parameters.AddRange(Parameters);
        //        conn.Open();
        //        command.ExecuteNonQuery();

        //        command.Parameters.Clear();
        //        command.Dispose();
        //        conn.Close();
        //        object[] values = outParms.Select(r => r.Value).ToArray();
        //        return values;
        //    }
        //}


        /// <summary>
        /// 通过Out参数返回要获取的值
        /// </summary>
        /// <param name="procName"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public object[] ExecuteProc(string procName, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            List<DbParameter> outParms = parms.Where(p => p.Direction == System.Data.ParameterDirection.Output || p.Direction == System.Data.ParameterDirection.ReturnValue).ToList();
            DbParameter returnParm = parms.FirstOrDefault(p => p.Direction == System.Data.ParameterDirection.ReturnValue);
            StringBuilder procBuilder = new StringBuilder(procName);
            foreach (DbParameter parm in parms)
            {
                if (parm.Direction == System.Data.ParameterDirection.Input)
                {
                    procBuilder.AppendFormat(" {0}{1}", parm.ParameterName, ",");
                }
                else if (parm.Direction == System.Data.ParameterDirection.Output)
                {
                    procBuilder.AppendFormat(" {0} {1}{2}", parm.ParameterName, "OUT", ",");
                }
            }
            string proc = procBuilder.ToString().TrimEnd(',');
            if (returnParm != null)
            {
                proc = "EXEC " + returnParm.ParameterName + "=" + proc;
            }
            else
            {
                proc = "EXEC " + proc;
            }
            if (dbContext.IsTransaction)
            {
                if (dbContext.Database.CurrentTransaction == null)
                {
                    dbContext.Database.BeginTransaction();
                }
            }
            var results = dbContext.Database.ExecuteSqlCommand(proc, parms);
            object[] values = outParms.Select(r => r.Value).ToArray();
            return values;

        }


        /// <summary>
        /// 返回结果集的存储过程
        /// </summary>
        /// <typeparam name="TView"></typeparam>
        /// <param name="procName"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public List<TView> ExecuteProc<TView>(string procName, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            StringBuilder procBuilder = new StringBuilder(procName);
            foreach (DbParameter parm in parms)
            {
                if (parm.Direction == System.Data.ParameterDirection.Input)
                {
                    procBuilder.AppendFormat(" {0}{1}", parm.ParameterName, ",");
                }
                else if (parm.Direction == System.Data.ParameterDirection.Output)
                {
                    procBuilder.AppendFormat(" {0} {1}{2}", parm.ParameterName, "OUT", ",");
                }
            }
            string proc = procBuilder.ToString().TrimEnd(',');
            return dbContext.Database.SqlQuery<TView>(proc, parms).ToList();
        }

        /// <summary>
        /// 创建参数
        /// </summary>
        /// <param name="name"></param>
        /// <param name="value"></param>
        /// <returns></returns>
        public SqlParameter GetParameter(string name, object value)
        {
            return new SqlParameter("@" + name, value);
        }

        /// <summary>
        /// 创建参数
        /// </summary>
        /// <param name="name"></param>
        /// <param name="type"></param>
        /// <param name="size"></param>
        /// <returns></returns>
        public SqlParameter GetParameterOut(string name, DbType type, int size)
        {
            return new SqlParameter
            {
                ParameterName = "@" + name,
                Direction = ParameterDirection.Output,
                DbType = type,
                Size = size

            };
        }

        /// <summary>
        /// Dispose
        /// </summary>
        public void Dispose()
        {
            this.Dispose(true);
            GC.SuppressFinalize(this);
        }

        /// <summary>
        /// Dispose
        /// </summary>
        /// <param name="disposing"></param>
        public virtual void Dispose(bool disposing)
        {
            if (!this.disposed)
            {
                if (disposing)
                {
                    this._dbContext.Dispose();
                }
            }
            this.disposed = true;
        }

        /// <summary>
        /// 默认支持单库
        /// 如果多库的操作请在DAL子类中实现此方法
        /// </summary>
        /// <returns></returns>
        public virtual MyDbContext GetDbContext()
        {
            return Transaction.DbContextFactory.GetDbContext();
        }

    }
  /// <summary>
    /// 数据仓储泛型基类
    /// </summary>
    /// <typeparam name="TEntity"></typeparam>
    public class BaseRepository<TEntity> : BaseRepository where TEntity : class
    {

        #region 增删改查

        /// <summary>
        /// 新增实体对象
        /// </summary>
        /// <param name="entity"></param>
        /// <returns></returns>
        public TResult Insert(TEntity entity)
        {
            return this.ChangeObjectState(entity, EntityState.Added);
        }

        /// <summary>
        /// 新增实体对象集合
        /// </summary>
        /// <param name="entities"></param>
        /// <returns></returns>
        public TResult Insert(IEnumerable<TEntity> entities)
        {
            return this.ChangeObjectState(entities, EntityState.Added);
        }

        /// <summary>
        /// 实体对象更改
        /// </summary>
        /// <param name="entity"></param>
        /// <returns></returns>
        public TResult Update(TEntity entity)
        {
            return this.ChangeObjectState(entity, EntityState.Modified);
        }

        /// <summary>
        /// 更新s实体对象集合
        /// </summary>
        /// <param name="entities"></param>
        /// <returns></returns>
        public TResult Update(IEnumerable<TEntity> entities)
        {
            return this.ChangeObjectState(entities, EntityState.Modified);
        }

        /// <summary>
        /// 更新实体对象部分属性
        /// </summary>
        /// <param name="predicate"></param>
        /// <param name="updateAction"></param>
        /// <returns></returns>
        public TResult Update(Expression<Func<TEntity, bool>> predicate, Action<TEntity> updateAction)
        {
            if (predicate == null)
                throw new ArgumentNullException("predicate");
            if (updateAction == null)
                throw new ArgumentNullException("updateAction");
            MyDbContext dbContext = this.GetDbContext();
            //dbContext.Configuration.AutoDetectChangesEnabled = true;
            var _model = dbContext.Set<TEntity>().AsNoTracking().Where(predicate).ToList();
            if (_model == null) return new TResult(false, "参数为NULL");
            _model.ForEach(p =>
            {
                updateAction(p);
                DetachExistsEntity(p);
                dbContext.Entry<TEntity>(p).State = EntityState.Modified;
            });
            return Save(EntityState.Modified);
        }

        /// <summary>
        /// 删除实体对象
        /// </summary>
        /// <param name="entity"></param>
        /// <returns></returns>
        public TResult Delete(TEntity entity)
        {
            return this.ChangeObjectState(entity, EntityState.Deleted);
        }

        /// <summary>
        /// 删除实体对象集合
        /// </summary>
        /// <param name="entities"></param>
        /// <returns></returns>
        public TResult Delete(IEnumerable<TEntity> entities)
        {
            return this.ChangeObjectState(entities, EntityState.Deleted);
        }

        /// <summary>
        /// 根据条件删除实体对象集合
        /// </summary>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public TResult Delete(Expression<Func<TEntity, bool>> predicate)
        {
            MyDbContext dbContext = this.GetDbContext();
            List<TEntity> _list = null;

            _list = dbContext.Set<TEntity>().AsNoTracking().Where(predicate).ToList();
            foreach (var item in _list)
            {
                dbContext.Entry<TEntity>(item).State = EntityState.Deleted;
            }
            return Save(EntityState.Deleted);
        }


        /// <summary>
        /// 用作单表条件查询使用
        /// </summary>
        /// <returns></returns>
        public IQueryable<TEntity> GetQueryable()
        {
            MyDbContext dbContext = this.GetDbContext();
            DbSet<TEntity> query = dbContext.Set<TEntity>();
            return query;
        }


        /// <summary>
        /// 直接获取特定一个或者多个字段的值
        /// 多个字段需要声明Model或者采用dynamic
        /// var dmic=  GetScalar《dynamic》(m=>m.ID== "1",m=>new { m.ID,m.Name });var v = dmic.Name;
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="predicate"></param>
        /// <param name="select"></param>
        /// <returns></returns>
        public T GetScalar<T>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, T>> select)
        {
            MyDbContext dbContext = this.GetDbContext();
            if (predicate == null)
            {
                return dbContext.Set<TEntity>().AsNoTracking().Select(select).FirstOrDefault();
            }
            else
            {
                return dbContext.Set<TEntity>().AsNoTracking().Where(predicate).Select(select).FirstOrDefault();
            }

        }

        /// <summary>
        /// 获取单个数据
        /// </summary>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public TEntity GetSingle(Expression<Func<TEntity, bool>> predicate = null)
        {
            MyDbContext dbContext = this.GetDbContext();
            if (predicate == null)
            {
                return dbContext.Set<TEntity>().AsNoTracking().FirstOrDefault();
            }
            else
            {
                return dbContext.Set<TEntity>().AsNoTracking().Where(predicate).FirstOrDefault();
            }

        }

        /// <summary>
        ///  执行带参数的sql语句,获取单一数据
        /// </summary>
        /// <param name="sqlText"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public TEntity GetSingle(string sqlText, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TEntity>(sqlText, parms).Cast<TEntity>().First();
            }
            catch
            {
                return null;
            }
        }

        /// <summary>
        /// 执行不带参数的sql语句,获取单一数据
        /// </summary>
        /// <param name="sqlText"></param>
        /// <returns></returns>
        public TEntity GetSingle(string sqlText)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TEntity>(sqlText).Cast<TEntity>().First();
            }
            catch
            {
                return null;
            }
        }

        /// <summary>
        /// 获取多条记录
        /// </summary>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public List<TEntity> GetList(Expression<Func<TEntity, bool>> predicate = null)
        {
            MyDbContext dbContext = this.GetDbContext();
            if (predicate == null)
            {
                return dbContext.Set<TEntity>().AsNoTracking().ToList();
            }
            else
            {
                return dbContext.Set<TEntity>().AsNoTracking().Where(predicate).ToList();
            }
        }

        /// <summary>
        /// 获取多条记录,根据特定字段返回自定义model或者匿名类
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="predicate"></param>
        /// <param name="select"></param>
        /// <returns></returns>
        public List<T> GetList<T>(Expression<Func<TEntity, bool>> predicate, Expression<Func<TEntity, T>> select)
        {
            MyDbContext dbContext = this.GetDbContext();
            if (predicate == null)
            {
                return dbContext.Set<TEntity>().AsNoTracking().Select(select).ToList();
            }
            else
            {
                return dbContext.Set<TEntity>().AsNoTracking().Where(predicate).Select(select).ToList();
            }
        }


        /// <summary>
        /// 带有lambda表达式排序的获取
        /// </summary>
        /// <param name="predicate"></param>
        /// <param name="orderBy"></param>
        /// <returns></returns>
        public List<TEntity> GetList(Expression<Func<TEntity, bool>> predicate, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy)
        {
            return Get(predicate, orderBy).ToList();
        }

        /// <summary>
        /// 带有文本排序的获取
        /// </summary>
        /// <param name="predicate"></param>
        /// <param name="orderBy"></param>
        /// <returns></returns>
        public List<TEntity> GetList(Expression<Func<TEntity, bool>> predicate, string orderBy)
        {
            return Get(predicate, orderBy).ToList();
        }

        /// <summary>
        /// 执行不带参数的sql语句,获取多条记录
        /// </summary>
        /// <param name="sqlText"></param>
        /// <returns></returns>
        public List<TEntity> GetList(string sqlText)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TEntity>(sqlText).ToList();
            }
            catch
            {
                return null;
            }
        }

        /// <summary>
        /// 执行带参数的sql语句,获取多条记录
        /// </summary>
        /// <param name="sqlText"></param>
        /// <param name="parms"></param>
        /// <returns></returns>
        public List<TEntity> GetList(string sqlText, params DbParameter[] parms)
        {
            MyDbContext dbContext = this.GetDbContext();
            try
            {
                return dbContext.Database.SqlQuery<TEntity>(sqlText, parms).ToList();
            }
            catch
            {
                return null;
            }
        }

        /// <summary>
        /// 分页
        /// </summary>
        /// <param name="total"></param>
        /// <param name="filter"></param>
        /// <param name="orderBy"></param>
        /// <param name="index"></param>
        /// <param name="size"></param>
        /// <returns></returns>
        public List<TEntity> GetPaged(out int total, Expression<Func<TEntity, bool>> filter = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, int index = 1, int size = 20)
        {
            int skipCount = (index - 1) * size;
            var query = Get(filter, orderBy);
            total = query.Count();
            query = skipCount > 0 ? query.Skip(skipCount).Take(size) : query.Take(size);
            return query.ToList();
        }

        /// <summary>
        /// 分页
        /// </summary>
        /// <param name="total"></param>
        /// <param name="filter"></param>
        /// <param name="orderBy"></param>
        /// <param name="index"></param>
        /// <param name="size"></param>
        /// <returns></returns>
        public List<TEntity> GetPaged(out int total, Expression<Func<TEntity, bool>> filter = null, string orderBy = null, int index = 1, int size = 20)
        {
            int skipCount = (index - 1) * size;
            var query = Get(filter, orderBy);
            total = query.Count();
            query = skipCount > 0 ? query.Skip(skipCount).Take(size) : query.Take(size);
            return query.ToList();
        }

        /// <summary>
        /// 单表分页查询
        /// </summary>
        /// <param name="query"></param>
        /// <param name="page"></param>
        /// <returns></returns>
        public PageSource<TEntity> GetPaged(IQueryable<TEntity> query, PageFilter page)
        {
            PageSource<TEntity> pageSource = new PageSource<TEntity>();
            int pageIndexPara = page.PageIndex;
            int total = query.Count();
            pageSource.TotalCount = total;
            int pageCount = 0;
            if (page.PageSize <= 0)
            {
                page.PageSize = 20;
            }
            if (pageSource.TotalCount % page.PageSize == 0)
            {
                pageCount = pageSource.TotalCount / page.PageSize;
            }
            else
            {
                pageCount = pageSource.TotalCount / page.PageSize + 1;
            }
            // 得到当前页面索引
            if (page.PageIndex < 1)
                page.PageIndex = 1;

            if (page.PageIndex > pageCount)
            {
                page.PageIndex = pageCount;
            }
            pageSource.PageCount = pageCount;
            pageSource.PageIndex = page.PageIndex;
            pageSource.PageSize = page.PageSize;
            int skipCount = (page.PageIndex - 1) * page.PageSize;

            if (pageIndexPara <= 0 || pageIndexPara > pageSource.PageCount)
            {
                pageSource.PageIndex = pageIndexPara;
                pageSource.DataSource = null;
                return pageSource;
            }

            query = skipCount > 0 ? query.Skip(skipCount).Take(page.PageSize) : query.Take(page.PageSize);
            pageSource.DataSource = query.ToList();
            return pageSource;
        }


        /// <summary>
        /// 内连接查询
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        public List<TModel> GetInnerJoin<TInner, TModel>(
            Expression<Func<TEntity, dynamic>> outerKeySelector,
            Expression<Func<TInner, dynamic>> innerKeySelector,
            Expression<Func<TEntity, TInner, TModel>> resultSelector) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().Join(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector);
            return query.ToList();
        }

        /// <summary>
        /// 内连接查询,查询条件在主表中,支持返回匿名类
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="predicate">主表查询条件</param>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        public List<TModel> GetInnerJoin<TInner, TModel>(
           Expression<Func<TEntity, bool>> predicate,
           Expression<Func<TEntity, dynamic>> outerKeySelector,
           Expression<Func<TInner, dynamic>> innerKeySelector,
           Expression<Func<TEntity, TInner, TModel>> resultSelector) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().Where(predicate).Join(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector);
            return query.ToList();
        }

        /// <summary>
        /// 内连接查询,查询条件在主从表中,不支持返回匿名类
        /// </summary>
        /// <typeparam name="TInner">关联表</typeparam>
        /// <typeparam name="TModel">返回类型</typeparam>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <param name="predicate">主从表查询条件</param>
        /// <returns></returns>
        public List<TModel> GetInnerJoin<TInner, TModel>(
          Expression<Func<TEntity, dynamic>> outerKeySelector,
          Expression<Func<TInner, dynamic>> innerKeySelector,
          Expression<Func<TEntity, TInner, TModel>> resultSelector,
          Expression<Func<TModel, bool>> predicate) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().Join(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector).Where(predicate);
            return query.ToList();
        }

        /// <summary>
        /// 左连接查询
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        public List<TModel> GetLeftJoin<TInner, TModel>(
           Expression<Func<TEntity, dynamic>> outerKeySelector,
           Expression<Func<TInner, dynamic>> innerKeySelector,
           Expression<Func<TEntity, TInner, TModel>> resultSelector) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().LeftOuterJoin(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector);
            return query.ToList();
        }

        /// <summary>
        /// 左连接查询
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="predicate"></param>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        public List<TModel> GetLeftJoin<TInner, TModel>(
            Expression<Func<TEntity, bool>> predicate,
            Expression<Func<TEntity, dynamic>> outerKeySelector,
            Expression<Func<TInner, dynamic>> innerKeySelector,
            Expression<Func<TEntity, TInner, TModel>> resultSelector) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().Where(predicate).LeftOuterJoin(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector);
            return query.ToList();
        }

        /// <summary>
        /// 左连接查询
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public List<TModel> GetLeftJoin<TInner, TModel>(
          Expression<Func<TEntity, dynamic>> outerKeySelector,
          Expression<Func<TInner, dynamic>> innerKeySelector,
          Expression<Func<TEntity, TInner, TModel>> resultSelector,
          Expression<Func<TModel, bool>> predicate) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().LeftOuterJoin(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector).Where(predicate);
            return query.ToList();
        }

        /// <summary>
        /// 一对多连接查询
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        public List<TModel> GetGroupJoin<TInner, TModel>(
            Expression<Func<TEntity, dynamic>> outerKeySelector,
            Expression<Func<TInner, dynamic>> innerKeySelector,
            Expression<Func<TEntity, IEnumerable<TInner>, TModel>> resultSelector) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().GroupJoin(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector);
            return query.ToList();
        }

        /// <summary>
        /// 一对多连接查询
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="predicate"></param>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <returns></returns>
        public List<TModel> GetGroupJoin<TInner, TModel>(
            Expression<Func<TEntity, bool>> predicate,
            Expression<Func<TEntity, dynamic>> outerKeySelector,
            Expression<Func<TInner, dynamic>> innerKeySelector,
            Expression<Func<TEntity, IEnumerable<TInner>, TModel>> resultSelector) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().Where(predicate).GroupJoin(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector);
            return query.ToList();
        }

        /// <summary>
        /// 一对多连接查询
        /// </summary>
        /// <typeparam name="TInner"></typeparam>
        /// <typeparam name="TModel"></typeparam>
        /// <param name="outerKeySelector"></param>
        /// <param name="innerKeySelector"></param>
        /// <param name="resultSelector"></param>
        /// <param name="predicate"></param>
        /// <returns></returns>
        public List<TModel> GetGroupJoin<TInner, TModel>(
           Expression<Func<TEntity, dynamic>> outerKeySelector,
           Expression<Func<TInner, dynamic>> innerKeySelector,
           Expression<Func<TEntity, IEnumerable<TInner>, TModel>> resultSelector,
           Expression<Func<TModel, bool>> predicate) where TInner : class
        {
            MyDbContext dbContext = this.GetDbContext();
            var query = dbContext.Set<TEntity>().GroupJoin(dbContext.Set<TInner>(), outerKeySelector, innerKeySelector, resultSelector).Where(predicate);
            return query.ToList();
        }

        #endregion

        #region 私有方法

        /// <summary>
        /// 分离依存的实体对象
        /// </summary>
        /// <param name="entity"></param>
        /// <returns></returns>
        private Boolean DetachExistsEntity(TEntity entity)
        {
            var objContext = ((IObjectContextAdapter)this.GetDbContext()).ObjectContext;
            var objSet = objContext.CreateObjectSet<TEntity>();
            var entityKey = objContext.CreateEntityKey(objSet.EntitySet.Name, entity);

            Object foundEntity;
            var exists = objContext.TryGetObjectByKey(entityKey, out foundEntity);

            if (exists)
            {
                objContext.Detach(foundEntity);
            }

            return (exists);
        }


        private TResult Save()
        {
            MyDbContext dbContext = this.GetDbContext();
            TResult result = new TResult();
            int effect = 0;

            try
            {
                effect = dbContext.SaveChanges();
                if (effect > 0)
                {
                    result.Flag = true;
                }
                else
                {
                    result.Flag = false;
                    result.Message = "无受影响行";
                    result.Code = "None Effect";
                }
            }
            catch (Exception ex)
            {
                result.Flag = false;
                result.Message = ex.Message;
                result.Code = "Exception";
                if (dbContext.IsTransaction)
                {
                    throw ex;
                }
            }

            return result;
        }

        private TResult Save(EntityState state)
        {
            MyDbContext dbContext = this.GetDbContext();
            TResult result = new TResult();
            int effect = 0;

            try
            {
                effect = dbContext.SaveChanges();
                if (effect > 0)
                {
                    result.Flag = true;
                    switch (state)
                    {
                        case EntityState.Added:
                            result.Message = "添加成功";
                            result.Code = "Insert Success";
                            break;
                        case EntityState.Modified:
                            result.Message = "更新成功";
                            result.Code = "Update Success";
                            break;
                        case EntityState.Deleted:
                            result.Message = "删除成功";
                            result.Code = "Delete Success";
                            break;
                        default:
                            break;
                    }
                }
                else
                {
                    result.Flag = false;
                    result.Message = "无受影响行";
                    result.Code = "None Effect";
                }
            }
            catch (Exception ex)
            {
                result.Flag = false;
                result.Message = ex.Message;
                result.Code = "Exception";
                if (dbContext.IsTransaction)
                {
                    throw ex;
                }
            }

            return result;
        }

        /// <summary>
        /// 变更上下文管理器(对象)
        /// </summary>
        /// <param name="entity"></param>
        /// <param name="state"></param>
        /// <returns></returns>
        private TResult ChangeObjectState(TEntity entity, EntityState state)
        {
            MyDbContext dbContext = this.GetDbContext();
            if (entity == null)
            {
                return new TResult(false, "参数为NULL");
            }
            //_context.Configuration.ValidateOnSaveEnabled = false; 
            DetachExistsEntity(entity);
            dbContext.Entry<TEntity>(entity).State = state;
            return Save(state);

        }

        /// <summary>
        /// 变更上下文管理器(对象集合)
        /// </summary>
        /// <param name="entities"></param>
        /// <param name="state"></param>
        /// <returns></returns>
        private TResult ChangeObjectState(IEnumerable<TEntity> entities, EntityState state)
        {
            if (entities == null) return new TResult(false, "参数为NULL");
            MyDbContext dbContext = this.GetDbContext();
            //_context.Configuration.AutoDetectChangesEnabled = false;
            entities.ToList().ForEach(p =>
            {
                DetachExistsEntity(p);
                dbContext.Entry<TEntity>(p).State = state;
            });
            return Save(state);

        }


        private IQueryable<TEntity> Get(Expression<Func<TEntity, bool>> filter = null, string orderBy = null)
        {
            MyDbContext dbContext = this.GetDbContext();
            IQueryable<TEntity> query = dbContext.Set<TEntity>();
            if (filter != null)
            {
                query = query.Where(filter);
            }
            if (!string.IsNullOrEmpty(orderBy))
            {
                query = query.OrderBy(orderBy);
            }
            return query.AsQueryable();
        }



        private IQueryable<TEntity> Get(Expression<Func<TEntity, bool>> filter = null, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null)
        {
            MyDbContext dbContext = this.GetDbContext();
            IQueryable<TEntity> query = dbContext.Set<TEntity>();
            if (filter != null)
            {
                query = query.Where(filter);
            }
            if (orderBy != null)
            {
                orderBy(query).AsQueryable();
            }
            return query.AsQueryable();
        }


        #endregion


    }
  /// <summary>
    /// DAL简单工厂
    /// </summary>
    public class DALFactory
    {

        private static object _lock = new object();

        private static DALCache _DALCaches = new DALCache();

        /// <summary>
        /// DAL缓存容器
        /// </summary>
        public static DALCache DALCaches
        {
            get
            {
                if (_DALCaches == null)
                {
                    lock (_lock)
                    {
                        if (_DALCaches == null)
                        {
                            _DALCaches = new DALCache();
                        }
                    }
                }
                return _DALCaches;
            }
        }

        /// <summary>
        /// 创建DAL简单工厂
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <returns></returns>
        public static T CreateDAL<T>() where T : class, IRepository, new()
        {
            Type key = typeof(T);
            if (DALCaches.Get(key.FullName) != null)
            {
                return (T)DALCaches.Get(key.FullName);
            }
            T dao = new T();
            DALCaches.Insert(key.FullName, dao);
            return dao;
        }

        /// <summary>
        /// 创建DAL简单工厂
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="myDbContext"></param>
        /// <returns></returns>
        public static T CreateDAL<T>(MyDbContext myDbContext) where T : class, IRepository, new()
        {
            Type key = typeof(T);
            T dao = default(T);
            if (DALCaches.Get(key.FullName) != null)
            {
                dao = (T)DALCaches.Get(key.FullName);
                if (dao.DbContext != null)
                {
                    dao.DbContext.Dispose();
                }
                dao.DbContext = myDbContext;
            }
            dao = new T();
            dao.DbContext = myDbContext;
            DALCaches.Insert(key.FullName, dao);
            return dao;
        }


    }
/// <summary>
    /// 创建DbContext简单工厂
    /// </summary>
    public abstract class DbContextFactory
    {
        /// <summary>
        /// 创建DbContext简单工厂
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <returns></returns>
        public static T GetCurrentDbContext<T>() where T : DbContext, new()
        {
            string name = typeof(T).Name;
            T dbContext = CallContext.GetData(name) as T;
            if (dbContext == null)
            {
                dbContext = new T();
                CallContext.SetData(name, dbContext);
            }
            return dbContext;
        }
      
    }
/// <summary>
    /// DbContext工厂方法接口,由DAL层实现继承并在配置文件里配置
    /// </summary>
    public interface IDbContextFactory
    {
        /// <summary>
        /// 获取数据上下文
        /// </summary>
        /// <returns></returns>
        MyDbContext GetDbContext();
    }
 /// <summary>
    /// 数据仓储接口
    /// </summary>
    public interface IRepository
    {
        /// <summary>
        /// 数据上下文
        /// </summary>
        MyDbContext DbContext { get; set; }

        /// <summary>
        /// 从线程中获取DbContext
        /// </summary>
        /// <returns></returns>
        MyDbContext GetDbContext();
    }
/// <summary>
    /// 事务接口
    /// </summary>
    public interface ITransaction
    {
        /// <summary>
        /// 事务标识
        /// </summary>
        bool IsTransaction { get; }

        /// <summary>
        /// 开启事务
        /// </summary>
        void BeginTransaction();

        /// <summary>
        /// 提交事务
        /// </summary>
        /// <returns></returns>
        int Commit();

        /// <summary>
        /// 回滚事务
        /// </summary>
        void Rollback();
    }
/// <summary>
    /// 数据上下文
    /// </summary>
    public class MyDbContext : DbContext, ITransaction
    {
        /// <summary>
        /// 构造函数
        /// </summary>
        /// <param name="connectionString"></param>
        public MyDbContext(string connectionString)
            : base(connectionString)
        {
            // 是否启动延迟加载
            Configuration.LazyLoadingEnabled = false;
            // 是否启动代理
            Configuration.ProxyCreationEnabled = false;
            Configuration.AutoDetectChangesEnabled = false;
            Configuration.ValidateOnSaveEnabled = false;

        }

        /// <summary>
        ///开启一个事务
        /// </summary>
        public void BeginTransaction()
        {
            if (this.Database.CurrentTransaction == null)
            {
                this.Database.BeginTransaction();
            }
            this.BeginCounter++;
            this.IsTransaction = true;

        }

        /// <summary>
        /// 提交一个事务
        /// </summary>
        /// <returns></returns>
        public int Commit()
        {
            this.BeginCounter--;
            int result = 0;
            if (this.BeginCounter == 0)
            {
                result += this.SaveChanges();
                this.IsTransaction = false;
                DbContextTransaction transaction = this.Database.CurrentTransaction;
                if (transaction != null)
                {
                    transaction.Commit();
                    transaction.Dispose();
                    result += 1;
                }
               
            }
            return result;
        }

        /// <summary>
        /// 回滚一个事务
        /// </summary>
        public void Rollback()
        {
            this.BeginCounter--;
            if (this.BeginCounter == 0)
            {
                this.IsTransaction = false;
                DbContextTransaction transaction = this.Database.CurrentTransaction;
                if (transaction != null)
                {
                    transaction.Rollback();
                    transaction.Dispose();
                }
            }
            else
            {
                //this.BeginCounter = 1;
                throw new Exception("嵌套内部事务异常");
            }
        }

        private bool isTransaction = false;

        /// <summary>
        /// 事务性操作
        /// </summary>
        public bool IsTransaction
        {
            get { return isTransaction; }
            set { this.isTransaction = value; }
        }

        private int beginCounter = 0;

        /// <summary>
        /// 事务计数器
        /// </summary>
        public int BeginCounter
        {
            get { return beginCounter; }
            set { this.beginCounter = value; }
        }
    }

 

/// <summary>
    /// 事务辅助类
    /// </summary>
    public class Transaction
    {
        private static object _lock = new object();

        private static IDbContextFactory dbContextFactory = null;

        /// <summary>
        /// DbContextFactory
        /// </summary>
        public static IDbContextFactory DbContextFactory
        {
            get
            {
                if (dbContextFactory == null)
                {
                    lock (_lock)
                    {
                        if (dbContextFactory == null)
                        {
                            dbContextFactory = LoadDbContextFactory();
                        }
                    }
                }
                return dbContextFactory;
            }
        }

        /// <summary>
        /// 开始事务
        /// </summary>
        public static void BeginTransaction()
        {
            MyDbContext dbContext = DbContextFactory.GetDbContext();
            dbContext.BeginTransaction();
        }

        /// <summary>
        /// 提交一个事务
        /// </summary>
        /// <returns></returns>
        public static int Commit()
        {
            MyDbContext dbContext = DbContextFactory.GetDbContext();
            return dbContext.Commit();
        }

        /// <summary>
        /// 回滚一个事务
        /// </summary>
        public static void Rollback()
        {
            MyDbContext dbContext = DbContextFactory.GetDbContext();
            dbContext.Rollback();
        }

        private static IDbContextFactory LoadDbContextFactory()
        {
            string factoryPath = ConfigurationManager.AppSettings["IDbcontextFactory"].ToString();
            string[] arr = factoryPath.Split(',');
            string assemblyName = arr[1];
            string className = arr[0];
            return (IDbContextFactory)Assembly.Load(assemblyName).CreateInstance(className);
        }

    }

接下来看看DAL层

 public class TestDAL: BaseRepository<Test>
    {
    }

 public class CustomerDbContext:MyDbContext
    {
        public CustomerDbContext()
            : base("XFW")
        {
            // 防止Entity变更导致数据库自动更新
            Database.SetInitializer<CustomerDbContext>(null);
        }

        public CustomerDbContext(string connectionString)
            : base(connectionString)
        {
            // 防止Entity变更导致数据库自动更新
            Database.SetInitializer<CustomerDbContext>(null);
        }

        public DbSet<Test> Test { get; set; }
    }

 public class CustomerDbContextFactory : IDbContextFactory
    {
        /// <summary>
        /// 创建自定义DbContext
        /// </summary>
        /// <returns></returns>
        public MyDbContext GetDbContext()
        {
           return  DbContextFactory.GetCurrentDbContext<CustomerDbContext>();
        }
    }

BIZ层

 public class TestBiz
    {
        public static TResult AddTest()
        {
            TestDAL testdal = DALFactory.CreateDAL<TestDAL>();
            TResult t;
            Transaction.BeginTransaction();
            try
            {
                testdal.Insert(new Test { ID=1 });
                testdal.ExecuteSqlCommand("update  test set name='111' where id=1");
                //Test test = testdal.GetSingle(m => m.ID == 1);
                //test.Name1 = "asdafaff";
                testdal.Update(m => m.ID == 1, m => { m.Name1 = "asdafaff"; });
                Transaction.Commit();
            }
            catch(Exception ex)
            {
                Transaction.Rollback();
            }
            finally
            {
                t= new TResult();
            }
            return t;
        }

       
    }

 

posted @ 2016-10-20 18:00  一条大河啊波浪宽啊  阅读(1029)  评论(0编辑  收藏  举报