Skip to content

Commit

Permalink
Update C++ XPU interface to handle multiple devices indices.
Browse files Browse the repository at this point in the history
  • Loading branch information
jatkinson1000 authored and ma595 committed Dec 16, 2024
1 parent 6a96d49 commit 84a4e5d
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,19 @@ const auto get_device(torch_device_t device_type, int device_index) {
}
return torch::Device(torch::kMPS);
case torch_kXPU:
if (device_index != -1) {
std::cerr << "[WARNING]: device index unused for XPU runs"
if (device_index == -1) {
std::cerr << "[WARNING]: device index unset, defaulting to 0"
<< std::endl;
device_index = 0;
}
if (device_index >= 0 && device_index < torch::xpu::device_count()) {
return torch::Device(torch::kXPU, device_index);
} else {
std::cerr << "[ERROR]: invalid device index " << device_index
<< " for XPU device count " << torch::xpu::device_count()
<< std::endl;
exit(EXIT_FAILURE);
}
return torch::Device(torch::kXPU);
default:
std::cerr << "[WARNING]: unknown device type, setting to torch_kCPU" << std::endl;
return torch::Device(torch::kCPU);
Expand Down

0 comments on commit 84a4e5d

Please sign in to comment.