springboot+mybatisPlus 自定义租户拦截
mybatisPlus自带租户拦截器 TenantLineHandler 不好用
自定义mybatis拦截器
package com.minex.configure.tenantconfig; import com.baomidou.mybatisplus.core.mapper.BaseMapper; import com.baomidou.mybatisplus.core.toolkit.StringUtils; import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor; import org.apache.ibatis.executor.Executor; import org.apache.ibatis.mapping.BoundSql; import org.apache.ibatis.mapping.MappedStatement; import org.apache.ibatis.session.RowBounds; import org.apache.ibatis.session.ResultHandler; import java.lang.reflect.Field; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.List; import java.util.Objects; import java.util.regex.Pattern; public class FirmCodeInterceptor implements InnerInterceptor { private static final Pattern FROM_OR_JOIN_PATTERN = Pattern.compile("\\bFROM\\b|\\bJOIN\\b", Pattern.CASE_INSENSITIVE); @Override public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) { String originalSql = boundSql.getSql(); if (shouldSkip(boundSql, ms)) { return; } String firmCode = TenantContext.getFirmCode(); if (StringUtils.isBlank(firmCode)) { return; } // 生成新的 SQL(添加 firm_code 条件) String newSql = new FirmCodeInterceptorUtil().addFirmCodeCondition(originalSql, firmCode); // 通过反射修改 BoundSql 中的 SQL try { Field sqlField = boundSql.getClass().getDeclaredField("sql"); sqlField.setAccessible(true); sqlField.set(boundSql, newSql); } catch (Exception e) { throw new RuntimeException("Failed to modify SQL for tenant isolation", e); } } /** * 判断实体类是否存在 firmCode 字段 */ private boolean hasFirmCodeField(Class<?> entityClass) { try { entityClass.getDeclaredField("firmCode"); return true; } catch (NoSuchFieldException e) { return false; } } /** * 获取当前操作的实体类 */ private Class<?> getEntityClass(MappedStatement ms) { String id = ms.getId(); int lastDotIndex = id.lastIndexOf('.'); if (lastDotIndex <= 0) { return null; } String mapperName = id.substring(0, lastDotIndex); try { Class<?> mapperInterface = Class.forName(mapperName); // 获取 BaseMapper 的泛型参数 for (Type genericInterface : mapperInterface.getGenericInterfaces()) { if (genericInterface instanceof ParameterizedType) { ParameterizedType parameterizedType = (ParameterizedType) genericInterface; Type rawType = parameterizedType.getRawType(); if (rawType instanceof Class<?> && BaseMapper.class.isAssignableFrom((Class<?>) rawType)) { Type[] actualTypeArguments = parameterizedType.getActualTypeArguments(); if (actualTypeArguments.length > 0 && actualTypeArguments[0] instanceof Class<?>) { return (Class<?>) actualTypeArguments[0]; } } } } // 递归查找父接口 return findBaseMapperEntityTypeRecursively(mapperInterface); } catch (ClassNotFoundException e) { return null; } } // 递归查找 BaseMapper 泛型参数 private Class<?> findBaseMapperEntityTypeRecursively(Class<?> mapperInterface) { for (Type genericInterface : mapperInterface.getGenericInterfaces()) { if (genericInterface instanceof ParameterizedType) { ParameterizedType pt = (ParameterizedType) genericInterface; Type rawType = pt.getRawType(); if (rawType instanceof Class<?> && BaseMapper.class.isAssignableFrom((Class<?>) rawType)) { Type[] args = pt.getActualTypeArguments(); if (args.length > 0 && args[0] instanceof Class<?>) { return (Class<?>) args[0]; } } } } // 递归处理父类 Type superClass = mapperInterface.getGenericSuperclass(); if (superClass instanceof ParameterizedType) { return findBaseMapperEntityTypeRecursively((Class<?>) ((ParameterizedType) superClass).getRawType()); } return null; } /** * 跳过不需要租户隔离的 MappedStatement */ private boolean shouldSkip(BoundSql boundSql, MappedStatement ms) { String sql = boundSql.getSql().toLowerCase(); // 获取当前操作的实体类 Class<?> entityClass = getEntityClass(ms); // 存在租户筛选条件 boolean condition1 = sql.contains("firm_code ="); // 不存在实体类或者实体类不存在firmCode字段 boolean condition2 = Objects.isNull(entityClass) || !hasFirmCodeField(entityClass); // 不操作复杂sql List<String> list = List.of(" union ", " join ", " with "); boolean condition3 = list.stream().anyMatch(sql::contains); // 忽略表 List<String> ignoreTable = List.of(" user "); boolean condition4 = ignoreTable.stream().anyMatch(sql::contains); if (condition1 || condition2 || condition3 || condition4) { return true; } return false; } }
package com.minex.configure.tenantconfig; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.expression.operators.relational.EqualsTo; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.statement.Statement; import net.sf.jsqlparser.statement.select.*; import net.sf.jsqlparser.statement.update.Update; import net.sf.jsqlparser.statement.delete.Delete; import net.sf.jsqlparser.expression.*; import net.sf.jsqlparser.schema.Column; public class FirmCodeInterceptorUtil { private static final String TENANT_CONDITION = "firm_code = '{firmCode}'"; /** * 动态添加 firm_code 条件(使用 JSQLParser 解析 SQL) */ public String addFirmCodeCondition(String originalSql, String firmCode) { try { Statement stmt = CCJSqlParserUtil.parse(originalSql); if (stmt instanceof Select) { return handleSelect((Select) stmt, firmCode); } else if (stmt instanceof Update) { return handleUpdate((Update) stmt, firmCode); } else if (stmt instanceof Delete) { return handleDelete((Delete) stmt, firmCode); } else { // 不支持的 SQL 类型,直接返回原语句 return originalSql; } } catch (JSQLParserException e) { // 解析失败时返回原始 SQL return originalSql; } } /** * 处理 SELECT 语句 */ private String handleSelect(Select select, String firmCode) { SelectBody selectBody = select.getSelectBody(); if (selectBody instanceof PlainSelect) { PlainSelect plainSelect = (PlainSelect) selectBody; Expression where = plainSelect.getWhere(); // 构建新的 WHERE 条件 Expression newCondition = new EqualsTo(); ((EqualsTo) newCondition).setLeftExpression(new Column("firm_code")); ((EqualsTo) newCondition).setRightExpression(new StringValue(firmCode)); if (where == null) { // 没有 WHERE,添加 WHERE 1=1 AND ... plainSelect.setWhere(new AndExpression( trueCondition(), newCondition )); } else { // 有 WHERE,追加 AND ... plainSelect.setWhere(new AndExpression(where, newCondition)); } // 更新后的 SQL return select.toString(); } else { // 复杂 SELECT(如 UNION),不处理 return select.toString(); } } /** * 处理 UPDATE 语句 */ private String handleUpdate(Update update, String firmCode) { Expression where = update.getWhere(); // 构建新的 WHERE 条件 Expression newCondition = new EqualsTo(); ((EqualsTo) newCondition).setLeftExpression(new Column("firm_code")); ((EqualsTo) newCondition).setRightExpression(new StringValue(firmCode)); if (where == null) { // 没有 WHERE,添加 WHERE 1=1 AND ... update.setWhere(new AndExpression( trueCondition(), newCondition )); } else { // 有 WHERE,追加 AND ... update.setWhere(new AndExpression(where, newCondition)); } return update.toString(); } /** * 处理 DELETE 语句 */ private String handleDelete(Delete delete, String firmCode) { Expression where = delete.getWhere(); // 构建新的 WHERE 条件 Expression newCondition = new EqualsTo(); ((EqualsTo) newCondition).setLeftExpression(new Column("firm_code")); ((EqualsTo) newCondition).setRightExpression(new StringValue(firmCode)); if (where == null) { // 没有 WHERE,添加 WHERE 1=1 AND ... delete.setWhere(new AndExpression( trueCondition(), newCondition )); } else { // 有 WHERE,追加 AND ... delete.setWhere(new AndExpression(where, newCondition)); } return delete.toString(); } private Expression trueCondition() { Expression trueCondition = new EqualsTo(); ((EqualsTo) trueCondition).setLeftExpression(new LongValue(1)); ((EqualsTo) trueCondition).setRightExpression(new LongValue(1)); return trueCondition; } }
package com.minex.configure.tenantconfig; import com.alibaba.fastjson.JSON; import com.minex.common.util.StringUtils; import com.minex.web.auth.entity.CurrentUser; import com.minex.web.common.util.SecurityUtils; import com.minex.web.tenant.entity.dto.TenantDTO; import com.minex.web.tenant.service.TenantService; import io.swagger.annotations.ApiOperation; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang.annotation.Aspect; import org.aspectj.lang.annotation.Pointcut; import org.springframework.context.annotation.Lazy; import org.springframework.stereotype.Component; import org.springframework.web.context.request.RequestContextHolder; import org.springframework.web.context.request.ServletRequestAttributes; import javax.annotation.Resource; import javax.servlet.http.HttpServletRequest; import java.security.Security; import java.util.Objects; @Aspect @Component public class TenantAspect { @Resource @Lazy private TenantService tenantService; @Pointcut("@annotation(io.swagger.annotations.ApiOperation)") public void tenantPointCut() {} @Around("tenantPointCut()") public Object around(ProceedingJoinPoint joinPoint) throws Throwable { String firmCode = ObtainFirmCode(); if (StringUtils.isNotBlank(firmCode)) { TenantContext.setFirmCode(firmCode); } try { return joinPoint.proceed(); } finally { TenantContext.clear(); // 清理上下文 } } private String ObtainFirmCode() { HttpServletRequest request = ((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest(); String firmCode = null; CurrentUser loginUser = null; try { loginUser = SecurityUtils.getLoginUser(); } catch (Exception ignored) { } if (Objects.nonNull(loginUser)) { firmCode = tenantService.getUserSelectedTenantInfo(loginUser).getCode(); } // 从请求参数中提取租户信息 String firmCodeFromParam = request.getParameter("firmCode"); if (StringUtils.isNotBlank(firmCodeFromParam)) { firmCode = firmCodeFromParam; } return firmCode; } }
package com.minex.configure.tenantconfig; import com.alibaba.ttl.TransmittableThreadLocal; public class TenantContext { private static final ThreadLocal<String> CONTEXT = new TransmittableThreadLocal<>(); public static void setFirmCode(String firmCode) { CONTEXT.set(firmCode); } public static String getFirmCode() { return CONTEXT.get(); } public static void clear() { CONTEXT.remove(); } }
@Bean public MybatisPlusInterceptor mybatisPlusInterceptor() { MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor(); // 添加自定义租户隔离拦截器 interceptor.addInnerInterceptor(new FirmCodeInterceptor()); interceptor.addInnerInterceptor(new PaginationInnerInterceptor(DbType.MYSQL)); return interceptor; }
注意:
1.拦截的是swagger注解
2.注意mybatisPlus拦截器的顺序,如果顺序不正确会导致租户拦截器不生效
⎛⎝官萧何⎠⎞一只快乐的爪哇程序猿;邮箱:1570608034@qq.com

浙公网安备 33010602011771号