diff --git a/src/dvclive/error.py b/src/dvclive/error.py index 790e64c3..d2e040c2 100644 --- a/src/dvclive/error.py +++ b/src/dvclive/error.py @@ -29,9 +29,8 @@ def __init__(self, name): class InvalidParameterTypeError(DvcLiveError): - def __init__(self, val: Any): - self.val = val - super().__init__(f"Parameter type {type(val)} is not supported.") + def __init__(self, msg: Any): + super().__init__(msg) class InvalidReportModeError(DvcLiveError): diff --git a/src/dvclive/live.py b/src/dvclive/live.py index d1971601..1c9b725c 100644 --- a/src/dvclive/live.py +++ b/src/dvclive/live.py @@ -454,7 +454,7 @@ def _dump_params(self): try: dump_yaml(self._params, self.params_file) except RepresenterError as exc: - raise InvalidParameterTypeError(exc.args) from exc + raise InvalidParameterTypeError(exc.args[0]) from exc def log_params(self, params: Dict[str, ParamLike]): """Saves the given set of parameters (dict) to yaml""" diff --git a/tests/test_log_param.py b/tests/test_log_param.py index bb008146..3c4e5ac6 100644 --- a/tests/test_log_param.py +++ b/tests/test_log_param.py @@ -84,5 +84,6 @@ class Dummy: param_value = Dummy() - with pytest.raises(InvalidParameterTypeError): + with pytest.raises(InvalidParameterTypeError) as excinfo: dvclive.log_param("param_complex", param_value) + assert "Dummy" in excinfo.value.args[0]