Skip to content

Commit

Permalink
update fit_model for new prediction method with X_prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed Jul 15, 2024
1 parent 792a669 commit 38d81c5
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 7 deletions.
78 changes: 72 additions & 6 deletions ml_garden/core/model_registry.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,71 @@
import importlib
import logging
import pkgutil
from typing import Dict, Type

from ml_garden.core.model import Model


class ModelClassNotFoundError(Exception):
"""Exception raised when a model class is not found in the registry."""

pass


class ModelRegistry:
def __init__(self):
self._model_registry = {}
"""
Initialize a new ModelRegistry instance.
Attributes
----------
_model_registry : dict
A dictionary mapping model names to model classes.
logger : logging.Logger
Logger for the class.
"""
self._model_registry: Dict[str, Type[Model]] = {}
self.logger = logging.getLogger(__name__)

def register_model(self, model_class: type):
model_name = model_class.__name__
def register_model(self, model_class: Type[Model]) -> None:
"""
Register a model class in the registry.
Parameters
----------
model_class : Type[Model]
The model class to be registered.
Raises
------
ValueError
If the model_class is not a subclass of Model.
"""
model_name = model_class.__name__.lower()
if not issubclass(model_class, Model):
raise ValueError(f"{model_class} must be a subclass of Model")
self._model_registry[model_name] = model_class

def get_model_class(self, model_name: str) -> type:
def get_model_class(self, model_name: str) -> Type[Model]:
"""
Retrieve a model class from the registry.
Parameters
----------
model_name : str
The name of the model class to retrieve.
Returns
-------
Type[Model]
The model class.
Raises
------
ModelClassNotFoundError
If the model class is not found in the registry.
"""
model_name = model_name.lower()
if model_name in self._model_registry:
return self._model_registry[model_name]
else:
Expand All @@ -29,10 +74,31 @@ def get_model_class(self, model_name: str) -> type:
f" {list(self._model_registry.keys())}"
)

def get_all_model_classes(self) -> dict:
def get_all_model_classes(self) -> Dict[str, Type[Model]]:
"""
Get all registered model classes.
Returns
-------
dict
A dictionary of all registered model classes.
"""
return self._model_registry

def auto_register_models_from_package(self, package_name: str):
def auto_register_models_from_package(self, package_name: str) -> None:
"""
Automatically register all model classes from a given package.
Parameters
----------
package_name : str
The name of the package to search for model classes.
Raises
------
ImportError
If the package cannot be imported.
"""
try:
package = importlib.import_module(package_name)
prefix = package.__name__ + "."
Expand Down
1 change: 1 addition & 0 deletions ml_garden/core/steps/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def predict(self, data: DataContainer) -> DataContainer:
The updated data container
"""
self.logger.info(f"Predicting with {self.model_class.__name__} model")
data.X_prediction = data.flow.drop(columns=data.columns_to_ignore_for_training)
data.flow[data.prediction_column] = data.model.predict(data.X_prediction)
data.predictions = data.flow[data.prediction_column]
return data
3 changes: 2 additions & 1 deletion ml_garden/implementation/tabular/autogluon/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pandas as pd
from autogluon.tabular import TabularPredictor

from ml_garden.core.constants import Task
from ml_garden.core.model import Model

logger = logging.getLogger(__file__)


class AutoGluon(Model):
TASKS = ["regression", "classification"]
TASKS = [Task.REGRESSION, Task.CLASSIFICATION]

def __init__(self, **params):
self.params = params
Expand Down

0 comments on commit 38d81c5

Please sign in to comment.