diff --git a/pkg/models/memory.go b/pkg/models/memory.go index 10d7f967..d7c7cc8d 100644 --- a/pkg/models/memory.go +++ b/pkg/models/memory.go @@ -9,6 +9,7 @@ import ( type Message struct { UUID uuid.UUID `json:"uuid"` CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` Role string `json:"role"` Content string `json:"content"` Metadata map[string]interface{} `json:"metadata,omitempty"` diff --git a/pkg/store/postgres/documents_test.go b/pkg/store/postgres/documents_test.go index ffad3756..e3353328 100644 --- a/pkg/store/postgres/documents_test.go +++ b/pkg/store/postgres/documents_test.go @@ -87,6 +87,7 @@ func TestCollectionUpdate(t *testing.T) { err = collection.GetByName(ctx) assert.NoError(t, err) assert.Equal(t, expectedDimensions, collection.EmbeddingDimensions) + assert.Less(t, collection.CreatedAt, collection.UpdatedAt) } func TestCollectionGetByName(t *testing.T) { diff --git a/pkg/store/postgres/memorystore_test.go b/pkg/store/postgres/memorystore_test.go index a12868ed..d0699e98 100644 --- a/pkg/store/postgres/memorystore_test.go +++ b/pkg/store/postgres/memorystore_test.go @@ -207,6 +207,12 @@ func verifyMessagesInDB( resultMessages[i].Metadata, "Expected Metadata to be equal", ) + assert.Less( + t, + resultMessages[i].CreatedAt, + resultMessages[i].UpdatedAt, + "CreatedAt should be less than UpdatedAt", + ) } } diff --git a/pkg/store/postgres/message_metadata.go b/pkg/store/postgres/message_metadata.go index 99360de9..ea1c0a08 100644 --- a/pkg/store/postgres/message_metadata.go +++ b/pkg/store/postgres/message_metadata.go @@ -106,7 +106,7 @@ func putMessageMetadataTx( retrievedMessage.UUID = message.UUID _, err = tx.NewUpdate(). Model(&retrievedMessage). - Column("metadata"). + Column("metadata", "updated_at"). Where("session_id = ? AND uuid = ?", sessionID, message.UUID). Returning("*"). Exec(ctx) diff --git a/pkg/store/postgres/messages.go b/pkg/store/postgres/messages.go index 9799fd8c..fb3c9019 100644 --- a/pkg/store/postgres/messages.go +++ b/pkg/store/postgres/messages.go @@ -58,6 +58,7 @@ func putMessages( UUID: msg.UUID, SessionID: sessionID, CreatedAt: msg.CreatedAt, + UpdatedAt: msg.UpdatedAt, Role: msg.Role, Content: msg.Content, TokenCount: msg.TokenCount, @@ -68,7 +69,7 @@ func putMessages( // Insert messages _, err = db.NewInsert(). Model(&pgMessages). - Column("id", "created_at", "uuid", "session_id", "role", "content", "token_count"). + Column("id", "created_at", "uuid", "session_id", "role", "content", "token_count", "updated_at"). On("CONFLICT (uuid) DO UPDATE"). Exec(ctx) if err != nil { diff --git a/pkg/store/postgres/schema.go b/pkg/store/postgres/schema.go index a351855a..c7a2939c 100644 --- a/pkg/store/postgres/schema.go +++ b/pkg/store/postgres/schema.go @@ -36,6 +36,15 @@ type SessionSchema struct { User *UserSchema `bun:"rel:belongs-to,join:user_id=user_id,on_delete:cascade" yaml:"-"` } +var _ bun.BeforeAppendModelHook = (*SessionSchema)(nil) + +func (s *SessionSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + // BeforeCreateTable is a marker method to ensure uniform interface across all table models - used in table creation iterator func (s *SessionSchema) BeforeCreateTable( _ context.Context, @@ -47,8 +56,6 @@ func (s *SessionSchema) BeforeCreateTable( type MessageStoreSchema struct { bun.BaseModel `bun:"table:message,alias:m" yaml:"-"` - // TODO: replace UUIDs with sortable ULIDs or UUIDv7s to avoid having to have both a UUID and an ID. - // see https://blog.daveallie.com/ulid-primary-keys UUID uuid.UUID `bun:",pk,type:uuid,default:gen_random_uuid()" yaml:"uuid"` // ID is used only for sorting / slicing purposes as we can't sort by CreatedAt for messages created simultaneously ID int64 `bun:",autoincrement" yaml:"id,omitempty"` @@ -63,6 +70,15 @@ type MessageStoreSchema struct { Session *SessionSchema `bun:"rel:belongs-to,join:session_id=session_id,on_delete:cascade" yaml:"-"` } +var _ bun.BeforeAppendModelHook = (*MessageStoreSchema)(nil) + +func (s *MessageStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + func (s *MessageStoreSchema) BeforeCreateTable( _ context.Context, _ *bun.CreateTableQuery, @@ -86,6 +102,15 @@ type MessageVectorStoreSchema struct { Message *MessageStoreSchema `bun:"rel:belongs-to,join:message_uuid=uuid,on_delete:cascade"` } +var _ bun.BeforeAppendModelHook = (*MessageVectorStoreSchema)(nil) + +func (s *MessageVectorStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + func (s *MessageVectorStoreSchema) BeforeCreateTable( _ context.Context, _ *bun.CreateTableQuery, @@ -109,6 +134,15 @@ type SummaryStoreSchema struct { Message *MessageStoreSchema `bun:"rel:belongs-to,join:summary_point_uuid=uuid,on_delete:cascade"` } +var _ bun.BeforeAppendModelHook = (*SummaryStoreSchema)(nil) + +func (s *SummaryStoreSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + func (s *SummaryStoreSchema) BeforeCreateTable( _ context.Context, _ *bun.CreateTableQuery, @@ -129,6 +163,15 @@ func (s *DocumentCollectionSchema) BeforeCreateTable( return nil } +var _ bun.BeforeAppendModelHook = (*DocumentCollectionSchema)(nil) + +func (s *DocumentCollectionSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + s.UpdatedAt = time.Now() + } + return nil +} + // DocumentSchemaTemplate represents the schema template for Document tables. // MessageEmbedding is manually added when createDocumentTable is run in order to set the correct dimensions. // This means the embedding is not returned when querying using bun. @@ -152,6 +195,15 @@ type UserSchema struct { Metadata map[string]interface{} `bun:"type:jsonb,nullzero,json_use_number" yaml:"metadata,omitempty"` } +var _ bun.BeforeAppendModelHook = (*UserSchema)(nil) + +func (u *UserSchema) BeforeAppendModel(_ context.Context, query bun.Query) error { + if _, ok := query.(*bun.UpdateQuery); ok { + u.UpdatedAt = time.Now() + } + return nil +} + // BeforeCreateTable is a marker method to ensure uniform interface across all table models - used in table creation iterator func (u *UserSchema) BeforeCreateTable( _ context.Context, diff --git a/pkg/store/postgres/schema_test.go b/pkg/store/postgres/schema_test.go index bdefa112..e13a29c5 100644 --- a/pkg/store/postgres/schema_test.go +++ b/pkg/store/postgres/schema_test.go @@ -2,9 +2,12 @@ package postgres import ( "context" + "reflect" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/uptrace/bun" ) func TestEnsurePostgresSchemaSetup(t *testing.T) { @@ -26,13 +29,50 @@ func TestEnsurePostgresSchemaSetup(t *testing.T) { } func TestCreateDocumentTable(t *testing.T) { - ctx := context.Background() - collection := NewTestCollectionDAO(3) tableName, err := generateDocumentTableName(&collection) assert.NoError(t, err) - err = createDocumentTable(ctx, testDB, tableName, collection.EmbeddingDimensions) + err = createDocumentTable(testCtx, testDB, tableName, collection.EmbeddingDimensions) assert.NoError(t, err) } + +func TestUpdatedAtIsSetAfterUpdate(t *testing.T) { + // Define a list of all schemas + schemas := []bun.BeforeAppendModelHook{ + &SessionSchema{}, + &MessageStoreSchema{}, + &SummaryStoreSchema{}, + &MessageVectorStoreSchema{}, + &UserSchema{}, + &DocumentCollectionSchema{}, + } + + // Iterate over all schemas + for _, schema := range schemas { + // Create a new instance of the schema + instance := reflect.New(reflect.TypeOf(schema).Elem()).Interface().(bun.BeforeAppendModelHook) + + // Set the UpdatedAt field to a time far in the past + reflect.ValueOf(instance). + Elem(). + FieldByName("UpdatedAt"). + Set(reflect.ValueOf(time.Unix(0, 0))) + + // Create a dummy UpdateQuery + updateQuery := &bun.UpdateQuery{} + + // Call the BeforeAppendModel method, which should update the UpdatedAt field + err := instance.BeforeAppendModel(context.Background(), updateQuery) + assert.NoError(t, err) + + // Check that the UpdatedAt field was updated + assert.True( + t, + reflect.ValueOf(instance).Elem().FieldByName("UpdatedAt").Interface().(time.Time).After( + time.Now().Add(-time.Minute), + ), + ) + } +} diff --git a/pkg/store/postgres/session.go b/pkg/store/postgres/session.go index 1073f276..112bd3f2 100644 --- a/pkg/store/postgres/session.go +++ b/pkg/store/postgres/session.go @@ -150,7 +150,7 @@ func (dao *SessionDAO) updateSession( Metadata: session.Metadata, DeletedAt: time.Time{}, // Intentionally overwrite soft-delete with zero value } - var columns = []string{"deleted_at"} + var columns = []string{"deleted_at", "updated_at"} if session.Metadata != nil { columns = append(columns, "metadata") } diff --git a/pkg/store/postgres/session_test.go b/pkg/store/postgres/session_test.go index a760d5f9..52564239 100644 --- a/pkg/store/postgres/session_test.go +++ b/pkg/store/postgres/session_test.go @@ -153,6 +153,7 @@ func TestSessionDAO_Update(t *testing.T) { assert.Equal(t, createdSession.SessionID, updatedSession.SessionID) assert.Equal(t, createdSession.UserID, updatedSession.UserID) assert.Equal(t, updateSession.Metadata, updatedSession.Metadata) + assert.Less(t, createdSession.UpdatedAt, updatedSession.UpdatedAt) } func TestSessionDAO_UpdateWithNilMetadata(t *testing.T) { diff --git a/pkg/store/postgres/userstore.go b/pkg/store/postgres/userstore.go index 58fa132b..412af487 100644 --- a/pkg/store/postgres/userstore.go +++ b/pkg/store/postgres/userstore.go @@ -125,7 +125,7 @@ func (dao *UserStoreDAO) updateUser( } r, err := dao.db.NewUpdate(). Model(&userDB). - Column("email", "first_name", "last_name", "metadata"). + Column("email", "first_name", "last_name", "metadata", "updated_at"). OmitZero(). Where("user_id = ?", user.UserID). Exec(ctx) diff --git a/pkg/store/postgres/userstore_test.go b/pkg/store/postgres/userstore_test.go index de287529..e426b393 100644 --- a/pkg/store/postgres/userstore_test.go +++ b/pkg/store/postgres/userstore_test.go @@ -63,7 +63,7 @@ func TestUserStoreDAO(t *testing.T) { }, Email: "email", } - _, err := userStore.Create(ctx, user) + createdUser, err := userStore.Create(ctx, user) assert.NoError(t, err) // Update the user with zero values @@ -77,10 +77,11 @@ func TestUserStoreDAO(t *testing.T) { assert.NoError(t, err) // Check that the updated user still has the original non-zero values - assert.Equal(t, user.Metadata, updatedUser.Metadata) - assert.Equal(t, user.Email, updatedUser.Email) + assert.Equal(t, createdUser.Metadata, updatedUser.Metadata) + assert.Equal(t, createdUser.Email, updatedUser.Email) // Bob should be the new first name assert.Equal(t, "bob", updatedUser.FirstName) + assert.Less(t, createdUser.UpdatedAt, updatedUser.UpdatedAt) }) t.Run("Update Non-Existant User should result in NotFoundError", func(t *testing.T) {