diff --git a/pyproject.toml b/pyproject.toml index 8572e99..81a492c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "starfile", "scipy", "mrcfile", + "pydantic", "mmdf", "requests", "torch-fourier-filter" diff --git a/src/ttsim3d/input_models.py b/src/ttsim3d/input_models.py new file mode 100644 index 0000000..f218447 --- /dev/null +++ b/src/ttsim3d/input_models.py @@ -0,0 +1,24 @@ +"""Pydantic models for input parameters.""" + +from pydantic import BaseModel + + +class SimulationParams(BaseModel): + """Parameters for the simulation.""" + + pdb_filename: str + output_filename: str + sim_volume_shape: tuple[int, int, int] = (400, 400, 400) + sim_pixel_spacing: float = 0.95 + num_frames: int = 50 + fluence_per_frame: float = 1 + beam_energy_kev: int = 300 + dose_weighting: bool = True + dose_B: float = -1 + apply_dqe: bool = True + mtf_filename: str + b_scaling: float = 0.5 + added_B: float = 0.0 + upsampling: int = -1 + gpu_id: int = -999 + modify_signal: int = 1 diff --git a/src/ttsim3d/run_ttsim3d.py b/src/ttsim3d/run_ttsim3d.py index ce5e80a..f2c3402 100644 --- a/src/ttsim3d/run_ttsim3d.py +++ b/src/ttsim3d/run_ttsim3d.py @@ -1,13 +1,15 @@ """Simple run script.""" +from ttsim3d.input_models import SimulationParams from ttsim3d.simulate3d import simulate3d def main() -> None: """A test function to run the simulate3d function from the ttsim3d package.""" - simulate3d( + params = SimulationParams( pdb_filename="/Users/josh/git/2dtm_tests/simulator/parsed_6Q8Y_whole_LSU_match3.pdb", output_filename="/Users/josh/git/2dtm_tests/simulator/simulated_6Q8Y_whole_LSU_match3.mrc", + mtf_filename="/Users/josh/git/2dtm_tests/simulator/mtf_k2_300kV.star", sim_volume_shape=(400, 400, 400), sim_pixel_spacing=0.95, num_frames=50, @@ -16,14 +18,15 @@ def main() -> None: dose_weighting=True, dose_B=-1, apply_dqe=True, - mtf_filename="/Users/josh/git/2dtm_tests/simulator/mtf_k2_300kV.star", b_scaling=0.5, added_B=0.0, upsampling=-1, gpu_id=-999, - modify_signal=1, # This is how to apply the dose weighting. + modify_signal=1, ) + simulate3d(**params.model_dump()) + if __name__ == "__main__": main()