diff --git a/examples/2_ResNet18/resnet_infer_fortran.f90 b/examples/2_ResNet18/resnet_infer_fortran.f90 index bc137e42..a23cf065 100644 --- a/examples/2_ResNet18/resnet_infer_fortran.f90 +++ b/examples/2_ResNet18/resnet_infer_fortran.f90 @@ -3,7 +3,7 @@ program inference use, intrinsic :: iso_fortran_env, only : sp => real32 ! Import our library for interfacing with PyTorch - use ftorch, only : torch_model, torch_tensor, torch_kCPU, torch_delete, & + use ftorch, only : torch_model, torch_tensor, torch_kXPU, torch_kCPU, torch_delete, & torch_tensor_from_array, torch_model_load, torch_model_forward ! Import our tools module for testing utils @@ -82,12 +82,12 @@ subroutine main() call load_data(filename, tensor_length, in_data) ! Create input/output tensors from the above arrays - call torch_tensor_from_array(in_tensors(1), in_data, in_layout, torch_kCPU) + call torch_tensor_from_array(in_tensors(1), in_data, in_layout, torch_kXPU, device_index=0) call torch_tensor_from_array(out_tensors(1), out_data, out_layout, torch_kCPU) ! Load ML model (edit this line to use different models) - call torch_model_load(model, args(1)) + call torch_model_load(model, args(1), device_index=0) ! Infer call torch_model_forward(model, in_tensors, out_tensors)