Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest work done for unified API #31

Open
wants to merge 4 commits into
base: unified-api
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 28 additions & 7 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ See file LICENSE for terms.
<spark.version>3.0.0</spark.version>
<scala.version>2.12.12</scala.version>
<scala.compat.version>2.12</scala.compat.version>
<jucx.version>1.10.0-SNAPSHOT</jucx.version>
<jucx.version>1.11.0-rc3</jucx.version>
<cudf.version>0.18.1</cudf.version>
</properties>

<dependencies>
Expand All @@ -55,26 +56,34 @@ See file LICENSE for terms.
<version>3.2.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>ai.rapids</groupId>
<artifactId>cudf</artifactId>
<version>${cudf.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>

<build>
<finalName>${project.artifactId}-${project.version}-for-spark-${spark.version}</finalName>
<finalName>${project.artifactId}-${project.version}-for-spark-3.0</finalName>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
<plugin>
<groupId>net.alchim31.maven</groupId>
<artifactId>scala-maven-plugin</artifactId>
<version>4.4.0</version>
<version>4.4.1</version>
<configuration>
<recompileMode>all</recompileMode>
<javacArgs>
<javacArg>-source</javacArg>
<javacArg>${maven.compiler.source}</javacArg>
<javacArg>-target</javacArg>
<javacArg>${maven.compiler.target}</javacArg>
</javacArgs>
<args>
<arg>-Xexperimental</arg>
<arg>-Xfatal-warnings</arg>
Expand Down Expand Up @@ -137,6 +146,18 @@ See file LICENSE for terms.
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.2.0</version>
<executions>
<execution>
<goals>
<goal>test-jar</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</build>

Expand Down
13 changes: 3 additions & 10 deletions src/main/scala/org/apache/spark/shuffle/ucx/ShuffleTransport.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,23 +135,16 @@ trait ShuffleTransport {
*/
def unregister(blockId: BlockId)

/**
* Hint for a transport that these blocks would needed soon.
*/
def prefetchBlocks(executorId: String, blockIds: Seq[BlockId])

/**
* Batch version of [[ fetchBlocksByBlockIds ]].
*/
def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
resultBuffer: Seq[MemoryBlock],
callbacks: Seq[OperationCallback]): Seq[Request]

/**
* Fetch remote blocks by blockIds.
*/
def fetchBlockByBlockId(executorId: String, blockId: BlockId,
resultBuffer: MemoryBlock, cb: OperationCallback): Request
def fetchBlocksByBlockIds(executorId: String, blockIds: Seq[BlockId],
resultBuffer: MemoryBlock,
callbacks: OperationCallback): Request

/**
* Progress outstanding operations. This routine is blocking (though may poll for event).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2019. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import java.io.File
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.file.StandardOpenOption
import java.util.concurrent.ConcurrentHashMap

import scala.collection.JavaConverters._

import org.openucx.jucx.UcxUtils
import org.apache.spark.shuffle.IndexShuffleBlockResolver
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.unsafe.Platform


case class UcxShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId {

def this(shuffleBlockId: ShuffleBlockId) = {
this(shuffleBlockId.shuffleId, shuffleBlockId.mapId, shuffleBlockId.reduceId)
}

def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
}

case class BufferBackedBlock(buffer: ByteBuffer) extends Block {
override def getMemoryBlock: MemoryBlock = MemoryBlock(UcxUtils.getAddress(buffer), buffer.capacity())
}

class UcxShuffleBlockResolver(conf: UcxShuffleConf, transport: UcxShuffleTransport)
extends IndexShuffleBlockResolver(conf) {

type MapId = Long

private val numPartitionsForMapId = new ConcurrentHashMap[MapId, Int]

override def writeIndexFileAndCommit(shuffleId: ShuffleId, mapId: Long,
lengths: Array[Long], dataTmp: File): Unit = {
super.writeIndexFileAndCommit(shuffleId, mapId, lengths, dataTmp)
val dataFile = getDataFile(shuffleId, mapId)
if (!dataFile.exists()) {
return
}
numPartitionsForMapId.put(mapId, lengths.length)
val fileChannel = FileChannel.open(dataFile.toPath, StandardOpenOption.READ,
StandardOpenOption.WRITE)
val mappedBuffer = fileChannel.map(FileChannel.MapMode.READ_WRITE, 0L, dataFile.length())

val baseAddress = UcxUtils.getAddress(mappedBuffer)
fileChannel.close()

// Register whole map output file as dummy block
transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE),
BufferBackedBlock(mappedBuffer))

val offsetSize = 8 * (lengths.length + 1)
val indexBuf = Platform.allocateDirectBuffer(offsetSize)

var offset = 0L
indexBuf.putLong(offset)
for (reduceId <- lengths.indices) {
if (lengths(reduceId) > 0) {
transport.register(UcxShuffleBlockId(shuffleId, mapId, reduceId), new Block {
private val memoryBlock = MemoryBlock(baseAddress + offset, lengths(reduceId))
override def getMemoryBlock: MemoryBlock = memoryBlock
})
offset += lengths(reduceId)
indexBuf.putLong(offset)
}
}

if (transport.ucxShuffleConf.protocol == transport.ucxShuffleConf.PROTOCOL.ONE_SIDED) {
transport.register(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE), BufferBackedBlock(indexBuf))
}
}

override def removeDataByMap(shuffleId: ShuffleId, mapId: Long): Unit = {
transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.MAP_FILE))
transport.unregister(UcxShuffleBlockId(shuffleId, mapId, BlocksConstants.INDEX_FILE))

val numRegisteredBlocks = numPartitionsForMapId.get(mapId)
(0 until numRegisteredBlocks)
.foreach(reduceId => transport.unregister(UcxShuffleBlockId(shuffleId, mapId, reduceId)))
super.removeDataByMap(shuffleId, mapId)
}

override def stop(): Unit = {
numPartitionsForMapId.keys.asScala.foreach(mapId => removeDataByMap(0, mapId))
super.stop()
}

}

object BlocksConstants {
val MAP_FILE: Int = -1
val INDEX_FILE: Int = -2
}
77 changes: 77 additions & 0 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (C) Mellanox Technologies Ltd. 2020. ALL RIGHTS RESERVED.
* See file LICENSE for terms.
*/
package org.apache.spark.shuffle.ucx

import java.util.concurrent.TimeUnit

import org.openucx.jucx.{UcxException, UcxUtils}
import org.apache.spark.SparkEnv
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.SHUFFLE_ACCURATE_BLOCK_THRESHOLD
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient, DownloadFileManager}
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, BlockId => SparkBlockId}

class UcxShuffleClient(transport: UcxShuffleTransport,
blocksByAddress: Iterator[(BlockManagerId, Seq[(SparkBlockId, Long, Int)])])
extends BlockStoreClient with Logging {

private val accurateThreshold = transport.ucxShuffleConf.conf.getSizeAsBytes(SHUFFLE_ACCURATE_BLOCK_THRESHOLD.key)

private val blockSizes: Map[SparkBlockId, Long] = blocksByAddress
.withFilter { case (blockManagerId, _) => blockManagerId != SparkEnv.get.blockManager.blockManagerId }
.flatMap {
case (blockManagerId, blocks) =>
blocks.map {
case (blockId, length, _) =>
if (length > accurateThreshold) {
(blockId, (length * 1.2).toLong)
} else {
(blockId, accurateThreshold * 2)
}
}
}.toMap

override def fetchBlocks(host: String, port: Int, execId: String,
blockIds: Array[String], listener: BlockFetchingListener,
downloadFileManager: DownloadFileManager): Unit = {
val ucxBlockIds = new Array[BlockId](blockIds.length)
val memoryBlocks = new Array[MemoryBlock](blockIds.length)
val callbacks = new Array[OperationCallback](blockIds.length)
for (i <- blockIds.indices) {
val blockId = SparkBlockId.apply(blockIds(i)).asInstanceOf[ShuffleBlockId]
if (!blockSizes.contains(blockId)) {
throw new UcxException(s"No $blockId found in MapOutput blocks: ${blockSizes.keys.mkString(",")}")
}
val resultMemory = transport.hostMemoryPool.get(blockSizes(blockId))
ucxBlockIds(i) = UcxShuffleBlockId(blockId.shuffleId, blockId.mapId, blockId.reduceId)
memoryBlocks(i) = MemoryBlock(resultMemory.address, blockSizes(blockId))
callbacks(i) = (result: OperationResult) => {
if (result.getStatus == OperationStatus.SUCCESS) {
val stats = result.getStats.get
logInfo(s" Received block ${ucxBlockIds(i)} " +
s"of size: ${stats.recvSize} " +
s"in ${TimeUnit.NANOSECONDS.toMillis(stats.getElapsedTimeNs)} ms")
val buffer = UcxUtils.getByteBufferView(resultMemory.address, result.getStats.get.recvSize.toInt)
listener.onBlockFetchSuccess(blockIds(i), new NioManagedBuffer(buffer) {
override def release: ManagedBuffer = {
transport.hostMemoryPool.put(resultMemory)
this
}
})
} else {
logError(s"Error fetching block $blockId of size ${blockSizes(blockId)}:" +
s" ${result.getError.getMessage}")
listener.onBlockFetchFailure(blockIds(i), new UcxException(result.getError.getMessage))
}
}
}
transport.fetchBlocksByBlockIds(execId, ucxBlockIds, memoryBlocks, callbacks)
}

override def close(): Unit = {

}
}
53 changes: 29 additions & 24 deletions src/main/scala/org/apache/spark/shuffle/ucx/UcxShuffleConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@ import org.apache.spark.util.Utils
class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
private def getUcxConf(name: String) = s"spark.shuffle.ucx.$name"

private val PROTOCOL =
object PROTOCOL extends Enumeration {
val ONE_SIDED, RNDV = Value
}

private lazy val PROTOCOL_CONF =
ConfigBuilder(getUcxConf("protocol"))
.doc("Which protocol to use: rndv (default), one-sided")
.doc("Which protocol to use: RNDV (default), ONE-SIDED")
.stringConf
.checkValue(protocol => protocol == "rndv" || protocol == "one-sided",
"Invalid protocol. Valid options: rndv / one-sided.")
.createWithDefault("rndv")
.transform(_.toUpperCase.replace("-", "_"))
.createWithDefault("RNDV")

private val MEMORY_PINNING =
private lazy val MEMORY_PINNING =
ConfigBuilder(getUcxConf("memoryPinning"))
.doc("Whether to pin whole shuffle data in memory")
.booleanConf
Expand All @@ -30,15 +35,7 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
ConfigBuilder(getUcxConf("maxWorkerSize"))
.doc("Maximum size of worker address in bytes")
.bytesConf(ByteUnit.BYTE)
.createWithDefault(1000)

lazy val RPC_MESSAGE_SIZE: ConfigEntry[Long] =
ConfigBuilder(getUcxConf("rpcMessageSize"))
.doc("Size of RPC message to send from fetchBlockByBlockId. Must contain ")
.bytesConf(ByteUnit.BYTE)
.checkValue(size => size > maxWorkerAddressSize,
"Rpc message must contain workerAddress")
.createWithDefault(2000)
.createWithDefault(1024)

// Memory Pool
private lazy val PREALLOCATE_BUFFERS =
Expand All @@ -52,11 +49,11 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
.booleanConf
.createWithDefault(false)

private lazy val RECV_QUEUE_SIZE =
ConfigBuilder(getUcxConf("recvQueueSize"))
.doc("The number of submitted receive requests.")
.intConf
.createWithDefault(5)
private lazy val USE_SOCKADDR =
ConfigBuilder(getUcxConf("useSockAddr"))
.doc("Whether to use socket address to connect executors.")
.booleanConf
.createWithDefault(true)

private lazy val MIN_REGISTRATION_SIZE =
ConfigBuilder(getUcxConf("memory.minAllocationSize"))
Expand All @@ -67,24 +64,32 @@ class UcxShuffleConf(val conf: SparkConf) extends SparkConf {
lazy val minRegistrationSize: Int = conf.getSizeAsBytes(MIN_REGISTRATION_SIZE.key,
MIN_REGISTRATION_SIZE.defaultValueString).toInt

lazy val protocol: String = conf.get(PROTOCOL.key, PROTOCOL.defaultValueString)
private lazy val USE_ODP =
ConfigBuilder(getUcxConf("useOdp"))
.doc("Whether to use on demand paging feature, to avoid memory pinning")
.booleanConf
.createWithDefault(false)

lazy val useOdp: Boolean = conf.getBoolean(getUcxConf("memory.useOdp"), defaultValue = false)
lazy val protocol: PROTOCOL.Value = PROTOCOL.withName(
conf.get(PROTOCOL_CONF.key, PROTOCOL_CONF.defaultValueString))

lazy val useOdp: Boolean = conf.getBoolean(USE_ODP.key, USE_ODP.defaultValue.get)

lazy val pinMemory: Boolean = conf.getBoolean(MEMORY_PINNING.key, MEMORY_PINNING.defaultValue.get)

lazy val maxWorkerAddressSize: Long = conf.getSizeAsBytes(WORKER_ADDRESS_SIZE.key,
WORKER_ADDRESS_SIZE.defaultValueString)

lazy val rpcMessageSize: Long = conf.getSizeAsBytes(RPC_MESSAGE_SIZE.key,
RPC_MESSAGE_SIZE.defaultValueString)
lazy val maxMetadataSize: Long = conf.getSizeAsBytes("spark.rapids.shuffle.maxMetadataSize",
"1024")


lazy val useWakeup: Boolean = conf.getBoolean(WAKEUP_FEATURE.key, WAKEUP_FEATURE.defaultValue.get)

lazy val recvQueueSize: Int = conf.getInt(RECV_QUEUE_SIZE.key, RECV_QUEUE_SIZE.defaultValue.get)
lazy val useSockAddr: Boolean = conf.getBoolean(USE_SOCKADDR.key, USE_SOCKADDR.defaultValue.get)

lazy val preallocateBuffersMap: Map[Long, Int] = {
conf.get(PREALLOCATE_BUFFERS).split(",").withFilter(s => !s.isEmpty)
conf.get(PREALLOCATE_BUFFERS.key, "").split(",").withFilter(s => s.nonEmpty)
.map(entry => entry.split(":") match {
case Array(bufferSize, bufferCount) =>
(Utils.byteStringAsBytes(bufferSize.trim), bufferCount.toInt)
Expand Down
Loading