Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addressed Comments on GPQA Metric and MMLU Pro Non-COT Repo #3161

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions src/helm/benchmark/metrics/chain_of_thought_metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import re
from typing import List, Optional

from helm.benchmark.adaptation.adapter_spec import AdapterSpec
from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.metrics.metric import Metric
from helm.benchmark.metrics.metric_name import MetricName
from helm.benchmark.metrics.metric_service import MetricService
from helm.benchmark.metrics.statistic import Stat


def extract_answer(output_text: str) -> Optional[str]:
"""
Extracts the answer from the output text using two exact regex patterns.
Returns None if no valid answer is found.

Args:
output_text (str): The text from which to extract the answer.

Returns:
Optional[str]: The extracted answer (A-J) if found, otherwise None.
"""
# First regex: Matches "answer is (A-J)" with optional parentheses
match = re.search(r"answer is \(?([A-J])\)?", output_text)
if match:
return match.group(1)

# Second regex: Matches "[answer: (A-J)]" with optional leading characters like "."
match = re.search(r"\.*\[aA\]nswer:\s*\(?([A-J])\)?", output_text)
if match:
return match.group(1)

# If neither regex matches, return None
return None


class ChainOfThoughtMetric(Metric):
"""
This metric focuses on structured reasoning and the accuracy of extracted answers.
It compares model outputs against correct answers provided in a multiple-choice
format and returns a score indicating the correctness of the generated response.
"""

def evaluate_generation(
self,
adapter_spec: AdapterSpec,
request_state: RequestState,
metric_service: MetricService,
eval_cache_path: str,
) -> List[Stat]:
"""
Evaluate the generated output for chain-of-thought reasoning accuracy.

The method extracts the model's output, determines the correct answer
from the provided references, and compares the two to compute a binary score.

Args:
adapter_spec (AdapterSpec): Specification of the adapter used for the evaluation.
request_state (RequestState): The state of the current request, including
the input instance, output results, and references.
metric_service (MetricService): A service used to compute metrics if needed.
eval_cache_path (str): Path to the evaluation cache for storing or retrieving data.

Returns:
List[Stat]: A list containing a single `Stat` object with the correctness
score (1 for correct, 0 for incorrect) under the metric
name "chain_of_thought_correct".
"""
# Assert that completions exist if the result is not None
assert (
request_state.result is not None and request_state.result.completions
), "Request state result must have completions."

# Set output_text if the assertion passes
output_text = request_state.result.completions[0].text

# Extract the answer using the updated logic
extracted_answer = extract_answer(output_text)

# Find the correct answer from references by translating index to letter
correct_answer = None
for index, option in enumerate(request_state.instance.references):
if option.is_correct:
correct_answer = chr(65 + index) # Translate index (0 -> A, 1 -> B, etc.)
break

# Raise an exception if no correct answer is found
if correct_answer is None:
raise ValueError(f"No correct answer found for instance ID {request_state.instance.id}")

# Compare extracted answer with the correct answer and compute the score
score = 1 if extracted_answer == correct_answer else 0
return [Stat(MetricName("chain_of_thought_correct")).add(score)]
7 changes: 5 additions & 2 deletions src/helm/benchmark/run_specs/lite_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
get_generative_harms_metric_specs,
get_generic_metric_specs,
get_open_ended_generation_metric_specs,
MetricSpec,
)
from helm.benchmark.run_spec import RunSpec, run_spec_function
from helm.benchmark.runner import get_benchmark_output_path
from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path
from helm.benchmark.metrics.metric import MetricSpec


@run_spec_function("narrative_qa")
Expand Down Expand Up @@ -414,7 +414,10 @@ def get_gpqa_spec(subset: str, use_chain_of_thought: str = "False", use_few_shot
name=f"gpqa:subset={subset},use_chain_of_thought={use_chain_of_thought_bool}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=get_exact_match_metric_specs(), # TODO: update this after cot metric is ready
metric_specs=get_exact_match_metric_specs()
+ [
MetricSpec(class_name="helm.benchmark.metrics.chain_of_thought_metric.ChainOfThoughtMetric", args={}),
],
groups=["gpqa"],
)

Expand Down
32 changes: 27 additions & 5 deletions src/helm/benchmark/scenarios/mmlu_pro.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import Dict, List
from datasets import load_dataset
from datasets import Dataset, load_dataset

from helm.common.hierarchical_logger import hlog
from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output
from helm.benchmark.scenarios.scenario import (
Scenario,
Instance,
Reference,
TRAIN_SPLIT,
TEST_SPLIT,
CORRECT_TAG,
Input,
Output,
)


class MMLUProScenario(Scenario):
Expand Down Expand Up @@ -33,7 +42,14 @@ def __init__(self, subject: str):
super().__init__()
self.subject: str = subject

def process_csv(self, data, split: str) -> List[Instance]:
def process_dataset(self, data: Dataset, split: str) -> List[Instance]:
"""
Process the dataset to create instances.

:param data: Hugging Face `Dataset` containing the data for a specific split.
:param split: The data split (e.g., "train", "test").
:return: A list of processed `Instance` objects.
"""
instances: List[Instance] = []
hlog(f"Processing data for {split} split")
for row in data:
Expand All @@ -55,8 +71,14 @@ def answer_to_reference(answer: str) -> Reference:
return instances

def get_instances(self, output_path: str) -> List[Instance]:
"""
Load and process the MMLU-Pro dataset to create instances.

:param output_path: Path to save or output the processed instances.
:return: A list of all processed `Instance` objects.
"""
# Load the MMLU-Pro dataset from Hugging Face
dataset = load_dataset("TIGER-Lab/MMLU-Pro")
dataset = load_dataset("TIGER-Lab/MMLU-Pro", revision="3373e0b")

# Process all the instances
instances: List[Instance] = []
Expand All @@ -66,6 +88,6 @@ def get_instances(self, output_path: str) -> List[Instance]:
}
for hf_split, split in splits.items():
data = dataset[hf_split].filter(lambda x: x["category"] == self.subject)
instances.extend(self.process_csv(data, split))
instances.extend(self.process_dataset(data, split))

return instances
25 changes: 6 additions & 19 deletions src/helm/benchmark/static/schema_lite_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ metrics:
short_display_name: IFEval Strict Acc
description: Fraction of instructions in the instance that are correctly followed.
lower_is_better: false
- name: chain_of_thought_correct
display_name: COT correct
short_display_name: COT correct
description: TBD.
lower_is_better: false

############################################################
perturbations: []
Expand Down Expand Up @@ -135,7 +140,6 @@ run_groups:
subgroups:
- mmlu_pro
- gpqa
- ifeval

- name: mmlu_pro
display_name: MMLU-Pro
Expand All @@ -162,24 +166,7 @@ run_groups:
- efficiency
- general_information
environment:
main_name: exact_match # non-CoT
main_split: test
taxonomy:
task: "?"
what: "?"
who: "?"
when: "?"
language: English

- name: ifeval
display_name: IFEval
description: IFEval
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: ifeval_strict_accuracy
main_name: chain_of_thought_correct # non-CoT
main_split: test
taxonomy:
task: "?"
Expand Down