-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DO NOT MERGE, wip audio evals #5616
base: main
Are you sure you want to change the base?
Changes from 14 commits
38ec52e
0edc57e
822b918
c99c52d
1fdee44
3ebe948
2552190
4efc6b8
cd548b3
490c090
b4fc91d
44deeb8
e609a36
58e083e
19ba695
829f664
2ca4464
c4aacfd
1019d86
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
from itertools import product | ||
from typing import ( | ||
Any, | ||
Callable, | ||
DefaultDict, | ||
Dict, | ||
Iterable, | ||
|
@@ -34,6 +35,7 @@ | |
) | ||
from phoenix.evals.utils import ( | ||
NOT_PARSABLE, | ||
Audio, | ||
get_tqdm_progress_bar_formatter, | ||
openai_function_call_kwargs, | ||
parse_openai_function_call, | ||
|
@@ -268,6 +270,126 @@ def _run_llm_classification_sync(input_data: pd.Series[Any]) -> ParsedLLMRespons | |
) | ||
|
||
|
||
def audio_classify( | ||
dataframe: Union[List, pd.DataFrame], | ||
model: BaseModel, | ||
template: Union[ClassificationTemplate, PromptTemplate, str], | ||
rails: List[str], | ||
data_fetcher: Optional[Callable[[str], Audio]], | ||
system_instruction: Optional[str] = None, | ||
verbose: bool = False, | ||
use_function_calling_if_available: bool = True, | ||
provide_explanation: bool = False, | ||
include_prompt: bool = False, | ||
include_response: bool = False, | ||
include_exceptions: bool = False, | ||
max_retries: int = 10, | ||
exit_on_error: bool = True, | ||
run_sync: bool = False, | ||
concurrency: Optional[int] = None, | ||
progress_bar_format: Optional[str] = get_tqdm_progress_bar_formatter("llm_classify"), | ||
) -> pd.DataFrame: | ||
""" | ||
Classifies each input row of the dataframe using an LLM. | ||
Returns a pandas.DataFrame where the first column is named `label` and contains | ||
the classification labels. An optional column named `explanation` is added when | ||
`provide_explanation=True`. | ||
|
||
Args: | ||
dataframe (Union[List, pandas.DataFrame]): A pandas dataframe in which each row represents | ||
a record to be classified. All template variable names must appear as column | ||
names in the dataframe (extra columns unrelated to the template are permitted). | ||
|
||
template (Union[ClassificationTemplate, PromptTemplate, str]): The prompt template | ||
as either an instance of PromptTemplate, ClassificationTemplate or a string. | ||
If a string, the variable names should be surrounded by curly braces so that | ||
a call to `.format` can be made to substitute variable values. | ||
|
||
model (BaseEvalModel): An LLM model class. | ||
|
||
rails (List[str]): A list of strings representing the possible output classes | ||
of the model's predictions. | ||
|
||
data_fetcher (Optional[Callable[[str], Audio]]): A function that takes a URL or audio bytes | ||
and returns a complete audio object. | ||
|
||
system_instruction (Optional[str], optional): An optional system message. | ||
|
||
verbose (bool, optional): If True, prints detailed info to stdout such as | ||
model invocation parameters and details about retries and snapping to rails. | ||
Default False. | ||
|
||
use_function_calling_if_available (bool, default=True): If True, use function | ||
calling (if available) as a means to constrain the LLM outputs. | ||
With function calling, the LLM is instructed to provide its response as a | ||
structured JSON object, which is easier to parse. | ||
|
||
provide_explanation (bool, default=False): If True, provides an explanation | ||
for each classification label. A column named `explanation` is added to | ||
the output dataframe. | ||
|
||
include_prompt (bool, default=False): If True, includes a column named `prompt` | ||
in the output dataframe containing the prompt used for each classification. | ||
|
||
include_response (bool, default=False): If True, includes a column named `response` | ||
in the output dataframe containing the raw response from the LLM. | ||
|
||
max_retries (int, optional): The maximum number of times to retry on exceptions. | ||
Defaults to 10. | ||
|
||
exit_on_error (bool, default=True): If True, stops processing evals after all retries | ||
are exhausted on a single eval attempt. If False, all evals are attempted before | ||
returning, even if some fail. | ||
|
||
run_sync (bool, default=False): If True, forces synchronous request submission. | ||
Otherwise evaluations will be run asynchronously if possible. | ||
|
||
concurrency (Optional[int], default=None): The number of concurrent evals if async | ||
submission is possible. If not provided, a recommended default concurrency is | ||
set on a per-model basis. | ||
|
||
progress_bar_format(Optional[str]): An optional format for progress bar shown. If not | ||
specified, defaults to: llm_classify |{bar}| {n_fmt}/{total_fmt} ({percentage:3.1f}%) " | ||
"| ⏳ {elapsed}<{remaining} | {rate_fmt}{postfix}". If 'None' is passed in specifically, | ||
the progress_bar log will be disabled. | ||
|
||
Returns: | ||
pandas.DataFrame: A dataframe where the `label` column (at column position 0) contains | ||
the classification labels. If provide_explanation=True, then an additional column named | ||
`explanation` is added to contain the explanation for each label. The dataframe has | ||
the same length and index as the input dataframe. The classification label values are | ||
from the entries in the rails argument or "NOT_PARSABLE" if the model's output could | ||
not be parsed. The output dataframe also includes three additional columns in the | ||
output dataframe: `exceptions`, `execution_status`, and `execution_seconds` containing | ||
details about execution errors that may have occurred during the classification as well | ||
as the total runtime of each classification (in seconds). | ||
""" | ||
if not isinstance(dataframe, pd.DataFrame): | ||
dataframe = pd.DataFrame(dataframe, columns=["audio_url"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should be hard-coding the column name like this maybe? If users are passing in an iterable of urls, we should map the column names to the template variable names |
||
dataframe["audio_bytes"] = dataframe["audio_url"].apply(lambda url: data_fetcher(url).data) | ||
dataframe["audio_format"] = dataframe["audio_url"].apply( | ||
lambda url: data_fetcher(url).format.value | ||
) | ||
return llm_classify( | ||
dataframe=dataframe, | ||
model=model, | ||
template=template, | ||
rails=rails, | ||
system_instruction=system_instruction, | ||
verbose=verbose, | ||
use_function_calling_if_available=use_function_calling_if_available, | ||
provide_explanation=provide_explanation, | ||
include_prompt=include_prompt, | ||
include_response=include_response, | ||
include_exceptions=include_exceptions, | ||
max_retries=max_retries, | ||
exit_on_error=exit_on_error, | ||
run_sync=run_sync, | ||
concurrency=concurrency, | ||
progress_bar_format=progress_bar_format, | ||
) | ||
|
||
|
||
class RunEvalsPayload(NamedTuple): | ||
evaluator: LLMEvaluator | ||
record: Record | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,7 @@ | |
from phoenix.evals.models.base import BaseModel | ||
from phoenix.evals.models.rate_limiters import RateLimiter | ||
from phoenix.evals.templates import MultimodalPrompt, PromptPartContentType | ||
from phoenix.evals.utils import AudioFormat | ||
|
||
MINIMUM_OPENAI_VERSION = "1.0.0" | ||
MODEL_TOKEN_LIMIT_MAPPING = { | ||
|
@@ -281,12 +282,29 @@ def _init_rate_limiter(self) -> None: | |
def _build_messages( | ||
self, prompt: MultimodalPrompt, system_instruction: Optional[str] = None | ||
) -> List[Dict[str, str]]: | ||
audio_format = None | ||
messages = [] | ||
for parts in prompt.parts: | ||
if parts.content_type == PromptPartContentType.TEXT: | ||
messages.append({"role": "system", "content": parts.content}) | ||
for part in prompt.parts: | ||
if part.content_type == PromptPartContentType.TEXT: | ||
messages.append({"role": "system", "content": part.content}) | ||
elif part.content_type == PromptPartContentType.AUDIO_FORMAT: | ||
audio_format = AudioFormat(part.content) | ||
elif part.content_type == PromptPartContentType.AUDIO_BYTES: | ||
if not audio_format: | ||
raise ValueError("No audio format provided") | ||
messages.append( | ||
{ # type: ignore | ||
"role": "user", | ||
"content": [ | ||
{ | ||
"type": "input_audio", | ||
"input_audio": {"data": part.content, "format": audio_format.value}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we should be trying to infer the format from the file headers instead—we'll just do our best effort and fall back to something sensible if the inference doesn't work. We have the bytestring so the headers should be there |
||
} | ||
], | ||
} | ||
) | ||
else: | ||
raise ValueError(f"Unsupported content type: {parts.content_type}") | ||
raise ValueError(f"Unsupported content type: {part.content_type}") | ||
if system_instruction: | ||
messages.insert(0, {"role": "system", "content": str(system_instruction)}) | ||
return messages | ||
|
@@ -321,7 +339,7 @@ def _generate(self, prompt: Union[str, MultimodalPrompt], **kwargs: Any) -> str: | |
prompt = MultimodalPrompt.from_string(prompt) | ||
|
||
invoke_params = self.invocation_params | ||
messages = self._build_messages(prompt, kwargs.get("instruction")) | ||
messages = self._build_messages(prompt=prompt, system_instruction=kwargs.get("instruction")) | ||
if functions := kwargs.get("functions"): | ||
invoke_params["functions"] = functions | ||
if function_call := kwargs.get("function_call"): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,16 @@ | ||
import base64 | ||
import json | ||
import subprocess | ||
from dataclasses import dataclass | ||
from enum import Enum | ||
from io import BytesIO | ||
from typing import Any, Dict, List, Optional, Tuple | ||
from urllib.error import HTTPError | ||
from urllib.request import urlopen | ||
from zipfile import ZipFile | ||
|
||
import pandas as pd | ||
import requests | ||
from tqdm.auto import tqdm | ||
|
||
# Rather than returning None, we return this string to indicate that the LLM output could not be | ||
|
@@ -20,6 +25,16 @@ | |
_FUNCTION_NAME = "record_response" | ||
|
||
|
||
class AudioFormat(Enum): | ||
WAV = "wav" | ||
|
||
|
||
@dataclass | ||
class Audio: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this feels pretty hard for a user to really know they have to wrap the output in this wrapper class |
||
data: str | ||
format: AudioFormat | ||
|
||
|
||
def download_benchmark_dataset(task: str, dataset_name: str) -> pd.DataFrame: | ||
"""Downloads an Arize evals benchmark dataset as a pandas dataframe. | ||
|
||
|
@@ -174,3 +189,52 @@ def _default_openai_function( | |
def printif(condition: bool, *args: Any, **kwargs: Any) -> None: | ||
if condition: | ||
tqdm.write(*args, **kwargs) | ||
|
||
|
||
def fetch_gcloud_data(url: str) -> Audio: | ||
token = None | ||
try: | ||
# Execute the gcloud command to fetch the access token | ||
output = subprocess.check_output( | ||
["gcloud", "auth", "print-access-token"], stderr=subprocess.STDOUT | ||
) | ||
token = output.decode("UTF-8").strip() | ||
|
||
# Ensure the token is not empty or None | ||
if not token: | ||
raise ValueError("Failed to retrieve a valid access token. Token is empty.") | ||
|
||
except subprocess.CalledProcessError as e: | ||
# Handle errors in the subprocess call | ||
if e.returncode == 1: | ||
print(f"Error executing gcloud command: {e.output.decode('UTF-8').strip()}") | ||
raise RuntimeError("Failed to execute gcloud auth command. You may not be logged in.") | ||
except Exception as e: | ||
# Catch any other exceptions and re-raise them with additional context | ||
raise RuntimeError(f"An unexpected error occurred: {str(e)}") | ||
|
||
# Set the token in the header | ||
gcloud_header = {"Authorization": f"Bearer {token}"} | ||
|
||
# Must ensure that the url begins with storage.googleapis..., rather than store.cloud.google... | ||
G_API_HOST = "https://storage.googleapis.com/" | ||
not_googleapis = url.startswith("https://storage.cloud.google.com/") or url.startswith("gs://") | ||
g_api_url = ( | ||
url.replace("https://storage.cloud.google.com/", G_API_HOST) | ||
if url and not_googleapis | ||
else url | ||
) | ||
|
||
# Get a response back, present the status | ||
response = requests.get(g_api_url, headers=gcloud_header) | ||
response.raise_for_status() | ||
|
||
encoded_string = base64.b64encode(response.content).decode("utf-8") | ||
|
||
# Get the content type from the response headers | ||
content_type = response.headers.get("Content-Type") | ||
|
||
# Replace 'audio/' with an empty string to get the audio format | ||
audio_format = AudioFormat(content_type.replace("audio/", "")) | ||
|
||
return Audio(data=encoded_string, format=audio_format) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't think we should be requiring the user return our internal type, it feels like a structure they need to learn / import from that feels clunky