Skip to content

Commit

Permalink
Pass the memory handler and filter config to exclude cuda transport
Browse files Browse the repository at this point in the history
  • Loading branch information
tvegas1 committed Oct 27, 2023
1 parent 7845726 commit 977ec32
Showing 1 changed file with 57 additions and 6 deletions.
63 changes: 57 additions & 6 deletions src/ucx_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
} while(0)

NCCL_PARAM(UCXDisable, "UCX_DISABLE", 0);
/* Exclude cuda-related UCX transports */
NCCL_PARAM(UCXCudaDisable, "UCX_CUDA_DISABLE", 1);

extern ncclDebugLogger_t pluginLogFunction;
static const ucp_tag_t tag = 0x8a000000;
Expand Down Expand Up @@ -210,15 +212,58 @@ static ncclResult_t GetSocketAddr(union ncclSocketAddress *addr) {
return ncclSuccess;
}

static ncclResult_t ucx_config_no_cuda(ucp_config_t *config) {
char tmp[PATH_MAX];
const char *ucx_tls;
ssize_t n;

ucx_tls = getenv("NCCL_UCX_TLS");
if (ucx_tls == NULL) {
ucx_tls = getenv("UCX_TLS");
}

if (ucx_tls == NULL) {
ucx_tls = "^cuda";
} else if (ucx_tls[0] == '^') {
/* Negative expression, make sure to keep cuda excluded */
n = snprintf(tmp, sizeof(tmp), "^cuda,%s", &ucx_tls[1]);
if (n >= sizeof(tmp)) {
return ncclInternalError;
}

ucx_tls = tmp;
} else {
/* Positive expression cannot allow cuda-like transports */
if ((strstr(ucx_tls, "cuda") != NULL) || (strstr(ucx_tls, "gdr") != NULL)) {
WARN("Cannot use cuda/gdr transports as part of specified UCX_TLS");
return ncclInternalError;
}
}

UCXCHECK(ucp_config_modify(config, "TLS", ucx_tls));
UCXCHECK(ucp_config_modify(config, "RNDV_THRESH", "0"));
UCXCHECK(
ucp_config_modify(config, "MEMTYPE_REG_WHOLE_ALLOC_TYPES", "unknown"));
return ncclSuccess;
}

static ncclResult_t ucx_init_context(ucp_context_h *ctx, int dev) {
ucp_params_t ucp_params;
ucp_config_t *config;
char ucx_dev_name[PATH_MAX];
ncclResult_t result;

snprintf(ucx_dev_name, PATH_MAX, "%s:%d", ncclIbDevs[dev].devName, ncclIbDevs[dev].port);
UCXCHECK(ucp_config_read("NCCL", NULL, &config));
UCXCHECK(ucp_config_modify(config, "NET_DEVICES", ucx_dev_name));

if (ncclParamUCXCudaDisable()) {
result = ucx_config_no_cuda(config);
if (result != ncclSuccess) {
return result;
}
}

memset(&ucp_params, 0, sizeof(ucp_params));
ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES;
ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_RMA;
Expand Down Expand Up @@ -275,14 +320,19 @@ static ncclResult_t ucx_get_ctx_and_worker(int dev, ucp_context_h *ctx,
ucp_worker_h *worker,
ucp_tag_t *newtag) {
pthread_mutex_lock(&nccl_ucx_lock);
ncclResult_t result;

#ifdef UCX_SHARED_WORKER
if (ncclNIbDevs < dev) {
WARN("Device index is too large");
return ncclSystemError;
}

if (workers[dev].count == 0) {
ucx_init_context(&workers[dev].ctx, dev);
result = ucx_init_context(&workers[dev].ctx, dev);
if (result != ncclSuccess) {
return result;
}
ucx_init_worker(workers[dev].ctx, &workers[dev].worker);
workers[dev].last_tag = tag;
}
Expand Down Expand Up @@ -723,10 +773,11 @@ static ncclResult_t nccl_ucx_isend(void *send_comm, void *data, int size,
params.cb.send = send_handler_nbx;
params.user_data = &req->pending;
if (mh) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.memory_type = mh->mem_type;
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;
params.memh = mh->ucp_memh;
}


ucp_req = ucp_tag_send_nbx(comm->ep, data, size,
nccl_ucx_ucp_tag(comm->tag, tag), &params);
if (UCS_PTR_IS_ERR(ucp_req)) {
Expand Down Expand Up @@ -778,10 +829,10 @@ static ncclResult_t nccl_ucx_irecv(void *recv_comm, int n, void **data,
ucx_request_add(req, sizes[i]);

if (mh[i]) {
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.memory_type = mh[i]->mem_type;
params.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH;
params.memh = mh[i]->ucp_memh;
} else {
params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMORY_TYPE;
params.op_attr_mask &= ~UCP_OP_ATTR_FIELD_MEMH;
}

ucp_req = ucp_tag_recv_nbx(comm->worker, data[i], sizes[i],
Expand Down

0 comments on commit 977ec32

Please sign in to comment.