侧边栏

DBLock

package com.example;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.text.MessageFormat;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.AbstractOwnableSynchronizer;
import java.util.concurrent.locks.LockSupport;

/**
 * 通过DB(PostgreSQL)实现的分布式锁。它具有以下功能:<br/>
 * <pre>
 *   lock
 *     有阻塞获取锁
 *     可重入
 *   tryLock
 *     无阻塞获取锁
 *     可重入
 *   tryLock(time)
 *     无阻塞获取锁
 *     可重入
 *     阻塞一段时间获取锁
 *   unlock
 *   *** 无死锁,当获取到分布式锁的进程死掉,需要提供释放锁的机制,否则锁将成死锁。
 * </pre>
 * 它还未有以下功能:<br/>
 * <pre>
 *   lockInterruptibly
 *   tryLock(time)
 *     响应中断
 *   newCondition
 * </pre>
 */
public class DBLock extends AbstractOwnableSynchronizer {

    private static final String insertSQL = "insert into lock_noblock(lock_name,state,expire_time) values(?,?) on conflict do nothing";
    private static final String compareAndSwapSetExpireSQL = "update lock_noblock set state = ? ,expire_time=date+?  where lock_name = ? and ( state = ? or expire_time < date)";
    private static final String updateExpireSQL = "update lock_noblock set expire_time=expire_time+? where lock_name = ?";

    /**
     * 参考 {@link java.util.concurrent.locks.AbstractQueuedSynchronizer#spinForTimeoutThreshold spinForTimeoutThreshold}
     */
    static final long spinForTimeoutThreshold = 1000L;
    static final long expire = 30000L;

    /**
     * 仅仅在可重入使用,表示重入的次数,不可以通过state=0判断锁已经释放
     */
    private int state;
    private DataSource ds;
    private String lockName;
    private LockDBSupport lockDBSupport;

    public DBLock(String lockName, DataSource ds) {
        this.lockName = lockName;
        this.ds = ds;
        initData();
        this.lockDBSupport = new LockDBSupport(lockName,ds);
    }

    public void lock() {
        lock(-1,TimeUnit.SECONDS);
    }

    public void lock(long leaseTime, TimeUnit unit){
        if (!tryGetLock()) {
            while (true) {
                if (compareAndSetStateAndSetExpire(0, 1,unit.toMillis(leaseTime))) {
                    setExclusiveOwnerThread(Thread.currentThread());
                    // 定时续期
                    return;
                } else {
                    lockDBSupport.park();
                }
            }
        }
    }

    boolean tryLock() {
        return tryGetLock();
    }

    boolean tryLock(long time, TimeUnit unit) {
        return tryLock(time,-1,unit);
    }

    boolean tryLock(long time,long leaseTime, TimeUnit unit) {
        long nanosTimeout = unit.toNanos(time);
        if (nanosTimeout <= 0L) {
            return false;
        }
        final long deadline = System.nanoTime() + nanosTimeout;
        if (!tryGetLock()) {
            while (true) {
                if (compareAndSetStateAndSetExpire(0, 1,unit.toMillis(leaseTime))) {
                    setExclusiveOwnerThread(Thread.currentThread());
                    // 定时续期
                    return true;
                } else {
                    nanosTimeout = deadline - System.nanoTime();
                    if (nanosTimeout <= 0L) {
                        return false;
                    }
                    if (nanosTimeout > spinForTimeoutThreshold) {
                        lockDBSupport.parkNanos(nanosTimeout);
                    }
                }
            }
        }
        return false;
    }

    void unlock() {
        if (!isHeldByCurrentThread()) {
            throw new IllegalMonitorStateException();
        } else {
            int c = getState() - 1;
            if (c == 0) {
                setExclusiveOwnerThread(null);
            }
            // 这是设置可重入数,还没有释放锁
            setState(c);
            // 这儿原本可以直接更新state=0,这样就等于释放锁了,为了少些代码直接使用cas更新
            compareAndSetStateAndSetExpire(1, 0,0);
            lockDBSupport.unpark();
        }
    }

    private void initData() {
        PreparedStatement ps = null;
        Connection conn = null;
        try {
            conn = ds.getConnection();
            conn.setAutoCommit(true);
            ps = conn.prepareStatement(insertSQL);
            ps.setString(1, lockName);
            ps.setInt(2, 0);
            ps.execute();
        } catch (SQLException e) {
            throw new RuntimeException(MessageFormat.format("无法创建[{0}]锁", lockName), e);
        } finally {
            close(ps);
            close(conn);
        }
    }

    private int getState() {
        return state;
    }

    private void setState(int state) {
        this.state = state;
    }

    private boolean tryGetLock() {
        if (isHeldByCurrentThread()) {
            int nextc = getState() + 1;
            if (nextc < 0) {// overflow
                throw new Error("Maximum lock count exceeded");
            }
            setState(nextc);
            return true;
        } else {
            if (compareAndSetStateAndSetExpire(0, 1,expire)) {
                setState(1);
                setExclusiveOwnerThread(Thread.currentThread());
                // 定时续期
                return true;
            }
        }
        return false;
    }

    /**
     * CAS获取锁,并且附带设置过期时间
     * @param expect
     * @param update
     * @param expire
     * @return
     */
    private boolean compareAndSetStateAndSetExpire(int expect, int update,long expire) {
        PreparedStatement ps = null;
        Connection conn = null;
        try {
            conn = ds.getConnection();
            ps = conn.prepareStatement(compareAndSwapSetExpireSQL);
            ps.setInt(1, update);
            ps.setLong(2,expire);
            ps.setString(3, lockName);
            ps.setInt(4, expect);
            return ps.executeUpdate() == 1 ? true : false;
        } catch (SQLException e) {
            throw new RuntimeException(MessageFormat.format("找到[{0}]锁", lockName), e);
        } finally {
            close(ps);
            close(conn);
        }
    }

    private void renewal(long nanos) {
        Thread thread = new Thread(new Runnable() {
            @Override
            public void run() {
                // get expire
                long expire = 0;
                long gapTime = expire - System.nanoTime();
                if(gapTime <= 0){
                    throw new RuntimeException("锁已过期");
                }
                try {
                    Thread.sleep(gapTime * 2/3);
                } catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
                PreparedStatement ps = null;
                Connection conn = null;
                try {
                    conn = ds.getConnection();
                    ps = conn.prepareStatement(updateExpireSQL);
                    ps.setLong(1, nanos);
                    ps.setString(2, lockName);
                } catch (SQLException e) {
                    throw new RuntimeException(MessageFormat.format("找到[{0}]锁", lockName), e);
                } finally {
                    close(ps);
                    close(conn);
                }
            }
        });
        thread.start();
    }

    /**
     * 锁是否由当前线程持有 </br>
     * 不需要设置 exclusiveOwnerThread 为 volatile。因为 A线程设置exclusiveOwnerThread,B线程不管是否能取到这个设置的值都是false。
     * @return true锁由当前线程持有;false锁不为当前线程持有。
     */
    private boolean isHeldByCurrentThread() {
        return Thread.currentThread() == getExclusiveOwnerThread();
    }

    private void close(AutoCloseable ac) {
        try {
            if (ac != null) {
                ac.close();
            }
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }
}

 

posted on 2023-07-31 00:05  SmilingEye  阅读(8)  评论(0编辑  收藏  举报

导航