Skip to content

Commit

Permalink
Implement flatten one level with keys in C++ and use it for the prefi…
Browse files Browse the repository at this point in the history
…x/equality error printing.

With this, we should be able to safely delete the python with-path registry after a new jaxlib release.

Also changed all `std::string_view` to `absl::string_view` per requirements of TF repository.

PiperOrigin-RevId: 705669465
  • Loading branch information
IvyZX authored and Google-ML-Automation committed Dec 13, 2024
1 parent eb3ea98 commit ef06607
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from jax._src import traceback_util
from jax._src.lib import pytree
from jax._src.lib import xla_extension_version
from jax._src.util import safe_zip, set_module
from jax._src.util import unzip2

Expand Down Expand Up @@ -607,6 +608,18 @@ def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
return out


# flatten_one_level_with_keys is not exported.
def flatten_one_level_with_keys(
tree: Any,
) -> tuple[Iterable[KeyLeafPair], Hashable]:
"""Flatten the given pytree node by one level, with keys."""
out = default_registry.flatten_one_level_with_keys(tree)
if out is None:
raise ValueError(f"can't tree-flatten type: {type(tree)}")
else:
return out


# prefix_errors is not exported
def prefix_errors(prefix_tree: Any, full_tree: Any,
is_leaf: Callable[[Any], bool] | None = None,
Expand Down Expand Up @@ -728,7 +741,7 @@ def keystr(keys: KeyPath):
return ''.join(map(str, keys))


# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
# TODO(ivyzheng): remove this after another jaxlib release.
class _RegistryWithKeypathsEntry(NamedTuple):
flatten_with_keys: Callable[..., Any]
unflatten_func: Callable[..., Any]
Expand Down Expand Up @@ -1146,6 +1159,8 @@ def tree_map_with_path(f: Callable[..., Any],

def _child_keys(pytree: Any) -> KeyPath:
assert not treedef_is_strict_leaf(tree_structure(pytree))
if xla_extension_version >= 301:
return tuple(k for k, _ in flatten_one_level_with_keys(pytree)[0])
handler = _registry_with_keypaths.get(type(pytree))
if handler:
return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0])
Expand Down

0 comments on commit ef06607

Please sign in to comment.