Loading

基于 SQL 解析的 JPA 多租户方案

概述

最近在对一个使用 JPA 的老项目进行多租户改造,由于年代过于久远,陈年屎山让人实在不敢轻举妄动,最后只能选择一个改造成本最小的方案,那就是通过拦截器改 SQL,动态添加租户 ID 作为查询条件。
本篇文章用于记录笔者基于该方案解决此问题的踩坑和思考过程,部分代码与实际代码有所出入。如果希望直接获取可运行的代码,可以直接在 github 仓库获取。

1.SQL 拦截器

由于 JPA 底层是基于 Hibernate 实现的,而 Hibernate 本身提供了 StatementInspector 接口用于实现 SQL 拦截。因此我们只需要在这个阶段对 SQL 进行解析,然后为需要按租户进行隔离的资源表动态的添加租户 ID 的过滤条件即可。
这里我们选择使用 JSqlParser 作为我们的 SQL 解析器。它社区还算活跃,文档详细,最重要的是,API 比较简单易懂。下文若不特意澄清,则所有与 SQL 解析相关的类都来自于它。

1.1.简单实现

在最开始,我们写一个简单的实现来验证一下可行性。
假设,我们需要指定拦截针对表 t_resource 的查询语句,为其添加 tenant_id = xxx 作为查询条件,那么这个 SQL 拦截器需要做到:

  1. 将 SQL 解析为 Statement 对象,然后检查其是否为查询类 SQL;
  2. 获取 SQL 的 from 语句,并判断查询的表是否为我们要拦截的表;
  3. 解析 where 语句:
    • 若原本没有任何条件,则为其生成一个 where t.tenant_id = xxx 的条件;
    • 如果原本已经有条件了,则为其在最后拼接 and t.tenant_id = xxx 的条件;

这里我们针对这个需求给出一个简单的实现:

@Slf4j
public class TenantSQLInterceptor {

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @param table 要拦截器的租户表名
     * @param column 租户字段名
     * @param value 租户字段值
     * @return 处理后的SQL语句
     */
    public String handle(String sql, String table, String column, String value) {
        log.debug("租户拦截器拦截原始 SQL: {}", sql);
        String handledSql = doHandle(sql, table, column, value);
        log.info("租户拦截器拦截后 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String doHandle(String sql, String table, String column, String value) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(statement -> doHandle(statement, table, column, value))
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
            return statements;
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失败");
        }
    }

    private Statement doHandle(Statement statement, String table, String column, String value) {
        if (!(statement instanceof Select)) {
            return statement;
        }
        try {
            SelectBody selectBody = ((Select) statement).getSelectBody();
            // 目前只处理普通的 SQL 查询
            if (selectBody instanceof PlainSelect) {
                PlainSelect plainSelect = (PlainSelect) selectBody;
                FromItem fromItem = plainSelect.getFromItem();
                Expression where = plainSelect.getWhere();

                // 如果查询的表即为要拦截的租户表,则为查询条件添加租户条件
                if (fromItem instanceof Table) {
                    String queryTable = ((Table) fromItem).getName();
                    if (Objects.equals(queryTable, table)) {
                        where = appendTenantCondition(plainSelect.getWhere(), fromItem, value, column);
                        plainSelect.setWhere(where);
                    }
                }
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败", ex);
        }
        return statement;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        // 生成一个 tenant_id = xxx 的条件
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }
}

测试一下:

public static void main(String[] args) {
    String sql = "select * from t_resource r where r.order = 1";
    TenantSQLInterceptor tenantSQLInterceptor = new TenantSQLInterceptor();
    String handledSql = tenantSQLInterceptor.handle(sql, "t_resource", "tenant_id", "1");
    System.out.println(handledSql); // = SELECT * FROM resource r WHERE r.t_resource = 1 AND r.tenant_id = '1'
}

虽然还非常简陋,不过这个拦截器已经能够初步实现我们想要的功能了,不过要投入实际场景,显然还需要做出“一点点”改进。

1.2.从上下文获取租户信息

首先,真实的使用场景中,一个 SQL 可能会同时涉及到多张需要拦截器的表,并且每张表对应的租户 ID 仍然有可能不同,因此我们最好直接将相关的配置信息提取出来,改为通过一个上下文对象进行获取:

@Slf4j
public class TenantSQLInterceptor {

    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 设置租户信息
     *
     * @param tenantInfo 租户信息
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租户信息
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }


    public String handle(String sql) {
        // 如果未设置租户信息,则直接返回原始SQL
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        if (Objects.isNull(tenantInfo)) {
            return sql;
        }
        log.debug("租户拦截器拦截原始 SQL: {}", sql);
        String handledSql = doHandle(sql);
        log.info("租户拦截器拦截后 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String doHandle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
            return statements;
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失败");
        }
    }

    private Statement doHandle(Statement statement) {
        if (!(statement instanceof Select)) {
            return statement;
        }
        try {
            SelectBody selectBody = ((Select) statement).getSelectBody();
            if (selectBody instanceof PlainSelect) {
                PlainSelect plainSelect = (PlainSelect) selectBody;
                FromItem fromItem = plainSelect.getFromItem();
                Expression where = plainSelect.getWhere();

                // 如果查询的表即为要拦截的租户表,则为查询条件添加租户条件
                if (fromItem instanceof Table) {
                    String queryTable = ((Table) fromItem).getName();
                    TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
                    String tenantColumn = tenantInfo.tablesWithTenantColumn.get(queryTable);
                    if (Objects.nonNull(tenantColumn)) {
                        plainSelect.setWhere(appendTenantCondition(where, fromItem, tenantInfo.tenantId, tenantColumn));
                    }
                }
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败", ex);
        }
        return statement;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租户信息
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租户ID
         */
        private final String tenantId;
        /**
         * 要添加租户条件的表名称与对应的租户字段
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

1.3.复杂 SQL 的解析

在实际场景中,尤其是涉及到手写 SQL 的场景中,SQL 往往比较复杂,比如:

  • 查询可能基于一张虚拟表,比如: select * from (selecrt from t1 where t1.id = xx) t2 这种情况。
  • 可能会存在关联查,比如: select * from t1 left join t2 on t2.id = t1.tid 这种情况。
  • 可能会涉及到子查询,比如:select * from t where t.id in (select t2.tid from t2 where t2.id = xxx) 这种情况。

除上述这几种情况外,我们还需要考虑各种组合的场景,比如 union 类型的联合查询,函数与子查询的嵌套,基于虚拟表的联查……等等。

1.3.1.改进方案

虽然情况有很多种,不过值得高兴的是,我们还是有办法为其归纳出一个处理流程。简单的来说,就是检查所有可能存在嵌套查询的语句,进行递归解析:

  1. 第一步,先解析语句本身,如果是 union 这种联合查询,则将其拆分为多条单体 SQL 进行递归解析;
  2. 第二步,对于单条 SQL,解析其 select 的字段,如果存在函数或者子查询,则将每个字段其作为一个单体 SQL 进行递归解析;
  3. 第三步,解析其 from 语句,如果存在函数或者基于子查询的临时表,则将子查询作为一个单体 SQL 进行递归解析;
  4. 第四步,解析 join 语句:
    1. 如果 join 的表本身是基于子查询的临时表,则将子查询作为一个单体 SQL 进行递归解析;
    2. 如果 on 条件中存在函数或者子查询,则将其作为单体 SQL 进行递归解析;
  5. 第五步,解析 where 条件,如果存在函数或者基于子查询的条件字段,则将其作为一个单体 SQL 进行递归解析。

基于上述分析,我们需要对现有的代码做出一点调整:

  • 在 doHandle 方法中,我们需要判断 fromItem 的类型,如果是子查询,则需要进行递归处理。
  • 在 doHandle 方法后,我们需要新增一部分对 join 语句的处理,由于 join 语句同样由 from 和 where 两部分组成,因此此处的逻辑应当与正常的 select 差不多。
  • 在 appendTenantCondition 方法之前,我们需要增加对特殊条件的处理,对应每个条件,我们都需要检查是否存在可能的子查询,如果存则需要进行递归处理。

1.3.2.改进后的代码

根据改进方案,我们再次调整代码:

@Slf4j
public class TenantSQLInterceptor {
    
    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 设置租户信息
     *
     * @param tenantInfo 租户信息
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租户信息
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }
    
    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String handle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new RuntimeException("SQL 解析失败", e);
        }
        return statements;
    }

    private Statement doHandle(Statement statement) {
        try {
            if (statement instanceof Select) {
                processSelect(((Select) statement).getSelectBody());
            } else if (statement instanceof Update) {
                processUpdate((Update) statement);
            } else if (statement instanceof Delete) {
                processDelete((Delete) statement);
            } else if (statement instanceof Insert) {
                processInsert((Insert) statement);
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败", ex);
        }
        return statement;
    }

    private void processSelect(SelectBody selectBody) {
        // 普通查询
        if (selectBody instanceof PlainSelect) {
            processSelect((PlainSelect) selectBody);
        }
        // 嵌套查询,比如 select xx from (select yy from t)
        else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelect(withItem.getSelectBody());
            }
        }
        // 联合查询,比如 union
        else if (selectBody instanceof SetOperationList) {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (CollUtil.isNotEmpty(operationList.getSelects())) {
                operationList.getSelects().forEach(this::processSelect);
            }
        }
        // 值查询,比如 select 1, 2, 3
        else if (selectBody instanceof ValuesStatement) {
            List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
            if (CollUtil.isNotEmpty(expressions)) {
                expressions.forEach(exp -> processCondition(exp, null));
            }
        } else {
            log.error("无法解析的 select 语句:{}({})", selectBody, selectBody.getClass());
            throw new RuntimeException("不支持的查询语句:" + selectBody.getClass().getName()
        }
    }

    /**
     * 处理插入语句
     *
     * @param insert 插入语句
     */
    protected void processInsert(Insert insert) {
        // do nothing
    }

    /**
     * 处理删除语句
     *
     * @param delete 删除语句
     */
    protected void processDelete(Delete delete) {
        Table table = delete.getTable();
        delete.setWhere(processCondition(delete.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = delete.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理更新语句
     *
     * @param update 更新语句
     */
    protected void processUpdate(Update update) {
        Table table = update.getTable();
        update.setWhere(processCondition(update.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = update.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理查询语句
     *
     * @param plainSelect 查询语句
     */
    protected void processSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        // 如果是普通的表名
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
        }
        // 如果是子查询,比如 select * from (select xxx from yyy)
        else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelect(subSelect.getSelectBody());
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
        }
        // 如果是带有特殊函数的子查询,比如 lateral (select sum(*) from yyy)
        else if (fromItem instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
        }
        // 未知类型的查询,直接报错
        else {
            log.error("无法解析的 from 语句:{}({})", fromItem, fromItem.getClass());
            throw new RuntimeException("不支持的查询语句:" + fromItem.getClass().getName()
        }

        // 如果还存在关联查询
        List<Join> joins = plainSelect.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理关联查询
     *
     * @param join 关联查询
     */
    protected void processJoin(Join join) {
        FromItem joinTable = join.getRightItem();
        if (joinTable instanceof Table) {
            Table table = (Table) joinTable;
            join.setOnExpression(processCondition(join.getOnExpression(), table));
        }
        else if (joinTable instanceof SubSelect) {
            processSelect(((SubSelect) joinTable).getSelectBody());
        }
        else if (joinTable instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
        }
        else {
            log.error("无法解析的 join 语句:{}({})", joinTable, joinTable.getClass());
            throw new RuntimeException("不支持的查询语句:" + joinTable.getClass().getName());
        }
    }

    /**
     * <p>获取添加了租户条件的查询条件,若条件中存在子查询,则也会为子查询添加租户条件。
     *
     * @param expression 条件表达式
     * @param table 表
     * @return 添加租户条件后的条件表达式
     */
    protected Expression processCondition(@Nullable Expression expression, FromItem table) {
        // 如果已经不可拆分的表达式,则直接返回
        if (isBasicExpression(expression)) {
            return expression;
        }
        // 如果是子查询,则需要对子查询进行递归处理
        else if (expression instanceof SubSelect) {
            processSelect(((SubSelect) expression).getSelectBody());
        }
        // 如果是 in 条件,比如:xxx in (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelect(((SubSelect) rightItems).getSelectBody());
            }
        }
        // 如果是 not 或者 != 条件,则需要对里面的条件进行递归处理
        else if (expression instanceof NotExpression) {
            NotExpression notExpression = (NotExpression) expression;
            processCondition(notExpression.getExpression(), table);
        }
        // 如果是 (xxx != xxx),则需要对括号里面的表达式进行递归处理
        else if (expression instanceof Parenthesis) {
            Parenthesis parenthesis = (Parenthesis) expression;
            Expression content = parenthesis.getExpression();
            processCondition(content, table);
        }
        // 如果是二元表达式,比如:xx = xx,xx > xx,则需要对左右两边的表达式进行递归处理
        else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) expression;
            Expression left = binaryExpression.getLeftExpression();
            processCondition(left, table);
            Expression right = binaryExpression.getRightExpression();
            processCondition(right, table);
        }
        // 如果是函数,比如:if(xx, xx) ,则需要对函数的参数进行递归处理
        else if (expression instanceof Function) {
            Function function = (Function) expression;
            ExpressionList parameters = function.getParameters();
            if (parameters != null) {
                parameters.getExpressions().forEach(param -> processCondition(param, table));
            }
        }
        // 如果是 case when 语句,则需要对 when 和 then 两个条件进行递归处理
        else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            processCondition(whenClause.getWhenExpression(), table);
            processCondition(whenClause.getThenExpression(), table);
        }
        // 如果是 case 语句,则需要对 switch、when、then、else 四个条件进行递归处理
        else if (expression instanceof CaseExpression) {
            CaseExpression caseExpression = (CaseExpression) expression;
            processCondition(caseExpression.getSwitchExpression(), table);
            List<WhenClause> whenClauses = caseExpression.getWhenClauses();
            if (CollUtil.isNotEmpty(whenClauses)) {
                whenClauses.forEach(whenClause -> {
                    processCondition(whenClause.getWhenExpression(), table);
                    processCondition(whenClause.getThenExpression(), table);
                });
            }
            processCondition(caseExpression.getElseExpression(), table);
        }
        // 如果是 exists 语句,比如:exists (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof ExistsExpression) {
            Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
            if (existsExpression instanceof SubSelect) {
                processSelect(((SubSelect) existsExpression).getSelectBody());
            }
        }
        // 如果是 all 或者 any 语句,比如:xx > all (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof AllComparisonExpression) {
            AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
            processSelect(allComparisonExpression.getSubSelect().getSelectBody());
        }
        else if (expression instanceof AnyComparisonExpression) {
            AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
            processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
        }
        // 如果是 cast 语句,比如:cast(xx as xx),则需要对子查询进行递归处理
        else if (expression instanceof CastExpression) {
            CastExpression castExpression = (CastExpression) expression;
            processCondition(castExpression.getLeftExpression(), table);
        }

        // 拼接查询条件
        Expression appendCondition = handleCondition(expression, table);
        return Objects.isNull(appendCondition) ? expression : appendCondition;
    }

    /**
     * 判断是否是已经是无法再拆分的基本表达式 <br/>
     * 比如:列名、常量、函数等
     *
     * @param expression 表达式
     * @return 是否是基本表达式
     */
    protected boolean isBasicExpression(@Nullable Expression expression) {
        return expression instanceof Column
            || expression instanceof LongValue
            || expression instanceof StringValue
            || expression instanceof DoubleValue
            || expression instanceof NullValue
            || expression instanceof TimeValue
            || expression instanceof TimestampValue
            || expression instanceof DateValue;
    }

    /**
     * 返回一个查询条件,该查询条件将替换{@code table}原有的{@code where}条件
     *
     * @param expression 原有的查询条件
     * @param table 指定的表
     * @return 查询条件
     */
    @Nullable
    protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        // 如果是一个标准表名,且改表名在租户表列表中,则为查询条件添加租户条件
        if (!(table instanceof Table)) {
            return null;
        }
        String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
        if (Objects.nonNull(tenantColumn)) {
            return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
        }
        return null;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租户信息
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租户ID
         */
        private final String tenantId;
        /**
         * 要添加租户条件的表名称与对应的租户字段
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

现在,针对预期的复杂场景,我们再来测试一下:

public static void main(String[] args) {
    Map<String, String> tablesWithTenantColumn = Maps.newHashMap();
    tablesWithTenantColumn.put("t", "tenant_id");
    TenantInfo tenantInfo = new TenantInfo("1", tablesWithTenantColumn);
    TenantSQLInterceptor.setTenantInfo(tenantInfo);

    // 处理包含的复杂子查询的SQL
    String sql = "select * " +
        "from (select * from t where a = 1) t " +
        "left join (select * from t where b = 2) t2 on t.id = t2.id " +
        "where b in (select * from t where c = 2) and d = 3";
    TenantSQLInterceptor interceptor = new TenantSQLInterceptor();
    String handledSql = interceptor.handle(sql);
    System.out.println(handledSql);
    // 输出结果:
    // select * 
    // from (select * from t where a = 1 and tenant_id = '1') t 
    // left join (select * from t where b = 2 and tenant_id = '1') t2 on t.id = t2.id 
    // where b in (select * from t where c = 2 and tenant_id = '1') and d = 3
}

完美!

1.4.分离公共代码

这个 SQL 拦截器已经可以完美满足我们的大部分需求了。现在功能已经实现,可以看看代码层面有什么可以优化的地方了。
我们再次分析一下上述代码,会注意到,上面的解析器其实干了两件事情:

  • 解析 SQL,并在递归获取不可再拆分的“根” SQL 后,替换其 where 条件。
  • 将 SQL 的 where 条件替换或追加上租户条件。

换而言之,第一步的逻辑似乎与“租户拦截”这个需求无关,它显然可以抽离为一个独立的组件以便后续复用。此外,我们现在实现的其实是一个行级别的租户拦截,如果我们日后需要表级别的租户拦截,最好也有办法基于它来实现。
综上考虑,这里我们将这个新组件根据其功能命名为 AbstractSqlHandler,并且为其添加一个 handleTable 抽象方法,使其具备拦截表名的能力:

/**
 * <p>SQL处理器,用于拦截SQL语句并修改其中的查询条件,
 * 该处理器支持处理嵌套查询、联合查询、关联查询等多种查询方式。
 *
 * @author huangchengxing
 * @see #handle
 * @see #handleCondition
 */
@Setter
@Slf4j
public abstract class AbstractSqlHandler {

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String handle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new RuntimeException("SQL 解析失败");
        }
        return statements;
    }

    private Statement doHandle(Statement statement) {
        try {
            if (statement instanceof Select) {
                processSelect(((Select) statement).getSelectBody());
            } else if (statement instanceof Update) {
                processUpdate((Update) statement);
            } else if (statement instanceof Delete) {
                processDelete((Delete) statement);
            } else if (statement instanceof Insert) {
                processInsert((Insert) statement);
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败");
        }
        return statement;
    }

    private void processSelect(SelectBody selectBody) {
        // 普通查询
        if (selectBody instanceof PlainSelect) {
            processSelect((PlainSelect) selectBody);
        }
        // 嵌套查询,比如 select xx from (select yy from t)
        else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelect(withItem.getSelectBody());
            }
        }
        // 联合查询,比如 union
        else if (selectBody instanceof SetOperationList) {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (CollUtil.isNotEmpty(operationList.getSelects())) {
                operationList.getSelects().forEach(this::processSelect);
            }
        }
        // 值查询,比如 select 1, 2, 3
        else if (selectBody instanceof ValuesStatement) {
            List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
            if (CollUtil.isNotEmpty(expressions)) {
                expressions.forEach(exp -> processCondition(exp, null));
            }
        } else {
            log.error("无法解析的 select 语句:{}({})", selectBody, selectBody.getClass());
            throw new RuntimeException("不支持的查询语句:" + selectBody.getClass().getName());
        }
    }

    /**
     * 处理插入语句
     *
     * @param insert 插入语句
     */
    protected void processInsert(Insert insert) {
        // do nothing
    }

    /**
     * 处理删除语句
     *
     * @param delete 删除语句
     */
    protected void processDelete(Delete delete) {
        Table table = delete.getTable();
        delete.setWhere(processCondition(delete.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = delete.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理更新语句
     *
     * @param update 更新语句
     */
    protected void processUpdate(Update update) {
        Table table = update.getTable();
        update.setWhere(processCondition(update.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = update.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理查询语句
     *
     * @param plainSelect 查询语句
     */
    protected void processSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        // 如果是普通的表名
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            plainSelect.setFromItem(handleTable(fromTable));
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
        }
        // 如果是子查询,比如 select * from (select xxx from yyy)
        else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelect(subSelect.getSelectBody());
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
        }
        // 如果是带有特殊函数的子查询,比如 lateral (select sum(*) from yyy)
        else if (fromItem instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
        }
        // 未知类型的查询,直接报错
        else {
            log.error("无法解析的 from 语句:{}({})", fromItem, fromItem.getClass());
            throw new RuntimeException("不支持的查询语句:" + fromItem.getClass().getName());
        }

        // 如果还存在关联查询
        List<Join> joins = plainSelect.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理关联查询
     *
     * @param join 关联查询
     */
    protected void processJoin(Join join) {
        FromItem joinTable = join.getRightItem();
        if (joinTable instanceof Table) {
            Table table = (Table) joinTable;
            join.setRightItem(handleTable((Table) joinTable));
            join.setOnExpression(processCondition(join.getOnExpression(), table));
        }
        else if (joinTable instanceof SubSelect) {
            processSelect(((SubSelect) joinTable).getSelectBody());
        }
        else if (joinTable instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
        }
        else {
            log.error("无法解析的 join 语句:{}({})", joinTable, joinTable.getClass());
            throw new RuntimeException("不支持的查询语句:" + joinTable.getClass().getName());
        }
    }

    /**
     * <p>获取添加了租户条件的查询条件,若条件中存在子查询,则也会为子查询添加租户条件。
     *
     * @param expression 条件表达式
     * @param table 表
     * @return 添加租户条件后的条件表达式
     */
    @SuppressWarnings({"java:S6541", "java:S3776"})
    protected Expression processCondition(@Nullable Expression expression, FromItem table) {
        // 如果已经不可拆分的表达式,则直接返回
        if (isBasicExpression(expression)) {
            return expression;
        }
        // 如果是子查询,则需要对子查询进行递归处理
        else if (expression instanceof SubSelect) {
            processSelect(((SubSelect) expression).getSelectBody());
        }
        // 如果是 in 条件,比如:xxx in (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelect(((SubSelect) rightItems).getSelectBody());
            }
        }
        // 如果是 not 或者 != 条件,则需要对里面的条件进行递归处理
        else if (expression instanceof NotExpression) {
            NotExpression notExpression = (NotExpression) expression;
            processCondition(notExpression.getExpression(), table);
        }
        // 如果是 (xxx != xxx),则需要对括号里面的表达式进行递归处理
        else if (expression instanceof Parenthesis) {
            Parenthesis parenthesis = (Parenthesis) expression;
            Expression content = parenthesis.getExpression();
            processCondition(content, table);
        }
        // 如果是二元表达式,比如:xx = xx,xx > xx,则需要对左右两边的表达式进行递归处理
        else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) expression;
            Expression left = binaryExpression.getLeftExpression();
            processCondition(left, table);
            Expression right = binaryExpression.getRightExpression();
            processCondition(right, table);
        }
        // 如果是函数,比如:if(xx, xx) ,则需要对函数的参数进行递归处理
        else if (expression instanceof Function) {
            Function function = (Function) expression;
            ExpressionList parameters = function.getParameters();
            if (parameters != null) {
                parameters.getExpressions().forEach(param -> processCondition(param, table));
            }
        }
        // 如果是 case when 语句,则需要对 when 和 then 两个条件进行递归处理
        else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            processCondition(whenClause.getWhenExpression(), table);
            processCondition(whenClause.getThenExpression(), table);
        }
        // 如果是 case 语句,则需要对 switch、when、then、else 四个条件进行递归处理
        else if (expression instanceof CaseExpression) {
            CaseExpression caseExpression = (CaseExpression) expression;
            processCondition(caseExpression.getSwitchExpression(), table);
            List<WhenClause> whenClauses = caseExpression.getWhenClauses();
            if (CollUtil.isNotEmpty(whenClauses)) {
                whenClauses.forEach(whenClause -> {
                    processCondition(whenClause.getWhenExpression(), table);
                    processCondition(whenClause.getThenExpression(), table);
                });
            }
            processCondition(caseExpression.getElseExpression(), table);
        }
        // 如果是 exists 语句,比如:exists (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof ExistsExpression) {
            Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
            if (existsExpression instanceof SubSelect) {
                processSelect(((SubSelect) existsExpression).getSelectBody());
            }
        }
        // 如果是 all 或者 any 语句,比如:xx > all (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof AllComparisonExpression) {
            AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
            processSelect(allComparisonExpression.getSubSelect().getSelectBody());
        }
        else if (expression instanceof AnyComparisonExpression) {
            AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
            processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
        }
        // 如果是 cast 语句,比如:cast(xx as xx),则需要对子查询进行递归处理
        else if (expression instanceof CastExpression) {
            CastExpression castExpression = (CastExpression) expression;
            processCondition(castExpression.getLeftExpression(), table);
        }

        // 拼接查询条件
        Expression appendCondition = handleCondition(expression, table);
        return Objects.isNull(appendCondition) ? expression : appendCondition;
    }

    /**
     * 返回一个查询条件,该查询条件将替换{@code table}原有的{@code where}条件
     *
     * @param expression 原有的查询条件
     * @param table 指定的表
     * @return 查询条件
     */
    protected abstract Expression handleCondition(@Nullable Expression expression, FromItem table);
    
    /**
     * 返回一个表名,该表名将替换原有的表名
     *
     * @param table 表名
     * @return 处理后的表名
     */
    protected FromItem handleTable(Table table) {
        return table;
    }

    /**
     * 判断是否是已经是无法再拆分的基本表达式 <br/>
     * 比如:列名、常量、函数等
     *
     * @param expression 表达式
     * @return 是否是基本表达式
     */
    protected boolean isBasicExpression(@Nullable Expression expression) {
        return expression instanceof Column
            || expression instanceof LongValue
            || expression instanceof StringValue
            || expression instanceof DoubleValue
            || expression instanceof NullValue
            || expression instanceof TimeValue
            || expression instanceof TimestampValue
            || expression instanceof DateValue;
    }
}

�接着,对于原本的 SQL 拦截器,我们令其继承 AbstractSqlHandler,然后更换一个更合适的名字 LineLevelTenantSqlHandler

/**
 * SQL拦截器,用于为SQL语句添加租户条件。
 * 每次执行SQL时,将会检查当前线程上下文中是否存在租户信息,如果存在,则会为查询语句添加租户条件,否则直接略过。
 *
 * @author huangchengxing
 * @see ContextTenantConditionSqlHandlerAdvisor
 */
@Slf4j
public class LineLevelTenantSqlHandler extends AbstractConditionSqlHandler {

    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 设置租户信息
     *
     * @param tenantInfo 租户信息
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租户信息
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }

    @Override
    public String handle(String sql) {
        // 如果未设置租户信息,则直接返回原始SQL
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        if (Objects.isNull(tenantInfo)) {
            return sql;
        }
        log.debug("租户拦截器拦截原始 SQL: {}", sql);
        String handledSql = super.handle(sql);
        log.info("租户拦截器拦截后 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    @Override
    @Nullable
    protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        // 如果是一个标准表名,且改表名在租户表列表中,则为查询条件添加租户条件
        if (!(table instanceof Table)) {
            return null;
        }
        String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
        if (Objects.nonNull(tenantColumn)) {
            return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
        }
        return null;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租户信息
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租户ID
         */
        private final String tenantId;
        /**
         * 要添加租户条件的表名称与对应的租户字段
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

1.5.与 JPA 结合使用

JPA 的默认实现 Hibernate 提供了 StatementInspector 接口,我们实现一个自定义的实现类,然后让基础上文实现好的租户解析器即可 LineLevelTenantSqlHandler

/**
 * SQL拦截器,用于为SQL语句添加租户条件。
 * 每次执行SQL时,将会检查当前线程中是否存在租户信息,如果存在,则会为查询语句添加租户条件,否则直接略过。
 *
 * @author huangchengxing
 */
@Slf4j
@RequiredArgsConstructor
public class HibernateLineLevelTenantStatementInspector
    extends LineLevelTenantSqlHandler implements StatementInspector {

    @Override
    public String inspect(String sql) {
        return handle(sql);
    }
}

同理,我们也可以结合 Mybatis 或其他的框架实现类似的效果。

2.租户拦截器

显然,我们不可能无条件的拦截所有的查询,有些查询本身不需要进行拦截,而有些查询当访问者为管理员时也不需要拦截……总而言之,对应租户拦截,我们需要采用白名单而不是黑名单的方式,因此最好的实现方法就是搞一个切面,然后只对带有特定注解的方法的调用进行拦截。

2.1.注解类

我们定义一个 @TenantOperation 注解,该注解可以被用于方法或者类上,当用于类上的时候等于类中所有的方法都应用拦截:

/**
 * 表明方法是一个租户操作方法,需要在相关的SQL中加入租户过滤条件
 *
 * @author huangchengxing
 * @see ContextTenantConditionSqlHandlerAdvisor
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
public @interface TenantOperation {

    /**
     * 表配置
     *
     * @return 表配置
     */
    Tables[] value() default {};

    /**
     * 是否对当前方法与后续调用链不进行租户拦截
     *
     * @return boolean
     * @see Ignore
     */
    boolean ignore() default false;

    /**
     * 对当前方法与后续调用链不进行租户拦截
     */
    @TenantOperation(ignore = true) // 基于 Springt 合成注解机制的扩展注解
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE})
    @interface Ignore {}

    /**
     * 表配置
     */
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @interface Tables {

        /**
         * 租户字段名,不指定时默认遵循配置文件中的字段名
         *
         * @return String
         */
        String column() default "";

        /**
         * 需要添加过滤条件的表名,不指定时默认遵循配置文件中的表名
         *
         * @return String
         */
        String[] tables() default {};
    }
}

此外,为了便于使用,注解还支持直接指定要拦截的表和字段,以便覆盖默认配置文件中的配置。

2.2.方法拦截器

为了便于后续扩展,这里笔者没有基于 Aspect 注解,而是基于 Spring 的方法拦截器,自定义了切点来实现这个效果:

/**
 * 方法拦截器,用于拦截带有{@link TenantOperation}注解的方法,为涉及的查询语句添加租户过滤条件
 *
 * @author huangchengxing
 * @see LineLevelTenantOperationAdvisor
 */
@Slf4j
public class LineLevelTenantOperationAdvisor implements PointcutAdvisor, MethodInterceptor {

    private static final String INTERCEPT_REQUEST_ENTRY = "tenant";
    private static final TenantOpsInfo NULL = new TenantOpsInfo(null);
    private final Map<Method, TenantOpsInfo> tenantInfoCaches = new ConcurrentReferenceHashMap<>();
    private final TenantOpsInfo opsByDefault;

    public LineLevelTenantOperationAdvisor(Map<String, String> tableWithColumns) {
        this.opsByDefault = new TenantOpsInfo(tableWithColumns);
    }

    @Override
    public Object invoke(MethodInvocation methodInvocation) throws Throwable {
        // 从上下文获取租户ID
        String tenantId = Optional.ofNullable(RequestUserContext.getUser())
            .map(RequestUserContext.User::getUserId)
            .orElse(null);
        // 若没有上下文信息,则直接放行
        if (Objects.isNull(tenantId)) {
            return methodInvocation.proceed();
        }

        // 解析配置信息
        TenantOpsInfo info = resolveMethod(methodInvocation.getMethod());
        if (info == NULL) {
            return methodInvocation.proceed();
        }

        // 设置租户信息
        try {
            LineLevelTenantSqlHandler.setTenantInfo(info.getTenantInfo(tenantId));
            return methodInvocation.proceed();
        } finally {
            LineLevelTenantSqlHandler.clearTenantInfo();
        }
    }

    private TenantOpsInfo resolveMethod(Method method) {
        return tenantInfoCaches.computeIfAbsent(method, m -> {
            // 从方法上或类上获取注解
            TenantOperation annotation = Optional.ofNullable(AnnotatedElementUtils.findMergedAnnotation(method, TenantOperation.class))
                .orElse(AnnotatedElementUtils.findMergedAnnotation(method.getDeclaringClass(), TenantOperation.class));
            if (Objects.isNull(annotation)) {
                return NULL;
            }
            // 若注解未指定column和tables,则使用默认值
            TenantOperation.Tables[] tables = annotation.value();
            if (ArrayUtil.isEmpty(tables)) {
                return opsByDefault;
            }
            // 若指定了column和tables,则使用指定值
            Map<String, String> tableWithColumns = new HashMap<>(tables.length);
            for (TenantOperation.Tables table : tables) {
                String column = table.column();
                for (String tableName : table.tables()) {
                    tableWithColumns.put(tableName, column);
                }
            }
            return new TenantOpsInfo(tableWithColumns);
        });
    }

    @RequiredArgsConstructor
    private static class TenantOpsInfo {
        private final Map<String, String> tablesWithTenantColumn;
        public LineLevelTenantSqlHandler.TenantInfo getTenantInfo(String tenantId) {
            return new LineLevelTenantSqlHandler.TenantInfo(tenantId, tablesWithTenantColumn);
        }
    }

    @Override
    public @NonNull Pointcut getPointcut() {
        return TenantOperationPointcut.INSTANCE;
    }

    @Override
    public @NonNull Advice getAdvice() {
        return this;
    }

    @Override
    public boolean isPerInstance() {
        return false;
    }

    // 自定义切点,拦截带有 @TenantOperation 注解的方法,或声明类上带有 @TenantOperation 注解的全部方法
    private static class TenantOperationPointcut extends StaticMethodMatcher implements Pointcut {
        public static final TenantOperationPointcut INSTANCE = new TenantOperationPointcut();
        @Override
        public @NonNull ClassFilter getClassFilter() {
            return ClassFilter.TRUE;
        }
        @Override
        public @NonNull MethodMatcher getMethodMatcher() {
            return this;
        }
        @Override
        public boolean matches(@NonNull Method method, @NonNull Class<?> type) {
            return AnnotatedElementUtils.isAnnotated(method, TenantOperation.class)
                || AnnotatedElementUtils.isAnnotated(type, TenantOperation.class);
        }
    }
}

2.3.上下文传递问题

如上文,我们选择使用方法拦截器在方法执行前设置租户信息,在方法执行后清空租户信息,这种做法在当同一条调用链上,同上触发了多次拦截时就会出现问题:

// 设置租户信息
try {
    LineLevelTenantSqlHandler.setTenantInfo(info.getTenantInfo(tenantId));
    return methodInvocation.proceed();
} finally {
    LineLevelTenantSqlHandler.clearTenantInfo();
}

举个例子,假如我们存在如下的调用:

@TenantOperation(
    @Tables(column = "userId")
)
public void method1() {
    // do something
    method2();
    method3()
    // do something
}

@TenantOperation(ignore = true) // 该方法不需要进行租户拦截
public void method1() {
    // do something
}

@TenantOperation(
    @Tables(column = "tenantId")
)
public void method1() {
    // do something
}

如果我们假设每个方法都能被正确的拦截,那么按原有的代码,当执行了 method2 以后,由于直接清空了上下文,最终会导致后续的调用都没有办法正确获取到租户信息。同理,
这个问题与 Spring 的事务传播有点异曲同工,我们的解决方案也类似,那就引入“挂起”这个概念。简单的来说,如果有一个被拦截的方法触发了上下文租户信息的更新,纳那么:

  • 如果上下文已经存在租户信息,说明当前方法只是调用链中的一个环节,那么就需要先将其挂起,先放入当前方法配置的租户信息,等到执行结束后,再将旧的租户信息放回上下文;
  • 如果上下文中没有存在租户信息,说明当前方法已经是调用链的源头,那么当执行完毕后,可以直接请上下文清空。

对此,我们参照 Spring 的做法,稍微调整一下这部分代码即可:

// 暂时挂起上一层级方法设置的租户信息
LineLevelTenantSqlHandler.TenantInfo previous = LineLevelTenantSqlHandler.getTenantInfo();
// 若当前方法设置了忽略租户信息,则清空上下文,否则设置当前租户信息
if (info.isIgnore()) {
    LineLevelTenantSqlHandler.clearTenantInfo();
} else {
    LineLevelTenantSqlHandler.TenantInfo current = new LineLevelTenantSqlHandler.TenantInfo(tenantId, info.getTablesWithTenantColumn());
    LineLevelTenantSqlHandler.setTenantInfo(current);
}
try {
    return methodInvocation.proceed();
} finally {
    // 恢复挂起的租户信息
    if (Objects.nonNull(previous)) {
        LineLevelTenantSqlHandler.setTenantInfo(previous);
    }
    // 若之前没有租户信息,则清空上下文
    else {
        LineLevelTenantSqlHandler.clearTenantInfo();
    }
}

3.使用

3.1.配置类

首先,我们先定义一个配置类以在项目中启用上述组件:

/**
 * <p>租户拦截器配置,启用后可以为指定的查询方法添加租户过滤条件。 <br/>
 * 可通过配置文件进行配置:<br/>
 * <pre>
 * # JPA 启用租户 SQL 拦截器
 * spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
 * # 启用租户拦截器
 * tenant.interceptor.enabled=true
 * # 需要拦截的表
 * tenant.interceptor.tables[0].column = tenant_id
 * tenant.interceptor.tables[0].tableNames = table1, table2
 * </pre>
 *
 * @author huangchengxing
 */
@Slf4j
@ConditionalOnProperty(prefix = TenantInterceptorConfig.Properties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@EnableConfigurationProperties(TenantInterceptorConfig.Properties.class)
@Configuration
public class TenantInterceptorConfig {

    @Bean
    public LineLevelTenantOperationAdvisor lineLevelTenantOperationAdvisor(Properties properties) {
        log.info("启用租户拦截器,需要拦截的表:{}", properties.getTables());
        Map<String, String> tableWithColumns = new HashMap<>(16);
        properties.getTables().forEach(ts -> ts.getTableNames().forEach(t -> {
            Assert.isFalse(tableWithColumns.containsKey(t), "同一张表具备只允许具备一个租户字段:{}", t);
            tableWithColumns.put(t, ts.getColumn());
        }));
        return new LineLevelTenantOperationAdvisor(tableWithColumns);
    }

    /**
     * @author huangchengxing
     */
    @ConfigurationProperties(prefix = Properties.CONFIG_PREFIX)
    @Data
    public static class Properties {

        public static final String CONFIG_PREFIX = "tenant.interceptor";

        /**
         * 表配置
         */
        private List<Tables> tables = new ArrayList<>();

        @Data
        public static class Tables {

            /**
             * 租户字段名
             */
            private String column;

            /**
             * 需要拦截的表名
             */
            private Set<String> tableNames;
        }
    }
}

3.2.配置文件

随后在配置文件中启用拦截器,并配置好要拦截的表:

# 启用租户 SQL 拦截器
spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
# 启用租户拦截器
tenant.interceptor.enabled=true
# 拦截 t_user, t_resource, t_assest 表中的 tenant_id 字段
tenant.interceptor.tables[0].column=tenant_id
tenant.interceptor.tables[0].table-names=t_user, t_resource, t_assest

3.3.添加注解

最后,我们只要在对应的类或者方法上添加 @TenantOperation 即可:

@TenantOperation // 默认所有方法都要应用拦截
@RestController
public class ResourceController {

    // @TenantOperation 因为类上已经加了,所以方法上可以不用加
    @GetMapping
    public List<Resource> listResource1(List<Integer> ids) {
        // do something
    }

    @TenantOperation.Ingore // 该方法不进行拦截
    @GetMapping
    public List<Resource> listResource2(List<Integer> ids) {
        // do something
    }
}
posted @ 2024-03-25 00:59  Createsequence  阅读(71)  评论(0编辑  收藏  举报