Skip to content

Commit

Permalink
Merge pull request #11 from AllenInstitute/feature/demand-exec-infra-…
Browse files Browse the repository at this point in the history
…updates

update demand exec infra with new features for configuration
  • Loading branch information
rpmcginty authored Jun 17, 2024
2 parents 510dc58 + 73e916b commit d74c575
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"BatchInvokedExecutorFragment",
"BatchInvokedLambdaFunction",
"DataSyncFragment",
"DistributedDataSyncFragment",
"DemandExecutionFragment",
"CleanFileSystemFragment",
"CleanFileSystemTriggerConfig",
Expand All @@ -13,6 +14,7 @@
)
from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics.data_sync import (
DataSyncFragment,
DistributedDataSyncFragment,
)
from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics.demand_execution import (
DemandExecutionFragment,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from aws_cdk import aws_iam as iam
from aws_cdk import aws_s3 as s3
from aws_cdk import aws_stepfunctions as sfn
from aws_cdk import aws_stepfunctions_tasks as sfn_tasks

from aibs_informatics_cdk_lib.common.aws.iam_utils import (
SFN_STATES_EXECUTION_ACTIONS,
Expand All @@ -15,6 +16,7 @@
)
from aibs_informatics_cdk_lib.constructs_.base import EnvBaseConstructMixins
from aibs_informatics_cdk_lib.constructs_.efs.file_system import MountPointConfiguration
from aibs_informatics_cdk_lib.constructs_.sfn.fragments.base import EnvBaseStateMachineFragment
from aibs_informatics_cdk_lib.constructs_.sfn.fragments.informatics.batch import (
BatchInvokedBaseFragment,
BatchInvokedLambdaFunction,
Expand Down Expand Up @@ -110,3 +112,88 @@ def required_inline_policy_statements(self) -> List[iam.PolicyStatement]:
actions=SFN_STATES_EXECUTION_ACTIONS + SFN_STATES_READ_ACCESS_ACTIONS,
),
]


class DistributedDataSyncFragment(BatchInvokedBaseFragment):
def __init__(
self,
scope: constructs.Construct,
id: str,
env_base: EnvBase,
aibs_informatics_docker_asset: Union[ecr_assets.DockerImageAsset, str],
batch_job_queue: Union[batch.JobQueue, str],
scaffolding_bucket: s3.Bucket,
mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None,
) -> None:
super().__init__(scope, id, env_base)
start_pass_state = sfn.Pass(
self,
f"{id}: Start",
parameters={
"request": sfn.JsonPath.object_at("$"),
},
)
prep_batch_sync_task_name = "prep-batch-data-sync-requests"

prep_batch_sync = BatchInvokedLambdaFunction(
scope=scope,
id=f"{id}: Prep Batch Data Sync",
env_base=env_base,
name=prep_batch_sync_task_name,
payload_path="$.request",
image=(
aibs_informatics_docker_asset
if isinstance(aibs_informatics_docker_asset, str)
else aibs_informatics_docker_asset.image_uri
),
handler="aibs_informatics_aws_lambda.handlers.data_sync.prepare_batch_data_sync_handler",
job_queue=(
batch_job_queue
if isinstance(batch_job_queue, str)
else batch_job_queue.job_queue_name
),
bucket_name=scaffolding_bucket.bucket_name,
memory=1024,
vcpus=1,
mount_point_configs=list(mount_point_configs) if mount_point_configs else None,
).enclose(result_path=f"$.tasks.{prep_batch_sync_task_name}.response")

batch_sync_map_state = sfn.Map(
self,
f"{id}: Batch Data Sync: Map Start",
comment="Runs requests for batch sync in parallel",
items_path=f"$.tasks.{prep_batch_sync_task_name}.response.requests",
result_path=sfn.JsonPath.DISCARD,
)

batch_sync_map_state.iterator(
BatchInvokedLambdaFunction(
scope=scope,
id=f"{id}: Batch Data Sync",
env_base=env_base,
name="batch-data-sync",
payload_path="$.requests",
image=(
aibs_informatics_docker_asset
if isinstance(aibs_informatics_docker_asset, str)
else aibs_informatics_docker_asset.image_uri
),
handler="aibs_informatics_aws_lambda.handlers.data_sync.batch_data_sync_handler",
job_queue=(
batch_job_queue
if isinstance(batch_job_queue, str)
else batch_job_queue.job_queue_name
),
bucket_name=scaffolding_bucket.bucket_name,
memory=2048,
vcpus=1,
mount_point_configs=list(mount_point_configs) if mount_point_configs else None,
)
)
# fmt: off
self.definition = (
start_pass_state
.next(prep_batch_sync)
.next(batch_sync_map_state)
)
# fmt: on
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import constructs
from aibs_informatics_core.env import EnvBase
from aibs_informatics_core.utils.tools.dicttools import remove_null_values
from aws_cdk import aws_batch_alpha as batch
from aws_cdk import aws_ecr_assets as ecr_assets
from aws_cdk import aws_iam as iam
Expand Down Expand Up @@ -35,6 +36,8 @@ def __init__(
data_sync_state_machine: sfn.StateMachine,
shared_mount_point_config: Optional[MountPointConfiguration],
scratch_mount_point_config: Optional[MountPointConfiguration],
tmp_mount_point_config: Optional[MountPointConfiguration] = None,
context_manager_configuration: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(scope, id, env_base)

Expand Down Expand Up @@ -72,7 +75,7 @@ def __init__(
file_system_configurations = {}

# Update arguments with mount points and volumes if provided
if shared_mount_point_config or scratch_mount_point_config:
if shared_mount_point_config or scratch_mount_point_config or tmp_mount_point_config:
mount_points = []
volumes = []
if shared_mount_point_config:
Expand Down Expand Up @@ -104,18 +107,34 @@ def __init__(
volumes.append(
scratch_mount_point_config.to_batch_volume("scratch", sfn_format=True)
)
if tmp_mount_point_config:
# update file system configurations for scaffolding function
file_system_configurations["tmp"] = {
"file_system": tmp_mount_point_config.file_system_id,
"access_point": tmp_mount_point_config.access_point_id,
"container_path": tmp_mount_point_config.mount_point,
}
# add to mount point and volumes list for batch invoked lambda functions
mount_points.append(
tmp_mount_point_config.to_batch_mount_point("tmp", sfn_format=True)
)
volumes.append(tmp_mount_point_config.to_batch_volume("tmp", sfn_format=True))

batch_invoked_lambda_kwargs["mount_points"] = mount_points
batch_invoked_lambda_kwargs["volumes"] = volumes

request = {
"demand_execution": sfn.JsonPath.object_at("$"),
"file_system_configurations": file_system_configurations,
}
if context_manager_configuration:
request["context_manager_configuration"] = context_manager_configuration

start_state = sfn.Pass(
self,
f"Start Demand Batch Task",
parameters={
"request": {
"demand_execution": sfn.JsonPath.object_at("$"),
"file_system_configurations": file_system_configurations,
}
"request": request,
},
)

Expand Down

0 comments on commit d74c575

Please sign in to comment.