Skip to content

Commit

Permalink
- Adding support in the graph scanner for Haiku & Flax normalization …
Browse files Browse the repository at this point in the history
…layers without learnable shift/offset params.

- Changing "Graph parameter registrations" logging message to include information about the graph match type (and not just the tag variant).

PiperOrigin-RevId: 696987909
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 15, 2024
1 parent b14fb6a commit defe7eb
Showing 1 changed file with 31 additions and 12 deletions.
43 changes: 31 additions & 12 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,31 +1482,48 @@ def _normalization_haiku_preprocessor(
return (normalized_inputs_var, *param_vars), [normalized_inputs_eqn]


def _make_normalization_haiku_pattern(
def _make_normalization_haiku_flax_pattern(
broadcast_ndim: int,
has_reshape: bool,
p_dim: int = 13,
):
has_shift: bool = True,
) -> GraphPattern:
"""Creates a pattern for a Haiku/Flax normalization layer."""

assert broadcast_ndim >= 0

x_shape = [i + 2 for i in range(broadcast_ndim)] + [p_dim]

example_params = [np.zeros([p_dim])]
if has_shift:
example_params.append(np.zeros([p_dim]))

return GraphPattern(
name=f"normalization_haiku_broadcast_{broadcast_ndim}",
tag_primitive=tags.layer_tag,
compute_func=functools.partial(
_normalization_haiku_flax,
has_scale=True,
has_shift=True,
has_shift=has_shift,
has_reshape=has_reshape),
parameters_extractor_func=_scale_and_shift_parameter_extractor,
example_args=[[np.zeros(x_shape), np.zeros(x_shape)],
[np.zeros([p_dim]), np.zeros([p_dim])]],
example_args=[[np.zeros(x_shape), np.zeros(x_shape)], example_params],
in_values_preprocessor=_normalization_haiku_preprocessor
)


NORMALIZATION_GRAPH_PATTERNS = tuple(
_make_normalization_haiku_flax_pattern(
broadcast_ndim=n,
has_reshape=r,
has_shift=s)
for n, r, s in itertools.product(
range(2),
(False, True),
(False, True),
)
)

DEFAULT_GRAPH_PATTERNS = (
_make_general_dense_pattern(True, False, 0),
_make_general_dense_pattern(True, False, 1),
Expand All @@ -1521,11 +1538,12 @@ def _make_normalization_haiku_pattern(
_make_conv2d_pattern(True, True),
_make_conv2d_pattern(False, False),
_make_scale_and_shift_pattern(1, True, True),
_make_scale_and_shift_pattern(0, True, True),
_make_normalization_haiku_pattern(1, False),
_make_normalization_haiku_pattern(1, True),
_make_normalization_haiku_pattern(0, False),
_make_normalization_haiku_pattern(0, True),
_make_scale_and_shift_pattern(0, True, True)
)

DEFAULT_GRAPH_PATTERNS += NORMALIZATION_GRAPH_PATTERNS

DEFAULT_GRAPH_PATTERNS += (
_make_scale_and_shift_pattern(1, True, False),
_make_scale_and_shift_pattern(0, True, False),
_make_scale_and_shift_pattern(1, False, True),
Expand Down Expand Up @@ -1956,7 +1974,7 @@ def _auto_register_tags(
if meta.name is None:
n = pattern_counters.get(meta.variant, 0)
pattern_counters[meta.variant] = n + 1
meta.name = f"Manual[{meta.variant}|{n}]"
meta.name = f"Manual[{meta.variant}({n})]"

tag_locations.append(TagLocation(eqn))

Expand All @@ -1976,7 +1994,8 @@ def _auto_register_tags(
assert meta.name is None
n = pattern_counters.get(meta.variant, 0)
pattern_counters[meta.variant] = n + 1
meta.name = f"Auto[{meta.variant}|{n}]"
meta.name = (f"Auto[tag_variant={meta.variant}({n})|"
f"match_type={match.name}]")
tag_locations.append(TagLocation(eqns[-1]))

final_outvars = [env.get(v, v) if isinstance(v, Var) else v
Expand Down

0 comments on commit defe7eb

Please sign in to comment.