diff --git a/include/debug.h b/include/debug.h index 939db62..0db5c9c 100644 --- a/include/debug.h +++ b/include/debug.h @@ -35,4 +35,6 @@ extern ncclDebugLogger_t pluginLogFunction; void ncclSetThreadName(pthread_t thread, const char *fmt, ...); +void ncclResetDebugInit(); + #endif diff --git a/include/ibvwrap.h b/include/ibvwrap.h index 00aa457..55edbd1 100644 --- a/include/ibvwrap.h +++ b/include/ibvwrap.h @@ -13,6 +13,9 @@ #define NCCL_IBVWRAP_H_ #include "config.h" #include "core.h" +#include "utils.h" +#include +#include #include #if !HAVE_DECL_IBV_ACCESS_RELAXED_ORDERING @@ -82,4 +85,14 @@ ncclResult_t wrap_ibv_post_send(struct ibv_qp *qp, struct ibv_send_wr *wr, struc ncclResult_t wrap_ibv_post_recv(struct ibv_qp *qp, struct ibv_recv_wr *wr, struct ibv_recv_wr **bad_wr); ncclResult_t wrap_ibv_event_type_str(char **ret, enum ibv_event_type event); +// converts a GID into a readable string. On success, returns a non-null pointer to gidStr. +// NULL is returned if there was an error, with errno set to indicate the error. +// errno = ENOSPC if the converted string would exceed strLen. +static inline const char* ibvGetGidStr(union ibv_gid* gid, char* gidStr, size_t strLen) { + // GID is a 16B handle, to convert it to a readable form, we use inet_ntop + // sizeof(ibv_gid) == sizeof(struct in6_addr), so using AF_INET6 + NCCL_STATIC_ASSERT(sizeof(union ibv_gid) == sizeof(struct in6_addr), "the sizeof struct ibv_gid must be the size of struct in6_addr"); + return inet_ntop(AF_INET6, gid->raw, gidStr, strLen); +} + #endif //End include guard diff --git a/include/nccl.h b/include/nccl.h index a234af9..5d75ff4 100644 --- a/include/nccl.h +++ b/include/nccl.h @@ -12,6 +12,9 @@ #if CUDART_VERSION >= 11000 #include #endif +#if CUDART_VERSION >= 11080 +#include +#endif #define NCCL_MAJOR 2 #define NCCL_MINOR 20 @@ -146,6 +149,11 @@ const char* pncclGetErrorString(ncclResult_t result); const char* ncclGetLastError(ncclComm_t comm); const char* pncclGetLastError(ncclComm_t comm); +/* Reload environment variables that determine logging. */ +void ncclResetDebugInit(); +void pncclResetDebugInit(); + + /* Checks whether the comm has encountered any asynchronous errors */ ncclResult_t ncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); ncclResult_t pncclCommGetAsyncError(ncclComm_t comm, ncclResult_t *asyncError); @@ -201,12 +209,10 @@ typedef enum { ncclInt8 = 0, ncclChar = 0, ncclFloat16 = 6, ncclHalf = 6, ncclFloat32 = 7, ncclFloat = 7, ncclFloat64 = 8, ncclDouble = 8, -#if CUDART_VERSION >= 11000 ncclBfloat16 = 9, - ncclNumTypes = 10 -#else - ncclNumTypes = 9 -#endif + ncclFloat8e4m3 = 10, + ncclFloat8e5m2 = 11, + ncclNumTypes = 12 } ncclDataType_t; /* ncclScalarResidence_t: Location and dereferencing logic for scalar arguments. */ diff --git a/include/net.h b/include/net.h index 1e60bad..bdb37ff 100644 --- a/include/net.h +++ b/include/net.h @@ -9,6 +9,10 @@ #include #define NCCL_NET_HANDLE_MAXSIZE 128 +//Maximum value NCCL can accept for maxP2pBytes and maxCollBytes net properties +#define NCCL_MAX_NET_SIZE_BYTES (1*1024*1024*1024*1024L) +#define NCCL_NET_OPTIONAL_RECV_COMPLETION 0x1 + #define NCCL_PTR_HOST 0x1 #define NCCL_PTR_CUDA 0x2 @@ -22,6 +26,7 @@ typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCC typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...); +#include "net_v9.h" #include "net_v8.h" #include "net_v7.h" #include "net_v6.h" diff --git a/include/net_device.h b/include/net_device.h index de914d3..4937b1c 100644 --- a/include/net_device.h +++ b/include/net_device.h @@ -25,6 +25,7 @@ typedef struct { } ncclNetDeviceHandle_v7_t; typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_v8_t; -typedef ncclNetDeviceHandle_v8_t ncclNetDeviceHandle_t; +typedef ncclNetDeviceHandle_v8_t ncclNetDeviceHandle_v9_t; +typedef ncclNetDeviceHandle_v9_t ncclNetDeviceHandle_t; #endif diff --git a/include/net_v8.h b/include/net_v8.h index f1bd56b..84164f2 100644 --- a/include/net_v8.h +++ b/include/net_v8.h @@ -22,8 +22,6 @@ typedef struct { int netDeviceVersion; // Version number for network offload } ncclNetProperties_v8_t; -typedef ncclNetProperties_v8_t ncclNetProperties_t; - typedef struct { // Name of the network (mainly for logs) const char* name; diff --git a/include/net_v9.h b/include/net_v9.h new file mode 100644 index 0000000..4c29cde --- /dev/null +++ b/include/net_v9.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) 2017-2023, NVIDIA CORPORATION. All rights reserved. + */ + +#ifndef NCCL_NET_V9_H_ +#define NCCL_NET_V9_H_ +#include "net_device.h" + +// Max number of ncclNet objects which can live in the same process +#define NCCL_NET_MAX_PLUGINS 3 + +#define NCCL_NET_MAX_DEVS_PER_NIC_V9 4 +#define NCCL_NET_MAX_DEVS_PER_NIC NCCL_NET_MAX_DEVS_PER_NIC_V9 + +typedef struct { + int ndevs; + int devs[NCCL_NET_MAX_DEVS_PER_NIC_V9]; +} ncclNetVDeviceProps_v9_t; +typedef ncclNetVDeviceProps_v9_t ncclNetVDeviceProps_t; + + +typedef struct { + char* name; // Used mostly for logging. + char* pciPath; // Path to the PCI device in /sys. + uint64_t guid; // Unique identifier for the NIC chip. Important for + // cards with multiple PCI functions (Physical or virtual). + int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF] + int regIsGlobal; // regMr is not tied to a particular comm + int forceFlush; // Force a flush on receives + int speed; // Port speed in Mbps. + int port; // Port number. + float latency; // Network latency + int maxComms; // Maximum number of comms we can create + int maxRecvs; // Maximum number of grouped receives. + ncclNetDeviceType netDeviceType; // Network offload type + int netDeviceVersion; // Version number for network offload + ncclNetVDeviceProps_v9_t vProps; + size_t maxP2pBytes; // Max transfer size for point-to-point operations + size_t maxCollBytes; // Max transfer size for collective operations +} ncclNetProperties_v9_t; +typedef ncclNetProperties_v9_t ncclNetProperties_t; + +typedef struct { + // Name of the network (mainly for logs) + const char* name; + // Initialize the network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v9_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create a connection. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Connect to a handle and return a sending comm object for that peer. + // This call must not block for the connection to be established, and instead + // should return successfully with sendComm == NULL with the expectation that + // it will be called again until sendComm != NULL. + // If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm); + // Finalize connection establishment after remote peer has called connect. + // This call must not block for the connection to be established, and instead + // should return successfully with recvComm == NULL with the expectation that + // it will be called again until recvComm != NULL. + // If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection + ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm); + // Register/Deregister memory. Comm can be either a sendComm or a recvComm. + // Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* comm, void* mhandle); + // Asynchronous send to a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*isend)(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request); + // Asynchronous recv from a peer. + // May return request == NULL if the call cannot be performed (or would block) + ncclResult_t (*irecv)(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* sizes); + // Close and free send/recv comm objects + ncclResult_t (*closeSend)(void* sendComm); + ncclResult_t (*closeRecv)(void* recvComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Copy the given mhandle to a dptr in a format usable by this plugin's device code + ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle); + + // Notify the plugin that a recv has completed by the device + ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request); + + // Create a virtual NIC given the specified properties, which can be accessed at device index d + ncclResult_t (*makeVDevice)(int* d, ncclNetVDeviceProps_t* props); +} ncclNet_v9_t; + +typedef struct { + void* mhandle; + void* address; + size_t size; +} ncclNetSGE_v9_t; + +typedef struct { + // Name of the collective network (mainly for logs) + const char* name; + // Initialize the collective network. + ncclResult_t (*init)(ncclDebugLogger_t logFunction); + // Return the number of adapters capable of doing collective operations. + // If ndev returns 0, all other functions might be set to NULL. + ncclResult_t (*devices)(int* ndev); + // Get various device properties. + ncclResult_t (*getProperties)(int dev, ncclNetProperties_v9_t* props); + // Create a receiving object and provide a handle to connect to it. The + // handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged + // between ranks to create connections. + ncclResult_t (*listen)(int dev, void* handle, void** listenComm); + // Create a group for collective operations. handles have been created + // using listen() above. rank indicates caller's rank in the collective network. + ncclResult_t (*connect)(void* handles[], int nranks, int rank, void* listenComm, void** collComm); + // Returns whether a reduction operation on a data type is supported. + // 1 for supported, 0 otherwise. + ncclResult_t (*reduceSupport)(ncclDataType_t dataType, ncclRedOp_t redOp, int* supported); + // Register/Deregister memory. Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA. + ncclResult_t (*regMr)(void* collComm, void* data, size_t size, int type, void** mhandle); + /* DMA-BUF support */ + ncclResult_t (*regMrDmaBuf)(void* collComm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle); + ncclResult_t (*deregMr)(void* collComm, void* mhandle); + // Performs an asynchronous allreduce operation on the collective group. + // May return request == NULL if the call cannot be performed (or would block). + ncclResult_t (*iallreduce)(void* collComm, void* sendData, void* recvData, size_t count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request); + ncclResult_t (*iallgather)(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v9_t* recvParts, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + void* sendMhandle, void** request); + ncclResult_t (*ireducescatter)(void* collComm, int nSendParts, ncclNetSGE_v9_t* sendParts, void* recvData, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + ncclDataType_t dataType, ncclRedOp_t redOp, + void* recvMhandle, void** request); + // Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is + // visible to the GPU + ncclResult_t (*iflush)(void* collComm, void* data, int size, void* mhandle, void** request); + // Test whether a request is complete. If size is not NULL, it returns the + // number of bytes sent/received. + ncclResult_t (*test)(void* request, int* done, int* size); + // Close and free collective comm objects + ncclResult_t (*closeColl)(void* collComm); + ncclResult_t (*closeListen)(void* listenComm); + + // Create a virtual NIC given the specified properties, which can be accessed at device index d + ncclResult_t (*makeVDevice)(int* d, ncclNetVDeviceProps_t* props); +} ncclCollNet_v9_t; + +#endif // end include guard diff --git a/include/p2p_plugin.h b/include/p2p_plugin.h index 26d5a93..201f6cf 100644 --- a/include/p2p_plugin.h +++ b/include/p2p_plugin.h @@ -46,15 +46,13 @@ struct ncclIbMrCache { int capacity, population; }; -#define NCCL_IB_MAX_DEVS_PER_NIC 2 +#define NCCL_IB_MAX_DEVS_PER_NIC 4 #define MAX_MERGED_DEV_NAME (MAXNAMESIZE*NCCL_IB_MAX_DEVS_PER_NIC)+NCCL_IB_MAX_DEVS_PER_NIC -struct ncclIbMergedDev { - int ndevs; - int devs[NCCL_IB_MAX_DEVS_PER_NIC]; // Points to an index in ncclIbDevs +typedef struct ncclIbMergedDev { + ncclNetVDeviceProps_t vProps; int speed; char devName[MAX_MERGED_DEV_NAME]; // Up to NCCL_IB_MAX_DEVS_PER_NIC * name size, and a character for each '+' - int dmaBufSupported; // 0 = uninit, 1 = yes, -1 = no -} __attribute__((aligned(64))); +} __attribute__((aligned(64))) ncclIbMergedDev; struct ncclIbStats { int fatalErrorCount; @@ -108,17 +106,21 @@ typedef struct ncclIbDev { struct ibv_pd* pd; char devName[MAXNAMESIZE]; char *pciPath; + char* virtualPciPath; int realPort; int maxQp; + float latency; struct ncclIbMrCache mrCache; int ar; // ADAPTIVE_ROUTING struct ibv_port_attr portAttr; struct ncclIbStats stats; + int dmaBufSupported; } __attribute__((aligned(64))) ncclIbDev; -#define MAX_IB_DEVS 32 -extern struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_DEVS]; +#define MAX_IB_DEVS 32 +#define MAX_IB_VDEVS MAX_IB_DEVS*8 +extern struct ncclIbMergedDev ncclIbMergedDevs[MAX_IB_VDEVS]; extern struct ncclIbDev ncclIbDevs[MAX_IB_DEVS]; /* Detect whether GDR can work on a given NIC with the current CUDA device * Returns : @@ -130,9 +132,10 @@ ncclResult_t nccl_p2p_dmabuf_support(int dev); ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, char** path, int* real_port); -ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetProperties_t* props); +ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int ncclNMergedIbDevs, int dev, ncclNetProperties_t* props); -ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction); +ncclResult_t nccl_p2p_ib_init(int *nDevs, int *nmDevs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, + pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction); /* Convert value returtned by ibv_query_port to actual link width */ int nccl_p2p_ib_width(int width); @@ -152,4 +155,6 @@ nccl_p2p_plugin_t nccl_p2p_get_plugin_type(); ncclResult_t ncclIbStatsInit(struct ncclIbStats* stat); +ncclResult_t ncclIbMakeVDeviceInternal(int* d, ncclNetVDeviceProps_t* props, int nDevs, int *nmDevs); + #endif diff --git a/include/socket.h b/include/socket.h index 4c04ae1..0f162d4 100644 --- a/include/socket.h +++ b/include/socket.h @@ -19,9 +19,6 @@ #define MAX_IFS 16 #define MAX_IF_NAME_SIZE 16 -#define SLEEP_INT 1000 // connection retry sleep interval in usec -#define RETRY_REFUSED_TIMES 2e4 // connection refused retry times before reporting a timeout (20 sec) -#define RETRY_TIMEDOUT_TIMES 3 // connection timed out retry times (each one can take 20s) #define SOCKET_NAME_MAXLEN (NI_MAXHOST+NI_MAXSERV) #define NCCL_SOCKET_MAGIC 0x564ab9f2fc4b9d6cULL @@ -41,24 +38,25 @@ enum ncclSocketState { ncclSocketStateConnectPolling = 5, ncclSocketStateConnected = 6, ncclSocketStateReady = 7, - ncclSocketStateClosed = 8, - ncclSocketStateError = 9, - ncclSocketStateNum = 10 + ncclSocketStateTerminating = 8, + ncclSocketStateClosed = 9, + ncclSocketStateError = 10, + ncclSocketStateNum = 11 + }; enum ncclSocketType { ncclSocketTypeUnknown = 0, ncclSocketTypeBootstrap = 1, ncclSocketTypeProxy = 2, - ncclSocketTypeNetSocket = 3, - ncclSocketTypeNetIb = 4 + ncclSocketTypeNetIb = 4, + ncclSocketTypeRasNetwork = 5 }; struct ncclSocket { int fd; int acceptFd; - int timedOutRetries; - int refusedRetries; + int errorRetries; union ncclSocketAddress addr; volatile uint32_t* abortFlag; int asyncFlag; @@ -66,15 +64,18 @@ struct ncclSocket { int salen; uint64_t magic; enum ncclSocketType type; + int customRetry; + int finalizeCounter; // Used to keep track of initial handshake for async sockets. + char finalizeBuffer[sizeof(uint64_t)]; // Used to keep track of initial handshake for async sockets. }; -const char *ncclSocketToString(union ncclSocketAddress *addr, char *buf, const int numericHostForm); +const char *ncclSocketToString(const union ncclSocketAddress *addr, char *buf, const int numericHostForm); ncclResult_t ncclSocketGetAddrFromString(union ncclSocketAddress* ua, const char* ip_port_pair); int ncclFindInterfaceMatchSubnet(char* ifNames, union ncclSocketAddress* localAddrs, union ncclSocketAddress* remoteAddr, int ifNameMaxSize, int maxIfs); int ncclFindInterfaces(char* ifNames, union ncclSocketAddress *ifAddrs, int ifNameMaxSize, int maxIfs); // Initialize a socket -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag); +ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag, int customRetry); // Create a listening socket. sock->addr can be pre-filled with IP & port info. sock->fd is set after a successful call ncclResult_t ncclSocketListen(struct ncclSocket* sock); ncclResult_t ncclSocketGetAddr(struct ncclSocket* sock, union ncclSocketAddress* addr); @@ -90,7 +91,7 @@ ncclResult_t ncclSocketSetFd(int fd, struct ncclSocket* sock); #define NCCL_SOCKET_SEND 0 #define NCCL_SOCKET_RECV 1 -ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset); +ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* closed); ncclResult_t ncclSocketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset); ncclResult_t ncclSocketSend(struct ncclSocket* sock, void* ptr, int size); ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size); diff --git a/include/ucx_uct_lib.h b/include/ucx_uct_lib.h index 02566d6..459f627 100644 --- a/include/ucx_uct_lib.h +++ b/include/ucx_uct_lib.h @@ -162,6 +162,7 @@ typedef struct nccl_uct_context { /* IB devices available */ int dev_count; + int merge_dev_count; /* Use by common code to setup communicators */ struct nccl_uct_ops { @@ -230,6 +231,8 @@ ncclResult_t nccl_uct_reg_mr(void *reg_comm, void *data, size_t size, int type, ncclResult_t nccl_uct_dereg_mr(void *dereg_comm, void *mhandle); /* Compatibility callback */ +ncclResult_t nccl_uct_get_properties_v8(int dev, + ncclNetProperties_v8_t *props_v8); ncclResult_t nccl_uct_get_properties_v7(int dev, ncclNetProperties_v7_t *props_v7); ncclResult_t nccl_uct_reg_mr_v7(void *comm, void *data, int size, int type, @@ -242,7 +245,8 @@ ncclResult_t nccl_uct_get_properties(int dev, ncclNetProperties_t *props); #define NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, get_properties_func, \ - connect_func, accept_func, reg_mr_func) \ + connect_func, accept_func, reg_mr_func, \ + isend_func, irecv_func) \ { \ .name = plugin_name, \ .init = prefix##_init, \ @@ -254,8 +258,8 @@ ncclResult_t nccl_uct_get_properties(int dev, ncclNetProperties_t *props); .regMr = reg_mr_func, \ .regMrDmaBuf = nccl_uct_reg_mr_dmabuf, \ .deregMr = nccl_uct_dereg_mr, \ - .isend = prefix##_isend, \ - .irecv = prefix##_irecv, \ + .isend = prefix##_##isend_func, \ + .irecv = prefix##_##irecv_func, \ .iflush = prefix##_iflush, \ .test = prefix##_test, \ .closeSend = prefix##_close, \ @@ -263,18 +267,25 @@ ncclResult_t nccl_uct_get_properties(int dev, ncclNetProperties_t *props); .closeListen = nccl_uct_close_listen \ } -#define NCCL_UCT_PLUGIN_V8(plugin_name, prefix) \ +#define NCCL_UCT_PLUGIN_V9(plugin_name, prefix) \ NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, nccl_uct_get_properties, \ - nccl_uct_connect, nccl_uct_accept, nccl_uct_reg_mr) + nccl_uct_connect, nccl_uct_accept, nccl_uct_reg_mr, \ + isend, irecv) + +#define NCCL_UCT_PLUGIN_V8(plugin_name, prefix) \ + NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, nccl_uct_get_properties_v8, \ + nccl_uct_connect, nccl_uct_accept, nccl_uct_reg_mr, \ + isend_v8, irecv_v8) #define NCCL_UCT_PLUGIN_V7(plugin_name, prefix) \ NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, nccl_uct_get_properties_v7, \ - nccl_uct_connect, nccl_uct_accept, nccl_uct_reg_mr_v7) + nccl_uct_connect, nccl_uct_accept, nccl_uct_reg_mr_v7, \ + isend_v8, irecv_v8) #define NCCL_UCT_PLUGIN_V6(plugin_name, prefix) \ NCCL_UCT_PLUGIN_BASE(plugin_name, prefix, nccl_uct_get_properties_v6, \ nccl_uct_connect_v6, nccl_uct_accept_v6, \ - nccl_uct_reg_mr_v7) + nccl_uct_reg_mr_v7, isend_v8, irecv_v8) #define NCCL_UCT_PLUGIN_V5(plugin_name, prefix) \ { \ @@ -287,8 +298,8 @@ ncclResult_t nccl_uct_get_properties(int dev, ncclNetProperties_t *props); .accept = nccl_uct_accept_v6, \ .regMr = nccl_uct_reg_mr_v7, \ .deregMr = nccl_uct_dereg_mr, \ - .isend = prefix##_isend, \ - .irecv = prefix##_irecv, \ + .isend = prefix##_isend_v8, \ + .irecv = prefix##_irecv_v8, \ .iflush = prefix##_iflush, \ .test = prefix##_test, \ .closeSend = prefix##_close, \ diff --git a/src/ib_plugin.c b/src/ib_plugin.c index edccd51..96571ea 100644 --- a/src/ib_plugin.c +++ b/src/ib_plugin.c @@ -29,12 +29,12 @@ static char ncclIbIfName[MAX_IF_NAME_SIZE+1]; static union ncclSocketAddress ncclIbIfAddr; - -static int ncclNIbDevs = -1; - pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER; int ncclIbRelaxedOrderingEnabled = 0; +static int ncclNMergedIbDevs = -1; +static int ncclNIbDevs = -1; + NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1); NCCL_PARAM(IbRoutableFlidIbGidIndex, "IB_ROUTABLE_FLID_GID_INDEX", 1); NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2); @@ -44,7 +44,7 @@ NCCL_PARAM(IbRetryCnt, "IB_RETRY_CNT", 7); NCCL_PARAM(IbPkey, "IB_PKEY", 0); NCCL_PARAM(IbUseInline, "IB_USE_INLINE", 0); NCCL_PARAM(IbSl, "IB_SL", 0); -NCCL_PARAM(IbTc, "IB_TC", 0); +NCCL_PARAM(IbTc, "IB_TC", -1); NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192); NCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2); NCCL_PARAM(IbFifoTc, "IB_FIFO_TC", 0); @@ -117,17 +117,17 @@ static void* envIbAddrRange(sa_family_t af, int* mask) { *(maskStrPtr++) = '\0'; if (inet_pton(af, addrStrPtr, ret) == 0) { - WARN("NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address '%s' is invalid for family %s, ignoring address", addrStrPtr, (af == AF_INET) ? "AF_INET" : "AF_INET6"); return NULL; } *mask = (int)strtol(maskStrPtr, NULL, 10); if (af == AF_INET && *mask > 32) { - WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); *mask = 0; ret = NULL; } else if (af == AF_INET6 && *mask > 128) { - WARN("NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); + INFO(NCCL_INIT|NCCL_NET, "NET/IB: Ip address mask '%d' is invalid for family %s, ignoring mask", *mask, (af == AF_INET) ? "AF_INET" : "AF_INET6"); *mask = 0; ret = NULL; } @@ -208,7 +208,7 @@ static bool validGid(union ibv_gid* gid) { static ncclResult_t ncclIbRoceGetVersionNum(const char* deviceName, int portNum, int gidIndex, int* version) { char gidRoceVerStr[16] = { 0 }; char roceTypePath[PATH_MAX] = { 0 }; - sprintf(roceTypePath, "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex); + snprintf(roceTypePath, sizeof(roceTypePath), "/sys/class/infiniband/%s/ports/%d/gid_attrs/types/%d", deviceName, portNum, gidIndex); int fd = open(roceTypePath, O_RDONLY); if (fd == -1) { @@ -321,19 +321,39 @@ NCCL_PARAM(IbMergeNics, "IB_MERGE_NICS", 1); extern ncclDebugLogger_t pluginLogFunction; ncclResult_t ncclIbDevices(int* ndev) { - *ndev = ncclNIbDevs; + *ndev = ncclNMergedIbDevs; return ncclSuccess; } ncclResult_t ncclIbGetProperties(int dev, ncclNetProperties_t* props) { - return nccl_p2p_ib_get_properties(ncclIbDevs, dev, props); + return nccl_p2p_ib_get_properties(ncclIbDevs, ncclNMergedIbDevs, dev, props); +} + +ncclResult_t ncclIbGetProperties_v8(int dev, ncclNetProperties_v8_t* props_v8) +{ + ncclNetProperties_t props; + ncclResult_t ret = nccl_p2p_ib_get_properties(ncclIbDevs, ncclNMergedIbDevs, dev, &props); + if (ret != ncclSuccess) return ret; + props_v8->name = props.name; + props_v8->pciPath = props.pciPath; + props_v8->guid = props.guid; + props_v8->ptrSupport = props.ptrSupport; + props_v8->regIsGlobal = props.regIsGlobal; + props_v8->speed = props.speed; + props_v8->latency = props.latency; + props_v8->port = props.port; + props_v8->maxComms = props.maxComms; + props_v8->maxRecvs = props.maxRecvs; + props_v8->netDeviceType = props.netDeviceType; + props_v8->netDeviceVersion = props.netDeviceVersion; + return ncclSuccess; } ncclResult_t ncclIbGetProperties_v7(int dev, ncclNetProperties_v7_t* props_v7) { ncclNetProperties_t props; - ncclResult_t ret = nccl_p2p_ib_get_properties(ncclIbDevs, dev, &props); + ncclResult_t ret = nccl_p2p_ib_get_properties(ncclIbDevs, ncclNMergedIbDevs, dev, &props); if (ret != ncclSuccess) return ret; props_v7->name = props.name; props_v7->pciPath = props.pciPath; @@ -352,7 +372,7 @@ ncclResult_t ncclIbGetProperties_v7(int dev, ncclNetProperties_v7_t* props_v7) ncclResult_t ncclIbGetProperties_v6(int dev, ncclNetProperties_v6_t* props_v6) { ncclNetProperties_t props; - ncclResult_t ret = nccl_p2p_ib_get_properties(ncclIbDevs, dev, &props); + ncclResult_t ret = nccl_p2p_ib_get_properties(ncclIbDevs, ncclNMergedIbDevs, dev, &props); if (ret != ncclSuccess) return ret; props_v6->name = props.name; props_v6->pciPath = props.pciPath; @@ -416,6 +436,8 @@ enum ncclIbCommState { ncclIbCommStateConnecting = 6, ncclIbCommStateConnected = 7, ncclIbCommStatePendingReady = 8, + ncclIbCommStateSendDevList = 9, + ncclIbCommStateRecvDevList = 10, }; struct ncclIbCommStage { @@ -448,12 +470,12 @@ struct ncclIbListenComm { struct ncclIbSendFifo { uint64_t addr; - int size; + uint64_t size; uint32_t rkeys[NCCL_IB_MAX_DEVS_PER_NIC]; uint32_t nreqs; uint32_t tag; uint64_t idx; - char padding[24]; + char padding[16]; }; typedef struct ncclIbQp { @@ -485,7 +507,7 @@ struct ncclIbMrHandle { }; typedef struct ncclIbNetCommBase { - int ndevs; + ncclNetVDeviceProps_t vProps; bool isSend; struct ncclIbRequest reqs[MAX_REQUESTS]; struct ncclIbQp qps[NCCL_IB_MAX_QPS]; @@ -496,6 +518,7 @@ typedef struct ncclIbNetCommBase { int ready; // Track necessary remDevInfo here int nRemDevs; + int nDataQps; struct ncclIbDevInfo remDevs[NCCL_IB_MAX_DEVS_PER_NIC]; // statistics about the comm struct ncclIbStats stats; @@ -530,7 +553,6 @@ struct ncclIbRemFifo { struct ncclIbRecvCommDev { struct ncclIbNetCommDevBase base; struct ncclIbGpuFlush gpuFlush; - uint32_t fifoRkey; struct ibv_mr* fifoMr; struct ibv_sge fifoSge; struct ibv_mr* sizesFifoMr; @@ -538,7 +560,7 @@ struct ncclIbRecvCommDev { struct ncclIbRecvComm { struct ncclIbNetCommBase base; - struct ncclIbRecvCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; + struct ncclIbRecvCommDev devs[NCCL_IB_MAX_DEVS_PER_NIC]; struct ncclIbRemFifo remFifo; int sizesFifo[MAX_REQUESTS][NCCL_NET_IB_MAX_RECVS]; int gpuFlushHostMem; @@ -559,7 +581,7 @@ ncclResult_t ncclIbInit(ncclDebugLogger_t logFunction) { NCCL_STATIC_ASSERT((offsetof(struct ncclIbRecvComm, remFifo) % 32) == 0, "ncclIbSendComm fifo must be 32-byte aligned"); - return nccl_p2p_ib_init(&ncclNIbDevs, ncclIbDevs, ncclIbIfName, &ncclIbIfAddr, &ncclIbAsyncThread, logFunction); + return nccl_p2p_ib_init(&ncclNIbDevs, &ncclNMergedIbDevs, ncclIbDevs, ncclIbIfName, &ncclIbIfAddr, &ncclIbAsyncThread, logFunction); } NCCL_PARAM(IbQpsPerConn, "IB_QPS_PER_CONNECTION", 1); @@ -626,10 +648,12 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base, qpAttr.port_num = ib_port; qpAttr.qp_access_flags = access_flags; NCCLCHECK(wrap_ibv_modify_qp(qp->qp, &qpAttr, IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS)); + TRACE(NCCL_NET, "NET/IB : ncclIbCreateQp port=%d dev=%d devName=%s ndevs=%d nmdevs=%d qpn=%u pkey=%u pd=%p", + ib_port, base->ibDevN, ncclIbDevs[base->ibDevN].devName, ncclNIbDevs, ncclNMergedIbDevs, qp->qp->qp_num, qpAttr.pkey_index, base->pd); return ncclSuccess; } -ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool override_tc) { +ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool fifoTc) { struct ibv_qp_attr qpAttr; int same_subnet; memset(&qpAttr, 0, sizeof(struct ibv_qp_attr)); @@ -646,11 +670,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint qpAttr.ah_attr.grh.flow_label = 0; qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex; qpAttr.ah_attr.grh.hop_limit = 255; - if(ncclParamIbFifoTc() && override_tc) { - qpAttr.ah_attr.grh.traffic_class = ncclParamIbFifoTc(); - } else { - qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc(); - } + qpAttr.ah_attr.grh.traffic_class = fifoTc && ncclParamIbFifoTc() != -1 ? ncclParamIbFifoTc() : ncclParamIbTc(); } else { same_subnet = (ncclIbExtractLocalSubnetPrefix(sGidInfo->localGid.global.subnet_prefix) == ncclIbExtractLocalSubnetPrefix(info->gid.global.subnet_prefix)); @@ -676,6 +696,7 @@ ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint qpAttr.ah_attr.sl = ncclParamIbSl(); qpAttr.ah_attr.src_path_bits = 0; qpAttr.ah_attr.port_num = info->ib_port; + TRACE(NCCL_NET, "NET/IB : ncclIbRtrQp qpn=%u mtu=%d dst=%u ll=%u port=%u", qp->qp_num, info->mtu, dest_qp_num, info->link_layer, info->ib_port); NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER)); return ncclSuccess; } @@ -703,7 +724,7 @@ ncclResult_t ncclIbListen(int dev, void* opaqueHandle, void** listenComm) { memset(handle, 0, sizeof(struct ncclIbHandle)); comm->dev = dev; handle->magic = NCCL_SOCKET_MAGIC; - NCCLCHECKGOTO(ncclSocketInit(&comm->sock, &ncclIbIfAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&comm->sock, &ncclIbIfAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1, 0), ret, fail); NCCLCHECKGOTO(ncclSocketListen(&comm->sock), ret, fail); NCCLCHECKGOTO(ncclSocketGetAddr(&comm->sock, &handle->connectAddr), ret, fail); *listenComm = comm; @@ -724,10 +745,12 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet int ready; *sendComm = NULL; - if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; - if (stage->state == ncclIbCommStateSend) goto ib_send; - if (stage->state == ncclIbCommStateConnecting) goto ib_connect; - if (stage->state == ncclIbCommStateConnected) goto ib_send_ready; + if (stage->state == ncclIbCommStateConnect) goto ib_connect_check; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; + if (stage->state == ncclIbCommStateSend) goto ib_send; + if (stage->state == ncclIbCommStateConnecting) goto ib_connect; + if (stage->state == ncclIbCommStateConnected) goto ib_send_ready; if (stage->state != ncclIbCommStateStart) { WARN("Error: trying to connect already connected sendComm"); return ncclInternalError; @@ -736,7 +759,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(struct ncclIbSendComm))); NCCLCHECKGOTO(ncclIbStatsInit(&comm->base.stats), ret, fail); - NCCLCHECKGOTO(ncclSocketInit(&comm->base.sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&comm->base.sock, &handle->connectAddr, handle->magic, ncclSocketTypeNetIb, NULL, 1, 0), ret, fail); stage->comm = comm; stage->state = ncclIbCommStateConnect; NCCLCHECKGOTO(ncclSocketConnect(&comm->base.sock), ret, fail); @@ -748,22 +771,51 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet // IB Setup struct ncclIbMergedDev* mergedDev; + if (dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Trying to use non-existant virtual device %d", dev); + return ncclInternalError; + } + mergedDev = ncclIbMergedDevs + dev; - comm->base.ndevs = mergedDev->ndevs; - comm->base.nqps = ncclParamIbQpsPerConn() * comm->base.ndevs; // We must have at least 1 qp per-device + comm->base.vProps = mergedDev->vProps; comm->base.isSend = true; + stage->state = ncclIbCommStateSendDevList; + stage->offset = 0; + struct ncclIbConnectionMetadata meta; + NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(meta)), ret, fail); + memcpy(stage->buffer, &mergedDev->vProps, sizeof(ncclNetVDeviceProps_t)); + +// In the case of mismatched nDevs, we will make sure that both sides of a logical connection have the same number of RC qps +ib_send_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset, NULL)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + + stage->state = ncclIbCommStateRecvDevList; + stage->offset = 0; + +ib_recv_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset, NULL)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + stage->offset = 0; + ncclNetVDeviceProps_t remoteVProps; + memcpy(&remoteVProps, stage->buffer, sizeof(ncclNetVDeviceProps_t)); + mergedDev = ncclIbMergedDevs + dev; + comm->base.vProps = mergedDev->vProps; + int localNqps, remoteNqps; + localNqps = ncclParamIbQpsPerConn() * comm->base.vProps.ndevs; // We must have at least 1 qp per-device + remoteNqps = ncclParamIbQpsPerConn() * remoteVProps.ndevs; + comm->base.nqps = remoteNqps > localNqps ? remoteNqps : localNqps; // Select max nqps (local or remote) // Init PD, Ctx for each IB device comm->ar = 1; // Set to 1 for logic - for (int i = 0; i < mergedDev->ndevs; i++) { - int ibDevN = mergedDev->devs[i]; + for (int i = 0; i < comm->base.vProps.ndevs; i++) { + int ibDevN = comm->base.vProps.devs[i]; NCCLCHECKGOTO(ncclIbInitCommDevBase(ibDevN, &comm->devs[i].base, &comm->base.stats), ret, fail); - comm->ar = comm->ar && ncclIbDevs[dev].ar; // ADAPTIVE_ROUTING - if all merged devs have it enabled + comm->ar = comm->ar && ncclIbDevs[ibDevN].ar; // ADAPTIVE_ROUTING - if all merged devs have it enabled } - struct ncclIbConnectionMetadata meta; memset(&meta, 0, sizeof(meta)); - meta.ndevs = comm->base.ndevs; + meta.ndevs = comm->base.vProps.ndevs; // Alternate QPs between devices int devIndex; @@ -782,10 +834,10 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet } else { meta.qpInfo[q].ece_supported = 0; } - devIndex = (devIndex + 1) % comm->base.ndevs; + devIndex = (devIndex + 1) % comm->base.vProps.ndevs; } - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { ncclIbSendCommDev* commDev = comm->devs + i; ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; @@ -818,7 +870,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet // Print just the QPs for this dev if (comm->base.qps[q].devIndex == i) INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d subnet-prefix %lu FLID %d fifoRkey=0x%x fifoLkey=0x%x", - comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", + comm->base.vProps.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, devInfo->gid.global.subnet_prefix, ncclIbExtractFlid(&devInfo->gid), devInfo->fifoRkey, commDev->fifoMr->lkey); } @@ -827,7 +879,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet // Print just the QPs for this dev if (comm->base.qps[q].devIndex == i) INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x", - comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, + comm->base.vProps.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex, devInfo->gid.global.subnet_prefix, devInfo->gid.global.interface_id, devInfo->fifoRkey, commDev->fifoMr->lkey); } @@ -839,12 +891,11 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet stage->state = ncclIbCommStateSend; stage->offset = 0; - NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(meta)), ret, fail); memcpy(stage->buffer, &meta, sizeof(meta)); ib_send: - NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(meta), &stage->offset), ret, fail); + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, stage->buffer, sizeof(meta), &stage->offset, NULL), ret, fail); if (stage->offset != sizeof(meta)) return ncclSuccess; @@ -854,23 +905,18 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet memset(stage->buffer, 0, sizeof(meta)); ib_connect: - NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclIbConnectionMetadata), &stage->offset), ret, fail); + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->base.sock, stage->buffer, sizeof(ncclIbConnectionMetadata), &stage->offset, NULL), ret, fail); if (stage->offset != sizeof(remMeta)) return ncclSuccess; memcpy(&remMeta, stage->buffer, sizeof(ncclIbConnectionMetadata)); comm->base.nRemDevs = remMeta.ndevs; - if (comm->base.nRemDevs != comm->base.ndevs) { - mergedDev = ncclIbMergedDevs + dev; - WARN("NET/IB : Local mergedDev=%s has a different number of devices=%d as remoteDev=%s nRemDevs=%d", - mergedDev->devName, comm->base.ndevs, remMeta.devName, comm->base.nRemDevs); - } int link_layer; link_layer = remMeta.devs[0].link_layer; for (int i = 1; i < remMeta.ndevs; i++) { if (remMeta.devs[i].link_layer != link_layer) { - WARN("NET/IB : Can't merge net devices with different link_layer. i=%d remMeta.ndevs=%d link_layer=%d rem_link_layer=%d", + WARN("NET/IB : Can't connect net devices with different link_layer. i=%d remMeta.ndevs=%d link_layer=%d rem_link_layer=%d", i, remMeta.ndevs, link_layer, remMeta.devs[i].link_layer); return ncclInternalError; } @@ -887,7 +933,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet comm->remSizesFifo.addr = remMeta.fifoAddr; } - for (int i=0; i < comm->base.ndevs; i++) { + for (int i=0; i < comm->base.vProps.ndevs; i++) { NCCLCHECKGOTO(wrap_ibv_reg_mr(comm->remSizesFifo.mrs+i, comm->devs[i].base.pd, &comm->remSizesFifo.elems, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); } comm->base.nRemDevs = remMeta.ndevs; @@ -904,6 +950,8 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet if (remQpInfo->ece_supported && remQpInfo->ece_supported) NCCLCHECKGOTO(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported), ret, fail); + ncclIbDev* ibDev = ncclIbDevs + commDev->base.ibDevN; + remDevInfo->mtu = MIN(remDevInfo->mtu, ibDev->portAttr.active_mtu); NCCLCHECKGOTO(ncclIbRtrQp(qp, &commDev->base.gidInfo, remQpInfo->qpn, remDevInfo, false), ret, fail); NCCLCHECKGOTO(ncclIbRtsQp(qp), ret, fail); } @@ -917,12 +965,15 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet ibDevN, ibDev->portNum, remMeta.qpInfo[q].qpn, remMeta.qpInfo[q].ece_supported, remMeta.qpInfo[q].ece.vendor_id, remMeta.qpInfo[q].ece.options, remMeta.qpInfo[q].ece.comp_mask); } } + + comm->base.nDataQps = MAX(comm->base.vProps.ndevs, comm->base.nRemDevs); + comm->base.ready = 1; stage->state = ncclIbCommStateConnected; stage->offset = 0; ib_send_ready: - NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, &comm->base.ready, sizeof(int), &stage->offset), ret, fail); + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &comm->base.sock, &comm->base.ready, sizeof(int), &stage->offset, NULL), ret, fail); if (stage->offset != sizeof(int)) return ncclSuccess; *sendComm = comm; @@ -940,6 +991,50 @@ ncclResult_t ncclIbConnect_v6(int dev, void* opaqueHandle, void** sendComm) { return ncclIbConnect(dev, opaqueHandle, sendComm, &handle); } +NCCL_PARAM(IbWarnRailLocal, "IB_WARN_RAIL_LOCAL", 0); + +ncclResult_t ncclIbCheckVProps(ncclNetVDeviceProps_t* vProps1, ncclNetVDeviceProps_t* vProps2) { + ncclNetVDeviceProps_t outVProps = {0}; + ncclNetVDeviceProps_t* minVProps = vProps2; + ncclNetVDeviceProps_t* maxVProps = vProps1; + if (vProps2->ndevs > vProps1->ndevs) { + minVProps = vProps1; + maxVProps = vProps2; + } + + // Find the intersection of devices + for (int i = 0; i < minVProps->ndevs; i++) { + int dev = minVProps->devs[i]; + for (int j = 0; j < maxVProps->ndevs; j++) { + // Found + if (maxVProps->devs[j] == dev) { + outVProps.devs[outVProps.ndevs++] = dev; + } + } + } + + // In the case that at least one side has a fused NIC but there are no matching physical NICs, we should check if the user wants this + if (ncclParamIbWarnRailLocal() && outVProps.ndevs < maxVProps->ndevs) { + char local[128]; + int cursor = 1; + snprintf(local, sizeof(local), "%d", vProps1->devs[0]); + for (int i = 1; i < vProps1->ndevs; i++) { + snprintf(local+cursor, sizeof(local)-cursor, ",%d", vProps1->devs[i]); + cursor += 2; + } + char remote[128]; + snprintf(remote, sizeof(remote), "%d", vProps2->devs[0]); + cursor = 1; + for (int i = 1; i < vProps2->ndevs; i++) { + snprintf(remote+cursor, sizeof(remote)-cursor, ",%d", vProps2->devs[i]); + cursor += 2; + } + INFO(NCCL_NET, "NET/IB : There are mismatched physical devices between local (%s) and remote (%s). To disable this warning, set NCCL_IB_WARN_RAIL_LOCAL=0", local, remote); + } + + return ncclSuccess; +} + NCCL_PARAM(IbGdrFlushDisable, "GDR_FLUSH_DISABLE", 0); ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandle_t** recvDevComm) { @@ -950,7 +1045,9 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl int ready; *recvComm = NULL; - if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; + if (stage->state == ncclIbCommStateAccept) goto ib_accept_check; + if (stage->state == ncclIbCommStateRecvDevList) goto ib_recv_dev_list; + if (stage->state == ncclIbCommStateSendDevList) goto ib_send_dev_list; if (stage->state == ncclIbCommStateRecv) goto ib_recv; if (stage->state == ncclIbCommStateSend) goto ib_send; if (stage->state == ncclIbCommStatePendingReady) goto ib_recv_ready; @@ -963,20 +1060,55 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl NCCLCHECKGOTO(ncclIbStatsInit(&rComm->base.stats), ret, fail); stage->comm = rComm; stage->state = ncclIbCommStateAccept; - NCCLCHECKGOTO(ncclSocketInit(&rComm->base.sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0), ret, fail); + NCCLCHECKGOTO(ncclSocketInit(&rComm->base.sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0, 0), ret, fail); NCCLCHECKGOTO(ncclSocketAccept(&rComm->base.sock, &lComm->sock), ret, fail); + // Alloc stage->buffer here to be used for all following steps + struct ncclIbConnectionMetadata remMeta; + stage->offset = 0; + NCCLCHECK(ncclIbMalloc((void**)&stage->buffer, sizeof(remMeta))); ib_accept_check: NCCLCHECKGOTO(ncclSocketReady(&rComm->base.sock, &ready), ret, fail); if (!ready) return ncclSuccess; - struct ncclIbConnectionMetadata remMeta; - stage->state = ncclIbCommStateRecv; + stage->state = ncclIbCommStateRecvDevList; + stage->offset = 0; + +// In the case of mismatched nDevs, we will make sure that both sides of a logical connection have the same number of RC qps +ib_recv_dev_list: + NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset, NULL)); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + ncclNetVDeviceProps_t remoteVProps; + memcpy(&remoteVProps, stage->buffer, sizeof(ncclNetVDeviceProps_t)); + if (lComm->dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Trying to use non-existant virtual device %d", lComm->dev); + return ncclInternalError; + } + + // Reduce the physical device list and store in the connection base + struct ncclIbMergedDev* mergedDev; + mergedDev = ncclIbMergedDevs + lComm->dev; + NCCLCHECK(ncclIbCheckVProps(&mergedDev->vProps, &remoteVProps)); + rComm->base.vProps = mergedDev->vProps; + memcpy(stage->buffer, &rComm->base.vProps, sizeof(ncclNetVDeviceProps_t)); + rComm->base.isSend = false; + int localNqps, remoteNqps; + localNqps = ncclParamIbQpsPerConn() * rComm->base.vProps.ndevs; // We must have at least 1 qp per-device + remoteNqps = ncclParamIbQpsPerConn() * remoteVProps.ndevs; + rComm->base.nqps = remoteNqps > localNqps ? remoteNqps : localNqps; // Select max nqps (local or remote) + stage->offset = 0; - NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(remMeta)), ret, fail);; + stage->state = ncclIbCommStateSendDevList; + +ib_send_dev_list: + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(ncclNetVDeviceProps_t), &stage->offset, NULL), ret, fail); + if (stage->offset != sizeof(ncclNetVDeviceProps_t)) return ncclSuccess; + + stage->offset = 0; + stage->state = ncclIbCommStateRecv; ib_recv: - NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(remMeta), &stage->offset), ret, fail); + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, stage->buffer, sizeof(remMeta), &stage->offset, NULL), ret, fail); if (stage->offset != sizeof(remMeta)) return ncclSuccess; /* copy back the received info */ @@ -984,7 +1116,6 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl // IB setup // Pre-declare variables because of goto - struct ncclIbMergedDev* mergedDev; struct ncclIbDev* ibDev; int ibDevN; struct ncclIbRecvCommDev* rCommDev; @@ -992,22 +1123,19 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl struct ncclIbQp* qp; mergedDev = ncclIbMergedDevs + lComm->dev; - rComm->base.ndevs = mergedDev->ndevs; - rComm->base.nqps = ncclParamIbQpsPerConn() * rComm->base.ndevs; // We must have at least 1 qp per-device - rComm->base.isSend = false; rComm->base.nRemDevs = remMeta.ndevs; - if (rComm->base.nRemDevs != rComm->base.ndevs) { - WARN("NET/IB : Local mergedDev %s has a different number of devices=%d as remote %s %d", - mergedDev->devName, rComm->base.ndevs, remMeta.devName, rComm->base.nRemDevs); + if (rComm->base.nRemDevs != rComm->base.vProps.ndevs) { + INFO(NCCL_NET, "NET/IB : Local mergedDev %s has a different number of devices=%d as remote %s %d", + mergedDev->devName, rComm->base.vProps.ndevs, remMeta.devName, rComm->base.nRemDevs); } // Metadata to send back to requestor (sender) struct ncclIbConnectionMetadata meta; memset(&meta, 0, sizeof(meta)); - for (int i = 0; i < rComm->base.ndevs; i++) { + for (int i = 0; i < rComm->base.vProps.ndevs; i++) { rCommDev = rComm->devs + i; - ibDevN = mergedDev->devs[i]; + ibDevN = rComm->base.vProps.devs[i]; NCCLCHECKGOTO(ncclIbInitCommDevBase(ibDevN, &rCommDev->base, &rComm->base.stats), ret, fail); ibDev = ncclIbDevs + ibDevN; NCCLCHECKGOTO(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &rCommDev->base.gidInfo.localGidIndex), ret, fail); @@ -1038,7 +1166,7 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl ibDev = ncclIbDevs + ibDevN; NCCLCHECKGOTO(ncclIbCreateQp(ibDev->portNum, &rCommDev->base, IBV_ACCESS_REMOTE_WRITE, &rComm->base.stats, qp), ret, fail); qp->devIndex = devIndex; - devIndex = (devIndex + 1) % rComm->base.ndevs; + devIndex = (devIndex + 1) % rComm->base.vProps.ndevs; // Set the ece (enhanced connection establishment) on this QP before RTR if (remMeta.qpInfo[q].ece_supported) { @@ -1055,21 +1183,18 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl } else { meta.qpInfo[q].ece_supported = 0; } - bool override_tc = (q == 0) ? true : false; - NCCLCHECKGOTO(ncclIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, override_tc), ret, fail); + NCCLCHECKGOTO(ncclIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, true), ret, fail); NCCLCHECKGOTO(ncclIbRtsQp(qp->qp), ret, fail); } rComm->flushEnabled = ((nccl_p2p_gdr_support() == ncclSuccess || nccl_p2p_dmabuf_support(lComm->dev) == ncclSuccess) && (ncclParamIbGdrFlushDisable() == 0)) ? 1 : 0; - for (int i = 0; i < mergedDev->ndevs; i++) { + for (int i = 0; i < rComm->base.vProps.ndevs; i++) { rCommDev = rComm->devs + i; - ibDevN = rCommDev->base.ibDevN; - ibDev = ncclIbDevs + ibDevN; + ibDev = ncclIbDevs + rCommDev->base.ibDevN; // Retain remote fifo info and prepare my RDMA ops - rCommDev->fifoRkey = remMeta.devs[i].fifoRkey; rComm->remFifo.addr = remMeta.fifoAddr; NCCLCHECKGOTO(wrap_ibv_reg_mr(&rCommDev->fifoMr, rCommDev->base.pd, &rComm->remFifo.elems, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); rCommDev->fifoSge.lkey = rCommDev->fifoMr->lkey; @@ -1099,9 +1224,9 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl } // Fill Handle - meta.devs[i].lid = ibDev->portAttr.lid; - meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; - meta.devs[i].ib_port = ibDev->portNum; + meta.devs[i].lid = ibDev->portAttr.lid; + meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer; + meta.devs[i].ib_port = ibDev->portNum; meta.devs[i].gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix; meta.devs[i].gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id; meta.devs[i].is_global = (ncclParamIbIsGlobal() @@ -1110,9 +1235,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl #endif ); - // Adjust the MTU - remMeta.devs[i].mtu = (enum ibv_mtu)MIN(remMeta.devs[i].mtu, ibDev->portAttr.active_mtu); - meta.devs[i].mtu = remMeta.devs[i].mtu; + meta.devs[i].mtu = ibDev->portAttr.active_mtu; + // Prepare sizes fifo NCCLCHECKGOTO(wrap_ibv_reg_mr(&rComm->devs[i].sizesFifoMr, rComm->devs[i].base.pd, rComm->sizesFifo, sizeof(int)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ), ret, fail); @@ -1125,8 +1249,9 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl meta.qpInfo[q].devIndex = rComm->base.qps[q].devIndex; } - meta.ndevs = rComm->base.ndevs; + meta.ndevs = rComm->base.vProps.ndevs; strncpy(meta.devName, mergedDev->devName, MAX_MERGED_DEV_NAME); + rComm->base.nDataQps = MAX(rComm->base.vProps.ndevs, rComm->base.nRemDevs); stage->state = ncclIbCommStateSend; stage->offset = 0; @@ -1137,14 +1262,14 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl NCCLCHECKGOTO(ncclIbMalloc((void**)&stage->buffer, sizeof(struct ncclIbConnectionMetadata)), ret, fail); memcpy(stage->buffer, &meta, sizeof(struct ncclIbConnectionMetadata)); ib_send: - NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(struct ncclIbConnectionMetadata), &stage->offset), ret, fail); + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_SEND, &rComm->base.sock, stage->buffer, sizeof(struct ncclIbConnectionMetadata), &stage->offset, NULL), ret, fail); if (stage->offset < sizeof(struct ncclIbConnectionMetadata)) return ncclSuccess; stage->offset = 0; stage->state = ncclIbCommStatePendingReady; ib_recv_ready: - NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, &rComm->base.ready, sizeof(int), &stage->offset), ret, fail); + NCCLCHECKGOTO(ncclSocketProgress(NCCL_SOCKET_RECV, &rComm->base.sock, &rComm->base.ready, sizeof(int), &stage->offset, NULL), ret, fail); if (stage->offset != sizeof(int)) return ncclSuccess; *recvComm = rComm; @@ -1261,7 +1386,7 @@ ncclResult_t ncclIbRegMrDmaBuf(void* comm, void* data, size_t size, int type, ui assert(size > 0); struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) malloc(sizeof(struct ncclIbMrHandle)); - for (int i = 0; i < base->ndevs; i++) { + for (int i = 0; i < base->vProps.ndevs; i++) { // Each ncclIbNetCommDevBase is at different offset in send and recv netComms struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); NCCLCHECKGOTO(ncclIbRegMrDmaBufInternal(devComm, data, size, type, offset, fd, mhandleWrapper->mrs + i), ret, fail); @@ -1309,9 +1434,11 @@ ncclResult_t ncclIbDeregMrInternal(ncclIbNetCommDevBase* base, struct ibv_mr* mh } ncclResult_t ncclIbDeregMr(void* comm, void* mhandle) { + if (mhandle == NULL) return ncclSuccess; + struct ncclIbMrHandle* mhandleWrapper = (struct ncclIbMrHandle*) mhandle; struct ncclIbNetCommBase* base = (struct ncclIbNetCommBase*) comm; - for (int i = 0; i < base->ndevs; i++) { + for (int i = 0; i < base->vProps.ndevs; i++) { // Each ncclIbNetCommDevBase is at different offset in send and recv netComms struct ncclIbNetCommDevBase* devComm = ncclIbGetNetCommDevBase(base, i); NCCLCHECK(ncclIbDeregMrInternal(devComm, mhandleWrapper->mrs[i])); @@ -1377,7 +1504,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { // Multi-QP: make sure IB writes are multiples of 128B so that LL and LL128 protocols still work const int align = 128; - int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; for (int i = 0; i < nqps; i++) { int qpIndex = comm->base.qpIndex; ncclIbQp* qp = comm->base.qps + qpIndex; @@ -1426,7 +1553,7 @@ ncclResult_t ncclIbMultiSend(struct ncclIbSendComm* comm, int slot) { return ncclSuccess; } -ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { +ncclResult_t ncclIbIsend(void* sendComm, void* data, size_t size, int tag, void* mhandle, void** request) { struct ncclIbSendComm* comm = (struct ncclIbSendComm*)sendComm; if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIsend() called when comm->base.ready == 0"); return ncclInternalError; } if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } @@ -1456,7 +1583,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh char line[SOCKET_NAME_MAXLEN + 1]; union ncclSocketAddress addr; ncclSocketGetAddr(&comm->base.sock, &addr); - WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %d addr %lx rkeys[0]=%x", + WARN("NET/IB : req %d/%d tag %x peer %s posted incorrect receive info: size %ld addr %lx rkeys[0]=%x", r, nreqs, tag, ncclSocketToString(&addr, line, 1), slots[r].size, slots[r].addr, slots[r].rkeys[0]); } struct ncclIbRequest* req; @@ -1470,7 +1597,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh req->send.offset = 0; // Populate events - int nEvents = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + int nEvents = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; int qpIndex = comm->base.qpIndex; // Count down while (nEvents > 0) { @@ -1485,7 +1612,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh } // Store all lkeys - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { req->send.lkeys[i] = mhandleWrapper->mrs[i]->lkey; } @@ -1511,7 +1638,7 @@ ncclResult_t ncclIbIsend(void* sendComm, void* data, int size, int tag, void* mh return ncclSuccess; } -ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) { +ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, size_t* sizes, int* tags, void** mhandles, struct ncclIbRequest* req) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); @@ -1523,14 +1650,14 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int // Select the next devIndex (local) and QP to use for posting this CTS message // Since QPs are initialized by striping across devIndex, we can simply assign this to the same value ncclIbQp* ctsQp = comm->base.qps + comm->base.devIndex; - comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.ndevs; + comm->base.devIndex = (comm->base.devIndex + 1) % comm->base.vProps.ndevs; for (int i=0; ibase.ndevs; j++) + for (int j = 0; j < comm->base.vProps.ndevs; j++) localElem[i].rkeys[j] = mhandleWrapper->mrs[j]->rkey; localElem[i].nreqs = n; @@ -1591,7 +1718,7 @@ ncclResult_t ncclIbPostFifo(struct ncclIbRecvComm* comm, int n, void** data, int return ncclSuccess; } -ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { +ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, size_t* sizes, int* tags, void** mhandles, void** request) { struct ncclIbRecvComm* comm = (struct ncclIbRecvComm*)recvComm; if (comm->base.ready == 0) { WARN("NET/IB: ncclIbIrecv() called when comm->base.ready == 0"); return ncclInternalError; } if (comm->base.ready == 0) { *request = NULL; return ncclSuccess; } @@ -1605,7 +1732,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta req->sock = &comm->base.sock; req->nreqs = n; - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { req->devBases[i] = &comm->devs[i].base; } @@ -1618,7 +1745,7 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta TIME_START(1); // Select either all QPs, or one qp per-device - const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.ndevs; + const int nqps = ncclParamIbSplitDataOnQps() ? comm->base.nqps : comm->base.nDataQps; // Post recvs struct ibv_recv_wr* bad_wr; @@ -1639,6 +1766,18 @@ ncclResult_t ncclIbIrecv(void* recvComm, int n, void** data, int* sizes, int* ta return ncclSuccess; } +ncclResult_t ncclIbIsend_v8(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) +{ + return ncclIbIsend(sendComm, data, (size_t)size, tag, mhandle, request); +} + +ncclResult_t ncclIbIrecv_v8(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) +{ + size_t sizesOut[NCCL_NET_IB_MAX_RECVS]; + for (int i=0; ibase.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); wr.wr_id = req - comm->base.reqs; @@ -1684,7 +1823,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { while (1) { NCCLCHECK(ncclIbStatsCheckFatalCount(&r->base->stats,__func__)); - if (r->events[0] == 0 && r->events[1] == 0) { + if (r->events[0] == 0 && r->events[1] == 0 && r->events[2] == 0 && r->events[3] == 0) { TRACE(NCCL_NET, "r=%p done", r); *done = 1; if (sizes && r->type == NCCL_NET_IB_REQ_RECV) { @@ -1722,13 +1861,13 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { char remoteGidString[INET6_ADDRSTRLEN] = ""; const char* localGidStr = NULL, *remoteGidStr = NULL; if (r->devBases[i]->gidInfo.link_layer == IBV_LINK_LAYER_ETHERNET) { - localGidStr = inet_ntop(AF_INET6, &r->devBases[i]->gidInfo.localGid, localGidString, sizeof(localGidString)); - remoteGidStr = inet_ntop(AF_INET6, &r->base->remDevs[i].remoteGid, remoteGidString, sizeof(remoteGidString)); + localGidStr = ibvGetGidStr(&r->devBases[i]->gidInfo.localGid, localGidString, sizeof(localGidString)); + remoteGidStr = ibvGetGidStr(&r->base->remDevs[i].remoteGid, remoteGidString, sizeof(remoteGidString)); } char line[SOCKET_NAME_MAXLEN+1]; - char *hcaName = r->devBases[i]->pd->context->device->name; - WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s hca %s", + char *hcaName = r->devBases[i]->pd->context->device->name; + WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%u vendor err %u (%s)%s%s%s%s hca %s", ncclSocketToString(&addr, line, 1), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type], localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString, hcaName); return ncclRemoteError; @@ -1740,7 +1879,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; - TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%ld r=%p type=%d events={%d,%d}, i=%d", + TTRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%u wr_id=%lu r=%p type=%d events={%d,%d}, i=%d", ncclSocketToString(&addr, line, 1), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], i); #endif if (req && req->type == NCCL_NET_IB_REQ_SEND) { @@ -1784,7 +1923,7 @@ ncclResult_t ncclIbCloseSend(void* sendComm) { for (int q = 0; q < comm->base.nqps; q++) if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { struct ncclIbSendCommDev* commDev = comm->devs + i; if (commDev->fifoMr != NULL) NCCLCHECK(wrap_ibv_dereg_mr(commDev->fifoMr)); if (comm->remSizesFifo.mrs[i] != NULL) NCCLCHECK(wrap_ibv_dereg_mr(comm->remSizesFifo.mrs[i])); @@ -1804,7 +1943,7 @@ ncclResult_t ncclIbCloseRecv(void* recvComm) { for (int q = 0; q < comm->base.nqps; q++) if (comm->base.qps[q].qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(comm->base.qps[q].qp)); - for (int i = 0; i < comm->base.ndevs; i++) { + for (int i = 0; i < comm->base.vProps.ndevs; i++) { struct ncclIbRecvCommDev* commDev = comm->devs + i; if (comm->flushEnabled) { if (commDev->gpuFlush.qp.qp != NULL) NCCLCHECK(wrap_ibv_destroy_qp(commDev->gpuFlush.qp.qp)); @@ -1828,8 +1967,15 @@ ncclResult_t ncclIbCloseListen(void* listenComm) { return ncclSuccess; } -const ncclNet_v8_t ibPlugin_v8 = { - .name = "IBext_v8", +ncclResult_t ncclIbMakeVDevice(int* d, ncclNetVDeviceProps_t* props) { + pthread_mutex_lock(&ncclIbLock); + ncclResult_t res = ncclIbMakeVDeviceInternal(d, props, ncclNIbDevs, &ncclNMergedIbDevs); + pthread_mutex_unlock(&ncclIbLock); + return res; +} + +const ncclNet_v9_t ibPlugin_v9 = { + .name = "IBext_v9", .init = ncclIbInit, .devices = ncclIbDevices, .getProperties = ncclIbGetProperties, @@ -1847,6 +1993,30 @@ const ncclNet_v8_t ibPlugin_v8 = { .closeRecv = ncclIbCloseRecv, .closeListen = ncclIbCloseListen, NULL /* getDeviceMr */, + NULL /* irecvConsumed */, + ncclIbMakeVDevice +}; + + +const ncclNet_v8_t ibPlugin_v8 = { + .name = "IBext_v8", + .init = ncclIbInit, + .devices = ncclIbDevices, + .getProperties = ncclIbGetProperties_v8, + .listen = ncclIbListen, + .connect = ncclIbConnect, + .accept = ncclIbAccept, + .regMr = ncclIbRegMr, + .regMrDmaBuf = ncclIbRegMrDmaBuf, + .deregMr = ncclIbDeregMr, + .isend = ncclIbIsend_v8, + .irecv = ncclIbIrecv_v8, + .iflush = ncclIbIflush, + .test = ncclIbTest, + .closeSend = ncclIbCloseSend, + .closeRecv = ncclIbCloseRecv, + .closeListen = ncclIbCloseListen, + NULL /* getDeviceMr */, NULL /* irecvConsumed */ }; @@ -1861,8 +2031,8 @@ const ncclNet_v7_t ibPlugin_v7 = { .regMr = ncclIbRegMr_v7, .regMrDmaBuf = ncclIbRegMrDmaBuf, .deregMr = ncclIbDeregMr, - .isend = ncclIbIsend, - .irecv = ncclIbIrecv, + .isend = ncclIbIsend_v8, + .irecv = ncclIbIrecv_v8, .iflush = ncclIbIflush, .test = ncclIbTest, .closeSend = ncclIbCloseSend, @@ -1883,8 +2053,8 @@ const ncclNet_v6_t ibPlugin_v6 = { .regMr = ncclIbRegMr_v7, .regMrDmaBuf = ncclIbRegMrDmaBuf, .deregMr = ncclIbDeregMr, - .isend = ncclIbIsend, - .irecv = ncclIbIrecv, + .isend = ncclIbIsend_v8, + .irecv = ncclIbIrecv_v8, .iflush = ncclIbIflush, .test = ncclIbTest, .closeSend = ncclIbCloseSend, @@ -1902,8 +2072,8 @@ const ncclNet_v5_t ibPlugin_v5 = { .accept = ncclIbAccept_v6, .regMr = ncclIbRegMr_v7, .deregMr = ncclIbDeregMr, - .isend = ncclIbIsend, - .irecv = ncclIbIrecv, + .isend = ncclIbIsend_v8, + .irecv = ncclIbIrecv_v8, .iflush = ncclIbIflush, .test = ncclIbTest, .closeSend = ncclIbCloseSend, diff --git a/src/ibvwrap.c b/src/ibvwrap.c index 4e4c771..eac086d 100644 --- a/src/ibvwrap.c +++ b/src/ibvwrap.c @@ -5,8 +5,12 @@ ************************************************************************/ #include +#include + #include "ibvwrap.h" +#include "utils.h" #include "nccl.h" +#include "param.h" #define IBV_PTR_CHECK_ERRNO(call, retval, error_retval, name) \ retval = call; \ @@ -27,7 +31,7 @@ #define IBV_INT_CHECK_RET_ERRNO_OPTIONAL(call, success_retval, name, supported) \ int ret = call; \ if (ret == ENOTSUP || ret == EOPNOTSUPP) { \ - INFO(NCCL_NET, "Call to " name " failed with error %s errno %d", strerror(ret), ret); \ + INFO(NCCL_NET, "Call to " name " not supported"); \ *supported = 0; \ return ncclSuccess; \ } else if (ret != success_retval) { \ @@ -58,6 +62,14 @@ call; \ return ncclSuccess; +NCCL_PARAM(IbMQpRetryAll, "IB_MQP_RETRY_ALL", 0); +NCCL_PARAM(IbMQpRetryCnt, "IB_MQP_RETRY_CNT", 34); +NCCL_PARAM(IbMQpRetryTimeout, "IB_MQP_RETRY_SLEEP_MSEC", 100); // in milliseconds + +#define IBV_ERR_EQ(e, code) (e == code || e == (-code)) +#define IBV_MQP_RETRY_ERRNO(e) (IBV_ERR_EQ(e, ETIMEDOUT)) +#define IBV_MQP_RETRY_ERRNO_ALL(e) (ncclParamIbMQpRetryAll() ? (e != 0) : IBV_MQP_RETRY_ERRNO(e)) + ncclResult_t wrap_ibv_fork_init() { IBV_INT_CHECK(ibv_fork_init(), -1, "ibv_fork_init"); } @@ -170,10 +182,87 @@ ncclResult_t wrap_ibv_create_qp(struct ibv_qp **ret, struct ibv_pd *pd, struct i IBV_PTR_CHECK_ERRNO(ibv_create_qp(pd, qp_init_attr), *ret, NULL, "ibv_create_qp"); } -ncclResult_t wrap_ibv_modify_qp(struct ibv_qp *qp, struct ibv_qp_attr *attr, int attr_mask) { /*returns 0 on success, or the value of errno on failure (which indicates the failure reason)*/ - IBV_INT_CHECK_RET_ERRNO(ibv_modify_qp(qp, attr, attr_mask), 0, "ibv_modify_qp"); +static void ibvQpStateName(enum ibv_qp_state state, char* msg, const size_t len) { + switch (state) { + case (IBV_QPS_RESET): snprintf(msg, len, "RESET"); break; + case (IBV_QPS_INIT): snprintf(msg, len, "INIT"); break; + case (IBV_QPS_RTR): snprintf(msg, len, "RTR"); break; + case (IBV_QPS_RTS): snprintf(msg, len, "RTS"); break; + case (IBV_QPS_SQD): snprintf(msg, len, "SQD"); break; + case (IBV_QPS_SQE): snprintf(msg, len, "SQE"); break; + case (IBV_QPS_ERR): snprintf(msg, len, "ERR"); break; + case (IBV_QPS_UNKNOWN): snprintf(msg, len, "UNKNOWN"); break; + default: snprintf(msg, len, "NOT RECOGNIZED (%d)", state); break; + } } +#define QP_ATTR(attr, userAttr, userFlag, mask) ((userFlag & mask) ? (userAttr) : (attr)) + +static void ibvModifyQpLog(struct ibv_qp* qp, enum ibv_qp_state qpState, struct ibv_qp_attr* userAttr, int userFlag, char* msg, size_t msgLen) { + ncclResult_t res; + int portNum = -1, gidIndex = -1; + char localGidName[INET6_ADDRSTRLEN], remoteGidName[INET6_ADDRSTRLEN]; + const char *localGidRes = NULL, *remoteGidRes = NULL; + + char nextState[32], currState[32]; + ibvQpStateName(qp->state, currState, sizeof(currState)); + ibvQpStateName(qpState, nextState, sizeof(nextState)); + char devName[IBV_SYSFS_NAME_MAX] = ""; + snprintf(devName, sizeof(devName), "%s", (qp->pd->context) ? wrap_ibv_get_device_name(qp->pd->context->device) : "N/A"); + + struct ibv_qp_attr attr; + struct ibv_qp_init_attr init_attr; + int attr_mask = IBV_QP_PORT | IBV_QP_AV; + res = wrap_ibv_query_qp(qp, &attr, attr_mask, &init_attr); + struct ibv_qp_attr *qpAttr = (res == ncclSuccess) ? &attr : NULL; + + // port info, portAttr can be NULL if not given by the user and query_qp failed + struct ibv_qp_attr *portAttr = QP_ATTR(qpAttr, userAttr, userFlag, IBV_QP_PORT); + portNum = portAttr ? portAttr->port_num : -1; + + // address info, avAttr can be NULL if not given by the user and query_qp failed + struct ibv_qp_attr *avAttr = QP_ATTR(qpAttr, userAttr, userFlag, IBV_QP_AV); + if (avAttr && avAttr->ah_attr.is_global) { + union ibv_gid *remoteGid = &avAttr->ah_attr.grh.dgid; + remoteGidRes = ibvGetGidStr(remoteGid, remoteGidName, sizeof(remoteGidName)); + // we need pd->context to retrieve local GID, skip if not there + if (!qp->pd->context) goto print; + gidIndex = avAttr->ah_attr.grh.sgid_index; + union ibv_gid localGid; + NCCLCHECKGOTO(wrap_ibv_query_gid(qp->pd->context, portNum, gidIndex, &localGid), res, print); + localGidRes = ibvGetGidStr(&localGid, localGidName, sizeof(localGidName)); + } +print: + snprintf(msg, msgLen, "on dev %s:%d, curr state %s, next state %s, local GID index %d, local GID %s, remote GID %s", + devName, portNum, currState, nextState, gidIndex, localGidRes ? localGidName : "N/A", remoteGidRes ? remoteGidName : "N/A"); + return; +} + +ncclResult_t wrap_ibv_modify_qp(struct ibv_qp* qp, struct ibv_qp_attr* attr, int attr_mask) { + char qpMsg[1024]; + int ret = 0, attempts = 0; + int maxCnt = (int)ncclParamIbMQpRetryCnt() + 1; // number of attempts = number of retry + 1 + int timeOut = (int)ncclParamIbMQpRetryTimeout(); + do { + if (attempts > 0) { + unsigned int sleepTime = timeOut * attempts; + ibvModifyQpLog(qp, attr->qp_state, attr, attr_mask, qpMsg, sizeof(qpMsg)); + INFO(NCCL_NET, "Call to ibv_modify_qp failed with %d %s, %s, retrying %d/%d after %u msec of sleep", ret, strerror(ret), qpMsg, attempts, maxCnt, sleepTime); + // sleep before retrying + struct timespec tv = {.tv_sec = sleepTime / 1000, .tv_nsec = (sleepTime % 1000) * ((long)1e6)}; + nanosleep(&tv, NULL); + } + ret = ibv_modify_qp(qp, attr, attr_mask); + attempts++; + } while (IBV_MQP_RETRY_ERRNO_ALL(ret) && attempts < maxCnt); + if (ret != 0) { + ibvModifyQpLog(qp, attr->qp_state, attr, attr_mask, qpMsg, sizeof(qpMsg)); + WARN("Call to ibv_modify_qp failed with %d %s, %s", ret, strerror(ret), qpMsg); + return ncclSystemError; + } + return ncclSuccess; + } + ncclResult_t wrap_ibv_post_send(struct ibv_qp *qp, struct ibv_send_wr *wr, struct ibv_send_wr **bad_wr) { IBV_INT_CHECK_RET_ERRNO(qp->context->ops.post_send(qp, wr, bad_wr), 0, "ibv_post_send"); } diff --git a/src/p2p_plugin.c b/src/p2p_plugin.c index 744f2cc..62bbbd8 100644 --- a/src/p2p_plugin.c +++ b/src/p2p_plugin.c @@ -15,24 +15,32 @@ #include "p2p_plugin.h" #ifdef HAVE_UCX_PLUGIN +extern ncclNet_v9_t ucxPlugin_v9; extern ncclNet_v8_t ucxPlugin_v8; extern ncclNet_v7_t ucxPlugin_v7; extern ncclNet_v6_t ucxPlugin_v6; extern ncclNet_v5_t ucxPlugin_v5; + +extern ncclNet_v9_t ucxRmaPlugin_v9; extern ncclNet_v8_t ucxRmaPlugin_v8; extern ncclNet_v7_t ucxRmaPlugin_v7; extern ncclNet_v6_t ucxRmaPlugin_v6; extern ncclNet_v5_t ucxRmaPlugin_v5; + +extern ncclNet_v9_t ucxUctPlugin_v9; extern ncclNet_v8_t ucxUctPlugin_v8; extern ncclNet_v7_t ucxUctPlugin_v7; extern ncclNet_v6_t ucxUctPlugin_v6; extern ncclNet_v5_t ucxUctPlugin_v5; + +extern ncclNet_v9_t ucxUctRdPlugin_v9; extern ncclNet_v8_t ucxUctRdPlugin_v8; extern ncclNet_v7_t ucxUctRdPlugin_v7; extern ncclNet_v6_t ucxUctRdPlugin_v6; extern ncclNet_v5_t ucxUctRdPlugin_v5; #endif +extern ncclNet_v9_t ibPlugin_v9; extern ncclNet_v8_t ibPlugin_v8; extern ncclNet_v7_t ibPlugin_v7; extern ncclNet_v6_t ibPlugin_v6; @@ -40,7 +48,7 @@ extern ncclNet_v5_t ibPlugin_v5; pthread_mutex_t nccl_p2p_lock = PTHREAD_MUTEX_INITIALIZER; ncclDebugLogger_t pluginLogFunction; -static int ncclNMergedIbDevs = -1; + #ifdef HAVE_SHARP_PLUGIN extern int ncclNSharpDevs; @@ -52,11 +60,17 @@ extern int ncclIbRelaxedOrderingEnabled; NCCL_PARAM(SharpMaxComms, "SHARP_MAX_COMMS", 1); NCCL_PARAM(IbAdaptiveRouting, "IB_ADAPTIVE_ROUTING", -2); +ncclResult_t pluginInit_v9(ncclDebugLogger_t logFunction); ncclResult_t pluginInit_v8(ncclDebugLogger_t logFunction); ncclResult_t pluginInit_v7(ncclDebugLogger_t logFunction); ncclResult_t pluginInit_v6(ncclDebugLogger_t logFunction); ncclResult_t pluginInit_v5(ncclDebugLogger_t logFunction); +ncclNet_v9_t ncclNetPlugin_v9 = { + "NCCL RDMA Plugin v9", + pluginInit_v9, +}; + ncclNet_v8_t ncclNetPlugin_v8 = { "NCCL RDMA Plugin v8", pluginInit_v8, @@ -109,24 +123,28 @@ static void pluginSetup() switch (p2p_plugin) { #ifdef HAVE_UCX_PLUGIN case NCCL_P2P_UCX: + ncclNetPlugin_v9 = ucxPlugin_v9; ncclNetPlugin_v8 = ucxPlugin_v8; ncclNetPlugin_v7 = ucxPlugin_v7; ncclNetPlugin_v6 = ucxPlugin_v6; ncclNetPlugin_v5 = ucxPlugin_v5; break; case NCCL_P2P_UCX_RMA: + ncclNetPlugin_v9 = ucxRmaPlugin_v9; ncclNetPlugin_v8 = ucxRmaPlugin_v8; ncclNetPlugin_v7 = ucxRmaPlugin_v7; ncclNetPlugin_v6 = ucxRmaPlugin_v6; ncclNetPlugin_v5 = ucxRmaPlugin_v5; break; case NCCL_P2P_UCX_UCT: + ncclNetPlugin_v9 = ucxUctPlugin_v9; ncclNetPlugin_v8 = ucxUctPlugin_v8; ncclNetPlugin_v7 = ucxUctPlugin_v7; ncclNetPlugin_v6 = ucxUctPlugin_v6; ncclNetPlugin_v5 = ucxUctPlugin_v5; break; case NCCL_P2P_UCX_UCT_RD: + ncclNetPlugin_v9 = ucxUctRdPlugin_v9; ncclNetPlugin_v8 = ucxUctRdPlugin_v8; ncclNetPlugin_v7 = ucxUctRdPlugin_v7; ncclNetPlugin_v6 = ucxUctRdPlugin_v6; @@ -134,6 +152,7 @@ static void pluginSetup() break; #endif default: + ncclNetPlugin_v9 = ibPlugin_v9; ncclNetPlugin_v8 = ibPlugin_v8; ncclNetPlugin_v7 = ibPlugin_v7; ncclNetPlugin_v6 = ibPlugin_v6; @@ -143,6 +162,14 @@ static void pluginSetup() } +ncclResult_t pluginInit_v9(ncclDebugLogger_t logFunction) { + pluginLogFunction = logFunction; + pluginSetup(); + INFO(NCCL_INIT|NCCL_NET, "P2P plugin v9 %s", ncclNetPlugin_v9.name); + return ncclNetPlugin_v9.init(logFunction); +} + + ncclResult_t pluginInit_v8(ncclDebugLogger_t logFunction) { pluginLogFunction = logFunction; pluginSetup(); @@ -196,27 +223,25 @@ ncclResult_t nccl_p2p_gdr_support() static __thread int ibDmaSupportInitDev; // which device to init, must be thread local static void ibDmaBufSupportInitOnce(){ ncclResult_t res; - // select the appropriate - struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs + ibDmaSupportInitDev; - // Test each real devices int dev_fail = 0; - for (int i = 0; i < mergedDev->ndevs; i++) { - int ibDev = mergedDev->devs[i]; - struct ibv_pd* pd; - struct ibv_context* ctx = ncclIbDevs[ibDev].context; - NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); - // Test kernel DMA-BUF support with a dummy call (fd=-1) - (void)wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/); - // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) - dev_fail |= (errno == EOPNOTSUPP) || (errno == EPROTONOSUPPORT); - NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); - // stop the search and goto failure - if (dev_fail) goto failure; - } - mergedDev->dmaBufSupported = 1; + + // This is a physical device, not a virtual one, so select from ibDevs + ncclIbMergedDev* mergedDev = ncclIbMergedDevs + ibDmaSupportInitDev; + ncclIbDev* ibDev = ncclIbDevs + mergedDev->vProps.devs[0]; + struct ibv_pd* pd; + struct ibv_context* ctx = ibDev->context; + NCCLCHECKGOTO(wrap_ibv_alloc_pd(&pd, ctx), res, failure); + // Test kernel DMA-BUF support with a dummy call (fd=-1) + (void)wrap_direct_ibv_reg_dmabuf_mr(pd, 0ULL /*offset*/, 0ULL /*len*/, 0ULL /*iova*/, -1 /*fd*/, 0 /*flags*/); + // ibv_reg_dmabuf_mr() will fail with EOPNOTSUPP/EPROTONOSUPPORT if not supported (EBADF otherwise) + dev_fail |= (errno == EOPNOTSUPP) || (errno == EPROTONOSUPPORT); + NCCLCHECKGOTO(wrap_ibv_dealloc_pd(pd), res, failure); + // stop the search and goto failure + if (dev_fail) goto failure; + ibDev->dmaBufSupported = 1; return; failure: - mergedDev->dmaBufSupported = -1; + ibDev->dmaBufSupported = -1; return; } @@ -233,33 +258,31 @@ ncclResult_t nccl_p2p_dmabuf_support(int dev) { // init the device only once ibDmaSupportInitDev = dev; pthread_once(&onces[dev].once, ibDmaBufSupportInitOnce); - - int dmaBufSupported = ncclIbMergedDevs[dev].dmaBufSupported; + ncclIbMergedDev* mergedDev = ncclIbMergedDevs + ibDmaSupportInitDev; + ncclIbDev* ibDev = ncclIbDevs + mergedDev->vProps.devs[0]; + int dmaBufSupported = ibDev->dmaBufSupported; if (dmaBufSupported == 1) return ncclSuccess; return ncclSystemError; } -ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetProperties_t* props) -{ - struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs+dev; - props->name = mergedDev->devName; - props->speed = mergedDev->speed; - - // Take the rest of the properties from an arbitrary sub-device (should be the same) - struct ncclIbDev* ibDev = ncclIbDevs + mergedDev->devs[0]; +ncclResult_t ncclIbGetPhysProperties(int dev, ncclNetProperties_t* props) { + struct ncclIbDev* ibDev = ncclIbDevs + dev; + pthread_mutex_lock(&ibDev->lock); + props->name = ibDev->devName; + props->speed = ibDev->speed; props->pciPath = ibDev->pciPath; props->guid = ibDev->guid; - props->ptrSupport = NCCL_PTR_HOST; if (nccl_p2p_gdr_support() == ncclSuccess) { props->ptrSupport |= NCCL_PTR_CUDA; // GDR support via nv_peermem - INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (nvidia-peermem) enabled for HCA %d '%s", dev, devs[dev].devName); + INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (nvidia-peermem) enabled for HCA %d '%s", dev, ibDev->devName); } props->regIsGlobal = 1; + props->forceFlush = 0; if ((nccl_p2p_is_uct_plugin(p2p_plugin) || (p2p_plugin == NCCL_P2P_IB)) && nccl_p2p_dmabuf_support(dev) == ncclSuccess) { props->ptrSupport |= NCCL_PTR_DMABUF; // GDR support via DMA-BUF - INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (DMABUF) enabled for HCA %d '%s", dev, devs[dev].devName); + INFO(NCCL_NET,"NET/IB : GPU Direct RDMA (DMABUF) enabled for HCA %d '%s", dev, ibDev->devName); } props->latency = 0; // Not set @@ -272,9 +295,25 @@ ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int dev, ncclNetPropert } else { props->maxRecvs = 1; } - props->netDeviceType = NCCL_NET_DEVICE_HOST; + props->netDeviceType = NCCL_NET_DEVICE_HOST; props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION; + props->maxP2pBytes = NCCL_MAX_NET_SIZE_BYTES; + pthread_mutex_unlock(&ibDev->lock); + return ncclSuccess; +} +ncclResult_t nccl_p2p_ib_get_properties(ncclIbDev *devs, int ncclNMergedIbDevs, int dev, ncclNetProperties_t* props) +{ + if (dev >= ncclNMergedIbDevs) { + WARN("NET/IB : Requested properties for vNic %d, only %d vNics have been created", dev, ncclNMergedIbDevs); + return ncclInvalidUsage; + } + struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs + dev; + // Take the rest of the properties from an arbitrary sub-device (should be the same) + NCCLCHECK(ncclIbGetPhysProperties(mergedDev->vProps.devs[0], props)); + props->name = mergedDev->devName; + props->speed = mergedDev->speed; + memcpy(&props->vProps, &mergedDev->vProps, sizeof(ncclNetVDeviceProps_t)); return ncclSuccess; } @@ -364,29 +403,68 @@ int devSharpCompare(const void *a, const void *b) else { return 1; } } -// Compare ncclIbDev[dev] to all stored mergedIbDevs -int ncclIbFindMatchingDev(int dev) { - for (int i = 0; i < ncclNMergedIbDevs; i++) { - if (ncclIbMergedDevs[i].ndevs < NCCL_IB_MAX_DEVS_PER_NIC) { - int compareDev = ncclIbMergedDevs[i].devs[0]; - if (strcmp(ncclIbDevs[dev].pciPath, ncclIbDevs[compareDev].pciPath) == 0 && - (ncclIbDevs[dev].guid == ncclIbDevs[compareDev].guid) && - (ncclIbDevs[dev].link == ncclIbDevs[compareDev].link)) { - TRACE(NCCL_NET, "NET/IB: Matched name1=%s pciPath1=%s guid1=0x%lx link1=%u name2=%s pciPath2=%s guid2=0x%lx link2=%u", - ncclIbDevs[dev].devName, ncclIbDevs[dev].pciPath, ncclIbDevs[dev].guid, ncclIbDevs[dev].link, - ncclIbDevs[compareDev].devName, ncclIbDevs[compareDev].pciPath, ncclIbDevs[compareDev].guid, ncclIbDevs[compareDev].link); - return i; - } +ncclResult_t ncclIbMakeVDeviceInternal(int* d, ncclNetVDeviceProps_t* props, int ncclNIbDevs, int *ncclNMergedIbDevs) { + if ((ncclParamIbMergeNics() == 0) && props->ndevs > 1) { + WARN("NET/IB : Trying to merge multiple devices together when NCCL_IB_MERGE_NICS=0. Please enable it or disable device merging in NCCL."); + return ncclInvalidUsage; + } + + if (props->ndevs == 0) { + WARN("NET/IB : Can't make virtual NIC with 0 devices"); + return ncclInvalidUsage; + } + + if (*ncclNMergedIbDevs == MAX_IB_VDEVS) { + WARN("NET/IB : Cannot allocate any more virtual devices (%d)", MAX_IB_VDEVS); + return ncclInvalidUsage; + } + + // Always count up number of merged devices + ncclIbMergedDev* mDev = ncclIbMergedDevs + *ncclNMergedIbDevs; + mDev->vProps.ndevs = 0; + mDev->speed = 0; + + for (int i = 0; i < props->ndevs; i++) { + ncclIbDev* dev = ncclIbDevs + props->devs[i]; + if (mDev->vProps.ndevs == NCCL_IB_MAX_DEVS_PER_NIC) return ncclInvalidUsage; + mDev->vProps.devs[mDev->vProps.ndevs++] = props->devs[i]; + mDev->speed += dev->speed; + // Each successive time, copy the name '+' new name + if (mDev->vProps.ndevs > 1) { + snprintf(mDev->devName + strlen(mDev->devName), sizeof(mDev->devName) - strlen(mDev->devName), "+%s", dev->devName); + // First time, copy the plain name + } else { + strncpy(mDev->devName, dev->devName, MAXNAMESIZE); + } + } + + // Check link layers + ncclIbDev* dev0 = ncclIbDevs + props->devs[0]; + for (int i = 1; i < props->ndevs; i++) { + if (props->devs[i] >= ncclNIbDevs) { + WARN("NET/IB : Cannot use physical device %d, max %d", props->devs[i], ncclNIbDevs); + return ncclInvalidUsage; + } + ncclIbDev* dev = ncclIbDevs + props->devs[i]; + if (dev->link != dev0->link) { + WARN("NET/IB : Trying to merge multiple devices together with different link_layer properties %s -> %d, %s -> %d. Try only selecting NICs with one type of link using NCCL_IB_HCA", + dev0->devName, dev0->link, dev->devName, dev->link); + return ncclInvalidUsage; } } - return ncclNMergedIbDevs; + *d = *ncclNMergedIbDevs; + (*ncclNMergedIbDevs)++; + + INFO(NCCL_NET, "NET/IB : Made virtual device [%d] name=%s speed=%d ndevs=%d", *d, mDev->devName, mDev->speed, mDev->vProps.ndevs); + return ncclSuccess; } -ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction) +ncclResult_t nccl_p2p_ib_init(int *nDevs, int *nmDevs, ncclIbDev *ncclIbDevs, char *ncclIbIfName, union ncclSocketAddress *ncclIbIfAddr, pthread_t *ncclIbAsyncThread, ncclDebugLogger_t logFunction) { - int ncclNIbDevs = *num_devs; ncclResult_t ret = ncclSuccess; + int ncclNIbDevs = *nDevs; + int ncclNMergedIbDevs = *nmDevs; pluginLogFunction = logFunction; if (ncclNIbDevs == -1) { for (int i=0; i< MAX_IB_DEVS; i++) @@ -416,11 +494,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS); if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; } - // Should NCCL merge multi-port devices into one? - int mergeNics; - mergeNics = ncclParamIbMergeNics(); -build_ib_list: - for (int d=0; dname); @@ -485,94 +559,45 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb PTHREADCHECKGOTO(pthread_detach(*ncclIbAsyncThread), "pthread_detach", ret, fail); // will not be pthread_join()'d } - int mergedDev = ncclNMergedIbDevs; - if (mergeNics) { - mergedDev = ncclIbFindMatchingDev(ncclNIbDevs); - } + // Add this plain physical device to the list of virtual devices + int vDev; + ncclNetVDeviceProps_t vProps = {0}; + vProps.ndevs = 1; + vProps.devs[0] = ncclNIbDevs; + NCCLCHECK(ncclIbMakeVDeviceInternal(&vDev, &vProps, ncclNIbDevs, &ncclNMergedIbDevs)); - // No matching dev found, create new mergedDev entry (it's okay if there's only one dev inside) - if (mergedDev == ncclNMergedIbDevs) { - // Set ndevs to 1, assign first ibDevN to the current IB device - ncclIbMergedDevs[mergedDev].ndevs = 1; - ncclIbMergedDevs[mergedDev].devs[0] = ncclNIbDevs; - ncclNMergedIbDevs++; - strncpy(ncclIbMergedDevs[mergedDev].devName, ncclIbDevs[ncclNIbDevs].devName, MAXNAMESIZE); - // Matching dev found, edit name - } else { - // Set next device in this array to the current IB device - int ndevs = ncclIbMergedDevs[mergedDev].ndevs; - ncclIbMergedDevs[mergedDev].devs[ndevs] = ncclNIbDevs; - ncclIbMergedDevs[mergedDev].ndevs++; - snprintf(ncclIbMergedDevs[mergedDev].devName + strlen(ncclIbMergedDevs[mergedDev].devName), MAXNAMESIZE+1, "+%s", ncclIbDevs[ncclNIbDevs].devName); - } - - // Aggregate speed - ncclIbMergedDevs[mergedDev].speed += ncclIbDevs[ncclNIbDevs].speed; ncclNIbDevs++; nPorts++; } if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; } } - // Detect if there are both multi-port and single-port NICs in the system. If so, disable port merging and build the list again - if (mergeNics) { - for (int d = 0; d < ncclNMergedIbDevs; d++) { - if (ncclIbMergedDevs[d].ndevs != ncclIbMergedDevs[0].ndevs) { - INFO(NCCL_NET, "Detected a mix of single and multiple-port NICs. Force-disabling NCCL_IB_MERGE_NICS"); - mergeNics = 0; - ncclNIbDevs = 0; - ncclNMergedIbDevs = 0; - memset(ncclIbMergedDevs, 0, sizeof(ncclIbMergedDevs)); - goto build_ib_list; - } - } - } - if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { ret = ncclInternalError; goto fail; }; } if (ncclNIbDevs == 0) { INFO(NCCL_INIT|NCCL_NET, "NET/IB : No device found."); - } else { - // sort devices on sharp capable - if (ncclNSharpDevs && (ncclNSharpDevs != ncclNIbDevs)) { - qsort(ncclIbDevs, ncclNIbDevs, sizeof(struct ncclIbDev), devSharpCompare); - } + } - char line[2048]; - line[0] = '\0'; - // Determine whether RELAXED_ORDERING is enabled and possible - ncclIbRelaxedOrderingEnabled = ncclIbRelaxedOrderingCapable(); - for (int d = 0; d < ncclNMergedIbDevs; d++) { - struct ncclIbMergedDev* mergedDev = ncclIbMergedDevs + d; - if (mergedDev->ndevs > 1) { - // Print out merged dev info - snprintf(line+strlen(line), 2047-strlen(line), " [%d]={", d); - for (int i = 0; i < mergedDev->ndevs; i++) { - int ibDev = mergedDev->devs[i]; - snprintf(line+strlen(line), 2047-strlen(line), "[%d] %s:%d/%s%s", ibDev, ncclIbDevs[ibDev].devName, - ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", - // Insert comma to delineate - i == (mergedDev->ndevs - 1) ? "" : ", "); - } - snprintf(line+strlen(line), 2047-strlen(line), "}"); - } else { - int ibDev = mergedDev->devs[0]; + // Print out all net devices to the user (in the same format as before) + char line[2048]; + line[0] = '\0'; + // Determine whether RELAXED_ORDERING is enabled and possible + ncclIbRelaxedOrderingEnabled = ncclIbRelaxedOrderingCapable(); + for (int d = 0; d < ncclNIbDevs; d++) { #ifdef HAVE_SHARP_PLUGIN - snprintf(line+strlen(line), 2047-strlen(line), " [%d]%s:%d/%s%s", ibDev, ncclIbDevs[ibDev].devName, - ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", - ncclIbDevs[ibDev].isSharpDev ? "/SHARP" : ""); + snprintf(line+strlen(line), sizeof(line)-strlen(line), " [%d]%s:%d/%s%s", d, ncclIbDevs[d].devName, + ncclIbDevs[d].portNum, ncclIbDevs[d].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE", + ncclIbDevs[d].isSharpDev ? "/SHARP" : ""); #else - snprintf(line+strlen(line), 2047-strlen(line), " [%d]%s:%d/%s", ibDev, ncclIbDevs[ibDev].devName, - ncclIbDevs[ibDev].portNum, ncclIbDevs[ibDev].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); + snprintf(line+strlen(line), sizeof(line)-strlen(line), " [%d]%s:%d/%s", d, ncclIbDevs[d].devName, + ncclIbDevs[d].portNum, ncclIbDevs[d].link == IBV_LINK_LAYER_INFINIBAND ? "IB" : "RoCE"); #endif - } - } - line[2047] = '\0'; - char addrline[SOCKET_NAME_MAXLEN+1]; - INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", - ncclIbIfName, ncclSocketToString(ncclIbIfAddr, addrline, 1)); } - *num_devs = ncclNMergedIbDevs; + char addrline[SOCKET_NAME_MAXLEN+1]; + INFO(NCCL_INIT|NCCL_NET, "NET/IB : Using%s %s; OOB %s:%s", line, ncclIbRelaxedOrderingEnabled ? "[RO]" : "", + ncclIbIfName, ncclSocketToString(ncclIbIfAddr, addrline, 1)); + *nDevs = ncclNIbDevs; + *nmDevs = ncclNMergedIbDevs; pthread_mutex_unlock(&nccl_p2p_lock); } exit: @@ -582,6 +607,16 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb goto exit; } +// Returns 0 if this is the path of two VFs of the same physical device +static int ncclIbMatchVfPath(char* path1, char* path2) { + // Merge multi-port NICs into the same PCI device + if (ncclParamIbMergeVfs()) { + return strncmp(path1, path2, strlen(path1)-4) == 0; + } else { + return strncmp(path1, path2, strlen(path1)-1) == 0; + } +} + ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, char** path, int* real_port) { char device_path[PATH_MAX]; @@ -590,14 +625,10 @@ ncclResult_t nccl_p2p_ib_pci_path(ncclIbDev *devs, int num_devs, char* dev_name, if (p == NULL) { WARN("Could not find real path of %s", *device_path); } else { - // Merge multi-port NICs into the same PCI device - p[strlen(p)-1] = '0'; - // Also merge virtual functions (VF) into the same device - if (ncclParamIbMergeVfs()) p[strlen(p)-3] = p[strlen(p)-4] = '0'; - // And keep the real port aside (the ibv port is always 1 on recent cards) + // Keep the real port aside (the ibv port is always 1 on recent cards) *real_port = 0; for (int d=0; d ((3L<maxCollBytes = NCCL_MAX_NET_SIZE_BYTES; +#else + props->maxCollBytes = (512*1024*1024L); //limited to 512M in SHARP 3.6 or older +#endif + return ncclSuccess; +} + ncclResult_t ncclSharpGetProperties_v8(int dev, ncclNetProperties_v8_t* props) { return ncclNetPlugin_v8.getProperties(dev, props); } @@ -247,7 +259,7 @@ ncclResult_t ncclSharpListen(int dev, void* opaqueHandle, void** listenComm) { ncclResult_t status; NCCLCHECK(ncclIbMalloc((void**)&lComm, sizeof(struct ncclSharpListenComm))); - status = ncclNetPlugin_v8.listen(dev, opaqueHandle, &lComm->listenCommP2P); + status = ncclNetPlugin_v9.listen(dev, opaqueHandle, &lComm->listenCommP2P); lComm->dev = dev; *listenComm = lComm; return status; @@ -397,7 +409,7 @@ ncclResult_t ncclSharpRegMrDmaBuf(void* collComm, void* data, size_t size, int t } TRACE(NCCL_INIT,"sharpRegAddr %lx size %ld handle %x", data, size, mh->mr); - NCCLCHECK(ncclNetPlugin_v8.regMrDmaBuf(cComm->recvComm, data, size, type, offset, fd, &mh->ncclIbMr)); + NCCLCHECK(ncclNetPlugin_v9.regMrDmaBuf(cComm->recvComm, data, size, type, offset, fd, &mh->ncclIbMr)); *mhandle = mh; return ncclSuccess; @@ -419,7 +431,7 @@ ncclResult_t ncclSharpRegMr(void* collComm, void* data, size_t size, int type, v } TRACE(NCCL_INIT,"sharpRegAddr %lx size %ld handle %x", data, size, mh->mr); - NCCLCHECK(ncclNetPlugin_v8.regMr(cComm->recvComm, data, size, type, &mh->ncclIbMr)); + NCCLCHECK(ncclNetPlugin_v9.regMr(cComm->recvComm, data, size, type, &mh->ncclIbMr)); *mhandle = mh; return ncclSuccess; @@ -437,7 +449,7 @@ ncclResult_t ncclSharpDeregMr(void* collComm, void* mhandle) { WARN("SHARP deregmr failed\n"); } - NCCLCHECK(ncclNetPlugin_v8.deregMr(cComm->recvComm, mh->ncclIbMr)); + NCCLCHECK(ncclNetPlugin_v9.deregMr(cComm->recvComm, mh->ncclIbMr)); free(mh); return ncclSuccess; @@ -459,7 +471,7 @@ ncclResult_t ncclSharpGetRequest(struct ncclSharpRequest* reqs, struct ncclSharp return ncclInternalError; } -ncclResult_t ncclSharpIallreduce(void* collComm, void* sendData, void* recvData, int count, +ncclResult_t ncclSharpIallreduce(void* collComm, void* sendData, void* recvData, size_t count, ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { struct ncclSharpCollComm* cComm = (struct ncclSharpCollComm*)collComm; @@ -521,7 +533,14 @@ ncclResult_t ncclSharpIallreduce(void* collComm, void* sendData, void* recvData, return ncclSuccess; } -ncclResult_t ncclSharpIallgather(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v8_t* recvParts, +ncclResult_t ncclSharpIallreduce_v8(void* collComm, void* sendData, void* recvData, int count, + ncclDataType_t dataType, ncclRedOp_t redOp, void* sendMhandle, void* recvMhandle, void** request) { + return ncclSharpIallreduce(collComm, sendData, recvData, (size_t)count, dataType, redOp, sendMhandle, + recvMhandle, request); +} + + +ncclResult_t ncclSharpIallgather(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v9_t* recvParts, size_t bytesPerRank, size_t windowOffset, size_t windowBytes, void* sendMhandle, void** request) { @@ -571,7 +590,20 @@ ncclResult_t ncclSharpIallgather(void* collComm, void* sendData, int nRecvParts, return ncclSuccess; } -ncclResult_t ncclSharpIreducescatter(void* collComm, int nSendParts, ncclNetSGE_v8_t* sendParts, void* recvData, +ncclResult_t ncclSharpIallgather_v8(void* collComm, void* sendData, int nRecvParts, ncclNetSGE_v8_t* recvParts, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + void* sendMhandle, void** request) { + ncclNetSGE_v9_t recvParts_v9; + recvParts_v9.mhandle = recvParts[0].mhandle; + recvParts_v9.address = recvParts[0].address; + recvParts_v9.size =(size_t)recvParts[0].size; + + return ncclSharpIallgather(collComm, sendData, nRecvParts, &recvParts_v9, + bytesPerRank, windowOffset, windowBytes, sendMhandle, request); +} + + +ncclResult_t ncclSharpIreducescatter(void* collComm, int nSendParts, ncclNetSGE_v9_t* sendParts, void* recvData, size_t bytesPerRank, size_t windowOffset, size_t windowBytes, ncclDataType_t dataType, ncclRedOp_t redOp, void* recvMhandle, void** request) @@ -640,6 +672,21 @@ ncclResult_t ncclSharpIreducescatter(void* collComm, int nSendParts, ncclNetSGE_ return ncclSuccess; } + ncclResult_t ncclSharpIreducescatter_v8(void* collComm, int nSendParts, ncclNetSGE_v8_t* sendParts, void* recvData, + size_t bytesPerRank, size_t windowOffset, size_t windowBytes, + ncclDataType_t dataType, ncclRedOp_t redOp, + void* recvMhandle, void** request) { + ncclNetSGE_v9_t sendParts_v9; + sendParts_v9.mhandle = sendParts[0].mhandle; + sendParts_v9.address = sendParts[0].address; + sendParts_v9.size = (size_t)sendParts[0].size; + + return ncclSharpIreducescatter(collComm, nSendParts, &sendParts_v9, + recvData, bytesPerRank, windowOffset, windowBytes, dataType, redOp, + recvMhandle, request); +} + + ncclResult_t ncclSharpIflush(void* collComm, void* data, int size, void* mhandle, void **request) { struct ncclSharpCollComm *cComm = (struct ncclSharpCollComm*)collComm; struct ncclSharpMemHandle *mh = (struct ncclSharpMemHandle *)mhandle; @@ -647,7 +694,7 @@ ncclResult_t ncclSharpIflush(void* collComm, void* data, int size, void* mhandle NCCLCHECK(ncclSharpGetRequest(cComm->reqs, &req)); req->requestType = NCCL_SHARP_REQ_IFLUSH; - ncclNetPlugin_v8.iflush(cComm->recvComm, 1, &data, &size, &mh->ncclIbMr, &req->sharpRequest); + ncclNetPlugin_v9.iflush(cComm->recvComm, 1, &data, &size, &mh->ncclIbMr, &req->sharpRequest); if (!req->sharpRequest) { *request = NULL; req->used = 0; @@ -662,7 +709,7 @@ ncclResult_t ncclSharpTest(void* request, int* done, int* size) { struct ncclSharpRequest* req = (struct ncclSharpRequest*)request; if (req->requestType == NCCL_SHARP_REQ_IFLUSH) { - ncclNetPlugin_v8.test(req->sharpRequest, done, size); + ncclNetPlugin_v9.test(req->sharpRequest, done, size); if (*done == 1) { req->used = 0; } @@ -696,8 +743,8 @@ ncclResult_t ncclSharpCloseColl(void* collComm) { sharp_coll_comm_destroy(cComm->sharpCollComm); sharp_coll_finalize(cComm->sharpCollContext); - NCCLCHECK(ncclNetPlugin_v8.closeRecv(cComm->recvComm)); - NCCLCHECK(ncclNetPlugin_v8.closeSend(cComm->sendComm)); + NCCLCHECK(ncclNetPlugin_v9.closeRecv(cComm->recvComm)); + NCCLCHECK(ncclNetPlugin_v9.closeSend(cComm->sendComm)); free(cComm); return ncclSuccess; } @@ -706,16 +753,16 @@ ncclResult_t ncclSharpCloseListen(void* listenComm) { struct ncclSharpListenComm *lComm = (struct ncclSharpListenComm*)listenComm; ncclResult_t status; - status = ncclNetPlugin_v8.closeListen(lComm->listenCommP2P); + status = ncclNetPlugin_v9.closeListen(lComm->listenCommP2P); free(listenComm); return status; } -ncclCollNet_v8_t ncclCollNetPlugin_v8 = { +ncclCollNet_v9_t ncclCollNetPlugin_v9 = { "SHARP", ncclSharpInit, ncclSharpDevices, - ncclSharpGetProperties_v8, + ncclSharpGetProperties_v9, ncclSharpListen, ncclSharpConnect, ncclSharpReduceSupport, @@ -728,6 +775,27 @@ ncclCollNet_v8_t ncclCollNetPlugin_v8 = { ncclSharpIflush, ncclSharpTest, ncclSharpCloseColl, + ncclSharpCloseListen, + NULL +}; + +ncclCollNet_v8_t ncclCollNetPlugin_v8 = { + "SHARP", + ncclSharpInit, + ncclSharpDevices, + ncclSharpGetProperties_v8, + ncclSharpListen, + ncclSharpConnect, + ncclSharpReduceSupport, + ncclSharpRegMr, + ncclSharpRegMrDmaBuf, + ncclSharpDeregMr, + ncclSharpIallreduce_v8, + ncclSharpIallgather_v8, + ncclSharpIreducescatter_v8, + ncclSharpIflush, + ncclSharpTest, + ncclSharpCloseColl, ncclSharpCloseListen }; @@ -742,7 +810,7 @@ ncclCollNet_v7_t ncclCollNetPlugin_v7 = { ncclSharpRegMr_v7, ncclSharpRegMrDmaBuf, ncclSharpDeregMr, - ncclSharpIallreduce, + ncclSharpIallreduce_v8, ncclSharpIflush, ncclSharpTest, ncclSharpCloseColl, @@ -760,7 +828,7 @@ ncclCollNet_v6_t ncclCollNetPlugin_v6 = { ncclSharpRegMr_v7, ncclSharpRegMrDmaBuf, ncclSharpDeregMr, - ncclSharpIallreduce, + ncclSharpIallreduce_v8, ncclSharpIflush, ncclSharpTest, ncclSharpCloseColl, @@ -777,7 +845,7 @@ ncclCollNet_v5_t ncclCollNetPlugin_v5 = { ncclSharpReduceSupport, ncclSharpRegMr_v7, ncclSharpDeregMr, - ncclSharpIallreduce, + ncclSharpIallreduce_v8, ncclSharpIflush, ncclSharpTest, ncclSharpCloseColl, diff --git a/src/socket.c b/src/socket.c index 7c4a6fb..89365e3 100755 --- a/src/socket.c +++ b/src/socket.c @@ -14,6 +14,18 @@ #include #include #include "param.h" +#include + +NCCL_PARAM(RetryCnt, "SOCKET_RETRY_CNT", 34); +NCCL_PARAM(RetryTimeOut, "SOCKET_RETRY_SLEEP_MSEC", 100); +static void msleep(unsigned int time_msec) { + const long c_1e6 = 1e6; + struct timespec tv = (struct timespec){ + .tv_sec = time_msec / 1000, + .tv_nsec = (time_msec % 1000) * c_1e6, + }; + nanosleep(&tv, NULL); +} static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int block, int* closed) { int bytes = 0; @@ -28,8 +40,13 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr return ncclSuccess; } if (bytes == -1) { + if ((op == NCCL_SOCKET_SEND && errno == EPIPE) || (op == NCCL_SOCKET_RECV && errno == ECONNRESET)) { + *closed = 1; + return ncclSuccess; + } if (errno != EINTR && errno != EWOULDBLOCK && errno != EAGAIN) { - WARN("socketProgressOpt: Call to recv from %s failed : %s", ncclSocketToString(&sock->addr, line, 1), strerror(errno)); + WARN("socketProgressOpt: Call to %s %s failed : %s", (op == NCCL_SOCKET_RECV ? "recv from" : "send to"), + ncclSocketToString(&sock->addr, line, 1), strerror(errno)); return ncclRemoteError; } else { bytes = 0; @@ -40,24 +57,29 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr INFO(NCCL_NET, "socketProgressOpt: abort called"); return ncclInternalError; } - } while (bytes > 0 && (*offset) < size); + } while (sock->asyncFlag == 0 && bytes > 0 && (*offset) < size); return ncclSuccess; } -static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { +static ncclResult_t socketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* pclosed) { int closed; NCCLCHECK(socketProgressOpt(op, sock, ptr, size, offset, 0 /*block*/, &closed)); if (closed) { - char line[SOCKET_NAME_MAXLEN+1]; - WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); - return ncclRemoteError; + if (pclosed) { + *pclosed = closed; + return ncclSuccess; + } else { + char line[SOCKET_NAME_MAXLEN+1]; + WARN("socketProgress: Connection closed by remote peer %s", ncclSocketToString(&sock->addr, line, 0)); + return ncclRemoteError; + } } return ncclSuccess; } static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { while (*offset < size) - NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); + NCCLCHECK(socketProgress(op, sock, ptr, size, offset, NULL)); return ncclSuccess; } @@ -65,9 +87,9 @@ static ncclResult_t socketWait(int op, struct ncclSocket* sock, void* ptr, int s * * Output: "IPv4/IPv6 address" */ -const char *ncclSocketToString(union ncclSocketAddress *addr, char *buf, const int numericHostForm /*= 1*/) { +const char *ncclSocketToString(const union ncclSocketAddress *addr, char *buf, const int numericHostForm /*= 1*/) { if (buf == NULL || addr == NULL) return NULL; - struct sockaddr *saddr = &addr->sa; + const struct sockaddr *saddr = &addr->sa; if (saddr->sa_family != AF_INET && saddr->sa_family != AF_INET6) { buf[0]='\0'; return buf; } char host[NI_MAXHOST], service[NI_MAXSERV]; /* NI_NUMERICHOST: If set, then the numeric form of the hostname is returned. @@ -374,10 +396,9 @@ ncclResult_t ncclSocketListen(struct ncclSocket* sock) { if (socketToPort(&sock->addr)) { // Port is forced by env. Make sure we get the port. int opt = 1; -#if defined(SO_REUSEPORT) - SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); -#else SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)), "setsockopt"); +#if defined(SO_REUSEPORT) + SYSCHECK(setsockopt(sock->fd, SOL_SOCKET, SO_REUSEPORT, &opt, sizeof(opt)), "setsockopt"); #endif } @@ -416,6 +437,15 @@ static ncclResult_t socketTryAccept(struct ncclSocket* sock) { sock->fd = accept(sock->acceptFd, (struct sockaddr*)&sock->addr, &socklen); if (sock->fd != -1) { sock->state = ncclSocketStateAccepted; + } else if (errno == ENETDOWN || errno == EPROTO || errno == ENOPROTOOPT || errno == EHOSTDOWN || + errno == ENONET || errno == EHOSTUNREACH || errno == EOPNOTSUPP || errno == ENETUNREACH) { + /* per accept's man page, for linux sockets, the following errors might be already pending errors + * and should be considered as EAGAIN. To avoid infinite loop in case of errors, we use the retry count*/ + if (++sock->errorRetries == ncclParamRetryCnt()) { + WARN("socketTryAccept: exceeded error retry count (%d), %s", sock->errorRetries, strerror(errno)); + return ncclSystemError; + } + INFO(NCCL_ALL, "Call to accept returned %s, retrying", strerror(errno)); } else if (errno != EAGAIN && errno != EWOULDBLOCK) { WARN("socketTryAccept: Accept failed: %s", strerror(errno)); return ncclSystemError; @@ -423,72 +453,119 @@ static ncclResult_t socketTryAccept(struct ncclSocket* sock) { return ncclSuccess; } -static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) { - uint64_t magic; - enum ncclSocketType type; - int received = 0; +static ncclResult_t socketSetFlags(struct ncclSocket* sock) { const int one = 1; + /* Set socket as non-blocking if async or if we need to be able to abort */ + if ((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) { + int flags; + SYSCHECK(flags = fcntl(sock->fd, F_GETFL), "fcntl"); + SYSCHECK(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), "fcntl"); + } SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); + return ncclSuccess; +} - NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); - if (received == 0) return ncclSuccess; - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); - if (magic != sock->magic) { - WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic); - close(sock->fd); - sock->fd = -1; - // Ignore spurious connection and accept again - sock->state = ncclSocketStateAccepting; - return ncclSuccess; - } else { - received = 0; - NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &type, sizeof(type), &received)); - if (type != sock->type) { - WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type); - sock->state = ncclSocketStateError; +static ncclResult_t socketFinalizeAccept(struct ncclSocket* sock) { + uint64_t magic; + enum ncclSocketType type; + int received; + // once accepted, linux sockets do NOT inherit file status flags such as O_NONBLOCK (BSD ones do) + NCCLCHECK(socketSetFlags(sock)); + + if (sock->asyncFlag == 0 || sock->finalizeCounter < sizeof(magic)) { + if (sock->asyncFlag == 0) { + received = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &magic, sizeof(magic), &received)); + } else { + received = sock->finalizeCounter; + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, sock, sock->finalizeBuffer, sizeof(magic), &received, NULL)); + sock->finalizeCounter = received; + if (received < sizeof(magic)) return ncclSuccess; + memcpy(&magic, sock->finalizeBuffer, sizeof(magic)); + } + if (magic != sock->magic) { + WARN("socketFinalizeAccept: wrong magic %lx != %lx", magic, sock->magic); close(sock->fd); sock->fd = -1; - return ncclInternalError; - } else { - sock->state = ncclSocketStateReady; + // Ignore spurious connection and accept again + sock->state = ncclSocketStateAccepting; + return ncclSuccess; } } + if (sock->asyncFlag == 0) { + received = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_RECV, sock, &type, sizeof(type), &received)); + } else { + received = sock->finalizeCounter - sizeof(magic); + NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, sock, sock->finalizeBuffer, sizeof(type), &received, NULL)); + sock->finalizeCounter = received + sizeof(magic); + if (received < sizeof(type)) return ncclSuccess; + memcpy(&type, sock->finalizeBuffer, sizeof(type)); + } + if (type != sock->type) { + WARN("socketFinalizeAccept: wrong type %d != %d", type, sock->type); + sock->state = ncclSocketStateError; + close(sock->fd); + sock->fd = -1; + return ncclInternalError; + } else { + sock->state = ncclSocketStateReady; + } return ncclSuccess; } -static ncclResult_t socketStartConnect(struct ncclSocket* sock) { - /* blocking/non-blocking connect() is determined by asyncFlag. */ - int ret = connect(sock->fd, &sock->addr.sa, sock->salen); - - if (ret == 0) { +static ncclResult_t socketResetFd(struct ncclSocket* sock) { + ncclResult_t ret = ncclSuccess; + int fd = -1; + SYSCHECKGOTO(fd = socket(sock->addr.sa.sa_family, SOCK_STREAM, 0), "socket", ret, cleanup); + // if sock->fd is valid, close it and reuse its number + if (sock->fd != -1) { + SYSCHECKGOTO(dup2(fd, sock->fd), "dup2", ret, cleanup); + SYSCHECKGOTO(close(fd), "close", ret, cleanup); + } else { + sock->fd = fd; + } + NCCLCHECKGOTO(socketSetFlags(sock), ret, exit); +exit: + return ret; +cleanup: + // cleanup fd, leave sock->fd untouched + if (fd != -1) { + (void)close(fd); + } + goto exit; +} +static ncclResult_t socketConnectCheck(struct ncclSocket* sock, int errCode, const char funcName[]) { + if (errCode == 0) { sock->state = ncclSocketStateConnected; - return ncclSuccess; - } else if (errno == EINPROGRESS) { + } else if (errCode == EINPROGRESS) { sock->state = ncclSocketStateConnectPolling; - return ncclSuccess; - } else if (errno == ECONNREFUSED) { - if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketStartConnect: exceeded retries (%d)", sock->refusedRetries); - return ncclRemoteError; - } - usleep(SLEEP_INT); - if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); - return ncclSuccess; - } else if (errno == ETIMEDOUT) { - if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketStartConnect: exceeded timeouts (%d)", sock->timedOutRetries); - return ncclRemoteError; + } else if (errCode == ETIMEDOUT || errCode == EHOSTUNREACH || errCode == ECONNREFUSED) { + if (sock->customRetry == 0) { + if (sock->errorRetries++ == ncclParamRetryCnt()) { + sock->state = ncclSocketStateError; + WARN("%s: connect returned %s, exceeded error retry count (%d)", funcName, strerror(errCode), sock->errorRetries); + return ncclRemoteError; + } + unsigned int sleepTime = sock->errorRetries * ncclParamRetryTimeOut(); + INFO(NCCL_ALL, "%s: connect returned %s, retrying (%d/%ld) after sleep for %u msec", funcName, strerror(errCode), sock->errorRetries, ncclParamRetryCnt(), sleepTime); + msleep(sleepTime); } - usleep(SLEEP_INT); - return ncclSuccess; + NCCLCHECK(socketResetFd(sock)); /* in case of failure in connect, socket state is unspecified */ + sock->state = ncclSocketStateConnecting; } else { char line[SOCKET_NAME_MAXLEN+1]; sock->state = ncclSocketStateError; - WARN("socketStartConnect: Connect to %s failed : %s", ncclSocketToString(&sock->addr, line, 1), strerror(errno)); + WARN("%s: Connect to %s failed : %s", funcName, ncclSocketToString(&sock->addr, line, 1), strerror(errCode)); return ncclSystemError; } + return ncclSuccess; +} + +static ncclResult_t socketStartConnect(struct ncclSocket* sock) { + /* blocking/non-blocking connect() is determined by asyncFlag. */ + int ret = connect(sock->fd, &sock->addr.sa, sock->salen); + return socketConnectCheck(sock, (ret == -1) ? errno : 0, __func__); } static ncclResult_t socketPollConnect(struct ncclSocket* sock) { @@ -508,38 +585,12 @@ static ncclResult_t socketPollConnect(struct ncclSocket* sock) { return ncclRemoteError; } else if (ret != 1 || (pfd.revents & POLLOUT) == 0) { WARN("socketPollConnect poll() returned %d%s", ret, (pfd.revents & POLLOUT) ? "" : ", no POLLOUT events"); - return ncclSystemError;; + return ncclSystemError; } /* check socket status */ SYSCHECK(getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, (void*)&ret, &rlen), "getsockopt"); - - if (ret == 0) { - sock->state = ncclSocketStateConnected; - } else if (ret == ECONNREFUSED) { - if (++sock->refusedRetries == RETRY_REFUSED_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketPollConnect: exceeded retries (%d)", sock->refusedRetries); - return ncclRemoteError; - } - if (sock->refusedRetries % 1000 == 0) INFO(NCCL_ALL, "Call to connect returned %s, retrying", strerror(errno)); - usleep(SLEEP_INT); - sock->state = ncclSocketStateConnecting; - } else if (ret == ETIMEDOUT) { - if (++sock->timedOutRetries == RETRY_TIMEDOUT_TIMES) { - sock->state = ncclSocketStateError; - WARN("socketPollConnect: exceeded timeouts (%d)", sock->timedOutRetries); - return ncclRemoteError; - } - usleep(SLEEP_INT); - sock->state = ncclSocketStateConnecting; - } else if (ret != EINPROGRESS) { - sock->state = ncclSocketStateError; - char line[SOCKET_NAME_MAXLEN+1]; - WARN("socketPollConnect: Connect to %s returned %d(%s) errno %d(%s)", ncclSocketToString(&sock->addr, line, 1), ret, strerror(ret), errno, strerror(errno)); - return ncclSystemError; - } - return ncclSuccess; + return socketConnectCheck(sock, ret, __func__); } ncclResult_t ncclSocketPollConnect(struct ncclSocket* sock) { @@ -552,12 +603,24 @@ ncclResult_t ncclSocketPollConnect(struct ncclSocket* sock) { } static ncclResult_t socketFinalizeConnect(struct ncclSocket* sock) { - int sent = 0; - NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); - if (sent == 0) return ncclSuccess; - NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); - sent = 0; - NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); + int sent; + if (sock->asyncFlag == 0) { + sent = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent)); + sent = 0; + NCCLCHECK(socketWait(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent)); + } else { + if (sock->finalizeCounter < sizeof(sock->magic)) { + sent = sock->finalizeCounter; + NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->magic, sizeof(sock->magic), &sent, NULL)); + sock->finalizeCounter = sent; + if (sent < sizeof(sock->magic)) return ncclSuccess; + } + sent = sock->finalizeCounter - sizeof(sock->magic); + NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sock, &sock->type, sizeof(sock->type), &sent, NULL)); + sock->finalizeCounter = sent + sizeof(sock->magic); + if (sent < sizeof(sock->type)) return ncclSuccess; + } sock->state = ncclSocketStateReady; return ncclSuccess; } @@ -602,7 +665,6 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { #ifdef ENABLE_TRACE char line[SOCKET_NAME_MAXLEN+1]; #endif - const int one = 1; if (sock == NULL) { WARN("ncclSocketConnect: pass NULL socket"); @@ -620,9 +682,8 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) { } TRACE(NCCL_INIT|NCCL_NET,"Connecting to socket %s", ncclSocketToString(&sock->addr, line, 1)); - SYSCHECK(setsockopt(sock->fd, IPPROTO_TCP, TCP_NODELAY, (char*)&one, sizeof(int)), "setsockopt"); - sock->state = ncclSocketStateConnecting; + sock->finalizeCounter = 0; do { NCCLCHECK(socketProgressState(sock)); } while (sock->asyncFlag == 0 && @@ -668,6 +729,7 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen memcpy(sock, listenSock, sizeof(struct ncclSocket)); sock->acceptFd = listenSock->fd; sock->state = ncclSocketStateAccepting; + sock->finalizeCounter = 0; } do { @@ -698,12 +760,11 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen return ret; } -ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag) { +ncclResult_t ncclSocketInit(struct ncclSocket* sock, const union ncclSocketAddress* addr, uint64_t magic, enum ncclSocketType type, volatile uint32_t* abortFlag, int asyncFlag, int customRetry) { ncclResult_t ret = ncclSuccess; if (sock == NULL) goto exit; - sock->timedOutRetries = 0; - sock->refusedRetries = 0; + sock->errorRetries = 0; sock->abortFlag = abortFlag; sock->asyncFlag = asyncFlag; sock->state = ncclSocketStateInitialized; @@ -711,6 +772,7 @@ ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* ad sock->type = type; sock->fd = -1; sock->acceptFd = -1; + sock->customRetry = customRetry; if (addr) { /* IPv4/IPv6 support */ @@ -722,28 +784,16 @@ ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* ad WARN("ncclSocketInit: connecting to address %s with family %d is neither AF_INET(%d) nor AF_INET6(%d)", ncclSocketToString(&sock->addr, line, 1), family, AF_INET, AF_INET6); ret = ncclInternalError; - goto fail; + goto exit; } sock->salen = (family == AF_INET) ? sizeof(struct sockaddr_in) : sizeof(struct sockaddr_in6); - /* Connect to a hostname / port */ - sock->fd = socket(family, SOCK_STREAM, 0); - if (sock->fd == -1) { - WARN("ncclSocketInit: Socket creation failed : %s", strerror(errno)); - ret = ncclSystemError; - goto fail; - } + // in case of error, we close the fd before returning as it's unclear if the caller has to use ncclSocketClose for cleanup + NCCLCHECKGOTO(socketResetFd(sock), ret, fail); } else { memset(&sock->addr, 0, sizeof(union ncclSocketAddress)); } - /* Set socket as non-blocking if async or if we need to be able to abort */ - if ((sock->asyncFlag || sock->abortFlag) && sock->fd >= 0) { - int flags; - SYSCHECKGOTO(flags = fcntl(sock->fd, F_GETFL), "fcntl", ret, fail); - SYSCHECKGOTO(fcntl(sock->fd, F_SETFL, flags | O_NONBLOCK), "fcntl", ret, fail); - } - exit: return ret; fail: @@ -754,12 +804,12 @@ ncclResult_t ncclSocketInit(struct ncclSocket* sock, union ncclSocketAddress* ad goto exit; } -ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset) { +ncclResult_t ncclSocketProgress(int op, struct ncclSocket* sock, void* ptr, int size, int* offset, int* closed) { if (sock == NULL) { WARN("ncclSocketProgress: pass NULL socket"); return ncclInvalidArgument; } - NCCLCHECK(socketProgress(op, sock, ptr, size, offset)); + NCCLCHECK(socketProgress(op, sock, ptr, size, offset, closed)); return ncclSuccess; } @@ -792,7 +842,7 @@ ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) { WARN("ncclSocketRecv: pass NULL socket"); return ncclInvalidArgument; } - if (sock->state != ncclSocketStateReady) { + if (sock->state != ncclSocketStateReady && sock->state != ncclSocketStateTerminating) { WARN("ncclSocketRecv: socket state (%d) is not ready", sock->state); return ncclInternalError; } @@ -800,6 +850,24 @@ ncclResult_t ncclSocketRecv(struct ncclSocket* sock, void* ptr, int size) { return ncclSuccess; } +ncclResult_t ncclSocketSendRecv(struct ncclSocket* sendSock, void* sendPtr, int sendSize, struct ncclSocket* recvSock, void* recvPtr, int recvSize) { + int sendOffset = 0, recvOffset = 0; + if (sendSock == NULL || recvSock == NULL) { + WARN("ncclSocketSendRecv: invalid socket %p/%p", sendSock, recvSock); + return ncclInternalError; + } + if (sendSock->state != ncclSocketStateReady || + (recvSock->state != ncclSocketStateReady && recvSock->state != ncclSocketStateTerminating)) { + WARN("ncclSocketSendRecv: socket state (%d/%d) is not ready", sendSock->state, recvSock->state); + return ncclInternalError; + } + while (sendOffset < sendSize || recvOffset < recvSize) { + if (sendOffset < sendSize) NCCLCHECK(socketProgress(NCCL_SOCKET_SEND, sendSock, sendPtr, sendSize, &sendOffset, NULL)); + if (recvOffset < recvSize) NCCLCHECK(socketProgress(NCCL_SOCKET_RECV, recvSock, recvPtr, recvSize, &recvOffset, NULL)); + } + return ncclSuccess; +} + // Receive or detect connection closed ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int* closed, bool blocking) { int offset = 0; @@ -832,9 +900,20 @@ ncclResult_t ncclSocketTryRecv(struct ncclSocket* sock, void* ptr, int size, int return ncclSuccess; } -ncclResult_t ncclSocketClose(struct ncclSocket* sock) { +// Make it possible to close just one part of a socket. +ncclResult_t ncclSocketShutdown(struct ncclSocket* sock, int how) { if (sock != NULL) { if (sock->fd >= 0) { + shutdown(sock->fd, how); + } + sock->state = ncclSocketStateTerminating; + } + return ncclSuccess; +} + +ncclResult_t ncclSocketClose(struct ncclSocket* sock) { + if (sock != NULL) { + if (sock->state > ncclSocketStateNone && sock->state < ncclSocketStateNum && sock->fd >= 0) { /* shutdown() is needed to send FIN packet to proxy thread; shutdown() is not affected * by refcount of fd, but close() is. close() won't close a fd and send FIN packet if * the fd is duplicated (e.g. fork()). So shutdown() guarantees the correct and graceful diff --git a/src/ucx_plugin.c b/src/ucx_plugin.c index 88c856b..1b05841 100644 --- a/src/ucx_plugin.c +++ b/src/ucx_plugin.c @@ -42,6 +42,7 @@ static const ucp_tag_t tag = 0x8a000000; static const ucp_tag_t tag_mask = (uint64_t)(-1); static int ncclNIbDevs = -1; +static int ncclNMergedIbDevs = -1; enum ncclUCXCommState { ncclUCXCommStateStart = 0, @@ -69,9 +70,30 @@ ncclResult_t nccl_ucx_devices(int* ndev) { ncclResult_t nccl_ucx_get_properties(int dev, ncclNetProperties_t* props) { - return nccl_p2p_ib_get_properties(ncclIbDevs, dev, props); + return nccl_p2p_ib_get_properties(ncclIbDevs, ncclNMergedIbDevs, dev, props); } +ncclResult_t nccl_ucx_get_properties_v8(int dev, ncclNetProperties_v8_t* props_v8) +{ + ncclNetProperties_t props; + ncclResult_t ret = nccl_ucx_get_properties(dev, &props); + if (ret != ncclSuccess) return ret; + props_v8->name = props.name; + props_v8->pciPath = props.pciPath; + props_v8->guid = props.guid; + props_v8->ptrSupport = props.ptrSupport; + props_v8->regIsGlobal = props.regIsGlobal; + props_v8->speed = props.speed; + props_v8->latency = props.latency; + props_v8->port = props.port; + props_v8->maxComms = props.maxComms; + props_v8->maxRecvs = props.maxRecvs; + props_v8->netDeviceType = props.netDeviceType; + props_v8->netDeviceVersion = props.netDeviceVersion; + return ncclSuccess; +} + + ncclResult_t nccl_ucx_get_properties_v7(int dev, ncclNetProperties_v7_t* props_v7) { ncclNetProperties_t props; @@ -469,7 +491,7 @@ ncclResult_t nccl_ucx_init(ncclDebugLogger_t logFunction) { worker_tags[i] = tag; } - return nccl_p2p_ib_init(&ncclNIbDevs, ncclIbDevs, if_name, + return nccl_p2p_ib_init(&ncclNIbDevs, &ncclNMergedIbDevs, ncclIbDevs, if_name, &nccl_ucx_if_addr, NULL, logFunction); } @@ -480,7 +502,7 @@ ncclResult_t nccl_ucx_listen(int dev, void *handle, void **listen_comm) { NCCL_STATIC_ASSERT(sizeof(ucx_listen_handle_t) < NCCL_NET_HANDLE_MAXSIZE, "UCX listen handle size too large"); my_handle->magic = NCCL_SOCKET_MAGIC; - NCCLCHECK(ncclSocketInit(&comm->sock, &nccl_ucx_if_addr, my_handle->magic, ncclSocketTypeNetIb, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->sock, &nccl_ucx_if_addr, my_handle->magic, ncclSocketTypeNetIb, NULL, 1, 0)); NCCLCHECK(ncclSocketListen(&comm->sock)); NCCLCHECK(ncclSocketGetAddr(&comm->sock, &my_handle->connectAddr)); NCCLCHECK(ucx_get_ctx_and_worker(dev, &comm->ctx, &comm->ucx_worker, &comm->tag)); @@ -516,7 +538,7 @@ ncclResult_t nccl_ucx_connect(int dev, void *handle, void **send_comm, ncclNetDe if (stage->state == ncclUCXCommStateConnect) goto ucx_connect_check; NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(ucx_comm_t))); - NCCLCHECK(ncclSocketInit(&comm->sock, &recv_handle->connectAddr, recv_handle->magic, ncclSocketTypeNetIb, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->sock, &recv_handle->connectAddr, recv_handle->magic, ncclSocketTypeNetIb, NULL, 1, 0)); stage->comm = comm; stage->state = ncclUCXCommStateConnect; NCCLCHECK(ncclSocketConnect(&comm->sock)); @@ -567,7 +589,7 @@ ncclResult_t nccl_ucx_accept(void *listen_comm, void **recv_comm, ncclNetDeviceH l_comm->sock.asyncFlag = 1; r_comm->sock.asyncFlag = 1; - NCCLCHECK(ncclSocketInit(&r_comm->sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0)); + NCCLCHECK(ncclSocketInit(&r_comm->sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0, 0)); NCCLCHECK(ncclSocketAccept(&r_comm->sock, &l_comm->sock)); ucx_accept_check: NCCLCHECK(ncclSocketReady(&r_comm->sock, &ready)); @@ -824,7 +846,7 @@ static ucp_tag_t nccl_ucx_ucp_tag(ucp_tag_t comm_tag, uint64_t tag) return comm_tag + (tag << 32); } -static ncclResult_t nccl_ucx_isend(void *send_comm, void *data, int size, +static ncclResult_t nccl_ucx_isend(void *send_comm, void *data, size_t size, int tag, void *mhandle, void **request) { ucx_comm_t *comm = (ucx_comm_t *)send_comm; @@ -872,8 +894,12 @@ static ncclResult_t nccl_ucx_isend(void *send_comm, void *data, int size, return ncclSuccess; } +ncclResult_t nccl_ucx_isend_v8(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { + return nccl_ucx_isend(sendComm, data, (size_t)size, tag, mhandle, request); +} + static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data, - int *sizes, int *tags, void **mhandle, + size_t *sizes, int *tags, void **mhandle, void **request) { ucx_comm_t *comm = (ucx_comm_t*)recv_comm; @@ -931,6 +957,12 @@ static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data, return ncclSuccess; } +ncclResult_t nccl_ucx_irecv_v8(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { + size_t sizes_sizet[NCCL_NET_IB_MAX_RECVS]; + for (int i=0; iname = props.name; + props_v8->pciPath = props.pciPath; + props_v8->guid = props.guid; + props_v8->ptrSupport = props.ptrSupport; + props_v8->regIsGlobal = props.regIsGlobal; + props_v8->speed = props.speed; + props_v8->latency = props.latency; + props_v8->port = props.port; + props_v8->maxComms = props.maxComms; + props_v8->maxRecvs = props.maxRecvs; + props_v8->netDeviceType = props.netDeviceType; + props_v8->netDeviceVersion = props.netDeviceVersion; + return ncclSuccess; } ncclResult_t nccl_ucx_rma_get_properties_v7(int dev, ncclNetProperties_v7_t* props_v7) @@ -386,7 +407,7 @@ ncclResult_t nccl_ucx_rma_init(ncclDebugLogger_t logFunction) { char *config_env; if (ncclParamUCXRMADisable()) return ncclInternalError; - NCCLCHECK(nccl_p2p_ib_init(&ncclNIbDevs, ncclIbDevs, if_name, &nccl_ucx_if_addr, + NCCLCHECK(nccl_p2p_ib_init(&ncclNIbDevs, &ncclNMergedIbDevs, ncclIbDevs, if_name, &nccl_ucx_if_addr, NULL, logFunction)); if (strlen(nccl_ucx_rma_tls) == 0) { @@ -422,7 +443,7 @@ ncclResult_t nccl_ucx_rma_listen(int dev, void *handle, void **listen_comm) my_handle->magic = NCCL_SOCKET_MAGIC; NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(nccl_ucx_rma_listen_comm_t))); - NCCLCHECK(ncclSocketInit(&comm->sock, &nccl_ucx_if_addr, my_handle->magic, ncclSocketTypeNetIb, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->sock, &nccl_ucx_if_addr, my_handle->magic, ncclSocketTypeNetIb, NULL, 1, 0)); NCCLCHECK(ncclSocketListen(&comm->sock)); NCCLCHECK(ncclSocketGetAddr(&comm->sock, &my_handle->connectAddr)); @@ -471,7 +492,7 @@ ncclResult_t nccl_ucx_rma_connect(int dev, void *handle, void **send_comm, ncclN if (stage->state == ncclUCXCommStateConnect) goto ucx_connect_check; NCCLCHECK(ncclIbMalloc((void**)&comm, sizeof(*comm))); - NCCLCHECK(ncclSocketInit(&comm->super.sock, &recv_handle->connectAddr, recv_handle->magic, ncclSocketTypeNetIb, NULL, 1)); + NCCLCHECK(ncclSocketInit(&comm->super.sock, &recv_handle->connectAddr, recv_handle->magic, ncclSocketTypeNetIb, NULL, 1, 0)); stage->comm = comm; stage->state = ncclUCXCommStateConnect; NCCLCHECK(ncclSocketConnect(&comm->super.sock)); @@ -546,7 +567,7 @@ static ncclResult_t nccl_ucx_rma_init_ep(struct ncclSocket *sock, ucp_worker_h w NCCLCHECK(ncclSocketRecv(sock, &peer_addr_len, sizeof(size_t))); } else { NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, sock, &peer_addr_len, - sizeof(size_t), &bytes)); + sizeof(size_t), &bytes, NULL)); if (bytes == 0) { ep = NULL; return ncclSuccess; @@ -582,7 +603,7 @@ ncclResult_t nccl_ucx_rma_accept(void *listen_comm, void **recv_comm, ncclNetDev l_comm->sock.asyncFlag = 1; r_comm->super.sock.asyncFlag = 1; - NCCLCHECK(ncclSocketInit(&r_comm->super.sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0)); + NCCLCHECK(ncclSocketInit(&r_comm->super.sock, NULL, NCCL_SOCKET_MAGIC, ncclSocketTypeUnknown, NULL, 0, 0)); NCCLCHECK(ncclSocketAccept(&r_comm->super.sock, &l_comm->sock)); ucx_accept_check: @@ -844,7 +865,7 @@ static ncclResult_t nccl_ucx_rma_recv_check(nccl_ucx_rma_recv_comm_t *comm) if (comm->super.ready == NCCL_UCX_RMA_RCOMM_WAIT_SCOMM) { NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->super.sock, &rem_comm_state, - sizeof(int), &bytes)); + sizeof(int), &bytes, NULL)); if (bytes == 0) { return ncclSuccess; } @@ -878,7 +899,7 @@ static void nccl_ucx_rma_put_isend_cb(void *request, ucs_status_t status, void * return; } -ncclResult_t nccl_ucx_rma_isend(void *send_comm, void *data, int size, int tag, +ncclResult_t nccl_ucx_rma_isend(void *send_comm, void *data, size_t size, int tag, void *mhandle, void **request) { nccl_ucx_rma_send_comm_t *comm = (nccl_ucx_rma_send_comm_t*)send_comm; @@ -965,6 +986,11 @@ ncclResult_t nccl_ucx_rma_isend(void *send_comm, void *data, int size, int tag, return ncclSuccess; } +ncclResult_t nccl_ucx_rma_isend_v8(void* sendComm, void* data, int size, int tag, void* mhandle, void** request) { + return nccl_ucx_rma_isend(sendComm, data, (size_t)size, tag, mhandle, request); +} + + static void nccl_ucx_rma_dummy_am_cb(void *request, ucs_status_t status) { return; @@ -1014,7 +1040,7 @@ ncclResult_t nccl_ucx_rma_post_fifo(nccl_ucx_rma_recv_comm_t *comm, return ncclSuccess; } -ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data,int *tags, int *sizes, +ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data, size_t *sizes, int *tags, void **mhandle, void **request) { nccl_ucx_rma_recv_comm_t *comm = (nccl_ucx_rma_recv_comm_t*)recv_comm; @@ -1043,6 +1069,12 @@ ncclResult_t nccl_ucx_rma_irecv(void *recv_comm, int n, void **data,int *tags, i return ncclSuccess; } +ncclResult_t nccl_ucx_rma_irecv_v8(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request) { + size_t sizes_sizet[NCCL_NET_IB_MAX_RECVS]; + for (int i=0; isock, &context.if_addr, NCCL_UCT_LISTEN_HANDLE_MAGIC, ncclSocketTypeNetIb, - NULL, 1)); + NULL, 1, 0)); NCCLCHECK(ncclSocketListen(&l_comm->sock)); NCCLCHECK(ncclSocketGetAddr(&l_comm->sock, &addr)); @@ -544,7 +545,7 @@ ncclResult_t nccl_uct_connect(int dev, void *listen_handle, void **send_comm, NCCLCHECK(context.ops.comm_alloc(&comm)); NCCLCHECK(context.ops.comm_init(comm, &context, NULL, dev, handle->comm)); NCCLCHECK(ncclSocketInit(&comm->sock, &handle->listener.addr, handle->magic, - ncclSocketTypeNetIb, NULL, 1)); + ncclSocketTypeNetIb, NULL, 1, 0)); NCCLCHECK(ncclSocketConnect(&comm->sock)); stage->comm = comm; @@ -568,7 +569,7 @@ ncclResult_t nccl_uct_connect(int dev, void *listen_handle, void **send_comm, case NCCL_UCT_RECEIVE_ADDR: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &comm->remote.addr, sizeof(comm->remote.addr), - &stage->offset)); + &stage->offset, NULL)); if (stage->offset != sizeof(comm->remote.addr)) { return ncclSuccess; /* In progress */ } @@ -608,7 +609,7 @@ ncclResult_t nccl_uct_accept(void *listen_comm, void **recv_comm, comm = l_comm->comm; NCCLCHECK(ncclSocketInit(&comm->sock, NULL, NCCL_SOCKET_MAGIC, - ncclSocketTypeUnknown, NULL, 0)); + ncclSocketTypeUnknown, NULL, 0, 0)); NCCLCHECK(ncclSocketAccept(&comm->sock, &l_comm->sock)); NCCLCHECK(context.ops.comm_init(comm, l_comm->context, l_comm->uct_worker, l_comm->dev, NULL)); @@ -633,7 +634,7 @@ ncclResult_t nccl_uct_accept(void *listen_comm, void **recv_comm, case NCCL_UCT_RECEIVE_REMOTE: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &comm->remote, - sizeof(comm->remote), &stage->offset)); + sizeof(comm->remote), &stage->offset, NULL)); if (stage->offset != sizeof(comm->remote)) { return ncclSuccess; } @@ -647,7 +648,7 @@ ncclResult_t nccl_uct_accept(void *listen_comm, void **recv_comm, case NCCL_UCT_RX_READY: NCCLCHECK(ncclSocketProgress(NCCL_SOCKET_RECV, &comm->sock, &stage->ready, - sizeof(stage->ready), &stage->offset)); + sizeof(stage->ready), &stage->offset, NULL)); if (stage->offset != sizeof(ready)) { return ncclSuccess; } @@ -807,7 +808,29 @@ void nccl_uct_comm_deinit(nccl_uct_comm_t *comm) { } ncclResult_t nccl_uct_get_properties(int dev, ncclNetProperties_t *props) { - return nccl_p2p_ib_get_properties(ncclIbDevs, dev, props); + return nccl_p2p_ib_get_properties(ncclIbDevs, context.merge_dev_count, dev, props); +} + +ncclResult_t nccl_uct_get_properties_v8(int dev, ncclNetProperties_v8_t *props_v8) { + ncclNetProperties_t props; + ncclResult_t ret = nccl_uct_get_properties(dev, &props); + if (ret != ncclSuccess) { + return ret; + } + + props_v8->name = props.name; + props_v8->pciPath = props.pciPath; + props_v8->guid = props.guid; + props_v8->ptrSupport = props.ptrSupport; + props_v8->regIsGlobal = props.regIsGlobal; + props_v8->speed = props.speed; + props_v8->latency = props.latency; + props_v8->port = props.port; + props_v8->maxComms = props.maxComms; + props_v8->maxRecvs = props.maxRecvs; + props_v8->netDeviceType = props.netDeviceType; + props_v8->netDeviceVersion = props.netDeviceVersion; + return ncclSuccess; } ncclResult_t nccl_uct_get_properties_v7(int dev, diff --git a/src/ucx_uct_plugin.c b/src/ucx_uct_plugin.c index f03361c..98f3495 100644 --- a/src/ucx_uct_plugin.c +++ b/src/ucx_uct_plugin.c @@ -107,7 +107,7 @@ static size_t nccl_uct_rdesc_size(int n) { /* Prepare a receive descriptor from irecv()/iflush() side */ static void nccl_uct_rdesc_set(nccl_uct_rdesc_t *rdesc, uint64_t id, int n, - void **data, int *sizes, int *tags, + void **data, size_t *sizes, int *tags, nccl_uct_memh_t **uct_memh) { nccl_uct_rdesc_hdr_t *desc = &rdesc->desc; int i; @@ -238,7 +238,7 @@ static ncclResult_t nccl_uct_wr_init(ncclDebugLogger_t logFunction) { context.am_short_size = nccl_uct_rdesc_size(NCCL_UCX_UCT_MAX_RECVS); context.rkey_size = sizeof(((nccl_uct_chunk_t*)0)->rkey); - return nccl_p2p_ib_init(&context.dev_count, ncclIbDevs, context.if_name, + return nccl_p2p_ib_init(&context.dev_count, &context.merge_dev_count, ncclIbDevs, context.if_name, &context.if_addr, NULL, logFunction); } @@ -315,7 +315,7 @@ static ncclResult_t nccl_uct_send(nccl_uct_wr_comm_t *comm, void *data, return ncclSuccess; } -static ncclResult_t nccl_uct_wr_isend(void *send_comm, void *data, int size, +static ncclResult_t nccl_uct_wr_isend(void *send_comm, void *data, size_t size, int tag, void *mhandle, void **request) { nccl_uct_wr_comm_t *comm = nccl_uct_wr_comm_get(send_comm); nccl_uct_rdesc_t *rdesc; @@ -338,8 +338,14 @@ static ncclResult_t nccl_uct_wr_isend(void *send_comm, void *data, int size, return ncclSuccess; } +static ncclResult_t nccl_uct_wr_isend_v8(void *send_comm, void *data, int size, + int tag, void *mhandle, void **request) { + return nccl_uct_wr_isend(send_comm, data, (size_t)size, tag, mhandle, request); +} + + static ncclResult_t nccl_uct_wr_irecv(void *recv_comm, int n, void **data, - int *sizes, int *tags, void **mhandles, + size_t *sizes, int *tags, void **mhandles, void **request) { nccl_uct_wr_comm_t *comm = nccl_uct_wr_comm_get(recv_comm); nccl_uct_memh_t **uct_memh = (nccl_uct_memh_t**)mhandles; @@ -369,6 +375,14 @@ static ncclResult_t nccl_uct_wr_irecv(void *recv_comm, int n, void **data, return ncclSuccess; } +static ncclResult_t nccl_uct_wr_irecv_v8(void *recv_comm, int n, void **data, + int *sizes, int *tags, void **mhandles, + void **request) { + size_t sizes_sizet[NCCL_NET_IB_MAX_RECVS]; + for (int i=0; icomm->req_count--; } -static ncclResult_t nccl_uct_rd_isend(void *send_comm, void *data, int size, +static ncclResult_t nccl_uct_rd_isend(void *send_comm, void *data, size_t size, int tag, void *mhandle, void **request) { nccl_uct_rd_comm_t *comm = nccl_uct_rd_comm_get(send_comm); @@ -302,8 +302,13 @@ static ncclResult_t nccl_uct_rd_isend(void *send_comm, void *data, int size, return ncclSuccess; } +static ncclResult_t nccl_uct_rd_isend_v8(void *send_comm, void *data, int size, + int tag, void *mhandle, void **request) { + return nccl_uct_rd_isend(send_comm, data, (size_t)size, tag, mhandle, request); +} + static ncclResult_t nccl_uct_rd_irecv(void *recv_comm, int n, void **data, - int *sizes, int *tags, void **mhandles, + size_t *sizes, int *tags, void **mhandles, void **request) { nccl_uct_rd_comm_t *comm = nccl_uct_rd_comm_get(recv_comm); nccl_uct_memh_t **uct_memh = (nccl_uct_memh_t**)mhandles; @@ -348,6 +353,14 @@ static ncclResult_t nccl_uct_rd_irecv(void *recv_comm, int n, void **data, return ncclSuccess; } +static ncclResult_t nccl_uct_rd_irecv_v8(void *recv_comm, int n, void **data, + int *sizes, int *tags, void **mhandles, + void **request) { + size_t sizes_sizet[NCCL_NET_IB_MAX_RECVS]; + for (int i=0; i