From 7c2e5d41f0feaae7b50ea0e69cf2767ba17d59df Mon Sep 17 00:00:00 2001 From: Stanislav Chzhen Date: Wed, 2 Jul 2025 23:52:16 +0300 Subject: [PATCH] home: imp code --- internal/home/auth.go | 4 +- internal/home/authhttp.go | 107 +++++++++++------------- internal/home/authhttp_internal_test.go | 102 ++++++++++++---------- internal/home/control.go | 2 +- internal/home/controlinstall.go | 2 +- internal/home/home.go | 2 +- internal/home/profilehttp.go | 12 ++- internal/home/web.go | 13 +-- 8 files changed, 128 insertions(+), 116 deletions(-) diff --git a/internal/home/auth.go b/internal/home/auth.go index 4f166e25..f0e074c0 100644 --- a/internal/home/auth.go +++ b/internal/home/auth.go @@ -179,8 +179,8 @@ func (a *auth) addUser(ctx context.Context, u *webUser, password string) (err er return nil } -// Close closes the authentication database. -func (a *auth) Close(ctx context.Context) { +// close closes the authentication database. +func (a *auth) close(ctx context.Context) { err := a.sessions.Close() if err != nil { a.logger.ErrorContext(ctx, "closing session storage", slogutil.KeyError, err) diff --git a/internal/home/authhttp.go b/internal/home/authhttp.go index efdc1f00..a3e0027a 100644 --- a/internal/home/authhttp.go +++ b/internal/home/authhttp.go @@ -319,6 +319,9 @@ type authMiddlewareDefaultConfig struct { rateLimiter loginRaateLimiter // trustedProxies is a set of subnets considered as trusted. + // + // TODO(s.chzhen): Use it not only to pass it to the middleware but also to + // log the work of the rate limiter. trustedProxies netutil.SubnetSet // sessions contains web user sessions. It must not be nil. @@ -332,9 +335,8 @@ type authMiddlewareDefaultConfig struct { // for a web client using an authentication cookie or basic auth credentials and // passes it with the context. type authMiddlewareDefault struct { - logger *slog.Logger - rateLimiter loginRaateLimiter - // TODO(s.chzhen): !! Use it. + logger *slog.Logger + rateLimiter loginRaateLimiter trustedProxies netutil.SubnetSet sessions aghuser.SessionStorage users aghuser.DB @@ -360,14 +362,13 @@ var _ httputil.Middleware = (*authMiddlewareDefault)(nil) func (mw *authMiddlewareDefault) Wrap(h http.Handler) (wrapped http.Handler) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - if !mw.needsAuthentication(ctx, r) { - h.ServeHTTP(w, r) - - return - } path := r.URL.Path u, err := mw.userFromRequest(ctx, r) + if err != nil { + mw.logger.ErrorContext(ctx, "retrieving user from request", slogutil.KeyError, err) + } + if u != nil { if path == "/login.html" { http.Redirect(w, r, "/", http.StatusFound) @@ -380,31 +381,29 @@ func (mw *authMiddlewareDefault) Wrap(h http.Handler) (wrapped http.Handler) { return } - if path == "/" || path == "index.html" { - http.Redirect(w, r, "login.html", http.StatusFound) + if isPublicResource(path) { + h.ServeHTTP(w, r) return } - if err != nil { - mw.logger.ErrorContext(ctx, "retrieving user from request", slogutil.KeyError, err) + if path == "/" || path == "/index.html" { + http.Redirect(w, r, "login.html", http.StatusFound) + + return } w.WriteHeader(http.StatusUnauthorized) }) } -// needsAuthentication returns true if the current request requires -// authentication. -func (mw *authMiddlewareDefault) needsAuthentication( +// userFromRequest tries to retrieve a user based on the request. r must not be +// nil. +func (mw *authMiddlewareDefault) userFromRequest( ctx context.Context, r *http.Request, -) (ok bool) { - path := r.URL.Path - - if isPublicResource(path) { - return false - } +) (u *aghuser.User, err error) { + defer func() { err = errors.Annotate(err, "getting user from request: %w") }() users, err := mw.users.All(ctx) if err != nil { @@ -413,47 +412,14 @@ func (mw *authMiddlewareDefault) needsAuthentication( } if len(users) == 0 { - return false + return nil, nil } - return true -} - -// userFromRequest tries to retrieve a user based on the request. -func (mw *authMiddlewareDefault) userFromRequest( - ctx context.Context, - r *http.Request, -) (u *aghuser.User, err error) { - defer func() { err = errors.Annotate(err, "getting user from request: %w") }() - cookie, err := r.Cookie(sessionCookieName) if err == nil { return mw.userFromCookie(ctx, cookie.Value) } - var remoteIP string - // realIP cannot be used here without taking TrustedProxies into account due - // to security issues. - // - // See https://github.com/AdguardTeam/AdGuardHome/issues/2799. - if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil { - return nil, fmt.Errorf("getting remote address: %w", err) - } - - rateLimiter := mw.rateLimiter - if left := rateLimiter.check(remoteIP); left > 0 { - return nil, fmt.Errorf("login attempt blocked for %s", left) - } - - rateLimiter.inc(remoteIP) - defer func() { - if err != nil { - return - } - - rateLimiter.remove(remoteIP) - }() - return mw.userFromRequestBasicAuth(ctx, r) } @@ -494,8 +460,7 @@ func sessionTokenFromHex(val string) (token aghuser.SessionToken, err error) { l := aghuser.SessionTokenLength - // TODO(s.chzhen): Use validate.Len. - err = validate.InRange("token length", len(sess), l, l) + err = validate.Equal("token length", l, len(sess)) if err != nil { // Don't wrap the error because it's informative enough as is. return token, err @@ -504,16 +469,40 @@ func sessionTokenFromHex(val string) (token aghuser.SessionToken, err error) { return aghuser.SessionToken(sess), nil } -// userFromRequestBasicAuth searches for a user using Basic Auth credentials. +// userFromRequestBasicAuth searches for a user using Basic Auth credentials. r +// must not be nil. func (mw *authMiddlewareDefault) userFromRequestBasicAuth( ctx context.Context, r *http.Request, ) (user *aghuser.User, err error) { login, pass, ok := r.BasicAuth() if !ok { - return nil, fmt.Errorf("credentials: %w", errors.ErrNoValue) + return nil, nil } + var remoteIP string + // realIP cannot be used here without taking TrustedProxies into account due + // to security issues. + // + // See https://github.com/AdguardTeam/AdGuardHome/issues/2799. + if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil { + return nil, fmt.Errorf("getting remote address: %w", err) + } + + rateLimiter := mw.rateLimiter + if left := rateLimiter.check(remoteIP); left > 0 { + return nil, fmt.Errorf("login attempt blocked for %s", left) + } + + rateLimiter.inc(remoteIP) + defer func() { + if err != nil { + return + } + + rateLimiter.remove(remoteIP) + }() + user, _ = mw.users.ByLogin(ctx, aghuser.Login(login)) if user == nil { return nil, errInvalidLogin diff --git a/internal/home/authhttp_internal_test.go b/internal/home/authhttp_internal_test.go index 2ececaee..93ffa8e8 100644 --- a/internal/home/authhttp_internal_test.go +++ b/internal/home/authhttp_internal_test.go @@ -165,41 +165,18 @@ func (h *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.user, _ = webUserFromContext(r.Context()) } -func TestAuthMiddlewareDefault_firstRun(t *testing.T) { - db := newTestUsersDB() - db.onAll = func(_ context.Context) (users []*aghuser.User, err error) { - return nil, nil - } - - mw := newAuthMiddlewareDefault(&authMiddlewareDefaultConfig{ - logger: testLogger, - rateLimiter: emptyRateLimiter{}, - sessions: &testSessionStorage{}, - users: db, - }) - - h := &testAuthHandler{} - wrapped := mw.Wrap(h) - - w := httptest.NewRecorder() - r := httptest.NewRequest(http.MethodGet, "/", nil) - wrapped.ServeHTTP(w, r) - - assert.Equal(t, http.StatusOK, w.Code) - assert.True(t, h.called) -} - func TestAuthMiddlewareDefault(t *testing.T) { t.Parallel() const ( - login aghuser.Login = "user_login" + loginStr = "user_login" + passwordStr = "user_password" - passwordRaw = "user_password" + login = aghuser.Login(loginStr) ) passwordHash, err := bcrypt.GenerateFromPassword( - []byte(passwordRaw), + []byte(passwordStr), bcrypt.DefaultCost, ) require.NoError(t, err) @@ -244,17 +221,8 @@ func TestAuthMiddlewareDefault(t *testing.T) { users: usersDB, }) - reqCookie := httptest.NewRequest(http.MethodGet, "/", nil) - reqCookie.AddCookie(&http.Cookie{Name: sessionCookieName, Value: tokenHex}) - - reqInvalidCookie := httptest.NewRequest(http.MethodGet, "/", nil) - reqInvalidCookie.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "invalid_cookie"}) - - reqBasicAuth := httptest.NewRequest(http.MethodGet, "/", nil) - reqBasicAuth.SetBasicAuth(string(login), passwordRaw) - - reqInvalidPassBasicAuth := httptest.NewRequest(http.MethodGet, "/", nil) - reqInvalidPassBasicAuth.SetBasicAuth(string(login), "invalid_password") + cookie := &http.Cookie{Name: sessionCookieName, Value: tokenHex} + invalidCookie := &http.Cookie{Name: sessionCookieName, Value: "123"} testCases := []struct { req *http.Request @@ -264,25 +232,55 @@ func TestAuthMiddlewareDefault(t *testing.T) { }{{ req: httptest.NewRequest(http.MethodGet, "/", nil), wantUser: nil, + name: "no_auth_root", + wantCode: http.StatusFound, + }, { + req: httptest.NewRequest(http.MethodGet, "/index.html", nil), + wantUser: nil, name: "no_auth", wantCode: http.StatusFound, }, { - req: reqCookie, + req: authRequest("/", invalidCookie, "", ""), + wantUser: nil, + name: "invalid_auth", + wantCode: http.StatusFound, + }, { + req: authRequest("/", cookie, "", ""), wantUser: user, name: "cookie", wantCode: http.StatusOK, }, { - req: reqBasicAuth, + req: authRequest("/login.html", cookie, "", ""), + wantUser: nil, + name: "redirect", + wantCode: http.StatusFound, + }, { + req: authRequest("/control/profile", cookie, "", ""), + wantUser: user, + name: "protected", + wantCode: http.StatusOK, + }, { + req: authRequest("/control/profile", invalidCookie, "", ""), + wantUser: nil, + name: "no_auth_protected", + wantCode: http.StatusUnauthorized, + }, { + req: httptest.NewRequest(http.MethodGet, "/control/login", nil), + wantUser: nil, + name: "public", + wantCode: http.StatusOK, + }, { + req: authRequest("/", nil, loginStr, passwordStr), wantUser: user, name: "basic_auth", wantCode: http.StatusOK, }, { - req: reqInvalidCookie, + req: authRequest("/", invalidCookie, "", ""), wantUser: nil, name: "invalid_cookie", wantCode: http.StatusFound, }, { - req: reqInvalidPassBasicAuth, + req: authRequest("/", nil, "invalid", "creds"), wantUser: nil, name: "invalid_basic_auth", wantCode: http.StatusFound, @@ -304,6 +302,22 @@ func TestAuthMiddlewareDefault(t *testing.T) { } } +// authRequest is a test helper function that returns a GET request configured +// with the provided credentials and path. +func authRequest(path string, c *http.Cookie, user, pass string) (r *http.Request) { + r = httptest.NewRequest(http.MethodGet, path, nil) + + if c != nil { + r.AddCookie(c) + } + + if user != "" { + r.SetBasicAuth(user, pass) + } + + return r +} + func TestAuth_ServeHTTP_firstRun(t *testing.T) { storeGlobals(t) @@ -458,7 +472,7 @@ func TestAuth_ServeHTTP_auth(t *testing.T) { }) require.NoError(t, err) - t.Cleanup(func() { auth.Close(testutil.ContextWithTimeout(t, testTimeout)) }) + t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) }) globalContext.mux = http.NewServeMux() @@ -606,7 +620,7 @@ func TestAuth_ServeHTTP_logout(t *testing.T) { }) require.NoError(t, err) - t.Cleanup(func() { auth.Close(testutil.ContextWithTimeout(t, testTimeout)) }) + t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) }) globalContext.mux = http.NewServeMux() diff --git a/internal/home/control.go b/internal/home/control.go index dcaa7023..45168805 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -177,7 +177,7 @@ func registerControlHandlers(web *webAPI) { httpRegister(http.MethodGet, "/control/status", web.handleStatus) httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage) httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) - httpRegister(http.MethodGet, "/control/profile", handleGetProfile) + httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile) httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile) // No auth is necessary for DoH/DoT configurations diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 590e40ad..2681e990 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -490,7 +490,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request // and with its own context, because it waits until all requests are handled // and will be blocked by it's own caller. go func(timeout time.Duration) { - shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), timeout) defer slogutil.RecoverAndLog(shutdownCtx, web.logger) defer cancel() diff --git a/internal/home/home.go b/internal/home/home.go index d2ec6752..5baaf2ea 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -812,7 +812,7 @@ func initUsers( blockDur := time.Duration(config.AuthBlockMin) * time.Minute rateLimiter = newAuthRateLimiter(blockDur, config.AuthAttempts) } else { - baseLogger.InfoContext(ctx, "authratelimiter is disabled") + baseLogger.WarnContext(ctx, "authratelimiter is disabled") rateLimiter = emptyRateLimiter{} } diff --git a/internal/home/profilehttp.go b/internal/home/profilehttp.go index 50048f19..8c3d6ef0 100644 --- a/internal/home/profilehttp.go +++ b/internal/home/profilehttp.go @@ -46,10 +46,16 @@ type profileJSON struct { } // handleGetProfile is the handler for GET /control/profile endpoint. -func handleGetProfile(w http.ResponseWriter, r *http.Request) { +func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) { var name string - u, ok := webUserFromContext(r.Context()) - if ok { + if !web.auth.isGLiNet { + u, ok := webUserFromContext(r.Context()) + if !ok { + w.WriteHeader(http.StatusUnauthorized) + + return + } + name = string(u.Login) } diff --git a/internal/home/web.go b/internal/home/web.go index b453d840..96d4852f 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -221,7 +221,10 @@ func (web *webAPI) start(ctx context.Context) { errs := make(chan error, 2) // Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies. - hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{}) + hdlr := h2c.NewHandler( + withMiddlewares(globalContext.mux, limitRequestBody), + &http2.Server{}, + ) logger := web.baseLogger.With(loggerKeyServer, "plain") @@ -232,7 +235,7 @@ func (web *webAPI) start(ctx context.Context) { // Create a new instance, because the Web is not usable after Shutdown. web.httpServer = &http.Server{ Addr: web.conf.BindAddr.String(), - Handler: globalContext.auth.middleware().Wrap(hdlr), + Handler: web.auth.middleware().Wrap(hdlr), ReadTimeout: web.conf.ReadTimeout, ReadHeaderTimeout: web.conf.ReadHeaderTimeout, WriteTimeout: web.conf.WriteTimeout, @@ -274,7 +277,7 @@ func (web *webAPI) close(ctx context.Context) { shutdownSrv(ctx, web.logger, web.httpServer) if web.auth != nil { - web.auth.Close(ctx) + web.auth.close(ctx) } web.logger.InfoContext(ctx, "stopped http server") @@ -318,7 +321,7 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) { web.httpsServer.server = &http.Server{ Addr: addr, - Handler: globalContext.auth.middleware().Wrap(hdlr), + Handler: web.auth.middleware().Wrap(hdlr), TLSConfig: &tls.Config{ Certificates: []tls.Certificate{web.httpsServer.cert}, RootCAs: web.tlsManager.rootCerts, @@ -359,7 +362,7 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) { CipherSuites: web.tlsManager.customCipherIDs, MinVersion: tls.VersionTLS12, }, - Handler: globalContext.auth.middleware().Wrap(withMiddlewares(globalContext.mux, limitRequestBody)), + Handler: web.auth.middleware().Wrap(withMiddlewares(globalContext.mux, limitRequestBody)), } web.logger.DebugContext(ctx, "starting http/3 server")