diff --git a/include/micm/solver/cuda_state.hpp b/include/micm/solver/cuda_state.hpp new file mode 100644 index 000000000..6b1fe8d4b --- /dev/null +++ b/include/micm/solver/cuda_state.hpp @@ -0,0 +1,39 @@ +// Copyright (C) 2023-2024 National Center for Atmospheric Research +// SPDX-License-Identifier: Apache-2.0 +#pragma once + +#include +#include +#include +#include + +namespace micm +{ + /// @brief Construct a state variable for CUDA tests + template + struct CudaState : public State + { + public: + CudaState(const CudaState&) = delete; + CudaState& operator=(const CudaState&) = delete; + CudaState(CudaState&&) = default; + CudaState& operator=(CudaState&&) = default; + + /// @brief Copy input variables to the device + template + requires(CudaMatrix && VectorizableDense) + void SyncInputsToDevice() + { + variables_.CopyToDevice(); + rate_constants_.CopyToDevice(); + } + + /// @brief Copy output variables to the host + template + requires(CudaMatrix && VectorizableDense) + void SyncOutputsToHost() + { + variables_.CopyToHost(); + } + }; +} // namespace micm \ No newline at end of file diff --git a/include/micm/solver/state.hpp b/include/micm/solver/state.hpp index ff0f1ebf1..86e4d053a 100644 --- a/include/micm/solver/state.hpp +++ b/include/micm/solver/state.hpp @@ -8,7 +8,6 @@ #include #include #include -#include #include #include @@ -92,12 +91,6 @@ namespace micm /// @brief Print state (concentrations) at the given time /// @param time solving time void PrintState(double time); - - /// @brief Copy input variables to the device - void SyncInputsToDevice(); - - /// @brief Copy output variables to the host - void SyncOutputsToHost(); }; } // namespace micm diff --git a/include/micm/solver/state.inl b/include/micm/solver/state.inl index d30fa6043..5185bcfdc 100644 --- a/include/micm/solver/state.inl +++ b/include/micm/solver/state.inl @@ -241,19 +241,4 @@ namespace micm std::cout.copyfmt(oldState); } } - - template - requires(CudaMatrix && VectorizableDense) - inline void State::SyncInputsToDevice() - { - variables_.CopyToDevice(); - rate_constants_.CopyToDevice(); - } - - template - requires(CudaMatrix && VectorizableDense) - inline void State::SyncOutputsToHost() - { - variables_.CopyToHost(); - } } // namespace micm diff --git a/test/integration/analytical_cuda_rosenbrock.cpp b/test/integration/analytical_cuda_rosenbrock.cpp index 11338a282..0641ae641 100644 --- a/test/integration/analytical_cuda_rosenbrock.cpp +++ b/test/integration/analytical_cuda_rosenbrock.cpp @@ -1,3 +1,6 @@ +// Copyright (C) 2023-2024 National Center for Atmospheric Research +// SPDX-License-Identifier: Apache-2.0 + #include "analytical_policy.hpp" #include "analytical_surface_rxn_policy.hpp"