From ecfedfe56e3485e606ff00fb6344f047d311174b Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Sat, 19 Sep 2020 18:13:17 +0300 Subject: [PATCH 1/4] Spark Shuffle manager implementation. --- pom.xml | 2 +- .../shuffle/ucx/UcxShuffleBlockResolver.scala | 101 ++++++++++ .../spark/shuffle/ucx/UcxShuffleClient.scala | 85 +++++++++ .../spark/shuffle/ucx/UcxShuffleConf.scala | 34 +++- .../spark/shuffle/ucx/UcxShuffleManager.scala | 102 +++++++++++ .../spark/shuffle/ucx/UcxShuffleReader.scala | 173 ++++++++++++++++++ .../shuffle/ucx/UcxShuffleTransport.scala | 152 +++++++++------ .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 100 +++++++--- .../ucx/io/UcxShuffleExecutorComponents.scala | 47 +++++ .../spark/shuffle/ucx/io/UcxShuffleIO.scala | 26 +++ .../ucx/rpc/GlobalWorkerRpcThread.scala | 8 +- .../ucx/rpc/UcxDriverRpcEndpoint.scala | 42 +++++ .../ucx/rpc/UcxExecutorRpcEndpoint.scala | 31 ++++ .../ucx/utils/SerializableDirectBuffer.scala | 28 ++- .../shuffle/ucx/utils/UcxHelperUtils.scala | 28 +++ 15 files changed, 865 insertions(+), 94 deletions(-) create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala diff --git a/pom.xml b/pom.xml index d9330461..aafd36e0 100755 --- a/pom.xml +++ b/pom.xml @@ -58,7 +58,7 @@ See file LICENSE for terms. - ${project.artifactId}-${project.version}-for-spark-${spark.version} + ${project.artifactId}-${project.version}-for-spark-3.0 org.apache.maven.plugins diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala new file mode 100755 index 00000000..0d0f32bc --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala @@ -0,0 +1,101 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.io.File +import java.nio.{ByteBuffer, MappedByteBuffer} +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.openucx.jucx.UcxUtils +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.unsafe.Platform + + +case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + + def this(shuffleBlockId: ShuffleBlockId) = { + this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId) + } + + def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId +} + +case class BufferBackedBlock(buffer: ByteBuffer) extends Block { + override def getMemoryBlock: MemoryBlock = MemoryBlock(UcxUtils.getAddress(buffer), buffer.capacity()) +} + +class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport) + extends IndexShuffleBlockResolver(conf) { + + type MapId = Long + + private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int] + + override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, + lengths: Array[Long], dataTmp: File): Unit = { + super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + val dataFile = getDataFile(shuffleId, mapId) + if (!dataFile.exists()) { + return + } + numPartitionsForMapId.put(mapId, lengths.length) + val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ, + StandardOpenOption.WRITE) + val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length()) + + val baseAddress = UcxUtils.getAddress(mappedBuffer) + fileChannel.close() + + // Register whole map output file as dummy block + transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE), + BufferBackedBlock(mappedBuffer)) + + val offsetSize = 8 * (lengths.length + 1) + val indexBuf = Platform.allocateDirectBuffer(offsetSize) + + var offset = 0L + indexBuf.putLong(offset) + for (reduceId <- lengths.indices) { + if (lengths(reduceId) > 0) { + transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block { + private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId)) + override def getMemoryBlock: MemoryBlock = memoryBlock + }) + offset += lengths(reduceId) + indexBuf.putLong(offset) + } + } + + if (transport.ucxShuffleConf.protocol == transport.ucxShuffleConf.PROTOCOL.ONE_SIDED) { + transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE), BufferBackedBlock(indexBuf)) + } + } + + override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = { + transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE)) + transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE)) + + val numRegisteredBlocks = numPartitionsForMapId.get(mapId) + (0 until numRegisteredBlocks) + .foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId))) + super.removeDataByMap(shuffleId, mapId) + } + + override def stop(): Unit = { + numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId)) + super.stop() + } + +} + +object BlocksConstants { + val MAP_FILE: Int = -1 + val INDEX_FILE: Int = -2 +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala new file mode 100755 index 00000000..25c00daf --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala @@ -0,0 +1,85 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.util.concurrent.TimeUnit + +import org.openucx.jucx.{UcxException, UcxUtils} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId} + +class UcxShuffleClient(transport: UcxShuffleTransport, + blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])]) + extends BlockStoreClient with Logging { + + private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key) + + private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress + .withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId } + .flatMap { + case (blockManagerId, blocks) => + val blockIds = blocks.map { + case (blockId, _, _) => + val sparkBlockId = blockId.asInstanceOf[ShuffleBlockId] + UcxShuffleBlockId(sparkBlockId.shuffleId, sparkBlockId.mapId, sparkBlockId.reduceId) + } + if (!transport.ucxShuffleConf.pinMemory) { + transport.prefetchBlocks(blockManagerId.executorId, blockIds) + } + blocks.map { + case (blockId, length, _) => + if (length > accurateThreshold) { + (blockId, (length * 1.2).toLong) + } else { + (blockId, accurateThreshold) + } + } + }.toMap + + override def fetchBlocks(host: String, port: Int, execId: String, + blockIds: Array[String], listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager): Unit = { + val ucxBlockIds = new Array[BlockId](blockIds.length) + val memoryBlocks = new Array[MemoryBlock](blockIds.length) + val callbacks = new Array[OperationCallback](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId] + if (!blockSizes.contains(blockId)) { + throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}") + } + val resultMemory = transport.memoryPool.get(blockSizes(blockId)) + ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) + memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId)) + callbacks(i) = (result: OperationResult) => { + if (result.getStatus == OperationStatus.SUCCESS) { + val stats = result.getStats.get + logInfo(s" Received block ${ucxBlockIds(i)} " + + s"of size: ${stats.recvSize} " + + s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms") + val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt) + listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { + override def release: ManagedBuffer = { + transport.memoryPool.put(resultMemory) + this + } + }) + } else { + logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" + + s" ${result.getError.getMessage}") + throw new UcxException(result.getError.getMessage) + } + } + } + transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks) + } + + override def close(): Unit = { + + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala index 47c9fd73..f73f0970 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala @@ -12,15 +12,20 @@ import org.apache.spark.util.Utils class UcxShuffleConf(val conf: SparkConf) extends SparkConf { private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name" - private val PROTOCOL = + object PROTOCOL extends Enumeration { + val ONE_SIDED, RNDV = Value + } + + private lazy val PROTOCOL_CONF = ConfigBuilder(getUcxConf("protocol")) - .doc("Which protocol to use: rndv (default), one-sided") + .doc("Which protocol to use: RNDV (default), ONE-SIDED") .stringConf .checkValue(protocol => protocol == "rndv" || protocol == "one-sided", "Invalid protocol. Valid options: rndv / one-sided.") - .createWithDefault("rndv") + .transform(_.toUpperCase.replace("-", "_")) + .createWithDefault("RNDV") - private val MEMORY_PINNING = + private lazy val MEMORY_PINNING = ConfigBuilder(getUcxConf("memoryPinning")) .doc("Whether to pin whole shuffle data in memory") .booleanConf @@ -30,14 +35,14 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { ConfigBuilder(getUcxConf("maxWorkerSize")) .doc("Maximum size of worker address in bytes") .bytesConf(ByteUnit.BYTE) - .createWithDefault(1000) + .createWithDefault(1024) lazy val RPC_MESSAGE_SIZE: ConfigEntry[Long] = ConfigBuilder(getUcxConf("rpcMessageSize")) .doc("Size of RPC message to send from fetchBlockByBlockId. Must contain ") .bytesConf(ByteUnit.BYTE) .checkValue(size => size > maxWorkerAddressSize, - "Rpc message must contain workerAddress") + "Rpc message must contain at least workerAddress") .createWithDefault(2000) // Memory Pool @@ -58,6 +63,12 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { .intConf .createWithDefault(5) + private lazy val USE_SOCKADDR = + ConfigBuilder(getUcxConf("useSockAddr")) + .doc("Whether to use socket address to connect executors.") + .booleanConf + .createWithDefault(true) + private lazy val MIN_REGISTRATION_SIZE = ConfigBuilder(getUcxConf("memory.minAllocationSize")) .doc("Minimal memory registration size in memory pool.") @@ -67,7 +78,14 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key, MIN_REGISTRATION_SIZE.defaultValueString).toInt - lazy val protocol: String = conf.get(PROTOCOL.key, PROTOCOL.defaultValueString) + private lazy val USE_ODP = + ConfigBuilder(getUcxConf("useOdp")) + .doc("Whether to use on demand paging feature, to avoid memory pinning") + .booleanConf + .createWithDefault(false) + + lazy val protocol: PROTOCOL.Value = PROTOCOL.withName( + conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString)) lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false) @@ -83,6 +101,8 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { lazy val recvQueueSize: Int = conf.getInt(RECV_QUEUE_SIZE.key, RECV_QUEUE_SIZE.defaultValue.get) + lazy val useSockAddr: Boolean = conf.getBoolean(USE_SOCKADDR.key, USE_SOCKADDR.defaultValue.get) + lazy val preallocateBuffersMap: Map[Long, Int] = { conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty) .map(entry => entry.split(":") match { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala new file mode 100755 index 00000000..3d957b0d --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala @@ -0,0 +1,102 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.util.Success + +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, TaskContext} + + +class UcxShuffleManager(conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { + + val ucxShuffleConf = new UcxShuffleConf(conf) + + lazy val ucxShuffleTransport: UcxShuffleTransport = if (!isDriver) { + new UcxShuffleTransport(ucxShuffleConf, "init") + } else { + null + } + + @volatile private var initialized: Boolean = false + + override val shuffleBlockResolver = + new UcxShuffleBlockResolver(ucxShuffleConf, ucxShuffleTransport) + + logInfo("Starting UcxShuffleManager") + + def initTransport(): Unit = this.synchronized { + if (!initialized) { + val driverEndpointName = "ucx-shuffle-driver" + if (isDriver) { + val rpcEnv = SparkEnv.get.rpcEnv + val driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv) + rpcEnv.setupEndpoint(driverEndpointName, driverEndpoint) + } else { + val blockManager = SparkEnv.get.blockManager.blockManagerId + ucxShuffleTransport.executorId = blockManager.executorId + val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.host, + blockManager.port, conf, new SecurityManager(conf), 1, clientMode=false) + logDebug("Initializing ucx transport") + val address = ucxShuffleTransport.init() + val executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxShuffleTransport) + val endpoint = rpcEnv.setupEndpoint( + s"ucx-shuffle-executor-${blockManager.executorId}", + executorEndpoint) + + val driverEndpoint = RpcUtils.makeDriverRef(driverEndpointName, conf, rpcEnv) + driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId, + endpoint, new SerializableDirectBuffer(address))) + .andThen{ + case Success(msg) => + logInfo(s"Receive reply $msg") + executorEndpoint.receive(msg) + } + } + initialized = true + } + } + + override def getReader[K, C](handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition) + new UcxShuffleReader(ucxShuffleTransport, + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + override def getReaderForRange[K, C]( handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + new UcxShuffleReader(ucxShuffleTransport, + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + override def stop(): Unit = { + if (ucxShuffleTransport != null) { + ucxShuffleTransport.close() + } + super.stop() + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala new file mode 100755 index 00000000..6ee3e682 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala @@ -0,0 +1,173 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle + +import java.io.InputStream +import java.util.concurrent.LinkedBlockingQueue + +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ucx.{UcxShuffleClient, UcxShuffleTransport} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +/** + * Fetches and reads the blocks from a shuffle by requesting them from other nodes' block stores. + */ +class UcxShuffleReader[K, C](transport: UcxShuffleTransport, + handle: BaseShuffleHandle[K, _, C], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + extends ShuffleReader[K, C] with Logging { + + private val dep = handle.dependency + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol + if (shouldBatchFetch && !doBatchFetch) { + logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol.") + } + doBatchFetch + } + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val (blocksByAddress1, blocksByAddress2) = blocksByAddress.duplicate + val shuffleClient = new UcxShuffleClient(transport, Random.shuffle(blocksByAddress1)) + val shuffleIterator = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + Random.shuffle(blocksByAddress2), + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + readMetrics, + fetchContinuousBlocksInBatch) + val wrappedStreams = shuffleIterator.toCompletionIterator + + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = classOf[ShuffleBlockFetcherIterator].getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") + queueField.setAccessible(true) + val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] + + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + while (resultQueue.isEmpty) { + transport.progress() + } + readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } + + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() + } + result + } + } + // End of ucx shuffle logic + + val serializerInstance = dep.serializer.newInstance() + + // Create a key/value iterator for each stream + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 10aab2f2..495afe4b 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -3,6 +3,7 @@ * See file LICENSE for terms. */ package org.apache.spark.shuffle.ucx +import java.net.InetSocketAddress import java.nio.ByteBuffer import java.util.concurrent.ConcurrentHashMap @@ -17,6 +18,8 @@ import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.{MemoryPool, UcxHostBounceBuffersPool} import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread +import org.apache.spark.shuffle.ucx.utils.{SerializationUtils, UcxHelperUtils} +import org.apache.spark.util.Utils /** @@ -59,7 +62,7 @@ class UcxRequest(request: UcpRequest, stats: OperationStats) extends Request { /** * UCX implementation of [[ ShuffleTransport ]] API */ -class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executorId: String) +class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executorId: String) extends ShuffleTransport with Logging { // UCX entities @@ -77,69 +80,98 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo // Mapping between executorId and it's address private[ucx] val executorIdToAddress = new ConcurrentHashMap[String, ByteBuffer]() + private[ucx] val executorIdToSockAddress = new ConcurrentHashMap[String, InetSocketAddress]() private[ucx] val clientConnections = mutable.Map.empty[String, UcpEndpoint] // Need host ucx bounce buffer memory pool to send fetchBlockByBlockId request var memoryPool: MemoryPool = _ + @volatile private var initialized: Boolean = false + + private var workerAddress: ByteBuffer = _ + /** * Initialize transport resources. This function should get called after ensuring that SparkConf * has the correct configurations since it will use the spark configuration to configure itself. */ - override def init(): ByteBuffer = { - if (ucxShuffleConf == null) { - ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf) - } + override def init(): ByteBuffer = this.synchronized { + if (!initialized) { + if (ucxShuffleConf == null) { + ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf) + } - if (ucxShuffleConf.useOdp) { - memMapParams.nonBlocking() - } + if (ucxShuffleConf.useOdp) { + memMapParams.nonBlocking() + } - val params = new UcpParams().requestTagFeature().requestWakeupFeature() - if (ucxShuffleConf.protocol == "one-sided") { - params.requestRmaFeature() - } - ucxContext = new UcpContext(params) - globalWorker = ucxContext.newWorker(new UcpWorkerParams().requestWakeupTagRecv() - .requestWakeupTagSend()) - - val result = globalWorker.getAddress - require(result.capacity <= ucxShuffleConf.maxWorkerAddressSize, - s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${result.capacity}") - - memoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - progressThread = new GlobalWorkerRpcThread(globalWorker, memoryPool, this) - - threadLocalWorker = ThreadLocal.withInitial(() => { - val localWorker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, memoryPool) - allocatedWorkers.add(workerWrapper) - workerWrapper - }) + val params = new UcpParams().requestTagFeature() + + if (ucxShuffleConf.useWakeup) { + params.requestWakeupFeature() + } + + if (ucxShuffleConf.protocol == ucxShuffleConf.PROTOCOL.ONE_SIDED) { + params.requestRmaFeature() + } + ucxContext = new UcpContext(params) + + val workerParams = new UcpWorkerParams() + + if (ucxShuffleConf.useWakeup) { + workerParams.requestWakeupTagRecv().requestWakeupTagSend() + } + globalWorker = ucxContext.newWorker(workerParams) + + workerAddress = if (ucxShuffleConf.useSockAddr) { + val listener = UcxHelperUtils.startListenerOnRandomPort(globalWorker, ucxShuffleConf.conf) + val buffer = SerializationUtils.serializeInetAddress(listener.getAddress) + buffer + } else { + val workerAddress = globalWorker.getAddress + require(workerAddress.capacity <= ucxShuffleConf.maxWorkerAddressSize, + s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${workerAddress.capacity}") + workerAddress + } + + memoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) + progressThread = new GlobalWorkerRpcThread(globalWorker, memoryPool, this) + + threadLocalWorker = ThreadLocal.withInitial(() => { + val localWorker = ucxContext.newWorker(ucpWorkerParams) + val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, memoryPool) + allocatedWorkers.add(workerWrapper) + workerWrapper + }) - progressThread.start() - result + progressThread.start() + initialized = true + } + workerAddress } /** * Close all transport resources */ override def close(): Unit = { - progressThread.interrupt() - globalWorker.signal() - try { - progressThread.join() - } catch { - case _:InterruptedException => - case e:Throwable => logWarning(e.getLocalizedMessage) - } + if (initialized) { + progressThread.interrupt() + if (ucxShuffleConf.useWakeup) { + globalWorker.signal() + } + try { + progressThread.join() + } catch { + case _:InterruptedException => + case e:Throwable => logWarning(e.getLocalizedMessage) + } - memoryPool.close() - clientConnections.values.foreach(ep => ep.close()) - registeredBlocks.forEachKey(1, blockId => unregister(blockId)) - allocatedWorkers.forEach(_.close()) - globalWorker.close() - ucxContext.close() + memoryPool.close() + clientConnections.values.foreach(ep => ep.close()) + registeredBlocks.forEachKey(100, blockId => unregister(blockId)) + allocatedWorkers.forEach(_.close()) + globalWorker.close() + ucxContext.close() + } } /** @@ -147,23 +179,29 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo * connection establishment outside of UcxShuffleManager. */ def addExecutor(executorId: String, workerAddress: ByteBuffer): Unit = { - executorIdToAddress.put(executorId, workerAddress) + if (ucxShuffleConf.useSockAddr) { + executorIdToSockAddress.put(executorId, SerializationUtils.deserializeInetAddress(workerAddress)) + } else { + executorIdToAddress.put(executorId, workerAddress) + } + allocatedWorkers.forEach(w => w.getConnection(executorId)) } private[ucx] def handlePrefetchRequest(workerId: String, workerAddress: ByteBuffer, blockIds: Seq[BlockId]) { - logInfo(s"Prefetching blocks: ${blockIds.mkString(",")}") + val startTime = System.nanoTime() clientConnections.getOrElseUpdate(workerId, globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) ) - blockIds.foreach(blockId => { + blockIds.par.foreach(blockId => { val block = registeredBlocks.get(blockId) if (!block.isInstanceOf[UcxPinnedBlock]) { registeredBlocks.put(blockId, UcxPinnedBlock(block, pinMemory(block), prefetched = true)) } }) + logInfo(s"Prefetched ${blockIds.length} for $workerId in ${Utils.getUsedTimeNs(startTime)}") } /** @@ -171,9 +209,17 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo */ private[ucx] def replyFetchBlockRequest(workerId: String, workerAddress: ByteBuffer, blockId: BlockId, tag: Long): Unit = { + val ep = clientConnections.getOrElseUpdate(workerId, { + val epParams = new UcpEndpointParams() + if (ucxShuffleConf.useSockAddr) { + epParams.setPeerErrorHandlingMode().setSocketAddress( + SerializationUtils.deserializeInetAddress(workerAddress)) + } else { + epParams.setUcpAddress(workerAddress) + } + globalWorker.newEndpoint(epParams) + } - val ep = clientConnections.getOrElseUpdate(workerId, - globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) ) val block = registeredBlocks.get(blockId) @@ -184,9 +230,12 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo lock.lock() val blockMemory = block.getMemoryBlock - logInfo(s"Sending $blockId of size ${blockMemory.size} to tag: $tag") + logInfo(s"Sending $blockId of size ${blockMemory.size} to $workerId tag: $tag") ep.sendTaggedNonBlocking(blockMemory.address, blockMemory.size, tag, new UcxCallback { + private val startTime = System.nanoTime() override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sent $blockId of size ${blockMemory.size} to $workerId " + + s"tag: $tag in ${Utils.getUsedTimeNs(startTime)}") if (block.isInstanceOf[UcxPinnedBlock]) { val pinnedBlock = block.asInstanceOf[UcxPinnedBlock] if (pinnedBlock.prefetched) { @@ -214,6 +263,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo * Registers blocks using blockId on SERVER side. */ override def register(blockId: BlockId, block: Block): Unit = { + logTrace(s"Registering $blockId") val registeredBock: Block = if (ucxShuffleConf.pinMemory) { UcxPinnedBlock(block, pinMemory(block)) } else { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 2b55eb06..e137a313 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,18 +5,20 @@ package org.apache.spark.shuffle.ucx import java.io.{Closeable, ObjectOutputStream} +import java.net.InetSocketAddress +import java.nio.{BufferOverflowException, ByteBuffer} import java.util.concurrent.ThreadLocalRandom import scala.collection.mutable import com.fasterxml.jackson.databind.util.ByteBufferBackedOutputStream -import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRequest, UcpWorker} +import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpListener, UcpRequest, UcpWorker} import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.MemoryPool import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds} -import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer +import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils, UcxHelperUtils} import org.apache.spark.util.Utils /** @@ -50,13 +52,25 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon // Would not need when migrate to active messages. private val id: String = transport.executorId + s"_${Thread.currentThread().getId}" private final val connections = mutable.Map.empty[String, UcpEndpoint] - private val workerAddress = worker.getAddress + private val listener: Option[UcpListener] = if (ucxConf.useSockAddr) { + Some(UcxHelperUtils.startListenerOnRandomPort(worker, ucxConf.conf)) + } else { + None + } + + private val workerAddress = if (ucxConf.useSockAddr) { + SerializationUtils.serializeInetAddress(listener.get.getAddress) + } else { + worker.getAddress + } + override def close(): Unit = { connections.foreach{ case (_, endpoint) => endpoint.close() } connections.clear() + listener.map(_.close()) worker.close() } @@ -70,16 +84,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon } } - private def getConnection(executorId: String): UcpEndpoint = { - val workerAdresses = transport.executorIdToAddress + private[ucx] def getConnection(executorId: String): UcpEndpoint = { + val workerAddresses = if (ucxConf.useSockAddr) { + transport.executorIdToSockAddress + } else { + transport.executorIdToAddress + } - if (!workerAdresses.contains(executorId)) { + if (!workerAddresses.contains(executorId)) { // Block until there's no worker address for this BlockManagerID val startTime = System.currentTimeMillis() val timeout = ucxConf.conf.getTimeAsMs("spark.network.timeout", "100") - workerAdresses.synchronized { - while (workerAdresses.get(executorId) == null) { - workerAdresses.wait(timeout) + workerAddresses.synchronized { + while (workerAddresses.get(executorId) == null) { + workerAddresses.wait(timeout) if (System.currentTimeMillis() - startTime > timeout) { throw new UcxException(s"Didn't get worker address for $executorId during $timeout") } @@ -90,13 +108,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon connections.getOrElseUpdate(executorId, { logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId") val endpointParams = new UcpEndpointParams() - .setUcpAddress(workerAdresses.get(executorId)) + if (ucxConf.useSockAddr) { + val sockAddr = workerAddresses.get(executorId).asInstanceOf[InetSocketAddress] + logInfo(s"Connecting worker to $executorId at $sockAddr") + endpointParams.setPeerErrorHandlingMode().setSocketAddress(sockAddr) + } else { + endpointParams.setUcpAddress(workerAddresses.get(executorId).asInstanceOf[ByteBuffer]) + } + worker.newEndpoint(endpointParams) }) } private[ucx] def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = { - logInfo(s"Sending prefetch ${blockIds.length} blocks to $executorId") + logDebug(s"Sending prefetch ${blockIds.length} blocks to $executorId") val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt) @@ -106,7 +131,15 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => val out = new ObjectOutputStream(bos) - out.writeObject(message) + try { + out.writeObject(message) + } catch { + case _: BufferOverflowException => + throw new UcxException(s"Prefetch blocks message size > " + + s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}") + case ex: Exception => throw new UcxException(ex.getMessage) + } + out.flush() out.close() } @@ -142,16 +175,10 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon } val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0) - logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - memoryPool.put(mem) - } - }) val requests = new Array[UcxRequest](blockIds.length) for (i <- blockIds.indices) { + logInfo(s"Receiving block ${blockIds(i)}") val stats = new UcxStats() val result = new UcxSuccessOperationResult(stats) val request = worker.recvTaggedNonBlocking(resultBuffer(i).address, resultBuffer(i).size, @@ -168,6 +195,8 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon override def onSuccess(request: UcpRequest): Unit = { stats.endTime = System.nanoTime() stats.receiveSize = request.getRecvSize + logInfo(s"Received block ${blockIds(i)} from ${executorId} " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") if (callbacks(i) != null) { callbacks(i).onComplete(result) } @@ -175,6 +204,14 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon }) requests(i) = new UcxRequest(request, stats) } + + logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") + ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, + new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + memoryPool.put(mem) + } + }) requests } @@ -198,16 +235,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon out.close() } - - logTrace(s"Sending message to $executorId to fetch $blockId on tag $tag," + - s"resultBuffer $resultBuffer") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - memoryPool.put(mem) - } - }) - + // To avoid unexpected messages, first posting recv val result = new UcxSuccessOperationResult(stats) val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size, tag, -1L, new UcxCallback () { @@ -223,12 +251,24 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon override def onSuccess(request: UcpRequest): Unit = { stats.endTime = System.nanoTime() stats.receiveSize = request.getRecvSize + logInfo(s"Received block ${blockId} " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") if (cb != null) { cb.onComplete(result) } } }) - new UcxRequest(request, stats) + val recvRequest = new UcxRequest(request, stats) + + logInfo(s"Sending message to $executorId to fetch $blockId on tag $tag," + + s"resultBuffer $resultBuffer") + ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, + new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + memoryPool.put(mem) + } + }) + recvRequest } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala new file mode 100755 index 00000000..b8e1f6ee --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala @@ -0,0 +1,47 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.io + +import java.util +import java.util.Optional + +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} +import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter} +import org.apache.spark.shuffle.ucx.{UcxShuffleBlockResolver, UcxShuffleManager, UcxShuffleTransport} +import org.apache.spark.{SparkConf, SparkEnv} + + +class UcxShuffleExecutorComponents(sparkConf: SparkConf) + extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging { + + var ucxShuffleTransport: UcxShuffleTransport = _ + private var blockResolver: UcxShuffleBlockResolver = _ + + override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { + val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] + ucxShuffleManager.initTransport() + blockResolver = ucxShuffleManager.shuffleBlockResolver + } + + override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + new LocalDiskShuffleMapOutputWriter( + shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf) + } + + override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): + Optional[SingleSpillShuffleMapOutputWriter] = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) + } + +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala new file mode 100755 index 00000000..04f03b44 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala @@ -0,0 +1,26 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.io + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD +import org.apache.spark.shuffle.api.{ShuffleDriverComponents, ShuffleExecutorComponents} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO +import org.apache.spark.shuffle.ucx.UcxShuffleManager + +class UcxShuffleIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging { + + sparkConf.set(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "100000") + + override def driver(): ShuffleDriverComponents = { + SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager].initTransport() + super.driver() + } + + override def executor(): ShuffleExecutorComponents = { + new UcxShuffleExecutorComponents(sparkConf) + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index c65dba23..17357907 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -21,7 +21,6 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, setDaemon(true) setName("Ucx Shuffle Transport Progress Thread") - override def run(): Unit = { val numRecvs = transport.ucxShuffleConf.recvQueueSize val msgSize = transport.ucxShuffleConf.rpcMessageSize @@ -36,9 +35,10 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, } while (!isInterrupted) { - if (globalWorker.progress() == 0) { - globalWorker.waitForEvents() - globalWorker.progress() + while (globalWorker.progress() == 0) { + if (transport.ucxShuffleConf.useWakeup) { + globalWorker.waitForEvents() + } } for (i <- 0 until numRecvs) { if (requests(i).isCompleted) { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala new file mode 100755 index 00000000..50a9e2e2 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala @@ -0,0 +1,42 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import scala.collection.immutable.HashMap +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc._ +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer + +class UcxDriverRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + private val endpoints: mutable.Set[RpcEndpointRef] = mutable.HashSet.empty + private var blockManagerToWorkerAddress = HashMap.empty[String, SerializableDirectBuffer] + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message@ExecutorAdded(executorId: String, endpoint: RpcEndpointRef, + ucxWorkerAddress: SerializableDirectBuffer) => { + // Driver receives a message from executor with it's workerAddress + // 1. Introduce existing members of a cluster + logInfo(s"Received $message") + if (blockManagerToWorkerAddress.nonEmpty) { + val msg = IntroduceAllExecutors(blockManagerToWorkerAddress.keys.toSeq, + blockManagerToWorkerAddress.values.toList) + logInfo(s"replying $msg to $executorId") + context.reply(msg) + } + blockManagerToWorkerAddress += executorId -> ucxWorkerAddress + // 2. For each existing member introduce newly joined executor. + endpoints.foreach(ep => { + logInfo(s"Sending $message to $ep") + ep.send(message) + }) + logInfo(s"Connecting back to address: ${context.senderAddress}") + endpoints.add(endpoint) + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala new file mode 100755 index 00000000..3b8dec20 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala @@ -0,0 +1,31 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer + +class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleTransport) + extends RpcEndpoint with Logging { + + + override def receive: PartialFunction[Any, Unit] = { + case ExecutorAdded(executorId: String, _: RpcEndpointRef, + ucxWorkerAddress: SerializableDirectBuffer) => { + logInfo(s"Received ExecutorAdded($executorId)") + transport.addExecutor(executorId, ucxWorkerAddress.value) + } + case IntroduceAllExecutors(executorIds: Seq[String], + ucxWorkerAddresses: Seq[SerializableDirectBuffer]) => { + logInfo(s"Received IntroduceAllExecutors(${executorIds.mkString(",")})") + executorIds.zip(ucxWorkerAddresses).foreach { + case (executorId, workerAddress) => transport.addExecutor(executorId, workerAddress.value) + } + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala index 009d7b08..4d453b38 100644 --- a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala @@ -5,11 +5,12 @@ package org.apache.spark.shuffle.ucx.utils import java.io.{EOFException, ObjectInputStream, ObjectOutputStream} +import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.Channels import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make @@ -64,3 +65,28 @@ class DeserializableToExternalMemoryBuffer(@transient var buffer: ByteBuffer)() buffer.rewind() // Allow us to read it later } } + + +object SerializationUtils { + + def deserializeInetAddress(workerAddress: ByteBuffer): InetSocketAddress = { + workerAddress.rewind() + Utils.tryWithResource(new ByteBufferInputStream(workerAddress)) { bin => + val objIn = new ObjectInputStream(bin) + val obj = objIn.readObject().asInstanceOf[InetSocketAddress] + objIn.close() + obj + } + } + + def serializeInetAddress(address: InetSocketAddress): ByteBuffer = { + val hostAddress = new InetSocketAddress(Utils.localCanonicalHostName(), address.getPort) + Utils.tryWithResource(new ByteBufferOutputStream(100)) {bos => + val out = new ObjectOutputStream(bos) + out.writeObject(hostAddress) + out.flush() + out.close() + bos.toByteBuffer + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala new file mode 100755 index 00000000..ef342c79 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala @@ -0,0 +1,28 @@ +package org.apache.spark.shuffle.ucx.utils + +import java.net.{BindException, InetSocketAddress} + +import scala.util.Random + +import org.openucx.jucx.UcxException +import org.openucx.jucx.ucp.{UcpListener, UcpListenerParams, UcpWorker} +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +object UcxHelperUtils extends Logging{ + def startListenerOnRandomPort(worker: UcpWorker, sparkConf: SparkConf): UcpListener = { + val ucpListenerParams = new UcpListenerParams() + val (listener, _) = Utils.startServiceOnPort(1024 + Random.nextInt(65535 - 1024), (port: Int) => { + ucpListenerParams.setSockAddr(new InetSocketAddress(port)) + val listener = try { + worker.newListener(ucpListenerParams) + } catch { + case ex:UcxException => throw new BindException(ex.getMessage) + } + (listener, listener.getAddress.getPort) + }, sparkConf) + logInfo(s"Started UcxListener on ${listener.getAddress}") + listener + } +} From 8b9d66b3f838048342272c4a76a22a019447a367 Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Wed, 27 Jan 2021 17:00:03 +0200 Subject: [PATCH 2/4] Switch to Active messages RPC. --- pom.xml | 2 +- .../spark/shuffle/ucx/ShuffleTransport.scala | 2 +- .../spark/shuffle/ucx/UcxShuffleClient.scala | 10 +- .../shuffle/ucx/UcxShuffleTransport.scala | 74 +++++----- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 136 +++++++++++------- .../ucx/rpc/GlobalWorkerRpcThread.scala | 124 ++++++++-------- .../shuffle/ucx/rpc/UcxRpcMessages.scala | 16 +-- 7 files changed, 191 insertions(+), 173 deletions(-) diff --git a/pom.xml b/pom.xml index aafd36e0..dffbb8ea 100755 --- a/pom.xml +++ b/pom.xml @@ -34,7 +34,7 @@ See file LICENSE for terms. 3.0.0 2.12.12 2.12 - 1.10.0-SNAPSHOT + 1.11.1-SNAPSHOT diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index 0bd92891..bc48e956 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -151,7 +151,7 @@ trait ShuffleTransport { * Fetch remote blocks by blockIds. */ def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBuffer: MemoryBlock, cb: OperationCallback): Request + resultBufferAllocator: MemoryBlock, cb: OperationCallback): Request /** * Progress outstanding operations. This routine is blocking (though may poll for event). diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala index 25c00daf..18ce27f7 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala @@ -24,20 +24,12 @@ class UcxShuffleClient(transport: UcxShuffleTransport, .withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId } .flatMap { case (blockManagerId, blocks) => - val blockIds = blocks.map { - case (blockId, _, _) => - val sparkBlockId = blockId.asInstanceOf[ShuffleBlockId] - UcxShuffleBlockId(sparkBlockId.shuffleId, sparkBlockId.mapId, sparkBlockId.reduceId) - } - if (!transport.ucxShuffleConf.pinMemory) { - transport.prefetchBlocks(blockManagerId.executorId, blockIds) - } blocks.map { case (blockId, length, _) => if (length > accurateThreshold) { (blockId, (length * 1.2).toLong) } else { - (blockId, accurateThreshold) + (blockId, accurateThreshold * 2) } } }.toMap diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 495afe4b..651c4904 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -17,7 +17,7 @@ import org.openucx.jucx.{UcxCallback, UcxException} import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.{MemoryPool, UcxHostBounceBuffersPool} -import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread +import org.apache.spark.shuffle.ucx.rpc.{GlobalWorkerRpcThread, UcxRpcMessages} import org.apache.spark.shuffle.ucx.utils.{SerializationUtils, UcxHelperUtils} import org.apache.spark.util.Utils @@ -50,13 +50,18 @@ class UcxStats extends OperationStats { override def recvSize: Long = receiveSize } -class UcxRequest(request: UcpRequest, stats: OperationStats) extends Request { +class UcxRequest(private var request: UcpRequest, stats: OperationStats, private val worker: UcpWorker) + extends Request { - override def isCompleted: Boolean = request.isCompleted + override def isCompleted: Boolean = (request != null) && request.isCompleted - override def cancel(): Unit = request.close() + override def cancel(): Unit = if (request != null) worker.cancelRequest(request) override def getStats: Option[OperationStats] = Some(stats) + + private[ucx] def setRequest(request: UcpRequest) = { + this.request = request + } } /** @@ -104,7 +109,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo memMapParams.nonBlocking() } - val params = new UcpParams().requestTagFeature() + val params = new UcpParams().requestTagFeature().requestAmFeature() if (ucxShuffleConf.useWakeup) { params.requestWakeupFeature() @@ -207,21 +212,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo /** * On a sender side process request of fetchBlockByBlockId */ - private[ucx] def replyFetchBlockRequest(workerId: String, workerAddress: ByteBuffer, - blockId: BlockId, tag: Long): Unit = { - val ep = clientConnections.getOrElseUpdate(workerId, { - val epParams = new UcpEndpointParams() - if (ucxShuffleConf.useSockAddr) { - epParams.setPeerErrorHandlingMode().setSocketAddress( - SerializationUtils.deserializeInetAddress(workerAddress)) - } else { - epParams.setUcpAddress(workerAddress) - } - globalWorker.newEndpoint(epParams) - } - - ) - + private[ucx] def replyFetchBlockRequest(blockId: BlockId, ep: UcpEndpoint, + replyTag: Int = UcxRpcMessages.FETCH_SINGLE_BLOCK_TAG): Unit = { val block = registeredBlocks.get(blockId) if (block == null) { throw new UcxException(s"Block $blockId not registered") @@ -230,33 +222,37 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo lock.lock() val blockMemory = block.getMemoryBlock - logInfo(s"Sending $blockId of size ${blockMemory.size} to $workerId tag: $tag") - ep.sendTaggedNonBlocking(blockMemory.address, blockMemory.size, tag, new UcxCallback { - private val startTime = System.nanoTime() - override def onSuccess(request: UcpRequest): Unit = { - logInfo(s"Sent $blockId of size ${blockMemory.size} to $workerId " + - s"tag: $tag in ${Utils.getUsedTimeNs(startTime)}") - if (block.isInstanceOf[UcxPinnedBlock]) { - val pinnedBlock = block.asInstanceOf[UcxPinnedBlock] - if (pinnedBlock.prefetched) { - registeredBlocks.put(blockId, pinnedBlock.block) - pinnedBlock.ucpMemory.deregister() + logInfo(s"Sending $blockId of size ${blockMemory.size}") + ep.sendAmNonBlocking(replyTag, 0l, 0l, blockMemory.address, blockMemory.size, + UcpConstants.UCP_AM_SEND_FLAG_RNDV, new UcxCallback { + private val startTime = System.nanoTime() + override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sent $blockId of size ${blockMemory.size}" + + s" in ${Utils.getUsedTimeNs(startTime)}") + if (block.isInstanceOf[UcxPinnedBlock]) { + val pinnedBlock = block.asInstanceOf[UcxPinnedBlock] + if (pinnedBlock.prefetched) { + registeredBlocks.put(blockId, pinnedBlock.block) + pinnedBlock.ucpMemory.deregister() + } } + lock.unlock() } - lock.unlock() - } - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to send $blockId: $errorMsg") - lock.unlock() - } - }) + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $blockId: $errorMsg") + lock.unlock() + } + } ) } private def pinMemory(block: Block): UcpMemory = { + val startTime = System.nanoTime() val blockMemory = block.getMemoryBlock - ucxContext.memoryMap( + val result = ucxContext.memoryMap( memMapParams.setAddress(blockMemory.address).setLength(blockMemory.size)) + logInfo(s"Pinning memory of size: ${blockMemory.size} took: ${Utils.getUsedTimeNs(startTime)}") + result } /** diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index e137a313..e28ca003 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -12,7 +12,8 @@ import java.util.concurrent.ThreadLocalRandom import scala.collection.mutable import com.fasterxml.jackson.databind.util.ByteBufferBackedOutputStream -import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpListener, UcpRequest, UcpWorker} +import org.openucx.jucx.ucp.{UcpAmData, UcpAmRecvCallback, UcpConstants, UcpEndpoint, UcpEndpointParams, UcpListener, UcpRequest, UcpWorker} +import org.openucx.jucx.ucs.UcsConstants import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.MemoryPool @@ -66,11 +67,11 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon override def close(): Unit = { - connections.foreach{ + connections.foreach { case (_, endpoint) => endpoint.close() } connections.clear() - listener.map(_.close()) + listener.foreach(_.close()) worker.close() } @@ -116,7 +117,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon endpointParams.setUcpAddress(workerAddresses.get(executorId).asInstanceOf[ByteBuffer]) } - worker.newEndpoint(endpointParams) + worker.newEndpoint(endpointParams) }) } @@ -127,7 +128,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt) workerAddress.rewind() - val message = PrefetchBlockIds(id, new SerializableDirectBuffer(workerAddress), blockIds) + val message = PrefetchBlockIds(blockIds) Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => val out = new ObjectOutputStream(bos) @@ -148,46 +149,47 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, UcxRpcMessages.PREFETCH_TAG, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId") - memoryPool.put(mem) - } - }) + override def onSuccess(request: UcpRequest): Unit = { + logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId") + memoryPool.put(mem) + } + }) } private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], - resultBuffer: Seq[MemoryBlock], - callbacks: Seq[OperationCallback]): Seq[Request] = { + resultBuffer: Seq[MemoryBlock], + callbacks: Seq[OperationCallback]): Seq[Request] = { val ep = getConnection(executorId) val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt) - workerAddress.rewind() - val message = FetchBlocksByBlockIdsRequest(id, new SerializableDirectBuffer(workerAddress), - blockIds) + val tag = ThreadLocalRandom.current().nextInt() + val message = FetchBlocksByBlockIdsRequest(tag, blockIds) + buffer.put(tag.toByte) Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => val out = new ObjectOutputStream(bos) out.writeObject(message) out.flush() out.close() } + val msgSize = buffer.position() - val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0) - - val requests = new Array[UcxRequest](blockIds.length) + val requests = new Array[UcxRequest](blockIds.size) for (i <- blockIds.indices) { - logInfo(s"Receiving block ${blockIds(i)}") val stats = new UcxStats() val result = new UcxSuccessOperationResult(stats) - val request = worker.recvTaggedNonBlocking(resultBuffer(i).address, resultBuffer(i).size, - tag + i, -1L, new UcxCallback () { + requests(i) = new UcxRequest(null, stats, worker) + worker.setAmRecvHandler(tag + i, (headerAddress: Long, headerSize: Long, amData: UcpAmData, + replyEp: UcpEndpoint) => { + require(amData.getLength <= resultBuffer(i).size, s"${amData.getLength} < ${resultBuffer(i).size}") + val request = amData.receive(resultBuffer(i).address, new UcxCallback() { override def onError(ucsStatus: Int, errorMsg: String): Unit = { logError(s"Failed to receive blockId ${blockIds(i)} on tag: $tag, from executorId: $executorId " + s" of size: ${resultBuffer.size}: $errorMsg") - if (callbacks(i) != null ) { + if (callbacks(i) != null) { callbacks(i).onComplete(new UcxFailureOperationResult(errorMsg)) } } @@ -195,19 +197,33 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon override def onSuccess(request: UcpRequest): Unit = { stats.endTime = System.nanoTime() stats.receiveSize = request.getRecvSize - logInfo(s"Received block ${blockIds(i)} from ${executorId} " + + logInfo(s"Received block ${blockIds(i)} from $executorId " + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") if (callbacks(i) != null) { callbacks(i).onComplete(result) } } }) - requests(i) = new UcxRequest(request, stats) + requests(i).setRequest(request) + UcsConstants.STATUS.UCS_OK + }) } logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, - new UcxCallback() { + var headerAddress = 0L + var headerSize = 0L + var dataAddress = 0L + var dataSize = 0L + + if (msgSize <= worker.getMaxAmHeaderSize) { + headerAddress = mem.address + headerSize = msgSize + } else { + dataAddress = mem.address + dataSize = msgSize + } + ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, + UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() { override def onSuccess(request: UcpRequest): Unit = { memoryPool.put(mem) } @@ -216,18 +232,19 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon } private[ucx] def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = { + resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = { val stats = new UcxStats() val ep = getConnection(executorId) val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt) - val tag = ThreadLocalRandom.current().nextLong(2, Long.MaxValue) + val tag = ThreadLocalRandom.current().nextInt() workerAddress.rewind() - val message = FetchBlockByBlockIdRequest(id, new SerializableDirectBuffer(workerAddress), blockId) + val message = FetchBlockByBlockIdRequest(tag, blockId) + buffer.put(tag.toByte) Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => val out = new ObjectOutputStream(bos) out.writeObject(message) @@ -235,34 +252,56 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon out.close() } + val msgSize = buffer.position() + val recvRequest = new UcxRequest(null, stats, worker) + // To avoid unexpected messages, first posting recv val result = new UcxSuccessOperationResult(stats) - val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size, - tag, -1L, new UcxCallback () { - - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " + - s" of size: ${resultBuffer.size}: $errorMsg") - if (cb != null ) { - cb.onComplete(new UcxFailureOperationResult(errorMsg)) + worker.setAmRecvHandler(tag, (headerAddress: Long, headerSize: Long, amData: UcpAmData, replyEp: UcpEndpoint) => { + require(amData.getLength <= resultBuffer.size) + val request = amData.receive(resultBuffer.address, new UcxCallback() { + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " + + s" of size: ${resultBuffer.size}: $errorMsg") + if (cb != null) { + cb.onComplete(new UcxFailureOperationResult(errorMsg)) + } } - } - override def onSuccess(request: UcpRequest): Unit = { - stats.endTime = System.nanoTime() - stats.receiveSize = request.getRecvSize - logInfo(s"Received block ${blockId} " + - s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") - if (cb != null) { - cb.onComplete(result) + override def onSuccess(request: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + stats.receiveSize = request.getRecvSize + logInfo(s"Received block $blockId " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") + if (cb != null) { + cb.onComplete(result) + } } - } + }) + + recvRequest.setRequest(request) + UcsConstants.STATUS.UCS_OK }) - val recvRequest = new UcxRequest(request, stats) + logInfo(s"Sending message to $executorId to fetch $blockId on tag $tag," + s"resultBuffer $resultBuffer") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, + + var headerAddress = 0L + var headerSize = 0L + var dataAddress = 0L + var dataSize = 0L + + if (msgSize <= worker.getMaxAmHeaderSize) { + headerAddress = mem.address + headerSize = msgSize + } else { + dataAddress = mem.address + dataSize = msgSize + } + + ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() { override def onSuccess(request: UcpRequest): Unit = { memoryPool.put(mem) @@ -270,5 +309,4 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon }) recvRequest } - } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index 17357907..759d50e6 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -5,14 +5,16 @@ package org.apache.spark.shuffle.ucx.rpc import java.io.ObjectInputStream +import java.nio.ByteBuffer import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream -import org.openucx.jucx.ucp.{UcpRequest, UcpWorker} -import org.openucx.jucx.{UcxException, UcxUtils} +import org.openucx.jucx.ucp._ +import org.openucx.jucx.ucs.UcsConstants +import org.openucx.jucx.{UcxCallback, UcxUtils} import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.ucx.UcxShuffleTransport import org.apache.spark.shuffle.ucx.memory.MemoryPool -import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds} -import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleTransport} +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest} import org.apache.spark.util.Utils class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, @@ -21,75 +23,69 @@ class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, setDaemon(true) setName("Ucx Shuffle Transport Progress Thread") - override def run(): Unit = { - val numRecvs = transport.ucxShuffleConf.recvQueueSize - val msgSize = transport.ucxShuffleConf.rpcMessageSize - val requests: Array[UcpRequest] = Array.ofDim[UcpRequest](numRecvs) - val recvMemory = memPool.get(msgSize * numRecvs) - val memoryBlocks = (0 until numRecvs).map(i => - MemoryBlock(recvMemory.address + i * msgSize, msgSize)) - - for (i <- 0 until numRecvs) { - requests(i) = globalWorker.recvTaggedNonBlocking(memoryBlocks(i).address, msgSize, - UcxRpcMessages.WILDCARD_TAG, UcxRpcMessages.WILDCARD_TAG_MASK, null) + private def handleFetchBlockRequest(buffer: ByteBuffer, ep: UcpEndpoint): Unit = { + val fetchSingleBlock = buffer.get() + if (fetchSingleBlock == UcxRpcMessages.FETCH_SINGLE_BLOCK_TAG.toByte) { + val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => + val objIn = new ObjectInputStream(bin) + val obj = objIn.readObject().asInstanceOf[FetchBlockByBlockIdRequest] + objIn.close() + obj + } + logInfo(s"Requested single block msg: $msg") + transport.replyFetchBlockRequest(msg.blockId, ep, msg.msgId) + } else { + val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => + val objIn = new ObjectInputStream(bin) + val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest] + objIn.close() + obj + } + logInfo(s"Requested blocks msg: ${msg.blockIds.mkString(",")}") + for (i <- msg.blockIds.indices) { + transport.replyFetchBlockRequest(msg.blockIds(i), ep, msg.startTag + i) + } } - while (!isInterrupted) { - while (globalWorker.progress() == 0) { - if (transport.ucxShuffleConf.useWakeup) { - globalWorker.waitForEvents() - } - } - for (i <- 0 until numRecvs) { - if (requests(i).isCompleted) { - val buffer = UcxUtils.getByteBufferView(memoryBlocks(i).address, msgSize.toInt) - val senderTag = requests(i).getSenderTag - if (senderTag >= 2L) { - // Fetch Single block - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[FetchBlockByBlockIdRequest] - objIn.close() - obj - } - transport.replyFetchBlockRequest(msg.executorId, msg.workerAddress.value, msg.blockId, senderTag) - } else if (senderTag < 0 ) { - // Batch fetch request - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest] - objIn.close() - obj - } - for (j <- msg.blockIds.indices) { - transport.replyFetchBlockRequest(msg.executorId, msg.workerAddress.value, msg.blockIds(j), - senderTag + j) - } + } - } else if (senderTag == UcxRpcMessages.PREFETCH_TAG) { - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[PrefetchBlockIds] - objIn.close() - obj + override def run(): Unit = { + val processCallback: UcpAmRecvCallback = (headerAddress: Long, headerSize: Long, amData: UcpAmData, + replyEp: UcpEndpoint) => { + if (headerSize > 0) { + logInfo(s"Received AM in header") + val header = UcxUtils.getByteBufferView(headerAddress, headerSize.toInt) + handleFetchBlockRequest(header, replyEp) + UcsConstants.STATUS.UCS_OK + } else { + if (amData.isDataValid) { + logInfo(s"Received AM in eager") + val data = UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) + handleFetchBlockRequest(data, replyEp) + UcsConstants.STATUS.UCS_OK + } else { + val recvData = memPool.get(amData.getLength) + amData.receive(recvData.address, new UcxCallback { + override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Received AM in rndv") + val data = UcxUtils.getByteBufferView(recvData.address, + request.getRecvSize.toInt) + handleFetchBlockRequest(data, replyEp) + memPool.put(recvData) } - transport.handlePrefetchRequest(msg.executorId, msg.workerAddress.value, msg.blockIds) - } - requests(i) = globalWorker.recvTaggedNonBlocking(memoryBlocks(i).address, msgSize, - UcxRpcMessages.WILDCARD_TAG, UcxRpcMessages.WILDCARD_TAG_MASK, null) + }) } + UcsConstants.STATUS.UCS_OK } } - - memPool.put(recvMemory) - for (i <- 0 until numRecvs) { - if (!requests(i).isCompleted) { - try { - globalWorker.cancelRequest(requests(i)) - } catch { - case _: UcxException => + globalWorker.setAmRecvHandler(0, processCallback) + while (!isInterrupted) { + if (globalWorker.progress() == 0) { + if (transport.ucxShuffleConf.useWakeup) { + globalWorker.waitForEvents() } } } + } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala index ebdc47f4..2a883b78 100644 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala @@ -10,9 +10,9 @@ import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer object UcxRpcMessages { - val PREFETCH_TAG = 1L - val WILDCARD_TAG = -1L - val WILDCARD_TAG_MASK = 0L + val PREFETCH_TAG = 1 + val FETCH_SINGLE_BLOCK_TAG = 2 + val FETCH_MULTIPLE_BLOCKS_TAG = 3 /** * Called from executor to driver, to introduce ucx worker address. @@ -27,13 +27,9 @@ object UcxRpcMessages { case class IntroduceAllExecutors(executorIds: Seq[String], ucxWorkerAddresses: Seq[SerializableDirectBuffer]) - case class FetchBlockByBlockIdRequest(executorId: String, workerAddress: SerializableDirectBuffer, - blockId: BlockId) + case class FetchBlockByBlockIdRequest(msgId: Int, blockId: BlockId) - case class FetchBlocksByBlockIdsRequest(executorId: String, - workerAddress: SerializableDirectBuffer, - blockIds: Seq[BlockId]) + case class FetchBlocksByBlockIdsRequest(startTag: Int, blockIds: Seq[BlockId]) - case class PrefetchBlockIds(executorId: String, workerAddress: SerializableDirectBuffer, - blockIds: Seq[BlockId]) + case class PrefetchBlockIds(blockIds: Seq[BlockId]) } From 5c79fbad17178de0e682c2a4e3a534676c88f9f7 Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Fri, 7 Aug 2020 15:04:35 +0300 Subject: [PATCH 3/4] Performance benchmark --- pom.xml | 19 ++ .../shuffle/ucx/UcxShuffleTransport.scala | 5 +- .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 17 +- .../perf/UcxShuffleTransportPerfTool.scala | 224 ++++++++++++++++++ .../spark/shuffle/ucx/GpuMemoryPool.scala | 30 +++ .../ucx/UcxShuffleTransportTestSuite.scala | 2 +- 6 files changed, 289 insertions(+), 8 deletions(-) create mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala create mode 100755 src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala diff --git a/pom.xml b/pom.xml index dffbb8ea..fa2d88ce 100755 --- a/pom.xml +++ b/pom.xml @@ -35,6 +35,7 @@ See file LICENSE for terms. 2.12.12 2.12 1.11.1-SNAPSHOT + 0.16 @@ -55,6 +56,12 @@ See file LICENSE for terms. 3.2.1 test + + ai.rapids + cudf + ${cudf.version} + test + @@ -137,6 +144,18 @@ See file LICENSE for terms. + + org.apache.maven.plugins + maven-jar-plugin + 3.2.0 + + + + test-jar + + + + diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 651c4904..bddc7f29 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -196,6 +196,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo blockIds: Seq[BlockId]) { val startTime = System.nanoTime() + logDebug(s"Prefetching blocks: ${blockIds.mkString(",")}") clientConnections.getOrElseUpdate(workerId, globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) ) @@ -235,8 +236,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo registeredBlocks.put(blockId, pinnedBlock.block) pinnedBlock.ucpMemory.deregister() } - } - lock.unlock() + } + lock.unlock() } override def onError(ucsStatus: Int, errorMsg: String): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index e28ca003..2cd4e1ba 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -117,7 +117,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon endpointParams.setUcpAddress(workerAddresses.get(executorId).asInstanceOf[ByteBuffer]) } - worker.newEndpoint(endpointParams) + worker.newEndpoint(endpointParams) }) } @@ -170,9 +170,16 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon buffer.put(tag.toByte) Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => val out = new ObjectOutputStream(bos) - out.writeObject(message) - out.flush() - out.close() + try { + out.writeObject(message) + out.flush() + out.close() + } catch { + case _: BufferOverflowException => + throw new UcxException(s"Prefetch blocks message size > " + + s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}") + case ex: Exception => throw new UcxException(ex.getMessage) + } } val msgSize = buffer.position() @@ -308,5 +315,5 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon } }) recvRequest - } + } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala new file mode 100755 index 00000000..039b3219 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala @@ -0,0 +1,224 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.perf + +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} +import java.nio.ByteBuffer +import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} + +import org.apache.commons.cli.{GnuParser, HelpFormatter, Options} +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.ucx._ +import org.apache.spark.shuffle.ucx.memory.MemoryPool +import org.apache.spark.util.Utils + +object UcxShuffleTransportPerfTool { + private val HELP_OPTION = "h" + private val ADDRESS_OPTION = "a" + private val NUM_BLOCKS_OPTION = "n" + private val SIZE_OPTION = "s" + private val PORT_OPTION = "p" + private val ITER_OPTION = "i" + private val MEMORY_TYPE_OPTION = "m" + private val NUM_THREADS_OPTION = "t" + + private val ucxShuffleConf = new UcxShuffleConf(new SparkConf()) + private val transport = new UcxShuffleTransport(ucxShuffleConf, "e") + private val workerAddress = transport.init() + private var memoryPool: MemoryPool = transport.memoryPool + + case class TestBlockId(id: Int) extends BlockId + + case class PerfOptions(remoteAddress: InetSocketAddress, numBlocks: Int, blockSize: Long, + serverPort: Int, numIterations: Int, numThreads: Int) + + private def initOptions(): Options = { + val options = new Options() + options.addOption(HELP_OPTION, "help", false, + "display help message") + options.addOption(ADDRESS_OPTION, "address", true, + "address of the remote host") + options.addOption(NUM_BLOCKS_OPTION, "num-blocks", true, + "number of blocks to transfer. Default: 1") + options.addOption(SIZE_OPTION, "block-size", true, + "size of block to transfer. Default: 4m") + options.addOption(PORT_OPTION, "server-port", true, + "server port. Default: 12345") + options.addOption(ITER_OPTION, "num-iterations", true, + "number of iterations. Default: 5") + options.addOption(NUM_THREADS_OPTION, "num-threads", true, + "number of threads. Default: 1") + options.addOption(MEMORY_TYPE_OPTION, "memory-type", true, + "memory type: host (default), cuda") + } + + private def parseOptions(args: Array[String]): PerfOptions = { + val parser = new GnuParser() + val options = initOptions() + val cmd = parser.parse(options, args) + + if (cmd.hasOption(HELP_OPTION)) { + new HelpFormatter().printHelp("UcxShufflePerfTool", options) + System.exit(0) + } + + val inetAddress = if (cmd.hasOption(ADDRESS_OPTION)) { + val Array(host, port) = cmd.getOptionValue(ADDRESS_OPTION).split(":") + new InetSocketAddress(host, Integer.parseInt(port)) + } else { + null + } + + val serverPort = Integer.parseInt(cmd.getOptionValue(PORT_OPTION, "12345")) + + val numIterations = Integer.parseInt(cmd.getOptionValue(ITER_OPTION, "5")) + + val threadsNumber = Integer.parseInt(cmd.getOptionValue(NUM_THREADS_OPTION, "1")) + + if (cmd.hasOption(MEMORY_TYPE_OPTION) && cmd.getOptionValue(MEMORY_TYPE_OPTION) == "cuda") { + val className = "org.apache.spark.shuffle.ucx.GpuMemoryPool" + val cls = Utils.classForName(className) + memoryPool = cls.getConstructor().newInstance().asInstanceOf[MemoryPool] + } + + PerfOptions(inetAddress, + Integer.parseInt(cmd.getOptionValue(NUM_BLOCKS_OPTION, "1")), + Utils.byteStringAsBytes(cmd.getOptionValue(SIZE_OPTION, "4m")), + serverPort, numIterations, threadsNumber) + } + + private def startServer(perfOptions: PerfOptions): Unit = { + val blocks: Seq[Block] = (0 until perfOptions.numBlocks).map { _ => + val block = memoryPool.get(perfOptions.blockSize) + new Block { + override def getMemoryBlock: MemoryBlock = + block + } + } + + val blockIds = (0 until perfOptions.numBlocks).map(i => TestBlockId(i)) + blockIds.zip(blocks).foreach { + case (blockId, block) => transport.register(blockId, block) + } + + val serverSocket = new ServerSocket(perfOptions.serverPort) + + println(s"Waiting for connections on " + + s"${InetAddress.getLocalHost.getHostName}:${perfOptions.serverPort} ") + + val clientSocket = serverSocket.accept() + val out = clientSocket.getOutputStream + val in = clientSocket.getInputStream + + val buf = ByteBuffer.allocate(workerAddress.capacity()) + buf.put(workerAddress) + buf.flip() + + out.write(buf.array()) + out.flush() + + println(s"Sending worker address to ${clientSocket.getInetAddress}") + + buf.flip() + + in.read(buf.array()) + clientSocket.close() + serverSocket.close() + + blocks.foreach(block => memoryPool.put(block.getMemoryBlock)) + blockIds.foreach(transport.unregister) + transport.close() + } + + private def startClient(perfOptions: PerfOptions): Unit = { + val socket = new Socket(perfOptions.remoteAddress.getHostName, + perfOptions.remoteAddress.getPort) + + val buf = new Array[Byte](4096) + val readSize = socket.getInputStream.read(buf) + val executorId = "1" + val workerAddress = ByteBuffer.allocateDirect(readSize) + + workerAddress.put(buf, 0, readSize) + println("Received worker address") + + transport.addExecutor(executorId, workerAddress) + + val resultSize = perfOptions.numBlocks * perfOptions.blockSize + val resultMemory = memoryPool.get(resultSize) + + val blockIds = (0 until perfOptions.numBlocks).map(i => TestBlockId(i)) + + val threadPool = Executors.newFixedThreadPool(perfOptions.numThreads) + + for (i <- 1 to perfOptions.numIterations) { + val elapsedTime = new AtomicLong(0) + val countDownLatch = new CountDownLatch(perfOptions.numThreads) + + for (_ <- 0 until perfOptions.numThreads) { + threadPool.execute(() => { + val completed = new AtomicInteger(0) + + val mem = new Array[MemoryBlock](perfOptions.numBlocks) + val callbacks = new Array[OperationCallback](perfOptions.numBlocks) + + for (j <- 0 until perfOptions.numBlocks) { + mem(j) = MemoryBlock(resultMemory.address + j * perfOptions.blockSize, perfOptions.blockSize) + callbacks(j) = (result: OperationResult) => { + elapsedTime.addAndGet(result.getStats.get.getElapsedTimeNs) + completed.incrementAndGet() + } + } + + transport.fetchBlocksByBlockIds(executorId, blockIds, mem, callbacks) + + while (completed.get() != perfOptions.numBlocks) { + transport.progress() + } + countDownLatch.countDown() + }) + } + + countDownLatch.await() + + val totalTime = if (elapsedTime.get() < TimeUnit.MILLISECONDS.toNanos(1)) { + s"$elapsedTime ns" + } else { + s"${TimeUnit.NANOSECONDS.toMillis(elapsedTime.get())} ms" + } + val throughput: Double = (resultSize * perfOptions.numThreads / 1024.0D / 1024.0D / 1024.0D) / + (elapsedTime.get() / 1e9D) + + println(f"${s"[$i/${perfOptions.numIterations}]"}%12s" + + s" numBlocks: ${perfOptions.numBlocks}" + + s" numThreads: ${perfOptions.numThreads}" + + s" size: ${Utils.bytesToString(perfOptions.blockSize)}," + + s" total size: ${Utils.bytesToString(resultSize * perfOptions.numThreads)}," + + f" time: $totalTime%3s" + + f" throughput: $throughput%.5f GB/s") + } + + val out = socket.getOutputStream + out.write(buf) + out.flush() + out.close() + socket.close() + + memoryPool.put(resultMemory) + transport.close() + } + + def main(args: Array[String]): Unit = { + val perfOptions = parseOptions(args) + + if (perfOptions.remoteAddress == null) { + startServer(perfOptions) + } else { + startClient(perfOptions) + } + } + +} diff --git a/src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala b/src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala new file mode 100755 index 00000000..6d3e633a --- /dev/null +++ b/src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala @@ -0,0 +1,30 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import ai.rapids.cudf.DeviceMemoryBuffer +import org.apache.spark.shuffle.ucx.memory.MemoryPool + + +/** + * Test GPU mempool to run with [[ UcxShuffleTransportPerfTool ]] + */ +class GpuMemoryPool extends MemoryPool { + + class GpuMemoryBlock(val deviceBuffer: DeviceMemoryBuffer, + override val address: Long, override val size: Long) + extends MemoryBlock(address, size, isHostMemory = false) + + override def get(size: Long): MemoryBlock = { + val deviceBuffer = DeviceMemoryBuffer.allocate(size) + new GpuMemoryBlock(deviceBuffer, deviceBuffer.getAddress, size) + } + + override def put(mem: MemoryBlock): Unit = { + mem.asInstanceOf[GpuMemoryBlock].deviceBuffer.close() + } + + override def close(): Unit = ??? +} diff --git a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala index f391f3cd..7997042a 100755 --- a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala +++ b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala @@ -1,5 +1,5 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ package org.apache.spark.shuffle.ucx From 73fc7626977123f092bd3b0d2547328692561a00 Mon Sep 17 00:00:00 2001 From: Peter Rudenko Date: Fri, 5 Feb 2021 16:16:01 +0200 Subject: [PATCH 4/4] SparkUCX: AM backend + benchmark implementation. Signed-off-by: Peter Rudenko --- pom.xml | 18 +- .../spark/shuffle/ucx/ShuffleTransport.scala | 13 +- .../shuffle/ucx/UcxShuffleBlockResolver.scala | 2 +- .../spark/shuffle/ucx/UcxShuffleClient.scala | 6 +- .../spark/shuffle/ucx/UcxShuffleConf.scala | 25 +- .../shuffle/ucx/UcxShuffleTransport.scala | 289 ++++++++++------ .../spark/shuffle/ucx/UcxWorkerWrapper.scala | 318 +++++++++--------- .../spark/shuffle/ucx/memory/MemoryPool.scala | 126 ++++++- .../ucx/memory/UcxHostBounceBuffersPool.scala | 116 ------- .../perf/UcxShuffleTransportPerfTool.scala | 123 ++++--- .../ucx/rpc/GlobalWorkerRpcThread.scala | 70 ++-- .../shuffle/ucx/rpc/UcxRpcMessages.scala | 12 +- .../shuffle/ucx/utils/UcxHelperUtils.scala | 14 +- .../spark/shuffle/ucx/GpuMemoryPool.scala | 30 -- .../ucx/UcxShuffleTransportTestSuite.scala | 26 +- 15 files changed, 627 insertions(+), 561 deletions(-) delete mode 100755 src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala delete mode 100755 src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala diff --git a/pom.xml b/pom.xml index fa2d88ce..1f5c22b9 100755 --- a/pom.xml +++ b/pom.xml @@ -34,8 +34,8 @@ See file LICENSE for terms. 3.0.0 2.12.12 2.12 - 1.11.1-SNAPSHOT - 0.16 + 1.11.0-rc3 + 0.18.1 @@ -60,7 +60,7 @@ See file LICENSE for terms. ai.rapids cudf ${cudf.version} - test + provided @@ -71,17 +71,19 @@ See file LICENSE for terms. org.apache.maven.plugins maven-compiler-plugin 3.8.1 - - 1.8 - 1.8 - net.alchim31.maven scala-maven-plugin - 4.4.0 + 4.4.1 all + + -source + ${maven.compiler.source} + -target + ${maven.compiler.target} + -Xexperimental -Xfatal-warnings diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index bc48e956..bd8adaef 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -135,11 +135,6 @@ trait ShuffleTransport { */ def unregister(blockId: BlockId) - /** - * Hint for a transport that these blocks would needed soon. - */ - def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]) - /** * Batch version of [[ fetchBlocksByBlockIds ]]. */ @@ -147,11 +142,9 @@ trait ShuffleTransport { resultBuffer: Seq[MemoryBlock], callbacks: Seq[OperationCallback]): Seq[Request] - /** - * Fetch remote blocks by blockIds. - */ - def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBufferAllocator: MemoryBlock, cb: OperationCallback): Request + def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], + resultBuffer: MemoryBlock, + callbacks: OperationCallback): Request /** * Progress outstanding operations. This routine is blocking (though may poll for event). diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala index 0d0f32bc..420d161f 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala @@ -5,7 +5,7 @@ package org.apache.spark.shuffle.ucx import java.io.File -import java.nio.{ByteBuffer, MappedByteBuffer} +import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.file.StandardOpenOption import java.util.concurrent.ConcurrentHashMap diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala index 18ce27f7..2fc9cbb1 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala @@ -45,7 +45,7 @@ class UcxShuffleClient(transport: UcxShuffleTransport, if (!blockSizes.contains(blockId)) { throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}") } - val resultMemory = transport.memoryPool.get(blockSizes(blockId)) + val resultMemory = transport.hostMemoryPool.get(blockSizes(blockId)) ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId)) callbacks(i) = (result: OperationResult) => { @@ -57,14 +57,14 @@ class UcxShuffleClient(transport: UcxShuffleTransport, val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt) listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { override def release: ManagedBuffer = { - transport.memoryPool.put(resultMemory) + transport.hostMemoryPool.put(resultMemory) this } }) } else { logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" + s" ${result.getError.getMessage}") - throw new UcxException(result.getError.getMessage) + listener.onBlockFetchFailure(blockIds(i), new UcxException(result.getError.getMessage)) } } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala index f73f0970..56867492 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala @@ -37,14 +37,6 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { .bytesConf(ByteUnit.BYTE) .createWithDefault(1024) - lazy val RPC_MESSAGE_SIZE: ConfigEntry[Long] = - ConfigBuilder(getUcxConf("rpcMessageSize")) - .doc("Size of RPC message to send from fetchBlockByBlockId. Must contain ") - .bytesConf(ByteUnit.BYTE) - .checkValue(size => size > maxWorkerAddressSize, - "Rpc message must contain at least workerAddress") - .createWithDefault(2000) - // Memory Pool private lazy val PREALLOCATE_BUFFERS = ConfigBuilder(getUcxConf("memory.preAllocateBuffers")) @@ -57,12 +49,6 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { .booleanConf .createWithDefault(false) - private lazy val RECV_QUEUE_SIZE = - ConfigBuilder(getUcxConf("recvQueueSize")) - .doc("The number of submitted receive requests.") - .intConf - .createWithDefault(5) - private lazy val USE_SOCKADDR = ConfigBuilder(getUcxConf("useSockAddr")) .doc("Whether to use socket address to connect executors.") @@ -87,24 +73,23 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { lazy val protocol: PROTOCOL.Value = PROTOCOL.withName( conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString)) - lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false) + lazy val useOdp: Boolean = conf.getBoolean(USE_ODP.key, USE_ODP.defaultValue.get) lazy val pinMemory: Boolean = conf.getBoolean(MEMORY_PINNING.key, MEMORY_PINNING.defaultValue.get) lazy val maxWorkerAddressSize: Long = conf.getSizeAsBytes(WORKER_ADDRESS_SIZE.key, WORKER_ADDRESS_SIZE.defaultValueString) - lazy val rpcMessageSize: Long = conf.getSizeAsBytes(RPC_MESSAGE_SIZE.key, - RPC_MESSAGE_SIZE.defaultValueString) + lazy val maxMetadataSize: Long = conf.getSizeAsBytes("spark.rapids.shuffle.maxMetadataSize", + "1024") - lazy val useWakeup: Boolean = conf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get) - lazy val recvQueueSize: Int = conf.getInt(RECV_QUEUE_SIZE.key, RECV_QUEUE_SIZE.defaultValue.get) + lazy val useWakeup: Boolean = conf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get) lazy val useSockAddr: Boolean = conf.getBoolean(USE_SOCKADDR.key, USE_SOCKADDR.defaultValue.get) lazy val preallocateBuffersMap: Map[Long, Int] = { - conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty) + conf.get(PREALLOCATE_BUFFERS.key, "").split(",").withFilter(s => s.nonEmpty) .map(entry => entry.split(":") match { case Array(bufferSize, bufferCount) => (Utils.byteStringAsBytes(bufferSize.trim), bufferCount.toInt) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index bddc7f29..c4dff189 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -3,28 +3,28 @@ * See file LICENSE for terms. */ package org.apache.spark.shuffle.ucx + import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} +import java.util.concurrent.locks.Lock +import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future -import scala.util.{Failure, Success} import org.openucx.jucx.ucp._ -import org.openucx.jucx.{UcxCallback, UcxException} +import org.openucx.jucx.ucs.UcsConstants +import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.memory.{MemoryPool, UcxHostBounceBuffersPool} -import org.apache.spark.shuffle.ucx.rpc.{GlobalWorkerRpcThread, UcxRpcMessages} +import org.apache.spark.shuffle.ucx.memory.{UcxGpuBounceBuffersPool, UcxHostBounceBuffersPool} +import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread import org.apache.spark.shuffle.ucx.utils.{SerializationUtils, UcxHelperUtils} import org.apache.spark.util.Utils /** - * Special type of [[ Block ]] interface backed by UcpMemory, - * it may not be actually pinned if used with [[ UcxShuffleConf.useOdp ]] flag. + * Special type of [[ Block ]] interface backed by UcpMemory. */ case class UcxPinnedBlock(block: Block, ucpMemory: UcpMemory, prefetched: Boolean = false) extends Block { @@ -33,6 +33,7 @@ case class UcxPinnedBlock(block: Block, ucpMemory: UcpMemory, prefetched: Boolea class UcxStats extends OperationStats { private[ucx] val startTime = System.nanoTime() + private[ucx] var amHandleTime = 0L private[ucx] var endTime: Long = 0L private[ucx] var receiveSize: Long = 0L @@ -44,8 +45,7 @@ class UcxStats extends OperationStats { override def getElapsedTimeNs: Long = endTime - startTime /** - * Indicates number of valid bytes in receive memory when using - * [[ ShuffleTransport.fetchBlockByBlockId()]] + * Indicates number of valid bytes in receive memory */ override def recvSize: Long = receiveSize } @@ -53,13 +53,15 @@ class UcxStats extends OperationStats { class UcxRequest(private var request: UcpRequest, stats: OperationStats, private val worker: UcpWorker) extends Request { - override def isCompleted: Boolean = (request != null) && request.isCompleted + private[ucx] var completed = false + + override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted) override def cancel(): Unit = if (request != null) worker.cancelRequest(request) override def getStats: Option[OperationStats] = Some(stats) - private[ucx] def setRequest(request: UcpRequest) = { + private[ucx] def setRequest(request: UcpRequest): Unit = { this.request = request } } @@ -71,9 +73,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo extends ShuffleTransport with Logging { // UCX entities - private var ucxContext: UcpContext = _ + private[ucx] var ucxContext: UcpContext = _ private var globalWorker: UcpWorker = _ - private val ucpWorkerParams = new UcpWorkerParams() + private val ucpWorkerParams = new UcpWorkerParams().requestThreadSafety() // TODO: reimplement as workerPool, since spark may create/destroy threads dynamically private var threadLocalWorker: ThreadLocal[UcxWorkerWrapper] = _ @@ -86,14 +88,16 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo // Mapping between executorId and it's address private[ucx] val executorIdToAddress = new ConcurrentHashMap[String, ByteBuffer]() private[ucx] val executorIdToSockAddress = new ConcurrentHashMap[String, InetSocketAddress]() - private[ucx] val clientConnections = mutable.Map.empty[String, UcpEndpoint] + private[ucx] val clientConnections = mutable.HashMap.empty[UcpEndpoint, (UcpEndpoint, InetSocketAddress)] // Need host ucx bounce buffer memory pool to send fetchBlockByBlockId request - var memoryPool: MemoryPool = _ + var hostMemoryPool: UcxHostBounceBuffersPool = _ + var deviceMemoryPool: UcxGpuBounceBuffersPool = _ @volatile private var initialized: Boolean = false private var workerAddress: ByteBuffer = _ + private var listener: Option[UcpListener] = None /** * Initialize transport resources. This function should get called after ensuring that SparkConf @@ -128,8 +132,8 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo globalWorker = ucxContext.newWorker(workerParams) workerAddress = if (ucxShuffleConf.useSockAddr) { - val listener = UcxHelperUtils.startListenerOnRandomPort(globalWorker, ucxShuffleConf.conf) - val buffer = SerializationUtils.serializeInetAddress(listener.getAddress) + listener = Some(UcxHelperUtils.startListenerOnRandomPort(globalWorker, ucxShuffleConf.conf, clientConnections)) + val buffer = SerializationUtils.serializeInetAddress(listener.get.getAddress) buffer } else { val workerAddress = globalWorker.getAddress @@ -138,14 +142,31 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo workerAddress } - memoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - progressThread = new GlobalWorkerRpcThread(globalWorker, memoryPool, this) + hostMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) + deviceMemoryPool = new UcxGpuBounceBuffersPool(ucxShuffleConf, ucxContext) + progressThread = new GlobalWorkerRpcThread(globalWorker, this) - threadLocalWorker = ThreadLocal.withInitial(() => { + val numThreads = ucxShuffleConf.getInt("spark.executor.cores", 4) + val allocateWorker = () => { val localWorker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, memoryPool) + val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, hostMemoryPool) allocatedWorkers.add(workerWrapper) + logInfo(s"Pre-connections on a new worker to ${executorIdToAddress.size()} " + + s"executors") + executorIdToAddress.keys.asScala.foreach(e => { + workerWrapper.getConnection(e) + }) workerWrapper + } + val preAllocatedWorkers = (0 until numThreads).map(_ => allocateWorker()).toList.asJava + val workersPool = new ConcurrentLinkedQueue[UcxWorkerWrapper](preAllocatedWorkers) + threadLocalWorker = ThreadLocal.withInitial(() => { + if (workersPool.isEmpty) { + logWarning(s"Allocating new worker") + allocateWorker() + } else { + workersPool.poll() + } }) progressThread.start() @@ -165,17 +186,18 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } try { progressThread.join() + hostMemoryPool.close() + deviceMemoryPool.close() + clientConnections.keys.foreach(ep => ep.close()) + registeredBlocks.forEachKey(100, blockId => unregister(blockId)) + allocatedWorkers.forEach(_.close()) + listener.foreach(_.close()) + globalWorker.close() + ucxContext.close() } catch { case _:InterruptedException => case e:Throwable => logWarning(e.getLocalizedMessage) } - - memoryPool.close() - clientConnections.values.foreach(ep => ep.close()) - registeredBlocks.forEachKey(100, blockId => unregister(blockId)) - allocatedWorkers.forEach(_.close()) - globalWorker.close() - ucxContext.close() } } @@ -189,69 +211,144 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } else { executorIdToAddress.put(executorId, workerAddress) } - allocatedWorkers.forEach(w => w.getConnection(executorId)) - } - - private[ucx] def handlePrefetchRequest(workerId: String, workerAddress: ByteBuffer, - blockIds: Seq[BlockId]) { - - val startTime = System.nanoTime() - logDebug(s"Prefetching blocks: ${blockIds.mkString(",")}") - clientConnections.getOrElseUpdate(workerId, - globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) - ) - - blockIds.par.foreach(blockId => { - val block = registeredBlocks.get(blockId) - if (!block.isInstanceOf[UcxPinnedBlock]) { - registeredBlocks.put(blockId, UcxPinnedBlock(block, pinMemory(block), prefetched = true)) - } + logInfo(s"Pre connection on ${allocatedWorkers.size()} workers") + allocatedWorkers.forEach(w => { + val ep = w.getConnection(executorId) + w.preConnect(ep) }) - logInfo(s"Prefetched ${blockIds.length} for $workerId in ${Utils.getUsedTimeNs(startTime)}") } /** * On a sender side process request of fetchBlockByBlockId */ - private[ucx] def replyFetchBlockRequest(blockId: BlockId, ep: UcpEndpoint, - replyTag: Int = UcxRpcMessages.FETCH_SINGLE_BLOCK_TAG): Unit = { - val block = registeredBlocks.get(blockId) - if (block == null) { - throw new UcxException(s"Block $blockId not registered") - } - val lock = block.lock.readLock() - lock.lock() - val blockMemory = block.getMemoryBlock - - logInfo(s"Sending $blockId of size ${blockMemory.size}") - ep.sendAmNonBlocking(replyTag, 0l, 0l, blockMemory.address, blockMemory.size, - UcpConstants.UCP_AM_SEND_FLAG_RNDV, new UcxCallback { - private val startTime = System.nanoTime() - override def onSuccess(request: UcpRequest): Unit = { - logInfo(s"Sent $blockId of size ${blockMemory.size}" + - s" in ${Utils.getUsedTimeNs(startTime)}") - if (block.isInstanceOf[UcxPinnedBlock]) { - val pinnedBlock = block.asInstanceOf[UcxPinnedBlock] - if (pinnedBlock.prefetched) { - registeredBlocks.put(blockId, pinnedBlock.block) - pinnedBlock.ucpMemory.deregister() + private[ucx] def replyFetchBlocksRequest(blockIds: Seq[BlockId], isRecvToHostMemory: Seq[Boolean], + ep: UcpEndpoint, startTag: Int, singleReply: Boolean): Unit = { + if (singleReply) { + var blockAddress = 0L + var blockSize = 0L + var responseBlock: MemoryBlock = null + var headerMem: MemoryBlock = null + var responseBlockBuff: ByteBuffer = null + + val totalSize = ucxShuffleConf.maxMetadataSize * blockIds.length + if (totalSize + 4 < globalWorker.getMaxAmHeaderSize) { + headerMem = hostMemoryPool.get(totalSize + 4L) + responseBlockBuff = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + responseBlockBuff.putInt(startTag) + logInfo(s"Sending ${blockIds.mkString(",")} in header") + } else { + headerMem = hostMemoryPool.get(4L) + val headerBuf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + headerBuf.putInt(startTag) + responseBlock = hostMemoryPool.get(totalSize) + responseBlockBuff = UcxUtils.getByteBufferView(responseBlock.address, responseBlock.size) + blockAddress = responseBlock.address + blockSize = responseBlock.size + logInfo(s"Sending ${blockIds.length} blocks in data") + } + val locks = new Array[Lock](blockIds.size) + for (i <- blockIds.indices) { + val block = registeredBlocks.get(blockIds(i)) + locks(i) = block.lock.readLock() + locks(i).lock() + val blockMemory = block.getMemoryBlock + require(blockMemory.isHostMemory && blockMemory.size <= ucxShuffleConf.maxMetadataSize) + val blockBuffer = UcxUtils.getByteBufferView(blockMemory.address, blockMemory.size) + responseBlockBuff.putInt(blockMemory.size.toInt) + responseBlockBuff.put(blockBuffer) + } + ep.sendAmNonBlocking(0, headerMem.address, headerMem.size, blockAddress, blockSize, 0L, + new UcxCallback { + private val startTime = System.nanoTime() + override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sent ${blockIds.length} blocks of size $totalSize" + + s" in ${Utils.getUsedTimeNs(startTime)}") + locks.foreach(_.unlock()) + hostMemoryPool.put(headerMem) + if (responseBlock != null) { + hostMemoryPool.put(responseBlock) + } + } + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send ${blockIds.mkString(",")}: $errorMsg") + locks.foreach(_.unlock()) + hostMemoryPool.put(headerMem) + if (responseBlock != null) { + hostMemoryPool.put(responseBlock) } - } - lock.unlock() + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) + + } else { + for (i <- blockIds.indices) { + val blockId = blockIds(i) + val block = registeredBlocks.get(blockId) + + if (block == null) { + throw new UcxException(s"Block $blockId is not registered") } - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to send $blockId: $errorMsg") - lock.unlock() + val lock = block.lock.readLock() + lock.lock() + + val blockMemory = block.getMemoryBlock + + var blockAddress = 0L + var blockSize = 0L + var headerMem: MemoryBlock = null + + if (blockMemory.isHostMemory && // The block itself in host memory + (blockMemory.size + 4 < globalWorker.getMaxAmHeaderSize) && // and can fit to header + isRecvToHostMemory(i) ) { + headerMem = hostMemoryPool.get(4L + blockMemory.size) + val buf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + buf.putInt(startTag + i) + val blockBuf = UcxUtils.getByteBufferView(blockMemory.address, blockMemory.size) + buf.put(blockBuf) + } else { + headerMem = hostMemoryPool.get(4L) + val buf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + buf.putInt(startTag + i) + blockAddress = blockMemory.address + blockSize = blockMemory.size } - } ) + + ep.sendAmNonBlocking(0, headerMem.address, headerMem.size, blockAddress, blockSize, 0L, + new UcxCallback { + private val startTime = System.nanoTime() + override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sent $blockId of size ${blockMemory.size} " + + s"memType: ${if (blockMemory.isHostMemory) "host" else "gpu"}" + + s" in ${Utils.getUsedTimeNs(startTime)}") + lock.unlock() + hostMemoryPool.put(headerMem) + } + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $blockId-$blockMemory: $errorMsg") + lock.unlock() + hostMemoryPool.put(headerMem) + } + }, + if (blockMemory.isHostMemory) UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST + else UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) + } + } + } private def pinMemory(block: Block): UcpMemory = { val startTime = System.nanoTime() val blockMemory = block.getMemoryBlock + val memType = if (blockMemory.isHostMemory) { + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST + } else { + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA + } val result = ucxContext.memoryMap( - memMapParams.setAddress(blockMemory.address).setLength(blockMemory.size)) + memMapParams.setAddress(blockMemory.address).setLength(blockMemory.size) + .setMemoryType(memType)) logInfo(s"Pinning memory of size: ${blockMemory.size} took: ${Utils.getUsedTimeNs(startTime)}") result } @@ -260,7 +357,7 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * Registers blocks using blockId on SERVER side. */ override def register(blockId: BlockId, block: Block): Unit = { - logTrace(s"Registering $blockId") + logTrace(s"Registering $blockId of size: ${block.getMemoryBlock.size}") val registeredBock: Block = if (ucxShuffleConf.pinMemory) { UcxPinnedBlock(block, pinMemory(block)) } else { @@ -273,23 +370,16 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo * Change location of underlying blockId in memory */ override def mutate(blockId: BlockId, block: Block, callback: OperationCallback): Unit = { - Future { - unregister(blockId) - register(blockId, block) - } andThen { - case Failure(t) => if (callback != null) { - callback.onComplete(new UcxFailureOperationResult(t.getMessage)) - } - case Success(_) => if (callback != null) { - callback.onComplete(new UcxSuccessOperationResult(new UcxStats)) - } - } + unregister(blockId) + register(blockId, block) + callback.onComplete(new UcxSuccessOperationResult(new UcxStats)) } /** * Indicate that this blockId is not needed any more by an application */ override def unregister(blockId: BlockId): Unit = { + logInfo(s"Unregistering $blockId") val block = registeredBlocks.remove(blockId) if (block != null) { block.lock.writeLock().lock() @@ -303,15 +393,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo } } - /** - * Fetch remote blocks by blockIds. - */ - override def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBuffer: MemoryBlock, - cb: OperationCallback): UcxRequest = { - threadLocalWorker.get().fetchBlockByBlockId(executorId, blockId, resultBuffer, cb) - } - /** * Progress outstanding operations. This routine is blocking. It's important to call this routine * within same thread that submitted requests. @@ -325,13 +406,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo executorIdToAddress.remove(executorId) } - /** - * Hint for a transport that these blocks would needed soon. - */ - override def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = { - threadLocalWorker.get().prefetchBlocks(executorId, blockIds) - } - /** * Batch version of [[ fetchBlocksByBlockIds ]]. */ @@ -340,4 +414,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executo callbacks: Seq[OperationCallback]): Seq[Request] = { threadLocalWorker.get().fetchBlocksByBlockIds(executorId, blockIds, resultBuffer, callbacks) } + + override def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], resultBuffer: MemoryBlock, + callback: OperationCallback): Request = { + threadLocalWorker.get().fetchBlocksByBlockIds(executorId, blockIds, resultBuffer, callback) + } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 2cd4e1ba..e22439bd 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -6,21 +6,19 @@ package org.apache.spark.shuffle.ucx import java.io.{Closeable, ObjectOutputStream} import java.net.InetSocketAddress -import java.nio.{BufferOverflowException, ByteBuffer} -import java.util.concurrent.ThreadLocalRandom +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable -import com.fasterxml.jackson.databind.util.ByteBufferBackedOutputStream -import org.openucx.jucx.ucp.{UcpAmData, UcpAmRecvCallback, UcpConstants, UcpEndpoint, UcpEndpointParams, UcpListener, UcpRequest, UcpWorker} +import ai.rapids.cudf.DeviceMemoryBuffer +import org.openucx.jucx.ucp._ import org.openucx.jucx.ucs.UcsConstants import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.MemoryPool -import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages -import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds} -import org.apache.spark.shuffle.ucx.utils.{SerializableDirectBuffer, SerializationUtils, UcxHelperUtils} -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.FetchBlocksByBlockIdsRequest +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** * Success operation result subclass that has operation stats. @@ -49,21 +47,49 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon memoryPool: MemoryPool) extends Closeable with Logging { - // To keep connection map on a remote side by id rather then by worker address, which could be big. - // Would not need when migrate to active messages. - private val id: String = transport.executorId + s"_${Thread.currentThread().getId}" + private var tag: Int = 1 private final val connections = mutable.Map.empty[String, UcpEndpoint] - private val listener: Option[UcpListener] = if (ucxConf.useSockAddr) { - Some(UcxHelperUtils.startListenerOnRandomPort(worker, ucxConf.conf)) - } else { - None - } - - private val workerAddress = if (ucxConf.useSockAddr) { - SerializationUtils.serializeInetAddress(listener.get.getAddress) - } else { - worker.getAddress - } + private val requestData = new ConcurrentHashMap[Int, (MemoryBlock, UcxCallback, UcxRequest)] + + worker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData, + replyEp: UcpEndpoint) => { + val headerBuffer = UcxUtils.getByteBufferView(headerAddress, headerSize) + val i = headerBuffer.getInt + val data = requestData.remove(i) + if (data == null) { + throw new UcxException(s"No data for tag $i") + } + val (resultMemory, callback, ucxRequest) = data + logDebug(s"Received message for tag $i with headerSize $headerSize. " + + s" AmData: ${amData} resultMemory(isHost): ${resultMemory.isHostMemory}") + ucxRequest.getStats.foreach(s => s.asInstanceOf[UcxStats].amHandleTime = System.nanoTime()) + if (headerSize > 4L) { + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, resultMemory.size) + resultBuffer.put(headerBuffer) + ucxRequest.completed = true + if (callback != null) { + ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = headerSize - 4 + callback.onSuccess(null) + } + } else if (amData.isDataValid && resultMemory.isHostMemory) { + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, resultMemory.size) + resultBuffer.put(UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength)) + ucxRequest.completed = true + if (callback != null) { + ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = amData.getLength + callback.onSuccess(null) + } + } else { + require(amData.getLength <= resultMemory.size, s"${amData.getLength} < ${resultMemory.size}") + val request = worker.recvAmDataNonBlocking(amData.getDataHandle, resultMemory.address, + amData.getLength, callback, + if (resultMemory.isHostMemory) UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST else + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) + ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = amData.getLength + ucxRequest.setRequest(request) + } + UcsConstants.STATUS.UCS_OK + }) override def close(): Unit = { @@ -71,21 +97,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon case (_, endpoint) => endpoint.close() } connections.clear() - listener.foreach(_.close()) worker.close() } /** * The only place for worker progress */ - private[ucx] def progress(): Unit = { + private[ucx] def progress(): Unit = this.synchronized { if ((worker.progress() == 0) && ucxConf.useWakeup) { worker.waitForEvents() worker.progress() } } - private[ucx] def getConnection(executorId: String): UcpEndpoint = { + private[ucx] def getConnection(executorId: String): UcpEndpoint = synchronized { val workerAddresses = if (ucxConf.useSockAddr) { transport.executorIdToSockAddress } else { @@ -98,6 +123,7 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon val timeout = ucxConf.conf.getTimeAsMs("spark.network.timeout", "100") workerAddresses.synchronized { while (workerAddresses.get(executorId) == null) { + logWarning(s"No workerAddress for executor $executorId") workerAddresses.wait(timeout) if (System.currentTimeMillis() - startTime > timeout) { throw new UcxException(s"Didn't get worker address for $executorId during $timeout") @@ -108,115 +134,90 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon connections.getOrElseUpdate(executorId, { logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId") - val endpointParams = new UcpEndpointParams() + val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() + .setErrorHandler(new UcpEndpointErrorHandler() { + override def onError(ep: UcpEndpoint, status: Int, errorMsg: String): Unit = { + logError(errorMsg) + } + }) if (ucxConf.useSockAddr) { val sockAddr = workerAddresses.get(executorId).asInstanceOf[InetSocketAddress] - logInfo(s"Connecting worker to $executorId at $sockAddr") - endpointParams.setPeerErrorHandlingMode().setSocketAddress(sockAddr) + endpointParams.setSocketAddress(sockAddr) } else { endpointParams.setUcpAddress(workerAddresses.get(executorId).asInstanceOf[ByteBuffer]) } - worker.newEndpoint(endpointParams) + worker.newEndpoint(endpointParams) }) } - private[ucx] def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = { - logDebug(s"Sending prefetch ${blockIds.length} blocks to $executorId") - val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) - val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt) + def preConnect(ep: UcpEndpoint): Unit = this.synchronized { + val hostBuffer = memoryPool.get(4) + val gpuBuffer = DeviceMemoryBuffer.allocate(1) + val req = ep.sendAmNonBlocking(-1, hostBuffer.address, hostBuffer.size, + gpuBuffer.getAddress, gpuBuffer.getLength, UcpConstants.UCP_AM_SEND_FLAG_REPLY, null) + while (!req.isCompleted) { + progress() + } + gpuBuffer.close() + memoryPool.put(hostBuffer) + } - workerAddress.rewind() - val message = PrefetchBlockIds(blockIds) + private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], + resultBuffers: Seq[MemoryBlock], + callbacks: Seq[OperationCallback]): Seq[Request] = this.synchronized { - Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => - val out = new ObjectOutputStream(bos) - try { - out.writeObject(message) - } catch { - case _: BufferOverflowException => - throw new UcxException(s"Prefetch blocks message size > " + - s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}") - case ex: Exception => throw new UcxException(ex.getMessage) - } + val startTime = System.nanoTime() + val ep = getConnection(executorId) + val message = FetchBlocksByBlockIdsRequest(tag, blockIds, resultBuffers.map(_.isHostMemory).toArray) + + val bos = new ByteBufferOutputStream(1000) + val out = new ObjectOutputStream(bos) + try { + out.writeObject(message) out.flush() out.close() + } catch { + case ex: Exception => throw new UcxException(ex.getMessage) } - val ep = getConnection(executorId) - - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, - UcxRpcMessages.PREFETCH_TAG, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId") - memoryPool.put(mem) - } - }) - } + val msgSize = bos.getCount() + val mem = memoryPool.get(msgSize) + val buffer = UcxUtils.getByteBufferView(mem.address, msgSize) - private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], - resultBuffer: Seq[MemoryBlock], - callbacks: Seq[OperationCallback]): Seq[Request] = { - val ep = getConnection(executorId) - val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) - val buffer = UcxUtils.getByteBufferView(mem.address, - transport.ucxShuffleConf.rpcMessageSize.toInt) - - val tag = ThreadLocalRandom.current().nextInt() - val message = FetchBlocksByBlockIdsRequest(tag, blockIds) - - buffer.put(tag.toByte) - Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => - val out = new ObjectOutputStream(bos) - try { - out.writeObject(message) - out.flush() - out.close() - } catch { - case _: BufferOverflowException => - throw new UcxException(s"Prefetch blocks message size > " + - s"${transport.ucxShuffleConf.RPC_MESSAGE_SIZE.key}:${transport.ucxShuffleConf.rpcMessageSize}") - case ex: Exception => throw new UcxException(ex.getMessage) - } - } - val msgSize = buffer.position() + buffer.put(bos.toByteBuffer) val requests = new Array[UcxRequest](blockIds.size) for (i <- blockIds.indices) { val stats = new UcxStats() val result = new UcxSuccessOperationResult(stats) requests(i) = new UcxRequest(null, stats, worker) - worker.setAmRecvHandler(tag + i, (headerAddress: Long, headerSize: Long, amData: UcpAmData, - replyEp: UcpEndpoint) => { - require(amData.getLength <= resultBuffer(i).size, s"${amData.getLength} < ${resultBuffer(i).size}") - val request = amData.receive(resultBuffer(i).address, new UcxCallback() { - - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to receive blockId ${blockIds(i)} on tag: $tag, from executorId: $executorId " + - s" of size: ${resultBuffer.size}: $errorMsg") - if (callbacks(i) != null) { - callbacks(i).onComplete(new UcxFailureOperationResult(errorMsg)) - } + val callback = if (callbacks.isEmpty) null else new UcxCallback() { + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to receive blockId ${blockIds(i)} on tag: $tag, from executorId: $executorId " + + s" of size: ${resultBuffers(i).size}: $errorMsg") + if (callbacks(i) != null) { + callbacks(i).onComplete(new UcxFailureOperationResult(errorMsg)) } + } - override def onSuccess(request: UcpRequest): Unit = { - stats.endTime = System.nanoTime() - stats.receiveSize = request.getRecvSize - logInfo(s"Received block ${blockIds(i)} from $executorId " + - s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") - if (callbacks(i) != null) { - callbacks(i).onComplete(result) - } + override def onSuccess(request: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logInfo(s"Received block ${blockIds(i)} from $executorId " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}. Time from amHanlde to recv: " + + s"${Utils.getUsedTimeNs(stats.amHandleTime)}") + if (callbacks(i) != null) { + callbacks(i).onComplete(result) } - }) - requests(i).setRequest(request) - UcsConstants.STATUS.UCS_OK - }) + } + } + requestData.put(tag + i, (resultBuffers(i), callback, requests(i))) } - logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") + logDebug(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") var headerAddress = 0L var headerSize = 0L var dataAddress = 0L @@ -229,72 +230,78 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon dataAddress = mem.address dataSize = msgSize } - ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, + worker.progressRequest(ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() { override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sending RPC message of size: ${headerSize + dataSize} " + + s"to $executorId to fetch ${blockIds.length} blocks on starting tag tag $tag took: " + + s"${Utils.getUsedTimeNs(startTime)}") memoryPool.put(mem) } - }) + })) + + for (i <- blockIds.indices) { + while (requestData.contains(tag + i)) { + progress() + } + } + + logInfo(s"FetchBlocksByBlockIds data took: ${Utils.getUsedTimeNs(startTime)}") + tag += blockIds.length requests } - private[ucx] def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = { + private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], + resultBuffer: MemoryBlock, + callback: OperationCallback): Request = this.synchronized { val stats = new UcxStats() val ep = getConnection(executorId) - val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) - val buffer = UcxUtils.getByteBufferView(mem.address, - transport.ucxShuffleConf.rpcMessageSize.toInt) - val tag = ThreadLocalRandom.current().nextInt() - workerAddress.rewind() - val message = FetchBlockByBlockIdRequest(tag, blockId) + val message = FetchBlocksByBlockIdsRequest(tag, blockIds, Seq.empty[Boolean], singleReply = true) - buffer.put(tag.toByte) - Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => - val out = new ObjectOutputStream(bos) + val bos = new ByteBufferOutputStream(1000) + val out = new ObjectOutputStream(bos) + try { out.writeObject(message) out.flush() out.close() + } catch { + case ex: Exception => throw new UcxException(ex.getMessage) } - val msgSize = buffer.position() - val recvRequest = new UcxRequest(null, stats, worker) + val msgSize = bos.getCount() - // To avoid unexpected messages, first posting recv + val mem = memoryPool.get(msgSize) + val buffer = UcxUtils.getByteBufferView(mem.address, msgSize) + + buffer.put(bos.toByteBuffer) + + val request = new UcxRequest(null, stats, worker) val result = new UcxSuccessOperationResult(stats) - worker.setAmRecvHandler(tag, (headerAddress: Long, headerSize: Long, amData: UcpAmData, replyEp: UcpEndpoint) => { - require(amData.getLength <= resultBuffer.size) - val request = amData.receive(resultBuffer.address, new UcxCallback() { + val requestCallback = new UcxCallback() { - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " + - s" of size: ${resultBuffer.size}: $errorMsg") - if (cb != null) { - cb.onComplete(new UcxFailureOperationResult(errorMsg)) - } + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to receive blockId ${blockIds.mkString(",")} on tag: $tag, from executorId: $executorId " + + s" of size: ${resultBuffer.size}: $errorMsg") + if (callback != null) { + callback.onComplete(new UcxFailureOperationResult(errorMsg)) } + } - override def onSuccess(request: UcpRequest): Unit = { - stats.endTime = System.nanoTime() - stats.receiveSize = request.getRecvSize - logInfo(s"Received block $blockId " + - s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") - if (cb != null) { - cb.onComplete(result) - } + override def onSuccess(request: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logInfo(s"Received ${blockIds.length} metadata blocks from $executorId " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") + if (callback != null) { + callback.onComplete(result) } - }) - - recvRequest.setRequest(request) - UcsConstants.STATUS.UCS_OK - }) - + } - logInfo(s"Sending message to $executorId to fetch $blockId on tag $tag," + - s"resultBuffer $resultBuffer") + } + requestData.put(tag, (resultBuffer, requestCallback, request)) + logDebug(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") var headerAddress = 0L var headerSize = 0L var dataAddress = 0L @@ -307,13 +314,20 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon dataAddress = mem.address dataSize = msgSize } - - ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, UcpConstants.UCP_AM_SEND_FLAG_REPLY, - new UcxCallback() { + worker.progressRequest(ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, + UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() { override def onSuccess(request: UcpRequest): Unit = { memoryPool.put(mem) } - }) - recvRequest - } + })) + + while (requestData.contains(tag)) { + progress() + } + + logInfo(s"FetchBlocksByBlockIds metadata took: ${Utils.getUsedTimeNs(stats.startTime)}") + tag += 1 + request + } + } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala index 72c18fc7..65da481f 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala @@ -5,14 +5,132 @@ package org.apache.spark.shuffle.ucx.memory import java.io.Closeable +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque} + +import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory} +import org.openucx.jucx.ucs.UcsConstants +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleConf} +import org.apache.spark.util.Utils + +class UcxBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger, + override val address: Long, override val size: Long) + extends MemoryBlock(address, size, memory.getMemType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) -import org.apache.spark.shuffle.ucx.MemoryBlock /** * Base class to implement memory pool */ -abstract class MemoryPool extends Closeable { - def get(size: Long): MemoryBlock +case class MemoryPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext, memoryType: Int) + extends Closeable with Logging { + + protected def roundUpToTheNextPowerOf2(size: Long): Long = { + // Round up length to the nearest power of two + var length = size + length -= 1 + length |= length >> 1 + length |= length >> 2 + length |= length >> 4 + length |= length >> 8 + length |= length >> 16 + length += 1 + length + } + + protected val allocatorMap = new ConcurrentHashMap[Long, AllocatorStack]() + + protected case class AllocatorStack(length: Long, memType: Int) extends Closeable { + logInfo(s"Allocator stack of memType: $memType and size $length") + private val stack = new ConcurrentLinkedDeque[UcxBounceBufferMemoryBlock] + private val numAllocs = new AtomicInteger(0) + private val memMapParams = new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length) + + private[memory] def get: UcxBounceBufferMemoryBlock = { + var result = stack.pollFirst() + if (result == null) { + numAllocs.incrementAndGet() + if (length < ucxShuffleConf.minRegistrationSize) { + preallocate((ucxShuffleConf.minRegistrationSize / length).toInt) + result = stack.pollFirst() + } else { + logInfo(s"Allocating buffer of size $length") + val memory = ucxContext.memoryMap(memMapParams) + result = new UcxBounceBufferMemoryBlock(memory, new AtomicInteger(1), + memory.getAddress, length) + } + } + result + } + + private[memory] def put(block: UcxBounceBufferMemoryBlock): Unit = { + stack.add(block) + } + + private[memory] def preallocate(numBuffers: Int): Unit = { + logInfo(s"PreAllocating $numBuffers of size $length, " + + s"totalSize: ${Utils.bytesToString(length * numBuffers) }") + val memory = ucxContext.memoryMap( + new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length * numBuffers)) + val refCount = new AtomicInteger(numBuffers) + var offset = 0L + (0 until numBuffers).foreach(_ => { + stack.add(new UcxBounceBufferMemoryBlock(memory, refCount, memory.getAddress + offset, length)) + offset += length + }) + } - def put(mem: MemoryBlock): Unit + override def close(): Unit = { + var numBuffers = 0 + stack.forEach(block => { + block.refCount.decrementAndGet() + if (block.memory.getNativeId != null) { + block.memory.deregister() + } + numBuffers += 1 + }) + logInfo(s"Closing $numBuffers buffers of size $length." + + s"Number of allocations: ${numAllocs.get()}") + stack.clear() + } + } + + override def close(): Unit = { + allocatorMap.values.forEach(allocator => allocator.close()) + allocatorMap.clear() + } + + def get(size: Long): MemoryBlock = { + val roundedSize = roundUpToTheNextPowerOf2(size) + val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, + s => AllocatorStack(s, memoryType)) + val result = allocatorStack.get + new UcxBounceBufferMemoryBlock(result.memory, result.refCount, result.address, size) + } + + def put(mem: MemoryBlock): Unit = { + mem match { + case m: UcxBounceBufferMemoryBlock => + val allocatorStack = allocatorMap.get(roundUpToTheNextPowerOf2(mem.size)) + allocatorStack.put(m) + case _ => + } + } + + def preAllocate(size: Long, numBuffers: Int): Unit = { + val roundedSize = roundUpToTheNextPowerOf2(size) + val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, + s => AllocatorStack(s, memoryType)) + allocatorStack.preallocate(numBuffers) + } } + +class UcxHostBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext) + extends MemoryPool(ucxShuffleConf, ucxContext, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) { + ucxShuffleConf.preallocateBuffersMap.foreach{ + case (size, numBuffers) => preAllocate(size, numBuffers) + } +} + +class UcxGpuBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext) + extends MemoryPool(ucxShuffleConf, ucxContext, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala deleted file mode 100755 index 90e2f6aa..00000000 --- a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ -package org.apache.spark.shuffle.ucx.memory - -import java.io.Closeable -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque} - -import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory} -import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleConf} - -/** - * Pre-registered host bounce buffers. - * TODO: support reclamation - */ -class UcxHostBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger, - override val address: Long, override val size: Long) - extends MemoryBlock(address, size) - -class UcxHostBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext) - extends MemoryPool with Logging { - - private val allocatorMap = new ConcurrentHashMap[Long, AllocatorStack]() - - ucxShuffleConf.preallocateBuffersMap.foreach { - case (size, numBuffers) => - val roundedSize = roundUpToTheNextPowerOf2(size) - logDebug(s"Pre allocating $numBuffers buffers of size $roundedSize") - val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, s => AllocatorStack(s)) - allocatorStack.preallocate(numBuffers) - } - - private case class AllocatorStack(length: Long) extends Closeable { - private val stack = new ConcurrentLinkedDeque[UcxHostBounceBufferMemoryBlock] - private val numAllocs = new AtomicInteger(0) - - private[UcxHostBounceBuffersPool] def get: UcxHostBounceBufferMemoryBlock = { - var result = stack.pollFirst() - if (result == null) { - numAllocs.incrementAndGet() - if (length < ucxShuffleConf.minRegistrationSize) { - preallocate((ucxShuffleConf.minRegistrationSize / length).toInt) - result = stack.pollFirst() - } else { - val memory = ucxContext.memoryMap(new UcpMemMapParams().allocate().setLength(length)) - result = new UcxHostBounceBufferMemoryBlock(memory, new AtomicInteger(1), - memory.getAddress, length) - } - } - result - } - - private[UcxHostBounceBuffersPool] def put(block: UcxHostBounceBufferMemoryBlock): Unit = { - stack.add(block) - } - - private[ucx] def preallocate(numBuffers: Int): Unit = { - val memory = ucxContext.memoryMap( - new UcpMemMapParams().allocate().setLength(length * numBuffers)) - val refCount = new AtomicInteger(numBuffers) - var offset = 0L - (0 until numBuffers).foreach(_ => { - stack.add(new UcxHostBounceBufferMemoryBlock(memory, refCount, memory.getAddress + offset, length)) - offset += length - }) - } - - override def close(): Unit = { - var numBuffers = 0 - stack.forEach(block => { - block.refCount.decrementAndGet() - if (block.memory.getNativeId != null) { - block.memory.deregister() - } - numBuffers += 1 - }) - logInfo(s"Closing $numBuffers buffers of size $length." + - s"Number of allocations: ${numAllocs.get()}") - stack.clear() - } - } - - private def roundUpToTheNextPowerOf2(size: Long): Long = { - // Round up length to the nearest power of two - var length = size - length -= 1 - length |= length >> 1 - length |= length >> 2 - length |= length >> 4 - length |= length >> 8 - length |= length >> 16 - length += 1 - length - } - - override def get(size: Long): MemoryBlock = { - val roundedSize = roundUpToTheNextPowerOf2(size) - val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, s => AllocatorStack(s)) - allocatorStack.get - } - - override def put(mem: MemoryBlock): Unit = { - val allocatorStack = allocatorMap.computeIfAbsent(roundUpToTheNextPowerOf2(mem.size), - s => AllocatorStack(s)) - allocatorStack.put(mem.asInstanceOf[UcxHostBounceBufferMemoryBlock]) - } - - override def close(): Unit = { - allocatorMap.values.forEach(allocator => allocator.close()) - allocatorMap.clear() - } - -} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala index 039b3219..c111f2a9 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala @@ -6,13 +6,17 @@ package org.apache.spark.shuffle.ucx.perf import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.nio.ByteBuffer -import java.util.concurrent.atomic.{AtomicInteger, AtomicLong} import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} +import ai.rapids.cudf.DeviceMemoryBuffer import org.apache.commons.cli.{GnuParser, HelpFormatter, Options} +import org.openucx.jucx.UcxUtils +import org.openucx.jucx.ucp.UcpMemMapParams +import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.SparkConf import org.apache.spark.shuffle.ucx._ import org.apache.spark.shuffle.ucx.memory.MemoryPool +import org.apache.spark.unsafe.Platform import org.apache.spark.util.Utils object UcxShuffleTransportPerfTool { @@ -24,16 +28,17 @@ object UcxShuffleTransportPerfTool { private val ITER_OPTION = "i" private val MEMORY_TYPE_OPTION = "m" private val NUM_THREADS_OPTION = "t" + private val REUSE_ADDRESS_OPTION = "r" private val ucxShuffleConf = new UcxShuffleConf(new SparkConf()) private val transport = new UcxShuffleTransport(ucxShuffleConf, "e") private val workerAddress = transport.init() - private var memoryPool: MemoryPool = transport.memoryPool + private var memoryPool: MemoryPool = transport.hostMemoryPool case class TestBlockId(id: Int) extends BlockId case class PerfOptions(remoteAddress: InetSocketAddress, numBlocks: Int, blockSize: Long, - serverPort: Int, numIterations: Int, numThreads: Int) + serverPort: Int, numIterations: Int, numThreads: Int, memoryType: Int) private def initOptions(): Options = { val options = new Options() @@ -78,28 +83,50 @@ object UcxShuffleTransportPerfTool { val threadsNumber = Integer.parseInt(cmd.getOptionValue(NUM_THREADS_OPTION, "1")) + var memoryType = UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST if (cmd.hasOption(MEMORY_TYPE_OPTION) && cmd.getOptionValue(MEMORY_TYPE_OPTION) == "cuda") { - val className = "org.apache.spark.shuffle.ucx.GpuMemoryPool" - val cls = Utils.classForName(className) - memoryPool = cls.getConstructor().newInstance().asInstanceOf[MemoryPool] + memoryPool = transport.deviceMemoryPool + memoryType = UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA } PerfOptions(inetAddress, Integer.parseInt(cmd.getOptionValue(NUM_BLOCKS_OPTION, "1")), Utils.byteStringAsBytes(cmd.getOptionValue(SIZE_OPTION, "4m")), - serverPort, numIterations, threadsNumber) + serverPort, numIterations, threadsNumber, memoryType) } private def startServer(perfOptions: PerfOptions): Unit = { - val blocks: Seq[Block] = (0 until perfOptions.numBlocks).map { _ => - val block = memoryPool.get(perfOptions.blockSize) - new Block { - override def getMemoryBlock: MemoryBlock = - block + val blocks: Seq[Block] = (0 until perfOptions.numBlocks * perfOptions.numThreads).map { _ => + if (ucxShuffleConf.pinMemory) { + val block = memoryPool.get(perfOptions.blockSize) + new Block { + override def getMemoryBlock: MemoryBlock = + block + } + } else if (perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) { + new Block { + var oldBlock: DeviceMemoryBuffer = _ + override def getMemoryBlock: MemoryBlock = { + if (oldBlock != null) { + oldBlock.close() + } + val block = DeviceMemoryBuffer.allocate(perfOptions.blockSize) + oldBlock = block + MemoryBlock(block.getAddress, block.getLength, isHostMemory = false) + } + } + } else { + new Block { + override def getMemoryBlock: MemoryBlock = { + val buf = ByteBuffer.allocateDirect(perfOptions.blockSize.toInt) + MemoryBlock(UcxUtils.getAddress(buf), perfOptions.blockSize) + } + } } + } - val blockIds = (0 until perfOptions.numBlocks).map(i => TestBlockId(i)) + val blockIds = (0 until perfOptions.numBlocks * perfOptions.numThreads).map(i => TestBlockId(i)) blockIds.zip(blocks).foreach { case (blockId, block) => transport.register(blockId, block) } @@ -147,58 +174,74 @@ object UcxShuffleTransportPerfTool { transport.addExecutor(executorId, workerAddress) - val resultSize = perfOptions.numBlocks * perfOptions.blockSize + val resultSize = perfOptions.numThreads * perfOptions.numBlocks * perfOptions.blockSize val resultMemory = memoryPool.get(resultSize) - - val blockIds = (0 until perfOptions.numBlocks).map(i => TestBlockId(i)) + transport.register(TestBlockId(-1), new Block { + override def getMemoryBlock: MemoryBlock = resultMemory + }) val threadPool = Executors.newFixedThreadPool(perfOptions.numThreads) for (i <- 1 to perfOptions.numIterations) { - val elapsedTime = new AtomicLong(0) + val startTime = System.nanoTime() val countDownLatch = new CountDownLatch(perfOptions.numThreads) - - for (_ <- 0 until perfOptions.numThreads) { + val deviceMemoryBuffers: Array[DeviceMemoryBuffer] = new Array(perfOptions.numThreads * perfOptions.numBlocks) + for (tid <- 0 until perfOptions.numThreads) { threadPool.execute(() => { - val completed = new AtomicInteger(0) + val blocksOffset = tid * perfOptions.numBlocks + val blockIds = (blocksOffset until blocksOffset + perfOptions.numBlocks).map(i => TestBlockId(i)) val mem = new Array[MemoryBlock](perfOptions.numBlocks) - val callbacks = new Array[OperationCallback](perfOptions.numBlocks) for (j <- 0 until perfOptions.numBlocks) { - mem(j) = MemoryBlock(resultMemory.address + j * perfOptions.blockSize, perfOptions.blockSize) - callbacks(j) = (result: OperationResult) => { - elapsedTime.addAndGet(result.getStats.get.getElapsedTimeNs) - completed.incrementAndGet() + mem(j) = MemoryBlock(resultMemory.address + + (tid * perfOptions.numBlocks * perfOptions.blockSize) + j * perfOptions.blockSize, perfOptions.blockSize) + + if (!ucxShuffleConf.pinMemory && perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) { + val buffer = Platform.allocateDirectBuffer(perfOptions.blockSize.toInt) + mem(j) = MemoryBlock(UcxUtils.getAddress(buffer), perfOptions.blockSize) + } else if (!ucxShuffleConf.pinMemory && + perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) { + val buffer = DeviceMemoryBuffer.allocate(perfOptions.blockSize) + deviceMemoryBuffers(tid * perfOptions.numBlocks + j) = buffer + mem(j) = MemoryBlock(buffer.getAddress, buffer.getLength, isHostMemory = false) + } else { + mem(j) = MemoryBlock(resultMemory.address + + (tid * perfOptions.numBlocks * perfOptions.blockSize) + j * perfOptions.blockSize, perfOptions.blockSize) } } - transport.fetchBlocksByBlockIds(executorId, blockIds, mem, callbacks) + val requests = transport.fetchBlocksByBlockIds(executorId, blockIds, mem, Seq.empty[OperationCallback]) - while (completed.get() != perfOptions.numBlocks) { + while (!requests.forall(_.isCompleted)) { transport.progress() } + countDownLatch.countDown() }) } countDownLatch.await() + val elapsedTime = System.nanoTime() - startTime - val totalTime = if (elapsedTime.get() < TimeUnit.MILLISECONDS.toNanos(1)) { + deviceMemoryBuffers.foreach(d => if (d != null) d.close()) + val totalTime = if (elapsedTime < TimeUnit.MILLISECONDS.toNanos(1)) { s"$elapsedTime ns" } else { - s"${TimeUnit.NANOSECONDS.toMillis(elapsedTime.get())} ms" + s"${TimeUnit.NANOSECONDS.toMillis(elapsedTime)} ms" + } + val throughput: Double = (resultSize / 1024.0D / 1024.0D / 1024.0D) / + (elapsedTime / 1e9D) + + if ((i % 100 == 0) || i == perfOptions.numIterations) { + println(f"${s"[$i/${perfOptions.numIterations}]"}%12s" + + s" numBlocks: ${perfOptions.numBlocks}" + + s" numThreads: ${perfOptions.numThreads}" + + s" size: ${Utils.bytesToString(perfOptions.blockSize)}," + + s" total size: ${Utils.bytesToString(resultSize * perfOptions.numThreads)}," + + f" time: $totalTime%3s" + + f" throughput: $throughput%.5f GB/s") } - val throughput: Double = (resultSize * perfOptions.numThreads / 1024.0D / 1024.0D / 1024.0D) / - (elapsedTime.get() / 1e9D) - - println(f"${s"[$i/${perfOptions.numIterations}]"}%12s" + - s" numBlocks: ${perfOptions.numBlocks}" + - s" numThreads: ${perfOptions.numThreads}" + - s" size: ${Utils.bytesToString(perfOptions.blockSize)}," + - s" total size: ${Utils.bytesToString(resultSize * perfOptions.numThreads)}," + - f" time: $totalTime%3s" + - f" throughput: $throughput%.5f GB/s") } val out = socket.getOutputStream @@ -207,8 +250,10 @@ object UcxShuffleTransportPerfTool { out.close() socket.close() + transport.unregister(TestBlockId(-1)) memoryPool.put(resultMemory) transport.close() + threadPool.shutdown() } def main(args: Array[String]): Unit = { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index 759d50e6..aec6c17e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -8,77 +8,55 @@ import java.io.ObjectInputStream import java.nio.ByteBuffer import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream +import org.openucx.jucx.UcxUtils import org.openucx.jucx.ucp._ import org.openucx.jucx.ucs.UcsConstants -import org.openucx.jucx.{UcxCallback, UcxUtils} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.UcxShuffleTransport -import org.apache.spark.shuffle.ucx.memory.MemoryPool -import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest} +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.FetchBlocksByBlockIdsRequest import org.apache.spark.util.Utils -class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, - transport: UcxShuffleTransport) extends Thread with Logging { +class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransport) + extends Thread with Logging { setDaemon(true) setName("Ucx Shuffle Transport Progress Thread") private def handleFetchBlockRequest(buffer: ByteBuffer, ep: UcpEndpoint): Unit = { - val fetchSingleBlock = buffer.get() - if (fetchSingleBlock == UcxRpcMessages.FETCH_SINGLE_BLOCK_TAG.toByte) { - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[FetchBlockByBlockIdRequest] - objIn.close() - obj - } - logInfo(s"Requested single block msg: $msg") - transport.replyFetchBlockRequest(msg.blockId, ep, msg.msgId) - } else { - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest] - objIn.close() - obj - } - logInfo(s"Requested blocks msg: ${msg.blockIds.mkString(",")}") - for (i <- msg.blockIds.indices) { - transport.replyFetchBlockRequest(msg.blockIds(i), ep, msg.startTag + i) - } + val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => + val objIn = new ObjectInputStream(bin) + val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest] + objIn.close() + obj } - + val startTime = System.nanoTime() + transport.replyFetchBlocksRequest(msg.blockIds, msg.isHostMemory, ep, msg.startTag, msg.singleReply) + logInfo(s"Sent reply for ${msg.blockIds.length} blocks in ${Utils.getUsedTimeNs(startTime)}") } override def run(): Unit = { val processCallback: UcpAmRecvCallback = (headerAddress: Long, headerSize: Long, amData: UcpAmData, replyEp: UcpEndpoint) => { if (headerSize > 0) { - logInfo(s"Received AM in header") - val header = UcxUtils.getByteBufferView(headerAddress, headerSize.toInt) + logDebug(s"Received AM in header on ${transport.clientConnections.get(replyEp)}") + val header = UcxUtils.getByteBufferView(headerAddress, headerSize) handleFetchBlockRequest(header, replyEp) UcsConstants.STATUS.UCS_OK } else { - if (amData.isDataValid) { - logInfo(s"Received AM in eager") - val data = UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength.toInt) - handleFetchBlockRequest(data, replyEp) - UcsConstants.STATUS.UCS_OK - } else { - val recvData = memPool.get(amData.getLength) - amData.receive(recvData.address, new UcxCallback { - override def onSuccess(request: UcpRequest): Unit = { - logInfo(s"Received AM in rndv") - val data = UcxUtils.getByteBufferView(recvData.address, - request.getRecvSize.toInt) - handleFetchBlockRequest(data, replyEp) - memPool.put(recvData) - } - }) - } + assert(amData.isDataValid) + logDebug(s"Received AM in eager on ${transport.clientConnections.get(replyEp)}") + val data = UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength) + handleFetchBlockRequest(data, replyEp) UcsConstants.STATUS.UCS_OK } } globalWorker.setAmRecvHandler(0, processCallback) + globalWorker.setAmRecvHandler(-1, new UcpAmRecvCallback() { + override def onReceive(headerAddress: Long, headerSize: Long, amData: UcpAmData, replyEp: UcpEndpoint): Int = { + logTrace(s"Hello") + UcsConstants.STATUS.UCS_OK + } + }) while (!isInterrupted) { if (globalWorker.progress() == 0) { if (transport.ucxShuffleConf.useWakeup) { diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala index 2a883b78..432d6e32 100644 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala @@ -9,11 +9,6 @@ import org.apache.spark.shuffle.ucx.BlockId import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer object UcxRpcMessages { - - val PREFETCH_TAG = 1 - val FETCH_SINGLE_BLOCK_TAG = 2 - val FETCH_MULTIPLE_BLOCKS_TAG = 3 - /** * Called from executor to driver, to introduce ucx worker address. */ @@ -27,9 +22,6 @@ object UcxRpcMessages { case class IntroduceAllExecutors(executorIds: Seq[String], ucxWorkerAddresses: Seq[SerializableDirectBuffer]) - case class FetchBlockByBlockIdRequest(msgId: Int, blockId: BlockId) - - case class FetchBlocksByBlockIdsRequest(startTag: Int, blockIds: Seq[BlockId]) - - case class PrefetchBlockIds(blockIds: Seq[BlockId]) + case class FetchBlocksByBlockIdsRequest(startTag: Int, blockIds: Seq[BlockId], + isHostMemory: Seq[Boolean], singleReply: Boolean = false) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala index ef342c79..37c7d34e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala @@ -2,17 +2,25 @@ package org.apache.spark.shuffle.ucx.utils import java.net.{BindException, InetSocketAddress} +import scala.collection.mutable import scala.util.Random import org.openucx.jucx.UcxException -import org.openucx.jucx.ucp.{UcpListener, UcpListenerParams, UcpWorker} +import org.openucx.jucx.ucp._ import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils object UcxHelperUtils extends Logging{ - def startListenerOnRandomPort(worker: UcpWorker, sparkConf: SparkConf): UcpListener = { - val ucpListenerParams = new UcpListenerParams() + def startListenerOnRandomPort(worker: UcpWorker, sparkConf: SparkConf, + endpoints: mutable.HashMap[UcpEndpoint, (UcpEndpoint, InetSocketAddress)]): + UcpListener = { + val ucpListenerParams = new UcpListenerParams().setConnectionHandler( + (ucpConnectionRequest: UcpConnectionRequest) => { + val repyEp = worker.newEndpoint(new UcpEndpointParams().setPeerErrorHandlingMode() + .setConnectionRequest(ucpConnectionRequest)) + endpoints += repyEp -> (repyEp, ucpConnectionRequest.getClientAddress) + }) val (listener, _) = Utils.startServiceOnPort(1024 + Random.nextInt(65535 - 1024), (port: Int) => { ucpListenerParams.setSockAddr(new InetSocketAddress(port)) val listener = try { diff --git a/src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala b/src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala deleted file mode 100755 index 6d3e633a..00000000 --- a/src/test/scala/org/apache/spark/shuffle/ucx/GpuMemoryPool.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* -* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ -package org.apache.spark.shuffle.ucx - -import ai.rapids.cudf.DeviceMemoryBuffer -import org.apache.spark.shuffle.ucx.memory.MemoryPool - - -/** - * Test GPU mempool to run with [[ UcxShuffleTransportPerfTool ]] - */ -class GpuMemoryPool extends MemoryPool { - - class GpuMemoryBlock(val deviceBuffer: DeviceMemoryBuffer, - override val address: Long, override val size: Long) - extends MemoryBlock(address, size, isHostMemory = false) - - override def get(size: Long): MemoryBlock = { - val deviceBuffer = DeviceMemoryBuffer.allocate(size) - new GpuMemoryBlock(deviceBuffer, deviceBuffer.getAddress, size) - } - - override def put(mem: MemoryBlock): Unit = { - mem.asInstanceOf[GpuMemoryBlock].deviceBuffer.close() - } - - override def close(): Unit = ??? -} diff --git a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala index 7997042a..6c107d6b 100755 --- a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala +++ b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala @@ -76,20 +76,16 @@ class UcxShuffleTransportTestSuite extends AnyFunSuite { val resultMemory2 = MemoryBlock(UcxUtils.getAddress(resultBuffer2), blockSize) val completed = new AtomicInteger(0) + val callback: OperationCallback = (result: OperationResult) => { + completed.incrementAndGet() + assert(result.getStats.get.recvSize == blockSize) + assert(result.getStatus == OperationStatus.SUCCESS) + } - transport.fetchBlockByBlockId(server.ID, TestBlockId(1), resultMemory1, - (result: OperationResult) => { - completed.incrementAndGet() - assert(result.getStats.get.recvSize == blockSize) - assert(result.getStatus == OperationStatus.SUCCESS) - }) - - transport.fetchBlockByBlockId(server.ID, TestBlockId(2), resultMemory2, - (result: OperationResult) => { - completed.incrementAndGet() - assert(result.getStats.get.recvSize == blockSize) - assert(result.getStatus == OperationStatus.SUCCESS) - }) + transport.fetchBlocksByBlockIds(server.ID, + Array(TestBlockId(1), TestBlockId(2)), + Array(resultMemory1, resultMemory2), + Array(callback, callback)) while (completed.get() != 2) { transport.progress() @@ -107,7 +103,9 @@ class UcxShuffleTransportTestSuite extends AnyFunSuite { while (!mutated.get()) {} - val request = transport.fetchBlockByBlockId(server.ID, TestBlockId(2), resultMemory2, null) + val request = transport.fetchBlocksByBlockIds(server.ID, + Array(TestBlockId(2)), Array(resultMemory2), Seq.empty)(0) + while (!request.isCompleted) { transport.progress() }