diff --git a/.ci/run_nccl_tests.sh b/.ci/run_nccl_tests.sh index c2724a4f..ba1d0d69 100755 --- a/.ci/run_nccl_tests.sh +++ b/.ci/run_nccl_tests.sh @@ -109,7 +109,7 @@ for TEST_EXE in ${NCCL_TEST_EXE[@]}; do #=================== # Enable ucx_rma tests once this is resolved: https://redmine.mellanox.com/issues/3037941 # for P2P_LAYER in ucx ucx_rma ib - for P2P_LAYER in ib ucx ucx_uct ucx_uct_read; do + for P2P_LAYER in ib ucx ucx_rma ucx_uct ucx_uct_read; do MPIRUN_OPTIONS_PLUGIN_P2P_LAYER="-x NCCL_PLUGIN_P2P=${P2P_LAYER}" #=================== diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 500150ff..64bc166f 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -247,7 +247,7 @@ ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetPropert props->maxComms = ibDev->maxQp; if (p2p_plugin == NCCL_P2P_IB || p2p_plugin == NCCL_P2P_UCX || - nccl_p2p_is_uct_plugin(p2p_plugin)) { + p2p_plugin == NCCL_P2P_UCX_RMA || nccl_p2p_is_uct_plugin(p2p_plugin)) { props->maxRecvs = NCCL_NET_IB_MAX_RECVS; } else { props->maxRecvs = 1; diff --git a/src/ucx_rma_plugin.c b/src/ucx_rma_plugin.c index a23b498c..e82c3573 100644 --- a/src/ucx_rma_plugin.c +++ b/src/ucx_rma_plugin.c @@ -1,21 +1,18 @@ /************************************************************************* - * * Copyright (c) 2016-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * * - * * See LICENSE.txt for license information - * ************************************************************************/ + * Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE.txt for license information + ************************************************************************/ #include -#include -#include -#include -#include "core.h" -#include "ibvwrap.h" +#include "nccl.h" +#include "net.h" #include "p2p_plugin.h" -#include "param.h" -#include "socket.h" + #include "ucp/api/ucp.h" +#define NCCL_UCP_HANDLE_MAGIC 0x3fea0433 #define UCXCHECK(cmd) do { \ int e = cmd; \ @@ -26,1239 +23,1288 @@ } \ } while(0) -#define UCXCHECK_VOID(cmd) do { \ - int e = cmd; \ - if( UCS_OK != e ) { \ - WARN("Failed: UCX error %s:%d '%d' %s\n", \ - __FILE__,__LINE__, e, ucs_status_string(e)); \ - } \ -} while(0) - -NCCL_PARAM(UCXRMADisable, "UCX_RMA_DISABLE", 0); - -extern ncclDebugLogger_t pluginLogFunction; -static char nccl_ucx_rma_tls[32] = ""; -static char nccl_ucx_rma_zcopy_thresh[32] =""; -static int ncclNIbDevs = -1; - -#define MAX_UCX_RKEY_BUF_SIZE 128 -typedef struct nccl_ucx_rma_rkey_buf { - int index; - int id; - char buf[MAX_UCX_RKEY_BUF_SIZE]; - size_t rkey_buf_size; - int send; -} nccl_ucx_rma_rkey_buf_t; - -enum ncclUCXCommState { - ncclUCXCommStateStart = 0, - ncclUCXCommStateConnect = 1, - ncclUCXCommStateAccept = 3, -}; +NCCL_PARAM(UCXAckDelay, "UCX_PUT_ACK_DELAY", 1); +NCCL_PARAM(UCXAckSkip, "UCX_PUT_ACK_SKIP", 0); + +typedef enum { + NCCL_UCP_TYPE_IRECV, + NCCL_UCP_TYPE_ISEND, + NCCL_UCP_TYPE_IFLUSH +} nccl_ucp_req_type_t; + +/* Connection management state machine */ +typedef enum { + NCCL_UCP_START = 0, + NCCL_UCP_CONNECT, + NCCL_UCP_ACCEPT, + NCCL_UCP_RECEIVE_REMOTE, + NCCL_UCP_RX_READY, + NCCL_UCP_DONE +} nccl_ucp_state_t; + +typedef struct nccl_ucp_worker { + struct nccl_ucp_worker *next; + ucp_worker_h ucp_worker; + int dev; + int used; + ucp_context_h ucp_context; + void *address; + size_t address_length; +} nccl_ucp_worker_t; + +typedef struct { + int dev_count; + int listener_count; + char if_name[MAX_IF_NAME_SIZE]; + union ncclSocketAddress if_addr; + nccl_ucp_worker_t *workers; +} nccl_ucp_context_t; + +struct nccl_ucp_comm; + +#define NCCL_UCP_RKEY_SIZE 96 /* bytes */ +#define NCCL_UCP_WORKER_ADDR_SIZE 1024 +#define NCCL_UCP_RKEY_COUNT 128 /* Maximum number of mh */ +#define NCCL_UCP_MAX_RECV 8 /* Maximum chunks per ->irecv() */ -struct ncclUCXCommStage { - enum ncclUCXCommState state; - uint8_t iteration; - void* sock; - void* comm; -}; +/* + * Max send request in-flight 8*8 = 64 + * Ring must be: 64+63 available slots + */ -typedef struct ucx_rma_mhandle { - ucp_mem_h ucp_memh; - ucp_rkey_h rkey; - nccl_ucx_rma_rkey_buf_t rkey_buf; - int mem_type; -} ucx_rma_mhandle_t; +#define NCCL_UCP_RING_SIZE 256 +#define NCCL_UCP_RING_MASK (NCCL_UCP_RING_SIZE - 1) + +#define REG_ALIGN (1 << 12) /* 4kB-pages */ +#define REG_MASK (REG_ALIGN - 1) + +typedef struct nccl_ucp_packed { + unsigned short rkey_id_start; + int rkey_buf_size; + unsigned char rkey_buf[NCCL_UCP_RKEY_SIZE]; + unsigned short rkey_id_end; +} __attribute__((aligned(64))) nccl_ucp_packed_rkey_t; + +typedef struct nccl_ucp_chunk { + uint64_t data; + int size; + int tag; + unsigned short rkey_id; + unsigned short id; +} nccl_ucp_chunk_t; + +typedef struct nccl_ucp_rtr { + unsigned short id_start; /* Id of the RTR */ + unsigned char count; /* Total chunks (at least 1) */ + char avail; /* Chunk left to proceed */ + char ack; /* Set if an ATP will be needed */ + nccl_ucp_chunk_t chunk[NCCL_UCP_MAX_RECV]; +} __attribute__((aligned(64))) nccl_ucp_rtr_t; + +struct nccl_ucp_comm; + +typedef struct nccl_ucp_atp { + unsigned short id_start; /* Id of the original RTR */ + unsigned char count; /* Added entries, incremented when posting */ + char inflight; /* Chunk still being sent */ + char reqs; /* Count request alive */ + int sizes[NCCL_UCP_MAX_RECV]; + unsigned short id; /* Id of the origin RTR again */ +} __attribute__((aligned(64))) nccl_ucp_atp_t; + +typedef struct nccl_ucp_share { + nccl_ucp_packed_rkey_t packed_rkey[NCCL_UCP_RKEY_COUNT]; + nccl_ucp_rtr_t rtr[NCCL_UCP_RING_SIZE]; + nccl_ucp_atp_t atp[NCCL_UCP_RING_SIZE]; + unsigned dummy_mem; /* Read-flush into it */ +} nccl_ucp_share_t; + +/* Exchanged OOB to connect to the remote communicator */ +typedef struct nccl_ucp_address { + /* Remote communicator pointer */ + struct nccl_ucp_comm *comm; + + /* Key and address for shared memory area */ + size_t share_rkey_length; + uint8_t share_rkey[NCCL_UCP_RKEY_SIZE]; + nccl_ucp_share_t *share; + + /* Worker address */ + size_t address_length; + uint8_t address[NCCL_UCP_WORKER_ADDR_SIZE]; +} nccl_ucp_address_t; + +typedef struct { + unsigned short rkey_id; /* Shared key identifier */ + int mem_type; + ucp_mem_h ucp_memh; + void *rkey_buf; /* Packed key */ + size_t rkey_buf_size; + int sent; /* Set to 1 only when PUT has been started */ + ucp_rkey_h rkey; /* Set only for local read-based gpu flush */ +} nccl_ucp_memh_t; + +/* NCCL UCX plugin request */ +typedef struct nccl_ucp_req { + struct nccl_ucp_comm *comm; /* Owner communicator */ + nccl_ucp_req_type_t type; /* Type of the request */ + unsigned short rtr_id; /* Id of the RTR received */ + int inflight; /* Set to zero when completed, irecv side */ +} nccl_ucp_req_t; + +/* Unpacked rkeys */ +typedef struct { + unsigned short rkey_id; + ucp_rkey_h rkey; +} nccl_ucp_rkey_t; + +typedef struct nccl_ucp_comm { + struct ncclSocket sock; /* OOB connection descriptor */ + int dev; /* Device ID of the communicator */ + nccl_ucp_worker_t *worker; /* Worker for the communicator */ + int gpu_flush; /* True if enabled */ + nccl_ucp_req_type_t type; /* Isend or Irecv side */ + + unsigned short req_id; /* Next request id to use */ + unsigned short rtr_id; /* Next RTR id to use */ + unsigned short rkey_id; /* Next rkey identifier */ + + unsigned total; /* Current requests in progress */ + int inflight_rkey; /* Total remote keys being sent */ + int delay_atp; /* Send ATP after remote completion */ + + /* Connected endpoints */ + ucp_ep_h ucp_ep; /* Remote endpoint */ + ucp_ep_h ucp_flush_ep; /* Local flush endpoint */ + + /* In-flight NCCL-UCX requests (send/receive/flush) */ + nccl_ucp_req_t req[NCCL_UCP_RING_SIZE]; + + /* Unpacked received rkeys */ + nccl_ucp_rkey_t rkey[NCCL_UCP_RKEY_COUNT]; + + /* Local registered memory area */ + struct { + nccl_ucp_share_t share; /* Remotely accessible memory area */ + nccl_ucp_memh_t *share_mh; /* Local memory handle of the share */ + } local; + + /* Remote shared memory area */ + struct { + nccl_ucp_share_t *share; + ucp_rkey_h rkey; + } remote; + + /* Remote worker address */ + nccl_ucp_address_t peer; +} nccl_ucp_comm_t; + +typedef struct { + nccl_ucp_state_t state; + nccl_ucp_comm_t *comm; + + int offset; + int ready; +} nccl_ucp_stage_t; + +typedef struct { + int dev; + int id; + struct ncclSocket sock; + nccl_ucp_stage_t stage; +} nccl_ucp_listen_comm_t; + +typedef struct { + unsigned int magic; + struct { + int id; + union ncclSocketAddress addr; + } listener; + nccl_ucp_stage_t stage; +} nccl_ucp_listen_handle_t; + +static nccl_ucp_context_t context = {.dev_count = -1}; + +static pthread_mutex_t global_lock = PTHREAD_MUTEX_INITIALIZER; + +static ncclResult_t nccl_ucx_rma_init(ncclDebugLogger_t logFunction) { + return nccl_p2p_ib_init(&context.dev_count, ncclIbDevs, context.if_name, + &context.if_addr, NULL, logFunction); +} -ncclResult_t nccl_ucx_rma_devices(int* ndev) { - *ndev = ncclNIbDevs; +static ncclResult_t nccl_ucx_rma_devices(int *ndev) { + *ndev = context.dev_count; return ncclSuccess; } -ncclResult_t nccl_ucx_rma_get_properties(int dev, ncclNetProperties_t* props) -{ +static ncclResult_t nccl_ucx_rma_get_properties(int dev, + ncclNetProperties_t *props) { return nccl_p2p_ib_get_properties(ncclIbDevs, dev, props); } -ncclResult_t nccl_ucx_rma_get_properties_v7(int dev, ncclNetProperties_v7_t* props_v7) -{ - ncclNetProperties_t props; - ncclResult_t ret = nccl_ucx_rma_get_properties(dev, &props); - if (ret != ncclSuccess) return ret; - props_v7->name = props.name; - props_v7->pciPath = props.pciPath; - props_v7->guid = props.guid; - props_v7->ptrSupport = props.ptrSupport; - props_v7->speed = props.speed; - props_v7->latency = props.latency; - props_v7->port = props.port; - props_v7->maxComms = props.maxComms; - props_v7->maxRecvs = props.maxRecvs; - props_v7->netDeviceType = props.netDeviceType; - props_v7->netDeviceVersion = props.netDeviceVersion; - return ncclSuccess; +static ncclResult_t nccl_ucx_rma_listen(int dev, void *listen_handle, + void **listen_comm) { + nccl_ucp_listen_handle_t *handle = listen_handle; + nccl_ucp_listen_comm_t *l_comm; + union ncclSocketAddress addr; + + NCCL_STATIC_ASSERT(sizeof(nccl_ucp_listen_handle_t) < NCCL_NET_HANDLE_MAXSIZE, + "UCP listen handle is too big"); + + l_comm = calloc(1, sizeof(*l_comm)); + if (l_comm == NULL) { + return ncclSystemError; + } + + /* Prepare socket */ + NCCLCHECK(ncclSocketInit(&l_comm->sock, &context.if_addr, + NCCL_UCP_HANDLE_MAGIC, ncclSocketTypeNetIb, NULL, + 1)); + NCCLCHECK(ncclSocketListen(&l_comm->sock)); + NCCLCHECK(ncclSocketGetAddr(&l_comm->sock, &addr)); + + /* Prepare listen communicator */ + l_comm->dev = dev; + l_comm->id = context.listener_count++; + *listen_comm = l_comm; + + /* Prepare handle to send */ + memset(handle, 0, sizeof(*handle)); + handle->magic = NCCL_UCP_HANDLE_MAGIC; + handle->listener.id = l_comm->id; + handle->listener.addr = addr; + + INFO(NCCL_INIT | NCCL_NET, "Listening id=%d dev=%d l_comm=%p", l_comm->id, + dev, l_comm); + return ncclSuccess; } -ncclResult_t nccl_ucx_rma_get_properties_v6(int dev, ncclNetProperties_v6_t* props_v6) -{ - ncclNetProperties_t props; - ncclResult_t ret = nccl_ucx_rma_get_properties(dev, &props); - if (ret != ncclSuccess) return ret; - props_v6->name = props.name; - props_v6->pciPath = props.pciPath; - props_v6->guid = props.guid; - props_v6->ptrSupport = props.ptrSupport; - props_v6->speed = props.speed; - props_v6->latency = props.latency; - props_v6->port = props.port; - props_v6->maxComms = props.maxComms; - props_v6->maxRecvs = props.maxRecvs; + +static ncclResult_t nccl_ucx_rma_close_listen(void *listen_comm) { + nccl_ucp_listen_comm_t *comm = listen_comm; + + if (comm) { + NCCLCHECK(ncclSocketClose(&comm->sock)); + free(comm); + } + return ncclSuccess; } -pthread_mutex_t nccl_ucx_rma_lock = PTHREAD_MUTEX_INITIALIZER; - -typedef struct ucx_rma_listen_handle { - union ncclSocketAddress connectAddr; /* reciever socket address */ - uint64_t magic; /* random number to help debugging */ - ucp_tag_t tag; /* tag that is used to distiguish data that was sent to - this reciever. Required when shared worker is used. */ - struct ncclUCXCommStage stage; -} ucx_rma_listen_handle_t; - -typedef struct nccl_ucx_rma_listen_comm { - int dev; - struct ncclSocket sock;/* socket for OOB connection */ - struct ncclUCXCommStage stage; -} nccl_ucx_rma_listen_comm_t; - -struct ep_list { - struct ncclSocket *sock; - struct ep_list *next; -}; +static ncclResult_t nccl_ucp_worker_init(nccl_ucp_worker_t *w, int dev, + ucp_context_h ucp_context) { + ucp_worker_params_t params = {.field_mask = + UCP_WORKER_PARAM_FIELD_THREAD_MODE, + .thread_mode = UCS_THREAD_MODE_MULTI}; + ucp_worker_attr_t attr = {.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE | + UCP_WORKER_ATTR_FIELD_ADDRESS | + UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS, + .address_flags = UCP_WORKER_ADDRESS_FLAG_NET_ONLY}; + + w->dev = dev; + w->ucp_context = ucp_context; + w->next = context.workers; + context.workers = w; + + UCXCHECK(ucp_worker_create(w->ucp_context, ¶ms, &w->ucp_worker)); + UCXCHECK(ucp_worker_query(w->ucp_worker, &attr)); + + if (attr.thread_mode != UCS_THREAD_MODE_MULTI) { + INFO(NCCL_NET, "Thread mode multi is not supported"); + } -struct nccl_ucx_worker { - ucp_context_h ctx; - ucp_worker_h worker; - int count; - struct ep_list *eps; -}; -static struct nccl_ucx_worker workers[MAX_IB_DEVS]; - -typedef struct ucx_gpu_flush { - int enabled; - int hostMem; - ucp_ep_h flush_ep; -} ucx_gpu_flush_t; - -enum { - UCX_RMA_REQ_TYPE_SEND, - UCX_RMA_REQ_TYPE_RECV, - UCX_RMA_REQ_TYPE_FLUSH, -}; + w->address_length = attr.address_length; + w->address = malloc(attr.address_length); + if (w->address == NULL) { + WARN("Failed to allocate worker address"); + goto err; + } -#define MAX_UCX_REQ_SIZE 256 -typedef struct nccl_ucx_rma_request { - char ucx_req[MAX_UCX_REQ_SIZE]; - int used; - int type; - int done; - int size; - int free; - uint64_t am_msg; - int seq; - ucs_status_ptr_t st; - ucp_worker_h worker; -} nccl_ucx_rma_request_t; - -typedef struct ucx_rma_send_fifo { - uint64_t addr; - uint64_t addr_request; - int size; - uint32_t seq; - uint32_t ready; - int rkey_idx; - int rkey_id; - int req_id; -} ucx_rma_send_fifo_t; - -#define NCCL_UCX_RMA_MAX_MHANDLES 16 -typedef struct nccl_ucx_rma_ctx { - int id; - int ready; - struct ncclSocket sock; - ucs_status_ptr_t check_req; - ucp_context_h ctx; - ucp_worker_h worker; - ucx_gpu_flush_t gpuFlush; - uint64_t num_mh; - ucx_rma_mhandle_t *mh[NCCL_UCX_RMA_MAX_MHANDLES]; - nccl_ucx_rma_request_t reqs[MAX_REQUESTS]; -} nccl_ucx_rma_ctx_t; - -typedef struct nccl_ucx_rma_rkey { - ucp_rkey_h rkey; - int id; -} nccl_ucx_rma_rkey_t; - -typedef struct nccl_ucx_rma_send_comm { - nccl_ucx_rma_ctx_t super; - ucp_ep_h ep; - ucx_rma_send_fifo_t fifo[MAX_REQUESTS]; - uint32_t fifo_head; - ucp_mem_h fifo_memh; - nccl_ucx_rma_rkey_t rkeys[NCCL_UCX_RMA_MAX_MHANDLES]; - int rem_am_id; -} nccl_ucx_rma_send_comm_t; - -typedef struct ucx_rma_rem_fifo { - ucx_rma_send_fifo_t elems[MAX_REQUESTS]; - uint64_t addr; - ucp_rkey_h rkey; - uint32_t tail; -} ucx_rma_rem_fifo_t; - -typedef struct nccl_ucx_rma_recv_comm { - nccl_ucx_rma_ctx_t super; - ucp_ep_h ep; - ucx_rma_rem_fifo_t rem_fifo; - int rem_am_id; - void *rkey_bufs; -} nccl_ucx_rma_recv_comm_t; - - -static union ncclSocketAddress nccl_ucx_if_addr; -static char if_name[MAX_IF_NAME_SIZE]; - -typedef struct nccl_ucx_am_request { - nccl_ucx_rma_request_t *req; -} nccl_ucx_am_request_t; - -typedef nccl_ucx_am_request_t nccl_ucx_flush_request_t; - -static ncclResult_t nccl_ucx_rma_init_ucp(int dev, ucp_context_h *ctx) -{ - ucp_params_t ucp_params; + memcpy(w->address, attr.address, attr.address_length); + ucp_worker_release_address(w->ucp_worker, attr.address); + return ncclSuccess; + +err: + ucp_worker_release_address(w->ucp_worker, attr.address); + ucp_worker_destroy(w->ucp_worker); + return ncclSystemError; +} + +static ncclResult_t nccl_ucp_context_create(int dev, + ucp_context_h *ucp_context) { + ucp_params_t params; ucp_config_t *config; - char ucx_dev_name[PATH_MAX]; + char ucx_dev_name[128]; + ucs_status_t status; - snprintf(ucx_dev_name, PATH_MAX, "%s:%d", ncclIbDevs[dev].devName, + snprintf(ucx_dev_name, sizeof(ucx_dev_name), "%s:%d", ncclIbDevs[dev].devName, ncclIbDevs[dev].portNum); UCXCHECK(ucp_config_read("NCCL", NULL, &config)); - UCXCHECK(ucp_config_modify(config, "NET_DEVICES", ucx_dev_name)); - UCXCHECK(ucp_config_modify(config, "TLS", nccl_ucx_rma_tls)); - UCXCHECK(ucp_config_modify(config, "ZCOPY_THRESH", nccl_ucx_rma_zcopy_thresh)); + UCXCHECK(ucp_config_modify(config, "TLS", "rc_x")); - memset(&ucp_params, 0, sizeof(ucp_params)); - ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES | - UCP_PARAM_FIELD_REQUEST_SIZE; - ucp_params.features = UCP_FEATURE_RMA | - UCP_FEATURE_AM; - ucp_params.request_size = sizeof(nccl_ucx_am_request_t); + params.field_mask = UCP_PARAM_FIELD_FEATURES; + params.features = UCP_FEATURE_RMA | UCP_FEATURE_AM; - UCXCHECK(ucp_init(&ucp_params, config, ctx)); + status = ucp_init(¶ms, config, ucp_context); ucp_config_release(config); - + NCCLCHECK(status); return ncclSuccess; } -static ncclResult_t nccl_ucx_rma_init_worker(ucp_context_h ctx, - ucp_worker_h *worker) -{ - ucp_worker_params_t worker_params; - ucp_worker_attr_t worker_attr; - - memset(&worker_params, 0, sizeof(worker_params)); - worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; - worker_params.thread_mode = UCS_THREAD_MODE_MULTI; +static nccl_ucp_worker_t *nccl_ucp_worker_get(int dev) { + nccl_ucp_worker_t *w; + ucp_context_h ucp_context; - UCXCHECK(ucp_worker_create(ctx, &worker_params, worker)); - - worker_attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE; - ucp_worker_query(*worker, &worker_attr); - if (worker_attr.thread_mode != UCS_THREAD_MODE_MULTI) { - INFO(NCCL_NET, "Thread mode multi is not supported"); + pthread_mutex_lock(&global_lock); + w = calloc(1, sizeof(*w)); + if (w == NULL) { + goto fail; } - return ncclSuccess; -} - -#define UCX_RMA_USE_SHARED_WORKER -static ncclResult_t nccl_ucx_rma_init_comm_context(int dev, - nccl_ucx_rma_ctx_t *comm_ctx) -{ - pthread_mutex_lock(&nccl_ucx_rma_lock); -#ifdef UCX_RMA_USE_SHARED_WORKER - if (workers[dev].count == 0) { - nccl_ucx_rma_init_ucp(dev, &workers[dev].ctx); - nccl_ucx_rma_init_worker(workers[dev].ctx, &workers[dev].worker); - workers->count = 0; - workers->eps = NULL; + if (nccl_ucp_context_create(dev, &ucp_context) != ncclSuccess) { + goto fail; } - comm_ctx->ctx = workers[dev].ctx; - comm_ctx->worker = workers[dev].worker; - comm_ctx->id = workers[dev].count; - workers[dev].count++; -#else - nccl_ucx_rma_init_ucp(dev, &comm_ctx->ctx); - nccl_ucx_rma_init_worker(comm_ctx->ctx, &comm_ctx->worker); -#endif - pthread_mutex_unlock(&nccl_ucx_rma_lock); - return ncclSuccess; -} - -static ncclResult_t nccl_ucx_rma_send_worker_address(ucp_worker_h worker, struct ncclSocket *sock) -{ - ucp_worker_attr_t attr; - - attr.field_mask = UCP_WORKER_ATTR_FIELD_ADDRESS | - UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS; - attr.address_flags = UCP_WORKER_ADDRESS_FLAG_NET_ONLY; + if (nccl_ucp_worker_init(w, dev, ucp_context) != ncclSuccess) { + ucp_cleanup(ucp_context); + goto fail; + } - UCXCHECK(ucp_worker_query(worker, &attr)); - NCCLCHECK(ncclSocketSend(sock, &attr.address_length, sizeof(attr.address_length))); - NCCLCHECK(ncclSocketSend(sock, attr.address, attr.address_length)); + w->used++; + pthread_mutex_unlock(&global_lock); + return w; - free(attr.address); - return ncclSuccess; +fail: + free(w); + pthread_mutex_unlock(&global_lock); + return NULL; } -static ncclResult_t nccl_ucx_free_worker(ucp_worker_h worker) -{ - int i; - int dummy; - struct ep_list *ep, *cur; - - pthread_mutex_lock(&nccl_ucx_rma_lock); - for(i = 0; i < ncclNIbDevs; i++) { - if (worker == workers[i].worker) { - workers[i].count--; - if (workers[i].count == 0) { - ep = workers[i].eps; - while(ep) { - cur = ep; - NCCLCHECK(ncclSocketRecv(ep->sock, &dummy, sizeof(int))); - ep = ep->next; - close(cur->sock->fd); - free(cur); - } - ucp_worker_destroy(workers[i].worker); - ucp_cleanup(workers[i].ctx); - INFO(NCCL_NET, "worker destroy"); - workers[i].eps = NULL; - workers[i].worker = NULL; - workers[i].ctx = NULL; +static void nccl_ucp_worker_put(nccl_ucp_worker_t *worker) { + int found = 0; + nccl_ucp_worker_t **w; + (void)found; + + pthread_mutex_lock(&global_lock); + if (--worker->used < 1) { + for (w = &context.workers; *w != NULL; w = &(*w)->next) { + if (*w == worker) { + *w = worker->next; + found = 1; + break; } - break; } + + assert(found == 1); + assert(worker->used == 0); + free(worker->address); + ucp_worker_destroy(worker->ucp_worker); + ucp_cleanup(worker->ucp_context); + free(worker); } - pthread_mutex_unlock(&nccl_ucx_rma_lock); - return ncclSuccess; + pthread_mutex_unlock(&global_lock); } -static ncclResult_t nccl_ucx_add_ep(ucp_worker_h worker, struct ncclSocket *sock) -{ - ncclResult_t status = ncclSuccess; - int i; +static nccl_ucp_memh_t *nccl_ucp_mem_register(nccl_ucp_comm_t *comm, void *data, + size_t size, int type) { + uint64_t addr; + nccl_ucp_memh_t *mh; + ucp_mem_map_params_t params; + ucs_status_t status; - for(i = 0; i < ncclNIbDevs; i++) { - if (worker == workers[i].worker) { - struct ep_list *new_ep = (struct ep_list*)malloc(sizeof(struct ep_list)); + mh = calloc(1, sizeof(*mh)); + if (mh == NULL) { + return NULL; + } - if (new_ep == NULL) { - status = ncclSystemError; - break; - } + mh->mem_type = + (type == NCCL_PTR_HOST) ? UCS_MEMORY_TYPE_HOST : UCS_MEMORY_TYPE_CUDA; + addr = (uint64_t)data & ~REG_MASK; + size = ROUNDUP(size + ((uint64_t)data & REG_MASK), REG_ALIGN); - new_ep->sock = sock; - new_ep->next = workers[i].eps; - workers[i].eps = new_ep; - break; - } + params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | + UCP_MEM_MAP_PARAM_FIELD_LENGTH | + UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; + params.address = (void*)addr; + params.length = size; + params.memory_type = mh->mem_type; + + status = ucp_mem_map(comm->worker->ucp_context, ¶ms, &mh->ucp_memh); + if (status != UCS_OK) { + WARN("Memory registration failed for comm=%p mem=%p/%zu", comm, (void*)addr, + size); + free(mh); + return NULL; } - return status; + status = ucp_rkey_pack(comm->worker->ucp_context, mh->ucp_memh, &mh->rkey_buf, + &mh->rkey_buf_size); + if (status != UCS_OK) { + WARN("Rkey packing failed comm=%p", comm); + ucp_mem_unmap(comm->worker->ucp_context, mh->ucp_memh); + free(mh); + return NULL; + } + + return mh; } -ncclResult_t nccl_ucx_rma_init(ncclDebugLogger_t logFunction) -{ - char *config_env; - if (ncclParamUCXRMADisable()) return ncclInternalError; - NCCLCHECK(nccl_p2p_ib_init(&ncclNIbDevs, ncclIbDevs, if_name, &nccl_ucx_if_addr, - NULL, logFunction)); - - if (strlen(nccl_ucx_rma_tls) == 0) { - config_env = getenv("NCCL_UCX_TLS"); - if (config_env != NULL) { - snprintf(nccl_ucx_rma_tls, 32, "%s", config_env); - } else { - snprintf(nccl_ucx_rma_tls, 32, "%s", "ib"); - } - INFO(NCCL_NET, "NET/UCX_RMA: using transports: %s", nccl_ucx_rma_tls); - } +static ncclResult_t nccl_ucx_rma_deregmr(void *dereg_comm, void *mhandle) { + nccl_ucp_comm_t *comm = dereg_comm; + nccl_ucp_memh_t *mh = mhandle; - if (strlen(nccl_ucx_rma_zcopy_thresh) == 0) { - config_env = getenv("NCCL_UCX_ZCOPY_THRESH"); - if (config_env != NULL) { - snprintf(nccl_ucx_rma_zcopy_thresh, 32, "%s", config_env); - } else { - snprintf(nccl_ucx_rma_zcopy_thresh, 32, "%s", "1"); - } - INFO(NCCL_NET, "NET/UCX_RMA: zero copy threshold: %s", nccl_ucx_rma_zcopy_thresh); + ucp_rkey_buffer_release(mh->rkey_buf); + if (mh->rkey != NULL) { + ucp_rkey_destroy(mh->rkey); } + ucp_mem_unmap(comm->worker->ucp_context, mh->ucp_memh); + free(mh); return ncclSuccess; } -ncclResult_t nccl_ucx_rma_listen(int dev, void *handle, void **listen_comm) -{ - ucx_rma_listen_handle_t *my_handle = (ucx_rma_listen_handle_t*)handle; - nccl_ucx_rma_listen_comm_t *comm; - - NCCL_STATIC_ASSERT(sizeof(ucx_rma_listen_handle_t) < NCCL_NET_HANDLE_MAXSIZE, - "UCX-RMA listen handle size too large"); +static ncclResult_t nccl_ucp_flush_ep_init(nccl_ucp_comm_t *comm) { + ucp_worker_attr_t attr = {.field_mask = UCP_WORKER_ATTR_FIELD_ADDRESS | + UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS, + .address_flags = UCP_WORKER_ADDRESS_FLAG_NET_ONLY}; + ucp_ep_params_t params; - my_handle->magic = NCCL_SOCKET_MAGIC; - NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(nccl_ucx_rma_listen_comm_t))); - NCCLCHECK(ncclSocketInit(&comm->sock, &nccl_ucx_if_addr, my_handle->magic, ncclSocketTypeNetIb, NULL, 1)); - NCCLCHECK(ncclSocketListen(&comm->sock)); - NCCLCHECK(ncclSocketGetAddr(&comm->sock, &my_handle->connectAddr)); + UCXCHECK(ucp_worker_query(comm->worker->ucp_worker, &attr)); - comm->dev = dev; - *listen_comm = comm; - + params.field_mask = + UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE; + params.address = attr.address; + params.err_mode = UCP_ERR_HANDLING_MODE_PEER; /* Mandatory with force close */ + UCXCHECK( + ucp_ep_create(comm->worker->ucp_worker, ¶ms, &comm->ucp_flush_ep)); + free(attr.address); return ncclSuccess; } -static ucs_status_t nccl_ucx_rma_am_rkey_cb(void *arg, void *data, size_t length, - ucp_ep_h reply_ep, unsigned flags) -{ - nccl_ucx_rma_send_comm_t *comm = (nccl_ucx_rma_send_comm_t*)arg; - nccl_ucx_rma_rkey_buf_t *rkey_buf = (nccl_ucx_rma_rkey_buf_t*)data; - ucs_status_t status; - - if (comm->rkeys[rkey_buf->index].rkey) { - ucp_rkey_destroy(comm->rkeys[rkey_buf->index].rkey); - } - comm->rkeys[rkey_buf->index].id = rkey_buf->id; - status = ucp_ep_rkey_unpack(comm->ep, rkey_buf->buf, - &comm->rkeys[rkey_buf->index].rkey); - if (status != UCS_OK) { - WARN("Failed: UCX am rkey cb: rkey unpack error %s", - ucs_status_string(status)); +static nccl_ucp_comm_t *nccl_ucp_comm_create(int dev, + nccl_ucp_req_type_t type) { + nccl_ucp_comm_t *comm = calloc(1, sizeof(*comm)); + if (comm == NULL) { + return comm; } - return UCS_OK; -} + comm->worker = nccl_ucp_worker_get(dev); + if (comm->worker == NULL) { + goto err; + } + comm->local.share_mh = nccl_ucp_mem_register( + comm, &comm->local.share, sizeof(comm->local.share), NCCL_PTR_HOST); + if (comm->local.share_mh == NULL) { + goto err; + } -ncclResult_t nccl_ucx_rma_connect(int dev, void *handle, void **send_comm, ncclNetDeviceHandle_t** sendDevComm) -{ - ucx_rma_listen_handle_t *recv_handle = (ucx_rma_listen_handle_t*)handle; - struct ncclUCXCommStage* stage = &recv_handle->stage; - nccl_ucx_rma_send_comm_t *comm = stage->comm; - ucp_mem_map_params_t mmap_params; - size_t rkey_buf_size; - void *rkey_buf; - uint64_t fifo_adr; - int i; - int ready; + comm->type = type; + comm->dev = dev; + comm->rtr_id = 1; + comm->req_id = 1; + comm->rkey_id = 1; + comm->delay_atp = !!ncclParamUCXAckDelay(); + comm->gpu_flush = (nccl_p2p_gdr_support(comm->dev) == ncclSuccess) || + (nccl_p2p_dmabuf_support(comm->dev) == ncclSuccess); + if (comm->gpu_flush && (nccl_ucp_flush_ep_init(comm) != ncclSuccess)) { + nccl_ucx_rma_deregmr(comm, comm->local.share_mh); + goto err; + } - *send_comm = NULL; + return comm; - if (stage->state == ncclUCXCommStateConnect) goto ucx_connect_check; - - NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(*comm))); - NCCLCHECK(ncclSocketInit(&comm->super.sock, &recv_handle->connectAddr, recv_handle->magic, ncclSocketTypeNetIb, NULL, 1)); - stage->comm = comm; - stage->state = ncclUCXCommStateConnect; - NCCLCHECK(ncclSocketConnect(&comm->super.sock)); - -ucx_connect_check: - /* since ncclSocketConnect is async, we must check if connection is complete */ - NCCLCHECK(ncclSocketReady(&comm->super.sock, &ready)); - if (!ready) return ncclSuccess; - - NCCLCHECK(nccl_ucx_rma_init_comm_context(dev, &comm->super)); - NCCLCHECK(nccl_ucx_rma_send_worker_address(comm->super.worker, &comm->super.sock)); - NCCLCHECK(nccl_ucx_add_ep(comm->super.worker, &comm->super.sock)); - UCXCHECK(ucp_worker_set_am_handler(comm->super.worker, comm->super.id, - nccl_ucx_rma_am_rkey_cb, comm, - UCP_AM_FLAG_WHOLE_MSG)); - for (i = 0; i < NCCL_UCX_RMA_MAX_MHANDLES; i++) { - comm->rkeys[i].id = -1; - } - fifo_adr = (uint64_t)comm->fifo; - mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | - UCP_MEM_MAP_PARAM_FIELD_LENGTH; - mmap_params.address = (void*)fifo_adr; - mmap_params.length = sizeof(ucx_rma_send_fifo_t) * - MAX_REQUESTS; - UCXCHECK(ucp_mem_map(comm->super.ctx, &mmap_params, &comm->fifo_memh)); - UCXCHECK(ucp_rkey_pack(comm->super.ctx, comm->fifo_memh, &rkey_buf, &rkey_buf_size)); - NCCLCHECK(ncclSocketSend(&comm->super.sock, &rkey_buf_size, sizeof(size_t))); - NCCLCHECK(ncclSocketSend(&comm->super.sock, rkey_buf, rkey_buf_size)); - NCCLCHECK(ncclSocketSend(&comm->super.sock, &fifo_adr, sizeof(uint64_t))); - NCCLCHECK(ncclSocketSend(&comm->super.sock, &comm->super.id, sizeof(comm->super.id))); - ucp_rkey_buffer_release(rkey_buf); - *send_comm = comm; +err: + free(comm); + return NULL; +} +static ncclResult_t nccl_ucp_ep_create(nccl_ucp_comm_t *comm) { + ucp_ep_params_t params = { + .field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | + UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE, + .address = (void*)comm->peer.address, + .err_mode = UCP_ERR_HANDLING_MODE_PEER /* Mandatory with force close */ + }; + + UCXCHECK(ucp_ep_create(comm->worker->ucp_worker, ¶ms, &comm->ucp_ep)); + UCXCHECK(ucp_ep_rkey_unpack(comm->ucp_ep, comm->peer.share_rkey, + &comm->remote.rkey)); + comm->remote.share = comm->peer.share; return ncclSuccess; } -ncclResult_t nccl_ucx_rma_connect_v6(int dev, void *handle, void **send_comm) -{ - ncclNetDeviceHandle_v7_t* dev_handle = NULL; - return nccl_ucx_rma_connect(dev, handle, send_comm, &dev_handle); +static ncclResult_t nccl_ucx_rma_address_send(nccl_ucp_comm_t *comm) { + struct nccl_ucp_address peer; + + assert(comm->worker->address_length <= sizeof(peer.address)); + assert(comm->local.share_mh->rkey_buf_size <= sizeof(peer.share_rkey)); + + peer.comm = comm; + peer.address_length = comm->worker->address_length; + peer.share_rkey_length = comm->local.share_mh->rkey_buf_size; + peer.share = &comm->local.share; + memcpy(peer.address, comm->worker->address, comm->worker->address_length); + memcpy(peer.share_rkey, comm->local.share_mh->rkey_buf, + comm->local.share_mh->rkey_buf_size); + + return ncclSocketSend(&comm->sock, &peer, sizeof(peer)); } -enum { - NCCL_UCX_RMA_REQUEST_INPROGRESS = 0, - NCCL_UCX_RMA_REQUEST_PUT_DONE = 1, - NCCL_UCX_RMA_REQUEST_AM_DONE = 2, - NCCL_UCX_RMA_REQUEST_DONE = 3, -}; +static ncclResult_t nccl_ucx_rma_connect(int dev, void *listen_handle, + void **send_comm, + ncclNetDeviceHandle_t **sendDevComm) { + nccl_ucp_listen_handle_t *handle = listen_handle; + nccl_ucp_stage_t *stage = &handle->stage; + nccl_ucp_comm_t *comm = stage->comm; + int ready = 0; -static ucs_status_t nccl_ucx_rma_am_cb(void *arg, void *data, size_t length, - ucp_ep_h reply_ep, unsigned flags) -{ - nccl_ucx_rma_request_t *reqs = (nccl_ucx_rma_request_t*)arg; - uint64_t *header = data; - int size = *header & 0xFFFFFFFFFFFFFFFF; - int id = *header >>32 ; + *send_comm = NULL; - reqs[id].size = size; - reqs[id].done = NCCL_UCX_RMA_REQUEST_DONE; + switch (stage->state) { + case NCCL_UCP_START: + comm = nccl_ucp_comm_create(dev, NCCL_UCP_TYPE_ISEND); + stage->comm = comm; + if (stage->comm == NULL) { + return ncclSystemError; + } - return UCS_OK; -} + NCCLCHECK(ncclSocketInit(&stage->comm->sock, &handle->listener.addr, + handle->magic, ncclSocketTypeNetIb, NULL, 1)); + NCCLCHECK(ncclSocketConnect(&stage->comm->sock)); -static ncclResult_t nccl_ucx_rma_init_ep(struct ncclSocket *sock, ucp_worker_h worker, ucp_ep_h *ep, int blocking) -{ - int bytes = 0; - ucp_ep_params_t ep_params; - size_t peer_addr_len; - void *peer_addr; + stage->state = NCCL_UCP_CONNECT; + /* fallthrough */ - if (blocking) { - NCCLCHECK(ncclSocketRecv(sock, &peer_addr_len, sizeof(size_t))); - } else { - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &peer_addr_len, - sizeof(size_t), &bytes)); - if (bytes == 0) { - ep = NULL; + case NCCL_UCP_CONNECT: + NCCLCHECK(ncclSocketReady(&stage->comm->sock, &ready)); + if (!ready) { return ncclSuccess; } - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, sock, &peer_addr_len, - sizeof(size_t), &bytes)); - } - peer_addr = alloca(peer_addr_len); - NCCLCHECK(ncclSocketRecv(sock, peer_addr, peer_addr_len)); - ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = peer_addr; - UCXCHECK(ucp_ep_create(worker, &ep_params, ep)); + NCCLCHECK(nccl_ucx_rma_address_send(comm)); + + stage->offset = 0; + stage->state = NCCL_UCP_RECEIVE_REMOTE; + /* fallthrough */ + + case NCCL_UCP_RECEIVE_REMOTE: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &comm->peer, + sizeof(comm->peer), &stage->offset)); + if (stage->offset != sizeof(comm->peer)) { + return ncclSuccess; + } + + NCCLCHECK(nccl_ucp_ep_create(comm)); + + ready = 1; + NCCLCHECK(ncclSocketSend(&comm->sock, &ready, sizeof(ready))); + + *send_comm = comm; + stage->ready = 0; + stage->offset = 0; + stage->state = NCCL_UCP_DONE; + INFO(NCCL_INIT | NCCL_NET, + "Connected comm=%p remote_comm=%p listener_id=%d " + "ack_delay=%d ack_skip=%d", + comm, comm->peer.comm, handle->listener.id, ncclParamUCXAckDelay(), + ncclParamUCXAckSkip()); + break; + + default: + break; + } return ncclSuccess; } -ncclResult_t nccl_ucx_rma_accept(void *listen_comm, void **recv_comm, ncclNetDeviceHandle_v7_t** recvDevComm) -{ - nccl_ucx_rma_listen_comm_t *l_comm = (nccl_ucx_rma_listen_comm_t *)listen_comm; - struct ncclUCXCommStage* stage = &l_comm->stage; - nccl_ucx_rma_recv_comm_t *r_comm = stage->comm; - void *rkey_buf; - size_t rkey_buf_size; - int ready; - +ncclResult_t nccl_ucx_rma_accept(void *listen_comm, void **recv_comm, + ncclNetDeviceHandle_v7_t **recvDevComm) { + nccl_ucp_listen_comm_t *l_comm = listen_comm; + nccl_ucp_stage_t *stage = &l_comm->stage; + nccl_ucp_comm_t *comm = stage->comm; + int ready = 0; + *recv_comm = NULL; - if (stage->state == ncclUCXCommStateAccept) goto ucx_accept_check; - NCCLCHECK(ncclIbMalloc((void**)&r_comm, sizeof(nccl_ucx_rma_recv_comm_t))); - stage->comm = r_comm; - stage->state = ncclUCXCommStateAccept; - l_comm->sock.asyncFlag = 1; - r_comm->super.sock.asyncFlag = 1; + switch (stage->state) { + case NCCL_UCP_START: + comm = nccl_ucp_comm_create(l_comm->dev, NCCL_UCP_TYPE_IRECV); + stage->comm = comm; + if (stage->comm == NULL) { + return ncclSystemError; + } - NCCLCHECK(ncclSocketInit(&r_comm->super.sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0)); - NCCLCHECK(ncclSocketAccept(&r_comm->super.sock, &l_comm->sock)); + NCCLCHECK(ncclSocketInit(&comm->sock, NULL, NCCL_UCP_HANDLE_MAGIC, + ncclSocketTypeUnknown, NULL, 0)); + NCCLCHECK(ncclSocketAccept(&comm->sock, &l_comm->sock)); -ucx_accept_check: - NCCLCHECK(ncclSocketReady(&r_comm->super.sock, &ready)); - if (!ready) return ncclSuccess; + stage->state = NCCL_UCP_ACCEPT; + /* fallthrough */ - NCCLCHECK(nccl_ucx_rma_init_comm_context(l_comm->dev, &r_comm->super)); - UCXCHECK(ucp_worker_set_am_handler(r_comm->super.worker, r_comm->super.id, - nccl_ucx_rma_am_cb, r_comm->super.reqs, - UCP_AM_FLAG_WHOLE_MSG)); + case NCCL_UCP_ACCEPT: + NCCLCHECK(ncclSocketReady(&comm->sock, &ready)); + if (!ready) { + return ncclSuccess; + } - NCCLCHECK(nccl_ucx_rma_init_ep(&r_comm->super.sock, r_comm->super.worker, &r_comm->ep, 1)); - NCCLCHECK(nccl_ucx_add_ep(r_comm->super.worker, &r_comm->super.sock)); - NCCLCHECK(ncclSocketRecv(&r_comm->super.sock, &rkey_buf_size, sizeof(size_t))); + stage->offset = 0; + stage->state = NCCL_UCP_RECEIVE_REMOTE; + /* fallthrough */ - rkey_buf = malloc(rkey_buf_size); - if (rkey_buf == NULL) { - return ncclSystemError; - } - NCCLCHECK(ncclSocketRecv(&r_comm->super.sock, rkey_buf, rkey_buf_size)); - NCCLCHECK(ncclSocketRecv(&r_comm->super.sock, &r_comm->rem_fifo.addr, sizeof(uint64_t))); - NCCLCHECK(ncclSocketRecv(&r_comm->super.sock, &r_comm->rem_am_id, sizeof(int))); - UCXCHECK(ucp_ep_rkey_unpack(r_comm->ep, rkey_buf, &r_comm->rem_fifo.rkey)); - free(rkey_buf); - - if (nccl_p2p_gdr_support(l_comm->dev) == ncclSuccess) { - r_comm->super.gpuFlush.enabled = 1; - } + case NCCL_UCP_RECEIVE_REMOTE: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &comm->peer, + sizeof(comm->peer), &stage->offset)); + if (stage->offset != sizeof(comm->peer)) { + return ncclSuccess; + } - if (r_comm->super.gpuFlush.enabled) { - ucp_worker_attr_t attr; - ucp_ep_params_t ep_params; + NCCLCHECK(nccl_ucp_ep_create(comm)); + NCCLCHECK(nccl_ucx_rma_address_send(comm)); - attr.field_mask = UCP_WORKER_ATTR_FIELD_ADDRESS | - UCP_WORKER_ATTR_FIELD_ADDRESS_FLAGS; - attr.address_flags = UCP_WORKER_ADDRESS_FLAG_NET_ONLY; + stage->ready = 0; + stage->offset = 0; + stage->state = NCCL_UCP_RX_READY; + /* fallthrough */ - UCXCHECK(ucp_worker_query(r_comm->super.worker, &attr)); - ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = attr.address; - UCXCHECK(ucp_ep_create(r_comm->super.worker, &ep_params, - &r_comm->super.gpuFlush.flush_ep)); + case NCCL_UCP_RX_READY: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &stage->ready, + sizeof(stage->ready), &stage->offset)); + if (stage->offset != sizeof(stage->ready)) { + return ncclSuccess; /* In progress */ + } - free(attr.address); + assert(stage->ready == 1); + *recv_comm = comm; + stage->state = NCCL_UCP_DONE; + INFO(NCCL_INIT | NCCL_NET, + "Accepted comm=%p peer_comm=%p listener_id=%d ack_delay=%d " + "ack_skip=%d", + comm, comm->peer.comm, l_comm->id, ncclParamUCXAckDelay(), + ncclParamUCXAckSkip()); + break; + + default: + break; } - r_comm->super.num_mh = 0; - *recv_comm = r_comm; return ncclSuccess; } -ncclResult_t nccl_ucx_rma_accept_v6(void *listen_comm, void **recv_comm) -{ - ncclNetDeviceHandle_v7_t* dev_handle = NULL; - return nccl_ucx_rma_accept(listen_comm, recv_comm, &dev_handle); +static void nccl_ucp_rdma_callback(void *request, ucs_status_t status, + void *data) { + int *inflight = data; + assert(status == UCS_OK); + assert(*inflight > 0); + (*inflight)--; + ucp_request_free(request); } -#define REG_ALIGN (4096) -ncclResult_t nccl_ucx_rma_regmr(void* comm, void* data, size_t size, int type, - void** mhandle) -{ - nccl_ucx_rma_ctx_t *ctx = (nccl_ucx_rma_ctx_t*)comm; - uint64_t addr = (uint64_t)data; - ucp_mem_map_params_t mmap_params; - ucx_rma_mhandle_t *mh; - uint64_t reg_addr, reg_size; - void *rkey_buf; - int i; - - for (i = 0; i < NCCL_UCX_RMA_MAX_MHANDLES; i++) { - if (ctx->mh[i] == NULL) { - break; - } - } - if (i == NCCL_UCX_RMA_MAX_MHANDLES) { - WARN("NET UCX/RMA: too many mhandles"); - return ncclSystemError; - } +static void nccl_ucp_rdma_isend_callback(void *request, ucs_status_t status, + void *data) { + nccl_ucp_req_t *req = data; - NCCLCHECK(ncclIbMalloc((void**)&mh, sizeof(ucx_rma_mhandle_t))); - reg_addr = addr & (~(REG_ALIGN - 1)); - reg_size = addr + size - reg_addr; - reg_size = ROUNDUP(reg_size, REG_ALIGN); - - mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | - UCP_MEM_MAP_PARAM_FIELD_LENGTH; - mmap_params.address = (void*)reg_addr; - mmap_params.length = reg_size; - mh->mem_type = (type == NCCL_PTR_HOST)? UCS_MEMORY_TYPE_HOST: UCS_MEMORY_TYPE_CUDA; - mmap_params.field_mask |= UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE; - mmap_params.memory_type = mh->mem_type; - - UCXCHECK(ucp_mem_map(ctx->ctx, &mmap_params, &mh->ucp_memh)); - UCXCHECK(ucp_rkey_pack(ctx->ctx, mh->ucp_memh, &rkey_buf, &mh->rkey_buf.rkey_buf_size)); - if (mh->rkey_buf.rkey_buf_size > MAX_UCX_RKEY_BUF_SIZE) { - WARN("NET UCX/RMA: rkey_buf is too large"); - ucp_mem_unmap(ctx->ctx, mh->ucp_memh); - ucp_rkey_buffer_release(rkey_buf); - free(mh); - return ncclSystemError; - } - memcpy(mh->rkey_buf.buf, rkey_buf, mh->rkey_buf.rkey_buf_size); - - if (ctx->gpuFlush.enabled) { - UCXCHECK(ucp_ep_rkey_unpack(ctx->gpuFlush.flush_ep, rkey_buf, &mh->rkey)); - } - - mh->rkey_buf.index = i; - mh->rkey_buf.send = 0; - mh->rkey_buf.id = ctx->num_mh; - ctx->mh[i] = mh; - ctx->num_mh += 1; - *mhandle = mh; - ucp_rkey_buffer_release(rkey_buf); - - return ncclSuccess; + nccl_ucp_rdma_callback(request, status, &req->inflight); + req->comm->local.share.atp[req->rtr_id & NCCL_UCP_RING_MASK].inflight--; } -ncclResult_t nccl_ucx_rma_regmr_v7(void* comm, void* data, int size, int type, - void** mhandle) -{ - return nccl_ucx_rma_regmr(comm, data, (size_t)size, type, mhandle); +static ucs_status_t nccl_ucp_shared_put(nccl_ucp_comm_t *comm, void *va, + size_t size, void *rva, int *inflight) { + ucp_request_param_t param = {}; + ucs_status_ptr_t status_ptr; + + assert((rva >= (void*)comm->remote.share) && + (rva + size) <= + ((void*)comm->remote.share + sizeof(*comm->remote.share))); + + param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FIELD_MEMH | + UCP_OP_ATTR_FIELD_MEMORY_TYPE; + param.cb.send = nccl_ucp_rdma_callback; + param.user_data = inflight; + param.memh = comm->local.share_mh->ucp_memh; + param.memory_type = comm->local.share_mh->mem_type; + + status_ptr = ucp_put_nbx(comm->ucp_ep, va, size, (uint64_t)rva, + comm->remote.rkey, ¶m); + return UCS_PTR_STATUS(status_ptr); } -ncclResult_t nccl_ucx_rma_regmr_dmabuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle) { - return nccl_ucx_rma_regmr(comm, data, size, type, mhandle); -} +static ncclResult_t nccl_ucp_mh_update(nccl_ucp_comm_t *comm, + nccl_ucp_memh_t *mh) { + ucs_status_t status; + nccl_ucp_packed_rkey_t *packed, *remote; + + if (!mh->sent) { + packed = &comm->local.share.packed_rkey[mh->rkey_id]; + remote = &comm->remote.share->packed_rkey[mh->rkey_id]; -ncclResult_t nccl_ucx_rma_deregmr(void* comm, void* mhandle) -{ - nccl_ucx_rma_ctx_t *ctx = (nccl_ucx_rma_ctx_t*)comm; - ucx_rma_mhandle_t *mh = (ucx_rma_mhandle_t*)mhandle; + packed->rkey_buf_size = mh->rkey_buf_size; + packed->rkey_id_start = mh->rkey_id; + packed->rkey_id_end = mh->rkey_id; + memcpy(packed->rkey_buf, mh->rkey_buf, mh->rkey_buf_size); - ctx->mh[mh->rkey_buf.index] = NULL; - if (ctx->gpuFlush.enabled) { - ucp_rkey_destroy(mh->rkey); + status = nccl_ucp_shared_put(comm, packed, sizeof(*packed), remote, + &comm->inflight_rkey); + if (UCS_STATUS_IS_ERR(status)) { + WARN("Failed to send packed rkey"); + return ncclSystemError; + } + + comm->inflight_rkey += (status == UCS_INPROGRESS); + mh->sent = 1; } - ucp_mem_unmap(ctx->ctx, mh->ucp_memh); - free(mh); return ncclSuccess; } -ncclResult_t ucx_rma_get_request(nccl_ucx_rma_request_t* reqs, int* req_id) -{ - nccl_ucx_rma_request_t *r; - int i; +static ncclResult_t nccl_ucx_rma_regmr(void *reg_comm, void *data, size_t size, + int type, void **mhandle) { + nccl_ucp_comm_t *comm = reg_comm; + nccl_ucp_memh_t *mh; - for (i = 0; i < MAX_REQUESTS; i++) { - r = reqs + i; - if (r->used == 0) { - r->used = 1; - r->type = 0; - r->done = NCCL_UCX_RMA_REQUEST_INPROGRESS; - r->size = -1; - r->free = 0; - r->st = NULL; - *req_id = i; - return ncclSuccess; + mh = nccl_ucp_mem_register(comm, data, size, type); + if (mh) { + mh->rkey_id = comm->rkey_id++; + assert(mh->rkey_id < NCCL_UCP_RKEY_COUNT); + + if (comm->gpu_flush) { + UCXCHECK(ucp_ep_rkey_unpack(comm->ucp_flush_ep, mh->rkey_buf, &mh->rkey)); } } - WARN("NET/UCX_RMA: unable to allocate requests"); - *req_id = -1; - return ncclInternalError; + *mhandle = mh; + return *mhandle ? ncclSuccess : ncclSystemError; } -static void nccl_ucx_rma_ep_flush_cb(void *request, ucs_status_t status) -{ - return; +static ncclResult_t nccl_ucx_rma_regmr_dmabuf(void *comm, void *data, + size_t size, int type, + uint64_t offset, int fd, + void **mhandle) { + (void)fd; /* UCX performs the lookup automatically */ + assert(offset == 0); + return nccl_ucx_rma_regmr(comm, data, size, type, mhandle); } -static void nccl_ucx_rma_gdr_flush_cb(void *request, ucs_status_t status) -{ - nccl_ucx_flush_request_t *req = (nccl_ucx_flush_request_t*)request; +static ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data, + int *sizes, int *tags, void **mhandle, + void **request) { + nccl_ucp_comm_t *comm = recv_comm; + nccl_ucp_memh_t **mh = (nccl_ucp_memh_t**)mhandle; + nccl_ucp_req_t *req; + nccl_ucp_rtr_t *rtr; + nccl_ucp_atp_t *atp; + int i; + void *remote; + ucs_status_t status; - req->req->done = NCCL_UCX_RMA_REQUEST_DONE; - return; -} + req = &comm->req[comm->req_id & NCCL_UCP_RING_MASK]; + rtr = &comm->local.share.rtr[comm->rtr_id & NCCL_UCP_RING_MASK]; + atp = &comm->local.share.atp[comm->rtr_id & NCCL_UCP_RING_MASK]; -/* - * nccl_ucx_rma_send_check prepeares send communictor to be used for actual data - * communication and consists of multiple stages: - */ -enum { - NCCL_UCX_RMA_SCOMM_NOT_READY = 0, /* initial comm state, only ucp worker is present - * wait for remote worker addr and create ep - * notify peer that endpoint has been created - */ - NCCL_UCX_RMA_SCOMM_EP_CREATED, /* endpoint is created but it's not gurantee that - * wireup is done. ucp_ep_flush is used to finish - * wireup process - */ - NCCL_UCX_RMA_SCOMM_EP_FLUSH_WAIT, /* ep flush is in progress */ - NCCL_UCX_RMA_SCOMM_READY /* communicator is ready, notify peer */ -}; + assert(n <= NCCL_UCP_MAX_RECV); + assert(req->comm == NULL); -static ncclResult_t nccl_ucx_rma_send_check(nccl_ucx_rma_send_comm_t *comm) -{ - ucs_status_t st; + rtr->id_start = comm->rtr_id; + rtr->count = n; + rtr->avail = n; + rtr->ack = !((*request == (void*)0x1) && ncclParamUCXAckSkip()); - ucp_worker_progress(comm->super.worker); - if (comm->super.ready == NCCL_UCX_RMA_SCOMM_NOT_READY) { - NCCLCHECK(nccl_ucx_rma_init_ep(&comm->super.sock, comm->super.worker, &comm->ep, 0)); - if (comm->ep == NULL) { - return ncclSuccess; - } - NCCLCHECK(ncclSocketRecv(&comm->super.sock, &comm->rem_am_id, sizeof(int))); - comm->super.ready = NCCL_UCX_RMA_SCOMM_EP_CREATED; - } + *request = NULL; - if (comm->super.ready == NCCL_UCX_RMA_SCOMM_EP_CREATED) { - comm->super.check_req = ucp_ep_flush_nb(comm->ep, 0, nccl_ucx_rma_ep_flush_cb); + for (i = 0; i < n; i++) { + NCCLCHECK(nccl_ucp_mh_update(comm, mh[i])); - if (comm->super.check_req == NULL) { - comm->super.ready = NCCL_UCX_RMA_SCOMM_READY; - NCCLCHECK(ncclSocketSend(&comm->super.sock, &comm->super.ready, sizeof(int))); - } else if (UCS_PTR_IS_ERR(comm->super.check_req)) { - return ncclSystemError; - } else { - comm->super.ready = NCCL_UCX_RMA_SCOMM_EP_FLUSH_WAIT; - } + rtr->chunk[i].data = (uint64_t)data[i]; + rtr->chunk[i].rkey_id = mh[i]->rkey_id; + rtr->chunk[i].size = sizes[i]; + rtr->chunk[i].tag = tags[i]; + rtr->chunk[i].id = comm->rtr_id; } - if (comm->super.ready == NCCL_UCX_RMA_SCOMM_EP_FLUSH_WAIT) { - st = ucp_request_check_status(comm->super.check_req); - if (st != UCS_INPROGRESS) { - ucp_request_free(comm->super.check_req); - comm->super.ready = NCCL_UCX_RMA_SCOMM_READY; - NCCLCHECK(ncclSocketSend(&comm->super.sock, &comm->super.ready, sizeof(int))); - } + if (!rtr->ack) { + atp->id_start = comm->rtr_id; + atp->count = n; + atp->inflight = 0; + atp->reqs = 0; + atp->id = comm->rtr_id; + memcpy(atp->sizes, sizes, sizeof(*sizes) * n); + } + + remote = &comm->remote.share->rtr[comm->rtr_id & NCCL_UCP_RING_MASK]; + status = nccl_ucp_shared_put( + comm, rtr, sizeof(*rtr) - (NCCL_UCP_MAX_RECV - n) * sizeof(*rtr->chunk), + remote, &req->inflight); + if (!UCS_STATUS_IS_ERR(status)) { + req->comm = comm; + req->type = NCCL_UCP_TYPE_IRECV; + req->rtr_id = comm->rtr_id; + req->inflight = (status == UCS_INPROGRESS); + + comm->rtr_id++; + comm->req_id++; + comm->total++; + + *request = req; } return ncclSuccess; } -/* - * nccl_ucx_rma_recv_check prepeares recv communictor to be used for actual data - * communication and consists of multiple stages: - */ -enum { - NCCL_UCX_RMA_RCOMM_SEND_CONN_INFO = 0, /* initial stage, send worker address to peer */ - NCCL_UCX_RMA_RCOMM_WAIT_SCOMM, /* wait for send communicator ready notification */ - NCCL_UCX_RMA_RCOMM_READY, /* recv comm ready */ -}; - -static ncclResult_t nccl_ucx_rma_recv_check(nccl_ucx_rma_recv_comm_t *comm) -{ - int bytes = 0; - int rem_comm_state; - - ucp_worker_progress(comm->super.worker); +static ucp_rkey_h nccl_ucp_rkey_get(nccl_ucp_comm_t *comm, + unsigned short rkey_id) { + nccl_ucp_rkey_t *nccl_rkey; + nccl_ucp_packed_rkey_t *packed; + ucs_status_t status; - if (comm->super.ready == NCCL_UCX_RMA_RCOMM_SEND_CONN_INFO) { - NCCLCHECK(nccl_ucx_rma_send_worker_address(comm->super.worker, &comm->super.sock)); - NCCLCHECK(ncclSocketSend(&comm->super.sock, &comm->super.id, sizeof(int))); - comm->super.ready = NCCL_UCX_RMA_RCOMM_WAIT_SCOMM; - } + assert(rkey_id < NCCL_UCP_RKEY_COUNT); + nccl_rkey = &comm->rkey[rkey_id]; + if (nccl_rkey->rkey_id != rkey_id) { + /* Try to unpack */ + __sync_synchronize(); + packed = &comm->local.share.packed_rkey[rkey_id]; + if ((packed->rkey_id_start != rkey_id) || + (packed->rkey_id_end != rkey_id)) { + return NULL; + } - if (comm->super.ready == NCCL_UCX_RMA_RCOMM_WAIT_SCOMM) { - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->super.sock, &rem_comm_state, - sizeof(int), &bytes)); - if (bytes == 0) { - return ncclSuccess; + status = + ucp_ep_rkey_unpack(comm->ucp_ep, packed->rkey_buf, &nccl_rkey->rkey); + if (status != UCS_OK) { + return NULL; } - NCCLCHECK(ncclSocketWait(NCCL_SOCKET_RECV, &comm->super.sock, &rem_comm_state, - sizeof(int), &bytes)); - if (rem_comm_state == NCCL_UCX_RMA_SCOMM_READY) { - comm->super.ready = NCCL_UCX_RMA_RCOMM_READY; - } else { - WARN("Unexpected socket msg %d (%d)", rem_comm_state, NCCL_UCX_RMA_SCOMM_READY); - return ncclSystemError; - } + nccl_rkey->rkey_id = rkey_id; } - return ncclSuccess; -} - -static void nccl_ucx_rma_am_isend_cb(void *request, ucs_status_t status) -{ - nccl_ucx_am_request_t *req = (nccl_ucx_am_request_t*)request; - - req->req->done |= NCCL_UCX_RMA_REQUEST_AM_DONE; - return; + return nccl_rkey->rkey; } -static void nccl_ucx_rma_put_isend_cb(void *request, ucs_status_t status, void *data) -{ - nccl_ucx_rma_request_t *req = (nccl_ucx_rma_request_t*)data; - - req->done |= NCCL_UCX_RMA_REQUEST_PUT_DONE; - return; -} +static ncclResult_t nccl_ucp_send(nccl_ucp_comm_t *comm, unsigned short id, + unsigned i, void *data, int size, + nccl_ucp_memh_t *mh, void **request) { + nccl_ucp_req_t *req; + nccl_ucp_rtr_t *rtr; + nccl_ucp_atp_t *atp; + ucs_status_ptr_t status_ptr; + ucp_request_param_t param; + ucp_rkey_h rkey; -ncclResult_t nccl_ucx_rma_isend(void *send_comm, void *data, int size, int tag, - void *mhandle, void **request) -{ - nccl_ucx_rma_send_comm_t *comm = (nccl_ucx_rma_send_comm_t*)send_comm; - ucx_rma_mhandle_t *mh = (ucx_rma_mhandle_t*)mhandle; - volatile ucx_rma_send_fifo_t *slot; - volatile uint32_t *ready_ptr; - volatile int *rkey_id; - volatile int *rkey_index; - nccl_ucx_rma_request_t *req; - ucs_status_ptr_t st; - int req_id; - ucp_request_param_t req_param; - - if (comm->super.ready != NCCL_UCX_RMA_SCOMM_READY) { - NCCLCHECK(nccl_ucx_rma_send_check(comm)); - if (comm->super.ready != NCCL_UCX_RMA_SCOMM_READY) { - *request = NULL; - return ncclSuccess; - } + *request = NULL; + atp = &comm->local.share.atp[id & NCCL_UCP_RING_MASK]; + rtr = &comm->local.share.rtr[id & NCCL_UCP_RING_MASK]; + req = &comm->req[comm->req_id & NCCL_UCP_RING_MASK]; + assert(req->comm == NULL); + assert(rtr->avail > 0); + assert(rtr->id_start == id); + + rkey = nccl_ucp_rkey_get(comm, rtr->chunk[i].rkey_id); + if (rkey == NULL) { + return ncclSuccess; } - slot = comm->fifo + (comm->fifo_head % MAX_REQUESTS); - ready_ptr = &slot->ready; - rkey_id = &slot->rkey_id; - rkey_index = &slot->rkey_idx; - - if ((*ready_ptr == 0) || - (comm->rkeys[*rkey_index].id != *rkey_id)) { - ucp_worker_progress(comm->super.worker); - *request = NULL; + param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FIELD_MEMH | + UCP_OP_ATTR_FIELD_MEMORY_TYPE; + param.cb.send = nccl_ucp_rdma_isend_callback; + param.user_data = req; + param.memh = mh->ucp_memh; + param.memory_type = mh->mem_type; + + status_ptr = + ucp_put_nbx(comm->ucp_ep, data, size, rtr->chunk[i].data, rkey, ¶m); + if (UCS_PTR_IS_ERR(status_ptr)) { return ncclSuccess; } - NCCLCHECK(ucx_rma_get_request(comm->super.reqs, &req_id)); - req = &(comm->super.reqs[req_id]); - req->size = size; - - req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_REQUEST | - UCP_OP_ATTR_FIELD_USER_DATA; - req_param.cb.send = nccl_ucx_rma_put_isend_cb; - req_param.user_data = req; - req_param.request = &req->used; - if (mh) { - req_param.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE; - req_param.memory_type = mh->mem_type; - } - - st = ucp_put_nbx(comm->ep, data, size, slot->addr, - comm->rkeys[*rkey_index].rkey, - &req_param); - - if (UCS_PTR_IS_ERR(st)) { - WARN("NET/UCX_RMA: isend pub_nb failed"); - return ncclInternalError; - } else if (st == NULL) { - req->done |= NCCL_UCX_RMA_REQUEST_PUT_DONE; + if (rtr->avail == rtr->count) { + assert(atp->reqs == 0); + assert(atp->inflight == 0); + atp->id_start = rtr->id_start; + atp->count = 0; + memset(atp->sizes, 0, sizeof(atp->sizes)); + atp->id = rtr->id_start; } - ucp_worker_fence(comm->super.worker); - req->am_msg = (((uint64_t)slot->req_id) << 32) | ((uint64_t)size); - req->st = ucp_am_send_nb(comm->ep, comm->rem_am_id, &req->am_msg, 8, - ucp_dt_make_contig(1), nccl_ucx_rma_am_isend_cb, 0); + req->comm = comm; + req->type = NCCL_UCP_TYPE_ISEND; + req->rtr_id = rtr->id_start; + req->inflight = UCS_PTR_IS_PTR(status_ptr); + atp->inflight += req->inflight; + atp->sizes[i] = size; + atp->count++; + atp->reqs++; - if (req->st == NULL) { - req->done |= NCCL_UCX_RMA_REQUEST_AM_DONE; - } else if (UCS_PTR_IS_PTR(req->st)) { - nccl_ucx_am_request_t *am_req = (nccl_ucx_am_request_t*)req->st; - am_req->req = req; - } else { - WARN("NET/UCX_RMA: isend am_send_nb failed"); - } + rtr->avail--; + rtr->chunk[i].tag = INT_MAX; - req->seq = slot->seq; - slot->ready = 0; - slot->addr = 0ULL; - slot->size = 0; - slot->seq = 0; - comm->fifo_head++; - - req->worker = comm->super.worker; - req->type = UCX_RMA_REQ_TYPE_SEND; + comm->req_id++; + comm->total++; *request = req; return ncclSuccess; } -static void nccl_ucx_rma_dummy_am_cb(void *request, ucs_status_t status) -{ - return; -} +static ncclResult_t nccl_ucx_rma_isend(void *send_comm, void *data, int size, + int tag, void *mhandle, void **request) { + ncclResult_t result = ncclSuccess; + nccl_ucp_comm_t *comm = send_comm; + volatile nccl_ucp_rtr_t *rtr; + unsigned short id; + unsigned i; -ncclResult_t nccl_ucx_rma_post_fifo(nccl_ucx_rma_recv_comm_t *comm, - ucx_rma_mhandle_t *mh, - uint64_t addr, int size, int req_id) -{ - ucx_rma_send_fifo_t *local_elem; - nccl_ucx_rma_request_t *req; - uint64_t remote_addr; - ucs_status_t st; - - if (!mh->rkey_buf.send) { - req = &(comm->super.reqs[req_id]); - req->st = ucp_am_send_nb(comm->ep, comm->rem_am_id, &mh->rkey_buf, - sizeof(nccl_ucx_rma_rkey_buf_t), ucp_dt_make_contig(1), - nccl_ucx_rma_dummy_am_cb, 0); - if (UCS_PTR_IS_ERR(req->st)) { - WARN("NET/UCX_RMA: am_send_nb failed"); - return ncclInternalError; + *request = NULL; + + assert(tag != INT_MAX); + for (id = comm->rtr_id;; id++) { + rtr = &comm->local.share.rtr[id & NCCL_UCP_RING_MASK]; + if ((rtr->id_start != id) || (rtr->chunk->id != id)) { + break; } - mh->rkey_buf.send = 1; - } - local_elem = comm->rem_fifo.elems + (comm->rem_fifo.tail % MAX_REQUESTS); - local_elem->addr = addr; - local_elem->ready = 1; - local_elem->size = size; - local_elem->seq = comm->rem_fifo.tail; - local_elem->rkey_idx = mh->rkey_buf.index; - local_elem->rkey_id = mh->rkey_buf.id; - local_elem->req_id = req_id; - - remote_addr = comm->rem_fifo.addr + (comm->rem_fifo.tail % MAX_REQUESTS) * - sizeof(ucx_rma_send_fifo_t); - st = ucp_put_nbi(comm->ep, (void*)local_elem, sizeof(ucx_rma_send_fifo_t), - remote_addr, comm->rem_fifo.rkey); - if (st < 0) { - WARN("ucx_rma post_fifo pub_nbi failed %d", (int)st); - return ncclInternalError; + for (i = 0; i < rtr->count; i++) { + while (rtr->chunk[i].id != id) { + __sync_synchronize(); + } + } + + if (rtr->avail < 1) { + if (id == comm->rtr_id) { + comm->rtr_id++; + } + continue; + } + + for (i = 0; i < rtr->count; i++) { + if (rtr->chunk[i].tag == tag) { + result = nccl_ucp_send(comm, id, i, data, size, mhandle, request); + goto out; + } + } } - comm->rem_fifo.tail++; +out: + if ((*request == NULL) && (comm->total == 0)) { + ucp_worker_progress(comm->worker->ucp_worker); + } - return ncclSuccess; + return result; } -ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data,int *tags, int *sizes, - void **mhandle, void **request) -{ - nccl_ucx_rma_recv_comm_t *comm = (nccl_ucx_rma_recv_comm_t*)recv_comm; - ucx_rma_mhandle_t *mh = (ucx_rma_mhandle_t*)mhandle[0]; - nccl_ucx_rma_request_t *req; - int req_id; +static int nccl_ucp_flush_index(nccl_ucp_comm_t *comm, int *sizes, int n) { + int i, last = -1; - if (comm->super.ready != NCCL_UCX_RMA_RCOMM_READY) { - NCCLCHECK(nccl_ucx_rma_recv_check(comm)); + if (comm->gpu_flush) { + for (i = 0; i < n; i++) { + if (sizes[i]) { + last = i; + } + } } - if (comm->super.ready != NCCL_UCX_RMA_RCOMM_READY) { + return last; +} + +static ncclResult_t nccl_ucx_rma_iflush(void *recv_comm, int n, void **data, + int *sizes, void **mhandle, + void **request) { + nccl_ucp_comm_t *comm = recv_comm; + nccl_ucp_memh_t **mh = (nccl_ucp_memh_t**)mhandle; + int last = nccl_ucp_flush_index(comm, sizes, n); + nccl_ucp_req_t *req; + ucs_status_ptr_t status_ptr; + ucp_request_param_t param; + + if (last == -1) { *request = NULL; return ncclSuccess; } - - NCCLCHECK(ucx_rma_get_request(comm->super.reqs, &req_id)); - req = &comm->super.reqs[req_id]; - - req->seq = comm->rem_fifo.tail; - NCCLCHECK(nccl_ucx_rma_post_fifo(comm, mh, (uint64_t)data[0], sizes[0], req_id)); - ucp_worker_progress(comm->super.worker); - req->worker = comm->super.worker; - req->type = UCX_RMA_REQ_TYPE_RECV; + + req = &comm->req[comm->req_id & NCCL_UCP_RING_MASK]; + param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | UCP_OP_ATTR_FIELD_MEMH | + UCP_OP_ATTR_FIELD_MEMORY_TYPE; + param.cb.send = nccl_ucp_rdma_callback; + param.user_data = &req->inflight; + param.memh = comm->local.share_mh->ucp_memh; + param.memory_type = UCS_MEMORY_TYPE_HOST; + + status_ptr = ucp_get_nbx(comm->ucp_flush_ep, &comm->local.share.dummy_mem, 1, + (uint64_t)data[last], mh[last]->rkey, ¶m); + assert(!UCS_PTR_IS_ERR(status_ptr)); + assert(req->comm == NULL); + + req->type = NCCL_UCP_TYPE_IFLUSH; + req->inflight = (UCS_PTR_STATUS(status_ptr) == UCS_INPROGRESS); + req->comm = comm; + + comm->req_id++; + comm->total++; *request = req; return ncclSuccess; } -ncclResult_t nccl_ucx_rma_iflush(void* recv_comm, int n, void** data, int* sizes, - void** mhandle, void ** request) -{ - nccl_ucx_rma_recv_comm_t *comm = (nccl_ucx_rma_recv_comm_t*)recv_comm; - ucx_rma_mhandle_t *mh = (ucx_rma_mhandle_t*)mhandle[0]; - nccl_ucx_rma_request_t *req; - int req_id; - - *request = NULL; - int last = -1; - for (int i=0; isuper.gpuFlush.enabled == 0 || last == -1) return ncclSuccess; +static void nccl_ucx_rma_close_ep(ucp_ep_h ep) { + void *req; + ucp_request_param_t param = {.op_attr_mask = UCP_OP_ATTR_FIELD_FLAGS, + .flags = UCP_EP_CLOSE_FLAG_FORCE}; - NCCLCHECK(ucx_rma_get_request(comm->super.reqs, &req_id)); - req = &comm->super.reqs[req_id]; + req = ucp_ep_close_nbx(ep, ¶m); + (void)req; + assert(req == NULL); +} - req->st = ucp_get_nb(comm->super.gpuFlush.flush_ep, &comm->super.gpuFlush.hostMem, 1, - (uint64_t)data, mh->rkey, nccl_ucx_rma_gdr_flush_cb); - if (UCS_PTR_IS_ERR(req->st)) { - WARN("ucx_flush: unable to read data (%s)", ucs_status_string(UCS_PTR_STATUS(req))); - return ncclSystemError; - } else if (req->st == NULL) { - return ncclSuccess; - } - nccl_ucx_flush_request_t *flush_req = (nccl_ucx_flush_request_t*)req->st; - flush_req->req = req; +static ncclResult_t nccl_ucx_rma_close_comm(void *close_comm) { + int i; + nccl_ucp_comm_t *comm = close_comm; - req->worker = comm->super.worker; - req->type = UCX_RMA_REQ_TYPE_FLUSH; - *request = req; + assert(comm->total == 0); + assert(comm->inflight_rkey == 0); - return ncclSuccess; -} + nccl_ucx_rma_close_ep(comm->ucp_ep); + if (comm->ucp_flush_ep != NULL) { + assert(comm->gpu_flush); + nccl_ucx_rma_close_ep(comm->ucp_flush_ep); + } -ncclResult_t nccl_ucx_rma_test(void *request, int *done, int *size) -{ - nccl_ucx_rma_request_t *req = (nccl_ucx_rma_request_t*)request; - unsigned p; + for (i = 0; i < NCCL_UCP_RKEY_COUNT; i++) { + if (comm->rkey[i].rkey != NULL) { + ucp_rkey_destroy(comm->rkey[i].rkey); + } + } - *done = 0; - do { - if (req->done == NCCL_UCX_RMA_REQUEST_DONE) { - *done = 1; - if (size) { - *size = req->size; - } - if (req->st != NULL) { - ucp_request_free(req->st); - } - req->used = 0; - return ncclSuccess; + for (i = 0; i < NCCL_UCP_RING_SIZE; i++) { + if (comm->type == NCCL_UCP_TYPE_ISEND) { + assert(comm->local.share.rtr[i].avail < 1); + assert(comm->local.share.atp[i].reqs == 0); + assert(comm->local.share.atp[i].inflight == 0); } - p = ucp_worker_progress(req->worker); - } while (p); + assert(comm->req[i].comm == NULL); + } + + if (comm->remote.rkey != NULL) { + ucp_rkey_destroy(comm->remote.rkey); + } + if (comm->local.share_mh != NULL) { + nccl_ucx_rma_deregmr(comm, comm->local.share_mh); + } + ncclSocketClose(&comm->sock); + nccl_ucp_worker_put(comm->worker); + free(comm); return ncclSuccess; } -static void wait_close(ucp_worker_h worker, nccl_ucx_rma_request_t *req) -{ - ucs_status_t status; - - if (UCS_PTR_IS_PTR(req)) { - do { - ucp_worker_progress(worker); - status = ucp_request_check_status(req); - } while(status == UCS_INPROGRESS); - ucp_request_free(req); - } else if (req != NULL) { - WARN("Failed to close UCX endpoint"); - } +static void nccl_ucp_req_release(nccl_ucp_req_t *req) { + assert(req->comm->total > 0); + req->comm->total--; + req->comm = NULL; } -ncclResult_t nccl_ucx_rma_close_send(void *send_comm) -{ - nccl_ucx_rma_send_comm_t *comm = (nccl_ucx_rma_send_comm_t*) send_comm; - void *close_req; - int close = 1; - int i; +static ncclResult_t nccl_ucx_rma_test(void *request, int *done, int *sizes) { + nccl_ucp_req_t *req = request; + nccl_ucp_comm_t *comm = req->comm; + nccl_ucp_atp_t *atp; + nccl_ucp_rtr_t *rtr; + ucs_status_t status; + void *remote; - if (send_comm) { - ucp_mem_unmap(comm->super.ctx, comm->fifo_memh); + *done = 0; + while (ucp_worker_progress(comm->worker->ucp_worker) != 0) + ; /* nothing */ + + if (req->type == NCCL_UCP_TYPE_ISEND) { + rtr = &comm->local.share.rtr[req->rtr_id & NCCL_UCP_RING_MASK]; + atp = &comm->local.share.atp[req->rtr_id & NCCL_UCP_RING_MASK]; + remote = &comm->remote.share->atp[req->rtr_id & NCCL_UCP_RING_MASK]; + + assert(comm->type == NCCL_UCP_TYPE_ISEND); + assert(rtr->id_start == req->rtr_id); + assert(atp->id_start == req->rtr_id); + + if (rtr->avail == 0) { + if (rtr->ack) { + if (atp->inflight && + (comm->delay_atp || + (ucp_worker_fence(comm->worker->ucp_worker) != UCS_OK))) { + return ncclSuccess; + } - for (i = 0; i < comm->super.num_mh; i++) { - if (comm->rkeys[i].rkey) { - ucp_rkey_destroy(comm->rkeys[i].rkey); + status = nccl_ucp_shared_put(comm, atp, sizeof(*atp), remote, + &req->inflight); + req->inflight += (status == UCS_INPROGRESS); + rtr->avail -= !UCS_STATUS_IS_ERR(status); + } else { + rtr->avail--; } } - if (comm->ep) { - close_req = ucp_ep_close_nb(comm->ep, UCP_EP_CLOSE_MODE_FLUSH); - wait_close(comm->super.worker, close_req); + + *done = (req->inflight == 0) && ((atp->reqs > 1) || (rtr->avail < 0)); + if (*done) { + atp->reqs--; + assert((atp->reqs > 0) || (atp->inflight == 0)); + nccl_ucp_req_release(req); + } + } else if (req->type == NCCL_UCP_TYPE_IRECV) { + assert(comm->type == NCCL_UCP_TYPE_IRECV); + atp = &comm->local.share.atp[req->rtr_id & NCCL_UCP_RING_MASK]; + __sync_synchronize(); + *done = (req->inflight == 0) && (atp->id_start == req->rtr_id) && + (atp->id == req->rtr_id) && + ((comm->total > 1) || (comm->inflight_rkey == 0)); + if (*done) { + if (sizes != NULL) { + memcpy(sizes, atp->sizes, sizeof(*atp->sizes) * atp->count); + } + nccl_ucp_req_release(req); + } + } else { + assert(req->type == NCCL_UCP_TYPE_IFLUSH); + assert(comm->type == NCCL_UCP_TYPE_IRECV); + *done = (req->inflight == 0) && + ((comm->total > 1) || (comm->inflight_rkey == 0)); + if (*done) { + nccl_ucp_req_release(req); } - NCCLCHECK(ncclSocketSend(&comm->super.sock, &close, sizeof(int))); - nccl_ucx_free_worker(comm->super.worker); - free(comm); } return ncclSuccess; } -ncclResult_t nccl_ucx_rma_close_recv(void *recv_comm) -{ - nccl_ucx_rma_recv_comm_t *comm = (nccl_ucx_rma_recv_comm_t*)recv_comm; - void *close_req; - int close = 1; - - if (recv_comm) { - ucp_rkey_destroy(comm->rem_fifo.rkey); - if (comm->super.gpuFlush.enabled) { - close_req = ucp_ep_close_nb(comm->super.gpuFlush.flush_ep, UCP_EP_CLOSE_MODE_FLUSH); - wait_close(comm->super.worker, close_req); - } - if (comm->ep) { - close_req = ucp_ep_close_nb(comm->ep, UCP_EP_CLOSE_MODE_FLUSH); - wait_close(comm->super.worker, close_req); - } - NCCLCHECK(ncclSocketSend(&comm->super.sock, &close, sizeof(int))); - nccl_ucx_free_worker(comm->super.worker); - free(comm); +static ncclResult_t nccl_ucx_rma_regmr_v7(void *comm, void *data, int size, + int type, void **mhandle) { + return nccl_ucx_rma_regmr(comm, data, (size_t)size, type, mhandle); +} + +static ncclResult_t +nccl_ucx_rma_get_properties_v7(int dev, ncclNetProperties_v7_t *props_v7) { + ncclNetProperties_t props; + ncclResult_t ret = nccl_ucx_rma_get_properties(dev, &props); + if (ret != ncclSuccess) { + return ret; } - + props_v7->name = props.name; + props_v7->pciPath = props.pciPath; + props_v7->guid = props.guid; + props_v7->ptrSupport = props.ptrSupport; + props_v7->speed = props.speed; + props_v7->latency = props.latency; + props_v7->port = props.port; + props_v7->maxComms = props.maxComms; + props_v7->maxRecvs = props.maxRecvs; + props_v7->netDeviceType = props.netDeviceType; + props_v7->netDeviceVersion = props.netDeviceVersion; return ncclSuccess; } -ncclResult_t nccl_ucx_rma_close_listen(void *listen_comm) -{ - nccl_ucx_rma_listen_comm_t *comm = (nccl_ucx_rma_listen_comm_t *)listen_comm; - - if (comm) { - close(comm->sock.fd); - free(comm); +static ncclResult_t +nccl_ucx_rma_get_properties_v6(int dev, ncclNetProperties_v6_t *props_v6) { + ncclNetProperties_t props; + ncclResult_t ret = nccl_ucx_rma_get_properties(dev, &props); + if (ret != ncclSuccess) { + return ret; } - + props_v6->name = props.name; + props_v6->pciPath = props.pciPath; + props_v6->guid = props.guid; + props_v6->ptrSupport = props.ptrSupport; + props_v6->speed = props.speed; + props_v6->latency = props.latency; + props_v6->port = props.port; + props_v6->maxComms = props.maxComms; + props_v6->maxRecvs = props.maxRecvs; return ncclSuccess; } +static ncclResult_t nccl_ucx_rma_connect_v6(int dev, void *handle, + void **send_comm) { + ncclNetDeviceHandle_v7_t *dev_handle = NULL; + return nccl_ucx_rma_connect(dev, handle, send_comm, &dev_handle); +} + +static ncclResult_t nccl_ucx_rma_accept_v6(void *listen_comm, + void **recv_comm) { + ncclNetDeviceHandle_v7_t *dev_handle = NULL; + return nccl_ucx_rma_accept(listen_comm, recv_comm, &dev_handle); +} + +#define UCX_RMA_PLUGIN_NAME "UCX-RMA" ncclNet_v8_t ucxRmaPlugin_v8 = { - .name = "UCX-RMA", - .init = nccl_ucx_rma_init, - .devices = nccl_ucx_rma_devices, + .name = UCX_RMA_PLUGIN_NAME, + .init = nccl_ucx_rma_init, + .devices = nccl_ucx_rma_devices, .getProperties = nccl_ucx_rma_get_properties, - .listen = nccl_ucx_rma_listen, - .connect = nccl_ucx_rma_connect, - .accept = nccl_ucx_rma_accept, - .regMr = nccl_ucx_rma_regmr, - .regMrDmaBuf = nccl_ucx_rma_regmr_dmabuf, - .deregMr = nccl_ucx_rma_deregmr, - .isend = nccl_ucx_rma_isend, - .irecv = nccl_ucx_rma_irecv, - .iflush = nccl_ucx_rma_iflush, - .test = nccl_ucx_rma_test, - .closeSend = nccl_ucx_rma_close_send, - .closeRecv = nccl_ucx_rma_close_recv, - .closeListen = nccl_ucx_rma_close_listen, - NULL /* getDeviceMr */, - NULL /* irecvConsumed */ + .listen = nccl_ucx_rma_listen, + .connect = nccl_ucx_rma_connect, + .accept = nccl_ucx_rma_accept, + .regMr = nccl_ucx_rma_regmr, + .regMrDmaBuf = nccl_ucx_rma_regmr_dmabuf, + .deregMr = nccl_ucx_rma_deregmr, + .isend = nccl_ucx_rma_isend, + .irecv = nccl_ucx_rma_irecv, + .iflush = nccl_ucx_rma_iflush, + .test = nccl_ucx_rma_test, + .closeSend = nccl_ucx_rma_close_comm, + .closeRecv = nccl_ucx_rma_close_comm, + .closeListen = nccl_ucx_rma_close_listen, }; ncclNet_v7_t ucxRmaPlugin_v7 = { - .name = "UCX-RMA", - .init = nccl_ucx_rma_init, - .devices = nccl_ucx_rma_devices, + .name = UCX_RMA_PLUGIN_NAME, + .init = nccl_ucx_rma_init, + .devices = nccl_ucx_rma_devices, .getProperties = nccl_ucx_rma_get_properties_v7, - .listen = nccl_ucx_rma_listen, - .connect = nccl_ucx_rma_connect, - .accept = nccl_ucx_rma_accept, - .regMr = nccl_ucx_rma_regmr_v7, - .regMrDmaBuf = nccl_ucx_rma_regmr_dmabuf, - .deregMr = nccl_ucx_rma_deregmr, - .isend = nccl_ucx_rma_isend, - .irecv = nccl_ucx_rma_irecv, - .iflush = nccl_ucx_rma_iflush, - .test = nccl_ucx_rma_test, - .closeSend = nccl_ucx_rma_close_send, - .closeRecv = nccl_ucx_rma_close_recv, - .closeListen = nccl_ucx_rma_close_listen, - NULL /* getDeviceMr */, - NULL /* irecvConsumed */ + .listen = nccl_ucx_rma_listen, + .connect = nccl_ucx_rma_connect, + .accept = nccl_ucx_rma_accept, + .regMr = nccl_ucx_rma_regmr_v7, + .regMrDmaBuf = nccl_ucx_rma_regmr_dmabuf, + .deregMr = nccl_ucx_rma_deregmr, + .isend = nccl_ucx_rma_isend, + .irecv = nccl_ucx_rma_irecv, + .iflush = nccl_ucx_rma_iflush, + .test = nccl_ucx_rma_test, + .closeSend = nccl_ucx_rma_close_comm, + .closeRecv = nccl_ucx_rma_close_comm, + .closeListen = nccl_ucx_rma_close_listen, }; ncclNet_v6_t ucxRmaPlugin_v6 = { - .name = "UCX-RMA", - .init = nccl_ucx_rma_init, - .devices = nccl_ucx_rma_devices, - .getProperties = nccl_ucx_rma_get_properties_v6, - .listen = nccl_ucx_rma_listen, - .connect = nccl_ucx_rma_connect_v6, - .accept = nccl_ucx_rma_accept_v6, - .regMr = nccl_ucx_rma_regmr_v7, - .regMrDmaBuf = nccl_ucx_rma_regmr_dmabuf, - .deregMr = nccl_ucx_rma_deregmr, - .isend = nccl_ucx_rma_isend, - .irecv = nccl_ucx_rma_irecv, - .iflush = nccl_ucx_rma_iflush, - .test = nccl_ucx_rma_test, - .closeSend = nccl_ucx_rma_close_send, - .closeRecv = nccl_ucx_rma_close_recv, - .closeListen = nccl_ucx_rma_close_listen + .name = UCX_RMA_PLUGIN_NAME, + .init = nccl_ucx_rma_init, + .devices = nccl_ucx_rma_devices, + .getProperties = nccl_ucx_rma_get_properties_v6, + .listen = nccl_ucx_rma_listen, + .connect = nccl_ucx_rma_connect_v6, + .accept = nccl_ucx_rma_accept_v6, + .regMr = nccl_ucx_rma_regmr_v7, + .regMrDmaBuf = nccl_ucx_rma_regmr_dmabuf, + .deregMr = nccl_ucx_rma_deregmr, + .isend = nccl_ucx_rma_isend, + .irecv = nccl_ucx_rma_irecv, + .iflush = nccl_ucx_rma_iflush, + .test = nccl_ucx_rma_test, + .closeSend = nccl_ucx_rma_close_comm, + .closeRecv = nccl_ucx_rma_close_comm, + .closeListen = nccl_ucx_rma_close_listen }; ncclNet_v5_t ucxRmaPlugin_v5 = { - .name = "UCX-RMA", - .init = nccl_ucx_rma_init, - .devices = nccl_ucx_rma_devices, - .getProperties = nccl_ucx_rma_get_properties_v6, - .listen = nccl_ucx_rma_listen, - .connect = nccl_ucx_rma_connect_v6, - .accept = nccl_ucx_rma_accept_v6, - .regMr = nccl_ucx_rma_regmr_v7, - .deregMr = nccl_ucx_rma_deregmr, - .isend = nccl_ucx_rma_isend, - .irecv = nccl_ucx_rma_irecv, - .iflush = nccl_ucx_rma_iflush, - .test = nccl_ucx_rma_test, - .closeSend = nccl_ucx_rma_close_send, - .closeRecv = nccl_ucx_rma_close_recv, - .closeListen = nccl_ucx_rma_close_listen + .name = UCX_RMA_PLUGIN_NAME, + .init = nccl_ucx_rma_init, + .devices = nccl_ucx_rma_devices, + .getProperties = nccl_ucx_rma_get_properties_v6, + .listen = nccl_ucx_rma_listen, + .connect = nccl_ucx_rma_connect_v6, + .accept = nccl_ucx_rma_accept_v6, + .regMr = nccl_ucx_rma_regmr_v7, + .deregMr = nccl_ucx_rma_deregmr, + .isend = nccl_ucx_rma_isend, + .irecv = nccl_ucx_rma_irecv, + .iflush = nccl_ucx_rma_iflush, + .test = nccl_ucx_rma_test, + .closeSend = nccl_ucx_rma_close_comm, + .closeRecv = nccl_ucx_rma_close_comm, + .closeListen = nccl_ucx_rma_close_listen };