Skip to content

Commit

Permalink
post data to studio on end (#738)
Browse files Browse the repository at this point in the history
* post data to studio on end

* fix test for dvc studio config
  • Loading branch information
Dave Berenbaum authored Nov 14, 2023
1 parent d2f861e commit 7d5c088
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ def end(self):

self.save_dvc_exp()

# Post any data that hasn't been sent
self.post_to_studio("data")
# Mark experiment as done
self.post_to_studio("done")

cleanup_dvclive_step_completed()
Expand Down
38 changes: 32 additions & 6 deletions tests/test_post_to_studio.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,24 @@ def test_post_to_studio_failed_start_request(
assert mocked_post.call_count == 1


def test_post_to_studio_end_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
def test_post_to_studio_done_only_once(tmp_dir, mocked_dvc_repo, mocked_studio_post):
mocked_post, _ = mocked_studio_post
with Live() as live:
live.log_metric("foo", 1)
live.next_step()

assert mocked_post.call_count == 4
expected_done_calls = [
call
for call in mocked_post.call_args_list
if call.kwargs["json"]["type"] == "done"
]
live.end()
assert mocked_post.call_count == 4
actual_done_calls = [
call
for call in mocked_post.call_args_list
if call.kwargs["json"]["type"] == "done"
]
assert expected_done_calls == actual_done_calls


@pytest.mark.studio()
Expand All @@ -157,7 +166,9 @@ def test_post_to_studio_skip_start_and_done_on_env_var(
live.log_metric("foo", 1)
live.next_step()

assert mocked_post.call_count == 2
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
assert "start" not in call_types
assert "done" not in call_types


@pytest.mark.studio()
Expand All @@ -169,14 +180,15 @@ def test_post_to_studio_dvc_studio_config(
monkeypatch.setenv(DVC_EXP_BASELINE_REV, "f" * 40)
monkeypatch.setenv(DVC_EXP_NAME, "bar")
monkeypatch.setenv(DVC_ROOT, tmp_dir)
monkeypatch.delenv(DVC_STUDIO_TOKEN)

mocked_dvc_repo.config = {"studio": {"token": "token"}}

with Live() as live:
live.log_metric("foo", 1)
live.next_step()

assert mocked_post.call_count == 2
assert mocked_post.call_args.kwargs["headers"]["Authorization"] == "token token"


@pytest.mark.studio()
Expand Down Expand Up @@ -236,7 +248,9 @@ def test_post_to_studio_inside_dvc_exp(
live.log_metric("foo", 1)
live.next_step()

assert mocked_post.call_count == 2
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
assert "start" not in call_types
assert "done" not in call_types


@pytest.mark.studio()
Expand Down Expand Up @@ -370,3 +384,15 @@ def test_post_to_studio_name(tmp_dir, mocked_dvc_repo, mocked_studio_post):
"https://0.0.0.0/api/live",
**get_studio_call("start", exp_name="custom-name"),
)


def test_post_to_studio_if_done_skipped(tmp_dir, mocked_dvc_repo, mocked_studio_post):
live = Live()
live._studio_events_to_skip.add("start")
live._studio_events_to_skip.add("done")
live.log_metric("foo", 1)
live.end()

mocked_post, _ = mocked_studio_post
call_types = [call.kwargs["json"]["type"] for call in mocked_post.call_args_list]
assert "data" in call_types

0 comments on commit 7d5c088

Please sign in to comment.