-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pallas Broken after making JAX-Triton calls serializable update #179
Comments
Could you provide your JAX Triton and JAXlib versions? |
Ok, so I have gone back through and created a fresh install. jax / jaxlib 0.4.12 with cuda 11 work with the latest commits. jax / jaxlib 0.4.13 is where I get the aforementioned error (but these versions work for code prior to the serialisation change commit f947255). |
I'm having the same issue. |
I haven't been able to reproduce this but I used jaxlib-0.4.14 from the nightlies: https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html. (Specifically the 0705 one). |
I was using the cuda 11 version of jaxlib/jax |
Are you able to try a nightly? |
The new update has broken pallas again with the error shown below. I have tried updating jax / jaxlib to head and the issue persists.
pallas_error.txt
The text was updated successfully, but these errors were encountered: