一个小型ORM框架,基于pymysql实现,较为简单。

  1 #!/usr/bin/env python
  2 # -*- coding: utf-8 -*-
  3 
  4 import pymysql
  5 from utils import NotImplementedError
  6 
  7 '''
  8 本文件是基于mysql实现的一个ORM框架
  9 '''
 10 
 11 class MysqlConnector(object):
 12     '''Python与mysql的连接器'''
 13 
 14     def __init__(self, host, port, username, password, db):
 15         conn = pymysql.connect(host=host, port=port, user=username,
 16                                passwd=password, db=db, use_unicode=True, charset="utf8")
 17         self.conn = conn
 18         self.cursor = conn.cursor(cursor=pymysql.cursors.DictCursor)
 19 
 20     def execute(self, sql_msg):
 21         '''
 22         执行sql语句
 23         :param sql_msg:  sql语句,字符串格式
 24         :return:
 25         '''
 26         ret = self.cursor.execute(sql_msg)
 27         self.conn.commit()
 28         return ret
 29 
 30     def close(self):
 31         '''关闭连接器'''
 32         self.cursor.close()
 33         self.conn.close()
 34 
 35 class BaseModel(object):
 36     '''
 37     实现将Python语句转换为sql语句,配合MysqlConnector实现表的创建以及数据的增删查改等操作。
 38     创建表时: 支持主键PRIMARY KEY,索引INDEX,唯一索引UNIQUE,自增AUTO INCREMENT 外键语句
 39         创建的表引擎指定为InnoDB,字符集为 utf-8
 40     增删查改: 支持WHERE [LIKE] LIMIT语句
 41     其子类必须设置initialize方法,并在该方法中创建字段对象
 42     '''
 43     def __new__(cls, *args, **kwargs):
 44         _instance = super().__new__(cls)
 45         _instance.initialize()
 46         return _instance
 47 
 48     def __init__(self, table_name, sql_connector):
 49         '''
 50         :param table_name: 要建立的表名
 51         :param sql_connector: MysqlConnector实例对象
 52         '''
 53         self.table_name = table_name
 54         self.fields = []
 55         self.primary_key_field = None
 56         self.uniques_fields = []
 57         self.index_fields = []
 58         self.is_foreign_key_fields = []
 59         self.sql_connector = sql_connector
 60         self._create_fields_list()
 61         self.create_table()
 62 
 63     def initialize(self):
 64         '''BaseModel的每个子类中必需包含该方法,且在该方法中定义字段'''
 65         raise NotImplementedError("Method or function hasn't been implemented yet.")
 66 
 67     def _create_fields_list(self):
 68         '''创建list用来存储表的字段对象'''
 69         for k, v in self.__dict__.items():
 70             if isinstance(v, BaseField):
 71                 self.fields.append(v)
 72                 v.full_column = '%s.%s' % (self.table_name, v.db_column)
 73                 v.table_name = self.table_name
 74         for field in self.fields:
 75             if field.primary_key:
 76                 self.primary_key_field = field
 77             if field.unique:
 78                 self.uniques_fields.append(field)
 79             if field.db_index:
 80                 self.index_fields.append(field)
 81             if field.is_foreign_key:
 82                 self.is_foreign_key_fields.append(field)
 83 
 84     def _has_created(self):
 85         '''检测表有没有被创建'''
 86         self.sql_connector.cursor.execute('SHOW TABLES;')
 87         ret = self.sql_connector.cursor.fetchall()
 88         for table in ret:
 89             for k, v in table.items():
 90                 if v == self.table_name:
 91                     return True
 92 
 93     def _create_table(self):
 94         ret = 'CREATE TABLE %s (' % self.table_name
 95         for v in self.fields:
 96             ret += v.generate_field_sql()
 97         ret = '%s%s%s%s%s' % (ret, self._generate_primary_key(),
 98                              self._generate_unique(), self._generate_index(),
 99                              self._generate_is_foreign_key())
100         ret = ret[:-1] + ')ENGINE=InnoDB DEFAULT CHARSET=utf8;'
101         return ret
102 
103     def create_table(self):
104         '''创建表'''
105         if not self._has_created():
106             print('创建表:%s' % self.table_name)
107             sql_msg = self._create_table()
108             # print(sql_msg)
109             self.sql_connector.execute(sql_msg)
110 
111     def _generate_primary_key(self):
112         '''生成sql语句中的 primary key 语句'''
113         ret = ''
114         if self.primary_key_field:
115             ret = 'PRIMARY KEY(%s),' % self.primary_key_field.db_column
116         return ret
117 
118     def _generate_is_foreign_key(self):
119         ret = ''
120         if self.is_foreign_key_fields:
121             for field in self.is_foreign_key_fields:
122                 ret += 'FOREIGN KEY(%s) REFERENCES %s(%s) ON DELETE %s  ON UPDATE %s,' % (field.db_column,
123                                                   field.model_obj.table_name,
124                                                   field.model_obj.primary_key_field.db_column,
125                                                   field.on_delete,
126                                                   field.on_delete )
127         return ret
128 
129     def _generate_unique(self):
130         '''生成sql语句中的 unique 语句'''
131         ret = ''
132         if self.uniques_fields:
133             ret = 'UNIQUE ('
134             for field in self.uniques_fields:
135                 ret += '%s,' % field.db_column
136             ret = ret[:-1]
137             ret += '),'
138         return ret
139 
140     def _generate_index(self):
141         index = ''
142         if self.index_fields:
143             index = 'INDEX ('
144             for field in self.index_fields:
145                 index += '%s,' % field.db_column
146             index = index[:-1]
147             index += '),'
148         return index
149 
150     def _generate_where(self, condition={}):
151         '''
152         根据条件生成 where 语句
153         :param condition: 一个dict,key是字段对象,value是条件(比如 'WHERE ID=3',那么value就是'=3')
154         :return:
155         '''
156         where = ''
157         if condition:
158             where = ' WHERE '
159             for k, v in condition.items():
160                 v = v.strip()
161                 offset = 1
162                 if v.startswith('l'):
163                     offset = 4
164                 if not k.is_str:
165                     where += '%s %s and' % (k.db_column, v)
166                 else:
167                     where += '%s %s "%s" and' % (k.db_column, v[:offset], v[offset:].strip())
168             where = where[:-3]
169         return where
170 
171     def select_items(self, counts=0, select_fields=[], condition={}, join_conditions=[]):
172         '''
173         根据condition 对表进行select,并 LIMIT counts
174         :param counts:
175         :param condition:
176         :return:
177         '''
178         join_length = len(join_conditions)
179         counts_sql = ''
180         join_sql = ''
181         select_fields_sql = ''
182         where = self._generate_where(condition)
183         if counts:
184             counts_sql = 'LIMIT %s' % counts
185         if not select_fields:
186             select_fields_sql = '* '
187         if join_conditions:
188             tables_order = list(list(zip(*join_conditions))[0])
189             tables_order.insert(0, self)
190             for i in select_fields:
191                 select_fields_sql += '%s,' % i.full_column
192             for n in range(join_length):
193                 one_join_condition = join_conditions[n]
194                 if n == 0:
195                     base_table = tables_order[0].table_name
196                 else:
197                     base_table = ''
198                 bracket_counts = join_length - n - 1
199                 join_sql += '%s %s LEFT JOIN %s on %s=%s%s' % (
200                     bracket_counts*'(', base_table, tables_order[n+1].table_name,
201                     one_join_condition[1].full_column, one_join_condition[2].full_column,
202                     bracket_counts * ')', )
203         else:
204             for i in select_fields:
205                 select_fields_sql += '%s,' % i.db_column
206             join_sql = self.table_name
207         select_fields_sql = select_fields_sql[:-1]
208         select = 'SELECT %s FROM %s %s %s;' % (select_fields_sql, join_sql, where, counts_sql)
209         # print('----------------', select)
210         self.sql_connector.execute(select)
211         result = self.sql_connector.cursor.fetchall()
212         return result
213 
214     def insert_item(self, data={}):
215         '''
216         向表中插入一行
217         :param data: 一个dict,key是字段对象,value则是值
218         :return:
219         '''
220         insert = 'INSERT INTO %s (' % self.table_name
221         value = '('
222         if data:
223             for k, v in data.items():
224                 insert += '%s,' % k.db_column
225                 if k.is_str:
226                     value += '"%s",' % v
227                 else:
228                     value += '%s,' % v
229                 # print('value is ',value)
230             insert = insert[:-1] + ')  VALUES '
231             value = value[:-1] + ');'
232             insert += value
233         # print('......',insert)
234         self.sql_connector.execute(insert)
235 
236     def delete_item(self, condition={}):
237         '''删除符合condition的条目'''
238         delete = 'DELETE FROM %s ' % self.table_name
239         where = self._generate_where(condition)
240         delete += where
241         # print(delete)
242         self.sql_connector.execute(delete)
243 
244     def update_item(self, data={}, condition={}):
245         '''将符合condition的条目修改为data'''
246         update = 'UPDATE %s' % self.table_name
247         data_statement = ''
248         if data:
249             data_statement = ' SET '
250             for k, v in data.items():
251                 if not k.is_str:
252                     data_statement += '%s=%s,' % (k.db_column, v)
253                 else:
254                     data_statement += '%s="%s",' % (k.db_column, v)
255             data_statement = data_statement[:-1]
256         where = self._generate_where(condition)
257         update += data_statement + where
258         # print('---------',update)
259         self.sql_connector.execute(update)
260 
261     def get_field_value(self, field, condition={}):
262         ret = self.select_items(condition=condition)
263         # print(ret)
264         if len(ret) == 1:
265             value = ret[0][field.db_column]
266         elif len(ret) > 1:
267             value = []
268             for i in ret:
269                 value.append(i[field.db_column])
270         else:
271             value = ''
272         # print('value is ',value)
273         return value
274 
275 class BaseField(object):
276     def __init__(self, db_column, null=True, blank=None, choice={},
277                  db_index=False, default=None, primary_key=False,
278                  unique=False, max_length=0, auto_increment=False,
279                  ):
280         '''
281 
282         :param db_column:  数据库中表的字段名
283         :param null:  该字段是否可以为空
284         :param blank: 如果该字段为空,存储什么值
285         :param choice: 该字段的值只能是choice的一个
286         :param db_index: 是否为该字段设置索引
287         :param default: 该字段的默认值
288         :param primary_key: 是否为该字段设置主键
289         :param unique: 该字段值是否可以重复
290         :param max_length: 该字段的最大长度
291         :param auto_increment: 是否自增
292         '''
293         self.db_column = db_column
294         self.null = null
295         self.blank = blank
296         self.choice = choice
297         self.db_index = db_index
298         self.default = default
299         self.primary_key = primary_key
300         if self.primary_key:
301             self.null = False
302         self.unique = unique
303         self.max_length = max_length
304         self.auto_increment = auto_increment
305         self.is_foreign_key = False
306 
307     def generate_field_sql(self):
308         pass
309 
310     def _generate_null(self):
311         if not self.null:
312             null = 'NOT NULL'
313         else:
314             null = 'NULL'
315         return null
316 
317     def _generate_default(self):
318         default = ''
319         if self.default is not None:
320             if self.is_str:
321                 default = ' DEFAULT "%s"' % self.default
322             else:
323                 default = ' DEFAULT %s' % self.default
324         return default
325 
326     def _generate_auto_increment(self):
327         ret = ''
328         if self.auto_increment:
329             ret = 'AUTO_INCREMENT'
330         return ret
331 
332 class CharField(BaseField):
333     def __init__(self, *args, **kwargs):
334         super().__init__(*args, **kwargs)
335         kwargs['blank'] = ''
336         if not self.max_length:
337             self.max_length = 128
338         if not self.default:
339             self.default = self.blank
340         self.is_str = True
341         self.field_type = 'CHAR'
342 
343     def generate_field_sql(self):
344         null = self._generate_null()
345         default = self._generate_default()
346         return '%s CHAR(%s) %s %s,' % (self.db_column, self.max_length, null, default)
347 
348 class IntField(BaseField):
349     def __init__(self, *args, **kwargs):
350         super().__init__(*args, **kwargs)
351         self.is_str = False
352         self.field_type = 'INT'
353 
354     def generate_field_sql(self):
355         null = self._generate_null()
356         default = self._generate_default()
357         auto_increment = self._generate_auto_increment()
358         return '%s INT %s %s %s,' % (self.db_column, null, default, auto_increment)
359 
360 class ForeignKeyField(BaseField):
361     def __init__(self, db_column, model_obj, null=True, default=None, on_delete='CASCADE'):
362         self.db_column = db_column
363         self.model_obj = model_obj
364         self.null = null
365         self.default = default
366         self.is_str = model_obj.primary_key_field.is_str
367         self.reference = model_obj.primary_key_field
368         self.on_delete = on_delete
369         self.is_foreign_key = True
370         self.primary_key = False
371         self.unique = False
372         self.db_index = False
373 
374     def generate_field_sql(self):
375         null = self._generate_null()
376         default = self._generate_default()
377         return '%s %s %s %s,' % (self.db_column, self.model_obj.primary_key_field.field_type, null, default)
378 
379 
380 Connector = MysqlConnector('127.0.0.1', 3306, 'root', '', 'test1')
381 
382 
383 if __name__ == '__main__':
384     class UserModel(BaseModel):
385         def initialize(self):
386             self.uid = IntField('uid', primary_key=True, auto_increment=True)
387             self.account = IntField('account', unique=True, null=False)
388             self.password = CharField('password', null=False)
389             self.name = CharField('name', null=False)
390             self.class_name = CharField('class_name', null=False)
391             self.profession = CharField('profession', null=False)
392             self.out_date_counts = IntField('out_date_counts', default=0)
393 
394     u = UserModel('user', Connector)

 

posted on 2017-06-21 15:06  MnCu  阅读(550)  评论(0)    收藏  举报