Mybatis分页拦截器

首先在Mybatis的配置文件中,进行拦截器的配置:

 1 <!-- sqlSessionFactory -->
 2     <bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
 3         <!-- 数据库连接池 -->
 4         <property name="dataSource" ref="dataSource" />
 5         
 6         <!-- 批量扫描别名 -->
 7         <property name="typeAliasesPackage" value="ssm.po" />  
 8         
 9         <!-- spring与mybatis整合不需要mybatis配置文件了,直接扫描mapper下的映射文件 -->
10         <property name="mapperLocations" value="classpath:ssm/mapper/*.xml" />
11         
12         <!-- MybatisSpringPageInterceptor分页拦截器 -->
13         <property name="plugins">  
14             <bean class="ssm.utils.MybatisSpringPageInterceptor"></bean>
15         </property>
16          
17     </bean>

 

拦截器代码:

  1 package ssm.utils;
  2 
  3 import java.lang.reflect.Field;
  4 import java.sql.Connection;
  5 import java.sql.PreparedStatement;
  6 import java.sql.ResultSet;
  7 import java.sql.SQLException;
  8 import java.util.List;
  9 import java.util.Map;
 10 import java.util.Properties;
 11 
 12 import org.apache.ibatis.executor.Executor;
 13 import org.apache.ibatis.executor.parameter.ParameterHandler;
 14 import org.apache.ibatis.executor.statement.RoutingStatementHandler;
 15 import org.apache.ibatis.executor.statement.StatementHandler;
 16 import org.apache.ibatis.mapping.BoundSql;
 17 import org.apache.ibatis.mapping.MappedStatement;
 18 import org.apache.ibatis.mapping.ParameterMapping;
 19 import org.apache.ibatis.plugin.Interceptor;
 20 import org.apache.ibatis.plugin.Intercepts;
 21 import org.apache.ibatis.plugin.Invocation;
 22 import org.apache.ibatis.plugin.Plugin;
 23 import org.apache.ibatis.plugin.Signature;
 24 import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
 25 import org.apache.ibatis.session.ResultHandler;
 26 import org.apache.ibatis.session.RowBounds;
 27 import org.slf4j.Logger;
 28 import org.slf4j.LoggerFactory;
 29 
 30 @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }),
 31     @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) })
 32 public class MybatisSpringPageInterceptor implements Interceptor {
 33     private static final Logger log = LoggerFactory.getLogger(MybatisSpringPageInterceptor.class);
 34 
 35     public static final String MYSQL = "mysql";
 36     public static final String ORACLE = "oracle";
 37 
 38     protected String databaseType;// 数据库类型,不同的数据库有不同的分页方法
 39 
 40     @SuppressWarnings("rawtypes")
 41     protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>();
 42 
 43     public String getDatabaseType() {
 44         return databaseType;
 45     }
 46 
 47     public void setDatabaseType(String databaseType) {
 48         if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) {
 49             throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]");
 50         }
 51         this.databaseType = databaseType;
 52     }
 53 
 54     @Override
 55     public Object plugin(Object target) {
 56         return Plugin.wrap(target, this);
 57     }
 58 
 59     @Override
 60     public void setProperties(Properties properties) {
 61         String databaseType = properties.getProperty("databaseType");
 62         if (databaseType != null) {
 63             setDatabaseType(databaseType);
 64         }
 65     }
 66 
 67     @Override
 68     @SuppressWarnings({ "unchecked", "rawtypes" })
 69     public Object intercept(Invocation invocation) throws Throwable {
 70         if (invocation.getTarget() instanceof StatementHandler) { // 控制SQL和查询总数的地方
 71             Page page = pageThreadLocal.get();
 72             if (page == null) { //不是分页查询
 73                 return invocation.proceed();
 74             }
 75 
 76             RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
 77             StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
 78             BoundSql boundSql = delegate.getBoundSql();
 79             
 80             Connection connection = (Connection) invocation.getArgs()[0];
 81             prepareAndCheckDatabaseType(connection); // 准备数据库类型
 82 
 83             if (page.getTotalPage() > -1) {
 84                 if (log.isTraceEnabled()) {
 85                     log.trace("已经设置了总页数, 不需要再查询总数.");
 86                 }
 87             } else {
 88                 Object parameterObj = boundSql.getParameterObject();
 89                 MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
 90                 queryTotalRecord(page, parameterObj, mappedStatement, connection);
 91             }
 92 
 93             String sql = boundSql.getSql();
 94             String pageSql = buildPageSql(page, sql);
 95             if (log.isDebugEnabled()) {
 96                 log.debug("分页时, 生成分页pageSql: " + pageSql);
 97             }
 98             ReflectUtil.setFieldValue(boundSql, "sql", pageSql);
 99 
100             return invocation.proceed();
101         } else { // 查询结果的地方
102             // 获取是否有分页Page对象
103             Page<?> page = findPageObject(invocation.getArgs()[1]);
104             if (page == null) {
105                 if (log.isTraceEnabled()) {
106                     log.trace("没有Page对象作为参数, 不是分页查询.");
107                 }
108                 return invocation.proceed();
109             } else {
110                 if (log.isTraceEnabled()) {
111                     log.trace("检测到分页Page对象, 使用分页查询.");
112                 }
113             }
114             //设置真正的parameterObj
115             invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]);
116 
117             pageThreadLocal.set(page);
118             try {
119                 Object resultObj = invocation.proceed(); // Executor.query(..)
120                 if (resultObj instanceof List) {
121                     /* @SuppressWarnings({ "unchecked", "rawtypes" }) */
122                     page.setResults((List) resultObj);
123                 }
124                 return resultObj;
125             } finally {
126                 pageThreadLocal.remove();
127             }
128         }
129     }
130 
131     protected Page<?> findPageObject(Object parameterObj) {
132         if (parameterObj instanceof Page<?>) {
133             return (Page<?>) parameterObj;
134         } else if (parameterObj instanceof Map) {
135             for (Object val : ((Map<?, ?>) parameterObj).values()) {
136                 if (val instanceof Page<?>) {
137                     return (Page<?>) val;
138                 }
139             }
140         }
141         return null;
142     }
143 
144     /**
145      * <pre>
146      * 把真正的参数对象解析出来
147      * Spring会自动封装对个参数对象为Map<String, Object>对象
148      * 对于通过@Param指定key值参数我们不做处理,因为XML文件需要该KEY值
149      * 而对于没有@Param指定时,Spring会使用0,1作为主键
150      * 对于没有@Param指定名称的参数,一般XML文件会直接对真正的参数对象解析,
151      * 此时解析出真正的参数作为根对象
152      * </pre>
153      * @param parameterObj
154      * @return
155      */
156     protected Object extractRealParameterObject(Object parameterObj) {
157         if (parameterObj instanceof Map<?, ?>) {
158             Map<?, ?> parameterMap = (Map<?, ?>) parameterObj;
159             if (parameterMap.size() == 2) {
160                 boolean springMapWithNoParamName = true;
161                 for (Object key : parameterMap.keySet()) {
162                     if (!(key instanceof String)) {
163                         springMapWithNoParamName = false;
164                         break;
165                     }
166                     String keyStr = (String) key;
167                     if (!"0".equals(keyStr) && !"1".equals(keyStr)) {
168                         springMapWithNoParamName = false;
169                         break;
170                     }
171                 }
172                 if (springMapWithNoParamName) {
173                     for (Object value : parameterMap.values()) {
174                         if (!(value instanceof Page<?>)) {
175                             return value;
176                         }
177                     }
178                 }
179             }
180         }
181         return parameterObj;
182     }
183 
184     protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException {
185         if (databaseType == null) {
186             String productName = connection.getMetaData().getDatabaseProductName();
187             if (log.isTraceEnabled()) {
188                 log.trace("Database productName: " + productName);
189             }
190             productName = productName.toLowerCase();
191             if (productName.indexOf(MYSQL) != -1) {
192                 databaseType = MYSQL;
193             } else if (productName.indexOf(ORACLE) != -1) {
194                 databaseType = ORACLE;
195             } else {
196                 throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]");
197             }
198             if (log.isInfoEnabled()) {
199                 log.info("自动检测到的数据库类型为: " + databaseType);
200             }
201         }
202     }
203 
204     /**
205      * <pre>
206      * 生成分页SQL
207      * </pre>
208      * 
209      * @param page
210      * @param sql
211      * @return
212      */
213     protected String buildPageSql(Page<?> page, String sql) {
214         if (MYSQL.equalsIgnoreCase(databaseType)) {
215             return buildMysqlPageSql(page, sql);
216         } else if (ORACLE.equalsIgnoreCase(databaseType)) {
217             return buildOraclePageSql(page, sql);
218         }
219         return sql;
220     }
221 
222     /**
223      * <pre>
224      * 生成Mysql分页查询SQL
225      * </pre>
226      * 
227      * @param page
228      * @param sql
229      * @return
230      */
231     protected String buildMysqlPageSql(Page<?> page, String sql) {
232         // 计算第一条记录的位置,Mysql中记录的位置是从0开始的。
233         int offset = (page.getPageNo() - 1) * page.getPageSize();
234         return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString();
235     }
236 
237     /**
238      * <pre>
239      * 生成Oracle分页查询SQL
240      * </pre>
241      * 
242      * @param page
243      * @param sql
244      * @return
245      */
246     protected String buildOraclePageSql(Page<?> page, String sql) {
247         // 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的
248         int offset = (page.getPageNo() - 1) * page.getPageSize() + 1;
249         StringBuilder sb = new StringBuilder(sql);
250         sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize());
251         sb.insert(0, "select * from (").append(") where r >= ").append(offset);
252         return sb.toString();
253     }
254 
255     /**
256      * <pre>
257      * 查询总数
258      * </pre>
259      * 
260      * @param page
261      * @param parameterObject
262      * @param mappedStatement
263      * @param connection
264      * @throws SQLException
265      */
266     protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, Connection connection) throws SQLException {
267         BoundSql boundSql = mappedStatement.getBoundSql(page);
268         String sql = boundSql.getSql();
269         String countSql = this.buildCountSql(sql);
270         if (log.isDebugEnabled()) {
271             log.debug("分页时, 生成countSql: " + countSql);
272         }
273 
274         List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
275         BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
276         ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
277         PreparedStatement pstmt = null;
278         ResultSet rs = null;
279         try {
280             pstmt = connection.prepareStatement(countSql);
281             parameterHandler.setParameters(pstmt);
282             rs = pstmt.executeQuery();
283             if (rs.next()) {
284                 long totalRecord = rs.getLong(1);
285                 page.setTotalRecord(totalRecord);
286             }
287         } finally {
288             if (rs != null)
289                 try {
290                     rs.close();
291                 } catch (Exception e) {
292                     if (log.isWarnEnabled()) {
293                         log.warn("关闭ResultSet时异常.", e);
294                     }
295                 }
296             if (pstmt != null)
297                 try {
298                     pstmt.close();
299                 } catch (Exception e) {
300                     if (log.isWarnEnabled()) {
301                         log.warn("关闭PreparedStatement时异常.", e);
302                     }
303                 }
304         }
305     }
306 
307     /**
308      * 根据原Sql语句获取对应的查询总记录数的Sql语句
309      * 
310      * @param sql
311      * @return
312      */
313     protected String buildCountSql(String sql) {
314         int index = sql.indexOf("from");
315         return "select count(*) " + sql.substring(index);
316     }
317 
318     /**
319      * 利用反射进行操作的一个工具类
320      * 
321      */
322     private static class ReflectUtil {
323         /**
324          * 利用反射获取指定对象的指定属性
325          * 
326          * @param obj 目标对象
327          * @param fieldName 目标属性
328          * @return 目标属性的值
329          */
330         public static Object getFieldValue(Object obj, String fieldName) {
331             Object result = null;
332             Field field = ReflectUtil.getField(obj, fieldName);
333             if (field != null) {
334                 field.setAccessible(true);
335                 try {
336                     result = field.get(obj);
337                 } catch (IllegalArgumentException e) {
338                     // TODO Auto-generated catch block
339                     e.printStackTrace();
340                 } catch (IllegalAccessException e) {
341                     // TODO Auto-generated catch block
342                     e.printStackTrace();
343                 }
344             }
345             return result;
346         }
347 
348         /**
349          * 利用反射获取指定对象里面的指定属性
350          * 
351          * @param obj 目标对象
352          * @param fieldName 目标属性
353          * @return 目标字段
354          */
355         private static Field getField(Object obj, String fieldName) {
356             Field field = null;
357             for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
358                 try {
359                     field = clazz.getDeclaredField(fieldName);
360                     break;
361                 } catch (NoSuchFieldException e) {
362                     // 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
363                 }
364             }
365             return field;
366         }
367 
368         /**
369          * 利用反射设置指定对象的指定属性为指定的值
370          * 
371          * @param obj 目标对象
372          * @param fieldName 目标属性
373          * @param fieldValue 目标值
374          */
375         public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
376             Field field = ReflectUtil.getField(obj, fieldName);
377             if (field != null) {
378                 try {
379                     field.setAccessible(true);
380                     field.set(obj, fieldValue);
381                 } catch (IllegalArgumentException e) {
382                     // TODO Auto-generated catch block
383                     e.printStackTrace();
384                 } catch (IllegalAccessException e) {
385                     // TODO Auto-generated catch block
386                     e.printStackTrace();
387                 }
388             }
389         }
390     }
391 
392     public static class PageNotSupportException extends RuntimeException {
393 
394         /** serialVersionUID*/
395         private static final long serialVersionUID = 1L;
396 
397         public PageNotSupportException() {
398             super();
399         }
400 
401         public PageNotSupportException(String message, Throwable cause) {
402             super(message, cause);
403         }
404 
405         public PageNotSupportException(String message) {
406             super(message);
407         }
408 
409         public PageNotSupportException(Throwable cause) {
410             super(cause);
411         }
412     }
413 
414 }

 

posted @ 2016-01-09 17:48  小xuan  阅读(848)  评论(0)    收藏  举报