-
Notifications
You must be signed in to change notification settings - Fork 5
/
middleware.go
262 lines (226 loc) · 9.77 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
/*
Copyright © 2024 Acronis International GmbH.
Released under MIT license.
*/
package authkit
import (
"context"
"errors"
"net/http"
"strings"
"github.com/acronis/go-appkit/httpserver/middleware"
"github.com/acronis/go-appkit/log"
"github.com/acronis/go-appkit/restapi"
"github.com/acronis/go-authkit/idptoken"
"github.com/acronis/go-authkit/internal/idputil"
"github.com/acronis/go-authkit/internal/metrics"
"github.com/acronis/go-authkit/jwt"
)
// HeaderAuthorization contains the name of HTTP header with data that is used for authentication and authorization.
const HeaderAuthorization = "Authorization"
// Authentication and authorization error codes.
// We are using "var" here because some services may want to use different error codes.
var (
ErrCodeBearerTokenMissing = "bearerTokenMissing"
ErrCodeAuthenticationFailed = "authenticationFailed"
ErrCodeAuthorizationFailed = "authorizationFailed"
)
// Authentication error messages.
// We are using "var" here because some services may want to use different error messages.
var (
ErrMessageBearerTokenMissing = "Authorization bearer token is missing."
ErrMessageAuthenticationFailed = "Authentication is failed."
ErrMessageAuthorizationFailed = "Authorization is failed."
)
type ctxKey int
const (
ctxKeyJWTClaims ctxKey = iota
ctxKeyBearerToken
)
// JWTParser is an interface for parsing string representation of JWT.
type JWTParser interface {
Parse(ctx context.Context, token string) (jwt.Claims, error)
}
// CachingJWTParser does the same as JWTParser but stores parsed JWT claims in cache.
type CachingJWTParser interface {
JWTParser
InvalidateCache(ctx context.Context)
}
// TokenIntrospector is an interface for introspecting tokens.
type TokenIntrospector interface {
IntrospectToken(ctx context.Context, token string) (idptoken.IntrospectionResult, error)
}
type jwtAuthHandler struct {
next http.Handler
errorDomain string
jwtParser JWTParser
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
promMetrics *metrics.PrometheusMetrics
}
type jwtAuthMiddlewareOpts struct {
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
prometheusLibInstanceLabel string
}
// JWTAuthMiddlewareOption is an option for JWTAuthMiddleware.
type JWTAuthMiddlewareOption func(options *jwtAuthMiddlewareOpts)
// WithJWTAuthMiddlewareVerifyAccess is an option to set a function that verifies access for JWTAuthMiddleware.
func WithJWTAuthMiddlewareVerifyAccess(verifyAccess func(r *http.Request, claims jwt.Claims) bool) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.verifyAccess = verifyAccess
}
}
// WithJWTAuthMiddlewareTokenIntrospector is an option to set a token introspector for JWTAuthMiddleware.
func WithJWTAuthMiddlewareTokenIntrospector(tokenIntrospector TokenIntrospector) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.tokenIntrospector = tokenIntrospector
}
}
// WithJWTAuthMiddlewareLoggerProvider is an option to set a logger provider for JWTAuthMiddleware.
func WithJWTAuthMiddlewareLoggerProvider(loggerProvider func(ctx context.Context) log.FieldLogger) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.loggerProvider = loggerProvider
}
}
// WithJWTAuthMiddlewarePrometheusLibInstanceLabel is an option to set a label for Prometheus metrics that are used by JWTAuthMiddleware.
func WithJWTAuthMiddlewarePrometheusLibInstanceLabel(label string) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.prometheusLibInstanceLabel = label
}
}
// JWTAuthMiddleware is a middleware that does authentication
// by Access Token from the "Authorization" HTTP header of incoming request.
// errorDomain is used for error responses. It is usually the name of the service that uses the middleware,
// and its goal is distinguishing errors from different services.
// It helps to understand where the error occurred and what service caused it.
// For example, if the "Authorization" HTTP header is missing, the middleware will return 401 with the following response body:
//
// {"error": {"domain": "MyService", "code": "bearerTokenMissing", "message": "Authorization bearer token is missing."}}
func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthMiddlewareOption) func(next http.Handler) http.Handler {
options := jwtAuthMiddlewareOpts{loggerProvider: middleware.GetLoggerFromContext}
for _, opt := range opts {
opt(&options)
}
return func(next http.Handler) http.Handler {
return &jwtAuthHandler{
next: next,
errorDomain: errorDomain,
jwtParser: jwtParser,
verifyAccess: options.verifyAccess,
tokenIntrospector: options.tokenIntrospector,
loggerProvider: options.loggerProvider,
promMetrics: metrics.GetPrometheusMetrics(options.prometheusLibInstanceLabel, metrics.SourceHTTPMiddleware),
}
}
}
func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
logger := idputil.GetLoggerFromProvider(r.Context(), h.loggerProvider)
bearerToken := GetBearerTokenFromRequest(r)
if bearerToken == "" {
apiErr := restapi.NewError(h.errorDomain, ErrCodeBearerTokenMissing, ErrMessageBearerTokenMissing)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
// Add the bearer token to the request context
r = r.WithContext(NewContextWithBearerToken(r.Context(), bearerToken))
var jwtClaims jwt.Claims
if h.tokenIntrospector != nil {
if introspectionResult, err := h.tokenIntrospector.IntrospectToken(r.Context(), bearerToken); err != nil {
switch {
case errors.Is(err, idptoken.ErrTokenIntrospectionNotNeeded):
// Do nothing. Access Token already contains all necessary information for authN/authZ.
logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token's introspection is not needed")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded)
case errors.Is(err, idptoken.ErrTokenNotIntrospectable):
// Token is not introspectable by some reason.
// In this case, we will parse it as JWT and use it for authZ.
logger.Warn("token is not introspectable, it will be used for authentication and authorization as is",
log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable)
case errors.Is(err, idptoken.ErrTokenIntrospectionInvalidClaims):
logger.Error("token's introspection failed because of invalid claims", log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusInvalidClaims)
h.respondAuthNFailedError(rw, logger)
return
default:
logger.Error("token's introspection failed", log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError)
h.respondAuthNFailedError(rw, logger)
return
}
} else {
if !introspectionResult.IsActive() {
logger.Warn("token was successfully introspected, but it is not active")
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive)
h.respondAuthNFailedError(rw, logger)
return
}
jwtClaims = introspectionResult.GetClaims()
logger.AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token was successfully introspected")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive)
}
}
if jwtClaims == nil {
var err error
if jwtClaims, err = h.jwtParser.Parse(r.Context(), bearerToken); err != nil {
logger.Error("authentication failed", log.Error(err))
h.respondAuthNFailedError(rw, logger)
return
}
}
// Add the JWT claims to the request context
r = r.WithContext(NewContextWithJWTClaims(r.Context(), jwtClaims))
if h.verifyAccess != nil {
// By passing a *http.Request to verifyAccess, we allow its implementations
// to inject new key/value pairs into the request context.
if !h.verifyAccess(r, jwtClaims) {
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthorizationFailed, ErrMessageAuthorizationFailed)
restapi.RespondError(rw, http.StatusForbidden, apiErr, logger)
return
}
}
h.next.ServeHTTP(rw, r)
}
func (h *jwtAuthHandler) respondAuthNFailedError(rw http.ResponseWriter, logger log.FieldLogger) {
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
}
// GetBearerTokenFromRequest extracts jwt token from request headers.
func GetBearerTokenFromRequest(r *http.Request) string {
authHeader := strings.TrimSpace(r.Header.Get(HeaderAuthorization))
if strings.HasPrefix(authHeader, "Bearer ") || strings.HasPrefix(authHeader, "bearer ") {
return authHeader[7:]
}
return ""
}
// NewContextWithJWTClaims creates a new context with JWT claims.
func NewContextWithJWTClaims(ctx context.Context, jwtClaims jwt.Claims) context.Context {
return context.WithValue(ctx, ctxKeyJWTClaims, jwtClaims)
}
// GetJWTClaimsFromContext extracts JWT claims from the context.
func GetJWTClaimsFromContext(ctx context.Context) jwt.Claims {
value := ctx.Value(ctxKeyJWTClaims)
if value == nil {
return nil
}
return value.(jwt.Claims)
}
// NewContextWithBearerToken creates a new context with token.
func NewContextWithBearerToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, ctxKeyBearerToken, token)
}
// GetBearerTokenFromContext extracts token from the context.
func GetBearerTokenFromContext(ctx context.Context) string {
value := ctx.Value(ctxKeyBearerToken)
if value == nil {
return ""
}
return value.(string)
}