diff --git a/colbert/indexing/codecs/residual.py b/colbert/indexing/codecs/residual.py index 60fd2655..d2a0b648 100644 --- a/colbert/indexing/codecs/residual.py +++ b/colbert/indexing/codecs/residual.py @@ -40,6 +40,7 @@ class ResidualCodec: Embeddings = ResidualEmbeddings def __init__(self, config, centroids, avg_residual=None, bucket_cutoffs=None, bucket_weights=None): + self.mmap_index = config.mmap_index self.use_gpu = config.total_visible_gpus > 0 if self.use_gpu > 0: self.centroids = centroids.cuda().half() @@ -159,7 +160,7 @@ def compress(self, embs): codes = torch.cat(codes) residuals = torch.cat(residuals) - return ResidualCodec.Embeddings(codes, residuals) + return ResidualCodec.Embeddings(codes, residuals, mmap_index=self.mmap_index) def binarize(self, residuals): residuals = torch.bucketize(residuals.float(), self.bucket_cutoffs).to(dtype=torch.uint8) diff --git a/colbert/indexing/codecs/residual_embeddings.py b/colbert/indexing/codecs/residual_embeddings.py index 270dfb99..575db1fb 100644 --- a/colbert/indexing/codecs/residual_embeddings.py +++ b/colbert/indexing/codecs/residual_embeddings.py @@ -1,14 +1,16 @@ import os import torch import ujson +from collections import defaultdict, namedtuple from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided +from colbert.utils.utils import print_message class ResidualEmbeddings: Strided = ResidualEmbeddingsStrided - def __init__(self, codes, residuals, mmap_index=False): + def __init__(self, codes, residuals, mmap_index=False, pid_to_chunk_metadata=None): """ Supply the already compressed residuals. """ @@ -17,6 +19,8 @@ def __init__(self, codes, residuals, mmap_index=False): if self.mmap_index: self.codes = codes self.residuals = residuals + self.pid_to_chunk_metadata = pid_to_chunk_metadata + return # assert isinstance(residuals, bitarray), type(residuals) assert codes.size(0) == residuals.size(0), (codes.size(), residuals.size()) @@ -45,20 +49,23 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False): codes_offset = 0 pid_offset = 0 - per_chunk_doclens = {} - pid_to_chunk_idx = {} + ChunkMetadata = namedtuple('ChunkMetadata', 'chunk_id, passage_doclen, passage_offset') + pid_to_chunk_metadata = {} # pid -> [chunk id, passage doclen, passage offset in the chunk] for chunk_idx in chunk_idxs: with open(os.path.join(index_path, f'{chunk_idx}.metadata.json')) as f: metadata = ujson.load(f) + with open(os.path.join(index_path, f'doclens.{chunk_idx}.json')) as f: + chunk_doclens = ujson.load(f) + + pid_offset_in_chunk = 0 for pid in range(pid_offset, pid_offset + metadata["num_passages"]): - pid_to_chunk_idx[pid] = chunk_idx + pid_doclen = chunk_doclens[pid - pid_offset] + pid_to_chunk_metadata[pid] = ChunkMetadata(chunk_idx, pid_doclen, pid_offset_in_chunk) + pid_offset_in_chunk += pid_doclen pid_offset += metadata["num_passages"] - with open(os.path.join(index_path, f'{chunk_idx}.doclens.json')) as f: - per_chunk_doclens[chunk_idx] = ujson.load(f) - codes_endpos = codes_offset + metadata["num_embeddings"] chunk = cls.load(index_path, chunk_idx, codes_offset, codes_endpos, packed_dim, mmap_index) @@ -76,8 +83,9 @@ def load_chunks(cls, index_path, chunk_idxs, num_embeddings, mmap_index=False): codes_offset = codes_endpos # codes, residuals = codes.cuda(), residuals.cuda() # FIXME: REMOVE THIS LINE! + print(f"code is {codes}") - return cls(codes, residuals) + return cls(codes, residuals, mmap_index=mmap_index, pid_to_chunk_metadata=pid_to_chunk_metadata) @classmethod def load(cls, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=False): @@ -87,7 +95,7 @@ def load(cls, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=Fals return cls(codes, residuals) @classmethod - def load_codes(self, index_path, chunk_idx, offset, endpos, packed_dim, mmap_index=False): + def load_codes(self, index_path, chunk_idx, offset=None, endpos=None, packed_dim=None, mmap_index=False): codes_path = os.path.join(index_path, f'{chunk_idx}.codes.pt') if mmap_index: @@ -109,33 +117,83 @@ def load_residuals(self, index_path, chunk_idx, offset, endpos, packed_dim, mmap return torch.load(residuals_path, map_location='cpu') - def save(self, path_prefix): + def save(self, index_path, chunk_idx): + path_prefix = os.path.join(index_path, str(chunk_idx)) codes_path = f'{path_prefix}.codes.pt' residuals_path = f'{path_prefix}.residuals.pt' # f'{path_prefix}.residuals.bn' - torch.save(self.codes, codes_path) - torch.save(self.residuals, residuals_path) + print(f"saving code {self.codes}, {self.codes.shape[0]}") + if self.mmap_index: + print("using mmap") + codes_size = self.codes.shape[0] + storage = torch.IntStorage.from_file(codes_path, True, codes_size) + torch.IntTensor(storage).copy_(self.codes) + + dim, nbits = get_dim_and_nbits(index_path) + packed_dim = dim // 8 * nbits + residuals_size = codes_size * packed_dim + storage = torch.ByteStorage.from_file(residuals_path, True, residuals_size) + torch.ByteTensor(storage).copy_(self.residuals) + else: + torch.save(self.codes, codes_path) + torch.save(self.residuals, residuals_path) # _save_bitarray(self.residuals, residuals_path) def lookup_codes(self, pids): assert self.mmap_index - codes = torch.zeros((sum(self.doclens[pid] for pid in pids])) + # prev_pid = 0 + # for pid in pids: + # if pid.item() < prev_pid: + # print_message("not in order") + # prev_pid = pid.item() + pids_per_chunk = defaultdict(list) - for pid in pids: - chunk_idx = self.pid_to_chunk_idx[pid] - pids_per_chunk[chunk_idx].append(pid) + codes_lengths = torch.zeros(len(pids)) + codes_size = 0 + for idx, pid in enumerate(pids): + pid_ = pid.item() + chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_] + pids_per_chunk[chunk_idx].append(pid_) + codes_lengths[idx] = pid_doclen + codes_size += pid_doclen + codes = torch.zeros(codes_size, dtype=torch.int32) + offset = 0 - for chunk_idx in sorted(chunks.keys()): + for chunk_idx in sorted(pids_per_chunk.keys()): pids_ = pids_per_chunk[chunk_idx] for pid in pids_: - codes[offset:offset + self.doclens[pid]] = self.codes[chunk_idx][self.chunk_offsets[pid]:self.chunk_offsets[pid] + doclens[pid]] - offset += doclens[pid] + _, pid_doclen, pid_offset_in_chunk = self.pid_to_chunk_metadata[pid] + codes[offset:offset + pid_doclen] = \ + self.codes[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen] + offset += pid_doclen - return codes + return codes, codes_lengths.long() def lookup_pids(self, pids): assert self.mmap_index - pass + packed_dim = self.residuals[0].shape[1] + + pids_per_chunk = defaultdict(list) + residuals_lengths = torch.zeros(len(pids)) + residuals_size = 0 + for idx, pid in enumerate(pids): + pid_ = pid.item() + chunk_idx, pid_doclen, _ = self.pid_to_chunk_metadata[pid_] + pids_per_chunk[chunk_idx].append(pid_) + residuals_lengths[idx] = pid_doclen + residuals_size += pid_doclen + residuals = torch.zeros(residuals_size, packed_dim, dtype=torch.uint8) + + offset = 0 + for chunk_idx in sorted(pids_per_chunk.keys()): + pids_ = pids_per_chunk[chunk_idx] + for pid in pids_: + _, pid_doclen, pid_offset_in_chunk = self.pid_to_chunk_metadata[pid] + residuals[offset:offset + pid_doclen] = \ + self.residuals[chunk_idx][pid_offset_in_chunk:pid_offset_in_chunk + pid_doclen] + offset += pid_doclen + + return residuals, residuals_lengths def __len__(self): return self.codes.size(0) diff --git a/colbert/indexing/collection_indexer.py b/colbert/indexing/collection_indexer.py index a78721c0..dcacec78 100644 --- a/colbert/indexing/collection_indexer.py +++ b/colbert/indexing/collection_indexer.py @@ -314,7 +314,7 @@ def _build_ivf(self): for chunk_idx in range(self.num_chunks): offset = self.embedding_offsets[chunk_idx] - chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx) + chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx, mmap_index=self.config.mmap_index) codes[offset:offset+chunk_codes.size(0)] = chunk_codes diff --git a/colbert/indexing/index_saver.py b/colbert/indexing/index_saver.py index 4899606e..f5574eff 100644 --- a/colbert/indexing/index_saver.py +++ b/colbert/indexing/index_saver.py @@ -48,8 +48,7 @@ def _saver_thread(self): self._write_chunk_to_disk(*args) def _write_chunk_to_disk(self, chunk_idx, offset, compressed_embs, doclens): - path_prefix = os.path.join(self.config.index_path_, str(chunk_idx)) - compressed_embs.save(path_prefix) + compressed_embs.save(self.config.index_path_, chunk_idx) doclens_path = os.path.join(self.config.index_path_, f'doclens.{chunk_idx}.json') with open(doclens_path, 'w') as output_doclens: diff --git a/colbert/search/index_storage.py b/colbert/search/index_storage.py index a433f1f3..ae1f786e 100644 --- a/colbert/search/index_storage.py +++ b/colbert/search/index_storage.py @@ -4,6 +4,7 @@ from colbert.indexing.loaders import load_doclens from colbert.indexing.codecs.residual_embeddings_strided import ResidualEmbeddingsStrided +from colbert.indexing.codecs import residual_embeddings from colbert.search.strided_tensor import StridedTensor from colbert.search.candidate_generation import CandidateGeneration @@ -124,7 +125,7 @@ def score_pids(self, config, Q, pids, centroid_scores): pids = pids[torch.topk(approx_scores, k=config.ndocs).indices] # Filter docs using full centroid scores - codes_packed, codes_lengths = self.lookup_codes(pids_) + codes_packed, codes_lengths = self.lookup_codes(pids) approx_scores = centroid_scores[codes_packed.long()] approx_scores_strided = StridedTensor(approx_scores, codes_lengths, use_gpu=self.use_gpu) approx_scores_padded, approx_scores_mask = approx_scores_strided.as_padded_tensor()