Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN running official KD code on different dataset, with packing + compile #2198

Open
thnkinbtfly opened this issue Dec 22, 2024 · 2 comments

Comments

@thnkinbtfly
Copy link

thnkinbtfly commented Dec 22, 2024

Hi, thanks for this great work!

With official code, I get NaN, if I change to the different dataset. Could anyone help this?

What's happening?

  • I get NaN during the training (about after 3500~3600 steps in my setting).
  • If I turn off the FlexAttention (by manually setting _SUPPORTS_FLEX_ATTENTION as False) it runs well without NaN for fairly long time
  • I tried to use the nightly torch (20241019) with CUDA 12.6 as well, but same thing happens.

Command

 tune run --nnodes 1 --nproc_per_node 8 knowledge_distillation_distributed --config llama3_2/8B_to_3B_KD_lora_distributed

Modified Config

# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py
# using a teacher and student model
#
# This config assumes that you've ran the following commands before launching KD:
# First download the student and teacher models
#   tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"
#   tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# You get better results using KD if the teacher model has already been fine-tuned on the target dataset:
#   tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora
#
# To launch on 2 devices, run the following command from root:
#   tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed
#
# This config works best for distilling on 2+ devices.


output_dir: llama3_2_8B_to_3B/KD_lora_distributed # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
  _component_: torchtune.models.llama3_2.lora_llama3_2_3b
  lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
  apply_lora_to_mlp: True
  apply_lora_to_output: False
  lora_rank: 256  # higher increases accuracy and memory
  lora_alpha: 512  # usually alpha=2*rank
  lora_dropout: 0.0

teacher_model:
  _component_: torchtune.models.llama3_1.llama3_1_8b

# Tokenizer
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: ../Llama-3.2-3B-Instruct/original/tokenizer.model
  max_seq_len: 8192

checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: ../Llama-3.2-3B-Instruct/
  checkpoint_files: [
    model-00001-of-00002.safetensors,
    model-00002-of-00002.safetensors,
  ]
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: False
save_adapter_weights_only: False

# Teacher checkpoint
teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: ../Meta-Llama-3.1-8B-Instruct/
  checkpoint_files: [
    model-00001-of-00004.safetensors,
    model-00002-of-00004.safetensors,
    model-00003-of-00004.safetensors,
    model-00004-of-00004.safetensors
  ]
  recipe_checkpoint: null
  output_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
  model_type: LLAMA3

# Dataset and Sampler
dataset:
  _component_: torchtune.datasets.chat_dataset
  source: arcee-ai/The-Tome
  packed: True  # True increases speed
  conversation_column: conversations
  conversation_style: sharegpt
seed: 42
shuffle: True
batch_size: 1

# Optimizer and Scheduler
optimizer:
  _component_: torch.optim.AdamW
  fused: True
  weight_decay: 0.01
  lr: 3e-4
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100

loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss

kd_loss:
  _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5

# Training
epochs: 1
max_steps_per_epoch: null
compile: True  # torch.compile the model + loss, True increases speed + decreases memory
gradient_accumulation_steps: 2  # Use to increase effective batch size

# Logging
metric_logger:
  _component_: torchtune.training.metric_logging.DiskLogger
  log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: False

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: False  # True reduces memory
enable_activation_offloading: False  # True reduces memory

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
profiler:
  _component_: torchtune.training.setup_torch_profiler

  enabled: False

  #Output directory of trace artifacts
  output_dir: ${output_dir}/profiling_outputs

  #`torch.profiler.ProfilerActivity` types to trace
  cpu: True
  cuda: True

  #trace options passed to `torch.profiler.profile`
  profile_memory: False
  with_stack: False
  record_shapes: True
  with_flops: False

  # `torch.profiler.schedule` options:
  # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
  wait_steps: 5
  warmup_steps: 5
  active_steps: 2
  num_cycles: 1

Environment

torchtune@f2bd4bc
torch: 2.5.1+cu121
CUDA: 12.5.40 (nvidia pytorch:24.06-py3)
torchao: 0.7.0+cu121
GPUS: H100x8
@felipemello1
Copy link
Contributor

felipemello1 commented Dec 22, 2024

@thnkinbtfly, its hard to say, since it is in the middle of training. Is it within the first epoch? if so, then maybe it could be a bad sample. If not, then maybe looking at the loss curves/gradient norms could show if maybe the loss is diverging.

Would you be able to run it again and try to capture the sample that it failed on? That way, we could try to make it reproducible and if it is a flex attention in torch core or torchtune, it should be easier to fix.

By the way, I highly recommend you give weights and biases logger a try, and check the amount of memory you are using. If you have 8xH100, and you are doing LoRA, you are probably barely using your memory. You can increase your bsz and/or max_seq_len.

@felipemello1
Copy link
Contributor

Edited my comment above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants