-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
109 lines (96 loc) · 3.87 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (C) 2023 National Research Council Canada.
#
# This file is part of vardial-2023.
#
# vardial-2023 is free software: you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# vardial-2023 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# vardial-2023. If not, see https://www.gnu.org/licenses/.
import os, argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from datasets import Dataset
from tqdm.auto import tqdm
from utils import load_lines, write_preds, CLASS_NAMES
DOC="""
Predict labels of texts using a pre-trained classifier.
"""
def load_data(path):
texts = load_lines(path)
data = Dataset.from_dict({"text": texts})
return data
def main(args):
def tokenize(examples):
out = tokenizer(examples["text"], padding=False, truncation=True)
return out
# Load data
data = load_data(args.path_test_texts)
nb_examples = len(data['text'])
label_list = CLASS_NAMES
label2id = {x:i for i,x in enumerate(label_list)}
label2id = {i:x for i,x in enumerate(label_list)}
# Load model
model = AutoModelForSequenceClassification.from_pretrained(args.path_checkpoint)
if model.config.problem_type == "single_label_classification":
mode = "single"
elif model.config.problem_type == "multi_label_classification":
mode = "multi"
else:
msg = f"Unrecognized problem type '{model.config.problem_type}'"
raise RuntimeError(msg)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.path_tokenizer)
# Tokenize data
print("Tokenizing...")
data = data.map(tokenize, batched=True, num_proc=1)
data = data.remove_columns("text")
data.set_format("torch")
collator = DataCollatorWithPadding(
tokenizer=tokenizer,
padding="longest",
max_length=tokenizer.model_max_length,
return_tensors="pt"
)
data_loader = DataLoader(data, shuffle=False, batch_size=args.batch_size, collate_fn=collator)
print(f"Nb examples: {nb_examples}")
print(f"Batch size: {args.batch_size}")
print(f"Nb batches: {len(data_loader)}")
# Use GPU If available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
# Run prediction
progress = tqdm(range(len(data_loader)), desc="Batches")
model.eval()
all_logits = []
for i,batch in enumerate(data_loader):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
logits = outputs.logits
all_logits.append(logits.cpu().detach())
progress.update(1)
# Write predictions
all_logits = torch.vstack(all_logits)
write_preds(all_logits, label_list, args.path_preds, mode)
return
if __name__ == "__main__":
p = argparse.ArgumentParser(description=DOC)
p.add_argument("path_checkpoint", help="Path of directory containing binary model file and config")
p.add_argument("path_tokenizer", help="Path of directory containing the tokenizer files")
p.add_argument("path_test_texts", help="Path of text file containing test texts (one per line)")
p.add_argument("path_preds", help="Path of output text file containing predicted labels")
p.add_argument("--batch_size",
type=int,
default=32)
args = p.parse_args()
main(args)