Skip to content

Commit

Permalink
Allow event listners to take extra keyword arguments.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 569138957
  • Loading branch information
chansoo-google authored and jax authors committed Sep 28, 2023
1 parent 2d068a1 commit 79d0a83
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
31 changes: 28 additions & 3 deletions jax/_src/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,31 @@
A typical listener callback is to send an event to a metrics collector for
aggregation/exporting.
"""
from typing import Callable
from typing import Callable, Mapping, Protocol, Union


class EventListenerWithMetadata(Protocol):

def __call__(
self, event: str, **kwargs: Mapping[str, Union[str, int]]
) -> None:
...


_event_listeners_with_metadata: list[EventListenerWithMetadata] = []
_event_listeners: list[Callable[[str], None]] = []
_event_duration_secs_listeners: list[Callable[[str, float], None]] = []

def record_event(event: str) -> None:

def record_event(event: str, **kwargs: Mapping[str, Union[str, int]]) -> None:
"""Record an event."""
for callback in _event_listeners:
callback(event)
if not kwargs:
callback(event)
for callback in _event_listeners_with_metadata:
if kwargs:
callback(event, **kwargs)


def record_event_duration_secs(event: str, duration: float) -> None:
"""Record an event duration in seconds (float)."""
Expand All @@ -39,6 +55,15 @@ def register_event_listener(callback: Callable[[str], None]) -> None:
"""Register a callback to be invoked during record_event()."""
_event_listeners.append(callback)


# TODO(b/301446522): Merge this function with register_event_listener.
def register_event_listener_with_kwargs(
callback: EventListenerWithMetadata,
) -> None:
"""Register a callback to be invoked during record_event()."""
_event_listeners_with_metadata.append(callback)


def register_event_duration_secs_listener(
callback : Callable[[str, float], None]) -> None:
"""Register a callback to be invoked during record_event_duration_secs()."""
Expand Down
9 changes: 5 additions & 4 deletions jax/monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
"""

from jax._src.monitoring import (
record_event as record_event,
record_event_duration_secs as record_event_duration_secs,
register_event_listener as register_event_listener,
register_event_duration_secs_listener as register_event_duration_secs_listener,
record_event_duration_secs as record_event_duration_secs,
record_event as record_event,
register_event_duration_secs_listener as register_event_duration_secs_listener,
register_event_listener_with_kwargs as register_event_listener_with_kwargs,
register_event_listener as register_event_listener,
)

0 comments on commit 79d0a83

Please sign in to comment.