From c04e36b3e44ba9b6ae9a91fa89a824e06deffe96 Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Wed, 14 Jul 2021 13:36:00 +0300 Subject: [PATCH] Support Spark-2.1 version. (#33) --- .github/workflows/sparkucx-ci.yml | 2 +- .github/workflows/sparkucx-release.yml | 2 +- pom.xml | 103 +++++++++++-- .../spark_2_1/OnOffsetsFetchCallback.java | 70 +++++++++ .../compat/spark_2_1/UcxShuffleClient.java | 111 ++++++++++++++ .../spark_2_1/UcxShuffleBlockResolver.scala | 44 ++++++ .../compat/spark_2_1/UcxShuffleManager.scala | 53 +++++++ .../compat/spark_2_1/UcxShuffleReader.scala | 139 ++++++++++++++++++ 8 files changed, 507 insertions(+), 17 deletions(-) create mode 100755 src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java create mode 100755 src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala diff --git a/.github/workflows/sparkucx-ci.yml b/.github/workflows/sparkucx-ci.yml index 33b4f321..3f627287 100755 --- a/.github/workflows/sparkucx-ci.yml +++ b/.github/workflows/sparkucx-ci.yml @@ -9,7 +9,7 @@ jobs: build-sparkucx: strategy: matrix: - spark_version: ["2.4", "3.0"] + spark_version: ["2.1", "2.4", "3.0"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/sparkucx-release.yml b/.github/workflows/sparkucx-release.yml index 7ffe6637..cfa93c58 100644 --- a/.github/workflows/sparkucx-release.yml +++ b/.github/workflows/sparkucx-release.yml @@ -13,7 +13,7 @@ jobs: release: strategy: matrix: - spark_version: ["2.4", "3.0"] + spark_version: ["2.1", "2.4", "3.0"] runs-on: ubuntu-latest steps: - name: Checkout code diff --git a/pom.xml b/pom.xml index 92a0b23e..b23b54ac 100755 --- a/pom.xml +++ b/pom.xml @@ -5,8 +5,8 @@ See file LICENSE for terms. --> 4.0.0 org.openucx @@ -34,12 +34,68 @@ See file LICENSE for terms. + + spark-2.1 + + + + org.apache.maven.plugins + maven-compiler-plugin + + + **/spark_3_0/** + **/spark_2_4/** + + + + + net.alchim31.maven + scala-maven-plugin + + + **/spark_3_0/** + **/spark_2_4/** + + + + + + + 2.1.0 + **/spark_3_0/**, **/spark_2_4/** + 2.11.12 + 2.11 + + spark-2.4 + + + + org.apache.maven.plugins + maven-compiler-plugin + + + **/spark_3_0/** + **/spark_2_1/** + + + + + net.alchim31.maven + scala-maven-plugin + + + **/spark_2_1/** + **/spark_3_0/** + + + + + 2.4.0 - **/spark_3_0/** - **/spark_3_0/** + **/spark_3_0/**, **/spark_2_1/** 2.11.12 2.11 @@ -49,12 +105,35 @@ See file LICENSE for terms. true + + + + org.apache.maven.plugins + maven-compiler-plugin + + + **/spark_2_1/** + **/spark_2_4/** + + + + + net.alchim31.maven + scala-maven-plugin + + + **/spark_2_1/** + **/spark_2_4/** + + + + + 3.0.1 2.12.10 2.12 - **/spark_2_4/** - **/spark_2_4/** + **/spark_2_1/**, **/spark_2_4/** @@ -84,9 +163,6 @@ See file LICENSE for terms. 1.8 1.8 - - ${project.excludes} - @@ -95,9 +171,6 @@ See file LICENSE for terms. 4.3.0 all - - ${project.excludes} - -nobootcp -Xexperimental @@ -111,9 +184,9 @@ See file LICENSE for terms. compile - - compile - + + compile + compile diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java new file mode 100755 index 00000000..5f95caf6 --- /dev/null +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/OnOffsetsFetchCallback.java @@ -0,0 +1,70 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ +package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1; + +import org.apache.spark.network.shuffle.BlockFetchingListener; +import org.apache.spark.shuffle.ucx.UnsafeUtils; +import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; +import org.apache.spark.shuffle.ucx.reducer.OnBlocksFetchCallback; +import org.apache.spark.shuffle.ucx.reducer.ReducerCallback; +import org.apache.spark.storage.ShuffleBlockId; +import org.openucx.jucx.UcxUtils; +import org.openucx.jucx.ucp.UcpEndpoint; +import org.openucx.jucx.ucp.UcpRemoteKey; +import org.openucx.jucx.ucp.UcpRequest; + +import java.nio.ByteBuffer; +import java.util.Map; + +/** + * Callback, called when got all offsets for blocks + */ +public class OnOffsetsFetchCallback extends ReducerCallback { + private final RegisteredMemory offsetMemory; + private final long[] dataAddresses; + private Map dataRkeysCache; + + public OnOffsetsFetchCallback(ShuffleBlockId[] blockIds, UcpEndpoint endpoint, BlockFetchingListener listener, + RegisteredMemory offsetMemory, long[] dataAddresses, + Map dataRkeysCache) { + super(blockIds, endpoint, listener); + this.offsetMemory = offsetMemory; + this.dataAddresses = dataAddresses; + this.dataRkeysCache = dataRkeysCache; + } + + @Override + public void onSuccess(UcpRequest request) { + ByteBuffer resultOffset = offsetMemory.getBuffer(); + long totalSize = 0; + int[] sizes = new int[blockIds.length]; + int offsetSize = UnsafeUtils.LONG_SIZE; + for (int i = 0; i < blockIds.length; i++) { + // Blocks in metadata buffer are in form | blockOffsetStart | blockOffsetEnd | + long blockOffset = resultOffset.getLong(i * 2 * offsetSize); + long blockLength = resultOffset.getLong(i * 2 * offsetSize + offsetSize) - blockOffset; + assert (blockLength > 0) && (blockLength <= Integer.MAX_VALUE); + sizes[i] = (int) blockLength; + totalSize += blockLength; + dataAddresses[i] += blockOffset; + } + + assert (totalSize > 0) && (totalSize < Integer.MAX_VALUE); + mempool.put(offsetMemory); + RegisteredMemory blocksMemory = mempool.get((int) totalSize); + + long offset = 0; + // Submits N fetch blocks requests + for (int i = 0; i < blockIds.length; i++) { + endpoint.getNonBlockingImplicit(dataAddresses[i], dataRkeysCache.get(((ShuffleBlockId)blockIds[i]).mapId()), + UcxUtils.getAddress(blocksMemory.getBuffer()) + offset, sizes[i]); + offset += sizes[i]; + } + + // Process blocks when all fetched. + // Flush guarantees that callback would invoke when all fetch requests will completed. + endpoint.flushNonBlocking(new OnBlocksFetchCallback(this, blocksMemory, sizes)); + } +} diff --git a/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java new file mode 100755 index 00000000..0595bfc8 --- /dev/null +++ b/src/main/java/org/apache/spark/shuffle/ucx/reducer/compat/spark_2_1/UcxShuffleClient.java @@ -0,0 +1,111 @@ +/* + * Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. + * See file LICENSE for terms. + */ +package org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1; + +import org.apache.spark.SparkEnv; +import org.apache.spark.executor.TempShuffleReadMetrics; +import org.apache.spark.network.shuffle.BlockFetchingListener; +import org.apache.spark.network.shuffle.ShuffleClient; +import org.apache.spark.shuffle.DriverMetadata; +import org.apache.spark.shuffle.UcxShuffleManager; +import org.apache.spark.shuffle.UcxWorkerWrapper; +import org.apache.spark.shuffle.ucx.UnsafeUtils; +import org.apache.spark.shuffle.ucx.memory.MemoryPool; +import org.apache.spark.shuffle.ucx.memory.RegisteredMemory; +import org.apache.spark.storage.BlockId; +import org.apache.spark.storage.BlockManagerId; +import org.apache.spark.storage.ShuffleBlockId; +import org.openucx.jucx.UcxUtils; +import org.openucx.jucx.ucp.UcpEndpoint; +import org.openucx.jucx.ucp.UcpRemoteKey; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.Option; + +import java.util.Arrays; +import java.util.HashMap; + +public class UcxShuffleClient extends ShuffleClient { + private final MemoryPool mempool; + private static final Logger logger = LoggerFactory.getLogger(UcxShuffleClient.class); + private final UcxShuffleManager ucxShuffleManager; + private final TempShuffleReadMetrics shuffleReadMetrics; + private final UcxWorkerWrapper workerWrapper; + final HashMap offsetRkeysCache = new HashMap<>(); + final HashMap dataRkeysCache = new HashMap<>(); + + public UcxShuffleClient(TempShuffleReadMetrics shuffleReadMetrics, + UcxWorkerWrapper workerWrapper) { + this.ucxShuffleManager = (UcxShuffleManager) SparkEnv.get().shuffleManager(); + this.mempool = ucxShuffleManager.ucxNode().getMemoryPool(); + this.shuffleReadMetrics = shuffleReadMetrics; + this.workerWrapper = workerWrapper; + } + + /** + * Submits n non blocking fetch offsets to get needed offsets for n blocks. + */ + private void submitFetchOffsets(UcpEndpoint endpoint, ShuffleBlockId[] blockIds, + long[] dataAddresses, RegisteredMemory offsetMemory) { + DriverMetadata driverMetadata = workerWrapper.fetchDriverMetadataBuffer(blockIds[0].shuffleId()); + for (int i = 0; i < blockIds.length; i++) { + ShuffleBlockId blockId = blockIds[i]; + + long offsetAddress = driverMetadata.offsetAddress(blockId.mapId()); + dataAddresses[i] = driverMetadata.dataAddress(blockId.mapId()); + + offsetRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> + endpoint.unpackRemoteKey(driverMetadata.offsetRkey(blockId.mapId()))); + + dataRkeysCache.computeIfAbsent(blockId.mapId(), mapId -> + endpoint.unpackRemoteKey(driverMetadata.dataRkey(blockId.mapId()))); + + endpoint.getNonBlockingImplicit( + offsetAddress + blockId.reduceId() * UnsafeUtils.LONG_SIZE, + offsetRkeysCache.get(blockId.mapId()), + UcxUtils.getAddress(offsetMemory.getBuffer()) + (i * 2L * UnsafeUtils.LONG_SIZE), + 2L * UnsafeUtils.LONG_SIZE); + } + } + + /** + * Reducer entry point. Fetches remote blocks, using 2 ucp_get calls. + * This method is inside ShuffleFetchIterator's for loop over hosts. + * First fetches block offset from index file, and then fetches block itself. + */ + @Override + public void fetchBlocks(String host, int port, String execId, + String[] blockIds, BlockFetchingListener listener) { + long startTime = System.currentTimeMillis(); + + BlockManagerId blockManagerId = BlockManagerId.apply(execId, host, port, Option.empty()); + UcpEndpoint endpoint = workerWrapper.getConnection(blockManagerId); + + long[] dataAddresses = new long[blockIds.length]; + + // Need to fetch 2 long offsets current block + next block to calculate exact block size. + RegisteredMemory offsetMemory = mempool.get(2 * UnsafeUtils.LONG_SIZE * blockIds.length); + + ShuffleBlockId[] shuffleBlockIds = Arrays.stream(blockIds) + .map(blockId -> (ShuffleBlockId) BlockId.apply(blockId)).toArray(ShuffleBlockId[]::new); + + // Submits N implicit get requests without callback + submitFetchOffsets(endpoint, shuffleBlockIds, dataAddresses, offsetMemory); + + // flush guarantees that all that requests completes when callback is called. + // TODO: fix https://github.com/openucx/ucx/issues/4267 and use endpoint flush. + workerWrapper.worker().flushNonBlocking( + new OnOffsetsFetchCallback(shuffleBlockIds, endpoint, listener, offsetMemory, + dataAddresses, dataRkeysCache)); + shuffleReadMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime); + } + + @Override + public void close() { + offsetRkeysCache.values().forEach(UcpRemoteKey::close); + dataRkeysCache.values().forEach(UcpRemoteKey::close); + logger.info("Shuffle read metrics, fetch wait time: {}ms", shuffleReadMetrics.fetchWaitTime()); + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala new file mode 100755 index 00000000..90084f39 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleBlockResolver.scala @@ -0,0 +1,44 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_2_1 + +import java.io.{File, RandomAccessFile} + +import org.apache.spark.SparkEnv +import org.apache.spark.shuffle.{CommonUcxShuffleBlockResolver, CommonUcxShuffleManager, IndexShuffleBlockResolver} +import org.apache.spark.storage.ShuffleIndexBlockId + +/** + * Mapper entry point for UcxShuffle plugin. Performs memory registration + * of data and index files and publish addresses to driver metadata buffer. + */ +class UcxShuffleBlockResolver(ucxShuffleManager: CommonUcxShuffleManager) + extends CommonUcxShuffleBlockResolver(ucxShuffleManager) { + + private def getIndexFile(shuffleId: Int, mapId: Int): File = { + SparkEnv.get.blockManager + .diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)) + } + + /** + * Mapper commit protocol extension. Register index and data files and publish all needed + * metadata to driver. + */ + override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Int, + lengths: Array[Long], dataTmp: File): Unit = { + super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + val dataFile = getDataFile(shuffleId, mapId) + val dataBackFile = new RandomAccessFile(dataFile, "rw") + + if (dataBackFile.length() == 0) { + dataBackFile.close() + return + } + + val indexFile = getIndexFile(shuffleId, mapId) + val indexBackFile = new RandomAccessFile(indexFile, "rw") + writeIndexFileAndCommitCommon(shuffleId, mapId, lengths, dataTmp, indexBackFile, dataBackFile) + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala new file mode 100755 index 00000000..7087e491 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleManager.scala @@ -0,0 +1,53 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle + +import org.apache.spark.shuffle.compat.spark_2_1.{UcxShuffleBlockResolver, UcxShuffleReader} +import org.apache.spark.util.ShutdownHookManager +import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} + +/** + * Main entry point of Ucx shuffle plugin. It extends spark's default SortShufflePlugin + * and injects needed logic in override methods. + */ +class UcxShuffleManager(override val conf: SparkConf, isDriver: Boolean) extends CommonUcxShuffleManager(conf, isDriver) { + ShutdownHookManager.addShutdownHook(Int.MaxValue - 1)(stop) + + /** + * Register a shuffle with the manager and obtain a handle for it to pass to tasks. + * Called on driver and guaranteed by spark that shuffle on executor will start after it. + */ + override def registerShuffle[K, V, C](shuffleId: ShuffleId, + numMaps: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + assume(isDriver) + val baseHandle = super.registerShuffle(shuffleId, numMaps, dependency).asInstanceOf[BaseShuffleHandle[K, V, C]] + registerShuffleCommon(baseHandle, shuffleId, numMaps) + } + + /** + * Mapper callback on executor. Just start UcxNode and use Spark mapper logic. + */ + override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, + context: TaskContext): ShuffleWriter[K, V] = { + startUcxNodeIfMissing() + shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,V,_]]) + super.getWriter(handle.asInstanceOf[UcxShuffleHandle[K,V,_]].baseHandle, mapId, context) + } + + override val shuffleBlockResolver: UcxShuffleBlockResolver = new UcxShuffleBlockResolver(this) + + /** + * Reducer callback on executor. + */ + override def getReader[K, C](handle: ShuffleHandle, startPartition: Int, + endPartition: Int, context: TaskContext): ShuffleReader[K, C] = { + startUcxNodeIfMissing() + shuffleIdToHandle.putIfAbsent(handle.shuffleId, handle.asInstanceOf[UcxShuffleHandle[K,_,C]]) + new UcxShuffleReader(handle.asInstanceOf[UcxShuffleHandle[K,_,C]], startPartition, + endPartition, context) + } +} + diff --git a/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala new file mode 100755 index 00000000..1c7e1511 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/compat/spark_2_1/UcxShuffleReader.scala @@ -0,0 +1,139 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.compat.spark_2_1 + +import java.io.InputStream +import java.util.concurrent.LinkedBlockingQueue + +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ucx.reducer.compat.spark_2_1.UcxShuffleClient +import org.apache.spark.shuffle.{ShuffleReader, UcxShuffleHandle, UcxShuffleManager} +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} + +/** + * Extension of Spark's shuffe reader with a logic of injection UcxShuffleClient, + * and lazy progress only when result queue is empty. + */ +class UcxShuffleReader[K, C](handle: UcxShuffleHandle[K, _, C], + startPartition: Int, + endPartition: Int, + context: TaskContext, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + extends ShuffleReader[K, C] with Logging { + + private val dep = handle.baseHandle.dependency + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() + val workerWrapper = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] + .ucxNode.getThreadLocalWorker + val shuffleClient = new UcxShuffleClient(shuffleMetrics, workerWrapper) + val wrappedStreams = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, + startPartition, endPartition), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, + SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)) + + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = wrappedStreams.getClass.getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") + queueField.setAccessible(true) + val resultQueue = queueField.get(wrappedStreams).asInstanceOf[LinkedBlockingQueue[_]] + + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + workerWrapper.fillQueueWithBlocks(resultQueue) + shuffleMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } + + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() + } + result + } + } + // End of ucx shuffle logic + + val serializerInstance = dep.serializer.newInstance() + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, + ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + CompletionIterator[Product2[K, C], + Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } + +}