Skip to content

Commit

Permalink
Refactor model_registry and step_registry tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ovejabu committed Jun 5, 2024
1 parent f0da701 commit 44ec38d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 12 deletions.
6 changes: 3 additions & 3 deletions tests/core/test_model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_get_all_model_classes():
assert AnotherDummyModel in all_models.values()


@patch("pipeline_lib.core.model_registry.pkgutil.walk_packages")
@patch("pipeline_lib.core.model_registry.importlib.import_module")
@patch("ml_garden.core.model_registry.pkgutil.walk_packages")
@patch("ml_garden.core.model_registry.importlib.import_module")
def test_auto_register_models_from_package(mock_import_module, mock_walk_packages):
mock_package = MagicMock()
mock_package.__name__ = "package"
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_auto_register_models_from_package(mock_import_module, mock_walk_package
assert AnotherDummyModel in registry.get_all_model_classes().values()


@patch("pipeline_lib.core.model_registry.importlib.import_module")
@patch("ml_garden.core.model_registry.importlib.import_module")
def test_auto_register_models_import_error(mock_import_module):
mock_import_module.side_effect = ImportError

Expand Down
18 changes: 9 additions & 9 deletions tests/core/test_step_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_get_all_step_classes():
assert AnotherDummyStep in all_steps.values()


@patch("pipeline_lib.core.step_registry.pkgutil.walk_packages")
@patch("pipeline_lib.core.step_registry.importlib.import_module")
@patch("ml_garden.core.step_registry.pkgutil.walk_packages")
@patch("ml_garden.core.step_registry.importlib.import_module")
def test_auto_register_steps_from_package(mock_import_module, mock_walk_packages):
mock_package = MagicMock()
mock_package.__name__ = "package"
Expand Down Expand Up @@ -79,7 +79,7 @@ def test_auto_register_steps_from_package(mock_import_module, mock_walk_packages
assert AnotherDummyStep in registry.get_all_step_classes().values()


@patch("pipeline_lib.core.step_registry.importlib.import_module")
@patch("ml_garden.core.step_registry.importlib.import_module")
def test_auto_register_steps_import_error(mock_import_module):
mock_import_module.side_effect = ImportError

Expand All @@ -89,9 +89,9 @@ def test_auto_register_steps_import_error(mock_import_module):
assert len(registry.get_all_step_classes()) == 0


@patch("pipeline_lib.core.step_registry.os.listdir")
@patch("pipeline_lib.core.step_registry.importlib.util.spec_from_file_location")
@patch("pipeline_lib.core.step_registry.importlib.util.module_from_spec")
@patch("ml_garden.core.step_registry.os.listdir")
@patch("ml_garden.core.step_registry.importlib.util.spec_from_file_location")
@patch("ml_garden.core.step_registry.importlib.util.module_from_spec")
def test_load_and_register_custom_steps(
mock_module_from_spec, mock_spec_from_file_location, mock_listdir
):
Expand All @@ -111,9 +111,9 @@ def test_load_and_register_custom_steps(
assert DummyStep in registry.get_all_step_classes().values()


@patch("pipeline_lib.core.step_registry.os.listdir")
@patch("pipeline_lib.core.step_registry.importlib.util.spec_from_file_location")
@patch("pipeline_lib.core.step_registry.importlib.util.module_from_spec")
@patch("ml_garden.core.step_registry.os.listdir")
@patch("ml_garden.core.step_registry.importlib.util.spec_from_file_location")
@patch("ml_garden.core.step_registry.importlib.util.module_from_spec")
def test_load_and_register_custom_steps_exception(
mock_module_from_spec, mock_spec_from_file_location, mock_listdir
):
Expand Down

0 comments on commit 44ec38d

Please sign in to comment.