Skip to content

Commit

Permalink
Move over remaining Source formula implementations to be defined as m…
Browse files Browse the repository at this point in the history
…odel_func instead.

PiperOrigin-RevId: 709314324
  • Loading branch information
Nush395 authored and Torax team committed Dec 24, 2024
1 parent ccdef07 commit e0b6036
Show file tree
Hide file tree
Showing 7 changed files with 17 additions and 23 deletions.
2 changes: 1 addition & 1 deletion torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def test_adding_standard_source_via_config(self):
# pytype: enable=attribute-error
self.assertEqual(
source_models_builder.runtime_params['gas_puff_source'].mode,
source_runtime_params_lib.Mode.FORMULA_BASED, # On by default.
source_runtime_params_lib.Mode.MODEL_BASED, # On by default.
)
self.assertEqual(
source_models_builder.runtime_params['ohmic_heat_source'].mode,
Expand Down
4 changes: 2 additions & 2 deletions torax/fvm/tests/fvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def test_implicit_solve_block_uses_updated_boundary_conditions(self):
)
)
geo = geometry.build_circular_geometry(n_rho=num_cells)
source_models = source_models_lib.SourceModels()
source_models = source_models_builder()
initial_core_profiles = core_profile_setters.initial_core_profiles(
static_runtime_params_slice,
dynamic_runtime_params_slice,
Expand Down Expand Up @@ -732,7 +732,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
),
)

source_models = source_models_lib.SourceModels()
source_models = source_models_builder()
pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel()
initial_core_profiles = core_profile_setters.initial_core_profiles(
static_runtime_params_slice_theta0,
Expand Down
9 changes: 3 additions & 6 deletions torax/sources/electron_density_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class GasPuffRuntimeParams(runtime_params_lib.RuntimeParams):
puff_decay_length: runtime_params_lib.TimeInterpolatedInput = 0.05
# total gas puff particles/s
S_puff_tot: runtime_params_lib.TimeInterpolatedInput = 1e22
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED

def make_provider(
self,
Expand Down Expand Up @@ -110,7 +110,6 @@ class GasPuffSource(source.Source):

SOURCE_NAME: ClassVar[str] = 'gas_puff_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_puff_source'
formula: source.SourceProfileFunction = calc_puff_source
model_func: source.SourceProfileFunction = calc_puff_source

@property
Expand All @@ -132,7 +131,7 @@ class GenericParticleSourceRuntimeParams(runtime_params_lib.RuntimeParams):
deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.0
# total particle source
S_tot: runtime_params_lib.TimeInterpolatedInput = 1e22
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED

def make_provider(
self,
Expand Down Expand Up @@ -208,7 +207,6 @@ class GenericParticleSource(source.Source):

SOURCE_NAME: ClassVar[str] = 'generic_particle_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_particle_source'
formula: source.SourceProfileFunction = calc_generic_particle_source
model_func: source.SourceProfileFunction = calc_generic_particle_source

@property
Expand All @@ -232,7 +230,7 @@ class PelletRuntimeParams(runtime_params_lib.RuntimeParams):
pellet_deposition_location: runtime_params_lib.TimeInterpolatedInput = 0.85
# total pellet particles/s (continuous pellet model)
S_pellet_tot: runtime_params_lib.TimeInterpolatedInput = 2e22
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED

def make_provider(
self,
Expand Down Expand Up @@ -298,7 +296,6 @@ class PelletSource(source.Source):

SOURCE_NAME: ClassVar[str] = 'pellet_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_pellet_source'
formula: source.SourceProfileFunction = calc_pellet_source
model_func: source.SourceProfileFunction = calc_pellet_source

@property
Expand Down
11 changes: 3 additions & 8 deletions torax/sources/generic_current_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams):

# Toggles if external current is provided absolutely or as a fraction of Ip.
use_absolute_current: bool = False
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED

@property
def grid_type(self) -> base.GridType:
Expand Down Expand Up @@ -236,7 +236,6 @@ class GenericCurrentSource(source.Source):

SOURCE_NAME: ClassVar[str] = 'generic_current_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'calc_generic_current_face'
formula: source.SourceProfileFunction = calculate_generic_current_face
hires_formula: source.SourceProfileFunction = _calculate_generic_current_hires
model_func: source.SourceProfileFunction = calculate_generic_current_face

Expand Down Expand Up @@ -305,12 +304,8 @@ def generic_current_source_hires(
geo=geo,
core_profiles=None,
# There is no model for this source.
model_func=(
lambda _0, _1, _2, _3, _4, _5: jnp.zeros_like(
geo.rho_hires_norm
)
),
formula=self.hires_formula,
model_func=self.hires_formula,
formula=None,
output_shape=geo.rho_hires_norm.shape,
prescribed_values=hires_prescribed_values,
source_models=getattr(self, 'source_models', None),
Expand Down
3 changes: 1 addition & 2 deletions torax/sources/generic_ion_el_heat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class RuntimeParams(runtime_params_lib.RuntimeParams):
Ptot: runtime_params_lib.TimeInterpolatedInput = 120e6
# electron heating fraction
el_heat_fraction: runtime_params_lib.TimeInterpolatedInput = 0.66666
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.FORMULA_BASED
mode: runtime_params_lib.Mode = runtime_params_lib.Mode.MODEL_BASED

def make_provider(
self,
Expand Down Expand Up @@ -151,7 +151,6 @@ class GenericIonElectronHeatSource(source.Source):

SOURCE_NAME: ClassVar[str] = 'generic_ion_el_heat_source'
DEFAULT_MODEL_FUNCTION_NAME: ClassVar[str] = 'default_formula'
formula: source.SourceProfileFunction = default_formula
model_func: source.SourceProfileFunction = default_formula

@property
Expand Down
2 changes: 1 addition & 1 deletion torax/sources/tests/test_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_source_value(self):
)
source_models = source_models_builder()
source = source_models.sources[self._source_name]
source_builder.runtime_params.mode = source.supported_modes[1]
source_builder.runtime_params.mode = runtime_params_lib.Mode.MODEL_BASED
self.assertIsInstance(source, source_lib.Source)
geo = geometry.build_circular_geometry()
dynamic_runtime_params_slice = (
Expand Down
9 changes: 6 additions & 3 deletions torax/tests/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import dataclasses
from typing import Callable

from absl.testing import absltest
from absl.testing import parameterized
import jax
Expand All @@ -27,10 +28,12 @@
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import geometry
from torax.sources import generic_current_source
from torax.sources import runtime_params as source_runtime_params
from torax.sources import source_models as source_models_lib
from torax.tests.test_lib import torax_refs


_trapz = jax.scipy.integrate.trapezoid


Expand Down Expand Up @@ -108,9 +111,9 @@ def test_update_psi_from_j(
runtime_params = references.runtime_params
source_models_builder = source_models_lib.SourceModelsBuilder()
# Turn on the external current source.
source_models_builder.runtime_params['generic_current_source'].mode = (
source_runtime_params.Mode.FORMULA_BASED
)
source_models_builder.runtime_params[
generic_current_source.GenericCurrentSource.SOURCE_NAME
].mode = source_runtime_params.Mode.MODEL_BASED
source_models = source_models_builder()
dynamic_runtime_params_slice, geo = (
torax_refs.build_consistent_dynamic_runtime_params_slice_and_geometry(
Expand Down

0 comments on commit e0b6036

Please sign in to comment.