Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use DataCollatorWithFlattening when not sample packing #2167

Merged
merged 7 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@ sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200

# Use batch flattening for speedups when not using sample_packing
batch_flattening:

# Passed through to transformers when loading the model when launched without accelerate
# Use `sequential` when training w/ model parallelism to limit memory
device_map:
Expand Down
13 changes: 11 additions & 2 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
DataCollatorWithFlattening,
EarlyStoppingCallback,
Trainer,
TrainerCallback,
Expand Down Expand Up @@ -1989,9 +1990,11 @@ def build_collator(
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
]
]
collator_args = [self.tokenizer]
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
if "max_length" in kwargs:
Expand All @@ -2011,12 +2014,18 @@ def build_collator(
collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
else:
collator = DataCollatorForSeq2Seq

kwargs["return_tensors"] = "pt"

return collator(
self.tokenizer,
return_tensors="pt",
*collator_args,
**kwargs,
)

Expand Down
26 changes: 26 additions & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ class Config:
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None

batch_flattening: Optional[Union[Literal["auto"], bool]] = None

# for PoSE context length extension
use_pose: Optional[bool] = None
pose_split_on_token_ids: Optional[List[int]] = None
Expand Down Expand Up @@ -924,6 +926,30 @@ def check_sample_packing_wo_flash(cls, data):

return data

@model_validator(mode="before")
@classmethod
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:
LOG.warning("batch_flattening has no effect with micro_batch_size == 1")

if (
batch_flattening_auto
and data.get("flash_attention")
and not data.get("sample_packing")
and data.get("micro_batch_size") > 1
):
data["batch_flattening"] = True
elif batch_flattening_auto:
data["batch_flattening"] = False

return data

@model_validator(mode="before")
@classmethod
def check_sample_packing_w_rl(cls, data):
Expand Down
39 changes: 39 additions & 0 deletions tests/e2e/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,42 @@ def test_fix_untrained_tokens(self, temp_dir):

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

def test_batch_flattening(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 5,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": False,
"batch_flattening": True,
"bf16": True,
"save_safetensors": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
70 changes: 70 additions & 0 deletions tests/patched/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,76 @@ def test_torch_compile_auto(self, minimal_cfg):
assert updated_cfg.torch_compile is False


class TestSampleOptimConfigValidation(BaseValidation):
"""
test configurations for sample optimizations like batch flattening
"""

def test_batch_flattening_auto_enables(self, minimal_cfg):
cfg = (
DictDefault(
{
"flash_attention": True,
"sample_packing": None,
"micro_batch_size": 2,
"batch_flattening": "auto",
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)
assert new_cfg["batch_flattening"] is True

def test_batch_flattening_auto_no_fa(self, minimal_cfg):
cfg = (
DictDefault(
{
"flash_attention": False,
"sample_packing": None,
"micro_batch_size": 2,
"batch_flattening": "auto",
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)
assert new_cfg["batch_flattening"] is False

def test_batch_flattening_auto_mbsz_1(self, minimal_cfg):
cfg = (
DictDefault(
{
"flash_attention": True,
"sample_packing": None,
"micro_batch_size": 1,
"batch_flattening": "auto",
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)
assert new_cfg["batch_flattening"] is False

def test_batch_flattening_auto_packing(self, minimal_cfg):
cfg = (
DictDefault(
{
"flash_attention": True,
"sample_packing": True,
"micro_batch_size": 2,
"batch_flattening": "auto",
}
)
| minimal_cfg
)

new_cfg = validate_config(cfg)
assert new_cfg["batch_flattening"] is False


class TestValidationCheckModelConfig(BaseValidation):
"""
Test the validation for the config when the model config is available
Expand Down
Loading