Skip to content

Commit

Permalink
Merge #136168
Browse files Browse the repository at this point in the history
136168: vecstore: store improvements and bug fixes r=drewkimball a=andy-kimball

1. When a new index is created, expect the store to create an empty root
partition. While the root partition can be updated, it can never be
deleted.

2. Replace the InMemoryStore UnmarshalBinary member function with a
LoadInMemoryStore function in order to avoid bugs where unmarshaling is
attempted on an already-initialized store.

3. Fix bug where fixupProcessor.Wait can panic when there are multiple
goroutines waiting for all fixups to be processed. Switch to a sync.Cond
to support that case.

Epic: CRDB-42943

Release note: None

Co-authored-by: Andrew Kimball <[email protected]>
  • Loading branch information
craig[bot] and andy-kimball committed Dec 3, 2024
2 parents 4f654c7 + a01728b commit b3fec61
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 89 deletions.
18 changes: 4 additions & 14 deletions pkg/cmd/vecbench/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,12 @@ func downloadDataset(ctx context.Context, datasetName string) {

// Use progressWriter to track download progress
var buf bytes.Buffer
progressWriter := &progressWriter{
writer := &progressWriter{
Writer: &buf,
Total: attrs.Size,
}

if _, err = io.Copy(progressWriter, reader); err != nil {
if _, err = io.Copy(writer, reader); err != nil {
log.Fatalf("Failed to copy object data: %v", err)
}

Expand Down Expand Up @@ -348,15 +348,6 @@ func buildIndex(ctx context.Context, datasetName string) {
panic(err)
}

// Insert empty root partition.
func() {
txn := beginTransaction(ctx, store)
defer commitTransaction(ctx, store, txn)
if err := index.CreateRoot(ctx, txn); err != nil {
panic(err)
}
}()

// Create unique primary key for each vector in a single large byte buffer.
primaryKeys := make([]byte, data.Train.Count*4)
for i := 0; i < data.Train.Count; i++ {
Expand Down Expand Up @@ -439,13 +430,12 @@ func loadStore(fileName string) *vecstore.InMemoryStore {
panic(err)
}

var inMemStore vecstore.InMemoryStore
err = inMemStore.UnmarshalBinary(data)
inMemStore, err := vecstore.LoadInMemoryStore(data)
if err != nil {
panic(err)
}

return &inMemStore
return inMemStore
}

// loadDataset deserializes a dataset saved as a gob file.
Expand Down
25 changes: 14 additions & 11 deletions pkg/sql/vecindex/fixup_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ type fixupProcessor struct {

// pendingVectors tracks pending fixups for deleting vectors.
pendingVectors map[string]bool

// waitForFixups broadcasts to any waiters when all fixups are processed.
waitForFixups sync.Cond
}

// --------------------------------------------------
Expand All @@ -101,10 +104,6 @@ type fixupProcessor struct {
// maxFixups limit has been reached.
fixupsLimitHit log.EveryN

// pendingCount tracks the number of pending fixups that still need to be
// processed.
pendingCount sync.WaitGroup

// --------------------------------------------------
// The following fields should only be accessed on a single background
// goroutine (or a single foreground goroutine in deterministic tests).
Expand Down Expand Up @@ -135,6 +134,7 @@ func (fp *fixupProcessor) Init(index *VectorIndex, seed int64) {
}
fp.mu.pendingPartitions = make(map[partitionFixupKey]bool, maxFixups)
fp.mu.pendingVectors = make(map[string]bool, maxFixups)
fp.mu.waitForFixups.L = &fp.mu
fp.fixups = make(chan fixup, maxFixups)
fp.fixupsLimitHit = log.Every(time.Second)
}
Expand Down Expand Up @@ -197,7 +197,11 @@ func (fp *fixupProcessor) Start(ctx context.Context) {
// Wait blocks until all pending fixups have been processed by the background
// goroutine. This is useful in testing.
func (fp *fixupProcessor) Wait() {
fp.pendingCount.Wait()
fp.mu.Lock()
defer fp.mu.Unlock()
for len(fp.mu.pendingVectors) > 0 || len(fp.mu.pendingPartitions) > 0 {
fp.mu.waitForFixups.Wait()
}
}

// runAll processes all fixups in the queue. This should only be called by tests
Expand Down Expand Up @@ -270,9 +274,6 @@ func (fp *fixupProcessor) run(ctx context.Context, wait bool) (ok bool, err erro
fp.mu.Lock()
defer fp.mu.Unlock()

// Decrement the number of pending fixups.
fp.pendingCount.Done()

switch next.Type {
case splitFixup, mergeFixup:
key := partitionFixupKey{Type: next.Type, PartitionKey: next.PartitionKey}
Expand All @@ -282,6 +283,11 @@ func (fp *fixupProcessor) run(ctx context.Context, wait bool) (ok bool, err erro
delete(fp.mu.pendingVectors, string(next.VectorKey))
}

// If there are no more pending fixups, notify any waiters.
if len(fp.mu.pendingPartitions) == 0 && len(fp.mu.pendingVectors) == 0 {
fp.mu.waitForFixups.Broadcast()
}

return true, err
}

Expand Down Expand Up @@ -319,9 +325,6 @@ func (fp *fixupProcessor) addFixup(ctx context.Context, fixup fixup) {
panic(errors.AssertionFailedf("unknown fixup %d", fixup.Type))
}

// Increment the number of pending fixups.
fp.pendingCount.Add(1)

// Note that the channel send operation should never block, since it has
// maxFixups capacity.
fp.fixups <- fixup
Expand Down
82 changes: 49 additions & 33 deletions pkg/sql/vecindex/vecstore/in_memory_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ type InMemoryStore struct {
txnLock syncutil.RWMutex
mu struct {
syncutil.Mutex
index map[PartitionKey]*Partition
nextKey PartitionKey
vectors map[string]vector.T
stats IndexStats
partitions map[PartitionKey]*Partition
nextKey PartitionKey
vectors map[string]vector.T
stats IndexStats
}
}

Expand All @@ -76,9 +76,21 @@ func NewInMemoryStore(dims int, seed int64) *InMemoryStore {
dims: dims,
seed: seed,
}
st.mu.index = make(map[PartitionKey]*Partition)
st.mu.partitions = make(map[PartitionKey]*Partition)

// Create empty root partition.
var empty vector.Set
quantizer := quantize.NewUnQuantizer(dims)
quantizedSet := quantizer.Quantize(context.Background(), &empty)
st.mu.partitions[RootKey] = &Partition{
quantizer: quantizer,
quantizedSet: quantizedSet,
level: LeafLevel,
}

st.mu.nextKey = RootKey + 1
st.mu.vectors = make(map[string]vector.T)
st.mu.stats.NumPartitions = 1
return st
}

Expand All @@ -98,7 +110,7 @@ func (s *InMemoryStore) CommitTransaction(ctx context.Context, txn Txn) error {
s.mu.Lock()
defer s.mu.Unlock()

partition, ok := s.mu.index[inMemTxn.unbalancedKey]
partition, ok := s.mu.partitions[inMemTxn.unbalancedKey]
if ok && partition.Count() == 0 && partition.Level() > LeafLevel {
panic(errors.AssertionFailedf(
"K-means tree is unbalanced, with empty non-leaf partition %d", inMemTxn.unbalancedKey))
Expand Down Expand Up @@ -135,7 +147,7 @@ func (s *InMemoryStore) GetPartition(
s.mu.Lock()
defer s.mu.Unlock()

partition, ok := s.mu.index[partitionKey]
partition, ok := s.mu.partitions[partitionKey]
if !ok {
return nil, ErrPartitionNotFound
}
Expand All @@ -152,9 +164,9 @@ func (s *InMemoryStore) SetRootPartition(ctx context.Context, txn Txn, partition
s.mu.Lock()
defer s.mu.Unlock()

_, ok := s.mu.index[RootKey]
_, ok := s.mu.partitions[RootKey]
if !ok {
s.mu.stats.NumPartitions++
panic(errors.AssertionFailedf("the root partition cannot be found"))
}

// Grow or shrink CVStats slice if a new level is being added or removed.
Expand All @@ -164,7 +176,7 @@ func (s *InMemoryStore) SetRootPartition(ctx context.Context, txn Txn, partition
}
s.mu.stats.CVStats = s.mu.stats.CVStats[:expectedLevels]

s.mu.index[RootKey] = partition
s.mu.partitions[RootKey] = partition
return nil
}

Expand All @@ -179,7 +191,7 @@ func (s *InMemoryStore) InsertPartition(

partitionKey := s.mu.nextKey
s.mu.nextKey++
s.mu.index[partitionKey] = partition
s.mu.partitions[partitionKey] = partition
s.mu.stats.NumPartitions++
return partitionKey, nil
}
Expand All @@ -193,11 +205,15 @@ func (s *InMemoryStore) DeletePartition(
s.mu.Lock()
defer s.mu.Unlock()

_, ok := s.mu.index[partitionKey]
if partitionKey == RootKey {
panic(errors.AssertionFailedf("cannot delete the root partition"))
}

_, ok := s.mu.partitions[partitionKey]
if !ok {
return ErrPartitionNotFound
}
delete(s.mu.index, partitionKey)
delete(s.mu.partitions, partitionKey)
s.mu.stats.NumPartitions--
return nil
}
Expand All @@ -211,7 +227,7 @@ func (s *InMemoryStore) AddToPartition(
s.mu.Lock()
defer s.mu.Unlock()

partition, ok := s.mu.index[partitionKey]
partition, ok := s.mu.partitions[partitionKey]
if !ok {
return 0, ErrPartitionNotFound
}
Expand All @@ -230,7 +246,7 @@ func (s *InMemoryStore) RemoveFromPartition(
s.mu.Lock()
defer s.mu.Unlock()

partition, ok := s.mu.index[partitionKey]
partition, ok := s.mu.partitions[partitionKey]
if !ok {
return 0, ErrPartitionNotFound
}
Expand Down Expand Up @@ -265,7 +281,7 @@ func (s *InMemoryStore) SearchPartitions(
defer s.mu.Unlock()

for i := 0; i < len(partitionKeys); i++ {
partition, ok := s.mu.index[partitionKeys[i]]
partition, ok := s.mu.partitions[partitionKeys[i]]
if !ok {
return 0, ErrPartitionNotFound
}
Expand Down Expand Up @@ -295,7 +311,7 @@ func (s *InMemoryStore) GetFullVectors(ctx context.Context, txn Txn, refs []Vect
ref := &refs[i]
if ref.Key.PartitionKey != InvalidKey {
// Return the partition's centroid.
partition, ok := s.mu.index[ref.Key.PartitionKey]
partition, ok := s.mu.partitions[ref.Key.PartitionKey]
if !ok {
return ErrPartitionNotFound
}
Expand Down Expand Up @@ -417,14 +433,14 @@ func (s *InMemoryStore) MarshalBinary() (data []byte, err error) {
storeProto := StoreProto{
Dims: s.dims,
Seed: s.seed,
Partitions: make([]PartitionProto, 0, len(s.mu.index)),
Partitions: make([]PartitionProto, 0, len(s.mu.partitions)),
NextKey: s.mu.nextKey,
Vectors: make([]VectorProto, 0, len(s.mu.vectors)),
Stats: s.mu.stats,
}

// Remap partitions to protobufs.
for partitionKey, partition := range s.mu.index {
for partitionKey, partition := range s.mu.partitions {
partitionProto := PartitionProto{
PartitionKey: partitionKey,
ChildKeys: partition.ChildKeys(),
Expand All @@ -451,24 +467,24 @@ func (s *InMemoryStore) MarshalBinary() (data []byte, err error) {
return protoutil.Marshal(&storeProto)
}

// UnmarshalBinary loads the in-memory store from bytes that were previously
// LoadInMemoryStore loads the in-memory store from bytes that were previously
// saved by MarshalBinary.
func (s *InMemoryStore) UnmarshalBinary(data []byte) error {
s.mu.Lock()
defer s.mu.Unlock()

func LoadInMemoryStore(data []byte) (*InMemoryStore, error) {
// Unmarshal bytes into a protobuf.
var storeProto StoreProto
if err := protoutil.Unmarshal(data, &storeProto); err != nil {
return err
return nil, err
}

// Construct the InMemoryStore object.
s.seed = storeProto.Seed
s.mu.index = make(map[PartitionKey]*Partition, len(storeProto.Partitions))
s.mu.nextKey = storeProto.NextKey
s.mu.vectors = make(map[string]vector.T, len(storeProto.Vectors))
s.mu.stats = storeProto.Stats
inMemStore := &InMemoryStore{
dims: storeProto.Dims,
seed: storeProto.Seed,
}
inMemStore.mu.partitions = make(map[PartitionKey]*Partition, len(storeProto.Partitions))
inMemStore.mu.nextKey = storeProto.NextKey
inMemStore.mu.vectors = make(map[string]vector.T, len(storeProto.Vectors))
inMemStore.mu.stats = storeProto.Stats

raBitQuantizer := quantize.NewRaBitQuantizer(storeProto.Dims, storeProto.Seed)
unquantizer := quantize.NewUnQuantizer(storeProto.Dims)
Expand All @@ -487,16 +503,16 @@ func (s *InMemoryStore) UnmarshalBinary(data []byte) error {
partition.quantizer = unquantizer
partition.quantizedSet = partitionProto.UnQuantized
}
s.mu.index[partitionProto.PartitionKey] = &partition
inMemStore.mu.partitions[partitionProto.PartitionKey] = &partition
}

// Insert vectors into the in-memory store.
for i := range storeProto.Vectors {
vectorProto := storeProto.Vectors[i]
s.mu.vectors[string(vectorProto.PrimaryKey)] = vectorProto.Vector
inMemStore.mu.vectors[string(vectorProto.PrimaryKey)] = vectorProto.Vector
}

return nil
return inMemStore, nil
}

// acquireTxnLock acquires a data or partition lock within the scope of the
Expand Down
22 changes: 8 additions & 14 deletions pkg/sql/vecindex/vecstore/in_memory_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,10 @@ func TestInMemoryStore(t *testing.T) {
}, vectors)
})

t.Run("insert empty root partition into the store", func(t *testing.T) {
t.Run("search empty root partition", func(t *testing.T) {
txn := beginTransaction(ctx, t, store)
defer commitTransaction(ctx, t, store, txn)

vectors := vector.MakeSet(2)
quantizedSet := quantizer.Quantize(ctx, &vectors)
root := NewPartition(quantizer, quantizedSet, []ChildKey{}, LeafLevel)
require.NoError(t, store.SetRootPartition(ctx, txn, root))

searchSet := SearchSet{MaxResults: 2}
partitionCounts := []int{0}
level, err := store.SearchPartitions(
Expand Down Expand Up @@ -373,7 +368,7 @@ func TestInMemoryStoreMarshalling(t *testing.T) {
dims: 2,
seed: 42,
}
store.mu.index = map[PartitionKey]*Partition{
store.mu.partitions = map[PartitionKey]*Partition{
10: {
quantizer: unquantizer,
quantizedSet: &quantize.UnQuantizedVectorSet{
Expand Down Expand Up @@ -418,15 +413,14 @@ func TestInMemoryStoreMarshalling(t *testing.T) {
data, err := store.MarshalBinary()
require.NoError(t, err)

var store2 InMemoryStore
err = store2.UnmarshalBinary(data)
store2, err := LoadInMemoryStore(data)
require.NoError(t, err)

require.Len(t, store2.mu.index, 2)
require.Equal(t, Level(1), store2.mu.index[10].level)
require.Equal(t, 3, store2.mu.index[10].quantizedSet.GetCount())
require.Equal(t, 2, store2.mu.index[20].quantizer.GetOriginalDims())
require.Len(t, store2.mu.index[20].childKeys, 3)
require.Len(t, store2.mu.partitions, 2)
require.Equal(t, Level(1), store2.mu.partitions[10].level)
require.Equal(t, 3, store2.mu.partitions[10].quantizedSet.GetCount())
require.Equal(t, 2, store2.mu.partitions[20].quantizer.GetOriginalDims())
require.Len(t, store2.mu.partitions[20].childKeys, 3)
require.Equal(t, PartitionKey(100), store2.mu.nextKey)
require.Len(t, store2.mu.vectors, 2)
require.Equal(t, vector.T{12, 13}, store2.mu.vectors[string([]byte{3, 4})])
Expand Down
Loading

0 comments on commit b3fec61

Please sign in to comment.