-
Notifications
You must be signed in to change notification settings - Fork 266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added COT Metric and Adapter to MMLU Pro #3162
Changes from 21 commits
fad62fd
89460ec
a366e24
d676183
de6b9b1
6c09cbc
2e02fb7
d039a9d
af03185
d675da0
16afbbe
4a8e167
d367578
23968c2
90ac194
7cfbb1c
c876828
97a9aff
6d5eb55
1398ab2
243057e
0ce9bc9
bd7edc1
c9b2082
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import re | ||
from typing import List, Optional | ||
|
||
from helm.benchmark.adaptation.adapter_spec import AdapterSpec | ||
from helm.benchmark.adaptation.request_state import RequestState | ||
from helm.benchmark.metrics.metric import Metric | ||
from helm.benchmark.metrics.metric_name import MetricName | ||
from helm.benchmark.metrics.metric_service import MetricService | ||
from helm.benchmark.metrics.statistic import Stat | ||
|
||
|
||
def extract_answer(output_text: str) -> Optional[str]: | ||
""" | ||
Extracts the answer from the output text using two exact regex patterns. | ||
Returns None if no valid answer is found. | ||
|
||
Args: | ||
output_text (str): The text from which to extract the answer. | ||
|
||
Returns: | ||
Optional[str]: The extracted answer (A-J) if found, otherwise None. | ||
""" | ||
# First regex: Matches "answer is (A-J)" with optional parentheses | ||
match = re.search(r"answer is \(?([A-J])\)?", output_text) | ||
if match: | ||
return match.group(1) | ||
|
||
# Second regex: Matches "[answer: (A-J)]" with optional leading characters like "." | ||
match = re.search(r"\.*\[aA\]nswer:\s*\(?([A-J])\)?", output_text) | ||
if match: | ||
return match.group(1) | ||
|
||
# If neither regex matches, return None | ||
return None | ||
|
||
|
||
class ChainOfThoughtMetric(Metric): | ||
""" | ||
This metric focuses on structured reasoning and the accuracy of extracted answers. | ||
It compares model outputs against correct answers provided in a multiple-choice | ||
format and returns a score indicating the correctness of the generated response. | ||
""" | ||
|
||
def evaluate_generation( | ||
self, | ||
adapter_spec: AdapterSpec, | ||
request_state: RequestState, | ||
metric_service: MetricService, | ||
eval_cache_path: str, | ||
) -> List[Stat]: | ||
""" | ||
Evaluate the generated output for chain-of-thought reasoning accuracy. | ||
|
||
The method extracts the model's output, determines the correct answer | ||
from the provided references, and compares the two to compute a binary score. | ||
|
||
Args: | ||
adapter_spec (AdapterSpec): Specification of the adapter used for the evaluation. | ||
request_state (RequestState): The state of the current request, including | ||
the input instance, output results, and references. | ||
metric_service (MetricService): A service used to compute metrics if needed. | ||
eval_cache_path (str): Path to the evaluation cache for storing or retrieving data. | ||
|
||
Returns: | ||
List[Stat]: A list containing a single `Stat` object with the correctness | ||
score (1 for correct, 0 for incorrect) under the metric | ||
name "chain_of_thought_correct". | ||
""" | ||
# Assert that completions exist if the result is not None | ||
assert ( | ||
request_state.result is not None and request_state.result.completions | ||
), "Request state result must have completions." | ||
|
||
# Set output_text if the assertion passes | ||
output_text = request_state.result.completions[0].text | ||
|
||
# Extract the answer using the updated logic | ||
extracted_answer = extract_answer(output_text) | ||
|
||
# Find the correct answer from references by translating index to letter | ||
correct_answer = None | ||
for index, option in enumerate(request_state.instance.references): | ||
if option.is_correct: | ||
correct_answer = chr(65 + index) # Translate index (0 -> A, 1 -> B, etc.) | ||
break | ||
|
||
# Raise an exception if no correct answer is found | ||
if correct_answer is None: | ||
raise ValueError(f"No correct answer found for instance ID {request_state.instance.id}") | ||
|
||
# Compare extracted answer with the correct answer and compute the score | ||
score = 1 if extracted_answer == correct_answer else 0 | ||
return [Stat(MetricName("chain_of_thought_correct")).add(score)] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,11 @@ | |
get_generative_harms_metric_specs, | ||
get_generic_metric_specs, | ||
get_open_ended_generation_metric_specs, | ||
MetricSpec, | ||
) | ||
from helm.benchmark.run_spec import RunSpec, run_spec_function | ||
from helm.benchmark.runner import get_benchmark_output_path | ||
from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path | ||
from helm.benchmark.metrics.metric import MetricSpec | ||
|
||
|
||
@run_spec_function("narrative_qa") | ||
|
@@ -137,23 +137,92 @@ def get_mmlu_spec(subject: str, method: str = ADAPT_MULTIPLE_CHOICE_JOINT) -> Ru | |
|
||
|
||
@run_spec_function("mmlu_pro") | ||
def get_mmlu_pro_spec(subject: str) -> RunSpec: | ||
def get_mmlu_pro_spec(subject: str, use_chain_of_thought: str = "False", use_few_shot: str = "False") -> RunSpec: | ||
# Convert to bools and remove the str versions | ||
use_chain_of_thought_bool: bool = use_chain_of_thought == "True" | ||
use_few_shot_bool: bool = use_few_shot == "True" | ||
del use_chain_of_thought | ||
del use_few_shot | ||
|
||
scenario_spec = ScenarioSpec( | ||
class_name="helm.benchmark.scenarios.mmlu_pro.MMLUProScenario", args={"subject": subject} | ||
) | ||
max_train_instance_num = 5 if use_few_shot_bool else 0 | ||
|
||
adapter_spec = get_multiple_choice_adapter_spec( | ||
method=ADAPT_MULTIPLE_CHOICE_JOINT, | ||
instructions=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.", | ||
input_noun="Question", | ||
output_noun="Answer", | ||
) | ||
if use_few_shot_bool: | ||
if use_chain_of_thought_bool: | ||
adapter_spec = get_multiple_choice_adapter_spec( | ||
method=ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT, | ||
max_tokens=1000, # following original repo | ||
max_train_instances=max_train_instance_num, | ||
instructions=( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the instructions from the paper. |
||
"Here are some example questions from experts. " | ||
"An explanation is given before the final answer. " | ||
"Answer the final question yourself, giving your reasoning beforehand." | ||
), | ||
input_noun="Question", | ||
input_suffix="\nChoices: \n", | ||
reference_prefix="(A) ", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Delete |
||
chain_of_thought_prefix="Let's think step by step: ", | ||
chain_of_thought_suffix="The correct answer is ", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this results in adding the answer twice to the prompt e.g. "The answer is (A). The correct answer is A" We need to deal with this somehow, probably in the adapter. I'm okay with defering this fix to another pull request. |
||
output_noun="", # will be overwritten with output_prefix | ||
output_prefix="", | ||
global_suffix=( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Follow the paper - they don't use this suffix. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Delete global_suffix |
||
"Give step by step reasoning before you answer, and when you’re ready to answer, " | ||
'please use the format "The correct answer is (insert answer here)":' | ||
), | ||
) | ||
else: | ||
adapter_spec = get_multiple_choice_adapter_spec( | ||
method=ADAPT_MULTIPLE_CHOICE_JOINT, | ||
max_train_instances=max_train_instance_num, | ||
instructions=( | ||
"Here are some example questions from experts. " | ||
"An explanation is given before the final answer. " | ||
"Answer the final question yourself, giving your reasoning beforehand." | ||
), | ||
input_noun="Question", | ||
input_suffix="\nChoices: \n", | ||
reference_prefix="(A) ", | ||
output_noun="", # will be overwritten with output_prefix | ||
output_prefix="The correct answer is ", | ||
) | ||
else: | ||
if use_chain_of_thought_bool: | ||
adapter_spec = AdapterSpec( | ||
method=ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT, | ||
max_train_instances=max_train_instance_num, | ||
max_tokens=1000, | ||
input_prefix="What is the correct answer to this question: ", | ||
input_suffix="\nChoices:\n", | ||
output_prefix="", | ||
reference_prefix="(A) ", | ||
global_suffix=( | ||
"Let’s think step by step. Based on your reasoning, what is the single, " | ||
"most likely answer choice? Format your response as follows: " | ||
'"The correct answer is (insert answer here)".' | ||
), | ||
) | ||
else: | ||
adapter_spec = AdapterSpec( | ||
method=ADAPT_MULTIPLE_CHOICE_JOINT, | ||
max_train_instances=max_train_instance_num, | ||
max_tokens=1000, | ||
input_prefix="What is the correct answer to this question: ", | ||
input_suffix="\nChoices:\n", | ||
output_prefix="", | ||
reference_prefix="(A) ", | ||
global_suffix=("Format your response as follows: " '"The correct answer is (insert answer here)".'), | ||
) | ||
|
||
return RunSpec( | ||
name=f"mmlu_pro:subject={subject}", | ||
name=f"gpqa:subset={subject},use_chain_of_thought={use_chain_of_thought_bool}", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change "gpqa" to "mmlu_pro" |
||
scenario_spec=scenario_spec, | ||
adapter_spec=adapter_spec, | ||
metric_specs=get_exact_match_metric_specs(), | ||
metric_specs=get_exact_match_metric_specs() | ||
+ [ | ||
MetricSpec(class_name="helm.benchmark.metrics.chain_of_thought_metric.ChainOfThoughtMetric", args={}), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only add this metric if chain of thought is used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Address this in GPQA as well |
||
], | ||
groups=["mmlu_pro"], | ||
) | ||
|
||
|
@@ -414,7 +483,10 @@ def get_gpqa_spec(subset: str, use_chain_of_thought: str = "False", use_few_shot | |
name=f"gpqa:subset={subset},use_chain_of_thought={use_chain_of_thought_bool}", | ||
scenario_spec=scenario_spec, | ||
adapter_spec=adapter_spec, | ||
metric_specs=get_exact_match_metric_specs(), # TODO: update this after cot metric is ready | ||
metric_specs=get_exact_match_metric_specs() | ||
+ [ | ||
MetricSpec(class_name="helm.benchmark.metrics.chain_of_thought_metric.ChainOfThoughtMetric", args={}), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only add this metric if chain of thought is used. |
||
], | ||
groups=["gpqa"], | ||
) | ||
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Somehow didn't catch this before, but please rename this file to |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,11 @@ metrics: | |
short_display_name: IFEval Strict Acc | ||
description: Fraction of instructions in the instance that are correctly followed. | ||
lower_is_better: false | ||
- name: chain_of_thought_correct | ||
display_name: COT correct | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Chain of thought correctness" or something more descriptive like that. |
||
short_display_name: COT correct | ||
description: TBD. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add description. |
||
lower_is_better: false | ||
|
||
############################################################ | ||
perturbations: [] | ||
|
@@ -135,7 +140,6 @@ run_groups: | |
subgroups: | ||
- mmlu_pro | ||
- gpqa | ||
- ifeval | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't delete IFEval. |
||
|
||
- name: mmlu_pro | ||
display_name: MMLU-Pro | ||
|
@@ -162,24 +166,7 @@ run_groups: | |
- efficiency | ||
- general_information | ||
environment: | ||
main_name: exact_match # non-CoT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't delete the rest of the environment and taxonomy. |
||
main_split: test | ||
taxonomy: | ||
task: "?" | ||
what: "?" | ||
who: "?" | ||
when: "?" | ||
language: English | ||
|
||
- name: ifeval | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't delete IFEval. |
||
display_name: IFEval | ||
description: IFEval | ||
metric_groups: | ||
- accuracy | ||
- efficiency | ||
- general_information | ||
environment: | ||
main_name: ifeval_strict_accuracy | ||
main_name: chain_of_thought_correct # non-CoT | ||
main_split: test | ||
taxonomy: | ||
task: "?" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't need this outer if. The rest of the adapter spec should be exactly the same for few shot and zero shot.