CountDownLatch、Semaphore的实现原理AQS

2021/08/07

1、CountDownLatch

使用举例

在java.util.councurrent包,我们常使用CountDownLatch来阻塞主线程,等待子线程都执行完毕,才继续往下执行首先new 一个指定数量计数器的CountDownLatch,主线程执行,调用work子线程,主线程接着调用 CountDownLatch 的 await() 方法进行阻塞,当子线程执行完后,执行countDown方法,把计数器减1,直到减为0,主线程才开始执行。

public class CountDownLatchDemo {
  public static void main(String[] args) throws InterruptedException {
    // CountDownLatch 减法计数器
    // 倒计时 6 -> 0执行
    CountDownLatch countDownLatch = new CountDownLatch(6);

    for (int i = 1; i <= 6; i++) {
      new Thread(()->{
        System.out.println(Thread.currentThread().getName() + " Start");
        countDownLatch.countDown(); // 计数器 -1
      },String.valueOf(i)).start();
    }
    // 只要计数器没有归零,我们就一直阻塞
    countDownLatch.await(); 
    // 目标:等待上面的6个线程跑完再执行,主线程才执行
    System.out.println(Thread.currentThread().getName() + " End");
  }
}

执行结果:

AQS实现

CountDownLatch的底层通过AQS实现,ReentrantLock、Semaphore 也是基于AQS实现的,AQS的一般使用方式就是以内部类的形式继承它

1、创建CountDownLatch对象

下面以CountDownLatch源码的角度分析AQS

构造方法

public CountDownLatch(int count) {
  if (count < 0) throw new IllegalArgumentException("count < 0");
  this.sync = new Sync(count);
}

调用它的内部类Sync,继承AbstractQueuedSynchronizer

/**
     * Synchronization control For CountDownLatch.
     * Uses AQS state to represent count.
     使用AQS state 来表示计数
     */
private static final class Sync extends AbstractQueuedSynchronizer {
  private static final long serialVersionUID = 4982264981922014374L;
  
  Sync(int count) {
    setState(count);
  }

  int getCount() {
    return getState();
  }

  /**
   若count为0,返回1,表示获取锁成功,此时线程将不会阻塞,正常运行
   若count不为0,则返回-1,表示获取锁失败,线程将会被阻塞
  **/
  protected int tryAcquireShared(int acquires) {
    return (getState() == 0) ? 1 : -1;
  }

  protected boolean tryReleaseShared(int releases) {
    // Decrement count; signal when transition to zero
    for (;;) {
      int c = getState();
      if (c == 0)
        return false;
      int nextc = c-1;
      if (compareAndSetState(c, nextc))
        return nextc == 0;
    }
  }
}

当我们创建CountDownLatch对象时,内部类Sync调用setState方法,它的计数器就是AQS的state变量,一个volatile变量,保证了可见性

2、主线程调用await方法,等待线程入队

public void await() throws InterruptedException {
  sync.acquireSharedInterruptibly(1);
}

主线程调用CountDownLatch.await()方法阻塞自己,它的原理是尝试获取共享锁,若获取失败,则线程将会被加入到AQS的同步队列中等待,直到获取成功为止。且这个方法是会响应中断的,线程在阻塞的过程中,若被其他线程中断,则此方法会通过抛出异常的方式结束等待。

点击方法acquireSharedInterruptibly

public final void acquireSharedInterruptibly(int arg)
  throws InterruptedException {
  if (Thread.interrupted())
    throw new InterruptedException();
   // 调用tryAcquireShared方法尝试获取锁,这个方法被Sycn类重写
  if (tryAcquireShared(arg) < 0)
    doAcquireSharedInterruptibly(arg);

方法 tryAcquireShared 被CountDownLatch内部类Sync重写,

protected int tryAcquireShared(int acquires) {
   // 若count为0,返回1,表示获取锁成功,此时线程将不会阻塞,正常运行
   // 若count不为0,则返回-1,表示获取锁失败,线程将会被阻塞
  return (getState() == 0) ? 1 : -1;
}

获取锁失败就调用AQS的方法doAcquireSharedInterruptibly,把当前线程加入AQS的同步队列中阻塞等待,直到成功获取锁,即count==0。

private void doAcquireSharedInterruptibly(int arg)
  throws InterruptedException {
  // 使用共享模式创建当前线程的节点并加入等待队列,然后返回,这里结合上面的举例,主线程就是要加入的等待节点
  final Node node = addWaiter(Node.SHARED);
  boolean failed = true;
  // 进入CAS循环,等待
  try {
    for (;;) {
      final Node p = node.predecessor(); // 获取前一个节点
      if (p == head) {
        // 如果前一个节点是头节点,尝试获取共享锁,即判断state是否0
        int r = tryAcquireShared(arg);
        if (r >= 0) {
          // state==0,说明没有子线程需要等待了,则将当前等待节点设为head头节点,并释放锁
          setHeadAndPropagate(node, r);
          p.next = null; // help GC
          failed = false;
          return;
        }
      }
      // state还不为0,就会到这里,
      // 第一次的时候,waitStatus是0,那么node的waitStatus就会被置为SIGNAL;
      // 第二次再走到这里,就会用LockSupport的park方法把当前线程阻塞住
      if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt())
        throw new InterruptedException();
    }
  } finally {
    if (failed)
      cancelAcquire(node);
  }
}

这是AQS的核心部分,用内部的一个 Node 类维护一个 CHL Node FIFO 队列,将当前线程加入等待队列,并通过 parkAndCheckInterrupt()方法实现当前线程的阻塞。

首先执行addWaiter方法,创建一个Node节点,加入等待队列的尾部,看他的源码

/**
     * Creates and enqueues node for current thread and given mode.
     *
     * @param mode Node.EXCLUSIVE for exclusive, Node.SHARED for shared
     独占模式和共享模式
     * @return the new node
     */
private Node addWaiter(Node mode) {
  Node node = new Node(Thread.currentThread(), mode);
  // Try the fast path of enq; backup to full enq on failure
  Node pred = tail;
  if (pred != null) {
    // 尾节点不为空,CAS替换新建的Node节点为新的尾节点
    node.prev = pred;
    // CAS比较更新尾节点
    if (compareAndSetTail(pred, node)) {
      pred.next = node;
      return node;
    }
  }
  // 尾节点为空,就尝试CAS入队
  enq(node);
  return node;
}

点进方法compareAndSetTail,看见是调用unsafe类提供的本地方法(带有native关键字),它的底层是C++方法,直接操作内存,在cpu层面加锁,Java是隔了一层JVM,不能操作内存。

/**
     * CAS tail field. Used only by enq.
     */
private final boolean compareAndSetTail(Node expect, Node update) {
  return unsafe.compareAndSwapObject(this, tailOffset, expect, update);
}

CAS机制就是compare and swap 也称 自旋锁,不断比较当前线程栈中的变量值(也就是期望值expect)与共享内存中的变量值是否一致,如果是则将共享内存中的变量值更新为update值,否则就把线程栈中变量直接赋值为共享内存中变量值一致,再重新判断,这就是所谓循环判断自己,来保证线程栈的update值能写入到共享内存中,保证读取与写入的原子性,所有原子类都是基于CAS机制实现的 原子操作,来保证原子类的值修改的原子性,即当前线程修改共享变量不收其他线程的干扰。

点进方法enq(node)

/**
     * Inserts node into queue, initializing if necessary. See picture above.
     * @param node the node to insert 代表当前线程的等待节点
     * @return node's predecessor
     */
private Node enq(final Node node) {
  // 死循环+CAS保证所有节点都入队
  for (;;) {
    Node t = tail;
    if (t == null) { // Must initialize
      // 头节点为空,就初始化等待队列,这里是new Node(),也就是AQS 默认要有一个虚拟节点,
      // 循环继续,进入else 当前等待节点node加入队尾。
      if (compareAndSetHead(new Node()))
        tail = head;
    } else {
      node.prev = t;
      // CAS 替换将等待节点加入队尾
      if (compareAndSetTail(t, node)) {
        t.next = node; 
        return t;
      }
    }
  }
}

继续回到上面的方法doAcquireSharedInterruptibly,当前线程的节点加入队列后,进行CAS循环阻塞,如果state=0了就会释放共享锁,执行方法setHeadAndPropagate

/**
 * Sets head of queue, and checks if successor may be waiting
 * in shared mode, if so propagating if either propagate > 0 or
 * PROPAGATE status was set.
 *
 * @param node the node
 * @param propagate the return value from a tryAcquireShared
 */
private void setHeadAndPropagate(Node node, int propagate) {
    Node h = head; // Record old head for check below
    setHead(node);
    /*
     * Try to signal next queued node if:
     *   Propagation was indicated by caller,
     *     or was recorded (as h.waitStatus either before
     *     or after setHead) by a previous operation
     *     (note: this uses sign-check of waitStatus because
     *      PROPAGATE status may transition to SIGNAL.)
     * and
     *   The next node is waiting in shared mode,
     *     or we don't know, because it appears null
     *
     * The conservatism in both of these checks may cause
     * unnecessary wake-ups, but only when there are multiple
     * racing acquires/releases, so most need signals now or soon
     * anyway.
     */
    if (propagate > 0 || h == null || h.waitStatus < 0 ||
        (h = head) == null || h.waitStatus < 0) {
        Node s = node.next;
        if (s == null || s.isShared())  // 当前阻塞节点的下一个节点为空,或者是共享模式的节点,释放所以等待线程
            doReleaseShared();
    }
}

这里的参数node是当前阻塞等待的主线程节点,propagate是1

Node 对象中有一个属性是 waitStatus ,它有四种状态,分别是:

//线程已被 cancelled ,这种状态的节点将会被忽略,并移出队列
static final int CANCELLED =  1;
// 表示当前线程已被挂起,并且后继节点可以尝试抢占锁
static final int SIGNAL    = -1;
//线程正在等待某些条件
static final int CONDITION = -2;
//共享模式下 无条件所有等待线程尝试抢占
static final int PROPAGATE = -3;

如果获取共享锁失败即state!=0,走到方法shouldParkAfterFailedAcquire(p, node)

    /**
     * Checks and updates status for a node that failed to acquire.
     * Returns true if thread should block. This is the main signal
     * control in all acquire loops.  Requires that pred == node.prev.
     *
     * @param pred node's predecessor holding status
     * @param node the node
     * @return {@code true} if thread should block
     */
    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        int ws = pred.waitStatus;
        if (ws == Node.SIGNAL)
            /*
             * This node has already set status asking a release
             * to signal it, so it can safely park.
             */
            return true;
        if (ws > 0) {
            /*
             * Predecessor was cancelled. Skip over predecessors and
             * indicate retry.线程已被cancelled,
             */
            do {
                node.prev = pred = pred.prev;
            } while (pred.waitStatus > 0);
            pred.next = node;
        } else {
            /*
             * waitStatus must be 0 or PROPAGATE.  Indicate that we
             * need a signal, but don't park yet.  Caller will need to
             * retry to make sure it cannot acquire before parking.
             CAS 自旋转更新上一节点为挂起状态
             */
            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
        }
        return false;
    }

第1次循环把上节点更新为挂起状态Node.SIGNAL=-1, 第2次循环则进入方法parkAndCheckInterrupt()

/**
     * Convenience method to park and then check if interrupted
     *
     * @return {@code true} if interrupted
     */
private final boolean parkAndCheckInterrupt() {
  LockSupport.park(this); // 当前CountDownLatch对象的成员变量Sync对象
  return Thread.interrupted();
}

点进LockSupport的park方法

public static void park(Object blocker) {
  Thread t = Thread.currentThread();
  setBlocker(t, blocker);
  UNSAFE.park(false, 0L);
  setBlocker(t, null);
}

private static void setBlocker(Thread t, Object arg) {
  // Even though volatile, hotspot doesn't need a write barrier here.
  UNSAFE.putObject(t, parkBlockerOffset, arg);
}

发现阻塞是调用unsafe类的本地方法putObject(),直接操作内存线程进行阻塞

3、子线程调用countDown方法,等待线程被唤醒

当执行 CountDownLatch 的 countDown()方法,将计数器减一,也就是state减一,当减到0的时候,等待队列中的线程被释放

/**
     * Decrements the count of the latch, releasing all waiting threads if
     * the count reaches zero.
     *
     * <p>If the current count is greater than zero then it is decremented.
     * If the new count is zero then all waiting threads are re-enabled for
     * thread scheduling purposes.
     *
     * <p>If the current count equals zero then nothing happens.
     */
public void countDown() {
  sync.releaseShared(1);
}

调用AQS的releaseShared方法

public final boolean releaseShared(int arg) {
  if (tryReleaseShared(arg)) {
    doReleaseShared();
    return true;
  }
  return false;
}

调用CountDownLatch的内部类Sync重写的tryReleaseShared(arg)方法

protected boolean tryReleaseShared(int releases) {
  // Decrement count; signal when transition to zero
  // 循环+CAS判断,实现计数器减1
  for (;;) {
    int c = getState();
    if (c == 0)
      // 已经释放共享锁了,返回false
      return false;
    int nextc = c-1;
    if (compareAndSetState(c, nextc))
      return nextc == 0; // 返回当前state==0,如果为true,则会执行doReleaseShared方法
  }
}

当state为0,执行doReleaseShared()唤醒被阻塞的线程

  /**
     * Release action for shared mode -- signals successor and ensures
     * propagation. (Note: For exclusive mode, release just amounts
     * to calling unparkSuccessor of head if it needs signal.)
     */
    private void doReleaseShared() {
        /*
         * Ensure that a release propagates, even if there are other
         * in-progress acquires/releases.  This proceeds in the usual
         * way of trying to unparkSuccessor of head if it needs
         * signal. But if it does not, status is set to PROPAGATE to
         * ensure that upon release, propagation continues.
         * Additionally, we must loop in case a new node is added
         * while we are doing this. Also, unlike other uses of
         * unparkSuccessor, we need to know if CAS to reset status
         * fails, if so rechecking.
         */
        for (;;) {
            Node h = head;
            if (h != null && h != tail) {
                int ws = h.waitStatus;
               // 如果节点状态为SIGNAL,则他的next节点也可以尝试被唤醒
                if (ws == Node.SIGNAL) {
                    if (!compareAndSetWaitStatus(h, Node.SIGNAL, 0))
                        continue;            // loop to recheck cases
                    unparkSuccessor(h); // 从头节点开始唤醒所有放入阻塞队列的线程
                }
              // 将节点状态设置为PROPAGATE,表示要向下传播,依次唤醒
                else if (ws == 0 &&
                         !compareAndSetWaitStatus(h, 0, Node.PROPAGATE))
                    continue;                // loop on failed CAS
            }
            if (h == head)                   // loop if head changed
                break;
        }
    }

点击唤醒方法unparkSuccessor

    /**
     * Wakes up node's successor, if one exists.
     *
     * @param node the node
     */
    private void unparkSuccessor(Node node) {
        /*
         * If status is negative (i.e., possibly needing signal) try
         * to clear in anticipation of signalling.  It is OK if this
         * fails or if status is changed by waiting thread.
         */
        int ws = node.waitStatus;
        if (ws < 0)
            compareAndSetWaitStatus(node, ws, 0);

        /*
         * Thread to unpark is held in successor, which is normally
         * just the next node.  But if cancelled or apparently null,
         * traverse backwards from tail to find the actual
         * non-cancelled successor.
         */
        Node s = node.next;
        if (s == null || s.waitStatus > 0) {
            s = null;
            for (Node t = tail; t != null && t != node; t = t.prev)
                if (t.waitStatus <= 0)
                    s = t;
        }
        if (s != null)
            LockSupport.unpark(s.thread); // 调用本地方法唤醒被阻塞的线程
    }

因为这是共享型的,当计数器为 0 后,会唤醒等待队列里的所有线程,所有调用了 await() 方法的线程都被唤醒,并发执行。这种情况对应到的场景是,有多个线程需要等待一些动作完成,比如一个线程完成初始化动作,其他5个线程都需要用到初始化的结果,那么在初始化线程调用 countDown 之前,其他5个线程都处在等待状态。一旦初始化线程调用了 countDown ,其他5个线程都被唤醒,开始执行。

总结

1、AQS 分为独占模式和共享模式,CountDownLatch 使用了它的共享模式。

2、AQS 当第一个等待线程(被包装为 Node)要入队的时候,要保证存在一个 head 节点,这个 head 节点不关联线程,也就是一个虚节点。

3、当队列中的等待节点(关联线程的,非 head 节点)抢到锁,将这个节点设置为 head 节点。

4、第一次自旋抢锁失败后,waitStatus 会被设置为 -1(SIGNAL),第二次再失败,就会被 LockSupport 阻塞挂起。

5、如果一个节点的前置节点为 SIGNAL 状态,则这个节点可以尝试抢占锁。

Post Directory