ThreadLocalMap为什么用线性探测解决哈希冲突

前言

ThreadLocal 本身不存储值,访问的是当前线程 ThreadLocalMap 里存储的数据副本,实现了线程间的数据隔离。只有当前线程能访问,也就不存在并发访问时的安全问题了。
ThreadLocal 的核心是 ThreadLocalMap,它和 HashMap 不同的是:面对哈希冲突时,后者用的是链表法,而前者用的是线性探测法,为什么呢???

理解线性探测

哈希表的实现,一般是创建一个数组,然后根据 key 计算哈希值,再根据数组长度计算一个下标,把 key value 封装成一个 Entry 写入当前下标位置即可。一旦发生哈希冲突,问题就来了,一块地址不能同时存储两个元素啊,这时就出现了几种常见的哈希冲突的解决方案。

链表法是 HashMap 的解决方案,把哈希冲突的 Entry 构建成链表即可,查找时得遍历整个链表。
线性探测是 ThreadLocalMap 的解决方案,它的思路是:一旦发生哈希冲突,就继续往后找(环形),找到第一个空节点的位置,再把当前 Entry 放进去。查找的过程也是一样的,先根据哈希值计算下标,再从这个位置开始往后找,如果找到第一个空节点还没找到,就认为 key 不存在。
所以,使用线性探测法有一个前提,数组必须能容纳所有的元素,否则就会出现死循环。一般情况下,使用线性探测法的哈希表,每次放入一个 Entry 后都要判断是否要扩容,确保有足够的容量存储下一个 Entry。

如下图所示,元素8、9分别占用了下标2、3的位置,此时元素14要放进来,下标计算也是2,但是因为已经有元素8了,所以只能往后继续找,直到发现下标4的地方是空的,元素就可以放进去了。
image.png

ThreadLocalMap

要知道为啥 ThreadLocalMap 用线性探测法,必然和它的某些特性相关,那就深入源码一探究竟。

ThreadLocalMap 虽然叫Map,但是并没有实现Map接口,只提供了简单的get、set方法。内部节点类 Entry 对 Key Value 进行了封装,且继承自 WeakReference,也就是说 Entry 对 ThreadLocal 是弱引用,这主要是为了清理过期节点,避免内存泄漏。

static class ThreadLocalMap {

  static class Entry extends WeakReference<ThreadLocal<?>> {
      Object value;

      Entry(ThreadLocal<?> k, Object v) {
          super(k);
          value = v;
      }
  }
}

ThreadLocal#set 其实是把 ThreadLocal 作为 key,参数作为 value 写入到 ThreadLocalMap,方法是ThreadLocalMap#set 。在set新值时,有几种情况:

  • 如果发现key已经存在,直接替换value即可
  • 否则线性探测向后查找,如果发现了过期节点,替换过期节点。
  • 如果以上都没发生,那么肯定会找到一个空节点,直接插入即可,插入后判断是否扩容
private void set(ThreadLocal<?> key, Object value) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            // key相同,直接替换Value
            e.value = value;
            return;
        }

        if (k == null) {
            // 过程中找到过期节点,有两种情况要处理:
            // 1. key不存在,直接替换掉过期节点即可,线性探测也不会中断
            // 2. key存在(在后面),接着往后找,把目标节点替换到当前过期节点
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 找到了空节点,插入
    tab[i] = new Entry(key, value);
    int sz = ++size;
    // 因为插入一个新节点,所以要判断是否需要扩容
    // 如果尝可以清理掉部分过期节点,那就无需扩容
    // 如果数量没有超过阈值,也无需扩容
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

上述第2步要特别注意:并不能单纯的替换过期节点,因为相同的key可能在过期节点的后面,如果直接替换就会导致相同的key被插入两次,程序就出bug了,所以专门写了个replaceStaleEntry() 用来替换过期节点。
替换过期节点时,还捎带做了一些其他事:

  • 从当前过期节点往前找,看看能否扫描到一些过期节点,捎带清理一下
  • 从当前过期节点往后找,如果发现相同key,就替换指定节点。找不到相同key,就替换过期节点
  • 最后把整个过程中扫描到的过期节点做清理
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 往前查找最久的过期节点,稍后清理掉
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // 线性探测往后找,找到key就替换,找不到就直接插入到当前过期节点
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        
        // 找到了key
        if (k == key) {
            // 替换value
            e.value = value;
            // 替换节点
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;
            
            if (slotToExpunge == staleSlot)
                // 前面没有过期节点,从当前节点开始往后清理&rehash
                slotToExpunge = i;
            // 从slotToExpunge开始往后清理&rehash
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
        
        // 往前没找到过期节点,但是往后找到了,那就从当前位置开始清理
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // key没找到,直接写入当前过期节点
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 往前/往后扫描到了过期节点,清理它
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

清理过期节点并不容易,不是把它从数组里移除那么简单。因为线性探测的原因,如果过期节点后面还有节点的话,单纯的把过期节点移除掉,会导致整个探测的链路断掉,程序就出bug了。
expungeStaleEntry() 用来清理过期节点,主要做了两件事:

  • 将当前过期节点设为null
  • 向后查找,扫描到过期节点就清理掉,正常节点就rehash操作,重新放进哈希表,以此来保证探测的链路完整
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 删除过期节点,本质是置为null
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // rehash后续不为null的节点,不然会中断线性探测
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        if (k == null) { // 发现过期节点,顺带清理掉
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // 重新计算下标,如果不一样就rehash
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                
                // rehash:线性探测,直到发现一个空节点插入
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

尾巴

回到问题本身,为什么 ThreadLocalMap 用线性探测来解决哈希冲突而不是链表法?本质上它和 HashMap 的定位是不一样的,HashMap 是性能优先,尽可能的保证元素的高效访问。ThreadLocalMap 性能不是第一要素,看完源码你会发现,如果数组元素比较密集的话,ThreadLocalMap 不管是 set 还是 get 都会不可避免地扫描很多节点,这肯定会影响性能。但是换来的收益,就是 ThreadLocalMap 可以在扫描节点时主动发现过期节点且清理掉,尽可能的避免内存泄漏,这比牺牲一点性能更加值得。