ThreadPoolTaskExecutor结合CompletableFuture执行批量任务并记录每个任务状态到数据库,只有全部成功才将这一整批数据发送给下游,超时时间10分钟
核心需求:
实现一个基于线程池的批量任务执行器,具备以下功能特性:
-
任务并发执行:将一批任务并发提交至线程池执行
-
执行状态监控:实时监控所有任务的执行状态,包括成功、失败及异常信息
-
超时控制机制:设置执行超时阈值(10分钟),超时后中断未完成任务的执行
-
持久化状态记录:
- 将每个任务的执行结果(成功/失败状态及异常信息)持久化至数据库
- 记录批次整体执行状态和统计信息
-
条件性下游通知:
- 仅当批次内所有任务均执行成功时,触发消息队列推送
- 执行失败或超时的批次不进行下游通知
-
事务一致性:确保任务执行结果与数据库状态记录的一致性
技术约束:
- 基于Java线程池(ThreadPoolExecutor)实现
- 支持异常隔离,单个任务失败不影响其他任务执行
- 需要具备优雅的资源清理机制
- 支持批次执行结果的详细统计和状态查询
业务场景:
典型的批量数据处理场景,要求高并发、高可靠性,并具备完整的执行审计和下游系统解耦能力。
1.用ThreadPoolTaskExecutor 和 CompletableFuture方式实现:
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.;
import java.util.concurrent.;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
public class BatchTaskExecutor {
private final ThreadPoolTaskExecutor taskExecutor;
private final DatabaseService databaseService;
private final MQService mqService;
public BatchTaskExecutor(ThreadPoolTaskExecutor taskExecutor,
DatabaseService databaseService,
MQService mqService) {
this.taskExecutor = taskExecutor;
this.databaseService = databaseService;
this.mqService = mqService;
}
/**
* 执行批量任务
* @param taskDataList 任务数据列表
* @param batchId 批次ID
* @param timeoutMinutes 超时时间(分钟)
* @return 执行结果
*/
public CompletableFuture<BatchExecutionResult> executeBatchAsync(List<TaskData> taskDataList,
String batchId,
int timeoutMinutes) {
// 创建所有任务的CompletableFuture
List<CompletableFuture<TaskResult>> taskFutures = taskDataList.stream()
.map(taskData -> createTaskFuture(taskData))
.collect(Collectors.toList());
// 组合所有任务的Future
CompletableFuture<List<TaskResult>> allTasksFuture =
CompletableFuture.allOf(taskFutures.toArray(new CompletableFuture[0]))
.thenApply(v -> taskFutures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList()));
// 添加超时控制
CompletableFuture<List<TaskResult>> timeoutFuture = addTimeoutControl(
allTasksFuture, taskFutures, timeoutMinutes);
// 处理执行结果
return timeoutFuture.thenCompose(taskResults ->
processBatchResult(taskResults, batchId, taskDataList));
}
/**
* 同步执行批量任务
*/
public BatchExecutionResult executeBatch(List<TaskData> taskDataList,
String batchId,
int timeoutMinutes) {
try {
return executeBatchAsync(taskDataList, batchId, timeoutMinutes).get();
} catch (Exception e) {
throw new RuntimeException("批量任务执行失败", e);
}
}
private CompletableFuture<TaskResult> createTaskFuture(TaskData taskData) {
return CompletableFuture
.supplyAsync(() -> {
try {
// 执行具体业务逻辑
executeBusinessLogic(taskData);
return new TaskResult(taskData.getId(), true, null);
} catch (Exception e) {
return new TaskResult(taskData.getId(), false, e.getMessage());
}
}, taskExecutor.getThreadPoolExecutor())
.exceptionally(throwable ->
new TaskResult(taskData.getId(), false, throwable.getMessage()));
}
private CompletableFuture<List<TaskResult>> addTimeoutControl(
CompletableFuture<List<TaskResult>> allTasksFuture,
List<CompletableFuture<TaskResult>> taskFutures,
int timeoutMinutes) {
// 创建超时Future
CompletableFuture<List<TaskResult>> timeoutFuture = new CompletableFuture<>();
// 设置超时处理
ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
scheduler.schedule(() -> {
if (!timeoutFuture.isDone()) {
// 取消所有未完成的任务
taskFutures.forEach(future -> future.cancel(true));
// 收集已完成的结果,未完成的标记为超时
List<TaskResult> results = new ArrayList<>();
for (CompletableFuture<TaskResult> future : taskFutures) {
if (future.isDone() && !future.isCancelled()) {
try {
results.add(future.get());
} catch (Exception e) {
results.add(new TaskResult("unknown", false, "任务执行异常"));
}
} else {
results.add(new TaskResult("unknown", false, "任务执行超时"));
}
}
timeoutFuture.complete(results);
}
scheduler.shutdown();
}, timeoutMinutes, TimeUnit.MINUTES);
// 正常完成的情况
allTasksFuture.whenComplete((results, throwable) -> {
if (!timeoutFuture.isDone()) {
if (throwable != null) {
timeoutFuture.completeExceptionally(throwable);
} else {
timeoutFuture.complete(results);
}
}
scheduler.shutdown();
});
return timeoutFuture;
}
private CompletableFuture<BatchExecutionResult> processBatchResult(
List<TaskResult> taskResults, String batchId, List<TaskData> taskDataList) {
return CompletableFuture.supplyAsync(() -> {
// 统计执行结果
long successCount = taskResults.stream().mapToLong(r -> r.isSuccess() ? 1 : 0).sum();
long failureCount = taskResults.size() - successCount;
boolean allSuccess = successCount == taskResults.size();
// 创建批次执行结果
BatchExecutionResult batchResult = new BatchExecutionResult(
batchId, allSuccess, true, (int)successCount, (int)failureCount, taskResults);
// 异步更新数据库状态
updateDatabaseStatusAsync(batchResult);
// 如果全部成功,推送到MQ
if (allSuccess) {
mqService.sendBatchAsync(batchId, taskDataList);
}
return batchResult;
}, taskExecutor.getThreadPoolExecutor());
}
private void updateDatabaseStatusAsync(BatchExecutionResult batchResult) {
// 异步更新任务状态
List<CompletableFuture<Void>> updateFutures = batchResult.getTaskResults().stream()
.map(taskResult -> CompletableFuture.runAsync(() -> {
String status = taskResult.isSuccess() ? "SUCCESS" : "FAILED";
databaseService.updateTaskStatus(taskResult.getTaskId(), status, taskResult.getErrorMessage());
}, taskExecutor.getThreadPoolExecutor()))
.collect(Collectors.toList());
// 等待所有任务状态更新完成后,更新批次状态
CompletableFuture.allOf(updateFutures.toArray(new CompletableFuture[0]))
.thenRunAsync(() -> {
String batchStatus = batchResult.isAllSuccess() ? "SUCCESS" : "FAILED";
databaseService.updateBatchStatus(batchResult.getBatchId(), batchStatus,
batchResult.getSuccessCount(), batchResult.getFailureCount());
}, taskExecutor.getThreadPoolExecutor());
}
/**
* 批量任务执行,支持任务分组和限流
*/
public CompletableFuture<List<BatchExecutionResult>> executeBatchesWithLimit(
List<List<TaskData>> batchGroups,
String baseBatchId,
int timeoutMinutes,
int concurrentBatchLimit) {
AtomicInteger batchCounter = new AtomicInteger(0);
// 创建信号量控制并发批次数
Semaphore semaphore = new Semaphore(concurrentBatchLimit);
List<CompletableFuture<BatchExecutionResult>> batchFutures = batchGroups.stream()
.map(taskDataList -> CompletableFuture
.supplyAsync(() -> {
try {
semaphore.acquire();
String batchId = baseBatchId + "_" + batchCounter.incrementAndGet();
return executeBatch(taskDataList, batchId, timeoutMinutes);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("批次执行被中断", e);
} finally {
semaphore.release();
}
}, taskExecutor.getThreadPoolExecutor()))
.collect(Collectors.toList());
return CompletableFuture.allOf(batchFutures.toArray(new CompletableFuture[0]))
.thenApply(v -> batchFutures.stream()
.map(CompletableFuture::join)
.collect(Collectors.toList()));
}
// 具体的业务逻辑执行方法
private void executeBusinessLogic(TaskData taskData) throws Exception {
// 实现具体的业务逻辑
Thread.sleep(100); // 模拟处理时间
// 模拟可能的业务异常
if (taskData.getId().contains("error")) {
throw new RuntimeException("Business logic failed for task: " + taskData.getId());
}
}
}
// 任务数据
class TaskData {
private String id;
private Object data;
public TaskData(String id, Object data) {
this.id = id;
this.data = data;
}
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public Object getData() { return data; }
public void setData(Object data) { this.data = data; }
}
// 单个任务结果
class TaskResult {
private String taskId;
private boolean success;
private String errorMessage;
public TaskResult(String taskId, boolean success, String errorMessage) {
this.taskId = taskId;
this.success = success;
this.errorMessage = errorMessage;
}
public String getTaskId() { return taskId; }
public boolean isSuccess() { return success; }
public String getErrorMessage() { return errorMessage; }
}
// 批次执行结果
class BatchExecutionResult {
private String batchId;
private boolean allSuccess;
private boolean completedWithinTimeout;
private int successCount;
private int failureCount;
private List
public BatchExecutionResult(String batchId, boolean allSuccess, boolean completedWithinTimeout,
int successCount, int failureCount, List<TaskResult> taskResults) {
this.batchId = batchId;
this.allSuccess = allSuccess;
this.completedWithinTimeout = completedWithinTimeout;
this.successCount = successCount;
this.failureCount = failureCount;
this.taskResults = taskResults;
}
public String getBatchId() { return batchId; }
public boolean isAllSuccess() { return allSuccess; }
public boolean isCompletedWithinTimeout() { return completedWithinTimeout; }
public int getSuccessCount() { return successCount; }
public int getFailureCount() { return failureCount; }
public List<TaskResult> getTaskResults() { return taskResults; }
}
// 数据库服务接口
interface DatabaseService {
void updateTaskStatus(String taskId, String status, String errorMessage);
void updateBatchStatus(String batchId, String status, int successCount, int failureCount);
}
// MQ服务接口
interface MQService {
void sendBatchAsync(String batchId, List
}
// Spring配置示例
@Configuration
class ThreadPoolConfig {
@Bean("batchTaskExecutor")
public ThreadPoolTaskExecutor batchTaskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(10);
executor.setMaxPoolSize(20);
executor.setQueueCapacity(500);
executor.setKeepAliveSeconds(60);
executor.setThreadNamePrefix("BatchTask-");
executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
executor.setWaitForTasksToCompleteOnShutdown(true);
executor.setAwaitTerminationSeconds(60);
executor.initialize();
return executor;
}
}
// 使用示例
@Service
class BatchProcessingService {
@Autowired
private BatchTaskExecutor batchTaskExecutor;
public void processBatch() {
List<TaskData> taskDataList = Arrays.asList(
new TaskData("task1", "data1"),
new TaskData("task2", "data2"),
new TaskData("task3", "data3")
);
// 异步执行
batchTaskExecutor.executeBatchAsync(taskDataList, "batch001", 10)
.thenAccept(result -> {
if (result.isAllSuccess()) {
System.out.println("批次执行成功,已推送到MQ");
} else {
System.out.println("批次执行失败,成功:" + result.getSuccessCount() +
", 失败:" + result.getFailureCount());
}
})
.exceptionally(throwable -> {
System.err.println("批次执行异常: " + throwable.getMessage());
return null;
});
// 同步执行
BatchExecutionResult result = batchTaskExecutor.executeBatch(taskDataList, "batch002", 10);
System.out.println("同步执行结果: " + result.isAllSuccess());
}
}
2.如果想了解背后是如何实现的,可以看如下传统方式实现:
import java.util.;
import java.util.concurrent.;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
public class BatchTaskExecutor {
private final ThreadPoolExecutor threadPool;
private final DatabaseService databaseService;
private final MQService mqService;
public BatchTaskExecutor(ThreadPoolExecutor threadPool,
DatabaseService databaseService,
MQService mqService) {
this.threadPool = threadPool;
this.databaseService = databaseService;
this.mqService = mqService;
}
/**
* 执行批量任务
* @param taskDataList 任务数据列表
* @param batchId 批次ID
* @param timeoutMinutes 超时时间(分钟)
* @return 执行结果
*/
public BatchExecutionResult executeBatch(List<TaskData> taskDataList,
String batchId,
int timeoutMinutes) {
// 初始化结果追踪
BatchResultTracker tracker = new BatchResultTracker(taskDataList.size(), batchId);
// 提交所有任务
List<Future<TaskResult>> futures = new ArrayList<>();
for (int i = 0; i < taskDataList.size(); i++) {
TaskData taskData = taskDataList.get(i);
int taskIndex = i;
Future<TaskResult> future = threadPool.submit(() -> {
try {
// 执行具体的业务逻辑
executeBusinessLogic(taskData);
TaskResult result = new TaskResult(taskData.getId(), true, null, taskIndex);
tracker.recordResult(result);
return result;
} catch (Exception e) {
TaskResult result = new TaskResult(taskData.getId(), false, e.getMessage(), taskIndex);
tracker.recordResult(result);
return result;
}
});
futures.add(future);
}
// 等待所有任务完成或超时
BatchExecutionResult batchResult = waitForCompletion(futures, tracker, timeoutMinutes);
// 更新数据库状态
updateDatabaseStatus(batchResult);
// 如果全部成功,推送到MQ
if (batchResult.isAllSuccess()) {
mqService.sendBatch(batchId, taskDataList);
}
return batchResult;
}
private BatchExecutionResult waitForCompletion(List<Future<TaskResult>> futures,
BatchResultTracker tracker,
int timeoutMinutes) {
long timeoutMillis = TimeUnit.MINUTES.toMillis(timeoutMinutes);
long startTime = System.currentTimeMillis();
try {
// 等待所有任务完成或超时
boolean allCompleted = true;
for (Future<TaskResult> future : futures) {
long remainingTime = timeoutMillis - (System.currentTimeMillis() - startTime);
if (remainingTime <= 0) {
allCompleted = false;
break;
}
try {
future.get(remainingTime, TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
allCompleted = false;
break;
} catch (ExecutionException | InterruptedException e) {
// 任务执行异常,已在任务内部处理
}
}
// 如果超时,取消未完成的任务
if (!allCompleted) {
for (Future<TaskResult> future : futures) {
if (!future.isDone()) {
future.cancel(true);
}
}
}
return tracker.getBatchResult(allCompleted);
} catch (Exception e) {
// 取消所有未完成的任务
futures.forEach(future -> future.cancel(true));
return tracker.getBatchResult(false);
}
}
private void updateDatabaseStatus(BatchExecutionResult batchResult) {
for (TaskResult taskResult : batchResult.getTaskResults()) {
if (taskResult.isSuccess()) {
databaseService.updateTaskStatus(taskResult.getTaskId(), "SUCCESS", null);
} else {
databaseService.updateTaskStatus(taskResult.getTaskId(), "FAILED", taskResult.getErrorMessage());
}
}
// 更新批次整体状态
String batchStatus = batchResult.isAllSuccess() ? "SUCCESS" : "FAILED";
databaseService.updateBatchStatus(batchResult.getBatchId(), batchStatus,
batchResult.getSuccessCount(), batchResult.getFailureCount());
}
// 具体的业务逻辑执行方法,需要根据实际业务实现
private void executeBusinessLogic(TaskData taskData) throws Exception {
// 这里实现具体的业务逻辑
// 如果执行失败,抛出异常
// 示例:模拟业务处理
Thread.sleep(100); // 模拟处理时间
// 模拟可能的业务异常
if (taskData.getId().contains("error")) {
throw new RuntimeException("Business logic failed for task: " + taskData.getId());
}
}
}
// 批量结果追踪器
class BatchResultTracker {
private final int totalTasks;
private final String batchId;
private final AtomicInteger completedCount = new AtomicInteger(0);
private final AtomicInteger successCount = new AtomicInteger(0);
private final Map<Integer, TaskResult> results = new ConcurrentHashMap<>();
public BatchResultTracker(int totalTasks, String batchId) {
this.totalTasks = totalTasks;
this.batchId = batchId;
}
public void recordResult(TaskResult result) {
results.put(result.getTaskIndex(), result);
completedCount.incrementAndGet();
if (result.isSuccess()) {
successCount.incrementAndGet();
}
}
public BatchExecutionResult getBatchResult(boolean completedWithinTimeout) {
List<TaskResult> taskResults = new ArrayList<>();
for (int i = 0; i < totalTasks; i++) {
TaskResult result = results.get(i);
if (result != null) {
taskResults.add(result);
} else {
// 未完成的任务标记为超时失败
taskResults.add(new TaskResult("unknown", false, "Task timeout or cancelled", i));
}
}
int finalSuccessCount = successCount.get();
int finalFailureCount = totalTasks - finalSuccessCount;
boolean allSuccess = completedWithinTimeout && (finalSuccessCount == totalTasks);
return new BatchExecutionResult(batchId, allSuccess, completedWithinTimeout,
finalSuccessCount, finalFailureCount, taskResults);
}
}
// 任务数据
class TaskData {
private String id;
private Object data; // 具体的业务数据
public TaskData(String id, Object data) {
this.id = id;
this.data = data;
}
// getters and setters
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public Object getData() { return data; }
public void setData(Object data) { this.data = data; }
}
// 单个任务结果
class TaskResult {
private String taskId;
private boolean success;
private String errorMessage;
private int taskIndex;
public TaskResult(String taskId, boolean success, String errorMessage, int taskIndex) {
this.taskId = taskId;
this.success = success;
this.errorMessage = errorMessage;
this.taskIndex = taskIndex;
}
// getters
public String getTaskId() { return taskId; }
public boolean isSuccess() { return success; }
public String getErrorMessage() { return errorMessage; }
public int getTaskIndex() { return taskIndex; }
}
// 批次执行结果
class BatchExecutionResult {
private String batchId;
private boolean allSuccess;
private boolean completedWithinTimeout;
private int successCount;
private int failureCount;
private List
public BatchExecutionResult(String batchId, boolean allSuccess, boolean completedWithinTimeout,
int successCount, int failureCount, List<TaskResult> taskResults) {
this.batchId = batchId;
this.allSuccess = allSuccess;
this.completedWithinTimeout = completedWithinTimeout;
this.successCount = successCount;
this.failureCount = failureCount;
this.taskResults = taskResults;
}
// getters
public String getBatchId() { return batchId; }
public boolean isAllSuccess() { return allSuccess; }
public boolean isCompletedWithinTimeout() { return completedWithinTimeout; }
public int getSuccessCount() { return successCount; }
public int getFailureCount() { return failureCount; }
public List<TaskResult> getTaskResults() { return taskResults; }
}
// 数据库服务接口(需要实现)
interface DatabaseService {
void updateTaskStatus(String taskId, String status, String errorMessage);
void updateBatchStatus(String batchId, String status, int successCount, int failureCount);
}
// MQ服务接口(需要实现)
interface MQService {
void sendBatch(String batchId, List
}
// 使用示例
class Example {
public void example() {
// 创建线程池
ThreadPoolExecutor threadPool = new ThreadPoolExecutor(
5, 10, 60L, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(100),
new ThreadPoolExecutor.CallerRunsPolicy()
);
// 创建服务实例(需要实际实现)
DatabaseService databaseService = new DatabaseServiceImpl();
MQService mqService = new MQServiceImpl();
// 创建批量任务执行器
BatchTaskExecutor executor = new BatchTaskExecutor(threadPool, databaseService, mqService);
// 准备任务数据
List<TaskData> taskDataList = Arrays.asList(
new TaskData("task1", "data1"),
new TaskData("task2", "data2"),
new TaskData("task3", "data3")
);
// 执行批量任务,超时时间10分钟
BatchExecutionResult result = executor.executeBatch(taskDataList, "batch001", 10);
// 检查结果
if (result.isAllSuccess()) {
System.out.println("所有任务执行成功,已推送到MQ");
} else {
System.out.println("批次执行失败,成功:" + result.getSuccessCount() +
", 失败:" + result.getFailureCount());
}
}
}
// 数据库服务实现示例
class DatabaseServiceImpl implements DatabaseService {
@Override
public void updateTaskStatus(String taskId, String status, String errorMessage) {
// 实现数据库更新逻辑
System.out.println("更新任务状态: " + taskId + " -> " + status);
}
@Override
public void updateBatchStatus(String batchId, String status, int successCount, int failureCount) {
// 实现批次状态更新逻辑
System.out.println("更新批次状态: " + batchId + " -> " + status);
}
}
// MQ服务实现示例
class MQServiceImpl implements MQService {
@Override
public void sendBatch(String batchId, List
// 实现MQ推送逻辑
System.out.println("推送批次到MQ: " + batchId + ", 任务数: " + taskDataList.size());
}
}
浙公网安备 33010602011771号