Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Warm start a stratey based on config provided seed conditions #487

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,46 @@ def to_dict(self, deduplicate: bool = True) -> Dict[str, Any]:
_dict[section][setting] = self[section][setting]
return _dict

# Turn the metadata section into JSON.
def jsonifyMetadata(self) -> str:
"""Turn the metadata section into JSON.
def get_metadata(self, only_extra: bool = False) -> Dict[Any, Any]:
"""Return a dictionary of the metadata section.

Args:
only_extra (bool, optional): Only gather the extra metadata. Defaults to False.

Returns:
str: JSON representation of the metadata section.
Dict[Any, Any]: a collection of the metadata stored in this conig.
"""
configdict = self.to_dict()
return json.dumps(configdict["metadata"])
metadata = configdict["metadata"].copy()

if only_extra:
default_metadata = [
"experiment_name",
"experiment_description",
"experiment_id",
"participant_id",
]
for name in default_metadata:
metadata.pop(name, None)

return metadata

# Turn the metadata section into JSON.
def jsonifyMetadata(self, only_extra: bool = False) -> str:
"""Return a json string of the metadata section.

Args:
only_extra (bool): Only jsonify the extra meta data.

Returns:
str: A json string representing the metadata dictionary or an empty string
if there is no metadata to return.
"""
metadata = self.get_metadata(only_extra)
if len(metadata.keys()) == 0:
return ""
else:
return json.dumps(metadata)

# Turn the entire config into JSON format.
def jsonifyAll(self) -> str:
Expand Down
478 changes: 478 additions & 0 deletions aepsych/database/data_fetcher.py

Large diffs are not rendered by default.

103 changes: 51 additions & 52 deletions aepsych/database/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
import uuid
import warnings
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -127,7 +128,7 @@ def execute_sql_query(self, query: str, vals: Dict[str, str]) -> List[Any]:
List[Any]: The results of the query.
"""
with self.session_scope() as session:
return session.execute(query, vals).fetchall()
return session.execute(query, vals).all()

def get_master_records(self) -> List[tables.DBMasterTable]:
"""Grab the list of master records.
Expand All @@ -138,18 +139,18 @@ def get_master_records(self) -> List[tables.DBMasterTable]:
records = self._session.query(tables.DBMasterTable).all()
return records

def get_master_record(self, experiment_id: int) -> Optional[tables.DBMasterTable]:
"""Grab the list of master record for a specific experiment (master) id.
def get_master_record(self, master_id: int) -> Optional[tables.DBMasterTable]:
"""Grab the list of master record for a specific master id (uniquie_id of master table).

Args:
experiment_id (int): The experiment id.
master_id (int): The master_id, which is the master key of the master table.

Returns:
tables.DBMasterTable or None: The master record or None if it doesn't exist.
"""
records = (
self._session.query(tables.DBMasterTable)
.filter(tables.DBMasterTable.experiment_id == experiment_id)
.filter(tables.DBMasterTable.unique_id == master_id)
.all()
)

Expand All @@ -162,7 +163,7 @@ def get_replay_for(self, master_id: int) -> Optional[List[tables.DbReplayTable]]
"""Get the replay records for a specific master row.

Args:
master_id (int): The master id.
master_id (int): The unique id for the master row (it's the master key).

Returns:
List[tables.DbReplayTable] or None: The replay records or None if they don't exist.
Expand All @@ -178,7 +179,7 @@ def get_strats_for(self, master_id: int = 0) -> Optional[List[Any]]:
"""Get the strat records for a specific master row.

Args:
master_id (int): The master id. Defaults to 0.
master_id (int): The master table unique ID. Defaults to 0.

Returns:
List[Any] or None: The strat records or None if they don't exist.
Expand Down Expand Up @@ -238,15 +239,22 @@ def get_raw_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:

return None

def get_all_params_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:
def get_all_params_for(self, master_id: int) -> Optional[List[tables.DbParamTable]]:
"""Get the parameters for all the iterations of a specific experiment.

Args:
master_id (int): The master id.

Returns:
List[tables.DbRawTable] or None: The parameters or None if they don't exist.
List[tables.DbParamTable] or None: The parameters or None if they don't exist.
"""
warnings.warn(
"get_all_params_for is the same as get_param_for since there can only be one instance of any master_id",
DeprecationWarning,
)
return self.get_param_for(master_id=master_id)

# TODO: This function should change to being able to get params for all experiments given specific metadata
raw_record = self.get_raw_for(master_id)
params = []

Expand All @@ -258,36 +266,40 @@ def get_all_params_for(self, master_id: int) -> Optional[List[tables.DbRawTable]

return None

def get_param_for(
self, master_id: int, iteration_id: int
) -> Optional[List[tables.DbRawTable]]:
def get_param_for(self, master_id: int) -> Optional[List[tables.DbParamTable]]:
"""Get the parameters for a specific iteration of a specific experiment.

Args:
master_id (int): The master id.
iteration_id (int): The iteration id.

Returns:
List[tables.DbRawTable] or None: The parameters or None if they don't exist.
List[tables.DbParamTable] or None: The parameters or None if they don't exist.
"""
raw_record = self.get_raw_for(master_id)

if raw_record is not None:
for raw in raw_record:
if raw.unique_id == iteration_id:
if raw.unique_id == master_id:
return raw.children_param

return None

def get_all_outcomes_for(self, master_id: int) -> Optional[List[tables.DbRawTable]]:
def get_all_outcomes_for(self, master_id: int) -> Optional[List[tables.DbOutcomeTable]]:
"""Get the outcomes for all the iterations of a specific experiment.

Args:
master_id (int): The master id.

Returns:
List[tables.DbRawTable] or None: The outcomes or None if they don't exist.
List[tables.DbOutcomeTable] or None: The outcomes or None if they don't exist.
"""
warnings.warn(
"get_all_outcomes_for is the same as get_outcome_for since there can only be one instance of any master_id",
DeprecationWarning,
)
return self.get_outcome_for(master_id=master_id)

# TODO: This function should change to being able to get outcomes for all experiments given specific metadata
raw_record = self.get_raw_for(master_id)
outcomes = []

Expand All @@ -299,72 +311,59 @@ def get_all_outcomes_for(self, master_id: int) -> Optional[List[tables.DbRawTabl

return None

def get_outcome_for(
self, master_id: int, iteration_id: int
) -> Optional[List[tables.DbRawTable]]:
def get_outcome_for(self, master_id: int) -> Optional[List[tables.DbOutcomeTable]]:
"""Get the outcomes for a specific iteration of a specific experiment.

Args:
master_id (int): The master id.
iteration_id (int): The iteration id.

Returns:
List[tables.DbRawTable] or None: The outcomes or None if they don't exist.
List[tables.DbOutcomeTable] or None: The outcomes or None if they don't exist.
"""
raw_record = self.get_raw_for(master_id)

if raw_record is not None:
for raw in raw_record:
if raw.unique_id == iteration_id:
if raw.unique_id == master_id:
return raw.children_outcome

return None

def record_setup(
self,
description: str,
name: str,
description: str = None,
name: str = None,
extra_metadata: Optional[str] = None,
id: Optional[str] = None,
exp_id: Optional[str] = None,
request: Dict[str, Any] = None,
participant_id: Optional[int] = None,
par_id: Optional[int] = None,
) -> str:
"""Record the setup of an experiment.

Args:
description (str): The description of the experiment.
name (str): The name of the experiment.
description (str, optional): The description of the experiment, defaults to None.
name (str, optional): The name of the experiment, defaults to None.
extra_metadata (str, optional): Extra metadata. Defaults to None.
id (str, optional): The id of the experiment. Defaults to None.
request (Dict[str, Any]): The request. Defaults to None.
participant_id (int, optional): The participant id. Defaults to None.
exp_id (str, optional): The id of the experiment. Defaults to a generated uuid.
request (Dict[str, Any], optional): The request. Defaults to None.
par_id (int, optional): The participant id. Defaults to generated uuid.

Returns:
str: The experiment id.
"""
self.get_engine()

if id is None:
master_table = tables.DBMasterTable()
master_table.experiment_description = description
master_table.experiment_name = name
master_table.experiment_id = str(uuid.uuid4())
if participant_id is not None:
master_table.participant_id = participant_id
else:
master_table.participant_id = str(
uuid.uuid4()
) # no p_id specified will result in a generated UUID

master_table.extra_metadata = extra_metadata

self._session.add(master_table)
master_table = tables.DBMasterTable()
master_table.experiment_description = description
master_table.experiment_name = name
master_table.experiment_id = exp_id if exp_id is not None else str(uuid.uuid4())
master_table.participant_id = (
par_id if par_id is not None else str(uuid.uuid4())
)
master_table.extra_metadata = extra_metadata
self._session.add(master_table)

logger.debug(f"record_setup = [{master_table}]")
else:
master_table = self.get_master_record(id)
if master_table is None:
raise RuntimeError(f"experiment id {id} doesn't exist in the db.")
logger.debug(f"record_setup = [{master_table}]")

record = tables.DbReplayTable()
record.message_type = "setup"
Expand Down
42 changes: 17 additions & 25 deletions aepsych/database/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,6 @@

Base = declarative_base()

"""
Original Schema
CREATE TABLE master (
unique_id INTEGER NOT NULL,
experiment_name VARCHAR(256),
experiment_description VARCHAR(2048),
experiment_id VARCHAR(10),
PRIMARY KEY (unique_id),
UNIQUE (experiment_id)
);
CREATE TABLE replay_data (
unique_id INTEGER NOT NULL,
timestamp DATETIME,
message_type VARCHAR(64),
message_contents BLOB,
master_table_id INTEGER,
PRIMARY KEY (unique_id),
FOREIGN KEY(master_table_id) REFERENCES master (unique_id)
);
"""


class DBMasterTable(Base):
"""
Expand All @@ -62,10 +41,10 @@ class DBMasterTable(Base):
__tablename__ = "master"

unique_id = Column(Integer, primary_key=True, autoincrement=True)
experiment_name = Column(String(256))
experiment_description = Column(String(2048))
experiment_id = Column(String(10), unique=True)
participant_id = Column(String(50), unique=True)
experiment_name = Column(String(256), nullable=True)
experiment_description = Column(String(2048), nullable=True)
experiment_id = Column(String(10))
participant_id = Column(String(50))

extra_metadata = Column(String(4096)) # JSON-formatted metadata

Expand Down Expand Up @@ -176,6 +155,19 @@ def _add_column(engine: Engine, column: str) -> None:
except Exception as e:
logger.debug(f"Column already exists, no need to alter. [{e}]")

@staticmethod
def _update_column(engine: Engine, column: str, spec: str) -> None:
"""Update column with a new spec.

Args:
engine (Engine): The sqlalchemy engine.
column (str): The column name.
spec (str): The new column spec.
"""
logger.debug(f"Altering the master table column: {column} to this spec {spec}")
engine.execute(f"ALTER TABLE master MODIFY {column} {spec}")
engine.commit()


class DbReplayTable(Base):
__tablename__ = "replay_data"
Expand Down
Loading