-
Notifications
You must be signed in to change notification settings - Fork 0
/
Model_2_With_SIL.py
168 lines (132 loc) · 5.42 KB
/
Model_2_With_SIL.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from transformers import AutoModelForSequenceClassification
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import get_scheduler
import torch
from tqdm.auto import tqdm
from torch.optim import AdamW
from torch.nn.functional import softmax
from Config_Manager import get_dataset, SEED, LEARNING_RATE, BATCH_SIZE, DEVICE, GENERATIONS, STUDENT_EPOCHS, TEACHER_EPOCHS
import sys
"""
HYPER PARAMS FROM CONFIG FILE
"""
seed = SEED
learning_rate = LEARNING_RATE
batch_size = BATCH_SIZE
device = DEVICE
generations = GENERATIONS
student_epochs = STUDENT_EPOCHS
teacher_epochs = TEACHER_EPOCHS
def custom_loss(predictions, labels):
loss_fn = torch.nn.CrossEntropyLoss()
return loss_fn(predictions, labels)
#get data sets and splits
dataset = get_dataset("masakhane")
train_dataset = dataset["train"]
val_dataset = dataset["val"]
val_dataset_d1 = get_dataset("naija")["val"]
del dataset
#create data loaders, create a vaidation data loader for data sets 1 and 2
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
val_dataloader_d1 = DataLoader(val_dataset_d1, batch_size=batch_size)
#Make the 2 models on the afriberta pre trained model from step 1
student_model = AutoModelForSequenceClassification.from_pretrained("Saved_Models/model_1").to(device)
student_model.config.loss_name = "cross_entropy" #use cross entropy loss function
student_optimizer = AdamW(student_model.parameters(), lr=learning_rate)
student_model.train()
num_training_steps = student_epochs * len(train_dataloader)
training_steps = student_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
name="linear",
optimizer=student_optimizer,
num_training_steps=training_steps,
num_warmup_steps=0
)
prog_bar = tqdm(range(training_steps))
train_epoch_loss = []
val_epoch_loss = []
val_epoch_loss_d1 = []
for gen in range(generations):
teacher_model = AutoModelForSequenceClassification.from_pretrained("Saved_Models/model_1").to(device)
teacher_model.config.loss_name = "cross_entropy" #use cross entropy loss function
#copy the student model to be the teacher model
teacher_model.load_state_dict(student_model.state_dict())
#define optimizer & scheduler for teacher model
teacher_optimizer = AdamW(teacher_model.parameters(), lr=learning_rate)
teacher_optimizer.zero_grad()
training_steps_1 = teacher_epochs * len(train_dataloader)
lr_scheduler_1 = get_scheduler(
name="linear",
optimizer=teacher_optimizer,
num_training_steps=training_steps_1,
num_warmup_steps=0
)
new_batch = []
for te in range(teacher_epochs):
new_batch = [] #Empty it so we only take the last set of pseudo labels from the last epoch
for batch in train_dataloader:
#First train the teacher model
batch = {k : v.to(device) for k, v in batch.items()}
teacher_model.train()
outputs = teacher_model(**batch) #get outputs from teacher model
loss = outputs.loss
loss.backward()
teacher_optimizer.step()
lr_scheduler_1.step()
teacher_optimizer.zero_grad()
prog_bar.update(1)
# Get teacher model predictions for the inputs next (new labels)
with torch.no_grad():
teacher_logits = teacher_model(**batch).logits
# softmax the teacher logits for pseudo-labels
pseudo_labels = softmax(teacher_logits, dim=1)
temp_batch = batch.copy()
temp_batch["labels"] = pseudo_labels
new_batch.append(temp_batch)
se_loss = []
for se in range(student_epochs):
student_model.train()
step_loss = []
for batch in new_batch:
batch = {k : v.to(device) for k, v in batch.items()}
# Train the student model using the teacher pseudo-labels (soft labels)
student_optimizer.zero_grad()
student_logits = student_model(**batch).logits
loss = custom_loss(student_logits, batch["labels"])
loss.backward()
student_optimizer.step()
lr_scheduler.step()
step_loss.append(loss.item())
se_loss.append(np.mean(step_loss)) #keep track of loss
#Loss for 1 generation averaged over the n student epochs
train_epoch_loss.append(np.mean(se_loss))
#Evaluate the model on the validation set for dataset 2
student_model.eval()
step_loss = []
for batch in val_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = student_model(**batch)
loss = outputs.loss
step_loss.append(loss.item())
val_epoch_loss.append(np.mean(step_loss))
#Evaluate the model on the validation set for dataset 1
student_model.eval()
step_loss = []
for batch in val_dataloader_d1:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = student_model(**batch)
loss = outputs.loss
step_loss.append(loss.item())
val_epoch_loss_d1.append(np.mean(step_loss))
loss_data = [
train_epoch_loss,
val_epoch_loss,
val_epoch_loss_d1
]
loss_data = np.array(loss_data)
#Save the student model
student_model.save_pretrained("Saved_Models/model_2_SIL")
np.save(f"Saved_Models/model_2_SIL/model_2_SIL_Loss_{sys.argv[1]}.npy", loss_data)