Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

grad_norm 0.0 while finetuning using group_by_label batch sampler #3130

Open
AmoghM opened this issue Dec 10, 2024 · 2 comments
Open

grad_norm 0.0 while finetuning using group_by_label batch sampler #3130

AmoghM opened this issue Dec 10, 2024 · 2 comments

Comments

@AmoghM
Copy link

AmoghM commented Dec 10, 2024

I am currently training a Sentence Transformer on my dataset using triplet loss, but I am encountering an issue where the gradient norm (grad_norm) is consistently 0.0 during training. This problem persists when using the recommended group_by_label batch sampler for triplet loss.

Details

  • Current Setup:
    • Model: Alibaba-NLP/gte-base-en-v1.5
    • Loss Function: Triplet Loss
    • Batch Sampler: group_by_label (recommended for triplet loss)

Observations

  • When I switch the batch sampler to either batch_sampler or no_duplicate, I notice an improvement in the training logs, and the grad_norm values become non-zero.
  • However, I want to utilize the group_by_label sampler as it is suggested for triplet loss, and I need assistance in understanding why this specific configuration is causing issues.

Below is the sample code:

training_args = SentenceTransformerTrainingArguments(
        num_train_epochs=1,
        per_device_train_batch_size=64,
        per_device_eval_batch_size=4,
        warmup_steps=200,
        weight_decay=0.01,
        logging_steps=1,
        logging_strategy="epoch",
        output_dir=output_dir, 
        learning_rate=2e-5,
        max_grad_norm=1.0,
        dataloader_drop_last=True,
        gradient_accumulation_steps=2, 
        gradient_checkpointing=True, 
        batch_sampler='group_by_label',
        evaluation_strategy="steps",
        logging_strategy="steps",
        eval_steps=50
    )

trainer = SentenceTransformerTrainer(
        model=model,
        loss=losses.BatchHardSoftMarginTripletLoss(model),
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
trainer.train()

Tensorboard viz of training with different batch sampler. Orange line is no_duplicate. Blue line is group_label. Red line is batch_sampler
image

image

Training logs for no_duplicate batch sampler:

{'loss': 6.174, 'grad_norm': 38.07846450805664, 'learning_rate': 1.0000000000000001e-07, 'epoch': 0.08}                                                                                                                             
{'loss': 6.8544, 'grad_norm': 44.4666748046875, 'learning_rate': 2.0000000000000002e-07, 'epoch': 0.15}                                                                                                                             
{'loss': 5.7911, 'grad_norm': 37.91443634033203, 'learning_rate': 3.0000000000000004e-07, 'epoch': 0.23}                                                                                                                            
{'loss': 5.8593, 'grad_norm': 41.3128662109375, 'learning_rate': 4.0000000000000003e-07, 'epoch': 0.31}                                                                                                                             
{'loss': 6.1478, 'grad_norm': 40.226253509521484, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.38}                                                                                                                            
{'loss': 6.2663, 'grad_norm': 37.63628005981445, 'learning_rate': 6.000000000000001e-07, 'epoch': 0.46}                                                                                                                             
{'loss': 6.5116, 'grad_norm': 45.362548828125, 'learning_rate': 7.000000000000001e-07, 'epoch': 0.54}                                                                                                                               
{'loss': 6.0732, 'grad_norm': 39.056190490722656, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.62}                                                                                                                            
{'loss': 6.1131, 'grad_norm': 37.20143508911133, 'learning_rate': 9.000000000000001e-07, 'epoch': 0.69}                                                                                                                             
{'loss': 6.2785, 'grad_norm': 42.78799057006836, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.77}                                                                                                                            
{'loss': 6.2814, 'grad_norm': 38.738624572753906, 'learning_rate': 1.1e-06, 'epoch': 0.85}                                                                                                                                          
{'loss': 6.2216, 'grad_norm': 40.94490051269531, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.92}                                                                                                                            
{'loss': 5.776, 'grad_norm': 38.426063537597656, 'learning_rate': 1.3e-06, 'epoch': 1.0}

vs training logs for group_by_label batch sampler:

{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 1.0000000000000001e-07, 'epoch': 0.08}                                                                                                                                          
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 2.0000000000000002e-07, 'epoch': 0.15}                                                                                                                                          
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 3.0000000000000004e-07, 'epoch': 0.23}                                                                                                                                          
{'loss': 1.3355, 'grad_norm': 22.390493392944336, 'learning_rate': 4.0000000000000003e-07, 'epoch': 0.31}                                                                                                                           
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 5.000000000000001e-07, 'epoch': 0.38}                                                                                                                                           
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 6.000000000000001e-07, 'epoch': 0.46}                                                                                                                                           
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 7.000000000000001e-07, 'epoch': 0.54}                                                                                                                                           
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.62}                                                                                                                                           
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 9.000000000000001e-07, 'epoch': 0.69}                                                                                                                                           
{'loss': 2.8415, 'grad_norm': 35.88393783569336, 'learning_rate': 1.0000000000000002e-06, 'epoch': 0.77}                                                                                                                            
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 1.1e-06, 'epoch': 0.85}                                                                                                                                                         
{'loss': 0.6931, 'grad_norm': 0.0, 'learning_rate': 1.2000000000000002e-06, 'epoch': 0.92} 

Questions

  1. What could be causing the grad_norm to be 0.0 when using the group_by_label sampler?
  2. Are there any adjustments or configurations you would recommend to resolve this issue while still using the recommended batch sampler?
    Thank you!
@AmoghM AmoghM changed the title grad_norm 0.0 while finetuning sentence transformer grad_norm 0.0 while finetuning using group_by_label batch sampler Dec 10, 2024
@tomaarsen
Copy link
Collaborator

tomaarsen commented Dec 23, 2024

Hello!
Apologies for the delay, I've been busy on a release.

Are you using the TripletLoss or the Batch...TripletLoss? Apologies for the confusion here, but there are fairly sizable differences:

  1. TripletLoss: Given (anchor, positive, negative) triplets, train such that anchor and positive get at least margin closer than anchor and negative.
  2. Batch...TripletLoss: Given text with class labels, the loss automatically finds pairs that should be more similar: texts from the same class, as well as pairs that should be less similar: texts from other classes. As you can expect, to get pairs that should be more similar, we need at least 2 texts from the same class in each batch. That's what the group_by_label batch sampler ensures.

In short, the latter benefits from group_by_label (at least in theory), whereas the former does not.

Could you let me know which of the two you are using?

  • Tom Aarsen

@AmoghM
Copy link
Author

AmoghM commented Dec 24, 2024

@tomaarsen No worries. Thanks for responding. The loss is BatchHardSoftMarginTripletLoss which is mentioned in the code snippet above. I'm pasting it here to avoid any further confusion.

trainer = SentenceTransformerTrainer(
        model=model,
        loss=losses.BatchHardSoftMarginTripletLoss(model),
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants