C# Sql帮助类,可扩展
查看代码
[System.AttributeUsage(AttributeTargets.Class | AttributeTargets.Struct, Inherited = false, AllowMultiple = false)]
public class DbTableAttribute : Attribute
{
public string Name { get; set; }
public string Charset { get; set; }
public string Collate { get; set; }
}
[System.AttributeUsage(AttributeTargets.Property, Inherited = false, AllowMultiple = false)]
public class DbColumnAttribute : Attribute
{
public string Name { get; set; }
/// <summary>
/// <para>type[(length)] [[primary key]|[unique]] [unsigned] [zerofill] [not null or null] [default your_value] [comment 'your comment'] [collate 'your encoding'] ...etc.</para>
/// <para>bigint auto_increment</para>
/// <para>int default '0'</para>
/// <para>varchar(50) null default null collate 'utf8_general_ci'</para>
/// <para>datetime null</para>
/// <para>datetime null default 'localtime'</para>
/// <para>timestamp not null default current_timestamp on update current_timestamp</para>
/// <para>bit(1) null default b'0' comment 'balabala'</para>
/// </summary>
public string Desc { get; set; }
/// <summary>
/// Index1,Indx 2,Index3,primary key,unique
/// </summary>
public string Index { get; set; }
public bool NotInsert { get; set; }
}
public abstract class DbContext<TTransaction, TConnection, TCommand, TParameter, TDataAdapter, TDataReader> : IDisposable
where TTransaction : DbTransaction
where TConnection : DbConnection
where TCommand : DbCommand
where TParameter : DbParameter
where TDataAdapter : DbDataAdapter
where TDataReader : DbDataReader
{
protected string connectStr;
protected TTransaction transaction = null;
protected TConnection connection = null;
public TConnection Connection
{
get { return connection; }
set { connection = value; }
}
public virtual int ExecuteNonQuery(string sql, params TParameter[] parameters)
{
Connect();
int i;
using (TCommand cmd = CreateCommand(sql, parameters))
{
if (this.transaction != null)
cmd.Transaction = (TTransaction)this.transaction;
i = cmd.ExecuteNonQuery();
}
return i;
}
public virtual object ExecuteScalar(string sql, params TParameter[] parameters)
{
Connect();
using (TCommand cmd = CreateCommand(sql, parameters))
{
return cmd.ExecuteScalar();
}
}
public virtual object ExecuteReader(string sql, params TParameter[] parameters)
{
Connect();
using (TCommand cmd = CreateCommand(sql, parameters))
{
using (var reader = cmd.ExecuteReader())
{
while (reader.NextResult()) { }
if (reader.Read() && reader.FieldCount > 0)
{
DataTable dt = new DataTable();
dt.Load(reader);
return dt;
}
else
return reader.RecordsAffected;
}
}
}
public virtual DataTable Query(string sql, params TParameter[] parameters)
{
Connect();
using (TCommand cmd = CreateCommand(sql, parameters))
{
using (TDataAdapter adapter = (TDataAdapter)Activator.CreateInstance(typeof(TDataAdapter), cmd))
{
DataSet ds = new DataSet();
adapter.Fill(ds, "ds");
return ds.Tables.Count > 0 ? ds.Tables[0] : null;
}
}
}
public virtual TCommand CreateCommand(string sql, params TParameter[] parameters)
{
TCommand cmd = (TCommand)Connection.CreateCommand();
cmd.CommandText = sql;
if (parameters != null && parameters.Length > 0)
cmd.Parameters.AddRange(parameters);
return cmd;
}
public virtual int AddColumn<Table>(DbColumnAttribute dbColumnAttribute)
{
if (dbColumnAttribute == null)
dbColumnAttribute = new DbColumnAttribute();
Type tableType = typeof(Table);
DbTableAttribute dbTableAttribute = GetDbTableAttribute<Table>();
return AddColumn(tableType.Name, dbColumnAttribute.Name, dbColumnAttribute.Desc);
}
public virtual int AddColumn(string tableName, string columnName, string options)
{
string sql = $"alter table '{tableName}' add `{columnName}` {options}";
int r = ExecuteNonQuery(sql);
return r;
}
public virtual void AddIndex(string tableName, string index, string[] columns)
{
string sql = $"create index `{index}` on `{tableName}` (`{string.Join(",", columns)}`)";
ExecuteNonQuery(sql);
}
public virtual void DeleteIndex(string tableName, string index)
{
ExecuteNonQuery($"drop index `{index}` on `{tableName}`");
}
public virtual List<Table> GetList<Table>(string sql, params TParameter[] parameters)
{
DataTable dt = Query(sql, parameters);
return TableToList<Table>(dt);
}
public virtual void Connect(bool reconnect = false)
{
if (!reconnect && connection != null && connection.State != System.Data.ConnectionState.Closed)
return;
connection?.Dispose();
connection = (TConnection)Activator.CreateInstance(typeof(TConnection), this.connectStr);
connection.Open();
}
public virtual int CreateTable<Table>()
{
Type tableType = typeof(Table);
var tableAttribute = GetDbTableAttribute<Table>();
StringBuilder sb = new StringBuilder($"create table `{tableAttribute.Name}`");
StringBuilder columns = new StringBuilder();
var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
foreach (PropertyInfo propertyInfo in propertyInfos)
{
var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (columnInfo == null) continue;
columns.AppendLine($"\t`{columnInfo.Name}` {columnInfo.Desc},");
}
if (columns.Length > 0)
{
string sColumns = columns.ToString();
sColumns = sColumns.Remove(sColumns.LastIndexOf(','), 1);
sb.Append($" (\r\n{sColumns})\r\n");
}
return ExecuteNonQuery(sb.ToString());
}
public abstract bool ExistColumn(string column, string table);
public abstract bool ExistIndex(string tableName, string index, string columnName = null);
public abstract bool ExistTable(string table);
public virtual Dictionary<string, List<string>> GetIndexs(Type table)
{
Dictionary<string, List<string>> indexDict = new Dictionary<string, List<string>>();
foreach (var propertyInfo in table.GetProperties(BindingFlags.Instance | BindingFlags.Public))
{
var column = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (column == null) continue;
if (!string.IsNullOrWhiteSpace(column.Index))
{
foreach (var index in column.Index.Split(','))
{
if (!indexDict.ContainsKey(index))
indexDict.Add(index, new List<string> { column.Name });
else
indexDict[index].Add(column.Name);
}
}
}
return indexDict;
}
public virtual void UseServer(string connectStr)
{
this.connectStr = connectStr;
Connect(true);
}
public virtual void Repair<Table>()
{
var tableType = typeof(Table);
RepairTable<Table>();
RepairColumns<Table>();
RepairIndex(tableType);
}
public virtual void RepairTable<Table>()
{
var table = GetDbTableAttribute<Table>();
var tableType = typeof(Table);
ValidateTableAttribute(table, tableType);
bool exist = ExistTable(table.Name);
if (!exist)
CreateTable<Table>();
}
public virtual void RepairColumns<Table>()
{
var table = GetDbTableAttribute<Table>();
foreach (var propertyInfo in typeof(Table).GetProperties(BindingFlags.Instance | BindingFlags.Public))
{
var column = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (column == null) continue;
if (ExistColumn(column.Name, table.Name))
AddColumn(table.Name, column.Name, column.Desc);
}
}
public virtual void RepairIndex(Type table)
{
Dictionary<string, List<string>> indexDict = GetIndexs(table);
foreach (var index in indexDict.Keys)
{
foreach (var column in indexDict[index])
{
if (!ExistIndex(table.Name, index, column))
{
if (ExistIndex(table.Name, index))
DeleteIndex(table.Name, index);
AddIndex(table.Name, index, indexDict[index].ToArray());
break;
}
}
}
}
public virtual void Dispose()
{
transaction?.Dispose();
connection?.Dispose();
}
public virtual int Insert<Table>(Table model)
{
Type tableType = typeof(Table);
var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
Dictionary<string, object> dict = new Dictionary<string, object>();
foreach (PropertyInfo propertyInfo in propertyInfos)
{
var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (columnInfo == null || columnInfo.NotInsert) continue;
dict.Add(columnInfo.Name, propertyInfo.GetValue(model));
}
DbTableAttribute dbTableAttribute = GetDbTableAttribute<Table>();
StringBuilder sb = new StringBuilder($"insert into {dbTableAttribute.Name}");
sb.Append(string.Join(",", "(", dict.Keys.Select(s => $"{s}=@{s}"), ")"));
sb.Append(string.Join(",", "(", dict.Keys.Select(s => $"@{s}"), ")"));
TParameter[] parameters = new TParameter[dict.Count];
int i = 0;
using (var enumerator = dict.GetEnumerator())
{
KeyValuePair<string, object> item;
var typeParam = typeof(TParameter);
while (enumerator.MoveNext())
{
item = enumerator.Current;
parameters[i++] = (TParameter)Activator.CreateInstance(typeParam, $"@{item.Key}", item.Value);
}
}
return ExecuteNonQuery(sb.ToString(), parameters);
}
public virtual int DeleteByPrimaryKey<Table>(object value)
{
Type type = typeof(Table);
var tableInfo = type.GetCustomAttribute<DbTableAttribute>();
if (tableInfo == null) return 0;
var properties = type.GetProperties(BindingFlags.Instance | BindingFlags.Public);
DbColumnAttribute columnInfo;
foreach (var propertyInfo in properties)
{
columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (columnInfo == null || !columnInfo.Desc.ToLower().Contains("primary key")) continue;
return Delete(tableInfo.Name, columnInfo.Name, value);
}
return 0;
}
public virtual int Delete(string table, string column, object value)
{
return ExecuteNonQuery($"delete from `{table}` where `{column}`=@{column}", (TParameter)Activator.CreateInstance(typeof(TParameter), $"@{column}", value));
}
public virtual int Delete(string table, string where, params TParameter[] parameters)
{
return ExecuteNonQuery($"delete from `{table}` where {where}", parameters);
}
public virtual void BeginTransaction()
{
Connect();
transaction = (TTransaction)connection.BeginTransaction();
}
public virtual bool EndTransaction()
{
bool r;
try
{
transaction.Commit();
r = true;
}
catch
{
transaction.Rollback();
r = false;
}
transaction.Dispose();
transaction = null;
return r;
}
public DbContext() { }
public DbContext(Action<DbContext<TTransaction, TConnection, TCommand, TParameter, TDataAdapter, TDataReader>> callback) { callback(this); }
public virtual List<Table> TableToList<Table>(DataTable dt)
{
var list = new List<Table>();
if (dt == null || dt.Rows.Count <= 0) return list;
Dictionary<string, PropertyInfo> dict = new Dictionary<string, PropertyInfo>();
foreach (var propertyInfo in typeof(Table).GetProperties(BindingFlags.Instance | BindingFlags.Public))
{
var column = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (column == null) continue;
dict.Add(column.Name, propertyInfo);
}
foreach (DataRow row in dt.Rows)
{
Table model = Activator.CreateInstance<Table>();
foreach (DataColumn column in dt.Columns)
{
if (!dict.ContainsKey(column.ColumnName)) continue;
var propertyInfo = dict[column.ColumnName];
var value = row[column.ColumnName];
if (value == null || value is DBNull) continue;
if (propertyInfo.PropertyType == typeof(int))
propertyInfo.SetValue(model, Convert.ToInt32(value));
else if (propertyInfo.PropertyType == typeof(byte))
propertyInfo.SetValue(model, Convert.ToByte(value));
else if (propertyInfo.PropertyType == typeof(long))
propertyInfo.SetValue(model, Convert.ToInt64(value));
else if (propertyInfo.PropertyType == typeof(float))
propertyInfo.SetValue(model, Convert.ToSingle(value));
else if (propertyInfo.PropertyType == typeof(double))
propertyInfo.SetValue(model, Convert.ToDouble(value));
else if (propertyInfo.PropertyType == typeof(decimal))
propertyInfo.SetValue(model, Convert.ToDecimal(value));
else if (propertyInfo.PropertyType == typeof(DateTime))
propertyInfo.SetValue(model, Convert.ToDateTime(value));
else if (propertyInfo.PropertyType == typeof(string))
propertyInfo.SetValue(model, value.ToString());
else if (propertyInfo.PropertyType.IsEnum)
propertyInfo.SetValue(model, Enum.Parse(propertyInfo.PropertyType, value.ToString()));
}
list.Add(model);
}
return list;
}
public DbTableAttribute GetDbTableAttribute<Table>() => GetDbTableAttribute(typeof(Table));
public virtual DbTableAttribute GetDbTableAttribute(Type t)
{
var tableAttribute = t.GetCustomAttribute<DbTableAttribute>();
if (tableAttribute == null)
tableAttribute = new DbTableAttribute();
ValidateTableAttribute(tableAttribute, t);
return tableAttribute;
}
public virtual void ValidateTableAttribute(DbTableAttribute tableAttribute, Type type)
{
if (string.IsNullOrWhiteSpace(tableAttribute.Name))
tableAttribute.Name = type.Name;
}
public virtual Dictionary<string, object> GetColumns<Table>(Table model)
{
Type tableType = typeof(Table);
Dictionary<string, object> dict = new Dictionary<string, object>();
var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
foreach (PropertyInfo propertyInfo in propertyInfos)
{
var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (columnInfo == null) continue;
dict.Add(columnInfo.Name, propertyInfo.GetValue(model));
}
return dict;
}
public virtual List<DbColumnAttribute> GetColumnAttributes<Table>() => GetColumnAttributes(typeof(Table));
public virtual List<DbColumnAttribute> GetColumnAttributes(Type tableType)
{
List<DbColumnAttribute> dbColumnAttributes = new List<DbColumnAttribute>();
var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
foreach (PropertyInfo propertyInfo in propertyInfos)
{
var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (columnInfo == null) continue;
dbColumnAttributes.Add(columnInfo);
}
return dbColumnAttributes;
}
public virtual string NameOf<Table>() => NameOf(typeof(Table));
public virtual string NameOf(Type type)
{
return GetDbTableAttribute(type).Name;
}
/// <summary>
///
/// </summary>
/// <typeparam name="Table"></typeparam>
/// <param name="propertyName">nameof(Table.Property)</param>
/// <returns></returns>
public virtual string NameOf<Table>(string propertyName) => NameOf(typeof(Table), propertyName);
/// <summary>
///
/// </summary>
/// <typeparam name="Table"></typeparam>
/// <param name="propertyName">nameof(Table.Property)</param>
/// <returns></returns>
public virtual string NameOf(Type tableType, string propertyName)
{
var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
foreach (PropertyInfo propertyInfo in propertyInfos)
{
var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (columnInfo != null && propertyInfo.Name == propertyName) return columnInfo.Name;
}
return propertyName;
}
}
Mysql:
查看代码
public class MySqlDbContext : DbContext<MySqlTransaction, MySqlConnection, MySqlCommand, MySqlParameter, MySqlDataAdapter, MySqlDataReader>
{
public MySqlDbContext() { }
public MySqlDbContext(Action<DbContext<MySqlTransaction, MySqlConnection, MySqlCommand, MySqlParameter, MySqlDataAdapter, MySqlDataReader>> callback) :base(callback) { }
public override void Repair<Table>()
{
var table = GetDbTableAttribute<Table>();
var tableType = typeof(Table);
ValidateTableAttribute(table, tableType);
bool exist = ExistTable(table.Name);
if (!exist)
CreateTable<Table>();
else
{
RepairColumns<Table>();
RepairIndex(tableType);
}
}
public override int CreateTable<Table>()
{
Type tableType = typeof(Table);
var tableAttribute = GetDbTableAttribute<Table>();
StringBuilder sb = new StringBuilder($"create table `{tableAttribute.Name}`");
StringBuilder columns = new StringBuilder();
var propertyInfos = tableType.GetProperties(BindingFlags.Instance | BindingFlags.Public);
foreach (PropertyInfo propertyInfo in propertyInfos)
{
var columnInfo = propertyInfo.GetCustomAttribute<DbColumnAttribute>();
if (columnInfo == null) continue;
columns.AppendLine($"\t`{columnInfo.Name}` {columnInfo.Desc},");
}
StringBuilder sbIndex = new StringBuilder();
Dictionary<string, List<string>> indexDict = GetIndexs(tableType);
foreach (var key in indexDict.Keys)
{
string indexType = "index";
switch (key.ToLower().Trim())
{
case "primary":
case "primary key":
case "primarykey":
case "unique":
case "unique index":
continue;
default:
break;
}
sbIndex.AppendLine($"\t{indexType} `{key}` (`{string.Join("`,`", indexDict[key])}`),");
}
string sIndexs = sbIndex.ToString();
string sColumns = columns.ToString();
if (sbIndex.Length > 0)
sIndexs = sIndexs.Remove(sIndexs.LastIndexOf(','), 1);
else
sColumns = sColumns.Remove(sColumns.LastIndexOf(','), 1);
string content = sColumns + sIndexs;
sb.Append($" (\r\n{content})\r\n");
if(!string.IsNullOrWhiteSpace(tableAttribute.Charset))
sb.AppendLine($"default character set {tableAttribute.Charset}");
if (!string.IsNullOrWhiteSpace(tableAttribute.Collate))
sb.AppendLine($"collate {tableAttribute.Collate}");
return ExecuteNonQuery(sb.ToString());
}
public override void ValidateTableAttribute(DbTableAttribute tableAttribute, Type type)
{
if(string.IsNullOrWhiteSpace(tableAttribute.Name))
tableAttribute.Name = type.Name;
if (string.IsNullOrWhiteSpace(tableAttribute.Charset))
tableAttribute.Charset = "utf8mb4";
if (string.IsNullOrWhiteSpace(tableAttribute.Collate))
tableAttribute.Collate = "utf8mb4_unicode_ci";
}
public override bool ExistColumn(string column, string table)
{
Connect();
var r = ExecuteScalar($"select 1 from information_schema.columns where table_schema='{Connection.Database}' and table_name ='{table}' and column_name='{column}';");
return r != null && r.ToString() == "1";
}
public override bool ExistIndex(string tableName, string index, string columnName = null)
{
StringBuilder sql = new StringBuilder($"select count(*) from information_schema.statistics where table_schema = database() and table_name = '{tableName}' and index_name = '{index}'");
if (!string.IsNullOrWhiteSpace(columnName))
sql.Append($" and column_name='{columnName}'");
object r = ExecuteScalar(sql.ToString());
return r != null && (r is int num) && num > 0;
}
public override bool ExistTable(string table)
{
Connect();
var r = ExecuteScalar($"select 1 from information_schema.tables where table_schema='{Connection.Database}' and table_name ='{table}';");
return r != null && r.ToString() == "1";
}
}
SQLite:
查看代码
public class SqliteDbContext : DbContext<SQLiteTransaction, SQLiteConnection, SQLiteCommand, SQLiteParameter, SQLiteDataAdapter, SQLiteDataReader>
{
public SqliteDbContext() { }
public SqliteDbContext(Action<DbContext<SQLiteTransaction, SQLiteConnection, SQLiteCommand, SQLiteParameter, SQLiteDataAdapter, SQLiteDataReader>> callback) : base(callback) { }
public override void Connect(bool reconnect = false)
{
if (!reconnect && Connection != null && (Connection.State == ConnectionState.Open || Connection.State == ConnectionState.Connecting)) return;
if (!Directory.Exists(Path.GetDirectoryName(connectStr)))
Directory.CreateDirectory(Path.GetDirectoryName(connectStr));
if (!File.Exists(connectStr))
SQLiteConnection.CreateFile(connectStr);
Connection = new SQLiteConnection(connectStr.ToLower().Contains("data source=") ? connectStr : "data source=" + connectStr);
Connection.Open();
}
public override bool ExistColumn(string column, string table)
{
Connect();
var r = ExecuteScalar($"select 1 from sqlite_master where type='table' and name='{table}' and sql like '%{column}%'");
return r != null && r.ToString() == "1";
}
public override bool ExistIndex(string tableName, string index, string columnName = null)
{
StringBuilder sql = new StringBuilder($"select 1 from sqlite_master where type='index' and name='{index}'");
if (!string.IsNullOrWhiteSpace(columnName))
sql.Append($" and sql like '%{columnName}%'");
object r = ExecuteScalar(sql.ToString());
return r != null && (r is int num) && num > 0;
}
public override bool ExistTable(string table)
{
Connect();
var r = ExecuteScalar($"select 1 from sqlite_master where type='table' and name='{table}'");
return r != null && r.ToString() == "1";
}
}
使用方法:
[DbTable(Name = "account")]
public class Account
{
[DbColumn(Name = "id", Desc = "bigint primary key auto_increment", NotInsert = true)]
public long Id { get; set; }
[DbColumn(Name = "un", Desc = "varchar(50) not null unique", Index = "un")]
public string Username { get; set; }
[DbColumn(Name = "pwd", Desc = "varchar(20) not null")]
public string Password { get; set; }
[DbColumn(Name = "is_deleted", Desc = "int(1) default '0")]
public bool Deleted { get; set; }
[DbColumn(Name = "create_time", Desc = "timestamp default localtime")]
public DateTime CreateTime { get; set; }
}
static void Main(string[] args)
{
MySqlDbContext dbContext = new MySqlDbContext();
dbContext.UseServer("Data Source=127.0.0.1; Database=tempdb; User ID=admin; Password=123;Charset=utf8mb4;");
dbContext.Repair<Account>();
var account = new Account();
dbContext.Insert(account);
var list = dbContext.GetList<Account>("select * from account limit 10");
dbContext.DeleteByPrimaryKey<Account>(account.Id);
Console.WriteLine("按任意键退出。");
Console.ReadKey();
}
浙公网安备 33010602011771号