From f3ab011bc9220b34733ec092c1feb5c3f968180a Mon Sep 17 00:00:00 2001 From: xm-gui Date: Wed, 11 May 2022 21:33:42 -0700 Subject: [PATCH 1/4] initial impl for lookup_codes --- .../indexing/codecs/residual_embeddings.py | 35 ++++++++++++------- colbert/search/index_storage.py | 1 + 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/colbert/indexing/codecs/residual_embeddings.py b/colbert/indexing/codecs/residual_embeddings.py index 270dfb99..36110faa 100644 --- a/colbert/indexing/codecs/residual_embeddings.py +++ b/colbert/indexing/codecs/residual_embeddings.py @@ -1,6 +1,7 @@ import os import torch import ujson +from collections import defaultdict from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided @@ -8,7 +9,7 @@ class ResidualEmbeddings: Strided = ResidualEmbeddingsStrided - def __init__(self, codes, residuals, mmap_index=False): + def __init__(self, codes, residuals, mmap_index=False, pid_to_chunk_metadata=None): """ Supply the already compressed residuals. """ @@ -17,6 +18,7 @@ def __init__(self, codes, residuals, mmap_index=False): if self.mmap_index: self.codes = codes self.residuals = residuals + self.pid_to_chunk_metadata = pid_to_chunk_metadata # assert isinstance(residuals, bitarray), type(residuals) assert codes.size(0) == residuals.size(0), (codes.size(), residuals.size()) @@ -45,20 +47,22 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False): codes_offset = 0 pid_offset = 0 - per_chunk_doclens = {} - pid_to_chunk_idx = {} + pid_to_chunk_metadata = {} # pid -> [chunk id, passage doclen, passage offset in the chunk] for chunk_idx in chunk_idxs: with open(os.path.join(index_path, f'{chunk_idx}.metadata.json')) as f: metadata = ujson.load(f) + with open(os.path.join(index_path, f'{chunk_idx}.doclens.json')) as f: + chunk_doclens = ujson.load(f) + + pid_offset_in_chunk = 0 for pid in range(pid_offset, pid_offset + metadata["num_passages"]): - pid_to_chunk_idx[pid] = chunk_idx + pid_doclen = chunk_doclens[pid - pid_offset] + pid_to_chunk_metadata[pid] = [chunk_idx, pid_doclen, pid_offset_in_chunk] + pid_offset_in_chunk += pid_doclen pid_offset += metadata["num_passages"] - with open(os.path.join(index_path, f'{chunk_idx}.doclens.json')) as f: - per_chunk_doclens[chunk_idx] = ujson.load(f) - codes_endpos = codes_offset + metadata["num_embeddings"] chunk = cls.load(index_path, chunk_idx, codes_offset, codes_endpos, packed_dim, mmap_index) @@ -77,7 +81,7 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False): # codes, residuals = codes.cuda(), residuals.cuda() # FIXME: REMOVE THIS LINE! - return cls(codes, residuals) + return cls(codes, residuals, mmap_index=mmap_index, pid_to_chunk_metadata=pid_to_chunk_metadata) @classmethod def load(cls, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=False): @@ -119,17 +123,22 @@ def save(self, path_prefix): def lookup_codes(self, pids): assert self.mmap_index - codes = torch.zeros((sum(self.doclens[pid] for pid in pids])) + codes = torch.zeros(sum([self.pid_to_chunk_metadata[pid][1] for pid in pids])) + pids_per_chunk = defaultdict(list) for pid in pids: - chunk_idx = self.pid_to_chunk_idx[pid] + chunk_idx = self.pid_to_chunk_metadata[pid][0] pids_per_chunk[chunk_idx].append(pid) + offset = 0 - for chunk_idx in sorted(chunks.keys()): + for chunk_idx in sorted(pids_per_chunk.keys()): pids_ = pids_per_chunk[chunk_idx] for pid in pids_: - codes[offset:offset + self.doclens[pid]] = self.codes[chunk_idx][self.chunk_offsets[pid]:self.chunk_offsets[pid] + doclens[pid]] - offset += doclens[pid] + pid_doclen = self.pid_to_chunk_metadata[pid][1] + pid_offset_in_chunk = self.pid_to_chunk_metadata[pid][2] + codes[offset:offset + pid_doclen] = \ + self.codes[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen] + offset += pid_doclen return codes diff --git a/colbert/search/index_storage.py b/colbert/search/index_storage.py index a433f1f3..c2b48e0c 100644 --- a/colbert/search/index_storage.py +++ b/colbert/search/index_storage.py @@ -4,6 +4,7 @@ from colbert.indexing.loaders import load_doclens from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided +from colbert.indexing.codecs import residual_embeddings from colbert.search.strided_tensor import StridedTensor from colbert.search.candidate_generation import CandidateGeneration From e6f8e52cd59450efa0fc512cee88cb77c01198bf Mon Sep 17 00:00:00 2001 From: xm-gui Date: Sat, 14 May 2022 13:19:43 -0700 Subject: [PATCH 2/4] initial impl for lookup pids --- .../indexing/codecs/residual_embeddings.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/colbert/indexing/codecs/residual_embeddings.py b/colbert/indexing/codecs/residual_embeddings.py index 36110faa..29b12bcd 100644 --- a/colbert/indexing/codecs/residual_embeddings.py +++ b/colbert/indexing/codecs/residual_embeddings.py @@ -144,7 +144,26 @@ def lookup_codes(self, pids): def lookup_pids(self, pids): assert self.mmap_index - pass + print(f"mei-test residuals shape {self.residuals.shape}") + packed_dim = self.residuals.shape[2] + residuals = torch.zeros(sum([self.pid_to_chunk_metadata[pid][1] for pid in pids]), packed_dim) + + pids_per_chunk = defaultdict(list) + for pid in pids: + chunk_idx = self.pid_to_chunk_metadata[pid][0] + pids_per_chunk[chunk_idx].append(pid) + + offset = 0 + for chunk_idx in sorted(pids_per_chunk.keys()): + pids_ = pids_per_chunk[chunk_idx] + for pid in pids_: + pid_doclen = self.pid_to_chunk_metadata[pid][1] + pid_offset_in_chunk = self.pid_to_chunk_metadata[pid][2] + residuals[offset:offset + pid_doclen, :packed_dim] = \ + self.residuals[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen, :packed_dim] + offset += pid_doclen + + return residuals def __len__(self): return self.codes.size(0) From e56aa0bc5daeb496f968846abeee2f3f8f741edd Mon Sep 17 00:00:00 2001 From: xm-gui Date: Sun, 15 May 2022 14:09:24 -0700 Subject: [PATCH 3/4] fixes in progress --- .../indexing/codecs/residual_embeddings.py | 56 ++++++++++++------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/colbert/indexing/codecs/residual_embeddings.py b/colbert/indexing/codecs/residual_embeddings.py index 29b12bcd..c54c557a 100644 --- a/colbert/indexing/codecs/residual_embeddings.py +++ b/colbert/indexing/codecs/residual_embeddings.py @@ -1,9 +1,10 @@ import os import torch import ujson -from collections import defaultdict +from collections import defaultdict, namedtuple from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided +from colbert.utils.utils import print_message class ResidualEmbeddings: @@ -19,6 +20,7 @@ def __init__(self, codes, residuals, mmap_index=False, pid_to_chunk_metadata=Non self.codes = codes self.residuals = residuals self.pid_to_chunk_metadata = pid_to_chunk_metadata + return # assert isinstance(residuals, bitarray), type(residuals) assert codes.size(0) == residuals.size(0), (codes.size(), residuals.size()) @@ -47,19 +49,20 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False): codes_offset = 0 pid_offset = 0 + ChunkMetadata = namedtuple('ChunkMetadata', 'chunk_id, passage_doclen, passage_offset') pid_to_chunk_metadata = {} # pid -> [chunk id, passage doclen, passage offset in the chunk] for chunk_idx in chunk_idxs: with open(os.path.join(index_path, f'{chunk_idx}.metadata.json')) as f: metadata = ujson.load(f) - with open(os.path.join(index_path, f'{chunk_idx}.doclens.json')) as f: + with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f: chunk_doclens = ujson.load(f) pid_offset_in_chunk = 0 for pid in range(pid_offset, pid_offset + metadata["num_passages"]): pid_doclen = chunk_doclens[pid - pid_offset] - pid_to_chunk_metadata[pid] = [chunk_idx, pid_doclen, pid_offset_in_chunk] + pid_to_chunk_metadata[pid] = ChunkMetadata(chunk_idx, pid_doclen, pid_offset_in_chunk) pid_offset_in_chunk += pid_doclen pid_offset += metadata["num_passages"] @@ -91,7 +94,7 @@ def load(cls, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=Fals return cls(codes, residuals) @classmethod - def load_codes(self, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=False): + def load_codes(self, index_path, chunk_idx, offset=None, endpos=None, packed_dim=None, mmap_index=False): codes_path = os.path.join(index_path, f'{chunk_idx}.codes.pt') if mmap_index: @@ -123,47 +126,62 @@ def save(self, path_prefix): def lookup_codes(self, pids): assert self.mmap_index - codes = torch.zeros(sum([self.pid_to_chunk_metadata[pid][1] for pid in pids])) + # prev_pid = 0 + # for pid in pids: + # if pid.item() < prev_pid: + # print_message("not in order") + # prev_pid = pid.item() pids_per_chunk = defaultdict(list) - for pid in pids: - chunk_idx = self.pid_to_chunk_metadata[pid][0] - pids_per_chunk[chunk_idx].append(pid) + codes_lengths = torch.zeros(len(pids)) + codes_size = 0 + for idx, pid in enumerate(pids): + # print_message(f"pid shape: {pid.shape}, {len(pid.shape)}") + pid_ = pid.item() + chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_] + pids_per_chunk[chunk_idx].append(pid_) + codes_lengths[idx] = pid_doclen + codes_size += pid_doclen + codes = torch.zeros(codes_size) offset = 0 for chunk_idx in sorted(pids_per_chunk.keys()): pids_ = pids_per_chunk[chunk_idx] for pid in pids_: - pid_doclen = self.pid_to_chunk_metadata[pid][1] - pid_offset_in_chunk = self.pid_to_chunk_metadata[pid][2] + _, pid_doclen, pid_offset_in_chunk = self.pid_to_chunk_metadata[pid] codes[offset:offset + pid_doclen] = \ self.codes[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen] offset += pid_doclen - return codes + return codes, codes_lengths def lookup_pids(self, pids): assert self.mmap_index - print(f"mei-test residuals shape {self.residuals.shape}") + print_message(f"mei-test residuals shape {self.residuals.shape}") packed_dim = self.residuals.shape[2] - residuals = torch.zeros(sum([self.pid_to_chunk_metadata[pid][1] for pid in pids]), packed_dim) pids_per_chunk = defaultdict(list) - for pid in pids: - chunk_idx = self.pid_to_chunk_metadata[pid][0] - pids_per_chunk[chunk_idx].append(pid) + residuals_lengths = torch.zeros(len(pids)) + residuals_size = 0 + for idx, pid in enumerate(pids): + print_message(f"pid shape: {pid.shape}, {len(pid.shape)}") + pid_ = pid.item() + chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_] + pids_per_chunk[chunk_idx].append(pid_) + residuals_lengths[idx] = pid_doclen + residuals_size += pid_doclen + residuals = torch.zeros(residuals_size) offset = 0 for chunk_idx in sorted(pids_per_chunk.keys()): pids_ = pids_per_chunk[chunk_idx] for pid in pids_: - pid_doclen = self.pid_to_chunk_metadata[pid][1] - pid_offset_in_chunk = self.pid_to_chunk_metadata[pid][2] + _, pid_doclen, pid_offset_in_chunk = self.pid_to_chunk_metadata[pid] residuals[offset:offset + pid_doclen, :packed_dim] = \ self.residuals[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen, :packed_dim] offset += pid_doclen - return residuals + return residuals, residuals_lengths def __len__(self): return self.codes.size(0) From 83bf7a46f66a072ebdb04b98c383af9a92935d7a Mon Sep 17 00:00:00 2001 From: xm-gui Date: Sat, 28 May 2022 17:12:55 -0700 Subject: [PATCH 4/4] changes for mmap --- colbert/indexing/codecs/residual.py | 3 +- .../indexing/codecs/residual_embeddings.py | 36 ++++++++++++------- colbert/indexing/collection_indexer.py | 2 +- colbert/indexing/index_saver.py | 3 +- colbert/search/index_storage.py | 2 +- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/colbert/indexing/codecs/residual.py b/colbert/indexing/codecs/residual.py index 60fd2655..d2a0b648 100644 --- a/colbert/indexing/codecs/residual.py +++ b/colbert/indexing/codecs/residual.py @@ -40,6 +40,7 @@ class ResidualCodec: Embeddings = ResidualEmbeddings def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bucket_weights=None): + self.mmap_index = config.mmap_index self.use_gpu = config.total_visible_gpus > 0 if self.use_gpu > 0: self.centroids = centroids.cuda().half() @@ -159,7 +160,7 @@ def compress(self, embs): codes = torch.cat(codes) residuals = torch.cat(residuals) - return ResidualCodec.Embeddings(codes, residuals) + return ResidualCodec.Embeddings(codes, residuals, mmap_index=self.mmap_index) def binarize(self, residuals): residuals = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8) diff --git a/colbert/indexing/codecs/residual_embeddings.py b/colbert/indexing/codecs/residual_embeddings.py index c54c557a..575db1fb 100644 --- a/colbert/indexing/codecs/residual_embeddings.py +++ b/colbert/indexing/codecs/residual_embeddings.py @@ -83,6 +83,7 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False): codes_offset = codes_endpos # codes, residuals = codes.cuda(), residuals.cuda() # FIXME: REMOVE THIS LINE! + print(f"code is {codes}") return cls(codes, residuals, mmap_index=mmap_index, pid_to_chunk_metadata=pid_to_chunk_metadata) @@ -116,12 +117,26 @@ def load_residuals(self, index_path, chunk_idx, offset, endpos, packed_dim, mmap return torch.load(residuals_path, map_location='cpu') - def save(self, path_prefix): + def save(self, index_path, chunk_idx): + path_prefix = os.path.join(index_path, str(chunk_idx)) codes_path = f'{path_prefix}.codes.pt' residuals_path = f'{path_prefix}.residuals.pt' # f'{path_prefix}.residuals.bn' - torch.save(self.codes, codes_path) - torch.save(self.residuals, residuals_path) + print(f"saving code {self.codes}, {self.codes.shape[0]}") + if self.mmap_index: + print("using mmap") + codes_size = self.codes.shape[0] + storage = torch.IntStorage.from_file(codes_path, True, codes_size) + torch.IntTensor(storage).copy_(self.codes) + + dim, nbits = get_dim_and_nbits(index_path) + packed_dim = dim // 8 * nbits + residuals_size = codes_size * packed_dim + storage = torch.ByteStorage.from_file(residuals_path, True, residuals_size) + torch.ByteTensor(storage).copy_(self.residuals) + else: + torch.save(self.codes, codes_path) + torch.save(self.residuals, residuals_path) # _save_bitarray(self.residuals, residuals_path) def lookup_codes(self, pids): @@ -136,13 +151,12 @@ def lookup_codes(self, pids): codes_lengths = torch.zeros(len(pids)) codes_size = 0 for idx, pid in enumerate(pids): - # print_message(f"pid shape: {pid.shape}, {len(pid.shape)}") pid_ = pid.item() chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_] pids_per_chunk[chunk_idx].append(pid_) codes_lengths[idx] = pid_doclen codes_size += pid_doclen - codes = torch.zeros(codes_size) + codes = torch.zeros(codes_size, dtype=torch.int32) offset = 0 for chunk_idx in sorted(pids_per_chunk.keys()): @@ -153,32 +167,30 @@ def lookup_codes(self, pids): self.codes[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen] offset += pid_doclen - return codes, codes_lengths + return codes, codes_lengths.long() def lookup_pids(self, pids): assert self.mmap_index - print_message(f"mei-test residuals shape {self.residuals.shape}") - packed_dim = self.residuals.shape[2] + packed_dim = self.residuals[0].shape[1] pids_per_chunk = defaultdict(list) residuals_lengths = torch.zeros(len(pids)) residuals_size = 0 for idx, pid in enumerate(pids): - print_message(f"pid shape: {pid.shape}, {len(pid.shape)}") pid_ = pid.item() chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_] pids_per_chunk[chunk_idx].append(pid_) residuals_lengths[idx] = pid_doclen residuals_size += pid_doclen - residuals = torch.zeros(residuals_size) + residuals = torch.zeros(residuals_size, packed_dim, dtype=torch.uint8) offset = 0 for chunk_idx in sorted(pids_per_chunk.keys()): pids_ = pids_per_chunk[chunk_idx] for pid in pids_: _, pid_doclen, pid_offset_in_chunk = self.pid_to_chunk_metadata[pid] - residuals[offset:offset + pid_doclen, :packed_dim] = \ - self.residuals[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen, :packed_dim] + residuals[offset:offset + pid_doclen] = \ + self.residuals[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen] offset += pid_doclen return residuals, residuals_lengths diff --git a/colbert/indexing/collection_indexer.py b/colbert/indexing/collection_indexer.py index a78721c0..dcacec78 100644 --- a/colbert/indexing/collection_indexer.py +++ b/colbert/indexing/collection_indexer.py @@ -314,7 +314,7 @@ def _build_ivf(self): for chunk_idx in range(self.num_chunks): offset = self.embedding_offsets[chunk_idx] - chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx) + chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx, mmap_index=self.config.mmap_index) codes[offset:offset+chunk_codes.size(0)] = chunk_codes diff --git a/colbert/indexing/index_saver.py b/colbert/indexing/index_saver.py index 4899606e..f5574eff 100644 --- a/colbert/indexing/index_saver.py +++ b/colbert/indexing/index_saver.py @@ -48,8 +48,7 @@ def _saver_thread(self): self._write_chunk_to_disk(*args) def _write_chunk_to_disk(self, chunk_idx, offset, compressed_embs, doclens): - path_prefix = os.path.join(self.config.index_path_, str(chunk_idx)) - compressed_embs.save(path_prefix) + compressed_embs.save(self.config.index_path_, chunk_idx) doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json') with open(doclens_path, 'w') as output_doclens: diff --git a/colbert/search/index_storage.py b/colbert/search/index_storage.py index c2b48e0c..ae1f786e 100644 --- a/colbert/search/index_storage.py +++ b/colbert/search/index_storage.py @@ -125,7 +125,7 @@ def score_pids(self, config, Q, pids, centroid_scores): pids = pids[torch.topk(approx_scores, k=config.ndocs).indices] # Filter docs using full centroid scores - codes_packed, codes_lengths = self.lookup_codes(pids_) + codes_packed, codes_lengths = self.lookup_codes(pids) approx_scores = centroid_scores[codes_packed.long()] approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu) approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()