Skip to content

Commit

Permalink
Make send check non-blocking to allow other workers to be progressed
Browse files Browse the repository at this point in the history
  • Loading branch information
tvegas1 authored and bureddy committed Nov 6, 2023
1 parent 17b32e3 commit 35520a0
Showing 1 changed file with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions src/ucx_plugin.c
Original file line number Diff line number Diff line change
Expand Up @@ -615,41 +615,65 @@ static ncclResult_t ucx_send_check(ucx_comm_t *comm) {
ucp_request_param_t params;
ucp_tag_message_h msg_tag;
ucp_tag_recv_info_t info_tag;
connect_msg_t *msg;
ucp_ep_params_t ep_params;
void *ucp_req;
ucs_status_t status;

ucp_worker_progress(comm->worker);

if (comm->connect_req != NULL) {
goto out_check_status;
}

msg_tag = ucp_tag_probe_nb(comm->worker, comm->ctag, tag_mask, 1, &info_tag);
if (msg_tag == NULL) {
return ncclSuccess;
}

msg = malloc(info_tag.length);
comm->msg = malloc(info_tag.length);
if (comm->msg == NULL) {
return ncclSystemError;
}

params.op_attr_mask = 0;
ucp_req = ucp_tag_msg_recv_nbx(comm->worker, msg, info_tag.length,
ucp_req = ucp_tag_msg_recv_nbx(comm->worker, comm->msg, info_tag.length,
msg_tag, &params);
if (UCS_PTR_IS_ERR(ucp_req)) {
WARN("Unable to receive connect msg (%s)",
ucs_status_string(UCS_PTR_STATUS(ucp_req)));
free(msg);
free(comm->msg);
comm->msg = NULL;
return ncclSystemError;
} else if (ucp_req != NULL) {
do {
ucp_worker_progress(comm->worker);
status = ucp_request_check_status(ucp_req);
} while (status == UCS_INPROGRESS);
assert(status == UCS_OK);
ucp_request_free(ucp_req);
} else if (ucp_req == NULL) {
goto out_set_ready;
}

assert(comm->connect_req == NULL);
comm->connect_req = ucp_req;

out_check_status:
status = ucp_request_check_status(comm->connect_req);
if (status == UCS_INPROGRESS) {
return ncclSuccess;
}

if (status != UCS_OK) {
free(comm->msg);
comm->msg = NULL;
WARN("Send check requested returned error (%s)", ucs_status_string(status));
return ncclSystemError;
}

ucp_request_free(comm->connect_req);
comm->connect_req = NULL;

out_set_ready:
ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS;
ep_params.address = (ucp_address_t*)(msg + 1);
ep_params.address = (ucp_address_t*)(comm->msg + 1);
UCXCHECK(ucp_ep_create(comm->worker, &ep_params, &comm->ep));
comm->ready = 1;
free(msg);
free(comm->msg);
comm->msg = NULL;

return ncclSuccess;
}
Expand Down Expand Up @@ -689,6 +713,7 @@ ncclResult_t ucx_recv_check(ucx_comm_t *comm) {
params.cb.send = check_handler;
params.user_data = comm;

assert(comm->connect_req == NULL);
comm->connect_req = ucp_tag_send_nbx(comm->ep, comm->msg, msg_len,
comm->ctag, &params);
if (UCS_PTR_IS_ERR(comm->connect_req)) {
Expand Down

0 comments on commit 35520a0

Please sign in to comment.