Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

export ids during tokenisation #30

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 28 additions & 9 deletions colbert/modeling/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,30 +28,49 @@ def doc(self, *args, to_cpu=False, **kw_args):
D = self.colbert.doc(*args, **kw_args)
return D.cpu() if to_cpu else D

def queryFromText(self, queries, bsize=None, to_cpu=False):
def queryFromText(self, queries, bsize=None, to_cpu=False, with_ids=False):
if bsize:
batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
return torch.cat(batches)
batchesEmbs = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
if with_ids:
return torch.cat(batchesEmbs), torch.cat([ids for ids, _ in batches]), torch.cat([masks for _, masks in batches])
return torch.cat(batchesEmbs)

input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
if with_ids:
return (self.query(input_ids, attention_mask), input_ids, attention_mask)
return self.query(input_ids, attention_mask)

def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False, with_ids=False):
if bsize:
batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

batch_ids, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)
# batch_ids contain batches; each batch is a 2-tuple, of which the left is
# the ids of each document, and the right is the masks of each document

batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
for input_ids, attention_mask in batches]

for input_ids, attention_mask in batch_ids]
if keep_dims:
D = _stack_3D_tensors(batches)
if with_ids:
Dids = _stack_3D_tensors(batch_ids)
return D[reverse_indices], Dids
return D[reverse_indices]

D = [d for batch in batches for d in batch]
if with_ids:
#the masking code assumes that args.mask_punctuation is false.
assert len(self.colbert.skiplist) == 0

D_i = [ d[(mask > 0) & (d != 0)] for input_ids, attention_masks in batch_ids for d, mask in zip(input_ids,attention_masks) ]

left = [D[idx] for idx in reverse_indices.tolist()]
right = [D_i[idx] for idx in reverse_indices.tolist()]
return left, right
return [D[idx] for idx in reverse_indices.tolist()]

input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
if with_ids:
return self.doc(input_ids, attention_mask, keep_dims=keep_dims), input_ids
return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

def score(self, Q, D, mask=None, lengths=None, explain=False):
Expand Down