Skip to content

Commit

Permalink
Merge branch 'main' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Berenbaum authored Aug 5, 2024
2 parents fa28c30 + 0f5715b commit ceeef21
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 130 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ repos:
.*\.ipynb
)$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.4.8"
rev: "v0.5.2"
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.13.0
rev: v2.14.0
hooks:
- id: pretty-format-toml
args: [--autofix, --no-sort]
Expand Down
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,24 +55,25 @@ tests = [
"dvclive[image,plots,markdown]",
"ipython",
"pytest_voluptuous",
"dpath"
"dpath",
"transformers[torch]",
"tf-keras"
]
dev = [
"dvclive[all,tests]",
"mypy==1.10.0",
"mypy==1.11.0",
"types-PyYAML"
]
mmcv = ["mmcv"]
tf = ["tensorflow"]
xgb = ["xgboost"]
lgbm = ["lightgbm"]
huggingface = ["transformers", "datasets"]
catalyst = ["catalyst>22"]
fastai = ["fastai"]
lightning = ["lightning>=2.0", "torch", "jsonargparse[signatures]>=4.26.1"]
optuna = ["optuna"]
all = [
"dvclive[image,mmcv,tf,xgb,lgbm,huggingface,catalyst,fastai,lightning,optuna,plots,markdown]"
"dvclive[image,mmcv,tf,xgb,lgbm,huggingface,fastai,lightning,optuna,plots,markdown]"
]

[project.urls]
Expand Down Expand Up @@ -139,7 +140,7 @@ files = ["src", "tests"]
ignore-words-list = "fpr"

[tool.ruff.lint]
ignore = ["N818", "UP006", "UP007", "UP035", "UP038", "B905", "PGH003"]
ignore = ["N818", "UP006", "UP007", "UP035", "UP038", "B905", "PGH003", "SIM103"]
select = ["F", "E", "W", "C90", "N", "UP", "YTT", "S", "BLE", "B", "A", "C4", "T10", "EXE", "ISC", "INP", "PIE", "T20", "PT", "Q", "RSE", "RET", "SLF", "SIM", "TID", "TCH", "INT", "ARG", "PGH", "PL", "TRY", "NPY", "RUF"]

[tool.ruff.lint.per-file-ignores]
Expand Down
23 changes: 0 additions & 23 deletions src/dvclive/catalyst.py

This file was deleted.

4 changes: 3 additions & 1 deletion src/dvclive/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]:
}

# logging of argparse.Namespace is not supported, sanitize as string
params = {k: str(v) if type(v) == Namespace else v for k, v in params.items()}
params = {
k: str(v) if isinstance(v, Namespace) else v for k, v in params.items()
}

return params # noqa: RET504

Expand Down
5 changes: 3 additions & 2 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def log_plot(
name: str,
datapoints: Union[pd.DataFrame, np.ndarray, List[Dict]],
x: str,
y: str,
y: Union[str, list[str]],
template: Optional[str] = "linear",
title: Optional[str] = None,
x_label: Optional[str] = None,
Expand All @@ -579,7 +579,8 @@ def log_plot(
datapoints (pd.DataFrame | np.ndarray | List[Dict]): Pandas DataFrame, Numpy
Array or List of dictionaries containing the data for the plot.
x (str): name of the key (present in the dictionaries) to use as the x axis.
y (str): name of the key (present in the dictionaries) to use the y axis.
y (str | list[str]): name of the key or keys (present in the
dictionaries) to use the y axis.
template (str): name of the `DVC plots template` to use. Defaults to
`"linear"`.
title (str): title to be displayed. Defaults to
Expand Down
4 changes: 2 additions & 2 deletions src/dvclive/plots/custom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional
from typing import Optional, Union

from dvclive.serialize import dump_json

Expand All @@ -15,7 +15,7 @@ def __init__(
name: str,
output_folder: str,
x: str,
y: str,
y: Union[str, list[str]],
template: Optional[str],
title: Optional[str] = None,
x_label: Optional[str] = None,
Expand Down
3 changes: 1 addition & 2 deletions src/dvclive/plots/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import json

NUMPY_INTS = [
"int_",
"intc",
"intp",
"int8",
Expand All @@ -13,7 +12,7 @@
"uint32",
"uint64",
]
NUMPY_FLOATS = ["float_", "float16", "float32", "float64"]
NUMPY_FLOATS = ["float16", "float32", "float64"]
NUMPY_SCALARS = NUMPY_INTS + NUMPY_FLOATS


Expand Down
92 changes: 0 additions & 92 deletions tests/frameworks/test_catalyst.py

This file was deleted.

3 changes: 2 additions & 1 deletion tests/frameworks/test_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def args():
num_train_epochs=2,
save_strategy="epoch",
report_to="none", # Disable auto-reporting to avoid duplication
use_cpu=True,
)


Expand Down Expand Up @@ -173,7 +174,7 @@ def test_huggingface_log_model(
trainer.train()

expected_call_count = {
"all": 2,
"all": 3,
True: 1,
False: 0,
None: 0,
Expand Down
27 changes: 27 additions & 0 deletions tests/plots/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,30 @@ def test_log_custom_plot(tmp_dir):
"x_label": "x_label",
"y_label": "y_label",
}


def test_log_custom_plot_multi_y(tmp_dir):
live = Live()
out = tmp_dir / live.plots_dir / CustomPlot.subfolder

datapoints = [{"x": 1, "y1": 2, "y2": 3}, {"x": 4, "y1": 5, "y2": 6}]
live.log_plot(
"custom_linear",
datapoints,
x="x",
y=["y1", "y2"],
template="linear",
title="custom_title",
x_label="x_label",
y_label="y_label",
)

assert json.loads((out / "custom_linear.json").read_text()) == datapoints
assert live._plots["custom_linear"].plot_config == {
"template": "linear",
"title": "custom_title",
"x": "x",
"y": ["y1", "y2"],
"x_label": "x_label",
"y_label": "y_label",
}

0 comments on commit ceeef21

Please sign in to comment.