Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datasets: MGSM and XNLI #248

Merged
merged 9 commits into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions utilization/dataset/mgsm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import re
import signal
from functools import cached_property

from ..metric import Accuracy
from .generation_dataset import GenerationDataset


class Mgsm(GenerationDataset):
"""The dataset of MGSM.

Multilingual Grade School Math Benchmark (MGSM) is a benchmark of grade-school math problems.

Examples:
'question': 'Question: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?',
'answer': 'Step-by-Step Answer: Roger started with 5 balls. 2 cans of 3 tennis balls each is 6 tennis balls. 5 + 6 = 11. The answer is 11.',
'answer_number': 11,
'equation_solution': '5 + 6 = 11.'
"""

instruction = "Answer the following question in {{lang}}.\n\nQuestion: {{question.replace('\n', ' ')}}\nAnswer:"

evaluation_set = "test"
example_set = "train"
load_args = ("juletxara/mgsm",)
metrics = [Accuracy()]
extra_model_args = dict(temperature=0)

_decimal_separator = re.compile(r"(\d),(\d)")
_extract_numbers = re.compile(r"[-+]?\d*\.\d+|\d+")

def init_arguments(self):
if self.model_type == 'base':
self.extra_model_args['stop'] = ['\n']

from langcodes import Language
self.language = Language(self.subset_name).language_name("en")

def post_processing(self, predictions):
new_predictions = []
for pred in predictions:
# replace numbers like `x,xxx` with `xxxx`
pred = self._decimal_separator.sub(r"\1\2", pred)
numbers = self._extract_numbers.findall(pred)
if numbers:
new_predictions.append(numbers[-1])
else:
new_predictions.append(pred)
return new_predictions

def format_instance(self, instance):
instance["lang"] = self.language
instance['short_answer'] = str(instance["answer_number"])
instance["target"] = instance["answer"]

return instance

@cached_property
def references(self):
return [instance["short_answer"] for instance in self.evaluation_data]
37 changes: 37 additions & 0 deletions utilization/dataset/xnli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from functools import cached_property
from logging import getLogger

from .multiple_choice_dataset import MultipleChoiceDataset

logger = getLogger(__name__)


class Xnli(MultipleChoiceDataset):
"""The dataset of XNLI.

XNLI (Conneau et al. 2018) is a subset of a few thousand examples from MNLI which has been translated into a 14 different languages (some low-ish resource).

Example:
"hypothesis": "Man verliert die Dinge auf die folgende Ebene , wenn sich die Leute erinnern .",
"label": 0,
"premise": "\"Du weißt , während der Saison und ich schätze , auf deiner Ebene verlierst du sie auf die nächste Ebene , wenn sie sich entschl..."
"""

instruction = "Given the premise sentence in '{{lang}}': '{{premise.strip()}}', does the hypothesis sentence '{{hypothesis.strip()}}' entail, contradict, or neither (neutral) with respect to the premise?{{'\n'+options if options}}\nAnswer:"
evaluation_set = "validation"
example_set = "train"
load_args = ("xnli",)
banned_subsets = ["all_languages"]

def init_arguments(self):
from langcodes import Language
self.language = Language(self.subset_name).language_name("en")

def format_instance(self, instance):
instance["lang"] = self.language
instance["options"] = ["entailment", "neutral", "contradiction"]
return instance

@cached_property
def references(self):
return [instance["label"] for instance in self.evaluation_data]