diff --git a/echo.go b/echo.go index 0bb64d214..7e440d37f 100644 --- a/echo.go +++ b/echo.go @@ -232,9 +232,12 @@ const ( HeaderXCorrelationID = "X-Correlation-Id" HeaderXRequestedWith = "X-Requested-With" HeaderServer = "Server" - HeaderOrigin = "Origin" - HeaderCacheControl = "Cache-Control" - HeaderConnection = "Connection" + + // HeaderOrigin request header indicates the origin (scheme, hostname, and port) that caused the request. + // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin + HeaderOrigin = "Origin" + HeaderCacheControl = "Cache-Control" + HeaderConnection = "Connection" // Access control HeaderAccessControlRequestMethod = "Access-Control-Request-Method" @@ -255,6 +258,11 @@ const ( HeaderContentSecurityPolicyReportOnly = "Content-Security-Policy-Report-Only" HeaderXCSRFToken = "X-CSRF-Token" HeaderReferrerPolicy = "Referrer-Policy" + + // HeaderSecFetchSite fetch metadata request header indicates the relationship between a request initiator's + // origin and the origin of the requested resource. + // See: https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Sec-Fetch-Site + HeaderSecFetchSite = "Sec-Fetch-Site" ) const ( diff --git a/middleware/csrf.go b/middleware/csrf.go index 92f4019dc..f9d3293b0 100644 --- a/middleware/csrf.go +++ b/middleware/csrf.go @@ -6,6 +6,8 @@ package middleware import ( "crypto/subtle" "net/http" + "slices" + "strings" "time" "github.com/labstack/echo/v4" @@ -16,6 +18,22 @@ type CSRFConfig struct { // Skipper defines a function to skip middleware. Skipper Skipper + // TrustedOrigin permits any request with `Sec-Fetch-Site` header whose `Origin` header + // exactly matches the specified value. + // Values should be formated as Origin header "scheme://host[:port]". + // + // See [Origin]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin + // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers + TrustedOrigins []string + + // AllowSecFetchSameSite allows custom behaviour for `Sec-Fetch-Site` requests that are about to + // fail with CRSF error, to be allowed or replaced with custom error. + // This function applies to `Sec-Fetch-Site` values: + // - `same-site` same registrable domain (subdomain and/or different port) + // - `cross-site` request originates from different site + // See [Sec-Fetch-Site]: https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers + AllowSecFetchSiteFunc func(c echo.Context) (bool, error) + // TokenLength is the length of the generated token. TokenLength uint8 `yaml:"token_length"` // Optional. Default value 32. @@ -94,7 +112,11 @@ func CSRF() echo.MiddlewareFunc { // CSRFWithConfig returns a CSRF middleware with config. // See `CSRF()`. func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { - // Defaults + return toMiddlewareOrPanic(config) +} + +// ToMiddleware converts CSRFConfig to middleware or returns an error for invalid configuration +func (config CSRFConfig) ToMiddleware() (echo.MiddlewareFunc, error) { if config.Skipper == nil { config.Skipper = DefaultCSRFConfig.Skipper } @@ -117,10 +139,16 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { if config.CookieSameSite == http.SameSiteNoneMode { config.CookieSecure = true } + if len(config.TrustedOrigins) > 0 { + if vErr := validateOrigins(config.TrustedOrigins, "trusted origin"); vErr != nil { + return nil, vErr + } + config.TrustedOrigins = append([]string(nil), config.TrustedOrigins...) + } extractors, cErr := CreateExtractors(config.TokenLookup) if cErr != nil { - panic(cErr) + return nil, cErr } return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -129,6 +157,17 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } + // use the `Sec-Fetch-Site` header as part of a modern approach to CSRF protection + allow, err := config.checkSecFetchSiteRequest(c) + if err != nil { + return err + } + if allow { + return next(c) + } + + // Fallback to legacy token based CSRF protection + token := "" if k, err := c.Cookie(config.CookieName); err != nil { token = randomString(config.TokenLength) @@ -210,9 +249,55 @@ func CSRFWithConfig(config CSRFConfig) echo.MiddlewareFunc { return next(c) } - } + }, nil } func validateCSRFToken(token, clientToken string) bool { return subtle.ConstantTimeCompare([]byte(token), []byte(clientToken)) == 1 } + +var safeMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions, http.MethodTrace} + +func (config CSRFConfig) checkSecFetchSiteRequest(c echo.Context) (bool, error) { + // https://cheatsheetseries.owasp.org/cheatsheets/Cross-Site_Request_Forgery_Prevention_Cheat_Sheet.html#fetch-metadata-headers + // Sec-Fetch-Site values are: + // - `same-origin` exact origin match - allow always + // - `same-site` same registrable domain (subdomain and/or different port) - block, unless explicitly trusted + // - `cross-site` request originates from different site - block, unless explicitly trusted + // - `none` direct navigation (URL bar, bookmark) - allow always + secFetchSite := c.Request().Header.Get(echo.HeaderSecFetchSite) + if secFetchSite == "" { + return false, nil + } + + if len(config.TrustedOrigins) > 0 { + // trusted sites ala OAuth callbacks etc. should be let through + origin := c.Request().Header.Get(echo.HeaderOrigin) + if origin != "" { + for _, trustedOrigin := range config.TrustedOrigins { + if strings.EqualFold(origin, trustedOrigin) { + return true, nil + } + } + } + } + isSafe := slices.Contains(safeMethods, c.Request().Method) + if !isSafe { // for state-changing request check SecFetchSite value + isSafe = secFetchSite == "same-origin" || secFetchSite == "none" + } + + if isSafe { + return true, nil + } + // we are here when request is state-changing and `cross-site` or `same-site` + + // Note: if you want to block `same-site` use config.TrustedOrigins or `config.AllowSecFetchSiteFunc` + if config.AllowSecFetchSiteFunc != nil { + return config.AllowSecFetchSiteFunc(c) + } + + if secFetchSite == "same-site" { + return false, nil // fall back to legacy token + } + return false, echo.NewHTTPError(http.StatusForbidden, "cross-site request blocked by CSRF") +} diff --git a/middleware/csrf_test.go b/middleware/csrf_test.go index 98e5d04f6..85b7f1077 100644 --- a/middleware/csrf_test.go +++ b/middleware/csrf_test.go @@ -4,6 +4,7 @@ package middleware import ( + "cmp" "net/http" "net/http/httptest" "net/url" @@ -16,15 +17,16 @@ import ( func TestCSRF_tokenExtractors(t *testing.T) { var testCases = []struct { - name string - whenTokenLookup string - whenCookieName string - givenCSRFCookie string - givenMethod string - givenQueryTokens map[string][]string - givenFormTokens map[string][]string - givenHeaderTokens map[string][]string - expectError string + name string + whenTokenLookup string + whenCookieName string + givenCSRFCookie string + givenMethod string + givenQueryTokens map[string][]string + givenFormTokens map[string][]string + givenHeaderTokens map[string][]string + expectError string + expectToMiddlewareError string }{ { name: "ok, multiple token lookups sources, succeeds on last one", @@ -146,6 +148,14 @@ func TestCSRF_tokenExtractors(t *testing.T) { givenQueryTokens: map[string][]string{}, expectError: "code=400, message=missing csrf token in the query string", }, + { + name: "nok, invalid TokenLookup", + whenTokenLookup: "q", + givenCSRFCookie: "token", + givenMethod: http.MethodPut, + givenQueryTokens: map[string][]string{}, + expectToMiddlewareError: "extractor source for lookup could not be split into needed parts: q", + }, } for _, tc := range testCases { @@ -188,16 +198,23 @@ func TestCSRF_tokenExtractors(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + config := CSRFConfig{ TokenLookup: tc.whenTokenLookup, CookieName: tc.whenCookieName, - }) + } + csrf, err := config.ToMiddleware() + if tc.expectToMiddlewareError != "" { + assert.EqualError(t, err, tc.expectToMiddlewareError) + return + } else if err != nil { + assert.NoError(t, err) + } h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") }) - err := h(c) + err = h(c) if tc.expectError != "" { assert.EqualError(t, err, tc.expectError) } else { @@ -207,6 +224,125 @@ func TestCSRF_tokenExtractors(t *testing.T) { } } +func TestCSRFWithConfig(t *testing.T) { + token := randomString(16) + + var testCases = []struct { + name string + givenConfig *CSRFConfig + whenMethod string + whenHeaders map[string]string + expectEmptyBody bool + expectMWError string + expectCookieContains string + expectErr string + }{ + { + name: "ok, GET", + whenMethod: http.MethodGet, + expectCookieContains: "_csrf", + }, + { + name: "ok, POST valid token", + whenHeaders: map[string]string{ + echo.HeaderCookie: "_csrf=" + token, + echo.HeaderXCSRFToken: token, + }, + whenMethod: http.MethodPost, + expectCookieContains: "_csrf", + }, + { + name: "nok, POST without token", + whenMethod: http.MethodPost, + expectEmptyBody: true, + expectErr: `code=400, message=missing csrf token in request header`, + }, + { + name: "nok, POST empty token", + whenHeaders: map[string]string{echo.HeaderXCSRFToken: ""}, + whenMethod: http.MethodPost, + expectEmptyBody: true, + expectErr: `code=403, message=invalid csrf token`, + }, + { + name: "nok, invalid trusted origin in Config", + givenConfig: &CSRFConfig{ + TrustedOrigins: []string{"http://example.com", "invalid"}, + }, + expectMWError: `trusted origin is missing scheme or host: invalid`, + }, + { + name: "ok, TokenLength", + givenConfig: &CSRFConfig{ + TokenLength: 16, + }, + whenMethod: http.MethodGet, + expectCookieContains: "_csrf", + }, + { + name: "ok, unsafe method + SecFetchSite=same-origin passes", + whenHeaders: map[string]string{ + echo.HeaderSecFetchSite: "same-origin", + }, + whenMethod: http.MethodPost, + }, + { + name: "nok, unsafe method + SecFetchSite=same-cross blocked", + whenHeaders: map[string]string{ + echo.HeaderSecFetchSite: "same-cross", + }, + whenMethod: http.MethodPost, + expectEmptyBody: true, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + e := echo.New() + + req := httptest.NewRequest(cmp.Or(tc.whenMethod, http.MethodPost), "/", nil) + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + + for key, value := range tc.whenHeaders { + req.Header.Set(key, value) + } + + config := CSRFConfig{} + if tc.givenConfig != nil { + config = *tc.givenConfig + } + mw, err := config.ToMiddleware() + if tc.expectMWError != "" { + assert.EqualError(t, err, tc.expectMWError) + return + } + assert.NoError(t, err) + + h := mw(func(c echo.Context) error { + return c.String(http.StatusOK, "test") + }) + + err = h(c) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + + expect := "test" + if tc.expectEmptyBody { + expect = "" + } + assert.Equal(t, expect, rec.Body.String()) + + if tc.expectCookieContains != "" { + assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), tc.expectCookieContains) + } + }) + } +} + func TestCSRF(t *testing.T) { e := echo.New() req := httptest.NewRequest(http.MethodGet, "/", nil) @@ -221,26 +357,6 @@ func TestCSRF(t *testing.T) { h(c) assert.Contains(t, rec.Header().Get(echo.HeaderSetCookie), "_csrf") - // Without CSRF cookie - req = httptest.NewRequest(http.MethodPost, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - assert.Error(t, h(c)) - - // Empty/invalid CSRF token - req = httptest.NewRequest(http.MethodPost, "/", nil) - rec = httptest.NewRecorder() - c = e.NewContext(req, rec) - req.Header.Set(echo.HeaderXCSRFToken, "") - assert.Error(t, h(c)) - - // Valid CSRF token - token := randomString(32) - req.Header.Set(echo.HeaderCookie, "_csrf="+token) - req.Header.Set(echo.HeaderXCSRFToken, token) - if assert.NoError(t, h(c)) { - assert.Equal(t, http.StatusOK, rec.Code) - } } func TestCSRFSetSameSiteMode(t *testing.T) { @@ -304,9 +420,10 @@ func TestCSRFWithSameSiteModeNone(t *testing.T) { rec := httptest.NewRecorder() c := e.NewContext(req, rec) - csrf := CSRFWithConfig(CSRFConfig{ + csrf, err := CSRFConfig{ CookieSameSite: http.SameSiteNoneMode, - }) + }.ToMiddleware() + assert.NoError(t, err) h := csrf(func(c echo.Context) error { return c.String(http.StatusOK, "test") @@ -382,3 +499,354 @@ func TestCSRFErrorHandling(t *testing.T) { assert.Equal(t, http.StatusTeapot, res.Code) assert.Equal(t, "{\"message\":\"error_handler_executed\"}\n", res.Body.String()) } + +func TestCSRFConfig_checkSecFetchSiteRequest(t *testing.T) { + var testCases = []struct { + name string + givenConfig CSRFConfig + whenMethod string + whenSecFetchSite string + whenOrigin string + expectAllow bool + expectErr string + }{ + { + name: "ok, unsafe POST, no SecFetchSite is not blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "", + expectAllow: false, // should fall back to token CSRF + }, + { + name: "ok, safe GET + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, safe GET + none passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "none", + expectAllow: true, + }, + { + name: "ok, safe GET + same-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "same-site", + expectAllow: true, + }, + { + name: "ok, safe GET + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodGet, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "nok, unsafe POST + same-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: ``, + }, + { + name: "ok, unsafe POST + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, unsafe POST + none passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "none", + expectAllow: true, + }, + { + name: "ok, unsafe PUT + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, unsafe PUT + none passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "none", + expectAllow: true, + }, + { + name: "ok, unsafe DELETE + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodDelete, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, unsafe PATCH + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPatch, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "nok, unsafe PUT + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "nok, unsafe PUT + same-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPut, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: ``, + }, + { + name: "nok, unsafe DELETE + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodDelete, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "nok, unsafe DELETE + same-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodDelete, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: ``, + }, + { + name: "nok, unsafe PATCH + cross-site is blocked", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPatch, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, safe HEAD + same-origin passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodHead, + whenSecFetchSite: "same-origin", + expectAllow: true, + }, + { + name: "ok, safe HEAD + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodHead, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "ok, safe OPTIONS + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodOptions, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "ok, safe TRACE + cross-site passes", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodTrace, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "ok, unsafe POST + cross-site + matching trusted origin passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://trusted.example.com", + expectAllow: true, + }, + { + name: "ok, unsafe POST + same-site + matching trusted origin passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + whenOrigin: "https://trusted.example.com", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site + non-matching origin is blocked", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://evil.example.com", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, unsafe POST + cross-site + case-insensitive trusted origin match passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://TRUSTED.example.com", + expectAllow: true, + }, + { + name: "ok, unsafe POST + same-origin + trusted origins configured but not matched passes", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-origin", + whenOrigin: "https://different.example.com", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site + empty origin + trusted origins configured is blocked", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, unsafe POST + cross-site + multiple trusted origins, second one matches", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://first.example.com", "https://second.example.com"}, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://second.example.com", + expectAllow: true, + }, + { + name: "ok, unsafe POST + same-site + custom func allows", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + return true, nil + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + expectAllow: true, + }, + { + name: "ok, unsafe POST + cross-site + custom func allows", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + return true, nil + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + expectAllow: true, + }, + { + name: "nok, unsafe POST + same-site + custom func returns custom error", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + return false, echo.NewHTTPError(http.StatusTeapot, "custom error from func") + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "same-site", + expectAllow: false, + expectErr: `code=418, message=custom error from func`, + }, + { + name: "nok, unsafe POST + cross-site + custom func returns false with nil error", + givenConfig: CSRFConfig{ + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + return false, nil + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + expectAllow: false, + expectErr: "", // custom func returns nil error, so no error expected + }, + { + name: "nok, unsafe POST + invalid Sec-Fetch-Site value treated as cross-site", + givenConfig: CSRFConfig{}, + whenMethod: http.MethodPost, + whenSecFetchSite: "invalid-value", + expectAllow: false, + expectErr: `code=403, message=cross-site request blocked by CSRF`, + }, + { + name: "ok, unsafe POST + cross-site + trusted origin takes precedence over custom func", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + return false, echo.NewHTTPError(http.StatusTeapot, "should not be called") + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://trusted.example.com", + expectAllow: true, + }, + { + name: "nok, unsafe POST + cross-site + trusted origin not matched, custom func blocks", + givenConfig: CSRFConfig{ + TrustedOrigins: []string{"https://trusted.example.com"}, + AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { + return false, echo.NewHTTPError(http.StatusTeapot, "custom block") + }, + }, + whenMethod: http.MethodPost, + whenSecFetchSite: "cross-site", + whenOrigin: "https://evil.example.com", + expectAllow: false, + expectErr: `code=418, message=custom block`, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest(tc.whenMethod, "/", nil) + if tc.whenSecFetchSite != "" { + req.Header.Set(echo.HeaderSecFetchSite, tc.whenSecFetchSite) + } + if tc.whenOrigin != "" { + req.Header.Set(echo.HeaderOrigin, tc.whenOrigin) + } + + res := httptest.NewRecorder() + e := echo.New() + c := e.NewContext(req, res) + + allow, err := tc.givenConfig.checkSecFetchSiteRequest(c) + + assert.Equal(t, tc.expectAllow, allow) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/middleware/middleware.go b/middleware/middleware.go index 6f33cc5c1..164e52b4c 100644 --- a/middleware/middleware.go +++ b/middleware/middleware.go @@ -88,3 +88,13 @@ func rewriteURL(rewriteRegex map[*regexp.Regexp]string, req *http.Request) error func DefaultSkipper(echo.Context) bool { return false } + +func toMiddlewareOrPanic(config interface { + ToMiddleware() (echo.MiddlewareFunc, error) +}) echo.MiddlewareFunc { + mw, err := config.ToMiddleware() + if err != nil { + panic(err) + } + return mw +} diff --git a/middleware/util.go b/middleware/util.go index 09428eb0b..5813990a5 100644 --- a/middleware/util.go +++ b/middleware/util.go @@ -6,7 +6,9 @@ package middleware import ( "bufio" "crypto/rand" + "fmt" "io" + "net/url" "strings" "sync" ) @@ -101,3 +103,26 @@ func randomString(length uint8) string { } } } + +func validateOrigins(origins []string, what string) error { + for _, o := range origins { + if err := validateOrigin(o, what); err != nil { + return err + } + } + return nil +} + +func validateOrigin(origin string, what string) error { + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("can not parse %s: %w", what, err) + } + if u.Scheme == "" || u.Host == "" { + return fmt.Errorf("%s is missing scheme or host: %s", what, origin) + } + if u.Path != "" || u.RawQuery != "" || u.Fragment != "" { + return fmt.Errorf("%s can not have path, query, and fragments: %s", what, origin) + } + return nil +} diff --git a/middleware/util_test.go b/middleware/util_test.go index b54f12627..1c171f5a5 100644 --- a/middleware/util_test.go +++ b/middleware/util_test.go @@ -149,3 +149,209 @@ func TestRandomStringBias(t *testing.T) { } } } + +func TestValidateOrigins(t *testing.T) { + var testCases = []struct { + name string + givenOrigins []string + givenWhat string + expectErr string + }{ + // Valid cases + { + name: "ok, empty origins", + givenOrigins: []string{}, + }, + { + name: "ok, basic http", + givenOrigins: []string{"http://example.com"}, + }, + { + name: "ok, basic https", + givenOrigins: []string{"https://example.com"}, + }, + { + name: "ok, with port", + givenOrigins: []string{"http://localhost:8080"}, + }, + { + name: "ok, with subdomain", + givenOrigins: []string{"https://api.example.com"}, + }, + { + name: "ok, subdomain with port", + givenOrigins: []string{"https://api.example.com:8080"}, + }, + { + name: "ok, localhost", + givenOrigins: []string{"http://localhost"}, + }, + { + name: "ok, IPv4 address", + givenOrigins: []string{"http://192.168.1.1"}, + }, + { + name: "ok, IPv4 with port", + givenOrigins: []string{"http://192.168.1.1:8080"}, + }, + { + name: "ok, IPv6 loopback", + givenOrigins: []string{"http://[::1]"}, + }, + { + name: "ok, IPv6 with port", + givenOrigins: []string{"http://[::1]:8080"}, + }, + { + name: "ok, IPv6 full address", + givenOrigins: []string{"http://[2001:db8::1]"}, + }, + { + name: "ok, multiple valid origins", + givenOrigins: []string{"http://example.com", "https://api.example.com:8080"}, + }, + { + name: "ok, different schemes", + givenOrigins: []string{"http://example.com", "https://example.com", "ws://example.com"}, + }, + // Invalid - missing scheme + { + name: "nok, plain domain", + givenOrigins: []string{"example.com"}, + expectErr: "trusted origin is missing scheme or host: example.com", + }, + { + name: "nok, with slashes but no scheme", + givenOrigins: []string{"//example.com"}, + expectErr: "trusted origin is missing scheme or host: //example.com", + }, + { + name: "nok, www without scheme", + givenOrigins: []string{"www.example.com"}, + expectErr: "trusted origin is missing scheme or host: www.example.com", + }, + { + name: "nok, localhost without scheme", + givenOrigins: []string{"localhost:8080"}, + expectErr: "trusted origin is missing scheme or host: localhost:8080", + }, + // Invalid - missing host + { + name: "nok, scheme only http", + givenOrigins: []string{"http://"}, + expectErr: "trusted origin is missing scheme or host: http://", + }, + { + name: "nok, scheme only https", + givenOrigins: []string{"https://"}, + expectErr: "trusted origin is missing scheme or host: https://", + }, + // Invalid - has path + { + name: "nok, has simple path", + givenOrigins: []string{"http://example.com/path"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path", + }, + { + name: "nok, has nested path", + givenOrigins: []string{"https://example.com/api/v1"}, + expectErr: "trusted origin can not have path, query, and fragments: https://example.com/api/v1", + }, + { + name: "nok, has root path", + givenOrigins: []string{"http://example.com/"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com/", + }, + // Invalid - has query + { + name: "nok, has single query param", + givenOrigins: []string{"http://example.com?foo=bar"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com?foo=bar", + }, + { + name: "nok, has multiple query params", + givenOrigins: []string{"https://example.com?foo=bar&baz=qux"}, + expectErr: "trusted origin can not have path, query, and fragments: https://example.com?foo=bar&baz=qux", + }, + // Invalid - has fragment + { + name: "nok, has simple fragment", + givenOrigins: []string{"http://example.com#section"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com#section", + }, + // Invalid - combinations + { + name: "nok, has path and query", + givenOrigins: []string{"http://example.com/path?foo=bar"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path?foo=bar", + }, + { + name: "nok, has path and fragment", + givenOrigins: []string{"http://example.com/path#section"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path#section", + }, + { + name: "nok, has query and fragment", + givenOrigins: []string{"http://example.com?foo=bar#section"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com?foo=bar#section", + }, + { + name: "nok, has path, query, and fragment", + givenOrigins: []string{"http://example.com/path?foo=bar#section"}, + expectErr: "trusted origin can not have path, query, and fragments: http://example.com/path?foo=bar#section", + }, + // Edge cases + { + name: "nok, empty string", + givenOrigins: []string{""}, + expectErr: "trusted origin is missing scheme or host: ", + }, + { + name: "nok, whitespace only", + givenOrigins: []string{" "}, + expectErr: "trusted origin is missing scheme or host: ", + }, + { + name: "nok, multiple origins - first invalid", + givenOrigins: []string{"example.com", "http://valid.com"}, + expectErr: "trusted origin is missing scheme or host: example.com", + }, + { + name: "nok, multiple origins - middle invalid", + givenOrigins: []string{"http://valid1.com", "invalid.com", "http://valid2.com"}, + expectErr: "trusted origin is missing scheme or host: invalid.com", + }, + { + name: "nok, multiple origins - last invalid", + givenOrigins: []string{"http://valid.com", "invalid.com"}, + expectErr: "trusted origin is missing scheme or host: invalid.com", + }, + // Different "what" parameter + { + name: "nok, custom what parameter - missing scheme", + givenOrigins: []string{"example.com"}, + givenWhat: "allowed origin", + expectErr: "allowed origin is missing scheme or host: example.com", + }, + { + name: "nok, custom what parameter - has path", + givenOrigins: []string{"http://example.com/path"}, + givenWhat: "cors origin", + expectErr: "cors origin can not have path, query, and fragments: http://example.com/path", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + what := tc.givenWhat + if what == "" { + what = "trusted origin" + } + err := validateOrigins(tc.givenOrigins, what) + if tc.expectErr != "" { + assert.EqualError(t, err, tc.expectErr) + } else { + assert.NoError(t, err) + } + }) + } +}