From 5ed34dea7a56a7d4c4fcf5a824d151155c7778c2 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 21 Sep 2023 11:03:06 +1000 Subject: [PATCH] add tests for linear/smooth template --- src/dvc_render/vega.py | 24 ++-- tests/test_vega.py | 259 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+), 11 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index bbe6a9c..9e35dea 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -198,7 +198,7 @@ def _fill_optional_anchors( skip_anchors, optional_anchors, "group_by", ["rev"] ) self._fill_optional_anchor( - skip_anchors, optional_anchors, "pivot_field", "rev" + skip_anchors, optional_anchors, "pivot_field", "datum.rev" ) for anchor in optional_anchors: self.template.fill_anchor(anchor, {}) @@ -230,6 +230,8 @@ def _fill_optional_anchors( filenameOrField = varied_keys[0] domain = list(variations[filenameOrField]) + domain.sort() + stroke_dash_scale = self._set_optional_anchor_scale( optional_anchors, concat_field, "stroke_dash", domain ) @@ -269,25 +271,25 @@ def _collect_variations( for key in ["filename", "field"]: variations[key].add(defn.get(key, None)) - valuesMatchVariations = [] - lessValuesThanVariations = [] + values_match_variations = [] + less_values_than_variations = [] for filenameOrField, valueSet in variations.items(): num_values = len(valueSet) if num_values == 1: continue if num_values == len(y_defn): - valuesMatchVariations.append(filenameOrField) + values_match_variations.append(filenameOrField) continue - lessValuesThanVariations.append(filenameOrField) + less_values_than_variations.append(filenameOrField) - if valuesMatchVariations: - valuesMatchVariations.extend(lessValuesThanVariations) - valuesMatchVariations.sort(reverse=True) - return valuesMatchVariations, variations + if values_match_variations: + values_match_variations.extend(less_values_than_variations) + values_match_variations.sort(reverse=True) + return values_match_variations, variations - lessValuesThanVariations.sort(reverse=True) - return lessValuesThanVariations, variations + less_values_than_variations.sort(reverse=True) + return less_values_than_variations, variations def _fill_optional_anchor( self, diff --git a/tests/test_vega.py b/tests/test_vega.py index 65c2ff0..6a4502f 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -225,3 +225,262 @@ def test_escape_special_characters(): assert filled["encoding"]["x"]["title"] == "foo.bar[0]" assert filled["encoding"]["y"]["field"] == "foo\\.bar\\[1\\]" assert filled["encoding"]["y"]["title"] == "foo.bar[1]" + + +@pytest.mark.parametrize( + ",".join( + [ + "datapoints", + "y", + "anchors_y_defn", + "expected_dp_keys", + "color_encoding", + "stroke_dash_encoding", + "pivot_field", + "group_by", + ] + ), + ( + ( + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + ], + "acc", + [{"filename": "test", "field": "acc"}], + ["rev", "acc", "step"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + {}, + "datum.rev", + ["rev"], + ), + ( + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "acc": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.09", + "filename": "train", + "field": "acc", + "step": 2, + }, + ], + "acc", + [ + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + ["rev", "acc", "step", "filename"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + { + "field": "filename", + "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.filename", + ["rev", "filename"], + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "filename": "train", + "field": "acc_norm", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.09", + "filename": "test", + "field": "acc_norm", + "step": 2, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "test", "field": "acc"}, + {"filename": "test", "field": "acc_norm"}, + ], + ["rev", "dvc_inferred_y_value", "step", "field"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + { + "field": "field", + "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.field", + ["rev", "field"], + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.09", + "filename": "train", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.02", + "filename": "test", + "field": "acc_norm", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.07", + "filename": "test", + "field": "acc_norm", + "step": 2, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "test", "field": "acc_norm"}, + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + ["rev", "dvc_inferred_y_value", "step", "filename::field"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + { + "field": "filename::field", + "scale": { + "domain": ["test::acc", "test::acc_norm", "train::acc"], + "range": [[1, 0], [8, 8], [8, 4]], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.filename + '::' + datum.field", + ["rev", "filename", "field"], + ), + ), +) +def test_optional_anchors_linear( + datapoints, + y, + anchors_y_defn, + expected_dp_keys, + color_encoding, + stroke_dash_encoding, + pivot_field, + group_by, +): # pylint: disable=too-many-arguments + props = { + "template": "linear", + "x": "step", + "y": y, + "anchor_revs": ["B"], + "anchors_y_defn": anchors_y_defn, + } + + expected_datapoints = [] + for datapoint in datapoints: + expected_datapoint = {} + for key in expected_dp_keys: + if key == "filename::field": + expected_datapoint[ + key + ] = f"{datapoint['filename']}::{datapoint['field']}" + else: + expected_datapoint[key] = datapoint.get(key) + expected_datapoints.append(expected_datapoint) + + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template( + as_string=False + ) + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["encoding"]["color"] == color_encoding + assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding + assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field + assert plot_content["layer"][0]["transform"][0]["groupby"] == group_by