Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] Update D2V, AutoTokenizer, and pretraining scripts #155

Merged
merged 11 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion EduNLP/Formula/Formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from .ast import str2ast, get_edges, link_variable

CONST_MATHORD = {r"\pi"}
CONST_MATHORD = {"\\pi"}

__all__ = ["Formula", "FormulaGroup", "CONST_MATHORD", "link_formulas"]

Expand Down
12 changes: 6 additions & 6 deletions EduNLP/I2V/i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class I2V(object):
(...)
>>> path = path_append(path, os.path.basename(path) + '.bin', to_str=True)
>>> i2v = D2V("pure_text", "d2v", filepath=path, pretrained_t2v=False)
>>> i2v(item)
([array([ ...dtype=float32)], None)
>>> i2v(item) # doctest: +SKIP
([array([ ...dtype=float32)], [[array([ ...dtype=float32)]])

Returns
-------
Expand Down Expand Up @@ -189,8 +189,8 @@ class D2V(I2V):
(...)
>>> path = path_append(path, os.path.basename(path) + '.bin', to_str=True)
>>> i2v = D2V("pure_text","d2v",filepath=path, pretrained_t2v = False)
>>> i2v(item)
([array([ ...dtype=float32)], None)
>>> i2v(item) # doctest: +SKIP
# ([array([ ...dtype=float32)], [[array([ ...dtype=float32)]])

Returns
-------
Expand Down Expand Up @@ -221,7 +221,7 @@ def infer_vector(self, items, tokenize=True, key=lambda x: x, *args,
"""
tokens = self.tokenize(items, key=key) if tokenize is True else items
tokens = [token for token in tokens]
return self.t2v(tokens, *args, **kwargs), None
return self.t2v(tokens, *args, **kwargs), self.t2v.infer_tokens(tokens, *args, **kwargs)

@classmethod
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
Expand Down Expand Up @@ -579,7 +579,7 @@ def get_pretrained_i2v(name, model_dir=MODEL_DIR, device='cpu'):
>>> (); i2v = get_pretrained_i2v("d2v_test_256", "examples/test_model/d2v"); () # doctest: +SKIP
(...)
>>> print(i2v(item)) # doctest: +SKIP
([array([ ...dtype=float32)], None)
([array([ ...dtype=float32)], [[array([ ...dtype=float32)]])
"""
pretrained_models = get_all_pretrained_models()
if name not in pretrained_models:
Expand Down
1 change: 1 addition & 0 deletions EduNLP/ModelZoo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .utils import *
from .bert import *
from .hf_model import *
from .rnn import *
from .disenqnet import *
from .quesnet import *
1 change: 1 addition & 0 deletions EduNLP/ModelZoo/hf_model/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .hf_model import *
165 changes: 165 additions & 0 deletions EduNLP/ModelZoo/hf_model/hf_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import torch
from torch import nn
import json
import os
from transformers import AutoModel, PretrainedConfig, AutoConfig
from typing import List
from EduNLP.utils.log import logger
from ..base_model import BaseModel
from ..utils import PropertyPredictionOutput, KnowledgePredictionOutput
from ..rnn.harnn import HAM


__all__ = ["HfModelForPropertyPrediction", "HfModelForKnowledgePrediction"]


class HfModelForPropertyPrediction(BaseModel):
def __init__(self, pretrained_model_dir=None, head_dropout=0.5, init=True):
super(HfModelForPropertyPrediction, self).__init__()
bert_config = AutoConfig.from_pretrained(pretrained_model_dir)
if init:
logger.info(f'Load AutoModel from checkpoint: {pretrained_model_dir}')
self.bert = AutoModel.from_pretrained(pretrained_model_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to sth like self.model? AutoModel should not be constrained to BERT

else:
logger.info(f'Load AutoModel from config: {pretrained_model_dir}')
self.bert = AutoModel(bert_config)
self.hidden_size = self.bert.config.hidden_size
self.head_dropout = head_dropout
self.dropout = nn.Dropout(head_dropout)
self.classifier = nn.Linear(self.hidden_size, 1)
self.sigmoid = nn.Sigmoid()
self.criterion = nn.MSELoss()

self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "bert_config"]}
self.config['architecture'] = 'HfModelForPropertyPrediction'
self.config = PretrainedConfig.from_dict(self.config)

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
item_embeds = outputs.last_hidden_state[:, 0, :]
item_embeds = self.dropout(item_embeds)

logits = self.sigmoid(self.classifier(item_embeds)).squeeze(1)
loss = None
if labels is not None:
loss = self.criterion(logits, labels) if labels is not None else None
return PropertyPredictionOutput(
loss=loss,
logits=logits,
)

@classmethod
def from_config(cls, config_path, **kwargs):
config_path = os.path.join(os.path.dirname(config_path), 'model_config.json')
with open(config_path, "r", encoding="utf-8") as rf:
model_config = json.load(rf)
model_config['pretrained_model_dir'] = os.path.dirname(config_path)
model_config.update(kwargs)
return cls(
pretrained_model_dir=model_config['pretrained_model_dir'],
head_dropout=model_config.get("head_dropout", 0.5),
init=model_config.get('init', False)
)

def save_config(self, config_dir):
config_path = os.path.join(config_dir, "model_config.json")
with open(config_path, "w", encoding="utf-8") as wf:
json.dump(self.config.to_dict(), wf, ensure_ascii=False, indent=2)
self.bert.config.save_pretrained(config_dir)


class HfModelForKnowledgePrediction(BaseModel):
def __init__(self,
pretrained_model_dir=None,
num_classes_list: List[int] = None,
num_total_classes: int = None,
head_dropout=0.5,
flat_cls_weight=0.5,
attention_unit_size=256,
fc_hidden_size=512,
beta=0.5,
init=True
):
super(HfModelForKnowledgePrediction, self).__init__()
bert_config = AutoConfig.from_pretrained(pretrained_model_dir)
if init:
logger.info(f'Load AutoModel from checkpoint: {pretrained_model_dir}')
self.bert = AutoModel.from_pretrained(pretrained_model_dir)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

else:
logger.info(f'Load AutoModel from config: {pretrained_model_dir}')
self.bert = AutoModel(bert_config)
self.hidden_size = self.bert.config.hidden_size
self.head_dropout = head_dropout
self.dropout = nn.Dropout(head_dropout)
self.sigmoid = nn.Sigmoid()
self.criterion = nn.MSELoss()
self.flat_classifier = nn.Linear(self.hidden_size, num_total_classes)
self.ham_classifier = HAM(
num_classes_list=num_classes_list,
num_total_classes=num_total_classes,
sequence_model_hidden_size=self.bert.config.hidden_size,
attention_unit_size=attention_unit_size,
fc_hidden_size=fc_hidden_size,
beta=beta,
dropout_rate=head_dropout
)
self.flat_cls_weight = flat_cls_weight
self.num_classes_list = num_classes_list
self.num_total_classes = num_total_classes

self.config = {k: v for k, v in locals().items() if k not in ["self", "__class__", "bert_config"]}
self.config['architecture'] = 'HfModelForKnowledgePrediction'
self.config = PretrainedConfig.from_dict(self.config)

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
item_embeds = outputs.last_hidden_state[:, 0, :]
item_embeds = self.dropout(item_embeds)
tokens_embeds = outputs.last_hidden_state
tokens_embeds = self.dropout(tokens_embeds)
flat_logits = self.sigmoid(self.flat_classifier(item_embeds))
ham_outputs = self.ham_classifier(tokens_embeds)
ham_logits = self.sigmoid(ham_outputs.scores)
logits = self.flat_cls_weight * flat_logits + (1 - self.flat_cls_weight) * ham_logits
loss = None
if labels is not None:
labels = torch.sum(torch.nn.functional.one_hot(labels, num_classes=self.num_total_classes), dim=1)
labels = labels.float()
loss = self.criterion(logits, labels) if labels is not None else None
return KnowledgePredictionOutput(
loss=loss,
logits=logits,
)

@classmethod
def from_config(cls, config_path, **kwargs):
config_path = os.path.join(os.path.dirname(config_path), 'model_config.json')
with open(config_path, "r", encoding="utf-8") as rf:
model_config = json.load(rf)
model_config['pretrained_model_dir'] = os.path.dirname(config_path)
model_config.update(kwargs)
return cls(
pretrained_model_dir=model_config['pretrained_model_dir'],
head_dropout=model_config.get("head_dropout", 0.5),
num_classes_list=model_config.get('num_classes_list'),
num_total_classes=model_config.get('num_total_classes'),
flat_cls_weight=model_config.get('flat_cls_weight', 0.5),
attention_unit_size=model_config.get('attention_unit_size', 256),
fc_hidden_size=model_config.get('fc_hidden_size', 512),
beta=model_config.get('beta', 0.5),
init=model_config.get('init', False)
)

def save_config(self, config_dir):
config_path = os.path.join(config_dir, "model_config.json")
with open(config_path, "w", encoding="utf-8") as wf:
json.dump(self.config.to_dict(), wf, ensure_ascii=False, indent=2)
self.bert.config.save_pretrained(config_dir)
79 changes: 37 additions & 42 deletions EduNLP/ModelZoo/quesnet/quesnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,10 @@ def load_emb(self, emb):
self.we.weight.detach().copy_(torch.from_numpy(emb))

def load_img(self, img_layer: nn.Module):
if self.config.emb_size != img_layer.emb_size:
raise ValueError("Unmatched pre-trained ImageAE and embedding size")
else:
self.ie.load_state_dict(img_layer.state_dict())
self.ie.load_state_dict(img_layer)

def load_meta(self, meta_layer: nn.Module):
if self.config.emb_size != meta_layer.emb_size or self.meta_size != meta_layer.meta_size:
raise ValueError("Unmatched pre-trained MetaAE and embedding size or meta size")
else:
self.me.load_state_dict(meta_layer.state_dict())
self.me.load_state_dict(meta_layer)

def make_batch(self, data, device, pretrain=False):
"""Returns embeddings"""
Expand All @@ -122,16 +116,18 @@ def make_batch(self, data, device, pretrain=False):
for q in data:
meta = torch.zeros(len(self.stoi[self.meta])).to(device)
meta[q.labels.get(self.meta) or []] = 1
_lembs = [self.we(torch.tensor([0], device=device)),
self.we(torch.tensor([0], device=device)),
_lembs = [torch.zeros(1, self.emb_size).to(device),
torch.zeros(1, self.emb_size).to(device),
self.me.enc(meta.unsqueeze(0)) * self.lambda_input[2]]
_rembs = [self.me.enc(meta.unsqueeze(0)) * self.lambda_input[2]]
_embs = [self.we(torch.tensor([0], device=device)),
self.we(torch.tensor([0], device=device)),
_embs = [torch.zeros(1, self.emb_size).to(device),
torch.zeros(1, self.emb_size).to(device),
self.me.enc(meta.unsqueeze(0)) * self.lambda_input[2]]
_gt = [torch.tensor([0], device=device), meta]
for w in q.content:
if isinstance(w, int):
if w >= self.vocab_size:
w = self.vocab_size - 1
word = torch.tensor([w], device=device)
item = self.we(word) * self.lambda_input[0]
_lembs.append(item)
Expand All @@ -146,10 +142,10 @@ def make_batch(self, data, device, pretrain=False):
_embs.append(item)
_gt.append(im)
_gt.append(torch.tensor([0], device=device))
_rembs.append(self.we(torch.tensor([0], device=device)))
_rembs.append(self.we(torch.tensor([0], device=device)))
_embs.append(self.we(torch.tensor([0], device=device)))
_embs.append(self.we(torch.tensor([0], device=device)))
_rembs.append(torch.zeros(1, self.emb_size).to(device))
_rembs.append(torch.zeros(1, self.emb_size).to(device))
_embs.append(torch.zeros(1, self.emb_size).to(device))
_embs.append(torch.zeros(1, self.emb_size).to(device))

lembs.append(torch.cat(_lembs, dim=0))
rembs.append(torch.cat(_rembs, dim=0))
Expand Down Expand Up @@ -308,15 +304,15 @@ def __init__(self, _stoi=None, pretrained_embs: np.ndarray = None, pretrained_im
self.config = PretrainedConfig.from_dict(self.config)

def forward(self, batch):
left, right, words, ims, metas, wmask, imask, mmask, inputs, ans_input, ans_output, false_opt_input = batch[0]
left, right, words, ims, metas, wmask, imask, mmask, inputs, ans_input, ans_output, false_opt_input = batch

# high-level loss
outputs = self.quesnet(inputs)
embeded = outputs.embeded
h = outputs.hidden

x = ans_input.packed()

# (4,1,256), (4,3,256)
y, _ = self.ans_decode(PackedSequence(self.quesnet.we(x[0].data), x.batch_sizes),
h.repeat(self.config.layers, 1, 1))
floss = F.cross_entropy(self.ans_output(y.data),
Expand All @@ -333,45 +329,44 @@ def forward(self, batch):
torch.zeros_like(self.ans_judge(y.data)))
loss = floss * self.lambda_loss[1]
# low-level loss
left_hid = self.quesnet(left).pack_embeded.data[:, :self.rnn_size].clone()
right_hid = self.quesnet(right).pack_embeded.data[:, self.rnn_size:].clone()
left_hid = self.quesnet(left).pack_embeded.data[:, :self.rnn_size]
right_hid = self.quesnet(right).pack_embeded.data[:, self.rnn_size:]

wloss = iloss = mloss = None

if words is not None:
lwfea = torch.masked_select(left_hid.clone(), wmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size).clone()
lout = self.lwoutput(lwfea.clone())
rwfea = torch.masked_select(right_hid.clone(), wmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size).clone()
rout = self.rwoutput(rwfea.clone())
out = self.woutput(torch.cat([lwfea.clone(), rwfea.clone()], dim=1).clone())
lwfea = torch.masked_select(left_hid, wmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
lout = self.lwoutput(lwfea)
rwfea = torch.masked_select(right_hid, wmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
rout = self.rwoutput(rwfea)
out = self.woutput(torch.cat([lwfea, rwfea], dim=1))
wloss = (F.cross_entropy(out, words) + F.cross_entropy(lout, words) + F.
cross_entropy(rout, words)) * self.quesnet.lambda_input[0] / 3
wloss *= self.lambda_loss[0]
loss = loss + wloss

if ims is not None:
lifea = torch.masked_select(left_hid.clone(), imask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size).clone()
lout = self.lioutput(lifea.clone())
rifea = torch.masked_select(right_hid.clone(), imask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size).clone()
rout = self.rioutput(rifea.clone())
out = self.ioutput(torch.cat([lifea.clone(), rifea.clone()], dim=1).clone())
lifea = torch.masked_select(left_hid, imask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
lout = self.lioutput(lifea)
rifea = torch.masked_select(right_hid, imask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
rout = self.rioutput(rifea)
out = self.ioutput(torch.cat([lifea, rifea], dim=1))
iloss = (self.quesnet.ie.loss(ims, out) + self.quesnet.ie.loss(ims, lout) + self.quesnet.ie.
loss(ims, rout)) * self.quesnet.lambda_input[1] / 3
iloss *= self.lambda_loss[0]
loss = loss + iloss

if metas is not None:
lmfea = torch.masked_select(left_hid.clone(), mmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size).clone()
lout = self.lmoutput(lmfea.clone())
rmfea = torch.masked_select(right_hid.clone(), mmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size).clone()
rout = self.rmoutput(rmfea.clone())
out = self.moutput(torch.cat([lmfea.clone(), rmfea.clone()], dim=1).clone())
lmfea = torch.masked_select(left_hid, mmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
lout = self.lmoutput(lmfea)
rmfea = torch.masked_select(right_hid, mmask.unsqueeze(1).bool()) \
.view(-1, self.rnn_size)
rout = self.rmoutput(rmfea)
out = self.moutput(torch.cat([lmfea, rmfea], dim=1))
mloss = (self.quesnet.me.loss(metas, out) + self.quesnet.me.loss(metas, lout) + self.quesnet.me.
loss(metas, rout)) * self.quesnet.lambda_input[2] / 3
mloss *= self.lambda_loss[0]
Expand Down
3 changes: 2 additions & 1 deletion EduNLP/Pretrain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# coding: utf-8
# 2021/5/29 @ tongshiwei

from .gensim_vec import train_vector, GensimWordTokenizer, GensimSegTokenizer
from .gensim_vec import pretrain_vector, GensimWordTokenizer, GensimSegTokenizer
from .elmo_vec import *
from .auto_vec import *
from .bert_vec import *
from .quesnet_vec import QuesNetTokenizer, pretrain_quesnet, Question
from .disenqnet_vec import *
Expand Down
Loading
Loading