diff --git a/pom.xml b/pom.xml
index d9330461..1f5c22b9 100755
--- a/pom.xml
+++ b/pom.xml
@@ -34,7 +34,8 @@ See file LICENSE for terms.
3.0.0
2.12.12
2.12
- 1.10.0-SNAPSHOT
+ 1.11.0-rc3
+ 0.18.1
@@ -55,26 +56,34 @@ See file LICENSE for terms.
3.2.1
test
+
+ ai.rapids
+ cudf
+ ${cudf.version}
+ provided
+
- ${project.artifactId}-${project.version}-for-spark-${spark.version}
+ ${project.artifactId}-${project.version}-for-spark-3.0
org.apache.maven.plugins
maven-compiler-plugin
3.8.1
-
-
- 1.8
-
net.alchim31.maven
scala-maven-plugin
- 4.4.0
+ 4.4.1
all
+
+ -source
+ ${maven.compiler.source}
+ -target
+ ${maven.compiler.target}
+
-Xexperimental
-Xfatal-warnings
@@ -137,6 +146,18 @@ See file LICENSE for terms.
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ 3.2.0
+
+
+
+ test-jar
+
+
+
+
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala
index 0bd92891..bd8adaef 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala
@@ -135,11 +135,6 @@ trait ShuffleTransport {
*/
def unregister(blockId: BlockId)
- /**
- * Hint for a transport that these blocks would needed soon.
- */
- def prefetchBlocks(executorId: String, blockIds: Seq[BlockId])
-
/**
* Batch version of [[ fetchBlocksByBlockIds ]].
*/
@@ -147,11 +142,9 @@ trait ShuffleTransport {
resultBuffer: Seq[MemoryBlock],
callbacks: Seq[OperationCallback]): Seq[Request]
- /**
- * Fetch remote blocks by blockIds.
- */
- def fetchBlockByBlockId(executorId: String, blockId: BlockId,
- resultBuffer: MemoryBlock, cb: OperationCallback): Request
+ def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
+ resultBuffer: MemoryBlock,
+ callbacks: OperationCallback): Request
/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala
new file mode 100755
index 00000000..420d161f
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleBlockResolver.scala
@@ -0,0 +1,101 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import java.io.File
+import java.nio.ByteBuffer
+import java.nio.channels.FileChannel
+import java.nio.file.StandardOpenOption
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.JavaConverters._
+
+import org.openucx.jucx.UcxUtils
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.unsafe.Platform
+
+
+case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {
+
+ def this(shuffleBlockId: ShuffleBlockId) = {
+ this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId)
+ }
+
+ def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
+}
+
+case class BufferBackedBlock(buffer: ByteBuffer) extends Block {
+ override def getMemoryBlock: MemoryBlock = MemoryBlock(UcxUtils.getAddress(buffer), buffer.capacity())
+}
+
+class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport)
+ extends IndexShuffleBlockResolver(conf) {
+
+ type MapId = Long
+
+ private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int]
+
+ override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
+ lengths: Array[Long], dataTmp: File): Unit = {
+ super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
+ val dataFile = getDataFile(shuffleId, mapId)
+ if (!dataFile.exists()) {
+ return
+ }
+ numPartitionsForMapId.put(mapId, lengths.length)
+ val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ,
+ StandardOpenOption.WRITE)
+ val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length())
+
+ val baseAddress = UcxUtils.getAddress(mappedBuffer)
+ fileChannel.close()
+
+ // Register whole map output file as dummy block
+ transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE),
+ BufferBackedBlock(mappedBuffer))
+
+ val offsetSize = 8 * (lengths.length + 1)
+ val indexBuf = Platform.allocateDirectBuffer(offsetSize)
+
+ var offset = 0L
+ indexBuf.putLong(offset)
+ for (reduceId <- lengths.indices) {
+ if (lengths(reduceId) > 0) {
+ transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block {
+ private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId))
+ override def getMemoryBlock: MemoryBlock = memoryBlock
+ })
+ offset += lengths(reduceId)
+ indexBuf.putLong(offset)
+ }
+ }
+
+ if (transport.ucxShuffleConf.protocol == transport.ucxShuffleConf.PROTOCOL.ONE_SIDED) {
+ transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE), BufferBackedBlock(indexBuf))
+ }
+ }
+
+ override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = {
+ transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE))
+ transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE))
+
+ val numRegisteredBlocks = numPartitionsForMapId.get(mapId)
+ (0 until numRegisteredBlocks)
+ .foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId)))
+ super.removeDataByMap(shuffleId, mapId)
+ }
+
+ override def stop(): Unit = {
+ numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId))
+ super.stop()
+ }
+
+}
+
+object BlocksConstants {
+ val MAP_FILE: Int = -1
+ val INDEX_FILE: Int = -2
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala
new file mode 100755
index 00000000..2fc9cbb1
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala
@@ -0,0 +1,77 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import java.util.concurrent.TimeUnit
+
+import org.openucx.jucx.{UcxException, UcxUtils}
+import org.apache.spark.SparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager}
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId}
+
+class UcxShuffleClient(transport: UcxShuffleTransport,
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])])
+ extends BlockStoreClient with Logging {
+
+ private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key)
+
+ private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress
+ .withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId }
+ .flatMap {
+ case (blockManagerId, blocks) =>
+ blocks.map {
+ case (blockId, length, _) =>
+ if (length > accurateThreshold) {
+ (blockId, (length * 1.2).toLong)
+ } else {
+ (blockId, accurateThreshold * 2)
+ }
+ }
+ }.toMap
+
+ override def fetchBlocks(host: String, port: Int, execId: String,
+ blockIds: Array[String], listener: BlockFetchingListener,
+ downloadFileManager: DownloadFileManager): Unit = {
+ val ucxBlockIds = new Array[BlockId](blockIds.length)
+ val memoryBlocks = new Array[MemoryBlock](blockIds.length)
+ val callbacks = new Array[OperationCallback](blockIds.length)
+ for (i <- blockIds.indices) {
+ val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId]
+ if (!blockSizes.contains(blockId)) {
+ throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}")
+ }
+ val resultMemory = transport.hostMemoryPool.get(blockSizes(blockId))
+ ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
+ memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId))
+ callbacks(i) = (result: OperationResult) => {
+ if (result.getStatus == OperationStatus.SUCCESS) {
+ val stats = result.getStats.get
+ logInfo(s" Received block ${ucxBlockIds(i)} " +
+ s"of size: ${stats.recvSize} " +
+ s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms")
+ val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt)
+ listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
+ override def release: ManagedBuffer = {
+ transport.hostMemoryPool.put(resultMemory)
+ this
+ }
+ })
+ } else {
+ logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" +
+ s" ${result.getError.getMessage}")
+ listener.onBlockFetchFailure(blockIds(i), new UcxException(result.getError.getMessage))
+ }
+ }
+ }
+ transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks)
+ }
+
+ override def close(): Unit = {
+
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala
index 47c9fd73..56867492 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala
@@ -12,15 +12,20 @@ import org.apache.spark.util.Utils
class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name"
- private val PROTOCOL =
+ object PROTOCOL extends Enumeration {
+ val ONE_SIDED, RNDV = Value
+ }
+
+ private lazy val PROTOCOL_CONF =
ConfigBuilder(getUcxConf("protocol"))
- .doc("Which protocol to use: rndv (default), one-sided")
+ .doc("Which protocol to use: RNDV (default), ONE-SIDED")
.stringConf
.checkValue(protocol => protocol == "rndv" || protocol == "one-sided",
"Invalid protocol. Valid options: rndv / one-sided.")
- .createWithDefault("rndv")
+ .transform(_.toUpperCase.replace("-", "_"))
+ .createWithDefault("RNDV")
- private val MEMORY_PINNING =
+ private lazy val MEMORY_PINNING =
ConfigBuilder(getUcxConf("memoryPinning"))
.doc("Whether to pin whole shuffle data in memory")
.booleanConf
@@ -30,15 +35,7 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
ConfigBuilder(getUcxConf("maxWorkerSize"))
.doc("Maximum size of worker address in bytes")
.bytesConf(ByteUnit.BYTE)
- .createWithDefault(1000)
-
- lazy val RPC_MESSAGE_SIZE: ConfigEntry[Long] =
- ConfigBuilder(getUcxConf("rpcMessageSize"))
- .doc("Size of RPC message to send from fetchBlockByBlockId. Must contain ")
- .bytesConf(ByteUnit.BYTE)
- .checkValue(size => size > maxWorkerAddressSize,
- "Rpc message must contain workerAddress")
- .createWithDefault(2000)
+ .createWithDefault(1024)
// Memory Pool
private lazy val PREALLOCATE_BUFFERS =
@@ -52,11 +49,11 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
.booleanConf
.createWithDefault(false)
- private lazy val RECV_QUEUE_SIZE =
- ConfigBuilder(getUcxConf("recvQueueSize"))
- .doc("The number of submitted receive requests.")
- .intConf
- .createWithDefault(5)
+ private lazy val USE_SOCKADDR =
+ ConfigBuilder(getUcxConf("useSockAddr"))
+ .doc("Whether to use socket address to connect executors.")
+ .booleanConf
+ .createWithDefault(true)
private lazy val MIN_REGISTRATION_SIZE =
ConfigBuilder(getUcxConf("memory.minAllocationSize"))
@@ -67,24 +64,32 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key,
MIN_REGISTRATION_SIZE.defaultValueString).toInt
- lazy val protocol: String = conf.get(PROTOCOL.key, PROTOCOL.defaultValueString)
+ private lazy val USE_ODP =
+ ConfigBuilder(getUcxConf("useOdp"))
+ .doc("Whether to use on demand paging feature, to avoid memory pinning")
+ .booleanConf
+ .createWithDefault(false)
- lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false)
+ lazy val protocol: PROTOCOL.Value = PROTOCOL.withName(
+ conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString))
+
+ lazy val useOdp: Boolean = conf.getBoolean(USE_ODP.key, USE_ODP.defaultValue.get)
lazy val pinMemory: Boolean = conf.getBoolean(MEMORY_PINNING.key, MEMORY_PINNING.defaultValue.get)
lazy val maxWorkerAddressSize: Long = conf.getSizeAsBytes(WORKER_ADDRESS_SIZE.key,
WORKER_ADDRESS_SIZE.defaultValueString)
- lazy val rpcMessageSize: Long = conf.getSizeAsBytes(RPC_MESSAGE_SIZE.key,
- RPC_MESSAGE_SIZE.defaultValueString)
+ lazy val maxMetadataSize: Long = conf.getSizeAsBytes("spark.rapids.shuffle.maxMetadataSize",
+ "1024")
+
lazy val useWakeup: Boolean = conf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get)
- lazy val recvQueueSize: Int = conf.getInt(RECV_QUEUE_SIZE.key, RECV_QUEUE_SIZE.defaultValue.get)
+ lazy val useSockAddr: Boolean = conf.getBoolean(USE_SOCKADDR.key, USE_SOCKADDR.defaultValue.get)
lazy val preallocateBuffersMap: Map[Long, Int] = {
- conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty)
+ conf.get(PREALLOCATE_BUFFERS.key, "").split(",").withFilter(s => s.nonEmpty)
.map(entry => entry.split(":") match {
case Array(bufferSize, bufferCount) =>
(Utils.byteStringAsBytes(bufferSize.trim), bufferCount.toInt)
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala
new file mode 100755
index 00000000..3d957b0d
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleManager.scala
@@ -0,0 +1,102 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx
+
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.util.Success
+
+import org.apache.spark.rpc.RpcEnv
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors}
+import org.apache.spark.shuffle.ucx.rpc.{UcxDriverRpcEndpoint, UcxExecutorRpcEndpoint}
+import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
+import org.apache.spark.util.RpcUtils
+import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, TaskContext}
+
+
+class UcxShuffleManager(conf: SparkConf, isDriver: Boolean) extends SortShuffleManager(conf) {
+
+ val ucxShuffleConf = new UcxShuffleConf(conf)
+
+ lazy val ucxShuffleTransport: UcxShuffleTransport = if (!isDriver) {
+ new UcxShuffleTransport(ucxShuffleConf, "init")
+ } else {
+ null
+ }
+
+ @volatile private var initialized: Boolean = false
+
+ override val shuffleBlockResolver =
+ new UcxShuffleBlockResolver(ucxShuffleConf, ucxShuffleTransport)
+
+ logInfo("Starting UcxShuffleManager")
+
+ def initTransport(): Unit = this.synchronized {
+ if (!initialized) {
+ val driverEndpointName = "ucx-shuffle-driver"
+ if (isDriver) {
+ val rpcEnv = SparkEnv.get.rpcEnv
+ val driverEndpoint = new UcxDriverRpcEndpoint(rpcEnv)
+ rpcEnv.setupEndpoint(driverEndpointName, driverEndpoint)
+ } else {
+ val blockManager = SparkEnv.get.blockManager.blockManagerId
+ ucxShuffleTransport.executorId = blockManager.executorId
+ val rpcEnv = RpcEnv.create("ucx-rpc-env", blockManager.host, blockManager.host,
+ blockManager.port, conf, new SecurityManager(conf), 1, clientMode=false)
+ logDebug("Initializing ucx transport")
+ val address = ucxShuffleTransport.init()
+ val executorEndpoint = new UcxExecutorRpcEndpoint(rpcEnv, ucxShuffleTransport)
+ val endpoint = rpcEnv.setupEndpoint(
+ s"ucx-shuffle-executor-${blockManager.executorId}",
+ executorEndpoint)
+
+ val driverEndpoint = RpcUtils.makeDriverRef(driverEndpointName, conf, rpcEnv)
+ driverEndpoint.ask[IntroduceAllExecutors](ExecutorAdded(blockManager.executorId,
+ endpoint, new SerializableDirectBuffer(address)))
+ .andThen{
+ case Success(msg) =>
+ logInfo(s"Receive reply $msg")
+ executorEndpoint.receive(msg)
+ }
+ }
+ initialized = true
+ }
+ }
+
+ override def getReader[K, C](handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(
+ handle.shuffleId, startPartition, endPartition)
+ new UcxShuffleReader(ucxShuffleTransport,
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
+ shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
+ }
+
+ override def getReaderForRange[K, C]( handle: ShuffleHandle,
+ startMapIndex: Int,
+ endMapIndex: Int,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext,
+ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = {
+ val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByRange(
+ handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition)
+ new UcxShuffleReader(ucxShuffleTransport,
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics,
+ shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context))
+ }
+
+ override def stop(): Unit = {
+ if (ucxShuffleTransport != null) {
+ ucxShuffleTransport.close()
+ }
+ super.stop()
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala
new file mode 100755
index 00000000..6ee3e682
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleReader.scala
@@ -0,0 +1,173 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle
+
+import java.io.InputStream
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.util.Random
+
+import org.apache.spark._
+import org.apache.spark.internal.{Logging, config}
+import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.SerializerManager
+import org.apache.spark.shuffle.ucx.{UcxShuffleClient, UcxShuffleTransport}
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator}
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.ExternalSorter
+
+/**
+ * Fetches and reads the blocks from a shuffle by requesting them from other nodes' block stores.
+ */
+class UcxShuffleReader[K, C](transport: UcxShuffleTransport,
+ handle: BaseShuffleHandle[K, _, C],
+ blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+ context: TaskContext,
+ readMetrics: ShuffleReadMetricsReporter,
+ serializerManager: SerializerManager = SparkEnv.get.serializerManager,
+ blockManager: BlockManager = SparkEnv.get.blockManager,
+ mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
+ shouldBatchFetch: Boolean = false)
+ extends ShuffleReader[K, C] with Logging {
+
+ private val dep = handle.dependency
+
+ private def fetchContinuousBlocksInBatch: Boolean = {
+ val conf = SparkEnv.get.conf
+ val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects
+ val compressed = conf.get(config.SHUFFLE_COMPRESS)
+ val codecConcatenation = if (compressed) {
+ CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf))
+ } else {
+ true
+ }
+ val useOldFetchProtocol = conf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)
+
+ val doBatchFetch = shouldBatchFetch && serializerRelocatable &&
+ (!compressed || codecConcatenation) && !useOldFetchProtocol
+ if (shouldBatchFetch && !doBatchFetch) {
+ logDebug("The feature tag of continuous shuffle block fetching is set to true, but " +
+ "we can not enable the feature because other conditions are not satisfied. " +
+ s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " +
+ s"codec concatenation: $codecConcatenation, use old shuffle fetch protocol: " +
+ s"$useOldFetchProtocol.")
+ }
+ doBatchFetch
+ }
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ val (blocksByAddress1, blocksByAddress2) = blocksByAddress.duplicate
+ val shuffleClient = new UcxShuffleClient(transport, Random.shuffle(blocksByAddress1))
+ val shuffleIterator = new ShuffleBlockFetcherIterator(
+ context,
+ shuffleClient,
+ blockManager,
+ Random.shuffle(blocksByAddress2),
+ serializerManager.wrapStream,
+ // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+ SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024,
+ SparkEnv.get.conf.get(config.REDUCER_MAX_REQS_IN_FLIGHT),
+ SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
+ SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT),
+ SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY),
+ readMetrics,
+ fetchContinuousBlocksInBatch)
+ val wrappedStreams = shuffleIterator.toCompletionIterator
+
+ // Ucx shuffle logic
+ // Java reflection to get access to private results queue
+ val queueField = classOf[ShuffleBlockFetcherIterator].getDeclaredField(
+ "org$apache$spark$storage$ShuffleBlockFetcherIterator$$results")
+ queueField.setAccessible(true)
+ val resultQueue = queueField.get(shuffleIterator).asInstanceOf[LinkedBlockingQueue[_]]
+
+ // Do progress if queue is empty before calling next on ShuffleIterator
+ val ucxWrappedStream = new Iterator[(BlockId, InputStream)] {
+ override def next(): (BlockId, InputStream) = {
+ val startTime = System.currentTimeMillis()
+ while (resultQueue.isEmpty) {
+ transport.progress()
+ }
+ readMetrics.incFetchWaitTime(System.currentTimeMillis() - startTime)
+ wrappedStreams.next()
+ }
+
+ override def hasNext: Boolean = {
+ val result = wrappedStreams.hasNext
+ if (!result) {
+ shuffleClient.close()
+ }
+ result
+ }
+ }
+ // End of ucx shuffle logic
+
+ val serializerInstance = dep.serializer.newInstance()
+
+ // Create a key/value iterator for each stream
+ val recordIter = ucxWrappedStream.flatMap { case (blockId, wrappedStream) =>
+ // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+ // NextIterator. The NextIterator makes sure that close() is called on the
+ // underlying InputStream when all records have been read.
+ serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+ }
+
+ // Update the context task metrics for each record read.
+ val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+ recordIter.map { record =>
+ readMetrics.incRecordsRead(1)
+ record
+ },
+ context.taskMetrics().mergeShuffleReadMetrics())
+
+ // An interruptible iterator must be used here in order to support task cancellation
+ val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+ val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+ if (dep.mapSideCombine) {
+ // We are reading values that are already combined
+ val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+ dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+ } else {
+ // We don't know the value type, but also don't care -- the dependency *should*
+ // have made sure its compatible w/ this aggregator, which will convert the value
+ // type to the combined type C
+ val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+ dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
+ }
+ } else {
+ interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
+ }
+
+ // Sort the output if there is a sort ordering defined.
+ val resultIter = dep.keyOrdering match {
+ case Some(keyOrd: Ordering[K]) =>
+ // Create an ExternalSorter to sort the data.
+ val sorter =
+ new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
+ sorter.insertAll(aggregatedIter)
+ context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+ context.taskMetrics().incPeakExecutionMemory(sorter.peakMemoryUsedBytes)
+ // Use completion callback to stop sorter if task was finished/cancelled.
+ context.addTaskCompletionListener[Unit](_ => {
+ sorter.stop()
+ })
+ CompletionIterator[Product2[K, C], Iterator[Product2[K, C]]](sorter.iterator, sorter.stop())
+ case None =>
+ aggregatedIter
+ }
+
+ resultIter match {
+ case _: InterruptibleIterator[Product2[K, C]] => resultIter
+ case _ =>
+ // Use another interruptible iterator here to support task cancellation as aggregator
+ // or(and) sorter may have consumed previous interruptible iterator.
+ new InterruptibleIterator[Product2[K, C]](context, resultIter)
+ }
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala
index 10aab2f2..c4dff189 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransport.scala
@@ -3,25 +3,28 @@
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx
+
+import java.net.InetSocketAddress
import java.nio.ByteBuffer
-import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue}
+import java.util.concurrent.locks.Lock
+import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.concurrent.ExecutionContext.Implicits.global
-import scala.concurrent.Future
-import scala.util.{Failure, Success}
import org.openucx.jucx.ucp._
-import org.openucx.jucx.{UcxCallback, UcxException}
+import org.openucx.jucx.ucs.UcsConstants
+import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
-import org.apache.spark.shuffle.ucx.memory.{MemoryPool, UcxHostBounceBuffersPool}
+import org.apache.spark.shuffle.ucx.memory.{UcxGpuBounceBuffersPool, UcxHostBounceBuffersPool}
import org.apache.spark.shuffle.ucx.rpc.GlobalWorkerRpcThread
+import org.apache.spark.shuffle.ucx.utils.{SerializationUtils, UcxHelperUtils}
+import org.apache.spark.util.Utils
/**
- * Special type of [[ Block ]] interface backed by UcpMemory,
- * it may not be actually pinned if used with [[ UcxShuffleConf.useOdp ]] flag.
+ * Special type of [[ Block ]] interface backed by UcpMemory.
*/
case class UcxPinnedBlock(block: Block, ucpMemory: UcpMemory, prefetched: Boolean = false)
extends Block {
@@ -30,6 +33,7 @@ case class UcxPinnedBlock(block: Block, ucpMemory: UcpMemory, prefetched: Boolea
class UcxStats extends OperationStats {
private[ucx] val startTime = System.nanoTime()
+ private[ucx] var amHandleTime = 0L
private[ucx] var endTime: Long = 0L
private[ucx] var receiveSize: Long = 0L
@@ -41,31 +45,37 @@ class UcxStats extends OperationStats {
override def getElapsedTimeNs: Long = endTime - startTime
/**
- * Indicates number of valid bytes in receive memory when using
- * [[ ShuffleTransport.fetchBlockByBlockId()]]
+ * Indicates number of valid bytes in receive memory
*/
override def recvSize: Long = receiveSize
}
-class UcxRequest(request: UcpRequest, stats: OperationStats) extends Request {
+class UcxRequest(private var request: UcpRequest, stats: OperationStats, private val worker: UcpWorker)
+ extends Request {
+
+ private[ucx] var completed = false
- override def isCompleted: Boolean = request.isCompleted
+ override def isCompleted: Boolean = completed || ((request != null) && request.isCompleted)
- override def cancel(): Unit = request.close()
+ override def cancel(): Unit = if (request != null) worker.cancelRequest(request)
override def getStats: Option[OperationStats] = Some(stats)
+
+ private[ucx] def setRequest(request: UcpRequest): Unit = {
+ this.request = request
+ }
}
/**
* UCX implementation of [[ ShuffleTransport ]] API
*/
-class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executorId: String)
+class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, var executorId: String)
extends ShuffleTransport with Logging {
// UCX entities
- private var ucxContext: UcpContext = _
+ private[ucx] var ucxContext: UcpContext = _
private var globalWorker: UcpWorker = _
- private val ucpWorkerParams = new UcpWorkerParams()
+ private val ucpWorkerParams = new UcpWorkerParams().requestThreadSafety()
// TODO: reimplement as workerPool, since spark may create/destroy threads dynamically
private var threadLocalWorker: ThreadLocal[UcxWorkerWrapper] = _
@@ -77,69 +87,118 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
// Mapping between executorId and it's address
private[ucx] val executorIdToAddress = new ConcurrentHashMap[String, ByteBuffer]()
- private[ucx] val clientConnections = mutable.Map.empty[String, UcpEndpoint]
+ private[ucx] val executorIdToSockAddress = new ConcurrentHashMap[String, InetSocketAddress]()
+ private[ucx] val clientConnections = mutable.HashMap.empty[UcpEndpoint, (UcpEndpoint, InetSocketAddress)]
// Need host ucx bounce buffer memory pool to send fetchBlockByBlockId request
- var memoryPool: MemoryPool = _
+ var hostMemoryPool: UcxHostBounceBuffersPool = _
+ var deviceMemoryPool: UcxGpuBounceBuffersPool = _
+
+ @volatile private var initialized: Boolean = false
+
+ private var workerAddress: ByteBuffer = _
+ private var listener: Option[UcpListener] = None
/**
* Initialize transport resources. This function should get called after ensuring that SparkConf
* has the correct configurations since it will use the spark configuration to configure itself.
*/
- override def init(): ByteBuffer = {
- if (ucxShuffleConf == null) {
- ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf)
- }
+ override def init(): ByteBuffer = this.synchronized {
+ if (!initialized) {
+ if (ucxShuffleConf == null) {
+ ucxShuffleConf = new UcxShuffleConf(SparkEnv.get.conf)
+ }
- if (ucxShuffleConf.useOdp) {
- memMapParams.nonBlocking()
- }
+ if (ucxShuffleConf.useOdp) {
+ memMapParams.nonBlocking()
+ }
- val params = new UcpParams().requestTagFeature().requestWakeupFeature()
- if (ucxShuffleConf.protocol == "one-sided") {
- params.requestRmaFeature()
- }
- ucxContext = new UcpContext(params)
- globalWorker = ucxContext.newWorker(new UcpWorkerParams().requestWakeupTagRecv()
- .requestWakeupTagSend())
-
- val result = globalWorker.getAddress
- require(result.capacity <= ucxShuffleConf.maxWorkerAddressSize,
- s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${result.capacity}")
-
- memoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext)
- progressThread = new GlobalWorkerRpcThread(globalWorker, memoryPool, this)
-
- threadLocalWorker = ThreadLocal.withInitial(() => {
- val localWorker = ucxContext.newWorker(ucpWorkerParams)
- val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, memoryPool)
- allocatedWorkers.add(workerWrapper)
- workerWrapper
- })
+ val params = new UcpParams().requestTagFeature().requestAmFeature()
- progressThread.start()
- result
+ if (ucxShuffleConf.useWakeup) {
+ params.requestWakeupFeature()
+ }
+
+ if (ucxShuffleConf.protocol == ucxShuffleConf.PROTOCOL.ONE_SIDED) {
+ params.requestRmaFeature()
+ }
+ ucxContext = new UcpContext(params)
+
+ val workerParams = new UcpWorkerParams()
+
+ if (ucxShuffleConf.useWakeup) {
+ workerParams.requestWakeupTagRecv().requestWakeupTagSend()
+ }
+ globalWorker = ucxContext.newWorker(workerParams)
+
+ workerAddress = if (ucxShuffleConf.useSockAddr) {
+ listener = Some(UcxHelperUtils.startListenerOnRandomPort(globalWorker, ucxShuffleConf.conf, clientConnections))
+ val buffer = SerializationUtils.serializeInetAddress(listener.get.getAddress)
+ buffer
+ } else {
+ val workerAddress = globalWorker.getAddress
+ require(workerAddress.capacity <= ucxShuffleConf.maxWorkerAddressSize,
+ s"${ucxShuffleConf.WORKER_ADDRESS_SIZE.key} < ${workerAddress.capacity}")
+ workerAddress
+ }
+
+ hostMemoryPool = new UcxHostBounceBuffersPool(ucxShuffleConf, ucxContext)
+ deviceMemoryPool = new UcxGpuBounceBuffersPool(ucxShuffleConf, ucxContext)
+ progressThread = new GlobalWorkerRpcThread(globalWorker, this)
+
+ val numThreads = ucxShuffleConf.getInt("spark.executor.cores", 4)
+ val allocateWorker = () => {
+ val localWorker = ucxContext.newWorker(ucpWorkerParams)
+ val workerWrapper = new UcxWorkerWrapper(localWorker, this, ucxShuffleConf, hostMemoryPool)
+ allocatedWorkers.add(workerWrapper)
+ logInfo(s"Pre-connections on a new worker to ${executorIdToAddress.size()} " +
+ s"executors")
+ executorIdToAddress.keys.asScala.foreach(e => {
+ workerWrapper.getConnection(e)
+ })
+ workerWrapper
+ }
+ val preAllocatedWorkers = (0 until numThreads).map(_ => allocateWorker()).toList.asJava
+ val workersPool = new ConcurrentLinkedQueue[UcxWorkerWrapper](preAllocatedWorkers)
+ threadLocalWorker = ThreadLocal.withInitial(() => {
+ if (workersPool.isEmpty) {
+ logWarning(s"Allocating new worker")
+ allocateWorker()
+ } else {
+ workersPool.poll()
+ }
+ })
+
+ progressThread.start()
+ initialized = true
+ }
+ workerAddress
}
/**
* Close all transport resources
*/
override def close(): Unit = {
- progressThread.interrupt()
- globalWorker.signal()
- try {
- progressThread.join()
- } catch {
- case _:InterruptedException =>
- case e:Throwable => logWarning(e.getLocalizedMessage)
+ if (initialized) {
+ progressThread.interrupt()
+ if (ucxShuffleConf.useWakeup) {
+ globalWorker.signal()
+ }
+ try {
+ progressThread.join()
+ hostMemoryPool.close()
+ deviceMemoryPool.close()
+ clientConnections.keys.foreach(ep => ep.close())
+ registeredBlocks.forEachKey(100, blockId => unregister(blockId))
+ allocatedWorkers.forEach(_.close())
+ listener.foreach(_.close())
+ globalWorker.close()
+ ucxContext.close()
+ } catch {
+ case _:InterruptedException =>
+ case e:Throwable => logWarning(e.getLocalizedMessage)
+ }
}
-
- memoryPool.close()
- clientConnections.values.foreach(ep => ep.close())
- registeredBlocks.forEachKey(1, blockId => unregister(blockId))
- allocatedWorkers.forEach(_.close())
- globalWorker.close()
- ucxContext.close()
}
/**
@@ -147,73 +206,158 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
* connection establishment outside of UcxShuffleManager.
*/
def addExecutor(executorId: String, workerAddress: ByteBuffer): Unit = {
- executorIdToAddress.put(executorId, workerAddress)
- }
-
- private[ucx] def handlePrefetchRequest(workerId: String, workerAddress: ByteBuffer,
- blockIds: Seq[BlockId]) {
-
- logInfo(s"Prefetching blocks: ${blockIds.mkString(",")}")
- clientConnections.getOrElseUpdate(workerId,
- globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress))
- )
-
- blockIds.foreach(blockId => {
- val block = registeredBlocks.get(blockId)
- if (!block.isInstanceOf[UcxPinnedBlock]) {
- registeredBlocks.put(blockId, UcxPinnedBlock(block, pinMemory(block), prefetched = true))
- }
+ if (ucxShuffleConf.useSockAddr) {
+ executorIdToSockAddress.put(executorId, SerializationUtils.deserializeInetAddress(workerAddress))
+ } else {
+ executorIdToAddress.put(executorId, workerAddress)
+ }
+ logInfo(s"Pre connection on ${allocatedWorkers.size()} workers")
+ allocatedWorkers.forEach(w => {
+ val ep = w.getConnection(executorId)
+ w.preConnect(ep)
})
}
/**
* On a sender side process request of fetchBlockByBlockId
*/
- private[ucx] def replyFetchBlockRequest(workerId: String, workerAddress: ByteBuffer,
- blockId: BlockId, tag: Long): Unit = {
+ private[ucx] def replyFetchBlocksRequest(blockIds: Seq[BlockId], isRecvToHostMemory: Seq[Boolean],
+ ep: UcpEndpoint, startTag: Int, singleReply: Boolean): Unit = {
+ if (singleReply) {
+ var blockAddress = 0L
+ var blockSize = 0L
+ var responseBlock: MemoryBlock = null
+ var headerMem: MemoryBlock = null
+ var responseBlockBuff: ByteBuffer = null
+
+ val totalSize = ucxShuffleConf.maxMetadataSize * blockIds.length
+ if (totalSize + 4 < globalWorker.getMaxAmHeaderSize) {
+ headerMem = hostMemoryPool.get(totalSize + 4L)
+ responseBlockBuff = UcxUtils.getByteBufferView(headerMem.address, headerMem.size)
+ responseBlockBuff.putInt(startTag)
+ logInfo(s"Sending ${blockIds.mkString(",")} in header")
+ } else {
+ headerMem = hostMemoryPool.get(4L)
+ val headerBuf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size)
+ headerBuf.putInt(startTag)
+ responseBlock = hostMemoryPool.get(totalSize)
+ responseBlockBuff = UcxUtils.getByteBufferView(responseBlock.address, responseBlock.size)
+ blockAddress = responseBlock.address
+ blockSize = responseBlock.size
+ logInfo(s"Sending ${blockIds.length} blocks in data")
+ }
+ val locks = new Array[Lock](blockIds.size)
+ for (i <- blockIds.indices) {
+ val block = registeredBlocks.get(blockIds(i))
+ locks(i) = block.lock.readLock()
+ locks(i).lock()
+ val blockMemory = block.getMemoryBlock
+ require(blockMemory.isHostMemory && blockMemory.size <= ucxShuffleConf.maxMetadataSize)
+ val blockBuffer = UcxUtils.getByteBufferView(blockMemory.address, blockMemory.size)
+ responseBlockBuff.putInt(blockMemory.size.toInt)
+ responseBlockBuff.put(blockBuffer)
+ }
+ ep.sendAmNonBlocking(0, headerMem.address, headerMem.size, blockAddress, blockSize, 0L,
+ new UcxCallback {
+ private val startTime = System.nanoTime()
+ override def onSuccess(request: UcpRequest): Unit = {
+ logInfo(s"Sent ${blockIds.length} blocks of size $totalSize" +
+ s" in ${Utils.getUsedTimeNs(startTime)}")
+ locks.foreach(_.unlock())
+ hostMemoryPool.put(headerMem)
+ if (responseBlock != null) {
+ hostMemoryPool.put(responseBlock)
+ }
+ }
- val ep = clientConnections.getOrElseUpdate(workerId,
- globalWorker.newEndpoint(new UcpEndpointParams().setUcpAddress(workerAddress))
- )
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ logError(s"Failed to send ${blockIds.mkString(",")}: $errorMsg")
+ locks.foreach(_.unlock())
+ hostMemoryPool.put(headerMem)
+ if (responseBlock != null) {
+ hostMemoryPool.put(responseBlock)
+ }
+ }
+ }, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
- val block = registeredBlocks.get(blockId)
- if (block == null) {
- throw new UcxException(s"Block $blockId not registered")
- }
- val lock = block.lock.readLock()
- lock.lock()
- val blockMemory = block.getMemoryBlock
+ } else {
+ for (i <- blockIds.indices) {
+ val blockId = blockIds(i)
+ val block = registeredBlocks.get(blockId)
- logInfo(s"Sending $blockId of size ${blockMemory.size} to tag: $tag")
- ep.sendTaggedNonBlocking(blockMemory.address, blockMemory.size, tag, new UcxCallback {
- override def onSuccess(request: UcpRequest): Unit = {
- if (block.isInstanceOf[UcxPinnedBlock]) {
- val pinnedBlock = block.asInstanceOf[UcxPinnedBlock]
- if (pinnedBlock.prefetched) {
- registeredBlocks.put(blockId, pinnedBlock.block)
- pinnedBlock.ucpMemory.deregister()
- }
+ if (block == null) {
+ throw new UcxException(s"Block $blockId is not registered")
+ }
+
+ val lock = block.lock.readLock()
+ lock.lock()
+
+ val blockMemory = block.getMemoryBlock
+
+ var blockAddress = 0L
+ var blockSize = 0L
+ var headerMem: MemoryBlock = null
+
+ if (blockMemory.isHostMemory && // The block itself in host memory
+ (blockMemory.size + 4 < globalWorker.getMaxAmHeaderSize) && // and can fit to header
+ isRecvToHostMemory(i) ) {
+ headerMem = hostMemoryPool.get(4L + blockMemory.size)
+ val buf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size)
+ buf.putInt(startTag + i)
+ val blockBuf = UcxUtils.getByteBufferView(blockMemory.address, blockMemory.size)
+ buf.put(blockBuf)
+ } else {
+ headerMem = hostMemoryPool.get(4L)
+ val buf = UcxUtils.getByteBufferView(headerMem.address, headerMem.size)
+ buf.putInt(startTag + i)
+ blockAddress = blockMemory.address
+ blockSize = blockMemory.size
}
- lock.unlock()
- }
- override def onError(ucsStatus: Int, errorMsg: String): Unit = {
- logError(s"Failed to send $blockId: $errorMsg")
- lock.unlock()
+ ep.sendAmNonBlocking(0, headerMem.address, headerMem.size, blockAddress, blockSize, 0L,
+ new UcxCallback {
+ private val startTime = System.nanoTime()
+ override def onSuccess(request: UcpRequest): Unit = {
+ logInfo(s"Sent $blockId of size ${blockMemory.size} " +
+ s"memType: ${if (blockMemory.isHostMemory) "host" else "gpu"}" +
+ s" in ${Utils.getUsedTimeNs(startTime)}")
+ lock.unlock()
+ hostMemoryPool.put(headerMem)
+ }
+
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ logError(s"Failed to send $blockId-$blockMemory: $errorMsg")
+ lock.unlock()
+ hostMemoryPool.put(headerMem)
+ }
+ },
+ if (blockMemory.isHostMemory) UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST
+ else UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA)
}
- })
+ }
+
}
private def pinMemory(block: Block): UcpMemory = {
+ val startTime = System.nanoTime()
val blockMemory = block.getMemoryBlock
- ucxContext.memoryMap(
- memMapParams.setAddress(blockMemory.address).setLength(blockMemory.size))
+ val memType = if (blockMemory.isHostMemory) {
+ UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST
+ } else {
+ UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA
+ }
+ val result = ucxContext.memoryMap(
+ memMapParams.setAddress(blockMemory.address).setLength(blockMemory.size)
+ .setMemoryType(memType))
+ logInfo(s"Pinning memory of size: ${blockMemory.size} took: ${Utils.getUsedTimeNs(startTime)}")
+ result
}
/**
* Registers blocks using blockId on SERVER side.
*/
override def register(blockId: BlockId, block: Block): Unit = {
+ logTrace(s"Registering $blockId of size: ${block.getMemoryBlock.size}")
val registeredBock: Block = if (ucxShuffleConf.pinMemory) {
UcxPinnedBlock(block, pinMemory(block))
} else {
@@ -226,23 +370,16 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
* Change location of underlying blockId in memory
*/
override def mutate(blockId: BlockId, block: Block, callback: OperationCallback): Unit = {
- Future {
- unregister(blockId)
- register(blockId, block)
- } andThen {
- case Failure(t) => if (callback != null) {
- callback.onComplete(new UcxFailureOperationResult(t.getMessage))
- }
- case Success(_) => if (callback != null) {
- callback.onComplete(new UcxSuccessOperationResult(new UcxStats))
- }
- }
+ unregister(blockId)
+ register(blockId, block)
+ callback.onComplete(new UcxSuccessOperationResult(new UcxStats))
}
/**
* Indicate that this blockId is not needed any more by an application
*/
override def unregister(blockId: BlockId): Unit = {
+ logInfo(s"Unregistering $blockId")
val block = registeredBlocks.remove(blockId)
if (block != null) {
block.lock.writeLock().lock()
@@ -256,15 +393,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
}
}
- /**
- * Fetch remote blocks by blockIds.
- */
- override def fetchBlockByBlockId(executorId: String, blockId: BlockId,
- resultBuffer: MemoryBlock,
- cb: OperationCallback): UcxRequest = {
- threadLocalWorker.get().fetchBlockByBlockId(executorId, blockId, resultBuffer, cb)
- }
-
/**
* Progress outstanding operations. This routine is blocking. It's important to call this routine
* within same thread that submitted requests.
@@ -278,13 +406,6 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
executorIdToAddress.remove(executorId)
}
- /**
- * Hint for a transport that these blocks would needed soon.
- */
- override def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = {
- threadLocalWorker.get().prefetchBlocks(executorId, blockIds)
- }
-
/**
* Batch version of [[ fetchBlocksByBlockIds ]].
*/
@@ -293,4 +414,9 @@ class UcxShuffleTransport(var ucxShuffleConf: UcxShuffleConf = null, val executo
callbacks: Seq[OperationCallback]): Seq[Request] = {
threadLocalWorker.get().fetchBlocksByBlockIds(executorId, blockIds, resultBuffer, callbacks)
}
+
+ override def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId], resultBuffer: MemoryBlock,
+ callback: OperationCallback): Request = {
+ threadLocalWorker.get().fetchBlocksByBlockIds(executorId, blockIds, resultBuffer, callback)
+ }
}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala
index 2b55eb06..e22439bd 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/UcxWorkerWrapper.scala
@@ -5,19 +5,20 @@
package org.apache.spark.shuffle.ucx
import java.io.{Closeable, ObjectOutputStream}
-import java.util.concurrent.ThreadLocalRandom
+import java.net.InetSocketAddress
+import java.nio.ByteBuffer
+import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable
-import com.fasterxml.jackson.databind.util.ByteBufferBackedOutputStream
-import org.openucx.jucx.ucp.{UcpEndpoint, UcpEndpointParams, UcpRequest, UcpWorker}
+import ai.rapids.cudf.DeviceMemoryBuffer
+import org.openucx.jucx.ucp._
+import org.openucx.jucx.ucs.UcsConstants
import org.openucx.jucx.{UcxCallback, UcxException, UcxUtils}
import org.apache.spark.internal.Logging
import org.apache.spark.shuffle.ucx.memory.MemoryPool
-import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages
-import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds}
-import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
-import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.FetchBlocksByBlockIdsRequest
+import org.apache.spark.util.{ByteBufferOutputStream, Utils}
/**
* Success operation result subclass that has operation stats.
@@ -46,14 +47,53 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
memoryPool: MemoryPool)
extends Closeable with Logging {
- // To keep connection map on a remote side by id rather then by worker address, which could be big.
- // Would not need when migrate to active messages.
- private val id: String = transport.executorId + s"_${Thread.currentThread().getId}"
+ private var tag: Int = 1
private final val connections = mutable.Map.empty[String, UcpEndpoint]
- private val workerAddress = worker.getAddress
+ private val requestData = new ConcurrentHashMap[Int, (MemoryBlock, UcxCallback, UcxRequest)]
+
+ worker.setAmRecvHandler(0, (headerAddress: Long, headerSize: Long, amData: UcpAmData,
+ replyEp: UcpEndpoint) => {
+ val headerBuffer = UcxUtils.getByteBufferView(headerAddress, headerSize)
+ val i = headerBuffer.getInt
+ val data = requestData.remove(i)
+ if (data == null) {
+ throw new UcxException(s"No data for tag $i")
+ }
+ val (resultMemory, callback, ucxRequest) = data
+ logDebug(s"Received message for tag $i with headerSize $headerSize. " +
+ s" AmData: ${amData} resultMemory(isHost): ${resultMemory.isHostMemory}")
+ ucxRequest.getStats.foreach(s => s.asInstanceOf[UcxStats].amHandleTime = System.nanoTime())
+ if (headerSize > 4L) {
+ val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, resultMemory.size)
+ resultBuffer.put(headerBuffer)
+ ucxRequest.completed = true
+ if (callback != null) {
+ ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = headerSize - 4
+ callback.onSuccess(null)
+ }
+ } else if (amData.isDataValid && resultMemory.isHostMemory) {
+ val resultBuffer = UcxUtils.getByteBufferView(resultMemory.address, resultMemory.size)
+ resultBuffer.put(UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength))
+ ucxRequest.completed = true
+ if (callback != null) {
+ ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = amData.getLength
+ callback.onSuccess(null)
+ }
+ } else {
+ require(amData.getLength <= resultMemory.size, s"${amData.getLength} < ${resultMemory.size}")
+ val request = worker.recvAmDataNonBlocking(amData.getDataHandle, resultMemory.address,
+ amData.getLength, callback,
+ if (resultMemory.isHostMemory) UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST else
+ UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA)
+ ucxRequest.getStats.get.asInstanceOf[UcxStats].receiveSize = amData.getLength
+ ucxRequest.setRequest(request)
+ }
+ UcsConstants.STATUS.UCS_OK
+ })
+
override def close(): Unit = {
- connections.foreach{
+ connections.foreach {
case (_, endpoint) => endpoint.close()
}
connections.clear()
@@ -63,23 +103,28 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
/**
* The only place for worker progress
*/
- private[ucx] def progress(): Unit = {
+ private[ucx] def progress(): Unit = this.synchronized {
if ((worker.progress() == 0) && ucxConf.useWakeup) {
worker.waitForEvents()
worker.progress()
}
}
- private def getConnection(executorId: String): UcpEndpoint = {
- val workerAdresses = transport.executorIdToAddress
+ private[ucx] def getConnection(executorId: String): UcpEndpoint = synchronized {
+ val workerAddresses = if (ucxConf.useSockAddr) {
+ transport.executorIdToSockAddress
+ } else {
+ transport.executorIdToAddress
+ }
- if (!workerAdresses.contains(executorId)) {
+ if (!workerAddresses.contains(executorId)) {
// Block until there's no worker address for this BlockManagerID
val startTime = System.currentTimeMillis()
val timeout = ucxConf.conf.getTimeAsMs("spark.network.timeout", "100")
- workerAdresses.synchronized {
- while (workerAdresses.get(executorId) == null) {
- workerAdresses.wait(timeout)
+ workerAddresses.synchronized {
+ while (workerAddresses.get(executorId) == null) {
+ logWarning(s"No workerAddress for executor $executorId")
+ workerAddresses.wait(timeout)
if (System.currentTimeMillis() - startTime > timeout) {
throw new UcxException(s"Didn't get worker address for $executorId during $timeout")
}
@@ -89,146 +134,200 @@ class UcxWorkerWrapper(worker: UcpWorker, transport: UcxShuffleTransport, ucxCon
connections.getOrElseUpdate(executorId, {
logInfo(s"Worker from thread ${Thread.currentThread().getName} connecting to $executorId")
- val endpointParams = new UcpEndpointParams()
- .setUcpAddress(workerAdresses.get(executorId))
- worker.newEndpoint(endpointParams)
+ val endpointParams = new UcpEndpointParams().setPeerErrorHandlingMode()
+ .setErrorHandler(new UcpEndpointErrorHandler() {
+ override def onError(ep: UcpEndpoint, status: Int, errorMsg: String): Unit = {
+ logError(errorMsg)
+ }
+ })
+ if (ucxConf.useSockAddr) {
+ val sockAddr = workerAddresses.get(executorId).asInstanceOf[InetSocketAddress]
+ endpointParams.setSocketAddress(sockAddr)
+ } else {
+ endpointParams.setUcpAddress(workerAddresses.get(executorId).asInstanceOf[ByteBuffer])
+ }
+
+ worker.newEndpoint(endpointParams)
})
}
- private[ucx] def prefetchBlocks(executorId: String, blockIds: Seq[BlockId]): Unit = {
- logInfo(s"Sending prefetch ${blockIds.length} blocks to $executorId")
-
- val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
- val buffer = UcxUtils.getByteBufferView(mem.address, transport.ucxShuffleConf.rpcMessageSize.toInt)
-
- workerAddress.rewind()
- val message = PrefetchBlockIds(id, new SerializableDirectBuffer(workerAddress), blockIds)
- Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos =>
- val out = new ObjectOutputStream(bos)
- out.writeObject(message)
- out.flush()
- out.close()
+ def preConnect(ep: UcpEndpoint): Unit = this.synchronized {
+ val hostBuffer = memoryPool.get(4)
+ val gpuBuffer = DeviceMemoryBuffer.allocate(1)
+ val req = ep.sendAmNonBlocking(-1, hostBuffer.address, hostBuffer.size,
+ gpuBuffer.getAddress, gpuBuffer.getLength, UcpConstants.UCP_AM_SEND_FLAG_REPLY, null)
+ while (!req.isCompleted) {
+ progress()
}
-
- val ep = getConnection(executorId)
-
- ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize,
- UcxRpcMessages.PREFETCH_TAG, new UcxCallback() {
- override def onSuccess(request: UcpRequest): Unit = {
- logTrace(s"Sent prefetch ${blockIds.length} blocks to $executorId")
- memoryPool.put(mem)
- }
- })
+ gpuBuffer.close()
+ memoryPool.put(hostBuffer)
}
private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
- resultBuffer: Seq[MemoryBlock],
- callbacks: Seq[OperationCallback]): Seq[Request] = {
+ resultBuffers: Seq[MemoryBlock],
+ callbacks: Seq[OperationCallback]): Seq[Request] = this.synchronized {
+
+ val startTime = System.nanoTime()
val ep = getConnection(executorId)
- val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
- val buffer = UcxUtils.getByteBufferView(mem.address,
- transport.ucxShuffleConf.rpcMessageSize.toInt)
- workerAddress.rewind()
- val message = FetchBlocksByBlockIdsRequest(id, new SerializableDirectBuffer(workerAddress),
- blockIds)
+ val message = FetchBlocksByBlockIdsRequest(tag, blockIds, resultBuffers.map(_.isHostMemory).toArray)
- Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos =>
- val out = new ObjectOutputStream(bos)
+ val bos = new ByteBufferOutputStream(1000)
+ val out = new ObjectOutputStream(bos)
+ try {
out.writeObject(message)
out.flush()
out.close()
+ } catch {
+ case ex: Exception => throw new UcxException(ex.getMessage)
}
- val tag = ThreadLocalRandom.current().nextLong(Long.MinValue, 0)
- logInfo(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
- ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag,
- new UcxCallback() {
- override def onSuccess(request: UcpRequest): Unit = {
- memoryPool.put(mem)
- }
- })
+ val msgSize = bos.getCount()
+ val mem = memoryPool.get(msgSize)
+ val buffer = UcxUtils.getByteBufferView(mem.address, msgSize)
+
+ buffer.put(bos.toByteBuffer)
- val requests = new Array[UcxRequest](blockIds.length)
+ val requests = new Array[UcxRequest](blockIds.size)
for (i <- blockIds.indices) {
val stats = new UcxStats()
val result = new UcxSuccessOperationResult(stats)
- val request = worker.recvTaggedNonBlocking(resultBuffer(i).address, resultBuffer(i).size,
- tag + i, -1L, new UcxCallback () {
-
- override def onError(ucsStatus: Int, errorMsg: String): Unit = {
- logError(s"Failed to receive blockId ${blockIds(i)} on tag: $tag, from executorId: $executorId " +
- s" of size: ${resultBuffer.size}: $errorMsg")
- if (callbacks(i) != null ) {
- callbacks(i).onComplete(new UcxFailureOperationResult(errorMsg))
- }
+ requests(i) = new UcxRequest(null, stats, worker)
+ val callback = if (callbacks.isEmpty) null else new UcxCallback() {
+
+ override def onError(ucsStatus: Int, errorMsg: String): Unit = {
+ logError(s"Failed to receive blockId ${blockIds(i)} on tag: $tag, from executorId: $executorId " +
+ s" of size: ${resultBuffers(i).size}: $errorMsg")
+ if (callbacks(i) != null) {
+ callbacks(i).onComplete(new UcxFailureOperationResult(errorMsg))
}
+ }
- override def onSuccess(request: UcpRequest): Unit = {
- stats.endTime = System.nanoTime()
- stats.receiveSize = request.getRecvSize
- if (callbacks(i) != null) {
- callbacks(i).onComplete(result)
- }
+ override def onSuccess(request: UcpRequest): Unit = {
+ stats.endTime = System.nanoTime()
+ logInfo(s"Received block ${blockIds(i)} from $executorId " +
+ s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}. Time from amHanlde to recv: " +
+ s"${Utils.getUsedTimeNs(stats.amHandleTime)}")
+ if (callbacks(i) != null) {
+ callbacks(i).onComplete(result)
}
- })
- requests(i) = new UcxRequest(request, stats)
+ }
+ }
+ requestData.put(tag + i, (resultBuffers(i), callback, requests(i)))
+ }
+
+ logDebug(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
+ var headerAddress = 0L
+ var headerSize = 0L
+ var dataAddress = 0L
+ var dataSize = 0L
+
+ if (msgSize <= worker.getMaxAmHeaderSize) {
+ headerAddress = mem.address
+ headerSize = msgSize
+ } else {
+ dataAddress = mem.address
+ dataSize = msgSize
+ }
+ worker.progressRequest(ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize,
+ UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() {
+ override def onSuccess(request: UcpRequest): Unit = {
+ logInfo(s"Sending RPC message of size: ${headerSize + dataSize} " +
+ s"to $executorId to fetch ${blockIds.length} blocks on starting tag tag $tag took: " +
+ s"${Utils.getUsedTimeNs(startTime)}")
+ memoryPool.put(mem)
+ }
+ }))
+
+ for (i <- blockIds.indices) {
+ while (requestData.contains(tag + i)) {
+ progress()
+ }
}
+
+ logInfo(s"FetchBlocksByBlockIds data took: ${Utils.getUsedTimeNs(startTime)}")
+ tag += blockIds.length
requests
}
- private[ucx] def fetchBlockByBlockId(executorId: String, blockId: BlockId,
- resultBuffer: MemoryBlock, cb: OperationCallback): UcxRequest = {
+ private[ucx] def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
+ resultBuffer: MemoryBlock,
+ callback: OperationCallback): Request = this.synchronized {
val stats = new UcxStats()
val ep = getConnection(executorId)
- val mem = memoryPool.get(transport.ucxShuffleConf.rpcMessageSize)
- val buffer = UcxUtils.getByteBufferView(mem.address,
- transport.ucxShuffleConf.rpcMessageSize.toInt)
- val tag = ThreadLocalRandom.current().nextLong(2, Long.MaxValue)
- workerAddress.rewind()
- val message = FetchBlockByBlockIdRequest(id, new SerializableDirectBuffer(workerAddress), blockId)
+ val message = FetchBlocksByBlockIdsRequest(tag, blockIds, Seq.empty[Boolean], singleReply = true)
- Utils.tryWithResource(new ByteBufferBackedOutputStream(buffer)) { bos =>
- val out = new ObjectOutputStream(bos)
+ val bos = new ByteBufferOutputStream(1000)
+ val out = new ObjectOutputStream(bos)
+ try {
out.writeObject(message)
out.flush()
out.close()
+ } catch {
+ case ex: Exception => throw new UcxException(ex.getMessage)
}
+ val msgSize = bos.getCount()
- logTrace(s"Sending message to $executorId to fetch $blockId on tag $tag," +
- s"resultBuffer $resultBuffer")
- ep.sendTaggedNonBlocking(mem.address, transport.ucxShuffleConf.rpcMessageSize, tag,
- new UcxCallback() {
- override def onSuccess(request: UcpRequest): Unit = {
- memoryPool.put(mem)
- }
- })
+ val mem = memoryPool.get(msgSize)
+ val buffer = UcxUtils.getByteBufferView(mem.address, msgSize)
+
+ buffer.put(bos.toByteBuffer)
+ val request = new UcxRequest(null, stats, worker)
val result = new UcxSuccessOperationResult(stats)
- val request = worker.recvTaggedNonBlocking(resultBuffer.address, resultBuffer.size,
- tag, -1L, new UcxCallback () {
+ val requestCallback = new UcxCallback() {
override def onError(ucsStatus: Int, errorMsg: String): Unit = {
- logError(s"Failed to receive blockId $blockId on tag: $tag, from executorId: $executorId " +
+ logError(s"Failed to receive blockId ${blockIds.mkString(",")} on tag: $tag, from executorId: $executorId " +
s" of size: ${resultBuffer.size}: $errorMsg")
- if (cb != null ) {
- cb.onComplete(new UcxFailureOperationResult(errorMsg))
+ if (callback != null) {
+ callback.onComplete(new UcxFailureOperationResult(errorMsg))
}
}
override def onSuccess(request: UcpRequest): Unit = {
stats.endTime = System.nanoTime()
- stats.receiveSize = request.getRecvSize
- if (cb != null) {
- cb.onComplete(result)
+ logInfo(s"Received ${blockIds.length} metadata blocks from $executorId " +
+ s"of size: ${stats.receiveSize} in ${Utils.getUsedTimeNs(stats.startTime)}")
+ if (callback != null) {
+ callback.onComplete(result)
}
}
- })
- new UcxRequest(request, stats)
+
+ }
+ requestData.put(tag, (resultBuffer, requestCallback, request))
+
+ logDebug(s"Sending message to $executorId to fetch ${blockIds.length} blocks on tag $tag")
+ var headerAddress = 0L
+ var headerSize = 0L
+ var dataAddress = 0L
+ var dataSize = 0L
+
+ if (msgSize <= worker.getMaxAmHeaderSize) {
+ headerAddress = mem.address
+ headerSize = msgSize
+ } else {
+ dataAddress = mem.address
+ dataSize = msgSize
+ }
+ worker.progressRequest(ep.sendAmNonBlocking(0, headerAddress, headerSize, dataAddress, dataSize,
+ UcpConstants.UCP_AM_SEND_FLAG_REPLY, new UcxCallback() {
+ override def onSuccess(request: UcpRequest): Unit = {
+ memoryPool.put(mem)
+ }
+ }))
+
+ while (requestData.contains(tag)) {
+ progress()
+ }
+
+ logInfo(s"FetchBlocksByBlockIds metadata took: ${Utils.getUsedTimeNs(stats.startTime)}")
+ tag += 1
+ request
}
}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala
new file mode 100755
index 00000000..b8e1f6ee
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleExecutorComponents.scala
@@ -0,0 +1,47 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx.io
+
+import java.util
+import java.util.Optional
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, SingleSpillShuffleMapOutputWriter}
+import org.apache.spark.shuffle.sort.io.{LocalDiskShuffleExecutorComponents, LocalDiskShuffleMapOutputWriter, LocalDiskSingleSpillMapOutputWriter}
+import org.apache.spark.shuffle.ucx.{UcxShuffleBlockResolver, UcxShuffleManager, UcxShuffleTransport}
+import org.apache.spark.{SparkConf, SparkEnv}
+
+
+class UcxShuffleExecutorComponents(sparkConf: SparkConf)
+ extends LocalDiskShuffleExecutorComponents(sparkConf) with Logging {
+
+ var ucxShuffleTransport: UcxShuffleTransport = _
+ private var blockResolver: UcxShuffleBlockResolver = _
+
+ override def initializeExecutor(appId: String, execId: String, extraConfigs: util.Map[String, String]): Unit = {
+ val ucxShuffleManager = SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager]
+ ucxShuffleManager.initTransport()
+ blockResolver = ucxShuffleManager.shuffleBlockResolver
+ }
+
+ override def createMapOutputWriter(shuffleId: Int, mapTaskId: Long, numPartitions: Int): ShuffleMapOutputWriter = {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.")
+ }
+ new LocalDiskShuffleMapOutputWriter(
+ shuffleId, mapTaskId, numPartitions, blockResolver, sparkConf)
+ }
+
+ override def createSingleFileMapOutputWriter(shuffleId: Int, mapId: Long):
+ Optional[SingleSpillShuffleMapOutputWriter] = {
+ if (blockResolver == null) {
+ throw new IllegalStateException(
+ "Executor components must be initialized before getting writers.")
+ }
+ Optional.of(new LocalDiskSingleSpillMapOutputWriter(shuffleId, mapId, blockResolver))
+ }
+
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala
new file mode 100755
index 00000000..04f03b44
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/io/UcxShuffleIO.scala
@@ -0,0 +1,26 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx.io
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD
+import org.apache.spark.shuffle.api.{ShuffleDriverComponents, ShuffleExecutorComponents}
+import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO
+import org.apache.spark.shuffle.ucx.UcxShuffleManager
+
+class UcxShuffleIO(sparkConf: SparkConf) extends LocalDiskShuffleDataIO(sparkConf) with Logging {
+
+ sparkConf.set(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key, "100000")
+
+ override def driver(): ShuffleDriverComponents = {
+ SparkEnv.get.shuffleManager.asInstanceOf[UcxShuffleManager].initTransport()
+ super.driver()
+ }
+
+ override def executor(): ShuffleExecutorComponents = {
+ new UcxShuffleExecutorComponents(sparkConf)
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala
index 72c18fc7..65da481f 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/memory/MemoryPool.scala
@@ -5,14 +5,132 @@
package org.apache.spark.shuffle.ucx.memory
import java.io.Closeable
+import java.util.concurrent.atomic.AtomicInteger
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque}
+
+import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory}
+import org.openucx.jucx.ucs.UcsConstants
+import org.apache.spark.internal.Logging
+import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleConf}
+import org.apache.spark.util.Utils
+
+class UcxBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger,
+ override val address: Long, override val size: Long)
+ extends MemoryBlock(address, size, memory.getMemType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST)
-import org.apache.spark.shuffle.ucx.MemoryBlock
/**
* Base class to implement memory pool
*/
-abstract class MemoryPool extends Closeable {
- def get(size: Long): MemoryBlock
+case class MemoryPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext, memoryType: Int)
+ extends Closeable with Logging {
+
+ protected def roundUpToTheNextPowerOf2(size: Long): Long = {
+ // Round up length to the nearest power of two
+ var length = size
+ length -= 1
+ length |= length >> 1
+ length |= length >> 2
+ length |= length >> 4
+ length |= length >> 8
+ length |= length >> 16
+ length += 1
+ length
+ }
+
+ protected val allocatorMap = new ConcurrentHashMap[Long, AllocatorStack]()
+
+ protected case class AllocatorStack(length: Long, memType: Int) extends Closeable {
+ logInfo(s"Allocator stack of memType: $memType and size $length")
+ private val stack = new ConcurrentLinkedDeque[UcxBounceBufferMemoryBlock]
+ private val numAllocs = new AtomicInteger(0)
+ private val memMapParams = new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length)
+
+ private[memory] def get: UcxBounceBufferMemoryBlock = {
+ var result = stack.pollFirst()
+ if (result == null) {
+ numAllocs.incrementAndGet()
+ if (length < ucxShuffleConf.minRegistrationSize) {
+ preallocate((ucxShuffleConf.minRegistrationSize / length).toInt)
+ result = stack.pollFirst()
+ } else {
+ logInfo(s"Allocating buffer of size $length")
+ val memory = ucxContext.memoryMap(memMapParams)
+ result = new UcxBounceBufferMemoryBlock(memory, new AtomicInteger(1),
+ memory.getAddress, length)
+ }
+ }
+ result
+ }
+
+ private[memory] def put(block: UcxBounceBufferMemoryBlock): Unit = {
+ stack.add(block)
+ }
+
+ private[memory] def preallocate(numBuffers: Int): Unit = {
+ logInfo(s"PreAllocating $numBuffers of size $length, " +
+ s"totalSize: ${Utils.bytesToString(length * numBuffers) }")
+ val memory = ucxContext.memoryMap(
+ new UcpMemMapParams().allocate().setMemoryType(memType).setLength(length * numBuffers))
+ val refCount = new AtomicInteger(numBuffers)
+ var offset = 0L
+ (0 until numBuffers).foreach(_ => {
+ stack.add(new UcxBounceBufferMemoryBlock(memory, refCount, memory.getAddress + offset, length))
+ offset += length
+ })
+ }
- def put(mem: MemoryBlock): Unit
+ override def close(): Unit = {
+ var numBuffers = 0
+ stack.forEach(block => {
+ block.refCount.decrementAndGet()
+ if (block.memory.getNativeId != null) {
+ block.memory.deregister()
+ }
+ numBuffers += 1
+ })
+ logInfo(s"Closing $numBuffers buffers of size $length." +
+ s"Number of allocations: ${numAllocs.get()}")
+ stack.clear()
+ }
+ }
+
+ override def close(): Unit = {
+ allocatorMap.values.forEach(allocator => allocator.close())
+ allocatorMap.clear()
+ }
+
+ def get(size: Long): MemoryBlock = {
+ val roundedSize = roundUpToTheNextPowerOf2(size)
+ val allocatorStack = allocatorMap.computeIfAbsent(roundedSize,
+ s => AllocatorStack(s, memoryType))
+ val result = allocatorStack.get
+ new UcxBounceBufferMemoryBlock(result.memory, result.refCount, result.address, size)
+ }
+
+ def put(mem: MemoryBlock): Unit = {
+ mem match {
+ case m: UcxBounceBufferMemoryBlock =>
+ val allocatorStack = allocatorMap.get(roundUpToTheNextPowerOf2(mem.size))
+ allocatorStack.put(m)
+ case _ =>
+ }
+ }
+
+ def preAllocate(size: Long, numBuffers: Int): Unit = {
+ val roundedSize = roundUpToTheNextPowerOf2(size)
+ val allocatorStack = allocatorMap.computeIfAbsent(roundedSize,
+ s => AllocatorStack(s, memoryType))
+ allocatorStack.preallocate(numBuffers)
+ }
}
+
+class UcxHostBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext)
+ extends MemoryPool(ucxShuffleConf, ucxContext, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) {
+ ucxShuffleConf.preallocateBuffersMap.foreach{
+ case (size, numBuffers) => preAllocate(size, numBuffers)
+ }
+}
+
+class UcxGpuBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext)
+ extends MemoryPool(ucxShuffleConf, ucxContext, UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA)
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala
deleted file mode 100755
index 90e2f6aa..00000000
--- a/src/main/scala/org/apache/spark/shuffle/ucx/memory/UcxHostBounceBuffersPool.scala
+++ /dev/null
@@ -1,116 +0,0 @@
-/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
-* See file LICENSE for terms.
-*/
-package org.apache.spark.shuffle.ucx.memory
-
-import java.io.Closeable
-import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedDeque}
-
-import org.openucx.jucx.ucp.{UcpContext, UcpMemMapParams, UcpMemory}
-import org.apache.spark.internal.Logging
-import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleConf}
-
-/**
- * Pre-registered host bounce buffers.
- * TODO: support reclamation
- */
-class UcxHostBounceBufferMemoryBlock(private[ucx] val memory: UcpMemory, private[ucx] val refCount: AtomicInteger,
- override val address: Long, override val size: Long)
- extends MemoryBlock(address, size)
-
-class UcxHostBounceBuffersPool(ucxShuffleConf: UcxShuffleConf, ucxContext: UcpContext)
- extends MemoryPool with Logging {
-
- private val allocatorMap = new ConcurrentHashMap[Long, AllocatorStack]()
-
- ucxShuffleConf.preallocateBuffersMap.foreach {
- case (size, numBuffers) =>
- val roundedSize = roundUpToTheNextPowerOf2(size)
- logDebug(s"Pre allocating $numBuffers buffers of size $roundedSize")
- val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, s => AllocatorStack(s))
- allocatorStack.preallocate(numBuffers)
- }
-
- private case class AllocatorStack(length: Long) extends Closeable {
- private val stack = new ConcurrentLinkedDeque[UcxHostBounceBufferMemoryBlock]
- private val numAllocs = new AtomicInteger(0)
-
- private[UcxHostBounceBuffersPool] def get: UcxHostBounceBufferMemoryBlock = {
- var result = stack.pollFirst()
- if (result == null) {
- numAllocs.incrementAndGet()
- if (length < ucxShuffleConf.minRegistrationSize) {
- preallocate((ucxShuffleConf.minRegistrationSize / length).toInt)
- result = stack.pollFirst()
- } else {
- val memory = ucxContext.memoryMap(new UcpMemMapParams().allocate().setLength(length))
- result = new UcxHostBounceBufferMemoryBlock(memory, new AtomicInteger(1),
- memory.getAddress, length)
- }
- }
- result
- }
-
- private[UcxHostBounceBuffersPool] def put(block: UcxHostBounceBufferMemoryBlock): Unit = {
- stack.add(block)
- }
-
- private[ucx] def preallocate(numBuffers: Int): Unit = {
- val memory = ucxContext.memoryMap(
- new UcpMemMapParams().allocate().setLength(length * numBuffers))
- val refCount = new AtomicInteger(numBuffers)
- var offset = 0L
- (0 until numBuffers).foreach(_ => {
- stack.add(new UcxHostBounceBufferMemoryBlock(memory, refCount, memory.getAddress + offset, length))
- offset += length
- })
- }
-
- override def close(): Unit = {
- var numBuffers = 0
- stack.forEach(block => {
- block.refCount.decrementAndGet()
- if (block.memory.getNativeId != null) {
- block.memory.deregister()
- }
- numBuffers += 1
- })
- logInfo(s"Closing $numBuffers buffers of size $length." +
- s"Number of allocations: ${numAllocs.get()}")
- stack.clear()
- }
- }
-
- private def roundUpToTheNextPowerOf2(size: Long): Long = {
- // Round up length to the nearest power of two
- var length = size
- length -= 1
- length |= length >> 1
- length |= length >> 2
- length |= length >> 4
- length |= length >> 8
- length |= length >> 16
- length += 1
- length
- }
-
- override def get(size: Long): MemoryBlock = {
- val roundedSize = roundUpToTheNextPowerOf2(size)
- val allocatorStack = allocatorMap.computeIfAbsent(roundedSize, s => AllocatorStack(s))
- allocatorStack.get
- }
-
- override def put(mem: MemoryBlock): Unit = {
- val allocatorStack = allocatorMap.computeIfAbsent(roundUpToTheNextPowerOf2(mem.size),
- s => AllocatorStack(s))
- allocatorStack.put(mem.asInstanceOf[UcxHostBounceBufferMemoryBlock])
- }
-
- override def close(): Unit = {
- allocatorMap.values.forEach(allocator => allocator.close())
- allocatorMap.clear()
- }
-
-}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala
new file mode 100755
index 00000000..c111f2a9
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/perf/UcxShuffleTransportPerfTool.scala
@@ -0,0 +1,269 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx.perf
+
+import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
+import java.nio.ByteBuffer
+import java.util.concurrent.{CountDownLatch, Executors, TimeUnit}
+
+import ai.rapids.cudf.DeviceMemoryBuffer
+import org.apache.commons.cli.{GnuParser, HelpFormatter, Options}
+import org.openucx.jucx.UcxUtils
+import org.openucx.jucx.ucp.UcpMemMapParams
+import org.openucx.jucx.ucs.UcsConstants
+import org.apache.spark.SparkConf
+import org.apache.spark.shuffle.ucx._
+import org.apache.spark.shuffle.ucx.memory.MemoryPool
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.util.Utils
+
+object UcxShuffleTransportPerfTool {
+ private val HELP_OPTION = "h"
+ private val ADDRESS_OPTION = "a"
+ private val NUM_BLOCKS_OPTION = "n"
+ private val SIZE_OPTION = "s"
+ private val PORT_OPTION = "p"
+ private val ITER_OPTION = "i"
+ private val MEMORY_TYPE_OPTION = "m"
+ private val NUM_THREADS_OPTION = "t"
+ private val REUSE_ADDRESS_OPTION = "r"
+
+ private val ucxShuffleConf = new UcxShuffleConf(new SparkConf())
+ private val transport = new UcxShuffleTransport(ucxShuffleConf, "e")
+ private val workerAddress = transport.init()
+ private var memoryPool: MemoryPool = transport.hostMemoryPool
+
+ case class TestBlockId(id: Int) extends BlockId
+
+ case class PerfOptions(remoteAddress: InetSocketAddress, numBlocks: Int, blockSize: Long,
+ serverPort: Int, numIterations: Int, numThreads: Int, memoryType: Int)
+
+ private def initOptions(): Options = {
+ val options = new Options()
+ options.addOption(HELP_OPTION, "help", false,
+ "display help message")
+ options.addOption(ADDRESS_OPTION, "address", true,
+ "address of the remote host")
+ options.addOption(NUM_BLOCKS_OPTION, "num-blocks", true,
+ "number of blocks to transfer. Default: 1")
+ options.addOption(SIZE_OPTION, "block-size", true,
+ "size of block to transfer. Default: 4m")
+ options.addOption(PORT_OPTION, "server-port", true,
+ "server port. Default: 12345")
+ options.addOption(ITER_OPTION, "num-iterations", true,
+ "number of iterations. Default: 5")
+ options.addOption(NUM_THREADS_OPTION, "num-threads", true,
+ "number of threads. Default: 1")
+ options.addOption(MEMORY_TYPE_OPTION, "memory-type", true,
+ "memory type: host (default), cuda")
+ }
+
+ private def parseOptions(args: Array[String]): PerfOptions = {
+ val parser = new GnuParser()
+ val options = initOptions()
+ val cmd = parser.parse(options, args)
+
+ if (cmd.hasOption(HELP_OPTION)) {
+ new HelpFormatter().printHelp("UcxShufflePerfTool", options)
+ System.exit(0)
+ }
+
+ val inetAddress = if (cmd.hasOption(ADDRESS_OPTION)) {
+ val Array(host, port) = cmd.getOptionValue(ADDRESS_OPTION).split(":")
+ new InetSocketAddress(host, Integer.parseInt(port))
+ } else {
+ null
+ }
+
+ val serverPort = Integer.parseInt(cmd.getOptionValue(PORT_OPTION, "12345"))
+
+ val numIterations = Integer.parseInt(cmd.getOptionValue(ITER_OPTION, "5"))
+
+ val threadsNumber = Integer.parseInt(cmd.getOptionValue(NUM_THREADS_OPTION, "1"))
+
+ var memoryType = UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST
+ if (cmd.hasOption(MEMORY_TYPE_OPTION) && cmd.getOptionValue(MEMORY_TYPE_OPTION) == "cuda") {
+ memoryPool = transport.deviceMemoryPool
+ memoryType = UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA
+ }
+
+ PerfOptions(inetAddress,
+ Integer.parseInt(cmd.getOptionValue(NUM_BLOCKS_OPTION, "1")),
+ Utils.byteStringAsBytes(cmd.getOptionValue(SIZE_OPTION, "4m")),
+ serverPort, numIterations, threadsNumber, memoryType)
+ }
+
+ private def startServer(perfOptions: PerfOptions): Unit = {
+ val blocks: Seq[Block] = (0 until perfOptions.numBlocks * perfOptions.numThreads).map { _ =>
+ if (ucxShuffleConf.pinMemory) {
+ val block = memoryPool.get(perfOptions.blockSize)
+ new Block {
+ override def getMemoryBlock: MemoryBlock =
+ block
+ }
+ } else if (perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) {
+ new Block {
+ var oldBlock: DeviceMemoryBuffer = _
+ override def getMemoryBlock: MemoryBlock = {
+ if (oldBlock != null) {
+ oldBlock.close()
+ }
+ val block = DeviceMemoryBuffer.allocate(perfOptions.blockSize)
+ oldBlock = block
+ MemoryBlock(block.getAddress, block.getLength, isHostMemory = false)
+ }
+ }
+ } else {
+ new Block {
+ override def getMemoryBlock: MemoryBlock = {
+ val buf = ByteBuffer.allocateDirect(perfOptions.blockSize.toInt)
+ MemoryBlock(UcxUtils.getAddress(buf), perfOptions.blockSize)
+ }
+ }
+ }
+
+ }
+
+ val blockIds = (0 until perfOptions.numBlocks * perfOptions.numThreads).map(i => TestBlockId(i))
+ blockIds.zip(blocks).foreach {
+ case (blockId, block) => transport.register(blockId, block)
+ }
+
+ val serverSocket = new ServerSocket(perfOptions.serverPort)
+
+ println(s"Waiting for connections on " +
+ s"${InetAddress.getLocalHost.getHostName}:${perfOptions.serverPort} ")
+
+ val clientSocket = serverSocket.accept()
+ val out = clientSocket.getOutputStream
+ val in = clientSocket.getInputStream
+
+ val buf = ByteBuffer.allocate(workerAddress.capacity())
+ buf.put(workerAddress)
+ buf.flip()
+
+ out.write(buf.array())
+ out.flush()
+
+ println(s"Sending worker address to ${clientSocket.getInetAddress}")
+
+ buf.flip()
+
+ in.read(buf.array())
+ clientSocket.close()
+ serverSocket.close()
+
+ blocks.foreach(block => memoryPool.put(block.getMemoryBlock))
+ blockIds.foreach(transport.unregister)
+ transport.close()
+ }
+
+ private def startClient(perfOptions: PerfOptions): Unit = {
+ val socket = new Socket(perfOptions.remoteAddress.getHostName,
+ perfOptions.remoteAddress.getPort)
+
+ val buf = new Array[Byte](4096)
+ val readSize = socket.getInputStream.read(buf)
+ val executorId = "1"
+ val workerAddress = ByteBuffer.allocateDirect(readSize)
+
+ workerAddress.put(buf, 0, readSize)
+ println("Received worker address")
+
+ transport.addExecutor(executorId, workerAddress)
+
+ val resultSize = perfOptions.numThreads * perfOptions.numBlocks * perfOptions.blockSize
+ val resultMemory = memoryPool.get(resultSize)
+ transport.register(TestBlockId(-1), new Block {
+ override def getMemoryBlock: MemoryBlock = resultMemory
+ })
+
+ val threadPool = Executors.newFixedThreadPool(perfOptions.numThreads)
+
+ for (i <- 1 to perfOptions.numIterations) {
+ val startTime = System.nanoTime()
+ val countDownLatch = new CountDownLatch(perfOptions.numThreads)
+ val deviceMemoryBuffers: Array[DeviceMemoryBuffer] = new Array(perfOptions.numThreads * perfOptions.numBlocks)
+ for (tid <- 0 until perfOptions.numThreads) {
+ threadPool.execute(() => {
+ val blocksOffset = tid * perfOptions.numBlocks
+ val blockIds = (blocksOffset until blocksOffset + perfOptions.numBlocks).map(i => TestBlockId(i))
+
+ val mem = new Array[MemoryBlock](perfOptions.numBlocks)
+
+ for (j <- 0 until perfOptions.numBlocks) {
+ mem(j) = MemoryBlock(resultMemory.address +
+ (tid * perfOptions.numBlocks * perfOptions.blockSize) + j * perfOptions.blockSize, perfOptions.blockSize)
+
+ if (!ucxShuffleConf.pinMemory && perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_HOST) {
+ val buffer = Platform.allocateDirectBuffer(perfOptions.blockSize.toInt)
+ mem(j) = MemoryBlock(UcxUtils.getAddress(buffer), perfOptions.blockSize)
+ } else if (!ucxShuffleConf.pinMemory &&
+ perfOptions.memoryType == UcsConstants.MEMORY_TYPE.UCS_MEMORY_TYPE_CUDA) {
+ val buffer = DeviceMemoryBuffer.allocate(perfOptions.blockSize)
+ deviceMemoryBuffers(tid * perfOptions.numBlocks + j) = buffer
+ mem(j) = MemoryBlock(buffer.getAddress, buffer.getLength, isHostMemory = false)
+ } else {
+ mem(j) = MemoryBlock(resultMemory.address +
+ (tid * perfOptions.numBlocks * perfOptions.blockSize) + j * perfOptions.blockSize, perfOptions.blockSize)
+ }
+ }
+
+ val requests = transport.fetchBlocksByBlockIds(executorId, blockIds, mem, Seq.empty[OperationCallback])
+
+ while (!requests.forall(_.isCompleted)) {
+ transport.progress()
+ }
+
+ countDownLatch.countDown()
+ })
+ }
+
+ countDownLatch.await()
+ val elapsedTime = System.nanoTime() - startTime
+
+ deviceMemoryBuffers.foreach(d => if (d != null) d.close())
+ val totalTime = if (elapsedTime < TimeUnit.MILLISECONDS.toNanos(1)) {
+ s"$elapsedTime ns"
+ } else {
+ s"${TimeUnit.NANOSECONDS.toMillis(elapsedTime)} ms"
+ }
+ val throughput: Double = (resultSize / 1024.0D / 1024.0D / 1024.0D) /
+ (elapsedTime / 1e9D)
+
+ if ((i % 100 == 0) || i == perfOptions.numIterations) {
+ println(f"${s"[$i/${perfOptions.numIterations}]"}%12s" +
+ s" numBlocks: ${perfOptions.numBlocks}" +
+ s" numThreads: ${perfOptions.numThreads}" +
+ s" size: ${Utils.bytesToString(perfOptions.blockSize)}," +
+ s" total size: ${Utils.bytesToString(resultSize * perfOptions.numThreads)}," +
+ f" time: $totalTime%3s" +
+ f" throughput: $throughput%.5f GB/s")
+ }
+ }
+
+ val out = socket.getOutputStream
+ out.write(buf)
+ out.flush()
+ out.close()
+ socket.close()
+
+ transport.unregister(TestBlockId(-1))
+ memoryPool.put(resultMemory)
+ transport.close()
+ threadPool.shutdown()
+ }
+
+ def main(args: Array[String]): Unit = {
+ val perfOptions = parseOptions(args)
+
+ if (perfOptions.remoteAddress == null) {
+ startServer(perfOptions)
+ } else {
+ startClient(perfOptions)
+ }
+ }
+
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala
index c65dba23..aec6c17e 100755
--- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/GlobalWorkerRpcThread.scala
@@ -5,91 +5,65 @@
package org.apache.spark.shuffle.ucx.rpc
import java.io.ObjectInputStream
+import java.nio.ByteBuffer
import com.fasterxml.jackson.databind.util.ByteBufferBackedInputStream
-import org.openucx.jucx.ucp.{UcpRequest, UcpWorker}
-import org.openucx.jucx.{UcxException, UcxUtils}
+import org.openucx.jucx.UcxUtils
+import org.openucx.jucx.ucp._
+import org.openucx.jucx.ucs.UcsConstants
import org.apache.spark.internal.Logging
-import org.apache.spark.shuffle.ucx.memory.MemoryPool
-import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{FetchBlockByBlockIdRequest, FetchBlocksByBlockIdsRequest, PrefetchBlockIds}
-import org.apache.spark.shuffle.ucx.{MemoryBlock, UcxShuffleTransport}
+import org.apache.spark.shuffle.ucx.UcxShuffleTransport
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.FetchBlocksByBlockIdsRequest
import org.apache.spark.util.Utils
-class GlobalWorkerRpcThread(globalWorker: UcpWorker, memPool: MemoryPool,
- transport: UcxShuffleTransport) extends Thread with Logging {
+class GlobalWorkerRpcThread(globalWorker: UcpWorker, transport: UcxShuffleTransport)
+ extends Thread with Logging {
setDaemon(true)
setName("Ucx Shuffle Transport Progress Thread")
+ private def handleFetchBlockRequest(buffer: ByteBuffer, ep: UcpEndpoint): Unit = {
+ val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin =>
+ val objIn = new ObjectInputStream(bin)
+ val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest]
+ objIn.close()
+ obj
+ }
+ val startTime = System.nanoTime()
+ transport.replyFetchBlocksRequest(msg.blockIds, msg.isHostMemory, ep, msg.startTag, msg.singleReply)
+ logInfo(s"Sent reply for ${msg.blockIds.length} blocks in ${Utils.getUsedTimeNs(startTime)}")
+ }
override def run(): Unit = {
- val numRecvs = transport.ucxShuffleConf.recvQueueSize
- val msgSize = transport.ucxShuffleConf.rpcMessageSize
- val requests: Array[UcpRequest] = Array.ofDim[UcpRequest](numRecvs)
- val recvMemory = memPool.get(msgSize * numRecvs)
- val memoryBlocks = (0 until numRecvs).map(i =>
- MemoryBlock(recvMemory.address + i * msgSize, msgSize))
-
- for (i <- 0 until numRecvs) {
- requests(i) = globalWorker.recvTaggedNonBlocking(memoryBlocks(i).address, msgSize,
- UcxRpcMessages.WILDCARD_TAG, UcxRpcMessages.WILDCARD_TAG_MASK, null)
+ val processCallback: UcpAmRecvCallback = (headerAddress: Long, headerSize: Long, amData: UcpAmData,
+ replyEp: UcpEndpoint) => {
+ if (headerSize > 0) {
+ logDebug(s"Received AM in header on ${transport.clientConnections.get(replyEp)}")
+ val header = UcxUtils.getByteBufferView(headerAddress, headerSize)
+ handleFetchBlockRequest(header, replyEp)
+ UcsConstants.STATUS.UCS_OK
+ } else {
+ assert(amData.isDataValid)
+ logDebug(s"Received AM in eager on ${transport.clientConnections.get(replyEp)}")
+ val data = UcxUtils.getByteBufferView(amData.getDataAddress, amData.getLength)
+ handleFetchBlockRequest(data, replyEp)
+ UcsConstants.STATUS.UCS_OK
+ }
}
-
+ globalWorker.setAmRecvHandler(0, processCallback)
+ globalWorker.setAmRecvHandler(-1, new UcpAmRecvCallback() {
+ override def onReceive(headerAddress: Long, headerSize: Long, amData: UcpAmData, replyEp: UcpEndpoint): Int = {
+ logTrace(s"Hello")
+ UcsConstants.STATUS.UCS_OK
+ }
+ })
while (!isInterrupted) {
if (globalWorker.progress() == 0) {
- globalWorker.waitForEvents()
- globalWorker.progress()
- }
- for (i <- 0 until numRecvs) {
- if (requests(i).isCompleted) {
- val buffer = UcxUtils.getByteBufferView(memoryBlocks(i).address, msgSize.toInt)
- val senderTag = requests(i).getSenderTag
- if (senderTag >= 2L) {
- // Fetch Single block
- val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin =>
- val objIn = new ObjectInputStream(bin)
- val obj = objIn.readObject().asInstanceOf[FetchBlockByBlockIdRequest]
- objIn.close()
- obj
- }
- transport.replyFetchBlockRequest(msg.executorId, msg.workerAddress.value, msg.blockId, senderTag)
- } else if (senderTag < 0 ) {
- // Batch fetch request
- val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin =>
- val objIn = new ObjectInputStream(bin)
- val obj = objIn.readObject().asInstanceOf[FetchBlocksByBlockIdsRequest]
- objIn.close()
- obj
- }
- for (j <- msg.blockIds.indices) {
- transport.replyFetchBlockRequest(msg.executorId, msg.workerAddress.value, msg.blockIds(j),
- senderTag + j)
- }
-
- } else if (senderTag == UcxRpcMessages.PREFETCH_TAG) {
- val msg = Utils.tryWithResource(new ByteBufferBackedInputStream(buffer)) { bin =>
- val objIn = new ObjectInputStream(bin)
- val obj = objIn.readObject().asInstanceOf[PrefetchBlockIds]
- objIn.close()
- obj
- }
- transport.handlePrefetchRequest(msg.executorId, msg.workerAddress.value, msg.blockIds)
- }
- requests(i) = globalWorker.recvTaggedNonBlocking(memoryBlocks(i).address, msgSize,
- UcxRpcMessages.WILDCARD_TAG, UcxRpcMessages.WILDCARD_TAG_MASK, null)
+ if (transport.ucxShuffleConf.useWakeup) {
+ globalWorker.waitForEvents()
}
}
}
- memPool.put(recvMemory)
- for (i <- 0 until numRecvs) {
- if (!requests(i).isCompleted) {
- try {
- globalWorker.cancelRequest(requests(i))
- } catch {
- case _: UcxException =>
- }
- }
- }
}
}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala
new file mode 100755
index 00000000..50a9e2e2
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxDriverRpcEndpoint.scala
@@ -0,0 +1,42 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx.rpc
+
+import scala.collection.immutable.HashMap
+import scala.collection.mutable
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc._
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors}
+import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
+
+class UcxDriverRpcEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging {
+
+ private val endpoints: mutable.Set[RpcEndpointRef] = mutable.HashSet.empty
+ private var blockManagerToWorkerAddress = HashMap.empty[String, SerializableDirectBuffer]
+
+ override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
+ case message@ExecutorAdded(executorId: String, endpoint: RpcEndpointRef,
+ ucxWorkerAddress: SerializableDirectBuffer) => {
+ // Driver receives a message from executor with it's workerAddress
+ // 1. Introduce existing members of a cluster
+ logInfo(s"Received $message")
+ if (blockManagerToWorkerAddress.nonEmpty) {
+ val msg = IntroduceAllExecutors(blockManagerToWorkerAddress.keys.toSeq,
+ blockManagerToWorkerAddress.values.toList)
+ logInfo(s"replying $msg to $executorId")
+ context.reply(msg)
+ }
+ blockManagerToWorkerAddress += executorId -> ucxWorkerAddress
+ // 2. For each existing member introduce newly joined executor.
+ endpoints.foreach(ep => {
+ logInfo(s"Sending $message to $ep")
+ ep.send(message)
+ })
+ logInfo(s"Connecting back to address: ${context.senderAddress}")
+ endpoints.add(endpoint)
+ }
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala
new file mode 100755
index 00000000..3b8dec20
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxExecutorRpcEndpoint.scala
@@ -0,0 +1,31 @@
+/*
+* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* See file LICENSE for terms.
+*/
+package org.apache.spark.shuffle.ucx.rpc
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.shuffle.ucx.UcxShuffleTransport
+import org.apache.spark.shuffle.ucx.rpc.UcxRpcMessages.{ExecutorAdded, IntroduceAllExecutors}
+import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
+
+class UcxExecutorRpcEndpoint(override val rpcEnv: RpcEnv, transport: UcxShuffleTransport)
+ extends RpcEndpoint with Logging {
+
+
+ override def receive: PartialFunction[Any, Unit] = {
+ case ExecutorAdded(executorId: String, _: RpcEndpointRef,
+ ucxWorkerAddress: SerializableDirectBuffer) => {
+ logInfo(s"Received ExecutorAdded($executorId)")
+ transport.addExecutor(executorId, ucxWorkerAddress.value)
+ }
+ case IntroduceAllExecutors(executorIds: Seq[String],
+ ucxWorkerAddresses: Seq[SerializableDirectBuffer]) => {
+ logInfo(s"Received IntroduceAllExecutors(${executorIds.mkString(",")})")
+ executorIds.zip(ucxWorkerAddresses).foreach {
+ case (executorId, workerAddress) => transport.addExecutor(executorId, workerAddress.value)
+ }
+ }
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala
index ebdc47f4..432d6e32 100644
--- a/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/rpc/UcxRpcMessages.scala
@@ -9,11 +9,6 @@ import org.apache.spark.shuffle.ucx.BlockId
import org.apache.spark.shuffle.ucx.utils.SerializableDirectBuffer
object UcxRpcMessages {
-
- val PREFETCH_TAG = 1L
- val WILDCARD_TAG = -1L
- val WILDCARD_TAG_MASK = 0L
-
/**
* Called from executor to driver, to introduce ucx worker address.
*/
@@ -27,13 +22,6 @@ object UcxRpcMessages {
case class IntroduceAllExecutors(executorIds: Seq[String],
ucxWorkerAddresses: Seq[SerializableDirectBuffer])
- case class FetchBlockByBlockIdRequest(executorId: String, workerAddress: SerializableDirectBuffer,
- blockId: BlockId)
-
- case class FetchBlocksByBlockIdsRequest(executorId: String,
- workerAddress: SerializableDirectBuffer,
- blockIds: Seq[BlockId])
-
- case class PrefetchBlockIds(executorId: String, workerAddress: SerializableDirectBuffer,
- blockIds: Seq[BlockId])
+ case class FetchBlocksByBlockIdsRequest(startTag: Int, blockIds: Seq[BlockId],
+ isHostMemory: Seq[Boolean], singleReply: Boolean = false)
}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala
index 009d7b08..4d453b38 100644
--- a/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/SerializableDirectBuffer.scala
@@ -5,11 +5,12 @@
package org.apache.spark.shuffle.ucx.utils
import java.io.{EOFException, ObjectInputStream, ObjectOutputStream}
+import java.net.InetSocketAddress
import java.nio.ByteBuffer
import java.nio.channels.Channels
import org.apache.spark.internal.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils}
/**
* A wrapper around a java.nio.ByteBuffer that is serializable through Java serialization, to make
@@ -64,3 +65,28 @@ class DeserializableToExternalMemoryBuffer(@transient var buffer: ByteBuffer)()
buffer.rewind() // Allow us to read it later
}
}
+
+
+object SerializationUtils {
+
+ def deserializeInetAddress(workerAddress: ByteBuffer): InetSocketAddress = {
+ workerAddress.rewind()
+ Utils.tryWithResource(new ByteBufferInputStream(workerAddress)) { bin =>
+ val objIn = new ObjectInputStream(bin)
+ val obj = objIn.readObject().asInstanceOf[InetSocketAddress]
+ objIn.close()
+ obj
+ }
+ }
+
+ def serializeInetAddress(address: InetSocketAddress): ByteBuffer = {
+ val hostAddress = new InetSocketAddress(Utils.localCanonicalHostName(), address.getPort)
+ Utils.tryWithResource(new ByteBufferOutputStream(100)) {bos =>
+ val out = new ObjectOutputStream(bos)
+ out.writeObject(hostAddress)
+ out.flush()
+ out.close()
+ bos.toByteBuffer
+ }
+ }
+}
diff --git a/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala
new file mode 100755
index 00000000..37c7d34e
--- /dev/null
+++ b/src/main/scala/org/apache/spark/shuffle/ucx/utils/UcxHelperUtils.scala
@@ -0,0 +1,36 @@
+package org.apache.spark.shuffle.ucx.utils
+
+import java.net.{BindException, InetSocketAddress}
+
+import scala.collection.mutable
+import scala.util.Random
+
+import org.openucx.jucx.UcxException
+import org.openucx.jucx.ucp._
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+object UcxHelperUtils extends Logging{
+ def startListenerOnRandomPort(worker: UcpWorker, sparkConf: SparkConf,
+ endpoints: mutable.HashMap[UcpEndpoint, (UcpEndpoint, InetSocketAddress)]):
+ UcpListener = {
+ val ucpListenerParams = new UcpListenerParams().setConnectionHandler(
+ (ucpConnectionRequest: UcpConnectionRequest) => {
+ val repyEp = worker.newEndpoint(new UcpEndpointParams().setPeerErrorHandlingMode()
+ .setConnectionRequest(ucpConnectionRequest))
+ endpoints += repyEp -> (repyEp, ucpConnectionRequest.getClientAddress)
+ })
+ val (listener, _) = Utils.startServiceOnPort(1024 + Random.nextInt(65535 - 1024), (port: Int) => {
+ ucpListenerParams.setSockAddr(new InetSocketAddress(port))
+ val listener = try {
+ worker.newListener(ucpListenerParams)
+ } catch {
+ case ex:UcxException => throw new BindException(ex.getMessage)
+ }
+ (listener, listener.getAddress.getPort)
+ }, sparkConf)
+ logInfo(s"Started UcxListener on ${listener.getAddress}")
+ listener
+ }
+}
diff --git a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala
index f391f3cd..6c107d6b 100755
--- a/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala
+++ b/src/test/scala/org/apache/spark/shuffle/ucx/UcxShuffleTransportTestSuite.scala
@@ -1,5 +1,5 @@
/*
-* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
+* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx
@@ -76,20 +76,16 @@ class UcxShuffleTransportTestSuite extends AnyFunSuite {
val resultMemory2 = MemoryBlock(UcxUtils.getAddress(resultBuffer2), blockSize)
val completed = new AtomicInteger(0)
+ val callback: OperationCallback = (result: OperationResult) => {
+ completed.incrementAndGet()
+ assert(result.getStats.get.recvSize == blockSize)
+ assert(result.getStatus == OperationStatus.SUCCESS)
+ }
- transport.fetchBlockByBlockId(server.ID, TestBlockId(1), resultMemory1,
- (result: OperationResult) => {
- completed.incrementAndGet()
- assert(result.getStats.get.recvSize == blockSize)
- assert(result.getStatus == OperationStatus.SUCCESS)
- })
-
- transport.fetchBlockByBlockId(server.ID, TestBlockId(2), resultMemory2,
- (result: OperationResult) => {
- completed.incrementAndGet()
- assert(result.getStats.get.recvSize == blockSize)
- assert(result.getStatus == OperationStatus.SUCCESS)
- })
+ transport.fetchBlocksByBlockIds(server.ID,
+ Array(TestBlockId(1), TestBlockId(2)),
+ Array(resultMemory1, resultMemory2),
+ Array(callback, callback))
while (completed.get() != 2) {
transport.progress()
@@ -107,7 +103,9 @@ class UcxShuffleTransportTestSuite extends AnyFunSuite {
while (!mutated.get()) {}
- val request = transport.fetchBlockByBlockId(server.ID, TestBlockId(2), resultMemory2, null)
+ val request = transport.fetchBlocksByBlockIds(server.ID,
+ Array(TestBlockId(2)), Array(resultMemory2), Seq.empty)(0)
+
while (!request.isCompleted) {
transport.progress()
}