Skip to content

Commit

Permalink
abstraction WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
bursteratom committed Dec 21, 2024
1 parent b771f30 commit 6ec2677
Showing 1 changed file with 206 additions and 0 deletions.
206 changes: 206 additions & 0 deletions src/axolotl/utils/collators/processing_strategies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Optional, Union

from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy

class ProcessingStrategies:
def __init__(self, processor: ProcessorMixin):
self.processor = processor

@staticmethod
def preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}

def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)

def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]

# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
result["messages"] = messages
return result

processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)

# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))

else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)

return processed_examples

@staticmethod
def process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""

def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images

images = [get_image(example) for example in examples]

# Count None and non-None images
none_count = sum(1 for img in images if img is None)

# All images are None
if none_count == len(images):
return None

# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)

# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]

return images

def process_texts(self, examples):
texts = [
self.processor.apply_chat_template(
example["messages"], chat_template=self.chat_template, tokenize=False
)
for example in examples
]
return texts

@staticmethod
def process_rows(
examples,
processor,
chat_template,
max_images,
length_only=False,
chat_template_type=None,
):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator

# *** This is COPIED from the trl example sft_vlm.py code ***
# use this as a starting point

# Preprocess the examples
examples = __class__.preprocess(examples)

# Get the texts and images, and apply the chat template

texts = __class__.process_texts(examples)

images = __class__.process_images(examples, max_images=max_images)
if chat_template_type == "llava":
# LLava1.5 does not support multiple images
images = [image[0] for image in images]

# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
if chat_template_type == "qwen2_vl":
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
else:
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels

if length_only:
return {
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]]
}
return batch

class PixtralProcessingStrategies(ProcessingStrategies):

@staticmethod
def pixtral_chat_conversion(messages):
is_single_message = not isinstance(messages, list)
if is_single_message:
messages = [messages]

for i, message in enumerate(messages):
if message["role"] == "user":
for j, content in enumerate(message["content"]):
if "type" in content and content["type"] == "text":
messages[i]["content"][j] = {
"type": "text",
"content": content["text"],
}

if message["role"] == "assistant":
messages[i]["content"] = message["content"][0]["text"]

if is_single_message:
return messages[0]
return messages

0 comments on commit 6ec2677

Please sign in to comment.