forked from gin-contrib/cors
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cors.go
198 lines (163 loc) · 5.71 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
package cors
import (
"errors"
"fmt"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// Config represents all available options for the middleware.
type Config struct {
AllowAllOrigins bool
// AllowOrigins is a list of origins a cross-domain request can be executed from.
// If the special "*" value is present in the list, all origins will be allowed.
// Default value is []
AllowOrigins []string
// AllowOriginFunc is a custom function to validate the origin. It takes the origin
// as an argument and returns true if allowed or false otherwise. If this option is
// set, the content of AllowOrigins is ignored.
AllowOriginFunc func(origin string) bool
// Same as AllowOriginFunc except also receives the full request context.
// This function should use the context as a read only source and not
// have any side effects on the request, such as aborting or injecting
// values on the request.
AllowOriginWithContextFunc func(c *gin.Context, origin string) bool
// AllowMethods is a list of methods the client is allowed to use with
// cross-domain requests. Default value is simple methods (GET, POST, PUT, PATCH, DELETE, HEAD, and OPTIONS)
AllowMethods []string
// AllowPrivateNetwork indicates whether the response should include allow private network header
AllowPrivateNetwork bool
// AllowHeaders is list of non simple headers the client is allowed to use with
// cross-domain requests.
AllowHeaders []string
// AllowCredentials indicates whether the request can include user credentials like
// cookies, HTTP authentication or client side SSL certificates.
AllowCredentials bool
// ExposeHeaders indicates which headers are safe to expose to the API of a CORS
// API specification
ExposeHeaders []string
// MaxAge indicates how long (with second-precision) the results of a preflight request
// can be cached
MaxAge time.Duration
// Allows to add origins like http://some-domain/*, https://api.* or http://some.*.subdomain.com
AllowWildcard bool
// Allows usage of popular browser extensions schemas
AllowBrowserExtensions bool
// Allows to add custom schema like tauri://
CustomSchemas []string
// Allows usage of WebSocket protocol
AllowWebSockets bool
// Allows usage of file:// schema (dangerous!) use it only when you 100% sure it's needed
AllowFiles bool
// Allows to pass custom OPTIONS response status code for old browsers / clients
OptionsResponseStatusCode int
}
// AddAllowMethods is allowed to add custom methods
func (c *Config) AddAllowMethods(methods ...string) {
c.AllowMethods = append(c.AllowMethods, methods...)
}
// AddAllowHeaders is allowed to add custom headers
func (c *Config) AddAllowHeaders(headers ...string) {
c.AllowHeaders = append(c.AllowHeaders, headers...)
}
// AddExposeHeaders is allowed to add custom expose headers
func (c *Config) AddExposeHeaders(headers ...string) {
c.ExposeHeaders = append(c.ExposeHeaders, headers...)
}
func (c Config) getAllowedSchemas() []string {
allowedSchemas := DefaultSchemas
if c.AllowBrowserExtensions {
allowedSchemas = append(allowedSchemas, ExtensionSchemas...)
}
if c.AllowWebSockets {
allowedSchemas = append(allowedSchemas, WebSocketSchemas...)
}
if c.AllowFiles {
allowedSchemas = append(allowedSchemas, FileSchemas...)
}
if c.CustomSchemas != nil {
allowedSchemas = append(allowedSchemas, c.CustomSchemas...)
}
return allowedSchemas
}
func (c Config) validateAllowedSchemas(origin string) bool {
allowedSchemas := c.getAllowedSchemas()
for _, schema := range allowedSchemas {
if strings.HasPrefix(origin, schema) {
return true
}
}
return false
}
// Validate is check configuration of user defined.
func (c Config) Validate() error {
hasOriginFn := c.AllowOriginFunc != nil
hasOriginFn = hasOriginFn || c.AllowOriginWithContextFunc != nil
if c.AllowAllOrigins && (hasOriginFn || len(c.AllowOrigins) > 0) {
originFields := strings.Join([]string{
"AllowOriginFunc",
"AllowOriginFuncWithContext",
"AllowOrigins",
}, " or ")
return fmt.Errorf(
"conflict settings: all origins enabled. %s is not needed",
originFields,
)
}
if !c.AllowAllOrigins && !hasOriginFn && len(c.AllowOrigins) == 0 {
return errors.New("conflict settings: all origins disabled")
}
for _, origin := range c.AllowOrigins {
if !strings.Contains(origin, "*") && !c.validateAllowedSchemas(origin) {
return errors.New("bad origin: origins must contain '*' or include " + strings.Join(c.getAllowedSchemas(), ","))
}
}
return nil
}
func (c Config) parseWildcardRules() [][]string {
var wRules [][]string
if !c.AllowWildcard {
return wRules
}
for _, o := range c.AllowOrigins {
if !strings.Contains(o, "*") {
continue
}
if c := strings.Count(o, "*"); c > 1 {
panic(errors.New("only one * is allowed").Error())
}
i := strings.Index(o, "*")
if i == 0 {
wRules = append(wRules, []string{"*", o[1:]})
continue
}
if i == (len(o) - 1) {
wRules = append(wRules, []string{o[:i], "*"})
continue
}
wRules = append(wRules, []string{o[:i], o[i+1:]})
}
return wRules
}
// DefaultConfig returns a generic default configuration mapped to localhost.
func DefaultConfig() Config {
return Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
AllowHeaders: []string{"Origin", "Content-Length", "Content-Type"},
AllowCredentials: false,
MaxAge: 12 * time.Hour,
}
}
// Default returns the location middleware with default configuration.
func Default() gin.HandlerFunc {
config := DefaultConfig()
config.AllowAllOrigins = true
return New(config)
}
// New returns the location middleware with user-defined custom configuration.
func New(config Config) gin.HandlerFunc {
cors := newCors(config)
return func(c *gin.Context) {
cors.applyCors(c)
}
}