Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added COT Metric and Adapter to MMLU Pro #3162

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/helm/benchmark/metrics/chain_of_thought_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
from typing import List, Optional

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.metrics.metric import Metric
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat


def extract_answer(output_text: str) -> Optional[str]:
"""
Extracts the answer from the output text using two exact regex patterns.
Returns None if no valid answer is found.

Args:
output_text (str): The text from which to extract the answer.

Returns:
Optional[str]: The extracted answer (A-J) if found, otherwise None.
"""
# First regex: Matches "answer is (A-J)" with optional parentheses
match = re.search(r"answer is \(?([A-J])\)?", output_text)
if match:
return match.group(1)

# Second regex: Matches "[answer: (A-J)]" with optional leading characters like "."
match = re.search(r"\.*\[aA\]nswer:\s*\(?([A-J])\)?", output_text)
if match:
return match.group(1)

# If neither regex matches, return None
return None


class ChainOfThoughtMetric(Metric):
"""
This metric focuses on structured reasoning and the accuracy of extracted answers.
It compares model outputs against correct answers provided in a multiple-choice
format and returns a score indicating the correctness of the generated response.
"""

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
"""
Evaluate the generated output for chain-of-thought reasoning accuracy.

The method extracts the model's output, determines the correct answer
from the provided references, and compares the two to compute a binary score.

Args:
adapter_spec (AdapterSpec): Specification of the adapter used for the evaluation.
request_state (RequestState): The state of the current request, including
the input instance, output results, and references.
metric_service (MetricService): A service used to compute metrics if needed.
eval_cache_path (str): Path to the evaluation cache for storing or retrieving data.

Returns:
List[Stat]: A list containing a single `Stat` object with the correctness
score (1 for correct, 0 for incorrect) under the metric
name "chain_of_thought_correct".
"""
# Assert that completions exist if the result is not None
assert (
request_state.result is not None and request_state.result.completions
), "Request state result must have completions."

# Set output_text if the assertion passes
output_text = request_state.result.completions[0].text

# Extract the answer using the updated logic
extracted_answer = extract_answer(output_text)

# Find the correct answer from references by translating index to letter
correct_answer = None
for index, option in enumerate(request_state.instance.references):
if option.is_correct:
correct_answer = chr(65 + index) # Translate index (0 -> A, 1 -> B, etc.)
break

# Raise an exception if no correct answer is found
if correct_answer is None:
raise ValueError(f"No correct answer found for instance ID {request_state.instance.id}")

# Compare extracted answer with the correct answer and compute the score
score = 1 if extracted_answer == correct_answer else 0
return [Stat(MetricName("chain_of_thought_correct")).add(score)]
94 changes: 83 additions & 11 deletions src/helm/benchmark/run_specs/lite_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
get_generative_harms_metric_specs,
get_generic_metric_specs,
get_open_ended_generation_metric_specs,
MetricSpec,
)
from helm.benchmark.run_spec import RunSpec, run_spec_function
from helm.benchmark.runner import get_benchmark_output_path
from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path
from helm.benchmark.metrics.metric import MetricSpec


@run_spec_function("narrative_qa")
Expand Down Expand Up @@ -137,23 +137,92 @@ def get_mmlu_spec(subject: str, method: str = ADAPT_MULTIPLE_CHOICE_JOINT) -> Ru


@run_spec_function("mmlu_pro")
def get_mmlu_pro_spec(subject: str) -> RunSpec:
def get_mmlu_pro_spec(subject: str, use_chain_of_thought: str = "False", use_few_shot: str = "False") -> RunSpec:
# Convert to bools and remove the str versions
use_chain_of_thought_bool: bool = use_chain_of_thought == "True"
use_few_shot_bool: bool = use_few_shot == "True"
del use_chain_of_thought
del use_few_shot

scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.mmlu_pro.MMLUProScenario", args={"subject": subject}
)
max_train_instance_num = 5 if use_few_shot_bool else 0

adapter_spec = get_multiple_choice_adapter_spec(
method=ADAPT_MULTIPLE_CHOICE_JOINT,
instructions=f"The following are multiple choice questions (with answers) about {subject.replace('_', ' ')}.",
input_noun="Question",
output_noun="Answer",
)
if use_few_shot_bool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't need this outer if. The rest of the adapter spec should be exactly the same for few shot and zero shot.

if use_chain_of_thought_bool:
adapter_spec = get_multiple_choice_adapter_spec(
method=ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT,
max_tokens=1000, # following original repo
max_train_instances=max_train_instance_num,
instructions=(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the instructions from the paper.

"Here are some example questions from experts. "
"An explanation is given before the final answer. "
"Answer the final question yourself, giving your reasoning beforehand."
),
input_noun="Question",
input_suffix="\nChoices: \n",
reference_prefix="(A) ",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete reference_prefix (default to "A. " which is used if reference_prefix is unspecified - this follows the paper).

chain_of_thought_prefix="Let's think step by step: ",
chain_of_thought_suffix="The correct answer is ",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this results in adding the answer twice to the prompt e.g. "The answer is (A). The correct answer is A"

We need to deal with this somehow, probably in the adapter. I'm okay with defering this fix to another pull request.

output_noun="", # will be overwritten with output_prefix
output_prefix="",
global_suffix=(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow the paper - they don't use this suffix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete global_suffix

"Give step by step reasoning before you answer, and when you’re ready to answer, "
'please use the format "The correct answer is (insert answer here)":'
),
)
else:
adapter_spec = get_multiple_choice_adapter_spec(
method=ADAPT_MULTIPLE_CHOICE_JOINT,
max_train_instances=max_train_instance_num,
instructions=(
"Here are some example questions from experts. "
"An explanation is given before the final answer. "
"Answer the final question yourself, giving your reasoning beforehand."
),
input_noun="Question",
input_suffix="\nChoices: \n",
reference_prefix="(A) ",
output_noun="", # will be overwritten with output_prefix
output_prefix="The correct answer is ",
)
else:
if use_chain_of_thought_bool:
adapter_spec = AdapterSpec(
method=ADAPT_MULTIPLE_CHOICE_JOINT_CHAIN_OF_THOUGHT,
max_train_instances=max_train_instance_num,
max_tokens=1000,
input_prefix="What is the correct answer to this question: ",
input_suffix="\nChoices:\n",
output_prefix="",
reference_prefix="(A) ",
global_suffix=(
"Let’s think step by step. Based on your reasoning, what is the single, "
"most likely answer choice? Format your response as follows: "
'"The correct answer is (insert answer here)".'
),
)
else:
adapter_spec = AdapterSpec(
method=ADAPT_MULTIPLE_CHOICE_JOINT,
max_train_instances=max_train_instance_num,
max_tokens=1000,
input_prefix="What is the correct answer to this question: ",
input_suffix="\nChoices:\n",
output_prefix="",
reference_prefix="(A) ",
global_suffix=("Format your response as follows: " '"The correct answer is (insert answer here)".'),
)

return RunSpec(
name=f"mmlu_pro:subject={subject}",
name=f"gpqa:subset={subject},use_chain_of_thought={use_chain_of_thought_bool}",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change "gpqa" to "mmlu_pro"

scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(),
metric_specs=get_exact_match_metric_specs()
+ [
MetricSpec(class_name="helm.benchmark.metrics.chain_of_thought_metric.ChainOfThoughtMetric", args={}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only add this metric if chain of thought is used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Address this in GPQA as well

],
groups=["mmlu_pro"],
)

Expand Down Expand Up @@ -414,7 +483,10 @@ def get_gpqa_spec(subset: str, use_chain_of_thought: str = "False", use_few_shot
name=f"gpqa:subset={subset},use_chain_of_thought={use_chain_of_thought_bool}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(), # TODO: update this after cot metric is ready
metric_specs=get_exact_match_metric_specs()
+ [
MetricSpec(class_name="helm.benchmark.metrics.chain_of_thought_metric.ChainOfThoughtMetric", args={}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only add this metric if chain of thought is used.

],
groups=["gpqa"],
)

Expand Down
32 changes: 27 additions & 5 deletions src/helm/benchmark/scenarios/mmlu_pro.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow didn't catch this before, but please rename this file to mmlu_pro_scenario.py to match the convention.

Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Dict, List
from datasets import load_dataset
from datasets import Dataset, load_dataset

from helm.common.hierarchical_logger import hlog
from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output
from helm.benchmark.scenarios.scenario import (
Scenario,
Instance,
Reference,
TRAIN_SPLIT,
TEST_SPLIT,
CORRECT_TAG,
Input,
Output,
)


class MMLUProScenario(Scenario):
Expand Down Expand Up @@ -33,7 +42,14 @@ def __init__(self, subject: str):
super().__init__()
self.subject: str = subject

def process_csv(self, data, split: str) -> List[Instance]:
def process_dataset(self, data: Dataset, split: str) -> List[Instance]:
"""
Process the dataset to create instances.

:param data: Hugging Face `Dataset` containing the data for a specific split.
:param split: The data split (e.g., "train", "test").
:return: A list of processed `Instance` objects.
"""
instances: List[Instance] = []
hlog(f"Processing data for {split} split")
for row in data:
Expand All @@ -55,8 +71,14 @@ def answer_to_reference(answer: str) -> Reference:
return instances

def get_instances(self, output_path: str) -> List[Instance]:
"""
Load and process the MMLU-Pro dataset to create instances.

:param output_path: Path to save or output the processed instances.
:return: A list of all processed `Instance` objects.
"""
# Load the MMLU-Pro dataset from Hugging Face
dataset = load_dataset("TIGER-Lab/MMLU-Pro")
dataset = load_dataset("TIGER-Lab/MMLU-Pro", revision="3373e0b")

# Process all the instances
instances: List[Instance] = []
Expand All @@ -66,6 +88,6 @@ def get_instances(self, output_path: str) -> List[Instance]:
}
for hf_split, split in splits.items():
data = dataset[hf_split].filter(lambda x: x["category"] == self.subject)
instances.extend(self.process_csv(data, split))
instances.extend(self.process_dataset(data, split))

return instances
25 changes: 6 additions & 19 deletions src/helm/benchmark/static/schema_lite_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ metrics:
short_display_name: IFEval Strict Acc
description: Fraction of instructions in the instance that are correctly followed.
lower_is_better: false
- name: chain_of_thought_correct
display_name: COT correct
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Chain of thought correctness" or something more descriptive like that.

short_display_name: COT correct
description: TBD.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add description.

lower_is_better: false

############################################################
perturbations: []
Expand Down Expand Up @@ -135,7 +140,6 @@ run_groups:
subgroups:
- mmlu_pro
- gpqa
- ifeval
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't delete IFEval.


- name: mmlu_pro
display_name: MMLU-Pro
Expand All @@ -162,24 +166,7 @@ run_groups:
- efficiency
- general_information
environment:
main_name: exact_match # non-CoT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't delete the rest of the environment and taxonomy.

main_split: test
taxonomy:
task: "?"
what: "?"
who: "?"
when: "?"
language: English

- name: ifeval
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't delete IFEval.

display_name: IFEval
description: IFEval
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: ifeval_strict_accuracy
main_name: chain_of_thought_correct # non-CoT
main_split: test
taxonomy:
task: "?"
Expand Down