diff --git a/jhub_apps/config_utils.py b/jhub_apps/config_utils.py index de1cd772..e8ea378a 100644 --- a/jhub_apps/config_utils.py +++ b/jhub_apps/config_utils.py @@ -1,6 +1,64 @@ -from traitlets import Unicode, Union, List, Callable, Integer, Bool +import textwrap +import typing as t +from pydantic import BaseModel, ValidationError +from traitlets import Int, Unicode, Union, List, Callable, Integer, TraitType, TraitError from traitlets.config import SingletonConfigurable, Enum +from jhub_apps.service.models import StartupApp + + +class PydanticModelTrait(TraitType): + """A trait type for validating Pydantic models. + + This trait ensures that the input is an instance of a specific Pydantic model type. + """ + + def __init__(self, model_class: t.Type[BaseModel], *args, **kwargs): + """ + Initialize the trait with a specific Pydantic model class. + + Args: + model_class: The Pydantic model class to validate against + *args: Additional arguments for TraitType + **kwargs: Additional keyword arguments for TraitType + """ + super().__init__(*args, **kwargs) + self.model_class = model_class + self.info_text = f"an instance of {model_class.__name__}" + + def validate(self, obj: t.Any, value: t.Any) -> BaseModel: + """ + Validate that the input is an instance of the specified Pydantic model. + + Args: + obj: The object the trait is attached to + value: The value to validate + + Returns: + Validated Pydantic model instance + + Raises: + TraitError: If the value is not a valid instance of the model + """ + # If None is allowed and value is None, return None + if self.allow_none and value is None: + return None + + # Check if value is an instance of the specified model class + if isinstance(value, self.model_class): + return value + + # If not an instance, try to create an instance from a dict + if isinstance(value, dict): + try: + return self.model_class(**value) + except ValidationError as e: + # Convert Pydantic validation error to TraitError + raise TraitError(f'Could not parse input as a valid {self.model_class.__name__} Pydantic model:\n' + f'{textwrap.indent(str(e), prefix=" ")}') + + raise TraitError(f'Input must be a valid {self.model_class.__name__} Pydantic model or dict object, but got {value}.') + class JAppsConfig(SingletonConfigurable): apps_auth_type = Enum( @@ -49,12 +107,26 @@ class JAppsConfig(SingletonConfigurable): help="The number of workers to create for the JHub Apps FastAPI service", ).tag(config=True) - allowed_frameworks = Bool( + allowed_frameworks = List( None, help="Allow only a specific set of frameworks to spun up apps.", + allow_none=True, ).tag(config=True) - blocked_frameworks = Bool( + blocked_frameworks = List( None, help="Disallow a set of frameworks to avoid spinning up apps using those frameworks", + allow_none=True, + ).tag(config=True) + + # TODO: Remove this attribute + my_int_list = List( + trait=Int, + ).tag(config=True) + + startup_apps = List( + trait=PydanticModelTrait(StartupApp), + description="only add a server if it is not already created or edit an existing one to match the config, won't delete any servers", + default_value=[], + help="List of apps to start on JHub Apps Launcher startup", ).tag(config=True) diff --git a/jhub_apps/hub_client/hub_client.py b/jhub_apps/hub_client/hub_client.py index 66f19bc2..381c81b9 100644 --- a/jhub_apps/hub_client/hub_client.py +++ b/jhub_apps/hub_client/hub_client.py @@ -122,7 +122,7 @@ def get_user(self, user=None): return user @requires_user_token - def get_server(self, username, servername): + def get_server(self, username, servername=None): users = self.get_users() filter_given_user = [user for user in users if user["name"] == username] if not filter_given_user: @@ -130,11 +130,17 @@ def get_server(self, username, servername): return else: given_user = filter_given_user[0] - for name, server in given_user["servers"].items(): - if name == servername: - return server + + if servername: + for name, server in given_user["servers"].items(): + if name == servername: + return server + else: + # return all user servers + return given_user["servers"] - def normalize_server_name(self, servername): + @staticmethod + def normalize_server_name(servername): # Convert text to lowercase text = servername.lower() # Remove all special characters except spaces and hyphen diff --git a/jhub_apps/service/app.py b/jhub_apps/service/app.py index 1608b9f3..0b60a623 100644 --- a/jhub_apps/service/app.py +++ b/jhub_apps/service/app.py @@ -1,14 +1,22 @@ +from contextlib import asynccontextmanager import os from pathlib import Path +from itertools import groupby from fastapi import FastAPI from fastapi.staticfiles import StaticFiles +from typing import Any +from jhub_apps.hub_client.hub_client import HubClient from jhub_apps.service.japps_routes import router as japps_router from jhub_apps.service.logging_utils import setup_logging from jhub_apps.service.middlewares import create_middlewares from jhub_apps.service.routes import router +from jhub_apps.service.utils import get_jupyterhub_config from jhub_apps.version import get_version +import structlog + +logger = structlog.get_logger(__name__) ### When managed by Jupyterhub, the actual endpoints ### will be served out prefixed by /services/:name. @@ -17,6 +25,43 @@ STATIC_DIR = Path(__file__).parent.parent / "static" +@asynccontextmanager +async def lifespan(app: FastAPI): + config = get_jupyterhub_config() + startup_apps_list = config['JAppsConfig']['startup_apps'] + # group user options by username + grouped_user_options_list = groupby(startup_apps_list, lambda x: x.username) + for username, startup_apps_list in grouped_user_options_list: + instantiate_startup_apps(startup_apps_list, username=username) + + yield + +def instantiate_startup_apps(startup_apps_list: list[dict[str, Any]], username: str): + # TODO: Support defining app from git repo + hub_client = HubClient(username=username) + + existing_servers = hub_client.get_server(username=username) + + for startup_app in startup_apps_list: + user_options = startup_app.user_options + normalized_servername = startup_app.normalized_servername + if normalized_servername in existing_servers: + # update the server + logger.info(f"Updating server: {normalized_servername}") + hub_client.edit_server(username, normalized_servername, user_options) + else: + # create the server + logger.info(f"Creating server {normalized_servername}") + hub_client.create_server( + username=username, + servername=normalized_servername, + user_options=user_options, + ) + + # stop server after creation + hub_client.delete_server(username, normalized_servername, remove=False) + logger.info('Done instantiating apps') + app = FastAPI( title="JApps Service", version=str(get_version()), @@ -30,6 +75,7 @@ ### Default /docs/oauth2 redirect will cause Hub ### to raise oauth2 redirect uri mismatch errors # swagger_ui_oauth2_redirect_url=os.environ["JUPYTERHUB_OAUTH_CALLBACK_URL"], + lifespan=lifespan, ) static_files = StaticFiles(directory=STATIC_DIR) app.mount(f"{router.prefix}/static", static_files, name="static") diff --git a/jhub_apps/service/models.py b/jhub_apps/service/models.py index 42442b1f..321b1b5e 100644 --- a/jhub_apps/service/models.py +++ b/jhub_apps/service/models.py @@ -5,7 +5,6 @@ from pydantic import BaseModel - # https://jupyterhub.readthedocs.io/en/stable/_static/rest-api/index.html class Server(BaseModel): name: str @@ -89,3 +88,11 @@ class UserOptions(JHubAppConfig): class ServerCreation(BaseModel): servername: str user_options: UserOptions + + @property + def normalized_servername(self): + from jhub_apps.hub_client.hub_client import HubClient + return HubClient.normalize_server_name(self.servername) + +class StartupApp(ServerCreation): + username: str \ No newline at end of file diff --git a/jhub_apps/service/utils.py b/jhub_apps/service/utils.py index 4b67f3bd..011b2236 100644 --- a/jhub_apps/service/utils.py +++ b/jhub_apps/service/utils.py @@ -11,6 +11,7 @@ from jupyterhub.app import JupyterHub from traitlets.config import LazyConfigValue +from jhub_apps.config_utils import JAppsConfig from jhub_apps.hub_client.hub_client import HubClient from jhub_apps.service.models import UserOptions from jhub_apps.spawner.types import FrameworkConf, FRAMEWORKS_MAPPING, FRAMEWORKS @@ -28,6 +29,8 @@ def get_jupyterhub_config(): jhub_config_file_path = os.environ["JHUB_JUPYTERHUB_CONFIG"] logger.info(f"Getting JHub config from file: {jhub_config_file_path}") hub.load_config_file(jhub_config_file_path) + # hacky, but I couldn't figure out how to get validation of the config otherwise (In this case, validation converts the dict in the config to a Pydantic model) + hub.config.JAppsConfig.startup_apps = JAppsConfig(config=hub.config).startup_apps config = hub.config logger.info(f"JHub Apps config: {config.JAppsConfig}") return config @@ -102,7 +105,7 @@ async def get_spawner_profiles(config, auth_state=None): ) -def encode_file_to_data_url(filename, file_contents): +def encode_file_to_data_url(filename, file_contents) -> str: """Converts image file to data url to display in browser.""" base64_encoded = base64.b64encode(file_contents) filename_ = filename.lower() @@ -117,7 +120,7 @@ def encode_file_to_data_url(filename, file_contents): return data_url -def get_default_thumbnail(framework_name): +def get_default_thumbnail(framework_name) -> str: framework: FrameworkConf = FRAMEWORKS_MAPPING.get(framework_name) thumbnail_path = framework.logo_path return encode_file_to_data_url( diff --git a/jupyterhub_config.py b/jupyterhub_config.py index 700ebdea..17b43dde 100644 --- a/jupyterhub_config.py +++ b/jupyterhub_config.py @@ -17,6 +17,30 @@ c.JAppsConfig.jupyterhub_config_path = "jupyterhub_config.py" c.JAppsConfig.conda_envs = [] c.JAppsConfig.service_workers = 1 +c.JAppsConfig.my_int_list = [1] +c.JAppsConfig.startup_apps = [ + { + "username": "alice", + # TODO: Add a test case when servername is different from noralized_servername + "servername": "alice's-startup-server", + "user_options": { + "display_name": "Alice's Panel App", + "description": "description", + "thumbnail": "data:image/jpeg;base64,iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", + "filepath": "", + "framework": "panel", + "custom_command": "", + "public": False, + "keep_alive": False, + "env": {"ENV_VAR_KEY_1": "ENV_VAR_KEY_1"}, + "repository": None, + "jhub_app": True, + "conda_env": "", + "profile": "", + "share_with": {"users": ["admin"], "groups": ["class-A"]}, + }, + } +] c.JupyterHub.default_url = "/hub/home" c = install_jhub_apps(