-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.py
141 lines (120 loc) · 5.24 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import argparse
import json
import logging
import os
from pathlib import Path
from lm_eval import evaluator, tasks, utils
from lm_eval.models import MODEL_REGISTRY
logging.getLogger("openai").setLevel(logging.WARNING)
def get_commit(repo_path):
git_folder = Path(repo_path, ".git")
if git_folder.is_file():
git_folder = Path(git_folder.parent, git_folder.read_text().split("\n")[0].split(" ")[-1])
if Path(git_folder, "HEAD").exists():
head_name = Path(git_folder, "HEAD").read_text().split("\n")[0].split(" ")[-1]
head_ref = Path(git_folder, head_name)
commit = head_ref.read_text().replace("\n", "")
else:
commit = ""
return commit
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True, choices=MODEL_REGISTRY, help="Name of internal model class type.")
parser.add_argument(
"--model_args", default="", help="Comma separated string arguments for transformers model autoclass."
)
parser.add_argument(
"--tasks", default=None, choices=utils.MultiChoice(tasks.ALL_TASKS), help="Comma separated list of task names."
)
parser.add_argument("--num_fewshot", type=int, default=0, help="Number of examples in few-shot context.")
parser.add_argument("--batch_size", type=str, default=None, help="Batch size for model.")
parser.add_argument(
"--max_batch_size", type=int, default=None, help="Maximal batch size to try with --batch_size auto."
)
parser.add_argument("--device", type=str, default=None, help="PyTorch device string for running models.")
parser.add_argument("--output_path", default=None, help="Path to store results of task run")
parser.add_argument(
"--limit",
type=float,
default=None,
help="Limit the number of examples per task. If <1, limit is a percentage of the total number of examples.",
)
parser.add_argument("--no_cache", action="store_true", help="Set to not cache model files.")
parser.add_argument(
"--decontamination_ngrams_path",
default=None,
help="Directory with the ngram files and info.json for decontamination",
)
parser.add_argument("--description_dict_path", default=None, help="Path to dictionary of custom task descriptions.")
parser.add_argument(
"--check_integrity",
action="store_true",
help="Whether to run the relevant part of the test suite for the tasks.",
)
parser.add_argument(
"--write_out",
action="store_true",
default=False,
help="Write details about prompts and logits to json for all tasks.",
)
parser.add_argument(
"--output_base_path",
type=str,
default=None,
help="Directory to which detailed eval info will be written. Defaults to present working dir.",
)
parser.add_argument(
"--inference", action="store_true", default=False, help="Whether the procedure runs without labels."
)
return parser.parse_args()
def main():
args = parse_args()
if args.limit:
print("WARNING: --limit SHOULD ONLY BE USED FOR TESTING. REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT.")
if args.tasks is None:
task_names = tasks.ALL_TASKS
else:
task_names = utils.pattern_match(args.tasks.split(","), tasks.ALL_TASKS)
print(f"Selected Tasks: {task_names}")
description_dict = {}
if args.description_dict_path:
with open(args.description_dict_path, "r") as f:
description_dict = json.load(f)
results = evaluator.simple_evaluate(
model=args.model,
model_args=args.model_args,
tasks=task_names,
num_fewshot=args.num_fewshot,
batch_size=args.batch_size,
max_batch_size=args.max_batch_size,
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict,
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
output_base_path=args.output_base_path,
inference=args.inference,
)
results["config"]["current_dir_commit"] = get_commit(os.getcwd()) # git hash of repo if exists
results["config"]["upper_dir_commit"] = get_commit(Path(os.getcwd(), "..")) # git hash of upper repo if exists
if args.inference:
output_base_path = Path(args.output_base_path) if args.output_base_path is not None else Path(".")
with open(output_base_path.joinpath("evaluation_config.json"), "w", encoding="utf8") as file:
json.dump(results["config"], file, indent=2, ensure_ascii=False)
dumped = json.dumps(results, indent=2, ensure_ascii=False)
print(dumped)
if args.output_path:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, "w") as f:
f.write(dumped)
batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))
print(
f"{args.model} ({args.model_args}), limit: {args.limit}, "
f"num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
)
if not args.inference:
print(evaluator.make_table(results))
if __name__ == "__main__":
main()