You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
With official code, I get NaN, if I change to the different dataset. Could anyone help this?
What's happening?
I get NaN during the training (about after 3500~3600 steps in my setting).
If I turn off the FlexAttention (by manually setting _SUPPORTS_FLEX_ATTENTION as False) it runs well without NaN for fairly long time
I tried to use the nightly torch (20241019) with CUDA 12.6 as well, but same thing happens.
Command
tune run --nnodes 1 --nproc_per_node 8 knowledge_distillation_distributed --config llama3_2/8B_to_3B_KD_lora_distributed
Modified Config
# Config for multi-device knowledge distillation in knowledge_distillation_distributed.py# using a teacher and student model## This config assumes that you've ran the following commands before launching KD:# First download the student and teacher models# tune download meta-llama/Llama-3.2-1B-Instruct --output-dir /tmp/Llama-3.2-1B-Instruct --ignore-patterns "original/consolidated.00.pth"# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"## You get better results using KD if the teacher model has already been fine-tuned on the target dataset:# tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3_1/8B_lora## To launch on 2 devices, run the following command from root:# tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config llama3_2/8B_to_1B_KD_lora_distributed## This config works best for distilling on 2+ devices.output_dir: llama3_2_8B_to_3B/KD_lora_distributed # /tmp may be deleted by your system. Change it to your preference.# Model Argumentsmodel:
_component_: torchtune.models.llama3_2.lora_llama3_2_3blora_attn_modules: ['q_proj', 'v_proj', 'output_proj']apply_lora_to_mlp: Trueapply_lora_to_output: Falselora_rank: 256# higher increases accuracy and memorylora_alpha: 512# usually alpha=2*ranklora_dropout: 0.0teacher_model:
_component_: torchtune.models.llama3_1.llama3_1_8b# Tokenizertokenizer:
_component_: torchtune.models.llama3.llama3_tokenizerpath: ../Llama-3.2-3B-Instruct/original/tokenizer.modelmax_seq_len: 8192checkpointer:
_component_: torchtune.training.FullModelHFCheckpointercheckpoint_dir: ../Llama-3.2-3B-Instruct/checkpoint_files: [model-00001-of-00002.safetensors,model-00002-of-00002.safetensors,]recipe_checkpoint: nulloutput_dir: ${output_dir}model_type: LLAMA3resume_from_checkpoint: Falsesave_adapter_weights_only: False# Teacher checkpointteacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointercheckpoint_dir: ../Meta-Llama-3.1-8B-Instruct/checkpoint_files: [model-00001-of-00004.safetensors,model-00002-of-00004.safetensors,model-00003-of-00004.safetensors,model-00004-of-00004.safetensors]recipe_checkpoint: nulloutput_dir: /tmp/Meta-Llama-3.1-8B-Instruct/model_type: LLAMA3# Dataset and Samplerdataset:
_component_: torchtune.datasets.chat_datasetsource: arcee-ai/The-Tomepacked: True # True increases speedconversation_column: conversationsconversation_style: sharegptseed: 42shuffle: Truebatch_size: 1# Optimizer and Scheduleroptimizer:
_component_: torch.optim.AdamWfused: Trueweight_decay: 0.01lr: 3e-4lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmupnum_warmup_steps: 100loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLosskd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLosskd_ratio: 0.5# Trainingepochs: 1max_steps_per_epoch: nullcompile: True # torch.compile the model + loss, True increases speed + decreases memorygradient_accumulation_steps: 2# Use to increase effective batch size# Loggingmetric_logger:
_component_: torchtune.training.metric_logging.DiskLoggerlog_dir: ${output_dir}/logslog_every_n_steps: 1log_peak_memory_stats: False# Environmentdevice: cudadtype: bf16enable_activation_checkpointing: False # True reduces memoryenable_activation_offloading: False # True reduces memory# Show case the usage of pytorch profiler# Set enabled to False as it's only needed for debugging trainingprofiler:
_component_: torchtune.training.setup_torch_profilerenabled: False#Output directory of trace artifactsoutput_dir: ${output_dir}/profiling_outputs#`torch.profiler.ProfilerActivity` types to tracecpu: Truecuda: True#trace options passed to `torch.profiler.profile`profile_memory: Falsewith_stack: Falserecord_shapes: Truewith_flops: False# `torch.profiler.schedule` options:# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeatwait_steps: 5warmup_steps: 5active_steps: 2num_cycles: 1
@thnkinbtfly, its hard to say, since it is in the middle of training. Is it within the first epoch? if so, then maybe it could be a bad sample. If not, then maybe looking at the loss curves/gradient norms could show if maybe the loss is diverging.
Would you be able to run it again and try to capture the sample that it failed on? That way, we could try to make it reproducible and if it is a flex attention in torch core or torchtune, it should be easier to fix.
By the way, I highly recommend you give weights and biases logger a try, and check the amount of memory you are using. If you have 8xH100, and you are doing LoRA, you are probably barely using your memory. You can increase your bsz and/or max_seq_len.
Hi, thanks for this great work!
With official code, I get NaN, if I change to the different dataset. Could anyone help this?
What's happening?
Command
Modified Config
Environment
The text was updated successfully, but these errors were encountered: