Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More Chat Loss Masking Strategies #2214

Open
EugenHotaj opened this issue Dec 30, 2024 · 4 comments
Open

More Chat Loss Masking Strategies #2214

EugenHotaj opened this issue Dec 30, 2024 · 4 comments

Comments

@EugenHotaj
Copy link
Contributor

Are there plans to add more loss masking strategies for chat data?

E.g. a very common loss masking strategy for multi-turn conversations is to mask everything but the last assistant response. However, train_on_input=False right now will compute the loss on all assistant turns, not just the last one. Is it possible to add this feature to torchtune?

@RdoubleA
Copy link
Contributor

If you are using a custom dataset with a custom message transform, you can manually mask the messages you need to in the transform by setting the masked field in the Message dataclass. If you are using one of the dataset builders, you're right that this is not currently possible. These are designed to be easily configurable from yaml so something more flexible like a loss mask list of booleans is a bit tougher. But if this is a common approach and other folks would like this for the built in dataset builders then we could consider something like changing train_on_input to a string masking strategy parameter, or something similar.

@RdoubleA
Copy link
Contributor

RdoubleA commented Jan 1, 2025

I just saw a similar request in #2207, so this might be worth enabling

@EugenHotaj
Copy link
Contributor Author

Nice to "see" you again Rafi! Thanks for the quick response.

But if this is a common approach and other folks would like this for the built in dataset builders then we could consider something like changing train_on_input to a string masking strategy parameter, or something similar.

I just saw a similar request in #2207, so this might be worth enabling

Masking the last turn only is a very (most?) common masking strategy so could be a nice feature to provide users out of the box.

If you are using a custom dataset with a custom message transform, you can manually mask the messages you need to in the transform by setting the masked field in the Message dataclass.

Any pointers / examples for how to do this?

@RdoubleA
Copy link
Contributor

RdoubleA commented Jan 2, 2025

Glad to see you on the torchtune repo Eugen :)

Yes, see this page for an example.

If your conversation is stored in a column, you can just query that column in the custom message transform and manually create Message objects for the whole conversation, leaving the last one unmasked. Then you'll need to make a custom dataset builder that you can specify in your config.

Let me know if there's any confusion on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants