@@ -10,7 +10,6 @@ import (
1010type CORSOption func (* cors ) error
1111
1212type cors struct {
13- h http.Handler
1413 allowedHeaders []string
1514 allowedMethods []string
1615 allowedOrigins []string
@@ -47,93 +46,95 @@ const (
4746 corsOriginMatchAll string = "*"
4847)
4948
50- func (ch * cors ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
51- origin := r .Header .Get (corsOriginHeader )
52- if ! ch .isOriginAllowed (origin ) {
53- if r .Method != corsOptionMethod || ch .ignoreOptions {
54- ch .h .ServeHTTP (w , r )
55- }
56-
57- return
58- }
59-
60- if r .Method == corsOptionMethod {
61- if ch .ignoreOptions {
62- ch .h .ServeHTTP (w , r )
63- return
64- }
49+ func (ch * cors ) wrap (h http.Handler ) http.Handler {
50+ return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
51+ origin := r .Header .Get (corsOriginHeader )
52+ if ! ch .isOriginAllowed (origin ) {
53+ if r .Method != corsOptionMethod || ch .ignoreOptions {
54+ h .ServeHTTP (w , r )
55+ }
6556
66- if _ , ok := r .Header [corsRequestMethodHeader ]; ! ok {
67- w .WriteHeader (http .StatusBadRequest )
6857 return
6958 }
7059
71- method := r . Header . Get ( corsRequestMethodHeader )
72- if ! ch .isMatch ( method , ch . allowedMethods ) {
73- w . WriteHeader ( http . StatusMethodNotAllowed )
74- return
75- }
60+ if r . Method == corsOptionMethod {
61+ if ch .ignoreOptions {
62+ h . ServeHTTP ( w , r )
63+ return
64+ }
7665
77- requestHeaders := strings .Split (r .Header .Get (corsRequestHeadersHeader ), "," )
78- allowedHeaders := []string {}
79- for _ , v := range requestHeaders {
80- canonicalHeader := http .CanonicalHeaderKey (strings .TrimSpace (v ))
81- if canonicalHeader == "" || ch .isMatch (canonicalHeader , defaultCorsHeaders ) {
82- continue
66+ if _ , ok := r .Header [corsRequestMethodHeader ]; ! ok {
67+ w .WriteHeader (http .StatusBadRequest )
68+ return
8369 }
8470
85- if ! ch .isMatch (canonicalHeader , ch .allowedHeaders ) {
86- w .WriteHeader (http .StatusForbidden )
71+ method := r .Header .Get (corsRequestMethodHeader )
72+ if ! ch .isMatch (method , ch .allowedMethods ) {
73+ w .WriteHeader (http .StatusMethodNotAllowed )
8774 return
8875 }
8976
90- allowedHeaders = append (allowedHeaders , canonicalHeader )
91- }
77+ requestHeaders := strings .Split (r .Header .Get (corsRequestHeadersHeader ), "," )
78+ allowedHeaders := []string {}
79+ for _ , v := range requestHeaders {
80+ canonicalHeader := http .CanonicalHeaderKey (strings .TrimSpace (v ))
81+ if canonicalHeader == "" || ch .isMatch (canonicalHeader , defaultCorsHeaders ) {
82+ continue
83+ }
9284
93- if len (allowedHeaders ) > 0 {
94- w .Header ().Set (corsAllowHeadersHeader , strings .Join (allowedHeaders , "," ))
95- }
85+ if ! ch .isMatch (canonicalHeader , ch .allowedHeaders ) {
86+ w .WriteHeader (http .StatusForbidden )
87+ return
88+ }
9689
97- if ch .maxAge > 0 {
98- w .Header ().Set (corsMaxAgeHeader , strconv .Itoa (ch .maxAge ))
99- }
90+ allowedHeaders = append (allowedHeaders , canonicalHeader )
91+ }
92+
93+ if len (allowedHeaders ) > 0 {
94+ w .Header ().Set (corsAllowHeadersHeader , strings .Join (allowedHeaders , "," ))
95+ }
96+
97+ if ch .maxAge > 0 {
98+ w .Header ().Set (corsMaxAgeHeader , strconv .Itoa (ch .maxAge ))
99+ }
100100
101- if ! ch .isMatch (method , defaultCorsMethods ) {
102- w .Header ().Set (corsAllowMethodsHeader , method )
101+ if ! ch .isMatch (method , defaultCorsMethods ) {
102+ w .Header ().Set (corsAllowMethodsHeader , method )
103+ }
104+ } else if len (ch .exposedHeaders ) > 0 {
105+ w .Header ().Set (corsExposeHeadersHeader , strings .Join (ch .exposedHeaders , "," ))
103106 }
104- } else if len (ch .exposedHeaders ) > 0 {
105- w .Header ().Set (corsExposeHeadersHeader , strings .Join (ch .exposedHeaders , "," ))
106- }
107107
108- if ch .allowCredentials {
109- w .Header ().Set (corsAllowCredentialsHeader , "true" )
110- }
108+ if ch .allowCredentials {
109+ w .Header ().Set (corsAllowCredentialsHeader , "true" )
110+ }
111111
112- if len (ch .allowedOrigins ) > 1 {
113- w .Header ().Set (corsVaryHeader , corsOriginHeader )
114- }
112+ if len (ch .allowedOrigins ) > 1 {
113+ w .Header ().Set (corsVaryHeader , corsOriginHeader )
114+ }
115115
116- returnOrigin := origin
117- if ch .allowedOriginValidator == nil && len (ch .allowedOrigins ) == 0 {
118- returnOrigin = "*"
119- } else {
120- for _ , o := range ch .allowedOrigins {
121- // A configuration of * is different than explicitly setting an allowed
122- // origin. Returning arbitrary origin headers in an access control allow
123- // origin header is unsafe and is not required by any use case.
124- if o == corsOriginMatchAll {
125- returnOrigin = "*"
126- break
116+ returnOrigin := origin
117+ if ch .allowedOriginValidator == nil && len (ch .allowedOrigins ) == 0 {
118+ returnOrigin = "*"
119+ } else {
120+ for _ , o := range ch .allowedOrigins {
121+ // A configuration of * is different than explicitly setting an allowed
122+ // origin. Returning arbitrary origin headers in an access control allow
123+ // origin header is unsafe and is not required by any use case.
124+ if o == corsOriginMatchAll {
125+ returnOrigin = "*"
126+ break
127+ }
127128 }
128129 }
129- }
130- w .Header ().Set (corsAllowOriginHeader , returnOrigin )
130+ w .Header ().Set (corsAllowOriginHeader , returnOrigin )
131131
132- if r .Method == corsOptionMethod {
133- w .WriteHeader (ch .optionStatusCode )
134- return
135- }
136- ch .h .ServeHTTP (w , r )
132+ if r .Method == corsOptionMethod {
133+ w .WriteHeader (ch .optionStatusCode )
134+ return
135+ }
136+ h .ServeHTTP (w , r )
137+ })
137138}
138139
139140// CORS provides Cross-Origin Resource Sharing middleware.
@@ -155,11 +156,7 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
155156// http.ListenAndServe(":8000", handlers.CORS()(r))
156157// }
157158func CORS (opts ... CORSOption ) func (http.Handler ) http.Handler {
158- return func (h http.Handler ) http.Handler {
159- ch := parseCORSOptions (opts ... )
160- ch .h = h
161- return ch
162- }
159+ return parseCORSOptions (opts ... ).wrap
163160}
164161
165162func parseCORSOptions (opts ... CORSOption ) * cors {
0 commit comments