-
Notifications
You must be signed in to change notification settings - Fork 0
/
prefinetune_examples.py
42 lines (29 loc) · 1.08 KB
/
prefinetune_examples.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import datasets
import pandas as pd
from pathlib import Path
import argparse
def parse_cla():
parser = argparse.ArgumentParser()
parser.add_argument("-save_path", type=Path)
parser.add_argument("-num_ex", type=int)
return parser.parse_args()
def load_dataset():
return datasets.load_dataset("webis/tldr-17")
def generate_prompt(text_body):
prompt = f"### Instruction: Write a concise summary of the following text.\n```{text_body}```\nSUMMARY:"
return prompt
def save_prompts(tldr_dataset, num_prompts, save_path):
i = 0
for text_dict in tldr_dataset["train"]:
if i == num_prompts:
break
prompt = generate_prompt(text_dict["content"])
with open(save_path.joinpath(f"example_prompt{i}.txt"), mode="w") as opened_txt:
opened_txt.write(prompt)
i += 1
def main():
args = parse_cla()
tldr_dataset = load_dataset()
save_prompts(tldr_dataset=tldr_dataset, num_prompts=args.num_ex, save_path=args.save_path)
if __name__ == "__main__":
main()