Skip to content

Commit

Permalink
Add torax_mesh to StaticRuntimeParamsSlice.
Browse files Browse the repository at this point in the history
torax_mesh is constant for a simulation and any modification implies a change in array sizes and hence recompilation. It is thus safe to have as a static variable and useful to have concrete values of the mesh in JAX jitted functions for simplifying various calculations.

PiperOrigin-RevId: 709045823
  • Loading branch information
jcitrin authored and Torax team committed Dec 24, 2024
1 parent a287db8 commit e998dcc
Show file tree
Hide file tree
Showing 29 changed files with 172 additions and 81 deletions.
30 changes: 26 additions & 4 deletions torax/config/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class StaticRuntimeParamsSlice:
stepper: stepper_params.StaticRuntimeParams
# Mapping of source name to source-specific static runtime params.
sources: Mapping[str, sources_params.StaticRuntimeParams]
# Torax mesh used to construct the geometry.
torax_mesh: geometry.Grid1D
# Solve the ion heat equation (ion temperature evolves over time)
ion_heat_eq: bool
# Solve the electron heat equation (electron temperature evolves over time)
Expand All @@ -133,7 +135,8 @@ class StaticRuntimeParamsSlice:
def __hash__(self):
return hash((
self.stepper,
tuple(sorted(self.sources.items())), # Hashable version of sources.
tuple(sorted(self.sources.items())), # Hashable version of sources
hash(self.torax_mesh), # Grid1D has a hash method defined.
self.ion_heat_eq,
self.el_heat_eq,
self.current_eq,
Expand All @@ -156,19 +159,38 @@ def _build_dynamic_sources(


def build_static_runtime_params_slice(
*,
runtime_params: general_runtime_params_lib.GeneralRuntimeParams,
source_runtime_params: dict[str, sources_params.RuntimeParams],
torax_mesh: geometry.Grid1D,
stepper: stepper_params.RuntimeParams | None = None,
) -> StaticRuntimeParamsSlice:
"""Builds a StaticRuntimeParamsSlice."""
# t set to None because there shouldnt be time-dependent params in the static
# config.
"""Builds a StaticRuntimeParamsSlice.
Args:
runtime_params: General runtime params from which static params are taken,
which are the choices on equations being solved, and adaptive dt.
source_runtime_params: data from which the source related static variables
are taken, which are the explicit/implicit toggle and calculation mode for
each source.
torax_mesh: The torax mesh, e.g. the grid used to construct the geometry.
This is static for the entire simulation and any modification implies
changed array sizes, and hence would require a recompilation. Useful to
have a static (concrete) mesh for various internal calculations.
stepper: stepper runtime params from which stepper static variables are
extracted, related to solver methods. If None, defaults to the
default stepper runtime params.
Returns:
A StaticRuntimeParamsSlice.
"""
stepper = stepper or stepper_params.RuntimeParams()
return StaticRuntimeParamsSlice(
sources={
source_name: specific_source_runtime_params.build_static_params()
for source_name, specific_source_runtime_params in source_runtime_params.items()
},
torax_mesh=torax_mesh,
stepper=stepper.build_static_params(),
ion_heat_eq=runtime_params.numerics.ion_heat_eq,
el_heat_eq=runtime_params.numerics.el_heat_eq,
Expand Down
13 changes: 9 additions & 4 deletions torax/config/tests/runtime_params_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,15 +608,18 @@ def test_update_dynamic_slice_provider_updates_transport(
self.assertEqual(dcs.transport.De_inner, 2.0)

def test_static_runtime_params_slice_hash_same_for_same_params(self):
"""Tests that the hash is the same for the same static params."""
runtime_params = general_runtime_params.GeneralRuntimeParams()
source_models_builder = default_sources.get_default_sources_builder()
static_slice1 = runtime_params_slice_lib.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=self._geo.torax_mesh,
)
static_slice2 = runtime_params_slice_lib.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=self._geo.torax_mesh,
)
self.assertEqual(hash(static_slice1), hash(static_slice2))

Expand All @@ -627,8 +630,9 @@ def test_static_runtime_params_slice_hash_different_for_different_params(
runtime_params = general_runtime_params.GeneralRuntimeParams()
source_models_builder = default_sources.get_default_sources_builder()
static_slice1 = runtime_params_slice_lib.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=self._geo.torax_mesh,
)
runtime_params_mod = dataclasses.replace(
runtime_params,
Expand All @@ -638,8 +642,9 @@ def test_static_runtime_params_slice_hash_different_for_different_params(
),
)
static_slice2 = runtime_params_slice_lib.build_static_runtime_params_slice(
runtime_params_mod,
runtime_params=runtime_params_mod,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=self._geo.torax_mesh,
)
self.assertNotEqual(hash(static_slice1), hash(static_slice2))

Expand Down
3 changes: 2 additions & 1 deletion torax/fvm/tests/calc_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ def test_calc_coeffs_smoke_test(
)
static_runtime_params_slice = (
runtime_params_slice_lib.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
stepper=stepper_params,
)
)
Expand Down
15 changes: 9 additions & 6 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,9 +430,10 @@ def test_nonlinear_solve_block_loss_minimum(
)
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_params,
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
source_runtime_params=source_models_builder.runtime_params,
stepper=stepper_params,
)
)
core_profiles = core_profile_setters.initial_core_profiles(
Expand Down Expand Up @@ -572,9 +573,10 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
)
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_params,
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
source_runtime_params=source_models_builder.runtime_params,
stepper=stepper_params,
)
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
Expand Down Expand Up @@ -717,9 +719,10 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
)
static_runtime_params_slice_theta0 = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_params,
runtime_params=runtime_params,
torax_mesh=geo.torax_mesh,
source_runtime_params=source_models_builder.runtime_params,
stepper=stepper_params,
)
)
static_runtime_params_slice_theta05 = dataclasses.replace(
Expand Down
3 changes: 2 additions & 1 deletion torax/pedestal_model/tests/set_tped_nped.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def test_build_and_call_pedestal_model(
pedestal_model = builder()
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
)
core_profiles = core_profile_setters.initial_core_profiles(
Expand Down
4 changes: 3 additions & 1 deletion torax/sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,11 +932,13 @@ def build_sim_object(
transport_model = transport_model_builder()
pedestal_model = pedestal_model_builder()

# TODO(b/385788907): Clearly document all changes that lead to recompilations.
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params=runtime_params,
stepper=stepper_builder.runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geometry_provider.torax_mesh,
stepper=stepper_builder.runtime_params,
)
)
dynamic_runtime_params_slice_provider = (
Expand Down
5 changes: 3 additions & 2 deletions torax/simulation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,10 @@ def update_sim(
_update_source_params(sim, source_runtime_params)
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_runtime_params,
torax_mesh=geo_provider.torax_mesh,
stepper=stepper_runtime_params,
)
)
dynamic_runtime_params_slice_provider = (
Expand Down
5 changes: 3 additions & 2 deletions torax/sources/tests/bremsstrahlung_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def test_compare_against_known(
)
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_runtime_params.RuntimeParams(),
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
stepper=stepper_runtime_params.RuntimeParams(),
)
)
source_models = source_models_builder()
Expand Down
15 changes: 9 additions & 6 deletions torax/sources/tests/electron_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,10 @@ def test_source_value(self):
)
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_runtime_params.RuntimeParams(),
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
stepper=stepper_runtime_params.RuntimeParams(),
)
)
core_profiles = core_profile_setters.initial_core_profiles(
Expand Down Expand Up @@ -119,9 +120,10 @@ def test_invalid_source_types_raise_errors(self):
)
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_runtime_params.RuntimeParams(),
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
stepper=stepper_runtime_params.RuntimeParams(),
)
)
core_profiles = core_profile_setters.initial_core_profiles(
Expand All @@ -136,9 +138,10 @@ def test_invalid_source_types_raise_errors(self):
# Construct a new slice with the given mode
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
stepper=stepper_runtime_params.RuntimeParams(),
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
stepper=stepper_runtime_params.RuntimeParams(),
)
)
with self.subTest(unsupported_mode.name):
Expand Down
3 changes: 2 additions & 1 deletion torax/sources/tests/fusion_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,9 @@ def test_calc_fusion(
)
)
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
source_models = source_models_builder()
core_profiles = core_profile_setters.initial_core_profiles(
Expand Down
6 changes: 4 additions & 2 deletions torax/sources/tests/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,13 @@ def test_generic_current_hires(self):
t=runtime_params.numerics.t_initial,
)
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params={
generic_current_source.GenericCurrentSource.SOURCE_NAME: (
source_builder.runtime_params
),
},
torax_mesh=geo.torax_mesh,
)
self.assertIsInstance(source, generic_current_source.GenericCurrentSource)
self.assertIsNotNone(
Expand Down Expand Up @@ -103,12 +104,13 @@ def test_profile_is_on_face_grid(self):
t=runtime_params.numerics.t_initial,
)
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params={
generic_current_source.GenericCurrentSource.SOURCE_NAME: (
source_builder.runtime_params
),
},
torax_mesh=geo.torax_mesh,
)
self.assertEqual(
source.get_value(
Expand Down
6 changes: 4 additions & 2 deletions torax/sources/tests/impurity_radiation_heat_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,9 @@ def test_source_value(self):
)
)
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
core_profiles = core_profile_setters.initial_core_profiles(
static_runtime_params_slice=static_slice,
Expand Down Expand Up @@ -180,8 +181,9 @@ def test_invalid_source_types_raise_errors(self):
)
static_runtime_params_slice = (
runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
)
core_profiles = core_profile_setters.initial_core_profiles(
Expand Down
6 changes: 4 additions & 2 deletions torax/sources/tests/ion_cyclotron_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,9 @@ def test_icrh_output_matches_total_power(
)
)
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
core_profiles = core_profile_setters.initial_core_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
Expand Down Expand Up @@ -233,8 +234,9 @@ def test_source_value(self, mock_path):
)
)
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
core_profiles = core_profile_setters.initial_core_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
Expand Down
9 changes: 6 additions & 3 deletions torax/sources/tests/qei_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ def test_source_value(self):
runtime_params = general_runtime_params.GeneralRuntimeParams()
geo = geometry.build_circular_geometry()
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
dynamic_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand Down Expand Up @@ -86,8 +87,9 @@ def test_invalid_source_types_raise_errors(self):
runtime_params = general_runtime_params.GeneralRuntimeParams()
geo = geometry.build_circular_geometry()
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
dynamic_slice = runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params,
Expand All @@ -106,13 +108,14 @@ def test_invalid_source_types_raise_errors(self):
with self.subTest(unsupported_mode.name):
with self.assertRaises(ValueError):
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params,
runtime_params=runtime_params,
source_runtime_params={
'qei_source': dataclasses.replace(
source_builder.runtime_params,
mode=unsupported_mode,
)
},
torax_mesh=geo.torax_mesh,
)
# Force pytype to recognize `source` has `get_qei`
assert isinstance(source, qei_source.QeiSource)
Expand Down
Loading

0 comments on commit e998dcc

Please sign in to comment.