diff --git a/middleware.go b/middleware.go index 89dcd3f7..3d79b639 100644 --- a/middleware.go +++ b/middleware.go @@ -19,6 +19,8 @@ type JWTMiddleware struct { validateOnOptions bool } +type JWTMiddlewares []*JWTMiddleware + // ValidateToken takes in a string JWT and makes sure it is valid and // returns the valid token. If it is not valid it will return nil and // an error message describing why validation failed. @@ -90,3 +92,66 @@ func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +// CheckJWTMulti is the main JWTMiddleware function which performs the main logic. It +// is passed a http.Handler which will be called if the JWT passes validation for one +// of the JWTMiddleware configs in a slice. +func (mm JWTMiddlewares) CheckJWTMulti(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < len(mm); i++ { + m := mm[i] + isLast := true + if (i + 1) == len(mm) { + isLast = true + } else { + isLast = false + } + // If we don't validate on OPTIONS and this is OPTIONS + // then continue onto next without validating. + if !m.validateOnOptions && r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + return + } + + token, err := m.tokenExtractor(r) + if err != nil { + // This is not ErrJWTMissing because an error here means that the + // tokenExtractor had an error and _not_ that the token was missing. + m.errorHandler(w, r, fmt.Errorf("error extracting token: %w", err)) + return + } + + if token == "" { + // If credentials are optional continue + // onto next without validating. + if m.credentialsOptional { + next.ServeHTTP(w, r) + return + } + + if !isLast { + continue + } + // Credentials were not optional so we error. + m.errorHandler(w, r, ErrJWTMissing) + return + } + + // Validate the token using the token validator. + validToken, err := m.validateToken(r.Context(), token) + if err != nil { + if !isLast { + continue + } + m.errorHandler(w, r, &invalidError{details: err}) + return + } + + // No err means we have a valid token, so set + // it into the context and continue onto next. + r = r.Clone(context.WithValue(r.Context(), ContextKey{}, validToken)) + next.ServeHTTP(w, r) + return + } + }) +}