diff --git a/pom.xml b/pom.xml index d9330461..aafd36e0 100755 --- a/pom.xml +++ b/pom.xml @@ -58,7 +58,7 @@ See file LICENSE for terms. - ${project.artifactId}-${project.version}-for-spark-${spark.version} + ${project.artifactId}-${project.version}-for-spark-3.0 org.apache.maven.plugins diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala new file mode 100755 index 00000000..0d0f32bc --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala @@ -0,0 +1,101 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.io.File +import java.nio.{ByteBuffer, MappedByteBuffer} +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.openucx.jucx.UcxUtils +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.unsafe.Platform + + +case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + + def this(shuffleBlockId: ShuffleBlockId) = { + this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId) + } + + def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId +} + +case class BufferBackedBlock(buffer: ByteBuffer) extends Block { + override def getMemoryBlock: MemoryBlock = MemoryBlock(UcxUtils.getAddress(buffer), buffer.capacity()) +} + +class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport) + extends IndexShuffleBlockResolver(conf) { + + type MapId = Long + + private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int] + + override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, + lengths: Array[Long], dataTmp: File): Unit = { + super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + val dataFile = getDataFile(shuffleId, mapId) + if (!dataFile.exists()) { + return + } + numPartitionsForMapId.put(mapId, lengths.length) + val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ, + StandardOpenOption.WRITE) + val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length()) + + val baseAddress = UcxUtils.getAddress(mappedBuffer) + fileChannel.close() + + // Register whole map output file as dummy block + transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE), + BufferBackedBlock(mappedBuffer)) + + val offsetSize = 8 * (lengths.length + 1) + val indexBuf = Platform.allocateDirectBuffer(offsetSize) + + var offset = 0L + indexBuf.putLong(offset) + for (reduceId <- lengths.indices) { + if (lengths(reduceId) > 0) { + transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block { + private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId)) + override def getMemoryBlock: MemoryBlock = memoryBlock + }) + offset += lengths(reduceId) + indexBuf.putLong(offset) + } + } + + if (transport.ucxShuffleConf.protocol == transport.ucxShuffleConf.PROTOCOL.ONE_SIDED) { + transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE), BufferBackedBlock(indexBuf)) + } + } + + override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = { + transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE)) + transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE)) + + val numRegisteredBlocks = numPartitionsForMapId.get(mapId) + (0 until numRegisteredBlocks) + .foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId))) + super.removeDataByMap(shuffleId, mapId) + } + + override def stop(): Unit = { + numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId)) + super.stop() + } + +} + +object BlocksConstants { + val MAP_FILE: Int = -1 + val INDEX_FILE: Int = -2 +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala new file mode 100755 index 00000000..25c00daf --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala @@ -0,0 +1,85 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.util.concurrent.TimeUnit + +import org.openucx.jucx.{UcxException, UcxUtils} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId} + +class UcxShuffleClient(transport: UcxShuffleTransport, + blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])]) + extends BlockStoreClient with Logging { + + private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key) + + private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress + .withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId } + .flatMap { + case (blockManagerId, blocks) => + val blockIds = blocks.map { + case (blockId, _, _) => + val sparkBlockId = blockId.asInstanceOf[ShuffleBlockId] + UcxShuffleBlockId(sparkBlockId.shuffleId, sparkBlockId.mapId, sparkBlockId.reduceId) + } + if (!transport.ucxShuffleConf.pinMemory) { + transport.prefetchBlocks(blockManagerId.executorId, blockIds) + } + blocks.map { + case (blockId, length, _) => + if (length > accurateThreshold) { + (blockId, (length * 1.2).toLong) + } else { + (blockId, accurateThreshold) + } + } + }.toMap + + override def fetchBlocks(host: String, port: Int, execId: String, + blockIds: Array[String], listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager): Unit = { + val ucxBlockIds = new Array[BlockId](blockIds.length) + val memoryBlocks = new Array[MemoryBlock](blockIds.length) + val callbacks = new Array[OperationCallback](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId] + if (!blockSizes.contains(blockId)) { + throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}") + } + val resultMemory = transport.memoryPool.get(blockSizes(blockId)) + ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) + memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId)) + callbacks(i) = (result: OperationResult) => { + if (result.getStatus == OperationStatus.SUCCESS) { + val stats = result.getStats.get + logInfo(s" Received block ${ucxBlockIds(i)} " + + s"of size: ${stats.recvSize} " + + s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms") + val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt) + listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { + override def release: ManagedBuffer = { + transport.memoryPool.put(resultMemory) + this + } + }) + } else { + logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" + + s" ${result.getError.getMessage}") + throw new UcxException(result.getError.getMessage) + } + } + } + transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks) + } + + override def close(): Unit = { + + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala index 47c9fd73..f73f0970 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala @@ -12,15 +12,20 @@ import org.apache.spark.util.Utils class UcxShuffleConf(val conf: SparkConf) extends SparkConf { private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name" - private val PROTOCOL = + object PROTOCOL extends Enumeration { + val ONE_SIDED, RNDV = Value + } + + private lazy val PROTOCOL_CONF = ConfigBuilder(getUcxConf("protocol")) - .doc("Which protocol to use: rndv (default), one-sided") + .doc("Which protocol to use: RNDV (default), ONE-SIDED") .stringConf .checkValue(protocol => protocol == "rndv" || protocol == "one-sided", "Invalid protocol. Valid options: rndv / one-sided.") - .createWithDefault("rndv") + .transform(_.toUpperCase.replace("-", "_")) + .createWithDefault("RNDV") - private val MEMORY_PINNING = + private lazy val MEMORY_PINNING = ConfigBuilder(getUcxConf("memoryPinning")) .doc("Whether to pin whole shuffle data in memory") .booleanConf @@ -30,14 +35,14 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { ConfigBuilder(getUcxConf("maxWorkerSize")) .doc("Maximum size of worker address in bytes") .bytesConf(ByteUnit.BYTE) - .createWithDefault(1000) + .createWithDefault(1024) lazy val RPC_MESSAGE_SIZE: ConfigEntry[Long] = ConfigBuilder(getUcxConf("rpcMessageSize")) .doc("Size of RPC message to send from fetchBlockByBlockId. Must contain ") .bytesConf(ByteUnit.BYTE) .checkValue(size => size > maxWorkerAddressSize, - "Rpc message must contain workerAddress") + "Rpc message must contain at least workerAddress") .createWithDefault(2000) // Memory Pool @@ -58,6 +63,12 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { .intConf .createWithDefault(5) + private lazy val USE_SOCKADDR = + ConfigBuilder(getUcxConf("useSockAddr")) + .doc("Whether to use socket address to connect executors.") + .booleanConf + .createWithDefault(true) + private lazy val MIN_REGISTRATION_SIZE = ConfigBuilder(getUcxConf("memory.minAllocationSize")) .doc("Minimal memory registration size in memory pool.") @@ -67,7 +78,14 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key, MIN_REGISTRATION_SIZE.defaultValueString).toInt - lazy val protocol: String = conf.get(PROTOCOL.key, PROTOCOL.defaultValueString) + private lazy val USE_ODP = + ConfigBuilder(getUcxConf("useOdp")) + .doc("Whether to use on demand paging feature, to avoid memory pinning") + .booleanConf + .createWithDefault(false) + + lazy val protocol: PROTOCOL.Value = PROTOCOL.withName( + conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString)) lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false) @@ -83,6 +101,8 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { lazy val recvQueueSize: Int = conf.getInt(RECV_QUEUE_SIZE.key, RECV_QUEUE_SIZE.defaultValue.get) + lazy val useSockAddr: Boolean = conf.getBoolean(USE_SOCKADDR.key, USE_SOCKADDR.defaultValue.get) + lazy val preallocateBuffersMap: Map[Long, Int] = { conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty) .map(entry => entry.split(":") match { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala new file mode 100755 index 00000000..3d957b0d --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala @@ -0,0 +1,102 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.util.Success + +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, TaskContext} + + +class UcxShuffleManager(conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { + + val ucxShuffleConf = new UcxShuffleConf(conf) + + lazy val ucxShuffleTransport: UcxShuffleTransport = if (!isDriver) { + new UcxShuffleTransport(ucxShuffleConf, "init") + } else { + null + } + + @volatile private var initialized: Boolean = false + + override val shuffleBlockResolver = + new UcxShuffleBlockResolver(ucxShuffleConf, ucxShuffleTransport) + + logInfo("Starting UcxShuffleManager") + + def initTransport(): Unit = this.synchronized { + if (!initialized) { + val driverEndpointName = "ucx-shuffle-driver" + if (isDriver) { + val rpcEnv = SparkEnv.get.rpcEnv + val driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv) + rpcEnv.setupEndpoint(driverEndpointName, driverEndpoint) + } else { + val blockManager = SparkEnv.get.blockManager.blockManagerId + ucxShuffleTransport.executorId = blockManager.executorId + val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.host, + blockManager.port, conf, new SecurityManager(conf), 1, clientMode=false) + logDebug("Initializing ucx transport") + val address = ucxShuffleTransport.init() + val executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxShuffleTransport) + val endpoint = rpcEnv.setupEndpoint( + s"ucx-shuffle-executor-${blockManager.executorId}", + executorEndpoint) + + val driverEndpoint = RpcUtils.makeDriverRef(driverEndpointName, conf, rpcEnv) + driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId, + endpoint, new SerializableDirectBuffer(address))) + .andThen{ + case Success(msg) => + logInfo(s"Receive reply $msg") + executorEndpoint.receive(msg) + } + } + initialized = true + } + } + + override def getReader[K, C](handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition) + new UcxShuffleReader(ucxShuffleTransport, + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + override def getReaderForRange[K, C]( handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + new UcxShuffleReader(ucxShuffleTransport, + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + override def stop(): Unit = { + if (ucxShuffleTransport != null) { + ucxShuffleTransport.close() + } + super.stop() + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala new file mode 100755 index 00000000..6ee3e682 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala @@ -0,0 +1,173 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle + +import java.io.InputStream +import java.util.concurrent.LinkedBlockingQueue + +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ucx.{UcxShuffleClient, UcxShuffleTransport} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +/** + * Fetches and reads the blocks from a shuffle by requesting them from other nodes' block stores. + */ +class UcxShuffleReader[K, C](transport: UcxShuffleTransport, + handle: BaseShuffleHandle[K, _, C], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + extends ShuffleReader[K, C] with Logging { + + private val dep = handle.dependency + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol + if (shouldBatchFetch && !doBatchFetch) { + logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol.") + } + doBatchFetch + } + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val (blocksByAddress1, blocksByAddress2) = blocksByAddress.duplicate + val shuffleClient = new UcxShuffleClient(transport, Random.shuffle(blocksByAddress1)) + val shuffleIterator = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + Random.shuffle(blocksByAddress2), + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + readMetrics, + fetchContinuousBlocksInBatch) + val wrappedStreams = shuffleIterator.toCompletionIterator + + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = classOf[ShuffleBlockFetcherIterator].getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") + queueField.setAccessible(true) + val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] + + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + while (resultQueue.isEmpty) { + transport.progress() + } + readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } + + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() + } + result + } + } + // End of ucx shuffle logic + + val serializerInstance = dep.serializer.newInstance() + + // Create a key/value iterator for each stream + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 10aab2f2..495afe4b 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -3,6 +3,7 @@ * See file LICENSE for terms. */ package org.apache.spark.shuffle.ucx +import java.net.InetSocketAddress import java.nio.ByteBuffer import java.util.concurrent.ConcurrentHashMap @@ -17,6 +18,8 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.{MemoryPool, UcxHostBounceBuffersPool} import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread +import org.apache.spark.shuffle.ucx.utils.{SerializationUtils, UcxHelperUtils} +import org.apache.spark.util.Utils /** @@ -59,7 +62,7 @@ class UcxRequest(request: UcpRequest, stats: OperationStats) extends Request { /** * UCX implementation of [[ ShuffleTransport ]] API */ -class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executorId: String) +class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executorId: String) extends ShuffleTransport with Logging { // UCX entities @@ -77,69 +80,98 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo // Mapping between executorId and it's address private[ucx] val executorIdToAddress = new ConcurrentHashMap[String, ByteBuffer]() + private[ucx] val executorIdToSockAddress = new ConcurrentHashMap[String, InetSocketAddress]() private[ucx] val clientConnections = mutable.Map.empty[String, UcpEndpoint] // Need host ucx bounce buffer memory pool to send fetchBlockByBlockId request var memoryPool: MemoryPool = _ + @volatile private var initialized: Boolean = false + + private var workerAddress: ByteBuffer = _ + /** * Initialize transport resources. This function should get called after ensuring that SparkConf * has the correct configurations since it will use the spark configuration to configure itself. */ - override def init(): ByteBuffer = { - if (ucxShuffleConf == null) { - ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf) - } + override def init(): ByteBuffer = this.synchronized { + if (!initialized) { + if (ucxShuffleConf == null) { + ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf) + } - if (ucxShuffleConf.useOdp) { - memMapParams.nonBlocking() - } + if (ucxShuffleConf.useOdp) { + memMapParams.nonBlocking() + } - val params = new UcpParams().requestTagFeature().requestWakeupFeature() - if (ucxShuffleConf.protocol == "one-sided") { - params.requestRmaFeature() - } - ucxContext = new UcpContext(params) - globalWorker = ucxContext.newWorker(new UcpWorkerParams().requestWakeupTagRecv() - .requestWakeupTagSend()) - - val result = globalWorker.getAddress - require(result.capacity <= ucxShuffleConf.maxWorkerAddressSize, - s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${result.capacity}") - - memoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - progressThread = new GlobalWorkerRpcThread(globalWorker, memoryPool, this) - - threadLocalWorker = ThreadLocal.withInitial(() => { - val localWorker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, memoryPool) - allocatedWorkers.add(workerWrapper) - workerWrapper - }) + val params = new UcpParams().requestTagFeature() + + if (ucxShuffleConf.useWakeup) { + params.requestWakeupFeature() + } + + if (ucxShuffleConf.protocol == ucxShuffleConf.PROTOCOL.ONE_SIDED) { + params.requestRmaFeature() + } + ucxContext = new UcpContext(params) + + val workerParams = new UcpWorkerParams() + + if (ucxShuffleConf.useWakeup) { + workerParams.requestWakeupTagRecv().requestWakeupTagSend() + } + globalWorker = ucxContext.newWorker(workerParams) + + workerAddress = if (ucxShuffleConf.useSockAddr) { + val listener = UcxHelperUtils.startListenerOnRandomPort(globalWorker, ucxShuffleConf.conf) + val buffer = SerializationUtils.serializeInetAddress(listener.getAddress) + buffer + } else { + val workerAddress = globalWorker.getAddress + require(workerAddress.capacity <= ucxShuffleConf.maxWorkerAddressSize, + s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${workerAddress.capacity}") + workerAddress + } + + memoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) + progressThread = new GlobalWorkerRpcThread(globalWorker, memoryPool, this) + + threadLocalWorker = ThreadLocal.withInitial(() => { + val localWorker = ucxContext.newWorker(ucpWorkerParams) + val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, memoryPool) + allocatedWorkers.add(workerWrapper) + workerWrapper + }) - progressThread.start() - result + progressThread.start() + initialized = true + } + workerAddress } /** * Close all transport resources */ override def close(): Unit = { - progressThread.interrupt() - globalWorker.signal() - try { - progressThread.join() - } catch { - case _:InterruptedException => - case e:Throwable => logWarning(e.getLocalizedMessage) - } + if (initialized) { + progressThread.interrupt() + if (ucxShuffleConf.useWakeup) { + globalWorker.signal() + } + try { + progressThread.join() + } catch { + case _:InterruptedException => + case e:Throwable => logWarning(e.getLocalizedMessage) + } - memoryPool.close() - clientConnections.values.foreach(ep => ep.close()) - registeredBlocks.forEachKey(1, blockId => unregister(blockId)) - allocatedWorkers.forEach(_.close()) - globalWorker.close() - ucxContext.close() + memoryPool.close() + clientConnections.values.foreach(ep => ep.close()) + registeredBlocks.forEachKey(100, blockId => unregister(blockId)) + allocatedWorkers.forEach(_.close()) + globalWorker.close() + ucxContext.close() + } } /** @@ -147,23 +179,29 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo * connection establishment outside of UcxShuffleManager. */ def addExecutor(executorId: String, workerAddress: ByteBuffer): Unit = { - executorIdToAddress.put(executorId, workerAddress) + if (ucxShuffleConf.useSockAddr) { + executorIdToSockAddress.put(executorId, SerializationUtils.deserializeInetAddress(workerAddress)) + } else { + executorIdToAddress.put(executorId, workerAddress) + } + allocatedWorkers.forEach(w => w.getConnection(executorId)) } private[ucx] def handlePrefetchRequest(workerId: String, workerAddress: ByteBuffer, blockIds: Seq[BlockId]) { - logInfo(s"Prefetching blocks: ${blockIds.mkString(",")}") + val startTime = System.nanoTime() clientConnections.getOrElseUpdate(workerId, globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) ) - blockIds.foreach(blockId => { + blockIds.par.foreach(blockId => { val block = registeredBlocks.get(blockId) if (!block.isInstanceOf[UcxPinnedBlock]) { registeredBlocks.put(blockId, UcxPinnedBlock(block, pinMemory(block), prefetched = true)) } }) + logInfo(s"Prefetched ${blockIds.length} for $workerId in ${Utils.getUsedTimeNs(startTime)}") } /** @@ -171,9 +209,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo */ private[ucx] def replyFetchBlockRequest(workerId: String, workerAddress: ByteBuffer, blockId: BlockId, tag: Long): Unit = { + val ep = clientConnections.getOrElseUpdate(workerId, { + val epParams = new UcpEndpointParams() + if (ucxShuffleConf.useSockAddr) { + epParams.setPeerErrorHandlingMode().setSocketAddress( + SerializationUtils.deserializeInetAddress(workerAddress)) + } else { + epParams.setUcpAddress(workerAddress) + } + globalWorker.newEndpoint(epParams) + } - val ep = clientConnections.getOrElseUpdate(workerId, - globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) ) val block = registeredBlocks.get(blockId) @@ -184,9 +230,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo lock.lock() val blockMemory = block.getMemoryBlock - logInfo(s"Sending $blockId of size ${blockMemory.size} to tag: $tag") + logInfo(s"Sending $blockId of size ${blockMemory.size} to $workerId tag: $tag") ep.sendTaggedNonBlocking(blockMemory.address, blockMemory.size, tag, new UcxCallback { + private val startTime = System.nanoTime() override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sent $blockId of size ${blockMemory.size} to $workerId " + + s"tag: $tag in ${Utils.getUsedTimeNs(startTime)}") if (block.isInstanceOf[UcxPinnedBlock]) { val pinnedBlock = block.asInstanceOf[UcxPinnedBlock] if (pinnedBlock.prefetched) { @@ -214,6 +263,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo * Registers blocks using blockId on SERVER side. */ override def register(blockId: BlockId, block: Block): Unit = { + logTrace(s"Registering $blockId") val registeredBock: Block = if (ucxShuffleConf.pinMemory) { UcxPinnedBlock(block, pinMemory(block)) } else { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 2b55eb06..e137a313 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,18 +5,20 @@ package org.apache.spark.shuffle.ucx import java.io.{Closeable, ObjectOutputStream} +import java.net.InetSocketAddress +import java.nio.{BufferOverflowException, ByteBuffer} import java.util.concurrent.ThreadLocalRandom import scala.collection.mutable import com.fasterxml.jackson.databind.util.ByteBufferBackedOutputStream -import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRequest, UcpWorker} +import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpListener, UcpRequest, UcpWorker} import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.MemoryPool import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds} -import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer +import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils, UcxHelperUtils} import org.apache.spark.util.Utils /** @@ -50,13 +52,25 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon // Would not need when migrate to active messages. private val id: String = transport.executorId + s"_${Thread.currentThread().getId}" private final val connections = mutable.Map.empty[String, UcpEndpoint] - private val workerAddress = worker.getAddress + private val listener: Option[UcpListener] = if (ucxConf.useSockAddr) { + Some(UcxHelperUtils.startListenerOnRandomPort(worker, ucxConf.conf)) + } else { + None + } + + private val workerAddress = if (ucxConf.useSockAddr) { + SerializationUtils.serializeInetAddress(listener.get.getAddress) + } else { + worker.getAddress + } + override def close(): Unit = { connections.foreach{ case (_, endpoint) => endpoint.close() } connections.clear() + listener.map(_.close()) worker.close() } @@ -70,16 +84,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon } } - private def getConnection(executorId: String): UcpEndpoint = { - val workerAdresses = transport.executorIdToAddress + private[ucx] def getConnection(executorId: String): UcpEndpoint = { + val workerAddresses = if (ucxConf.useSockAddr) { + transport.executorIdToSockAddress + } else { + transport.executorIdToAddress + } - if (!workerAdresses.contains(executorId)) { + if (!workerAddresses.contains(executorId)) { // Block until there's no worker address for this BlockManagerID val startTime = System.currentTimeMillis() val timeout = ucxConf.conf.getTimeAsMs("spark.network.timeout", "100") - workerAdresses.synchronized { - while (workerAdresses.get(executorId) == null) { - workerAdresses.wait(timeout) + workerAddresses.synchronized { + while (workerAddresses.get(executorId) == null) { + workerAddresses.wait(timeout) if (System.currentTimeMillis() - startTime > timeout) { throw new UcxException(s"Didn't get worker address for $executorId during $timeout") } @@ -90,13 +108,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon connections.getOrElseUpdate(executorId, { logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId") val endpointParams = new UcpEndpointParams() - .setUcpAddress(workerAdresses.get(executorId)) + if (ucxConf.useSockAddr) { + val sockAddr = workerAddresses.get(executorId).asInstanceOf[InetSocketAddress] + logInfo(s"Connecting worker to $executorId at $sockAddr") + endpointParams.setPeerErrorHandlingMode().setSocketAddress(sockAddr) + } else { + endpointParams.setUcpAddress(workerAddresses.get(executorId).asInstanceOf[ByteBuffer]) + } + worker.newEndpoint(endpointParams) }) } private[ucx] def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = { - logInfo(s"Sending prefetch ${blockIds.length} blocks to $executorId") + logDebug(s"Sending prefetch ${blockIds.length} blocks to $executorId") val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt) @@ -106,7 +131,15 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => val out = new ObjectOutputStream(bos) - out.writeObject(message) + try { + out.writeObject(message) + } catch { + case _: BufferOverflowException => + throw new UcxException(s"Prefetch blocks message size > " + + s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}") + case ex: Exception => throw new UcxException(ex.getMessage) + } + out.flush() out.close() } @@ -142,16 +175,10 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon } val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0) - logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - memoryPool.put(mem) - } - }) val requests = new Array[UcxRequest](blockIds.length) for (i <- blockIds.indices) { + logInfo(s"Receiving block ${blockIds(i)}") val stats = new UcxStats() val result = new UcxSuccessOperationResult(stats) val request = worker.recvTaggedNonBlocking(resultBuffer(i).address, resultBuffer(i).size, @@ -168,6 +195,8 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon override def onSuccess(request: UcpRequest): Unit = { stats.endTime = System.nanoTime() stats.receiveSize = request.getRecvSize + logInfo(s"Received block ${blockIds(i)} from ${executorId} " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") if (callbacks(i) != null) { callbacks(i).onComplete(result) } @@ -175,6 +204,14 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon }) requests(i) = new UcxRequest(request, stats) } + + logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") + ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, + new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + memoryPool.put(mem) + } + }) requests } @@ -198,16 +235,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon out.close() } - - logTrace(s"Sending message to $executorId to fetch $blockId on tag $tag," + - s"resultBuffer $resultBuffer") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - memoryPool.put(mem) - } - }) - + // To avoid unexpected messages, first posting recv val result = new UcxSuccessOperationResult(stats) val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size, tag, -1L, new UcxCallback () { @@ -223,12 +251,24 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon override def onSuccess(request: UcpRequest): Unit = { stats.endTime = System.nanoTime() stats.receiveSize = request.getRecvSize + logInfo(s"Received block ${blockId} " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") if (cb != null) { cb.onComplete(result) } } }) - new UcxRequest(request, stats) + val recvRequest = new UcxRequest(request, stats) + + logInfo(s"Sending message to $executorId to fetch $blockId on tag $tag," + + s"resultBuffer $resultBuffer") + ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, + new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + memoryPool.put(mem) + } + }) + recvRequest } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala new file mode 100755 index 00000000..b8e1f6ee --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala @@ -0,0 +1,47 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.io + +import java.util +import java.util.Optional + +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} +import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter} +import org.apache.spark.shuffle.ucx.{UcxShuffleBlockResolver, UcxShuffleManager, UcxShuffleTransport} +import org.apache.spark.{SparkConf, SparkEnv} + + +class UcxShuffleExecutorComponents(sparkConf: SparkConf) + extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging { + + var ucxShuffleTransport: UcxShuffleTransport = _ + private var blockResolver: UcxShuffleBlockResolver = _ + + override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { + val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] + ucxShuffleManager.initTransport() + blockResolver = ucxShuffleManager.shuffleBlockResolver + } + + override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + new LocalDiskShuffleMapOutputWriter( + shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf) + } + + override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): + Optional[SingleSpillShuffleMapOutputWriter] = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) + } + +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala new file mode 100755 index 00000000..04f03b44 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala @@ -0,0 +1,26 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.io + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD +import org.apache.spark.shuffle.api.{ShuffleDriverComponents, ShuffleExecutorComponents} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO +import org.apache.spark.shuffle.ucx.UcxShuffleManager + +class UcxShuffleIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging { + + sparkConf.set(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "100000") + + override def driver(): ShuffleDriverComponents = { + SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager].initTransport() + super.driver() + } + + override def executor(): ShuffleExecutorComponents = { + new UcxShuffleExecutorComponents(sparkConf) + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index c65dba23..17357907 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -21,7 +21,6 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, setDaemon(true) setName("Ucx Shuffle Transport Progress Thread") - override def run(): Unit = { val numRecvs = transport.ucxShuffleConf.recvQueueSize val msgSize = transport.ucxShuffleConf.rpcMessageSize @@ -36,9 +35,10 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, } while (!isInterrupted) { - if (globalWorker.progress() == 0) { - globalWorker.waitForEvents() - globalWorker.progress() + while (globalWorker.progress() == 0) { + if (transport.ucxShuffleConf.useWakeup) { + globalWorker.waitForEvents() + } } for (i <- 0 until numRecvs) { if (requests(i).isCompleted) { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala new file mode 100755 index 00000000..50a9e2e2 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala @@ -0,0 +1,42 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import scala.collection.immutable.HashMap +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc._ +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer + +class UcxDriverRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + private val endpoints: mutable.Set[RpcEndpointRef] = mutable.HashSet.empty + private var blockManagerToWorkerAddress = HashMap.empty[String, SerializableDirectBuffer] + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message@ExecutorAdded(executorId: String, endpoint: RpcEndpointRef, + ucxWorkerAddress: SerializableDirectBuffer) => { + // Driver receives a message from executor with it's workerAddress + // 1. Introduce existing members of a cluster + logInfo(s"Received $message") + if (blockManagerToWorkerAddress.nonEmpty) { + val msg = IntroduceAllExecutors(blockManagerToWorkerAddress.keys.toSeq, + blockManagerToWorkerAddress.values.toList) + logInfo(s"replying $msg to $executorId") + context.reply(msg) + } + blockManagerToWorkerAddress += executorId -> ucxWorkerAddress + // 2. For each existing member introduce newly joined executor. + endpoints.foreach(ep => { + logInfo(s"Sending $message to $ep") + ep.send(message) + }) + logInfo(s"Connecting back to address: ${context.senderAddress}") + endpoints.add(endpoint) + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala new file mode 100755 index 00000000..3b8dec20 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala @@ -0,0 +1,31 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer + +class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleTransport) + extends RpcEndpoint with Logging { + + + override def receive: PartialFunction[Any, Unit] = { + case ExecutorAdded(executorId: String, _: RpcEndpointRef, + ucxWorkerAddress: SerializableDirectBuffer) => { + logInfo(s"Received ExecutorAdded($executorId)") + transport.addExecutor(executorId, ucxWorkerAddress.value) + } + case IntroduceAllExecutors(executorIds: Seq[String], + ucxWorkerAddresses: Seq[SerializableDirectBuffer]) => { + logInfo(s"Received IntroduceAllExecutors(${executorIds.mkString(",")})") + executorIds.zip(ucxWorkerAddresses).foreach { + case (executorId, workerAddress) => transport.addExecutor(executorId, workerAddress.value) + } + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala index 009d7b08..4d453b38 100644 --- a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala @@ -5,11 +5,12 @@ package org.apache.spark.shuffle.ucx.utils import java.io.{EOFException, ObjectInputStream, ObjectOutputStream} +import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.Channels import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make @@ -64,3 +65,28 @@ class DeserializableToExternalMemoryBuffer(@transient var buffer: ByteBuffer)() buffer.rewind() // Allow us to read it later } } + + +object SerializationUtils { + + def deserializeInetAddress(workerAddress: ByteBuffer): InetSocketAddress = { + workerAddress.rewind() + Utils.tryWithResource(new ByteBufferInputStream(workerAddress)) { bin => + val objIn = new ObjectInputStream(bin) + val obj = objIn.readObject().asInstanceOf[InetSocketAddress] + objIn.close() + obj + } + } + + def serializeInetAddress(address: InetSocketAddress): ByteBuffer = { + val hostAddress = new InetSocketAddress(Utils.localCanonicalHostName(), address.getPort) + Utils.tryWithResource(new ByteBufferOutputStream(100)) {bos => + val out = new ObjectOutputStream(bos) + out.writeObject(hostAddress) + out.flush() + out.close() + bos.toByteBuffer + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala new file mode 100755 index 00000000..ef342c79 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala @@ -0,0 +1,28 @@ +package org.apache.spark.shuffle.ucx.utils + +import java.net.{BindException, InetSocketAddress} + +import scala.util.Random + +import org.openucx.jucx.UcxException +import org.openucx.jucx.ucp.{UcpListener, UcpListenerParams, UcpWorker} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +object UcxHelperUtils extends Logging{ + def startListenerOnRandomPort(worker: UcpWorker, sparkConf: SparkConf): UcpListener = { + val ucpListenerParams = new UcpListenerParams() + val (listener, _) = Utils.startServiceOnPort(1024 + Random.nextInt(65535 - 1024), (port: Int) => { + ucpListenerParams.setSockAddr(new InetSocketAddress(port)) + val listener = try { + worker.newListener(ucpListenerParams) + } catch { + case ex:UcxException => throw new BindException(ex.getMessage) + } + (listener, listener.getAddress.getPort) + }, sparkConf) + logInfo(s"Started UcxListener on ${listener.getAddress}") + listener + } +}