diff --git a/adapters/adapter.go b/adapters/adapter.go index 31cdbb2..a88d43b 100644 --- a/adapters/adapter.go +++ b/adapters/adapter.go @@ -16,6 +16,7 @@ func init() { gob.Register(&GothSession{}) gob.Register(&GothTeam{}) gob.Register(&GothVerificationToken{}) + gob.Register(&GothCsrfToken{}) } // CsrfTokenGenerator is a function that generates a CSRF token. diff --git a/adapters/gorm/gorm.go b/adapters/gorm/gorm.go index 481156c..bf6c048 100644 --- a/adapters/gorm/gorm.go +++ b/adapters/gorm/gorm.go @@ -85,8 +85,16 @@ func (a *gormAdapter) GetUser(ctx context.Context, id uuid.UUID) (adapters.GothU // CreateSession is a helper function to create a new session. func (a *gormAdapter) CreateSession(ctx context.Context, userID uuid.UUID, expires time.Time) (adapters.GothSession, error) { - session := adapters.GothSession{UserID: userID, SessionToken: uuid.NewString(), ExpiresAt: expires} - err := a.db.WithContext(ctx).Create(&session).Error + session := adapters.GothSession{ + UserID: userID, + SessionToken: uuid.NewString(), + ExpiresAt: expires, + CsrfToken: adapters.GothCsrfToken{ + Token: uuid.NewString(), // creates a token that is used to prevent CSRF attacks + ExpiresAt: time.Now().Add(24 * time.Hour), + }, + } + err := a.db.Session(&gorm.Session{FullSaveAssociations: true}).WithContext(ctx).Create(&session).Error if err != nil { return adapters.GothSession{}, goth.ErrBadSession }