diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 80acd024..5186f083 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,13 +29,13 @@ repos: .*\.ipynb )$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.4.8" + rev: "v0.5.2" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] - id: ruff-format - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.13.0 + rev: v2.14.0 hooks: - id: pretty-format-toml args: [--autofix, --no-sort] diff --git a/pyproject.toml b/pyproject.toml index 9c9b671c..7a0dce0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,11 +55,13 @@ tests = [ "dvclive[image,plots,markdown]", "ipython", "pytest_voluptuous", - "dpath" + "dpath", + "transformers[torch]", + "tf-keras" ] dev = [ "dvclive[all,tests]", - "mypy==1.10.0", + "mypy==1.11.0", "types-PyYAML" ] mmcv = ["mmcv"] @@ -67,12 +69,11 @@ tf = ["tensorflow"] xgb = ["xgboost"] lgbm = ["lightgbm"] huggingface = ["transformers", "datasets"] -catalyst = ["catalyst>22"] fastai = ["fastai"] lightning = ["lightning>=2.0", "torch", "jsonargparse[signatures]>=4.26.1"] optuna = ["optuna"] all = [ - "dvclive[image,mmcv,tf,xgb,lgbm,huggingface,catalyst,fastai,lightning,optuna,plots,markdown]" + "dvclive[image,mmcv,tf,xgb,lgbm,huggingface,fastai,lightning,optuna,plots,markdown]" ] [project.urls] @@ -139,7 +140,7 @@ files = ["src", "tests"] ignore-words-list = "fpr" [tool.ruff.lint] -ignore = ["N818", "UP006", "UP007", "UP035", "UP038", "B905", "PGH003"] +ignore = ["N818", "UP006", "UP007", "UP035", "UP038", "B905", "PGH003", "SIM103"] select = ["F", "E", "W", "C90", "N", "UP", "YTT", "S", "BLE", "B", "A", "C4", "T10", "EXE", "ISC", "INP", "PIE", "T20", "PT", "Q", "RSE", "RET", "SLF", "SIM", "TID", "TCH", "INT", "ARG", "PGH", "PL", "TRY", "NPY", "RUF"] [tool.ruff.lint.per-file-ignores] diff --git a/src/dvclive/catalyst.py b/src/dvclive/catalyst.py deleted file mode 100644 index 532d31b3..00000000 --- a/src/dvclive/catalyst.py +++ /dev/null @@ -1,23 +0,0 @@ -# ruff: noqa: ARG002 -from typing import Optional - -from catalyst.core.callback import Callback, CallbackOrder - -from dvclive import Live - - -class DVCLiveCallback(Callback): - def __init__(self, live: Optional[Live] = None, **kwargs): - super().__init__(order=CallbackOrder.external) - self.live = live if live is not None else Live(**kwargs) - - def on_epoch_end(self, runner) -> None: - for loader_key, per_loader_metrics in runner.epoch_metrics.items(): - for key, value in per_loader_metrics.items(): - self.live.log_metric( - f"{loader_key}/{key.replace('/', '_')}", float(value) - ) - self.live.next_step() - - def on_experiment_end(self, runner): - self.live.end() diff --git a/src/dvclive/fabric.py b/src/dvclive/fabric.py index 1b99b097..0e0903df 100644 --- a/src/dvclive/fabric.py +++ b/src/dvclive/fabric.py @@ -132,7 +132,9 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: } # logging of argparse.Namespace is not supported, sanitize as string - params = {k: str(v) if type(v) == Namespace else v for k, v in params.items()} + params = { + k: str(v) if isinstance(v, Namespace) else v for k, v in params.items() + } return params # noqa: RET504 diff --git a/src/dvclive/live.py b/src/dvclive/live.py index 8f49fc99..c0b4aa81 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -562,7 +562,7 @@ def log_plot( name: str, datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]], x: str, - y: str, + y: Union[str, list[str]], template: Optional[str] = "linear", title: Optional[str] = None, x_label: Optional[str] = None, @@ -579,7 +579,8 @@ def log_plot( datapoints (pd.DataFrame | np.ndarray | List[Dict]): Pandas DataFrame, Numpy Array or List of dictionaries containing the data for the plot. x (str): name of the key (present in the dictionaries) to use as the x axis. - y (str): name of the key (present in the dictionaries) to use the y axis. + y (str | list[str]): name of the key or keys (present in the + dictionaries) to use the y axis. template (str): name of the `DVC plots template` to use. Defaults to `"linear"`. title (str): title to be displayed. Defaults to diff --git a/src/dvclive/plots/custom.py b/src/dvclive/plots/custom.py index f7e55563..0ea15272 100644 --- a/src/dvclive/plots/custom.py +++ b/src/dvclive/plots/custom.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Optional, Union from dvclive.serialize import dump_json @@ -15,7 +15,7 @@ def __init__( name: str, output_folder: str, x: str, - y: str, + y: Union[str, list[str]], template: Optional[str], title: Optional[str] = None, x_label: Optional[str] = None, diff --git a/src/dvclive/plots/utils.py b/src/dvclive/plots/utils.py index 80675dc9..ed9db22c 100644 --- a/src/dvclive/plots/utils.py +++ b/src/dvclive/plots/utils.py @@ -1,7 +1,6 @@ import json NUMPY_INTS = [ - "int_", "intc", "intp", "int8", @@ -13,7 +12,7 @@ "uint32", "uint64", ] -NUMPY_FLOATS = ["float_", "float16", "float32", "float64"] +NUMPY_FLOATS = ["float16", "float32", "float64"] NUMPY_SCALARS = NUMPY_INTS + NUMPY_FLOATS diff --git a/tests/frameworks/test_catalyst.py b/tests/frameworks/test_catalyst.py deleted file mode 100644 index bb195986..00000000 --- a/tests/frameworks/test_catalyst.py +++ /dev/null @@ -1,92 +0,0 @@ -# ruff: noqa: N806 -import os - -import pytest - -from dvclive import Live -from dvclive.plots import Metric - -try: - import catalyst - import torch - from catalyst import dl - - from dvclive.catalyst import DVCLiveCallback -except ImportError: - pytest.skip("skipping catalyst tests", allow_module_level=True) - - -@pytest.fixture() -def runner(): - return dl.SupervisedRunner( - engine=catalyst.utils.torch.get_available_engine(cpu=True), - input_key="features", - output_key="logits", - target_key="targets", - loss_key="loss", - ) - - -# see: -# https://github.com/catalyst-team/catalyst/blob/e99f9/tests/catalyst/callbacks/test_batch_overfit.py -@pytest.fixture() -def runner_params(): - from torch.utils.data import DataLoader, TensorDataset - - catalyst.utils.set_global_seed(42) - num_samples, num_features = int(32e1), int(1e1) - X, y = torch.rand(num_samples, num_features), torch.rand(num_samples) - dataset = TensorDataset(X, y) - loader = DataLoader(dataset, batch_size=32, num_workers=0) - loaders = {"train": loader, "valid": loader} - - model = torch.nn.Linear(num_features, 1) - criterion = torch.nn.MSELoss() - optimizer = torch.optim.Adam(model.parameters()) - scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [3, 6]) - return { - "model": model, - "criterion": criterion, - "optimizer": optimizer, - "scheduler": scheduler, - "loaders": loaders, - } - - -def test_catalyst_callback(tmp_dir, runner, runner_params, mocker): - callback = DVCLiveCallback() - live = callback.live - spy = mocker.spy(live, "end") - - runner.train( - **runner_params, - num_epochs=2, - callbacks=[ - dl.AccuracyCallback(input_key="logits", target_key="targets"), - callback, - ], - logdir="./logs", - valid_loader="valid", - valid_metric="loss", - minimize_valid_metric=True, - verbose=True, - load_best_on_end=True, - ) - spy.assert_called_once() - - assert os.path.exists("dvclive") - - train_path = tmp_dir / "dvclive" / "plots" / Metric.subfolder / "train" - valid_path = tmp_dir / "dvclive" / "plots" / Metric.subfolder / "valid" - - assert train_path.is_dir() - assert valid_path.is_dir() - assert any("accuracy" in x.name for x in train_path.iterdir()) - assert any("accuracy" in x.name for x in valid_path.iterdir()) - - -def test_catalyst_pass_logger(): - logger = Live("train_logs") - - assert DVCLiveCallback().live is not logger - assert DVCLiveCallback(live=logger).live is logger diff --git a/tests/frameworks/test_huggingface.py b/tests/frameworks/test_huggingface.py index ac824588..18b39901 100644 --- a/tests/frameworks/test_huggingface.py +++ b/tests/frameworks/test_huggingface.py @@ -103,6 +103,7 @@ def args(): num_train_epochs=2, save_strategy="epoch", report_to="none", # Disable auto-reporting to avoid duplication + use_cpu=True, ) @@ -173,7 +174,7 @@ def test_huggingface_log_model( trainer.train() expected_call_count = { - "all": 2, + "all": 3, True: 1, False: 0, None: 0, diff --git a/tests/plots/test_custom.py b/tests/plots/test_custom.py index 17b18f2b..349c726a 100644 --- a/tests/plots/test_custom.py +++ b/tests/plots/test_custom.py @@ -29,3 +29,30 @@ def test_log_custom_plot(tmp_dir): "x_label": "x_label", "y_label": "y_label", } + + +def test_log_custom_plot_multi_y(tmp_dir): + live = Live() + out = tmp_dir / live.plots_dir / CustomPlot.subfolder + + datapoints = [{"x": 1, "y1": 2, "y2": 3}, {"x": 4, "y1": 5, "y2": 6}] + live.log_plot( + "custom_linear", + datapoints, + x="x", + y=["y1", "y2"], + template="linear", + title="custom_title", + x_label="x_label", + y_label="y_label", + ) + + assert json.loads((out / "custom_linear.json").read_text()) == datapoints + assert live._plots["custom_linear"].plot_config == { + "template": "linear", + "title": "custom_title", + "x": "x", + "y": ["y1", "y2"], + "x_label": "x_label", + "y_label": "y_label", + }