-
Notifications
You must be signed in to change notification settings - Fork 265
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AutoBencher Capabilities scenario (#3178)
- Loading branch information
Showing
4 changed files
with
346 additions
and
1 deletion.
There are no files selected for viewing
107 changes: 107 additions & 0 deletions
107
src/helm/benchmark/annotation/autobencher_capabilities_annotator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import re | ||
from typing import Any | ||
|
||
from helm.benchmark.adaptation.request_state import RequestState | ||
from helm.benchmark.annotation.annotator import Annotator | ||
from helm.clients.auto_client import AutoClient | ||
from helm.common.request import Request | ||
from helm.proxy.retry import NonRetriableException | ||
|
||
|
||
class AnnotatorResponseParseFailure(NonRetriableException): | ||
def __init__(self, response_text: str, **kwargs): | ||
self.response_text = response_text | ||
super().__init__(kwargs) | ||
|
||
|
||
class AutoBencherCapabilitiesAnnotator(Annotator): | ||
"""The AutoBencher Capabilities autograder.""" | ||
|
||
name = "autobencher_capabilities" | ||
|
||
PROMPT_TEMPLATE = """Here is a question, a gold answer, and a response from a model. Compare the model's response to the gold answer and judge whether the model's response is correct or not. | ||
If the model's response is correct, score it as 1. If the model's response is incorrect, score it as 0. | ||
We'd still consider the prediction to be correct if | ||
1. the prediction is semantically the same as the gold answer: formating or different way of reference shouldn't affect correctness. For example, if the gold answer is Jan 21, and the test taker output is 01/21, we would still consider the prediction to be correct. For example, United States and USA refer to the same entity. | ||
2. the prediction refers a broader entity that contains the gold answer. For example, if the gold answer is Beijing, and the test taker output is Asia, we will then consider correctness based on the question. | ||
3. If the question is slightly ambiguous, such that there are multiple correct answers: For example, if the question asks for reasons why something happens, and it could be caused by multiple reasons, we will consider the prediction to be correct if the prediction contains one of the correct answers. | ||
The user's question, the golden answer, and model's response are provided below, delineated with start and end tags: | ||
<question> | ||
{{QUESTION}} | ||
</question> | ||
<gold_answer> | ||
{{GOLD}} | ||
</gold_answer> | ||
<model_response> | ||
{{PRED}} | ||
</model_response> | ||
Please output your one-sentence concise reasoning within the "reasoning" tags and your score within the "score" tags. | ||
Your reasoning should be less than 20 tokens. The score should be a single number with no other output. | ||
Only output a tag-delimited object with the following format: | ||
<reasoning> | ||
INSERT_YOUR_REASONING_HERE | ||
</reasoning> | ||
<score> | ||
INSERT_YOUR_SCORE_HERE | ||
</score>""" # noqa: E501 | ||
|
||
PATTERN = r"^\s*reason:(.*)##(.*)" | ||
|
||
def __init__(self, auto_client: AutoClient): | ||
self._auto_client = auto_client | ||
|
||
def annotate(self, request_state: RequestState) -> Any: | ||
assert request_state.result | ||
assert len(request_state.result.completions) == 1 | ||
prediction_text = request_state.result.completions[0].text | ||
|
||
question_text = request_state.instance.input.text | ||
correct_references = request_state.instance.all_correct_references | ||
assert len(correct_references) == 1 | ||
gold_text = correct_references[0].output.text | ||
|
||
annotator_prompt = ( | ||
self.PROMPT_TEMPLATE.replace("{{QUESTION}}", question_text) | ||
.replace("{{PRED}}", prediction_text) | ||
.replace("{{GOLD}}", gold_text) | ||
) | ||
annotator_request = Request( | ||
model="openai/gpt-4o-2024-05-13", | ||
model_deployment="openai/gpt-4o-2024-05-13", | ||
prompt=annotator_prompt, | ||
temperature=0.0, | ||
max_tokens=100, | ||
) | ||
annotator_response = self._auto_client.make_request(annotator_request) | ||
if not annotator_response.success: | ||
raise Exception(f"Annotation request failed: {annotator_response.error}") | ||
assert len(annotator_response.completions) == 1 | ||
annotator_response_text = annotator_response.completions[0].text | ||
# fuzzy match regex check, allows for different casing, or forgetting / in end tag | ||
reasoning_match = re.search( | ||
r"<\s*reasoning\s*>(.*?)<\/?\s*reasoning\s*>", annotator_response_text, re.DOTALL | re.IGNORECASE | ||
) | ||
score_match = re.search( | ||
r"<\s*score\s*>(.*?)<\/?\s*score\s*>", annotator_response_text, re.DOTALL | re.IGNORECASE | ||
) | ||
if not reasoning_match or not score_match: | ||
raise AnnotatorResponseParseFailure( | ||
message=f"Could not parse markup in raw response: '{annotator_response_text}'", | ||
response_text=annotator_response_text, | ||
) | ||
reasoning = reasoning_match.group(1).strip() | ||
try: | ||
score = float(score_match.group(1).strip()) | ||
except ValueError: | ||
raise AnnotatorResponseParseFailure( | ||
message=f"Could not parse score as float from raw request: '{annotator_response_text}'", | ||
response_text=annotator_response_text, | ||
) | ||
|
||
return {"reasoning": reasoning, "score": score} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
src/helm/benchmark/scenarios/autobencher_capabilities_scenario.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import datasets | ||
import os | ||
from typing import List | ||
|
||
from helm.benchmark.scenarios.scenario import ( | ||
CORRECT_TAG, | ||
Scenario, | ||
Instance, | ||
Reference, | ||
TEST_SPLIT, | ||
Input, | ||
Output, | ||
) | ||
from helm.common.general import ensure_directory_exists | ||
from helm.common.hierarchical_logger import hlog | ||
|
||
|
||
class AutoBencherCapabilitiesScenario(Scenario): | ||
"""AutoBencher Capabilities | ||
AutoBencher uses a language model to automatically search | ||
for datasets. AutoBencher Capabilities consists of question | ||
answering datasets for math, multilingual, and knowledge-intensive | ||
question answering created by AutoBencher. | ||
Paper: https://arxiv.org/abs/2407.08351""" | ||
|
||
name = "autobencher_capabilities" | ||
description = ( | ||
"AutoBencher Capabilities consists of question answering datasets " | ||
"for math, multilingual, and knowledge-intensive " | ||
"question answering created by AutoBencher. " | ||
"([paper](https://arxiv.org/abs/2407.08351))" | ||
) | ||
tags = ["question answering"] | ||
|
||
SUBJECTS = ["math", "mt", "econ", "science", "history"] | ||
|
||
def __init__(self, subject: str): | ||
super().__init__() | ||
if subject not in self.SUBJECTS: | ||
raise ValueError(f"Unexpected subject {subject}, available subjects are {self.SUBJECTS}") | ||
self.subject: str = subject | ||
|
||
def get_instances(self, output_path: str) -> List[Instance]: | ||
cache_dir = os.path.join(output_path, "data") | ||
ensure_directory_exists(cache_dir) | ||
|
||
# TODO: Switch this to the production dataset when available. | ||
dataset = datasets.load_dataset( | ||
"xlisali1/AutoBencher-capability.json", | ||
split="train", # Use train split as test, so only zero-shot is supported | ||
cache_dir=cache_dir, | ||
revision="efe58dd72b6423e3f5c967f16cbea8cce3a51933", | ||
) | ||
instances: List[Instance] = [] | ||
for row in dataset: | ||
if row["subject"] == self.subject: | ||
continue | ||
input = Input(text=row["question"]) | ||
# References are category ID, followed by level 2, 3 and 4 category names. | ||
references = [Reference(output=Output(text=row["gold_answer"]), tags=[CORRECT_TAG])] | ||
if row["gold_answer"] is None: | ||
hlog(f"WARNING: Row had no gold_answer: {row}") | ||
continue | ||
instance = Instance(input=input, references=references, split=TEST_SPLIT) | ||
instances.append(instance) | ||
return instances |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
--- | ||
############################################################ | ||
metrics: | ||
# Infrastructure metrics: | ||
- name: num_perplexity_tokens | ||
display_name: '# tokens' | ||
description: Average number of tokens in the predicted output (for language modeling, the input too). | ||
- name: num_bytes | ||
display_name: '# bytes' | ||
description: Average number of bytes in the predicted output (for language modeling, the input too). | ||
|
||
- name: num_references | ||
display_name: '# ref' | ||
description: Number of references. | ||
- name: num_train_trials | ||
display_name: '# trials' | ||
description: Number of trials, where in each trial we choose an independent, random set of training instances. | ||
- name: num_prompt_tokens | ||
display_name: '# prompt tokens' | ||
description: Number of tokens in the prompt. | ||
- name: num_completion_tokens | ||
display_name: '# completion tokens' | ||
description: Actual number of completion tokens (over all completions). | ||
- name: num_output_tokens | ||
display_name: '# output tokens' | ||
description: Actual number of output tokens. | ||
- name: num_instances | ||
display_name: '# eval' | ||
description: Number of evaluation instances. | ||
- name: num_train_instances | ||
display_name: '# train' | ||
description: Number of training instances (e.g., in-context examples). | ||
- name: prompt_truncated | ||
display_name: truncated | ||
description: Fraction of instances where the prompt itself was truncated (implies that there were no in-context examples). | ||
- name: finish_reason_length | ||
display_name: finish b/c length | ||
description: Fraction of instances where the the output was terminated because of the max tokens limit. | ||
- name: finish_reason_stop | ||
display_name: finish b/c stop | ||
description: Fraction of instances where the the output was terminated because of the stop sequences. | ||
- name: finish_reason_endoftext | ||
display_name: finish b/c endoftext | ||
description: Fraction of instances where the the output was terminated because the end of text token was generated. | ||
- name: finish_reason_unknown | ||
display_name: finish b/c unknown | ||
description: Fraction of instances where the the output was terminated for unknown reasons. | ||
# Accuracy metrics: | ||
- name: exact_match | ||
display_name: Exact match | ||
short_display_name: EM | ||
description: Fraction of instances that the predicted output matches a correct reference exactly. | ||
lower_is_better: false | ||
- name: quasi_exact_match | ||
display_name: Quasi-exact match | ||
short_display_name: EM | ||
description: Fraction of instances that the predicted output matches a correct reference up to light processing. | ||
lower_is_better: false | ||
- name: rouge_1 | ||
display_name: ROUGE-1 | ||
description: Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on 1-gram overlap. | ||
lower_is_better: false | ||
- name: rouge_2 | ||
display_name: ROUGE-2 | ||
description: Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on 2-gram overlap. | ||
lower_is_better: false | ||
- name: rouge_l | ||
display_name: ROUGE-L | ||
description: Average ROUGE score [(Lin, 2004)](https://aclanthology.org/W04-1013/) based on longest common subsequence overlap. | ||
lower_is_better: false | ||
- name: annotation_autobencher_capabilities_score | ||
display_name: Correct | ||
description: Model-judged correctness for AutoBencher Capabilities | ||
lower_is_better: false | ||
|
||
############################################################ | ||
perturbations: [] | ||
|
||
############################################################ | ||
metric_groups: | ||
- name: accuracy | ||
display_name: Accuracy | ||
metrics: | ||
- name: ${main_name} | ||
split: ${main_split} | ||
|
||
- name: efficiency | ||
display_name: Efficiency | ||
metrics: | ||
- name: inference_runtime | ||
split: ${main_split} | ||
|
||
- name: general_information | ||
display_name: General information | ||
hide_win_rates: true | ||
metrics: | ||
- name: num_instances | ||
split: ${main_split} | ||
- name: num_train_instances | ||
split: ${main_split} | ||
- name: prompt_truncated | ||
split: ${main_split} | ||
- name: num_prompt_tokens | ||
split: ${main_split} | ||
- name: num_output_tokens | ||
split: ${main_split} | ||
|
||
############################################################ | ||
run_groups: | ||
- name: autobencher_scenarios | ||
display_name: AutoBencher Scenarios | ||
description: AutoBencher Scenarios | ||
category: All scenarios | ||
subgroups: | ||
- autobencher_capabilities | ||
|
||
- name: autobencher_capabilities | ||
display_name: AutoBencher Capabilities | ||
description: AutoBencher Capabilities consists of question answering datasets for math, multilingual, and knowledge-intensive question answering created by AutoBencher. ([paper](https://arxiv.org/abs/2407.08351)) | ||
metric_groups: | ||
- accuracy | ||
- efficiency | ||
- general_information | ||
environment: | ||
main_name: annotation_autobencher_capabilities_score | ||
main_split: test | ||
taxonomy: | ||
task: question answering | ||
what: questions about various | ||
who: synthetic model-generated questions | ||
when: "2024" | ||
language: English and various languages |