Skip to content

Commit

Permalink
higher level abstraction wip
Browse files Browse the repository at this point in the history
  • Loading branch information
bursteratom committed Dec 22, 2024
1 parent 6ec2677 commit 76436ad
Show file tree
Hide file tree
Showing 4 changed files with 276 additions and 220 deletions.
6 changes: 3 additions & 3 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.collators.mm_processing_strategies import get_processing_strategy
from axolotl.utils.models import ensure_dtype
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
Expand Down Expand Up @@ -2011,10 +2012,9 @@ def build_collator(
collator = BatchSamplerDataCollatorForSeq2Seq
else:
if self.cfg.processor_type and self.processor:

collator = MultiModalChatDataCollator
kwargs["processor"] = self.processor
kwargs["chat_template"] = training_args.chat_template
kwargs["chat_template_type"] = self.cfg.chat_template
kwargs["processing_strategy"] = get_processing_strategy(self.processor, training_args.chat_template, self.cfg.chat_template)
elif self.cfg.batch_flattening:
collator = DataCollatorWithFlattening
collator_args.pop(0)
Expand Down
188 changes: 11 additions & 177 deletions src/axolotl/utils/collators/mm_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
Collators for multi-modal chat messages and packing
"""

from copy import deepcopy
from dataclasses import dataclass
from typing import Any, Optional, Union

from PIL import Image
from transformers import PreTrainedTokenizerBase, ProcessorMixin
from transformers import PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollatorMixin
from transformers.utils import PaddingStrategy
from .mm_processing_strategies import ProcessingStrategy


@dataclass
Expand All @@ -19,10 +18,8 @@ class MultiModalChatDataCollator(DataCollatorMixin):
"""

tokenizer: PreTrainedTokenizerBase
processor: ProcessorMixin
processing_strategy: ProcessingStrategy
return_tensors: str = "pt"
chat_template: Optional[str] = None
chat_template_type: Optional[str] = None
packing: bool = False
max_images: int = -1
padding: Union[bool, str, PaddingStrategy] = True
Expand All @@ -38,154 +35,16 @@ def torch_call(
# Handle dict or lists with proper padding and conversion to tensor.
return self.__class__.process_rows(
examples,
self.processor,
self.chat_template,
self.processing_strategy,
self.max_images,
chat_template_type=self.chat_template_type,
)

@staticmethod
def preprocess(examples: list[dict]) -> list[dict]:
"""
Preprocess conversation examples to ensure consistent format.
Converts different conversation formats to OpenAI format with 'messages'.
Supports two formats:
1. OpenAI format with 'messages'
2. Legacy format with 'conversations'
Args:
examples: list of conversation dictionaries
Returns:
dict in OpenAI format with 'messages' key
Raises:
ValueError: If the conversation format is not supported
"""
role_mapping = {
"human": "user",
"gpt": "assistant",
}

def normalize_role(role: str) -> str:
"""Normalize role names to OpenAI format. Default to original role if not found."""
return role_mapping.get(role, role)

def convert_legacy_format(example: dict) -> dict:
"""Convert legacy 'conversations' format to OpenAI 'messages' format."""
messages = [
{
"role": normalize_role(convo["from"]),
"content": convo["value"],
}
for convo in example["conversations"]
]

# Create new dict without 'conversations' key
result = deepcopy(example)
result.pop("conversations")
result["messages"] = messages
return result

processed_examples = []
for example in examples:
# OpenAI format
if "messages" in example:
processed_examples.append(example)

# Legacy format
elif "conversations" in example:
processed_examples.append(convert_legacy_format(example))

else:
raise ValueError(
"Only `messages` and `conversations` message keys are currently supported."
)

return processed_examples

@staticmethod
def process_images(examples, max_images):
"""
Process images from examples, ensuring consistency in image presence and applying max_images limit.
Args:
examples: List of dictionaries that may contain 'images' key
max_images: Maximum number of images to keep per example (0 means no limit)
Returns:
Either None (if no images) or List[Image objects] (if all examples have images)
Raises:
ValueError: If there's a mix of None and non-None images
"""

def get_image(example):
if "images" not in example:
return None
images = example["images"]
if isinstance(images, str):
return Image.open(images)
return images

images = [get_image(example) for example in examples]

# Count None and non-None images
none_count = sum(1 for img in images if img is None)

# All images are None
if none_count == len(images):
return None

# Mix of None and non-None images
if none_count > 0:
raise ValueError(
"All images should be either None or not None. "
"Please provide images for all examples or None."
)

# Apply max_images limit if specified
if max_images > 0:
images = [
(
img_batch[:max_images]
if isinstance(img_batch, (list, tuple))
else img_batch
)
for img_batch in images
]

return images

@staticmethod
def pixtral_chat_conversion(messages):
is_single_message = not isinstance(messages, list)
if is_single_message:
messages = [messages]

for i, message in enumerate(messages):
if message["role"] == "user":
for j, content in enumerate(message["content"]):
if "type" in content and content["type"] == "text":
messages[i]["content"][j] = {
"type": "text",
"content": content["text"],
}

if message["role"] == "assistant":
messages[i]["content"] = message["content"][0]["text"]

if is_single_message:
return messages[0]
return messages

@staticmethod
def process_rows(
examples,
processor,
chat_template,
processing_strategy: ProcessingStrategy,
max_images,
length_only=False,
chat_template_type=None,
):
# HINT: use `_torch_collate_batch` to stack and pad tensors
# see also DataCollatorWithFlattening and DefaultDataCollator
Expand All @@ -194,45 +53,20 @@ def process_rows(
# use this as a starting point

# Preprocess the examples
examples = __class__.preprocess(examples)
examples = processing_strategy.preprocess(examples)

# Get the texts and images, and apply the chat template
if chat_template_type == "pixtral":
texts = [
processor.apply_chat_template(
__class__.pixtral_chat_conversion(example["messages"]),
chat_template=chat_template,
tokenize=False,
)
for example in examples
]
else:
texts = [
processor.apply_chat_template(
example["messages"], chat_template=chat_template, tokenize=False
)
for example in examples
]

images = __class__.process_images(examples, max_images=max_images)
if chat_template_type == "llava":
# LLava1.5 does not support multiple images
images = [image[0] for image in images]
texts = processing_strategy.process_texts(examples)
images = processing_strategy.process_images(examples, max_images)

# Tokenize the texts and process the images
batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
batch = processing_strategy.processor(text=texts, images=images, return_tensors="pt", padding=True)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
labels = batch["input_ids"].clone()
labels[labels == processor.tokenizer.pad_token_id] = -100 #
labels[labels == processing_strategy.processor.tokenizer.pad_token_id] = -100 #
# Ignore the image token index in the loss computation (model specific)
if chat_template_type == "qwen2_vl":
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
else:
image_token_id = processor.tokenizer.convert_tokens_to_ids(
processor.image_token
)
labels[labels == image_token_id] = -100
labels[labels == processing_strategy.image_token_id] = -100
batch["labels"] = labels

if length_only:
Expand Down
Loading

0 comments on commit 76436ad

Please sign in to comment.