#!/usr/bin/env python
# -*- coding: utf8 -*-import pymysql, sqlparse
from app import logger
class MysqlDb:
def __init__(self, db, connect_timeout=5):
self.host = db['host']
self.user = db['user']
self.password = db['password']
self.database = db['database']
self.port = db['port']
self.connect_timeout = connect_timeout
self.max_allowed_packet = 16 * 1024 * 1024
self.read_timeout = 240
self.write_timeout = 10
def _db_connect(self):
"""
连接数据库
"""
count, conn, cur = 1, None, None
while True:
try:
conn = pymysql.connect(self.host, self.user, self.password, self.database, self.port,
connect_timeout=self.connect_timeout,
max_allowed_packet=self.max_allowed_packet, read_timeout=self.read_timeout,
write_timeout=self.write_timeout, charset='utf8')
cur = conn.cursor()
break
except Exception as e:
if count == 3:
raise Exception(e)
count += 1
return conn, cur
def _db_close(self, conn, cur):
"""
关闭数据库
"""
if conn and cur:
conn.close()
cur.close()
def many_insert(self, sql, param=None):
"""
批量插入
:param sql: "INSERT INTO table name (field1, field2) VALUES(%s, %s)"
:param param: 二元数组 ((1, 1), (2, 2))
"""
conn, cur = self._db_connect()
try:
if conn and cur:
cur.executemany(sql, param)
conn.commit()
except Exception as e:
conn.rollback()
raise Exception(e)
finally:
self._db_close(conn, cur)
def sql_execute(self, sql, param=None):
"""
执行sql
:param sql: UPDATE语句, DELETE语句, INSERT语句
:param sql: "INSERT INTO table name (field1, field2) VALUES(%s, %s)"
:param param: 一元数组 (1, 1)
:return last_id: INSERT语句返回自增ID
"""
result = 0
conn, cur = self._db_connect()
try:
if conn and cur:
result = cur.execute(sql, param)
conn.commit()
# 提交之后,获取刚插入的数据自增的ID
if cur.lastrowid:
result = cur.lastrowid
except Exception as e:
print(e)
logger.error(e)
conn.rollback()
raise Exception(e)
finally:
self._db_close(conn, cur)
return result
def sql_select(self, sql, param=None):
"""
SQL查询
:param sql: SELECT语句
:param param: 一元数组 (1, 1)
:return result: 返回字段名和数据
"""
result = {
"field": [],
"data": []
}
conn, cur = self._db_connect()
try:
if conn and cur:
cur.execute(sql, param)
result["field"] = [field[0] for field in cur.description]
result["data"] = cur.fetchall()
conn.commit()
except Exception as e:
conn.rollback()
raise Exception(e)
finally:
self._db_close(conn, cur)
return result
def sql_business(self, sqlcontent):
'''
sql事务处理
:param sqlcontent: 全部sql
:return:
'''
conn, cur = self._db_connect()
try:
if conn and cur:
execute_sql = sqlparse.format(sqlcontent, strip_comments=True).strip()
for sql in sqlparse.split(execute_sql):
cur.execute(sql)
except Exception as e:
conn.rollback() # 事务回滚
print('事务处理失败', e)
else:
conn.commit() # 事务提交
print('事务处理成功', cur.rowcount) # 关闭连接
self._db_close(conn, cur)