diff --git a/pkg/store/postgres/schema.go b/pkg/store/postgres/schema.go index a351855a..c7a2939c 100644 --- a/pkg/store/postgres/schema.go +++ b/pkg/store/postgres/schema.go @@ -36,6 +36,15 @@ type SessionSchema struct { User *UserSchema `bun:"rel:belongs-to,join:user_id=user_id,on_delete:cascade" yaml:"-"` } +var _ bun.BeforeAppendModelHook = (*SessionSchema)(nil) + +func (s *SessionSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + // BeforeCreateTable is a marker method to ensure uniform interface across all table models - used in table creation iterator func (s *SessionSchema) BeforeCreateTable( _ context.Context, @@ -47,8 +56,6 @@ func (s *SessionSchema) BeforeCreateTable( type MessageStoreSchema struct { bun.BaseModel `bun:"table:message,alias:m" yaml:"-"` - // TODO: replace UUIDs with sortable ULIDs or UUIDv7s to avoid having to have both a UUID and an ID. - // see https://blog.daveallie.com/ulid-primary-keys UUID uuid.UUID `bun:",pk,type:uuid,default:gen_random_uuid()" yaml:"uuid"` // ID is used only for sorting / slicing purposes as we can't sort by CreatedAt for messages created simultaneously ID int64 `bun:",autoincrement" yaml:"id,omitempty"` @@ -63,6 +70,15 @@ type MessageStoreSchema struct { Session *SessionSchema `bun:"rel:belongs-to,join:session_id=session_id,on_delete:cascade" yaml:"-"` } +var _ bun.BeforeAppendModelHook = (*MessageStoreSchema)(nil) + +func (s *MessageStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + func (s *MessageStoreSchema) BeforeCreateTable( _ context.Context, _ *bun.CreateTableQuery, @@ -86,6 +102,15 @@ type MessageVectorStoreSchema struct { Message *MessageStoreSchema `bun:"rel:belongs-to,join:message_uuid=uuid,on_delete:cascade"` } +var _ bun.BeforeAppendModelHook = (*MessageVectorStoreSchema)(nil) + +func (s *MessageVectorStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + func (s *MessageVectorStoreSchema) BeforeCreateTable( _ context.Context, _ *bun.CreateTableQuery, @@ -109,6 +134,15 @@ type SummaryStoreSchema struct { Message *MessageStoreSchema `bun:"rel:belongs-to,join:summary_point_uuid=uuid,on_delete:cascade"` } +var _ bun.BeforeAppendModelHook = (*SummaryStoreSchema)(nil) + +func (s *SummaryStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + func (s *SummaryStoreSchema) BeforeCreateTable( _ context.Context, _ *bun.CreateTableQuery, @@ -129,6 +163,15 @@ func (s *DocumentCollectionSchema) BeforeCreateTable( return nil } +var _ bun.BeforeAppendModelHook = (*DocumentCollectionSchema)(nil) + +func (s *DocumentCollectionSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + // DocumentSchemaTemplate represents the schema template for Document tables. // MessageEmbedding is manually added when createDocumentTable is run in order to set the correct dimensions. // This means the embedding is not returned when querying using bun. @@ -152,6 +195,15 @@ type UserSchema struct { Metadata map[string]interface{} `bun:"type:jsonb,nullzero,json_use_number" yaml:"metadata,omitempty"` } +var _ bun.BeforeAppendModelHook = (*UserSchema)(nil) + +func (u *UserSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + u.UpdatedAt = time.Now() + } + return nil +} + // BeforeCreateTable is a marker method to ensure uniform interface across all table models - used in table creation iterator func (u *UserSchema) BeforeCreateTable( _ context.Context, diff --git a/pkg/store/postgres/schema_test.go b/pkg/store/postgres/schema_test.go index bdefa112..e13a29c5 100644 --- a/pkg/store/postgres/schema_test.go +++ b/pkg/store/postgres/schema_test.go @@ -2,9 +2,12 @@ package postgres import ( "context" + "reflect" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/uptrace/bun" ) func TestEnsurePostgresSchemaSetup(t *testing.T) { @@ -26,13 +29,50 @@ func TestEnsurePostgresSchemaSetup(t *testing.T) { } func TestCreateDocumentTable(t *testing.T) { - ctx := context.Background() - collection := NewTestCollectionDAO(3) tableName, err := generateDocumentTableName(&collection) assert.NoError(t, err) - err = createDocumentTable(ctx, testDB, tableName, collection.EmbeddingDimensions) + err = createDocumentTable(testCtx, testDB, tableName, collection.EmbeddingDimensions) assert.NoError(t, err) } + +func TestUpdatedAtIsSetAfterUpdate(t *testing.T) { + // Define a list of all schemas + schemas := []bun.BeforeAppendModelHook{ + &SessionSchema{}, + &MessageStoreSchema{}, + &SummaryStoreSchema{}, + &MessageVectorStoreSchema{}, + &UserSchema{}, + &DocumentCollectionSchema{}, + } + + // Iterate over all schemas + for _, schema := range schemas { + // Create a new instance of the schema + instance := reflect.New(reflect.TypeOf(schema).Elem()).Interface().(bun.BeforeAppendModelHook) + + // Set the UpdatedAt field to a time far in the past + reflect.ValueOf(instance). + Elem(). + FieldByName("UpdatedAt"). + Set(reflect.ValueOf(time.Unix(0, 0))) + + // Create a dummy UpdateQuery + updateQuery := &bun.UpdateQuery{} + + // Call the BeforeAppendModel method, which should update the UpdatedAt field + err := instance.BeforeAppendModel(context.Background(), updateQuery) + assert.NoError(t, err) + + // Check that the UpdatedAt field was updated + assert.True( + t, + reflect.ValueOf(instance).Elem().FieldByName("UpdatedAt").Interface().(time.Time).After( + time.Now().Add(-time.Minute), + ), + ) + } +}