From 8606ff401a3e7a3fa01253cb90439ecaea9e032f Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 19 Dec 2024 00:28:02 -0500 Subject: [PATCH] make it work --- src/axolotl/core/trainers/kd.py | 52 +++++++++------ .../prompt_strategies/chat_template.py | 21 +++++- src/axolotl/utils/collators/kd.py | 66 ++++++++++++------- 3 files changed, 94 insertions(+), 45 deletions(-) diff --git a/src/axolotl/core/trainers/kd.py b/src/axolotl/core/trainers/kd.py index e84036079..f1b47f50c 100644 --- a/src/axolotl/core/trainers/kd.py +++ b/src/axolotl/core/trainers/kd.py @@ -13,46 +13,52 @@ def kd_loss_function( student_logits, target_token_ids, target_logprobs, + target_mask, num_items_in_batch: Optional[int] = None, - **kwargs, # pylint: disable=unused-argument ): - # student_logits: [B, seq_len, vocab_size] from the student's forward pass - # target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher - # target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens + # teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding + # Determine the teacher sequence length teacher_seq_len = target_token_ids.shape[1] - # Slice the student logits to match the teacher-provided seq length + # Slice student logits to match the teacher-provided sequence length student_logits_for_kd = student_logits[ :, -teacher_seq_len:, : - ] # Now [B, teacher_seq_len, vocab_size] + ] # [B, teacher_seq_len, vocab_size] # Gather student logits for teacher's top-K tokens student_logits_topk = torch.gather( student_logits_for_kd, dim=-1, index=target_token_ids ) # [B, teacher_seq_len, K] - # Convert student top-K logits to logprobs + # Convert student top-k logits to logprobs student_logprobs_topk = student_logits_topk - torch.logsumexp( student_logits_topk, dim=-1, keepdim=True - ) + ) # [B, seq_len, K] - # teacher_probs are simply exp of teacher_logprobs (already scaled) + # Convert teacher_mask to boolean for indexing + valid_mask = target_mask.bool() + + # Prune tensors to only keep valid tokens + # This will result in 1D arrays of only valid positions + student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens] + target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens] + + # Since teacher_logprobs are already normalized, just exponentiate to get probabilities teacher_probs = target_logprobs.exp() - # Compute forward KL - # L_kl = sum_k p^T_k (log p^T_k - log p^S_k) - kd_loss_per_position = ( - teacher_probs * (target_logprobs - student_logprobs_topk) - ).sum( - dim=-1 - ) # [B, teacher_seq_len] - - # gradient accumulation fixes - if num_items_in_batch: - kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar + # Compute forward KL: + # KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens. + kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) + kd_loss = kd_loss_per_token.sum() + + # Normalize by number of items or mean over valid tokens + if num_items_in_batch is not None: + # If you know how many items should be considered in the batch + kd_loss = kd_loss / num_items_in_batch else: - kd_loss = kd_loss_per_position.mean() # Scalar + # Otherwise, just average over all valid tokens + kd_loss = kd_loss / kd_loss_per_token.size(0) return kd_loss @@ -70,6 +76,8 @@ def _set_signature_columns_if_needed(self): columns_to_add.append("target_logprobs") if "target_token_ids" not in self._signature_columns: columns_to_add.append("target_token_ids") + if "target_mask" not in self._signature_columns: + columns_to_add.append("target_mask") if columns_to_add: self._signature_columns += columns_to_add @@ -83,6 +91,7 @@ def compute_loss( """ target_logprobs = inputs.pop("target_logprobs") target_token_ids = inputs.pop("target_token_ids") + target_mask = inputs.pop("target_mask") if self.model_accepts_loss_kwargs: loss_kwargs = {} @@ -96,6 +105,7 @@ def compute_loss( student_logits, target_token_ids, target_logprobs, + target_mask, num_items_in_batch=num_items_in_batch, ) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 5780f8fe3..ea13f634b 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -490,8 +490,23 @@ def __init__( def transform_logprobs(self, sample): logprobs = sample.pop(self.logprobs_field) + target_seq_len = len(logprobs) + input_seq_len = len(sample["input_ids"]) + padding_len = input_seq_len - target_seq_len + top_k = len(logprobs[0]) target_logprobs = [] target_token_ids = [] + target_mask = [] + + # fill with -inf for padding_len tokens for top_k tokens + # extend target_logprobs with a padding_len x top_k 2D list filled with -inf + for _ in range(padding_len): + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + + for _ in range(target_seq_len): + target_mask.append([1] * top_k) for _, token_pos_logprobs in enumerate(logprobs): # Initialize collections for logprobs and token_ids @@ -519,6 +534,7 @@ def transform_logprobs(self, sample): # Apply temperature scaling at data load time # log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T) position_logprobs_tensor = position_logprobs_tensor / self.temperature + # normalize to probabilities so they sum up to 1 position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp( position_logprobs_tensor, dim=0, keepdim=True ) @@ -531,6 +547,7 @@ def transform_logprobs(self, sample): # Update sample with transformed logprobs sample["target_logprobs"] = target_logprobs sample["target_token_ids"] = target_token_ids + sample["target_mask"] = target_mask return sample @@ -538,7 +555,9 @@ def tokenize_prompt(self, prompt): logprobs = prompt.pop(self.logprobs_field) tokenized_prompt = super().tokenize_prompt(prompt) tokenized_prompt[self.logprobs_field] = logprobs - return self.transform_logprobs(tokenized_prompt) + tokenized_prompt = self.transform_logprobs(tokenized_prompt) + + return tokenized_prompt def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): diff --git a/src/axolotl/utils/collators/kd.py b/src/axolotl/utils/collators/kd.py index a210d221c..256dbe7a8 100644 --- a/src/axolotl/utils/collators/kd.py +++ b/src/axolotl/utils/collators/kd.py @@ -1,5 +1,6 @@ """ -DataCollator for axolotl to handle KD fields +DataCollator for axolotl to handle KD fields without using -inf for padding, +and with a teacher_mask to identify padded positions. """ from dataclasses import dataclass @@ -17,6 +18,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): """ Data collator for KD, including handling KD-specific fields. + + This version avoids using -inf and instead uses a large negative value for padding + target_logprobs. It also creates a teacher_mask to indicate which entries are valid. """ tokenizer: PreTrainedTokenizerBase @@ -32,7 +36,9 @@ def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors - # Extract labels and position_ids first (as in original code) + padding_side = self.tokenizer.padding_side + + # Pad labels and position_ids first for feature_name, pad_token_id in [ ("labels", self.label_pad_token_id), ("position_ids", self.position_pad_token_id), @@ -46,7 +52,6 @@ def __call__(self, features, return_tensors=None): // self.pad_to_multiple_of ) * self.pad_to_multiple_of - padding_side = self.tokenizer.padding_side for f in features: # pylint: disable=invalid-name remainder = [pad_token_id] * (max_len - len(f[feature_name])) if isinstance(f[feature_name], list): @@ -69,63 +74,77 @@ def __call__(self, features, return_tensors=None): # Handle target_logprobs and target_token_ids manually target_logprobs_list = [] target_token_ids_list = [] + target_mask_list = [] has_teacher_data = ("target_logprobs" in features[0]) and ( "target_token_ids" in features[0] ) if has_teacher_data: - # Extract these fields + # Extract and remove from features for f in features: # pylint: disable=invalid-name target_logprobs_list.append(f.pop("target_logprobs")) target_token_ids_list.append(f.pop("target_token_ids")) + target_mask_list.append(f.pop("target_mask")) - # Determine max lengths to pad + # Determine max lengths max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list) max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq) - # Pad target_logprobs and target_token_ids padded_target_logprobs = [] padded_target_token_ids = [] - for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list): - # Pad seq dimension + padded_teacher_mask_list = [] + + for t_logprobs, t_ids, t_mask in zip( + target_logprobs_list, target_token_ids_list, target_mask_list + ): t_logprobs_padded = [] t_ids_padded = [] - for i in range( # pylint: disable=consider-using-enumerate - len(t_logprobs) + t_mask_padded = [] + + for lp, ids, mask in zip( # pylint: disable=invalid-name + t_logprobs, t_ids, t_mask ): - lp = t_logprobs[i] # pylint: disable=invalid-name - ids = t_ids[i] - # Pad K dimension lp_len = len(lp) if lp_len < max_k: - lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name - max_k - lp_len - ) # or some pad value that won't break exp() - ids = ids + [0] * (max_k - lp_len) + # Use -1e9 for padding logprobs and 0 for token_ids + pad_len = max_k - lp_len + lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name + ids = ids + [0] * pad_len + mask = mask + [0] * pad_len + else: + lp = lp[:max_k] # pylint: disable=invalid-name + ids = ids[:max_k] + mask = mask[:max_k] + t_logprobs_padded.append(lp) t_ids_padded.append(ids) + t_mask_padded.append(mask) - # If sequence is shorter than max_teacher_seq_len seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded) if seq_len_diff > 0: + # Pad sequences fully if needed t_logprobs_padded.extend( - [[-float("inf")] * max_k for _ in range(seq_len_diff)] + [[-1e9] * max_k for _ in range(seq_len_diff)] ) t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)]) + t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)]) padded_target_logprobs.append(t_logprobs_padded) padded_target_token_ids.append(t_ids_padded) + padded_teacher_mask_list.append(t_mask_padded) # Convert to tensors padded_target_logprobs = torch.tensor( padded_target_logprobs, dtype=torch.float ) - # We can store token_ids as long tensor padded_target_token_ids = torch.tensor( padded_target_token_ids, dtype=torch.long ) + padded_teacher_mask_list = torch.tensor( + padded_teacher_mask_list, dtype=torch.int + ) - # Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.) + # Pad using tokenizer for regular fields features = self.tokenizer.pad( features, padding=self.padding, @@ -134,12 +153,13 @@ def __call__(self, features, return_tensors=None): return_tensors=return_tensors, ) - # Add back the teacher data if it exists + # Add back teacher data if present if has_teacher_data: features["target_logprobs"] = padded_target_logprobs features["target_token_ids"] = padded_target_token_ids + features["target_mask"] = padded_teacher_mask_list - # Prepare decoder_input_ids if applicable + # Prepare decoder_input_ids if the model supports it if ( "labels" in features and self.model is not None