Skip to content

Commit

Permalink
make it work
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Dec 19, 2024
1 parent 45d708a commit 8606ff4
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 45 deletions.
52 changes: 31 additions & 21 deletions src/axolotl/core/trainers/kd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,52 @@ def kd_loss_function(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch: Optional[int] = None,
**kwargs, # pylint: disable=unused-argument
):
# student_logits: [B, seq_len, vocab_size] from the student's forward pass
# target_token_ids: [B, teacher_seq_len, K] top-K token IDs from teacher
# target_logprobs: [B, teacher_seq_len, K] teacher logprobs for these top-K tokens
# teacher_mask: [B, teacher_seq_len, K], where 1 indicates a valid token and 0 indicates padding

# Determine the teacher sequence length
teacher_seq_len = target_token_ids.shape[1]

# Slice the student logits to match the teacher-provided seq length
# Slice student logits to match the teacher-provided sequence length
student_logits_for_kd = student_logits[
:, -teacher_seq_len:, :
] # Now [B, teacher_seq_len, vocab_size]
] # [B, teacher_seq_len, vocab_size]

# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]

# Convert student top-K logits to logprobs
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
)
) # [B, seq_len, K]

# teacher_probs are simply exp of teacher_logprobs (already scaled)
# Convert teacher_mask to boolean for indexing
valid_mask = target_mask.bool()

# Prune tensors to only keep valid tokens
# This will result in 1D arrays of only valid positions
student_logprobs_topk = student_logprobs_topk[valid_mask] # [N_valid_tokens]
target_logprobs = target_logprobs[valid_mask] # [N_valid_tokens]

# Since teacher_logprobs are already normalized, just exponentiate to get probabilities
teacher_probs = target_logprobs.exp()

# Compute forward KL
# L_kl = sum_k p^T_k (log p^T_k - log p^S_k)
kd_loss_per_position = (
teacher_probs * (target_logprobs - student_logprobs_topk)
).sum(
dim=-1
) # [B, teacher_seq_len]

# gradient accumulation fixes
if num_items_in_batch:
kd_loss = kd_loss_per_position.sum() / num_items_in_batch # Scalar
# Compute forward KL:
# KL = sum p^T_k (log p^T_k - log p^S_k), summed over all valid tokens.
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()

# Normalize by number of items or mean over valid tokens
if num_items_in_batch is not None:
# If you know how many items should be considered in the batch
kd_loss = kd_loss / num_items_in_batch
else:
kd_loss = kd_loss_per_position.mean() # Scalar
# Otherwise, just average over all valid tokens
kd_loss = kd_loss / kd_loss_per_token.size(0)

return kd_loss

Expand All @@ -70,6 +76,8 @@ def _set_signature_columns_if_needed(self):
columns_to_add.append("target_logprobs")
if "target_token_ids" not in self._signature_columns:
columns_to_add.append("target_token_ids")
if "target_mask" not in self._signature_columns:
columns_to_add.append("target_mask")
if columns_to_add:
self._signature_columns += columns_to_add

Expand All @@ -83,6 +91,7 @@ def compute_loss(
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")

if self.model_accepts_loss_kwargs:
loss_kwargs = {}
Expand All @@ -96,6 +105,7 @@ def compute_loss(
student_logits,
target_token_ids,
target_logprobs,
target_mask,
num_items_in_batch=num_items_in_batch,
)

Expand Down
21 changes: 20 additions & 1 deletion src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,23 @@ def __init__(

def transform_logprobs(self, sample):
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0])
target_logprobs = []
target_token_ids = []
target_mask = []

# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)

for _ in range(target_seq_len):
target_mask.append([1] * top_k)

for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
Expand Down Expand Up @@ -519,6 +534,7 @@ def transform_logprobs(self, sample):
# Apply temperature scaling at data load time
# log p_k^(T) = (log p_k / T) - logsumexp(log p_j / T)
position_logprobs_tensor = position_logprobs_tensor / self.temperature
# normalize to probabilities so they sum up to 1
position_logprobs_tensor = position_logprobs_tensor - torch.logsumexp(
position_logprobs_tensor, dim=0, keepdim=True
)
Expand All @@ -531,14 +547,17 @@ def transform_logprobs(self, sample):
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask

return sample

def tokenize_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super().tokenize_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
return self.transform_logprobs(tokenized_prompt)
tokenized_prompt = self.transform_logprobs(tokenized_prompt)

return tokenized_prompt


def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
Expand Down
66 changes: 43 additions & 23 deletions src/axolotl/utils/collators/kd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""
DataCollator for axolotl to handle KD fields
DataCollator for axolotl to handle KD fields without using -inf for padding,
and with a teacher_mask to identify padded positions.
"""

from dataclasses import dataclass
Expand All @@ -17,6 +18,9 @@
class DataCollatorForKD(DataCollatorForSeq2Seq):
"""
Data collator for KD, including handling KD-specific fields.
This version avoids using -inf and instead uses a large negative value for padding
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
"""

tokenizer: PreTrainedTokenizerBase
Expand All @@ -32,7 +36,9 @@ def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors

# Extract labels and position_ids first (as in original code)
padding_side = self.tokenizer.padding_side

# Pad labels and position_ids first
for feature_name, pad_token_id in [
("labels", self.label_pad_token_id),
("position_ids", self.position_pad_token_id),
Expand All @@ -46,7 +52,6 @@ def __call__(self, features, return_tensors=None):
// self.pad_to_multiple_of
) * self.pad_to_multiple_of

padding_side = self.tokenizer.padding_side
for f in features: # pylint: disable=invalid-name
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
if isinstance(f[feature_name], list):
Expand All @@ -69,63 +74,77 @@ def __call__(self, features, return_tensors=None):
# Handle target_logprobs and target_token_ids manually
target_logprobs_list = []
target_token_ids_list = []
target_mask_list = []
has_teacher_data = ("target_logprobs" in features[0]) and (
"target_token_ids" in features[0]
)

if has_teacher_data:
# Extract these fields
# Extract and remove from features
for f in features: # pylint: disable=invalid-name
target_logprobs_list.append(f.pop("target_logprobs"))
target_token_ids_list.append(f.pop("target_token_ids"))
target_mask_list.append(f.pop("target_mask"))

# Determine max lengths to pad
# Determine max lengths
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)

# Pad target_logprobs and target_token_ids
padded_target_logprobs = []
padded_target_token_ids = []
for t_logprobs, t_ids in zip(target_logprobs_list, target_token_ids_list):
# Pad seq dimension
padded_teacher_mask_list = []

for t_logprobs, t_ids, t_mask in zip(
target_logprobs_list, target_token_ids_list, target_mask_list
):
t_logprobs_padded = []
t_ids_padded = []
for i in range( # pylint: disable=consider-using-enumerate
len(t_logprobs)
t_mask_padded = []

for lp, ids, mask in zip( # pylint: disable=invalid-name
t_logprobs, t_ids, t_mask
):
lp = t_logprobs[i] # pylint: disable=invalid-name
ids = t_ids[i]
# Pad K dimension
lp_len = len(lp)
if lp_len < max_k:
lp = lp + [-float("inf")] * ( # pylint: disable=invalid-name
max_k - lp_len
) # or some pad value that won't break exp()
ids = ids + [0] * (max_k - lp_len)
# Use -1e9 for padding logprobs and 0 for token_ids
pad_len = max_k - lp_len
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
ids = ids + [0] * pad_len
mask = mask + [0] * pad_len
else:
lp = lp[:max_k] # pylint: disable=invalid-name
ids = ids[:max_k]
mask = mask[:max_k]

t_logprobs_padded.append(lp)
t_ids_padded.append(ids)
t_mask_padded.append(mask)

# If sequence is shorter than max_teacher_seq_len
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
if seq_len_diff > 0:
# Pad sequences fully if needed
t_logprobs_padded.extend(
[[-float("inf")] * max_k for _ in range(seq_len_diff)]
[[-1e9] * max_k for _ in range(seq_len_diff)]
)
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])

padded_target_logprobs.append(t_logprobs_padded)
padded_target_token_ids.append(t_ids_padded)
padded_teacher_mask_list.append(t_mask_padded)

# Convert to tensors
padded_target_logprobs = torch.tensor(
padded_target_logprobs, dtype=torch.float
)
# We can store token_ids as long tensor
padded_target_token_ids = torch.tensor(
padded_target_token_ids, dtype=torch.long
)
padded_teacher_mask_list = torch.tensor(
padded_teacher_mask_list, dtype=torch.int
)

# Now pad using tokenizer for the remaining fields (input_ids, attention_mask, etc.)
# Pad using tokenizer for regular fields
features = self.tokenizer.pad(
features,
padding=self.padding,
Expand All @@ -134,12 +153,13 @@ def __call__(self, features, return_tensors=None):
return_tensors=return_tensors,
)

# Add back the teacher data if it exists
# Add back teacher data if present
if has_teacher_data:
features["target_logprobs"] = padded_target_logprobs
features["target_token_ids"] = padded_target_token_ids
features["target_mask"] = padded_teacher_mask_list

# Prepare decoder_input_ids if applicable
# Prepare decoder_input_ids if the model supports it
if (
"labels" in features
and self.model is not None
Expand Down

0 comments on commit 8606ff4

Please sign in to comment.