python3 sqlite3 数据库连接

python3 sqlite3 数据库创建 & 连接 脚本

# -*- coding:utf-8 -*-
import traceback
import sqlite3
import re
import os

class DB(object):
    def __init__(self, dbname, autocommit = True):        
        self.dbname = dbname
        self.cursor = None
        self.connected = False
        self.autocommit = autocommit
        
    def connect(self):
        self.close()
        if self.autocommit:
            self.conn = sqlite3.connect(self.dbname,isolation_level=None)
        else:
            self.conn = sqlite3.connect(self.dbname)
        self.connected = True
    
    def startTransaction(self):
        if not self.connected:
            self.connect()
    
    def commitTransaction(self):
        self.cursor.close() 
        self.conn.commit()
    
    def endTransaction(self):
        pass
    
    def rollbackTransaction(self):
        self.cursor.close()  
        self.conn.rollback()
        
    def dict_factory(self, cursor, row):
        d = {}
        for index, col in enumerate(cursor.description):
            d[col[0]] = row[index]
        return d
    
    # 正则表达式
    def regexp(self, expr, item):
        '''
        正则表达式函数
        @params expr: 需要搜索的正则表达式
        @params item: sqlite 传入的需要搜索的内容
        '''
        if type(item).__name__ == "bytes":
            item = item.decode()
        reg = re.compile(expr)
        return reg.search(item) is not None
    
    # bytes 转 str
    def bytes2Str(self, expr):
        '''
        bytes 转 str
        在sqlite3 中, blob中获取到的值不能直接和字符串比, 需要先转换成字符串
        @params expr: 需要转换的字符串
        '''
        if expr:
            if type(expr).__name__ == "bytes":
                return expr.decode()
        return expr
        
    def query(self, sql, params=()):
          try:
               if self.connected == False:
                   self.connect()
               self.conn.row_factory = self.dict_factory
               self.conn.create_function("regexp", 2, self.regexp)
               self.conn.create_function("bytes2Str", 1, self.bytes2Str)
               self.cursor = self.conn.cursor()
               self.cursor.execute(sql, params)
               if not self.autocommit:
                   self.conn.commit()
          except (AttributeError, sqlite3.OperationalError) as e:
               self.connect()
               self.conn.row_factory = self.dict_factory
               self.conn.create_function("regexp", 2, self.regexp)
               self.conn.create_function("bytes2Str",1, self.bytes2Str)
               self.cursor = self.conn.cursor()
               self.cursor.execute(sql, params)
               if not self.autocommit:
                   self.conn.commit()
          except sqlite3.Error as e:
              print("Error {0}, sql:({1})".format(e,sql))
              if not self.autocommit:
                self.rollbackTransaction()
              print("{0}".format(e))
              return False
          return self.cursor

    def script(self, SQLScriptStr):
          """执行SQL脚本"""
          try:
               if self.connected == False:
                   self.connect()
               self.conn.row_factory = self.dict_factory
               self.conn.create_function("regexp", 2, self.regexp)
               self.conn.create_function("bytes2Str", 1, self.bytes2Str)
               self.cursor = self.conn.cursor()
               self.cursor.executescript(SQLScriptStr)
               if not self.autocommit:
                   self.conn.commit()
          except (AttributeError, sqlite3.OperationalError) as e:
               self.connect()
               self.conn.row_factory = self.dict_factory
               self.conn.create_function("regexp", 2, self.regexp)
               self.conn.create_function("bytes2Str",1, self.bytes2Str)
               self.cursor = self.conn.cursor()
               self.cursor.executescript(SQLScriptStr)
               if not self.autocommit:
                   self.conn.commit()
          except sqlite3.Error as e:
              print("Error {0}, sql:({1})".format(e,SQLScriptStr))
              if not self.autocommit:
                self.rollbackTransaction()
              print("{0}".format(e))
              return False
          return self.cursor

    def getInsertId(self):
          """"获取最近插入记录的rowid"""
          returnid = None
          try:
               returnid = self.query("select last_insert_rowid()").fetchone().get("last_insert_rowid()")
          except (AttributeError, sqlite3.OperationalError):
               self.connect()
               returnid =self.query("select last_insert_rowid()").fetchone().get("last_insert_rowid()")
          except sqlite3.Error as e:
               print("{0}".format(e))
          return returnid
      
    def close(self):
        if  hasattr(self, 'cursor') and self.cursor:
            self.cursor.close()
            self.cursor = None

        if  hasattr(self, 'conn') and self.conn:
            self.conn.close()


def check_db(dbname, create_sql):
    """检查数据库,如果没有就创建"""
    print("start check db")
    local_path = os.path.dirname(os.path.abspath(__file__))
    db_path = os.path.join(local_path, dbname)
    dir_path = os.path.dirname(db_path)
    
    if not os.path.exists(db_path):
        print("not find db: {0}, start create".format(db_path))
        
        if not os.path.exists(dir_path):
            print("not find dir: {0}, start make it".format(dir_path))
            try:
                os.makedirs(dir_path)
            except  Exception as e:
                print("make dir: {0} failed:{1}".format(dir_path, traceback.format_exc()))
                print("check db finished")
                return False
        try:
            db = DB(dbname=dbname)
            db.script(create_sql)
            if os.path.exists(db_path):
                print("create db: {0} success".format(db_path))
            else:
                print("create db: {0} failed".format(db_path))
        except Exception as e:
            print("create db: {0} failed, error: {1}".format(db_path, traceback.format_exc()))
            return False
            
    """ 暂时不需要更新数据库表
    else:
        print("start run update_sql")
        try:
            db =  DB(dbname=dbname)
            db.script(update_sql)
        except Exception as e:
            print("run update_sql failed: {0}".format(traceback.format_exc()))
    print("check db finished")
    """
    return True

def column_exists(db, table_name, column_name):
    """
    检查表中是否存在某个字段
    """    
    columns = db.query("PRAGMA table_info({0});".format(table_name)).fetchall()
    for column in columns:
        if column.get("name") == column_name:
            return True
    return False

def change_db():
    # 检查数据库中是否有某个字段,如果没有,就插入这个字段 
    print("start change db")
    try:
        db = DB()
        if not column_exists(db, "table1", "column_name"):
            db.query("ALTER TABLE table1 ADD COLUMN column_name INTEGER DEFAULT 0;")
            print("column add table1.column_name success")
        else:
            print("column table1.column_name already exists")
        db.close()
        return
    except Exception as e:
        print("change db failed: {0}".format(traceback.format_exc()))

if __name__ == "__main__":
    dbname = "test.db"
    create_sql = """
        CREATE TABLE IF NOT EXISTS "user_info" (
            "id"	INTEGER,
            "name" char(64),
            PRIMARY KEY("id" AUTOINCREMENT)
    );
    """
    check_db(dbname=dbname, create_sql=create_sql)
    db = DB(dbname=dbname)
    insert_sql = """insert into user_info(name) values('John');"""
    db.query(insert_sql)
    
    get_sql = """select * from user_info;"""
    user_info = db.query(get_sql).fetchone()
    print(user_info)
    
    db.close()

补充


另一个版本
sqlite_utils.py

import sqlite3
import time
from contextlib import contextmanager


default_db_path = 'default.db'


class SQLiteUtils:
    def __init__(self, db_path=default_db_path):
        self.db_path = db_path
        self.connection = sqlite3.connect(db_path, check_same_thread=False)  # 允许跨线程使用连接
        self.connection.row_factory = sqlite3.Row  # 设置行工厂为字典形式
        # 启用WAL模式提升并发性能
        with self._get_connection() as cursor:
            cursor.execute('PRAGMA journal_mode=WAL;')
            cursor.execute('PRAGMA synchronous=NORMAL;')

    @contextmanager
    def _get_connection(self):
        """上下文管理器管理数据库连接和事务"""
        try:
            if not hasattr(self, 'connection') or self.connection is None:
                self.connection = sqlite3.connect(self.db_path, check_same_thread=False)  # 允许跨线程使用连接
                self.connection.row_factory = sqlite3.Row  # 设置行工厂为字典形式
            cursor = self.connection.cursor()
            yield cursor
            self.connection.commit()
        except sqlite3.Error as e:
            self.connection.rollback()
            raise RuntimeError(f'数据库操作失败: {str(e)}')
        finally:
            if 'cursor' in locals() and cursor:
                cursor.close()

    def create_table(self, table_name, columns):
        """
        创建数据表
        :param table_name: 表名
        :param columns: 字段定义字典 {字段名: 字段类型}
        """
        column_defs = ', '.join([f'{k} {v}' for k, v in columns.items()])
        create_sql = f'CREATE TABLE IF NOT EXISTS {table_name} ({column_defs})'
        with self._get_connection() as cursor:
            cursor.execute(create_sql)

    def create_index(self, table_name, index_name, columns, unique=False):
        """
        创建索引(带锁冲突重试)
        :param table_name: 表名
        :param index_name: 索引名
        :param columns: 索引字段列表 [字段名, ...]
        :param unique: 是否唯一索引(默认False)
        """
        if not columns:
            raise ValueError('索引字段不能为空')
        columns_str = ', '.join(columns)
        unique_str = 'UNIQUE ' if unique else ''
        create_index_sql = f'CREATE {unique_str}INDEX IF NOT EXISTS {index_name} ON {table_name} ({columns_str})'
        max_retries = 10
        retry_delay = 0.2  # 200ms
        for attempt in range(max_retries):
            try:
                with self._get_connection() as cursor:
                    cursor.execute(create_index_sql)
                    return
            except sqlite3.OperationalError as e:
                if 'database is locked' in str(e) and attempt < max_retries - 1:
                    time.sleep(retry_delay)
                    continue
                raise RuntimeError(f'创建索引失败(尝试{max_retries}次): {str(e)}')

    def insert(self, table_name, data):
        """
        插入单条记录
        :param table_name: 表名
        :param data: 数据字典 {字段名: 值}
        :return: 插入的行ID
        """
        placeholders = ', '.join(['?' for _ in data.values()])
        columns = ', '.join(data.keys())
        insert_sql = f'INSERT INTO {table_name} ({columns}) VALUES ({placeholders})'
        with self._get_connection() as cursor:
            cursor.execute(insert_sql, tuple(data.values()))
            return cursor.lastrowid

    def insert_batch(self, table_name, data_list):
        """
        批量插入记录(带锁冲突重试)
        :param table_name: 表名
        :param data_list: 数据列表 [{{字段名: 值}}, ...]
        """
        if not data_list:
            return
        columns = ', '.join(data_list[0].keys())
        placeholders = ', '.join(['?' for _ in data_list[0].values()])
        insert_sql = f'INSERT INTO {table_name} ({columns}) VALUES ({placeholders})'
        max_retries = 10
        retry_delay = 0.2  # 200ms
        for attempt in range(max_retries):
            try:
                with self._get_connection() as cursor:
                    cursor.executemany(insert_sql, [tuple(item.values()) for item in data_list])
                    return
            except sqlite3.OperationalError as e:
                if 'database is locked' in str(e) and attempt < max_retries - 1:
                    time.sleep(retry_delay)
                    continue
                raise RuntimeError(f'批量插入失败(尝试{max_retries}次): {str(e)}')

    def query(self, table_name, condition=None, fields='*', fetch_type='fetchall', other_condition=None):
        """
        查询记录
        :param table_name: 表名
        :param condition: 条件字典 {字段名: 值}(默认查询所有)
        :param fields: 查询字段(默认*)
        :param fetch_type: 查询结果获取方式,可选'fetchone'(单条)或'fetchall'(所有),默认'fetchall'
        :return: 结果列表(字典形式,fetchall时)或单条记录(字典形式,fetchone时)
        """
        where_clause = ''
        params = ()
        if condition:
            conditions = []
            params = []
            # 支持的运算符列表(按长度从长到短匹配)
            operators = ['not in', '>=', '<=', '>', '<', 'like', 'not like', 'in', '!=', 'is NULL', 'is not NULL']
            for key, value in condition.items():
                field = key
                operator = '='
                # 匹配运算符
                for op in operators:
                    if key.endswith(op):
                        operator = op
                        field = key[:-len(op)]
                        field = field.strip()
                        break
                # 处理in/not in需要多个占位符的情况
                if operator in ['in', 'not in']:
                    if not isinstance(value, (list, tuple)):
                        raise ValueError(f'运算符{operator}要求值为列表/元组类型')
                    placeholders = ', '.join(['?' for _ in value])
                    conditions.append(f'{field} {operator} ({placeholders})')
                    params.extend(value)
                elif operator in ['is NULL','is not NULL']:
                    conditions.append(f'{field} {operator}')  # IS NULL不需要占位符和参数
                else:
                    conditions.append(f'{field} {operator} ?')
                    params.append(value)
            where_clause = 'WHERE ' + ' AND '.join(conditions)
            params = tuple(params)
        
        other_condition_str = ''
        if other_condition:
            other_condition_str = ' ' + other_condition
        
        query_sql = f'SELECT {fields} FROM {table_name} {where_clause} {other_condition_str}'
        with self._get_connection() as cursor:
            cursor.execute(query_sql, params)
            if fetch_type == 'fetchone':
                row = cursor.fetchone()
                return dict(row) if row else None
            else:
                return [dict(row) for row in cursor.fetchall()]

    def update(self, table_name, data, condition, other_condition=None):
        """
        更新记录(带锁冲突重试)
        :param table_name: 表名
        :param data: 新数据字典 {字段名: 值}
        :param condition: 条件字典 {字段名: 值}
        :return: 受影响的行数
        """
        set_clause = ', '.join([f'{k}=?' for k in data.keys()])
        # 解析条件(支持运算符)
        conditions = []
        params_list = []
        operators = ['not in', '>=', '<=', '>', '<', 'like', 'not like', 'in', '!=', 'is NULL', 'is not NULL']
        for key, value in condition.items():
            field = key
            operator = '='
            for op in operators:
                if key.endswith(op):
                    operator = op
                    field = key[:-len(op)]
                    field = field.strip()
                    break
            if operator in ['in', 'not in']:
                if not isinstance(value, (list, tuple)):
                    raise ValueError(f'运算符{operator}要求值为列表/元组类型')
                placeholders = ', '.join(['?' for _ in value])
                conditions.append(f'{field} {operator} ({placeholders})')
                params_list.extend(value)
            elif operator in ['is NULL','is not NULL']:
                conditions.append(f'{field} {operator}')
            else:
                conditions.append(f'{field} {operator} ?')
                params_list.append(value)
        where_clause = ' AND '.join(conditions) if conditions else ''
        
        other_condition_str = ''
        if other_condition:
            other_condition_str = ' ' + other_condition
        
        update_sql = f'UPDATE {table_name} SET {set_clause} WHERE {where_clause}' if where_clause else f'UPDATE {table_name} SET {set_clause}'
        update_sql = update_sql + other_condition_str
        params = tuple(data.values()) + tuple(params_list)
        max_retries = 10
        retry_delay = 0.2  # 200ms
        for attempt in range(max_retries):
            try:
                with self._get_connection() as cursor:
                    cursor.execute(update_sql, params)
                    return cursor.rowcount
            except sqlite3.OperationalError as e:
                if 'database is locked' in str(e) and attempt < max_retries - 1:
                    time.sleep(retry_delay)
                    continue
                raise RuntimeError(f'更新失败(尝试{max_retries}次): {str(e)}')

    def delete(self, table_name, condition=None):
        """
        删除记录(带锁冲突重试)
        :param table_name: 表名
        :param condition: 条件字典 {字段名: 值}(默认删除所有记录)
        :return: 受影响的行数
        """
        # 解析条件(支持运算符)
        conditions = []
        params_list = []
        operators = ['not in', '>=', '<=', '>', '<', 'like', 'not like', 'in', '!=', 'is NULL', 'is not NULL']
        if condition:
            for key, value in condition.items():
                field = key
                operator = '='
                for op in operators:
                    if key.endswith(op):
                        operator = op
                        field = key[:-len(op)]
                        field = field.strip()
                        break
                if operator in ['in', 'not in']:
                    if not isinstance(value, (list, tuple)):
                        raise ValueError(f'运算符{operator}要求值为列表/元组类型')
                    placeholders = ', '.join(['?' for _ in value])
                    conditions.append(f'{field} {operator} ({placeholders})')
                    params_list.extend(value)
                elif operator in ['is NULL', 'is not NULL']:
                    conditions.append(f'{field} {operator}')
                else:
                    conditions.append(f'{field} {operator} ?')
                    params_list.append(value)
            where_clause = ' AND '.join(conditions)
            delete_sql = f'DELETE FROM {table_name} WHERE {where_clause}'
            params = tuple(params_list)
        else:
            delete_sql = f'DELETE FROM {table_name}'
            params = ()
        max_retries = 10
        retry_delay = 0.2  # 200ms
        for attempt in range(max_retries):
            try:
                with self._get_connection() as cursor:
                    cursor.execute(delete_sql, params)
                    return cursor.rowcount
            except sqlite3.OperationalError as e:
                if 'database is locked' in str(e) and attempt < max_retries - 1:
                    time.sleep(retry_delay)
                    continue
                raise RuntimeError(f'删除失败(尝试{max_retries}次): {str(e)}')
    
    def __del__(self):
        """析构时关闭连接"""
        if hasattr(self, 'connection') and self.connection:
            self.connection.close()

if __name__ == '__main__':
    # 示例用法
    db = SQLiteUtils('demo.db')
    # 创建用户表
    db.create_table('users', {
        'id': 'INTEGER PRIMARY KEY AUTOINCREMENT',
        'name': 'TEXT NOT NULL',
        'age': 'INTEGER',
        'email': 'TEXT UNIQUE'
    })
    
    # 清理表
    db.delete('users', None)
    
    # 插入单条记录
    user_id = db.insert('users', {'name': 'Alice', 'age': 30, 'email': 'alice@example.com'})
    print(f'插入用户ID: {user_id}')
    
    # 批量插入
    db.insert_batch('users', [
        {'name': 'Bob', 'age': 25, 'email': 'bob@example.com'},
        {'name': 'Charlie', 'age': 35, 'email': 'charlie@example.com'},
        {'name': 'Tom', 'age': 35, 'email': 'tom@example.com'}
    ])

    # 查询所有用户
    print('所有用户:', db.query('users'))

    # 更新用户年龄
    affected = db.update('users', {'age': 26}, {'name': 'Bob'})
    print(f'更新影响行数: {affected}')
    print('更新后Bob:', db.query('users', {'name': 'Bob'}))

    # 删除用户
    affected = db.delete('users', {'name': 'Charlie'})
    print(f'删除影响行数: {affected}')
    print('删除后用户:', db.query('users'))

    # 测试in/not in/!=条件查询
    print('测试in条件查询(年龄在25-30之间):', db.query('users', {'age in': [25, 26, 30]}))
    print('测试not in条件查询(年龄不在30以上):', db.query('users', {'age not in': [31, 35]}))
    print('测试>条件查询(年龄在30以上):', db.query('users', {'age >': 30}))
    print('测试!=条件查询(邮箱不等于alice@example.com):', db.query('users', {'email!=': 'alice@example.com'}))

    # 创建邮箱索引并验证效率
    db.create_index('users', 'idx_email', ['email'], unique=True)
    start = time.time()
    db.query('users', {'email': 'bob@example.com'})
    print('索引查询耗时:', time.time() - start, '秒')

    # 测试批量插入更多数据
    test_data = [{'name': f'user_{i}', 'age': 20 + i, 'email': f'user_{i}@example.com'} for i in range(10, 20)]
    db.insert_batch('users', test_data)
    print('\n批量插入后总用户数:', len(db.query('users')))
    
    # # 测试条件更新
    # db.update('users', {'age':20}, {'name like':"%user%"})
    
    # # 测试条件删除
    # db.delete('users', {'age <=': 30})
    print(db.query('users', {"name is not NULL":""}))

test_sqlite_performance.py

import time
import random
import threading
import multiprocessing
from sqlite_utils import SQLiteUtils
import uuid

# 配置参数
TEST_DB = 'performance_test.db'
DATA_SIZE = 1000  # 单批次数据量
BATCH_SIZE = 1000  # 批量插入批次大小
THREAD_NUM = 4    # 线程数
PROCESS_NUM = 4    # 进程数


def generate_test_data(size, identifier):
    """生成模拟测试数据(带唯一标识)"""
    return [{
        'name': f'user_{i}',
        'age': random.randint(18, 60),
        'email': f'user_{identifier}_{i}@test.com'
    } for i in range(size)]


def single_thread_test(thread_id):
    """单线程性能测试"""
    db = SQLiteUtils(TEST_DB)
    start = time.perf_counter()

    # 批量插入
    data = generate_test_data(DATA_SIZE, thread_id)
    for i in range(0, DATA_SIZE, BATCH_SIZE):
        db.insert_batch('users', data[i:i+BATCH_SIZE])

    # 随机查询
    query_count = 1000
    for _ in range(query_count):
        db.query('users', {'name': f'user_{random.randint(0, DATA_SIZE-1)}'})

    # 随机更新
    update_count = 500
    for _ in range(update_count):
        db.update('users',
                  {'age': random.randint(18, 60)},
                  {'name': f'user_{random.randint(0, DATA_SIZE-1)}'})

    # 随机删除
    delete_count = 200
    for _ in range(delete_count):
        db.delete('users', {'name': f'user_{random.randint(0, DATA_SIZE-1)}'})

    elapsed = time.perf_counter() - start
    print(f'线程{thread_id}完成测试,耗时: {elapsed:.2f}s')
    return elapsed


def multi_thread_test():
    """多线程并发测试"""
    start = time.perf_counter()
    threads = []
    for i in range(THREAD_NUM):
        t = threading.Thread(target=single_thread_test, args=(i,))
        threads.append(t)
        t.start()

    for t in threads:
        t.join()

    elapsed = time.perf_counter() - start
    print(f'\n多线程测试完成,总耗时: {elapsed:.2f}s')
    print(f'吞吐量: {THREAD_NUM * DATA_SIZE / elapsed:.2f}条/秒')


def single_process_test(process_id):
    """单进程性能测试(用于多进程并发)"""
    single_thread_test(process_id)


def multi_process_test():
    """多进程并发测试"""
    start = time.perf_counter()
    processes = []
    for i in range(PROCESS_NUM):
        p = multiprocessing.Process(target=single_process_test, args=(i,))
        processes.append(p)
        p.start()

    for p in processes:
        p.join()

    elapsed = time.perf_counter() - start
    print(f'\n多进程测试完成,总耗时: {elapsed:.2f}s')
    print(f'吞吐量: {PROCESS_NUM * DATA_SIZE / elapsed:.2f}条/秒')


def data_consistency_check():
    """数据一致性验证"""
    db = SQLiteUtils(TEST_DB)
    total = len(db.query('users'))
    print(f'\n数据一致性检查:当前用户表记录数 {total} 条')
    if total > 0:
        sample = db.query('users', {'name': 'user_1'})
        print(f'示例记录:{sample}')
        sample2 = db.query('users', {'name': 'user_2'})
        print(f'示例记录:{sample2}')
    else:
        print('警告:测试后表中无数据!')


if __name__ == '__main__':
    # 初始化测试数据库
    test_db = SQLiteUtils(TEST_DB)
    test_db.create_table('users', {
        'id': 'INTEGER PRIMARY KEY AUTOINCREMENT',
        'name': 'TEXT NOT NULL',
        'age': 'INTEGER',
        'email': 'TEXT UNIQUE'
    })
    
    test_db.create_index("users", "name_index", ["name"], unique=False)

    # 执行测试
    print('=== 开始单线程性能测试 ===')
    test_db.delete('users', {})  # 清空旧数据
    single_elapsed = single_thread_test(0)
    print(f'单线程测试耗时: {single_elapsed:.2f}s')
    print(f'单线程吞吐量: {DATA_SIZE / single_elapsed:.2f}条/秒')

    print('\n=== 开始多线程并发测试 ===')
    test_db.delete('users', {})  # 清空旧数据
    multi_thread_test()

    print('\n=== 开始多进程并发测试 ===')
    test_db.delete('users', {})  # 清空旧数据
    multi_process_test()

    # 一致性检查
    data_consistency_check()

    # 清理测试数据库(可选)
    # import os
    # os.remove(TEST_DB)'}}}
posted @ 2024-10-12 16:32  BrianSun  阅读(122)  评论(0)    收藏  举报