diff --git a/src/helm/benchmark/metrics/chain_of_thought_metric.py b/src/helm/benchmark/metrics/chain_of_thought_metric.py new file mode 100644 index 0000000000..32cfd880f3 --- /dev/null +++ b/src/helm/benchmark/metrics/chain_of_thought_metric.py @@ -0,0 +1,93 @@ +import re +from typing import List, Optional + +from helm.benchmark.adaptation.adapter_spec import AdapterSpec +from helm.benchmark.adaptation.request_state import RequestState +from helm.benchmark.metrics.metric import Metric +from helm.benchmark.metrics.metric_name import MetricName +from helm.benchmark.metrics.metric_service import MetricService +from helm.benchmark.metrics.statistic import Stat + + +def extract_answer(output_text: str) -> Optional[str]: + """ + Extracts the answer from the output text using two exact regex patterns. + Returns None if no valid answer is found. + + Args: + output_text (str): The text from which to extract the answer. + + Returns: + Optional[str]: The extracted answer (A-J) if found, otherwise None. + """ + # First regex: Matches "answer is (A-J)" with optional parentheses + match = re.search(r"answer is \(?([A-J])\)?", output_text) + if match: + return match.group(1) + + # Second regex: Matches "[answer: (A-J)]" with optional leading characters like "." + match = re.search(r"\.*\[aA\]nswer:\s*\(?([A-J])\)?", output_text) + if match: + return match.group(1) + + # If neither regex matches, return None + return None + + +class ChainOfThoughtMetric(Metric): + """ + This metric focuses on structured reasoning and the accuracy of extracted answers. + It compares model outputs against correct answers provided in a multiple-choice + format and returns a score indicating the correctness of the generated response. + """ + + def evaluate_generation( + self, + adapter_spec: AdapterSpec, + request_state: RequestState, + metric_service: MetricService, + eval_cache_path: str, + ) -> List[Stat]: + """ + Evaluate the generated output for chain-of-thought reasoning accuracy. + + The method extracts the model's output, determines the correct answer + from the provided references, and compares the two to compute a binary score. + + Args: + adapter_spec (AdapterSpec): Specification of the adapter used for the evaluation. + request_state (RequestState): The state of the current request, including + the input instance, output results, and references. + metric_service (MetricService): A service used to compute metrics if needed. + eval_cache_path (str): Path to the evaluation cache for storing or retrieving data. + + Returns: + List[Stat]: A list containing a single `Stat` object with the correctness + score (1 for correct, 0 for incorrect) under the metric + name "chain_of_thought_correct". + """ + # Assert that completions exist if the result is not None + assert ( + request_state.result is not None and request_state.result.completions + ), "Request state result must have completions." + + # Set output_text if the assertion passes + output_text = request_state.result.completions[0].text + + # Extract the answer using the updated logic + extracted_answer = extract_answer(output_text) + + # Find the correct answer from references by translating index to letter + correct_answer = None + for index, option in enumerate(request_state.instance.references): + if option.is_correct: + correct_answer = chr(65 + index) # Translate index (0 -> A, 1 -> B, etc.) + break + + # Raise an exception if no correct answer is found + if correct_answer is None: + raise ValueError(f"No correct answer found for instance ID {request_state.instance.id}") + + # Compare extracted answer with the correct answer and compute the score + score = 1 if extracted_answer == correct_answer else 0 + return [Stat(MetricName("chain_of_thought_correct")).add(score)] diff --git a/src/helm/benchmark/run_specs/lite_run_specs.py b/src/helm/benchmark/run_specs/lite_run_specs.py index 38fbd4ecaa..e7c5ea8a83 100644 --- a/src/helm/benchmark/run_specs/lite_run_specs.py +++ b/src/helm/benchmark/run_specs/lite_run_specs.py @@ -21,11 +21,11 @@ get_generative_harms_metric_specs, get_generic_metric_specs, get_open_ended_generation_metric_specs, - MetricSpec, ) from helm.benchmark.run_spec import RunSpec, run_spec_function from helm.benchmark.runner import get_benchmark_output_path from helm.benchmark.scenarios.scenario import ScenarioSpec, get_scenario_cache_path +from helm.benchmark.metrics.metric import MetricSpec @run_spec_function("narrative_qa") @@ -414,7 +414,10 @@ def get_gpqa_spec(subset: str, use_chain_of_thought: str = "False", use_few_shot name=f"gpqa:subset={subset},use_chain_of_thought={use_chain_of_thought_bool}", scenario_spec=scenario_spec, adapter_spec=adapter_spec, - metric_specs=get_exact_match_metric_specs(), # TODO: update this after cot metric is ready + metric_specs=get_exact_match_metric_specs() + + [ + MetricSpec(class_name="helm.benchmark.metrics.chain_of_thought_metric.ChainOfThoughtMetric", args={}), + ], groups=["gpqa"], ) diff --git a/src/helm/benchmark/scenarios/mmlu_pro.py b/src/helm/benchmark/scenarios/mmlu_pro.py index a091387dc2..5d08d4f9d1 100644 --- a/src/helm/benchmark/scenarios/mmlu_pro.py +++ b/src/helm/benchmark/scenarios/mmlu_pro.py @@ -1,8 +1,17 @@ from typing import Dict, List -from datasets import load_dataset +from datasets import Dataset, load_dataset from helm.common.hierarchical_logger import hlog -from .scenario import Scenario, Instance, Reference, TRAIN_SPLIT, TEST_SPLIT, CORRECT_TAG, Input, Output +from helm.benchmark.scenarios.scenario import ( + Scenario, + Instance, + Reference, + TRAIN_SPLIT, + TEST_SPLIT, + CORRECT_TAG, + Input, + Output, +) class MMLUProScenario(Scenario): @@ -33,7 +42,14 @@ def __init__(self, subject: str): super().__init__() self.subject: str = subject - def process_csv(self, data, split: str) -> List[Instance]: + def process_dataset(self, data: Dataset, split: str) -> List[Instance]: + """ + Process the dataset to create instances. + + :param data: Hugging Face `Dataset` containing the data for a specific split. + :param split: The data split (e.g., "train", "test"). + :return: A list of processed `Instance` objects. + """ instances: List[Instance] = [] hlog(f"Processing data for {split} split") for row in data: @@ -55,8 +71,14 @@ def answer_to_reference(answer: str) -> Reference: return instances def get_instances(self, output_path: str) -> List[Instance]: + """ + Load and process the MMLU-Pro dataset to create instances. + + :param output_path: Path to save or output the processed instances. + :return: A list of all processed `Instance` objects. + """ # Load the MMLU-Pro dataset from Hugging Face - dataset = load_dataset("TIGER-Lab/MMLU-Pro") + dataset = load_dataset("TIGER-Lab/MMLU-Pro", revision="3373e0b") # Process all the instances instances: List[Instance] = [] @@ -66,6 +88,6 @@ def get_instances(self, output_path: str) -> List[Instance]: } for hf_split, split in splits.items(): data = dataset[hf_split].filter(lambda x: x["category"] == self.subject) - instances.extend(self.process_csv(data, split)) + instances.extend(self.process_dataset(data, split)) return instances diff --git a/src/helm/benchmark/static/schema_lite_v2.yaml b/src/helm/benchmark/static/schema_lite_v2.yaml index 17d6fd5834..b00b87e76f 100644 --- a/src/helm/benchmark/static/schema_lite_v2.yaml +++ b/src/helm/benchmark/static/schema_lite_v2.yaml @@ -93,6 +93,11 @@ metrics: short_display_name: IFEval Strict Acc description: Fraction of instructions in the instance that are correctly followed. lower_is_better: false + - name: chain_of_thought_correct + display_name: COT correct + short_display_name: COT correct + description: TBD. + lower_is_better: false ############################################################ perturbations: [] @@ -135,7 +140,6 @@ run_groups: subgroups: - mmlu_pro - gpqa - - ifeval - name: mmlu_pro display_name: MMLU-Pro @@ -162,24 +166,7 @@ run_groups: - efficiency - general_information environment: - main_name: exact_match # non-CoT - main_split: test - taxonomy: - task: "?" - what: "?" - who: "?" - when: "?" - language: English - - - name: ifeval - display_name: IFEval - description: IFEval - metric_groups: - - accuracy - - efficiency - - general_information - environment: - main_name: ifeval_strict_accuracy + main_name: chain_of_thought_correct # non-CoT main_split: test taxonomy: task: "?"