-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simplify implementation of nn.relu. #25331
base: main
Are you sure you want to change the base?
Conversation
162f364
to
b41e069
Compare
This is one of those changes that's probably OK, but has a small chance of causing some strange unintended numerical regression deep within models that use RELU. This simplification saves one line of code: I'm inclined to fall back on the "if it's not broken don't fix it" principle. What do you think? |
@jakevdp I'd consider that 4 lines of code. 🙂 The test should be considered separately, since it should've existed in the first place. And I think tests in general should be treated as a separate category for the purposes of codebase length, since it's not part of JAX in and of itself, but rather a test of it, and we typically want extensive test coverage. I think we should always take the opportunity to simplify JAX's codebase, unless there's a very strong reason not to. Otherwise we'd be stuck with suboptimal choices forever. Many small improvements add up. That eases the future maintenance burden. In this particular case, there's no reason why we should have to bring in the If any issues do pop up, I'll handle them. |
I would approve a PR adding that test without any comment 😁 Note however, that your change does modify numerical results. For example, try |
@jakevdp If you're taking the grad of relu at nan, you've got bigger problems. 😁 |
I'm truly not trying to be difficult here, it's just that many times I've seen changes like this cause headaches down the road, and saving one line of code doesn't strike me as worth the risk. |
IMO we can always undo it if it does cause issues (though I don't see how it could). But if you'd rather not, that's fine too. |
b41e069
to
107911e
Compare
@jakevdp Just in case this addresses your concern, I edited it so that it now produces the same grad at nan. I also created a separate PR to isolate the other changes. If that one is merged, I'll rebase this one. |
Thanks for splitting the uncontroversial changes into another PR! Let's go ahead and try this: can you rebase on the current main branch? Thanks! |
107911e
to
4eb7ce2
Compare
@jakevdp Done. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
We're seeing some failures on
The test is here: Line 4640 in 5fe8bcc
I'm not sure how to best address this. |
Maybe use |
This change is also breaking one flax test, on this line: https://github.com/google/flax/blob/554b690bab07920860acbdb1d4fae03cc516d385/tests/linen/summary_test.py#L676
|
@jakevdp I'll try that. But does this expose a broader issue or lack of optimization in the compiler or sharding machinery? |
I think it's due to the fact that the broadcasting machinery in |
4eb7ce2
to
879a599
Compare
Updated.
Is this considered a current shortcoming? If so, should we open an issue about it (and link back to here)? |
There are some other failures as well that I'm not sure how to deal with: for example, an internal rematerialization test that's looking for a particular pattern in the emitted StableHLO for some complicated model. The flax test is concerning as well, as it points to the fact that the compiler is generating a quantitatively different program with your PR than with the existing implementation. Given the importance of Honestly, I don't think it's worth the effort to track down these fixes and do the analysis necessary to land this change. |
#25423 should make |
No description provided.