Skip to content

Commit

Permalink
updates to compute stack
Browse files Browse the repository at this point in the history
  • Loading branch information
rpmcginty committed Apr 19, 2024
1 parent d82f195 commit 392da29
Showing 1 changed file with 78 additions and 21 deletions.
99 changes: 78 additions & 21 deletions src/aibs_informatics_cdk_lib/stacks/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from email.policy import default
from typing import Any, Iterable, List, Optional, TypeVar, Union
from urllib import request

Expand Down Expand Up @@ -36,7 +37,11 @@
from aibs_informatics_cdk_lib.constructs_.batch.launch_template import BatchLaunchTemplateBuilder
from aibs_informatics_cdk_lib.constructs_.batch.types import BatchEnvironmentDescriptor
from aibs_informatics_cdk_lib.constructs_.ec2 import EnvBaseVpc
from aibs_informatics_cdk_lib.constructs_.efs.file_system import EFSEcosystem, EnvBaseFileSystem
from aibs_informatics_cdk_lib.constructs_.efs.file_system import (
EFSEcosystem,
EnvBaseFileSystem,
MountPointConfiguration,
)
from aibs_informatics_cdk_lib.constructs_.s3 import EnvBaseBucket, LifecycleRuleGenerator
from aibs_informatics_cdk_lib.constructs_.sfn.fragments.batch import SubmitJobFragment
from aibs_informatics_cdk_lib.stacks.base import EnvBaseStack
Expand Down Expand Up @@ -97,7 +102,10 @@ def __init__(
env_base: EnvBase,
vpc: ec2.Vpc,
buckets: Optional[Iterable[s3.Bucket]] = None,
file_systems: Optional[Iterable[efs.FileSystem]] = None,
file_systems: Optional[Iterable[Union[efs.FileSystem, efs.IFileSystem]]] = None,
mount_point_configs: Optional[Iterable[MountPointConfiguration]] = None,
create_state_machine: bool = True,
state_machine_name: Optional[str] = "submit-job",
**kwargs,
) -> None:
super().__init__(scope, id, env_base, **kwargs)
Expand All @@ -110,13 +118,26 @@ def __init__(

file_system_list = list(file_systems or [])

if mount_point_configs:
mount_point_config_list = list(mount_point_configs)
file_system_list = self._update_file_systems_from_mount_point_configs(
file_system_list, mount_point_config_list
)
else:
mount_point_config_list = self._get_mount_point_configs(file_system_list)

# Validation to ensure that the file systems are not duplicated
self._validate_mount_point_configs(mount_point_config_list)

self.grant_storage_access(*bucket_list, *file_system_list)

self.create_step_functions(file_system=file_system_list[0] if file_system_list else None)
self.create_step_functions(
name=state_machine_name, mount_point_configs=mount_point_config_list
)

self.export_values()

def grant_storage_access(self, *resources: Union[s3.Bucket, efs.FileSystem]):
def grant_storage_access(self, *resources: Union[s3.Bucket, efs.FileSystem, efs.IFileSystem]):
self.batch.grant_instance_role_permissions(read_write_resources=list(resources))

for batch_environment in self.batch.environments:
Expand Down Expand Up @@ -162,9 +183,12 @@ def create_batch_environments(self):
launch_template_builder=lt_builder,
)

def create_step_functions(self, file_system: Optional[efs.FileSystem] = None):

state_machine_core_name = "submit-job"
def create_step_functions(
self,
name: Optional[str] = None,
mount_point_configs: Optional[list[MountPointConfiguration]] = None,
):
state_machine_core_name = name or "submit-job"
defaults: dict[str, Any] = {}
defaults["command"] = []
defaults["job_queue"] = self.on_demand_batch_environment.job_queue.job_queue_arn
Expand All @@ -174,23 +198,14 @@ def create_step_functions(self, file_system: Optional[efs.FileSystem] = None):
defaults["gpu"] = "0"
defaults["platform_capabilities"] = ["EC2"]

if file_system:
file_system.file_system_id
if mount_point_configs:
defaults["mount_points"] = [
convert_key_case(to_mount_point("/opt/efs", False, "efs-root-volume"), pascalcase)
convert_key_case(mpc.to_batch_mount_point(f"efs-vol{i}"), pascalcase)
for i, mpc in enumerate(mount_point_configs)
]
defaults["volumes"] = [
convert_key_case(
to_volume(
None,
"efs-root-volume",
{
"fileSystemId": file_system.file_system_id,
"rootDirectory": "/",
},
),
pascalcase,
)
convert_key_case(mpc.to_batch_volume(f"efs-vol{i}"), pascalcase)
for i, mpc in enumerate(mount_point_configs)
]

start = sfn.Pass(
Expand Down Expand Up @@ -273,3 +288,45 @@ def export_values(self) -> None:
self.export_value(self.on_demand_batch_environment.job_queue.job_queue_arn)
self.export_value(self.spot_batch_environment.job_queue.job_queue_arn)
self.export_value(self.fargate_batch_environment.job_queue.job_queue_arn)

## Private methods

def _validate_mount_point_configs(self, mount_point_configs: List[MountPointConfiguration]):
_ = {}
for mpc in mount_point_configs:
if mpc.mount_point in _ and _[mpc.mount_point] != mpc:
raise ValueError(
f"Mount point {mpc.mount_point} is duplicated. "
"Cannot have multiple mount points configurations with the same name."
)
_[mpc.mount_point] = mpc

def _get_mount_point_configs(
self, file_systems: Optional[List[Union[efs.FileSystem, efs.IFileSystem]]]
) -> List[MountPointConfiguration]:
mount_point_configs = []
if file_systems:
for fs in file_systems:
mount_point_configs.append(MountPointConfiguration.from_file_system(fs))
return mount_point_configs

def _update_file_systems_from_mount_point_configs(
self,
file_systems: List[Union[efs.FileSystem, efs.IFileSystem]],
mount_point_configs: List[MountPointConfiguration],
) -> List[Union[efs.FileSystem, efs.IFileSystem]]:
file_system_map: dict[str, Union[efs.FileSystem, efs.IFileSystem]] = {
fs.file_system_id: fs for fs in file_systems
}
for mpc in mount_point_configs:
if mpc.file_system_id not in file_system_map:
if not mpc.file_system and mpc.access_point:
file_system_map[mpc.file_system_id] = mpc.access_point.file_system
elif mpc.file_system:
file_system_map[mpc.file_system_id] = mpc.file_system
else:
raise ValueError(
"Mount point configuration must have a file system or access point."
)

return list(file_system_map.values())

0 comments on commit 392da29

Please sign in to comment.