sqlalchemy_cruder
sqlalchemy_cruder
https://github.com/fanqingsong/sqlalchemy_cruder
core implementation
from typing import Any, Dict, Generic, List, Optional, Type, TypeVar, Union from contextlib import contextmanager from fastapi.encoders import jsonable_encoder from pydantic import BaseModel from sqlalchemy.orm import sessionmaker, Session from sqlalchemy.ext.declarative import DeclarativeMeta, declarative_base from app.db.base import Base ModelType = TypeVar("ModelType", bound=Base) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) class CRUDER(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): def __init__(self, model: Type[ModelType], session_maker: sessionmaker[Session]): """ CRUD object with default methods to Create, Read, Update, Delete (CRUD). **Parameters** * `model`: A SQLAlchemy model class * `schema`: A Pydantic model (schema) class """ # 检查 model 是否满足条件 def is_declarative_subclass(cls): for base in cls.__mro__: if type(base) is DeclarativeMeta: return True return False if not is_declarative_subclass(model): raise ValueError(f"model must be a subclass of a class with DeclarativeMeta metaclass, got {model}") self.model = model self.session_maker = session_maker self.db: Optional[Session] = None @contextmanager def get_db_session(self): db = self.session_maker() try: yield db finally: db.close() @contextmanager def transaction(self): db = self.session_maker() try: yield db db.commit() except Exception as e: db.rollback() raise e finally: db.close() def query_by_condition(self, condition: CreateSchemaType) -> List[ModelType]: """ 根据指定的 BaseModel 对象条件查询数据库记录。 **参数** * `condition`: 一个 Pydantic 的 BaseModel 对象,用于指定查询条件。 **返回** * 满足条件的记录列表。 """ with self.get_db_session() as db: query = db.query(self.model) condition_data = condition.dict(exclude_unset=True) for field, value in condition_data.items(): query = query.filter(getattr(self.model, field) == value) return query.all() def query_by_id(self, id: Any) -> Optional[ModelType]: with self.get_db_session() as db: return db.query(self.model).filter(self.model.id == id).first() def query_by_pagination(self, skip: int = 0, limit: int = 100) -> List[ModelType]: with self.get_db_session() as db: return db.query(self.model).offset(skip).limit(limit).all() def query_by_like(self, field: str, value: str) -> List[ModelType]: with self.get_db_session() as db: return db.query(self.model).filter(getattr(self.model, field).like(f"%{value}%")).all() def query_by_range(self, field: str, start: Any, end: Any) -> List[ModelType]: with self.get_db_session() as db: return db.query(self.model).filter(getattr(self.model, field).between(start, end)).all() def create(self, obj_in: CreateSchemaType) -> ModelType: with self.get_db_session() as db: obj_in_data = jsonable_encoder(obj_in) db_obj = self.model(**obj_in_data) # type: ignore db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def update(self, db_obj: Optional[ModelType], obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> Optional[ModelType]: with self.get_db_session() as db: obj_data = jsonable_encoder(db_obj) if isinstance(obj_in, dict): update_data = obj_in else: update_data = obj_in.dict(exclude_unset=True) for field in obj_data: if field in update_data: setattr(db_obj, field, update_data[field]) db.add(db_obj) db.commit() db.refresh(db_obj) return db_obj def remove(self, id: int) -> Optional[ModelType]: with self.get_db_session() as db: obj = db.query(self.model).get(id) db.delete(obj) db.commit() return obj def create_multi(self, objs_in: List[CreateSchemaType]) -> List[ModelType]: with self.transaction() as db: db_objs = [] for obj_in in objs_in: obj_in_data = jsonable_encoder(obj_in) db_obj = self.model(**obj_in_data) db.add(db_obj) db_objs.append(db_obj) return db_objs def update_multi(self, db_objs: List[ModelType], obj_in: Union[UpdateSchemaType, Dict[str, Any]]) -> List[ModelType]: with self.transaction() as db: if isinstance(obj_in, dict): update_data = obj_in else: update_data = obj_in.dict(exclude_unset=True) for db_obj in db_objs: obj_data = jsonable_encoder(db_obj) for field in obj_data: if field in update_data: setattr(db_obj, field, update_data[field]) db.add(db_obj) return db_objs def remove_multi(self, ids: List[int]) -> List[Optional[ModelType]]: with self.transaction() as db: removed_objs = [] for id in ids: obj = db.query(self.model).get(id) if obj: db.delete(obj) removed_objs.append(obj) return removed_objs def combined_operation(self, create_obj: CreateSchemaType, update_obj: ModelType, update_data: Union[UpdateSchemaType, Dict[str, Any]]): with self.transaction() as db: # 创建操作 obj_in_data = jsonable_encoder(create_obj) db_create_obj = self.model(**obj_in_data) db.add(db_create_obj) # 更新操作 if isinstance(update_data, dict): update_dict = update_data else: update_dict = update_data.dict(exclude_unset=True) for field in jsonable_encoder(update_obj): if field in update_dict: setattr(update_obj, field, update_dict[field]) db.add(update_obj) return db_create_obj, update_obj def execute_query(self, statement) -> List[ModelType]: """ 执行查询语句 :param statement: SQLAlchemy 查询语句 :return: 查询结果列表 """ with self.get_db_session() as db: result = db.execute(statement) return result.scalars().all() def execute_statement(self, statement) -> int: """ 执行非查询语句(如更新、删除等) :param statement: SQLAlchemy 非查询语句 :return: 受影响的行数 """ with self.get_db_session() as db: result = db.execute(statement) db.commit() return result.rowcount
usages
import logging import sys import os # Add the project root directory to the Python path project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(project_root) from sqlalchemy import Table, Column, Integer, String, MetaData from sqlalchemy.orm import sessionmaker, declarative_base from sqlalchemy import create_engine, select, update from sqlalchemy_cruder.sqlalchemy_cruder import CRUDER from pydantic import BaseModel # Create a base class for declarative models Base = declarative_base() # Create a base class for declarative models Base = declarative_base() # Define the User model class class User(Base): __tablename__ = 'users' id = Column(Integer, primary_key=True) name = Column(String) # Define Pydantic schemas for create and update operations class UserCreate(BaseModel): name: str class UserUpdate(BaseModel): name: str # Example database configuration engine = create_engine('sqlite:///:memory:') Session = sessionmaker(bind=engine) session = Session() # Create all tables Base.metadata.create_all(engine) # 创建 CRUDER 实例 # cruder = CRUDER[User, UserCreate, UserUpdate](model=User, session_maker=Session) # cruder = CRUDER(model=User, session_maker=Session) # 改进后的创建 CRUDER 实例方式 def create_cruder(model, create_schema, update_schema, session_maker): return CRUDER[model, create_schema, update_schema](model=model, session_maker=session_maker) # 创建 CRUDER 实例 cruder = create_cruder(User, UserCreate, UserUpdate, Session) # Insert a sample user user_create = UserCreate(name='old_name') created_user = cruder.create(user_create) # Query example select_stmt = select(User).where(User.id == created_user.id) results = cruder.execute_query(select_stmt) print("Query results:", results) # Update example update_stmt = ( update(User). where(User.id == created_user.id). values(name='new_name') ) affected_rows = cruder.execute_statement(update_stmt) print("Affected rows:", affected_rows)
出处:http://www.cnblogs.com/lightsong/
本文版权归作者和博客园共有,欢迎转载,但未经作者同意必须保留此段声明,且在文章页面明显位置给出原文连接。