PageHelper源码分析

PageHelper源码分析

分析版本:5.2.0

项目地址:https://github.com/pagehelper/Mybatis-PageHelper

可以看下作者写的关于拦截器的文章

分页的原理,就是根据Mybatis提供的拦截器机制,来对Executor执行SQL语句时做一个拦截,并替换掉原来的SQL语句。

核心代码都在com.github.pagehelper.PageInterceptor,主要逻辑是先在分页前执行count语句(各个不同的类型的数据库可能会有差异,通过实现不同dialect来屏蔽掉),然后改写sql语句注入分页逻辑的语句块。

@Intercepts({
  @Signature(
    type = Executor.class, 
    method = "query", 
    args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
  ),
  @Signature(
    type = Executor.class, 
    method = "query", 
    args = {
      MappedStatement.class, 
      Object.class, 
      RowBounds.class, 
      ResultHandler.class, 
      CacheKey.class, 
      BoundSql.class
    }
  )
})
public class PageInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        try {
            Object[] args = invocation.getArgs();
            MappedStatement ms = (MappedStatement) args[0];
            Object parameter = args[1];
            RowBounds rowBounds = (RowBounds) args[2];
            ResultHandler resultHandler = (ResultHandler) args[3];
            Executor executor = (Executor) invocation.getTarget();
            CacheKey cacheKey;
            BoundSql boundSql;
            //由于逻辑关系,只会进入一次
            if (args.length == 4) {
                //4 个参数时
                boundSql = ms.getBoundSql(parameter);
                cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
            } else {
                //6 个参数时
                cacheKey = (CacheKey) args[4];
                boundSql = (BoundSql) args[5];
            }
            checkDialectExists();
            //对 boundSql 的拦截处理
            if (dialect instanceof BoundSqlInterceptor.Chain) {
                boundSql = ((BoundSqlInterceptor.Chain) dialect).doBoundSql(BoundSqlInterceptor.Type.ORIGINAL, boundSql, cacheKey);
            }
            List resultList;
            //调用方法判断是否需要进行分页,如果不需要,直接返回结果
            if (!dialect.skip(ms, parameter, rowBounds)) {
                //判断是否需要进行 count 查询
                if (dialect.beforeCount(ms, parameter, rowBounds)) {
                    //查询总数
                    Long count = count(executor, ms, parameter, rowBounds, null, boundSql);
                    //处理查询总数,返回 true 时继续分页查询,false 时直接返回
                    if (!dialect.afterCount(count, parameter, rowBounds)) {
                        //当查询总数为 0 时,直接返回空的结果
                        return dialect.afterPage(new ArrayList(), parameter, rowBounds);
                    }
                }
                // ‼️注入分页SQL
                resultList = ExecutorUtil.pageQuery(dialect, executor,
                        ms, parameter, rowBounds, resultHandler, boundSql, cacheKey);
            } else {
                //rowBounds用参数值,不使用分页插件处理时,仍然支持默认的内存分页
                resultList = executor.query(ms, parameter, rowBounds, resultHandler, cacheKey, boundSql);
            }
            return dialect.afterPage(resultList, parameter, rowBounds);
        } finally {
            if(dialect != null){
                dialect.afterAll();
            }
        }
    }
}

ExceutorUtil.pageQuery静态方法

中间会调用dialect来生成分页的sql。

public static <E> List<E> pageQuery(
  Dialect dialect, 
  Executor executor, 
  MappedStatement ms, 
  Object parameter,
  RowBounds rowBounds, 
  ResultHandler resultHandler,
  BoundSql boundSql, 
  CacheKey cacheKey
) throws SQLException {
  //判断是否需要进行分页查询
  if (dialect.beforePage(ms, parameter, rowBounds)) {
    //生成分页的缓存 key
    CacheKey pageKey = cacheKey;
    //‼️处理参数对象,注入参数映射关系
    parameter = dialect.processParameterObject(ms, parameter, boundSql, pageKey);
    //‼️调用方言获取分页 sql
    String pageSql = dialect.getPageSql(ms, boundSql, parameter, rowBounds, pageKey);
    BoundSql pageBoundSql = new BoundSql(ms.getConfiguration(), pageSql, boundSql.getParameterMappings(), parameter);

    Map<String, Object> additionalParameters = getAdditionalParameter(boundSql);
    //设置动态参数
    for (String key : additionalParameters.keySet()) {
      pageBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
    }
    //对 boundSql 的拦截处理
    if (dialect instanceof BoundSqlInterceptor.Chain) {
      pageBoundSql = ((BoundSqlInterceptor.Chain) dialect).doBoundSql(BoundSqlInterceptor.Type.PAGE_SQL, pageBoundSql, pageKey);
    }
    //执行分页查询
    return executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, pageKey, pageBoundSql);
  } else {
    //不执行分页的情况下,也不执行内存分页
    return executor.query(ms, parameter, RowBounds.DEFAULT, resultHandler, cacheKey, boundSql);
  }
}

MySQLDialect(MySQL方言)

MySQLDialect用两个方法:processPageParametergetPageSql,前者生成占位符对于的参数映射关系,后者生成分页SQL语句。

public class MySqlDialect extends AbstractHelperDialect {

    // 注入参数,对应占位符?
    @Override
    public Object processPageParameter(MappedStatement ms, Map<String, Object> paramMap, Page page, BoundSql boundSql, CacheKey pageKey) {
        paramMap.put(PAGEPARAMETER_FIRST, page.getStartRow());
        paramMap.put(PAGEPARAMETER_SECOND, page.getPageSize());
        //处理pageKey
        pageKey.update(page.getStartRow());
        pageKey.update(page.getPageSize());
        //处理参数配置
        if (boundSql.getParameterMappings() != null) {
            List<ParameterMapping> newParameterMappings = new ArrayList<ParameterMapping>(boundSql.getParameterMappings());
            if (page.getStartRow() == 0) {
                newParameterMappings.add(new ParameterMapping.Builder(ms.getConfiguration(), PAGEPARAMETER_SECOND, int.class).build());
            } else {
                newParameterMappings.add(new ParameterMapping.Builder(ms.getConfiguration(), PAGEPARAMETER_FIRST, long.class).build());
                newParameterMappings.add(new ParameterMapping.Builder(ms.getConfiguration(), PAGEPARAMETER_SECOND, int.class).build());
            }
            MetaObject metaObject = MetaObjectUtil.forObject(boundSql);
            metaObject.setValue("parameterMappings", newParameterMappings);
        }
        return paramMap;
    }

    // 生成SQL语句
    @Override
    public String getPageSql(String sql, Page page, CacheKey pageKey) {
        StringBuilder sqlBuilder = new StringBuilder(sql.length() + 14);
        sqlBuilder.append(sql);
        if (page.getStartRow() == 0) {
            sqlBuilder.append("\n LIMIT ? ");
        } else {
            sqlBuilder.append("\n LIMIT ?, ? ");
        }
        return sqlBuilder.toString();
    }

}

PageHelper如何执行这个过程?

在使用PageHelper.startPage静态方法时,会自动注入一个新的Page对象到PageMethod类中的LOCAL_PAGE静态属性,LOCAL_PAGE是一个线程本地变量,类型为ThreadLocal<Page>,执行sql语句时,mybatis会保持一个线程执行一条SQL语句,不存在多线程竞争,因此,在调用分页方法时,会拦截之后第一次执行的SQL语句,并通过mybatis的拦截器机制实现改写sql语句。(⚠️注意:调用PageHelper.startPage方法不需要实例化PageHelper,因为是静态方法)

public abstract class PageMethod {  
    public static <E> Page<E> startPage(int pageNum, int pageSize, boolean count, Boolean reasonable, Boolean pageSizeZero) {
        Page<E> page = new Page<E>(pageNum, pageSize, count);
        page.setReasonable(reasonable);
        page.setPageSizeZero(pageSizeZero);
        //当已经执行过orderBy的时候
        Page<E> oldPage = getLocalPage();
        if (oldPage != null && oldPage.isOrderByOnly()) {
            page.setOrderBy(oldPage.getOrderBy());
        }
        setLocalPage(page);
        return page;
    }
}

public class PageHelper extends PageMethod implements Dialect, BoundSqlInterceptor.Chain {
  // ...
}

PageInterceptor拦截器,在拦截到Executor的query方法时,会在首次执行时实例化PageHelper。

public class PageInterceptor implements Interceptor {
		private String default_dialect_class = "com.github.pagehelper.PageHelper";
		
  	private void checkDialectExists() {
        if (dialect == null) {
            synchronized (default_dialect_class) {
                if (dialect == null) {
                    setProperties(new Properties());
                }
            }
        }
    }
  
      @Override
    public void setProperties(Properties properties) {
        // ...
        String dialectClass = properties.getProperty("dialect");
        if (StringUtil.isEmpty(dialectClass)) {
            dialectClass = default_dialect_class;
        }
        try {
          	// 使用默认无参构造器创建
            Class<?> aClass = Class.forName(dialectClass);
            dialect = (Dialect) aClass.newInstance();
        } catch (Exception e) {
            throw new PageException(e);
        }
      	// ...
    }
}

无论如何,最后到要走到dialect(PageHelper)的注入分页SQL的逻辑。在PageInterceptor中调用dialect的就是ExecutorUtil.pageQuery方法。这个方法最终又会走到PageHelper上,但是PageHelper并没有实现任何特定数据库的方言处理逻辑,而是把这些逻辑委托给PageAutoDialect。这个类会自动识别当前的数据库类型并执行对应的方言SQL。

![image-20221105033618894](/Users/wenxuan70/Library/Application Support/typora-user-images/image-20221105033618894.png)

public class PageHelper extends PageMethod implements Dialect, BoundSqlInterceptor.Chain {
    @Override
    public String getPageSql(
      MappedStatement ms, 
      BoundSql boundSql, 
      Object parameterObject, 
      RowBounds rowBounds, 
      CacheKey pageKey
    ) {
        return autoDialect.getDelegate()
          .getPageSql(ms, boundSql, parameterObject, rowBounds, pageKey);
    }
}

PageAutoDialec是如何判断数据库类型的?

很简单,通过jdbc的url来判断。

public class PageAutoDialect {
  	// 在类加载阶段,就注册号所有数据库对应的dialect
  	private static Map<String, Class<? extends Dialect>> dialectAliasMap = 
          new HashMap<String, Class<? extends Dialect>>();

    public static void registerDialectAlias(
      String alias, 
      Class<? extends Dialect> dialectClass
    ){
        dialectAliasMap.put(alias, dialectClass);
    }

    static {
        //注册别名
        registerDialectAlias("hsqldb", HsqldbDialect.class);
        registerDialectAlias("h2", HsqldbDialect.class);
        registerDialectAlias("postgresql", HsqldbDialect.class);
        registerDialectAlias("phoenix", HsqldbDialect.class);

        registerDialectAlias("mysql", MySqlDialect.class);
        registerDialectAlias("mariadb", MySqlDialect.class);
        registerDialectAlias("sqlite", MySqlDialect.class);

        registerDialectAlias("herddb", HerdDBDialect.class);

        registerDialectAlias("oracle", OracleDialect.class);
        registerDialectAlias("oracle9i", Oracle9iDialect.class);
        registerDialectAlias("db2", Db2Dialect.class);
        registerDialectAlias("informix", InformixDialect.class);
        //解决 informix-sqli #129,仍然保留上面的
        registerDialectAlias("informix-sqli", InformixDialect.class);

        registerDialectAlias("sqlserver", SqlServerDialect.class);
        registerDialectAlias("sqlserver2012", SqlServer2012Dialect.class);

        registerDialectAlias("derby", SqlServer2012Dialect.class);
        //达梦数据库,https://github.com/mybatis-book/book/issues/43
        registerDialectAlias("dm", OracleDialect.class);
        //阿里云PPAS数据库,https://github.com/pagehelper/Mybatis-PageHelper/issues/281
        registerDialectAlias("edb", OracleDialect.class);
        //神通数据库
        registerDialectAlias("oscar", MySqlDialect.class);
        registerDialectAlias("clickhouse", MySqlDialect.class);
    }
  
  	// 获取方言
    private AbstractHelperDialect getDialect(MappedStatement ms) {
        //改为对dataSource做缓存
        DataSource dataSource = ms.getConfiguration().getEnvironment().getDataSource();
        String url = getUrl(dataSource);
        if (urlDialectMap.containsKey(url)) {
            return urlDialectMap.get(url);
        }
        try {
            lock.lock();
            if (urlDialectMap.containsKey(url)) {
                return urlDialectMap.get(url);
            }
            if (StringUtil.isEmpty(url)) {
                throw new PageException("无法自动获取jdbcUrl,请在分页插件中配置dialect参数!");
            }
            String dialectStr = fromJdbcUrl(url);
            if (dialectStr == null) {
                throw new PageException("无法自动获取数据库类型,请通过 helperDialect 参数指定!");
            }
            AbstractHelperDialect dialect = initDialect(dialectStr, properties);
            urlDialectMap.put(url, dialect);
            return dialect;
        } finally {
            lock.unlock();
        }
    }
}

count语句是如何实现的?

简单点来说,就是把原有的SQL封装为一个子查询,并执行count聚合函数。

伪代码如下:

String originalSql = "select * from xxx";
String countSql = String.format("select count(0) from (%s) as temp_table", originalSql);

当然,PageHelper代码逻辑更复杂,包含对一些特殊情况的处理、对SQL语句的解析等等。


mybatis拦截器实现原理

要想实现mybatis中的拦截器,得先实现Interceptor接口,并在plugin方法中确定是否返回代理类。

mybatis框架在InterceptorChain中实现了拦截器链。

public class InterceptorChain {
  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      // 层层代理
      target = interceptor.plugin(target);
    }
    return target;
  }
}

可代理的目标有四类:ParameterHandlerResultSetHandlerStatementHandlerExecutor

在Configuration类中调用newXxx方法生成对应的实例时,会调用InterceptorChain的pluginAll方法获取代理对象。

public class Configuration {
  public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
    ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
    // 生成代理对象
    parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
    return parameterHandler;
  }

  public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
      ResultHandler resultHandler, BoundSql boundSql) {
    ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
    // 生成代理对象
    resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
    return resultSetHandler;
  }

  public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
    // 生成代理对象
    statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
    return statementHandler;
  }  

  public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    executorType = executorType == null ? defaultExecutorType : executorType;
    executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
    Executor executor;
    if (ExecutorType.BATCH == executorType) {
      executor = new BatchExecutor(this, transaction);
    } else if (ExecutorType.REUSE == executorType) {
      executor = new ReuseExecutor(this, transaction);
    } else {
      executor = new SimpleExecutor(this, transaction);
    }
    if (cacheEnabled) {
      executor = new CachingExecutor(executor);
    }
    // 生成代理对象
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
  }
}

当然,如果不想生成对应类型的对象,这需要在拦截器的plugin方法中返回target参数即可。

如果要生成该怎么做?

调用mybatis提供的工具类Plugin的静态方法wrap。并在拦截器上配置@Intercepts@Signature注解。

这个类通过JDK自带的动态代理来生成代理对象。

public class Plugin implements InvocationHandler {
    public static Object wrap(Object target, Interceptor interceptor) {
    // 解析@Signature注解,获取接口信息和方法信息
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    Class<?> type = target.getClass();
    // 获取需要代理的接口(从拦截器类上的@Signature注解获取需要被代理接口,并判断被代理的对象是否存在相应的接口,全都不存在则不会进行代理)
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    if (interfaces.length > 0) {
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      if (methods != null && methods.contains(method)) {
        // 命中🎯,则调用拦截器方法
        return interceptor.intercept(new Invocation(target, method, args));
      }
      // 否则,执行真实对象的方法
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }
 	
  // ...
}

那么Plugin又是如何知道拦截器要拦截对应接口的哪个方法呢?

通过拦截器的类上的注解Signature,Plugin会在生成代理对象前就把需要代理方法对签名和信息封装到一个map中,并在invoke方法中进行判断和执行。

获取需要被拦截的接口的方法代码如下:

public class Plugin implements Invocation {
  // 解析注解
  private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
    Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
    // issue #251
    if (interceptsAnnotation == null) {
      throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
    }
    Signature[] sigs = interceptsAnnotation.value();
    Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
    for (Signature sig : sigs) {
      Set<Method> methods = signatureMap.computeIfAbsent(sig.type(), k -> new HashSet<>());
      try {
        Method method = sig.type().getMethod(sig.method(), sig.args());
        methods.add(method);
      } catch (NoSuchMethodException e) {
        throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
      }
    }
    return signatureMap;
  }

  private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
    // 获取所有接口
    Set<Class<?>> interfaces = new HashSet<>();
    while (type != null) {
      for (Class<?> c : type.getInterfaces()) {
 				// 判断是否与@Signature注解中的接口是否匹配
        if (signatureMap.containsKey(c)) {
          interfaces.add(c);
        }
      }
      // 迭代查询父接口
      type = type.getSuperclass();
    }
    return interfaces.toArray(new Class<?>[interfaces.size()]);
  }
}
posted @ 2022-11-05 04:27  yghr  阅读(386)  评论(0)    收藏  举报