diff --git a/src/aerie_cli/aerie_client.py b/src/aerie_cli/aerie_client.py index a0c35bb4..b2cfdbe1 100644 --- a/src/aerie_cli/aerie_client.py +++ b/src/aerie_cli/aerie_client.py @@ -4,6 +4,7 @@ from typing import Dict from typing import List from copy import deepcopy +from warnings import warn import arrow @@ -12,6 +13,7 @@ from .schemas.api import ApiMissionModelCreate from .schemas.api import ApiMissionModelRead from .schemas.api import ApiResourceSampleResults +from .schemas.api import ApiSimulationDatasetRead from .schemas.client import Activity from .schemas.client import ActivityPlanCreate from .schemas.client import ActivityPlanRead @@ -19,6 +21,7 @@ from .schemas.client import ExpansionRun from .schemas.client import ExpansionRule from .schemas.client import ExpansionSet +from .schemas.client import SimulationDataset from .schemas.client import ResourceType from .utils.serialization import postgres_interval_to_microseconds from .aerie_host import AerieHost @@ -355,7 +358,7 @@ def exec_sim_query(): return sim_dataset_id def get_resource_timelines(self, plan_id: int): - samples = self.get_resource_samples(self.get_simulation_dataset_ids_by_plan_id(plan_id)[0]) + samples = self.get_resource_samples(self.list_simulation_datasets_by_plan_id(plan_id)[0].id) api_resource_timeline = ApiResourceSampleResults.from_dict(samples) return api_resource_timeline @@ -973,27 +976,58 @@ def get_rules_by_type(self) -> Dict[str, List[ExpansionRule]]: return rules_by_type def get_simulation_dataset_ids_by_plan_id(self, plan_id: int) -> List[int]: - """Get the IDs of the simulation datasets generated from a given plan + warn("get_simulation_dataset_ids_by_plan_id is deprecated. " + "Use list_simulation_datasets_by_plan_id instead", + DeprecationWarning, + stacklevel=2) + return [s.id for s in self.list_simulation_datasets_by_plan_id(plan_id)] + + # TODO: Change output type to sim dataset + def list_simulation_datasets_by_plan_id(self, plan_id: int) -> List[SimulationDataset]: + """Get metadata for the simulation datasets generated from a given plan Args: plan_id (int): ID of parent plan Returns: - List[int]: IDs of simulation datasets in descending order + List[SimulationDataset]: Simulation datasets in descending order by ID """ + # Since GQL will group results by simulation, we have to sort client-side get_simulation_dataset_query = """ query GetSimulationDatasetId($plan_id: Int!) { - simulation(where: {plan_id: {_eq: $plan_id}}, order_by: { id: desc }, limit: 1) { - simulation_datasets(order_by: { id: desc }) { - id + plan_by_pk(id: $plan_id) { + simulations { + simulation_datasets { + id + simulation_id + dataset_id + offset_from_plan_start + plan_revision + model_revision + simulation_template_revision + simulation_revision + dataset_revision + arguments + simulation_start_time + simulation_end_time + status + reason + canceled + requested_by + requested_at + } } } } """ data = self.aerie_host.post_to_graphql( get_simulation_dataset_query, plan_id=plan_id) - return [d["id"] for d in data[0]["simulation_datasets"]] + result = [SimulationDataset.from_api_read(ApiSimulationDatasetRead.from_dict(d)) + for sim in data["simulations"] + for d in sim["simulation_datasets"]] + result.sort(key=lambda s: s.id, reverse=True) + return result def expand_simulation( self, simulation_dataset_id: int, expansion_set_id: int diff --git a/src/aerie_cli/commands/expansion.py b/src/aerie_cli/commands/expansion.py index fff57453..753b4f35 100644 --- a/src/aerie_cli/commands/expansion.py +++ b/src/aerie_cli/commands/expansion.py @@ -72,8 +72,8 @@ def list_expansion_runs( client = CommandContext.get_client() if simulation_dataset_id is None: - simulation_datasets = client.get_simulation_dataset_ids_by_plan_id( - plan_id) + simulation_datasets = [ + d.id for d in client.list_simulation_datasets_by_plan_id(plan_id)] table_caption = f'All runs for Plan ID {plan_id}' else: simulation_datasets = [simulation_dataset_id] @@ -132,8 +132,8 @@ def list_sequences( client = CommandContext.get_client() if simulation_dataset_id is None: - simulation_datasets = client.get_simulation_dataset_ids_by_plan_id( - plan_id) + simulation_datasets = [ + d.id for d in client.list_simulation_datasets_by_plan_id(plan_id)] table_caption = f'All sequences for Plan ID {plan_id}' else: simulation_datasets = [simulation_dataset_id] diff --git a/src/aerie_cli/commands/plans.py b/src/aerie_cli/commands/plans.py index 91557576..6d2c3dfa 100644 --- a/src/aerie_cli/commands/plans.py +++ b/src/aerie_cli/commands/plans.py @@ -242,9 +242,9 @@ def list(): table.add_column("Latest Sim. Dataset ID", no_wrap=True) table.add_column("Model ID", no_wrap=True) for activity_plan in resp: - sim_ids = client.get_simulation_dataset_ids_by_plan_id(activity_plan.id) + sim_ids = client.list_simulation_datasets_by_plan_id(activity_plan.id) if len(sim_ids): - simulation_dataset_id = str(max(sim_ids)) + simulation_dataset_id = str(sim_ids[0].id) else: simulation_dataset_id = '' diff --git a/src/aerie_cli/schemas/api.py b/src/aerie_cli/schemas/api.py index 1d7a872e..b4794ead 100644 --- a/src/aerie_cli/schemas/api.py +++ b/src/aerie_cli/schemas/api.py @@ -191,3 +191,32 @@ class ApiMissionModelCreate(ApiSerialize): @define class ApiMissionModelRead(ApiMissionModelCreate): id: int + + +@define +class ApiSimulationDatasetRead(ApiSerialize): + id: int + simulation_id: int + dataset_id: int + offset_from_plan_start: timedelta = field( + converter=convert_to_time_delta + ) + plan_revision: int + model_revision: int + simulation_template_revision: int + simulation_revision: int + dataset_revision: int + arguments: Dict[str, Any] + simulation_start_time: Arrow = field( + converter=arrow.get + ) + simulation_end_time: Arrow = field( + converter=arrow.get + ) + status: str + reason: str + canceled: bool + requested_by: str + requested_at: Arrow = field( + converter=arrow.get + ) diff --git a/src/aerie_cli/schemas/client.py b/src/aerie_cli/schemas/client.py index a5b8a1a5..237d686c 100644 --- a/src/aerie_cli/schemas/client.py +++ b/src/aerie_cli/schemas/client.py @@ -26,6 +26,7 @@ from aerie_cli.schemas.api import ApiResourceSampleResults from aerie_cli.schemas.api import ApiSimulatedResourceSample from aerie_cli.schemas.api import ApiSimulationResults +from aerie_cli.schemas.api import ApiSimulationDatasetRead from aerie_cli.schemas.api import ActivityBase def parse_timedelta_str_converter(t) -> timedelta: @@ -370,3 +371,45 @@ class ExpansionRule(ClientSerialize): class ResourceType(ClientSerialize): name: str schema: Dict + +@define +class SimulationDataset(ClientSerialize): + id: int + simulation_id: int + dataset_id: int + offset_from_plan_start: timedelta + plan_revision: int + model_revision: int + simulation_template_revision: int + simulation_revision: int + dataset_revision: int + arguments: Dict[str, Any] + simulation_start_time: Arrow + simulation_end_time: Arrow + status: str + reason: str + canceled: bool + requested_by: str + requested_at: Arrow + + @classmethod + def from_api_read(cls, api_sim_dataset: ApiSimulationDatasetRead) -> "SimulationDataset": + return SimulationDataset( + id=api_sim_dataset.id, + simulation_id=api_sim_dataset.simulation_id, + dataset_id=api_sim_dataset.dataset_id, + offset_from_plan_start=api_sim_dataset.offset_from_plan_start, + plan_revision=api_sim_dataset.plan_revision, + model_revision=api_sim_dataset.model_revision, + simulation_template_revision=api_sim_dataset.simulation_template_revision, + simulation_revision=api_sim_dataset.simulation_revision, + dataset_revision=api_sim_dataset.dataset_revision, + arguments=api_sim_dataset.arguments, + simulation_start_time=api_sim_dataset.simulation_start_time, + simulation_end_time=api_sim_dataset.simulation_end_time, + status=api_sim_dataset.status, + reason=api_sim_dataset.reason, + canceled=api_sim_dataset.canceled, + requested_by=api_sim_dataset.requested_by, + requested_at=api_sim_dataset.requested_at + ) diff --git a/tests/integration_tests/test_plans.py b/tests/integration_tests/test_plans.py index 16585426..02476ab7 100644 --- a/tests/integration_tests/test_plans.py +++ b/tests/integration_tests/test_plans.py @@ -186,9 +186,9 @@ def test_delete_collaborators(): def test_plan_simulate(): result = cli_plan_simulate() - sim_ids = client.get_simulation_dataset_ids_by_plan_id(plan_id) + sim_ids = client.list_simulation_datasets_by_plan_id(plan_id) global sim_id - sim_id = sim_ids[-1] + sim_id = sim_ids[0].id assert result.exit_code == 0,\ f"{result.stdout}"\ f"{result.stderr}"