Skip to content

Commit

Permalink
feat: Enable phoenix.evals to handle multimodal message templates (#…
Browse files Browse the repository at this point in the history
…5522)

* Templates return lists of messages from `format` method

* Update model wrapper base class

* Migrate model wrappers

* Resolve type errors

* Fix more type issues

* Wire up full backwards compatibility

* Do not error not on PromptMessage inputs to model wrappers

* Do not use list subscripting

* Refactor to use generic `parts` instead of `messages`

* Ruff 🐶
  • Loading branch information
anticorrelator authored Dec 2, 2024
1 parent 12e1f33 commit 41a4fc2
Show file tree
Hide file tree
Showing 12 changed files with 245 additions and 90 deletions.
7 changes: 4 additions & 3 deletions packages/phoenix-evals/src/phoenix/evals/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from phoenix.evals.models import BaseModel, OpenAIModel, set_verbosity
from phoenix.evals.templates import (
ClassificationTemplate,
MultimodalPrompt,
PromptOptions,
PromptTemplate,
normalize_classification_template,
Expand Down Expand Up @@ -177,7 +178,7 @@ def llm_classify(
if generation_info := model.verbose_generation_info():
printif(verbose, generation_info)

def _map_template(data: pd.Series[Any]) -> str:
def _map_template(data: pd.Series[Any]) -> MultimodalPrompt:
try:
variables = {var: data[var] for var in eval_template.variables}
empty_keys = [k for k, v in variables.items() if v is None]
Expand Down Expand Up @@ -217,7 +218,7 @@ async def _run_llm_classification_async(input_data: pd.Series[Any]) -> ParsedLLM
prompt, instruction=system_instruction, **model_kwargs
)
inference, explanation = _process_response(response)
return inference, explanation, response, prompt
return inference, explanation, response, str(prompt)

def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMResponse:
with set_verbosity(model, verbose) as verbose_model:
Expand All @@ -226,7 +227,7 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons
prompt, instruction=system_instruction, **model_kwargs
)
inference, explanation = _process_response(response)
return inference, explanation, response, prompt
return inference, explanation, response, str(prompt)

fallback_return_value: ParsedLLMResponse = (None, None, "", "")

Expand Down
15 changes: 10 additions & 5 deletions packages/phoenix-evals/src/phoenix/evals/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)
from phoenix.evals.models import BaseModel, set_verbosity
from phoenix.evals.templates import (
MultimodalPrompt,
PromptTemplate,
map_template,
normalize_prompt_template,
Expand Down Expand Up @@ -90,7 +91,9 @@ def llm_generate(
logger.info(f"Template variables: {template.variables}")
prompts = map_template(dataframe, template)

async def _run_llm_generation_async(enumerated_prompt: Tuple[int, str]) -> Dict[str, Any]:
async def _run_llm_generation_async(
enumerated_prompt: Tuple[int, MultimodalPrompt],
) -> Dict[str, Any]:
index, prompt = enumerated_prompt
with set_verbosity(model, verbose) as verbose_model:
response = await verbose_model._async_generate(
Expand All @@ -99,12 +102,14 @@ async def _run_llm_generation_async(enumerated_prompt: Tuple[int, str]) -> Dict[
)
parsed_response = output_parser(response, index)
if include_prompt:
parsed_response["prompt"] = prompt
parsed_response["prompt"] = str(prompt)
if include_response:
parsed_response["response"] = response
return parsed_response

def _run_llm_generation_sync(enumerated_prompt: Tuple[int, str]) -> Dict[str, Any]:
def _run_llm_generation_sync(
enumerated_prompt: Tuple[int, MultimodalPrompt],
) -> Dict[str, Any]:
index, prompt = enumerated_prompt
with set_verbosity(model, verbose) as verbose_model:
response = verbose_model._generate(
Expand All @@ -113,7 +118,7 @@ def _run_llm_generation_sync(enumerated_prompt: Tuple[int, str]) -> Dict[str, An
)
parsed_response = output_parser(response, index)
if include_prompt:
parsed_response["prompt"] = prompt
parsed_response["prompt"] = str(prompt)
if include_response:
parsed_response["response"] = response
return parsed_response
Expand All @@ -133,5 +138,5 @@ def _run_llm_generation_sync(enumerated_prompt: Tuple[int, str]) -> Dict[str, An
exit_on_error=True,
fallback_return_value=fallback_return_value,
)
results, _ = executor.run(list(enumerate(prompts.tolist())))
results, _ = executor.run(list(enumerate(prompts)))
return pd.DataFrame.from_records(results, index=dataframe.index)
27 changes: 20 additions & 7 deletions packages/phoenix-evals/src/phoenix/evals/models/anthropic.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

from phoenix.evals.exceptions import PhoenixContextLimitExceeded
from phoenix.evals.models.base import BaseModel
from phoenix.evals.models.rate_limiters import RateLimiter
from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType

MINIMUM_ANTHROPIC_VERSION = "0.18.0"

Expand Down Expand Up @@ -110,9 +111,12 @@ def invocation_parameters(self) -> Dict[str, Any]:
"top_k": self.top_k,
}

def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str:
# instruction is an invalid input to Anthropic models, it is passed in by
# BaseEvalModel.__call__ and needs to be removed
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

kwargs.pop("instruction", None)
invocation_parameters = self.invocation_parameters()
invocation_parameters.update(kwargs)
Expand All @@ -138,9 +142,14 @@ def _completion(**kwargs: Any) -> Any:

return _completion(**kwargs)

async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
async def _async_generate(
self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]
) -> str:
# instruction is an invalid input to Anthropic models, it is passed in by
# BaseEvalModel.__call__ and needs to be removed
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

kwargs.pop("instruction", None)
invocation_parameters = self.invocation_parameters()
invocation_parameters.update(kwargs)
Expand All @@ -166,8 +175,12 @@ async def _async_completion(**kwargs: Any) -> Any:

return await _async_completion(**kwargs)

def _format_prompt_for_claude(self, prompt: str) -> List[Dict[str, str]]:
def _format_prompt_for_claude(self, prompt: MultimodalPrompt) -> List[Dict[str, str]]:
# the Anthropic messages API expects a list of messages
return [
{"role": "user", "content": prompt},
]
messages = []
for part in prompt.parts:
if part.content_type == PromptPartContentType.TEXT:
messages.append({"role": "user", "content": part.content})
else:
raise ValueError(f"Unsupported content type: {part.content_type}")
return messages
16 changes: 10 additions & 6 deletions packages/phoenix-evals/src/phoenix/evals/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from dataclasses import dataclass, field
from typing import Any, Generator, Optional, Sequence

from typing_extensions import TypeVar
from typing_extensions import TypeVar, Union

from phoenix.evals.models.rate_limiters import RateLimiter
from phoenix.evals.templates import MultimodalPrompt

T = TypeVar("T", bound=type)

Expand Down Expand Up @@ -61,11 +62,13 @@ def _model_name(self) -> str:
def reload_client(self) -> None:
pass

def __call__(self, prompt: str, instruction: Optional[str] = None, **kwargs: Any) -> str:
def __call__(
self, prompt: Union[str, MultimodalPrompt], instruction: Optional[str] = None, **kwargs: Any
) -> str:
"""Run the LLM on the given prompt."""
if not isinstance(prompt, str):
if not isinstance(prompt, (str, MultimodalPrompt)):
raise TypeError(
"Invalid type for argument `prompt`. Expected a string but found "
"Invalid type for argument `prompt`. Expected a string or PromptMessages but found "
f"{type(prompt)}. If you want to run the LLM on multiple prompts, use "
"`generate` instead."
)
Expand All @@ -74,6 +77,7 @@ def __call__(self, prompt: str, instruction: Optional[str] = None, **kwargs: Any
"Invalid type for argument `instruction`. Expected a string but found "
f"{type(instruction)}."
)

return self._generate(prompt=prompt, instruction=instruction, **kwargs)

def verbose_generation_info(self) -> str:
Expand All @@ -82,11 +86,11 @@ def verbose_generation_info(self) -> str:
return ""

@abstractmethod
async def _async_generate(self, prompt: str, **kwargs: Any) -> str:
async def _async_generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str:
raise NotImplementedError

@abstractmethod
def _generate(self, prompt: str, **kwargs: Any) -> str:
def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str:
raise NotImplementedError

@staticmethod
Expand Down
31 changes: 23 additions & 8 deletions packages/phoenix-evals/src/phoenix/evals/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import logging
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from phoenix.evals.exceptions import PhoenixContextLimitExceeded
from phoenix.evals.models.base import BaseModel
from phoenix.evals.models.rate_limiters import RateLimiter
from phoenix.evals.templates import MultimodalPrompt

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -102,7 +103,10 @@ def _init_rate_limiter(self) -> None:
enforcement_window_minutes=1,
)

def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str:
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

body = json.dumps(self._create_request_body(prompt))
accept = "application/json"
contentType = "application/json"
Expand All @@ -113,7 +117,12 @@ def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:

return self._parse_output(response) or ""

async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
async def _async_generate(
self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]
) -> str:
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, partial(self._generate, prompt, **kwargs))

Expand Down Expand Up @@ -148,13 +157,19 @@ def _format_prompt_for_claude(self, prompt: str) -> List[Dict[str, str]]:
{"role": "user", "content": prompt},
]

def _create_request_body(self, prompt: str) -> Dict[str, Any]:
def _create_request_body(self, prompt: MultimodalPrompt) -> Dict[str, Any]:
# The request formats for bedrock models differ
# see https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html

# TODO: Migrate to using the bedrock `converse` API
# https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-call.html

prompt_str = prompt.to_text_only_prompt()

if self.model_id.startswith("ai21"):
return {
**{
"prompt": prompt,
"prompt": prompt_str,
"temperature": self.temperature,
"topP": self.top_p,
"maxTokens": self.max_tokens,
Expand All @@ -166,7 +181,7 @@ def _create_request_body(self, prompt: str) -> Dict[str, Any]:
return {
**{
"anthropic_version": "bedrock-2023-05-31",
"messages": self._format_prompt_for_claude(prompt),
"messages": self._format_prompt_for_claude(prompt_str),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
Expand All @@ -178,7 +193,7 @@ def _create_request_body(self, prompt: str) -> Dict[str, Any]:
elif self.model_id.startswith("mistral"):
return {
**{
"prompt": prompt,
"prompt": prompt_str,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"stop": self.stop_sequences,
Expand All @@ -192,7 +207,7 @@ def _create_request_body(self, prompt: str) -> Dict[str, Any]:
logger.warn(f"Unknown format for model {self.model_id}, returning titan format...")
return {
**{
"inputText": prompt,
"inputText": prompt_str,
"textGenerationConfig": {
"temperature": self.temperature,
"topP": self.top_p,
Expand Down
26 changes: 20 additions & 6 deletions packages/phoenix-evals/src/phoenix/evals/models/litellm.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import logging
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from phoenix.evals.models.base import BaseModel
from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -103,10 +104,18 @@ def _init_environment(self) -> None:
package_name="litellm",
)

async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
async def _async_generate(
self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]
) -> str:
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

return self._generate(prompt, **kwargs)

def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str:
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

messages = self._get_messages_from_prompt(prompt)
response = self._litellm.completion(
model=self.model,
Expand All @@ -120,7 +129,12 @@ def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
)
return str(response.choices[0].message.content)

def _get_messages_from_prompt(self, prompt: str) -> List[Dict[str, str]]:
def _get_messages_from_prompt(self, prompt: MultimodalPrompt) -> List[Dict[str, str]]:
# LiteLLM requires prompts in the format of messages
# messages=[{"content": "ABC?","role": "user"}]
return [{"content": prompt, "role": "user"}]
messages = []
for part in prompt.parts:
if part.content_type == PromptPartContentType.TEXT:
messages.append({"content": part.content, "role": "user"})
else:
raise ValueError(f"Unsupported content type: {part.content_type}")
return messages
25 changes: 20 additions & 5 deletions packages/phoenix-evals/src/phoenix/evals/models/mistralai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from phoenix.evals.models.base import BaseModel
from phoenix.evals.models.rate_limiters import RateLimiter
from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType

if TYPE_CHECKING:
from mistralai.models.chat_completion import ChatMessage
Expand Down Expand Up @@ -106,9 +107,12 @@ def invocation_parameters(self) -> Dict[str, Any]:
# Mistral is strict about not passing None values to the API
return {k: v for k, v in params.items() if v is not None}

def _generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]) -> str:
# instruction is an invalid input to Mistral models, it is passed in by
# BaseEvalModel.__call__ and needs to be removed
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

kwargs.pop("instruction", None)
invocation_parameters = self.invocation_parameters()
invocation_parameters.update(kwargs)
Expand All @@ -134,9 +138,14 @@ def _completion(**kwargs: Any) -> Any:

return _completion(**kwargs)

async def _async_generate(self, prompt: str, **kwargs: Dict[str, Any]) -> str:
async def _async_generate(
self, prompt: Union[str, MultimodalPrompt], **kwargs: Dict[str, Any]
) -> str:
# instruction is an invalid input to Mistral models, it is passed in by
# BaseEvalModel.__call__ and needs to be removed
if isinstance(prompt, str):
prompt = MultimodalPrompt.from_string(prompt)

kwargs.pop("instruction", None)
invocation_parameters = self.invocation_parameters()
invocation_parameters.update(kwargs)
Expand All @@ -163,6 +172,12 @@ async def _async_completion(**kwargs: Any) -> Any:

return await _async_completion(**kwargs)

def _format_prompt(self, prompt: str) -> List["ChatMessage"]:
def _format_prompt(self, prompt: MultimodalPrompt) -> List["ChatMessage"]:
ChatMessage = self._ChatMessage
return [ChatMessage(role="user", content=prompt)]
messages = []
for part in prompt.parts:
if part.content_type == PromptPartContentType.TEXT:
messages.append(ChatMessage(role="user", content=part.content))
else:
raise ValueError(f"Unsupported content type: {part.content_type}")
return messages
Loading

0 comments on commit 41a4fc2

Please sign in to comment.