Skip to content
This repository has been archived by the owner on Jan 26, 2024. It is now read-only.

Commit

Permalink
SWDEV-372153 - Add hipStreamGetDevice Implementation
Browse files Browse the repository at this point in the history
Change-Id: Ifd1f13e311e8221ca6d94cf27f9131eb97678067
  • Loading branch information
cjatin committed Mar 1, 2023
1 parent 9b42cc5 commit 7f83be5
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 1 deletion.
26 changes: 25 additions & 1 deletion include/hip/amd_detail/hip_prof_str.h
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,8 @@ enum hip_api_id_t {
HIP_API_ID_hipArray3DGetDescriptor = 360,
HIP_API_ID_hipArrayGetDescriptor = 361,
HIP_API_ID_hipArrayGetInfo = 362,
HIP_API_ID_LAST = 362,
HIP_API_ID_hipStreamGetDevice = 363,
HIP_API_ID_LAST = 363,

HIP_API_ID_hipBindTexture = HIP_API_ID_NONE,
HIP_API_ID_hipBindTexture2D = HIP_API_ID_NONE,
Expand Down Expand Up @@ -743,6 +744,7 @@ static inline const char* hip_api_name(const uint32_t id) {
case HIP_API_ID_hipStreamEndCapture: return "hipStreamEndCapture";
case HIP_API_ID_hipStreamGetCaptureInfo: return "hipStreamGetCaptureInfo";
case HIP_API_ID_hipStreamGetCaptureInfo_v2: return "hipStreamGetCaptureInfo_v2";
case HIP_API_ID_hipStreamGetDevice: return "hipStreamGetDevice";
case HIP_API_ID_hipStreamGetFlags: return "hipStreamGetFlags";
case HIP_API_ID_hipStreamGetPriority: return "hipStreamGetPriority";
case HIP_API_ID_hipStreamIsCapturing: return "hipStreamIsCapturing";
Expand Down Expand Up @@ -1108,6 +1110,7 @@ static inline uint32_t hipApiIdByName(const char* name) {
if (strcmp("hipStreamEndCapture", name) == 0) return HIP_API_ID_hipStreamEndCapture;
if (strcmp("hipStreamGetCaptureInfo", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo;
if (strcmp("hipStreamGetCaptureInfo_v2", name) == 0) return HIP_API_ID_hipStreamGetCaptureInfo_v2;
if (strcmp("hipStreamGetDevice", name) == 0) return HIP_API_ID_hipStreamGetDevice;
if (strcmp("hipStreamGetFlags", name) == 0) return HIP_API_ID_hipStreamGetFlags;
if (strcmp("hipStreamGetPriority", name) == 0) return HIP_API_ID_hipStreamGetPriority;
if (strcmp("hipStreamIsCapturing", name) == 0) return HIP_API_ID_hipStreamIsCapturing;
Expand Down Expand Up @@ -3062,6 +3065,11 @@ typedef struct hip_api_data_s {
size_t* numDependencies_out;
size_t numDependencies_out__val;
} hipStreamGetCaptureInfo_v2;
struct {
hipStream_t stream;
hipDevice_t* device;
hipDevice_t device__val;
} hipStreamGetDevice;
struct {
hipStream_t stream;
unsigned int* flags;
Expand Down Expand Up @@ -5231,6 +5239,11 @@ typedef struct hip_api_data_s {
cb_data.args.hipStreamGetCaptureInfo_v2.dependencies_out = (const hipGraphNode_t**)dependencies_out; \
cb_data.args.hipStreamGetCaptureInfo_v2.numDependencies_out = (size_t*)numDependencies_out; \
};
// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')]
#define INIT_hipStreamGetDevice_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipStreamGetDevice.stream = (hipStream_t)stream; \
cb_data.args.hipStreamGetDevice.device = (hipDevice_t*)device; \
};
// hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')]
#define INIT_hipStreamGetFlags_CB_ARGS_DATA(cb_data) { \
cb_data.args.hipStreamGetFlags.stream = (hipStream_t)stream; \
Expand Down Expand Up @@ -6765,6 +6778,10 @@ static inline void hipApiArgsInit(hip_api_id_t id, hip_api_data_t* data) {
if (data->args.hipStreamGetCaptureInfo_v2.dependencies_out) data->args.hipStreamGetCaptureInfo_v2.dependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.dependencies_out);
if (data->args.hipStreamGetCaptureInfo_v2.numDependencies_out) data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val = *(data->args.hipStreamGetCaptureInfo_v2.numDependencies_out);
break;
// hipStreamGetDevice[('hipStream_t', 'stream'), ('hipDevice_t*', 'device')]
case HIP_API_ID_hipStreamGetDevice:
if (data->args.hipStreamGetDevice.device) data->args.hipStreamGetDevice.device__val = *(data->args.hipStreamGetDevice.device);
break;
// hipStreamGetFlags[('hipStream_t', 'stream'), ('unsigned int*', 'flags')]
case HIP_API_ID_hipStreamGetFlags:
if (data->args.hipStreamGetFlags.flags) data->args.hipStreamGetFlags.flags__val = *(data->args.hipStreamGetFlags.flags);
Expand Down Expand Up @@ -9491,6 +9508,13 @@ static inline const char* hipApiString(hip_api_id_t id, const hip_api_data_t* da
else { oss << ", numDependencies_out="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetCaptureInfo_v2.numDependencies_out__val); }
oss << ")";
break;
case HIP_API_ID_hipStreamGetDevice:
oss << "hipStreamGetDevice(";
oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.stream);
if (data->args.hipStreamGetDevice.device == NULL) oss << ", device=NULL";
else { oss << ", device="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetDevice.device__val); }
oss << ")";
break;
case HIP_API_ID_hipStreamGetFlags:
oss << "hipStreamGetFlags(";
oss << "stream="; roctracer::hip_support::detail::operator<<(oss, data->args.hipStreamGetFlags.stream);
Expand Down
14 changes: 14 additions & 0 deletions include/hip/nvidia_detail/nvidia_hip_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2507,6 +2507,20 @@ inline static hipError_t hipStreamAddCallback(hipStream_t stream, hipStreamCallb
cudaStreamAddCallback(stream, (cudaStreamCallback_t)callback, userData, flags));
}

inline static hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) {
hipCtx_t context;
auto err = hipCUResultTohipError(cuStreamGetCtx(stream, &context));
if (err != hipSuccess) return err;

err = hipCUResultTohipError(cuCtxPushCurrent(context));
if (err != hipSuccess) return err;

err = hipCUResultTohipError(cuCtxGetDevice(device));
if (err != hipSuccess) return err;

return hipCUResultTohipError(cuCtxPopCurrent(&context));
}

inline static hipError_t hipDriverGetVersion(int* driverVersion) {
return hipCUDAErrorTohipError(cudaDriverGetVersion(driverVersion));
}
Expand Down
1 change: 1 addition & 0 deletions src/amdhip.def
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ hipStreamCreate
hipStreamCreateWithFlags
hipStreamCreateWithPriority
hipStreamDestroy
hipStreamGetDevice
hipStreamGetFlags
hipStreamQuery
hipStreamSynchronize
Expand Down
1 change: 1 addition & 0 deletions src/hip_hcc.def.in
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ hipStreamCreate
hipStreamCreateWithFlags
hipStreamCreateWithPriority
hipStreamDestroy
hipStreamGetDevice
hipStreamGetFlags
hipStreamQuery
hipStreamSynchronize
Expand Down
1 change: 1 addition & 0 deletions src/hip_hcc.map.in
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ global:
hipStreamCreateWithFlags;
hipStreamCreateWithPriority;
hipStreamDestroy;
hipStreamGetDevice;
hipStreamGetFlags;
hipStreamQuery;
hipStreamSynchronize;
Expand Down
24 changes: 24 additions & 0 deletions src/hip_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -795,3 +795,27 @@ hipError_t hipExtStreamGetCUMask(hipStream_t stream, uint32_t cuMaskSize, uint32
}
HIP_RETURN(hipSuccess);
}

// ================================================================================================
hipError_t hipStreamGetDevice(hipStream_t stream, hipDevice_t* device) {
HIP_INIT_API(hipStreamGetDevice, stream, device);

if (device == nullptr) {
HIP_RETURN(hipErrorInvalidValue);
}

if (!hip::isValid(stream)) {
return HIP_RETURN(hipErrorContextIsDestroyed);
}

if (stream == nullptr) { // handle null stream
// null stream is associated with current device, return the device id associated with the
// current device
*device = hip::getCurrentDevice()->deviceId();
} else {
getStreamPerThread(stream);
*device = reinterpret_cast<hip::Stream*>(stream)->DeviceId();
}

HIP_RETURN(hipSuccess);
}

0 comments on commit 7f83be5

Please sign in to comment.