工具类-循环栅栏 CyclicBarrier

概念

一组线程互相等待,直到所有线程都到达某个屏障点(barrier point)后再继续执行,特性如下:

  1. 循环使用:与 CountDownLatch 不同,CyclicBarrier 可以重复使用
  2. 屏障点回调:支持在所有线程到达屏障后执行特定操作(Runnable)
  3. 多线程同步:协调固定数量的线程在屏障点等待
  4. 异常处理:提供 BrokenBarrierException 处理屏障破坏情况

类结构

public class CyclicBarrier {
  	// 可重入锁
    private final ReentrantLock lock = new ReentrantLock();
  	// 条件变量
    private final Condition trip = lock.newCondition();
  	// 屏障阈值,构造 CyclicBarrier 时确定,不可变,决定多少个线程调用 await() 后屏障才放行
    private final int parties;
  	// 屏障命令(回调函数)
    private final Runnable barrierCommand;
  	// 屏障
    private Generation generation = new Generation();
  	// 当前尚未达到屏障处的线程数,每个线程调用 await(),count 就 -1,直到0(恢复为 parties)
    private int count;
}

核心方法

// 非使当前线程等待,直到所有参与线程都调用了此方法
// 当最后一个线程调用后,屏障被触发,所有线程被释放
public int await() {}

// 带超时的等待。在指定时间内等待其他线程到达,超时后抛出 TimeoutException 并破坏屏障
public int await(long timeout, TimeUnit unit) {}

// 将屏障重置为初始状态(专门提供出来,在外部手动调用)
public void reset() {}

// 未到达屏障的线程数量
public int getNumberWaiting() {}

// 触发屏障所需的线程数(构造时指定的 parties 值)
public int getParties() { return parties; }

// 屏障是否处于破坏状态
public boolean isBroken() {}

// 打破屏障,当出现异常时调用
private void breakBarrier() {}

// 刷新屏障,当最后一个线程到达屏障时调用
private void nextGeneration() {}

使用示例

代码

public static void main(String[] args) {
    // 创建 CyclicBarrier(规定需要3个线程到达屏障)
    CyclicBarrier barrier = new CyclicBarrier(3, 
            () -> System.out.println("屏障触发,所有线程到达"));

    // 创建3个测试线程
    for (int i = 0; i < 3; i++) {
        new Thread(() -> {
            try {
                System.out.println(Thread.currentThread().getName() + " 到达屏障1");
                barrier.await();

                System.out.println(Thread.currentThread().getName() + " 到达屏障2");
                barrier.await();

                System.out.println(Thread.currentThread().getName() + " 到达屏障3");
                barrier.await();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }).start();
    }
}

运行结果

Thread-0 到达屏障1
Thread-2 到达屏障1
Thread-1 到达屏障1
屏障触发,所有线程到达 ----- 第 1 次到达屏障
Thread-1 到达屏障2
Thread-2 到达屏障2
Thread-0 到达屏障2
屏障触发,所有线程到达 ----- 第 2 次到达屏障
Thread-0 到达屏障3
Thread-1 到达屏障3
Thread-2 到达屏障3
屏障触发,所有线程到达 ----- 第 3 次到达屏障

流程图

flowchart TD A[线程1调用await] --> B{count--} B -->|count=2| C[线程1阻塞] D[线程2调用await] --> E{count--} E -->|count=1| F[线程2阻塞] G[线程3调用await] --> H{count--} H -->|count=0| I[执行屏障命令] I --> J[唤醒所有线程] J --> K[重置count=3] K --> L[线程1继续执行] K --> M[线程2继续执行] K --> N[线程3继续执行] L --> O[线程1二次await] M --> P[线程2二次await] N --> Q[线程3二次await]

工作原理

// java.util.concurrent.CyclicBarrier#await()
public int await() throws InterruptedException, BrokenBarrierException {
    try {
        return dowait(false, 0L); // 和超时方法是同一个,只是传入参数表示不超时
    } catch (TimeoutException toe) {
        throw new Error(toe); // cannot happen
    }
}

// java.util.concurrent.CyclicBarrier#dowait
private int dowait(boolean timed, long nanos) throws InterruptedException, BrokenBarrierException, TimeoutException {
    
  	// 使用 ReentrantLock 来实现的,这是独占锁
  	final ReentrantLock lock = this.lock;
    lock.lock(); // 每个线程都要获取锁
    try {
        final Generation g = generation; // 当前屏障

        if (g.broken) throw new BrokenBarrierException(); // 屏障如果已经被破坏,直接抛出异常

      	// 初始 count == parties,每一个线程到达屏障 count 就 -1
        int index = --count;
      
      	// 最后一个到达的线程屏障点的处理
        if (index == 0) {
            boolean ranAction = false;
            try {
                final Runnable command = barrierCommand;
                if (command != null)
                    command.run(); // 执行屏障任务
                ranAction = true;
                nextGeneration(); // 重置屏障,唤醒所有线程
                return 0;
            } finally {
                if (!ranAction) // 屏障任务执行失败,破坏屏障
                    breakBarrier();
            }
        }

        // 非最后一个线程到达屏障点的处理
        for (;;) { // 无限循环,防止虚假唤醒(即使没有线程调用 signal(),操作系统也可能随机唤醒线程,概率低但存在)
            try { // 处理也很简单就是利用 condition.await() 阻塞线程(条件变量,这里不赘述,想详细了解去看看 AQS)
                if (!timed) // 如果是无限期等待
                    trip.await();
                else if (nanos > 0L) // 如果是超时等待
                    nanos = trip.awaitNanos(nanos);
            } catch (InterruptedException ie) {
               ...
            }
          
          	// 唤醒后的线程会执行下面的逻辑

            if (g != generation) // 如果发现屏障已经刷新,返回当前索引(返回也不会用到,主要是会跳出循环)
                return index;

            if (timed && nanos <= 0L) { // 走到这里说明屏障未刷新,并且超时时间也到了,破坏屏障并抛出异常
                breakBarrier();
                throw new TimeoutException();
            }
        }
    } finally {
        lock.unlock(); // 释放锁,每次 await() 后都会释放锁,不然别的线程拿不到锁,虽然释放锁但是线程也会挂起
    }
}

逻辑还是比较清晰,不是最后一个线程就挂起,是最后一个线程就执行回调,并且刷新屏障,遇到异常就打破屏障
利用了 ReenTrantLock 和 Condition 实现的,每个线程必须持有锁才能执行,没获取到锁就进入条件队列,ReenTrantLock 这里也不赘述了

再看看怎么刷新和打破屏障的,源码如下:

// 刷新屏障
private void nextGeneration() {

    trip.signalAll(); // 唤醒所有线程,AQS 的,这里不赘述

    count = parties; // 恢复 count 为 parties
    generation = new Generation(); // 新一代屏障
}

// 打破屏障
private void breakBarrier() {
    generation.broken = true;
    count = parties;
    trip.signalAll();
}

注意事项

必须与构造函数中的 parties 参数一致

posted @ 2023-05-24 16:53  CyrusHuang  阅读(186)  评论(0)    收藏  举报