TransmittableThreadLocal 的反复与纠缠

TransmittableThreadLocal 相信很多人用过,一个在多线程情况下共享线程上下文的利器
名字好长,以下简称 ttl
本文以之前一个真实项目开发遇到的问题,还原当时从源码的角度分析并解决问题的过程


环境

item version
java 8
springboot 2.2.2.RELEASE
ttl 2.11.4

代码如下,主线程并行启复数任务丢给线程池处理
点击查看代码
List<ProcDef> defs = createsValidate(procCreate);
List<CompletableFuture<CreateResult>> cfs = Lists.newArrayListWithCapacity(defs.size());
defs.forEach(def -> {
    AbstractTransientVariable variable = procCreate.getProcDefKeyVars().get(def.getProcDefKey());
    CompletableFuture<CreateResult> cf = CompletableFuture.supplyAsync(() -> create(def, variable), threadPoolTaskExecutor)
	    .handle((r, e) -> {
		if (e != null) {
		    expHandle(def.getProcDefKey(), procCreate.getUserId(), variable, e);
		}
		return r;
	    });
    cfs.add(cf);
});

List<CreateResult> results = cfs.stream().map(CompletableFuture::join).filter(Objects::nonNull).collect(Collectors.toList());
if (defs.size() != results.size()) {
    log.error("Process create fail exist: [{}]", JacksonUtil.toJsonString(procCreate));
}

第一行 createsValidate 方法在主线程设置了当前用户的上下文
将用户信息放入了 ttl,方便后续子线程使用
测试的时候, 线程池在处理任务的时候,有时会获取不到主线程 ttl 信息
很奇怪,之前也是一直这样使用,为什么没有问题


于是本地main方法模拟

点击查看代码
UserContext.set(new ContextUser().setUserId("mycx26"));

IntStream.range(0, 10).forEach(e -> {
    Supplier<Void> supplier = () -> {
	String userId = UserContext.get() != null ? UserContext.getUserId() : null;
	System.out.println(Thread.currentThread().getName() + " get: " + userId);
	return null;
    };
    CompletableFuture.supplyAsync(supplier);
});

Thread.currentThread().join();
这里主线程将用户信息放入 ttl,依次将异步任务丢给线程池,任务执行获取 ttl 并打印
点击查看代码
ForkJoinPool.commonPool-worker-9 get: mycx26
ForkJoinPool.commonPool-worker-6 get: mycx26
ForkJoinPool.commonPool-worker-13 get: mycx26
ForkJoinPool.commonPool-worker-4 get: mycx26
ForkJoinPool.commonPool-worker-11 get: mycx26
ForkJoinPool.commonPool-worker-2 get: mycx26
ForkJoinPool.commonPool-worker-6 get: mycx26
ForkJoinPool.commonPool-worker-15 get: mycx26
ForkJoinPool.commonPool-worker-8 get: mycx26
ForkJoinPool.commonPool-worker-9 get: mycx26

从输出结果看,各线程都拿到了用户信息,似乎又没有问题


将相同的代码放到工程的单元测试方法里跑

点击查看代码
ForkJoinPool.commonPool-worker-10 get: null
ForkJoinPool.commonPool-worker-15 get: null
ForkJoinPool.commonPool-worker-9 get: null
ForkJoinPool.commonPool-worker-1 get: null
ForkJoinPool.commonPool-worker-13 get: null
ForkJoinPool.commonPool-worker-8 get: null
ForkJoinPool.commonPool-worker-3 get: null
ForkJoinPool.commonPool-worker-2 get: null
ForkJoinPool.commonPool-worker-11 get: null
ForkJoinPool.commonPool-worker-15 get: null
结果却截然相反,到这里我有点怀疑 ttl 对于多线程支持的泛用性了

找到 ttl 的 github readme 阅读
要保证线程池中传递值,一种方式是修饰 Runnable 和 Callable,Supplier 也有类似的包装器
于是修改代码重新测试,测试通过

虽然问题是解决了,但是原因却无从得知,等于还是绕过了问题
下次遇到 ttl 的问题,不知道原理还是无从下手
找到了一个已经 closed 类似的 issue
https://github.com/alibaba/transmittable-thread-local/issues/138
但还是没有解决我的疑问
没有办法,只能看源码了,问题还是要一个一个解决


一. main方法没有修饰的任务为什么能跨越线程池传递 ttl

1.1 首先看看 ttl 的 set 方法做了什么

点击查看代码
public final void set(T value) {
    if (!disableIgnoreNullValueSemantics && null == value) {
        // may set null to remove value
        remove();
    } else {
        super.set(value);
        addThisToHolder();
    }
}
else 走了父类 ThreadLocal 的 set 方法
点击查看代码
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
}

先看一下 ttl 的继承体系
ttl 继承 InheritableThreadLocal,InheritableThreadLocal 继承 ThreadLocal
InheritableThreadLocal 可以让子线程访问父线程设置的本地变量
点击查看代码
ThreadLocalMap getMap(Thread t) {
   return t.inheritableThreadLocals;
}

void createMap(Thread t, T firstValue) {
    t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}

通过重写 getMap 和 createMap 方法将 ThreadLocal 的维护职责
由 threadLocals 转移给了 inheritableThreadLocals
threadLocals 和 inheritableThreadLocals 类型一样
是 ThreadLocal 中的静态内部类 ThreadLocalMap
为了维护线程本地变量定制化的哈希map, 两者由 Thread 持有

回到上文 TheadLocal set方法
首先获取当前线程,入参调用 getMap 方法获取当前线程的 inheritableThreadLocals

  • map不为null
    将 ttl 做为 key,value 作为值,放入当前线程的 inheritableThreadLocals

  • map为null
    将 ttl 和 value 构造一个新的 ThreadLocalMap,初始化当前线程的 inheritableThreadLocals

1.2 接下来看 CompletableFuture 的 supplyAsync 方法
这个方法调用栈很深,如果多线程功力不深,基本看不懂
但这不妨碍排查这个问题
supplyAsync 默认用的 ForkJoinPool 跑任务
那么必然会启一个线程
即必然会调用 Thread 的 init 方法初始化线程

首先将断点加到 CompletableFuture.supplyAsync(supplier); 这行
debug跑起来
然后将断点加到 Thread init 方法的第一行
(防止jvm启动初始化的线程产生干扰,比如 c2 complier thread)

点击查看代码
Thread parent = currentThread();

if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

重点是这里,判断父线程的 inheritableThreadLocals 如果不为 null
就把父线程的 inheritableThreadLocals 复制到子线程


1.3 接着看 ttl 的 get 方法

点击查看代码
public final T get() {
    T value = super.get();
    if (disableIgnoreNullValueSemantics || null != value) addThisToHolder();
    return value;
}

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

同样走的 ThreadLocal 的 get 方法
首先获取当前线程,getMap 获取其对应的 inheritableThreadLocals
顺利拿到之前父线程设置的变量
到这里,第一个问题算是解了


二. 同样的代码跑在单元测试,没有修饰的任务为什么不能跨越线程池传递 ttl

这里我有理由怀疑是 spring 容器在拉起的时候,提前用到了 ForkJoinPool 的 commonPool
但是项目依赖众多,如何定位
既然用到了,那么将断点加在 ForkJoinPool 启动线程
然后沿着调用栈帧一直向上找不就行了
将断点加在 ForkJoinPool 的 createWorker方法的第一行
开始找

点击查看代码
/* 检查逻辑删除字段只能有最多一个 */
    Assert.isTrue(fieldList.parallelStream().filter(TableFieldInfo::isLogicDelete).count() < 2L,
        String.format("annotation of @TableLogic can't more than one in class : %s.", clazz.getName()));

果然,熟悉的身影,mybatis plus
spring容器拉起时在初始化 SqlSessionFactory 时
会调用 TableInfoHelper 的 initTableFields 方法初始化表主键和字段
注意这里用的 stream 的并行流 parallel stream,很熟悉了
底层默认用的 ForkJoinPool 的 commonPool
那么在主线程设置的 TTL,线程池中的线程之前已经初始化,当然就拿不到了
好,这是第二个问题


三. 为什么项目中自定义线程池获取不到前面主线程创建的 ttl
和二是相同的问题,执行操作前,线程池已经被调度执行任务了
线程如果池化,那么后续在跑异步任务时就没有父子线程之说了
那么现在只剩最后一个问题


四. 为什么项目中任务加了包装器后又拿到了

点击查看代码
TtlWrappers.wrap(() -> create(def, variable))

没有什么办法,跟进去吧

4.1 看看 TtlWrappers 的静态方法 wrap 做了什么

点击查看代码
public static <T> Supplier<T> wrap(@Nullable Supplier<T> supplier) {
    if (supplier == null) return null;
    else if (supplier instanceof TtlEnhanced) return supplier;
    else return new TtlSupplier<T>(supplier);
}

看样子,大概是想用装饰模式包装 Supplier 为 TTL wrapper
new 了一个 TtlSupplier,这是 TtlWrappers 的一个静态内部类
继续进去

点击查看代码
TtlSupplier(@NonNull Supplier<T> supplier) {
    this.supplier = supplier;
    this.capture = capture();
}

supplier完成赋值后,重点是后面的 capture

点击查看代码
/**
 * Capture all {@link TransmittableThreadLocal} and registered {@link ThreadLocal} values in the current thread.
 *
 * @return the captured {@link TransmittableThreadLocal} values
 * @since 2.3.0
 */
@NonNull
public static Object capture() {
    return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

capture 静态方法位于 ttl 的静态内部类 Transmitter 中
注释很清晰,捕获当前线程的所有 ttl 和 ThreadLocal 的值
new Snapshort 继续跟进去

点击查看代码
private static class Snapshot {
    final WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
    final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value;

    private Snapshot(WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
        this.ttl2Value = ttl2Value;
        this.threadLocal2Value = threadLocal2Value;
    }
}

Snaphost 同样是 ttl 的静态内部类
构造方法的第一个参数方法 captureTtlValues 跟进去

点击查看代码
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    return ttl2Value;
}

同样来自 Transmitter
好,代码并不复杂,重点是 holder

点击查看代码
// Note about holder:
// 1. holder self is a InheritableThreadLocal(a *ThreadLocal*).
// 2. The type of value in holder is WeakHashMap<TransmittableThreadLocal<Object>, ?>.
//    2.1 but the WeakHashMap is used as a *Set*:
//        - the value of WeakHashMap is *always null,
//        - and never be used.
//    2.2 WeakHashMap support *null* value.
private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
        new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
            @Override
            protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
            }

            @Override
            protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
            }
        };

holder 为 ttl 的静态成员变量,类型为 InheritableThreadLocal 的匿名内部类
重写了 initialValue 和 childValue 方法
再看注释
这里 value 的 type 是 WeakHashMap 并且这个 map 被当作 set 用了
还记得上文分析 ttl 的 set 方法吗,有一块没有讲
对,就是 else 的 addThisToHolder

点击查看代码
private void addThisToHolder() {
    if (!holder.get().containsKey(this)) {
        holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
    }
}

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

set方法
第一步将 ttl 做为 key,value 作为值,放入当前线程的 inheritableThreadLocals
第二步 addThisToHolder
holder.get()方法获取当前线程的 inheritableThreadLocals 变量
因为现在是 holder, 这里一定要理解
第一次肯定是拿不到的,那么这里为什么没有npe?
上文提到他重写了 initialValue 方法
所以holder.get()方法第一次获取到一个空的 WeakHashMap
到此
当前线程的 inhrritableThreadLocals 会多一个 Entry
key 为 holder,value 是空的 WeakHashMap

继续
判断这个 map 是否包含 ttl
第一次肯定没有
那么把 ttl 本身作为key, value 为 null, 放入 map
到此
当前线程的 inhrritableThreadLocals 的 Entry 没有变化
key 为 holder 的 WeakHashMap 里会多一个值
key 为 当前 ttl, value 为 null

至此 TtlSupplier 的 capture 属性已经持有了主线程的所有 ttl 快照


4.2 接下来看 TtlSupplier 重写的 get 方法

这是核心的行为,可以断定,其必然做了增强

点击查看代码
public T get() {
    final Object backup = replay(capture);
    try {
        return supplier.get();
    } finally {
        restore(backup);
    }
}

结构很清晰,先replay,再执行核心行为,最后restore
replay 跟进去

点击查看代码
/**
 * Replay the captured {@link TransmittableThreadLocal} and registered {@link ThreadLocal} values from {@link #capture()},
 * and return the backup {@link TransmittableThreadLocal} values in the current thread before replay.
 *
 * @param captured captured {@link TransmittableThreadLocal} values from other thread from {@link #capture()}
 * @return the backup {@link TransmittableThreadLocal} values before replay
 * @see #capture()
 * @since 2.3.0
 */
@NonNull
public static Object replay(@NonNull Object captured) {
    final Snapshot capturedSnapshot = (Snapshot) captured;
    return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

同样是位于 ttl 的静态内部类 Transmitter 的静态方法
上文 Tramsmitter capture()方法捕获的主线程快照这里用到了
replayTtlValues 方法跟进去

点击查看代码
private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // backup
        backup.put(threadLocal, threadLocal.get());

        // clear the TTL values that is not in captured
        // avoid the extra TTL values after replay when run task
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // set TTL values to captured
    setTtlValuesTo(captured);

    // call beforeExecute callback
    doExecuteCallback(true);

    return backup;
}

private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
    for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

注意这里已经到子线程了
首先通过 holder.get() 获取子线程的 inheritableThreadLocals 变量
很可能没有
但是如果线程之前已经池化用完没有remove,这里是有的
如果 map 的 size 大于0,遍历 map 的 key ttl
这里为了不污染子线程上下文,先做了备份
对于快照中不包含的 ttl 信息依次从子线程 remove
然后遍历快照信息设置到当前线程的 inheritableThreadLocals
doExecuteCallback 方法是 ttl 为开发者留的一个勾子方法
时机在任务执行前
最后返回子线程 ttl 备份

再回到 TtlSupplier 的 get 方法
supplier 的 get 方法执行任务
最后还剩 restore,传入上面子线程的 ttl 备份

点击查看代码
/**
 * Restore the backup {@link TransmittableThreadLocal} and
 * registered {@link ThreadLocal} values from {@link #replay(Object)}/{@link #clear()}.
 *
 * @param backup the backup {@link TransmittableThreadLocal} values from {@link #replay(Object)}/{@link #clear()}
 * @see #replay(Object)
 * @see #clear()
 * @since 2.3.0
 */
public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}

private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
    // call afterExecute callback
    doExecuteCallback(false);

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // clear the TTL values that is not in backup
        // avoid the extra TTL values after restore
        if (!backup.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // restore TTL values
    setTtlValuesTo(backup);
}

主要看 restoreTtlValues 方法
doExecuteCallback 和上面逻辑类似
区别在时机在任务执行后
holder 获取子线程的 inheritableThreadLocals 变量
遍历 map 的 key ttl
对于不在备份的 ttl 全部删除
最后恢复子线程的 ttl

仿佛一切没有发生过

至此最后一个问题解决

你对的不一定对,你错了一定是错了
源码面前,没有什么秘密可言了

posted @ 2024-03-13 10:25  mycx26  阅读(336)  评论(0)    收藏  举报