Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide client method for simulation dataset metadata #94

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions src/aerie_cli/aerie_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Dict
from typing import List
from copy import deepcopy
from warnings import warn

import arrow

Expand All @@ -12,13 +13,15 @@
from .schemas.api import ApiMissionModelCreate
from .schemas.api import ApiMissionModelRead
from .schemas.api import ApiResourceSampleResults
from .schemas.api import ApiSimulationDatasetRead
from .schemas.client import Activity
from .schemas.client import ActivityPlanCreate
from .schemas.client import ActivityPlanRead
from .schemas.client import CommandDictionaryInfo
from .schemas.client import ExpansionRun
from .schemas.client import ExpansionRule
from .schemas.client import ExpansionSet
from .schemas.client import SimulationDataset
from .schemas.client import ResourceType
from .utils.serialization import postgres_interval_to_microseconds
from .aerie_host import AerieHost
Expand Down Expand Up @@ -355,7 +358,7 @@ def exec_sim_query():
return sim_dataset_id

def get_resource_timelines(self, plan_id: int):
samples = self.get_resource_samples(self.get_simulation_dataset_ids_by_plan_id(plan_id)[0])
samples = self.get_resource_samples(self.list_simulation_datasets_by_plan_id(plan_id)[0].id)
api_resource_timeline = ApiResourceSampleResults.from_dict(samples)
return api_resource_timeline

Expand Down Expand Up @@ -973,27 +976,58 @@ def get_rules_by_type(self) -> Dict[str, List[ExpansionRule]]:
return rules_by_type

def get_simulation_dataset_ids_by_plan_id(self, plan_id: int) -> List[int]:
"""Get the IDs of the simulation datasets generated from a given plan
warn("get_simulation_dataset_ids_by_plan_id is deprecated. "
"Use list_simulation_datasets_by_plan_id instead",
DeprecationWarning,
stacklevel=2)
return [s.id for s in self.list_simulation_datasets_by_plan_id(plan_id)]

# TODO: Change output type to sim dataset
def list_simulation_datasets_by_plan_id(self, plan_id: int) -> List[SimulationDataset]:
"""Get metadata for the simulation datasets generated from a given plan

Args:
plan_id (int): ID of parent plan

Returns:
List[int]: IDs of simulation datasets in descending order
List[SimulationDataset]: Simulation datasets in descending order by ID
"""

# Since GQL will group results by simulation, we have to sort client-side
get_simulation_dataset_query = """
query GetSimulationDatasetId($plan_id: Int!) {
simulation(where: {plan_id: {_eq: $plan_id}}, order_by: { id: desc }, limit: 1) {
simulation_datasets(order_by: { id: desc }) {
id
plan_by_pk(id: $plan_id) {
cartermak marked this conversation as resolved.
Show resolved Hide resolved
simulations {
simulation_datasets {
id
simulation_id
dataset_id
offset_from_plan_start
plan_revision
model_revision
simulation_template_revision
simulation_revision
dataset_revision
arguments
simulation_start_time
simulation_end_time
status
reason
canceled
requested_by
requested_at
}
}
}
}
"""
data = self.aerie_host.post_to_graphql(
get_simulation_dataset_query, plan_id=plan_id)
return [d["id"] for d in data[0]["simulation_datasets"]]
result = [SimulationDataset.from_api_read(ApiSimulationDatasetRead.from_dict(d))
for sim in data["simulations"]
for d in sim["simulation_datasets"]]
result.sort(key=lambda s: s.id, reverse=True)
return result

def expand_simulation(
self, simulation_dataset_id: int, expansion_set_id: int
Expand Down
8 changes: 4 additions & 4 deletions src/aerie_cli/commands/expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def list_expansion_runs(
client = CommandContext.get_client()

if simulation_dataset_id is None:
simulation_datasets = client.get_simulation_dataset_ids_by_plan_id(
plan_id)
simulation_datasets = [
d.id for d in client.list_simulation_datasets_by_plan_id(plan_id)]
table_caption = f'All runs for Plan ID {plan_id}'
else:
simulation_datasets = [simulation_dataset_id]
Expand Down Expand Up @@ -132,8 +132,8 @@ def list_sequences(
client = CommandContext.get_client()

if simulation_dataset_id is None:
simulation_datasets = client.get_simulation_dataset_ids_by_plan_id(
plan_id)
simulation_datasets = [
d.id for d in client.list_simulation_datasets_by_plan_id(plan_id)]
table_caption = f'All sequences for Plan ID {plan_id}'
else:
simulation_datasets = [simulation_dataset_id]
Expand Down
4 changes: 2 additions & 2 deletions src/aerie_cli/commands/plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,9 @@ def list():
table.add_column("Latest Sim. Dataset ID", no_wrap=True)
table.add_column("Model ID", no_wrap=True)
for activity_plan in resp:
sim_ids = client.get_simulation_dataset_ids_by_plan_id(activity_plan.id)
sim_ids = client.list_simulation_datasets_by_plan_id(activity_plan.id)
if len(sim_ids):
simulation_dataset_id = str(max(sim_ids))
simulation_dataset_id = str(sim_ids[0].id)
else:
simulation_dataset_id = ''

Expand Down
29 changes: 29 additions & 0 deletions src/aerie_cli/schemas/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,32 @@ class ApiMissionModelCreate(ApiSerialize):
@define
class ApiMissionModelRead(ApiMissionModelCreate):
id: int


@define
class ApiSimulationDatasetRead(ApiSerialize):
id: int
simulation_id: int
dataset_id: int
offset_from_plan_start: timedelta = field(
converter=convert_to_time_delta
)
plan_revision: int
model_revision: int
simulation_template_revision: int
simulation_revision: int
dataset_revision: int
arguments: Dict[str, Any]
simulation_start_time: Arrow = field(
converter=arrow.get
)
simulation_end_time: Arrow = field(
converter=arrow.get
)
status: str
reason: str
canceled: bool
requested_by: str
requested_at: Arrow = field(
converter=arrow.get
)
43 changes: 43 additions & 0 deletions src/aerie_cli/schemas/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aerie_cli.schemas.api import ApiResourceSampleResults
from aerie_cli.schemas.api import ApiSimulatedResourceSample
from aerie_cli.schemas.api import ApiSimulationResults
from aerie_cli.schemas.api import ApiSimulationDatasetRead
from aerie_cli.schemas.api import ActivityBase

def parse_timedelta_str_converter(t) -> timedelta:
Expand Down Expand Up @@ -370,3 +371,45 @@ class ExpansionRule(ClientSerialize):
class ResourceType(ClientSerialize):
name: str
schema: Dict

@define
class SimulationDataset(ClientSerialize):
id: int
simulation_id: int
dataset_id: int
offset_from_plan_start: timedelta
plan_revision: int
model_revision: int
simulation_template_revision: int
simulation_revision: int
dataset_revision: int
arguments: Dict[str, Any]
simulation_start_time: Arrow
simulation_end_time: Arrow
status: str
reason: str
canceled: bool
requested_by: str
requested_at: Arrow

@classmethod
def from_api_read(cls, api_sim_dataset: ApiSimulationDatasetRead) -> "SimulationDataset":
return SimulationDataset(
id=api_sim_dataset.id,
simulation_id=api_sim_dataset.simulation_id,
dataset_id=api_sim_dataset.dataset_id,
offset_from_plan_start=api_sim_dataset.offset_from_plan_start,
plan_revision=api_sim_dataset.plan_revision,
model_revision=api_sim_dataset.model_revision,
simulation_template_revision=api_sim_dataset.simulation_template_revision,
simulation_revision=api_sim_dataset.simulation_revision,
dataset_revision=api_sim_dataset.dataset_revision,
arguments=api_sim_dataset.arguments,
simulation_start_time=api_sim_dataset.simulation_start_time,
simulation_end_time=api_sim_dataset.simulation_end_time,
status=api_sim_dataset.status,
reason=api_sim_dataset.reason,
canceled=api_sim_dataset.canceled,
requested_by=api_sim_dataset.requested_by,
requested_at=api_sim_dataset.requested_at
)
4 changes: 2 additions & 2 deletions tests/integration_tests/test_plans.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,9 @@ def test_delete_collaborators():

def test_plan_simulate():
result = cli_plan_simulate()
sim_ids = client.get_simulation_dataset_ids_by_plan_id(plan_id)
sim_ids = client.list_simulation_datasets_by_plan_id(plan_id)
global sim_id
sim_id = sim_ids[-1]
sim_id = sim_ids[0].id
assert result.exit_code == 0,\
f"{result.stdout}"\
f"{result.stderr}"
Expand Down
Loading