代码改变世界

一个用于多种数据库连接,并且可以反射出自定义类型对象的DBHelper

2010-08-11 17:48 by Tsanie, ... 阅读, ... 评论, 收藏, 编辑

写这篇随笔的目的只是为了记录一下我在MVC架构下的数据处理思路的改变,由于我文采不好,写的难以忍受或者瞎眼的地方请多多见谅。另外我非常希望能够得到一个更严谨、完善的DBHelper,大家有什么批评指正的尽管提。

 

之所以要实现DBHelper,目的就是帮助我们写Service的时候可以专注于业务实现而不用考虑琐碎的.Open()与.Close()以及那些for循环。

为了提高代码的复用性,可以用以下思路来实现DBHelper:

因为ADO.NET有一套统一的接口(IDbConnection、IDbCommand、IDbDataParameter、IDbDataAdapter等),所以我们可以用一套代码来完成不同数据库的处理功能。

通过DbProviderFactories.GetFactory(providerInvariantName)来获取某个Data Provider的XXXXFactory,由于这个类继承DbProviderFactory,于是我们可以使用.CreateConnection()来创建连接、.CreateDataAdapter()来创建适配器等等……

不过我尝试了指定MySql.Data.MySqlClient为provider来创建Factory结果没成功,网上搜了一下据说是未注册到GAC。但是由于服务器权限不够我只能考虑换一种办法,使用Assembly.LoadFrom(assemblyFile)来加载MySql.Data.dll,然后获取MySql.Data.MySqlClient.MySqlClientFactory的Type,最后通过.GetConstructor(Type.EmptyTypes).Invoke(null)来反射实例化一个MySqlClientFactory对象。

 

到了这里(获取到了DbProviderFactory),我们就可以增加自己的便利方法了(ExecuteSql、ExecuteQueryCount、ExecuteQueryTable等)。

然后,我们还会想到一个问题,为了更加专心与业务逻辑的实现,我们甚至不想去使用DataTable(我们认为要记住那些繁多的ColumnNames实在是太浪费脑细胞了),于是我想到了自定义类型以及泛型(于是有了IEnumerable<TEntity> ExecuteQuery<TEntity>()与TEntity ExecuteQuerySingle<TEntity>())。原理就是遍历TEntity的所有公共属性,然后从DbDataReader中以属性名为列名查找数据赋给TEntity……

 

代码实现:

代码
1 using System;
2  using System.Collections.Generic;
3  using System.Reflection;
4  using System.Data;
5  using System.Data.Common;
6
7  namespace Tsanie.Web.Models.Utils {
8
9 /// <summary>
10 /// 数据库连接帮助类
11 /// </summary>
12   public class DBHelper {
13
14 #region - 静态 -
15 private readonly static string _connStr;
16 private readonly static DbProviderFactory _provider;
17
18 /// <summary>
19 /// 静态构造
20 /// </summary>
21   static DBHelper() {
22 System.Collections.Specialized.NameValueCollection appSettings
23 = System.Configuration.ConfigurationManager.AppSettings;
24 // 连接字符串
25 _connStr = appSettings["connStr"];
26 // 优先provider,如果存在则用DbProviderFactories.GetFactory()创建
27 string provider = appSettings["provider"];
28 if (provider != null) {
29 _provider = DbProviderFactories.GetFactory(provider);
30 return;
31 }
32 // 否则使用Assembly.LoadFrom()载入程序集
33 // 因为DbProviderFactories.GetFactory()无法创建MySqlClientFactory
34 string factory = appSettings["factory"];
35 Type typeProvider = Assembly
36 .LoadFrom(AppDomain.CurrentDomain.BaseDirectory + appSettings["driver"])
37 .GetType(factory);
38 _provider = typeProvider
39 .GetConstructor(Type.EmptyTypes)
40 .Invoke(null) as DbProviderFactory;
41 }
42 #endregion
43
44 #region - 属性 -
45 /// <summary>
46 /// 获取或设置SQL语句文本
47 /// </summary>
48 public string Command { get; set; }
49 /// <summary>
50 /// 获取或设置语句类型
51 /// </summary>
52 public CommandType CommandType { get; set; }
53 /// <summary>
54 /// 获取语句参数
55 /// </summary>
56 public List<DbParameter> Parameters { get; private set; }
57 #endregion
58
59 #region - 构造 -
60 private DbConnection _conn;
61
62 /// <summary>
63 /// 构造数据库助手
64 /// </summary>
65 public DBHelper() : this(null, CommandType.Text) { }
66
67 /// <summary>
68 /// 构造数据库助手
69 /// </summary>
70 /// <param name="sql">要执行的SQL语句</param>
71 public DBHelper(string sql) : this(sql, CommandType.Text) { }
72
73 /// <summary>
74 /// 构造数据库助手
75 /// </summary>
76 /// <param name="sql">要执行的SQL语句</param>
77 /// <param name="type">语句类型</param>
78 public DBHelper(string sql, CommandType type) {
79 _conn = _provider.CreateConnection();
80 _conn.ConnectionString = _connStr;
81 Command = sql;
82 CommandType = type;
83 Parameters = null;
84 }
85 #endregion
86
87 #region - 公共方法 -
88 /// <summary>
89 /// 添加参数
90 /// </summary>
91 /// <param name="parameterName">参数名</param>
92 /// <param name="value">参数值</param>
93 /// <returns>参数</returns>
94 public DbParameter AddParameter(string parameterName, object value) {
95 DbParameter param = _provider.CreateParameter();
96 param.ParameterName = parameterName;
97 param.Value = value;
98 return AddParameter(param);
99 }
100
101 /// <summary>
102 /// 添加参数
103 /// </summary>
104 /// <param name="parameterName">参数名</param>
105 /// <param name="value">参数值</param>
106 /// <param name="dbType">参数数据类型</param>
107 /// <returns>参数</returns>
108 public DbParameter AddParameter(string parameterName, object value, DbType dbType) {
109 DbParameter param = _provider.CreateParameter();
110 param.ParameterName = parameterName;
111 param.Value = value;
112 param.DbType = dbType;
113 return AddParameter(param);
114 }
115
116 /// <summary>
117 /// 添加参数
118 /// </summary>
119 /// <param name="parameterName">参数名</param>
120 /// <param name="value">参数值</param>
121 /// <param name="dbType">参数数据类型</param>
122 /// <param name="size">参数长度</param>
123 /// <returns>参数</returns>
124 public DbParameter AddParameter(string parameterName, object value, DbType dbType, int size) {
125 return AddParameter(parameterName, value, dbType, size, ParameterDirection.Input);
126 }
127
128 /// <summary>
129 /// 添加参数
130 /// </summary>
131 /// <param name="parameterName">参数名</param>
132 /// <param name="value">参数值</param>
133 /// <param name="dbType">参数数据类型</param>
134 /// <param name="size">参数长度</param>
135 /// <param name="direction">参数类型</param>
136 /// <returns>参数</returns>
137 public DbParameter AddParameter(string parameterName, object value, DbType dbType,
138 int size, ParameterDirection direction) {
139 DbParameter param = _provider.CreateParameter();
140 param.ParameterName = parameterName;
141 param.Value = value;
142 param.DbType = dbType;
143 param.Size = size;
144 param.Direction = direction;
145 return AddParameter(param);
146 }
147
148 /// <summary>
149 /// 执行非查询的SQL语句
150 /// </summary>
151 /// <returns>DAO处理结果</returns>
152 public int ExecuteSql() {
153 using (DbCommand cmd = PrepareCommand()) {
154 try {
155 int rows = cmd.ExecuteNonQuery();
156 cmd.Parameters.Clear();
157 return rows;
158 } catch (DbException e) {
159 throw e;
160 } finally {
161 CloseConnection();
162 }
163 }
164 }
165
166 /// <summary>
167 /// 执行查询SQL语句
168 /// </summary>
169 /// <returns>查询列表</returns>
170 public IEnumerable<TEntity> ExecuteQuery<TEntity>() where TEntity : new() {
171 using (DbCommand cmd = PrepareCommand()) {
172 DbDataReader reader = null;
173 try {
174 reader = cmd.ExecuteReader();
175 return createEnumerable<TEntity>(reader);
176 } catch (DbException e) {
177 throw e;
178 } finally {
179 CloseConnection(reader);
180 }
181 }
182 }
183
184 /// <summary>
185 /// 执行查询SQL语句
186 /// </summary>
187 /// <returns>查询到的单个对象</returns>
188 public TEntity ExecuteQuerySingle<TEntity>() where TEntity : new() {
189 using (DbCommand cmd = PrepareCommand()) {
190 DbDataReader reader = null;
191 try {
192 reader = cmd.ExecuteReader();
193 if (reader.Read()) {
194 return createSingle<TEntity>(reader);
195 }
196 return default(TEntity);
197 } catch (DbException e) {
198 throw e;
199 } finally {
200 CloseConnection(reader);
201 }
202 }
203 }
204
205 /// <summary>
206 /// 执行查询SQL语句
207 /// </summary>
208 /// <returns>查询结果表</returns>
209 public DataTable ExecuteQueryTable() {
210 using (DbCommand cmd = PrepareCommand()) {
211 DbDataAdapter adapter = _provider.CreateDataAdapter();
212 adapter.SelectCommand = cmd;
213 try {
214 DataTable table = new DataTable();
215 adapter.Fill(table);
216 return table;
217 } catch (DbException ex) {
218 throw ex;
219 } finally {
220 CloseConnection();
221 }
222 }
223 }
224
225 /// <summary>
226 /// 执行查询SQL语句并返回个数
227 /// </summary>
228 /// <returns>查询到的个数</returns>
229 public long ExecuteQueryCount() {
230 using (DbCommand cmd = PrepareCommand()) {
231 try {
232 object result = cmd.ExecuteScalar();
233 return result == null ? -1 : (long)result;
234 } catch (DbException e) {
235 throw e;
236 } finally {
237 CloseConnection();
238 }
239 }
240 }
241 #endregion
242
243 #region - 私有方法 -
244 /// <summary>
245 /// 向参数集合里添加参数
246 /// </summary>
247 /// <param name="param"></param>
248 private DbParameter AddParameter(DbParameter param) {
249 if (Parameters == null)
250 Parameters = new List<DbParameter>();
251 Parameters.Add(param);
252 return param;
253 }
254 /// <summary>
255 /// 准备SQL命令对象
256 /// </summary>
257 /// <returns>SQL命令对象</returns>
258 private DbCommand PrepareCommand() {
259 try {
260 if (_conn.State != ConnectionState.Open)
261 _conn.Open();
262 DbCommand cmd = _conn.CreateCommand();
263 cmd.CommandText = Command;
264 cmd.CommandType = CommandType;
265 if (Parameters != null && Parameters.Count > 0) {
266 foreach (DbParameter param in Parameters) {
267 cmd.Parameters.Add(param);
268 }
269 }
270 return cmd;
271 } catch (DbException e) {
272 throw e;
273 }
274 }
275
276 /// <summary>
277 /// 创建集合
278 /// </summary>
279 /// <typeparam name="TEntity">实体类型</typeparam>
280 /// <param name="reader">数据读取器</param>
281 /// <returns>实体集合</returns>
282 private IEnumerable<TEntity> createEnumerable<TEntity>(DbDataReader reader) where TEntity : new() {
283 List<TEntity> list = new List<TEntity>();
284 // 遍历结果集
285 while (reader.Read()) {
286 // 将新添加的实例存放在ArrayList中
287 list.Add(createSingle<TEntity>(reader));
288 }
289 return list;
290 }
291
292 /// <summary>
293 /// 创建实体
294 /// </summary>
295 /// <typeparam name="TEntity">实体类型</typeparam>
296 /// <param name="reader">数据读取器</param>
297 /// <returns>实体对象</returns>
298 private TEntity createSingle<TEntity>(DbDataReader reader) where TEntity : new() {
299 // 创建该类的实例
300 TEntity curObj = new TEntity();
301 Type type = curObj.GetType();
302
303 // 遍历所有的公有属性
304 foreach (PropertyInfo pInfo in type.GetProperties()) {
305 MethodInfo method = pInfo.GetSetMethod(false);
306 if (method != null) {
307 try {
308 object value = reader[pInfo.Name];
309 if (value == null || value is System.DBNull) {
310 method.Invoke(curObj, new object[] { null });
311 } else {
312 method.Invoke(curObj, new object[] { reader[pInfo.Name] });
313 }
314 } catch (Exception e) {
315 throw e;
316 }
317 }
318 }
319 return curObj;
320 }
321
322 /// <summary>
323 /// 关闭连接
324 /// </summary>
325 private void CloseConnection() {
326 CloseConnection(null);
327 }
328
329 /// <summary>
330 /// 关闭连接
331 /// </summary>
332 private void CloseConnection(DbDataReader reader) {
333 if (reader != null && !reader.IsClosed) {
334 reader.Close();
335 }
336 if (_conn != null && _conn.State == ConnectionState.Open) {
337 _conn.Close();
338 }
339 }
340 #endregion
341
342 }
343 }
344

 

至此,我们所有的工作就完成了,然后Service里我们可以这么调用:

1、查询个数

string sql = "select count(*) from A";
int count = new DBHelper(sql).ExecuteQueryCount();

 

2、查询反射到自定义类型

public class MemoModel {
    public int MemoID { get; set; }
    public string UserName { get; set; }
    public string MemoText { get; set; }
    public DateTime MemoDate { get; set; }
}

...
string sql = "select A.MemoID,B.UserName,A.MemoText,A.MemoDate"
    + " from Memos A left join Users B on B.UserID=A.UserID"
    + " where A.MemoID=@MemoID";
DBHelper helper = new DBHelper();
helper.Command = sql;
helper.AddParameter("@MemoID", memoId);
IEnumerable<MemoModel> memos = helper.ExecuteQuery<MemoModel>();
...

 

DBHelper.cs 下载地址:

http://files.cnblogs.com/tsorgy/DBHelper.zip