-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use float8 for training? #2201
Comments
Hi @vgoklani , we don't currently support this, but you could modify a recipe to call torchao.float8.convert_to_float8_training on your model at the end of this function. However, I recommend using QLoRA, where frozen base model params are quantized to a lower precision (NF4), while the trainable adapter params are kept in a higher precision. Here's an example config. The QLoRA model builder replaces the model's linear layers with Let me know if you have any questions! |
Thanks @calvinpelletier. We are using the full-finetune scripts, and since the hardware already supports FP8, we are just leaving a lot of performance on the table... We can add it to our internal version, but I would imagine that there are other groups that want this included too. |
We would definitely appreciate a PR if full-finetuning in FP8 works out well for you all! |
I was working on adding INT8 training to torchtune #1552, and FP8 was also on the discussion. Once the INT8 PR is merged, we can make another one for FP8 too, since it follows a similar design. |
Thank you @calvinpelletier and @gau-nernst Using Dynamic scaling with the torachao api was trivial, and gave a ~30% performance boost in tokens-per-second We're running on 4x NVIDIA A6000 Ada cards (SM89) from torchao.float8 import (
CastConfig,
Float8LinearConfig,
ScalingType,
convert_to_float8_training,
)
config = Float8LinearConfig(
enable_fsdp_float8_all_gather=True,
force_recompute_fp8_weight_in_bwd=True,
cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DYNAMIC),
)
convert_to_float8_training(mlp, config=config) strangely enough, using |
@vgoklani Delayed scaling is not as well-supported as dynamic scaling I think. Should be fine to stick to dynamic scaling. Curious. Do you observe any convergence issue? |
@gau-nernst The loss was very close to bfloat16! I'm looking forward to int8 training :) |
Are there any examples for training the MLP blocks using float8 from torchao?
Thanks!
The text was updated successfully, but these errors were encountered: