Skip to content

Commit

Permalink
Merge pull request #14 from codetalker7/scorer
Browse files Browse the repository at this point in the history
The `Searcher` component.
  • Loading branch information
codetalker7 authored Jul 28, 2024
2 parents 63755bb + 7b5bd41 commit 45b943b
Show file tree
Hide file tree
Showing 8 changed files with 486 additions and 3 deletions.
1 change: 1 addition & 0 deletions examples/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ config = ColBERTConfig(
# create and run the indexer
indexer = Indexer(config)
index(indexer)
ColBERT.save(config)
25 changes: 25 additions & 0 deletions examples/searching.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using ColBERT

# create the config
dataroot = "downloads/lotte"
dataset = "lifestyle"
datasplit = "dev"
path = joinpath(dataroot, dataset, datasplit, "short_collection.tsv")

nbits = 2 # encode each dimension with 2 bits

index_root = "experiments/notebook/indexes"
index_name = "short_$(dataset).$(datasplit).$(nbits)bits"
index_path = joinpath(index_root, index_name)

# build the searcher
searcher = Searcher(index_path)

# search for a query
query = "what are white spots on raspberries?"
pids, scores = search(searcher, query, 2)
print(searcher.config.resource_settings.collection.data[pids])

query = "are rabbits easy to housebreak?"
pids, scores = search(searcher, query, 9)
print(searcher.config.resource_settings.collection.data[pids])
6 changes: 6 additions & 0 deletions src/ColBERT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,10 @@ include("indexing/index_saver.jl")
include("indexing/collection_indexer.jl")
export Indexer, CollectionIndexer, index

# searcher
include("search/strided_tensor.jl")
include("search/index_storage.jl")
include("searching.jl")
export Searcher, search

end
80 changes: 80 additions & 0 deletions src/indexing/codecs/residual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ mutable struct ResidualCodec
bucket_weights::Vector{Float64}
end

"""
# Examples
```julia-repl
julia> codec = load_codec(index_path);
```
"""
function load_codec(index_path::String)
config = load(joinpath(index_path, "config.jld2"), "config")
centroids = load(joinpath(index_path, "centroids.jld2"), "centroids")
avg_residual = load(joinpath(index_path, "avg_residual.jld2"), "avg_residual")
buckets = load(joinpath(index_path, "buckets.jld2"))
ResidualCodec(config, centroids, avg_residual, buckets["bucket_cutoffs"], buckets["bucket_weights"])
end

"""
compress_into_codes(codec::ResidualCodec, embs::Matrix{Float64})
Expand Down Expand Up @@ -130,6 +146,64 @@ function compress(codec::ResidualCodec, embs::Matrix{Float64})
codes, residuals
end

function decompress_residuals(codec::ResidualCodec, binary_residuals::Array{UInt8})
dim = codec.config.doc_settings.dim
nbits = codec.config.indexing_settings.nbits

@assert ndims(binary_residuals) == 2
@assert size(binary_residuals)[1] == (dim / 8) * nbits

# unpacking UInt8 into bits
unpacked_bits = BitVector()
for byte in vec(binary_residuals)
append!(unpacked_bits, [byte & (0x1<<n) != 0 for n in 0:7])
end

# reshaping into dims (nbits, dim, num_embeddings); inverse of what binarize does
unpacked_bits = reshape(unpacked_bits, nbits, dim, size(binary_residuals)[2])

# get decimal value for coordinate of the nbits-wide dimension; again, inverse of binarize
positionbits = fill(1, (nbits, 1, 1))
for i in 1:nbits
positionbits[i, :, :] .= 1 << (i - 1)
end

# multiply by 2^(i - 1) for the ith bit, and take sum to get the original bin back
unpacked_bits = unpacked_bits .* positionbits
unpacked_bits = sum(unpacked_bits, dims=1)
unpacked_bits = unpacked_bits .+ 1 # adding 1 to get correct bin indices

# reshaping to get rid of the nbits wide dimension
unpacked_bits = reshape(unpacked_bits, size(unpacked_bits)[2:end]...)
embeddings = codec.bucket_weights[unpacked_bits]
end

function decompress(codec::ResidualCodec, codes::Vector{Int}, residuals::Array{UInt8})
@assert ndims(codes) == 1
@assert ndims(residuals) == 2
@assert length(codes) == size(residuals)[2]

# decompress in batches
D = Vector{Array{<:AbstractFloat}}()
bsize = 1 << 15
batch_offset = 1
while batch_offset <= length(codes)
batch_codes = codes[batch_offset:min(batch_offset + bsize - 1, length(codes))]
batch_residuals = residuals[:, batch_offset:min(batch_offset + bsize - 1, length(codes))]

centroids_ = codec.centroids[:, batch_codes]
residuals_ = decompress_residuals(codec, batch_residuals)

batch_embeddings = centroids_ + residuals_
batch_embeddings = mapslices(v -> iszero(v) ? v : normalize(v), batch_embeddings, dims = 1)
push!(D, batch_embeddings)

batch_offset += bsize
end

cat(D..., dims = 2)
end

"""
load_codes(codec::ResidualCodec, chunk_idx::Int)
Expand All @@ -150,3 +224,9 @@ function load_codes(codec::ResidualCodec, chunk_idx::Int)
codes = JLD2.load(codes_path, "codes")
codes
end

function load_residuals(codec::ResidualCodec, chunk_idx::Int)
residual_path = joinpath(codec.config.indexing_settings.index_path, "$(chunk_idx).residuals.jld2")
residuals = JLD2.load(residual_path, "residuals")
residuals
end
5 changes: 2 additions & 3 deletions src/infra/settings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,6 @@ Base.@kwdef struct IndexingSettings
end

Base.@kwdef struct SearchSettings
ncells::Union{Nothing, Int} = nothing
centroid_score_threshold::Union{Nothing, Float64} = nothing
ndocs::Union{Nothing, Int} = nothing
nprobe::Int = 2
ncandidates::Int = 8192
end
174 changes: 174 additions & 0 deletions src/search/index_storage.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
struct IndexScorer
metadata::Dict
codec::ResidualCodec
ivf::Vector{Int}
ivf_lengths::Vector{Int}
doclens::Vector{Int}
codes::Vector{Int}
residuals::Matrix{UInt8}
emb2pid::Vector{Int}
end

"""
# Examples
```julia-repl
julia> IndexScorer(index_path)
```
"""
function IndexScorer(index_path::String)
@info "Loading the index from {index_path}."

# loading the config from the index path
config = JLD2.load(joinpath(index_path, "config.jld2"))["config"]

# the metadata
metadata_path = joinpath(index_path, "metadata.json")
metadata = JSON.parsefile(metadata_path)

# loading the codec
codec = load_codec(index_path)

# loading ivf into a StridedTensor
ivf_path = joinpath(index_path, "ivf.jld2")
ivf_dict = JLD2.load(ivf_path)
ivf, ivf_lengths = ivf_dict["ivf"], ivf_dict["ivf_lengths"]
# ivf = StridedTensor(ivf, ivf_lengths)

# loading all doclens
doclens = Vector{Int}()
for chunk_idx in 1:metadata["num_chunks"]
doclens_file = joinpath(index_path, "doclens.$(chunk_idx).jld2")
chunk_doclens = JLD2.load(doclens_file, "doclens")
append!(doclens, chunk_doclens)
end

# loading all embeddings
num_embeddings = metadata["num_embeddings"]
dim, nbits = config.doc_settings.dim, config.indexing_settings.nbits
@assert (dim * nbits) % 8 == 0
codes = zeros(Int, num_embeddings)
residuals = zeros(UInt8, Int((dim / 8) * nbits), num_embeddings)
codes_offset = 1
for chunk_idx in 1:metadata["num_chunks"]
chunk_codes = load_codes(codec, chunk_idx)
chunk_residuals = load_residuals(codec, chunk_idx)

codes_endpos = codes_offset + length(chunk_codes) - 1
codes[codes_offset:codes_endpos] = chunk_codes
residuals[:, codes_offset:codes_endpos] = chunk_residuals

codes_offset = codes_offset + length(chunk_codes)
end

# the emb2pid mapping
@info "Building the emb2pid mapping."
@assert isequal(sum(doclens), metadata["num_embeddings"])
emb2pid = zeros(Int, metadata["num_embeddings"])

offset_doclens = 1
for (pid, dlength) in enumerate(doclens)
emb2pid[offset_doclens:offset_doclens + dlength - 1] .= pid
offset_doclens += dlength
end

IndexScorer(
metadata,
codec,
ivf,
ivf_lengths,
doclens,
codes,
residuals,
emb2pid,
)
end

"""
Return a candidate set of `pids` for the query matrix `Q`. This is done as follows: the nearest `nprobe` centroids for each query embedding are found. This list is then flattened and the unique set of these centroids is built. Using the `ivf`, the list of all unique embedding IDs contained in these centroids is computed. Finally, these embedding IDs are converted to `pids` using `emb2pid`. This list of `pids` is the final candidate set.
"""
function retrieve(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat})
@assert isequal(size(Q)[2], config.query_settings.query_maxlen) # Q: (128, 32, 1)

Q = reshape(Q, size(Q)[1:end .!= end]...) # squeeze out the last dimension
@assert isequal(length(size(Q)), 2)

# score of each query embedding with each centroid and take top nprobe centroids
cells = transpose(Q) * ranker.codec.centroids
cells = mapslices(row -> partialsortperm(row, 1:config.search_settings.nprobe, rev=true), cells, dims = 2) # take top nprobe centroids for each query
centroid_ids = sort(unique(vec(cells)))

# get all embedding IDs contained in centroid_ids using ivf
centroid_ivf_offsets = cat([1], 1 .+ cumsum(ranker.ivf_lengths)[1:end .!= end], dims = 1)
eids = Vector{Int}()
for centroid_id in centroid_ids
offset = centroid_ivf_offsets[centroid_id]
length = ranker.ivf_lengths[centroid_id]
append!(eids, ranker.ivf[offset:offset + length - 1])
end
@assert isequal(length(eids), sum(ranker.ivf_lengths[centroid_ids]))
eids = sort(unique(eids))

# get pids from the emb2pid mapping
pids = sort(unique(ranker.emb2pid[eids]))
pids
end

"""
- Get the decompressed embedding matrix for all embeddings in `pids`. Use `doclens` for this.
"""
function score_pids(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat}, pids::Vector{Int})
# get codes and residuals for all embeddings across all pids
num_embs = sum(ranker.doclens[pids])
codes_packed = zeros(Int, num_embs)
residuals_packed = zeros(UInt8, size(ranker.residuals)[1], num_embs)
pid_offsets = cat([1], 1 .+ cumsum(ranker.doclens)[1:end .!= end], dims=1)

offset = 1
for pid in pids
pid_offset = pid_offsets[pid]
num_embs_pid = ranker.doclens[pid]
codes_packed[offset: offset + num_embs_pid - 1] = ranker.codes[pid_offset: pid_offset + num_embs_pid - 1]
residuals_packed[:, offset: offset + num_embs_pid - 1] = ranker.residuals[:, pid_offset: pid_offset + num_embs_pid - 1]
offset += num_embs_pid
end
@assert offset == num_embs + 1

# decompress these codes and residuals to get the original embeddings
D_packed = decompress(ranker.codec, codes_packed, residuals_packed)
@assert ndims(D_packed) == 2
@assert size(D_packed)[1] == config.doc_settings.dim
@assert size(D_packed)[2] == num_embs

# get the max-sim scores
if size(Q)[3] > 1
error("Only one query is supported at the moment!")
end
@assert size(Q)[3] == 1
Q = reshape(Q, size(Q)[1:2]...)

scores = Vector{Float64}()
query_doc_scores = transpose(Q) * D_packed # (num_query_tokens, num_embeddings)
offset = 1
for pid in pids
num_embs_pid = ranker.doclens[pid]
pid_scores = query_doc_scores[:, offset:min(num_embs, offset + num_embs_pid - 1)]
push!(scores, sum(maximum(pid_scores, dims = 2)))

offset += num_embs_pid
end
@assert offset == num_embs + 1

scores
end

function rank(ranker::IndexScorer, config::ColBERTConfig, Q::Array{<:AbstractFloat})
pids = retrieve(ranker, config, Q)
scores = score_pids(ranker, config, Q, pids)
indices = sortperm(scores, rev=true)

pids[indices], scores[indices]
end
Loading

0 comments on commit 45b943b

Please sign in to comment.