工具类-信号量 Semaphore

Semaphore 维护了一组 许可(permits),线程访问共享资源时需要先获取许可,类似 ReentrantLock 获取方式也分为公平和非公平两种方式

  • 当线程访问资源时,必须获取(acquire)一个许可
  • 使用完资源后,必须释放(release)许可
  • 如果没有可用许可,线程将被阻塞,直到其他线程释放许可

Semaphore 原理
1)初始时设置 AQS 的 state 赋值,这个值就是许可数量
2)每当有线程获取一个许可 state 就 -1
3)线程获取许可后如果 state 变量值小于 0 就意味着获取许可失败,线程就入队并且阻塞

如果熟悉 AQS 理解 Semaphore 就会非常简单,比 ReentrantLock、ReentrantReadWriteLock、StempLock 简单的多

类结构

类成员

public class Semaphore implements java.io.Serializable {
  
  	// 成员变量就一个同步器
    private final Sync sync;
  
  	// 构造方法,默认非公平方式
    public Semaphore(int permits) {
        sync = new NonfairSync(permits);
    }
  	
  	// 构造方法,可以指定公平还是非公平
    public Semaphore(int permits, boolean fair) {
        sync = fair ? new FairSync(permits) : new NonfairSync(permits);
    }
  
  	// 获取一个许可,获取不成功线程入队阻塞
    public void acquire() throws InterruptedException {
        sync.acquireSharedInterruptibly(1);
    }
  
  	// 获取一个许可,获取不成功线程不会入队也不会阻塞
  	public boolean tryAcquire() {
        return sync.nonfairTryAcquireShared(1) >= 0;
    }
  
  	// tryAcquire() 带一个超时时间版本
  	public boolean tryAcquire(long timeout, TimeUnit unit)
        throws InterruptedException {
        return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
    }
  
  	// 获取多个许可,获取不成功线程入队阻塞
   	public void acquire(int permits) throws InterruptedException {
        if (permits < 0) throw new IllegalArgumentException();
        sync.acquireSharedInterruptibly(permits);
    }
  
  	// 释放一个许可
    public void release() {
        sync.releaseShared(1);
    }
  
  	// 释放多个许可
  	public void release(int permits) {
        if (permits < 0) throw new IllegalArgumentException();
        sync.releaseShared(permits);
    }
  
  	// 许可池目前剩余许可数量
    public int availablePermits() {
        return sync.getPermits();
    }
    
}

内部类

// 继承自 AQS
abstract static class Sync extends AbstractQueuedSynchronizer {

  	// 构造方法
    Sync(int permits) { setState(permits); }

  	// 获取剩余许可数量
    final int getPermits() { return getState(); }

  	// 非公平方式获取许可
    final int nonfairTryAcquireShared(int acquires) { }

  	// 释放许可
    protected final boolean tryReleaseShared(int releases) { }

  	// 调整许可池许可总数量
    final void reducePermits(int reductions) {}

}

// 非公平模式
static final class NonfairSync extends Sync {

  	// 构造方法
    NonfairSync(int permits) { super(permits); }

  	// 释放许可
    protected int tryAcquireShared(int acquires) { return nonfairTryAcquireShared(acquires); }
}

// 公平模式
static final class FairSync extends Sync {

  	// 构造方法
    FairSync(int permits) { super(permits); }

  	// 释放许可
    protected int tryAcquireShared(int acquires) { }
}

许可池初始化

并不复杂,就是初始化 AQS 的 state 的值,state 就是许可数量

// java.util.concurrent.Semaphore#Semaphore(int)
public Semaphore(int permits) {
    sync = new NonfairSync(permits);
}
    
// java.util.concurrent.Semaphore.NonfairSync#NonfairSync
NonfairSync(int permits) {
    super(permits);
}

// java.util.concurrent.Semaphore.Sync#Sync
Sync(int permits) {
    setState(permits);
}

// java.util.concurrent.locks.AbstractQueuedSynchronizer#setState
protected final void setState(int newState) {
    state = newState;
}

获取许可

// java.util.concurrent.Semaphore#acquire()
public void acquire() throws InterruptedException {
    sync.acquireSharedInterruptibly(1);
}

// java.util.concurrent.locks.AbstractQueuedSynchronizer#acquireSharedInterruptibly
public final void acquireSharedInterruptibly(int arg) throws InterruptedException {
    if (Thread.interrupted()) throw new InterruptedException();
  
  	// 获取许可,如果返回值小于 0 说明获取许可失败,执行 doAcquireSharedInterruptibly()
    if (tryAcquireShared(arg) < 0) 
        doAcquireSharedInterruptibly(arg);
}

怎么获取许可的,以非公平为例,源码如下:

// java.util.concurrent.Semaphore.NonfairSync#tryAcquireShared
protected int tryAcquireShared(int acquires) {
    return nonfairTryAcquireShared(acquires);
}

// java.util.concurrent.Semaphore.Sync#nonfairTryAcquireShared
final int nonfairTryAcquireShared(int acquires) { // 返回值要么是负数(许可不足),要么是 >= 0 的数(许可剩余可用数量)
    for (;;) {
        int available = getState(); // 获取当前许可数量
        int remaining = available - acquires; // 剩余可用的许可数量
      
      	/**
      	 * remaining < 0:许可不足(|| 是短路运算不会修改 state 的值)直接返回这个负数
      	 * compareAndSetState():当 remaining >= 0 时才执行,表示许可数量是够的,这时 CAS 修改 state 的值,返回剩余可用的许可数量
      	 */
        if (remaining < 0 || compareAndSetState(available, remaining))
            return remaining; // 返回
    }
}

许可不足时怎么处理,这是共性操作,所以 AQS 中实现已经实现好了,熟悉 AQS 的话这里不陌生,我现在大致嫖一眼就知道要做什么了~~

// java.util.concurrent.locks.AbstractQueuedSynchronizer#doAcquireSharedInterruptibly
private void doAcquireSharedInterruptibly(int arg) throws InterruptedException {
    final Node node = addWaiter(Node.SHARED); // 入同步队列(队列为空:初始化,虚拟节点,挂载新节点;队列不为空:挂载新节点)
    boolean failed = true;
    try {
        for (;;) { // 熟悉的无限循环
            final Node p = node.predecessor(); // 前驱节点
            if (p == head) { // 如果前驱节点是头结点
              	// 如果当前节点的前驱节点是头结点,说明下一次本来就该当前节点获取到许可,所以这里先尝试一次,能获取到就不能避免阻塞
                int r = tryAcquireShared(arg); 
                if (r >= 0) { // 获取许可成功,不用阻塞了,原来的头结点出队,当前节点作为新的头结点(虚拟节点)
                    setHeadAndPropagate(node, r);
                    p.next = null; // help GC
                    failed = false;
                    return;
                }
            }
          	// 如果获取许可失败,那就要阻塞了
            if (shouldParkAfterFailedAcquire(p, node) && // 内部无限循环,目的是:1)队列移除已取消的节点;2)前驱节点改为 -1;3)返回 true
                parkAndCheckInterrupt()) // 挂起线程,使用 LockSupport.park()
                throw new InterruptedException();
        }
    } finally {
        if (failed)
            cancelAcquire(node);
    }
}

释放许可

// java.util.concurrent.Semaphore#release()
public void release() {
    sync.releaseShared(1);
}

// java.util.concurrent.locks.AbstractQueuedSynchronizer#releaseShared
public final boolean releaseShared(int arg) {
    if (tryReleaseShared(arg)) { // 释放许可
        doReleaseShared(); // 释放成功的处理
        return true;
    }
    return false;
}

怎么释放许可的:修改 CAS 原子 state 的值,并且保证新的 state 不能超过 Integer 的最大值。源码如下:

// java.util.concurrent.Semaphore.Sync#tryReleaseShared
protected final boolean tryReleaseShared(int releases) {
    for (;;) {
        int current = getState(); // 当前可用的许可数量
        int next = current + releases; // 计算释放后的新许可数量
        if (next < current) // Integer 最大值+1是负数(突然想到前段时间抖音上黄子韬直播间的点赞数量是负数)
            throw new Error("Maximum permit count exceeded");
        if (compareAndSetState(current, next)) // CAS原子更新state值
            return true;
    }
}

释放许可后就要唤醒线程了,这是共性操作,AQS 提供,AQS 的文章 里介绍过这个方法,这里再贴一下,源码如下:

private void doReleaseShared() {
    for (;;) { // 又是无限循环
        Node h = head;
        if (h != null && h != tail) { // 队列不为空才有可以唤醒的线程
            int ws = h.waitStatus;
            if (ws == Node.SIGNAL) { // 头结点如果是 -1(waitStatus = -1 表示当前节点有责任唤醒后继节点)那就唤醒
              	// 修改头结点为0,为什么修改为0 前面说过了,这个头结点要出队,获取到线程的节点将作为新的头结点,这里是只是先改成 0
                if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) 
                    continue; 
                unparkSuccessor(h); // 原来的头结点改为0后,唤醒线程,头结点会更新(头结点出队)
            }
          	// 如果头结点是0,设置为 -3(比如 Semaphore 多个线程并发调用 release)
            else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                continue;                // loop on failed CAS
        }
      
      	// 只要 unparkSuccessor() 执行了,并且成功唤醒节点(唤醒节点后会更新头结点),就继续循环,唤醒所有节点
      	// 如果 head 没变化,说明没有执行 unparkSuccessor() 或 unparkSuccessor() 没唤醒成功,就不用循环了
        if (h == head) 
            break;
    }
}

使用示例

import java.util.concurrent.Semaphore;

public class SemaphoreExample {
  
    // 模拟数据库连接池(只有3个连接可用)
    private static final int MAX_CONNECTIONS = 3;
    private final Semaphore semaphore = new Semaphore(MAX_CONNECTIONS, true); // 公平模式
    
    // 模拟获取数据库连接
    public void useDatabase(int threadId) {
        try {
            System.out.println("线程" + threadId + " 正在等待获取连接...");
            
            // 获取许可(如果没有可用连接会阻塞)
            semaphore.acquire();
            
            System.out.println("线程" + threadId + " 获取到连接,开始操作数据库...");
            
            // 模拟数据库操作耗时
            Thread.sleep(2000);
            
            System.out.println("线程" + threadId + " 操作完成,释放连接");
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        } finally {
            // 释放许可(重要!必须放在finally块中)
            semaphore.release();
        }
    }

    public static void main(String[] args) {
        SemaphoreExample example = new SemaphoreExample();
        
        // 创建10个线程模拟并发请求
        for (int i = 1; i <= 10; i++) {
            final int threadId = i;
            new Thread(() -> example.useDatabase(threadId)).start();
            
            // 每隔500毫秒启动一个线程
            try {
                Thread.sleep(500);
            } catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    }
}
posted @ 2023-05-24 16:11  CyrusHuang  阅读(75)  评论(0)    收藏  举报