在 kafka broker 内部,当执行一些需要等待的任务时(比如 broker 处理 producer 的消息,需要等待消息同步到其他副本),会使用到 DelayedOperationPurgatory 和 DelayedOperation,大致流程如下图:

 

顶层的类是 DelayedOperationPurgatory,它内部包含 2 个重要的属性:WatcherList 的数组,哈希时间轮,WatcherList 可以认为是 ConcurrentHashMap,用于保存映射的 key 和 DelayedOperation,而使用数组是为了降低锁的粒度;哈希时间轮则是为了在任务到期时能及时的触发执行。

  1 abstract class DelayedOperation(override val delayMs: Long,
  2                                 lockOpt: Option[Lock] = None)
  3   extends TimerTask with Logging {
  4 
  5   private val completed = new AtomicBoolean(false)
  6   // Visible for testing
  7   private[server] val lock: Lock = lockOpt.getOrElse(new ReentrantLock)
  8 
  9   /*
 10    * Force completing the delayed operation, if not already completed.
 11    * This function can be triggered when
 12    *
 13    * 1. The operation has been verified to be completable inside tryComplete()
 14    * 2. The operation has expired and hence needs to be completed right now
 15    *
 16    * Return true iff the operation is completed by the caller: note that
 17    * concurrent threads can try to complete the same operation, but only
 18    * the first thread will succeed in completing the operation and return
 19    * true, others will still return false
 20    */
 21   def forceComplete(): Boolean = {
 22     if (completed.compareAndSet(false, true)) {
 23       // cancel the timeout timer
 24       cancel()
 25       onComplete()
 26       true
 27     } else {
 28       false
 29     }
 30   }
 31 
 32   /**
 33    * Check if the delayed operation is already completed
 34    */
 35   def isCompleted: Boolean = completed.get()
 36 
 37   /**
 38    * Call-back to execute when a delayed operation gets expired and hence forced to complete.
 39    */
 40   def onExpiration(): Unit
 41 
 42   /**
 43    * Process for completing an operation; This function needs to be defined
 44    * in subclasses and will be called exactly once in forceComplete()
 45    */
 46   def onComplete(): Unit
 47 
 48   /**
 49    * Try to complete the delayed operation by first checking if the operation
 50    * can be completed by now. If yes execute the completion logic by calling
 51    * forceComplete() and return true iff forceComplete returns true; otherwise return false
 52    *
 53    * This function needs to be defined in subclasses
 54    */
 55   def tryComplete(): Boolean
 56 
 57   /**
 58    * Thread-safe variant of tryComplete() and call extra function if first tryComplete returns false
 59    * @param f else function to be executed after first tryComplete returns false
 60    * @return result of tryComplete
 61    */
 62   private[server] def safeTryCompleteOrElse(f: => Unit): Boolean = inLock(lock) {
 63     if (tryComplete()) true
 64     else {
 65       f
 66       // last completion check
 67       tryComplete()
 68     }
 69   }
 70 
 71   /**
 72    * Thread-safe variant of tryComplete()
 73    */
 74   private[server] def safeTryComplete(): Boolean = inLock(lock)(tryComplete())
 75 
 76   /*
 77    * run() method defines a task that is executed on timeout
 78    */
 79   override def run(): Unit = {
 80     if (forceComplete())
 81       onExpiration()
 82   }
 83 }
 84 
 85 object DelayedOperationPurgatory {
 86 
 87   private val Shards = 512 // Shard the watcher list to reduce lock contention
 88 
 89   def apply[T <: DelayedOperation](purgatoryName: String,
 90                                    brokerId: Int = 0,
 91                                    purgeInterval: Int = 1000,
 92                                    reaperEnabled: Boolean = true,
 93                                    timerEnabled: Boolean = true): DelayedOperationPurgatory[T] = {
 94     val timer = new SystemTimer(purgatoryName)
 95     new DelayedOperationPurgatory[T](purgatoryName, timer, brokerId, purgeInterval, reaperEnabled, timerEnabled)
 96   }
 97 
 98 }
 99 
100 /**
101  * A helper purgatory class for bookkeeping delayed operations with a timeout, and expiring timed out operations.
102  */
103 final class DelayedOperationPurgatory[T <: DelayedOperation](purgatoryName: String,
104                                                              timeoutTimer: Timer,
105                                                              brokerId: Int = 0,
106                                                              purgeInterval: Int = 1000,
107                                                              reaperEnabled: Boolean = true,
108                                                              timerEnabled: Boolean = true)
109         extends Logging with KafkaMetricsGroup {
110   /* a list of operation watching keys */
111   private class WatcherList {
112     val watchersByKey = new Pool[Any, Watchers](Some((key: Any) => new Watchers(key)))
113 
114     val watchersLock = new ReentrantLock()
115 
116     /*
117      * Return all the current watcher lists,
118      * note that the returned watchers may be removed from the list by other threads
119      */
120     def allWatchers = {
121       watchersByKey.values
122     }
123   }
124 
125   private val watcherLists = Array.fill[WatcherList](DelayedOperationPurgatory.Shards)(new WatcherList)
126   private def watcherList(key: Any): WatcherList = {
127     watcherLists(Math.abs(key.hashCode() % watcherLists.length))
128   }
129 
130   // the number of estimated total operations in the purgatory
131   private[this] val estimatedTotalOperations = new AtomicInteger(0)
132 
133   /* background thread expiring operations that have timed out */
134   private val expirationReaper = new ExpiredOperationReaper()
135 
136   private val metricsTags = Map("delayedOperation" -> purgatoryName)
137   newGauge("PurgatorySize", () => watched, metricsTags)
138   newGauge("NumDelayedOperations", () => numDelayed, metricsTags)
139 
140   if (reaperEnabled)
141     expirationReaper.start()
142 
143   /**
144    * Check if the operation can be completed, if not watch it based on the given watch keys
145    *
146    * Note that a delayed operation can be watched on multiple keys. It is possible that
147    * an operation is completed after it has been added to the watch list for some, but
148    * not all of the keys. In this case, the operation is considered completed and won't
149    * be added to the watch list of the remaining keys. The expiration reaper thread will
150    * remove this operation from any watcher list in which the operation exists.
151    *
152    * @param operation the delayed operation to be checked
153    * @param watchKeys keys for bookkeeping the operation
154    * @return true iff the delayed operations can be completed by the caller
155    */
156   def tryCompleteElseWatch(operation: T, watchKeys: Seq[Any]): Boolean = {
157     assert(watchKeys.nonEmpty, "The watch key list can't be empty")
158 
159     // The cost of tryComplete() is typically proportional to the number of keys. Calling tryComplete() for each key is
160     // going to be expensive if there are many keys. Instead, we do the check in the following way through safeTryCompleteOrElse().
161     // If the operation is not completed, we just add the operation to all keys. Then we call tryComplete() again. At
162     // this time, if the operation is still not completed, we are guaranteed that it won't miss any future triggering
163     // event since the operation is already on the watcher list for all keys.
164     //
165     // ==============[story about lock]==============
166     // Through safeTryCompleteOrElse(), we hold the operation's lock while adding the operation to watch list and doing
167     // the tryComplete() check. This is to avoid a potential deadlock between the callers to tryCompleteElseWatch() and
168     // checkAndComplete(). For example, the following deadlock can happen if the lock is only held for the final tryComplete()
169     // 1) thread_a holds readlock of stateLock from TransactionStateManager
170     // 2) thread_a is executing tryCompleteElseWatch()
171     // 3) thread_a adds op to watch list
172     // 4) thread_b requires writelock of stateLock from TransactionStateManager (blocked by thread_a)
173     // 5) thread_c calls checkAndComplete() and holds lock of op
174     // 6) thread_c is waiting readlock of stateLock to complete op (blocked by thread_b)
175     // 7) thread_a is waiting lock of op to call the final tryComplete() (blocked by thread_c)
176     //
177     // Note that even with the current approach, deadlocks could still be introduced. For example,
178     // 1) thread_a calls tryCompleteElseWatch() and gets lock of op
179     // 2) thread_a adds op to watch list
180     // 3) thread_a calls op#tryComplete and tries to require lock_b
181     // 4) thread_b holds lock_b and calls checkAndComplete()
182     // 5) thread_b sees op from watch list
183     // 6) thread_b needs lock of op
184     // To avoid the above scenario, we recommend DelayedOperationPurgatory.checkAndComplete() be called without holding
185     // any exclusive lock. Since DelayedOperationPurgatory.checkAndComplete() completes delayed operations asynchronously,
186     // holding a exclusive lock to make the call is often unnecessary.
187     if (operation.safeTryCompleteOrElse {
188       watchKeys.foreach(key => watchForOperation(key, operation))
189       if (watchKeys.nonEmpty) estimatedTotalOperations.incrementAndGet()
190     }) return true
191 
192     // if it cannot be completed by now and hence is watched, add to the expire queue also
193     if (!operation.isCompleted) {
194       if (timerEnabled)
195         timeoutTimer.add(operation)
196       if (operation.isCompleted) {
197         // cancel the timer task
198         operation.cancel()
199       }
200     }
201 
202     false
203   }
204 
205   /**
206    * Check if some delayed operations can be completed with the given watch key,
207    * and if yes complete them.
208    *
209    * @return the number of completed operations during this process
210    */
211   def checkAndComplete(key: Any): Int = {
212     val wl = watcherList(key)
213     val watchers = inLock(wl.watchersLock) { wl.watchersByKey.get(key) }
214     val numCompleted = if (watchers == null)
215       0
216     else
217       watchers.tryCompleteWatched()
218     debug(s"Request key $key unblocked $numCompleted $purgatoryName operations")
219     numCompleted
220   }
221 
222   /**
223    * Return the total size of watch lists the purgatory. Since an operation may be watched
224    * on multiple lists, and some of its watched entries may still be in the watch lists
225    * even when it has been completed, this number may be larger than the number of real operations watched
226    */
227   def watched: Int = {
228     watcherLists.foldLeft(0) { case (sum, watcherList) => sum + watcherList.allWatchers.map(_.countWatched).sum }
229   }
230 
231   /**
232    * Return the number of delayed operations in the expiry queue
233    */
234   def numDelayed: Int = timeoutTimer.size
235 
236   /**
237     * Cancel watching on any delayed operations for the given key. Note the operation will not be completed
238     */
239   def cancelForKey(key: Any): List[T] = {
240     val wl = watcherList(key)
241     inLock(wl.watchersLock) {
242       val watchers = wl.watchersByKey.remove(key)
243       if (watchers != null)
244         watchers.cancel()
245       else
246         Nil
247     }
248   }
249 
250   /*
251    * Return the watch list of the given key, note that we need to
252    * grab the removeWatchersLock to avoid the operation being added to a removed watcher list
253    */
254   private def watchForOperation(key: Any, operation: T): Unit = {
255     val wl = watcherList(key)
256     inLock(wl.watchersLock) {
257       val watcher = wl.watchersByKey.getAndMaybePut(key)
258       watcher.watch(operation)
259     }
260   }
261 
262   /*
263    * Remove the key from watcher lists if its list is empty
264    */
265   private def removeKeyIfEmpty(key: Any, watchers: Watchers): Unit = {
266     val wl = watcherList(key)
267     inLock(wl.watchersLock) {
268       // if the current key is no longer correlated to the watchers to remove, skip
269       if (wl.watchersByKey.get(key) != watchers)
270         return
271 
272       if (watchers != null && watchers.isEmpty) {
273         wl.watchersByKey.remove(key)
274       }
275     }
276   }
277 
278   /**
279    * Shutdown the expire reaper thread
280    */
281   def shutdown(): Unit = {
282     if (reaperEnabled)
283       expirationReaper.shutdown()
284     timeoutTimer.shutdown()
285     removeMetric("PurgatorySize", metricsTags)
286     removeMetric("NumDelayedOperations", metricsTags)
287   }
288 
289   /**
290    * A linked list of watched delayed operations based on some key
291    */
292   private class Watchers(val key: Any) {
293     private[this] val operations = new ConcurrentLinkedQueue[T]()
294 
295     // count the current number of watched operations. This is O(n), so use isEmpty() if possible
296     def countWatched: Int = operations.size
297 
298     def isEmpty: Boolean = operations.isEmpty
299 
300     // add the element to watch
301     def watch(t: T): Unit = {
302       operations.add(t)
303     }
304 
305     // traverse the list and try to complete some watched elements
306     def tryCompleteWatched(): Int = {
307       var completed = 0
308 
309       val iter = operations.iterator()
310       while (iter.hasNext) {
311         val curr = iter.next()
312         if (curr.isCompleted) {
313           // another thread has completed this operation, just remove it
314           iter.remove()
315         } else if (curr.safeTryComplete()) {
316           iter.remove()
317           completed += 1
318         }
319       }
320 
321       if (operations.isEmpty)
322         removeKeyIfEmpty(key, this)
323 
324       completed
325     }
326 
327     def cancel(): List[T] = {
328       val iter = operations.iterator()
329       val cancelled = new ListBuffer[T]()
330       while (iter.hasNext) {
331         val curr = iter.next()
332         curr.cancel()
333         iter.remove()
334         cancelled += curr
335       }
336       cancelled.toList
337     }
338 
339     // traverse the list and purge elements that are already completed by others
340     def purgeCompleted(): Int = {
341       var purged = 0
342 
343       val iter = operations.iterator()
344       while (iter.hasNext) {
345         val curr = iter.next()
346         if (curr.isCompleted) {
347           iter.remove()
348           purged += 1
349         }
350       }
351 
352       if (operations.isEmpty)
353         removeKeyIfEmpty(key, this)
354 
355       purged
356     }
357   }
358 
359   def advanceClock(timeoutMs: Long): Unit = {
360     timeoutTimer.advanceClock(timeoutMs)
361 
362     // Trigger a purge if the number of completed but still being watched operations is larger than
363     // the purge threshold. That number is computed by the difference btw the estimated total number of
364     // operations and the number of pending delayed operations.
365     if (estimatedTotalOperations.get - numDelayed > purgeInterval) {
366       // now set estimatedTotalOperations to delayed (the number of pending operations) since we are going to
367       // clean up watchers. Note that, if more operations are completed during the clean up, we may end up with
368       // a little overestimated total number of operations.
369       estimatedTotalOperations.getAndSet(numDelayed)
370       debug("Begin purging watch lists")
371       val purged = watcherLists.foldLeft(0) {
372         case (sum, watcherList) => sum + watcherList.allWatchers.map(_.purgeCompleted()).sum
373       }
374       debug("Purged %d elements from watch lists.".format(purged))
375     }
376   }
377 
378   /**
379    * A background reaper to expire delayed operations that have timed out
380    */
381   private class ExpiredOperationReaper extends ShutdownableThread(
382     "ExpirationReaper-%d-%s".format(brokerId, purgatoryName),
383     false) {
384 
385     override def doWork(): Unit = {
386       advanceClock(200L)
387     }
388   }
389 }
View Code

哈希轮的数据结构是 SystemTimer,值得注意的是,这是一个多层的哈希轮,同时时间的推进使用的是 DelayQueue 获取 bucket,以免做无效的推进。

  1 package kafka.utils.timer
  2 
  3 import java.util.concurrent.atomic.AtomicInteger
  4 import java.util.concurrent.locks.ReentrantReadWriteLock
  5 import java.util.concurrent.{DelayQueue, Executors, TimeUnit}
  6 
  7 import kafka.utils.threadsafe
  8 import org.apache.kafka.common.utils.{KafkaThread, Time}
  9 
 10 trait Timer {
 11   /**
 12     * Add a new task to this executor. It will be executed after the task's delay
 13     * (beginning from the time of submission)
 14     * @param timerTask the task to add
 15     */
 16   def add(timerTask: TimerTask): Unit
 17 
 18   /**
 19     * Advance the internal clock, executing any tasks whose expiration has been
 20     * reached within the duration of the passed timeout.
 21     * @param timeoutMs
 22     * @return whether or not any tasks were executed
 23     */
 24   def advanceClock(timeoutMs: Long): Boolean
 25 
 26   /**
 27     * Get the number of tasks pending execution
 28     * @return the number of tasks
 29     */
 30   def size: Int
 31 
 32   /**
 33     * Shutdown the timer service, leaving pending tasks unexecuted
 34     */
 35   def shutdown(): Unit
 36 }
 37 
 38 @threadsafe
 39 class SystemTimer(executorName: String,
 40                   tickMs: Long = 1,
 41                   wheelSize: Int = 20,
 42                   startMs: Long = Time.SYSTEM.hiResClockMs) extends Timer {
 43 
 44   // timeout timer
 45   private[this] val taskExecutor = Executors.newFixedThreadPool(1,
 46     (runnable: Runnable) => KafkaThread.nonDaemon("executor-" + executorName, runnable))
 47 
 48   private[this] val delayQueue = new DelayQueue[TimerTaskList]()
 49   private[this] val taskCounter = new AtomicInteger(0)
 50   private[this] val timingWheel = new TimingWheel(
 51     tickMs = tickMs,
 52     wheelSize = wheelSize,
 53     startMs = startMs,
 54     taskCounter = taskCounter,
 55     delayQueue
 56   )
 57 
 58   // Locks used to protect data structures while ticking
 59   private[this] val readWriteLock = new ReentrantReadWriteLock()
 60   private[this] val readLock = readWriteLock.readLock()
 61   private[this] val writeLock = readWriteLock.writeLock()
 62 
 63   def add(timerTask: TimerTask): Unit = {
 64     readLock.lock()
 65     try {
 66       addTimerTaskEntry(new TimerTaskEntry(timerTask, timerTask.delayMs + Time.SYSTEM.hiResClockMs))
 67     } finally {
 68       readLock.unlock()
 69     }
 70   }
 71 
 72   private def addTimerTaskEntry(timerTaskEntry: TimerTaskEntry): Unit = {
 73     if (!timingWheel.add(timerTaskEntry)) {
 74       // Already expired or cancelled
 75       if (!timerTaskEntry.cancelled)
 76         taskExecutor.submit(timerTaskEntry.timerTask)
 77     }
 78   }
 79 
 80   /*
 81    * Advances the clock if there is an expired bucket. If there isn't any expired bucket when called,
 82    * waits up to timeoutMs before giving up.
 83    */
 84   def advanceClock(timeoutMs: Long): Boolean = {
 85     var bucket = delayQueue.poll(timeoutMs, TimeUnit.MILLISECONDS)
 86     if (bucket != null) {
 87       writeLock.lock()
 88       try {
 89         while (bucket != null) {
 90           timingWheel.advanceClock(bucket.getExpiration)
 91           bucket.flush(addTimerTaskEntry)
 92           bucket = delayQueue.poll()
 93         }
 94       } finally {
 95         writeLock.unlock()
 96       }
 97       true
 98     } else {
 99       false
100     }
101   }
102 
103   def size: Int = taskCounter.get
104 
105   override def shutdown(): Unit = {
106     taskExecutor.shutdown()
107   }
108 
109 }
View Code

在 DelayedOperationPurgatory 中,推动时间轮转动的线程是 ExpiredOperationReaper。

 

代码中 DelayedOperationPurgatory 的内部数据结构图如下:一个哈希 map,一个时间轮

 

 这里面有个冗余的点,假定 hash(key) % 512 = 25,则 hash(key) = 512*n + 25, 这样 hash(key) % 16 = 9,后面的 map 就退化成 list 了。

posted on 2023-09-20 13:25  偶尔发呆  阅读(21)  评论(0编辑  收藏  举报