flask自定义参数校验、序列化和反序列化

项目总体结构

 

我的工厂函数factory.py

from settings import setting
from flask import Flask
from models.models import db
from flask_migrate import Migrate
from urls.router import bp_te, bp_lo
# from flask_script import Manager
from utils.log import set_log
# from flask_limiter import Limiter
# from flask_limiter.util import get_remote_address
# https://www.cnblogs.com/Du704/p/13281032.html

mysql_host = setting.MYSQL_HOST
mysql_port = setting.MYSQL_PORT
mysql_user = setting.MYSQL_USER
mysql_pwd = setting.MYSQL_PASSWORD
mysql_database = setting.MYSQL_DATABASE
env_cnf = setting.ENV_CNF


def create_app():
    set_log()
    application = Flask(__name__)
    DB_URI = f'mysql+pymysql://{mysql_user}:{mysql_pwd}@{mysql_host}:{mysql_port}/{mysql_database}'
    application.config['SQLALCHEMY_DATABASE_URI'] = DB_URI
    # 是否追踪数据库修改,一般不开启, 会影响性能
    application.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
    # 是否显示底层执行的SQL语句
    application.config['SQLALCHEMY_ECHO'] = False

    # 初始化db,关联项目
    db.app = application
    db.init_app(application)
    migrate = Migrate()
    migrate.init_app(application, db)

    # 注册蓝图
    application.register_blueprint(bp_te)
    application.register_blueprint(bp_lo)

    # manager = Manager(application)

    return application


application = create_app()

 

配置文件setting.py,读取数据库等配置信息

from configparser import ConfigParser
from pathlib import Path

from utils.encryption import getDataAes

BASE_DIR = Path(__file__).resolve().parent.parent

conf = ConfigParser()
conf.read("config.ini", encoding='utf-8')
try:
    mysqlhost = conf.get("mysql", "host")
    mysqlport = conf.get("mysql", "port")
    mysqluser = conf.get("mysql", "user")
    mysqlpassword = conf.get("mysql", "password")
    mysqlname = conf.get("mysql", "name")

    secret = conf.get('serve', 'secret')
    env_cnf = conf.get('serve', 'env')

    redishost = conf.get("redis", "host")
    redisport = conf.get("redis", "port")
    redispwd = conf.get("redis", "password")
    redislibrary = conf.get("redis", "library")

    MINIOHOST = conf.get("minio", "clienthost")
    MINIOPORT = conf.get("minio", "clientport")
    miniopwd = conf.get("minio", "password")
    miniouser = conf.get("minio", "user")
    MINIOWEBHOST = conf.get("minio", "webhost")
    MINIOWEBPORT = conf.get("minio", "webport")
    MINIOHTTP = conf.get("minio", "http")

except Exception as e:
    print(e)
    mysqlhost = '127.0.0.1'
    mysqlport = 3306
    mysqluser = 'root'
    mysqlpassword = '000000'
    mysqlname = '0'
    secret = 'Wchime'
    env_cnf = 'develop'

    redishost = '127.0.0.1'
    redisport = '6379'
    redispwd = '000000'
    redislibrary = "1"

    MINIOHOST = '127.0.0.1'
    MINIOPORT = '9000'
    MINIOWEBPORT = '9000'
    miniopwd = '000000'
    miniouser = '000000'
    MINIOHTTP = 'http://'
    MINIOWEBHOST = '127.0.0.1'

MYSQL_HOST = mysqlhost
MYSQL_PORT = mysqlport
MYSQL_USER = getDataAes(secret, mysqluser)
MYSQL_PASSWORD = getDataAes(secret, mysqlpassword)
MYSQL_DATABASE = mysqlname

REDIS_HOST = redishost
REDIS_PORT = redisport
REDIS_PASSWORD = getDataAes(secret, redispwd)
REDIS_LIBRARY = redislibrary

MINIOPWD = getDataAes(secret, miniopwd)
MINIOUSER = getDataAes(secret, miniouser)

ENV_CNF = env_cnf

if __name__ == '__main__':
    print(mysqlpassword)

 

models.py数据库模型文件

  

import datetime

from utils.core import db
from sqlalchemy_serializer import SerializerMixin


class Uu(db.Model, SerializerMixin):

    __tablename__ = 'uu'
    id = db.Column(db.Integer, autoincrement=True, primary_key=True)
    name = db.Column(db.String(20), nullable=False)
    age = db.Column(db.Integer, nullable=False)

    ux_id = db.Column(db.Integer, db.ForeignKey('ux.id', ondelete='SET NULL'), nullable=True)
    ux = db.relationship('Ux', backref='uu')        # , lazy='dynamic'

    des = db.Column(db.String(20), nullable=True)

    img = db.Column(db.String(128), nullable=True)


class Ux(db.Model, SerializerMixin):
    serialize_rules = ("-uu",)
    __tablename__ = 'ux'
    id = db.Column(db.Integer, autoincrement=True, primary_key=True)
    name = db.Column(db.String(20), nullable=False)

 

 序列化文件serializes.py

from models import models
from utils.base import Serialize, DeSerialize


class TestSerialize(Serialize):
    model = models.Uu
    fields = ['id', 'name']
    build_fiels = [
        {'name': 'ux_name', 'source': 'ux.name'},
        {'name': 'img', 'method': True}
    ]

    def get_img(self, instance):
        return []


class TestDeSerialize(DeSerialize):
    model = models.Ux
    required_fields = ['name']
    ser_fields = ['id', 'name']

 

base.py自定义序列化和反序列化和参数解析文件

 

from models.models import db
from flask_restful import abort
from sqlalchemy import inspect



class DeSerialize(object):
    """
    反序列化,增删改
    """
    model = None
    req_fields = None
    other_fields = []
    req_data = {}
    insatance = None
    ser_fields = []

    def __init__(self, insatance=None, data={}):
        self.insatance = insatance
        self.req_data = data

    @property
    def required_fields(self):
        return self._get_fileds(self.req_fields)

    @property
    def serializer_fields(self):
        return self.ser_fields if self.ser_fields else self.model().serializable_keys

    def _get_fileds(self, fileds):
        if fileds is None:

            values_valid = self.model().serializable_keys
        else:
            values_valid = fileds
        values_valid = list(values_valid)
        # 删除主键
        try:
            primary_key = list(map(lambda x: x.name, inspect(self.model).primary_key))[0]
            if primary_key in values_valid:
                values_valid.remove(primary_key)
        except:
            if 'id' in values_valid:
                values_valid.remove('id')
        return values_valid

    def _get_vaild_values(self):
        vaild_dict = {}
        err = 'request data is empty'
        for key in self.required_fields:
            value = self.req_data.get(key)
            if value is None:
                err = f'{key} is not required'
                return False, err
            vaild_dict[key] = value

        for key in self.other_fields:
            value = self.req_data.get(key)
            if value is None:
                continue
            vaild_dict[key] = value

        return vaild_dict, err

    def _create(self):
        vaild_data, err = self._get_vaild_values()
        if vaild_data:
            try:
                instance = self.model(**vaild_data)
                db.session.add(instance)
                db.session.commit()
                self.insatance = instance
                return True, 'success'
            except:
                err = 'please correct fileds'
                return False, err
        else:
            return False, err

    def _update(self):
        vaild_data, err = self._get_vaild_values()
        if vaild_data:
            try:
                instance = self.insatance
                if instance is None:
                    return False, 'not find data'

                for key, value in vaild_data.items():
                    setattr(instance, key, value)

                db.session.commit()
                return True, 'success'
            except:
                err = 'please correct fileds'
                return False, err
        else:
            return False, err

    def save(self):
        if self.insatance is None:
            ret, msg = self._create()
            if ret is False:
                abort(400, msg=msg)

        else:
            ret, msg = self._update()
            if ret is False:
                abort(400, msg=msg)

    @property
    def data(self):
        if self.insatance is None:
            msg = 'data is not save'
            abort(500, msg=msg)
        return self.insatance.to_dict(only=tuple(self.serializer_fields))

    def delete(self):
        if self.insatance is None:
            msg = 'not find data'
            abort(400, msg=msg)
        else:
            try:
                db.session.delete(self.insatance)
                db.session.commit()
            except Exception:
                msg = 'delete exception'
                abort(400, msg=msg)


class Serialize(object):
    """
    序列化
    """
    model = None
    fields = "__all__"
    modelsDatas = []
    many = True
    date_format = "%Y-%m-%d"
    datetime_format = "%Y-%m-%d %H:%M:%S"
    time_format = "%H:%M:%S"

    build_fiels = []

    def __init__(self, serializers, many=True):
        self.modelsDatas = serializers
        self.many = many

    @property
    def data(self):
        li = []
        try:
            if self.many:
                for data in self.modelsDatas:
                    da = data.to_dict(only=tuple(self.serializer_fields), date_format=self.date_format,
                                      datetime_format=self.datetime_format, time_format=self.time_format)
                    da.update(self._get_build_files_values(data))
                    li.append(da)
            else:
                da = self.modelsDatas.to_dict(only=tuple(self.serializer_fields), date_format=self.date_format,
                                              datetime_format=self.datetime_format, time_format=self.time_format)
                da.update(self._get_build_files_values(self.modelsDatas))
                li.append(da)
        except Exception as e:
            print(e)
            msg = 'serialize error'
            li.append(msg)
            abort(500, msg=msg)
        return li

    @property
    def serializer_fields(self):
        return self.fields if self.fields != "__all__" else self.model().serializable_keys

    def _get_build_files_values(self, data):
        dit = {}
        for build in self.build_fiels:
            obj = data
            if build.get('method'):
                func = f"get_{build['name']}"
                f = getattr(self, func)
                value = f(obj)
                dit[build['name']] = value
            else:
                source_list = build['source'].split('.')
                value = None
                for source in source_list:

                    value = getattr(obj, source, None)
                    if value:
                        obj = value
                    else:
                        break
                dit[build['name']] = value

        return dit


class ParseQuery(object):
    """
    查询,排序
    """
    filer_query = frozenset(['gt', 'ge', 'lt', 'le', 'ne', 'eq', 'ic', 'ni', 'in'])

    def __init__(self, model, req_data, filter_list=[], order_by=None):
        self.model = model
        self.req_data = req_data
        self.filter_list = filter_list
        self.order_by = order_by

        self._operator_funcs = {
            'gt': self.__gt_model,
            'ge': self.__ge_model,
            'lt': self.__lt_model,
            'le': self.__le_model,
            'ne': self.__ne_model,
            'eq': self.__eq_model,
            'ic': self.__ic_model,
            'ni': self.__ni_model,
            # 'by': self.__by_model,
            'in': self.__in_model,
        }

    @property
    def _filter_data(self):
        search_dict = {}
        for fit in self.filter_list:
            val = self.req_data.get(fit)
            key = fit.split('__')[0]
            if val and hasattr(self.model, key):
                search_dict[fit] = val
        return search_dict

    def _parse_fields(self):
        li = []
        for search_key, value in self._filter_data.items():
            key, ope = search_key.split('__')
            if ope in self.filer_query:
                data = self._operator_funcs[ope](key=key, value=value)
                li.append(data)
        return li

    def _filter(self):
        data = tuple(self._parse_fields())
        quety_data = self.model.query.filter(*data)
        if self.order_by:
            data = self._parse_order_by()
            quety_data = quety_data.order_by(*data)
        return quety_data

    @property
    def query(self):
        return self._filter()

    def pagination_class(self, page_num=1, page_size=10, max_page_size=50, error_out=False):
        pagin = self.query.paginate(
            page=page_num,
            per_page=page_size,
            error_out=error_out,
            max_per_page=max_page_size
        )
        return pagin.items, pagin.total

    def _parse_order_by(self):
        """
        解析排序
        :return:
        """
        li = []
        for ord in list(self.order_by):
            if ord.find('-') == -1:
                data = self.__by_model(ord)
                if data:
                    li.append(data.asc())
            else:
                ord = ord[1:]
                data = self.__by_model(ord)
                if data:
                    li.append(data.desc())
        return tuple(li)

    def __by_model(self, key):
        """
        排序时获取字段
        :return:
        """
        return getattr(self.model, key)

    def __gt_model(self, key, value):
        """
        大于
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key) > value

    def __ge_model(self, key, value):
        """
        大于等于
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key) >= value

    def __lt_model(self, key, value):
        """
        小于
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key) < value

    def __le_model(self, key, value):
        """
        小于等于
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key) <= value

    def __eq_model(self, key, value):
        """
        等于
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key) == value

    def __ne_model(self, key, value):
        """
        不等于
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key) != value

    def __ic_model(self, key, value):
        """
        包含
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key).like('%{}%'.format(value))

    def __ni_model(self, key, value):
        """
        不包含
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key).notlike('%{}%'.format(value))

    def __in_model(self, key, value):
        """
        查询多个相同字段的值
        :param key:
        :param value:
        :return:
        """
        return getattr(self.model, key).in_(value)

 

自定义序列化和反序列化后,接口将变得简单

 上面的类视图将只有短短的几行代码

posted @ 2023-06-15 15:02  Wchime  阅读(153)  评论(0编辑  收藏  举报