Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon committed Sep 21, 2023
1 parent 5ed34de commit 38d7977
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 86 deletions.
210 changes: 133 additions & 77 deletions src/dvc_render/vega.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .base import Renderer
from .utils import list_dict_to_dict_list
from .vega_templates import BadTemplateError, LinearTemplate, get_template
from .vega_templates import BadTemplateError, LinearTemplate, Template, get_template


class VegaRenderer(Renderer):
Expand Down Expand Up @@ -58,14 +58,12 @@ def __init__(self, datapoints: List, name: str, **properties):
],
"shape": ["square", "circle", "triangle", "diamond"],
}
self._optional_anchor_values: Dict[
str,
Dict[str, Dict[str, str]],
] = defaultdict(dict)

self._split_content: Dict[str, Any] = {}

def get_filled_template(
self,
skip_anchors: Optional[List[str]] = None,
split_anchors: Optional[List[str]] = None,
strict: bool = True,
as_string: bool = True,
) -> Union[str, Dict[str, Any]]:
Expand All @@ -74,8 +72,8 @@ def get_filled_template(
if not self.datapoints:
return {}

if skip_anchors is None:
skip_anchors = []
if split_anchors is None:
split_anchors = []

if strict:
if self.properties.get("x"):
Expand All @@ -91,15 +89,18 @@ def get_filled_template(
self.properties.setdefault("y_label", self.properties.get("y"))
self.properties.setdefault("data", self.datapoints)

self._process_optional_anchors(skip_anchors)
self._process_optional_anchors(split_anchors)

names = ["title", "x", "y", "x_label", "y_label", "data"]
for name in names:
if name in skip_anchors:
continue
value = self.properties.get(name)
if value is None:
continue

if name in split_anchors:
self._set_split_content(name, value)
continue

if name == "data":
if not self.template.has_anchor(name):
anchor = self.template.anchor(name)
Expand All @@ -116,6 +117,15 @@ def get_filled_template(

return self.template.content

def get_partial_filled_template(self):
"""
Returns a partially filled template along with the split out anchor content
"""
content = self.get_filled_template(
split_anchors=["data", "color", "stroke_dash", "shape"], strict=True
)
return content, self._split_content

def partial_html(self, **kwargs) -> str:
return self.get_filled_template() # type: ignore

Expand Down Expand Up @@ -164,7 +174,7 @@ def generate_markdown(self, report_path=None) -> str:

return ""

def _process_optional_anchors(self, skip_anchors: List[str]):
def _process_optional_anchors(self, split_anchors: List[str]):
optional_anchors = [
anchor
for anchor in [
Expand All @@ -177,79 +187,85 @@ def _process_optional_anchors(self, skip_anchors: List[str]):
]
if self.template.has_anchor(anchor)
]
if optional_anchors:
# split varied_keys out from _fill_optional_anchors to avoid bugs
# but first.... tests
varied_keys = self._fill_optional_anchors(skip_anchors, optional_anchors)
self._update_datapoints(varied_keys)

def _fill_optional_anchors(
self, skip_anchors: List[str], optional_anchors: List[str]
) -> List[str]:
self._fill_color(skip_anchors, optional_anchors)

if not optional_anchors:
return []
return

y_defn = self.properties.get("anchors_y_defn", [])
is_single_source = len(y_defn) <= 1

if len(y_defn) <= 1:
self._fill_optional_anchor(
skip_anchors, optional_anchors, "group_by", ["rev"]
)
self._fill_optional_anchor(
skip_anchors, optional_anchors, "pivot_field", "datum.rev"
)
for anchor in optional_anchors:
self.template.fill_anchor(anchor, {})
return []
if is_single_source:
self._process_single_source_plot(split_anchors, optional_anchors)
return

self._process_multi_source_plot(split_anchors, optional_anchors, y_defn)

def _process_single_source_plot(
self, split_anchors: List[str], optional_anchors: List[str]
):
self._fill_color(split_anchors, optional_anchors)
self._fill_optional_anchor(split_anchors, optional_anchors, "group_by", ["rev"])
self._fill_optional_anchor(
split_anchors, optional_anchors, "pivot_field", "datum.rev"
)
for anchor in optional_anchors:
self.template.fill_anchor(anchor, {})

self._update_datapoints([])

def _process_multi_source_plot(
self,
split_anchors: List[str],
optional_anchors: List[str],
y_defn: List[Dict[str, str]],
):
varied_keys, varied_values = self._collect_variations(y_defn)
domain = self._get_domain(varied_keys, varied_values, y_defn)

self._fill_optional_multi_source_anchors(
split_anchors, optional_anchors, varied_keys, domain
)
self._update_datapoints(varied_keys)

def _fill_optional_multi_source_anchors(
self,
split_anchors: List[str],
optional_anchors: List[str],
varied_keys: List[str],
domain: List[str],
):
self._fill_color(split_anchors, optional_anchors)

if not optional_anchors:
return

varied_keys, variations = self._collect_variations(y_defn)
grouped_keys = ["rev", *varied_keys]
concat_field = "::".join(varied_keys)
self._fill_optional_anchor(
skip_anchors, optional_anchors, "group_by", grouped_keys
split_anchors, optional_anchors, "group_by", grouped_keys
)
self._fill_optional_anchor(
skip_anchors,
split_anchors,
optional_anchors,
"pivot_field",
" + '::' + ".join([f"datum.{key}" for key in grouped_keys]),
)
# concatenate grouped_keys together
self._fill_optional_anchor(
skip_anchors, optional_anchors, "row", {"field": concat_field}
)

if not optional_anchors:
return varied_keys

if len(varied_keys) == 2:
domain = ["::".join([d.get("filename"), d.get("field")]) for d in y_defn]
else:
filenameOrField = varied_keys[0]
domain = list(variations[filenameOrField])

domain.sort()

stroke_dash_scale = self._set_optional_anchor_scale(
optional_anchors, concat_field, "stroke_dash", domain
)
concat_field = "::".join(varied_keys)
self._fill_optional_anchor(
skip_anchors, optional_anchors, "stroke_dash", stroke_dash_scale
split_anchors, optional_anchors, "row", {"field": concat_field}
)

shape_scale = self._set_optional_anchor_scale(
optional_anchors, concat_field, "shape", domain
)
self._fill_optional_anchor(skip_anchors, optional_anchors, "shape", shape_scale)
if not optional_anchors:
return

return varied_keys
for field in ["stroke_dash", "shape"]:
self._fill_optional_anchor_mapping(
split_anchors, optional_anchors, concat_field, field, domain
)

def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]):
def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]):
all_revs = self.properties.get("anchor_revs", [])
self._fill_optional_anchor(
skip_anchors,
split_anchors,
optional_anchors,
"color",
{
Expand All @@ -266,15 +282,15 @@ def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]):
def _collect_variations(
self, y_defn: List[Dict[str, str]]
) -> Tuple[List[str], Dict[str, set]]:
variations = defaultdict(set)
varied_values = defaultdict(set)
for defn in y_defn:
for key in ["filename", "field"]:
variations[key].add(defn.get(key, None))
varied_values[key].add(defn.get(key, None))

values_match_variations = []
less_values_than_variations = []

for filenameOrField, valueSet in variations.items():
for filenameOrField, valueSet in varied_values.items():
num_values = len(valueSet)
if num_values == 1:
continue
Expand All @@ -286,14 +302,14 @@ def _collect_variations(
if values_match_variations:
values_match_variations.extend(less_values_than_variations)
values_match_variations.sort(reverse=True)
return values_match_variations, variations
return values_match_variations, varied_values

less_values_than_variations.sort(reverse=True)
return less_values_than_variations, variations
return less_values_than_variations, varied_values

def _fill_optional_anchor(
self,
skip_anchors: List[str],
split_anchors: List[str],
optional_anchors: List[str],
name: str,
value: Any,
Expand All @@ -303,26 +319,63 @@ def _fill_optional_anchor(

optional_anchors.remove(name)

if name in skip_anchors:
if name in split_anchors:
return

self.template.fill_anchor(name, value)

def _set_optional_anchor_scale(
self, optional_anchors: List[str], field: str, name: str, domain: List[str]
def _get_domain(
self,
varied_keys: List[str],
varied_values: Dict[str, set],
y_defn: List[Dict[str, str]],
):
if len(varied_keys) == 2:
domain = [
"::".join([d.get("filename", ""), d.get("field", "")]) for d in y_defn
]
else:
filenameOrField = varied_keys[0]
domain = list(varied_values[filenameOrField])

domain.sort()
return domain

def _fill_optional_anchor_mapping(
self,
split_anchors: List[str],
optional_anchors: List[str],
field: str,
name: str,
domain: List[str],
):
if name not in optional_anchors:
return {"field": field, "scale": {"domain": [], "range": []}}
return

optional_anchors.remove(name)

encoding = self._get_optional_anchor_mapping(field, name, domain)

if name in split_anchors:
self._set_split_content(name, encoding)
return

self.template.fill_anchor(name, encoding)

def _get_optional_anchor_mapping(
self,
field: str,
name: str,
domain: List[str],
):
full_range_values: List[Any] = self._optional_anchor_ranges.get(name, [])
anchor_range_values = full_range_values.copy()
anchor_range = []

for domain_value in domain:
anchor_range = []
for _ in range(len(domain)):
if not anchor_range_values:
anchor_range_values = full_range_values.copy()
range_value = anchor_range_values.pop(0)
self._optional_anchor_values[name][domain_value] = range_value
anchor_range.append(range_value)

return {
Expand All @@ -347,3 +400,6 @@ def _update_datapoints(self, varied_keys: List[str]):
)
for key in to_remove:
datapoint.pop(key, None)

def _set_split_content(self, name: str, value: Any):
self._split_content[Template.anchor(name)] = value
Loading

0 comments on commit 38d7977

Please sign in to comment.