51工具盒子

依楼听风雨
笑看云卷云舒,淡观潮起潮落

CountDownLatch 源码详解

一、前言 {#一、前言}

本篇的介绍对象是 CountDownLatch ,它同样是基于 AQS 之上扩展的一款多线程场景下的工具类,它可以使一个或多个线程等待其他线程各自执行完毕后再执行。

对于 CountDownLatch 理解,我们可以将单次拆开为 CountDownLatchCountDown 表示倒计时,Latch 表示门闩,当倒计时结束后门闩解除,门就开了。

二、使用场景 {#二、使用场景}

要完成一项复杂的任务,任务被划分为子任务1和子任务2,3,4...,为了提高执行任务的效率,采用多线程去完成。

由于子任务1的执行条件依赖于 子任务2,3,4...,需要先执行子任务2,3,4...获取到相应的结果才能执行子任务1,这是 CountDownLatch 就派上用场了。

三、工作原理 {#三、工作原理}

给定 CountDownLatch 一个倒计时数,每个线程都能访问 CountDownLatch 实例。当一批线程要协作完成任务,线程 A 可以调用 CountDownLatchawait() 进行等待阻塞。其他线程则做其他业务,当业务执行完成后调用 CountDownLatchcountDown() 减掉倒计时。最后倒计时减到 0 时,阻塞的线程 A 就会被唤醒执行后续的业务。

由于是 CountDownLatch 是基于 AQS 扩展的,因此引用 AQS 模型图可方便我们理解:

图中,state 用于保存倒计时数,Node 节点用于封装等待阻塞的线程。

四、源码解析 {#四、源码解析}

我们先通过案例了解 CountDownLatch 基本使用。

  • 案例

我们将 CountDownLatch 当作餐馆服务员,线程比作客人。当客人来到餐馆吃饭时,餐馆服务员负责记录餐桌、客人吃饭的情况。

|------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 | public class CountDownLatchTest { public static void main(String[] args) throws InterruptedException { // (1) CountDownLatch countDownLatch = new CountDownLatch(5); System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 上菜"); for (int i = 1; i <= 5; i++) { new Thread(() -> { try { System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 开始吃饭"); Double time = Math.random() * 3 + 1; TimeUnit.SECONDS.sleep(time.intValue()); System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 吃饭结束,走人"); // (2) 减去倒计时 countDownLatch.countDown(); } catch (InterruptedException e) { e.printStackTrace(); } }, "t" + i).start(); } // (3) 等待阻塞,当倒计时为 0 就放行 System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 等待客人结账"); countDownLatch.await(); System.out.println(LocalDateTime.now() + " -> " + Thread.currentThread().getName() + " 客人都走了,开始收摊"); } } |

执行结果:

|---------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 | 2023-03-15T11:40:08.536 -> main 上菜 2023-03-15T11:40:08.542 -> t1 开始吃饭 2023-03-15T11:40:08.542 -> t2 开始吃饭 2023-03-15T11:40:08.542 -> main 等待客人结账 2023-03-15T11:40:08.542 -> t3 开始吃饭 2023-03-15T11:40:08.542 -> t4 开始吃饭 2023-03-15T11:40:08.542 -> t5 开始吃饭 2023-03-15T11:40:09.543 -> t2 吃饭结束,走人 2023-03-15T11:40:10.543 -> t4 吃饭结束,走人 2023-03-15T11:40:11.542 -> t3 吃饭结束,走人 2023-03-15T11:40:11.542 -> t1 吃饭结束,走人 2023-03-15T11:40:11.542 -> t5 吃饭结束,走人 2023-03-15T11:40:11.542 -> main 客人都走了,开始收摊 |

当服务员上菜给客人后,需要等待(await())所有客人吃完饭结账后才能收摊,客人吃完饭需要通知服务员吃完饭结账(countDown())。

  • 源码分析

我们按照例子中的代码执行顺序分析。

首先查看 (1) 处代码,即创建 CountDownLatch 实例,进入构造方法中:

|------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } int getCount() { return getState(); } // (4) 尝试获取资源 protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } // (5) 尝试释放资源 protected boolean tryReleaseShared(int releases) { for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } } } private final Sync sync; public CountDownLatch(int count) { if (count < 0) throw new IllegalArgumentException("count < 0"); this.sync = new Sync(count); } // ...省略... } |

在构造方法内部创建了 Sync 实例,而 Sync 是一个静态的内部类, 它继承 AbstractQueuedSynchronizer 类,因此 Sync 拥有了 AQS 的能力,CountDownLatch 的所有操作都是通过 Sync 实例完成的。

调用构造方法传入的 count 值(倒计时数)被传入到 Sync 的构造方法中,其内部调用 setState(count) 方法,该方法来自 AQS ,被保存到 AQSstate 中。

此时,AQS 的模型图如下:

回到案例代码中,main 线程创建好 CountDownLatch 实例后, 接着执行 for 循环,其方法体中创建新的线程执行其他业务,都是异步操作。我们顺着当前线程直接来到 (3) 处,即 countDownLatch.await(),跳进源码:

|---------------------|---------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 | public class CountDownLatch { public void await() throws InterruptedException { sync.acquireSharedInterruptibly(1); } } |

await() 方法底层通过 Sync 实例调用了 acquireSharedInterruptibly(1) 方法,该方法来自 AQS

|---------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... public final void acquireSharedInterruptibly(int arg) throws InterruptedException { if (Thread.interrupted()) throw new InterruptedException(); // (6) if (tryAcquireShared(arg) < 0) // (7) doAcquireSharedInterruptibly(arg); } } |

进入该方法:先判断 main 线程是否被中断,并没有,然后执行 (6) 处代码,即 tryAcquireShared(arg),尝试获取资源权限(判断倒计时是否为 0)。该方法是一个抽象方法,最终通过子类来实现,即上文提到的 Sync 类来实现,跳回 (4) 处:

|---------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } int getCount() { return getState(); } // (4) protected int tryAcquireShared(int acquires) { return (getState() == 0) ? 1 : -1; } // ...省略... } // ...省略... } |

tryAcquireShared() 方法中判断 state 值(倒计时)是否为 0 ,是则返回 1,否则返回 -1。

从上文案例的执行结果可以看出,main 线程在线程阻塞之后,其他线程才陆续执行完毕,因此 state 值不可能为 0,最终方法返回 -1,然后执行 (7) 处代码,即 doAcquireSharedInterruptibly(arg) 方法:

|---------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { // (8) 线程被封装到 Node 节点中 final Node node = addWaiter(Node.SHARED); boolean failed = true; try { for (;;) { // (9) 获取前驱节点 final Node p = node.predecessor(); if (p == head) { // (10) 再一次尝试获取资源 int r = tryAcquireShared(arg); if (r >= 0) { // (11) 设置头结点 setHeadAndPropagate(node, r); p.next = null; // help GC failed = false; return; } } // (12) 获取资源失败,修改前驱节点的 state 状态 if (shouldParkAfterFailedAcquire(p, node) && // (13) 底层调用 LockSupport.lock() 挂起当前线程 parkAndCheckInterrupt()) throw new InterruptedException(); } } finally { if (failed) cancelAcquire(node); } } // ...省略... } |

该方法在 《AQS 源码详解》 文章中详细解说过,源码上已简单注释说明,此处不多赘述。

最终,main 线程执行到 parkAndCheckInterrupt() 方法中被挂起等待。

此时,AQS 的模型图如下:

我们切换到其他线程视角,案例中 t2 线程先执行完业务调用了 countDown() 方法:

|---------------------------|--------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 | public class CountDownLatch { // ...省略... public void countDown() { sync.releaseShared(1); } } |

countDown() 方法底层调用 releaseShared(1),该方法来自 AQS

|------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... public final boolean releaseShared(int arg) { // (14) if (tryReleaseShared(arg)) { // (15) doReleaseShared(); return true; } return false; } } |

线程 t2 来到 releaseShared(1) 方法中先执行 (14) 处代码,即 tryReleaseShared(arg) 代码,该方法是个抽象方法,通过子类 Sync 来实现:

|---------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | public class CountDownLatch { private static final class Sync extends AbstractQueuedSynchronizer { Sync(int count) { setState(count); } int getCount() { return getState(); } // (5) 尝试释放资源 protected boolean tryReleaseShared(int releases) { for (;;) { int c = getState(); if (c == 0) return false; int nextc = c-1; if (compareAndSetState(c, nextc)) return nextc == 0; } } // ...省略... } // ...省略... } |

进到 tryReleaseShared(arg) 方法中,开启一个无限循环:

  1. 获取 state 值,当前值为 5。
  2. 判断 state 值,如果 为 0 返回 false,否则计算 state 新值(state 旧值 -1),此时新值为 4。
  3. 通过 CAS 方式将新值赋给 state
  4. 如果 state 新值为 0 返回 true,否则返回 false。

t2 线程执行方法最终返回值为 false,线程也跟着结束。

此时,AQS 的模型图如下:

其他条线程的执行步骤与 t2 线程都一样,我们直接跳到最后的 t5 线程视角。当 t5 线程执行 tryReleaseShared(arg)state 值改为 0 后,方法返回 true,开始执行 (15) 处代码,即 doReleaseShared()

|------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... private void doReleaseShared() { for (;;) { Node h = head; // (16) if (h != null && h != tail) { int ws = h.waitStatus; // (17) Node.SIGNAL:-1 if (ws == Node.SIGNAL) { if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0)) continue; // (18) unparkSuccessor(h); } else if (ws == 0 && !compareAndSetWaitStatus(h, 0, Node.PROPAGATE)) continue; } // (19) if (h == head) break; } } // ...省略... } |

该方法用于修改 CLH 队列中头结点的 waitStatus 值以及唤醒头结点的后继节点中的线程。 开启一个无限循环:

  1. 获取 CLH 的头结点
  2. 判断头结点(dummy)是否为空,同时头结点是否与尾节点相同。由 AQS 模型图可知,(16) 处的判断是成立的,随后 t5 线程进到 if 方法体中。
  3. 判断头结点(dummy)的 waitStatus 状态,当前状态值为 -1,(17) 处判断成立,将头结点的 waitStatus 通过 CAS 方式还原为 0。
  4. 修改成功后执行 (18) 处代码,即 unparkSuccessor(h),该方法用于查询头结点的后继节点 node1,并通过 LockSupport.unpark(thread) 唤醒节点中的线程(main 线程)。由于该方法在 《AQS 源码详解》 已讲解,此处不多赘述。
  5. t5 线程最后来到 (19) 处,判断成立退出无限循环。

这样 t5 线程释放锁完毕,结束线程,我们转回被唤醒的 main 线程视角:

|---------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | public abstract class AbstractQueuedSynchronizer extends AbstractOwnableSynchronizer implements java.io.Serializable { // ...省略... private void doAcquireSharedInterruptibly(int arg) throws InterruptedException { // (8) 线程被封装到 Node 节点中 final Node node = addWaiter(Node.SHARED); boolean failed = true; try { for (;;) { // (9) 获取前驱节点 final Node p = node.predecessor(); if (p == head) { // (10) 再一次尝试获取资源 int r = tryAcquireShared(arg); if (r >= 0) { // (11) 设置头结点 setHeadAndPropagate(node, r); p.next = null; // help GC failed = false; return; } } // (12) 获取资源失败,修改前驱节点的 state 状态 if (shouldParkAfterFailedAcquire(p, node) && // (13) 底层调用 LockSupport.lock() 挂起当前线程 parkAndCheckInterrupt()) throw new InterruptedException(); } } finally { if (failed) cancelAcquire(node); } } // ...省略... } |

main 线程在执行 (13) 处代码被挂起等待,该方法是在一个无限循环中进行的,当 main 线程被 t5 线程唤醒后开始执行下一轮循环任务:

  1. 获取前驱节点,即 dummy 节点,判断是否头结点,由 AQS 模型图可知,判断成立。
  2. 调用 tryAcquireShared(arg),上文已介绍,由于 state 值被减为 0, 最终该方法返回值为 1。
  3. 之后执行 (11) 处代码,即 setHeadAndPropagate(node, r),该方法用于将 node1 节点设置为新的头结点,移除节点中的线程
  4. 旧的头结点与当前节点解除关系

最终, AQS 的模型图如下:

五、参考资料 {#五、参考资料}

CAS 原理新讲

LockSupport 工具介绍

AQS 源码详解

赞(0)
未经允许不得转载:工具盒子 » CountDownLatch 源码详解