From a04367151383dfb32b8f2f601b08bb914b8661b2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 10 Dec 2024 07:36:57 -0500 Subject: [PATCH 1/7] use DataCollatorWithFlattening when not sample packing --- src/axolotl/core/trainer_builder.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0f30f511c..092adf06c 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -28,6 +28,7 @@ from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( + DataCollatorWithFlattening, EarlyStoppingCallback, Trainer, TrainerCallback, @@ -1989,6 +1990,7 @@ def build_collator( V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, DataCollatorForSeq2Seq, + DataCollatorWithFlattening, RewardDataCollatorWithPadding, ] ] @@ -2011,6 +2013,8 @@ def build_collator( collator = MultiModalChatDataCollator kwargs["processor"] = self.processor kwargs["chat_template"] = training_args.chat_template + elif self.cfg.flash_attention: + collator = DataCollatorWithFlattening else: collator = DataCollatorForSeq2Seq From a2d9d1f76c7ffc704c6ec62e11a77868b3528f09 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 10 Dec 2024 08:43:57 -0500 Subject: [PATCH 2/7] DataCollatorWithFlattening doesn't accept most args/kwargs --- src/axolotl/core/trainer_builder.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 092adf06c..c688071b6 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1994,6 +1994,7 @@ def build_collator( RewardDataCollatorWithPadding, ] ] + collator_args = [self.tokenizer] if self.cfg.reward_model: collator = RewardDataCollatorWithPadding if "max_length" in kwargs: @@ -2015,12 +2016,16 @@ def build_collator( kwargs["chat_template"] = training_args.chat_template elif self.cfg.flash_attention: collator = DataCollatorWithFlattening + collator_args.pop(0) + kwargs.pop("pad_to_multiple_of", None) + kwargs.pop("padding", None) else: collator = DataCollatorForSeq2Seq + kwargs["return_tensors"] = "pt" + return collator( - self.tokenizer, - return_tensors="pt", + *collator_args, **kwargs, ) From 5528328f712f17d411b2025220991a60d57f2789 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 11 Dec 2024 20:15:46 -0500 Subject: [PATCH 3/7] restrict use cases for flattening --- src/axolotl/core/trainer_builder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index c688071b6..ff04befea 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -2014,7 +2014,11 @@ def build_collator( collator = MultiModalChatDataCollator kwargs["processor"] = self.processor kwargs["chat_template"] = training_args.chat_template - elif self.cfg.flash_attention: + elif ( + self.cfg.flash_attention + and self.cfg.micro_batch_size > 1 + and not self.cfg.sample_packing + ): collator = DataCollatorWithFlattening collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) From 9e0805c4b7db7dd2729b8fb4fcf2f5d60da4e8a8 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 12 Dec 2024 13:23:33 -0500 Subject: [PATCH 4/7] add validation for batch flattening --- src/axolotl/core/trainer_builder.py | 6 +----- .../utils/config/models/input/v0_4_1/__init__.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ff04befea..54ee19536 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -2014,11 +2014,7 @@ def build_collator( collator = MultiModalChatDataCollator kwargs["processor"] = self.processor kwargs["chat_template"] = training_args.chat_template - elif ( - self.cfg.flash_attention - and self.cfg.micro_batch_size > 1 - and not self.cfg.sample_packing - ): + elif self.cfg.batch_flattening: collator = DataCollatorWithFlattening collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 69baf9af2..ae4cca975 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -696,6 +696,8 @@ class Config: curriculum_sampling: Optional[bool] = None multipack_real_batches: Optional[bool] = None + batch_flattening: Optional[bool] = None + # for PoSE context length extension use_pose: Optional[bool] = None pose_split_on_token_ids: Optional[List[int]] = None @@ -924,6 +926,19 @@ def check_sample_packing_wo_flash(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_batch_flattening_fa(cls, data): + if data.get("batch_flattening"): + if not data.get("flash_attention"): + raise ValueError("batch_flattening requires flash attention") + if data.get("sample_packing"): + raise ValueError("batch_flattening not compatible with sample_packing") + if data.get("micro_batch_size") == 1: + LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + + return data + @model_validator(mode="before") @classmethod def check_sample_packing_w_rl(cls, data): From 513dbcb9df97b283618e0c3a40961a31f887c414 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Dec 2024 13:13:09 -0500 Subject: [PATCH 5/7] add auto support for batch flattening and add tests --- docs/config.qmd | 3 + .../config/models/input/v0_4_1/__init__.py | 19 +++-- tests/patched/test_validation.py | 70 +++++++++++++++++++ 3 files changed, 88 insertions(+), 4 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index d52170959..70679791e 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -245,6 +245,9 @@ sample_packing_group_size: 100000 # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. sample_packing_bin_size: 200 +# Use batch flattening for speedups when not using sample_packing +batch_flattening: + # Passed through to transformers when loading the model when launched without accelerate # Use `sequential` when training w/ model parallelism to limit memory device_map: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index ae4cca975..5ddf04811 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -696,7 +696,7 @@ class Config: curriculum_sampling: Optional[bool] = None multipack_real_batches: Optional[bool] = None - batch_flattening: Optional[bool] = None + batch_flattening: Optional[Union[Literal["auto"], bool]] = None # for PoSE context length extension use_pose: Optional[bool] = None @@ -930,13 +930,24 @@ def check_sample_packing_wo_flash(cls, data): @classmethod def check_batch_flattening_fa(cls, data): if data.get("batch_flattening"): - if not data.get("flash_attention"): + batch_flattening_auto = data.get("batch_flattening") == "auto" + if not data.get("flash_attention") and not batch_flattening_auto: raise ValueError("batch_flattening requires flash attention") - if data.get("sample_packing"): + if data.get("sample_packing") and not batch_flattening_auto: raise ValueError("batch_flattening not compatible with sample_packing") - if data.get("micro_batch_size") == 1: + if data.get("micro_batch_size") == 1 and not batch_flattening_auto: LOG.warning("batch_flattening has no effect with micro_batch_size == 1") + if ( + batch_flattening_auto + and data.get("flash_attention") + and not data.get("sample_packing") + and data.get("micro_batch_size") > 1 + ): + data["batch_flattening"] = True + elif batch_flattening_auto: + data["batch_flattening"] = False + return data @model_validator(mode="before") diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 3d1b74789..9d41dac76 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1236,6 +1236,76 @@ def test_torch_compile_auto(self, minimal_cfg): assert updated_cfg.torch_compile is False +class TestSampleOptimConfigValidation(BaseValidation): + """ + test configurations for sample optimizations like batch flattening + """ + + def test_batch_flattening_auto_enables(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": None, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is True + + def test_batch_flattening_auto_no_fa(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": False, + "sample_packing": None, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + def test_batch_flattening_auto_mbsz_1(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": None, + "micro_batch_size": 1, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + def test_batch_flattening_auto_packing(self, minimal_cfg): + cfg = ( + DictDefault( + { + "flash_attention": True, + "sample_packing": True, + "micro_batch_size": 2, + "batch_flattening": "auto", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg["batch_flattening"] is False + + class TestValidationCheckModelConfig(BaseValidation): """ Test the validation for the config when the model config is available From fb0752a5e409a3462055099204eb584ff53df8f2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Dec 2024 13:55:45 -0500 Subject: [PATCH 6/7] add e2e for flattening --- tests/e2e/test_llama.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index 33d12157a..c7d000744 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -104,3 +104,43 @@ def test_fix_untrained_tokens(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() + + + def test_batch_flattening(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "trust_remote_code": True, + "sequence_len": 512, + "val_set_size": 0.01, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "max_steps": 5, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "flash_attention": True, + "sample_packing": False, + "batch_flattening": True, + "bf16": True, + "save_safetensors": True, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() From 838aae1761a344bb80afda89a032bad00d861fb2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Dec 2024 16:45:19 -0500 Subject: [PATCH 7/7] chore: lint - merge conflict --- tests/e2e/test_llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/e2e/test_llama.py b/tests/e2e/test_llama.py index c7d000744..1ce9d60b9 100644 --- a/tests/e2e/test_llama.py +++ b/tests/e2e/test_llama.py @@ -105,7 +105,6 @@ def test_fix_untrained_tokens(self, temp_dir): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() - def test_batch_flattening(self, temp_dir): # pylint: disable=duplicate-code cfg = DictDefault(