Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify implementation of nn.relu. #25331

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

carlosgmartin
Copy link
Contributor

No description provided.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 9, 2024

This is one of those changes that's probably OK, but has a small chance of causing some strange unintended numerical regression deep within models that use RELU. This simplification saves one line of code: I'm inclined to fall back on the "if it's not broken don't fix it" principle. What do you think?

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Dec 9, 2024

@jakevdp I'd consider that 4 lines of code. 🙂 The test should be considered separately, since it should've existed in the first place. And I think tests in general should be treated as a separate category for the purposes of codebase length, since it's not part of JAX in and of itself, but rather a test of it, and we typically want extensive test coverage.

I think we should always take the opportunity to simplify JAX's codebase, unless there's a very strong reason not to. Otherwise we'd be stuck with suboptimal choices forever. Many small improvements add up. That eases the future maintenance burden.

In this particular case, there's no reason why we should have to bring in the custom_jvp machinery to define ReLU.

If any issues do pop up, I'll handle them.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 9, 2024

I would approve a PR adding that test without any comment 😁

Note however, that your change does modify numerical results. For example, try grad(relu)(nan) with the old and new definitions. Will that make a difference downstream? Hard to say.

@carlosgmartin
Copy link
Contributor Author

@jakevdp If you're taking the grad of relu at nan, you've got bigger problems. 😁

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 9, 2024

I'm truly not trying to be difficult here, it's just that many times I've seen changes like this cause headaches down the road, and saving one line of code doesn't strike me as worth the risk.

@carlosgmartin
Copy link
Contributor Author

IMO we can always undo it if it does cause issues (though I don't see how it could).

But if you'd rather not, that's fine too.

@carlosgmartin
Copy link
Contributor Author

carlosgmartin commented Dec 11, 2024

@jakevdp Just in case this addresses your concern, I edited it so that it now produces the same grad at nan.

I also created a separate PR to isolate the other changes. If that one is merged, I'll rebase this one.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 11, 2024

Thanks for splitting the uncontroversial changes into another PR! Let's go ahead and try this: can you rebase on the current main branch? Thanks!

@carlosgmartin
Copy link
Contributor Author

@jakevdp Done.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Dec 11, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Dec 11, 2024

We're seeing some failures on ShardingInTypesTest.test_scan when run on TPU backends: on all backends:

TypeError: select cases must have the same shardings, got [NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec(None, None)), NamedSharding(mesh=Mesh('x': 2, 'y': 2), spec=PartitionSpec(None, 'y'))].

The test is here:

class ShardingInTypesTest(jtu.JaxTestCase):

I'm not sure how to best address this.

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 11, 2024

Maybe use lax.full_like(x, 0) instead of 0 for the last argument of where?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 11, 2024

This change is also breaking one flax test, on this line: https://github.com/google/flax/blob/554b690bab07920860acbdb1d4fae03cc516d385/tests/linen/summary_test.py#L676

AssertionError: '1628954' not found in '│        │ Classifier │ float32[1,28,28,1] │ float32[1,10]   │ 1629979 │ 5698101   │                                         │'

@carlosgmartin
Copy link
Contributor Author

@jakevdp I'll try that. But does this expose a broader issue or lack of optimization in the compiler or sharding machinery?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 11, 2024

@jakevdp I'll try that. But does this expose a broader issue or lack of optimization in the compiler or sharding machinery?

I think it's due to the fact that the broadcasting machinery in jax.numpy does not have any logic around sharding.

@carlosgmartin
Copy link
Contributor Author

Updated.

I think it's due to the fact that the broadcasting machinery in jax.numpy does not have any logic around sharding.

Is this considered a current shortcoming? If so, should we open an issue about it (and link back to here)?

@jakevdp
Copy link
Collaborator

jakevdp commented Dec 11, 2024

There are some other failures as well that I'm not sure how to deal with: for example, an internal rematerialization test that's looking for a particular pattern in the emitted StableHLO for some complicated model.

The flax test is concerning as well, as it points to the fact that the compiler is generating a quantitatively different program with your PR than with the existing implementation. Given the importance of relu to production models, we'd probably want some comprehensive benchmarks of realistic models to convince ourselves that this change is going in the right direction.

Honestly, I don't think it's worth the effort to track down these fixes and do the analysis necessary to land this change.

@yashk2810
Copy link
Collaborator

yashk2810 commented Dec 11, 2024

#25423 should make jnp.where(x > 0, x, 0) work properly with shardings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants