Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

- Adding support in the graph scanner for Haiku & Flax normalization layers without learnable shift/offset params. #297

Merged
merged 1 commit into from
Nov 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,27 +1490,32 @@ def _normalization_haiku_preprocessor(
return (normalized_inputs_var, *param_vars), [normalized_inputs_eqn]


def _make_normalization_haiku_pattern(
def _make_normalization_haiku_flax_pattern(
broadcast_ndim: int,
has_reshape: bool,
p_dim: int = 13,
):
has_shift: bool = True,
) -> GraphPattern:
"""Creates a pattern for a Haiku/Flax normalization layer."""

assert broadcast_ndim >= 0

x_shape = [i + 2 for i in range(broadcast_ndim)] + [p_dim]

example_params = [np.zeros([p_dim])]
if has_shift:
example_params.append(np.zeros([p_dim]))

return GraphPattern(
name=f"normalization_haiku_broadcast_{broadcast_ndim}",
tag_primitive=tags.layer_tag,
compute_func=functools.partial(
_normalization_haiku_flax,
has_scale=True,
has_shift=True,
has_shift=has_shift,
has_reshape=has_reshape),
parameters_extractor_func=_scale_and_shift_parameter_extractor,
example_args=[[np.zeros(x_shape), np.zeros(x_shape)],
[np.zeros([p_dim]), np.zeros([p_dim])]],
example_args=[[np.zeros(x_shape), np.zeros(x_shape)], example_params],
in_values_preprocessor=_normalization_haiku_preprocessor
)

Expand All @@ -1531,16 +1536,29 @@ def _make_normalization_haiku_pattern(
)
)

NORMALIZATION_GRAPH_PATTERNS = tuple(
_make_normalization_haiku_flax_pattern(
broadcast_ndim=n,
has_reshape=r,
has_shift=s)
for n, r, s in itertools.product(
range(2),
(False, True),
(False, True),
)
)

DEFAULT_GRAPH_PATTERNS = DENSE_GRAPH_PATTERNS + (
_make_conv2d_pattern(True, False),
_make_conv2d_pattern(True, True),
_make_conv2d_pattern(False, False),
_make_scale_and_shift_pattern(1, True, True),
_make_scale_and_shift_pattern(0, True, True),
_make_normalization_haiku_pattern(1, False),
_make_normalization_haiku_pattern(1, True),
_make_normalization_haiku_pattern(0, False),
_make_normalization_haiku_pattern(0, True),
_make_scale_and_shift_pattern(0, True, True)
)

DEFAULT_GRAPH_PATTERNS += NORMALIZATION_GRAPH_PATTERNS

DEFAULT_GRAPH_PATTERNS += (
_make_scale_and_shift_pattern(1, True, False),
_make_scale_and_shift_pattern(0, True, False),
_make_scale_and_shift_pattern(1, False, True),
Expand Down Expand Up @@ -1946,7 +1964,7 @@ def _auto_register_tags(
inputs_index=(),
outputs_index=(0,),
params_index=(0,),
name=f"Auto[generic|{n}]",
name=f"Auto[generic({n})]",
)),
effects=set(),
)
Expand All @@ -1971,7 +1989,7 @@ def _auto_register_tags(
if meta.name is None:
n = pattern_counters.get(meta.variant, 0)
pattern_counters[meta.variant] = n + 1
meta.name = f"Manual[{meta.variant}|{n}]"
meta.name = f"Manual[{meta.variant}({n})]"

tag_locations.append(TagLocation(eqn))

Expand All @@ -1991,7 +2009,8 @@ def _auto_register_tags(
assert meta.name is None
n = pattern_counters.get(meta.variant, 0)
pattern_counters[meta.variant] = n + 1
meta.name = f"Auto[{meta.variant}|{n}]"
meta.name = (f"Auto[tag_variant={meta.variant}({n})|"
f"match_type={match.name}]")
tag_locations.append(TagLocation(eqns[-1]))

final_outvars = [env.get(v, v) if isinstance(v, Var) else v
Expand Down
Loading