46、锁存器和栅栏

内容来自王争 Java 编程之美

在平时开发中,有时候我们需要自己编写代码,测试某个接口在不同并发下的性能
比如测试在 N 个(N 可以是 100、200 等任意数)接口请求并发执行时接口的响应时间、QPS 等
为了模拟接口的并发请求,我们需要用到多线程,除此之外,线程之间的协调执行,还需要用到本节要讲的 CountDownLatch(锁存器)和 CyclicBarrier(栅栏)
接下来我们就结合这个需求的开发,详细讲解一下 CountDownLatch 和 CyclicBarrier 的用法以及实现原理

  • CountDownLatch 内部维护了一个计数器,通过调用 countDown() 方法来减少计数器的值,当计数器的值变为 0 时,所有调用 await() 的线程都会被唤醒
  • CyclicBarrier 内部维护了一个计数器,通过调用 await() 方法来来减少计数器的值并阻塞当前线程,当计数器的值变为 0 时,所有调用 await() 的线程都会被唤醒

1、CountDownLatch 的用法

CountDownLatch 内部维护了一个计数器,通过调用 countDown() 方法来减少计数器的值,当计数器的值变为 0 时,所有调用 await() 的线程都会被唤醒

1.1、介绍

CountDownLatch 中文名称叫作锁存器,其提供的常用方法有以下几个

// 构造函数, 传入 count 值
public CountDownLatch(int count);

// 阻塞等待 count 值变为 0
public void await() throws InterruptedException;

public boolean await(long timeout, TimeUnit unit) throws InterruptedException;

// 将 count 值减一
public void countDown();

CountDownLatch 的作用有点类似 Thread 类中的 join() 函数,用于一个线程等待其他多个线程的事件发生
对于 join() 函数来说,这里的事件指的是线程结束,对于 CountDownLatch 来说,这里的事件可以根据业务逻辑来定义
除此之外,使用 join() 需要知道被等待的线程是谁,而使用 CountDownLatch 则不需要,因此 CountDownLatch 相对于 join() 函数来说更加通用

1.2、示例

我们举个例子解释一下,代码如下所示

  • 在 DemoJoin 中,主线程(执行 main() 函数的线程)调用 join() 函数阻塞等待线程 t1 和 t2 的结束
  • 在 DemoLatch 中,CountDownLatch 中的 count 值初始化为 2,主线程调用 await() 函数阻塞等待 count 值变为 0
    另外两个线程在执行完部分逻辑之后,调用 countDown() 函数将 count 值减一
    当两个线程均执行 countDown() 函数之后,count 值变为 0,阻塞在 await() 函数上的主线程被唤醒,继续执行后续逻辑
public class DemoJoin {

    public static class RunnableForJoin implements Runnable {
        @Override
        public void run() {
            // ... do something ...
        }
    }

    public static void main(String[] args) throws InterruptedException {
        Thread t1 = new Thread(new RunnableForJoin());
        Thread t2 = new Thread(new RunnableForJoin());
        t1.start();
        t2.start();
        t1.join(); // join() 只用来等待线程执行结束, 并且必须知道被等待线程是谁
        t2.join();
    }
}
public class DemoLatch {

    private static final CountDownLatch latch = new CountDownLatch(2);

    public static class RunnableForLatch implements Runnable {
        @Override
        public void run() {
            // ... do something ...
            latch.countDown();
            // ... do other thing ...
        }
    }

    public static void main(String[] args) throws InterruptedException {
        new Thread(new RunnableForLatch()).start();
        new Thread(new RunnableForLatch()).start();
        latch.await(); // 等待 something 执行完成而非等待线程结束, 并且不需要知道在等谁
        // ... 执行后续逻辑 ...
    }
}

2、CountDownLatch 的实现原理

2.1、源码

CountDownLatch 的用法非常简单,其实现原理也不难,底层依赖 AQS 来实现,CountDownLatch 的部分源码如下所示
AQS 模板方法的使用方法是比较固定的,因此 CountDownLatch 的代码结构跟之前讲过的 ReentrantLock、ReentrantReadWriteLock、Semaphore 的代码结构是类似的
代码结构大致为:具体化抽象模板方法类 AQS,在具体类 Sync 中实现 AQS 中的抽象方法,使用具体类 Sync 中的模板方法来编程

public class CountDownLatch {

    // 具体化抽象模板方法类 AQS
    private static final class Sync extends AbstractQueuedSynchronizer {
        Sync(int count) {
            setState(count);// 将 count 值存储在 AQS 的 state 中
        }

        // 实现 AQS 的抽象方法
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1; // 检查 count 是不是为 0 了
        }

        // 实现 AQS 的抽象方法
        protected boolean tryReleaseShared(int releases) {
            for (; ; ) { // 执行 count--
                int c = getState();
                if (c == 0) return false;
                int nextc = c - 1;
                if (compareAndSetState(c, nextc)) return nextc == 0;
            }
        }
    }

    private final Sync sync;

    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    public void await() throws InterruptedException {
        sync.acquireSharedInterruptibly(1); // 使用 Sync 的模板方法编程
    }

    public boolean await(long timeout, TimeUnit unit) throws InterruptedException {
        // 使用 Sync 的模板方法编程
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }

    public void countDown() {
        sync.releaseShared(1); // 使用 Sync 的模板方法编程
    }
}

我们重点看下 await() 和 countDown() 这两个函数的实现原理

2.2、await()

await() -> acquireSharedInterruptibly() -> tryAcquireShared() -> doAcquireSharedInterruptibly()

从上述代码,我们可以发现,await() 函数直接调用 AQS 的 acquireSharedInterruptibly() 函数,acquireSharedInterruptibly() 函数的源码如下所示
CountDownLatch 的 count 值存储在 AQS 的 state 中,acquireSharedInterruptibly() 函数调用 tryAcquireShared() 函数查看 state 是否为 0
如果 state 为 0,则直接返回,如果 state 不为 0,则调用 doAcquireSharedInterruptibly() 函数阻塞等待 state 变为 0

tryAcquireShared() 函数在上述 CountDownLatch 的 Sync 内部类中已经给出
对于 doAcquireSharedInterruptibly() 函数,其跟我们之前在讲解 ReentrantReadWriteLock 实现原理时讲到的 doAcquireShared() 函数类似,读者可以自行查阅源码了解

public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    if (Thread.interrupted()) throw new InterruptedException();
    if (tryAcquireShared(arg) < 0) // 查看 state 是否为 0, 否则就放入等待队列等待 state 为 0
        doAcquireSharedInterruptibly(arg); // 阻塞等待 state 为 0
}

2.3、countDown()

countDown() -> releaseShared() -> tryReleaseShared() -> doReleaseShared()

我们再来看下 CountDownLatch 中的 countDown() 函数,countDown() 函数直接调用 AQS 的 releaseShared() 函数来实现,releaseShared() 函数的源码如下所示
releaseShared() 函数调用 tryReleaseShared() 函数将 state 减一
如果此时 state 变为 0,则执行 doReleaseShared() 函数唤醒等待队列中的线程,也就是唤醒调用了 await() 函数的线程

tryReleaseShared() 函数在上述 CountDownLatch 的 Sync 内部类中已经给出
对于 doReleaseShared() 函数,其在我们讲解 ReentrantReadWriteLock 实现原理时已经详细讲解,这里就不再赘述了

public final boolean releaseShared(int arg) {
    // state--, 如果 state 变为 0, 则执行 doReleaseShared()
    if (tryReleaseShared(arg)) {
        doReleaseShared(); // 唤醒等待队列中的线程
        return true;
    }
    return false;
}

3、CyclicBarrier 的用法

CyclicBarrier 内部维护了一个计数器,通过调用 await() 方法来来减少计数器的值并阻塞当前线程,当计数器的值变为 0 时,所有调用 await() 的线程都会被唤醒

CyclicBarrier 的字面意思是可循环使用(Cyclic)的屏障(Barrier)
它要做的事情是,让一组线程到达一个屏障(也可以叫同步点)时被阻塞,直到最后一个线程到达屏障时,屏障才会开门,所有被屏障拦截的线程于会继续运行

3.1、介绍

接下来我们再来看下 CyclicBarrier,CyclicBarrier 包含的常用方法有如下所示

// 构造函数, 传入 parties
public CyclicBarrier(int parties);

// 构造函数, 用于在线程到达屏障时, 优先执行 barrierAction, 方便处理更复杂的业务场景
public CyclicBarrier(int parties, Runnable barrierAction);

// 调用 await() 函数的线程会将 parties 减一, 如果不为 0, 则阻塞, 直到为 0 为止
public int await() throws InterruptedException, BrokenBarrierException;

public int await(long timeout, TimeUnit unit) throws InterruptedException, BrokenBarrierException, TimeoutException;

CyclicBarrier 的中文名为栅栏,非常形象地解释了 CyclicBarrier 的作用,用于多个线程互相等待,互相等待的线程都就位之后,再同时开始执行

3.2、示例 1

我们举个例子解释一下,如下代码所示
我们创建了一个 parties 为 10 的 CyclicBarrier 对象,用于 10 个线程之间互相等待
尽管这 10 个线程启动(执行 start() 函数)的时间不同,但每个线程启动之后,都会调用 await() 函数,将 parties 减一,然后检查 parties 是否为 0
如果 parties 不为 0,则当前线程阻塞等待,如果 parties 为 0,则当前线程唤醒所有调用了 await() 函数的线程

public class Demo {

    private static final CyclicBarrier barrier = new CyclicBarrier(10);

    public static void main(String[] args) {
        for (int i = 0; i < 10; i++) {
            new Thread(new Runnable() {
                @Override
                public void run() {
                    try {
                        barrier.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace(); // 当前线程被中断
                    } catch (BrokenBarrierException e) {
                        e.printStackTrace(); // 其他线程调用 await() 期间被中断
                    }
                    // 执行业务逻辑
                }
            }).start();
        }
        // 主线程需要等待以上 10 个线程执行结束, 方法有以下 3 种
        // 1、Sleep()
        // 2、join()
        // 3、CountDownLatch
    }
}

3.3、示例 2

对于 CountDownLatch 和 CyclicBarrier,前者是用于一个线程阻塞等待其他线程,后者是用于多个线程互相等待

使用 CountDownLatch,也可以实现 CyclicBarrier 所能实现的功能,如下代码所示
我们创建一个 count 值为 1 的 CountDownLatch 对象,10 个线程均调用 await() 函数阻塞等待 count 为 0
主线程调用 countDown() 函数将 count 值减一,变为 0,然后唤醒调用了 await() 函数的这 10 个线程,以此达到让这 10 个线程同步执行的目的

public class Demo {

    private static final CountDownLatch latch = new CountDownLatch(1);

    public static void main(String[] args) {
        for (int i = 0; i < 10; i++) {
            new Thread(new Runnable() {
                @Override
                public void run() {
                    try {
                        latch.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                    // 执行业务逻辑
                }
            }).start();
        }
        latch.countDown();
        // 主线程需要等待以上 10 个线程执行结束, 方法有以下 3 种
        // 1、Sleep()
        // 2、join()
        // 3、CountDownLatch
    }
}

4、CyclicBarrier 的实现原理

CyclicBarrier 的使用方法比较简单,接下来我们来看下它的实现原理
跟 CountDownLatch 不同,CyclicBarrier 并未使用 AQS 来实现,而是使用之前讲到条件变量来实现的
CyclicBarrier 类的源码如下所示,为了更清晰地展示其核心实现原理,我对 CyclicBarrier 中的源码做了简化
调用 await() 函数先将 parties 减一,然后检查 parties 是否为 0

  • 如果不为 0,则调用 Condition 上的 await() 函数让当前线程阻塞等待
  • 如果为 0,则调用 Condition 上的 signalAll() 函数唤醒所有调用了 await() 函数的线程
public class CyclicBarrier {

    private final ReentrantLock lock = new ReentrantLock();
    private final Condition trip = lock.newCondition();
    private int parties;

    public CyclicBarrier(int parties) {
        this.parties = parties;
    }

    // 函数返回值: 线程调用 await() 函数之后的 parties 剩余值
    // 注意: 以下代码逻辑省略掉了对 InterruptedException 和 BrokenBarrierException 的处理
    public int await() throws InterruptedException, BrokenBarrierException {
        final ReentrantLock lock = this.lock;
        lock.lock();
        try {
            int index = --parties;
            if (index == 0) {  // 最后一个调用 await() 的线程, 唤醒其他线程
                trip.signalAll();
                return 0;
            }
            trip.await();
            return index;
        } finally {
            lock.unlock();
        }
    }
}

5、在接口性能测试中的应用

对 CountDownLatch 和 CyclicBarrier 的用法和实现原理有了了解之后,接下来我们使用 CountDownLatch 和 CyclicBarrier 来编写开头的接口并发性能测试代码

如果要测试 N 个接口请求并发执行时接口的性能,我们需要创建 N 个测试线程,让每个测试线程循环执行接口请求,并且记录每个接口请求的响应时间
主线程通过 CountDownLatch 来等待其他测试线程执行完成,然后再通过记录的运行数据,统计接口的性能,比如:平均响应时间、QPS 等
除此之外,在以下代码实现中,我们还使用了 CyclicBarrier,让各测试线程更加精确地同时开始执行,以便更加准确地测试指定并发下的接口性能

public class ApiBenchmark {

    private static int numThread = 20;         // 并发度为 20
    private static int numReqPerThread = 1000; // 每个线程请求 1000 次接口

    private static CountDownLatch latch = new CountDownLatch(numThread); // 锁存器: 等待各测试线程完成后,唤醒主线程
    private static CyclicBarrier barrier = new CyclicBarrier(numThread); // 栅栏: 让各测试线程更加精确地同时开始执行

    public static class TestRunnable implements Runnable {
        public List<Long> respTimes = new ArrayList<>();

        @Override
        public void run() {
            try {
                barrier.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            } catch (BrokenBarrierException e) {
                e.printStackTrace();
            }
            for (int i = 0; i < numReqPerThread; i++) {
                long reqStartTime = System.nanoTime();
                // ... 调用接口
                long reqEndTime = System.nanoTime();
                respTimes.add(reqEndTime - reqStartTime);
            }
            latch.countDown();
        }
    }

    public static void main(String[] args) throws InterruptedException {
        // 创建线程
        Thread[] threads = new Thread[numThread];
        TestRunnable[] runnables = new TestRunnable[numThread];
        for (int i = 0; i < numThread; i++) {
            runnables[i] = new TestRunnable();
            threads[i] = new Thread(runnables[i]);
        }

        // 启动线程
        long startTime = System.nanoTime();
        for (int i = 0; i < numThread; i++) {
            threads[i].start();
        }

        // 等待测试线程结束
        latch.await();
        long endTime = System.nanoTime();

        // 统计接口性能
        long qps = (numThread * numReqPerThread * 1000) / ((endTime - startTime) / 1000000);
        float avgRespTime = 0.0f;
        for (int i = 0; i < numThread; i++) {
            for (Long respTime : runnables[i].respTimes) {
                avgRespTime += respTime;
            }
        }
        avgRespTime /= (numThread * numReqPerThread);
    }
}

6、课后思考题

在本节中我们讲到,CountDownLatch 底层使用 AQS 来实现,CyclicBarrier 底层使用条件变量来实现,那么本节中提到的 join() 函数是怎么实现的呢

public final void join() throws InterruptedException {
    join(0);
}

public final synchronized void join(final long millis) throws InterruptedException {
    if (millis > 0) {
        if (isAlive()) {
            final long startTime = System.nanoTime();
            long delay = millis;
            do {
                wait(delay);
            } while (isAlive() && (delay = millis -
                    TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime)) > 0);
        }
    } else if (millis == 0) {
        while (isAlive()) {
            wait(0);
        }
    } else {
        throw new IllegalArgumentException("timeout value is negative");
    }
}

join() 采用 Object 上的 wait() 函数实现,它是基于 Java 内置条件变量实现的,当线程终止时,会调用线程自身的 notifyAll() 方法
image

public synchronized Object get(long millis) throws InterruptedException {
    // 1、加锁
    long future = System.currentTimeMillis() + millis;
    long remaining = millis;

    // 2、检查状态变量是否满足条件, while 循环
    while ((result == null) && remaining > 0) {
        wait(remaining); // 3、等待并释放锁 4、被唤醒之后重新竞争获取锁
        remaining = future - System.currentTimeMillis();
    }

    // 以下为业务逻辑
    return result;
    // 5、解锁
}
posted @ 2023-06-19 12:01  lidongdongdong~  阅读(73)  评论(0)    收藏  举报