From 38d81c5fcdab9e4450d1996b8a1ecaa5b34cca80 Mon Sep 17 00:00:00 2001 From: Diego Marvid Date: Mon, 15 Jul 2024 16:54:40 -0300 Subject: [PATCH] update fit_model for new prediction method with X_prediction --- ml_garden/core/model_registry.py | 78 +++++++++++++++++-- ml_garden/core/steps/fit_model.py | 1 + .../implementation/tabular/autogluon/model.py | 3 +- 3 files changed, 75 insertions(+), 7 deletions(-) diff --git a/ml_garden/core/model_registry.py b/ml_garden/core/model_registry.py index 409f268..cb51eb9 100644 --- a/ml_garden/core/model_registry.py +++ b/ml_garden/core/model_registry.py @@ -1,26 +1,71 @@ import importlib import logging import pkgutil +from typing import Dict, Type from ml_garden.core.model import Model class ModelClassNotFoundError(Exception): + """Exception raised when a model class is not found in the registry.""" + pass class ModelRegistry: def __init__(self): - self._model_registry = {} + """ + Initialize a new ModelRegistry instance. + + Attributes + ---------- + _model_registry : dict + A dictionary mapping model names to model classes. + logger : logging.Logger + Logger for the class. + """ + self._model_registry: Dict[str, Type[Model]] = {} self.logger = logging.getLogger(__name__) - def register_model(self, model_class: type): - model_name = model_class.__name__ + def register_model(self, model_class: Type[Model]) -> None: + """ + Register a model class in the registry. + + Parameters + ---------- + model_class : Type[Model] + The model class to be registered. + + Raises + ------ + ValueError + If the model_class is not a subclass of Model. + """ + model_name = model_class.__name__.lower() if not issubclass(model_class, Model): raise ValueError(f"{model_class} must be a subclass of Model") self._model_registry[model_name] = model_class - def get_model_class(self, model_name: str) -> type: + def get_model_class(self, model_name: str) -> Type[Model]: + """ + Retrieve a model class from the registry. + + Parameters + ---------- + model_name : str + The name of the model class to retrieve. + + Returns + ------- + Type[Model] + The model class. + + Raises + ------ + ModelClassNotFoundError + If the model class is not found in the registry. + """ + model_name = model_name.lower() if model_name in self._model_registry: return self._model_registry[model_name] else: @@ -29,10 +74,31 @@ def get_model_class(self, model_name: str) -> type: f" {list(self._model_registry.keys())}" ) - def get_all_model_classes(self) -> dict: + def get_all_model_classes(self) -> Dict[str, Type[Model]]: + """ + Get all registered model classes. + + Returns + ------- + dict + A dictionary of all registered model classes. + """ return self._model_registry - def auto_register_models_from_package(self, package_name: str): + def auto_register_models_from_package(self, package_name: str) -> None: + """ + Automatically register all model classes from a given package. + + Parameters + ---------- + package_name : str + The name of the package to search for model classes. + + Raises + ------ + ImportError + If the package cannot be imported. + """ try: package = importlib.import_module(package_name) prefix = package.__name__ + "." diff --git a/ml_garden/core/steps/fit_model.py b/ml_garden/core/steps/fit_model.py index f57de08..883c131 100644 --- a/ml_garden/core/steps/fit_model.py +++ b/ml_garden/core/steps/fit_model.py @@ -298,6 +298,7 @@ def predict(self, data: DataContainer) -> DataContainer: The updated data container """ self.logger.info(f"Predicting with {self.model_class.__name__} model") + data.X_prediction = data.flow.drop(columns=data.columns_to_ignore_for_training) data.flow[data.prediction_column] = data.model.predict(data.X_prediction) data.predictions = data.flow[data.prediction_column] return data diff --git a/ml_garden/implementation/tabular/autogluon/model.py b/ml_garden/implementation/tabular/autogluon/model.py index bb58450..3164acd 100644 --- a/ml_garden/implementation/tabular/autogluon/model.py +++ b/ml_garden/implementation/tabular/autogluon/model.py @@ -3,13 +3,14 @@ import pandas as pd from autogluon.tabular import TabularPredictor +from ml_garden.core.constants import Task from ml_garden.core.model import Model logger = logging.getLogger(__file__) class AutoGluon(Model): - TASKS = ["regression", "classification"] + TASKS = [Task.REGRESSION, Task.CLASSIFICATION] def __init__(self, **params): self.params = params