diff --git a/src/indexing.jl b/src/indexing.jl index 305481e..bc504b0 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -145,3 +145,13 @@ function index(indexer::Indexer) # check if all relevant files are saved _check_all_files_are_saved(indexer.config.index_path) end + +function Base.show(io::IO, indexer::Indexer) + print(io, "ColBERT Indexer:\n") + print(io, " Collection size: $(length(indexer.collection)) documents\n") + print(io, " Model: $(indexer.config.checkpoint)\n") + print(io, " Dimension: $(indexer.config.dim)\n") + print(io, " Index path: $(indexer.config.index_path)\n") + print(io, " Document maxlen: $(indexer.config.doc_maxlen)\n") + print(io, " Compression bits: $(indexer.config.nbits)\n") +end diff --git a/src/indexing/codecs/residual.jl b/src/indexing/codecs/residual.jl index b6a3d07..4c0d117 100644 --- a/src/indexing/codecs/residual.jl +++ b/src/indexing/codecs/residual.jl @@ -782,3 +782,13 @@ function decompress( end embeddings end + +function Base.show(io::IO, codec::Dict{String, Any}) + print(io, "ColBERT Residual Codec:\n") + print(io, " Centroids: $(size(codec["centroids"],2))\n") + print(io, " Average residual: $(round(codec["avg_residual"], digits=4))\n") + print( + io, " Bucket cutoffs: $(round.(codec["bucket_cutoffs"], digits=4))\n") + print( + io, " Bucket weights: $(round.(codec["bucket_weights"], digits=4))\n") +end diff --git a/src/infra/config.jl b/src/infra/config.jl index 06fbc5d..d0824b7 100644 --- a/src/infra/config.jl +++ b/src/infra/config.jl @@ -63,7 +63,7 @@ Base.@kwdef struct ColBERTConfig query_token::String = "[Q]" doc_token::String = "[D]" - # resource settings + # resource settings checkpoint::String = "colbert-ir/colbertv2.0" collection::Union{String, Vector{String}} = "" @@ -88,3 +88,31 @@ Base.@kwdef struct ColBERTConfig nprobe::Int = 2 ncandidates::Int = 8192 end + +function Base.show(io::IO, config::ColBERTConfig) + print(io, "ColBERTConfig:\n") + print(io, " Model:\n") + print(io, " checkpoint: $(config.checkpoint)\n") + print(io, " dim: $(config.dim)\n") + print(io, " Documents:\n") + print(io, + " collection: $(config.collection isa String ? config.collection : "$(length(config.collection)) documents")\n") + print(io, " max length: $(config.doc_maxlen)\n") + print(io, " mask punctuation: $(config.mask_punctuation)\n") + print(io, " Queries:\n") + print(io, " max length: $(config.query_maxlen)\n") + print(io, " attend to mask: $(config.attend_to_mask_tokens)\n") + print(io, " Indexing:\n") + print(io, " path: $(config.index_path)\n") + print(io, " batch size: $(config.index_bsize)\n") + print(io, " chunk size: $(config.chunksize)\n") + print(io, " compression bits: $(config.nbits)\n") + print(io, " kmeans iterations: $(config.kmeans_niters)\n") + print(io, " Search:\n") + print(io, " nprobe: $(config.nprobe)\n") + print(io, " ncandidates: $(config.ncandidates)\n") + print(io, " Hardware:\n") + print(io, " GPU: $(config.use_gpu)\n") + print(io, " rank: $(config.rank)\n") + print(io, " nranks: $(config.nranks)\n") +end diff --git a/src/searching.jl b/src/searching.jl index e617931..ec5368a 100644 --- a/src/searching.jl +++ b/src/searching.jl @@ -126,3 +126,15 @@ function search(searcher::Searcher, query::String, k::Int) pids, scores = pids[indices], scores[indices] pids[1:k], scores[1:k] end + +function Base.show(io::IO, searcher::Searcher) + print(io, "ColBERT Searcher:\n") + print(io, " Model: $(searcher.config.checkpoint)\n") + print(io, " Dimension: $(searcher.config.dim)\n") + print(io, " Index path: $(searcher.config.index_path)\n") + print(io, " nprobe: $(searcher.config.nprobe)\n") + print(io, " ncandidates: $(searcher.config.ncandidates)\n") + print(io, " Embeddings:\n") + print(io, " Total: $(sum(searcher.doclens))\n") + print(io, " Centroids: $(size(searcher.centroids,2))\n") +end