CountDownLatch(计数器)使用和源码分析

        CountDownLatch类位于java.util.concurrent包下,利用它可以实现类似计数器的功能。比如有一个任务A,它要等待其他4个任务执行完毕之后才能执行,此时就可以利用CountDownLatch来实现这种功能。

         写个测试Demo

public class Test {
    public static void main(String[] args) {
        final CountDownLatch latch = new CountDownLatch(2);
        new Thread(){
            public void run() {
                try {
                    System.out.println("子线程"+Thread.currentThread().getName()+"正在执行");
                    Thread.sleep(3000);
                    System.out.println("子线程"+Thread.currentThread().getName()+"执行完毕");
                    latch.countDown();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            };
        }.start();
        new Thread(){
            public void run() {
                try {
                    System.out.println("子线程"+Thread.currentThread().getName()+"正在执行");
                    Thread.sleep(3000);
                    System.out.println("子线程"+Thread.currentThread().getName()+"执行完毕");
                    latch.countDown();
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            };
        }.start();
        try {
            System.out.println("等待2个子线程执行完毕...");
            latch.await();
            System.out.println("2个子线程已经执行完毕");
            System.out.println("继续执行主线程");
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

结果:

 

 

 Demo中的主线程(main方法)为什么会等两个子线程执行countDown方法了才继续执行呢?只能从CountDownLatch的源码中找答案

CountDownLatch countDownLatch = new CountDownLatch(2);    //新建计数器   个数为2个
countDownLatch.countDown();    //个数减1
countDownLatch.await();        //await

 

CountDownLatch的构造方法
    public CountDownLatch(int count) {
        if (count < 0) throw new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }

    Sync类    

    private static final class Sync extends AbstractQueuedSynchronizer {
        private static final long serialVersionUID = 4982264981922014374L;
        //将父类AbstractQueuedSynchronizer中的state的值设置为2
        Sync(int count) {
            setState(count);
        }
        //获得state的值
        int getCount() {
            return getState();
        }
       //如果state为0,返回1,否则返回-1
        protected int tryAcquireShared(int acquires) {
            return (getState() == 0) ? 1 : -1;
        }
        //state减1
        protected boolean tryReleaseShared(int releases) {
            // Decrement count; signal when transition to zero
            for (;;) {
                int c = getState();
                if (c == 0)
                    return false;
                int nextc = c-1;
                //cas操作将state减1
                if (compareAndSetState(c, nextc))
                    return nextc == 0;
            }
        }
    }
countDownLatch.countDown()执行
public void countDown() {
        sync.releaseShared(1);
    }


public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            doReleaseShared();
            return true;
        }
        return false;
}
tryReleaseShared   将state的值减1

   countDownLatch.await() 执行

/**
    部分源码注释
    Causes the current thread to wait until the latch has counted 
    down to  zero, unless the thread is {@linkplain 
    Thread#interrupt interrupted}.
    等待计数器为0,除非线程中断
*/
public void await() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

//AQS类中的方法
public final void acquireSharedInterruptibly(int arg)
            throws InterruptedException {
        if (Thread.interrupted())
            throw new InterruptedException();
    
        //protected int tryAcquireShared(int acquires) {
        //    return (getState() == 0) ? 1 : -1;
        //}   
        //tryAcquireShared就是返回(getState() == 0) ? 1 : -1的值
        //原先为2,刚才减1,所以返回-1
        if (tryAcquireShared(arg) < 0)
            doAcquireSharedInterruptibly(arg);
}
doAcquireSharedInterruptibly 执行
/**
 * Acquires in shared interruptible mode.
 * @param arg the acquire argument
 */
private void doAcquireSharedInterruptibly(int arg)
    throws InterruptedException {
    //添加到aqs队列中
    final Node node = addWaiter(Node.SHARED);
    boolean failed = true;
    try {
        for (;;) {
            final Node p = node.predecessor();
            if (p == head) {
                //再次检验state是否为0,true返回1,false返回-1
                int r = tryAcquireShared(arg);
                //显然目前是r=-1
                if (r >= 0) {
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
            //进入shouldParkAfterFailedAcquire
            if (shouldParkAfterFailedAcquire(p, node) &&
                parkAndCheckInterrupt())
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}
shouldParkAfterFailedAcquire  执行
private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
    int ws = pred.waitStatus;
    //ws目前为0,而Node.SIGNAL为-1
    if (ws == Node.SIGNAL)
        /*
         * This node has already set status asking a release
         * to signal it, so it can safely park.
         */
        return true;
    if (ws > 0) {
        /*
         * Predecessor was cancelled. Skip over predecessors and
         * indicate retry.
         */
        do {
            node.prev = pred = pred.prev;
        } while (pred.waitStatus > 0);
        pred.next = node;
    } else {
        /*
         * waitStatus must be 0 or PROPAGATE.  Indicate that we
         * need a signal, but don't park yet.  Caller will need to
         * retry to make sure it cannot acquire before parking.
         */
        //通过cas把pred的waitStatus设置为-1
        compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
    }
    return false;
}
parkAndCheckInterrupt  执行
    /**
     * Convenience method to park and then check if interrupted
     *
     * @return {@code true} if interrupted
     */
    private final boolean parkAndCheckInterrupt() {
//当前线程挂起 LockSupport.park(
this); return Thread.interrupted(); }

主线程已经挂起,此时如果有其他线程再次执行 countDownLatch.countDown(),将state减为0

 public final boolean releaseShared(int arg) {
        if (tryReleaseShared(arg)) {
            //if成立,进入doReleaseShared
            doReleaseShared();
            return true;
        }
        return false;
    }
doReleaseShared 执行
private void doReleaseShared() {
    //自旋
    for (;;) {
        Node h = head;
        //此时h已经不为null,因为刚才主线程已经加入到AQS队列中
        if (h != null && h != tail) {
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) {
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                    continue;            
                //重点
                unparkSuccessor(h);
            }
            else if (ws == 0 &&
                     !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
        if (h == head)                   // loop if head changed
            break;
    }
}
unparkSuccessor  执行
private void unparkSuccessor(Node node) {
    
    int ws = node.waitStatus;
    if (ws < 0)
        compareAndSetWaitStatus(node, ws, 0);

    Node s = node.next;
    if (s == null || s.waitStatus > 0) {
        s = null;
        for (Node t = tail; t != null && t != node; t = t.prev)
            if (t.waitStatus <= 0)
                s = t;
    }
    //只要s肯定不为null
    if (s != null)
        //唤醒主线程
        LockSupport.unpark(s.thread);
}
 LockSupport.unpark(s.thread)方法唤醒了主线程,主线程继续执行,流程结束。

总结:简单的点说,new CountDownLatch(数量)就是设置AQS中state的数量,线程每次countDown则state减去1,当主线程调用await时候,如果已经为0,则继续执行,如果不为0,这park当前主线程,后面其他线程countDown时候,会判断state是否为0,如果为0,会unpark主线程,即唤醒,主线程继续执行。

CyclicBarrier 回环栅栏(模拟并发场景)
测试代码:
public class Test {
    public static void main(String[] args) {
        int N = 4;
        CyclicBarrier barrier  = new CyclicBarrier(N);

        for(int i=0;i<4;i++) {
            new Writer(barrier).start();
        }
    }
    static class Writer extends Thread{
        private CyclicBarrier cyclicBarrier;
        public Writer(CyclicBarrier cyclicBarrier) {
            this.cyclicBarrier = cyclicBarrier;
        }

        @Override
        public void run() {
            System.out.println("线程"+Thread.currentThread().getName()+"正在写入数据...");
            try {
                Thread.sleep(5000);      //以睡眠来模拟写入数据操作
                System.out.println("线程"+Thread.currentThread().getName()+"写入数据完毕,等待其他线程写入完毕");

                cyclicBarrier.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }catch(BrokenBarrierException e){
                e.printStackTrace();
            }
            System.out.println(Thread.currentThread().getName()+"所有线程写入完毕,继续处理其他任务...");
        }
    }
}

CountDownLatch和CyclicBarrier的比较

        1.CountDownLatch是线程组之间的等待,即一个(或多个)线程等待N个线程完成某件事情之后再执行;而CyclicBarrier则是线程组内的等待,即每个线程相互等待,即N个线程都被拦截之后,然后依次执行。

        2.CountDownLatch是减计数方式,而CyclicBarrier是加计数方式。

        3.CountDownLatch计数为0无法重置,而CyclicBarrier计数达到初始值,则可以重置。

        4.CountDownLatch不可以复用,而CyclicBarrier可以复用。


        附录:   相关书籍《Java并发编程的艺术》《Java多线程编程核心技术_完整版》

posted @ 2019-07-24 11:42  Don'tYouSee  阅读(545)  评论(0)    收藏  举报