diff --git a/pom.xml b/pom.xml index d9330461..1f5c22b9 100755 --- a/pom.xml +++ b/pom.xml @@ -34,7 +34,8 @@ See file LICENSE for terms. 3.0.0 2.12.12 2.12 - 1.10.0-SNAPSHOT + 1.11.0-rc3 + 0.18.1 @@ -55,26 +56,34 @@ See file LICENSE for terms. 3.2.1 test + + ai.rapids + cudf + ${cudf.version} + provided + - ${project.artifactId}-${project.version}-for-spark-${spark.version} + ${project.artifactId}-${project.version}-for-spark-3.0 org.apache.maven.plugins maven-compiler-plugin 3.8.1 - - 1.8 - 1.8 - net.alchim31.maven scala-maven-plugin - 4.4.0 + 4.4.1 all + + -source + ${maven.compiler.source} + -target + ${maven.compiler.target} + -Xexperimental -Xfatal-warnings @@ -137,6 +146,18 @@ See file LICENSE for terms. + + org.apache.maven.plugins + maven-jar-plugin + 3.2.0 + + + + test-jar + + + + diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala index 0bd92891..bd8adaef 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala @@ -135,11 +135,6 @@ trait ShuffleTransport { */ def unregister(blockId: BlockId) - /** - * Hint for a transport that these blocks would needed soon. - */ - def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]) - /** * Batch version of [[ fetchBlocksByBlockIds ]]. */ @@ -147,11 +142,9 @@ trait ShuffleTransport { resultBuffer: Seq[MemoryBlock], callbacks: Seq[OperationCallback]): Seq[Request] - /** - * Fetch remote blocks by blockIds. - */ - def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBuffer: MemoryBlock, cb: OperationCallback): Request + def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], + resultBuffer: MemoryBlock, + callbacks: OperationCallback): Request /** * Progress outstanding operations. This routine is blocking (though may poll for event). diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala new file mode 100755 index 00000000..420d161f --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala @@ -0,0 +1,101 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.io.File +import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.openucx.jucx.UcxUtils +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.unsafe.Platform + + +case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { + + def this(shuffleBlockId: ShuffleBlockId) = { + this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId) + } + + def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId +} + +case class BufferBackedBlock(buffer: ByteBuffer) extends Block { + override def getMemoryBlock: MemoryBlock = MemoryBlock(UcxUtils.getAddress(buffer), buffer.capacity()) +} + +class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport) + extends IndexShuffleBlockResolver(conf) { + + type MapId = Long + + private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int] + + override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long, + lengths: Array[Long], dataTmp: File): Unit = { + super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp) + val dataFile = getDataFile(shuffleId, mapId) + if (!dataFile.exists()) { + return + } + numPartitionsForMapId.put(mapId, lengths.length) + val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ, + StandardOpenOption.WRITE) + val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length()) + + val baseAddress = UcxUtils.getAddress(mappedBuffer) + fileChannel.close() + + // Register whole map output file as dummy block + transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE), + BufferBackedBlock(mappedBuffer)) + + val offsetSize = 8 * (lengths.length + 1) + val indexBuf = Platform.allocateDirectBuffer(offsetSize) + + var offset = 0L + indexBuf.putLong(offset) + for (reduceId <- lengths.indices) { + if (lengths(reduceId) > 0) { + transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block { + private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId)) + override def getMemoryBlock: MemoryBlock = memoryBlock + }) + offset += lengths(reduceId) + indexBuf.putLong(offset) + } + } + + if (transport.ucxShuffleConf.protocol == transport.ucxShuffleConf.PROTOCOL.ONE_SIDED) { + transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE), BufferBackedBlock(indexBuf)) + } + } + + override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = { + transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE)) + transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE)) + + val numRegisteredBlocks = numPartitionsForMapId.get(mapId) + (0 until numRegisteredBlocks) + .foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId))) + super.removeDataByMap(shuffleId, mapId) + } + + override def stop(): Unit = { + numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId)) + super.stop() + } + +} + +object BlocksConstants { + val MAP_FILE: Int = -1 + val INDEX_FILE: Int = -2 +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala new file mode 100755 index 00000000..2fc9cbb1 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala @@ -0,0 +1,77 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import java.util.concurrent.TimeUnit + +import org.openucx.jucx.{UcxException, UcxUtils} +import org.apache.spark.SparkEnv +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId} + +class UcxShuffleClient(transport: UcxShuffleTransport, + blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])]) + extends BlockStoreClient with Logging { + + private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key) + + private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress + .withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId } + .flatMap { + case (blockManagerId, blocks) => + blocks.map { + case (blockId, length, _) => + if (length > accurateThreshold) { + (blockId, (length * 1.2).toLong) + } else { + (blockId, accurateThreshold * 2) + } + } + }.toMap + + override def fetchBlocks(host: String, port: Int, execId: String, + blockIds: Array[String], listener: BlockFetchingListener, + downloadFileManager: DownloadFileManager): Unit = { + val ucxBlockIds = new Array[BlockId](blockIds.length) + val memoryBlocks = new Array[MemoryBlock](blockIds.length) + val callbacks = new Array[OperationCallback](blockIds.length) + for (i <- blockIds.indices) { + val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId] + if (!blockSizes.contains(blockId)) { + throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}") + } + val resultMemory = transport.hostMemoryPool.get(blockSizes(blockId)) + ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId) + memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId)) + callbacks(i) = (result: OperationResult) => { + if (result.getStatus == OperationStatus.SUCCESS) { + val stats = result.getStats.get + logInfo(s" Received block ${ucxBlockIds(i)} " + + s"of size: ${stats.recvSize} " + + s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms") + val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt) + listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) { + override def release: ManagedBuffer = { + transport.hostMemoryPool.put(resultMemory) + this + } + }) + } else { + logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" + + s" ${result.getError.getMessage}") + listener.onBlockFetchFailure(blockIds(i), new UcxException(result.getError.getMessage)) + } + } + } + transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks) + } + + override def close(): Unit = { + + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala index 47c9fd73..56867492 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala @@ -12,15 +12,20 @@ import org.apache.spark.util.Utils class UcxShuffleConf(val conf: SparkConf) extends SparkConf { private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name" - private val PROTOCOL = + object PROTOCOL extends Enumeration { + val ONE_SIDED, RNDV = Value + } + + private lazy val PROTOCOL_CONF = ConfigBuilder(getUcxConf("protocol")) - .doc("Which protocol to use: rndv (default), one-sided") + .doc("Which protocol to use: RNDV (default), ONE-SIDED") .stringConf .checkValue(protocol => protocol == "rndv" || protocol == "one-sided", "Invalid protocol. Valid options: rndv / one-sided.") - .createWithDefault("rndv") + .transform(_.toUpperCase.replace("-", "_")) + .createWithDefault("RNDV") - private val MEMORY_PINNING = + private lazy val MEMORY_PINNING = ConfigBuilder(getUcxConf("memoryPinning")) .doc("Whether to pin whole shuffle data in memory") .booleanConf @@ -30,15 +35,7 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { ConfigBuilder(getUcxConf("maxWorkerSize")) .doc("Maximum size of worker address in bytes") .bytesConf(ByteUnit.BYTE) - .createWithDefault(1000) - - lazy val RPC_MESSAGE_SIZE: ConfigEntry[Long] = - ConfigBuilder(getUcxConf("rpcMessageSize")) - .doc("Size of RPC message to send from fetchBlockByBlockId. Must contain ") - .bytesConf(ByteUnit.BYTE) - .checkValue(size => size > maxWorkerAddressSize, - "Rpc message must contain workerAddress") - .createWithDefault(2000) + .createWithDefault(1024) // Memory Pool private lazy val PREALLOCATE_BUFFERS = @@ -52,11 +49,11 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { .booleanConf .createWithDefault(false) - private lazy val RECV_QUEUE_SIZE = - ConfigBuilder(getUcxConf("recvQueueSize")) - .doc("The number of submitted receive requests.") - .intConf - .createWithDefault(5) + private lazy val USE_SOCKADDR = + ConfigBuilder(getUcxConf("useSockAddr")) + .doc("Whether to use socket address to connect executors.") + .booleanConf + .createWithDefault(true) private lazy val MIN_REGISTRATION_SIZE = ConfigBuilder(getUcxConf("memory.minAllocationSize")) @@ -67,24 +64,32 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf { lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key, MIN_REGISTRATION_SIZE.defaultValueString).toInt - lazy val protocol: String = conf.get(PROTOCOL.key, PROTOCOL.defaultValueString) + private lazy val USE_ODP = + ConfigBuilder(getUcxConf("useOdp")) + .doc("Whether to use on demand paging feature, to avoid memory pinning") + .booleanConf + .createWithDefault(false) - lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false) + lazy val protocol: PROTOCOL.Value = PROTOCOL.withName( + conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString)) + + lazy val useOdp: Boolean = conf.getBoolean(USE_ODP.key, USE_ODP.defaultValue.get) lazy val pinMemory: Boolean = conf.getBoolean(MEMORY_PINNING.key, MEMORY_PINNING.defaultValue.get) lazy val maxWorkerAddressSize: Long = conf.getSizeAsBytes(WORKER_ADDRESS_SIZE.key, WORKER_ADDRESS_SIZE.defaultValueString) - lazy val rpcMessageSize: Long = conf.getSizeAsBytes(RPC_MESSAGE_SIZE.key, - RPC_MESSAGE_SIZE.defaultValueString) + lazy val maxMetadataSize: Long = conf.getSizeAsBytes("spark.rapids.shuffle.maxMetadataSize", + "1024") + lazy val useWakeup: Boolean = conf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get) - lazy val recvQueueSize: Int = conf.getInt(RECV_QUEUE_SIZE.key, RECV_QUEUE_SIZE.defaultValue.get) + lazy val useSockAddr: Boolean = conf.getBoolean(USE_SOCKADDR.key, USE_SOCKADDR.defaultValue.get) lazy val preallocateBuffersMap: Map[Long, Int] = { - conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty) + conf.get(PREALLOCATE_BUFFERS.key, "").split(",").withFilter(s => s.nonEmpty) .map(entry => entry.split(":") match { case Array(bufferSize, bufferCount) => (Utils.byteStringAsBytes(bufferSize.trim), bufferCount.toInt) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala new file mode 100755 index 00000000..3d957b0d --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala @@ -0,0 +1,102 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.util.Success + +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer +import org.apache.spark.util.RpcUtils +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, TaskContext} + + +class UcxShuffleManager(conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) { + + val ucxShuffleConf = new UcxShuffleConf(conf) + + lazy val ucxShuffleTransport: UcxShuffleTransport = if (!isDriver) { + new UcxShuffleTransport(ucxShuffleConf, "init") + } else { + null + } + + @volatile private var initialized: Boolean = false + + override val shuffleBlockResolver = + new UcxShuffleBlockResolver(ucxShuffleConf, ucxShuffleTransport) + + logInfo("Starting UcxShuffleManager") + + def initTransport(): Unit = this.synchronized { + if (!initialized) { + val driverEndpointName = "ucx-shuffle-driver" + if (isDriver) { + val rpcEnv = SparkEnv.get.rpcEnv + val driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv) + rpcEnv.setupEndpoint(driverEndpointName, driverEndpoint) + } else { + val blockManager = SparkEnv.get.blockManager.blockManagerId + ucxShuffleTransport.executorId = blockManager.executorId + val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.host, + blockManager.port, conf, new SecurityManager(conf), 1, clientMode=false) + logDebug("Initializing ucx transport") + val address = ucxShuffleTransport.init() + val executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxShuffleTransport) + val endpoint = rpcEnv.setupEndpoint( + s"ucx-shuffle-executor-${blockManager.executorId}", + executorEndpoint) + + val driverEndpoint = RpcUtils.makeDriverRef(driverEndpointName, conf, rpcEnv) + driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId, + endpoint, new SerializableDirectBuffer(address))) + .andThen{ + case Success(msg) => + logInfo(s"Receive reply $msg") + executorEndpoint.receive(msg) + } + } + initialized = true + } + } + + override def getReader[K, C](handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition) + new UcxShuffleReader(ucxShuffleTransport, + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + override def getReaderForRange[K, C]( handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + new UcxShuffleReader(ucxShuffleTransport, + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) + } + + override def stop(): Unit = { + if (ucxShuffleTransport != null) { + ucxShuffleTransport.close() + } + super.stop() + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala new file mode 100755 index 00000000..6ee3e682 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala @@ -0,0 +1,173 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle + +import java.io.InputStream +import java.util.concurrent.LinkedBlockingQueue + +import scala.util.Random + +import org.apache.spark._ +import org.apache.spark.internal.{Logging, config} +import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.shuffle.ucx.{UcxShuffleClient, UcxShuffleTransport} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} +import org.apache.spark.util.CompletionIterator +import org.apache.spark.util.collection.ExternalSorter + +/** + * Fetches and reads the blocks from a shuffle by requesting them from other nodes' block stores. + */ +class UcxShuffleReader[K, C](transport: UcxShuffleTransport, + handle: BaseShuffleHandle[K, _, C], + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + context: TaskContext, + readMetrics: ShuffleReadMetricsReporter, + serializerManager: SerializerManager = SparkEnv.get.serializerManager, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean = false) + extends ShuffleReader[K, C] with Logging { + + private val dep = handle.dependency + + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + } else { + true + } + val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL) + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) && !useOldFetchProtocol + if (shouldBatchFetch && !doBatchFetch) { + logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " + + s"$useOldFetchProtocol.") + } + doBatchFetch + } + + /** Read the combined key-values for this reduce task */ + override def read(): Iterator[Product2[K, C]] = { + val (blocksByAddress1, blocksByAddress2) = blocksByAddress.duplicate + val shuffleClient = new UcxShuffleClient(transport, Random.shuffle(blocksByAddress1)) + val shuffleIterator = new ShuffleBlockFetcherIterator( + context, + shuffleClient, + blockManager, + Random.shuffle(blocksByAddress2), + serializerManager.wrapStream, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, + SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), + SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), + readMetrics, + fetchContinuousBlocksInBatch) + val wrappedStreams = shuffleIterator.toCompletionIterator + + // Ucx shuffle logic + // Java reflection to get access to private results queue + val queueField = classOf[ShuffleBlockFetcherIterator].getDeclaredField( + "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results") + queueField.setAccessible(true) + val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]] + + // Do progress if queue is empty before calling next on ShuffleIterator + val ucxWrappedStream = new Iterator[(BlockId, InputStream)] { + override def next(): (BlockId, InputStream) = { + val startTime = System.currentTimeMillis() + while (resultQueue.isEmpty) { + transport.progress() + } + readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime) + wrappedStreams.next() + } + + override def hasNext: Boolean = { + val result = wrappedStreams.hasNext + if (!result) { + shuffleClient.close() + } + result + } + } + // End of ucx shuffle logic + + val serializerInstance = dep.serializer.newInstance() + + // Create a key/value iterator for each stream + val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map { record => + readMetrics.incRecordsRead(1) + record + }, + context.taskMetrics().mergeShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) + + val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { + if (dep.mapSideCombine) { + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) + } else { + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) + } + } else { + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] + } + + // Sort the output if there is a sort ordering defined. + val resultIter = dep.keyOrdering match { + case Some(keyOrd: Ordering[K]) => + // Create an ExternalSorter to sort the data. + val sorter = + new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer) + sorter.insertAll(aggregatedIter) + context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled) + context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled) + context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes) + // Use completion callback to stop sorter if task was finished/cancelled. + context.addTaskCompletionListener[Unit](_ => { + sorter.stop() + }) + CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop()) + case None => + aggregatedIter + } + + resultIter match { + case _: InterruptibleIterator[Product2[K, C]] => resultIter + case _ => + // Use another interruptible iterator here to support task cancellation as aggregator + // or(and) sorter may have consumed previous interruptible iterator. + new InterruptibleIterator[Product2[K, C]](context, resultIter) + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala index 10aab2f2..c4dff189 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala @@ -3,25 +3,28 @@ * See file LICENSE for terms. */ package org.apache.spark.shuffle.ucx + +import java.net.InetSocketAddress import java.nio.ByteBuffer -import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} +import java.util.concurrent.locks.Lock +import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.concurrent.ExecutionContext.Implicits.global -import scala.concurrent.Future -import scala.util.{Failure, Success} import org.openucx.jucx.ucp._ -import org.openucx.jucx.{UcxCallback, UcxException} +import org.openucx.jucx.ucs.UcsConstants +import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.memory.{MemoryPool, UcxHostBounceBuffersPool} +import org.apache.spark.shuffle.ucx.memory.{UcxGpuBounceBuffersPool, UcxHostBounceBuffersPool} import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread +import org.apache.spark.shuffle.ucx.utils.{SerializationUtils, UcxHelperUtils} +import org.apache.spark.util.Utils /** - * Special type of [[ Block ]] interface backed by UcpMemory, - * it may not be actually pinned if used with [[ UcxShuffleConf.useOdp ]] flag. + * Special type of [[ Block ]] interface backed by UcpMemory. */ case class UcxPinnedBlock(block: Block, ucpMemory: UcpMemory, prefetched: Boolean = false) extends Block { @@ -30,6 +33,7 @@ case class UcxPinnedBlock(block: Block, ucpMemory: UcpMemory, prefetched: Boolea class UcxStats extends OperationStats { private[ucx] val startTime = System.nanoTime() + private[ucx] var amHandleTime = 0L private[ucx] var endTime: Long = 0L private[ucx] var receiveSize: Long = 0L @@ -41,31 +45,37 @@ class UcxStats extends OperationStats { override def getElapsedTimeNs: Long = endTime - startTime /** - * Indicates number of valid bytes in receive memory when using - * [[ ShuffleTransport.fetchBlockByBlockId()]] + * Indicates number of valid bytes in receive memory */ override def recvSize: Long = receiveSize } -class UcxRequest(request: UcpRequest, stats: OperationStats) extends Request { +class UcxRequest(private var request: UcpRequest, stats: OperationStats, private val worker: UcpWorker) + extends Request { + + private[ucx] var completed = false - override def isCompleted: Boolean = request.isCompleted + override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted) - override def cancel(): Unit = request.close() + override def cancel(): Unit = if (request != null) worker.cancelRequest(request) override def getStats: Option[OperationStats] = Some(stats) + + private[ucx] def setRequest(request: UcpRequest): Unit = { + this.request = request + } } /** * UCX implementation of [[ ShuffleTransport ]] API */ -class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executorId: String) +class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executorId: String) extends ShuffleTransport with Logging { // UCX entities - private var ucxContext: UcpContext = _ + private[ucx] var ucxContext: UcpContext = _ private var globalWorker: UcpWorker = _ - private val ucpWorkerParams = new UcpWorkerParams() + private val ucpWorkerParams = new UcpWorkerParams().requestThreadSafety() // TODO: reimplement as workerPool, since spark may create/destroy threads dynamically private var threadLocalWorker: ThreadLocal[UcxWorkerWrapper] = _ @@ -77,69 +87,118 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo // Mapping between executorId and it's address private[ucx] val executorIdToAddress = new ConcurrentHashMap[String, ByteBuffer]() - private[ucx] val clientConnections = mutable.Map.empty[String, UcpEndpoint] + private[ucx] val executorIdToSockAddress = new ConcurrentHashMap[String, InetSocketAddress]() + private[ucx] val clientConnections = mutable.HashMap.empty[UcpEndpoint, (UcpEndpoint, InetSocketAddress)] // Need host ucx bounce buffer memory pool to send fetchBlockByBlockId request - var memoryPool: MemoryPool = _ + var hostMemoryPool: UcxHostBounceBuffersPool = _ + var deviceMemoryPool: UcxGpuBounceBuffersPool = _ + + @volatile private var initialized: Boolean = false + + private var workerAddress: ByteBuffer = _ + private var listener: Option[UcpListener] = None /** * Initialize transport resources. This function should get called after ensuring that SparkConf * has the correct configurations since it will use the spark configuration to configure itself. */ - override def init(): ByteBuffer = { - if (ucxShuffleConf == null) { - ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf) - } + override def init(): ByteBuffer = this.synchronized { + if (!initialized) { + if (ucxShuffleConf == null) { + ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf) + } - if (ucxShuffleConf.useOdp) { - memMapParams.nonBlocking() - } + if (ucxShuffleConf.useOdp) { + memMapParams.nonBlocking() + } - val params = new UcpParams().requestTagFeature().requestWakeupFeature() - if (ucxShuffleConf.protocol == "one-sided") { - params.requestRmaFeature() - } - ucxContext = new UcpContext(params) - globalWorker = ucxContext.newWorker(new UcpWorkerParams().requestWakeupTagRecv() - .requestWakeupTagSend()) - - val result = globalWorker.getAddress - require(result.capacity <= ucxShuffleConf.maxWorkerAddressSize, - s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${result.capacity}") - - memoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) - progressThread = new GlobalWorkerRpcThread(globalWorker, memoryPool, this) - - threadLocalWorker = ThreadLocal.withInitial(() => { - val localWorker = ucxContext.newWorker(ucpWorkerParams) - val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, memoryPool) - allocatedWorkers.add(workerWrapper) - workerWrapper - }) + val params = new UcpParams().requestTagFeature().requestAmFeature() - progressThread.start() - result + if (ucxShuffleConf.useWakeup) { + params.requestWakeupFeature() + } + + if (ucxShuffleConf.protocol == ucxShuffleConf.PROTOCOL.ONE_SIDED) { + params.requestRmaFeature() + } + ucxContext = new UcpContext(params) + + val workerParams = new UcpWorkerParams() + + if (ucxShuffleConf.useWakeup) { + workerParams.requestWakeupTagRecv().requestWakeupTagSend() + } + globalWorker = ucxContext.newWorker(workerParams) + + workerAddress = if (ucxShuffleConf.useSockAddr) { + listener = Some(UcxHelperUtils.startListenerOnRandomPort(globalWorker, ucxShuffleConf.conf, clientConnections)) + val buffer = SerializationUtils.serializeInetAddress(listener.get.getAddress) + buffer + } else { + val workerAddress = globalWorker.getAddress + require(workerAddress.capacity <= ucxShuffleConf.maxWorkerAddressSize, + s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${workerAddress.capacity}") + workerAddress + } + + hostMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext) + deviceMemoryPool = new UcxGpuBounceBuffersPool(ucxShuffleConf, ucxContext) + progressThread = new GlobalWorkerRpcThread(globalWorker, this) + + val numThreads = ucxShuffleConf.getInt("spark.executor.cores", 4) + val allocateWorker = () => { + val localWorker = ucxContext.newWorker(ucpWorkerParams) + val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, hostMemoryPool) + allocatedWorkers.add(workerWrapper) + logInfo(s"Pre-connections on a new worker to ${executorIdToAddress.size()} " + + s"executors") + executorIdToAddress.keys.asScala.foreach(e => { + workerWrapper.getConnection(e) + }) + workerWrapper + } + val preAllocatedWorkers = (0 until numThreads).map(_ => allocateWorker()).toList.asJava + val workersPool = new ConcurrentLinkedQueue[UcxWorkerWrapper](preAllocatedWorkers) + threadLocalWorker = ThreadLocal.withInitial(() => { + if (workersPool.isEmpty) { + logWarning(s"Allocating new worker") + allocateWorker() + } else { + workersPool.poll() + } + }) + + progressThread.start() + initialized = true + } + workerAddress } /** * Close all transport resources */ override def close(): Unit = { - progressThread.interrupt() - globalWorker.signal() - try { - progressThread.join() - } catch { - case _:InterruptedException => - case e:Throwable => logWarning(e.getLocalizedMessage) + if (initialized) { + progressThread.interrupt() + if (ucxShuffleConf.useWakeup) { + globalWorker.signal() + } + try { + progressThread.join() + hostMemoryPool.close() + deviceMemoryPool.close() + clientConnections.keys.foreach(ep => ep.close()) + registeredBlocks.forEachKey(100, blockId => unregister(blockId)) + allocatedWorkers.forEach(_.close()) + listener.foreach(_.close()) + globalWorker.close() + ucxContext.close() + } catch { + case _:InterruptedException => + case e:Throwable => logWarning(e.getLocalizedMessage) + } } - - memoryPool.close() - clientConnections.values.foreach(ep => ep.close()) - registeredBlocks.forEachKey(1, blockId => unregister(blockId)) - allocatedWorkers.forEach(_.close()) - globalWorker.close() - ucxContext.close() } /** @@ -147,73 +206,158 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo * connection establishment outside of UcxShuffleManager. */ def addExecutor(executorId: String, workerAddress: ByteBuffer): Unit = { - executorIdToAddress.put(executorId, workerAddress) - } - - private[ucx] def handlePrefetchRequest(workerId: String, workerAddress: ByteBuffer, - blockIds: Seq[BlockId]) { - - logInfo(s"Prefetching blocks: ${blockIds.mkString(",")}") - clientConnections.getOrElseUpdate(workerId, - globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) - ) - - blockIds.foreach(blockId => { - val block = registeredBlocks.get(blockId) - if (!block.isInstanceOf[UcxPinnedBlock]) { - registeredBlocks.put(blockId, UcxPinnedBlock(block, pinMemory(block), prefetched = true)) - } + if (ucxShuffleConf.useSockAddr) { + executorIdToSockAddress.put(executorId, SerializationUtils.deserializeInetAddress(workerAddress)) + } else { + executorIdToAddress.put(executorId, workerAddress) + } + logInfo(s"Pre connection on ${allocatedWorkers.size()} workers") + allocatedWorkers.forEach(w => { + val ep = w.getConnection(executorId) + w.preConnect(ep) }) } /** * On a sender side process request of fetchBlockByBlockId */ - private[ucx] def replyFetchBlockRequest(workerId: String, workerAddress: ByteBuffer, - blockId: BlockId, tag: Long): Unit = { + private[ucx] def replyFetchBlocksRequest(blockIds: Seq[BlockId], isRecvToHostMemory: Seq[Boolean], + ep: UcpEndpoint, startTag: Int, singleReply: Boolean): Unit = { + if (singleReply) { + var blockAddress = 0L + var blockSize = 0L + var responseBlock: MemoryBlock = null + var headerMem: MemoryBlock = null + var responseBlockBuff: ByteBuffer = null + + val totalSize = ucxShuffleConf.maxMetadataSize * blockIds.length + if (totalSize + 4 < globalWorker.getMaxAmHeaderSize) { + headerMem = hostMemoryPool.get(totalSize + 4L) + responseBlockBuff = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + responseBlockBuff.putInt(startTag) + logInfo(s"Sending ${blockIds.mkString(",")} in header") + } else { + headerMem = hostMemoryPool.get(4L) + val headerBuf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + headerBuf.putInt(startTag) + responseBlock = hostMemoryPool.get(totalSize) + responseBlockBuff = UcxUtils.getByteBufferView(responseBlock.address, responseBlock.size) + blockAddress = responseBlock.address + blockSize = responseBlock.size + logInfo(s"Sending ${blockIds.length} blocks in data") + } + val locks = new Array[Lock](blockIds.size) + for (i <- blockIds.indices) { + val block = registeredBlocks.get(blockIds(i)) + locks(i) = block.lock.readLock() + locks(i).lock() + val blockMemory = block.getMemoryBlock + require(blockMemory.isHostMemory && blockMemory.size <= ucxShuffleConf.maxMetadataSize) + val blockBuffer = UcxUtils.getByteBufferView(blockMemory.address, blockMemory.size) + responseBlockBuff.putInt(blockMemory.size.toInt) + responseBlockBuff.put(blockBuffer) + } + ep.sendAmNonBlocking(0, headerMem.address, headerMem.size, blockAddress, blockSize, 0L, + new UcxCallback { + private val startTime = System.nanoTime() + override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sent ${blockIds.length} blocks of size $totalSize" + + s" in ${Utils.getUsedTimeNs(startTime)}") + locks.foreach(_.unlock()) + hostMemoryPool.put(headerMem) + if (responseBlock != null) { + hostMemoryPool.put(responseBlock) + } + } - val ep = clientConnections.getOrElseUpdate(workerId, - globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress)) - ) + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send ${blockIds.mkString(",")}: $errorMsg") + locks.foreach(_.unlock()) + hostMemoryPool.put(headerMem) + if (responseBlock != null) { + hostMemoryPool.put(responseBlock) + } + } + }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) - val block = registeredBlocks.get(blockId) - if (block == null) { - throw new UcxException(s"Block $blockId not registered") - } - val lock = block.lock.readLock() - lock.lock() - val blockMemory = block.getMemoryBlock + } else { + for (i <- blockIds.indices) { + val blockId = blockIds(i) + val block = registeredBlocks.get(blockId) - logInfo(s"Sending $blockId of size ${blockMemory.size} to tag: $tag") - ep.sendTaggedNonBlocking(blockMemory.address, blockMemory.size, tag, new UcxCallback { - override def onSuccess(request: UcpRequest): Unit = { - if (block.isInstanceOf[UcxPinnedBlock]) { - val pinnedBlock = block.asInstanceOf[UcxPinnedBlock] - if (pinnedBlock.prefetched) { - registeredBlocks.put(blockId, pinnedBlock.block) - pinnedBlock.ucpMemory.deregister() - } + if (block == null) { + throw new UcxException(s"Block $blockId is not registered") + } + + val lock = block.lock.readLock() + lock.lock() + + val blockMemory = block.getMemoryBlock + + var blockAddress = 0L + var blockSize = 0L + var headerMem: MemoryBlock = null + + if (blockMemory.isHostMemory && // The block itself in host memory + (blockMemory.size + 4 < globalWorker.getMaxAmHeaderSize) && // and can fit to header + isRecvToHostMemory(i) ) { + headerMem = hostMemoryPool.get(4L + blockMemory.size) + val buf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + buf.putInt(startTag + i) + val blockBuf = UcxUtils.getByteBufferView(blockMemory.address, blockMemory.size) + buf.put(blockBuf) + } else { + headerMem = hostMemoryPool.get(4L) + val buf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size) + buf.putInt(startTag + i) + blockAddress = blockMemory.address + blockSize = blockMemory.size } - lock.unlock() - } - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to send $blockId: $errorMsg") - lock.unlock() + ep.sendAmNonBlocking(0, headerMem.address, headerMem.size, blockAddress, blockSize, 0L, + new UcxCallback { + private val startTime = System.nanoTime() + override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sent $blockId of size ${blockMemory.size} " + + s"memType: ${if (blockMemory.isHostMemory) "host" else "gpu"}" + + s" in ${Utils.getUsedTimeNs(startTime)}") + lock.unlock() + hostMemoryPool.put(headerMem) + } + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to send $blockId-$blockMemory: $errorMsg") + lock.unlock() + hostMemoryPool.put(headerMem) + } + }, + if (blockMemory.isHostMemory) UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST + else UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) } - }) + } + } private def pinMemory(block: Block): UcpMemory = { + val startTime = System.nanoTime() val blockMemory = block.getMemoryBlock - ucxContext.memoryMap( - memMapParams.setAddress(blockMemory.address).setLength(blockMemory.size)) + val memType = if (blockMemory.isHostMemory) { + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST + } else { + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA + } + val result = ucxContext.memoryMap( + memMapParams.setAddress(blockMemory.address).setLength(blockMemory.size) + .setMemoryType(memType)) + logInfo(s"Pinning memory of size: ${blockMemory.size} took: ${Utils.getUsedTimeNs(startTime)}") + result } /** * Registers blocks using blockId on SERVER side. */ override def register(blockId: BlockId, block: Block): Unit = { + logTrace(s"Registering $blockId of size: ${block.getMemoryBlock.size}") val registeredBock: Block = if (ucxShuffleConf.pinMemory) { UcxPinnedBlock(block, pinMemory(block)) } else { @@ -226,23 +370,16 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo * Change location of underlying blockId in memory */ override def mutate(blockId: BlockId, block: Block, callback: OperationCallback): Unit = { - Future { - unregister(blockId) - register(blockId, block) - } andThen { - case Failure(t) => if (callback != null) { - callback.onComplete(new UcxFailureOperationResult(t.getMessage)) - } - case Success(_) => if (callback != null) { - callback.onComplete(new UcxSuccessOperationResult(new UcxStats)) - } - } + unregister(blockId) + register(blockId, block) + callback.onComplete(new UcxSuccessOperationResult(new UcxStats)) } /** * Indicate that this blockId is not needed any more by an application */ override def unregister(blockId: BlockId): Unit = { + logInfo(s"Unregistering $blockId") val block = registeredBlocks.remove(blockId) if (block != null) { block.lock.writeLock().lock() @@ -256,15 +393,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo } } - /** - * Fetch remote blocks by blockIds. - */ - override def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBuffer: MemoryBlock, - cb: OperationCallback): UcxRequest = { - threadLocalWorker.get().fetchBlockByBlockId(executorId, blockId, resultBuffer, cb) - } - /** * Progress outstanding operations. This routine is blocking. It's important to call this routine * within same thread that submitted requests. @@ -278,13 +406,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo executorIdToAddress.remove(executorId) } - /** - * Hint for a transport that these blocks would needed soon. - */ - override def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = { - threadLocalWorker.get().prefetchBlocks(executorId, blockIds) - } - /** * Batch version of [[ fetchBlocksByBlockIds ]]. */ @@ -293,4 +414,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo callbacks: Seq[OperationCallback]): Seq[Request] = { threadLocalWorker.get().fetchBlocksByBlockIds(executorId, blockIds, resultBuffer, callbacks) } + + override def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], resultBuffer: MemoryBlock, + callback: OperationCallback): Request = { + threadLocalWorker.get().fetchBlocksByBlockIds(executorId, blockIds, resultBuffer, callback) + } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala index 2b55eb06..e22439bd 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala @@ -5,19 +5,20 @@ package org.apache.spark.shuffle.ucx import java.io.{Closeable, ObjectOutputStream} -import java.util.concurrent.ThreadLocalRandom +import java.net.InetSocketAddress +import java.nio.ByteBuffer +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable -import com.fasterxml.jackson.databind.util.ByteBufferBackedOutputStream -import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRequest, UcpWorker} +import ai.rapids.cudf.DeviceMemoryBuffer +import org.openucx.jucx.ucp._ +import org.openucx.jucx.ucs.UcsConstants import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils} import org.apache.spark.internal.Logging import org.apache.spark.shuffle.ucx.memory.MemoryPool -import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages -import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds} -import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer -import org.apache.spark.util.Utils +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.FetchBlocksByBlockIdsRequest +import org.apache.spark.util.{ByteBufferOutputStream, Utils} /** * Success operation result subclass that has operation stats. @@ -46,14 +47,53 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon memoryPool: MemoryPool) extends Closeable with Logging { - // To keep connection map on a remote side by id rather then by worker address, which could be big. - // Would not need when migrate to active messages. - private val id: String = transport.executorId + s"_${Thread.currentThread().getId}" + private var tag: Int = 1 private final val connections = mutable.Map.empty[String, UcpEndpoint] - private val workerAddress = worker.getAddress + private val requestData = new ConcurrentHashMap[Int, (MemoryBlock, UcxCallback, UcxRequest)] + + worker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData, + replyEp: UcpEndpoint) => { + val headerBuffer = UcxUtils.getByteBufferView(headerAddress, headerSize) + val i = headerBuffer.getInt + val data = requestData.remove(i) + if (data == null) { + throw new UcxException(s"No data for tag $i") + } + val (resultMemory, callback, ucxRequest) = data + logDebug(s"Received message for tag $i with headerSize $headerSize. " + + s" AmData: ${amData} resultMemory(isHost): ${resultMemory.isHostMemory}") + ucxRequest.getStats.foreach(s => s.asInstanceOf[UcxStats].amHandleTime = System.nanoTime()) + if (headerSize > 4L) { + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, resultMemory.size) + resultBuffer.put(headerBuffer) + ucxRequest.completed = true + if (callback != null) { + ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = headerSize - 4 + callback.onSuccess(null) + } + } else if (amData.isDataValid && resultMemory.isHostMemory) { + val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, resultMemory.size) + resultBuffer.put(UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength)) + ucxRequest.completed = true + if (callback != null) { + ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = amData.getLength + callback.onSuccess(null) + } + } else { + require(amData.getLength <= resultMemory.size, s"${amData.getLength} < ${resultMemory.size}") + val request = worker.recvAmDataNonBlocking(amData.getDataHandle, resultMemory.address, + amData.getLength, callback, + if (resultMemory.isHostMemory) UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST else + UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) + ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = amData.getLength + ucxRequest.setRequest(request) + } + UcsConstants.STATUS.UCS_OK + }) + override def close(): Unit = { - connections.foreach{ + connections.foreach { case (_, endpoint) => endpoint.close() } connections.clear() @@ -63,23 +103,28 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon /** * The only place for worker progress */ - private[ucx] def progress(): Unit = { + private[ucx] def progress(): Unit = this.synchronized { if ((worker.progress() == 0) && ucxConf.useWakeup) { worker.waitForEvents() worker.progress() } } - private def getConnection(executorId: String): UcpEndpoint = { - val workerAdresses = transport.executorIdToAddress + private[ucx] def getConnection(executorId: String): UcpEndpoint = synchronized { + val workerAddresses = if (ucxConf.useSockAddr) { + transport.executorIdToSockAddress + } else { + transport.executorIdToAddress + } - if (!workerAdresses.contains(executorId)) { + if (!workerAddresses.contains(executorId)) { // Block until there's no worker address for this BlockManagerID val startTime = System.currentTimeMillis() val timeout = ucxConf.conf.getTimeAsMs("spark.network.timeout", "100") - workerAdresses.synchronized { - while (workerAdresses.get(executorId) == null) { - workerAdresses.wait(timeout) + workerAddresses.synchronized { + while (workerAddresses.get(executorId) == null) { + logWarning(s"No workerAddress for executor $executorId") + workerAddresses.wait(timeout) if (System.currentTimeMillis() - startTime > timeout) { throw new UcxException(s"Didn't get worker address for $executorId during $timeout") } @@ -89,146 +134,200 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon connections.getOrElseUpdate(executorId, { logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId") - val endpointParams = new UcpEndpointParams() - .setUcpAddress(workerAdresses.get(executorId)) - worker.newEndpoint(endpointParams) + val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode() + .setErrorHandler(new UcpEndpointErrorHandler() { + override def onError(ep: UcpEndpoint, status: Int, errorMsg: String): Unit = { + logError(errorMsg) + } + }) + if (ucxConf.useSockAddr) { + val sockAddr = workerAddresses.get(executorId).asInstanceOf[InetSocketAddress] + endpointParams.setSocketAddress(sockAddr) + } else { + endpointParams.setUcpAddress(workerAddresses.get(executorId).asInstanceOf[ByteBuffer]) + } + + worker.newEndpoint(endpointParams) }) } - private[ucx] def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = { - logInfo(s"Sending prefetch ${blockIds.length} blocks to $executorId") - - val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) - val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt) - - workerAddress.rewind() - val message = PrefetchBlockIds(id, new SerializableDirectBuffer(workerAddress), blockIds) - Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => - val out = new ObjectOutputStream(bos) - out.writeObject(message) - out.flush() - out.close() + def preConnect(ep: UcpEndpoint): Unit = this.synchronized { + val hostBuffer = memoryPool.get(4) + val gpuBuffer = DeviceMemoryBuffer.allocate(1) + val req = ep.sendAmNonBlocking(-1, hostBuffer.address, hostBuffer.size, + gpuBuffer.getAddress, gpuBuffer.getLength, UcpConstants.UCP_AM_SEND_FLAG_REPLY, null) + while (!req.isCompleted) { + progress() } - - val ep = getConnection(executorId) - - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, - UcxRpcMessages.PREFETCH_TAG, new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId") - memoryPool.put(mem) - } - }) + gpuBuffer.close() + memoryPool.put(hostBuffer) } private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], - resultBuffer: Seq[MemoryBlock], - callbacks: Seq[OperationCallback]): Seq[Request] = { + resultBuffers: Seq[MemoryBlock], + callbacks: Seq[OperationCallback]): Seq[Request] = this.synchronized { + + val startTime = System.nanoTime() val ep = getConnection(executorId) - val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) - val buffer = UcxUtils.getByteBufferView(mem.address, - transport.ucxShuffleConf.rpcMessageSize.toInt) - workerAddress.rewind() - val message = FetchBlocksByBlockIdsRequest(id, new SerializableDirectBuffer(workerAddress), - blockIds) + val message = FetchBlocksByBlockIdsRequest(tag, blockIds, resultBuffers.map(_.isHostMemory).toArray) - Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => - val out = new ObjectOutputStream(bos) + val bos = new ByteBufferOutputStream(1000) + val out = new ObjectOutputStream(bos) + try { out.writeObject(message) out.flush() out.close() + } catch { + case ex: Exception => throw new UcxException(ex.getMessage) } - val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0) - logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - memoryPool.put(mem) - } - }) + val msgSize = bos.getCount() + val mem = memoryPool.get(msgSize) + val buffer = UcxUtils.getByteBufferView(mem.address, msgSize) + + buffer.put(bos.toByteBuffer) - val requests = new Array[UcxRequest](blockIds.length) + val requests = new Array[UcxRequest](blockIds.size) for (i <- blockIds.indices) { val stats = new UcxStats() val result = new UcxSuccessOperationResult(stats) - val request = worker.recvTaggedNonBlocking(resultBuffer(i).address, resultBuffer(i).size, - tag + i, -1L, new UcxCallback () { - - override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to receive blockId ${blockIds(i)} on tag: $tag, from executorId: $executorId " + - s" of size: ${resultBuffer.size}: $errorMsg") - if (callbacks(i) != null ) { - callbacks(i).onComplete(new UcxFailureOperationResult(errorMsg)) - } + requests(i) = new UcxRequest(null, stats, worker) + val callback = if (callbacks.isEmpty) null else new UcxCallback() { + + override def onError(ucsStatus: Int, errorMsg: String): Unit = { + logError(s"Failed to receive blockId ${blockIds(i)} on tag: $tag, from executorId: $executorId " + + s" of size: ${resultBuffers(i).size}: $errorMsg") + if (callbacks(i) != null) { + callbacks(i).onComplete(new UcxFailureOperationResult(errorMsg)) } + } - override def onSuccess(request: UcpRequest): Unit = { - stats.endTime = System.nanoTime() - stats.receiveSize = request.getRecvSize - if (callbacks(i) != null) { - callbacks(i).onComplete(result) - } + override def onSuccess(request: UcpRequest): Unit = { + stats.endTime = System.nanoTime() + logInfo(s"Received block ${blockIds(i)} from $executorId " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}. Time from amHanlde to recv: " + + s"${Utils.getUsedTimeNs(stats.amHandleTime)}") + if (callbacks(i) != null) { + callbacks(i).onComplete(result) } - }) - requests(i) = new UcxRequest(request, stats) + } + } + requestData.put(tag + i, (resultBuffers(i), callback, requests(i))) + } + + logDebug(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") + var headerAddress = 0L + var headerSize = 0L + var dataAddress = 0L + var dataSize = 0L + + if (msgSize <= worker.getMaxAmHeaderSize) { + headerAddress = mem.address + headerSize = msgSize + } else { + dataAddress = mem.address + dataSize = msgSize + } + worker.progressRequest(ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, + UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + logInfo(s"Sending RPC message of size: ${headerSize + dataSize} " + + s"to $executorId to fetch ${blockIds.length} blocks on starting tag tag $tag took: " + + s"${Utils.getUsedTimeNs(startTime)}") + memoryPool.put(mem) + } + })) + + for (i <- blockIds.indices) { + while (requestData.contains(tag + i)) { + progress() + } } + + logInfo(s"FetchBlocksByBlockIds data took: ${Utils.getUsedTimeNs(startTime)}") + tag += blockIds.length requests } - private[ucx] def fetchBlockByBlockId(executorId: String, blockId: BlockId, - resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = { + private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], + resultBuffer: MemoryBlock, + callback: OperationCallback): Request = this.synchronized { val stats = new UcxStats() val ep = getConnection(executorId) - val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize) - val buffer = UcxUtils.getByteBufferView(mem.address, - transport.ucxShuffleConf.rpcMessageSize.toInt) - val tag = ThreadLocalRandom.current().nextLong(2, Long.MaxValue) - workerAddress.rewind() - val message = FetchBlockByBlockIdRequest(id, new SerializableDirectBuffer(workerAddress), blockId) + val message = FetchBlocksByBlockIdsRequest(tag, blockIds, Seq.empty[Boolean], singleReply = true) - Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos => - val out = new ObjectOutputStream(bos) + val bos = new ByteBufferOutputStream(1000) + val out = new ObjectOutputStream(bos) + try { out.writeObject(message) out.flush() out.close() + } catch { + case ex: Exception => throw new UcxException(ex.getMessage) } + val msgSize = bos.getCount() - logTrace(s"Sending message to $executorId to fetch $blockId on tag $tag," + - s"resultBuffer $resultBuffer") - ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag, - new UcxCallback() { - override def onSuccess(request: UcpRequest): Unit = { - memoryPool.put(mem) - } - }) + val mem = memoryPool.get(msgSize) + val buffer = UcxUtils.getByteBufferView(mem.address, msgSize) + + buffer.put(bos.toByteBuffer) + val request = new UcxRequest(null, stats, worker) val result = new UcxSuccessOperationResult(stats) - val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size, - tag, -1L, new UcxCallback () { + val requestCallback = new UcxCallback() { override def onError(ucsStatus: Int, errorMsg: String): Unit = { - logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " + + logError(s"Failed to receive blockId ${blockIds.mkString(",")} on tag: $tag, from executorId: $executorId " + s" of size: ${resultBuffer.size}: $errorMsg") - if (cb != null ) { - cb.onComplete(new UcxFailureOperationResult(errorMsg)) + if (callback != null) { + callback.onComplete(new UcxFailureOperationResult(errorMsg)) } } override def onSuccess(request: UcpRequest): Unit = { stats.endTime = System.nanoTime() - stats.receiveSize = request.getRecvSize - if (cb != null) { - cb.onComplete(result) + logInfo(s"Received ${blockIds.length} metadata blocks from $executorId " + + s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}") + if (callback != null) { + callback.onComplete(result) } } - }) - new UcxRequest(request, stats) + + } + requestData.put(tag, (resultBuffer, requestCallback, request)) + + logDebug(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag") + var headerAddress = 0L + var headerSize = 0L + var dataAddress = 0L + var dataSize = 0L + + if (msgSize <= worker.getMaxAmHeaderSize) { + headerAddress = mem.address + headerSize = msgSize + } else { + dataAddress = mem.address + dataSize = msgSize + } + worker.progressRequest(ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize, + UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() { + override def onSuccess(request: UcpRequest): Unit = { + memoryPool.put(mem) + } + })) + + while (requestData.contains(tag)) { + progress() + } + + logInfo(s"FetchBlocksByBlockIds metadata took: ${Utils.getUsedTimeNs(stats.startTime)}") + tag += 1 + request } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala new file mode 100755 index 00000000..b8e1f6ee --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala @@ -0,0 +1,47 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.io + +import java.util +import java.util.Optional + +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter} +import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter} +import org.apache.spark.shuffle.ucx.{UcxShuffleBlockResolver, UcxShuffleManager, UcxShuffleTransport} +import org.apache.spark.{SparkConf, SparkEnv} + + +class UcxShuffleExecutorComponents(sparkConf: SparkConf) + extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging { + + var ucxShuffleTransport: UcxShuffleTransport = _ + private var blockResolver: UcxShuffleBlockResolver = _ + + override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = { + val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager] + ucxShuffleManager.initTransport() + blockResolver = ucxShuffleManager.shuffleBlockResolver + } + + override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + new LocalDiskShuffleMapOutputWriter( + shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf) + } + + override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long): + Optional[SingleSpillShuffleMapOutputWriter] = { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers.") + } + Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver)) + } + +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala new file mode 100755 index 00000000..04f03b44 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala @@ -0,0 +1,26 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.io + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD +import org.apache.spark.shuffle.api.{ShuffleDriverComponents, ShuffleExecutorComponents} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO +import org.apache.spark.shuffle.ucx.UcxShuffleManager + +class UcxShuffleIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging { + + sparkConf.set(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "100000") + + override def driver(): ShuffleDriverComponents = { + SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager].initTransport() + super.driver() + } + + override def executor(): ShuffleExecutorComponents = { + new UcxShuffleExecutorComponents(sparkConf) + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala index 72c18fc7..65da481f 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala @@ -5,14 +5,132 @@ package org.apache.spark.shuffle.ucx.memory import java.io.Closeable +import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque} + +import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory} +import org.openucx.jucx.ucs.UcsConstants +import org.apache.spark.internal.Logging +import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleConf} +import org.apache.spark.util.Utils + +class UcxBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger, + override val address: Long, override val size: Long) + extends MemoryBlock(address, size, memory.getMemType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) -import org.apache.spark.shuffle.ucx.MemoryBlock /** * Base class to implement memory pool */ -abstract class MemoryPool extends Closeable { - def get(size: Long): MemoryBlock +case class MemoryPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext, memoryType: Int) + extends Closeable with Logging { + + protected def roundUpToTheNextPowerOf2(size: Long): Long = { + // Round up length to the nearest power of two + var length = size + length -= 1 + length |= length >> 1 + length |= length >> 2 + length |= length >> 4 + length |= length >> 8 + length |= length >> 16 + length += 1 + length + } + + protected val allocatorMap = new ConcurrentHashMap[Long, AllocatorStack]() + + protected case class AllocatorStack(length: Long, memType: Int) extends Closeable { + logInfo(s"Allocator stack of memType: $memType and size $length") + private val stack = new ConcurrentLinkedDeque[UcxBounceBufferMemoryBlock] + private val numAllocs = new AtomicInteger(0) + private val memMapParams = new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length) + + private[memory] def get: UcxBounceBufferMemoryBlock = { + var result = stack.pollFirst() + if (result == null) { + numAllocs.incrementAndGet() + if (length < ucxShuffleConf.minRegistrationSize) { + preallocate((ucxShuffleConf.minRegistrationSize / length).toInt) + result = stack.pollFirst() + } else { + logInfo(s"Allocating buffer of size $length") + val memory = ucxContext.memoryMap(memMapParams) + result = new UcxBounceBufferMemoryBlock(memory, new AtomicInteger(1), + memory.getAddress, length) + } + } + result + } + + private[memory] def put(block: UcxBounceBufferMemoryBlock): Unit = { + stack.add(block) + } + + private[memory] def preallocate(numBuffers: Int): Unit = { + logInfo(s"PreAllocating $numBuffers of size $length, " + + s"totalSize: ${Utils.bytesToString(length * numBuffers) }") + val memory = ucxContext.memoryMap( + new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length * numBuffers)) + val refCount = new AtomicInteger(numBuffers) + var offset = 0L + (0 until numBuffers).foreach(_ => { + stack.add(new UcxBounceBufferMemoryBlock(memory, refCount, memory.getAddress + offset, length)) + offset += length + }) + } - def put(mem: MemoryBlock): Unit + override def close(): Unit = { + var numBuffers = 0 + stack.forEach(block => { + block.refCount.decrementAndGet() + if (block.memory.getNativeId != null) { + block.memory.deregister() + } + numBuffers += 1 + }) + logInfo(s"Closing $numBuffers buffers of size $length." + + s"Number of allocations: ${numAllocs.get()}") + stack.clear() + } + } + + override def close(): Unit = { + allocatorMap.values.forEach(allocator => allocator.close()) + allocatorMap.clear() + } + + def get(size: Long): MemoryBlock = { + val roundedSize = roundUpToTheNextPowerOf2(size) + val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, + s => AllocatorStack(s, memoryType)) + val result = allocatorStack.get + new UcxBounceBufferMemoryBlock(result.memory, result.refCount, result.address, size) + } + + def put(mem: MemoryBlock): Unit = { + mem match { + case m: UcxBounceBufferMemoryBlock => + val allocatorStack = allocatorMap.get(roundUpToTheNextPowerOf2(mem.size)) + allocatorStack.put(m) + case _ => + } + } + + def preAllocate(size: Long, numBuffers: Int): Unit = { + val roundedSize = roundUpToTheNextPowerOf2(size) + val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, + s => AllocatorStack(s, memoryType)) + allocatorStack.preallocate(numBuffers) + } } + +class UcxHostBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext) + extends MemoryPool(ucxShuffleConf, ucxContext, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) { + ucxShuffleConf.preallocateBuffersMap.foreach{ + case (size, numBuffers) => preAllocate(size, numBuffers) + } +} + +class UcxGpuBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext) + extends MemoryPool(ucxShuffleConf, ucxContext, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala deleted file mode 100755 index 90e2f6aa..00000000 --- a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala +++ /dev/null @@ -1,116 +0,0 @@ -/* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. -* See file LICENSE for terms. -*/ -package org.apache.spark.shuffle.ucx.memory - -import java.io.Closeable -import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque} - -import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory} -import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleConf} - -/** - * Pre-registered host bounce buffers. - * TODO: support reclamation - */ -class UcxHostBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger, - override val address: Long, override val size: Long) - extends MemoryBlock(address, size) - -class UcxHostBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext) - extends MemoryPool with Logging { - - private val allocatorMap = new ConcurrentHashMap[Long, AllocatorStack]() - - ucxShuffleConf.preallocateBuffersMap.foreach { - case (size, numBuffers) => - val roundedSize = roundUpToTheNextPowerOf2(size) - logDebug(s"Pre allocating $numBuffers buffers of size $roundedSize") - val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, s => AllocatorStack(s)) - allocatorStack.preallocate(numBuffers) - } - - private case class AllocatorStack(length: Long) extends Closeable { - private val stack = new ConcurrentLinkedDeque[UcxHostBounceBufferMemoryBlock] - private val numAllocs = new AtomicInteger(0) - - private[UcxHostBounceBuffersPool] def get: UcxHostBounceBufferMemoryBlock = { - var result = stack.pollFirst() - if (result == null) { - numAllocs.incrementAndGet() - if (length < ucxShuffleConf.minRegistrationSize) { - preallocate((ucxShuffleConf.minRegistrationSize / length).toInt) - result = stack.pollFirst() - } else { - val memory = ucxContext.memoryMap(new UcpMemMapParams().allocate().setLength(length)) - result = new UcxHostBounceBufferMemoryBlock(memory, new AtomicInteger(1), - memory.getAddress, length) - } - } - result - } - - private[UcxHostBounceBuffersPool] def put(block: UcxHostBounceBufferMemoryBlock): Unit = { - stack.add(block) - } - - private[ucx] def preallocate(numBuffers: Int): Unit = { - val memory = ucxContext.memoryMap( - new UcpMemMapParams().allocate().setLength(length * numBuffers)) - val refCount = new AtomicInteger(numBuffers) - var offset = 0L - (0 until numBuffers).foreach(_ => { - stack.add(new UcxHostBounceBufferMemoryBlock(memory, refCount, memory.getAddress + offset, length)) - offset += length - }) - } - - override def close(): Unit = { - var numBuffers = 0 - stack.forEach(block => { - block.refCount.decrementAndGet() - if (block.memory.getNativeId != null) { - block.memory.deregister() - } - numBuffers += 1 - }) - logInfo(s"Closing $numBuffers buffers of size $length." + - s"Number of allocations: ${numAllocs.get()}") - stack.clear() - } - } - - private def roundUpToTheNextPowerOf2(size: Long): Long = { - // Round up length to the nearest power of two - var length = size - length -= 1 - length |= length >> 1 - length |= length >> 2 - length |= length >> 4 - length |= length >> 8 - length |= length >> 16 - length += 1 - length - } - - override def get(size: Long): MemoryBlock = { - val roundedSize = roundUpToTheNextPowerOf2(size) - val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, s => AllocatorStack(s)) - allocatorStack.get - } - - override def put(mem: MemoryBlock): Unit = { - val allocatorStack = allocatorMap.computeIfAbsent(roundUpToTheNextPowerOf2(mem.size), - s => AllocatorStack(s)) - allocatorStack.put(mem.asInstanceOf[UcxHostBounceBufferMemoryBlock]) - } - - override def close(): Unit = { - allocatorMap.values.forEach(allocator => allocator.close()) - allocatorMap.clear() - } - -} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala new file mode 100755 index 00000000..c111f2a9 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala @@ -0,0 +1,269 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.perf + +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} +import java.nio.ByteBuffer +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} + +import ai.rapids.cudf.DeviceMemoryBuffer +import org.apache.commons.cli.{GnuParser, HelpFormatter, Options} +import org.openucx.jucx.UcxUtils +import org.openucx.jucx.ucp.UcpMemMapParams +import org.openucx.jucx.ucs.UcsConstants +import org.apache.spark.SparkConf +import org.apache.spark.shuffle.ucx._ +import org.apache.spark.shuffle.ucx.memory.MemoryPool +import org.apache.spark.unsafe.Platform +import org.apache.spark.util.Utils + +object UcxShuffleTransportPerfTool { + private val HELP_OPTION = "h" + private val ADDRESS_OPTION = "a" + private val NUM_BLOCKS_OPTION = "n" + private val SIZE_OPTION = "s" + private val PORT_OPTION = "p" + private val ITER_OPTION = "i" + private val MEMORY_TYPE_OPTION = "m" + private val NUM_THREADS_OPTION = "t" + private val REUSE_ADDRESS_OPTION = "r" + + private val ucxShuffleConf = new UcxShuffleConf(new SparkConf()) + private val transport = new UcxShuffleTransport(ucxShuffleConf, "e") + private val workerAddress = transport.init() + private var memoryPool: MemoryPool = transport.hostMemoryPool + + case class TestBlockId(id: Int) extends BlockId + + case class PerfOptions(remoteAddress: InetSocketAddress, numBlocks: Int, blockSize: Long, + serverPort: Int, numIterations: Int, numThreads: Int, memoryType: Int) + + private def initOptions(): Options = { + val options = new Options() + options.addOption(HELP_OPTION, "help", false, + "display help message") + options.addOption(ADDRESS_OPTION, "address", true, + "address of the remote host") + options.addOption(NUM_BLOCKS_OPTION, "num-blocks", true, + "number of blocks to transfer. Default: 1") + options.addOption(SIZE_OPTION, "block-size", true, + "size of block to transfer. Default: 4m") + options.addOption(PORT_OPTION, "server-port", true, + "server port. Default: 12345") + options.addOption(ITER_OPTION, "num-iterations", true, + "number of iterations. Default: 5") + options.addOption(NUM_THREADS_OPTION, "num-threads", true, + "number of threads. Default: 1") + options.addOption(MEMORY_TYPE_OPTION, "memory-type", true, + "memory type: host (default), cuda") + } + + private def parseOptions(args: Array[String]): PerfOptions = { + val parser = new GnuParser() + val options = initOptions() + val cmd = parser.parse(options, args) + + if (cmd.hasOption(HELP_OPTION)) { + new HelpFormatter().printHelp("UcxShufflePerfTool", options) + System.exit(0) + } + + val inetAddress = if (cmd.hasOption(ADDRESS_OPTION)) { + val Array(host, port) = cmd.getOptionValue(ADDRESS_OPTION).split(":") + new InetSocketAddress(host, Integer.parseInt(port)) + } else { + null + } + + val serverPort = Integer.parseInt(cmd.getOptionValue(PORT_OPTION, "12345")) + + val numIterations = Integer.parseInt(cmd.getOptionValue(ITER_OPTION, "5")) + + val threadsNumber = Integer.parseInt(cmd.getOptionValue(NUM_THREADS_OPTION, "1")) + + var memoryType = UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST + if (cmd.hasOption(MEMORY_TYPE_OPTION) && cmd.getOptionValue(MEMORY_TYPE_OPTION) == "cuda") { + memoryPool = transport.deviceMemoryPool + memoryType = UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA + } + + PerfOptions(inetAddress, + Integer.parseInt(cmd.getOptionValue(NUM_BLOCKS_OPTION, "1")), + Utils.byteStringAsBytes(cmd.getOptionValue(SIZE_OPTION, "4m")), + serverPort, numIterations, threadsNumber, memoryType) + } + + private def startServer(perfOptions: PerfOptions): Unit = { + val blocks: Seq[Block] = (0 until perfOptions.numBlocks * perfOptions.numThreads).map { _ => + if (ucxShuffleConf.pinMemory) { + val block = memoryPool.get(perfOptions.blockSize) + new Block { + override def getMemoryBlock: MemoryBlock = + block + } + } else if (perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) { + new Block { + var oldBlock: DeviceMemoryBuffer = _ + override def getMemoryBlock: MemoryBlock = { + if (oldBlock != null) { + oldBlock.close() + } + val block = DeviceMemoryBuffer.allocate(perfOptions.blockSize) + oldBlock = block + MemoryBlock(block.getAddress, block.getLength, isHostMemory = false) + } + } + } else { + new Block { + override def getMemoryBlock: MemoryBlock = { + val buf = ByteBuffer.allocateDirect(perfOptions.blockSize.toInt) + MemoryBlock(UcxUtils.getAddress(buf), perfOptions.blockSize) + } + } + } + + } + + val blockIds = (0 until perfOptions.numBlocks * perfOptions.numThreads).map(i => TestBlockId(i)) + blockIds.zip(blocks).foreach { + case (blockId, block) => transport.register(blockId, block) + } + + val serverSocket = new ServerSocket(perfOptions.serverPort) + + println(s"Waiting for connections on " + + s"${InetAddress.getLocalHost.getHostName}:${perfOptions.serverPort} ") + + val clientSocket = serverSocket.accept() + val out = clientSocket.getOutputStream + val in = clientSocket.getInputStream + + val buf = ByteBuffer.allocate(workerAddress.capacity()) + buf.put(workerAddress) + buf.flip() + + out.write(buf.array()) + out.flush() + + println(s"Sending worker address to ${clientSocket.getInetAddress}") + + buf.flip() + + in.read(buf.array()) + clientSocket.close() + serverSocket.close() + + blocks.foreach(block => memoryPool.put(block.getMemoryBlock)) + blockIds.foreach(transport.unregister) + transport.close() + } + + private def startClient(perfOptions: PerfOptions): Unit = { + val socket = new Socket(perfOptions.remoteAddress.getHostName, + perfOptions.remoteAddress.getPort) + + val buf = new Array[Byte](4096) + val readSize = socket.getInputStream.read(buf) + val executorId = "1" + val workerAddress = ByteBuffer.allocateDirect(readSize) + + workerAddress.put(buf, 0, readSize) + println("Received worker address") + + transport.addExecutor(executorId, workerAddress) + + val resultSize = perfOptions.numThreads * perfOptions.numBlocks * perfOptions.blockSize + val resultMemory = memoryPool.get(resultSize) + transport.register(TestBlockId(-1), new Block { + override def getMemoryBlock: MemoryBlock = resultMemory + }) + + val threadPool = Executors.newFixedThreadPool(perfOptions.numThreads) + + for (i <- 1 to perfOptions.numIterations) { + val startTime = System.nanoTime() + val countDownLatch = new CountDownLatch(perfOptions.numThreads) + val deviceMemoryBuffers: Array[DeviceMemoryBuffer] = new Array(perfOptions.numThreads * perfOptions.numBlocks) + for (tid <- 0 until perfOptions.numThreads) { + threadPool.execute(() => { + val blocksOffset = tid * perfOptions.numBlocks + val blockIds = (blocksOffset until blocksOffset + perfOptions.numBlocks).map(i => TestBlockId(i)) + + val mem = new Array[MemoryBlock](perfOptions.numBlocks) + + for (j <- 0 until perfOptions.numBlocks) { + mem(j) = MemoryBlock(resultMemory.address + + (tid * perfOptions.numBlocks * perfOptions.blockSize) + j * perfOptions.blockSize, perfOptions.blockSize) + + if (!ucxShuffleConf.pinMemory && perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) { + val buffer = Platform.allocateDirectBuffer(perfOptions.blockSize.toInt) + mem(j) = MemoryBlock(UcxUtils.getAddress(buffer), perfOptions.blockSize) + } else if (!ucxShuffleConf.pinMemory && + perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) { + val buffer = DeviceMemoryBuffer.allocate(perfOptions.blockSize) + deviceMemoryBuffers(tid * perfOptions.numBlocks + j) = buffer + mem(j) = MemoryBlock(buffer.getAddress, buffer.getLength, isHostMemory = false) + } else { + mem(j) = MemoryBlock(resultMemory.address + + (tid * perfOptions.numBlocks * perfOptions.blockSize) + j * perfOptions.blockSize, perfOptions.blockSize) + } + } + + val requests = transport.fetchBlocksByBlockIds(executorId, blockIds, mem, Seq.empty[OperationCallback]) + + while (!requests.forall(_.isCompleted)) { + transport.progress() + } + + countDownLatch.countDown() + }) + } + + countDownLatch.await() + val elapsedTime = System.nanoTime() - startTime + + deviceMemoryBuffers.foreach(d => if (d != null) d.close()) + val totalTime = if (elapsedTime < TimeUnit.MILLISECONDS.toNanos(1)) { + s"$elapsedTime ns" + } else { + s"${TimeUnit.NANOSECONDS.toMillis(elapsedTime)} ms" + } + val throughput: Double = (resultSize / 1024.0D / 1024.0D / 1024.0D) / + (elapsedTime / 1e9D) + + if ((i % 100 == 0) || i == perfOptions.numIterations) { + println(f"${s"[$i/${perfOptions.numIterations}]"}%12s" + + s" numBlocks: ${perfOptions.numBlocks}" + + s" numThreads: ${perfOptions.numThreads}" + + s" size: ${Utils.bytesToString(perfOptions.blockSize)}," + + s" total size: ${Utils.bytesToString(resultSize * perfOptions.numThreads)}," + + f" time: $totalTime%3s" + + f" throughput: $throughput%.5f GB/s") + } + } + + val out = socket.getOutputStream + out.write(buf) + out.flush() + out.close() + socket.close() + + transport.unregister(TestBlockId(-1)) + memoryPool.put(resultMemory) + transport.close() + threadPool.shutdown() + } + + def main(args: Array[String]): Unit = { + val perfOptions = parseOptions(args) + + if (perfOptions.remoteAddress == null) { + startServer(perfOptions) + } else { + startClient(perfOptions) + } + } + +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala index c65dba23..aec6c17e 100755 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala @@ -5,91 +5,65 @@ package org.apache.spark.shuffle.ucx.rpc import java.io.ObjectInputStream +import java.nio.ByteBuffer import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream -import org.openucx.jucx.ucp.{UcpRequest, UcpWorker} -import org.openucx.jucx.{UcxException, UcxUtils} +import org.openucx.jucx.UcxUtils +import org.openucx.jucx.ucp._ +import org.openucx.jucx.ucs.UcsConstants import org.apache.spark.internal.Logging -import org.apache.spark.shuffle.ucx.memory.MemoryPool -import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds} -import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleTransport} +import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.FetchBlocksByBlockIdsRequest import org.apache.spark.util.Utils -class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool, - transport: UcxShuffleTransport) extends Thread with Logging { +class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransport) + extends Thread with Logging { setDaemon(true) setName("Ucx Shuffle Transport Progress Thread") + private def handleFetchBlockRequest(buffer: ByteBuffer, ep: UcpEndpoint): Unit = { + val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => + val objIn = new ObjectInputStream(bin) + val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest] + objIn.close() + obj + } + val startTime = System.nanoTime() + transport.replyFetchBlocksRequest(msg.blockIds, msg.isHostMemory, ep, msg.startTag, msg.singleReply) + logInfo(s"Sent reply for ${msg.blockIds.length} blocks in ${Utils.getUsedTimeNs(startTime)}") + } override def run(): Unit = { - val numRecvs = transport.ucxShuffleConf.recvQueueSize - val msgSize = transport.ucxShuffleConf.rpcMessageSize - val requests: Array[UcpRequest] = Array.ofDim[UcpRequest](numRecvs) - val recvMemory = memPool.get(msgSize * numRecvs) - val memoryBlocks = (0 until numRecvs).map(i => - MemoryBlock(recvMemory.address + i * msgSize, msgSize)) - - for (i <- 0 until numRecvs) { - requests(i) = globalWorker.recvTaggedNonBlocking(memoryBlocks(i).address, msgSize, - UcxRpcMessages.WILDCARD_TAG, UcxRpcMessages.WILDCARD_TAG_MASK, null) + val processCallback: UcpAmRecvCallback = (headerAddress: Long, headerSize: Long, amData: UcpAmData, + replyEp: UcpEndpoint) => { + if (headerSize > 0) { + logDebug(s"Received AM in header on ${transport.clientConnections.get(replyEp)}") + val header = UcxUtils.getByteBufferView(headerAddress, headerSize) + handleFetchBlockRequest(header, replyEp) + UcsConstants.STATUS.UCS_OK + } else { + assert(amData.isDataValid) + logDebug(s"Received AM in eager on ${transport.clientConnections.get(replyEp)}") + val data = UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength) + handleFetchBlockRequest(data, replyEp) + UcsConstants.STATUS.UCS_OK + } } - + globalWorker.setAmRecvHandler(0, processCallback) + globalWorker.setAmRecvHandler(-1, new UcpAmRecvCallback() { + override def onReceive(headerAddress: Long, headerSize: Long, amData: UcpAmData, replyEp: UcpEndpoint): Int = { + logTrace(s"Hello") + UcsConstants.STATUS.UCS_OK + } + }) while (!isInterrupted) { if (globalWorker.progress() == 0) { - globalWorker.waitForEvents() - globalWorker.progress() - } - for (i <- 0 until numRecvs) { - if (requests(i).isCompleted) { - val buffer = UcxUtils.getByteBufferView(memoryBlocks(i).address, msgSize.toInt) - val senderTag = requests(i).getSenderTag - if (senderTag >= 2L) { - // Fetch Single block - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[FetchBlockByBlockIdRequest] - objIn.close() - obj - } - transport.replyFetchBlockRequest(msg.executorId, msg.workerAddress.value, msg.blockId, senderTag) - } else if (senderTag < 0 ) { - // Batch fetch request - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest] - objIn.close() - obj - } - for (j <- msg.blockIds.indices) { - transport.replyFetchBlockRequest(msg.executorId, msg.workerAddress.value, msg.blockIds(j), - senderTag + j) - } - - } else if (senderTag == UcxRpcMessages.PREFETCH_TAG) { - val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin => - val objIn = new ObjectInputStream(bin) - val obj = objIn.readObject().asInstanceOf[PrefetchBlockIds] - objIn.close() - obj - } - transport.handlePrefetchRequest(msg.executorId, msg.workerAddress.value, msg.blockIds) - } - requests(i) = globalWorker.recvTaggedNonBlocking(memoryBlocks(i).address, msgSize, - UcxRpcMessages.WILDCARD_TAG, UcxRpcMessages.WILDCARD_TAG_MASK, null) + if (transport.ucxShuffleConf.useWakeup) { + globalWorker.waitForEvents() } } } - memPool.put(recvMemory) - for (i <- 0 until numRecvs) { - if (!requests(i).isCompleted) { - try { - globalWorker.cancelRequest(requests(i)) - } catch { - case _: UcxException => - } - } - } } } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala new file mode 100755 index 00000000..50a9e2e2 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala @@ -0,0 +1,42 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import scala.collection.immutable.HashMap +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc._ +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer + +class UcxDriverRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + private val endpoints: mutable.Set[RpcEndpointRef] = mutable.HashSet.empty + private var blockManagerToWorkerAddress = HashMap.empty[String, SerializableDirectBuffer] + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case message@ExecutorAdded(executorId: String, endpoint: RpcEndpointRef, + ucxWorkerAddress: SerializableDirectBuffer) => { + // Driver receives a message from executor with it's workerAddress + // 1. Introduce existing members of a cluster + logInfo(s"Received $message") + if (blockManagerToWorkerAddress.nonEmpty) { + val msg = IntroduceAllExecutors(blockManagerToWorkerAddress.keys.toSeq, + blockManagerToWorkerAddress.values.toList) + logInfo(s"replying $msg to $executorId") + context.reply(msg) + } + blockManagerToWorkerAddress += executorId -> ucxWorkerAddress + // 2. For each existing member introduce newly joined executor. + endpoints.foreach(ep => { + logInfo(s"Sending $message to $ep") + ep.send(message) + }) + logInfo(s"Connecting back to address: ${context.senderAddress}") + endpoints.add(endpoint) + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala new file mode 100755 index 00000000..3b8dec20 --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala @@ -0,0 +1,31 @@ +/* +* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* See file LICENSE for terms. +*/ +package org.apache.spark.shuffle.ucx.rpc + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.shuffle.ucx.UcxShuffleTransport +import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors} +import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer + +class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleTransport) + extends RpcEndpoint with Logging { + + + override def receive: PartialFunction[Any, Unit] = { + case ExecutorAdded(executorId: String, _: RpcEndpointRef, + ucxWorkerAddress: SerializableDirectBuffer) => { + logInfo(s"Received ExecutorAdded($executorId)") + transport.addExecutor(executorId, ucxWorkerAddress.value) + } + case IntroduceAllExecutors(executorIds: Seq[String], + ucxWorkerAddresses: Seq[SerializableDirectBuffer]) => { + logInfo(s"Received IntroduceAllExecutors(${executorIds.mkString(",")})") + executorIds.zip(ucxWorkerAddresses).foreach { + case (executorId, workerAddress) => transport.addExecutor(executorId, workerAddress.value) + } + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala index ebdc47f4..432d6e32 100644 --- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala @@ -9,11 +9,6 @@ import org.apache.spark.shuffle.ucx.BlockId import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer object UcxRpcMessages { - - val PREFETCH_TAG = 1L - val WILDCARD_TAG = -1L - val WILDCARD_TAG_MASK = 0L - /** * Called from executor to driver, to introduce ucx worker address. */ @@ -27,13 +22,6 @@ object UcxRpcMessages { case class IntroduceAllExecutors(executorIds: Seq[String], ucxWorkerAddresses: Seq[SerializableDirectBuffer]) - case class FetchBlockByBlockIdRequest(executorId: String, workerAddress: SerializableDirectBuffer, - blockId: BlockId) - - case class FetchBlocksByBlockIdsRequest(executorId: String, - workerAddress: SerializableDirectBuffer, - blockIds: Seq[BlockId]) - - case class PrefetchBlockIds(executorId: String, workerAddress: SerializableDirectBuffer, - blockIds: Seq[BlockId]) + case class FetchBlocksByBlockIdsRequest(startTag: Int, blockIds: Seq[BlockId], + isHostMemory: Seq[Boolean], singleReply: Boolean = false) } diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala index 009d7b08..4d453b38 100644 --- a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala +++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala @@ -5,11 +5,12 @@ package org.apache.spark.shuffle.ucx.utils import java.io.{EOFException, ObjectInputStream, ObjectOutputStream} +import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.Channels import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} /** * A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make @@ -64,3 +65,28 @@ class DeserializableToExternalMemoryBuffer(@transient var buffer: ByteBuffer)() buffer.rewind() // Allow us to read it later } } + + +object SerializationUtils { + + def deserializeInetAddress(workerAddress: ByteBuffer): InetSocketAddress = { + workerAddress.rewind() + Utils.tryWithResource(new ByteBufferInputStream(workerAddress)) { bin => + val objIn = new ObjectInputStream(bin) + val obj = objIn.readObject().asInstanceOf[InetSocketAddress] + objIn.close() + obj + } + } + + def serializeInetAddress(address: InetSocketAddress): ByteBuffer = { + val hostAddress = new InetSocketAddress(Utils.localCanonicalHostName(), address.getPort) + Utils.tryWithResource(new ByteBufferOutputStream(100)) {bos => + val out = new ObjectOutputStream(bos) + out.writeObject(hostAddress) + out.flush() + out.close() + bos.toByteBuffer + } + } +} diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala new file mode 100755 index 00000000..37c7d34e --- /dev/null +++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala @@ -0,0 +1,36 @@ +package org.apache.spark.shuffle.ucx.utils + +import java.net.{BindException, InetSocketAddress} + +import scala.collection.mutable +import scala.util.Random + +import org.openucx.jucx.UcxException +import org.openucx.jucx.ucp._ +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +object UcxHelperUtils extends Logging{ + def startListenerOnRandomPort(worker: UcpWorker, sparkConf: SparkConf, + endpoints: mutable.HashMap[UcpEndpoint, (UcpEndpoint, InetSocketAddress)]): + UcpListener = { + val ucpListenerParams = new UcpListenerParams().setConnectionHandler( + (ucpConnectionRequest: UcpConnectionRequest) => { + val repyEp = worker.newEndpoint(new UcpEndpointParams().setPeerErrorHandlingMode() + .setConnectionRequest(ucpConnectionRequest)) + endpoints += repyEp -> (repyEp, ucpConnectionRequest.getClientAddress) + }) + val (listener, _) = Utils.startServiceOnPort(1024 + Random.nextInt(65535 - 1024), (port: Int) => { + ucpListenerParams.setSockAddr(new InetSocketAddress(port)) + val listener = try { + worker.newListener(ucpListenerParams) + } catch { + case ex:UcxException => throw new BindException(ex.getMessage) + } + (listener, listener.getAddress.getPort) + }, sparkConf) + logInfo(s"Started UcxListener on ${listener.getAddress}") + listener + } +} diff --git a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala index f391f3cd..6c107d6b 100755 --- a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala +++ b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala @@ -1,5 +1,5 @@ /* -* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED. +* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED. * See file LICENSE for terms. */ package org.apache.spark.shuffle.ucx @@ -76,20 +76,16 @@ class UcxShuffleTransportTestSuite extends AnyFunSuite { val resultMemory2 = MemoryBlock(UcxUtils.getAddress(resultBuffer2), blockSize) val completed = new AtomicInteger(0) + val callback: OperationCallback = (result: OperationResult) => { + completed.incrementAndGet() + assert(result.getStats.get.recvSize == blockSize) + assert(result.getStatus == OperationStatus.SUCCESS) + } - transport.fetchBlockByBlockId(server.ID, TestBlockId(1), resultMemory1, - (result: OperationResult) => { - completed.incrementAndGet() - assert(result.getStats.get.recvSize == blockSize) - assert(result.getStatus == OperationStatus.SUCCESS) - }) - - transport.fetchBlockByBlockId(server.ID, TestBlockId(2), resultMemory2, - (result: OperationResult) => { - completed.incrementAndGet() - assert(result.getStats.get.recvSize == blockSize) - assert(result.getStatus == OperationStatus.SUCCESS) - }) + transport.fetchBlocksByBlockIds(server.ID, + Array(TestBlockId(1), TestBlockId(2)), + Array(resultMemory1, resultMemory2), + Array(callback, callback)) while (completed.get() != 2) { transport.progress() @@ -107,7 +103,9 @@ class UcxShuffleTransportTestSuite extends AnyFunSuite { while (!mutated.get()) {} - val request = transport.fetchBlockByBlockId(server.ID, TestBlockId(2), resultMemory2, null) + val request = transport.fetchBlocksByBlockIds(server.ID, + Array(TestBlockId(2)), Array(resultMemory2), Seq.empty)(0) + while (!request.isCompleted) { transport.progress() }