juc:工具类-CountDownLatch

 

概念

CountDownLatch是一个辅助同步器类,用来作计数使用,

它的作用有点类似于生活中的倒数计数器,先设定一个计数初始值,当计数降到0时,将会触发一些事件,如火箭的倒数计时。

初始计数值在构造CountDownLatch对象时传入,每调用一次 countDown() 方法,计数值就会减1。

线程可以调用CountDownLatch的await方法进入阻塞,当计数值降到0时,所有之前调用await阻塞的线程都会释放。

注意:CountDownLatch的初始计数值一旦降到0,无法重置。如果需要重置,可以考虑使用CyclicBarrier。

用法示例

作为一个开关/入口

将初始计数值为1的 CountDownLatch 作为一个的开关或入口:

在调用 countDown() 的线程打开入口前,所有调用 await 的线程都一直在入口处等待。

public class Driver {
    private static final int N = 10;
    public static void main() throws InterruptedException {
        CountDownLatch switcher = new CountDownLatch(1);
        for (int i = 0; i < N; ++i) {
            new Thread(new Worker(switcher)).start();
        }
        doSomething();
        switcher.countDown();       // 主线程开启开关
    }
 
    public static void doSomething() {
    }
}
 
class Worker implements Runnable {
    private final CountDownLatch startSignal;
    Worker(CountDownLatch startSignal) {
        this.startSignal = startSignal;
    }
 
    public void run() {
        try {
            startSignal.await();    //所有执行线程在此处等待开关开启
            doWork();
        } catch (InterruptedException ex) {
        }
    }
    void doWork() { ...}
}
View Code

作为一个完成信号

将初始计数值为N的 CountDownLatch作为一个完成信号点:使某个线程在其它N个线程完成某项操作之前一直等待。

public class Driver {
    private static final int N = 10;
 
    public static void main() throws InterruptedException {
        CountDownLatch compsignal = new CountDownLatch(N);
        for (int i = 0; i < N; ++i) {
            new Thread(new Worker(compsignal)).start();
        }
 
        compsignal.await();       // 主线程等待其它N个线程完成
        doSomething();
    }
 
    public static void doSomething() {
    }
}
 
class Worker implements Runnable {
    private final CountDownLatch compSignal;
 
    Worker(CountDownLatch compSignal) {
        this.compSignal = compSignal;
    }
 
    public void run() {
        try {
            doWork();
            compSignal.countDown(); //每个线程做完自己的事情后,就将计数器减去1
        } catch (InterruptedException ex) {
        }
    }
 
    void doWork() { ...}
}
View Code

初步使用

package threads.countdownLatch;

import java.util.concurrent.CountDownLatch;

public class MyService {
    private CountDownLatch downLatch=new CountDownLatch(1);

    public void testMethod(){
        try {
            System.out.println("A");
            //此处await方法还可以重载 超时的那种
            downLatch.await();//当运行此方法后代码不再继续往下运行,等待计数为0
            System.out.println("B");
        }catch (InterruptedException ex){
            ex.printStackTrace();
        }
    }

    public void downMethod(){
        System.out.println("X");
        downLatch.countDown();
    }
}
View Code
package threads.countdownLatch;


public class MyThread extends Thread{
    private MyService myService;

    public MyThread(MyService myService){
        super();
        this.myService=myService;
    }

    @Override
    public void run(){
        myService.testMethod();
    }

}
View Code
package threads.countdownLatch;

public class Run {
    public static void main(String[] args) throws InterruptedException{
        MyService service=new MyService();
        MyThread t=new MyThread(service);
        t.start();
        Thread.sleep(2000);
        service.downMethod();
    }
}
View Code

等待所有的线程都到达同步点后才会继续执行

package threads.countdownLatch;


import java.util.concurrent.CountDownLatch;

public class MyThread extends Thread{

    private CountDownLatch maxRunner;
    public MyThread(CountDownLatch maxRunner){
        super();
        this.maxRunner=maxRunner;
    }

    @Override
    public void run(){
        try {
            Thread.sleep(2000);
            System.out.println(Thread.currentThread().getName());
            maxRunner.countDown();
        }catch (InterruptedException ex){
            ex.printStackTrace();
        }
    }

}
View Code
package threads.countdownLatch;

import java.util.concurrent.CountDownLatch;

public class Run {
    public static void main(String[] args) throws InterruptedException{
        CountDownLatch maxRunner=new CountDownLatch(10);
        MyThread[] myThreads=new MyThread[Integer.parseInt(""+maxRunner.getCount())];
        for (int i=0;i<myThreads.length;i++){
            myThreads[i]=new MyThread(maxRunner);
            myThreads[i].setName("线程"+(i+1));
            myThreads[i].start();
        }
        maxRunner.await();//等待所有的线程都到达同步点后才会继续执行
        System.out.println("all in...");
    }
}
View Code

 

join 的使用

 

假如有这样一个需求,当我们需要解析一个Excel里多个sheet的数据时,可以考虑使用多线程,每个线程解析一个sheet里的数据,等到所有的sheet都解析完之后,程序需要统计解析总耗时。分析一下:解析每个sheet耗时可能不一样,总耗时就是最长耗时的那个操作。

 

我们能够想到的最简单的做法是使用join,代码如下:

package com.itsoku.chat13;

import java.util.concurrent.TimeUnit;

/**
 * 微信公众号:javacode2018,获取年薪50万课程
 */
public class Demo1 {

    public static class T extends Thread {
        //休眠时间(秒)
        int sleepSeconds;

        public T(String name, int sleepSeconds) {
            super(name);
            this.sleepSeconds = sleepSeconds;
        }

        @Override
        public void run() {
            Thread ct = Thread.currentThread();
            long startTime = System.currentTimeMillis();
            System.out.println(startTime + "," + ct.getName() + ",开始处理!");
            try {
                //模拟耗时操作,休眠sleepSeconds秒
                TimeUnit.SECONDS.sleep(this.sleepSeconds);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            long endTime = System.currentTimeMillis();
            System.out.println(endTime + "," + ct.getName() + ",处理完毕,耗时:" + (endTime - startTime));
        }
    }

    public static void main(String[] args) throws InterruptedException {
        long starTime = System.currentTimeMillis();
        T t1 = new T("解析sheet1线程", 2);
        t1.start();

        T t2 = new T("解析sheet2线程", 5);
        t2.start();

        t1.join();
        t2.join();
        long endTime = System.currentTimeMillis();
        System.out.println("总耗时:" + (endTime - starTime));

    }
}
输出:

1563767560271,解析sheet1线程,开始处理!
1563767560272,解析sheet2线程,开始处理!
1563767562273,解析sheet1线程,处理完毕,耗时:2002
1563767565274,解析sheet2线程,处理完毕,耗时:5002
总耗时:5005
View Code

代码中启动了2个解析sheet的线程,第一个耗时2秒,第二个耗时5秒,最终结果中总耗时:5秒。上面的关键技术点是线程的join()方法,此方法会让当前线程等待被调用的线程完成之后才能继续。可以看一下join的源码,内部其实是在synchronized方法中调用了线程的wait方法,最后被调用的线程执行完毕之后,由jvm自动调用其notifyAll()方法,唤醒所有等待中的线程。这个notifyAll()方法是由jvm内部自动调用的,jdk源码中是看不到的,需要看jvm源码,有兴趣的同学可以去查一下。所以JDK不推荐在线程上调用wait、notify、notifyAll方法。

而在JDK1.5之后的并发包中提供的CountDownLatch也可以实现join的这个功能。

CountDownLatch介绍

CountDownLatch称之为闭锁,它可以使一个或一批线程在闭锁上等待,等到其他线程执行完相应操作后,闭锁打开,这些等待的线程才可以继续执行。确切的说,闭锁在内部维护了一个倒计数器。通过该计数器的值来决定闭锁的状态,从而决定是否允许等待的线程继续执行。

常用方法:

public CountDownLatch(int count):构造方法,count表示计数器的值,不能小于0,否者会报异常。

public void await() throws InterruptedException:调用await()会让当前线程等待,直到计数器为0的时候,方法才会返回,此方法会响应线程中断操作。

public boolean await(long timeout, TimeUnit unit) throws InterruptedException:限时等待,在超时之前,计数器变为了0,方法返回true,否者直到超时,返回false,此方法会响应线程中断操作。

public void countDown():让计数器减1

CountDownLatch使用步骤:

创建CountDownLatch对象

调用其实例方法await(),让当前线程等待

调用countDown()方法,让计数器减1

当计数器变为0的时候,await()方法会返回

示例1:一个简单的示例

我们使用CountDownLatch来完成上面示例中使用join实现的功能,代码如下:

package com.itsoku.chat13;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 微信公众号:javacode2018,获取年薪50万课程
 */
public class Demo2 {

    public static class T extends Thread {
        //休眠时间(秒)
        int sleepSeconds;
        CountDownLatch countDownLatch;

        public T(String name, int sleepSeconds, CountDownLatch countDownLatch) {
            super(name);
            this.sleepSeconds = sleepSeconds;
            this.countDownLatch = countDownLatch;
        }

        @Override
        public void run() {
            Thread ct = Thread.currentThread();
            long startTime = System.currentTimeMillis();
            System.out.println(startTime + "," + ct.getName() + ",开始处理!");
            try {
                //模拟耗时操作,休眠sleepSeconds秒
                TimeUnit.SECONDS.sleep(this.sleepSeconds);
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                countDownLatch.countDown();
            }
            long endTime = System.currentTimeMillis();
            System.out.println(endTime + "," + ct.getName() + ",处理完毕,耗时:" + (endTime - startTime));
        }
    }

    public static void main(String[] args) throws InterruptedException {
        System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 start!");
        CountDownLatch countDownLatch = new CountDownLatch(2);

        long starTime = System.currentTimeMillis();
        T t1 = new T("解析sheet1线程", 2, countDownLatch);
        t1.start();

        T t2 = new T("解析sheet2线程", 5, countDownLatch);
        t2.start();

        countDownLatch.await();
        System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 end!");
        long endTime = System.currentTimeMillis();
        System.out.println("总耗时:" + (endTime - starTime));

    }
}
输出:

1563767580511,main线程 start!
1563767580513,解析sheet1线程,开始处理!
1563767580513,解析sheet2线程,开始处理!
1563767582515,解析sheet1线程,处理完毕,耗时:2002
1563767585515,解析sheet2线程,处理完毕,耗时:5002
1563767585515,main线程 end!
总耗时:5003
View Code

从结果中看出,效果和join实现的效果一样,代码中创建了计数器为2的CountDownLatch,主线程中调用countDownLatch.await();会让主线程等待,t1、t2线程中模拟执行耗时操作,最终在finally中调用了countDownLatch.countDown();,此方法每调用一次,CountDownLatch内部计数器会减1,当计数器变为0的时候,主线程中的await()会返回,然后继续执行。注意:上面的countDown()这个是必须要执行的方法,所以放在finally中执行。

示例2:等待指定的时间

还是上面的示例,2个线程解析2个sheet,主线程等待2个sheet解析完成。主线程说,我等待2秒,你们还是无法处理完成,就不等待了,直接返回。如下代码:

package com.itsoku.chat13;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 微信公众号:javacode2018,获取年薪50万课程
 */
public class Demo3 {

    public static class T extends Thread {
        //休眠时间(秒)
        int sleepSeconds;
        CountDownLatch countDownLatch;

        public T(String name, int sleepSeconds, CountDownLatch countDownLatch) {
            super(name);
            this.sleepSeconds = sleepSeconds;
            this.countDownLatch = countDownLatch;
        }

        @Override
        public void run() {
            Thread ct = Thread.currentThread();
            long startTime = System.currentTimeMillis();
            System.out.println(startTime + "," + ct.getName() + ",开始处理!");
            try {
                //模拟耗时操作,休眠sleepSeconds秒
                TimeUnit.SECONDS.sleep(this.sleepSeconds);
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                countDownLatch.countDown();
            }
            long endTime = System.currentTimeMillis();
            System.out.println(endTime + "," + ct.getName() + ",处理完毕,耗时:" + (endTime - startTime));
        }
    }

    public static void main(String[] args) throws InterruptedException {
        System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 start!");
        CountDownLatch countDownLatch = new CountDownLatch(2);

        long starTime = System.currentTimeMillis();
        T t1 = new T("解析sheet1线程", 2, countDownLatch);
        t1.start();

        T t2 = new T("解析sheet2线程", 5, countDownLatch);
        t2.start();

        boolean result = countDownLatch.await(2, TimeUnit.SECONDS);

        System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 end!");
        long endTime = System.currentTimeMillis();
        System.out.println("主线程耗时:" + (endTime - starTime) + ",result:" + result);

    }
}
输出:

1563767637316,main线程 start!
1563767637320,解析sheet1线程,开始处理!
1563767637320,解析sheet2线程,开始处理!
1563767639321,解析sheet1线程,处理完毕,耗时:2001
1563767639322,main线程 end!
主线程耗时:2004,result:false
1563767642322,解析sheet2线程,处理完毕,耗时:5002
View Code

从输出结果中可以看出,线程2耗时了5秒,主线程耗时了2秒,主线程中调用countDownLatch.await(2, TimeUnit.SECONDS);,表示最多等2秒,不管计数器是否为0,await方法都会返回,若等待时间内,计数器变为0了,立即返回true,否则超时后返回false。

示例3:2个CountDown结合使用的示例

有3个人参见跑步比赛,需要先等指令员发指令枪后才能开跑,所有人都跑完之后,指令员喊一声,大家跑完了。

示例代码:

package com.itsoku.chat13;

import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

/**
 * 微信公众号:javacode2018,获取年薪50万课程
 */
public class Demo4 {

    public static class T extends Thread {
        //跑步耗时(秒)
        int runCostSeconds;
        CountDownLatch commanderCd;
        CountDownLatch countDown;

        public T(String name, int runCostSeconds, CountDownLatch commanderCd, CountDownLatch countDown) {
            super(name);
            this.runCostSeconds = runCostSeconds;
            this.commanderCd = commanderCd;
            this.countDown = countDown;
        }

        @Override
        public void run() {
            //等待指令员枪响
            try {
                commanderCd.await();
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            Thread ct = Thread.currentThread();
            long startTime = System.currentTimeMillis();
            System.out.println(startTime + "," + ct.getName() + ",开始跑!");
            try {
                //模拟耗时操作,休眠runCostSeconds秒
                TimeUnit.SECONDS.sleep(this.runCostSeconds);
            } catch (InterruptedException e) {
                e.printStackTrace();
            } finally {
                countDown.countDown();
            }
            long endTime = System.currentTimeMillis();
            System.out.println(endTime + "," + ct.getName() + ",跑步结束,耗时:" + (endTime - startTime));
        }
    }

    public static void main(String[] args) throws InterruptedException {
        System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "线程 start!");
        CountDownLatch commanderCd = new CountDownLatch(1);
        CountDownLatch countDownLatch = new CountDownLatch(3);

        long starTime = System.currentTimeMillis();
        T t1 = new T("小张", 2, commanderCd, countDownLatch);
        t1.start();

        T t2 = new T("小李", 5, commanderCd, countDownLatch);
        t2.start();

        T t3 = new T("路人甲", 10, commanderCd, countDownLatch);
        t3.start();

        //主线程休眠5秒,模拟指令员准备发枪耗时操作
        TimeUnit.SECONDS.sleep(5);
        System.out.println(System.currentTimeMillis() + ",枪响了,大家开始跑");
        commanderCd.countDown();

        countDownLatch.await();
        long endTime = System.currentTimeMillis();
        System.out.println(System.currentTimeMillis() + "," + Thread.currentThread().getName() + "所有人跑完了,主线程耗时:" + (endTime - starTime));

    }
}
输出:

1563767691087,main线程 start!
1563767696092,枪响了,大家开始跑
1563767696092,小张,开始跑!
1563767696092,小李,开始跑!
1563767696092,路人甲,开始跑!
1563767698093,小张,跑步结束,耗时:2001
1563767701093,小李,跑步结束,耗时:5001
1563767706093,路人甲,跑步结束,耗时:10001
1563767706093,main所有人跑完了,主线程耗时:15004
View Code

代码中,t1、t2、t3启动之后,都阻塞在commanderCd.await();,主线程模拟发枪准备操作耗时5秒,然后调用commanderCd.countDown();模拟发枪操作,此方法被调用以后,阻塞在commanderCd.await();的3个线程会向下执行。主线程调用countDownLatch.await();之后进行等待,每个人跑完之后,调用countDown.countDown();通知一下countDownLatch让计数器减1,最后3个人都跑完了,主线程从countDownLatch.await();返回继续向下执行。

手写一个并行处理任务的工具类

package com.itsoku.chat13;

import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * 微信公众号:javacode2018,获取年薪50万课程
 */
public class TaskDisposeUtils {
    //并行线程数
    public static final int POOL_SIZE;

    static {
        POOL_SIZE = Integer.max(Runtime.getRuntime().availableProcessors(), 5);
    }

    /**
     * 并行处理,并等待结束
     *
     * @param taskList 任务列表
     * @param consumer 消费者
     * @param <T>
     * @throws InterruptedException
     */
    public static <T> void dispose(List<T> taskList, Consumer<T> consumer) throws InterruptedException {
        dispose(true, POOL_SIZE, taskList, consumer);
    }

    /**
     * 并行处理,并等待结束
     *
     * @param moreThread 是否多线程执行
     * @param poolSize   线程池大小
     * @param taskList   任务列表
     * @param consumer   消费者
     * @param <T>
     * @throws InterruptedException
     */
    public static <T> void dispose(boolean moreThread, int poolSize, List<T> taskList, Consumer<T> consumer) throws InterruptedException {
        if (CollectionUtils.isEmpty(taskList)) {
            return;
        }
        if (moreThread && poolSize > 1) {
            poolSize = Math.min(poolSize, taskList.size());
            ExecutorService executorService = null;
            try {
                executorService = Executors.newFixedThreadPool(poolSize);
                CountDownLatch countDownLatch = new CountDownLatch(taskList.size());
                for (T item : taskList) {
                    executorService.execute(() -> {
                        try {
                            consumer.accept(item);
                        } finally {
                            countDownLatch.countDown();
                        }
                    });
                }
                countDownLatch.await();
            } finally {
                if (executorService != null) {
                    executorService.shutdown();
                }
            }
        } else {
            for (T item : taskList) {
                consumer.accept(item);
            }
        }
    }

    public static void main(String[] args) throws InterruptedException {
        //生成1-10的10个数字,放在list中,相当于10个任务
        List<Integer> list = Stream.iterate(1, a -> a + 1).limit(10).collect(Collectors.toList());
        //启动多线程处理list中的数据,每个任务休眠时间为list中的数值
        TaskDisposeUtils.dispose(list, item -> {
            try {
                long startTime = System.currentTimeMillis();
                TimeUnit.SECONDS.sleep(item);
                long endTime = System.currentTimeMillis();

                System.out.println(System.currentTimeMillis() + ",任务" + item + "执行完毕,耗时:" + (endTime - startTime));
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        //上面所有任务处理完毕完毕之后,程序才能继续
        System.out.println(list + "中的任务都处理完毕!");
    }
}
运行代码输出:

1563769828130,任务1执行完毕,耗时:1000
1563769829130,任务2执行完毕,耗时:2000
1563769830131,任务3执行完毕,耗时:3001
1563769831131,任务4执行完毕,耗时:4001
1563769832131,任务5执行完毕,耗时:5001
1563769833130,任务6执行完毕,耗时:6000
1563769834131,任务7执行完毕,耗时:7001
1563769835131,任务8执行完毕,耗时:8001
1563769837131,任务9执行完毕,耗时:9001
1563769839131,任务10执行完毕,耗时:10001
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]中的任务都处理完毕!
View Code

TaskDisposeUtils是一个并行处理的工具类,可以传入n个任务内部使用线程池进行处理,等待所有任务都处理完成之后,方法才会返回。比如我们发送短信,系统中有1万条短信,我们使用上面的工具,每次取100条并行发送,待100个都处理完毕之后,再取一批按照同样的逻辑发送。

 

 

参考文章

死磕Java并发】—–J.U.C之并发工具类:CountDownLatch

Java多线程进阶(十八)—— J.U.C之synchronizer框架:CountDownLatch

Java多线程同步工具箱之CountDownLatch篇

 

posted @ 2019-12-15 15:57  弱水三千12138  阅读(195)  评论(0)    收藏  举报