Skip to content

Commit

Permalink
Merge pull request #187 from scipp/visualize-modes
Browse files Browse the repository at this point in the history
Add `mode` argument to `visualize` for more compact data or task graph display
  • Loading branch information
SimonHeybrock authored Nov 1, 2024
2 parents 8f40579 + 1e8432b commit 0d4edca
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 23 deletions.
36 changes: 33 additions & 3 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
from collections.abc import Callable, Hashable, Iterable, Sequence
from itertools import chain
from types import UnionType
from typing import TYPE_CHECKING, Any, TypeVar, get_args, get_type_hints, overload
from typing import (
TYPE_CHECKING,
Any,
Literal,
TypeVar,
get_args,
get_type_hints,
overload,
)

from ._provider import Provider, ToProvider
from ._utils import key_name
Expand Down Expand Up @@ -91,7 +99,13 @@ def compute(self, tp: type | Iterable[type] | UnionType, **kwargs: Any) -> Any:
return self.get(tp, **kwargs).compute()

def visualize(
self, tp: type | Iterable[type] | None = None, **kwargs: Any
self,
tp: type | Iterable[type] | None = None,
compact: bool = False,
mode: Literal['data', 'task', 'both'] = 'data',
cluster_generics: bool = True,
cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
) -> graphviz.Digraph:
"""
Return a graphviz Digraph object representing the graph for the given keys.
Expand All @@ -103,12 +117,28 @@ def visualize(
tp:
Type to visualize the graph for.
Can be a single type or an iterable of types.
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommended for large graphs with long parameter tables.
mode:
If 'data', only data nodes are shown. If 'task', only task nodes and input
data nodes are shown. If 'both', all nodes are shown.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
if tp is None:
tp = self.output_keys()
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(**kwargs)
return self.get(tp, handler=HandleAsComputeTimeException()).visualize(
compact=compact,
mode=mode,
cluster_generics=cluster_generics,
cluster_color=cluster_color,
**kwargs,
)

def get(
self,
Expand Down
30 changes: 27 additions & 3 deletions src/sciline/task_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import Generator, Hashable, Sequence
from html import escape
from typing import Any, TypeVar
from typing import Any, Literal, TypeVar

from ._utils import key_name
from .scheduler import DaskScheduler, NaiveScheduler, Scheduler
Expand Down Expand Up @@ -126,18 +126,42 @@ def keys(self) -> Generator[Key, None, None]:
"""
yield from self._graph.keys()

def visualize(self, **kwargs: Any) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
def visualize(
self,
compact: bool = False,
mode: Literal['data', 'task', 'both'] = 'data',
cluster_generics: bool = True,
cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
) -> graphviz.Digraph: # type: ignore[name-defined] # noqa: F821
"""
Return a graphviz Digraph object representing the graph.
Parameters
----------
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommended for large graphs with long parameter tables.
mode:
If 'data', only data nodes are shown. If 'task', only task nodes and input
data nodes are shown. If 'both', all nodes are shown.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Background color of clusters. If None, clusters are dotted.
kwargs:
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
from .visualize import to_graphviz

return to_graphviz(self._graph, **kwargs)
return to_graphviz(
self._graph,
compact=compact,
mode=mode,
cluster_generics=cluster_generics,
cluster_color=cluster_color,
**kwargs,
)

def serialize(self) -> dict[str, Json]:
"""Serialize the graph to JSON.
Expand Down
112 changes: 95 additions & 17 deletions src/sciline/visualize.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# SPDX-License-Identifier: BSD-3-Clause
# Copyright (c) 2023 Scipp contributors (https://github.com/scipp)
import html
from collections.abc import Hashable
from dataclasses import dataclass
from typing import Any, get_args, get_origin
from typing import Any, Literal, get_args, get_origin

import cyclebane
from graphviz import Digraph
Expand Down Expand Up @@ -31,6 +32,7 @@ class FormattedProvider:
def to_graphviz(
graph: Graph,
compact: bool = False,
mode: Literal['data', 'task', 'both'] = 'data',
cluster_generics: bool = True,
cluster_color: str | None = '#f0f0ff',
**kwargs: Any,
Expand All @@ -45,6 +47,9 @@ def to_graphviz(
compact:
If True, parameter-table-dependent branches are collapsed into a single copy
of the branch. Recommended for large graphs with long parameter tables.
mode:
If 'data', only data nodes are shown. If 'task', only task nodes and input data
nodes are shown. If 'both', all nodes are shown.
cluster_generics:
If True, generic products are grouped into clusters.
cluster_color:
Expand All @@ -53,6 +58,28 @@ def to_graphviz(
Keyword arguments passed to :py:class:`graphviz.Digraph`.
"""
dot = Digraph(strict=True, **kwargs)
if dot.graph_attr.get('rankdir', 'TB') == 'LR':
# Significant horizontal space helps distinguishing edges
dot.graph_attr['ranksep'] = '1'
# Little vertical space
dot.graph_attr['nodesep'] = '0.05'
# Avoiding edges connecting to top/bottom reduces edge clutter in larger graphs
dot.edge_attr['tailport'] = 'e'
dot.edge_attr['headport'] = 'w'
else:
dot.graph_attr['ranksep'] = '0.5'
dot.graph_attr['nodesep'] = '0.1'
# With tailport='s' we get more curved edges, so we omit it. In larger graphs
# this still seems to happen though, may need revisiting.
# Nodes are wide in west-east direction, so *not* connecting to headport='n'
# looks better
dot.node_attr.update({'height': '0', 'width': '0'})
# Ensure user can override defaults
dot.node_attr.update(kwargs.get('node_attr', {}))
dot.edge_attr.update(kwargs.get('edge_attr', {}))
dot.graph_attr.update(kwargs.get('graph_attr', {}))
# Compound is required for connecting edges to clusters
dot.graph_attr['compound'] = 'true'
formatted_graph = _format_graph(graph, compact=compact)
ordered_graph = dict(
sorted(formatted_graph.items(), key=lambda item: item[1].ret.name)
Expand All @@ -69,7 +96,13 @@ def to_graphviz(
dot_subgraph.attr(style='dotted')
else:
dot_subgraph.attr(style='filled', color=cluster_color)
_add_subgraph(subgraph, dot, dot_subgraph)
# For keys such as MyType[int] we show MyType only once as the cluster
# label. The nodes within the cluster will only show to bit inside [].
# This save a lot of horizontal space in the graph in LR mode and
# duplication and clutter in general.
origin = next(iter(subgraph.values())).ret.name.split('[')[0]
dot_subgraph.attr(label=f'{origin}')
_add_subgraph(subgraph, dot, dot_subgraph, mode=mode)
return dot


Expand All @@ -82,28 +115,73 @@ def _to_subgraphs(graph: FormattedGraph) -> dict[str, FormattedGraph]:
return subgraphs


def _add_subgraph(graph: FormattedGraph, dot: Digraph, subgraph: Digraph) -> None:
def _add_subgraph(
graph: FormattedGraph,
dot: Digraph,
subgraph: Digraph,
mode: Literal['data', 'task', 'both'],
) -> None:
cluster = subgraph.name is not None
cluster_connected = []
common_provider = len(graph) > 1 and len({v.name for v in graph.values()}) == 1
for p, formatted_p in graph.items():
ret_name = formatted_p.ret.name
if cluster:
# Remove the origin from the name if we are in a cluster, as it is shown
# as the cluster label
split = ret_name[ret_name.index('[') :]
# The nodes within the cluster use slightly smaller text.
name = f'<<font point-size="12">{split}</font>>'
else:
name = f'<{ret_name}>'
if mode == 'data' and formatted_p.kind == 'function':
# Show provider name in data mode
via_name = html.escape(formatted_p.name)
via = f'<font point-size="11">via:<i>{via_name}</i></font>'
if common_provider:
origin = ret_name.split('[')[0]
subgraph.attr(label=f'<{origin}<br/>{via}>')
else:
name = f'{name[:-1]}<br/>{via}>'
shape = 'box3d' if formatted_p.ret.collapsed else 'rectangle'
if formatted_p.kind == 'unsatisfied':
subgraph.node(
formatted_p.ret.name,
formatted_p.ret.name,
shape='box3d' if formatted_p.ret.collapsed else 'rectangle',
ret_name,
name,
shape=shape,
color='red',
fontcolor='red', # Set text color to red
fontcolor='red',
style='dashed',
)
else:
subgraph.node(
formatted_p.ret.name,
formatted_p.ret.name,
shape='box3d' if formatted_p.ret.collapsed else 'rectangle',
)
elif mode != 'task' or formatted_p.kind == 'parameter':
subgraph.node(ret_name, name, shape=shape)
if formatted_p.kind == 'function':
dot.node(p, formatted_p.name, shape='ellipse')
for arg in formatted_p.args:
dot.edge(arg.name, p)
dot.edge(p, formatted_p.ret.name)
if mode == 'both':
dot.node(p, formatted_p.name, shape='ellipse')
for arg in formatted_p.args:
dot.edge(arg.name, p)
dot.edge(p, ret_name)
elif mode == 'task':
p = ret_name
dot.node(p, formatted_p.name, shape='ellipse')
for arg in formatted_p.args:
dot.edge(arg.name, p)
elif mode == 'data':
for arg in formatted_p.args:
if cluster and common_provider and '[' not in arg.name:
# Avoid duplicate arrows to subnodes if all providers are the
# same and the argument is not a generic
if arg.name not in cluster_connected:
dot.edge(
arg.name,
ret_name,
lhead=subgraph.name,
# Thick pen to indicate multiple connections
penwidth='2.0',
)
cluster_connected.append(arg.name)
else:
dot.edge(arg.name, ret_name)
# else: Do not draw dummy providers created by Pipeline when setting instances


Expand Down

0 comments on commit 0d4edca

Please sign in to comment.