Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use float8 for training? #2201

Open
vgoklani opened this issue Dec 23, 2024 · 7 comments
Open

How to use float8 for training? #2201

vgoklani opened this issue Dec 23, 2024 · 7 comments

Comments

@vgoklani
Copy link

Are there any examples for training the MLP blocks using float8 from torchao?

Thanks!

@calvinpelletier
Copy link
Contributor

Hi @vgoklani , we don't currently support this, but you could modify a recipe to call torchao.float8.convert_to_float8_training on your model at the end of this function.

However, I recommend using QLoRA, where frozen base model params are quantized to a lower precision (NF4), while the trainable adapter params are kept in a higher precision.

Here's an example config. The QLoRA model builder replaces the model's linear layers with LoRALinear(...,quantize_base=True) layers. If you want to use float8 instead of nf4, you can modify the LoRALinear class.

Let me know if you have any questions!

@vgoklani
Copy link
Author

Thanks @calvinpelletier. We are using the full-finetune scripts, and since the hardware already supports FP8, we are just leaving a lot of performance on the table... We can add it to our internal version, but I would imagine that there are other groups that want this included too.

@calvinpelletier
Copy link
Contributor

We would definitely appreciate a PR if full-finetuning in FP8 works out well for you all!

@gau-nernst
Copy link
Contributor

I was working on adding INT8 training to torchtune #1552, and FP8 was also on the discussion. Once the INT8 PR is merged, we can make another one for FP8 too, since it follows a similar design.

@vgoklani
Copy link
Author

vgoklani commented Dec 24, 2024

Thank you @calvinpelletier and @gau-nernst

Using Dynamic scaling with the torachao api was trivial, and gave a ~30% performance boost in tokens-per-second

We're running on 4x NVIDIA A6000 Ada cards (SM89)

from torchao.float8 import (
    CastConfig,
    Float8LinearConfig,
    ScalingType,
    convert_to_float8_training,
)

config = Float8LinearConfig(
    enable_fsdp_float8_all_gather=True,
    force_recompute_fp8_weight_in_bwd=True,
    cast_config_input=CastConfig(scaling_type=ScalingType.DYNAMIC),
    cast_config_weight=CastConfig(scaling_type=ScalingType.DYNAMIC),
    cast_config_grad_output=CastConfig(scaling_type=ScalingType.DYNAMIC),
)

convert_to_float8_training(mlp, config=config)

strangely enough, using DELAYED scaling crashed torch.compile... will need to dig into that further...

@gau-nernst
Copy link
Contributor

gau-nernst commented Dec 25, 2024

@vgoklani Delayed scaling is not as well-supported as dynamic scaling I think. Should be fine to stick to dynamic scaling.

Curious. Do you observe any convergence issue?

@vgoklani
Copy link
Author

vgoklani commented Dec 25, 2024

@gau-nernst The loss was very close to bfloat16! I'm looking forward to int8 training :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants