基于Entity Framework 4.1实现一个适用于单元测试的MockDbContext(下)
Posted on 2011-07-24 05:13 Saar 阅读(2255) 评论(7) 收藏 举报上篇中提到,我们在利用修改做的MockDbContext进行单元测试时,在获取数据时出现了问题。原因在于,在写获取的代码时,我们直接调用了DbContext的Set<T>()方法,而这个方法会从数据库中取数据。
我们来看一个例子:
这是一个简化的业务逻辑代码来获取全部的Handbook:
        public IList<Handbook> GetMyHandbooks()
          {
  var handbookSet = dbContext.Set<Handbook>();
            return handbookSet.ToList();
  }
对应的单元测试:
   1:          [TestMethod]
  2: public void GetMyHandBooksTest()
   3:          {
  4: var mockDbContext = new MockMTBContainer();
5: BizHandbook target = new BizHandbook(mockDbContext);
   6:   
  7: mockDbContext.Handbooks1.Add(new Handbook() { HandbookID = 1 });
8: mockDbContext.Handbooks1.Add(new Handbook() { HandbookID = 2 });
   9:   
    10:              var result = target.GetMyHandbooks();
    11:              Assert.IsNotNull(result);
    12:              Assert.AreEqual(2, result.Count);
    13:          }
行4-5创建一个业务逻辑对象,也是我们的测试目标target,第5行中的构造函数以以mockDbContext为参数,此业务逻辑对象会使用mockDbContext而非默认的DbContext。第10行代码调用了GetMyHandbooks()方法。第11行和12行分别验能够获得结果并且结果集中有两个对象。大家已经知道使用上篇中的mockDbContext,这个单元测试会Fail。第12行Assert预期为2,实际为0。
直接问题出在GetMyHandbooks方法的return语句中的.ToList()方法。这是一个System.Linq中提供的扩展方法,它所做的事情是:调用Expression对象进行数据查询得到一个IEnumerator<TEntity>,然后创建一个List<TEntity>对象,进行迭代将数据填入列表,最后返回。谁提供了Expression?DbContext中的DbSet<TEntity>。
在上篇中,我们仅仅是简单的重写了(或者说,废掉了^_^)DbContext的SaveChanges(),但是在查询数据的时候,Expression仍然会从对应的DB中获取数据。另外,对Local属性的理解也有问题,Local里并不是所有数据,但是新增或变更过的数据。于是,单元测试12行出现actual=0的现象也就得到合理解释了。看来,之前的想法太过简单了。需要重新构思一个mockDbContext。
由于DbSet<TEntity>的Expression属性不是virtual属性,简单重写一下Expression的想法行不通。曾经想过用一个同名属性覆盖它,但是,这样应该也有问题(之所以说应该,是因为没有试过),因为DbContext中使用的都是DbSet<TEntity>类,会调用原来的Expression属性的。
通过以上的分析,我们知道,要让MockDbContext好用,涉及两个类:DbContext和DbSet<TEntity>,因此,现在的思路是,实现两个新的类,替换掉它们。
由于要在业务逻辑不作任何修改的情况下调用DbContext和MockDbContext,这两个类要求实现同一接口;
同理,DbSet<TEntity>和新写的类(叫MockDbSet<TEntity>吧)同样也要实现同一个接口。
我们从MockDbSet<TEntity>着手。由于DbSet<TEntity>实现了IDbSet<TEntity>接口,因此,对于MockDbSet<TEntity>来说,实现IDbSet<TEntity>即可(这个类方法比较多,但都是非常基本的方法,文章最后有下载):
1: public class MockDbSet<TEntity> : IDbSet<TEntity>
2: where TEntity : class
   3:      {
  4: private ObservableCollection<TEntity> storage = new ObservableCollection<TEntity>();
   5:   
     6:   
  7: public TEntity Add(TEntity entity)
   8:          {
  9: if (entity != null)
  10:              {
    11:                  storage.Add(entity);
    12:              }
  13: return entity;
  14:          }
    15:   
  16: public TEntity Attach(TEntity entity)
  17:          {
    18:              storage.Add(entity);
  19: return entity;
  20:          }
    21:   
  22: public TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity
  23:          {
  24: return Activator.CreateInstance<TDerivedEntity>();
  25:          }
    26:   
  27: public TEntity Create()
  28:          {
  29: return Activator.CreateInstance<TEntity>();
  30:          }
    31:   
  32: public TEntity Find(params object[] keyValues)
  33:          {
  34: int currentKeyPropertyIndex;
35: foreach (var entity in storage)
  36:              {
    37:                  currentKeyPropertyIndex = 0;
  38: foreach (var property in typeof(TEntity).GetProperties())
  39:                  {
  40: if (property.Name.Contains("ID"))
  41:                      {
  42: if (property.GetValue(entity).Equals(keyValues[currentKeyPropertyIndex]))
  43:                          {
    44:                              currentKeyPropertyIndex++;
  45: if (currentKeyPropertyIndex == keyValues.Length)
  46:                              {
  47: return entity;
  48:                              }
    49:                          }
    50:                      }
    51:                  }
    52:              }
  53: return null;
  54:          }
    55:   
  56: public System.Collections.ObjectModel.ObservableCollection<TEntity> Local
  57:          {
  58: get { return new ObservableCollection<TEntity>(this.storage); }
  59:          }
    60:   
  61: public TEntity Remove(TEntity entity)
  62:          {
    63:              storage.Remove(entity);
  64: return entity;
  65:          }
    66:   
  67: public IEnumerator<TEntity> GetEnumerator()
  68:          {
  69: return this.storage.GetEnumerator();
  70:          }
    71:   
    72:          System.Collections.IEnumerator System.Collections.IEnumerable.GetEnumerator()
    73:          {
  74: return this.storage.GetEnumerator();
  75:          }
    76:   
  77: public Type ElementType
  78:          {
    79:              get
    80:              {
  81: return storage.AsQueryable().ElementType;
  82:              }
    83:          }
    84:   
  85: public System.Linq.Expressions.Expression Expression
  86:          {
    87:              get
    88:              {
  89: return storage.AsQueryable().Expression;
  90:              }
    91:          }
    92:   
  93: public IQueryProvider Provider
  94:          {
    95:              get
    96:              {
  97: return storage.AsQueryable().Provider;
  98:              }
    99:          }
   100:      }
延续前文思路,在第4行添加了一个ObservableCollection<T>的集合,用作本地存储;增、删、改全部针对这个集合来完成。第85到91行,在返回Expression的时候,把storage.AsQueryable()的表达式返回出去,这样,查询的的时候就会获取到查询storage中元素的表达式。
接下来,我们要把使用DbSet<TEntity>的DbContext和使用MockDbSet<TEntity>的MockDbContext统一起来。查看一下DbContext……呃~没有实现任何接口-.-||(曾经有位伟人说过……算了,不说了)既然没有接口,我们就来创建一个:
using System.Data.Entity;
  namespace MTB.Data
  {
  public interface ITestableDbContext
    {
  IDbSet<TEntity> Set<TEntity>() where TEntity : class;
        int SaveChanges();
  }
}
这个接口做两件事,第一,可以实现数据持久化——SaveChanges()。第二,可以获取到IDbSet<TEntity>集合以对集合中数据进行操作。然后,做两件事:第一,让具体的DbContext实现这个接口;第二,写一个MockDbContext实现这个接口。 
  
我们先来看DbContext实现接口的部分:
1: public partial class MTBContainer : DbContext, ITestableDbContext
   2:      {
  3: public MTBContainer()
4: : base("name=MTBContainer")
   5:          {
     6:          }
     7:      
  8: public new IDbSet<TEntity> Set<TEntity>() where TEntity : class
   9:          {
  10: return base.Set<TEntity>();
  11:          }
    12:      
  13: //... Other code...
  14:      
  15: public IDbSet<Handbook> Handbooks1 { get; set; }
16: public IDbSet<Trip> Trips { get; set; }
17: // ... more item
  18:      }
其中,第8到11行,虽然覆盖了Set<TEntity>(),但调用的仍然是DbContext类中的Set<TEntity>()方法;然后把对应的DbSet<TEntity对象集合全部改为IDbSet<TEntity>,大功告成。由于用的是Model First的EF4.1,因此,其实这个类是通过修改模板而来的。修改过的模板会在文章结束时附上。
MockDbContext实现:
1: public partial class MockMTBContainer : DbContext, ITestableDbContext
   2:      {
  3: public MockMTBContainer()
4: : base("name=MTBContainer")
   5:          {
  6: Handbooks1 = new MockDbSet<Handbook>();
7: Trips = new MockDbSet<Trip>();
// Other code
  15:          }
    16:      
  17: public new IDbSet<TEntity> Set<TEntity>() where TEntity : class
  18:          {
  19: foreach (PropertyInfo property in typeof(MockMTBContainer).GetProperties())
  20:                  {
  21: if (property.PropertyType == typeof(IDbSet<TEntity>))
  22:                      {
  23: return property.GetValue(this, null) as IDbSet<TEntity>;
  24:                      }
    25:                  }
  26: throw new Exception("Type collection not found");
  27:          }
    28:      
  29: public override int SaveChanges()
  30:          {
  31: //Do nothing
32: return 0;
  33:          }
    34:      
  35: // ...
  36:      
  37: public IDbSet<Handbook> Handbooks1 { get; set; }
38: public IDbSet<Trip> Trips { get; set; }
39: // ...
  40:      }
这个MockDbContext同样继承自DbContext类(嗯,可以少写不少代码呢)。但是,覆盖了Set<TEntity>()方法;当然,第29到33行废掉SaveChanges()的事仍然不可不做——如果想让测试更全面一些,看看那些个增、删、改方法有没有忘调用SaveChanges(),利用这个重写设置一个标志为也不错啊
Okay,一切就绪,我们把这一切综合起来使用:
首先,业务逻辑:
1: public class BizHandbook
   2:      {
  3: ITestableDbContext dbContext = null;
   4:   
  5: #region Constructions
6: public BizHandbook()
7: : this(null)
   8:          {
     9:          }
    10:   
    11:   
  12: public BizHandbook(ITestableDbContext dbContextl)
  13:          {
  14: if (dbContext == null)
  15:              {
  16: dbContext = new MTBContainer();
  17:              }
  18: this.dbContext = dbContext;
  19:          }
  20: #endregion
  21:   
  22: #region Publich Methods
23: public IList<Handbook> GetMyHandbooks()
  24:          {
    25:              var handbookSet = dbContext.Set<Handbook>();
  26: return handbookSet.ToList();
  27:          }
  28: #endregion
  29:      }
第12行构造函数会要求一个实现了ITestableDbContext的对象,如果为null,那么使用默认的DbContext。增删改查一律该怎么写就怎么写。模板下载,使用的时候记得修改对应的using和inputFile变量——点击下载。
内容:
1. 使用IDbSet<T>的DbContext的模板;
2. MockDbContext模板;
3. MockDbSet<TEntity>类;
4. ITestableDbContext接口;
对了,顺便提一下,在MockDbSet<TEntity>实现的时候,Find()方法为了方便起见,使用了一个Hack的方法来判断属性是否为Key属性——属性名称中是否含有ID。
Little knowledge is dangerous.
 
                    
                     
                    
                 
                    
                 
        
 
                
            
         浙公网安备 33010602011771号
浙公网安备 33010602011771号