Skip to content

Commit

Permalink
Update lowering.py
Browse files Browse the repository at this point in the history
  • Loading branch information
shangz-ai authored Dec 17, 2024
1 parent 45598f2 commit 472cda8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,9 +1620,9 @@ def _broadcast_in_dim_lowering_rule(


@register_lowering(lax.squeeze_p)
def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions, sharding=None):
def _squeeze_lowering_rule(ctx: LoweringRuleContext, a, *, dimensions):
del dimensions
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None, sharding=sharding)
return _reshape_lowering_rule(ctx, a, new_sizes=None, dimensions=None, sharding=None)


@register_lowering(lax.reshape_p)
Expand Down

0 comments on commit 472cda8

Please sign in to comment.