Java 多线程之信号量-Semaphore
计数信号量(Counting Semaphore)用来控制同时访问某个特定资源的操作数量,或者执行某个特定操作的数量。计数信号量还可以用来实现某种资源池或者对容器加边界。
Semaphore中管理着一组虚拟许可(permit),许可的初始数量可以通过在构造方法中指定。在执行操作时首先获取许可(permit),在使用完成后释放许可。如果没有许可,那么acquire则一直阻塞到有许可或者超时中断。
/**
* 计数信号量(Counting Semaphore)用来控制同时访问某个特定资源的操作数量,或者执行某个指定操作的数量。
* 计数信号量还可以用来实现某种资源池或者对容器进行加边界
* Semaphore 中管理着一组虚拟的许可(permit),许可的初始化数量可以通过构造函数来指定。在执行操作时可以
* 首先获取许可(只要还有剩余的许可),并在使用后释放许可(permit)。 如果没有许可(permit)那么aquire将阻塞
* 直到有许可或者直到被中断、操作超时。release方法将释放一个许可,返回给信号量(semaphore)
*
* @author zhangwei_david
* @version $Id: BoundedHashSet.java, v 0.1 2014年11月11日 下午1:49:24 zhangwei_david Exp $
*/
public class BoundedHashSet<T> {
//
private final Set<T> set;
// 信号量
private final Semaphore sem;
public BoundedHashSet(int bound) {
//初始化一个同步Set
this.set = Collections.synchronizedSet(new HashSet<T>());
//初始化信号量大小
sem = new Semaphore(bound);
}
public boolean add(T o) throws InterruptedException {
// 获取许可
sem.acquire();
boolean wasAdded = false;
try {
wasAdded = set.add(o);
return wasAdded;
} finally {
if (!wasAdded) {
// 如果增加失败
sem.release();
}
}
}
public boolean remove(Object o) {
boolean wasRemoved = set.remove(o);
if (wasRemoved) {
// 删除后释放许可
sem.release();
}
return wasRemoved;
}
}
下面看看源码具体实现方式
public class Semaphore implements java.io.Serializable {
private static final long serialVersionUID = -3222578661600680210L;
/** All mechanics via AbstractQueuedSynchronizer subclass */
private final Sync sync;
abstract static class Sync extends AbstractQueuedSynchronizer {
private static final long serialVersionUID = 1192457210091910933L;
//构造方法,初始化许可数量
Sync(int permits) {
setState(permits);
}
//获取总许可数量
final int getPermits() {
return getState();
}
//非公平获取许可
final int nonfairTryAcquireShared(int acquires) {
for (;;) {
//当前可用的总许可数量
int available = getState();
//当前有效许可数减去本次期望获取的数量得到的是本次操作后可能剩余许可数量
int remaining = available - acquires;
//如果剩余许可数量小于0,表示有效许可数不满足需求,或者CAS更新失败则自旋
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
protected final boolean tryReleaseShared(int releases) {
for (;;) {
//获取当前许可数量
int current = getState();
//释放后预期的许可数量
int next = current + releases;
//如果释放后许可数量小于当前许可数量则表示溢出
if (next < current)
throw new Error("Maximum permit count exceeded");
//CAS更新成功则返回true,更新失败则自旋
if (compareAndSetState(current, next))
return true;
}
}
//直接减少许可数量
final void reducePermits(int reductions) {
for (;;) {
//获取当前许可数量
int current = getState();
//扣减后预期的许可数量
int next = current - reductions;
//如果扣减后还大于当前许可数表示溢出
if (next > current)
throw new Error("Permit count underflow");
//CAS更新成功就返回,否则自旋
if (compareAndSetState(current, next))
return;
}
}
//排干许可,即通过自旋CAS将许可树更新为0
final int drainPermits() {
for (;;) {
int current = getState();
if (current == 0 || compareAndSetState(current, 0))
return current;
}
}
}
/**
* NonFair version
*/
static final class NonfairSync extends Sync {
private static final long serialVersionUID = -2694183684443567898L;
NonfairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
//调用非公平获取方法
return nonfairTryAcquireShared(acquires);
}
}
static final class FairSync extends Sync {
private static final long serialVersionUID = 2014338818796000944L;
FairSync(int permits) {
super(permits);
}
protected int tryAcquireShared(int acquires) {
for (;;) {
//是否有排队
if (hasQueuedPredecessors())
return -1;
//
int available = getState();
int remaining = available - acquires;
if (remaining < 0 ||
compareAndSetState(available, remaining))
return remaining;
}
}
}
//默认非公平
public Semaphore(int permits) {
sync = new NonfairSync(permits);
}
public Semaphore(int permits, boolean fair) {
sync = fair ? new FairSync(permits) : new NonfairSync(permits);
}
//获取1个许可
public void acquire() throws InterruptedException {
sync.acquireSharedInterruptibly(1);
}
public void acquireUninterruptibly() {
sync.acquireShared(1);
}
public boolean tryAcquire() {
return sync.nonfairTryAcquireShared(1) >= 0;
}
public boolean tryAcquire(long timeout, TimeUnit unit)
throws InterruptedException {
return sync.tryAcquireSharedNanos(1, unit.toNanos(timeout));
}
public void release() {
sync.releaseShared(1);
}
public void acquire(int permits) throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireSharedInterruptibly(permits);
}
public void acquireUninterruptibly(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.acquireShared(permits);
}
public boolean tryAcquire(int permits) {
if (permits < 0) throw new IllegalArgumentException();
return sync.nonfairTryAcquireShared(permits) >= 0;
}
public boolean tryAcquire(int permits, long timeout, TimeUnit unit)
throws InterruptedException {
if (permits < 0) throw new IllegalArgumentException();
return sync.tryAcquireSharedNanos(permits, unit.toNanos(timeout));
}
public void release(int permits) {
if (permits < 0) throw new IllegalArgumentException();
sync.releaseShared(permits);
}
public int availablePermits() {
return sync.getPermits();
}
public int drainPermits() {
return sync.drainPermits();
}
protected void reducePermits(int reduction) {
if (reduction < 0) throw new IllegalArgumentException();
sync.reducePermits(reduction);
}
public boolean isFair() {
return sync instanceof FairSync;
}
public final boolean hasQueuedThreads() {
return sync.hasQueuedThreads();
}
public final int getQueueLength() {
return sync.getQueueLength();
}
protected Collection<Thread> getQueuedThreads() {
return sync.getQueuedThreads();
}
}

浙公网安备 33010602011771号