-
Notifications
You must be signed in to change notification settings - Fork 41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Import error encountered in jax_triton #264
Comments
Hi @egg5154, which version of |
Hello, jax_triton is 0.1.4 and triton(triton-nightly) is 2.1.0.post20231216005823 |
I suspect you might need the |
Hello @superbobry , I changed to the |
Ouch, sorry you have to deal with this. It is indeed quite tricky to find a working If you are open to using Pallas instead of Triton directly, jax-ml/jax#19890 changed how Pallas-produced Triton kernels are compiled. We no longer need neither |
Thanks! Actually I want to use flash-attention in |
Yeah, you could use Pallas, which would lower to Triton on GPU without using Triton Python APIs. |
Hello, I was running jax_triton on A100 and CUDA 12.2, but when I run the command
python -c 'import jax_triton as jt'
, error occurs:My jax_triton was installed following jax-ml/jax#18603
The text was updated successfully, but these errors were encountered: