Skip to content

Commit

Permalink
feat: support multiclass classification (#1629)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored Oct 16, 2023
1 parent 2c68368 commit 7d9ae34
Show file tree
Hide file tree
Showing 13 changed files with 127 additions and 124 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ from phoenix.experimental.evals import (
RAG_RELEVANCY_PROMPT_RAILS_MAP,
OpenAIModel,
download_benchmark_dataset,
llm_eval_binary,
llm_classify,
)
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, ConfusionMatrixDisplay

Expand All @@ -237,7 +237,7 @@ model = OpenAIModel(
temperature=0.0,
)
rails =list(RAG_RELEVANCY_PROMPT_RAILS_MAP.values())
df["eval_relevance"] = llm_eval_binary(df, model, RAG_RELEVANCY_PROMPT_TEMPLATE_STR, rails)
df["eval_relevance"] = llm_classify(df, model, RAG_RELEVANCY_PROMPT_TEMPLATE_STR, rails)
#Golden dataset has True/False map to -> "irrelevant" / "relevant"
#we can then scikit compare to output of template - same format
y_true = df["relevant"].map({True: "relevant", False: "irrelevant"})
Expand Down
4 changes: 2 additions & 2 deletions scripts/rag/llama_index_w_evals_and_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from llama_index.node_parser import SimpleNodeParser
from llama_index.query_engine.multistep_query_engine import MultiStepQueryEngine
from llama_index.query_engine.transform_query_engine import TransformQueryEngine
from phoenix.experimental.evals import OpenAIModel, llm_eval_binary, run_relevance_eval
from phoenix.experimental.evals import OpenAIModel, llm_classify, run_relevance_eval
from phoenix.experimental.evals.functions.processing import concatenate_and_truncate_chunks
from phoenix.experimental.evals.models import BaseEvalModel
from phoenix.experimental.evals.templates import NOT_PARSABLE
Expand Down Expand Up @@ -302,7 +302,7 @@ def df_evals(

df = df.rename(columns={"query": "question", "response": "sampled_answer"})
# Q&A Eval: Did the LLM get the answer right? Checking the LLM
Q_and_A_classifications = llm_eval_binary(
Q_and_A_classifications = llm_classify(
dataframe=df,
template=template,
model=model,
Expand Down
3 changes: 2 additions & 1 deletion src/phoenix/experimental/evals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .functions import llm_eval_binary, llm_generate, run_relevance_eval
from .functions import llm_classify, llm_eval_binary, llm_generate, run_relevance_eval
from .models import OpenAIModel, VertexAIModel
from .retrievals import compute_precisions_at_k
from .templates import (
Expand All @@ -18,6 +18,7 @@
__all__ = [
"compute_precisions_at_k",
"download_benchmark_dataset",
"llm_classify",
"llm_eval_binary",
"llm_generate",
"OpenAIModel",
Expand Down
4 changes: 2 additions & 2 deletions src/phoenix/experimental/evals/functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .binary import llm_eval_binary, run_relevance_eval
from .classify import llm_classify, llm_eval_binary, run_relevance_eval
from .generate import llm_generate

__all__ = ["llm_eval_binary", "run_relevance_eval", "llm_generate"]
__all__ = ["llm_classify", "llm_eval_binary", "run_relevance_eval", "llm_generate"]
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from typing import Any, Iterable, List, Optional, Set, Union, cast
import warnings
from typing import Any, Iterable, List, Optional, Union, cast

import pandas as pd

Expand All @@ -22,15 +23,15 @@
OPENINFERENCE_DOCUMENT_COLUMN_NAME = "attributes." + RETRIEVAL_DOCUMENTS


def llm_eval_binary(
def llm_classify(
dataframe: pd.DataFrame,
model: BaseEvalModel,
template: Union[PromptTemplate, str],
rails: List[str],
system_instruction: Optional[str] = None,
verbose: bool = False,
) -> List[str]:
"""Runs binary classifications using an LLM.
"""Classifies each input row of the dataframe using an LLM.
Args:
dataframe (pandas.DataFrame): A pandas dataframe in which each row represents a record to be
Expand Down Expand Up @@ -62,9 +63,62 @@ def llm_eval_binary(
eval_template = normalize_template(template)
prompts = map_template(dataframe, eval_template)
responses = verbose_model.generate(prompts.to_list(), instruction=system_instruction)
rails_set = set(rails)
printif(verbose, f"Snapping {len(responses)} responses to rails: {rails_set}")
return [_snap_to_rail(response, rails_set, verbose=verbose) for response in responses]
printif(verbose, f"Snapping {len(responses)} responses to rails: {rails}")
return [_snap_to_rail(response, rails, verbose=verbose) for response in responses]


def llm_eval_binary(
dataframe: pd.DataFrame,
model: BaseEvalModel,
template: Union[PromptTemplate, str],
rails: List[str],
system_instruction: Optional[str] = None,
verbose: bool = False,
) -> List[str]:
"""Performs a binary classification on the rows of the input dataframe using an LLM.
Args:
dataframe (pandas.DataFrame): A pandas dataframe in which each row represents a record to be
classified. All template variable names must appear as column names in the dataframe (extra
columns unrelated to the template are permitted).
template (Union[PromptTemplate, str]): The prompt template as either an instance of
PromptTemplate or a string. If the latter, the variable names should be surrounded by
curly braces so that a call to `.format` can be made to substitute variable values.
model (BaseEvalModel): An LLM model class.
rails (List[str]): A list of strings representing the possible output classes of the model's
predictions.
system_instruction (Optional[str], optional): An optional system message.
verbose (bool, optional): If True, prints detailed info to stdout such as model invocation
parameters and details about retries and snapping to rails. Default False.
Returns:
List[str]: A list of strings representing the predicted class for each record in the
dataframe. The list should have the same length as the input dataframe and its values should
be the entries in the rails argument or "NOT_PARSABLE" if the model's prediction could not
be parsed.
"""

warnings.warn(
"This function will soon be deprecated. "
"Use llm_classify instead, which has the same function signature "
"and provides support for multi-class classification "
"in addition to binary classification.",
category=DeprecationWarning,
stacklevel=2,
)
return llm_classify(
dataframe=dataframe,
model=model,
template=template,
rails=rails,
system_instruction=system_instruction,
verbose=verbose,
)


def run_relevance_eval(
Expand Down Expand Up @@ -161,7 +215,7 @@ def run_relevance_eval(
indexes.append(index)
expanded_queries.append(query)
expanded_documents.append(document)
predictions = llm_eval_binary(
predictions = llm_classify(
dataframe=pd.DataFrame(
{
query_column_name: expanded_queries,
Expand All @@ -188,92 +242,33 @@ def _get_contents_from_openinference_documents(documents: Iterable[Any]) -> List
return [doc.get(DOCUMENT_CONTENT) if isinstance(doc, dict) else None for doc in documents]


def _snap_to_rail(string: str, rails: Set[str], verbose: bool = False) -> str:
def _snap_to_rail(raw_string: str, rails: List[str], verbose: bool = False) -> str:
"""
Snaps a string to the nearest rail, or returns None if the string cannot be snapped to a
rail.
Snaps a string to the nearest rail, or returns None if the string cannot be
snapped to a rail.
Args:
string (str): An input to be snapped to a rail.
raw_string (str): An input to be snapped to a rail.
rails (Set[str]): The target set of strings to snap to.
rails (List[str]): The target set of strings to snap to.
Returns:
str: A string from the rails argument or None if the input string could not be snapped.
str: A string from the rails argument or "UNPARSABLE" if the input
string could not be snapped.
"""

processed_string = string.strip()
rails_list = list(rails)
rail = _extract_rail(processed_string, rails_list[0], rails_list[1])
if not rail:
printif(verbose, f"- Cannot snap {repr(string)} to rails: {rails}")
logger.warning(
f"LLM output cannot be snapped to rails {list(rails)}, returning {NOT_PARSABLE}. "
f'Output: "{string}"'
)
snap_string = raw_string.lower()
rails = list(set(rails))
rails = [rail.lower() for rail in rails]
rails.sort(key=len, reverse=True)
found_rails = set()
for rail in rails:
if rail in snap_string:
found_rails.add(rail)
snap_string = snap_string.replace(rail, "")
if len(found_rails) != 1:
printif(verbose, f"- Cannot snap {repr(raw_string)} to rails")
return NOT_PARSABLE
else:
printif(verbose, f"- Snapped {repr(string)} to rail: {rail}")
rail = list(found_rails)[0]
printif(verbose, f"- Snapped {repr(raw_string)} to rail: {rail}")
return rail


def _extract_rail(string: str, positive_rail: str, negative_rail: str) -> Optional[str]:
"""
Extracts the right rails text from the llm output. If the rails have overlapping characters,
(e.x. "regular" and "irregular"), it also ensures that the correct rail is returned.
Args:
string (str): An input to be snapped to a rail.
positive_rail (str): The positive rail (e.x. toxic)
negative_rail (str): The negative rail. (e.x. non-toxic)
Returns:
str: A string from the rails or None if the input string could not be extracted.
Examples:
given: positive_rail = "irregular", negative_rail = "regular"
string = "irregular"
Output: "irregular"
string = "regular"
Output: "regular"
string = "regular,:....random"
Output: "regular"
string = "regular..irregular" - contains both rails
Output: None
string = "Irregular"
Output: "irregular"
"""

# Convert the inputs to lowercase for case-insensitive matching
string = string.lower()
positive_rail = positive_rail.lower()
negative_rail = negative_rail.lower()

positive_pos, negative_pos = string.find(positive_rail), string.find(negative_rail)

# If both positive and negative rails are in the string
if positive_pos != -1 and negative_pos != -1:
# If either one is a substring of the other, return the longer one
# e.x. "regular" and "irregular"
if positive_pos < negative_pos < positive_pos + len(
positive_rail
) or negative_pos < positive_pos < negative_pos + len(negative_rail):
# Return the longer of the rails since it means the LLM returned the longer one
return max(positive_rail, negative_rail, key=len)
else:
# If both rails values are in the string, we cannot determine which to return
return None
# If only positive is in string
elif positive_pos != -1:
return positive_rail
# If only negative is in the string
elif negative_pos != -1:
return negative_rail
return None
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
NOT_PARSABLE,
RAG_RELEVANCY_PROMPT_TEMPLATE_STR,
OpenAIModel,
llm_eval_binary,
llm_classify,
run_relevance_eval,
)
from phoenix.experimental.evals.functions.binary import _snap_to_rail
from phoenix.experimental.evals.functions.classify import _snap_to_rail
from phoenix.experimental.evals.models.openai import OPENAI_API_KEY_ENVVAR_NAME


@responses.activate
def test_llm_eval_binary(monkeypatch: pytest.MonkeyPatch):
def test_llm_classify(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
dataframe = pd.DataFrame(
[
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_llm_eval_binary(monkeypatch: pytest.MonkeyPatch):
with patch.object(OpenAIModel, "_init_tiktoken", return_value=None):
model = OpenAIModel()

relevance_classifications = llm_eval_binary(
relevance_classifications = llm_classify(
dataframe=dataframe,
template=RAG_RELEVANCY_PROMPT_TEMPLATE_STR,
model=model,
Expand All @@ -72,7 +72,7 @@ def test_llm_eval_binary(monkeypatch: pytest.MonkeyPatch):


@responses.activate
def test_llm_eval_binary_prints_to_stdout_with_verbose_flag(monkeypatch: pytest.MonkeyPatch, capfd):
def test_llm_classify_prints_to_stdout_with_verbose_flag(monkeypatch: pytest.MonkeyPatch, capfd):
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
dataframe = pd.DataFrame(
[
Expand Down Expand Up @@ -116,7 +116,7 @@ def test_llm_eval_binary_prints_to_stdout_with_verbose_flag(monkeypatch: pytest.
with patch.object(OpenAIModel, "_init_tiktoken", return_value=None):
model = OpenAIModel()

llm_eval_binary(
llm_classify(
dataframe=dataframe,
template=RAG_RELEVANCY_PROMPT_TEMPLATE_STR,
model=model,
Expand All @@ -134,7 +134,7 @@ def test_llm_eval_binary_prints_to_stdout_with_verbose_flag(monkeypatch: pytest.
assert "sk-0123456789" not in out, "Credentials should not be printed out in cleartext"


def test_llm_eval_binary_shows_retry_info_with_verbose_flag(monkeypatch: pytest.MonkeyPatch, capfd):
def test_llm_classify_shows_retry_info_with_verbose_flag(monkeypatch: pytest.MonkeyPatch, capfd):
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
dataframe = pd.DataFrame(
[
Expand Down Expand Up @@ -163,7 +163,7 @@ def test_llm_eval_binary_shows_retry_info_with_verbose_flag(monkeypatch: pytest.
stack.enter_context(patch.object(OpenAIModel, "_init_tiktoken", return_value=None))
stack.enter_context(patch.object(model._openai.ChatCompletion, "create", mock_openai))
stack.enter_context(pytest.raises(model._openai_error.ServiceUnavailableError))
llm_eval_binary(
llm_classify(
dataframe=dataframe,
template=RAG_RELEVANCY_PROMPT_TEMPLATE_STR,
model=model,
Expand All @@ -183,7 +183,7 @@ def test_llm_eval_binary_shows_retry_info_with_verbose_flag(monkeypatch: pytest.
assert "Failed attempt 5" not in out, "Maximum retries should not be exceeded"


def test_llm_eval_binary_does_not_persist_verbose_flag(monkeypatch: pytest.MonkeyPatch, capfd):
def test_llm_classify_does_not_persist_verbose_flag(monkeypatch: pytest.MonkeyPatch, capfd):
monkeypatch.setenv(OPENAI_API_KEY_ENVVAR_NAME, "sk-0123456789")
dataframe = pd.DataFrame(
[
Expand All @@ -209,7 +209,7 @@ def test_llm_eval_binary_does_not_persist_verbose_flag(monkeypatch: pytest.Monke
stack.enter_context(patch.object(OpenAIModel, "_init_tiktoken", return_value=None))
stack.enter_context(patch.object(model._openai.ChatCompletion, "create", mock_openai))
stack.enter_context(pytest.raises(model._openai_error.APIError))
llm_eval_binary(
llm_classify(
dataframe=dataframe,
template=RAG_RELEVANCY_PROMPT_TEMPLATE_STR,
model=model,
Expand All @@ -231,7 +231,7 @@ def test_llm_eval_binary_does_not_persist_verbose_flag(monkeypatch: pytest.Monke
stack.enter_context(patch.object(OpenAIModel, "_init_tiktoken", return_value=None))
stack.enter_context(patch.object(model._openai.ChatCompletion, "create", mock_openai))
stack.enter_context(pytest.raises(model._openai_error.APIError))
llm_eval_binary(
llm_classify(
dataframe=dataframe,
template=RAG_RELEVANCY_PROMPT_TEMPLATE_STR,
model=model,
Expand Down Expand Up @@ -409,3 +409,10 @@ def test_overlapping_rails():
# Both rails are present, cannot parse
assert _snap_to_rail("relevant...irrelevant", ["irrelevant", "relevant"]) is NOT_PARSABLE
assert _snap_to_rail("Irrelevant", ["relevant", "irrelevant"]) == "irrelevant"
# One rail appears twice
assert _snap_to_rail("relevant...relevant", ["irrelevant", "relevant"]) == "relevant"
assert _snap_to_rail("b b", ["a", "b", "c"]) == "b"
# More than two rails
assert _snap_to_rail("a", ["a", "b", "c"]) == "a"
assert _snap_to_rail(" abc", ["a", "ab", "abc"]) == "abc"
assert _snap_to_rail("abc", ["abc", "a", "ab"]) == "abc"
6 changes: 3 additions & 3 deletions tutorials/evals/evaluate_QA_classifications.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"from phoenix.experimental.evals import (\n",
" OpenAIModel,\n",
" download_benchmark_dataset,\n",
" llm_eval_binary,\n",
" llm_classify,\n",
")\n",
"from pycm import ConfusionMatrix\n",
"from sklearn.metrics import classification_report\n",
Expand Down Expand Up @@ -670,7 +670,7 @@
"# It will remove text such as \",,,\" or \"...\", anything not the\n",
"# binary value expected from the template\n",
"rails = list(templates.QA_PROMPT_RAILS_MAP.values())\n",
"Q_and_A_classifications = llm_eval_binary(\n",
"Q_and_A_classifications = llm_classify(\n",
" dataframe=df_sample,\n",
" template=templates.QA_PROMPT_TEMPLATE_STR,\n",
" model=model,\n",
Expand Down Expand Up @@ -776,7 +776,7 @@
}
],
"source": [
"Q_and_A_classifications = llm_eval_binary(\n",
"Q_and_A_classifications = llm_classify(\n",
" dataframe=df_sample,\n",
" template=templates.QA_PROMPT_TEMPLATE_STR,\n",
" model=model,\n",
Expand Down
Loading

0 comments on commit 7d9ae34

Please sign in to comment.