ThreadPoolTaskExecutor结合CompletableFuture执行批量任务并记录每个任务状态到数据库,只有全部成功才将这一整批数据发送给下游,超时时间10分钟
核心需求:
实现一个基于线程池的批量任务执行器,具备以下功能特性:
- 
任务并发执行:将一批任务并发提交至线程池执行
 - 
执行状态监控:实时监控所有任务的执行状态,包括成功、失败及异常信息
 - 
超时控制机制:设置执行超时阈值(10分钟),超时后中断未完成任务的执行
 - 
持久化状态记录:
- 将每个任务的执行结果(成功/失败状态及异常信息)持久化至数据库
 - 记录批次整体执行状态和统计信息
 
 - 
条件性下游通知:
- 仅当批次内所有任务均执行成功时,触发消息队列推送
 - 执行失败或超时的批次不进行下游通知
 
 - 
事务一致性:确保任务执行结果与数据库状态记录的一致性
 
技术约束:
- 基于Java线程池(ThreadPoolExecutor)实现
 - 支持异常隔离,单个任务失败不影响其他任务执行
 - 需要具备优雅的资源清理机制
 - 支持批次执行结果的详细统计和状态查询
 
业务场景:
典型的批量数据处理场景,要求高并发、高可靠性,并具备完整的执行审计和下游系统解耦能力。
1.用ThreadPoolTaskExecutor 和 CompletableFuture方式实现:
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import java.util.;
import java.util.concurrent.;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
public class BatchTaskExecutor {
private final ThreadPoolTaskExecutor taskExecutor;
private final DatabaseService databaseService;
private final MQService mqService;
public BatchTaskExecutor(ThreadPoolTaskExecutor taskExecutor, 
                       DatabaseService databaseService, 
                       MQService mqService) {
    this.taskExecutor = taskExecutor;
    this.databaseService = databaseService;
    this.mqService = mqService;
}
/**
 * 执行批量任务
 * @param taskDataList 任务数据列表
 * @param batchId 批次ID
 * @param timeoutMinutes 超时时间(分钟)
 * @return 执行结果
 */
public CompletableFuture<BatchExecutionResult> executeBatchAsync(List<TaskData> taskDataList, 
                                                               String batchId, 
                                                               int timeoutMinutes) {
    
    // 创建所有任务的CompletableFuture
    List<CompletableFuture<TaskResult>> taskFutures = taskDataList.stream()
        .map(taskData -> createTaskFuture(taskData))
        .collect(Collectors.toList());
    
    // 组合所有任务的Future
    CompletableFuture<List<TaskResult>> allTasksFuture = 
        CompletableFuture.allOf(taskFutures.toArray(new CompletableFuture[0]))
            .thenApply(v -> taskFutures.stream()
                .map(CompletableFuture::join)
                .collect(Collectors.toList()));
    
    // 添加超时控制
    CompletableFuture<List<TaskResult>> timeoutFuture = addTimeoutControl(
        allTasksFuture, taskFutures, timeoutMinutes);
    
    // 处理执行结果
    return timeoutFuture.thenCompose(taskResults -> 
        processBatchResult(taskResults, batchId, taskDataList));
}
/**
 * 同步执行批量任务
 */
public BatchExecutionResult executeBatch(List<TaskData> taskDataList, 
                                       String batchId, 
                                       int timeoutMinutes) {
    try {
        return executeBatchAsync(taskDataList, batchId, timeoutMinutes).get();
    } catch (Exception e) {
        throw new RuntimeException("批量任务执行失败", e);
    }
}
private CompletableFuture<TaskResult> createTaskFuture(TaskData taskData) {
    return CompletableFuture
        .supplyAsync(() -> {
            try {
                // 执行具体业务逻辑
                executeBusinessLogic(taskData);
                return new TaskResult(taskData.getId(), true, null);
            } catch (Exception e) {
                return new TaskResult(taskData.getId(), false, e.getMessage());
            }
        }, taskExecutor.getThreadPoolExecutor())
        .exceptionally(throwable -> 
            new TaskResult(taskData.getId(), false, throwable.getMessage()));
}
private CompletableFuture<List<TaskResult>> addTimeoutControl(
        CompletableFuture<List<TaskResult>> allTasksFuture,
        List<CompletableFuture<TaskResult>> taskFutures,
        int timeoutMinutes) {
    
    // 创建超时Future
    CompletableFuture<List<TaskResult>> timeoutFuture = new CompletableFuture<>();
    
    // 设置超时处理
    ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
    scheduler.schedule(() -> {
        if (!timeoutFuture.isDone()) {
            // 取消所有未完成的任务
            taskFutures.forEach(future -> future.cancel(true));
            
            // 收集已完成的结果,未完成的标记为超时
            List<TaskResult> results = new ArrayList<>();
            for (CompletableFuture<TaskResult> future : taskFutures) {
                if (future.isDone() && !future.isCancelled()) {
                    try {
                        results.add(future.get());
                    } catch (Exception e) {
                        results.add(new TaskResult("unknown", false, "任务执行异常"));
                    }
                } else {
                    results.add(new TaskResult("unknown", false, "任务执行超时"));
                }
            }
            timeoutFuture.complete(results);
        }
        scheduler.shutdown();
    }, timeoutMinutes, TimeUnit.MINUTES);
    
    // 正常完成的情况
    allTasksFuture.whenComplete((results, throwable) -> {
        if (!timeoutFuture.isDone()) {
            if (throwable != null) {
                timeoutFuture.completeExceptionally(throwable);
            } else {
                timeoutFuture.complete(results);
            }
        }
        scheduler.shutdown();
    });
    
    return timeoutFuture;
}
private CompletableFuture<BatchExecutionResult> processBatchResult(
        List<TaskResult> taskResults, String batchId, List<TaskData> taskDataList) {
    
    return CompletableFuture.supplyAsync(() -> {
        // 统计执行结果
        long successCount = taskResults.stream().mapToLong(r -> r.isSuccess() ? 1 : 0).sum();
        long failureCount = taskResults.size() - successCount;
        boolean allSuccess = successCount == taskResults.size();
        
        // 创建批次执行结果
        BatchExecutionResult batchResult = new BatchExecutionResult(
            batchId, allSuccess, true, (int)successCount, (int)failureCount, taskResults);
        
        // 异步更新数据库状态
        updateDatabaseStatusAsync(batchResult);
        
        // 如果全部成功,推送到MQ
        if (allSuccess) {
            mqService.sendBatchAsync(batchId, taskDataList);
        }
        
        return batchResult;
    }, taskExecutor.getThreadPoolExecutor());
}
private void updateDatabaseStatusAsync(BatchExecutionResult batchResult) {
    // 异步更新任务状态
    List<CompletableFuture<Void>> updateFutures = batchResult.getTaskResults().stream()
        .map(taskResult -> CompletableFuture.runAsync(() -> {
            String status = taskResult.isSuccess() ? "SUCCESS" : "FAILED";
            databaseService.updateTaskStatus(taskResult.getTaskId(), status, taskResult.getErrorMessage());
        }, taskExecutor.getThreadPoolExecutor()))
        .collect(Collectors.toList());
    
    // 等待所有任务状态更新完成后,更新批次状态
    CompletableFuture.allOf(updateFutures.toArray(new CompletableFuture[0]))
        .thenRunAsync(() -> {
            String batchStatus = batchResult.isAllSuccess() ? "SUCCESS" : "FAILED";
            databaseService.updateBatchStatus(batchResult.getBatchId(), batchStatus,
                batchResult.getSuccessCount(), batchResult.getFailureCount());
        }, taskExecutor.getThreadPoolExecutor());
}
/**
 * 批量任务执行,支持任务分组和限流
 */
public CompletableFuture<List<BatchExecutionResult>> executeBatchesWithLimit(
        List<List<TaskData>> batchGroups, 
        String baseBatchId,
        int timeoutMinutes,
        int concurrentBatchLimit) {
    
    AtomicInteger batchCounter = new AtomicInteger(0);
    
    // 创建信号量控制并发批次数
    Semaphore semaphore = new Semaphore(concurrentBatchLimit);
    
    List<CompletableFuture<BatchExecutionResult>> batchFutures = batchGroups.stream()
        .map(taskDataList -> CompletableFuture
            .supplyAsync(() -> {
                try {
                    semaphore.acquire();
                    String batchId = baseBatchId + "_" + batchCounter.incrementAndGet();
                    return executeBatch(taskDataList, batchId, timeoutMinutes);
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new RuntimeException("批次执行被中断", e);
                } finally {
                    semaphore.release();
                }
            }, taskExecutor.getThreadPoolExecutor()))
        .collect(Collectors.toList());
    
    return CompletableFuture.allOf(batchFutures.toArray(new CompletableFuture[0]))
        .thenApply(v -> batchFutures.stream()
            .map(CompletableFuture::join)
            .collect(Collectors.toList()));
}
// 具体的业务逻辑执行方法
private void executeBusinessLogic(TaskData taskData) throws Exception {
    // 实现具体的业务逻辑
    Thread.sleep(100); // 模拟处理时间
    
    // 模拟可能的业务异常
    if (taskData.getId().contains("error")) {
        throw new RuntimeException("Business logic failed for task: " + taskData.getId());
    }
}
}
// 任务数据
class TaskData {
private String id;
private Object data;
public TaskData(String id, Object data) {
    this.id = id;
    this.data = data;
}
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public Object getData() { return data; }
public void setData(Object data) { this.data = data; }
}
// 单个任务结果
class TaskResult {
private String taskId;
private boolean success;
private String errorMessage;
public TaskResult(String taskId, boolean success, String errorMessage) {
    this.taskId = taskId;
    this.success = success;
    this.errorMessage = errorMessage;
}
public String getTaskId() { return taskId; }
public boolean isSuccess() { return success; }
public String getErrorMessage() { return errorMessage; }
}
// 批次执行结果
class BatchExecutionResult {
private String batchId;
private boolean allSuccess;
private boolean completedWithinTimeout;
private int successCount;
private int failureCount;
private List
public BatchExecutionResult(String batchId, boolean allSuccess, boolean completedWithinTimeout,
                          int successCount, int failureCount, List<TaskResult> taskResults) {
    this.batchId = batchId;
    this.allSuccess = allSuccess;
    this.completedWithinTimeout = completedWithinTimeout;
    this.successCount = successCount;
    this.failureCount = failureCount;
    this.taskResults = taskResults;
}
public String getBatchId() { return batchId; }
public boolean isAllSuccess() { return allSuccess; }
public boolean isCompletedWithinTimeout() { return completedWithinTimeout; }
public int getSuccessCount() { return successCount; }
public int getFailureCount() { return failureCount; }
public List<TaskResult> getTaskResults() { return taskResults; }
}
// 数据库服务接口
interface DatabaseService {
void updateTaskStatus(String taskId, String status, String errorMessage);
void updateBatchStatus(String batchId, String status, int successCount, int failureCount);
}
// MQ服务接口
interface MQService {
void sendBatchAsync(String batchId, List
}
// Spring配置示例
@Configuration
class ThreadPoolConfig {
@Bean("batchTaskExecutor")
public ThreadPoolTaskExecutor batchTaskExecutor() {
    ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
    executor.setCorePoolSize(10);
    executor.setMaxPoolSize(20);
    executor.setQueueCapacity(500);
    executor.setKeepAliveSeconds(60);
    executor.setThreadNamePrefix("BatchTask-");
    executor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
    executor.setWaitForTasksToCompleteOnShutdown(true);
    executor.setAwaitTerminationSeconds(60);
    executor.initialize();
    return executor;
}
}
// 使用示例
@Service
class BatchProcessingService {
@Autowired
private BatchTaskExecutor batchTaskExecutor;
public void processBatch() {
    List<TaskData> taskDataList = Arrays.asList(
        new TaskData("task1", "data1"),
        new TaskData("task2", "data2"),
        new TaskData("task3", "data3")
    );
    
    // 异步执行
    batchTaskExecutor.executeBatchAsync(taskDataList, "batch001", 10)
        .thenAccept(result -> {
            if (result.isAllSuccess()) {
                System.out.println("批次执行成功,已推送到MQ");
            } else {
                System.out.println("批次执行失败,成功:" + result.getSuccessCount() + 
                                 ", 失败:" + result.getFailureCount());
            }
        })
        .exceptionally(throwable -> {
            System.err.println("批次执行异常: " + throwable.getMessage());
            return null;
        });
    
    // 同步执行
    BatchExecutionResult result = batchTaskExecutor.executeBatch(taskDataList, "batch002", 10);
    System.out.println("同步执行结果: " + result.isAllSuccess());
}
}
2.如果想了解背后是如何实现的,可以看如下传统方式实现:
import java.util.;
import java.util.concurrent.;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
public class BatchTaskExecutor {
private final ThreadPoolExecutor threadPool;
private final DatabaseService databaseService;
private final MQService mqService;
public BatchTaskExecutor(ThreadPoolExecutor threadPool, 
                       DatabaseService databaseService, 
                       MQService mqService) {
    this.threadPool = threadPool;
    this.databaseService = databaseService;
    this.mqService = mqService;
}
/**
 * 执行批量任务
 * @param taskDataList 任务数据列表
 * @param batchId 批次ID
 * @param timeoutMinutes 超时时间(分钟)
 * @return 执行结果
 */
public BatchExecutionResult executeBatch(List<TaskData> taskDataList, 
                                       String batchId, 
                                       int timeoutMinutes) {
    
    // 初始化结果追踪
    BatchResultTracker tracker = new BatchResultTracker(taskDataList.size(), batchId);
    
    // 提交所有任务
    List<Future<TaskResult>> futures = new ArrayList<>();
    for (int i = 0; i < taskDataList.size(); i++) {
        TaskData taskData = taskDataList.get(i);
        int taskIndex = i;
        
        Future<TaskResult> future = threadPool.submit(() -> {
            try {
                // 执行具体的业务逻辑
                executeBusinessLogic(taskData);
                
                TaskResult result = new TaskResult(taskData.getId(), true, null, taskIndex);
                tracker.recordResult(result);
                return result;
                
            } catch (Exception e) {
                TaskResult result = new TaskResult(taskData.getId(), false, e.getMessage(), taskIndex);
                tracker.recordResult(result);
                return result;
            }
        });
        
        futures.add(future);
    }
    
    // 等待所有任务完成或超时
    BatchExecutionResult batchResult = waitForCompletion(futures, tracker, timeoutMinutes);
    
    // 更新数据库状态
    updateDatabaseStatus(batchResult);
    
    // 如果全部成功,推送到MQ
    if (batchResult.isAllSuccess()) {
        mqService.sendBatch(batchId, taskDataList);
    }
    
    return batchResult;
}
private BatchExecutionResult waitForCompletion(List<Future<TaskResult>> futures, 
                                             BatchResultTracker tracker, 
                                             int timeoutMinutes) {
    
    long timeoutMillis = TimeUnit.MINUTES.toMillis(timeoutMinutes);
    long startTime = System.currentTimeMillis();
    
    try {
        // 等待所有任务完成或超时
        boolean allCompleted = true;
        for (Future<TaskResult> future : futures) {
            long remainingTime = timeoutMillis - (System.currentTimeMillis() - startTime);
            if (remainingTime <= 0) {
                allCompleted = false;
                break;
            }
            
            try {
                future.get(remainingTime, TimeUnit.MILLISECONDS);
            } catch (TimeoutException e) {
                allCompleted = false;
                break;
            } catch (ExecutionException | InterruptedException e) {
                // 任务执行异常,已在任务内部处理
            }
        }
        
        // 如果超时,取消未完成的任务
        if (!allCompleted) {
            for (Future<TaskResult> future : futures) {
                if (!future.isDone()) {
                    future.cancel(true);
                }
            }
        }
        
        return tracker.getBatchResult(allCompleted);
        
    } catch (Exception e) {
        // 取消所有未完成的任务
        futures.forEach(future -> future.cancel(true));
        return tracker.getBatchResult(false);
    }
}
private void updateDatabaseStatus(BatchExecutionResult batchResult) {
    for (TaskResult taskResult : batchResult.getTaskResults()) {
        if (taskResult.isSuccess()) {
            databaseService.updateTaskStatus(taskResult.getTaskId(), "SUCCESS", null);
        } else {
            databaseService.updateTaskStatus(taskResult.getTaskId(), "FAILED", taskResult.getErrorMessage());
        }
    }
    
    // 更新批次整体状态
    String batchStatus = batchResult.isAllSuccess() ? "SUCCESS" : "FAILED";
    databaseService.updateBatchStatus(batchResult.getBatchId(), batchStatus, 
                                    batchResult.getSuccessCount(), batchResult.getFailureCount());
}
// 具体的业务逻辑执行方法,需要根据实际业务实现
private void executeBusinessLogic(TaskData taskData) throws Exception {
    // 这里实现具体的业务逻辑
    // 如果执行失败,抛出异常
    
    // 示例:模拟业务处理
    Thread.sleep(100); // 模拟处理时间
    
    // 模拟可能的业务异常
    if (taskData.getId().contains("error")) {
        throw new RuntimeException("Business logic failed for task: " + taskData.getId());
    }
}
}
// 批量结果追踪器
class BatchResultTracker {
private final int totalTasks;
private final String batchId;
private final AtomicInteger completedCount = new AtomicInteger(0);
private final AtomicInteger successCount = new AtomicInteger(0);
private final Map<Integer, TaskResult> results = new ConcurrentHashMap<>();
public BatchResultTracker(int totalTasks, String batchId) {
    this.totalTasks = totalTasks;
    this.batchId = batchId;
}
public void recordResult(TaskResult result) {
    results.put(result.getTaskIndex(), result);
    completedCount.incrementAndGet();
    if (result.isSuccess()) {
        successCount.incrementAndGet();
    }
}
public BatchExecutionResult getBatchResult(boolean completedWithinTimeout) {
    List<TaskResult> taskResults = new ArrayList<>();
    for (int i = 0; i < totalTasks; i++) {
        TaskResult result = results.get(i);
        if (result != null) {
            taskResults.add(result);
        } else {
            // 未完成的任务标记为超时失败
            taskResults.add(new TaskResult("unknown", false, "Task timeout or cancelled", i));
        }
    }
    
    int finalSuccessCount = successCount.get();
    int finalFailureCount = totalTasks - finalSuccessCount;
    boolean allSuccess = completedWithinTimeout && (finalSuccessCount == totalTasks);
    
    return new BatchExecutionResult(batchId, allSuccess, completedWithinTimeout, 
                                  finalSuccessCount, finalFailureCount, taskResults);
}
}
// 任务数据
class TaskData {
private String id;
private Object data; // 具体的业务数据
public TaskData(String id, Object data) {
    this.id = id;
    this.data = data;
}
// getters and setters
public String getId() { return id; }
public void setId(String id) { this.id = id; }
public Object getData() { return data; }
public void setData(Object data) { this.data = data; }
}
// 单个任务结果
class TaskResult {
private String taskId;
private boolean success;
private String errorMessage;
private int taskIndex;
public TaskResult(String taskId, boolean success, String errorMessage, int taskIndex) {
    this.taskId = taskId;
    this.success = success;
    this.errorMessage = errorMessage;
    this.taskIndex = taskIndex;
}
// getters
public String getTaskId() { return taskId; }
public boolean isSuccess() { return success; }
public String getErrorMessage() { return errorMessage; }
public int getTaskIndex() { return taskIndex; }
}
// 批次执行结果
class BatchExecutionResult {
private String batchId;
private boolean allSuccess;
private boolean completedWithinTimeout;
private int successCount;
private int failureCount;
private List
public BatchExecutionResult(String batchId, boolean allSuccess, boolean completedWithinTimeout,
                          int successCount, int failureCount, List<TaskResult> taskResults) {
    this.batchId = batchId;
    this.allSuccess = allSuccess;
    this.completedWithinTimeout = completedWithinTimeout;
    this.successCount = successCount;
    this.failureCount = failureCount;
    this.taskResults = taskResults;
}
// getters
public String getBatchId() { return batchId; }
public boolean isAllSuccess() { return allSuccess; }
public boolean isCompletedWithinTimeout() { return completedWithinTimeout; }
public int getSuccessCount() { return successCount; }
public int getFailureCount() { return failureCount; }
public List<TaskResult> getTaskResults() { return taskResults; }
}
// 数据库服务接口(需要实现)
interface DatabaseService {
void updateTaskStatus(String taskId, String status, String errorMessage);
void updateBatchStatus(String batchId, String status, int successCount, int failureCount);
}
// MQ服务接口(需要实现)
interface MQService {
void sendBatch(String batchId, List
}
// 使用示例
class Example {
public void example() {
// 创建线程池
ThreadPoolExecutor threadPool = new ThreadPoolExecutor(
5, 10, 60L, TimeUnit.SECONDS,
new LinkedBlockingQueue<>(100),
new ThreadPoolExecutor.CallerRunsPolicy()
);
    // 创建服务实例(需要实际实现)
    DatabaseService databaseService = new DatabaseServiceImpl();
    MQService mqService = new MQServiceImpl();
    
    // 创建批量任务执行器
    BatchTaskExecutor executor = new BatchTaskExecutor(threadPool, databaseService, mqService);
    
    // 准备任务数据
    List<TaskData> taskDataList = Arrays.asList(
        new TaskData("task1", "data1"),
        new TaskData("task2", "data2"),
        new TaskData("task3", "data3")
    );
    
    // 执行批量任务,超时时间10分钟
    BatchExecutionResult result = executor.executeBatch(taskDataList, "batch001", 10);
    
    // 检查结果
    if (result.isAllSuccess()) {
        System.out.println("所有任务执行成功,已推送到MQ");
    } else {
        System.out.println("批次执行失败,成功:" + result.getSuccessCount() + 
                         ", 失败:" + result.getFailureCount());
    }
}
}
// 数据库服务实现示例
class DatabaseServiceImpl implements DatabaseService {
@Override
public void updateTaskStatus(String taskId, String status, String errorMessage) {
// 实现数据库更新逻辑
System.out.println("更新任务状态: " + taskId + " -> " + status);
}
@Override
public void updateBatchStatus(String batchId, String status, int successCount, int failureCount) {
    // 实现批次状态更新逻辑
    System.out.println("更新批次状态: " + batchId + " -> " + status);
}
}
// MQ服务实现示例
class MQServiceImpl implements MQService {
@Override
public void sendBatch(String batchId, List
// 实现MQ推送逻辑
System.out.println("推送批次到MQ: " + batchId + ", 任务数: " + taskDataList.size());
}
}
                    
                
                
            
        
浙公网安备 33010602011771号