深入理解Spark 2.1 Core (十一):Shuffle Reduce 端的原理与源码分析
我們曾經在《深入理解Spark 2.1 Core (一):RDD的原理與源碼分析 》講解過:
為了有效地實現容錯,RDD提供了一種高度受限的共享內存,即RDD是只讀的,并且只能通過其他RDD上的批量操作來創建(注:還可以由外部存儲系數據集創建,如HDFS)
可知,我們在第九,第十篇博文所講的是傳統hadoop MapReduce類似的,在最初從HDFS中讀取數據生成HadoopRDD的過程。而RDD可以通過其他RDD上的批量操作來創建,所以這里的HadoopRDD對于下一個生成的ShuffledRDD可以視為Map端,當然下一個生成的ShuffledRDD可以被下下個ShuffledRDD視為Map端。反過來說,下一個ShuffledRDD可以被`HadoopRDD視作Reduce端。
這篇博文,我們就來講下Shuffle的Reduce端。其實在RDD迭代部分和第九篇博文類似,不同的是,這里調用的是rdd.ShuffledRDD.compute:
override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {// 得到依賴val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]// 調用getReader,傳入dep.shuffleHandle 分區 上下文 // 得到Reader,調用read()// 得到迭代器SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context).read().asInstanceOf[Iterator[(K, C)]]}- 1
- 3
這里調用的是shuffle.sort.SortShuffleManager的getReader:
override def getReader[K, C](handle: ShuffleHandle,startPartition: Int,endPartition: Int,context: TaskContext): ShuffleReader[K, C] = {// 生成返回 BlockStoreShuffleReadernew BlockStoreShuffleReader(handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)}- 1
shuffle.BlockStoreShuffleReader.read:
override def read(): Iterator[Product2[K, C]] = {// 實例化ShuffleBlockFetcherIteratorval blockFetcherItr = new ShuffleBlockFetcherIterator(context,blockManager.shuffleClient,blockManager,// 通過消息發送獲取 ShuffleMapTask 存儲數據位置的元數據mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),// 設置每次傳輸的大小SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,// // 設置Int的大小SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))// 基于配置的壓縮和加密來包裝流val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>serializerManager.wrapStream(blockId, inputStream)}val serializerInstance = dep.serializer.newInstance()// 對每個流生成 k/v 迭代器val recordIter = wrappedStreams.flatMap { wrappedStream =>serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator}// 每條記錄讀取后更新任務度量val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()// 生成完整的迭代器val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](recordIter.map { record =>readMetrics.incRecordsRead(1)record},context.taskMetrics().mergeShuffleReadMetrics())// 傳入metricIter到可中斷的迭代器// 為了能取消迭代val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {// 若需要對數據進行聚合if (dep.mapSideCombine) {// 若需要進行Map端(對于下一個Shuffle來說)的合并val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)// 若只需要進行Reduce端(對于下一個Shuffle來說)的合并} else {val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)}} else {require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]}dep.keyOrdering match {case Some(keyOrd: Ordering[K]) =>// 若需要排序// 若spark.shuffle.spill設置為否的話// 將不會spill到磁盤val sorter =new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)sorter.insertAll(aggregatedIter)context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())case None =>aggregatedIter}}類調用關系圖:
下面我們來深入講解下實例化ShuffleBlockFetcherIterator的過程:
// 實例化ShuffleBlockFetcherIteratorval blockFetcherItr = new ShuffleBlockFetcherIterator(context,blockManager.shuffleClient,blockManager,// 通過消息發送獲取 ShuffleMapTask 存儲數據位置的元數據mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),// 設置每次傳輸的大小SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,// // 設置Int的大小SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))- 1
獲取元數據
mapOutputTracker.getMapSizesByExecutorId
首先我們會調用mapOutputTracker.getMapSizesByExecutorId:
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")// 得到元數據val statuses = getStatuses(shuffleId)// 返回格式為:// Seq[BlockManagerId,Seq[(shuffle block id, shuffle block size)]]statuses.synchronized {return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)}}- 1
mapOutputTracker.getStatuses
private def getStatuses(shuffleId: Int): Array[MapStatus] = {// 嘗試從本地獲取數據val statuses = mapStatuses.get(shuffleId).orNullif (statuses == null) {// 若本地無數據logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")val startTime = System.currentTimeMillisvar fetchedStatuses: Array[MapStatus] = nullfetching.synchronized {// 若以及有其他人也準備遠程獲取這數據的話// 則等待while (fetching.contains(shuffleId)) {try {fetching.wait()} catch {case e: InterruptedException =>}}// 嘗試直接獲取數據fetchedStatuses = mapStatuses.get(shuffleId).orNullif (fetchedStatuses == null) {// 若還是不得不遠程獲取,// 則將shuffleId加入fetchingfetching += shuffleId}}if (fetchedStatuses == null) {logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)try {// 遠程獲取val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))// 反序列化fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)logInfo("Got the output locations")// 將數據加入mapStatusesmapStatuses.put(shuffleId, fetchedStatuses)} finally {fetching.synchronized {fetching -= shuffleIdfetching.notifyAll()}}}logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +s"${System.currentTimeMillis - startTime} ms")if (fetchedStatuses != null) {// 若直接獲取,則直接返回return fetchedStatuses} else {logError("Missing all output locations for shuffle " + shuffleId)throw new MetadataFetchFailedException(shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)}} else {// 若直接獲取,則直接返回return statuses}}- 1
- 2
mapOutputTracker.askTracker
向trackerEndpoint發送消息GetMapOutputStatuses(shuffleId)
protected def askTracker[T: ClassTag](message: Any): T = {try {trackerEndpoint.askWithRetry[T](message)} catch {case e: Exception =>logError("Error communicating with MapOutputTracker", e)throw new SparkException("Error communicating with MapOutputTracker", e)}}- 1
MapOutputTrackerMasterEndpoint.receiveAndReply
case GetMapOutputStatuses(shuffleId: Int) =>val hostPort = context.senderAddress.hostPortlogInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort)val mapOutputStatuses = tracker.post(new GetMapOutputMessage(shuffleId, context))可以看到,這里并不是直接返回消息,而是調用tracker.post:
def post(message: GetMapOutputMessage): Unit = {mapOutputRequests.offer(message)}- 1
- 2
向mapOutputRequests加入GetMapOutputMessage(shuffleId, context)消息。這里的mapOutputRequests是鏈式阻塞隊列。
private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]- 1
MapOutputTrackerMaster.MessageLoop.run
MessageLoop啟一個線程不斷的參數從mapOutputRequests讀取數據:
private class MessageLoop extends Runnable {override def run(): Unit = {try {while (true) {try {val data = mapOutputRequests.take()if (data == PoisonPill) {mapOutputRequests.offer(PoisonPill)return}val context = data.contextval shuffleId = data.shuffleIdval hostPort = context.senderAddress.hostPortlogDebug("Handling request to send map output locations for shuffle " + shuffleId +" to " + hostPort)// 若讀到數據// 則序列化val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)// 返回數據context.reply(mapOutputStatuses)} catch {case NonFatal(e) => logError(e.getMessage, e)}}} catch {case ie: InterruptedException => // exit}}}- 1
- 2
MapOutputTracker.convertMapStatuses
我們回到mapOutputTracker.getMapSizesByExecutorId中返回的MapOutputTracker.convertMapStatuses:
private def convertMapStatuses(shuffleId: Int,startPartition: Int,endPartition: Int,statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {assert (statuses != null)val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]]for ((status, mapId) <- statuses.zipWithIndex) {if (status == null) {val errorMessage = s"Missing an output location for shuffle $shuffleId"logError(errorMessage)throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)} else {for (part <- startPartition until endPartition) {// 返回的Seq中的結構是status.location,Seq[ShuffleBlockId,SizeForBlock]splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) +=((ShuffleBlockId(shuffleId, mapId, part), status.getSizeForBlock(part)))}}}// 對Seq根據status.location進行排序splitsByAddress.toSeq}- 1
劃分本地和遠程Block
讓我回到new ShuffleBlockFetcherIterator
storage.ShuffleBlockFetcherIterator.initialize
當我們實例化ShuffleBlockFetcherIterator時,會調用initialize:
private[this] def initialize(): Unit = {context.addTaskCompletionListener(_ => cleanup())// 劃分本地和遠程的blocksval remoteRequests = splitLocalRemoteBlocks()// 把遠程請求隨機的添加到隊列中fetchRequests ++= Utils.randomize(remoteRequests)assert ((0 == reqsInFlight) == (0 == bytesInFlight),"expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +", expected bytesInFlight = 0 but found bytesInFlight = " + bytesInFlight)// 發送遠程請求獲取blocksfetchUpToMaxBytes()val numFetches = remoteRequests.size - fetchRequests.sizelogInfo("Started " + numFetches + " remote fetches in" + Utils.getUsedTimeMs(startTime))// 獲取本地的BlocksfetchLocalBlocks()logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))}- 1
storage.ShuffleBlockFetcherIterator.splitLocalRemoteBlocks
private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {// 是的遠程請求最大長度為 maxBytesInFlight / 5// maxBytesInFlight: 為單次航班請求的最大字節數// 航班: 一批請求// 1/5 : 是為了提高請求批發度,允許5個請求分別從5個節點獲取數據val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)// 緩存需要遠程請求的FetchRequest對象val remoteRequests = new ArrayBuffer[FetchRequest]// 總共 blocks 的數量var totalBlocks = 0// 我們從上文可知blocksByAddress是根據status.location進行排序的for ((address, blockInfos) <- blocksByAddress) {totalBlocks += blockInfos.sizeif (address.executorId == blockManager.blockManagerId.executorId) {// 若 executorId 相同 與本 blockManagerId.executorId,// 則從本地獲取localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)numBlocksToFetch += localBlocks.size} else {// 否則 遠程請求// 得到迭代器val iterator = blockInfos.iterator// 當前累計塊的大小var curRequestSize = 0L// 當前累加塊// 累加: 若向一個節點頻繁的請求字節很少的Block,// 那么會造成網絡阻塞var curBlocks = new ArrayBuffer[(BlockId, Long)]// iterator 中的block 都是同一節點的while (iterator.hasNext) {val (blockId, size) = iterator.next()if (size > 0) {curBlocks += ((blockId, size))remoteBlocks += blockIdnumBlocksToFetch += 1curRequestSize += size} else if (size < 0) {throw new BlockException(blockId, "Negative block size " + size)}if (curRequestSize >= targetRequestSize) {// 若累加到大于遠程請求的尺寸// 往remoteRequests加入FetchRequestremoteRequests += new FetchRequest(address, curBlocks)curBlocks = new ArrayBuffer[(BlockId, Long)]logDebug(s"Creating fetch request of $curRequestSize at $address")curRequestSize = 0}}// 增加最后的請求if (curBlocks.nonEmpty) {remoteRequests += new FetchRequest(address, curBlocks)}}}logInfo(s"Getting $numBlocksToFetch non-empty blocks out of $totalBlocks blocks")remoteRequests}- 1
- 8
獲取Block
storage.ShuffleBlockFetcherIterator.fetchUpToMaxBytes
我們回到storage.ShuffleBlockFetcherIterator.initialize的fetchUpToMaxBytes()來深入講解下如何獲取遠程的Block:
private def fetchUpToMaxBytes(): Unit = {// Send fetch requests up to maxBytesInFlight// 單次航班請求數要小于最大航班請求數// 單次航班字節數數要小于最大航班字節數while (fetchRequests.nonEmpty &&(bytesInFlight == 0 ||(reqsInFlight + 1 <= maxReqsInFlight &&bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) {sendRequest(fetchRequests.dequeue())}}- 1
storage.ShuffleBlockFetcherIterator.sendRequest
private[this] def sendRequest(req: FetchRequest) {logDebug("Sending request for %d blocks (%s) from %s".format(req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))bytesInFlight += req.sizereqsInFlight += 1// 可根據blockID查詢block大小val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMapval remainingBlocks = new HashSet[String]() ++= sizeMap.keysval blockIds = req.blocks.map(_._1.toString)val address = req.address// 關于shuffleClient.fetchBlocks我們會在之后的博文講解shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,new BlockFetchingListener {// 請求成功override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {ShuffleBlockFetcherIterator.this.synchronized {if (!isZombie) {buf.retain()remainingBlocks -= blockIdresults.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,remainingBlocks.isEmpty))logDebug("remainingBlocks: " + remainingBlocks)}}logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))}// 請求失敗override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)results.put(new FailureFetchResult(BlockId(blockId), address, e))}})}- 1
storage.ShuffleBlockFetcherIterator.fetchLocalBlocks
我們再回過頭來看獲取本地blocks:
private[this] def fetchLocalBlocks() {// 獲取迭代器val iter = localBlocks.iteratorwhile (iter.hasNext) {val blockId = iter.next()try {// 遍歷獲取數據// blockManager.getBlockData 會在后續博文講解val buf = blockManager.getBlockData(blockId)shuffleMetrics.incLocalBlocksFetched(1)shuffleMetrics.incLocalBytesRead(buf.size)buf.retain()results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf, false))} catch {case e: Exception =>logError(s"Error occurred while fetching local blocks", e)results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))return}}}總結
以上是生活随笔為你收集整理的深入理解Spark 2.1 Core (十一):Shuffle Reduce 端的原理与源码分析的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 深入理解Spark 2.1 Core (
- 下一篇: 深入理解Spark 2.1 Core (