From 4915852b9c8c030c086a259a18b1343b9f2e5c10 Mon Sep 17 00:00:00 2001 From: fatedier Date: Mon, 29 May 2023 00:27:27 +0800 Subject: [PATCH] use constant time comparison (#3452) --- client/admin.go | 2 +- pkg/auth/token.go | 12 ++++++------ pkg/nathole/controller.go | 2 +- pkg/plugin/client/http_proxy.go | 6 +++++- pkg/plugin/client/static_file.go | 3 ++- pkg/util/net/http.go | 32 ++++++++++++++++---------------- pkg/util/util/util.go | 5 +++++ server/dashboard.go | 2 +- test/e2e/framework/framework.go | 8 ++++---- test/e2e/framework/process.go | 4 ++-- test/e2e/pkg/port/port.go | 4 ++-- 11 files changed, 45 insertions(+), 35 deletions(-) diff --git a/client/admin.go b/client/admin.go index f96e1bc1..949ab8ad 100644 --- a/client/admin.go +++ b/client/admin.go @@ -48,7 +48,7 @@ func (svr *Service) RunAdminServer(address string) (err error) { subRouter := router.NewRoute().Subrouter() user, passwd := svr.cfg.AdminUser, svr.cfg.AdminPwd - subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).Middleware) + subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware) // api, see admin_api.go subRouter.HandleFunc("/api/reload", svr.apiReload).Methods("GET") diff --git a/pkg/auth/token.go b/pkg/auth/token.go index 1049174d..9b2e3f7c 100644 --- a/pkg/auth/token.go +++ b/pkg/auth/token.go @@ -73,30 +73,30 @@ func (auth *TokenAuthSetterVerifier) SetNewWorkConn(newWorkConnMsg *msg.NewWorkC return nil } -func (auth *TokenAuthSetterVerifier) VerifyLogin(loginMsg *msg.Login) error { - if util.GetAuthKey(auth.token, loginMsg.Timestamp) != loginMsg.PrivilegeKey { +func (auth *TokenAuthSetterVerifier) VerifyLogin(m *msg.Login) error { + if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) { return fmt.Errorf("token in login doesn't match token from configuration") } return nil } -func (auth *TokenAuthSetterVerifier) VerifyPing(pingMsg *msg.Ping) error { +func (auth *TokenAuthSetterVerifier) VerifyPing(m *msg.Ping) error { if !auth.AuthenticateHeartBeats { return nil } - if util.GetAuthKey(auth.token, pingMsg.Timestamp) != pingMsg.PrivilegeKey { + if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) { return fmt.Errorf("token in heartbeat doesn't match token from configuration") } return nil } -func (auth *TokenAuthSetterVerifier) VerifyNewWorkConn(newWorkConnMsg *msg.NewWorkConn) error { +func (auth *TokenAuthSetterVerifier) VerifyNewWorkConn(m *msg.NewWorkConn) error { if !auth.AuthenticateNewWorkConns { return nil } - if util.GetAuthKey(auth.token, newWorkConnMsg.Timestamp) != newWorkConnMsg.PrivilegeKey { + if !util.ConstantTimeEqString(util.GetAuthKey(auth.token, m.Timestamp), m.PrivilegeKey) { return fmt.Errorf("token in NewWorkConn doesn't match token from configuration") } return nil diff --git a/pkg/nathole/controller.go b/pkg/nathole/controller.go index 71feb1be..a04006b9 100644 --- a/pkg/nathole/controller.go +++ b/pkg/nathole/controller.go @@ -174,7 +174,7 @@ func (c *Controller) HandleVisitor(m *msg.NatHoleVisitor, transporter transport. if !ok { return fmt.Errorf("xtcp server for [%s] doesn't exist", m.ProxyName) } - if m.SignKey != util.GetAuthKey(clientCfg.sk, m.Timestamp) { + if !util.ConstantTimeEqString(m.SignKey, util.GetAuthKey(clientCfg.sk, m.Timestamp)) { return fmt.Errorf("xtcp connection of [%s] auth failed", m.ProxyName) } c.sessions[sid] = session diff --git a/pkg/plugin/client/http_proxy.go b/pkg/plugin/client/http_proxy.go index 78930e39..86542045 100644 --- a/pkg/plugin/client/http_proxy.go +++ b/pkg/plugin/client/http_proxy.go @@ -21,11 +21,13 @@ import ( "net" "net/http" "strings" + "time" frpIo "github.com/fatedier/golib/io" gnet "github.com/fatedier/golib/net" frpNet "github.com/fatedier/frp/pkg/util/net" + "github.com/fatedier/frp/pkg/util/util" ) const PluginHTTPProxy = "http_proxy" @@ -179,7 +181,9 @@ func (hp *HTTPProxy) Auth(req *http.Request) bool { return false } - if pair[0] != hp.AuthUser || pair[1] != hp.AuthPasswd { + if !util.ConstantTimeEqString(pair[0], hp.AuthUser) || + !util.ConstantTimeEqString(pair[1], hp.AuthPasswd) { + time.Sleep(200 * time.Millisecond) return false } return true diff --git a/pkg/plugin/client/static_file.go b/pkg/plugin/client/static_file.go index 4d31ea53..097af060 100644 --- a/pkg/plugin/client/static_file.go +++ b/pkg/plugin/client/static_file.go @@ -18,6 +18,7 @@ import ( "io" "net" "net/http" + "time" "github.com/gorilla/mux" @@ -64,7 +65,7 @@ func NewStaticFilePlugin(params map[string]string) (Plugin, error) { } router := mux.NewRouter() - router.Use(frpNet.NewHTTPAuthMiddleware(httpUser, httpPasswd).Middleware) + router.Use(frpNet.NewHTTPAuthMiddleware(httpUser, httpPasswd).SetAuthFailDelay(200 * time.Millisecond).Middleware) router.PathPrefix(prefix).Handler(frpNet.MakeHTTPGzipHandler(http.StripPrefix(prefix, http.FileServer(http.Dir(localPath))))).Methods("GET") sp.s = &http.Server{ Handler: router, diff --git a/pkg/util/net/http.go b/pkg/util/net/http.go index fa1c34af..1a7da23f 100644 --- a/pkg/util/net/http.go +++ b/pkg/util/net/http.go @@ -19,6 +19,9 @@ import ( "io" "net/http" "strings" + "time" + + "github.com/fatedier/frp/pkg/util/util" ) type HTTPAuthWraper struct { @@ -46,8 +49,9 @@ func (aw *HTTPAuthWraper) ServeHTTP(w http.ResponseWriter, r *http.Request) { } type HTTPAuthMiddleware struct { - user string - passwd string + user string + passwd string + authFailDelay time.Duration } func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware { @@ -57,32 +61,28 @@ func NewHTTPAuthMiddleware(user, passwd string) *HTTPAuthMiddleware { } } +func (authMid *HTTPAuthMiddleware) SetAuthFailDelay(delay time.Duration) *HTTPAuthMiddleware { + authMid.authFailDelay = delay + return authMid +} + func (authMid *HTTPAuthMiddleware) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reqUser, reqPasswd, hasAuth := r.BasicAuth() if (authMid.user == "" && authMid.passwd == "") || - (hasAuth && reqUser == authMid.user && reqPasswd == authMid.passwd) { + (hasAuth && util.ConstantTimeEqString(reqUser, authMid.user) && + util.ConstantTimeEqString(reqPasswd, authMid.passwd)) { next.ServeHTTP(w, r) } else { + if authMid.authFailDelay > 0 { + time.Sleep(authMid.authFailDelay) + } w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) } }) } -func HTTPBasicAuth(h http.HandlerFunc, user, passwd string) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - reqUser, reqPasswd, hasAuth := r.BasicAuth() - if (user == "" && passwd == "") || - (hasAuth && reqUser == user && reqPasswd == passwd) { - h.ServeHTTP(w, r) - } else { - w.Header().Set("WWW-Authenticate", `Basic realm="Restricted"`) - http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) - } - } -} - type HTTPGzipWraper struct { h http.Handler } diff --git a/pkg/util/util/util.go b/pkg/util/util/util.go index b437f3e3..d2562b01 100644 --- a/pkg/util/util/util.go +++ b/pkg/util/util/util.go @@ -17,6 +17,7 @@ package util import ( "crypto/md5" "crypto/rand" + "crypto/subtle" "encoding/hex" "fmt" mathrand "math/rand" @@ -139,3 +140,7 @@ func RandomSleep(duration time.Duration, minRatio, maxRatio float64) time.Durati time.Sleep(d) return d } + +func ConstantTimeEqString(a, b string) bool { + return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1 +} diff --git a/server/dashboard.go b/server/dashboard.go index 02c86325..ff941399 100644 --- a/server/dashboard.go +++ b/server/dashboard.go @@ -50,7 +50,7 @@ func (svr *Service) RunDashboardServer(address string) (err error) { subRouter := router.NewRoute().Subrouter() user, passwd := svr.cfg.DashboardUser, svr.cfg.DashboardPwd - subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).Middleware) + subRouter.Use(frpNet.NewHTTPAuthMiddleware(user, passwd).SetAuthFailDelay(200 * time.Millisecond).Middleware) // metrics if svr.cfg.EnablePrometheus { diff --git a/test/e2e/framework/framework.go b/test/e2e/framework/framework.go index b2f65f2f..6a7a655f 100644 --- a/test/e2e/framework/framework.go +++ b/test/e2e/framework/framework.go @@ -66,8 +66,8 @@ func NewDefaultFramework() *Framework { options := Options{ TotalParallelNode: suiteConfig.ParallelTotal, CurrentNodeIndex: suiteConfig.ParallelProcess, - FromPortIndex: 20000, - ToPortIndex: 50000, + FromPortIndex: 10000, + ToPortIndex: 60000, } return NewFramework(options) } @@ -118,14 +118,14 @@ func (f *Framework) AfterEach() { // stop processor for _, p := range f.serverProcesses { _ = p.Stop() - if TestContext.Debug { + if TestContext.Debug || ginkgo.CurrentSpecReport().Failed() { fmt.Println(p.ErrorOutput()) fmt.Println(p.StdOutput()) } } for _, p := range f.clientProcesses { _ = p.Stop() - if TestContext.Debug { + if TestContext.Debug || ginkgo.CurrentSpecReport().Failed() { fmt.Println(p.ErrorOutput()) fmt.Println(p.StdOutput()) } diff --git a/test/e2e/framework/process.go b/test/e2e/framework/process.go index 6c0eeea4..dba809dd 100644 --- a/test/e2e/framework/process.go +++ b/test/e2e/framework/process.go @@ -38,7 +38,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str err = p.Start() ExpectNoError(err) } - time.Sleep(2 * time.Second) + time.Sleep(1 * time.Second) currentClientProcesses := make([]*process.Process, 0, len(clientTemplates)) for i := range clientTemplates { @@ -56,7 +56,7 @@ func (f *Framework) RunProcesses(serverTemplates []string, clientTemplates []str ExpectNoError(err) time.Sleep(500 * time.Millisecond) } - time.Sleep(5 * time.Second) + time.Sleep(2 * time.Second) return currentServerProcesses, currentClientProcesses } diff --git a/test/e2e/pkg/port/port.go b/test/e2e/pkg/port/port.go index f1136942..b9bcccfc 100644 --- a/test/e2e/pkg/port/port.go +++ b/test/e2e/pkg/port/port.go @@ -58,7 +58,7 @@ func (pa *Allocator) GetByName(portName string) int { return 0 } - l, err := net.Listen("tcp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + l, err := net.Listen("tcp", net.JoinHostPort("0.0.0.0", strconv.Itoa(port))) if err != nil { // Maybe not controlled by us, mark it used. pa.used.Insert(port) @@ -66,7 +66,7 @@ func (pa *Allocator) GetByName(portName string) int { } l.Close() - udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("127.0.0.1", strconv.Itoa(port))) + udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.Itoa(port))) if err != nil { continue }