Skip to content

Commit

Permalink
DBRX Model Support (#1462)
Browse files Browse the repository at this point in the history
* wip for dbrx finetuning

* add fastcore for parallel loading of sharded weights

* fix dtype for load, use PartialState instead of accelerator to init process group, remove redundant wandb callback

* update to use v2 of the converted model

* more fixes for dbrx loras

* make sure to enable fsdp activation checkpointing

* fix support for 8bit loras too for dbrx

* apply z3 leaf moe fix for DBRX with deepspeed

* don't raise value error since child module searches could fail and be ok

* revert a previous change to fix fsdp

* update mistral/mistral qlora+fsdp yamls

* fix qlora+fsdp quant storage type

* more edge cases for qlora-fsdp

* fixes for fsdp+qlora w optimizer in 8bit

* add bigstral z3 config and make sure to use full_state_dict for fsdp
  • Loading branch information
winglian authored Apr 12, 2024
1 parent 5ed2939 commit 132eb74
Show file tree
Hide file tree
Showing 19 changed files with 859 additions and 29 deletions.
2 changes: 2 additions & 0 deletions deepspeed_configs/zero3_bf16_cpuoffload_all.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
Expand Down
2 changes: 2 additions & 0 deletions deepspeed_configs/zero3_bf16_cpuoffload_params.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
{
"zero_force_ds_cpu_optimizer": false,
"zero_allow_untested_optimizer": true,
"zero_optimization": {
"stage": 3,
"offload_param": {
Expand Down
81 changes: 81 additions & 0 deletions examples/dbrx/16bit-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out

sequence_len: 512
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true
81 changes: 81 additions & 0 deletions examples/dbrx/8bit-lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true

load_in_8bit: true
load_in_4bit: false
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out

sequence_len: 512
sample_packing: false
pad_to_sequence_len: false

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

adapter: lora
lora_model_dir:
lora_r: 8
lora_alpha: 16
lora_dropout: 0.05
# w1, w2, & v1 will hang the trainer
lora_target_modules:
- q_proj # attn
- k_proj # attn
- v_proj # attn
- out_proj # attn
- layer # router
# - w1
# - w2
# - v1

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: false # don't use with fsdp_activation_checkpointing
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: false
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: DbrxBlock
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_activation_checkpointing: true
26 changes: 26 additions & 0 deletions examples/dbrx/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# DBRX MoE

Currently, for LoRA, only the `q_proj`, `k_proj`, `v_proj` `out_proj` and `layer` Linear layers are trainable.

We are using the "converted" base models based on [this issue](https://huggingface.co/databricks/dbrx-instruct/discussions/10)
where the Experts are fused as an `nn.Parameter` rather than a `nn.Linear` layer. However, the implementation
is still a bit buggy and attempting to train a LoRA adapter over those `w1`, `w2` and `v1` layers
results in the trainer hanging.


### FSDP
We've tested using the [`LnL-AI/dbrx-base-converted-v2`](https://huggingface.co/LnL-AI/dbrx-base-converted-v2) model as the base model for FSDP.

The high memory usage seen w/ FSDP is due to FSDP not supporting 8bit optimizers.

- 16-bit LoRA w/ FSDP
- ✅ w/o CPU Offload - 8x80GB uses ~80GiB/gpu
- ❌ w/ CPU Offload - `paged_adamw_8bit` optimizer errors from being on cpu
- ✅ 8-bit LoRA w/ FSDP
- ❌ 4-bit QLoRA w/ FSDP - errors w/: `Error an illegal memory access was encountered at line 90 in file /src/csrc/ops.cu`
- ✅ bf16 full finetune w/ FSDP, freezing all but first 8 layers (8x80GB uses ~78GiB/gpu)


### Deepspeed

WIP
56 changes: 56 additions & 0 deletions examples/dbrx/fft-ds-zero3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
base_model: LnL-AI/dbrx-base-converted-v2
trust_remote_code: true

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: ./out

sequence_len: 512
sample_packing: false
pad_to_sequence_len: false

unfrozen_parameters:
- transformer.blocks.[0-7].

wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch:
saves_per_epoch: 1
debug:
weight_decay: 0.0
deepspeed: deepspeed_configs/zero3_bf16.json
4 changes: 3 additions & 1 deletion examples/llama-2/qlora-fsdp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,14 @@ deepspeed:
weight_decay: 0.0
fsdp:
- full_shard
- auto_wrap
fsdp_config:
fsdp_limit_all_gathers: true
fsdp_sync_module_states: true
fsdp_offload_params: true
fsdp_use_orig_params: false
fsdp_cpu_ram_efficient_loading: true
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_state_dict_type: FULL_STATE_DICT
special_tokens:
63 changes: 63 additions & 0 deletions examples/mistral/bigstral-ds-zero3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
base_model: mistral-community/Mixtral-8x22B-v0.1
model_type: AutoModelForCausalLM
tokenizer_type: LlamaTokenizer
trust_remote_code: true

load_in_8bit: false
load_in_4bit: false
strict: false

unfrozen_parameters:
- ^lm_head.weight$
- ^model.embed_tokens.weight$
- model.layers.4[4-9]+.block_sparse_moe.gate
- model.layers.4[4-9]+.block_sparse_moe.experts
- model.layers.5[0-5]+.block_sparse_moe.gate
- model.layers.5[0-5]+.block_sparse_moe.experts

model_config:
output_router_logits: true

datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./out

sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 3
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0001

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

save_total_limit: 1
save_steps:
debug:
deepspeed: deepspeed_configs/zero3_bf16_cpuoffload_params.json
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
eos_token: "<|im_end|>"
tokens:
- "<|im_start|>"
Loading

0 comments on commit 132eb74

Please sign in to comment.