Skip to content

Commit

Permalink
Merge branch 'thallysonjsa/enem_challenge_scenario' of https://github…
Browse files Browse the repository at this point in the history
….com/thallysonjsa/helm into thallysonjsa/enem_challenge_scenario
  • Loading branch information
thallysonjsa committed Dec 6, 2024
2 parents 48bf8e8 + 9461ad3 commit 506b020
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 13 deletions.
4 changes: 4 additions & 0 deletions src/helm/benchmark/presentation/run_entries_speech.conf
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ entries: [
{description: "vocal_sound:model=audiolm", priority: 1}
{description: "audiocaps:model=audiolm", priority: 1}
{description: "voxceleb2:model=audiolm", priority: 1}
{description: "air_bench_chat:subject=speech,model=audiolm", priority: 1}
{description: "air_bench_chat:subject=sound,model=audiolm", priority: 1}
{description: "air_bench_chat:subject=music,model=audiolm", priority: 1}
{description: "air_bench_chat:subject=mix,model=audiolm", priority: 1}

####################################################################################################################
# Fairness
Expand Down
20 changes: 20 additions & 0 deletions src/helm/benchmark/run_specs/audio_run_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,3 +373,23 @@ def get_casual_conversations2_run_spec(subject: str) -> RunSpec:
metric_specs=metric_specs,
groups=[run_spec_name],
)


@run_spec_function("air_bench_chat")
def get_air_bench_chat_run_spec(subject: str) -> RunSpec:
scenario_spec = ScenarioSpec(
class_name="helm.benchmark.scenarios.audio_language.air_bench_chat_scenario." "AirBenchChatScenario",
args={"subject": subject},
)
adapter_spec = _get_generation_adapter_spec(
max_tokens=50,
)
metric_specs: List[MetricSpec] = _get_open_ended_generation_metric_specs()
run_spec_name: str = "air_bench_chat"
return RunSpec(
name=f"{run_spec_name}:subject={subject}",
scenario_spec=scenario_spec,
adapter_spec=adapter_spec,
metric_specs=metric_specs,
groups=[run_spec_name],
)
117 changes: 117 additions & 0 deletions src/helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from typing import List
import os

from helm.benchmark.scenarios.scenario import (
Scenario,
Instance,
Reference,
TEST_SPLIT,
CORRECT_TAG,
Input,
Output,
)
from tqdm import tqdm
from helm.common.media_object import MediaObject, MultimediaObject
from helm.common.general import ensure_file_downloaded
import json


class AirBenchChatScenario(Scenario):
"""Air-Bench Chat
Air-Bench AIR-Bench (Audio InstRuction Benchmark) is a benchmark designed to evaluate the ability of audio language
models to understand various types of audio signals (including human speech, natural sounds and music), and
furthermore, to interact with humans in textual format. AIR-Bench encompasses two dimensions: foundation
and chat benchmarks. The former consists of 19 tasks with approximately 19k single-choice questions. The
latter one contains 2k instances of open-ended question-and-answer data. We consider the chat benchmark
in this scenario.
Paper: https://aclanthology.org/2024.acl-long.109.pdf
Code: https://github.com/OFA-Sys/AIR-Bench
Citation:
@inproceedings{yang-etal-2024-air,
title = "{AIR}-Bench: Benchmarking Large Audio-Language Models via Generative Comprehension",
author = "Yang, Qian and
Xu, Jin and
Liu, Wenrui and
Chu, Yunfei and
Jiang, Ziyue and
Zhou, Xiaohuan and
Leng, Yichong and
Lv, Yuanjun and
Zhao, Zhou and
Zhou, Chang and
Zhou, Jingren",
booktitle = "Proceedings of the 62nd Annual Meeting of the Association for Computational
Linguistics (Volume 1: Long Papers)",
year = "2024",}
"""

HF_DATA_PATH_PREFIX = "https://huggingface.co/datasets/qyang1021/AIR-Bench-Dataset/resolve/main/Chat"
META_DATA_FILE_PATH = "https://huggingface.co/datasets/qyang1021/AIR-Bench-Dataset/resolve/main/Chat/Chat_meta.json"
SUJECTS = ["music", "sound", "speech", "mix"]

name = "air_bench_chat"
description = "A large-scale dataset of about 46K audio clips to human-written text pairs \
([Yang et al, 2024](https://aclanthology.org/2024.acl-long.109.pdf))."
tags: List[str] = ["audio", "reasoning"]

def __init__(self, subject: str) -> None:
super().__init__()

if subject not in AirBenchChatScenario.SUJECTS:
raise ValueError(f"Invalid subject. Valid subjects are: {AirBenchChatScenario.SUJECTS}")

self._subject: str = subject

def _get_subject_indices(self, meta_data) -> List[int]:
subject_indices = []
for idx, line in enumerate(meta_data):
if self._subject == "mix":
if "_".join(line["task_name"].split("_")[:2]) == "speech_and":
subject_indices.append(idx)
else:
if line["task_name"].split("_")[0] == self._subject and line["task_name"].split("_")[1] != "and":
subject_indices.append(idx)
return subject_indices

def _get_content_type(self, audio_file_name) -> str:
if audio_file_name.endswith(".wav"):
return "audio/wav"
elif audio_file_name.endswith(".mp3"):
return "audio/mp3"
else:
raise ValueError(f"Unsupported audio file format: {audio_file_name}")

def get_instances(self, output_path: str) -> List[Instance]:
instances: List[Instance] = []
data_dir: str = os.path.join(output_path, "wav_files")
meta_data_path: str = os.path.join(output_path, "Chat_meta.json")
ensure_file_downloaded(source_url=AirBenchChatScenario.META_DATA_FILE_PATH, target_path=meta_data_path)
meta_data = json.load(open(meta_data_path))
subject_indices = self._get_subject_indices(meta_data)
for _, row in enumerate(tqdm(subject_indices)):
audio_meda_data = meta_data[row]
hf_audio_file_path = os.path.join(
self.HF_DATA_PATH_PREFIX,
f'{audio_meda_data["task_name"]}_{audio_meda_data["dataset_name"]}/{audio_meda_data["path"]}',
)
local_audio_file_path = os.path.join(
data_dir, f'{audio_meda_data["task_name"]}_{audio_meda_data["dataset_name"]}_{audio_meda_data["path"]}'
)
ensure_file_downloaded(source_url=hf_audio_file_path, target_path=local_audio_file_path)
input = Input(
multimedia_content=MultimediaObject(
[
MediaObject(
content_type=self._get_content_type(audio_meda_data["path"]),
location=local_audio_file_path,
),
MediaObject(content_type="text/plain", text=audio_meda_data["question"]),
]
)
)
references = [Reference(Output(text=audio_meda_data["answer_gt"]), tags=[CORRECT_TAG])]
instances.append(Instance(input=input, references=references, split=TEST_SPLIT))
return instances
34 changes: 24 additions & 10 deletions src/helm/benchmark/static/schema_speech.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ run_groups:
audio sample ([Becker et al, 2023](https://arxiv.org/abs/1807.03418)).
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
Expand All @@ -219,7 +218,6 @@ run_groups:
([Wang et al, 2020](https://arxiv.org/abs/2007.10310)).
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: bleu
Expand All @@ -241,7 +239,6 @@ run_groups:
age, gender, native language, country, and health condition ([Gong et al, 2022](https://arxiv.org/abs/2205.03433)).
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
Expand All @@ -263,7 +260,6 @@ run_groups:
Dutch, German, French, Spanish, Italian, Portuguese", Polish ([Pratap et al, 2022](https://arxiv.org/abs/2012.03411)).
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: f1_score
Expand All @@ -288,7 +284,6 @@ run_groups:
South Asian, South East Asian, Chinese Japanase Korean ([Conneau et al, 2022](https://arxiv.org/abs/2205.12446)).
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
Expand Down Expand Up @@ -353,7 +348,6 @@ run_groups:
([Ardila et al, 2020](https://arxiv.org/abs/1912.06670)).
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: word_accuracy
Expand All @@ -378,7 +372,6 @@ run_groups:
([Shah et al, 2024](https://arxiv.org/abs/2403.07937)).
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: word_accuracy
Expand All @@ -401,7 +394,6 @@ run_groups:
The dataset contains the audio and question for three subsets: occupation, status, and potential_crime.
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
Expand All @@ -427,7 +419,6 @@ run_groups:
questions answering task.
metric_groups:
- accuracy
- efficiency
- general_information
environment:
main_name: exact_match
Expand All @@ -437,4 +428,27 @@ run_groups:
what: audio, spoken language, speaker's gender, age information of audio samples
who: real speakers
when: "2023"
language: 10 languages
language: 10 languages

- name: air_bench_chat
display_name: Air-Bench Chat
description: >
Air-Bench (Yang et al, 2024) encompasses two dimensions: foundation and chat benchmarks. The former consists of 19 tasks with
approximately 19k single-choice questions. The latter one contains 2k instances of open-ended question-and-answer data.
We consider the chat benchmark in this scenario.
The dataset contains the audio question answering task in four subjects: sound, speech, music, and mixed.
([Yang et al, 2024](https://aclanthology.org/2024.acl-long.109.pdf)).
metric_groups:
- accuracy
- general_information
- reasoning
environment:
main_name: f1_score
main_split: test
taxonomy:
task: audio question answering
what: adio, question, and answer of audio samples
who: real speakers
when: "2024"
language: English
23 changes: 23 additions & 0 deletions src/helm/clients/upstage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from helm.clients.openai_client import OpenAIClient
from helm.common.cache import CacheConfig
from helm.tokenizers.tokenizer import Tokenizer


class UpstageChatClient(OpenAIClient):
"""Sends request to a Upstage model using a OpenAI-compatible Chat API."""

def __init__(
self,
tokenizer: Tokenizer,
tokenizer_name: str,
cache_config: CacheConfig,
api_key: str,
):
super().__init__(
tokenizer=tokenizer,
tokenizer_name=tokenizer_name,
cache_config=cache_config,
api_key=api_key,
org_id=None,
base_url="https://api.upstage.ai/v1/solar",
)
23 changes: 23 additions & 0 deletions src/helm/config/model_deployments.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,22 @@ model_deployments:
max_sequence_length: 2048
client_spec:
class_name: "helm.clients.vision_language.huggingface_vlm_client.HuggingFaceVLMClient"

## NECTEC
- name: huggingface/Pathumma-llm-text-1.0.0
model_name: nectec/Pathumma-llm-text-1.0.0
tokenizer_name: nectec/Pathumma-llm-text-1.0.0
max_sequence_length: 8192
client_spec:
class_name: "helm.clients.huggingface_client.HuggingFaceClient"

- name: huggingface/OpenThaiLLM-Prebuilt-7B
model_name: nectec/OpenThaiLLM-Prebuilt-7B
tokenizer_name: nectec/OpenThaiLLM-Prebuilt-7B
max_sequence_length: 4096
client_spec:
class_name: "helm.clients.huggingface_client.HuggingFaceClient"

## KAIST AI
- name: huggingface/prometheus-vision-13b-v1.0-hf
model_name: kaistai/prometheus-vision-13b-v1.0-hf
Expand Down Expand Up @@ -2751,6 +2766,14 @@ model_deployments:
client_spec:
class_name: "helm.clients.reka_client.RekaClient"

# Upstage
- name: upstage/solar-pro-241126
model_name: upstage/solar-pro-241126
tokenizer_name: upstage/solar-pro-preview-instruct
max_sequence_length: 32768
client_spec:
class_name: "helm.clients.upstage_client.UpstageChatClient"

# Diva Llama
- name: huggingface/diva-llama
model_name: stanford/diva-llama
Expand Down
31 changes: 31 additions & 0 deletions src/helm/config/model_metadata.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,27 @@ models:
tags: [TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG]



# NECTEC
- name: nectec/Pathumma-llm-text-1.0.0
display_name: Pathumma-llm-text-1.0.0 (7B)
description: Pathumma-llm-text-1.0.0 (7B) is a instruction model from OpenThaiLLM-Prebuilt-7B ([blog](https://medium.com/nectec/pathummallm-v-1-0-0-release-6a098ddfe276))
creator_organization_name: nectec
access: open
num_parameters: 7620000000
release_date: 2024-10-28
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

- name: nectec/OpenThaiLLM-Prebuilt-7B
display_name: OpenThaiLLM-Prebuilt-7B (7B)
description: OpenThaiLLM-Prebuilt-7B (7B) is a pretrained Thai large language model with 7 billion parameters based on Qwen2.5-7B.
creator_organization_name: nectec
access: open
num_parameters: 7620000000
release_date: 2024-10-28
tags: [TEXT_MODEL_TAG, PARTIAL_FUNCTIONALITY_TEXT_MODEL_TAG]



# Neurips
- name: neurips/local
Expand Down Expand Up @@ -3146,6 +3167,15 @@ models:
release_date: 2024-09-11
tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG]

- name: upstage/solar-pro-241126
display_name: Solar Pro
display_name: Solar Pro
description: Solar Pro is a LLM designed for instruction-following and processing structured formats like HTML and Markdown. It supports English, Korean, and Japanese and has domain expertise in Finance, Healthcare, and Legal. ([blog](https://www.upstage.ai/blog/press/solar-pro-aws)).
creator_organization_name: Upstage
access: limited
num_parameters: 22000000000
release_date: 2024-11-26
tags: [TEXT_MODEL_TAG, LIMITED_FUNCTIONALITY_TEXT_MODEL_TAG]

# Writer
- name: writer/palmyra-base
Expand Down Expand Up @@ -3469,3 +3499,4 @@ models:
num_parameters: 6740000000
release_date: 2023-11-08
tags: [TEXT_MODEL_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG]

Loading

0 comments on commit 506b020

Please sign in to comment.