From f8496e646da4d8662e337350ede318508b712154 Mon Sep 17 00:00:00 2001 From: Sylwester Arabas Date: Tue, 4 Jun 2024 17:54:54 +0200 Subject: [PATCH 1/5] add consistency checks for allow halving/doubling flags --- src/aero_state.hpp | 12 +++++++++- src/run_part.cpp | 14 +++++++++++ src/run_part_opt.hpp | 3 +++ tests/test_aero_state.py | 33 ++++++++++++++++++++++++++ tests/test_output.py | 4 ++-- tests/test_run_part.py | 50 +++++++++++++++++++++++++++++++++++++++- 6 files changed, 112 insertions(+), 4 deletions(-) diff --git a/src/aero_state.hpp b/src/aero_state.hpp index a3c86399..0a0f37a0 100644 --- a/src/aero_state.hpp +++ b/src/aero_state.hpp @@ -224,6 +224,7 @@ auto pointer_vec_magic(arr_t &data_vec, const arg_t &arg) { struct AeroState { PMCResource ptr; std::shared_ptr aero_data; + int allow_halving = -1, allow_doubling = -1; AeroState( std::shared_ptr aero_data, @@ -572,7 +573,7 @@ struct AeroState { } static int dist_sample( - const AeroState &self, + AeroState &self, const AeroDist &aero_dist, const double &sample_prop, const double &create_time, @@ -581,6 +582,15 @@ struct AeroState { ) { int n_part_add = 0; + if ( + (self.allow_doubling != -1 && self.allow_doubling != allow_doubling) || + (self.allow_halving != -1 && self.allow_halving != allow_halving) + ) + throw std::runtime_error("dist_sample() called with different halving/doubling settings then in last call"); + + self.allow_doubling = allow_doubling; + self.allow_halving = allow_halving; + f_aero_state_add_aero_dist_sample( self.ptr.f_arg(), self.aero_data->ptr.f_arg(), diff --git a/src/run_part.cpp b/src/run_part.cpp index 55a0f6ff..624c7ba6 100644 --- a/src/run_part.cpp +++ b/src/run_part.cpp @@ -7,6 +7,17 @@ #include "run_part.hpp" #include "pybind11/stl.h" +void check_allow_flags( + const AeroState &aero_state, + const RunPartOpt &run_part_opt +) { + if ( + (aero_state.allow_halving != -1 && run_part_opt.allow_halving != aero_state.allow_halving) || + (aero_state.allow_doubling != -1 && run_part_opt.allow_doubling != aero_state.allow_doubling) + ) + throw std::runtime_error("allow halving/doubling flags set differently then while sampling"); +} + void run_part( const Scenario &scenario, EnvState &env_state, @@ -18,6 +29,7 @@ void run_part( const CampCore &camp_core, const Photolysis &photolysis ) { + check_allow_flags(aero_state, run_part_opt); f_run_part( scenario.ptr.f_arg(), env_state.ptr.f_arg_non_const(), @@ -47,6 +59,7 @@ std::tuple run_part_timestep( double &last_progress_time, int &i_output ) { + check_allow_flags(aero_state, run_part_opt); f_run_part_timestep( scenario.ptr.f_arg(), env_state.ptr.f_arg_non_const(), @@ -84,6 +97,7 @@ std::tuple run_part_timeblock( double &last_progress_time, int &i_output ) { + check_allow_flags(aero_state, run_part_opt); f_run_part_timeblock( scenario.ptr.f_arg(), env_state.ptr.f_arg_non_const(), diff --git a/src/run_part_opt.hpp b/src/run_part_opt.hpp index 23e8dce5..f657dbe6 100644 --- a/src/run_part_opt.hpp +++ b/src/run_part_opt.hpp @@ -18,6 +18,7 @@ extern "C" void f_run_part_opt_del_t(const void *ptr, double *del_t) noexcept; struct RunPartOpt { PMCResource ptr; + bool allow_halving, allow_doubling; RunPartOpt(const nlohmann::json &json) : ptr(f_run_part_opt_ctor, f_run_part_opt_dtor) @@ -39,6 +40,8 @@ struct RunPartOpt { })) if (json_copy.find(key) == json_copy.end()) json_copy[key] = true; + allow_halving = json_copy["allow_halving"]; + allow_doubling = json_copy["allow_doubling"]; for (auto key : std::set({ "t_output", "t_progress", "rand_init" diff --git a/tests/test_aero_state.py b/tests/test_aero_state.py index 74813d36..a4ad9b42 100644 --- a/tests/test_aero_state.py +++ b/tests/test_aero_state.py @@ -556,3 +556,36 @@ def test_dist_sample_mono(): # assert assert np.isclose(np.array(sut.diameters()), diam).all() + + @staticmethod + @pytest.mark.parametrize( + "args", + ( + ((True, True), (True, False)), + ((True, True), (False, True)), + ((True, True), (False, False)), + ((False, False), (True, False)), + ((False, False), (False, True)), + ((False, False), (True, True)), + ((True, False), (False, False)), + ((True, False), (False, True)), + ((False, True), (False, False)), + ((False, True), (True, False)), + ), + ) + def test_dist_sample_different_halving(args): + # arrange + aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_MINIMAL) + aero_dist = ppmc.AeroDist(aero_data, [AERO_MODE_CTOR_SAMPLED]) + sut = ppmc.AeroState(aero_data, *AERO_STATE_CTOR_ARG_MINIMAL) + + # act + with pytest.raises(RuntimeError) as excinfo: + _ = sut.dist_sample(aero_dist, 1.0, 0.0, *args[0]) + _ = sut.dist_sample(aero_dist, 1.0, 0.0, *args[1]) + + # assert + assert ( + str(excinfo.value) + == f"dist_sample() called with different halving/doubling settings then in last call" + ) diff --git a/tests/test_output.py b/tests/test_output.py index df06854a..77b9c8da 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -68,8 +68,8 @@ def test_input_netcdf(tmp_path): aero_dist, sample_prop=1.0, create_time=0.0, - allow_doubling=True, - allow_halving=True, + allow_doubling=False, + allow_halving=False, ) num_concs = aero_state.num_concs diff --git a/tests/test_run_part.py b/tests/test_run_part.py index 811196a8..beaf6a32 100644 --- a/tests/test_run_part.py +++ b/tests/test_run_part.py @@ -94,8 +94,56 @@ def test_run_part_do_condensation(common_args, tmp_path): "do_condensation": True, } ) - aero_state.dist_sample(aero_dist, 1.0, 0.0, True, True) + aero_state.dist_sample(aero_dist, 1.0, 0.0, False, False) ppmc.condense_equilib_particles(env_state, aero_data, aero_state) ppmc.run_part(*args) assert np.sum(aero_state.masses(include=["H2O"])) > 0.0 + + @staticmethod + @pytest.mark.parametrize( + "flags", + ( + ((True, True), (True, False)), + ((True, True), (False, True)), + ((True, True), (False, False)), + ((False, False), (True, False)), + ((False, False), (False, True)), + ((False, False), (True, True)), + ((True, False), (False, False)), + ((True, False), (False, True)), + ((False, True), (False, False)), + ((False, True), (True, False)), + ), + ) + def test_run_part_allow_flag_mimatch(common_args, tmp_path, flags): + # arrange + filename = tmp_path / "test" + env_state = ppmc.EnvState(ENV_STATE_CTOR_ARG_HIGH_RH) + aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_FULL) + aero_dist = ppmc.AeroDist(aero_data, AERO_DIST_CTOR_ARG_FULL) + aero_state = ppmc.AeroState(aero_data, *AERO_STATE_CTOR_ARG_MINIMAL) + args = list(common_args) + args[0].init_env_state(env_state, 0.0) + args[1] = env_state + args[2] = aero_data + args[3] = aero_state + args[6] = ppmc.RunPartOpt( + { + **RUN_PART_OPT_CTOR_ARG_SIMULATION, + "output_prefix": str(filename), + "allow_doubling": flags[0][0], + "allow_halving": flags[0][1], + } + ) + aero_state.dist_sample(aero_dist, 1.0, 0.0, flags[1][0], flags[1][1]) + + # act + with pytest.raises(RuntimeError) as excinfo: + ppmc.run_part(*args) + + # assert + assert ( + str(excinfo.value) + == f"allow halving/doubling flags set differently then while sampling" + ) From 1122ad751fa92902f0eac5c7f9cd9de4e7c7bd8a Mon Sep 17 00:00:00 2001 From: Sylwester Arabas Date: Tue, 4 Jun 2024 18:08:11 +0200 Subject: [PATCH 2/5] address pylint hints --- tests/test_aero_state.py | 2 +- tests/test_run_part.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_aero_state.py b/tests/test_aero_state.py index a4ad9b42..61d90433 100644 --- a/tests/test_aero_state.py +++ b/tests/test_aero_state.py @@ -587,5 +587,5 @@ def test_dist_sample_different_halving(args): # assert assert ( str(excinfo.value) - == f"dist_sample() called with different halving/doubling settings then in last call" + == "dist_sample() called with different halving/doubling settings then in last call" ) diff --git a/tests/test_run_part.py b/tests/test_run_part.py index beaf6a32..a894ca40 100644 --- a/tests/test_run_part.py +++ b/tests/test_run_part.py @@ -145,5 +145,5 @@ def test_run_part_allow_flag_mimatch(common_args, tmp_path, flags): # assert assert ( str(excinfo.value) - == f"allow halving/doubling flags set differently then while sampling" + == "allow halving/doubling flags set differently then while sampling" ) From 7df0fd08a2ed26e90be98798580d6e06ca16aa12 Mon Sep 17 00:00:00 2001 From: Sylwester Arabas Date: Tue, 4 Jun 2024 18:11:35 +0200 Subject: [PATCH 3/5] fix typo --- tests/test_run_part.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_run_part.py b/tests/test_run_part.py index a894ca40..e12b4299 100644 --- a/tests/test_run_part.py +++ b/tests/test_run_part.py @@ -116,7 +116,7 @@ def test_run_part_do_condensation(common_args, tmp_path): ((False, True), (True, False)), ), ) - def test_run_part_allow_flag_mimatch(common_args, tmp_path, flags): + def test_run_part_allow_flag_mismatch(common_args, tmp_path, flags): # arrange filename = tmp_path / "test" env_state = ppmc.EnvState(ENV_STATE_CTOR_ARG_HIGH_RH) From 18bc61c37a37768bb0db0c204f223777d8d16091 Mon Sep 17 00:00:00 2001 From: Sylwester Arabas Date: Tue, 4 Jun 2024 19:52:28 +0200 Subject: [PATCH 4/5] cover _timestep and _timeblock variants as well --- tests/test_run_part.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/test_run_part.py b/tests/test_run_part.py index e12b4299..20f0ce4c 100644 --- a/tests/test_run_part.py +++ b/tests/test_run_part.py @@ -52,7 +52,7 @@ def test_run_part(common_args): @staticmethod def test_run_part_timestep(common_args): - (last_output_time, last_progress_time, i_output) = ppmc.run_part_timestep( + last_output_time, last_progress_time, i_output = ppmc.run_part_timestep( *common_args, 1, 0, 0, 0, 1 ) @@ -63,14 +63,18 @@ def test_run_part_timestep(common_args): @staticmethod def test_run_part_timeblock(common_args): + # arrange num_times = int( RUN_PART_OPT_CTOR_ARG_SIMULATION["t_output"] / RUN_PART_OPT_CTOR_ARG_SIMULATION["del_t"] ) - (last_output_time, last_progress_time, i_output) = ppmc.run_part_timeblock( + + # act + last_output_time, last_progress_time, i_output = ppmc.run_part_timeblock( *common_args, 1, num_times, 0, 0, 0, 1 ) + # assert assert last_output_time == RUN_PART_OPT_CTOR_ARG_SIMULATION["t_output"] assert last_progress_time == 0.0 assert i_output == 2 @@ -116,7 +120,15 @@ def test_run_part_do_condensation(common_args, tmp_path): ((False, True), (True, False)), ), ) - def test_run_part_allow_flag_mismatch(common_args, tmp_path, flags): + @pytest.mark.parametrize( + "fun_args", + ( + ("run_part", []), + ("run_part_timestep", [0, 0, 0, 0, 0]), + ("run_part_timeblock", [0, 0, 0, 0, 0, 0]), + ), + ) + def test_run_part_allow_flag_mismatch(common_args, tmp_path, fun_args, flags): # arrange filename = tmp_path / "test" env_state = ppmc.EnvState(ENV_STATE_CTOR_ARG_HIGH_RH) @@ -140,7 +152,7 @@ def test_run_part_allow_flag_mismatch(common_args, tmp_path, flags): # act with pytest.raises(RuntimeError) as excinfo: - ppmc.run_part(*args) + getattr(ppmc, fun_args[0])(*args, *fun_args[1]) # assert assert ( From c9b40e4aceb1c4c1e6265b6559a66a36a7766eef Mon Sep 17 00:00:00 2001 From: Sylwester Arabas Date: Wed, 5 Jun 2024 10:00:21 +0200 Subject: [PATCH 5/5] disable new exception-checking tests on Apple Silicon --- tests/test_aero_state.py | 1 + tests/test_run_part.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/tests/test_aero_state.py b/tests/test_aero_state.py index 61d90433..61c8c8d9 100644 --- a/tests/test_aero_state.py +++ b/tests/test_aero_state.py @@ -573,6 +573,7 @@ def test_dist_sample_mono(): ((False, True), (True, False)), ), ) + @pytest.mark.skipif(platform.machine() == "arm64", reason="TODO #348") def test_dist_sample_different_halving(args): # arrange aero_data = ppmc.AeroData(AERO_DATA_CTOR_ARG_MINIMAL) diff --git a/tests/test_run_part.py b/tests/test_run_part.py index 20f0ce4c..df95a11b 100644 --- a/tests/test_run_part.py +++ b/tests/test_run_part.py @@ -4,6 +4,8 @@ # Authors: https://github.com/open-atmos/PyPartMC/graphs/contributors # #################################################################################################### +import platform + import numpy as np import pytest @@ -128,6 +130,7 @@ def test_run_part_do_condensation(common_args, tmp_path): ("run_part_timeblock", [0, 0, 0, 0, 0, 0]), ), ) + @pytest.mark.skipif(platform.machine() == "arm64", reason="TODO #348") def test_run_part_allow_flag_mismatch(common_args, tmp_path, fun_args, flags): # arrange filename = tmp_path / "test"