From 9561ec9edfe5cbe49a2ad2ea6064e87428ef65f0 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Thu, 19 Dec 2024 17:29:25 -0800 Subject: [PATCH] Update CzechBankQA (#3227) --- .../benchmark/annotation/czech_bank_qa_annotator.py | 13 ++++++++++++- .../benchmark/run_specs/experimental_run_specs.py | 5 +++-- .../benchmark/scenarios/czech_bank_qa_scenario.py | 8 +++++++- .../scenarios/test_czech_bank_qa_scenario.py | 2 +- 4 files changed, 23 insertions(+), 5 deletions(-) diff --git a/src/helm/benchmark/annotation/czech_bank_qa_annotator.py b/src/helm/benchmark/annotation/czech_bank_qa_annotator.py index bc7052fb1b..67efe11fae 100644 --- a/src/helm/benchmark/annotation/czech_bank_qa_annotator.py +++ b/src/helm/benchmark/annotation/czech_bank_qa_annotator.py @@ -1,9 +1,11 @@ +import os import sqlite3 import threading from typing import Any, Optional, Tuple from helm.benchmark.adaptation.request_state import RequestState from helm.benchmark.annotation.annotator import Annotator +from helm.common.general import ensure_directory_exists, ensure_file_downloaded class CzechBankQAAnnotator(Annotator): @@ -13,8 +15,18 @@ class CzechBankQAAnnotator(Annotator): name = "czech_bank_qa" + DATABASE_SOURCE_URL = ( + "https://huggingface.co/datasets/yifanmai/czech_bank_qa/resolve/main/czech_bank.db?download=true" + ) + def __init__(self, file_storage_path: str): super().__init__() + + cache_dir = os.path.join(file_storage_path, "data") + ensure_directory_exists(cache_dir) + file_name = "czech_bank.db" + file_path = os.path.join(cache_dir, file_name) + ensure_file_downloaded(source_url=CzechBankQAAnnotator.DATABASE_SOURCE_URL, target_path=file_path) database = sqlite3.connect("/home/yifanmai/oss/helm/czech_bank/czech_bank.db") # csv_files_dir = "/home/yifanmai/oss/helm-scenarios/1999-czech-bank" @@ -38,7 +50,6 @@ def __init__(self, file_storage_path: str): self.database = database self.lock = threading.Lock() - print("done iwth init") def get_result(self, query: str) -> Tuple[Optional[str], Optional[str]]: result: Optional[str] = None diff --git a/src/helm/benchmark/run_specs/experimental_run_specs.py b/src/helm/benchmark/run_specs/experimental_run_specs.py index 0f142981c2..c242b9db48 100644 --- a/src/helm/benchmark/run_specs/experimental_run_specs.py +++ b/src/helm/benchmark/run_specs/experimental_run_specs.py @@ -167,11 +167,12 @@ def get_autobencher_safety_spec() -> RunSpec: @run_spec_function("czech_bank_qa") -def get_czech_bank_qa_spec() -> RunSpec: +def get_czech_bank_qa_spec(config_name: str = "berka_queries_1024_2024_12_18") -> RunSpec: from helm.benchmark.scenarios.czech_bank_qa_scenario import CzechBankQAScenario scenario_spec = ScenarioSpec( - class_name="helm.benchmark.scenarios.czech_bank_qa_scenario.CzechBankQAScenario", args={} + class_name="helm.benchmark.scenarios.czech_bank_qa_scenario.CzechBankQAScenario", + args={"config_name": config_name}, ) adapter_spec = get_generation_adapter_spec( diff --git a/src/helm/benchmark/scenarios/czech_bank_qa_scenario.py b/src/helm/benchmark/scenarios/czech_bank_qa_scenario.py index 05338dd951..f54cb83505 100644 --- a/src/helm/benchmark/scenarios/czech_bank_qa_scenario.py +++ b/src/helm/benchmark/scenarios/czech_bank_qa_scenario.py @@ -111,10 +111,16 @@ class CzechBankQAScenario(Scenario): description = "This is a list of SQL queries for a text-to-SQL task over the Czech Bank 1999 dataset." tags = ["text_to_sql"] + def __init__(self, config_name: str): + super().__init__() + self.config_name = config_name + def get_instances(self, output_path: str) -> List[Instance]: cache_dir = os.path.join(output_path, "data") ensure_directory_exists(cache_dir) - dataset = datasets.load_dataset("yifanmai/czech_bank_qa", split="test", cache_dir=cache_dir) + dataset = datasets.load_dataset( + "yifanmai/czech_bank_qa", name=self.config_name, split="test", cache_dir=cache_dir + ) instances: List[Instance] = [] for row in dataset: input = Input(text=row["description"]) diff --git a/src/helm/benchmark/scenarios/test_czech_bank_qa_scenario.py b/src/helm/benchmark/scenarios/test_czech_bank_qa_scenario.py index 211677b494..7215debff5 100644 --- a/src/helm/benchmark/scenarios/test_czech_bank_qa_scenario.py +++ b/src/helm/benchmark/scenarios/test_czech_bank_qa_scenario.py @@ -7,7 +7,7 @@ @pytest.mark.scenarios def test_czech_bank_qa_scenario_get_instances(): - scenario = CzechBankQAScenario() + scenario = CzechBankQAScenario(config_name="default") with TemporaryDirectory() as tmpdir: actual_instances = scenario.get_instances(tmpdir) assert len(actual_instances) == 30