-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from bigdata-ustc/d2v
[FEATURE] Upgrade SIF and enable end2end vectorization
- Loading branch information
Showing
34 changed files
with
1,693 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# coding: utf-8 | ||
# 2021/8/1 @ tongshiwei | ||
|
||
from .i2v import I2V, get_pretrained_i2v | ||
from .i2v import D2V |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# coding: utf-8 | ||
# 2021/8/1 @ tongshiwei | ||
|
||
import json | ||
from EduNLP.constant import MODEL_DIR | ||
from ..Vector import T2V, get_pretrained_t2v as get_t2v_pretrained_model | ||
from ..Tokenizer import Tokenizer, get_tokenizer | ||
from EduNLP import logger | ||
|
||
__all__ = ["I2V", "D2V", "get_pretrained_i2v"] | ||
|
||
|
||
class I2V(object): | ||
def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None, pretrained_t2v=False, **kwargs): | ||
""" | ||
Parameters | ||
---------- | ||
tokenizer: str | ||
the tokenizer name | ||
t2v: str | ||
the name of token2vector model | ||
args: | ||
the parameters passed to t2v | ||
tokenizer_kwargs: dict | ||
the parameters passed to tokenizer | ||
pretrained_t2v: bool | ||
kwargs: | ||
the parameters passed to t2v | ||
""" | ||
self.tokenizer: Tokenizer = get_tokenizer(tokenizer, **tokenizer_kwargs if tokenizer_kwargs is not None else {}) | ||
if pretrained_t2v: | ||
logger.info("Use pretrained t2v model %s" % t2v) | ||
self.t2v = get_t2v_pretrained_model(t2v, kwargs.get("model_dir", MODEL_DIR)) | ||
else: | ||
self.t2v = T2V(t2v, *args, **kwargs) | ||
self.params = { | ||
"tokenizer": tokenizer, | ||
"tokenizer_kwargs": tokenizer_kwargs, | ||
"t2v": t2v, | ||
"args": args, | ||
"kwargs": kwargs, | ||
"pretrained_t2v": pretrained_t2v | ||
} | ||
|
||
def __call__(self, items, *args, **kwargs): | ||
return self.infer_vector(items, *args, **kwargs) | ||
|
||
def tokenize(self, items, indexing=True, padding=False, *args, **kwargs) -> list: | ||
return self.tokenizer(items, *args, **kwargs) | ||
|
||
def infer_vector(self, items, tokenize=True, indexing=False, padding=False, *args, **kwargs) -> tuple: | ||
raise NotImplementedError | ||
|
||
def infer_item_vector(self, tokens, *args, **kwargs) -> ...: | ||
return self.infer_vector(tokens, *args, **kwargs)[0] | ||
|
||
def infer_token_vector(self, tokens, *args, **kwargs) -> ...: | ||
return self.infer_vector(tokens, *args, **kwargs)[1] | ||
|
||
def save(self, config_path, *args, **kwargs): | ||
with open(config_path, "w", encoding="utf-8") as wf: | ||
json.dump(self.params, wf, ensure_ascii=False, indent=2) | ||
|
||
@classmethod | ||
def load(cls, config_path, *args, **kwargs): | ||
with open(config_path, encoding="utf-8") as f: | ||
params: dict = json.load(f) | ||
tokenizer = params.pop("tokenizer") | ||
t2v = params.pop("t2v") | ||
args = params.pop("args") | ||
kwargs = params.pop("kwargs") | ||
params.update(kwargs) | ||
return cls(tokenizer, t2v, *args, **params) | ||
|
||
@classmethod | ||
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs): | ||
raise NotImplementedError | ||
|
||
@property | ||
def vector_size(self): | ||
return self.t2v.vector_size | ||
|
||
|
||
class D2V(I2V): | ||
def infer_vector(self, items, tokenize=True, indexing=False, padding=False, *args, **kwargs) -> tuple: | ||
tokens = self.tokenize(items, return_token=True) if tokenize is True else items | ||
return self.t2v(tokens, *args, **kwargs), None | ||
|
||
@classmethod | ||
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs): | ||
return cls("text", name, pretrained_t2v=True, model_dir=model_dir) | ||
|
||
|
||
MODELS = { | ||
"d2v_all_256": [D2V, "d2v_all_256"], | ||
"d2v_sci_256": [D2V, "d2v_sci_256"], | ||
"d2v_eng_256": [D2V, "d2v_eng_256"], | ||
"d2v_lit_256": [D2V, "d2v_lit_256"], | ||
} | ||
|
||
|
||
def get_pretrained_i2v(name, model_dir=MODEL_DIR): | ||
if name not in MODELS: | ||
raise KeyError( | ||
"Unknown model name %s, use one of the provided models: %s" % (name, ", ".join(MODELS.keys())) | ||
) | ||
_class, *params = MODELS[name] | ||
return _class.from_pretrained(*params, model_dir=model_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# coding: utf-8 | ||
# 2021/7/12 @ tongshiwei | ||
|
||
from .utils import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# coding: utf-8 | ||
# 2021/7/12 @ tongshiwei | ||
|
||
from .rnn import LM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# coding: utf-8 | ||
# 2021/7/12 @ tongshiwei | ||
|
||
import torch | ||
from torch import nn | ||
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence | ||
|
||
|
||
class LM(nn.Module): | ||
""" | ||
Examples | ||
-------- | ||
>>> import torch | ||
>>> seq_idx = torch.LongTensor([[1, 2, 3], [1, 2, 0], [3, 0, 0]]) | ||
>>> seq_len = torch.LongTensor([3, 2, 1]) | ||
>>> lm = LM("RNN", 4, 3, 2) | ||
>>> output, hn = lm(seq_idx, seq_len) | ||
>>> output.shape | ||
torch.Size([3, 3, 2]) | ||
>>> hn.shape | ||
torch.Size([1, 3, 2]) | ||
>>> lm = LM("RNN", 4, 3, 2, num_layers=2) | ||
>>> output, hn = lm(seq_idx, seq_len) | ||
>>> output.shape | ||
torch.Size([3, 3, 2]) | ||
>>> hn.shape | ||
torch.Size([2, 3, 2]) | ||
""" | ||
|
||
def __init__(self, rnn_type: str, vocab_size: int, embedding_dim: int, hidden_size: int, num_layers=1, | ||
bidirectional=False, embedding=None, **kwargs): | ||
super(LM, self).__init__() | ||
rnn_type = rnn_type.upper() | ||
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim) if embedding is None else embedding | ||
self.c = False | ||
if rnn_type == "RNN": | ||
self.rnn = torch.nn.RNN( | ||
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs | ||
) | ||
elif rnn_type == "LSTM": | ||
self.rnn = torch.nn.LSTM( | ||
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs | ||
) | ||
self.c = True | ||
elif rnn_type == "GRU": | ||
self.rnn = torch.nn.GRU( | ||
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs | ||
) | ||
elif rnn_type == "ELMO": | ||
bidirectional = True | ||
self.rnn = torch.nn.LSTM( | ||
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs | ||
) | ||
self.c = True | ||
else: | ||
raise TypeError("Unknown rnn_type %s" % rnn_type) | ||
|
||
self.num_layers = num_layers | ||
self.bidirectional = bidirectional | ||
if bidirectional is True: | ||
self.num_layers *= 2 | ||
self.hidden_size = hidden_size | ||
|
||
def forward(self, seq_idx, seq_len): | ||
seq = self.embedding(seq_idx) | ||
pack = pack_padded_sequence(seq, seq_len, batch_first=True) | ||
h0 = torch.randn(self.num_layers, seq.shape[0], self.hidden_size) | ||
if self.c is True: | ||
c0 = torch.randn(self.num_layers, seq.shape[0], self.hidden_size) | ||
output, (hn, _) = self.rnn(pack, (h0, c0)) | ||
else: | ||
output, hn = self.rnn(pack, h0) | ||
output, _ = pad_packed_sequence(output, batch_first=True) | ||
return output, hn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# coding: utf-8 | ||
# 2021/7/12 @ tongshiwei | ||
|
||
from .padder import PadSequence, pad_sequence | ||
from .device import set_device |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# coding: utf-8 | ||
# 2021/8/2 @ tongshiwei | ||
import logging | ||
import torch | ||
from torch.nn import DataParallel | ||
|
||
|
||
def set_device(_net, ctx, *args, **kwargs): # pragma: no cover | ||
"""code from longling v1.3.26""" | ||
if ctx == "cpu": | ||
if not isinstance(_net, DataParallel): | ||
_net = DataParallel(_net) | ||
return _net.cpu() | ||
elif any(map(lambda x: x in ctx, ["cuda", "gpu"])): | ||
if not torch.cuda.is_available(): | ||
try: | ||
torch.ones((1,), device=torch.device("cuda: 0")) | ||
except AssertionError as e: | ||
raise TypeError("no cuda detected, noly cpu is supported, the detailed error msg:%s" % str(e)) | ||
if torch.cuda.device_count() >= 1: | ||
if ":" in ctx: | ||
ctx_name, device_ids = ctx.split(":") | ||
assert ctx_name in ["cuda", "gpu"], "the equipment should be 'cpu', 'cuda' or 'gpu', now is %s" % ctx | ||
device_ids = [int(i) for i in device_ids.strip().split(",")] | ||
try: | ||
if not isinstance(_net, DataParallel): | ||
return DataParallel(_net, device_ids).cuda | ||
return _net.cuda(device_ids) | ||
except AssertionError as e: | ||
logging.error(device_ids) | ||
raise e | ||
elif ctx in ["cuda", "gpu"]: | ||
if not isinstance(_net, DataParallel): | ||
_net = DataParallel(_net) | ||
return _net.cuda() | ||
else: | ||
raise TypeError("the equipment should be 'cpu', 'cuda' or 'gpu', now is %s" % ctx) | ||
else: | ||
logging.error(torch.cuda.device_count()) | ||
raise TypeError("0 gpu can be used, use cpu") | ||
else: | ||
if not isinstance(_net, DataParallel): | ||
return DataParallel(_net, device_ids=ctx).cuda() | ||
return _net.cuda(ctx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# coding: utf-8 | ||
# 2021/7/12 @ tongshiwei | ||
|
||
__all__ = ["PadSequence", "pad_sequence"] | ||
|
||
|
||
class PadSequence(object): | ||
"""Pad the sequence. | ||
Pad the sequence to the given `length` by inserting `pad_val`. If `clip` is set, | ||
sequence that has length larger than `length` will be clipped. | ||
Parameters | ||
---------- | ||
length : int | ||
The maximum length to pad/clip the sequence | ||
pad_val : number | ||
The pad value. Default 0 | ||
clip : bool | ||
""" | ||
|
||
def __init__(self, length, pad_val=0, clip=True): | ||
self._length = length | ||
self._pad_val = pad_val | ||
self._clip = clip | ||
|
||
def __call__(self, sample: list): | ||
""" | ||
Parameters | ||
---------- | ||
sample : list of number | ||
Returns | ||
------- | ||
ret : list of number | ||
""" | ||
sample_length = len(sample) | ||
if sample_length >= self._length: | ||
if self._clip and sample_length > self._length: | ||
return sample[:self._length] | ||
else: | ||
return sample | ||
else: | ||
return sample + [ | ||
self._pad_val for _ in range(self._length - sample_length) | ||
] | ||
|
||
|
||
def pad_sequence(sequence: list, max_length=None, pad_val=0, clip=True): | ||
""" | ||
Parameters | ||
---------- | ||
sequence | ||
max_length | ||
pad_val | ||
clip | ||
Returns | ||
------- | ||
Examples | ||
-------- | ||
>>> seq = [[4, 3, 3], [2], [3, 3, 2]] | ||
>>> pad_sequence(seq) | ||
[[4, 3, 3], [2, 0, 0], [3, 3, 2]] | ||
>>> pad_sequence(seq, pad_val=1) | ||
[[4, 3, 3], [2, 1, 1], [3, 3, 2]] | ||
>>> pad_sequence(seq, max_length=2) | ||
[[4, 3], [2, 0], [3, 3]] | ||
>>> pad_sequence(seq, max_length=2, clip=False) | ||
[[4, 3, 3], [2, 0], [3, 3, 2]] | ||
""" | ||
padder = PadSequence(max([len(seq) for seq in sequence]) if max_length is None else max_length, pad_val, clip) | ||
return [padder(seq) for seq in sequence] |
Oops, something went wrong.