Skip to content

Commit

Permalink
run clang-format on C++ and C-header files.
Browse files Browse the repository at this point in the history
  • Loading branch information
jatkinson1000 committed Nov 8, 2024
1 parent 8e254d3 commit a11d643
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 154 deletions.
235 changes: 110 additions & 125 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

#include "ctorch.h"

constexpr auto get_dtype(torch_data_t dtype)
{
constexpr auto get_dtype(torch_data_t dtype) {
switch (dtype) {
case torch_kUInt8:
std::cerr << "[WARNING]: uint8 not supported in Fortran" << std::endl;
Expand Down Expand Up @@ -33,8 +32,7 @@ constexpr auto get_dtype(torch_data_t dtype)
}
}

const auto get_device(torch_device_t device_type, int device_index)
{
const auto get_device(torch_device_t device_type, int device_index) {
switch (device_type) {
case torch_kCPU:
if (device_index != -1) {
Expand Down Expand Up @@ -64,81 +62,77 @@ const auto get_device(torch_device_t device_type, int device_index)
}

void set_is_training(torch_jit_script_module_t module,
const bool is_training=false)
{
auto model = static_cast<torch::jit::script::Module*>(module);
const bool is_training = false) {
auto model = static_cast<torch::jit::script::Module *>(module);
if (is_training) {
model->train();
} else {
model->eval();
}
}

torch_tensor_t torch_zeros(int ndim, const int64_t* shape, torch_data_t dtype,
torch_tensor_t torch_zeros(int ndim, const int64_t *shape, torch_data_t dtype,
torch_device_t device_type, int device_index = -1,
const bool requires_grad=false)
{
const bool requires_grad = false) {
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;
torch::Tensor *tensor = nullptr;
try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
tensor = new torch::Tensor;
*tensor = torch::zeros(
vshape, torch::dtype(get_dtype(dtype))).to(get_device(device_type, device_index));
} catch (const torch::Error& e) {
*tensor = torch::zeros(vshape, torch::dtype(get_dtype(dtype)))
.to(get_device(device_type, device_index));
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
std::cerr << "[ERROR]: " << e.what() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
}
return tensor;
}

torch_tensor_t torch_ones(int ndim, const int64_t* shape, torch_data_t dtype,
torch_tensor_t torch_ones(int ndim, const int64_t *shape, torch_data_t dtype,
torch_device_t device_type, int device_index = -1,
const bool requires_grad=false)
{
const bool requires_grad = false) {
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;
torch::Tensor *tensor = nullptr;
try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
tensor = new torch::Tensor;
*tensor = torch::ones(
vshape, torch::dtype(get_dtype(dtype))).to(get_device(device_type, device_index));
} catch (const torch::Error& e) {
*tensor = torch::ones(vshape, torch::dtype(get_dtype(dtype)))
.to(get_device(device_type, device_index));
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
std::cerr << "[ERROR]: " << e.what() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
}
return tensor;
}

torch_tensor_t torch_empty(int ndim, const int64_t* shape, torch_data_t dtype,
torch_tensor_t torch_empty(int ndim, const int64_t *shape, torch_data_t dtype,
torch_device_t device_type, int device_index = -1,
const bool requires_grad=false)
{
const bool requires_grad = false) {
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;
torch::Tensor *tensor = nullptr;
try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
tensor = new torch::Tensor;
*tensor = torch::empty(
vshape, torch::dtype(get_dtype(dtype))).to(get_device(device_type, device_index));
} catch (const torch::Error& e) {
*tensor = torch::empty(vshape, torch::dtype(get_dtype(dtype)))
.to(get_device(device_type, device_index));
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
std::cerr << "[ERROR]: " << e.what() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
Expand All @@ -148,120 +142,113 @@ torch_tensor_t torch_empty(int ndim, const int64_t* shape, torch_data_t dtype,

// Exposes the given data as a Tensor without taking ownership of the original
// data
torch_tensor_t torch_from_blob(void* data, int ndim, const int64_t* shape,
const int64_t* strides, torch_data_t dtype,
torch_device_t device_type, int device_index = -1,
const bool requires_grad=false)
{
torch_tensor_t torch_from_blob(void *data, int ndim, const int64_t *shape,
const int64_t *strides, torch_data_t dtype,
torch_device_t device_type,
int device_index = -1,
const bool requires_grad = false) {
torch::AutoGradMode enable_grad(requires_grad);
torch::Tensor* tensor = nullptr;
torch::Tensor *tensor = nullptr;

try {
// This doesn't throw if shape and dimensions are incompatible
c10::IntArrayRef vshape(shape, ndim);
c10::IntArrayRef vstrides(strides, ndim);
tensor = new torch::Tensor;
*tensor = torch::from_blob(
data, vshape, vstrides,
torch::dtype(get_dtype(dtype))).to(get_device(device_type, device_index));
*tensor =
torch::from_blob(data, vshape, vstrides, torch::dtype(get_dtype(dtype)))
.to(get_device(device_type, device_index));

} catch (const torch::Error& e) {
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
std::cerr << "[ERROR]: " << e.what() << std::endl;
delete tensor;
exit(EXIT_FAILURE);
}
return tensor;
}

void* torch_to_blob(const torch_tensor_t tensor, const torch_data_t dtype)
{
auto t = reinterpret_cast<torch::Tensor* const>(tensor);
void* raw_ptr;
switch (dtype) {
case torch_kUInt8:
std::cerr << "[WARNING]: uint8 not supported" << std::endl;
exit(EXIT_FAILURE);
case torch_kInt8:
raw_ptr = (void*) t->data_ptr<int8_t>();
break;
case torch_kInt16:
raw_ptr = (void*) t->data_ptr<int16_t>();
break;
case torch_kInt32:
raw_ptr = (void*) t->data_ptr<int32_t>();
break;
case torch_kInt64:
raw_ptr = (void*) t->data_ptr<int64_t>();
break;
case torch_kFloat16:
std::cerr << "[WARNING]: float16 not supported" << std::endl;
// NOTE: std::float16_t is available but only with C++23
exit(EXIT_FAILURE);
case torch_kFloat32:
raw_ptr = (void*) t->data_ptr<float>();
// NOTE: std::float32_t is available but only with C++23
break;
case torch_kFloat64:
raw_ptr = (void*) t->data_ptr<double>();
// NOTE: std::float64_t is available but only with C++23
break;
default:
std::cerr << "[WARNING]: unknown data type" << std::endl;
exit(EXIT_FAILURE);
}
return raw_ptr;
void *torch_to_blob(const torch_tensor_t tensor, const torch_data_t dtype) {
auto t = reinterpret_cast<torch::Tensor *const>(tensor);
void *raw_ptr;
switch (dtype) {
case torch_kUInt8:
std::cerr << "[WARNING]: uint8 not supported" << std::endl;
exit(EXIT_FAILURE);
case torch_kInt8:
raw_ptr = (void *)t->data_ptr<int8_t>();
break;
case torch_kInt16:
raw_ptr = (void *)t->data_ptr<int16_t>();
break;
case torch_kInt32:
raw_ptr = (void *)t->data_ptr<int32_t>();
break;
case torch_kInt64:
raw_ptr = (void *)t->data_ptr<int64_t>();
break;
case torch_kFloat16:
std::cerr << "[WARNING]: float16 not supported" << std::endl;
// NOTE: std::float16_t is available but only with C++23
exit(EXIT_FAILURE);
case torch_kFloat32:
raw_ptr = (void *)t->data_ptr<float>();
// NOTE: std::float32_t is available but only with C++23
break;
case torch_kFloat64:
raw_ptr = (void *)t->data_ptr<double>();
// NOTE: std::float64_t is available but only with C++23
break;
default:
std::cerr << "[WARNING]: unknown data type" << std::endl;
exit(EXIT_FAILURE);
}
return raw_ptr;
}

void torch_tensor_print(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
void torch_tensor_print(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
std::cout << *t << std::endl;
}

int torch_tensor_get_device_index(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
int torch_tensor_get_device_index(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
return t->device().index();
}

int torch_tensor_get_rank(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
int torch_tensor_get_rank(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
return t->sizes().size();
}

const long int* torch_tensor_get_sizes(const torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
const long int *torch_tensor_get_sizes(const torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
return t->sizes().data();
}

void torch_tensor_delete(torch_tensor_t tensor)
{
auto t = reinterpret_cast<torch::Tensor*>(tensor);
void torch_tensor_delete(torch_tensor_t tensor) {
auto t = reinterpret_cast<torch::Tensor *>(tensor);
delete t;
}

torch_jit_script_module_t torch_jit_load(const char* filename,
const torch_device_t device_type = torch_kCPU,
const int device_index = -1,
const bool requires_grad=false,
const bool is_training=false)
{
torch_jit_script_module_t
torch_jit_load(const char *filename,
const torch_device_t device_type = torch_kCPU,
const int device_index = -1, const bool requires_grad = false,
const bool is_training = false) {
torch::AutoGradMode enable_grad(requires_grad);
torch::jit::script::Module* module = nullptr;
torch::jit::script::Module *module = nullptr;
try {
module = new torch::jit::script::Module;
*module = torch::jit::load(filename, get_device(device_type, device_index));
} catch (const torch::Error& e) {
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
delete module;
exit(EXIT_FAILURE);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
std::cerr << "[ERROR]: " << e.what() << std::endl;
delete module;
exit(EXIT_FAILURE);
Expand All @@ -274,26 +261,26 @@ torch_jit_script_module_t torch_jit_load(const char* filename,
void torch_jit_module_forward(const torch_jit_script_module_t module,
const torch_tensor_t *inputs, const int nin,
torch_tensor_t *outputs, const int nout,
const bool requires_grad=false)
{
const bool requires_grad = false) {
torch::AutoGradMode enable_grad(requires_grad);
// Here we cast the pointers we recieved in to Tensor objects
auto model = static_cast<torch::jit::script::Module*>(module);
auto in = reinterpret_cast<torch::Tensor* const*>(inputs);
auto out = reinterpret_cast<torch::Tensor**>(outputs);
auto model = static_cast<torch::jit::script::Module *>(module);
auto in = reinterpret_cast<torch::Tensor *const *>(inputs);
auto out = reinterpret_cast<torch::Tensor **>(outputs);
// Local IValue for checking we are passed types
torch::jit::IValue LocalTensor;
// Generate a vector of IValues (placeholders for various Torch types)
std::vector<torch::jit::IValue> inputs_vec;
// Populate with Tensors pointed at by pointers
// For each IValue check it is of Tensor type
for (int i=0; i<nin; ++i) {
for (int i = 0; i < nin; ++i) {
LocalTensor = *(in[i]);
if (LocalTensor.isTensor()) {
inputs_vec.push_back(LocalTensor);
}
else {
std::cerr << "[ERROR]: One of the inputs to torch_jit_module_forward is not a Tensor." << std::endl;
} else {
std::cerr << "[ERROR]: One of the inputs to torch_jit_module_forward is "
"not a Tensor."
<< std::endl;
exit(EXIT_FAILURE);
}
}
Expand All @@ -302,29 +289,27 @@ void torch_jit_module_forward(const torch_jit_script_module_t module,
if (model_out.isTensor()) {
// Single output models will return a tensor directly.
std::move(*out[0]) = model_out.toTensor();
}
else if (model_out.isTuple()) {
} else if (model_out.isTuple()) {
// Multiple output models will return a tuple => cast to tensors.
for (int i=0; i<nout; ++i) {
for (int i = 0; i < nout; ++i) {
std::move(*out[i]) = model_out.toTuple()->elements()[i].toTensor();
}
} else {
// If for some reason the forward method does not return a Tensor it
// should raise an error when trying to cast to a Tensor type
std::cerr << "[ERROR]: Model Output is neither Tensor nor Tuple."
<< std::endl;
}
else {
// If for some reason the forward method does not return a Tensor it should
// raise an error when trying to cast to a Tensor type
std::cerr << "[ERROR]: Model Output is neither Tensor nor Tuple." << std::endl;
}
} catch (const torch::Error& e) {
} catch (const torch::Error &e) {
std::cerr << "[ERROR]: " << e.msg() << std::endl;
exit(EXIT_FAILURE);
} catch (const std::exception& e) {
} catch (const std::exception &e) {
std::cerr << "[ERROR]: " << e.what() << std::endl;
exit(EXIT_FAILURE);
}
}

void torch_jit_module_delete(torch_jit_script_module_t module)
{
auto m = reinterpret_cast<torch::jit::script::Module*>(module);
void torch_jit_module_delete(torch_jit_script_module_t module) {
auto m = reinterpret_cast<torch::jit::script::Module *>(module);
delete m;
}
Loading

0 comments on commit a11d643

Please sign in to comment.