Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update #162

Merged
merged 1 commit into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 104 additions & 46 deletions src/ib_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pthread_mutex_t ncclIbLock = PTHREAD_MUTEX_INITIALIZER;
int ncclIbRelaxedOrderingEnabled = 0;

NCCL_PARAM(IbGidIndex, "IB_GID_INDEX", -1);
NCCL_PARAM(IbRoutableFlidIbGidIndex, "IB_ROUTABLE_FLID_GID_INDEX", 1);
NCCL_PARAM(IbRoceVersionNum, "IB_ROCE_VERSION_NUM", 2);
NCCL_PARAM(IbIsGlobal, "IB_IS_GLOBAL", 0);
NCCL_PARAM(IbTimeout, "IB_TIMEOUT", 18);
Expand All @@ -46,6 +47,7 @@ NCCL_PARAM(IbSl, "IB_SL", 0);
NCCL_PARAM(IbTc, "IB_TC", 0);
NCCL_PARAM(IbArThreshold, "IB_AR_THRESHOLD", 8192);
NCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2);
NCCL_PARAM(IbFifoTc, "IB_FIFO_TC", 0);

static pthread_t ncclIbAsyncThread;

Expand Down Expand Up @@ -249,7 +251,38 @@ static ncclResult_t ncclUpdateGidIndex(struct ibv_context* context, uint8_t port
return ncclSuccess;
}

static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, int gidTblLen, int *gidIndex) {
// GID Format
// global: | 64b - subnet-prefix | 64b - EUI |
// raw : | 10b fixed | 22b 0 | 16b FLID | 16b subnet-prefix | 64b - EUI |
static uint16_t ncclIbExtractLocalSubnetPrefix(uint64_t subnet_prefix)
{
return (be64toh(subnet_prefix) & 0xffff);
}

static int ncclIbExtractFlid (union ibv_gid *gid)
{
return ntohs(*((uint16_t*)((uintptr_t)(gid->raw) + 4)));
}

static ncclResult_t ncclIbGetGidIndex(struct ibv_context *context, uint8_t portNum, struct ibv_port_attr* portAttr, int *gidIndex) {
int gidTblLen = portAttr->gid_tbl_len;

//for IB, choose GID Index that will have routable FLID if present
if (portAttr->link_layer == IBV_LINK_LAYER_INFINIBAND) {
union ibv_gid gid;
int routableGidIndex = ncclParamIbRoutableFlidIbGidIndex();
if (routableGidIndex < gidTblLen) {
NCCLCHECK(wrap_ibv_query_gid(context, portNum, routableGidIndex, &gid));
if (ncclIbExtractFlid(&gid) != 0) {
*gidIndex = routableGidIndex;
return ncclSuccess;
}
}
*gidIndex = 0;
return ncclSuccess;
}

//for ROCE
*gidIndex = ncclParamIbGidIndex();
if (*gidIndex >= 0) {
return ncclSuccess;
Expand Down Expand Up @@ -342,12 +375,13 @@ typedef struct ncclIbDevInfo {
uint8_t link_layer;
uint8_t is_global;

// For RoCE and IB GRH
uint64_t spn;
uint64_t iid;
//For RoCE and IB GRH & Rounter
union ibv_gid gid;

// FIFO RDMA info
uint32_t fifoRkey;

//remote dev info
union ibv_gid remoteGid;
} ncclIbDevInfo;

Expand Down Expand Up @@ -579,29 +613,53 @@ ncclResult_t ncclIbCreateQp(uint8_t ib_port, struct ncclIbNetCommDevBase* base,
return ncclSuccess;
}

ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, uint8_t sGidIndex, uint32_t dest_qp_num, struct ncclIbDevInfo* info) {
ncclResult_t ncclIbRtrQp(struct ibv_qp* qp, struct ncclIbGidInfo* sGidInfo, uint32_t dest_qp_num, struct ncclIbDevInfo* info, bool override_tc) {
struct ibv_qp_attr qpAttr;
int same_subnet;
memset(&qpAttr, 0, sizeof(struct ibv_qp_attr));
qpAttr.qp_state = IBV_QPS_RTR;
qpAttr.path_mtu = info->mtu;
qpAttr.dest_qp_num = dest_qp_num;
qpAttr.rq_psn = 0;
qpAttr.max_dest_rd_atomic = 1;
qpAttr.min_rnr_timer = 12;
qpAttr.ah_attr.is_global = 0;
qpAttr.ah_attr.dlid = info->lid;
qpAttr.ah_attr.sl = ncclParamIbSl();
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info->ib_port;
if (info->link_layer == IBV_LINK_LAYER_ETHERNET || info->is_global) {
if (info->link_layer == IBV_LINK_LAYER_ETHERNET) {
qpAttr.ah_attr.is_global = 1;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->spn;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->iid;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id;
qpAttr.ah_attr.grh.flow_label = 0;
qpAttr.ah_attr.grh.sgid_index = sGidIndex;
qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex;
qpAttr.ah_attr.grh.hop_limit = 255;
qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc();
if(ncclParamIbFifoTc() && override_tc) {
qpAttr.ah_attr.grh.traffic_class = ncclParamIbFifoTc();
} else {
qpAttr.ah_attr.grh.traffic_class = ncclParamIbTc();
}
} else {
same_subnet = (ncclIbExtractLocalSubnetPrefix(sGidInfo->localGid.global.subnet_prefix) ==
ncclIbExtractLocalSubnetPrefix(info->gid.global.subnet_prefix));
qpAttr.ah_attr.is_global = 0;
qpAttr.ah_attr.dlid = info->lid;
if (!same_subnet || info->is_global) {
if (!same_subnet) {
uint16_t flid = ncclIbExtractFlid(&info->gid);
if (flid == 0) {
WARN("Warning: remote FLID configured as zero even when endpoints are on different subnets, using dlid as fallback");
qpAttr.ah_attr.dlid = info->lid;
} else {
qpAttr.ah_attr.dlid = ncclIbExtractFlid(&info->gid);
}
}
qpAttr.ah_attr.is_global = 1;
qpAttr.ah_attr.grh.dgid.global.subnet_prefix = info->gid.global.subnet_prefix;
qpAttr.ah_attr.grh.dgid.global.interface_id = info->gid.global.interface_id;
qpAttr.ah_attr.grh.sgid_index = sGidInfo->localGidIndex;
qpAttr.ah_attr.grh.hop_limit = 255;
}
}
qpAttr.ah_attr.sl = ncclParamIbSl();
qpAttr.ah_attr.src_path_bits = 0;
qpAttr.ah_attr.port_num = info->ib_port;
NCCLCHECK(wrap_ibv_modify_qp(qp, &qpAttr, IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER));
return ncclSuccess;
}
Expand Down Expand Up @@ -711,29 +769,28 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
NCCLCHECK(wrap_ibv_reg_mr(&commDev->fifoMr, commDev->base.pd, comm->fifo, sizeof(struct ncclIbSendFifo)*MAX_REQUESTS*NCCL_NET_IB_MAX_RECVS, IBV_ACCESS_LOCAL_WRITE|IBV_ACCESS_REMOTE_WRITE|IBV_ACCESS_REMOTE_READ));
devInfo->fifoRkey = commDev->fifoMr->rkey;

// RoCE support
devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer;
devInfo->is_global = (ncclParamIbIsGlobal()
#if HAVE_DECL_IBV_QPF_GRH_REQUIRED
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
#endif
);

// Pack local GID info
devInfo->link_layer = commDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer;
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &commDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid));
devInfo->gid.global.subnet_prefix = commDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo->gid.global.interface_id = commDev->base.gidInfo.localGid.global.interface_id;

if (devInfo->link_layer == IBV_LINK_LAYER_ETHERNET || devInfo->is_global) {

NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &commDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, commDev->base.gidInfo.localGidIndex, &commDev->base.gidInfo.localGid));
devInfo->spn = commDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo->iid = commDev->base.gidInfo.localGid.global.interface_id;
}

// info logging
if (devInfo->link_layer == IBV_LINK_LAYER_INFINIBAND) { // IB
for (int q = 0; q < comm->base.nqps; q++) {
// Print just the QPs for this dev
if (comm->base.qps[q].devIndex == i)
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d fifoRkey=0x%x fifoLkey=0x%x",
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d LID %d subnet-prefix %lu FLID %d fifoRkey=0x%x fifoLkey=0x%x",
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev",
dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid, devInfo->fifoRkey, commDev->fifoMr->lkey);
dev, commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, devInfo->lid,
devInfo->gid.global.subnet_prefix, ncclIbExtractFlid(&devInfo->gid), devInfo->fifoRkey, commDev->fifoMr->lkey);
}
} else { // RoCE
for (int q = 0; q < comm->base.nqps; q++) {
Expand All @@ -742,7 +799,7 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
INFO(NCCL_NET,"NET/IB: %s %d IbDev %d Port %d qpn %d mtu %d query_ece={supported=%d, vendor_id=0x%x, options=0x%x, comp_mask=0x%x} GID %ld (%lX/%lX) fifoRkey=0x%x fifoLkey=0x%x",
comm->base.ndevs > 2 ? "NCCL MergedDev" : "NCCL Dev", dev,
commDev->base.ibDevN, ibDev->portNum, meta.qpInfo[q].qpn, devInfo->mtu, meta.qpInfo[q].ece_supported, meta.qpInfo[q].ece.vendor_id, meta.qpInfo[q].ece.options, meta.qpInfo[q].ece.comp_mask, (int64_t)commDev->base.gidInfo.localGidIndex,
devInfo->spn, devInfo->iid, devInfo->fifoRkey, commDev->fifoMr->lkey);
devInfo->gid.global.subnet_prefix, devInfo->gid.global.interface_id, devInfo->fifoRkey, commDev->fifoMr->lkey);
}
}
}
Expand Down Expand Up @@ -792,8 +849,8 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
for (int i = 0; i < remMeta.ndevs; i++) {
comm->base.remDevs[i] = remMeta.devs[i];
comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].iid;
comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].spn;
comm->base.remDevs[i].remoteGid.global.interface_id = comm->base.remDevs[i].gid.global.interface_id;
comm->base.remDevs[i].remoteGid.global.subnet_prefix = comm->base.remDevs[i].gid.global.subnet_prefix;

// Retain remote sizes fifo info and prepare RDMA ops
comm->remSizesFifo.rkeys[i] = remMeta.devs[i].fifoRkey;
Expand All @@ -812,13 +869,12 @@ ncclResult_t ncclIbConnect(int dev, void* opaqueHandle, void** sendComm, ncclNet
comm->base.qps[q].remDevIdx = remQpInfo->devIndex;
int devIndex = comm->base.qps[q].devIndex;
ncclIbSendCommDev* commDev = comm->devs + devIndex;
uint8_t gidIndex = commDev->base.gidInfo.localGidIndex;

struct ibv_qp* qp = comm->base.qps[q].qp;
if (remQpInfo->ece_supported && remQpInfo->ece_supported)
NCCLCHECK(wrap_ibv_set_ece(qp, &remQpInfo->ece, &remQpInfo->ece_supported));

NCCLCHECK(ncclIbRtrQp(qp, gidIndex, remQpInfo->qpn, remDevInfo));
NCCLCHECK(ncclIbRtrQp(qp, &commDev->base.gidInfo, remQpInfo->qpn, remDevInfo, false));
NCCLCHECK(ncclIbRtsQp(qp));
}

Expand Down Expand Up @@ -918,15 +974,15 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
ibDevN = mergedDev->devs[i];
NCCLCHECK(ncclIbInitCommDevBase(ibDevN, &rCommDev->base));
ibDev = ncclIbDevs + ibDevN;
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, ibDev->portAttr.gid_tbl_len, &rCommDev->base.gidInfo.localGidIndex));
NCCLCHECK(ncclIbGetGidIndex(ibDev->context, ibDev->portNum, &ibDev->portAttr, &rCommDev->base.gidInfo.localGidIndex));
NCCLCHECK(wrap_ibv_query_gid(ibDev->context, ibDev->portNum, rCommDev->base.gidInfo.localGidIndex, &rCommDev->base.gidInfo.localGid));
}

// Copy remDevInfo for things like remGidInfo, remFifoAddr, etc.
for (int i = 0; i < remMeta.ndevs; i++) {
rComm->base.remDevs[i] = remMeta.devs[i];
rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].iid;
rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].spn;
rComm->base.remDevs[i].remoteGid.global.interface_id = rComm->base.remDevs[i].gid.global.interface_id;
rComm->base.remDevs[i].remoteGid.global.subnet_prefix = rComm->base.remDevs[i].gid.global.subnet_prefix;
}

// Stripe QP creation across merged devs
Expand Down Expand Up @@ -957,7 +1013,8 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
if (meta.qpInfo[q].ece_supported)
NCCLCHECK(wrap_ibv_query_ece(qp->qp, &meta.qpInfo[q].ece, &meta.qpInfo[q].ece_supported));
}
NCCLCHECK(ncclIbRtrQp(qp->qp, rCommDev->base.gidInfo.localGidIndex, remMeta.qpInfo[q].qpn, remDevInfo));
bool override_tc = (q == 0) ? true : false;
NCCLCHECK(ncclIbRtrQp(qp->qp, &rCommDev->base.gidInfo, remMeta.qpInfo[q].qpn, remDevInfo, override_tc));
NCCLCHECK(ncclIbRtsQp(qp->qp));
}

Expand Down Expand Up @@ -987,24 +1044,24 @@ ncclResult_t ncclIbAccept(void* listenComm, void** recvComm, ncclNetDeviceHandl
devInfo.lid = ibDev->portAttr.lid;
devInfo.link_layer = ibDev->portAttr.link_layer;
devInfo.ib_port = ibDev->portNum;
devInfo.spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo.iid = rCommDev->base.gidInfo.localGid.global.interface_id;
devInfo.gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
devInfo.gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id;
devInfo.is_global = (ncclParamIbIsGlobal()
#if HAVE_DECL_IBV_QPF_GRH_REQUIRED
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
#endif
);
devInfo.mtu = ibDev->portAttr.active_mtu;
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, rCommDev->base.gidInfo.localGidIndex, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
NCCLCHECK(ncclIbRtrQp(rCommDev->gpuFlush.qp.qp, &rCommDev->base.gidInfo, rCommDev->gpuFlush.qp.qp->qp_num, &devInfo, false));
NCCLCHECK(ncclIbRtsQp(rCommDev->gpuFlush.qp.qp));
}

// Fill Handle
meta.devs[i].lid = ibDev->portAttr.lid;
meta.devs[i].link_layer = rCommDev->base.gidInfo.link_layer = ibDev->portAttr.link_layer;
meta.devs[i].ib_port = ibDev->portNum;
meta.devs[i].spn = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
meta.devs[i].iid = rCommDev->base.gidInfo.localGid.global.interface_id;
meta.devs[i].gid.global.subnet_prefix = rCommDev->base.gidInfo.localGid.global.subnet_prefix;
meta.devs[i].gid.global.interface_id = rCommDev->base.gidInfo.localGid.global.interface_id;
meta.devs[i].is_global = (ncclParamIbIsGlobal()
#if HAVE_DECL_IBV_QPF_GRH_REQUIRED
|| (ibDev->portAttr.flags & IBV_QPF_GRH_REQUIRED)
Expand Down Expand Up @@ -1612,9 +1669,10 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {
}

char line[SOCKET_NAME_MAXLEN+1];
WARN("NET/IB : Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s",
char *hcaName = r->devBases[i]->pd->context->device->name;
WARN("NET/IB: Got completion from peer %s with status=%d opcode=%d len=%d vendor err %d (%s)%s%s%s%s hca %s",
ncclSocketToString(&addr, line, 1), wc->status, wc->opcode, wc->byte_len, wc->vendor_err, reqTypeStr[r->type],
localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString);
localGidStr ? " localGid ":"", localGidString, remoteGidStr ? " remoteGids":"", remoteGidString, hcaName);
return ncclRemoteError;
}

Expand All @@ -1624,7 +1682,7 @@ ncclResult_t ncclIbTest(void* request, int* done, int* sizes) {

#ifdef ENABLE_TRACE
char line[SOCKET_NAME_MAXLEN+1];
TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%d r=%p type=%d events={%d,%d}, i=%d",
TRACE(NCCL_NET, "Got completion from peer %s with status=%d opcode=%d len=%d wr_id=%ld r=%p type=%d events={%d,%d}, i=%d",
ncclSocketToString(&addr, line, 1), wc->status, wc->opcode,wc->byte_len, wc->wr_id, req, req->type, req->events[0], req->events[1], i);
#endif
if (req->type == NCCL_NET_IB_REQ_SEND) {
Expand Down
22 changes: 20 additions & 2 deletions src/p2p_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,10 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
int nUserIfs = parseStringList(userIbEnv, userIfs, MAX_IB_DEVS);

if (ncclSuccess != wrap_ibv_get_device_list(&devices, &nIbDevs)) { ret = ncclInternalError; goto fail; }

// Should NCCL merge multi-port devices into one?
int mergeNics;
mergeNics = ncclParamIbMergeNics();
build_ib_list:
for (int d=0; d<nIbDevs; d++) {
struct ibv_context * context;
if (ncclSuccess != wrap_ibv_open_device(&context, devices[d]) || context == NULL) {
Expand Down Expand Up @@ -398,7 +401,7 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
}

int mergedDev = ncclNMergedIbDevs;
if (ncclParamIbMergeNics()) {
if (mergeNics) {
mergedDev = ncclIbFindMatchingDev(ncclNIbDevs);
}

Expand All @@ -425,6 +428,21 @@ ncclResult_t nccl_p2p_ib_init(int *num_devs, ncclIbDev *ncclIbDevs, char *ncclIb
}
if (nPorts == 0 && ncclSuccess != wrap_ibv_close_device(context)) { ret = ncclInternalError; goto fail; }
}

// Detect if there are both multi-port and single-port NICs in the system. If so, disable port merging and build the list again
if (mergeNics) {
for (int d = 0; d < ncclNMergedIbDevs; d++) {
if (ncclIbMergedDevs[d].ndevs != ncclIbMergedDevs[0].ndevs) {
INFO(NCCL_NET, "Detected a mix of single and multiple-port NICs. Force-disabling NCCL_IB_MERGE_NICS");
mergeNics = 0;
ncclNIbDevs = 0;
ncclNMergedIbDevs = 0;
memset(ncclIbMergedDevs, 0, sizeof(ncclIbMergedDevs));
goto build_ib_list;
}
}
}

if (nIbDevs && (ncclSuccess != wrap_ibv_free_device_list(devices))) { ret = ncclInternalError; goto fail; };
}
if (ncclNIbDevs == 0) {
Expand Down
10 changes: 5 additions & 5 deletions src/socket.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ static ncclResult_t socketProgressOpt(int op, struct ncclSocket* sock, void* ptr
}
}
(*offset) += bytes;
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) {
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) {
INFO(NCCL_NET, "socketProgressOpt: abort called");
return ncclInternalError;
}
Expand Down Expand Up @@ -624,12 +624,12 @@ ncclResult_t ncclSocketConnect(struct ncclSocket* sock) {
do {
NCCLCHECK(socketProgressState(sock));
} while (sock->asyncFlag == 0 &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) &&
(sock->state == ncclSocketStateConnecting ||
sock->state == ncclSocketStateConnectPolling ||
sock->state == ncclSocketStateConnected));

if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError;
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError;

switch (sock->state) {
case ncclSocketStateConnecting:
Expand Down Expand Up @@ -671,11 +671,11 @@ ncclResult_t ncclSocketAccept(struct ncclSocket* sock, struct ncclSocket* listen
do {
NCCLCHECKGOTO(socketProgressState(sock), ret, exit);
} while (sock->asyncFlag == 0 &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED) == 0) &&
(sock->abortFlag == NULL || __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE) == 0) &&
(sock->state == ncclSocketStateAccepting ||
sock->state == ncclSocketStateAccepted));

if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_RELAXED)) return ncclInternalError;
if (sock->abortFlag && __atomic_load_n(sock->abortFlag, __ATOMIC_ACQUIRE)) return ncclInternalError;

switch (sock->state) {
case ncclSocketStateAccepting:
Expand Down
Loading