Mybatis利用拦截器做统一分页

mybatis利用拦截器做统一分页

 

查询传递Page参数,或者传递继承Page的对象参数。拦截器查询记录之后,通过改造查询sql获取总记录数。赋值Page对象,返回。

示例项目:https://github.com/windwant/spring-boot-service

https://github.com/windwant/spring-dubbo-service/tree/master/spring-boot-server


拦截器:

package com.xxx;

import com.xxx.Page;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.Properties;

/** 
*/
@Intercepts({//拦截query查询
    @Signature(type=Executor.class,method="query",args={MappedStatement.class,Object.class,RowBounds.class,ResultHandler.class})
})
public class PageIntercept implements Interceptor {

    public static final Logger logger = LoggerFactory.getLogger(PageIntercept.class);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object parameter = invocation.getArgs()[1];
        Object args = invocation.getArgs()[1];
		Object returnValue = invocation.proceed();
        if(args instanceof Page && ((Page) args).isCount()){
            MappedStatement mappedStatement=(MappedStatement)invocation.getArgs()[0];
            BoundSql boundSql = mappedStatement.getBoundSql(parameter);
            String originalSql = boundSql.getSql().trim();
            Object parameterObject = boundSql.getParameterObject();
            if(parameterObject instanceof Page){
                ((Page) parameterObject).setOffset(0);
                ((Page) parameterObject).setLimit(Integer.MAX_VALUE/((Page) parameterObject).getPage());
            }

            String countSql = "SELECT COUNT(*) FROM (" + originalSql + ") aliasForPage";
            Connection connection=mappedStatement.getConfiguration().getEnvironment().getDataSource().getConnection();

            BoundSql countBS = copyFromBoundSql(mappedStatement, boundSql, countSql);
            DefaultParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBS);
            PreparedStatement countStmt = connection.prepareStatement(countSql);
            parameterHandler.setParameters(countStmt);
            ResultSet rs = countStmt.executeQuery();
            int total =0;
            if (rs.next()) {
                total = rs.getInt(1);
            }
            rs.close();
            countStmt.close();
            connection.close();
            //分页计算
            ((Page) args).setTotal(total);
        }

	return returnValue;
    }


    private BoundSql copyFromBoundSql(MappedStatement ms, BoundSql boundSql, String sql) {
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(),sql, boundSql.getParameterMappings(), boundSql.getParameterObject());
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
                }
            }
        return newBoundSql;
    }



    @Override
	public Object plugin(Object target) {
		 return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties properties) {

	}

}

数据库配置:

package com.xxx.config;

import javax.sql.DataSource;

import com.xxx.PageIntercept;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.SqlSessionFactoryBean;
import org.mybatis.spring.annotation.MapperScan;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.boot.autoconfigure.jdbc.DataSourceBuilder;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.jdbc.datasource.DataSourceTransactionManager;

/**
* @description 
*/

@Configuration
@MapperScan(basePackages = "com.xxx.dao", sqlSessionFactoryRef = "sqlSessionFactory")
public class DBConfig {

	private final static Logger logger = LoggerFactory.getLogger(DBConfig.class);
	
	@Bean(name = "dataSource")
    @ConfigurationProperties(prefix = "datasource.planes")
	@Primary
    public DataSource dataSource() {
        return DataSourceBuilder.create().build();
    }
	
	@Bean(name = "sqlSessionFactory")
	@Primary
    public SqlSessionFactory sqlSessionFactory() throws Exception {
        SqlSessionFactoryBean bean = new SqlSessionFactoryBean();
        bean.setDataSource(dataSource());
        bean.setPlugins(new Interceptor[]{new PageIntercept()});
        return bean.getObject();
    }
	
	@Bean(name = "transactionManager")
	@Primary
    public DataSourceTransactionManager transactionManager() {
        return new DataSourceTransactionManager(dataSource());
    }
	
}  

Page对象:

package com.xxx;

import com.xxx.Constants;

/**
 * Created by Administrator on 2018/1/5.
 */
public class Page {

    private int offset = 0;

    private int limit = Constants.DEAFULT_PAGE_LIMIT;

    private int total = 0;

    private int page = 1;

    private boolean count = true;

    public boolean isCount() {
        return count;
    }

    public void setCount(boolean count) {
        this.count = count;
    }

    public Page(){}

    public Page(int offset, int limit){
        this. offset = offset;
        this.limit = limit;
    }
    public int getPage() {
        return page;
    }

    public void setPage(int page) {
        this.page = page < 1?1:page;
    }

    public int getOffset() {
        return (page - 1) * limit;
    }

    public void setOffset(int offset) {
        this.offset = offset;
    }

    public int getLimit() {
        return limit;
    }

    public void setLimit(int limit) {
        this.limit = limit;
    }

    public int getTotal() {
        return total;
    }

    public void setTotal(int total) {
        this.total = total;
    }
}

  

 

 

posted @ 2018-01-08 16:15  WindWant  阅读(2157)  评论(0编辑  收藏  举报
文章精选列表