xiaobenchi

导航

Java并发包中线程同步器原理剖析

Java并发包中线程同步器原理剖析

1. CountDownLatch原理剖析

  • 案例介绍

    日常开发中经常会遇到需要在主线程中开启多个线程去并行执行任务,并且主线程需要等待所有子线程执行完毕再进行汇总的场景。在CountDownLatch出现之前一般都使用线程的join()方法来实现这一点。

    使用CountDownLatch的代码如下

    public class JoinCountDownLatch{
        
        //创建一个CountDownLatch实例
        private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
        
        public static void main(String[] args) throws InterruptedException{
            
            Thread threadOne = new Thread(new Runnable() {
                
                @override
                public void run(){
                    
                    try{
                        Thread.sleep(1000);
                    }catch(InterruptedException e){
                        e.printStackTrace();
                    }finally{
                        countDownLatch.countDown();
                    }
                    System.out.println("child threadOne over!");
                }
            });
            
            Thread threadTwo = new Thread(new Runnable() {
                
                @override
                public void run(){
                    
                    try{
                        Thread.sleep(1000);
                    }catch(InterruptedException e){
                        e.printStackTrace();
                    }finally{
                        countDownLatch.countDown();
                    }
                    System.out.println("child threadTwo over!");
                }
            });
            
            //启动子线程
            threadOne.start();
            threadTwo.start();
            
            System.out.println("Wait all child thread over!");
            
            //等待子线程执行完毕,返回
            countDownLatch.await();
            
            System.out.println("All child thread over!");
        }
    }
    

    上面的代码还不够优雅,在项目实践中一般都避免直接操作线程,而是使用 ExecutorService线程池来管理。

    public class JoinCountDownLatch{
        
        //创建一个CountDownLatch实例
        private static volatile CountDownLatch countDownLatch = new CountDownLatch(2);
        
        public static void main(String[] args) throws InterruptedException{
            ExecutorService executorService = Executors.newFixedThreadPool(2);
            //将线程A添加到线程池
            executorService.submit(new Runnable() {
               
                public void run(){
                    
                    try{
                        Thread.sleep(1000);
                    }catch(InterruptedException e){
                        e.printStackTrace();
                    }finally{
                        countDownLatch.countDown();
                    }
                    System.out.println("child threadOne over!");
                }
            });
            
            //将线程A添加到线程池
            executorService.submit(new Runnable() {
               
                public void run(){
                    
                    try{
                        Thread.sleep(1000);
                    }catch(InterruptedException e){
                        e.printStackTrace();
                    }finally{
                        countDownLatch.countDown();
                    }
                    System.out.println("child threadTwo over!");
                }
            });
            
              System.out.println("Wait all child thread over!");
            
            //等待子线程执行完毕,返回
            countDownLatch.await();
            
            System.out.println("All child thread over!");
        }
    }
    
  • 实现原理探究

    从CountDownLatch的名字就可以猜测其内部应该有个计数器,并且这个计数器是递减的。

    • 类图

    从类图可以看出,CountDownLatch是使用AQS实现的。

    通过下面的构造函数,你会发现,实际上是把计数器的值赋给了AQS的状态变量state,也就是这里使用AQS的状态值来表示计数器值

    public CountDownLatch(int count){
        if(count < 0) throw  new IllegalArgumentException("count < 0");
        this.sync = new Sync(count);
    }
    
    Sync(int count){
        setState(count);
    }
    

    下面我们来研究CountDownLatch中的几个重要的方法,看它们是如何调用AQS来实现功能的。

    1. void await( )方法

      当线程调用CountDownLatch对象的await方法后,当前线程会被阻塞,直到下面情况之一发生才会返回:当所有线程都调用了CountDownLatch对象的countDown方法后,也就是计数器的值为0时;其他线程调用了当前线程的interrupt( )方法中断了当前线程,当前线程就会抛出InterruptedException异常,然后返回。

      //CountDownLatch的await()方法
      public void await() throws InterruptedException{
          sync.acquireSharedInterruptibly(1);
      }
      

      await()方法委托sync调用了AQS的acquireSharedInterruptibly方法,后者的代码如下:

      //AQS获取共享资源时可被中断的方法
      public final void acquireSharedInterruptibly(int arg)
          throws InterruptedException{
          //如果线程被中断则抛出异常
          if(Thread.interrupted())
              throw new InterruptedException();
          //查看当前计数器值是否为0,为0则直接返回,否则进入AQS的队列等待
          if(tryAcquireShared(arg) < 0)
              doAcquireSharedInterruptibly(arg);
      }
      
      //sync类实现的AQS接口
      proteceted int tryAcquireShared(int acquires){
          return (getState() == 0) ? 1 : -1;
      }
      

      该方法的特点是线程获取资源时可以被中断,并且获取的资源是共享资源。acquireSharedInterruptibly首先判断当前线程是否已被中断,若是则抛出异常,否则调用sync实现的tryAcquireShared方法查看当前状态值(计数器值)是否为0,是则当前线程的await()方法直接返回,否则调用AQS的doAcquireSharedInterruptibly方法让当前线程阻塞。

    2. boolean wait(long timeout, TimeUnit unit)方法

      当线程调用了CountDownLatch对象的该方法后,当前线程会被阻塞,直到下面的情况之一发生才会返回:当所有线程都调用了CountDownLatch对象的countDown方法后,也就是计数器值为0时,这时候会放回true;设置的timeout时间到了,因为超时而返回false;其他线程调用了当前线程的interrupt( )方法中断了当前线程,当前线程会抛出InterruptedException异常,然后返回。

    3. void countDown( ) 方法

      线程调用该方法后,计数器的值递减,递减后如果计数器值为0则唤醒所有因调用await方法而被阻塞的线程,否则什么都不做。

      public void countDown(){
          //委托sync调用AQS的方法
          sync.releaseShared(1);
      }
      
      //AQS的方法
      public final boolean releaseShared(int arg){
          //调用sync实现的tryReleaseShared
          if(tryReleaseShared(arg)){
              //AQS的释放资源方法
              doReleaseShared();
              return true;
          }
          return false;
      }
      

      releaseShared首先调用了sync实现的AQS的tryReleaseShared方法

      //sync的方法
      protected boolean tryReleaseShared(int releases){
          //循环进行CAS,直到当前线程成功完成CAS使计数器值(状态值state)减1并更新到state
          for(;;){
              int c = getState();
              
              //如果当前状态值为0则直接返回(1)
              if(c == 0)
                  return false;
              
              //使用CAS让计数器值减1(2)
              int nextc = c - 1;
              if(compareAndSetState(c,nextc))
                  return nextc == 0;
          }
      }
      
    4. long getCount()方法

      获取当前计数器的值,也就是AQS的state的值,一般在测试时使用该方法。

      public long getCount(){
          return sync.getCount();
      }
      
      int getCount(){
          return getState();
      }
      

      其内部还是调用了AQS的getState方法来获取state的值。

2. 回环屏障CyclicBarrier原理探究

上节介绍的CountDownLatch在解决多个线程同步方面相对于调用线程的join方法已经有了不少优化,但是CountDownLatch的计数器是一次性的,也就是等到计数器值变为0后,再调用CountDownLatch的await和countdown方法都会立刻返回,这就起不到线程同步的效果了。所以为了满足计数器可以重置的需要,JDK开发组提供了CyclicBarrier类,并且CyclicBarrier类的功能并不限于CountDownLatch的功能。从字面意思理解,CyclicBarrier是回环屏障的意思,它可以让一组线程全部达到一个状态后再全部同时执行。这里之所以叫作回环是因为当所有等待线程执行完毕,并重置CyclicBarrier的状态后它可以被重用。之所以叫作屏障是因为线程调用await方法后就会被阻塞,这个阻塞点就称为屏障点,等所有线程都调用了await方法后,线程们就会冲破屏障,继续向下运行。

  • 案例介绍

    我们要实现,使用两个线程去执行一个被分解的任务A,当两个线程把自己的任务都指向完毕后再对它们的结果进行汇总处理。

    public class CycleBarrierTest1{
        
        //创建一个CyclicBarrier实例,添加一个所有子线程全部达到屏障后执行的任务
        private static CyclicBarrier cyclicBarrier = new CyclicBarrier(2,new Runnable(){
            public void run(){
                System.out.println(Thread.currentThread() + "task1 merge result");
            }
        });
        
        public static void main(String[] args) throws InterruptedException{
            //创建一个线程个数固定为2的线程池
            ExecutorService executorService = Executors.newFixedThreadPool(2);
            //将线程A添加到线程池
            executorService.submit(new Runnable(){
                public void run(){
                    try{
                        System.out.println(Thread.currentThread() + "task1-1");
                        System.out.println(Thread.currentThread() + "enter int barrier");
                        cyclicBarrier.await();
                        System.out.println(Thread.currentThread() + "enter out barrier");
                    }catch(Exception e){
                        e.printStackTrace();
                    }
                }
            });
            
            //将线程B添加到线程池
            executorService.submit(new Runnable(){
                public void run(){
                    try{
                        System.out.println(Thread.currentThread() + "task1-2");
                        System.out.println(Thread.currentThread() + "enter int barrier");
                        cyclicBarrier.await();
                        System.out.println(Thread.currentThread() + "enter out barrier");
                    }catch(Exception e){
                        e.printStackTrace();
                    }
                }
            });
            
            //关闭线程池
            executorService.shutDown();
        }
    }
    

    下面再举个例子来说明CyclicBarrier的可复用性。

    假设一个任务由阶段1、阶段2和阶段3组成,每个线程要串行地执行阶段1、阶段2和阶段3,当多个线程执行该任务时,必须要保证所有线程的阶段1全部完成后才能进入阶段2执行,当所有线程的阶段2全部完成后才能进入阶段3执行。下面使用CyclicBarrier来完成这个需求。

    public class CycleBarrierTest2{
        //创建一个CyclicBarrier实例
        private static CyclicBarrier cyclicBarrier  = new CyclicBarrier(2);
        
        public static void main(String[] args) throws InterruptedException{
            ExecutorService executorService = Executors.newFixedThreadPool(2);
            
            //将线程A添加到线程池
            executorService.submit(new Runnable() {
                public void run() {
                    try{
                        System.out.println(Thread.currentThread() + "step1");
                        cyclicBarrier.await();
                        
                        System.out.println(Thread.currentThread() + "step2");
                        cyclicBarrier.await();
                        
                        System.out.println(Thread.currentThread() + "step3");
                        cyclicBarrier.await();
                        
                    }catch(Exception e){
                        e.printStackTrack();
                    }
                }
            });
            
            //将线程B添加到线程池
            executorService.submit(new Runnable() {
                public void run() {
                    try{
                        System.out.println(Thread.currentThread() + "step1");
                        cyclicBarrier.await();
                        
                        System.out.println(Thread.currentThread() + "step2");
                        cyclicBarrier.await();
                        
                        System.out.println(Thread.currentThread() + "step3");
                        cyclicBarrier.await();
                        
                    }catch(Exception e){
                        e.printStackTrack();
                    }
                }
            });
            
            //关闭线程池
            executorService.shutdown();
            
        }
    }
    
  • 实现原理探究

    • 类图结构

CyclicBarrier基于独占锁实现,本质底层还是基于AQS的。parties用来记录线程个数,这里表示多少线程调用await后,所有线程才会冲破屏障继续往下运行。而count一开始等于parties,每当有线程调用await方法就递减1,当count为0时就表示所有线程都到了屏障点。
  • CyclicBarrier中的几个重要方法。

    1. int await()方法

      当前线程调用CyclicBarrier的该方法时会被阻塞,直到满足下面条件之一才会返回“parties个线程都调用了await()方法,也就是线程都到达了屏障点;

      public int await() throws InterruptedException, BrokenBarrierException{
          try{
              return dowait(false,0L);
          }catch(TimeoutException toe){
              throw new Error(toe);
          }
      }
      
    2. boolean await(long timeout, TimeUnit unit)方法

      当前线程调用CyclicBarrier的该方法时会被阻塞,直到满足下面条件之一才会返回:parties个线程都调用了await()方法,也就是线程都到了屏障点,这时候返回true;设置的超时时间到了后返回false;其他线程调用当前线程的interrupt()方法中断了当前线程,则当前线程会抛出InterruptedException异常然后返回;与当前屏障点关联的Generation对象的broken标志被设置为true时,会抛出BrokenBarrierException异常,然后返回。

      public int await(long timeout,TimeUnit unit) throws InterruptedException, BrokenBarrierException, TimeoutException{
          return dowait(true,unit.toNanos(timeout));
      }
      
    3. int doawait(boolean timed, long nanos)方法

      该方法实现了CyclicBarrier的核心功能

      private int doawait(boolean timed,long nanaos) InterruptedException, BrokenBarrierException, TimeoutException{
          final ReentrantLock lock = this.lock;
          lock.lock();
          try{
              ...
             //(1)如果index==0则说明所有线程都到了屏障点,此时执行初始化时传递的任务
             int intdex = -- count;
              if(index == 0){
                  boolean ranAction = false;
                  try{
                      //(2)执行任务
                      if(command != null)
                          command.run();
                      ranAction = true;
                      //(3)激活其他因调用await方法而被阻塞的线程,并重置CyclicBarrier
                      nextGeneration();
                      //返回
                      return 0;
                  }finally{
                      if(!ranAction)
                          breakBarrier();
                  }
              }
              
              //(4)如果index != 0
              for(;;){
                  try{
                      //(5)没有设置超时时间。
                      if(!timed){
                          trip.await();
                      }//(6)设置了超时时间
                      else if(nanos > 0L)
                          nanos = trip.awaitNanos(nanos);
                  }catch(InterruptedException ie){
                      //...
                  }
                     //...
              }
          }finally{
              lock.unlock();
          }
      }
      
      private void nextGeneration(){
          //(7) 唤醒条件队列里面阻塞线程
          trip.signalAll();
          //(8)重置CyclicBarrier
          count = parties;
          generation = new Generation();
      }
      
  • 小结

    本节首先通过案例说明了CycleBarrier与CountDownLatch的不同在于,前者是可以复用的,并且前者特别适合分段任务有序执行的场景。然后分析了CycleBarrier,其通过独占锁ReentrantLock实现计数器原子性更新,并使用条件变量队列来实现线程同步。

3. 信号量Semaphore原理探究

Semaphore信号量也是Java中的一个同步器,与CountDownLatch和CycleBarrier不同的是,它内部的计数器是递增的,并且在一开始初始化Semaphore时可以指定一个初始值,但是并不需要知道需要同步的线程个数,而是在需要同步的地方调用acquire方法时指定需要同步的线程个数。

  • 案例介绍

    同样下面的例子也是在主线程中开启两个子线程让它们执行,等所有子线程执行完毕后主线程再继续向下运行。

    public class SemaphoreTest{
        //创建一个Semaphore实例
        private static Semaphore = new Semaphore(0);
        
        public static void main(String[] args) throws InterruptedException{
            
            ExecutorServie executorService = Executors.newFixedThreadPool(2);
            
            //将线程A添加到线程池
            executorService.submit(new Runnable(){
                public void run(){
                    try{
                        System.out.println(Thread.currentThread() + "over");
                        semaphore.release();
                    }catch(Exception e){
                        e.printStackTrack();
                    }
                }
            });
            
            //将线程B添加到线程池
            executorService.submit(new Runnable(){
                public void run(){
                    try{
                        System.out.println(Thread.currentThread() + "over");
                        semaphore.release();
                    }catch(Exception e){
                        e.printStackTrack();
                    }
                }
            });
            
            //等待子线程执行完毕,返回
            semaphore.acquire(2);
            System.out.println("all child thread over!");
            
            //关闭线程池
            executorService.shutdown();
        }
        
    }
    
  • 实现原理探究

    • 类图

    由该类图可知,Semaphore还是使用AQS实现的。Sync只是对AQS的一个修饰,并且Sync有两个实现类,用来指定获取信号量时是否采用公平策略。

    public Semaphore(int permits){
        sync = new NonfairSync(permits);
    }
    
    public Semaphore(int permits, boolean fair){
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }
    
    Sync(int permits){
        setState(permits);
    }
    
    • Semaphore实现的主要方法

      1. void acquire( )方法

        当前线程调用该方法的目的是希望获取一个信号量资源。如果当前信号量个数大于0,则当前信号量的计数会减1,然后该方法直接返回。否则如果当前信号量个数等于0,则当前线程会被放入AQS的阻塞队列。当其他线程调用了当前线程的interrupt()方法中断了当前线程时,则当前线程会抛出InterruptedException异常返回。

        public void acquire() throws InterruptedException{
        	//传递参数为1,说明要获取一个信号量资源
            sync.acquireSharedInterruptibly(1);
        }
        
        public final void acquireSharedInterruptibly(int arg) throws InterruptedException{
            //(1)如果线程被中断,则抛出中断异常
            if(Thread.interrupted())
                throw new InterruptedException();
            
            //(2)否则调用Sync子类方法尝试获取,这里根据构造函数确定使用公平策略
            if(tryAcquireShared(arg) < 0)
                
                //如果获取失败则放入阻塞队列。然后再次尝试,如果失败则调用park方法挂起当前线程
                doAcquireShareInterruption(arg);
        }
        
      2. void acquire (int permits)方法

        该方法与acquire()方法不同,后者只需要获取一个信号量值,而前者则获取permits个。

      3. void acquireUninterruptibly()方法

        该方法与acquire()类似,不同之处在于该方法对中断不响应,也就是当当前线程调用了acquireUninterruptibly获取资源时(包含被阻塞后),其他线程调用了当前线程的interrupt()方法设置了当前线程的中断标志,此时当前线程并不会抛出InterruptedException异常而返回。

      4. void acquireUninterruptibly(int permits)方法

      5. void release( )方法

        该方法的作用是把当前Semaphore对象的信号量值增加1,如果当前有线程因为调用aquire方法被阻塞而被放入了AQS的阻塞队列,则会根据公平策略选择一个信号量个数能被满足的线程进行激活,激活的线程会尝试获取刚增加的信号量。

      6. void release(int permits)方法

        该方法与不带参数的release方法的不同之处在于,前者每次调用会在信号量值原来的基础上增加permits,而后者每次增加1。

4. 总结

本章介绍了并发包中关于线程协作的一些重要类。首先CountDownLatch通过计数器提供了更灵活的控制,只要检测到计数器值为0,就可以往下执行,这相比使用join必须等待线程执行完毕后主线程才会继续向下运行更灵活。另外,CyclicBarrier也可以达到CountDownLatch的效果,但是后者在计数器值变为0后,就不能再被复用,而前者则可以使用reset方法重置后复用,前者对同一个算法但是输入参数不同的类似场景比较适用。而Semaphore采用了信号量递增的策略,一开始并不需要关心同步的线程个数,等调用aquire方法时再指定需要同步的个数,并且提供了获取信号量的公平性策略。使用本章介绍的类会大大减少你在Java中使用wait、notify等来实现线程同步的代码量,在日常开发中当需要进行线程同步时使用这些同步类会节省很多代码并且可以保证正确性。

posted on 2022-08-24 18:44  小迟在努力  阅读(42)  评论(0)    收藏  举报