Mybatis分页拦截器
首先在Mybatis的配置文件中,进行拦截器的配置:
1 <!-- sqlSessionFactory --> 2 <bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean"> 3 <!-- 数据库连接池 --> 4 <property name="dataSource" ref="dataSource" /> 5 6 <!-- 批量扫描别名 --> 7 <property name="typeAliasesPackage" value="ssm.po" /> 8 9 <!-- spring与mybatis整合不需要mybatis配置文件了,直接扫描mapper下的映射文件 --> 10 <property name="mapperLocations" value="classpath:ssm/mapper/*.xml" /> 11 12 <!-- MybatisSpringPageInterceptor分页拦截器 --> 13 <property name="plugins"> 14 <bean class="ssm.utils.MybatisSpringPageInterceptor"></bean> 15 </property> 16 17 </bean>
拦截器代码:
1 package ssm.utils; 2 3 import java.lang.reflect.Field; 4 import java.sql.Connection; 5 import java.sql.PreparedStatement; 6 import java.sql.ResultSet; 7 import java.sql.SQLException; 8 import java.util.List; 9 import java.util.Map; 10 import java.util.Properties; 11 12 import org.apache.ibatis.executor.Executor; 13 import org.apache.ibatis.executor.parameter.ParameterHandler; 14 import org.apache.ibatis.executor.statement.RoutingStatementHandler; 15 import org.apache.ibatis.executor.statement.StatementHandler; 16 import org.apache.ibatis.mapping.BoundSql; 17 import org.apache.ibatis.mapping.MappedStatement; 18 import org.apache.ibatis.mapping.ParameterMapping; 19 import org.apache.ibatis.plugin.Interceptor; 20 import org.apache.ibatis.plugin.Intercepts; 21 import org.apache.ibatis.plugin.Invocation; 22 import org.apache.ibatis.plugin.Plugin; 23 import org.apache.ibatis.plugin.Signature; 24 import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; 25 import org.apache.ibatis.session.ResultHandler; 26 import org.apache.ibatis.session.RowBounds; 27 import org.slf4j.Logger; 28 import org.slf4j.LoggerFactory; 29 30 @Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }), 31 @Signature(method = "query", type = Executor.class, args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class }) }) 32 public class MybatisSpringPageInterceptor implements Interceptor { 33 private static final Logger log = LoggerFactory.getLogger(MybatisSpringPageInterceptor.class); 34 35 public static final String MYSQL = "mysql"; 36 public static final String ORACLE = "oracle"; 37 38 protected String databaseType;// 数据库类型,不同的数据库有不同的分页方法 39 40 @SuppressWarnings("rawtypes") 41 protected ThreadLocal<Page> pageThreadLocal = new ThreadLocal<Page>(); 42 43 public String getDatabaseType() { 44 return databaseType; 45 } 46 47 public void setDatabaseType(String databaseType) { 48 if (!databaseType.equalsIgnoreCase(MYSQL) && !databaseType.equalsIgnoreCase(ORACLE)) { 49 throw new PageNotSupportException("Page not support for the type of database, database type [" + databaseType + "]"); 50 } 51 this.databaseType = databaseType; 52 } 53 54 @Override 55 public Object plugin(Object target) { 56 return Plugin.wrap(target, this); 57 } 58 59 @Override 60 public void setProperties(Properties properties) { 61 String databaseType = properties.getProperty("databaseType"); 62 if (databaseType != null) { 63 setDatabaseType(databaseType); 64 } 65 } 66 67 @Override 68 @SuppressWarnings({ "unchecked", "rawtypes" }) 69 public Object intercept(Invocation invocation) throws Throwable { 70 if (invocation.getTarget() instanceof StatementHandler) { // 控制SQL和查询总数的地方 71 Page page = pageThreadLocal.get(); 72 if (page == null) { //不是分页查询 73 return invocation.proceed(); 74 } 75 76 RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget(); 77 StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate"); 78 BoundSql boundSql = delegate.getBoundSql(); 79 80 Connection connection = (Connection) invocation.getArgs()[0]; 81 prepareAndCheckDatabaseType(connection); // 准备数据库类型 82 83 if (page.getTotalPage() > -1) { 84 if (log.isTraceEnabled()) { 85 log.trace("已经设置了总页数, 不需要再查询总数."); 86 } 87 } else { 88 Object parameterObj = boundSql.getParameterObject(); 89 MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement"); 90 queryTotalRecord(page, parameterObj, mappedStatement, connection); 91 } 92 93 String sql = boundSql.getSql(); 94 String pageSql = buildPageSql(page, sql); 95 if (log.isDebugEnabled()) { 96 log.debug("分页时, 生成分页pageSql: " + pageSql); 97 } 98 ReflectUtil.setFieldValue(boundSql, "sql", pageSql); 99 100 return invocation.proceed(); 101 } else { // 查询结果的地方 102 // 获取是否有分页Page对象 103 Page<?> page = findPageObject(invocation.getArgs()[1]); 104 if (page == null) { 105 if (log.isTraceEnabled()) { 106 log.trace("没有Page对象作为参数, 不是分页查询."); 107 } 108 return invocation.proceed(); 109 } else { 110 if (log.isTraceEnabled()) { 111 log.trace("检测到分页Page对象, 使用分页查询."); 112 } 113 } 114 //设置真正的parameterObj 115 invocation.getArgs()[1] = extractRealParameterObject(invocation.getArgs()[1]); 116 117 pageThreadLocal.set(page); 118 try { 119 Object resultObj = invocation.proceed(); // Executor.query(..) 120 if (resultObj instanceof List) { 121 /* @SuppressWarnings({ "unchecked", "rawtypes" }) */ 122 page.setResults((List) resultObj); 123 } 124 return resultObj; 125 } finally { 126 pageThreadLocal.remove(); 127 } 128 } 129 } 130 131 protected Page<?> findPageObject(Object parameterObj) { 132 if (parameterObj instanceof Page<?>) { 133 return (Page<?>) parameterObj; 134 } else if (parameterObj instanceof Map) { 135 for (Object val : ((Map<?, ?>) parameterObj).values()) { 136 if (val instanceof Page<?>) { 137 return (Page<?>) val; 138 } 139 } 140 } 141 return null; 142 } 143 144 /** 145 * <pre> 146 * 把真正的参数对象解析出来 147 * Spring会自动封装对个参数对象为Map<String, Object>对象 148 * 对于通过@Param指定key值参数我们不做处理,因为XML文件需要该KEY值 149 * 而对于没有@Param指定时,Spring会使用0,1作为主键 150 * 对于没有@Param指定名称的参数,一般XML文件会直接对真正的参数对象解析, 151 * 此时解析出真正的参数作为根对象 152 * </pre> 153 * @param parameterObj 154 * @return 155 */ 156 protected Object extractRealParameterObject(Object parameterObj) { 157 if (parameterObj instanceof Map<?, ?>) { 158 Map<?, ?> parameterMap = (Map<?, ?>) parameterObj; 159 if (parameterMap.size() == 2) { 160 boolean springMapWithNoParamName = true; 161 for (Object key : parameterMap.keySet()) { 162 if (!(key instanceof String)) { 163 springMapWithNoParamName = false; 164 break; 165 } 166 String keyStr = (String) key; 167 if (!"0".equals(keyStr) && !"1".equals(keyStr)) { 168 springMapWithNoParamName = false; 169 break; 170 } 171 } 172 if (springMapWithNoParamName) { 173 for (Object value : parameterMap.values()) { 174 if (!(value instanceof Page<?>)) { 175 return value; 176 } 177 } 178 } 179 } 180 } 181 return parameterObj; 182 } 183 184 protected void prepareAndCheckDatabaseType(Connection connection) throws SQLException { 185 if (databaseType == null) { 186 String productName = connection.getMetaData().getDatabaseProductName(); 187 if (log.isTraceEnabled()) { 188 log.trace("Database productName: " + productName); 189 } 190 productName = productName.toLowerCase(); 191 if (productName.indexOf(MYSQL) != -1) { 192 databaseType = MYSQL; 193 } else if (productName.indexOf(ORACLE) != -1) { 194 databaseType = ORACLE; 195 } else { 196 throw new PageNotSupportException("Page not support for the type of database, database product name [" + productName + "]"); 197 } 198 if (log.isInfoEnabled()) { 199 log.info("自动检测到的数据库类型为: " + databaseType); 200 } 201 } 202 } 203 204 /** 205 * <pre> 206 * 生成分页SQL 207 * </pre> 208 * 209 * @param page 210 * @param sql 211 * @return 212 */ 213 protected String buildPageSql(Page<?> page, String sql) { 214 if (MYSQL.equalsIgnoreCase(databaseType)) { 215 return buildMysqlPageSql(page, sql); 216 } else if (ORACLE.equalsIgnoreCase(databaseType)) { 217 return buildOraclePageSql(page, sql); 218 } 219 return sql; 220 } 221 222 /** 223 * <pre> 224 * 生成Mysql分页查询SQL 225 * </pre> 226 * 227 * @param page 228 * @param sql 229 * @return 230 */ 231 protected String buildMysqlPageSql(Page<?> page, String sql) { 232 // 计算第一条记录的位置,Mysql中记录的位置是从0开始的。 233 int offset = (page.getPageNo() - 1) * page.getPageSize(); 234 return new StringBuilder(sql).append(" limit ").append(offset).append(",").append(page.getPageSize()).toString(); 235 } 236 237 /** 238 * <pre> 239 * 生成Oracle分页查询SQL 240 * </pre> 241 * 242 * @param page 243 * @param sql 244 * @return 245 */ 246 protected String buildOraclePageSql(Page<?> page, String sql) { 247 // 计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的 248 int offset = (page.getPageNo() - 1) * page.getPageSize() + 1; 249 StringBuilder sb = new StringBuilder(sql); 250 sb.insert(0, "select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getPageSize()); 251 sb.insert(0, "select * from (").append(") where r >= ").append(offset); 252 return sb.toString(); 253 } 254 255 /** 256 * <pre> 257 * 查询总数 258 * </pre> 259 * 260 * @param page 261 * @param parameterObject 262 * @param mappedStatement 263 * @param connection 264 * @throws SQLException 265 */ 266 protected void queryTotalRecord(Page<?> page, Object parameterObject, MappedStatement mappedStatement, Connection connection) throws SQLException { 267 BoundSql boundSql = mappedStatement.getBoundSql(page); 268 String sql = boundSql.getSql(); 269 String countSql = this.buildCountSql(sql); 270 if (log.isDebugEnabled()) { 271 log.debug("分页时, 生成countSql: " + countSql); 272 } 273 274 List<ParameterMapping> parameterMappings = boundSql.getParameterMappings(); 275 BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject); 276 ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql); 277 PreparedStatement pstmt = null; 278 ResultSet rs = null; 279 try { 280 pstmt = connection.prepareStatement(countSql); 281 parameterHandler.setParameters(pstmt); 282 rs = pstmt.executeQuery(); 283 if (rs.next()) { 284 long totalRecord = rs.getLong(1); 285 page.setTotalRecord(totalRecord); 286 } 287 } finally { 288 if (rs != null) 289 try { 290 rs.close(); 291 } catch (Exception e) { 292 if (log.isWarnEnabled()) { 293 log.warn("关闭ResultSet时异常.", e); 294 } 295 } 296 if (pstmt != null) 297 try { 298 pstmt.close(); 299 } catch (Exception e) { 300 if (log.isWarnEnabled()) { 301 log.warn("关闭PreparedStatement时异常.", e); 302 } 303 } 304 } 305 } 306 307 /** 308 * 根据原Sql语句获取对应的查询总记录数的Sql语句 309 * 310 * @param sql 311 * @return 312 */ 313 protected String buildCountSql(String sql) { 314 int index = sql.indexOf("from"); 315 return "select count(*) " + sql.substring(index); 316 } 317 318 /** 319 * 利用反射进行操作的一个工具类 320 * 321 */ 322 private static class ReflectUtil { 323 /** 324 * 利用反射获取指定对象的指定属性 325 * 326 * @param obj 目标对象 327 * @param fieldName 目标属性 328 * @return 目标属性的值 329 */ 330 public static Object getFieldValue(Object obj, String fieldName) { 331 Object result = null; 332 Field field = ReflectUtil.getField(obj, fieldName); 333 if (field != null) { 334 field.setAccessible(true); 335 try { 336 result = field.get(obj); 337 } catch (IllegalArgumentException e) { 338 // TODO Auto-generated catch block 339 e.printStackTrace(); 340 } catch (IllegalAccessException e) { 341 // TODO Auto-generated catch block 342 e.printStackTrace(); 343 } 344 } 345 return result; 346 } 347 348 /** 349 * 利用反射获取指定对象里面的指定属性 350 * 351 * @param obj 目标对象 352 * @param fieldName 目标属性 353 * @return 目标字段 354 */ 355 private static Field getField(Object obj, String fieldName) { 356 Field field = null; 357 for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) { 358 try { 359 field = clazz.getDeclaredField(fieldName); 360 break; 361 } catch (NoSuchFieldException e) { 362 // 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。 363 } 364 } 365 return field; 366 } 367 368 /** 369 * 利用反射设置指定对象的指定属性为指定的值 370 * 371 * @param obj 目标对象 372 * @param fieldName 目标属性 373 * @param fieldValue 目标值 374 */ 375 public static void setFieldValue(Object obj, String fieldName, String fieldValue) { 376 Field field = ReflectUtil.getField(obj, fieldName); 377 if (field != null) { 378 try { 379 field.setAccessible(true); 380 field.set(obj, fieldValue); 381 } catch (IllegalArgumentException e) { 382 // TODO Auto-generated catch block 383 e.printStackTrace(); 384 } catch (IllegalAccessException e) { 385 // TODO Auto-generated catch block 386 e.printStackTrace(); 387 } 388 } 389 } 390 } 391 392 public static class PageNotSupportException extends RuntimeException { 393 394 /** serialVersionUID*/ 395 private static final long serialVersionUID = 1L; 396 397 public PageNotSupportException() { 398 super(); 399 } 400 401 public PageNotSupportException(String message, Throwable cause) { 402 super(message, cause); 403 } 404 405 public PageNotSupportException(String message) { 406 super(message); 407 } 408 409 public PageNotSupportException(Throwable cause) { 410 super(cause); 411 } 412 } 413 414 }

浙公网安备 33010602011771号