mybatis plus拦截器应用-增加汇总行

代码

package com.jhict.common.data.interceptor;

import com.jhict.common.core.util.StringUtil;
import com.jhict.common.data.annotation.SqlLogAnnotation;
import com.jhict.common.data.annotation.TpcoPageSummaryAnnotation;
import com.jhict.common.data.entity.SqlLogParam;
import com.jhict.common.data.entity.TpQuerySummaryParam;
import com.jhict.common.data.mybatis.entity.JhPage;
import com.jhict.common.entity.TpcoPage;
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import org.apache.ibatis.builder.StaticSqlSource;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@Log4j2
@Aspect
@Intercepts(@Signature(type = Executor.class, method = "query",
	args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}))
@RequiredArgsConstructor
public class TpSummaryInterceptor implements Interceptor {
	private final ThreadLocal<TpQuerySummaryParam> threadLocal = new ThreadLocal<>();

	public void setThreadLocal(TpQuerySummaryParam param) {
		threadLocal.set(param);
	}

	@Around("@annotation(pageSummaryAnnotation)")
	private Object doBefore(ProceedingJoinPoint joinPoint, TpcoPageSummaryAnnotation pageSummaryAnnotation)
		throws Throwable {
		if (threadLocal.get() == null) {
			TpQuerySummaryParam param = new TpQuerySummaryParam();
			param.setColumns(pageSummaryAnnotation.value());
			threadLocal.set(param);
		}
		Object proceed = joinPoint.proceed();
		threadLocal.remove();
		return proceed;
	}

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		Object target = invocation.getTarget();
		TpcoPage page = null;
		Object[] args = invocation.getArgs();
		Object parameter = args[1];
		try {
			// 执行原始方法,并获取返回结果
			Object proceed = invocation.proceed();

			// 获取汇总
			TpQuerySummaryParam param = threadLocal.get();
			if (param != null && param.getColumns().length > 0 &&
				((Map<?, ?>)parameter).get("arg0") instanceof TpcoPage && parameter instanceof Map &&
				target instanceof Executor) {
				page = (TpcoPage)((Map<?, ?>)parameter).get("arg0");
				// 执行汇总查询
				MappedStatement mappedStatement = (MappedStatement)args[0];
				Executor executor = (Executor)target;

				Map<String, Object> paramMap = new HashMap<>();
				((Map<?, ?>)parameter).forEach((k, v) -> {
					if (!(v instanceof JhPage)) {
						paramMap.put(k.toString(), v);
					}
				});
				page.setSummary(executeSummaryQuery(mappedStatement, paramMap, executor));
			}
			return proceed;
		}
		catch (Exception e) {
			e.printStackTrace();
			throw e;
		}
	}

	private Object executeSummaryQuery(MappedStatement originalMs, Object parameter,
		Executor executor) throws SQLException {
		// 1. 创建新的MappedStatement ID
		String summaryStatementId = originalMs.getId() + "-summary";

		// 2. 获取原始SQL并转换为汇总SQL
		BoundSql originalBoundSql = originalMs.getBoundSql(parameter);
		String originalSql = originalBoundSql.getSql();
		String summarySql = convertToSummarySql(originalSql);

		// 3. 创建新的SqlSource
		SqlSource sqlSource = new StaticSqlSource(
			originalMs.getConfiguration(),
			summarySql,
			originalBoundSql.getParameterMappings()
		);

		// 4. 克隆并修改MappedStatement
		MappedStatement summaryMs = cloneMappedStatement(originalMs, summaryStatementId, sqlSource);

		// 5. 执行汇总查询
		List<Object> result = executor.query(
			summaryMs,
			parameter,
			RowBounds.DEFAULT,
			Executor.NO_RESULT_HANDLER,
			CacheKey.NULL_CACHE_KEY,
			sqlSource.getBoundSql(parameter)
		);

		return result.isEmpty() ? null : result.get(0);
	}

	private String convertToSummarySql(String originalSql) {
		// 这里实现将普通查询SQL转换为汇总SQL的逻辑
		// 示例:将查询改为只返回count和sum
		TpQuerySummaryParam param = threadLocal.get();
		String[] columns = param.getColumns();
		StringBuffer sb = new StringBuffer("select ");
		for (String column : columns) {
			sb.append(" sum(" + column + ") " + column + ",");
		}
		String summarySql = sb.toString();
		if (summarySql.endsWith(",")) {
			summarySql = summarySql.substring(0, summarySql.length() - 1);
		}
		return summarySql + " from (" + originalSql + ") t";
	}

	private MappedStatement cloneMappedStatement(MappedStatement ms, String newId, SqlSource sqlSource) {
		MappedStatement.Builder builder = new MappedStatement.Builder(
			ms.getConfiguration(),
			newId,
			sqlSource,
			ms.getSqlCommandType()
		);

		builder.resource(ms.getResource());
		builder.fetchSize(ms.getFetchSize());
		builder.statementType(ms.getStatementType());
		builder.keyGenerator(ms.getKeyGenerator());
		if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
			builder.keyProperty(String.join(",", ms.getKeyProperties()));
		}
		builder.timeout(ms.getTimeout());
		builder.parameterMap(ms.getParameterMap());
		builder.resultMaps(ms.getResultMaps());
		builder.resultSetType(ms.getResultSetType());
		builder.cache(ms.getCache());
		builder.flushCacheRequired(ms.isFlushCacheRequired());
		builder.useCache(ms.isUseCache());
		builder.databaseId(ms.getDatabaseId());

		return builder.build();
	}


}

posted @ 2025-04-10 19:21  fight139  阅读(12)  评论(0)    收藏  举报