JUC ThreadLocal

words: 2.2k    views:    time: 9min

ThreadLocal从其命名上就可以知道其意图是创建线程本地变量,就是希望同一个变量在不同的线程中拥有各自的值并且互不影响,其非常适合用来作为线程上下文变量,比如在一些连接池或者事务的场景中。

其思路是让每个Thread都持有一个私有的ThreadLocalMap,然后使用共享的key来保存值,而这个key就是共享的ThreadLocal实例,因此每个ThreadLocal也就对应一个本地变量。。但是,如果这个本地变量本身就是一个线程共享的对象,那么就算使用ThreadLocal也不是线程安全的。另外,设计者将具体的map操作都封装在了ThreadLocal中,然后提供统一的get/set接口,好让开发更加简洁方便。

1. 结构

先看一下Map中的元素Entry,其使用了弱引用来封装key

java.lang.ThreadLocal
1
2
3
4
5
6
7
8
9
static class Entry extends WeakReference<ThreadLocal<?>> {

Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

至于使用弱引用自然是为了当ThreadLocal不再共享时,能尽快被回收。考虑如下场景:在方法中创建并使用ThreadLocal

1
2
3
4
5
6
7
public void threadContext(){
ThreadLocal<Object> local = new ThreadLocal<>();
local.set(new Object());

local = new ThreadLocal<>();
//...
}

如果key使用的是强引用,那么即便这里的实例被重置了或者方法结束了,只要线程没结束,ThreadLocal实例就无法被GC,尤其在一些使用线程池的场景中。另外,虽然使用了弱引用,但在一些线程复用的场景中(比如线程池),如果确定了ThreadLocal不再使用,最好也主动remove,以免内存泄漏。

不过,一般情况下,ThreadLocal作为key都是由线程共享的,因此通常会定义为静态对象,或者由其它线程共享的对象持有。

2. 接口

2.1. set

首先获取当前线程持有的map,如果没有则创建,然后以当前ThreadLocal作为key来保存值

java.lang.ThreadLocal
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}

ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

至于set逻辑都封装在ThreadLocalMap中,这里需要考虑hash冲突、清理已经被回收的key、以及扩容等问题,但是作者非常巧妙地通过环形遍历来解决hash冲突的问题,并充分利用了空间,具体下面的注释已经非常详细,就不再赘述

java.lang.ThreadLocal.ThreadLocalMap
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

// 从hash确定的位置开始,向后环形遍历
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
// 遇到相同key,则直接覆盖,并返回
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}

// 遇到过期Entry,则从当前位置向后环形遍历直至遇到null为止(因为由于hash冲突的处理方式,key可能在当前位置之后)
// 如果找到相同的key则进行覆盖,否则在过期Entry的位置上新建个Entry代替,
// 另外,会尝试找其它过期Entry的位置,并纪录下来以便帮忙清理
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

// 直到为null的位置,都没遇到上面的情况,则就地新建Entry,然后尝试删除过期的Entry,没有则检查是否需要扩容
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold) // 这里保证了每次set一定能找到空位置
rehash();
}

// 向后环型遍历
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

private void replaceStaleEntry(ThreadLocal<?> key, Object value, int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// 向前环形遍历直至为null,找到最前面过期Entry的位置,并记为slotToExpunge,作为下面清理的起点
// 这样做是因为如果出现了过期Entry,则大概率在其附近还会有其它过期的Entry,毕竟GC时是平等的,没有只针对哪个
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// 向后环形遍历直至为null
for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

// 如果遇到了相同key,则直接覆盖值,并提前结束
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e; // 交换下位置,将key放到其hash对应的正确位置上

// 如果没发现其它过期的Entry,那么就清除位置staleSlot上的,不过现在的位置是i
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

// 如果之前向前遍历没有发现其它的过期Entry,但现在又发现了,那么重置下清理的起点
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

// 没有遇到相同key的Entry,那么在过期Entry的位置上创建一个新的代替
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

if (slotToExpunge != staleSlot) // 如果还存在其它的过期Entry,则进行清理
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;

// 由于环形遍历没有尽头,所以清理时默认只向后遍历log(n)次
// 如果遍历过程中发现有过期的Entry,则重置n为tab.length,并跳到下一个为null的位置,然后再向后遍历log(len)次
// 这里设定的 threshold = (2/3) * tab.length,以及默认遍历log(size)次,应该是综合考虑了效率和成本
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i); // 清除当前Entry,并尝试继续向后清除过期Entry,直至为null的位置返回
}
} while ( (n >>>= 1) != 0);
return removed;
}

private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// 清除过期Entry
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

Entry e;
int i;
// 向后遍历,直到遇到Entry为null的位置
// 遍历过程中如果遇到过期的Entry,则帮忙清除
// 否则如果Entry之前由于hash冲突导致位置不对(往后顺延了),则帮忙调整一下位置(尽量靠近正确的位置)
for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
while (tab[h] != null) // 如果发现位置h也已经有了Entry,则向后顺延
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

2.2. get

获取则比较简单,就是尝试从map中获取值,如果map为空则创建,并提供了方法来初始化创建时的值

java.lang.ThreadLocal
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}

private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

protected T initialValue() {
return null;
}

至于map获取的过程也是类似,即如果位置上的key不对,则向后环形遍历,如果遇到了就返回,否则如果遇到了过期Entry则帮忙清理,如果直至null也没有遇到则返回空

java.lang.ThreadLocal.ThreadLocalMap
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) { // 向后环形遍历,直至为null
ThreadLocal<?> k = e.get();
if (k == key) // 找到了直接返回
return e;

if (k == null) // 过期Entry则帮忙清理,与上面一样
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
return null; // 没有这个hash对应的ThreadLocal,直接返回null
}

2.3. remove

java.lang.ThreadLocal
1
2
3
4
5
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
java.lang.ThreadLocal.ThreadLocalMap
1
2
3
4
5
6
7
8
9
10
11
12
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i); // 清除当前Entry,并尝试继续向后清除过期Entry,直至为null的位置返回
return;
}
}
}

3. InheritableThreadLocal

如果希望值能够在父子线程之间传递,即当前线程的值能够在其创建的线程中继续使用,则可以使用InheritableThreadLocal,其实现基本一样,区别在于Thread初始化的时候会拿父线程InheritableThreadLocal的值来给子线程初始化


参考:

  1. https://juejin.cn/post/6844903552477822989
  2. https://www.jianshu.com/p/dde92ec37bd1