Skip to content

Commit

Permalink
Update CzechBankQA (#3227)
Browse files Browse the repository at this point in the history
  • Loading branch information
yifanmai authored Dec 20, 2024
1 parent a4f8b9e commit 9561ec9
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 5 deletions.
13 changes: 12 additions & 1 deletion src/helm/benchmark/annotation/czech_bank_qa_annotator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import sqlite3
import threading
from typing import Any, Optional, Tuple

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.annotation.annotator import Annotator
from helm.common.general import ensure_directory_exists, ensure_file_downloaded


class CzechBankQAAnnotator(Annotator):
Expand All @@ -13,8 +15,18 @@ class CzechBankQAAnnotator(Annotator):

name = "czech_bank_qa"

DATABASE_SOURCE_URL = (
"https://huggingface.co/datasets/yifanmai/czech_bank_qa/resolve/main/czech_bank.db?download=true"
)

def __init__(self, file_storage_path: str):
super().__init__()

cache_dir = os.path.join(file_storage_path, "data")
ensure_directory_exists(cache_dir)
file_name = "czech_bank.db"
file_path = os.path.join(cache_dir, file_name)
ensure_file_downloaded(source_url=CzechBankQAAnnotator.DATABASE_SOURCE_URL, target_path=file_path)
database = sqlite3.connect("/home/yifanmai/oss/helm/czech_bank/czech_bank.db")

# csv_files_dir = "/home/yifanmai/oss/helm-scenarios/1999-czech-bank"
Expand All @@ -38,7 +50,6 @@ def __init__(self, file_storage_path: str):

self.database = database
self.lock = threading.Lock()
print("done iwth init")

def get_result(self, query: str) -> Tuple[Optional[str], Optional[str]]:
result: Optional[str] = None
Expand Down
5 changes: 3 additions & 2 deletions src/helm/benchmark/run_specs/experimental_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,12 @@ def get_autobencher_safety_spec() -> RunSpec:


@run_spec_function("czech_bank_qa")
def get_czech_bank_qa_spec() -> RunSpec:
def get_czech_bank_qa_spec(config_name: str = "berka_queries_1024_2024_12_18") -> RunSpec:
from helm.benchmark.scenarios.czech_bank_qa_scenario import CzechBankQAScenario

scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.czech_bank_qa_scenario.CzechBankQAScenario", args={}
class_name="helm.benchmark.scenarios.czech_bank_qa_scenario.CzechBankQAScenario",
args={"config_name": config_name},
)

adapter_spec = get_generation_adapter_spec(
Expand Down
8 changes: 7 additions & 1 deletion src/helm/benchmark/scenarios/czech_bank_qa_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,16 @@ class CzechBankQAScenario(Scenario):
description = "This is a list of SQL queries for a text-to-SQL task over the Czech Bank 1999 dataset."
tags = ["text_to_sql"]

def __init__(self, config_name: str):
super().__init__()
self.config_name = config_name

def get_instances(self, output_path: str) -> List[Instance]:
cache_dir = os.path.join(output_path, "data")
ensure_directory_exists(cache_dir)
dataset = datasets.load_dataset("yifanmai/czech_bank_qa", split="test", cache_dir=cache_dir)
dataset = datasets.load_dataset(
"yifanmai/czech_bank_qa", name=self.config_name, split="test", cache_dir=cache_dir
)
instances: List[Instance] = []
for row in dataset:
input = Input(text=row["description"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

@pytest.mark.scenarios
def test_czech_bank_qa_scenario_get_instances():
scenario = CzechBankQAScenario()
scenario = CzechBankQAScenario(config_name="default")
with TemporaryDirectory() as tmpdir:
actual_instances = scenario.get_instances(tmpdir)
assert len(actual_instances) == 30
Expand Down

0 comments on commit 9561ec9

Please sign in to comment.