Skip to content

Commit

Permalink
extend state class to cudastate class
Browse files Browse the repository at this point in the history
  • Loading branch information
sjsprecious committed Jun 22, 2024
1 parent fa1b8a6 commit cd5735e
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 22 deletions.
39 changes: 39 additions & 0 deletions include/micm/solver/cuda_state.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2023-2024 National Center for Atmospheric Research
// SPDX-License-Identifier: Apache-2.0
#pragma once

#include <micm/solver/state.hpp>
#include <micm/util/matrix.hpp>
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/cuda_dense_matrix.hpp>

namespace micm
{
/// @brief Construct a state variable for CUDA tests
template<class DenseMatrixPolicy = StandardDenseMatrix, class SparseMatrixPolicy = StandardSparseMatrix>
struct CudaState : public State<DenseMatrixPolicy, SparseMatrixPolicy>
{
public:
CudaState(const CudaState&) = delete;
CudaState& operator=(const CudaState&) = delete;
CudaState(CudaState&&) = default;
CudaState& operator=(CudaState&&) = default;

/// @brief Copy input variables to the device
template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(CudaMatrix<DenseMatrixPolicy> && VectorizableDense<DenseMatrixPolicy>)
void SyncInputsToDevice()
{
variables_.CopyToDevice();
rate_constants_.CopyToDevice();
}

/// @brief Copy output variables to the host
template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(CudaMatrix<DenseMatrixPolicy> && VectorizableDense<DenseMatrixPolicy>)
void SyncOutputsToHost()
{
variables_.CopyToHost();
}
};
} // namespace micm
7 changes: 0 additions & 7 deletions include/micm/solver/state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include <micm/util/jacobian.hpp>
#include <micm/util/matrix.hpp>
#include <micm/util/sparse_matrix.hpp>
#include <micm/util/cuda_dense_matrix.hpp>

#include <algorithm>
#include <cstddef>
Expand Down Expand Up @@ -92,12 +91,6 @@ namespace micm
/// @brief Print state (concentrations) at the given time
/// @param time solving time
void PrintState(double time);

/// @brief Copy input variables to the device
void SyncInputsToDevice();

/// @brief Copy output variables to the host
void SyncOutputsToHost();
};

} // namespace micm
Expand Down
15 changes: 0 additions & 15 deletions include/micm/solver/state.inl
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,4 @@ namespace micm
std::cout.copyfmt(oldState);
}
}

template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(CudaMatrix<DenseMatrixPolicy> && VectorizableDense<DenseMatrixPolicy>)
inline void State<DenseMatrixPolicy, SparseMatrixPolicy>::SyncInputsToDevice()
{
variables_.CopyToDevice();
rate_constants_.CopyToDevice();
}

template<class DenseMatrixPolicy, class SparseMatrixPolicy>
requires(CudaMatrix<DenseMatrixPolicy> && VectorizableDense<DenseMatrixPolicy>)
inline void State<DenseMatrixPolicy, SparseMatrixPolicy>::SyncOutputsToHost()
{
variables_.CopyToHost();
}
} // namespace micm
3 changes: 3 additions & 0 deletions test/integration/analytical_cuda_rosenbrock.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Copyright (C) 2023-2024 National Center for Atmospheric Research
// SPDX-License-Identifier: Apache-2.0

#include "analytical_policy.hpp"
#include "analytical_surface_rxn_policy.hpp"

Expand Down

0 comments on commit cd5735e

Please sign in to comment.