-
-
Notifications
You must be signed in to change notification settings - Fork 898
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
b771f30
commit 6ec2677
Showing
1 changed file
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from typing import Any, Optional, Union | ||
|
||
from PIL import Image | ||
from transformers import PreTrainedTokenizerBase, ProcessorMixin | ||
from transformers.data.data_collator import DataCollatorMixin | ||
from transformers.utils import PaddingStrategy | ||
|
||
class ProcessingStrategies: | ||
def __init__(self, processor: ProcessorMixin): | ||
self.processor = processor | ||
|
||
@staticmethod | ||
def preprocess(examples: list[dict]) -> list[dict]: | ||
""" | ||
Preprocess conversation examples to ensure consistent format. | ||
Converts different conversation formats to OpenAI format with 'messages'. | ||
Supports two formats: | ||
1. OpenAI format with 'messages' | ||
2. Legacy format with 'conversations' | ||
Args: | ||
examples: list of conversation dictionaries | ||
Returns: | ||
dict in OpenAI format with 'messages' key | ||
Raises: | ||
ValueError: If the conversation format is not supported | ||
""" | ||
role_mapping = { | ||
"human": "user", | ||
"gpt": "assistant", | ||
} | ||
|
||
def normalize_role(role: str) -> str: | ||
"""Normalize role names to OpenAI format. Default to original role if not found.""" | ||
return role_mapping.get(role, role) | ||
|
||
def convert_legacy_format(example: dict) -> dict: | ||
"""Convert legacy 'conversations' format to OpenAI 'messages' format.""" | ||
messages = [ | ||
{ | ||
"role": normalize_role(convo["from"]), | ||
"content": convo["value"], | ||
} | ||
for convo in example["conversations"] | ||
] | ||
|
||
# Create new dict without 'conversations' key | ||
result = deepcopy(example) | ||
result.pop("conversations") | ||
result["messages"] = messages | ||
return result | ||
|
||
processed_examples = [] | ||
for example in examples: | ||
# OpenAI format | ||
if "messages" in example: | ||
processed_examples.append(example) | ||
|
||
# Legacy format | ||
elif "conversations" in example: | ||
processed_examples.append(convert_legacy_format(example)) | ||
|
||
else: | ||
raise ValueError( | ||
"Only `messages` and `conversations` message keys are currently supported." | ||
) | ||
|
||
return processed_examples | ||
|
||
@staticmethod | ||
def process_images(examples, max_images): | ||
""" | ||
Process images from examples, ensuring consistency in image presence and applying max_images limit. | ||
Args: | ||
examples: List of dictionaries that may contain 'images' key | ||
max_images: Maximum number of images to keep per example (0 means no limit) | ||
Returns: | ||
Either None (if no images) or List[Image objects] (if all examples have images) | ||
Raises: | ||
ValueError: If there's a mix of None and non-None images | ||
""" | ||
|
||
def get_image(example): | ||
if "images" not in example: | ||
return None | ||
images = example["images"] | ||
if isinstance(images, str): | ||
return Image.open(images) | ||
return images | ||
|
||
images = [get_image(example) for example in examples] | ||
|
||
# Count None and non-None images | ||
none_count = sum(1 for img in images if img is None) | ||
|
||
# All images are None | ||
if none_count == len(images): | ||
return None | ||
|
||
# Mix of None and non-None images | ||
if none_count > 0: | ||
raise ValueError( | ||
"All images should be either None or not None. " | ||
"Please provide images for all examples or None." | ||
) | ||
|
||
# Apply max_images limit if specified | ||
if max_images > 0: | ||
images = [ | ||
( | ||
img_batch[:max_images] | ||
if isinstance(img_batch, (list, tuple)) | ||
else img_batch | ||
) | ||
for img_batch in images | ||
] | ||
|
||
return images | ||
|
||
def process_texts(self, examples): | ||
texts = [ | ||
self.processor.apply_chat_template( | ||
example["messages"], chat_template=self.chat_template, tokenize=False | ||
) | ||
for example in examples | ||
] | ||
return texts | ||
|
||
@staticmethod | ||
def process_rows( | ||
examples, | ||
processor, | ||
chat_template, | ||
max_images, | ||
length_only=False, | ||
chat_template_type=None, | ||
): | ||
# HINT: use `_torch_collate_batch` to stack and pad tensors | ||
# see also DataCollatorWithFlattening and DefaultDataCollator | ||
|
||
# *** This is COPIED from the trl example sft_vlm.py code *** | ||
# use this as a starting point | ||
|
||
# Preprocess the examples | ||
examples = __class__.preprocess(examples) | ||
|
||
# Get the texts and images, and apply the chat template | ||
|
||
texts = __class__.process_texts(examples) | ||
|
||
images = __class__.process_images(examples, max_images=max_images) | ||
if chat_template_type == "llava": | ||
# LLava1.5 does not support multiple images | ||
images = [image[0] for image in images] | ||
|
||
# Tokenize the texts and process the images | ||
batch = processor(text=texts, images=images, return_tensors="pt", padding=True) | ||
|
||
# The labels are the input_ids, and we mask the padding tokens in the loss computation | ||
labels = batch["input_ids"].clone() | ||
labels[labels == processor.tokenizer.pad_token_id] = -100 # | ||
# Ignore the image token index in the loss computation (model specific) | ||
if chat_template_type == "qwen2_vl": | ||
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>") | ||
else: | ||
image_token_id = processor.tokenizer.convert_tokens_to_ids( | ||
processor.image_token | ||
) | ||
labels[labels == image_token_id] = -100 | ||
batch["labels"] = labels | ||
|
||
if length_only: | ||
return { | ||
"length": [len(sample["input_ids"]) for sample in batch["input_ids"]] | ||
} | ||
return batch | ||
|
||
class PixtralProcessingStrategies(ProcessingStrategies): | ||
|
||
@staticmethod | ||
def pixtral_chat_conversion(messages): | ||
is_single_message = not isinstance(messages, list) | ||
if is_single_message: | ||
messages = [messages] | ||
|
||
for i, message in enumerate(messages): | ||
if message["role"] == "user": | ||
for j, content in enumerate(message["content"]): | ||
if "type" in content and content["type"] == "text": | ||
messages[i]["content"][j] = { | ||
"type": "text", | ||
"content": content["text"], | ||
} | ||
|
||
if message["role"] == "assistant": | ||
messages[i]["content"] = message["content"][0]["text"] | ||
|
||
if is_single_message: | ||
return messages[0] | ||
return messages |