From 41a4fc2ef98b8a06747e5cdd07f07d9bbe5662c3 Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Mon, 2 Dec 2024 16:43:52 -0500 Subject: [PATCH] feat: Enable `phoenix.evals` to handle multimodal message templates (#5522) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Templates return lists of messages from `format` method * Update model wrapper base class * Migrate model wrappers * Resolve type errors * Fix more type issues * Wire up full backwards compatibility * Do not error not on PromptMessage inputs to model wrappers * Do not use list subscripting * Refactor to use generic `parts` instead of `messages` * Ruff 🐶 --- .../src/phoenix/evals/classify.py | 7 +- .../src/phoenix/evals/generate.py | 15 ++- .../src/phoenix/evals/models/anthropic.py | 27 ++-- .../src/phoenix/evals/models/base.py | 16 ++- .../src/phoenix/evals/models/bedrock.py | 31 +++-- .../src/phoenix/evals/models/litellm.py | 26 +++- .../src/phoenix/evals/models/mistralai.py | 25 +++- .../src/phoenix/evals/models/openai.py | 20 ++- .../src/phoenix/evals/models/vertex.py | 28 +++-- .../src/phoenix/evals/models/vertexai.py | 18 ++- .../src/phoenix/evals/templates.py | 118 +++++++++++++----- .../phoenix/evals/templates/test_template.py | 4 +- 12 files changed, 245 insertions(+), 90 deletions(-) diff --git a/packages/phoenix-evals/src/phoenix/evals/classify.py b/packages/phoenix-evals/src/phoenix/evals/classify.py index 1dc58a8af1..8d47145da4 100644 --- a/packages/phoenix-evals/src/phoenix/evals/classify.py +++ b/packages/phoenix-evals/src/phoenix/evals/classify.py @@ -27,6 +27,7 @@ from phoenix.evals.models import BaseModel, OpenAIModel, set_verbosity from phoenix.evals.templates import ( ClassificationTemplate, + MultimodalPrompt, PromptOptions, PromptTemplate, normalize_classification_template, @@ -177,7 +178,7 @@ def llm_classify( if generation_info := model.verbose_generation_info(): printif(verbose, generation_info) - def _map_template(data: pd.Series[Any]) -> str: + def _map_template(data: pd.Series[Any]) -> MultimodalPrompt: try: variables = {var: data[var] for var in eval_template.variables} empty_keys = [k for k, v in variables.items() if v is None] @@ -217,7 +218,7 @@ async def _run_llm_classification_async(input_data: pd.Series[Any]) -> ParsedLLM prompt, instruction=system_instruction, **model_kwargs ) inference, explanation = _process_response(response) - return inference, explanation, response, prompt + return inference, explanation, response, str(prompt) def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMResponse: with set_verbosity(model, verbose) as verbose_model: @@ -226,7 +227,7 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons prompt, instruction=system_instruction, **model_kwargs ) inference, explanation = _process_response(response) - return inference, explanation, response, prompt + return inference, explanation, response, str(prompt) fallback_return_value: ParsedLLMResponse = (None, None, "", "") diff --git a/packages/phoenix-evals/src/phoenix/evals/generate.py b/packages/phoenix-evals/src/phoenix/evals/generate.py index 1ce3e27113..4f8892cb43 100644 --- a/packages/phoenix-evals/src/phoenix/evals/generate.py +++ b/packages/phoenix-evals/src/phoenix/evals/generate.py @@ -8,6 +8,7 @@ ) from phoenix.evals.models import BaseModel, set_verbosity from phoenix.evals.templates import ( + MultimodalPrompt, PromptTemplate, map_template, normalize_prompt_template, @@ -90,7 +91,9 @@ def llm_generate( logger.info(f"Template variables: {template.variables}") prompts = map_template(dataframe, template) - async def _run_llm_generation_async(enumerated_prompt: Tuple[int, str]) -> Dict[str, Any]: + async def _run_llm_generation_async( + enumerated_prompt: Tuple[int, MultimodalPrompt], + ) -> Dict[str, Any]: index, prompt = enumerated_prompt with set_verbosity(model, verbose) as verbose_model: response = await verbose_model._async_generate( @@ -99,12 +102,14 @@ async def _run_llm_generation_async(enumerated_prompt: Tuple[int, str]) -> Dict[ ) parsed_response = output_parser(response, index) if include_prompt: - parsed_response["prompt"] = prompt + parsed_response["prompt"] = str(prompt) if include_response: parsed_response["response"] = response return parsed_response - def _run_llm_generation_sync(enumerated_prompt: Tuple[int, str]) -> Dict[str, Any]: + def _run_llm_generation_sync( + enumerated_prompt: Tuple[int, MultimodalPrompt], + ) -> Dict[str, Any]: index, prompt = enumerated_prompt with set_verbosity(model, verbose) as verbose_model: response = verbose_model._generate( @@ -113,7 +118,7 @@ def _run_llm_generation_sync(enumerated_prompt: Tuple[int, str]) -> Dict[str, An ) parsed_response = output_parser(response, index) if include_prompt: - parsed_response["prompt"] = prompt + parsed_response["prompt"] = str(prompt) if include_response: parsed_response["response"] = response return parsed_response @@ -133,5 +138,5 @@ def _run_llm_generation_sync(enumerated_prompt: Tuple[int, str]) -> Dict[str, An exit_on_error=True, fallback_return_value=fallback_return_value, ) - results, _ = executor.run(list(enumerate(prompts.tolist()))) + results, _ = executor.run(list(enumerate(prompts))) return pd.DataFrame.from_records(results, index=dataframe.index) diff --git a/packages/phoenix-evals/src/phoenix/evals/models/anthropic.py b/packages/phoenix-evals/src/phoenix/evals/models/anthropic.py index b16f4e168a..c36147682f 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/anthropic.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/anthropic.py @@ -1,9 +1,10 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from phoenix.evals.exceptions import PhoenixContextLimitExceeded from phoenix.evals.models.base import BaseModel from phoenix.evals.models.rate_limiters import RateLimiter +from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType MINIMUM_ANTHROPIC_VERSION = "0.18.0" @@ -110,9 +111,12 @@ def invocation_parameters(self) -> Dict[str, Any]: "top_k": self.top_k, } - def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str: # instruction is an invalid input to Anthropic models, it is passed in by # BaseEvalModel.__call__ and needs to be removed + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + kwargs.pop("instruction", None) invocation_parameters = self.invocation_parameters() invocation_parameters.update(kwargs) @@ -138,9 +142,14 @@ def _completion(**kwargs: Any) -> Any: return _completion(**kwargs) - async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + async def _async_generate( + self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any] + ) -> str: # instruction is an invalid input to Anthropic models, it is passed in by # BaseEvalModel.__call__ and needs to be removed + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + kwargs.pop("instruction", None) invocation_parameters = self.invocation_parameters() invocation_parameters.update(kwargs) @@ -166,8 +175,12 @@ async def _async_completion(**kwargs: Any) -> Any: return await _async_completion(**kwargs) - def _format_prompt_for_claude(self, prompt: str) -> List[Dict[str, str]]: + def _format_prompt_for_claude(self, prompt: MultimodalPrompt) -> List[Dict[str, str]]: # the Anthropic messages API expects a list of messages - return [ - {"role": "user", "content": prompt}, - ] + messages = [] + for part in prompt.parts: + if part.content_type == PromptPartContentType.TEXT: + messages.append({"role": "user", "content": part.content}) + else: + raise ValueError(f"Unsupported content type: {part.content_type}") + return messages diff --git a/packages/phoenix-evals/src/phoenix/evals/models/base.py b/packages/phoenix-evals/src/phoenix/evals/models/base.py index 5135a4ae7f..8e99cab320 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/base.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/base.py @@ -4,9 +4,10 @@ from dataclasses import dataclass, field from typing import Any, Generator, Optional, Sequence -from typing_extensions import TypeVar +from typing_extensions import TypeVar, Union from phoenix.evals.models.rate_limiters import RateLimiter +from phoenix.evals.templates import MultimodalPrompt T = TypeVar("T", bound=type) @@ -61,11 +62,13 @@ def _model_name(self) -> str: def reload_client(self) -> None: pass - def __call__(self, prompt: str, instruction: Optional[str] = None, **kwargs: Any) -> str: + def __call__( + self, prompt: Union[str, MultimodalPrompt], instruction: Optional[str] = None, **kwargs: Any + ) -> str: """Run the LLM on the given prompt.""" - if not isinstance(prompt, str): + if not isinstance(prompt, (str, MultimodalPrompt)): raise TypeError( - "Invalid type for argument `prompt`. Expected a string but found " + "Invalid type for argument `prompt`. Expected a string or PromptMessages but found " f"{type(prompt)}. If you want to run the LLM on multiple prompts, use " "`generate` instead." ) @@ -74,6 +77,7 @@ def __call__(self, prompt: str, instruction: Optional[str] = None, **kwargs: Any "Invalid type for argument `instruction`. Expected a string but found " f"{type(instruction)}." ) + return self._generate(prompt=prompt, instruction=instruction, **kwargs) def verbose_generation_info(self) -> str: @@ -82,11 +86,11 @@ def verbose_generation_info(self) -> str: return "" @abstractmethod - async def _async_generate(self, prompt: str, **kwargs: Any) -> str: + async def _async_generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str: raise NotImplementedError @abstractmethod - def _generate(self, prompt: str, **kwargs: Any) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str: raise NotImplementedError @staticmethod diff --git a/packages/phoenix-evals/src/phoenix/evals/models/bedrock.py b/packages/phoenix-evals/src/phoenix/evals/models/bedrock.py index 560e8fe352..297f19572d 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/bedrock.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/bedrock.py @@ -3,11 +3,12 @@ import logging from dataclasses import dataclass, field from functools import partial -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from phoenix.evals.exceptions import PhoenixContextLimitExceeded from phoenix.evals.models.base import BaseModel from phoenix.evals.models.rate_limiters import RateLimiter +from phoenix.evals.templates import MultimodalPrompt logger = logging.getLogger(__name__) @@ -102,7 +103,10 @@ def _init_rate_limiter(self) -> None: enforcement_window_minutes=1, ) - def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + body = json.dumps(self._create_request_body(prompt)) accept = "application/json" contentType = "application/json" @@ -113,7 +117,12 @@ def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: return self._parse_output(response) or "" - async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + async def _async_generate( + self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any] + ) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + loop = asyncio.get_event_loop() return await loop.run_in_executor(None, partial(self._generate, prompt, **kwargs)) @@ -148,13 +157,19 @@ def _format_prompt_for_claude(self, prompt: str) -> List[Dict[str, str]]: {"role": "user", "content": prompt}, ] - def _create_request_body(self, prompt: str) -> Dict[str, Any]: + def _create_request_body(self, prompt: MultimodalPrompt) -> Dict[str, Any]: # The request formats for bedrock models differ # see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html + + # TODO: Migrate to using the bedrock `converse` API + # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html + + prompt_str = prompt.to_text_only_prompt() + if self.model_id.startswith("ai21"): return { **{ - "prompt": prompt, + "prompt": prompt_str, "temperature": self.temperature, "topP": self.top_p, "maxTokens": self.max_tokens, @@ -166,7 +181,7 @@ def _create_request_body(self, prompt: str) -> Dict[str, Any]: return { **{ "anthropic_version": "bedrock-2023-05-31", - "messages": self._format_prompt_for_claude(prompt), + "messages": self._format_prompt_for_claude(prompt_str), "temperature": self.temperature, "top_p": self.top_p, "top_k": self.top_k, @@ -178,7 +193,7 @@ def _create_request_body(self, prompt: str) -> Dict[str, Any]: elif self.model_id.startswith("mistral"): return { **{ - "prompt": prompt, + "prompt": prompt_str, "max_tokens": self.max_tokens, "temperature": self.temperature, "stop": self.stop_sequences, @@ -192,7 +207,7 @@ def _create_request_body(self, prompt: str) -> Dict[str, Any]: logger.warn(f"Unknown format for model {self.model_id}, returning titan format...") return { **{ - "inputText": prompt, + "inputText": prompt_str, "textGenerationConfig": { "temperature": self.temperature, "topP": self.top_p, diff --git a/packages/phoenix-evals/src/phoenix/evals/models/litellm.py b/packages/phoenix-evals/src/phoenix/evals/models/litellm.py index c68d1b9335..505de536fc 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/litellm.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/litellm.py @@ -1,9 +1,10 @@ import logging import warnings from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from phoenix.evals.models.base import BaseModel +from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType logger = logging.getLogger(__name__) @@ -103,10 +104,18 @@ def _init_environment(self) -> None: package_name="litellm", ) - async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + async def _async_generate( + self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any] + ) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + return self._generate(prompt, **kwargs) - def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + messages = self._get_messages_from_prompt(prompt) response = self._litellm.completion( model=self.model, @@ -120,7 +129,12 @@ def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: ) return str(response.choices[0].message.content) - def _get_messages_from_prompt(self, prompt: str) -> List[Dict[str, str]]: + def _get_messages_from_prompt(self, prompt: MultimodalPrompt) -> List[Dict[str, str]]: # LiteLLM requires prompts in the format of messages - # messages=[{"content": "ABC?","role": "user"}] - return [{"content": prompt, "role": "user"}] + messages = [] + for part in prompt.parts: + if part.content_type == PromptPartContentType.TEXT: + messages.append({"content": part.content, "role": "user"}) + else: + raise ValueError(f"Unsupported content type: {part.content_type}") + return messages diff --git a/packages/phoenix-evals/src/phoenix/evals/models/mistralai.py b/packages/phoenix-evals/src/phoenix/evals/models/mistralai.py index 96744db758..2d29aeafbd 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/mistralai.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/mistralai.py @@ -1,8 +1,9 @@ from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from phoenix.evals.models.base import BaseModel from phoenix.evals.models.rate_limiters import RateLimiter +from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType if TYPE_CHECKING: from mistralai.models.chat_completion import ChatMessage @@ -106,9 +107,12 @@ def invocation_parameters(self) -> Dict[str, Any]: # Mistral is strict about not passing None values to the API return {k: v for k, v in params.items() if v is not None} - def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str: # instruction is an invalid input to Mistral models, it is passed in by # BaseEvalModel.__call__ and needs to be removed + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + kwargs.pop("instruction", None) invocation_parameters = self.invocation_parameters() invocation_parameters.update(kwargs) @@ -134,9 +138,14 @@ def _completion(**kwargs: Any) -> Any: return _completion(**kwargs) - async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + async def _async_generate( + self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any] + ) -> str: # instruction is an invalid input to Mistral models, it is passed in by # BaseEvalModel.__call__ and needs to be removed + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + kwargs.pop("instruction", None) invocation_parameters = self.invocation_parameters() invocation_parameters.update(kwargs) @@ -163,6 +172,12 @@ async def _async_completion(**kwargs: Any) -> Any: return await _async_completion(**kwargs) - def _format_prompt(self, prompt: str) -> List["ChatMessage"]: + def _format_prompt(self, prompt: MultimodalPrompt) -> List["ChatMessage"]: ChatMessage = self._ChatMessage - return [ChatMessage(role="user", content=prompt)] + messages = [] + for part in prompt.parts: + if part.content_type == PromptPartContentType.TEXT: + messages.append(ChatMessage(role="user", content=part.content)) + else: + raise ValueError(f"Unsupported content type: {part.content_type}") + return messages diff --git a/packages/phoenix-evals/src/phoenix/evals/models/openai.py b/packages/phoenix-evals/src/phoenix/evals/models/openai.py index edc2b79b3d..c7efe16229 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/openai.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/openai.py @@ -17,6 +17,7 @@ from phoenix.evals.exceptions import PhoenixContextLimitExceeded from phoenix.evals.models.base import BaseModel from phoenix.evals.models.rate_limiters import RateLimiter +from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType MINIMUM_OPENAI_VERSION = "1.0.0" MODEL_TOKEN_LIMIT_MAPPING = { @@ -278,9 +279,14 @@ def _init_rate_limiter(self) -> None: ) def _build_messages( - self, prompt: str, system_instruction: Optional[str] = None + self, prompt: MultimodalPrompt, system_instruction: Optional[str] = None ) -> List[Dict[str, str]]: - messages = [{"role": "system", "content": prompt}] + messages = [] + for parts in prompt.parts: + if parts.content_type == PromptPartContentType.TEXT: + messages.append({"role": "system", "content": parts.content}) + else: + raise ValueError(f"Unsupported content type: {parts.content_type}") if system_instruction: messages.insert(0, {"role": "system", "content": str(system_instruction)}) return messages @@ -288,7 +294,10 @@ def _build_messages( def verbose_generation_info(self) -> str: return f"OpenAI invocation parameters: {self.public_invocation_params}" - async def _async_generate(self, prompt: str, **kwargs: Any) -> str: + async def _async_generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + invoke_params = self.invocation_params messages = self._build_messages(prompt, kwargs.get("instruction")) if functions := kwargs.get("functions"): @@ -307,7 +316,10 @@ async def _async_generate(self, prompt: str, **kwargs: Any) -> str: return str(function_call.get("arguments") or "") return str(message["content"]) - def _generate(self, prompt: str, **kwargs: Any) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + invoke_params = self.invocation_params messages = self._build_messages(prompt, kwargs.get("instruction")) if functions := kwargs.get("functions"): diff --git a/packages/phoenix-evals/src/phoenix/evals/models/vertex.py b/packages/phoenix-evals/src/phoenix/evals/models/vertex.py index b5417b8f2d..4937461319 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/vertex.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/vertex.py @@ -1,9 +1,10 @@ import logging from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from phoenix.evals.models.base import BaseModel from phoenix.evals.models.rate_limiters import RateLimiter +from phoenix.evals.templates import MultimodalPrompt from phoenix.evals.utils import printif if TYPE_CHECKING: @@ -135,17 +136,21 @@ def _init_params(self) -> Dict[str, Any]: "credentials": self.credentials, } - def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str: # instruction is an invalid input to Gemini models, it is passed in by # BaseEvalModel.__call__ and needs to be removed kwargs.pop("instruction", None) + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + @self._rate_limiter.limit def _rate_limited_completion( - prompt: str, generation_config: Dict[str, Any], **kwargs: Any + prompt: MultimodalPrompt, generation_config: Dict[str, Any], **kwargs: Any ) -> Any: + prompt_str = self._construct_prompt(prompt) response = self._model.generate_content( - contents=prompt, generation_config=generation_config, **kwargs + contents=prompt_str, generation_config=generation_config, **kwargs ) return self._parse_response_candidates(response) @@ -157,17 +162,23 @@ def _rate_limited_completion( return str(response) - async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + async def _async_generate( + self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any] + ) -> str: # instruction is an invalid input to Gemini models, it is passed in by # BaseEvalModel.__call__ and needs to be removed kwargs.pop("instruction", None) + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + @self._rate_limiter.alimit async def _rate_limited_completion( - prompt: str, generation_config: Dict[str, Any], **kwargs: Any + prompt: MultimodalPrompt, generation_config: Dict[str, Any], **kwargs: Any ) -> Any: + prompt_str = self._construct_prompt(prompt) response = await self._model.generate_content_async( - contents=prompt, generation_config=generation_config, **kwargs + contents=prompt_str, generation_config=generation_config, **kwargs ) return self._parse_response_candidates(response) @@ -201,3 +212,6 @@ def _parse_response_candidates(self, response: Any) -> Any: printif(self._verbose, "The 'response' object does not have a 'candidates' attribute.") candidate = "" return candidate + + def _construct_prompt(self, prompt: MultimodalPrompt) -> str: + return prompt.to_text_only_prompt() diff --git a/packages/phoenix-evals/src/phoenix/evals/models/vertexai.py b/packages/phoenix-evals/src/phoenix/evals/models/vertexai.py index 6a5e6b157c..277f7db936 100644 --- a/packages/phoenix-evals/src/phoenix/evals/models/vertexai.py +++ b/packages/phoenix-evals/src/phoenix/evals/models/vertexai.py @@ -1,9 +1,10 @@ import logging import warnings from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from phoenix.evals.models.base import BaseModel +from phoenix.evals.templates import MultimodalPrompt if TYPE_CHECKING: from google.auth.credentials import Credentials @@ -149,13 +150,22 @@ def _instantiate_model(self) -> None: def verbose_generation_info(self) -> str: return f"VertexAI invocation parameters: {self.invocation_params}" - async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + async def _async_generate( + self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any] + ) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + return self._generate(prompt, **kwargs) - def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str: + def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str: + if isinstance(prompt, str): + prompt = MultimodalPrompt.from_string(prompt) + + prompt_str = prompt.to_text_only_prompt() invoke_params = self.invocation_params response = self._model.predict( - prompt=prompt, + prompt=prompt_str, **invoke_params, ) return str(response.text) diff --git a/packages/phoenix-evals/src/phoenix/evals/templates.py b/packages/phoenix-evals/src/phoenix/evals/templates.py index 5eecbcb4c4..4f2a578ed6 100644 --- a/packages/phoenix-evals/src/phoenix/evals/templates.py +++ b/packages/phoenix-evals/src/phoenix/evals/templates.py @@ -1,5 +1,6 @@ import re from dataclasses import dataclass +from enum import Enum from string import Formatter from typing import Any, Callable, List, Mapping, Optional, Sequence, Tuple, Union @@ -28,53 +29,103 @@ def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, A return obj, field_name -class PromptTemplate: +class PromptPartContentType(str, Enum): + TEXT = "text" + AUDIO_URL = "audio_url" + + +@dataclass +class PromptPart: + content_type: PromptPartContentType + content: str + + +@dataclass +class PromptPartTemplate: + content_type: PromptPartContentType template: str + + +@dataclass +class MultimodalPrompt: + parts: List[PromptPart] + + @staticmethod + def from_string(string_prompt: str) -> "MultimodalPrompt": + return MultimodalPrompt( + parts=[PromptPart(content_type=PromptPartContentType.TEXT, content=string_prompt)] + ) + + def to_text_only_prompt(self) -> str: + return "\n\n".join( + [part.content for part in self.parts if part.content_type == PromptPartContentType.TEXT] + ) + + def __str__(self) -> str: + return "\n\n".join([part.content for part in self.parts]) + + +class PromptTemplate: + template: List[PromptPartTemplate] variables: List[str] def __init__( self, - template: str, + template: Union[str, List[PromptPartTemplate]], delimiters: Tuple[str, str] = (DEFAULT_START_DELIM, DEFAULT_END_DELIM), ): - self.template = template + self.template: List[PromptPartTemplate] = self._normalize_template(template) self._start_delim, self._end_delim = delimiters self.variables = self._parse_variables(self.template) - def prompt(self, options: Optional[PromptOptions] = None) -> str: + def prompt(self, options: Optional[PromptOptions] = None) -> List[PromptPartTemplate]: return self.template def format( self, variable_values: Mapping[str, Union[bool, int, float, str]], options: Optional[PromptOptions] = None, - ) -> str: + ) -> MultimodalPrompt: prompt = self.prompt(options) - if self._start_delim == "{" and self._end_delim == "}": - self.formatter = DotKeyFormatter() - prompt = self.formatter.format(prompt, **variable_values) - else: - for variable_name in self.variables: - prompt = prompt.replace( - self._start_delim + variable_name + self._end_delim, - str(variable_values[variable_name]), - ) - return prompt - - def _parse_variables(self, text: str) -> List[str]: + prompt_messages = [] + for template_message in prompt: + if self._start_delim == "{" and self._end_delim == "}": + self.formatter = DotKeyFormatter() + prompt_message = self.formatter.format(template_message.template, **variable_values) + else: + for variable_name in self.variables: + prompt_message = prompt_message.replace( + self._start_delim + variable_name + self._end_delim, + str(variable_values[variable_name]), + ) + prompt_messages.append( + PromptPart(content_type=template_message.content_type, content=prompt_message) + ) + return MultimodalPrompt(parts=prompt_messages) + + def _parse_variables(self, template: List[PromptPartTemplate]) -> List[str]: start = re.escape(self._start_delim) end = re.escape(self._end_delim) pattern = rf"{start}(.*?){end}" - variables = re.findall(pattern, text) + variables = [] + for template_message in template: + variables += re.findall(pattern, template_message.template) return variables + def _normalize_template( + self, template: Union[str, List[PromptPartTemplate]] + ) -> List[PromptPartTemplate]: + if isinstance(template, str): + return [PromptPartTemplate(content_type=PromptPartContentType.TEXT, template=template)] + return template + class ClassificationTemplate(PromptTemplate): def __init__( self, rails: List[str], - template: str, - explanation_template: Optional[str] = None, + template: Union[str, List[PromptPartTemplate]], + explanation_template: Optional[Union[str, List[PromptPartTemplate]]] = None, explanation_label_parser: Optional[Callable[[str], str]] = None, delimiters: Tuple[str, str] = (DEFAULT_START_DELIM, DEFAULT_END_DELIM), scores: Optional[List[float]] = None, @@ -85,20 +136,21 @@ def __init__( "(i.e., the length of both lists must be the same)." ) self.rails = rails - self.template = template - self.explanation_template = explanation_template + self.template = self._normalize_template(template) + if explanation_template: + self.explanation_template = self._normalize_template(explanation_template) self.explanation_label_parser = explanation_label_parser self._start_delim, self._end_delim = delimiters self.variables: List[str] = [] - for text in [template, explanation_template]: - if text is not None: - self.variables += self._parse_variables(text) + for _template in [self.template, self.explanation_template]: + if _template: + self.variables.extend(self._parse_variables(template=_template)) self._scores = scores def __repr__(self) -> str: - return self.template + return "\n\n".join([template.template for template in self.template]) - def prompt(self, options: Optional[PromptOptions] = None) -> str: + def prompt(self, options: Optional[PromptOptions] = None) -> List[PromptPartTemplate]: if options is None: return self.template @@ -178,7 +230,7 @@ def map_template( dataframe: pd.DataFrame, template: PromptTemplate, options: Optional[PromptOptions] = None, -) -> "pd.Series[str]": +) -> List[MultimodalPrompt]: """ Maps over a dataframe to construct a list of prompts from a template and a dataframe. """ @@ -190,13 +242,13 @@ def map_template( prompt_options: PromptOptions = PromptOptions() if options is None else options try: - prompts = dataframe.apply( - lambda row: template.format( + prompts = [ + template.format( variable_values={var_name: row[var_name] for var_name in template.variables}, options=prompt_options, - ), - axis=1, - ) + ) + for _, row in dataframe.iterrows() + ] return prompts except KeyError as e: raise RuntimeError( diff --git a/packages/phoenix-evals/tests/phoenix/evals/templates/test_template.py b/packages/phoenix-evals/tests/phoenix/evals/templates/test_template.py index 8b4eac957c..b8a3df6a65 100644 --- a/packages/phoenix-evals/tests/phoenix/evals/templates/test_template.py +++ b/packages/phoenix-evals/tests/phoenix/evals/templates/test_template.py @@ -28,7 +28,7 @@ def test_classification_template_score_returns_zero_for_missing_rail(): def test_template_with_default_delimiters_uses_python_string_formatting(): template = PromptTemplate(template='Hello, {name}! Look at this JSON {{ "hello": "world" }}') assert ( - template.format(variable_values={"name": "world"}) + str(template.format(variable_values={"name": "world"})) == 'Hello, world! Look at this JSON { "hello": "world" }' ) @@ -36,6 +36,6 @@ def test_template_with_default_delimiters_uses_python_string_formatting(): def test_template_with_default_delimiters_accepts_keys_with_dots(): template = PromptTemplate(template='Hello, {my.name}! Look at this JSON {{ "hello": "world" }}') assert ( - template.format(variable_values={"my.name": "world"}) + str(template.format(variable_values={"my.name": "world"})) == 'Hello, world! Look at this JSON { "hello": "world" }' )