ThreadLocal使用

ThreadLocal

java.lang.ThreadLocal 是 Java 中的一种用于实现线程局部变量的工具,它允许每个线程都有自己的独立变量副本。这在多线程环境中非常有用,尤其是在需要避免共享状态的情况下。

基本使用

适用场景: 适用于需要为每个线程维护独立状态的场景,例如数据库连接、用户会话等。

主要方法

  1. T get():返回当前线程所对应的线程局部变量的值。如果该线程没有自己的副本,则返回 null

    T value = threadLocal.get();
    
  2. void set(T value):将当前线程的线程局部变量设置为给定的值。如果当前线程之前没有该变量的副本,则创建一个新的副本。

    threadLocal.set(value);
    
  3. void remove():移除当前线程的线程局部变量副本。这通常用于清理资源,防止内存泄漏。

    threadLocal.remove();
    
  4. ThreadLocal():默认构造函数,用于创建一个新的 ThreadLocal 实例。

  5. ThreadLocal(Supplier<? extends T> supplier):使用给定的 Supplier 创建一个新的 ThreadLocal 实例,以在每个线程第一次调用 get() 时初始化变量。

    ThreadLocal<MyType> threadLocal = ThreadLocal.withInitial(() -> new MyType());
    

实现原理

每个线程都存有一个ThreadLocalMap,这个是线程私有的

class Thread {
    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. 
     */
    ThreadLocal.ThreadLocalMap threadLocals = null;
}

每次set,get操作时都要先取当前线程私有的ThreadLocalMap,所以没有数据共享带来的线程安全问题

/**
 * Get the map associated with a ThreadLocal. Overridden in
 * InheritableThreadLocal.
 *
 * @param  t the current thread
 * @return the map
 */
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

ThreaLocalMap

key为ThreadLocal对象,value为ThreadLocal存放的值,每个Entry与普通Map有点不同,key为弱引用

/**
 * The entries in this hash map extend WeakReference, using
 * its main ref field as the key (which is always a
 * ThreadLocal object).  Note that null keys (i.e. entry.get()
 * == null) mean that the key is no longer referenced, so the
 * entry can be expunged from table.  Such entries are referred to
 * as "stale entries" in the code that follows.
 */
static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

void set(T value)

/**
 * Sets the current thread's copy of this thread-local variable
 * to the specified value.  Most subclasses will have no need to
 * override this method, relying solely on the {@link #initialValue}
 * method to set the values of thread-locals.
 *
 * @param value the value to be stored in the current thread's copy of
 *        this thread-local.
 */
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
}

/**
 * Create the map associated with a ThreadLocal. Overridden in
 * InheritableThreadLocal.
 *
 * @param t the current thread
 * @param firstValue value for the initial entry of the map
 */
void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap#set

/**
 * Set the value associated with key.
 *
 * @param key the thread local object
 * @param value the value to be set
 */
private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            // 被回收了
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

replaceStaleEntry

/**
 * Replace a stale entry encountered during a set operation
 * with an entry for the specified key.  The value passed in
 * the value parameter is stored in the entry, whether or not
 * an entry already exists for the specified key.
 *
 * As a side effect, this method expunges all stale entries in the
 * "run" containing the stale entry.  (A run is a sequence of entries
 * between two null slots.)
 *
 * @param  key the key
 * @param  value the value to be associated with key
 * @param  staleSlot index of the first stale entry encountered while
 *         searching for key.
 */
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // Find either the key or trailing null slot of run, whichever
    // occurs first
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run.
        if (k == key) {
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

cleanSomeSlots

当新元素被添加时,会扫描key为null的Entry进行删除

/**
 * Heuristically scan some cells looking for stale entries.
 * This is invoked when either a new element is added, or
 * another stale one has been expunged. It performs a
 * logarithmic number of scans, as a balance between no
 * scanning (fast but retains garbage) and a number of scans
 * proportional to number of elements, that would find all
 * garbage but would cause some insertions to take O(n) time.
 *
 * @param i a position known NOT to hold a stale entry. The
 * scan starts at the element after i.
 *
 * @param n scan control: {@code log2(n)} cells are scanned,
 * unless a stale entry is found, in which case
 * {@code log2(table.length)-1} additional cells are scanned.
 * When called from insertions, this parameter is the number
 * of elements, but when from replaceStaleEntry, it is the
 * table length. (Note: all this could be changed to be either
 * more or less aggressive by weighting n instead of just
 * using straight log n. But this version is simple, fast, and
 * seems to work well.)
 *
 * @return true if any stale entries have been removed.
 */
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

T get()

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

/**
 * Get the entry associated with key.  This method itself handles only the fast 
 * path: a direct hit of existing key. It otherwise relays to getEntryAfterMiss.  
 * This is designed to maximize performance for direct hits, in part by making 
 * this method readily inlinable.
 *
 * @param  key the thread local object
 * @return the entry associated with key, or null if no such
 */
private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

InheritableThreadLocal

ThreadLocal不能在父子线程之间传递, 即在子线程中无法访问在父线程中设置的本地线程变量。 后来为了解决这个问题,引入了一个新的类InheritableThreadLocal

InheritableThreadLocal 是一个特殊类型的 ThreadLocal,它允许子线程继承父线程的值。这在需要在线程之间传递状态时非常有用,比如在创建子线程时传递上下文信息。

使用该方法后,子线程可以访问在创建子线程时父线程当时的本地线程变量,其实现原理就是在父线程创建子线程时将父线程当前存在的本地线程变量拷贝到子线程的本地线程变量中。主要是重写了childValuegetMapcreateMap方法

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

    protected T childValue(T parentValue) {
        return parentValue;
    }

    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

主要特点

  • 继承机制: 子线程可以访问父线程的 InheritableThreadLocal 变量的值。
  • 线程局部: 每个线程依然有自己的副本,子线程获取的是在创建时父线程的值。

下面是一个使用InheritableThreadLocal的例子:

public class InheritableThreadLocalExample {
  private static final InheritableThreadLocal<String> inheritableThreadLocal =
    new InheritableThreadLocal<String>() {
      @Override
      protected String initialValue() {
        return "Initial Value";
      }
    };

  public static void main(String[] args) {
    // 设置父线程的值
    inheritableThreadLocal.set("Parent Value");

    // 创建子线程
    Thread childThread = new Thread(() -> {
      // 获取子线程的值
      String value = inheritableThreadLocal.get();
      System.out.println("Child Thread Value: " + value); // 输出 "Parent Value"
    });

    childThread.start();

    try {
      childThread.join(); // 等待子线程完成
    } catch (InterruptedException e) {
      e.printStackTrace();
    }

    // 主线程的值
    System.out.println("Main Thread Value: " + inheritableThreadLocal.get()); // 输出 "Parent Value"
  }
}

常见使用问题

脏数据

当使用线程池时,线程是复用的。如果在某个线程中设置了 ThreadLocal 变量,但在后续任务中未清理该变量,可能会导致旧数据被保留在新任务中。

解决方案: 在每个任务执行完毕后,调用 remove() 方法清理 ThreadLocal 变量。

内存泄漏

ThreadLocal在某些场景下可能存在内存泄漏问题

平常使用ThreadLocal一般都是static修饰

public class ThreadLocalExample {
  private static final ThreadLocal<Integer> 
      threadLocalValue = ThreadLocal.withInitial(() -> 0);
}

这种情况下ThreadLocal由于static修饰,是强引用,一般不会被回收,所以ThreadLocalMap把key设置为弱引用对于value的回收是没关系的,此时需要通过set覆盖上一次设置的值,那么上一次设置的值自然就不被引用而最终会被回收,或者手动remove。

如果线程会被回收,ThreadLocal本身由于被static修饰,不会随着线程被回收而回收。假设这个ThreadLocal使用频率非常低或者只使用一次,那么也可以看做是内存泄漏。

将ThreadLocal用作局部变量或者成员变量时,因为ThreadLocalMap是和线程绑定的,如果线程没被回收,那么ThreadLocalMap也不会被回收。但是此时作为局部变量的ThreadLocal或者作为成员变量所在的那个对象被回收的话,拿不到ThreadLocal的引用了,由于ThreadLocal作为key是弱引用,所以ThreadLocal本身不会存在内存泄漏,但是ThreadLocalMap的value是强引用,是绑定到线程上的。这种场景需要及时remove以防止内存泄漏。

public static void main(String[] args) {
  ThreadLocal<Integer> threadLocalValue = new ThreadLocal<>();
  ThreadPoolExecutor
    executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(3);
  for (int i = 0; i < 3; i++) {
    final int threadId = i;
    executor.execute(() -> {
      // 设置线程局部变量
      threadLocalValue.set(threadId);
      System.out.println("Thread " + Thread.currentThread().getName() + 
                         " has value: " + threadLocalValue.get());
      threadLocalValue.remove();
    });
  }
  executor.shutdown();
}

参考资料:

  1. https://stackoverflow.com/questions/17968803/threadlocal-memory-leak
posted @ 2025-08-03 07:54  vonlinee  阅读(64)  评论(0)    收藏  举报