springboot+mybatisPlus 自定义租户拦截

mybatisPlus自带租户拦截器 TenantLineHandler 不好用

自定义mybatis拦截器

package com.minex.configure.tenantconfig;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.session.RowBounds;
import org.apache.ibatis.session.ResultHandler;

import java.lang.reflect.Field;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.List;
import java.util.Objects;
import java.util.regex.Pattern;

public class FirmCodeInterceptor implements InnerInterceptor {

    private static final Pattern FROM_OR_JOIN_PATTERN = Pattern.compile("\\bFROM\\b|\\bJOIN\\b", Pattern.CASE_INSENSITIVE);

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        String originalSql = boundSql.getSql();
        if (shouldSkip(boundSql, ms)) {
            return;
        }
        String firmCode = TenantContext.getFirmCode();
        if (StringUtils.isBlank(firmCode)) {
            return;
        }
        // 生成新的 SQL(添加 firm_code 条件)
        String newSql = new FirmCodeInterceptorUtil().addFirmCodeCondition(originalSql, firmCode);

        // 通过反射修改 BoundSql 中的 SQL
        try {
            Field sqlField = boundSql.getClass().getDeclaredField("sql");
            sqlField.setAccessible(true);
            sqlField.set(boundSql, newSql);
        } catch (Exception e) {
            throw new RuntimeException("Failed to modify SQL for tenant isolation", e);
        }
    }

    /**
     * 判断实体类是否存在 firmCode 字段
     */
    private boolean hasFirmCodeField(Class<?> entityClass) {
        try {
            entityClass.getDeclaredField("firmCode");
            return true;
        } catch (NoSuchFieldException e) {
            return false;
        }
    }

    /**
     * 获取当前操作的实体类
     */
    private Class<?> getEntityClass(MappedStatement ms) {
        String id = ms.getId();
        int lastDotIndex = id.lastIndexOf('.');
        if (lastDotIndex <= 0) {
            return null;
        }

        String mapperName = id.substring(0, lastDotIndex);
        try {
            Class<?> mapperInterface = Class.forName(mapperName);
            // 获取 BaseMapper 的泛型参数
            for (Type genericInterface : mapperInterface.getGenericInterfaces()) {
                if (genericInterface instanceof ParameterizedType) {
                    ParameterizedType parameterizedType = (ParameterizedType) genericInterface;
                    Type rawType = parameterizedType.getRawType();
                    if (rawType instanceof Class<?> && BaseMapper.class.isAssignableFrom((Class<?>) rawType)) {
                        Type[] actualTypeArguments = parameterizedType.getActualTypeArguments();
                        if (actualTypeArguments.length > 0 && actualTypeArguments[0] instanceof Class<?>) {
                            return (Class<?>) actualTypeArguments[0];
                        }
                    }
                }
            }
            // 递归查找父接口
            return findBaseMapperEntityTypeRecursively(mapperInterface);
        } catch (ClassNotFoundException e) {
            return null;
        }
    }

    // 递归查找 BaseMapper 泛型参数
    private Class<?> findBaseMapperEntityTypeRecursively(Class<?> mapperInterface) {
        for (Type genericInterface : mapperInterface.getGenericInterfaces()) {
            if (genericInterface instanceof ParameterizedType) {
                ParameterizedType pt = (ParameterizedType) genericInterface;
                Type rawType = pt.getRawType();
                if (rawType instanceof Class<?> && BaseMapper.class.isAssignableFrom((Class<?>) rawType)) {
                    Type[] args = pt.getActualTypeArguments();
                    if (args.length > 0 && args[0] instanceof Class<?>) {
                        return (Class<?>) args[0];
                    }
                }
            }
        }
        // 递归处理父类
        Type superClass = mapperInterface.getGenericSuperclass();
        if (superClass instanceof ParameterizedType) {
            return findBaseMapperEntityTypeRecursively((Class<?>) ((ParameterizedType) superClass).getRawType());
        }
        return null;
    }

    /**
     * 跳过不需要租户隔离的 MappedStatement
     */
    private boolean shouldSkip(BoundSql boundSql, MappedStatement ms) {
        String sql = boundSql.getSql().toLowerCase();
        // 获取当前操作的实体类
        Class<?> entityClass = getEntityClass(ms);
        // 存在租户筛选条件
        boolean condition1 = sql.contains("firm_code =");
        // 不存在实体类或者实体类不存在firmCode字段
        boolean condition2 = Objects.isNull(entityClass) || !hasFirmCodeField(entityClass);
        // 不操作复杂sql
        List<String> list = List.of(" union ", " join ", " with ");
        boolean condition3 = list.stream().anyMatch(sql::contains);
        // 忽略表
        List<String> ignoreTable = List.of(" user ");
        boolean condition4 = ignoreTable.stream().anyMatch(sql::contains);
        if (condition1 || condition2 || condition3 || condition4) {
            return true;
        }
        return false;
    }
}

 

package com.minex.configure.tenantconfig;

import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.expression.*;
import net.sf.jsqlparser.schema.Column;

public class FirmCodeInterceptorUtil {

    private static final String TENANT_CONDITION = "firm_code = '{firmCode}'";

    /**
     * 动态添加 firm_code 条件(使用 JSQLParser 解析 SQL)
     */
    public String addFirmCodeCondition(String originalSql, String firmCode) {
        try {
            Statement stmt = CCJSqlParserUtil.parse(originalSql);
            if (stmt instanceof Select) {
                return handleSelect((Select) stmt, firmCode);
            } else if (stmt instanceof Update) {
                return handleUpdate((Update) stmt, firmCode);
            } else if (stmt instanceof Delete) {
                return handleDelete((Delete) stmt, firmCode);
            } else {
                // 不支持的 SQL 类型,直接返回原语句
                return originalSql;
            }
        } catch (JSQLParserException e) {
            // 解析失败时返回原始 SQL
            return originalSql;
        }
    }

    /**
     * 处理 SELECT 语句
     */
    private String handleSelect(Select select, String firmCode) {
        SelectBody selectBody = select.getSelectBody();
        if (selectBody instanceof PlainSelect) {
            PlainSelect plainSelect = (PlainSelect) selectBody;
            Expression where = plainSelect.getWhere();

            // 构建新的 WHERE 条件
            Expression newCondition = new EqualsTo();
            ((EqualsTo) newCondition).setLeftExpression(new Column("firm_code"));
            ((EqualsTo) newCondition).setRightExpression(new StringValue(firmCode));

            if (where == null) {
                // 没有 WHERE,添加 WHERE 1=1 AND ...
                plainSelect.setWhere(new AndExpression(
                        trueCondition(),
                        newCondition
                ));
            } else {
                // 有 WHERE,追加 AND ...
                plainSelect.setWhere(new AndExpression(where, newCondition));
            }

            // 更新后的 SQL
            return select.toString();
        } else {
            // 复杂 SELECT(如 UNION),不处理
            return select.toString();
        }
    }

    /**
     * 处理 UPDATE 语句
     */
    private String handleUpdate(Update update, String firmCode) {
        Expression where = update.getWhere();

        // 构建新的 WHERE 条件
        Expression newCondition = new EqualsTo();
        ((EqualsTo) newCondition).setLeftExpression(new Column("firm_code"));
        ((EqualsTo) newCondition).setRightExpression(new StringValue(firmCode));

        if (where == null) {
            // 没有 WHERE,添加 WHERE 1=1 AND ...
            update.setWhere(new AndExpression(
                    trueCondition(),
                    newCondition
            ));
        } else {
            // 有 WHERE,追加 AND ...
            update.setWhere(new AndExpression(where, newCondition));
        }

        return update.toString();
    }

    /**
     * 处理 DELETE 语句
     */
    private String handleDelete(Delete delete, String firmCode) {
        Expression where = delete.getWhere();

        // 构建新的 WHERE 条件
        Expression newCondition = new EqualsTo();
        ((EqualsTo) newCondition).setLeftExpression(new Column("firm_code"));
        ((EqualsTo) newCondition).setRightExpression(new StringValue(firmCode));

        if (where == null) {
            // 没有 WHERE,添加 WHERE 1=1 AND ...
            delete.setWhere(new AndExpression(
                    trueCondition(),
                    newCondition
            ));
        } else {
            // 有 WHERE,追加 AND ...
            delete.setWhere(new AndExpression(where, newCondition));
        }

        return delete.toString();
    }

    private Expression trueCondition() {
        Expression trueCondition = new EqualsTo();
        ((EqualsTo) trueCondition).setLeftExpression(new LongValue(1));
        ((EqualsTo) trueCondition).setRightExpression(new LongValue(1));
        return trueCondition;
    }
}

 

package com.minex.configure.tenantconfig;

import com.alibaba.fastjson.JSON;
import com.minex.common.util.StringUtils;
import com.minex.web.auth.entity.CurrentUser;
import com.minex.web.common.util.SecurityUtils;
import com.minex.web.tenant.entity.dto.TenantDTO;
import com.minex.web.tenant.service.TenantService;
import io.swagger.annotations.ApiOperation;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.springframework.context.annotation.Lazy;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.annotation.Resource;
import javax.servlet.http.HttpServletRequest;
import java.security.Security;
import java.util.Objects;

@Aspect
@Component
public class TenantAspect {

    @Resource
    @Lazy
    private TenantService tenantService;

    @Pointcut("@annotation(io.swagger.annotations.ApiOperation)")
    public void tenantPointCut() {}

    @Around("tenantPointCut()")
    public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
        String firmCode = ObtainFirmCode();
        if (StringUtils.isNotBlank(firmCode)) {
            TenantContext.setFirmCode(firmCode);
        }
        try {
            return joinPoint.proceed();
        } finally {
            TenantContext.clear(); // 清理上下文
        }
    }

    private String ObtainFirmCode() {
        HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest();
        String firmCode = null;
        CurrentUser loginUser = null;
        try {
            loginUser = SecurityUtils.getLoginUser();
        } catch (Exception ignored) {
        }
        if (Objects.nonNull(loginUser)) {
            firmCode = tenantService.getUserSelectedTenantInfo(loginUser).getCode();
        }
        // 从请求参数中提取租户信息
        String firmCodeFromParam = request.getParameter("firmCode");
        if (StringUtils.isNotBlank(firmCodeFromParam)) {
            firmCode = firmCodeFromParam;
        }
        return firmCode;
    }
}

 

package com.minex.configure.tenantconfig;

import com.alibaba.ttl.TransmittableThreadLocal;

public class TenantContext {
    private static final ThreadLocal<String> CONTEXT = new TransmittableThreadLocal<>();

    public static void setFirmCode(String firmCode) {
        CONTEXT.set(firmCode);
    }

    public static String getFirmCode() {
        return CONTEXT.get();
    }

    public static void clear() {
        CONTEXT.remove();
    }
}

 

@Bean
    public MybatisPlusInterceptor mybatisPlusInterceptor() {
        MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor();
        // 添加自定义租户隔离拦截器
        interceptor.addInnerInterceptor(new FirmCodeInterceptor());
        interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL));
        return interceptor;
    }

 

注意:

1.拦截的是swagger注解

2.注意mybatisPlus拦截器的顺序,如果顺序不正确会导致租户拦截器不生效

posted @ 2025-06-17 15:39  官萧何  阅读(66)  评论(0)    收藏  举报