python3 sqlite3 数据库连接
python3 sqlite3 数据库创建 & 连接 脚本
# -*- coding:utf-8 -*-
import traceback
import sqlite3
import re
import os
class DB(object):
def __init__(self, dbname, autocommit = True):
self.dbname = dbname
self.cursor = None
self.connected = False
self.autocommit = autocommit
def connect(self):
self.close()
if self.autocommit:
self.conn = sqlite3.connect(self.dbname,isolation_level=None)
else:
self.conn = sqlite3.connect(self.dbname)
self.connected = True
def startTransaction(self):
if not self.connected:
self.connect()
def commitTransaction(self):
self.cursor.close()
self.conn.commit()
def endTransaction(self):
pass
def rollbackTransaction(self):
self.cursor.close()
self.conn.rollback()
def dict_factory(self, cursor, row):
d = {}
for index, col in enumerate(cursor.description):
d[col[0]] = row[index]
return d
# 正则表达式
def regexp(self, expr, item):
'''
正则表达式函数
@params expr: 需要搜索的正则表达式
@params item: sqlite 传入的需要搜索的内容
'''
if type(item).__name__ == "bytes":
item = item.decode()
reg = re.compile(expr)
return reg.search(item) is not None
# bytes 转 str
def bytes2Str(self, expr):
'''
bytes 转 str
在sqlite3 中, blob中获取到的值不能直接和字符串比, 需要先转换成字符串
@params expr: 需要转换的字符串
'''
if expr:
if type(expr).__name__ == "bytes":
return expr.decode()
return expr
def query(self, sql, params=()):
try:
if self.connected == False:
self.connect()
self.conn.row_factory = self.dict_factory
self.conn.create_function("regexp", 2, self.regexp)
self.conn.create_function("bytes2Str", 1, self.bytes2Str)
self.cursor = self.conn.cursor()
self.cursor.execute(sql, params)
if not self.autocommit:
self.conn.commit()
except (AttributeError, sqlite3.OperationalError) as e:
self.connect()
self.conn.row_factory = self.dict_factory
self.conn.create_function("regexp", 2, self.regexp)
self.conn.create_function("bytes2Str",1, self.bytes2Str)
self.cursor = self.conn.cursor()
self.cursor.execute(sql, params)
if not self.autocommit:
self.conn.commit()
except sqlite3.Error as e:
print("Error {0}, sql:({1})".format(e,sql))
if not self.autocommit:
self.rollbackTransaction()
print("{0}".format(e))
return False
return self.cursor
def script(self, SQLScriptStr):
"""执行SQL脚本"""
try:
if self.connected == False:
self.connect()
self.conn.row_factory = self.dict_factory
self.conn.create_function("regexp", 2, self.regexp)
self.conn.create_function("bytes2Str", 1, self.bytes2Str)
self.cursor = self.conn.cursor()
self.cursor.executescript(SQLScriptStr)
if not self.autocommit:
self.conn.commit()
except (AttributeError, sqlite3.OperationalError) as e:
self.connect()
self.conn.row_factory = self.dict_factory
self.conn.create_function("regexp", 2, self.regexp)
self.conn.create_function("bytes2Str",1, self.bytes2Str)
self.cursor = self.conn.cursor()
self.cursor.executescript(SQLScriptStr)
if not self.autocommit:
self.conn.commit()
except sqlite3.Error as e:
print("Error {0}, sql:({1})".format(e,SQLScriptStr))
if not self.autocommit:
self.rollbackTransaction()
print("{0}".format(e))
return False
return self.cursor
def getInsertId(self):
""""获取最近插入记录的rowid"""
returnid = None
try:
returnid = self.query("select last_insert_rowid()").fetchone().get("last_insert_rowid()")
except (AttributeError, sqlite3.OperationalError):
self.connect()
returnid =self.query("select last_insert_rowid()").fetchone().get("last_insert_rowid()")
except sqlite3.Error as e:
print("{0}".format(e))
return returnid
def close(self):
if hasattr(self, 'cursor') and self.cursor:
self.cursor.close()
self.cursor = None
if hasattr(self, 'conn') and self.conn:
self.conn.close()
def check_db(dbname, create_sql):
"""检查数据库,如果没有就创建"""
print("start check db")
local_path = os.path.dirname(os.path.abspath(__file__))
db_path = os.path.join(local_path, dbname)
dir_path = os.path.dirname(db_path)
if not os.path.exists(db_path):
print("not find db: {0}, start create".format(db_path))
if not os.path.exists(dir_path):
print("not find dir: {0}, start make it".format(dir_path))
try:
os.makedirs(dir_path)
except Exception as e:
print("make dir: {0} failed:{1}".format(dir_path, traceback.format_exc()))
print("check db finished")
return False
try:
db = DB(dbname=dbname)
db.script(create_sql)
if os.path.exists(db_path):
print("create db: {0} success".format(db_path))
else:
print("create db: {0} failed".format(db_path))
except Exception as e:
print("create db: {0} failed, error: {1}".format(db_path, traceback.format_exc()))
return False
""" 暂时不需要更新数据库表
else:
print("start run update_sql")
try:
db = DB(dbname=dbname)
db.script(update_sql)
except Exception as e:
print("run update_sql failed: {0}".format(traceback.format_exc()))
print("check db finished")
"""
return True
def column_exists(db, table_name, column_name):
"""
检查表中是否存在某个字段
"""
columns = db.query("PRAGMA table_info({0});".format(table_name)).fetchall()
for column in columns:
if column.get("name") == column_name:
return True
return False
def change_db():
# 检查数据库中是否有某个字段,如果没有,就插入这个字段
print("start change db")
try:
db = DB()
if not column_exists(db, "table1", "column_name"):
db.query("ALTER TABLE table1 ADD COLUMN column_name INTEGER DEFAULT 0;")
print("column add table1.column_name success")
else:
print("column table1.column_name already exists")
db.close()
return
except Exception as e:
print("change db failed: {0}".format(traceback.format_exc()))
if __name__ == "__main__":
dbname = "test.db"
create_sql = """
CREATE TABLE IF NOT EXISTS "user_info" (
"id" INTEGER,
"name" char(64),
PRIMARY KEY("id" AUTOINCREMENT)
);
"""
check_db(dbname=dbname, create_sql=create_sql)
db = DB(dbname=dbname)
insert_sql = """insert into user_info(name) values('John');"""
db.query(insert_sql)
get_sql = """select * from user_info;"""
user_info = db.query(get_sql).fetchone()
print(user_info)
db.close()
补充
另一个版本
sqlite_utils.py
import sqlite3
import time
from contextlib import contextmanager
default_db_path = 'default.db'
class SQLiteUtils:
def __init__(self, db_path=default_db_path):
self.db_path = db_path
self.connection = sqlite3.connect(db_path, check_same_thread=False) # 允许跨线程使用连接
self.connection.row_factory = sqlite3.Row # 设置行工厂为字典形式
# 启用WAL模式提升并发性能
with self._get_connection() as cursor:
cursor.execute('PRAGMA journal_mode=WAL;')
cursor.execute('PRAGMA synchronous=NORMAL;')
@contextmanager
def _get_connection(self):
"""上下文管理器管理数据库连接和事务"""
try:
if not hasattr(self, 'connection') or self.connection is None:
self.connection = sqlite3.connect(self.db_path, check_same_thread=False) # 允许跨线程使用连接
self.connection.row_factory = sqlite3.Row # 设置行工厂为字典形式
cursor = self.connection.cursor()
yield cursor
self.connection.commit()
except sqlite3.Error as e:
self.connection.rollback()
raise RuntimeError(f'数据库操作失败: {str(e)}')
finally:
if 'cursor' in locals() and cursor:
cursor.close()
def create_table(self, table_name, columns):
"""
创建数据表
:param table_name: 表名
:param columns: 字段定义字典 {字段名: 字段类型}
"""
column_defs = ', '.join([f'{k} {v}' for k, v in columns.items()])
create_sql = f'CREATE TABLE IF NOT EXISTS {table_name} ({column_defs})'
with self._get_connection() as cursor:
cursor.execute(create_sql)
def create_index(self, table_name, index_name, columns, unique=False):
"""
创建索引(带锁冲突重试)
:param table_name: 表名
:param index_name: 索引名
:param columns: 索引字段列表 [字段名, ...]
:param unique: 是否唯一索引(默认False)
"""
if not columns:
raise ValueError('索引字段不能为空')
columns_str = ', '.join(columns)
unique_str = 'UNIQUE ' if unique else ''
create_index_sql = f'CREATE {unique_str}INDEX IF NOT EXISTS {index_name} ON {table_name} ({columns_str})'
max_retries = 10
retry_delay = 0.2 # 200ms
for attempt in range(max_retries):
try:
with self._get_connection() as cursor:
cursor.execute(create_index_sql)
return
except sqlite3.OperationalError as e:
if 'database is locked' in str(e) and attempt < max_retries - 1:
time.sleep(retry_delay)
continue
raise RuntimeError(f'创建索引失败(尝试{max_retries}次): {str(e)}')
def insert(self, table_name, data):
"""
插入单条记录
:param table_name: 表名
:param data: 数据字典 {字段名: 值}
:return: 插入的行ID
"""
placeholders = ', '.join(['?' for _ in data.values()])
columns = ', '.join(data.keys())
insert_sql = f'INSERT INTO {table_name} ({columns}) VALUES ({placeholders})'
with self._get_connection() as cursor:
cursor.execute(insert_sql, tuple(data.values()))
return cursor.lastrowid
def insert_batch(self, table_name, data_list):
"""
批量插入记录(带锁冲突重试)
:param table_name: 表名
:param data_list: 数据列表 [{{字段名: 值}}, ...]
"""
if not data_list:
return
columns = ', '.join(data_list[0].keys())
placeholders = ', '.join(['?' for _ in data_list[0].values()])
insert_sql = f'INSERT INTO {table_name} ({columns}) VALUES ({placeholders})'
max_retries = 10
retry_delay = 0.2 # 200ms
for attempt in range(max_retries):
try:
with self._get_connection() as cursor:
cursor.executemany(insert_sql, [tuple(item.values()) for item in data_list])
return
except sqlite3.OperationalError as e:
if 'database is locked' in str(e) and attempt < max_retries - 1:
time.sleep(retry_delay)
continue
raise RuntimeError(f'批量插入失败(尝试{max_retries}次): {str(e)}')
def query(self, table_name, condition=None, fields='*', fetch_type='fetchall', other_condition=None):
"""
查询记录
:param table_name: 表名
:param condition: 条件字典 {字段名: 值}(默认查询所有)
:param fields: 查询字段(默认*)
:param fetch_type: 查询结果获取方式,可选'fetchone'(单条)或'fetchall'(所有),默认'fetchall'
:return: 结果列表(字典形式,fetchall时)或单条记录(字典形式,fetchone时)
"""
where_clause = ''
params = ()
if condition:
conditions = []
params = []
# 支持的运算符列表(按长度从长到短匹配)
operators = ['not in', '>=', '<=', '>', '<', 'like', 'not like', 'in', '!=', 'is NULL', 'is not NULL']
for key, value in condition.items():
field = key
operator = '='
# 匹配运算符
for op in operators:
if key.endswith(op):
operator = op
field = key[:-len(op)]
field = field.strip()
break
# 处理in/not in需要多个占位符的情况
if operator in ['in', 'not in']:
if not isinstance(value, (list, tuple)):
raise ValueError(f'运算符{operator}要求值为列表/元组类型')
placeholders = ', '.join(['?' for _ in value])
conditions.append(f'{field} {operator} ({placeholders})')
params.extend(value)
elif operator in ['is NULL','is not NULL']:
conditions.append(f'{field} {operator}') # IS NULL不需要占位符和参数
else:
conditions.append(f'{field} {operator} ?')
params.append(value)
where_clause = 'WHERE ' + ' AND '.join(conditions)
params = tuple(params)
other_condition_str = ''
if other_condition:
other_condition_str = ' ' + other_condition
query_sql = f'SELECT {fields} FROM {table_name} {where_clause} {other_condition_str}'
with self._get_connection() as cursor:
cursor.execute(query_sql, params)
if fetch_type == 'fetchone':
row = cursor.fetchone()
return dict(row) if row else None
else:
return [dict(row) for row in cursor.fetchall()]
def update(self, table_name, data, condition, other_condition=None):
"""
更新记录(带锁冲突重试)
:param table_name: 表名
:param data: 新数据字典 {字段名: 值}
:param condition: 条件字典 {字段名: 值}
:return: 受影响的行数
"""
set_clause = ', '.join([f'{k}=?' for k in data.keys()])
# 解析条件(支持运算符)
conditions = []
params_list = []
operators = ['not in', '>=', '<=', '>', '<', 'like', 'not like', 'in', '!=', 'is NULL', 'is not NULL']
for key, value in condition.items():
field = key
operator = '='
for op in operators:
if key.endswith(op):
operator = op
field = key[:-len(op)]
field = field.strip()
break
if operator in ['in', 'not in']:
if not isinstance(value, (list, tuple)):
raise ValueError(f'运算符{operator}要求值为列表/元组类型')
placeholders = ', '.join(['?' for _ in value])
conditions.append(f'{field} {operator} ({placeholders})')
params_list.extend(value)
elif operator in ['is NULL','is not NULL']:
conditions.append(f'{field} {operator}')
else:
conditions.append(f'{field} {operator} ?')
params_list.append(value)
where_clause = ' AND '.join(conditions) if conditions else ''
other_condition_str = ''
if other_condition:
other_condition_str = ' ' + other_condition
update_sql = f'UPDATE {table_name} SET {set_clause} WHERE {where_clause}' if where_clause else f'UPDATE {table_name} SET {set_clause}'
update_sql = update_sql + other_condition_str
params = tuple(data.values()) + tuple(params_list)
max_retries = 10
retry_delay = 0.2 # 200ms
for attempt in range(max_retries):
try:
with self._get_connection() as cursor:
cursor.execute(update_sql, params)
return cursor.rowcount
except sqlite3.OperationalError as e:
if 'database is locked' in str(e) and attempt < max_retries - 1:
time.sleep(retry_delay)
continue
raise RuntimeError(f'更新失败(尝试{max_retries}次): {str(e)}')
def delete(self, table_name, condition=None):
"""
删除记录(带锁冲突重试)
:param table_name: 表名
:param condition: 条件字典 {字段名: 值}(默认删除所有记录)
:return: 受影响的行数
"""
# 解析条件(支持运算符)
conditions = []
params_list = []
operators = ['not in', '>=', '<=', '>', '<', 'like', 'not like', 'in', '!=', 'is NULL', 'is not NULL']
if condition:
for key, value in condition.items():
field = key
operator = '='
for op in operators:
if key.endswith(op):
operator = op
field = key[:-len(op)]
field = field.strip()
break
if operator in ['in', 'not in']:
if not isinstance(value, (list, tuple)):
raise ValueError(f'运算符{operator}要求值为列表/元组类型')
placeholders = ', '.join(['?' for _ in value])
conditions.append(f'{field} {operator} ({placeholders})')
params_list.extend(value)
elif operator in ['is NULL', 'is not NULL']:
conditions.append(f'{field} {operator}')
else:
conditions.append(f'{field} {operator} ?')
params_list.append(value)
where_clause = ' AND '.join(conditions)
delete_sql = f'DELETE FROM {table_name} WHERE {where_clause}'
params = tuple(params_list)
else:
delete_sql = f'DELETE FROM {table_name}'
params = ()
max_retries = 10
retry_delay = 0.2 # 200ms
for attempt in range(max_retries):
try:
with self._get_connection() as cursor:
cursor.execute(delete_sql, params)
return cursor.rowcount
except sqlite3.OperationalError as e:
if 'database is locked' in str(e) and attempt < max_retries - 1:
time.sleep(retry_delay)
continue
raise RuntimeError(f'删除失败(尝试{max_retries}次): {str(e)}')
def __del__(self):
"""析构时关闭连接"""
if hasattr(self, 'connection') and self.connection:
self.connection.close()
if __name__ == '__main__':
# 示例用法
db = SQLiteUtils('demo.db')
# 创建用户表
db.create_table('users', {
'id': 'INTEGER PRIMARY KEY AUTOINCREMENT',
'name': 'TEXT NOT NULL',
'age': 'INTEGER',
'email': 'TEXT UNIQUE'
})
# 清理表
db.delete('users', None)
# 插入单条记录
user_id = db.insert('users', {'name': 'Alice', 'age': 30, 'email': 'alice@example.com'})
print(f'插入用户ID: {user_id}')
# 批量插入
db.insert_batch('users', [
{'name': 'Bob', 'age': 25, 'email': 'bob@example.com'},
{'name': 'Charlie', 'age': 35, 'email': 'charlie@example.com'},
{'name': 'Tom', 'age': 35, 'email': 'tom@example.com'}
])
# 查询所有用户
print('所有用户:', db.query('users'))
# 更新用户年龄
affected = db.update('users', {'age': 26}, {'name': 'Bob'})
print(f'更新影响行数: {affected}')
print('更新后Bob:', db.query('users', {'name': 'Bob'}))
# 删除用户
affected = db.delete('users', {'name': 'Charlie'})
print(f'删除影响行数: {affected}')
print('删除后用户:', db.query('users'))
# 测试in/not in/!=条件查询
print('测试in条件查询(年龄在25-30之间):', db.query('users', {'age in': [25, 26, 30]}))
print('测试not in条件查询(年龄不在30以上):', db.query('users', {'age not in': [31, 35]}))
print('测试>条件查询(年龄在30以上):', db.query('users', {'age >': 30}))
print('测试!=条件查询(邮箱不等于alice@example.com):', db.query('users', {'email!=': 'alice@example.com'}))
# 创建邮箱索引并验证效率
db.create_index('users', 'idx_email', ['email'], unique=True)
start = time.time()
db.query('users', {'email': 'bob@example.com'})
print('索引查询耗时:', time.time() - start, '秒')
# 测试批量插入更多数据
test_data = [{'name': f'user_{i}', 'age': 20 + i, 'email': f'user_{i}@example.com'} for i in range(10, 20)]
db.insert_batch('users', test_data)
print('\n批量插入后总用户数:', len(db.query('users')))
# # 测试条件更新
# db.update('users', {'age':20}, {'name like':"%user%"})
# # 测试条件删除
# db.delete('users', {'age <=': 30})
print(db.query('users', {"name is not NULL":""}))
test_sqlite_performance.py
import time
import random
import threading
import multiprocessing
from sqlite_utils import SQLiteUtils
import uuid
# 配置参数
TEST_DB = 'performance_test.db'
DATA_SIZE = 1000 # 单批次数据量
BATCH_SIZE = 1000 # 批量插入批次大小
THREAD_NUM = 4 # 线程数
PROCESS_NUM = 4 # 进程数
def generate_test_data(size, identifier):
"""生成模拟测试数据(带唯一标识)"""
return [{
'name': f'user_{i}',
'age': random.randint(18, 60),
'email': f'user_{identifier}_{i}@test.com'
} for i in range(size)]
def single_thread_test(thread_id):
"""单线程性能测试"""
db = SQLiteUtils(TEST_DB)
start = time.perf_counter()
# 批量插入
data = generate_test_data(DATA_SIZE, thread_id)
for i in range(0, DATA_SIZE, BATCH_SIZE):
db.insert_batch('users', data[i:i+BATCH_SIZE])
# 随机查询
query_count = 1000
for _ in range(query_count):
db.query('users', {'name': f'user_{random.randint(0, DATA_SIZE-1)}'})
# 随机更新
update_count = 500
for _ in range(update_count):
db.update('users',
{'age': random.randint(18, 60)},
{'name': f'user_{random.randint(0, DATA_SIZE-1)}'})
# 随机删除
delete_count = 200
for _ in range(delete_count):
db.delete('users', {'name': f'user_{random.randint(0, DATA_SIZE-1)}'})
elapsed = time.perf_counter() - start
print(f'线程{thread_id}完成测试,耗时: {elapsed:.2f}s')
return elapsed
def multi_thread_test():
"""多线程并发测试"""
start = time.perf_counter()
threads = []
for i in range(THREAD_NUM):
t = threading.Thread(target=single_thread_test, args=(i,))
threads.append(t)
t.start()
for t in threads:
t.join()
elapsed = time.perf_counter() - start
print(f'\n多线程测试完成,总耗时: {elapsed:.2f}s')
print(f'吞吐量: {THREAD_NUM * DATA_SIZE / elapsed:.2f}条/秒')
def single_process_test(process_id):
"""单进程性能测试(用于多进程并发)"""
single_thread_test(process_id)
def multi_process_test():
"""多进程并发测试"""
start = time.perf_counter()
processes = []
for i in range(PROCESS_NUM):
p = multiprocessing.Process(target=single_process_test, args=(i,))
processes.append(p)
p.start()
for p in processes:
p.join()
elapsed = time.perf_counter() - start
print(f'\n多进程测试完成,总耗时: {elapsed:.2f}s')
print(f'吞吐量: {PROCESS_NUM * DATA_SIZE / elapsed:.2f}条/秒')
def data_consistency_check():
"""数据一致性验证"""
db = SQLiteUtils(TEST_DB)
total = len(db.query('users'))
print(f'\n数据一致性检查:当前用户表记录数 {total} 条')
if total > 0:
sample = db.query('users', {'name': 'user_1'})
print(f'示例记录:{sample}')
sample2 = db.query('users', {'name': 'user_2'})
print(f'示例记录:{sample2}')
else:
print('警告:测试后表中无数据!')
if __name__ == '__main__':
# 初始化测试数据库
test_db = SQLiteUtils(TEST_DB)
test_db.create_table('users', {
'id': 'INTEGER PRIMARY KEY AUTOINCREMENT',
'name': 'TEXT NOT NULL',
'age': 'INTEGER',
'email': 'TEXT UNIQUE'
})
test_db.create_index("users", "name_index", ["name"], unique=False)
# 执行测试
print('=== 开始单线程性能测试 ===')
test_db.delete('users', {}) # 清空旧数据
single_elapsed = single_thread_test(0)
print(f'单线程测试耗时: {single_elapsed:.2f}s')
print(f'单线程吞吐量: {DATA_SIZE / single_elapsed:.2f}条/秒')
print('\n=== 开始多线程并发测试 ===')
test_db.delete('users', {}) # 清空旧数据
multi_thread_test()
print('\n=== 开始多进程并发测试 ===')
test_db.delete('users', {}) # 清空旧数据
multi_process_test()
# 一致性检查
data_consistency_check()
# 清理测试数据库(可选)
# import os
# os.remove(TEST_DB)'}}}