home: imp code

This commit is contained in:
Stanislav Chzhen 2025-07-02 23:52:16 +03:00
parent b90031495a
commit 7c2e5d41f0
8 changed files with 128 additions and 116 deletions

View File

@ -179,8 +179,8 @@ func (a *auth) addUser(ctx context.Context, u *webUser, password string) (err er
return nil
}
// Close closes the authentication database.
func (a *auth) Close(ctx context.Context) {
// close closes the authentication database.
func (a *auth) close(ctx context.Context) {
err := a.sessions.Close()
if err != nil {
a.logger.ErrorContext(ctx, "closing session storage", slogutil.KeyError, err)

View File

@ -319,6 +319,9 @@ type authMiddlewareDefaultConfig struct {
rateLimiter loginRaateLimiter
// trustedProxies is a set of subnets considered as trusted.
//
// TODO(s.chzhen): Use it not only to pass it to the middleware but also to
// log the work of the rate limiter.
trustedProxies netutil.SubnetSet
// sessions contains web user sessions. It must not be nil.
@ -332,9 +335,8 @@ type authMiddlewareDefaultConfig struct {
// for a web client using an authentication cookie or basic auth credentials and
// passes it with the context.
type authMiddlewareDefault struct {
logger *slog.Logger
rateLimiter loginRaateLimiter
// TODO(s.chzhen): !! Use it.
logger *slog.Logger
rateLimiter loginRaateLimiter
trustedProxies netutil.SubnetSet
sessions aghuser.SessionStorage
users aghuser.DB
@ -360,14 +362,13 @@ var _ httputil.Middleware = (*authMiddlewareDefault)(nil)
func (mw *authMiddlewareDefault) Wrap(h http.Handler) (wrapped http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !mw.needsAuthentication(ctx, r) {
h.ServeHTTP(w, r)
return
}
path := r.URL.Path
u, err := mw.userFromRequest(ctx, r)
if err != nil {
mw.logger.ErrorContext(ctx, "retrieving user from request", slogutil.KeyError, err)
}
if u != nil {
if path == "/login.html" {
http.Redirect(w, r, "/", http.StatusFound)
@ -380,31 +381,29 @@ func (mw *authMiddlewareDefault) Wrap(h http.Handler) (wrapped http.Handler) {
return
}
if path == "/" || path == "index.html" {
http.Redirect(w, r, "login.html", http.StatusFound)
if isPublicResource(path) {
h.ServeHTTP(w, r)
return
}
if err != nil {
mw.logger.ErrorContext(ctx, "retrieving user from request", slogutil.KeyError, err)
if path == "/" || path == "/index.html" {
http.Redirect(w, r, "login.html", http.StatusFound)
return
}
w.WriteHeader(http.StatusUnauthorized)
})
}
// needsAuthentication returns true if the current request requires
// authentication.
func (mw *authMiddlewareDefault) needsAuthentication(
// userFromRequest tries to retrieve a user based on the request. r must not be
// nil.
func (mw *authMiddlewareDefault) userFromRequest(
ctx context.Context,
r *http.Request,
) (ok bool) {
path := r.URL.Path
if isPublicResource(path) {
return false
}
) (u *aghuser.User, err error) {
defer func() { err = errors.Annotate(err, "getting user from request: %w") }()
users, err := mw.users.All(ctx)
if err != nil {
@ -413,47 +412,14 @@ func (mw *authMiddlewareDefault) needsAuthentication(
}
if len(users) == 0 {
return false
return nil, nil
}
return true
}
// userFromRequest tries to retrieve a user based on the request.
func (mw *authMiddlewareDefault) userFromRequest(
ctx context.Context,
r *http.Request,
) (u *aghuser.User, err error) {
defer func() { err = errors.Annotate(err, "getting user from request: %w") }()
cookie, err := r.Cookie(sessionCookieName)
if err == nil {
return mw.userFromCookie(ctx, cookie.Value)
}
var remoteIP string
// realIP cannot be used here without taking TrustedProxies into account due
// to security issues.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
return nil, fmt.Errorf("getting remote address: %w", err)
}
rateLimiter := mw.rateLimiter
if left := rateLimiter.check(remoteIP); left > 0 {
return nil, fmt.Errorf("login attempt blocked for %s", left)
}
rateLimiter.inc(remoteIP)
defer func() {
if err != nil {
return
}
rateLimiter.remove(remoteIP)
}()
return mw.userFromRequestBasicAuth(ctx, r)
}
@ -494,8 +460,7 @@ func sessionTokenFromHex(val string) (token aghuser.SessionToken, err error) {
l := aghuser.SessionTokenLength
// TODO(s.chzhen): Use validate.Len.
err = validate.InRange("token length", len(sess), l, l)
err = validate.Equal("token length", l, len(sess))
if err != nil {
// Don't wrap the error because it's informative enough as is.
return token, err
@ -504,16 +469,40 @@ func sessionTokenFromHex(val string) (token aghuser.SessionToken, err error) {
return aghuser.SessionToken(sess), nil
}
// userFromRequestBasicAuth searches for a user using Basic Auth credentials.
// userFromRequestBasicAuth searches for a user using Basic Auth credentials. r
// must not be nil.
func (mw *authMiddlewareDefault) userFromRequestBasicAuth(
ctx context.Context,
r *http.Request,
) (user *aghuser.User, err error) {
login, pass, ok := r.BasicAuth()
if !ok {
return nil, fmt.Errorf("credentials: %w", errors.ErrNoValue)
return nil, nil
}
var remoteIP string
// realIP cannot be used here without taking TrustedProxies into account due
// to security issues.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
return nil, fmt.Errorf("getting remote address: %w", err)
}
rateLimiter := mw.rateLimiter
if left := rateLimiter.check(remoteIP); left > 0 {
return nil, fmt.Errorf("login attempt blocked for %s", left)
}
rateLimiter.inc(remoteIP)
defer func() {
if err != nil {
return
}
rateLimiter.remove(remoteIP)
}()
user, _ = mw.users.ByLogin(ctx, aghuser.Login(login))
if user == nil {
return nil, errInvalidLogin

View File

@ -165,41 +165,18 @@ func (h *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.user, _ = webUserFromContext(r.Context())
}
func TestAuthMiddlewareDefault_firstRun(t *testing.T) {
db := newTestUsersDB()
db.onAll = func(_ context.Context) (users []*aghuser.User, err error) {
return nil, nil
}
mw := newAuthMiddlewareDefault(&authMiddlewareDefaultConfig{
logger: testLogger,
rateLimiter: emptyRateLimiter{},
sessions: &testSessionStorage{},
users: db,
})
h := &testAuthHandler{}
wrapped := mw.Wrap(h)
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
wrapped.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.True(t, h.called)
}
func TestAuthMiddlewareDefault(t *testing.T) {
t.Parallel()
const (
login aghuser.Login = "user_login"
loginStr = "user_login"
passwordStr = "user_password"
passwordRaw = "user_password"
login = aghuser.Login(loginStr)
)
passwordHash, err := bcrypt.GenerateFromPassword(
[]byte(passwordRaw),
[]byte(passwordStr),
bcrypt.DefaultCost,
)
require.NoError(t, err)
@ -244,17 +221,8 @@ func TestAuthMiddlewareDefault(t *testing.T) {
users: usersDB,
})
reqCookie := httptest.NewRequest(http.MethodGet, "/", nil)
reqCookie.AddCookie(&http.Cookie{Name: sessionCookieName, Value: tokenHex})
reqInvalidCookie := httptest.NewRequest(http.MethodGet, "/", nil)
reqInvalidCookie.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "invalid_cookie"})
reqBasicAuth := httptest.NewRequest(http.MethodGet, "/", nil)
reqBasicAuth.SetBasicAuth(string(login), passwordRaw)
reqInvalidPassBasicAuth := httptest.NewRequest(http.MethodGet, "/", nil)
reqInvalidPassBasicAuth.SetBasicAuth(string(login), "invalid_password")
cookie := &http.Cookie{Name: sessionCookieName, Value: tokenHex}
invalidCookie := &http.Cookie{Name: sessionCookieName, Value: "123"}
testCases := []struct {
req *http.Request
@ -264,25 +232,55 @@ func TestAuthMiddlewareDefault(t *testing.T) {
}{{
req: httptest.NewRequest(http.MethodGet, "/", nil),
wantUser: nil,
name: "no_auth_root",
wantCode: http.StatusFound,
}, {
req: httptest.NewRequest(http.MethodGet, "/index.html", nil),
wantUser: nil,
name: "no_auth",
wantCode: http.StatusFound,
}, {
req: reqCookie,
req: authRequest("/", invalidCookie, "", ""),
wantUser: nil,
name: "invalid_auth",
wantCode: http.StatusFound,
}, {
req: authRequest("/", cookie, "", ""),
wantUser: user,
name: "cookie",
wantCode: http.StatusOK,
}, {
req: reqBasicAuth,
req: authRequest("/login.html", cookie, "", ""),
wantUser: nil,
name: "redirect",
wantCode: http.StatusFound,
}, {
req: authRequest("/control/profile", cookie, "", ""),
wantUser: user,
name: "protected",
wantCode: http.StatusOK,
}, {
req: authRequest("/control/profile", invalidCookie, "", ""),
wantUser: nil,
name: "no_auth_protected",
wantCode: http.StatusUnauthorized,
}, {
req: httptest.NewRequest(http.MethodGet, "/control/login", nil),
wantUser: nil,
name: "public",
wantCode: http.StatusOK,
}, {
req: authRequest("/", nil, loginStr, passwordStr),
wantUser: user,
name: "basic_auth",
wantCode: http.StatusOK,
}, {
req: reqInvalidCookie,
req: authRequest("/", invalidCookie, "", ""),
wantUser: nil,
name: "invalid_cookie",
wantCode: http.StatusFound,
}, {
req: reqInvalidPassBasicAuth,
req: authRequest("/", nil, "invalid", "creds"),
wantUser: nil,
name: "invalid_basic_auth",
wantCode: http.StatusFound,
@ -304,6 +302,22 @@ func TestAuthMiddlewareDefault(t *testing.T) {
}
}
// authRequest is a test helper function that returns a GET request configured
// with the provided credentials and path.
func authRequest(path string, c *http.Cookie, user, pass string) (r *http.Request) {
r = httptest.NewRequest(http.MethodGet, path, nil)
if c != nil {
r.AddCookie(c)
}
if user != "" {
r.SetBasicAuth(user, pass)
}
return r
}
func TestAuth_ServeHTTP_firstRun(t *testing.T) {
storeGlobals(t)
@ -458,7 +472,7 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
})
require.NoError(t, err)
t.Cleanup(func() { auth.Close(testutil.ContextWithTimeout(t, testTimeout)) })
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
globalContext.mux = http.NewServeMux()
@ -606,7 +620,7 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
})
require.NoError(t, err)
t.Cleanup(func() { auth.Close(testutil.ContextWithTimeout(t, testTimeout)) })
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
globalContext.mux = http.NewServeMux()

View File

@ -177,7 +177,7 @@ func registerControlHandlers(web *webAPI) {
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
// No auth is necessary for DoH/DoT configurations

View File

@ -490,7 +490,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
// and with its own context, because it waits until all requests are handled
// and will be blocked by it's own caller.
go func(timeout time.Duration) {
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), timeout)
defer slogutil.RecoverAndLog(shutdownCtx, web.logger)
defer cancel()

View File

@ -812,7 +812,7 @@ func initUsers(
blockDur := time.Duration(config.AuthBlockMin) * time.Minute
rateLimiter = newAuthRateLimiter(blockDur, config.AuthAttempts)
} else {
baseLogger.InfoContext(ctx, "authratelimiter is disabled")
baseLogger.WarnContext(ctx, "authratelimiter is disabled")
rateLimiter = emptyRateLimiter{}
}

View File

@ -46,10 +46,16 @@ type profileJSON struct {
}
// handleGetProfile is the handler for GET /control/profile endpoint.
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) {
var name string
u, ok := webUserFromContext(r.Context())
if ok {
if !web.auth.isGLiNet {
u, ok := webUserFromContext(r.Context())
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
name = string(u.Login)
}

View File

@ -221,7 +221,10 @@ func (web *webAPI) start(ctx context.Context) {
errs := make(chan error, 2)
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{})
hdlr := h2c.NewHandler(
withMiddlewares(globalContext.mux, limitRequestBody),
&http2.Server{},
)
logger := web.baseLogger.With(loggerKeyServer, "plain")
@ -232,7 +235,7 @@ func (web *webAPI) start(ctx context.Context) {
// Create a new instance, because the Web is not usable after Shutdown.
web.httpServer = &http.Server{
Addr: web.conf.BindAddr.String(),
Handler: globalContext.auth.middleware().Wrap(hdlr),
Handler: web.auth.middleware().Wrap(hdlr),
ReadTimeout: web.conf.ReadTimeout,
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
WriteTimeout: web.conf.WriteTimeout,
@ -274,7 +277,7 @@ func (web *webAPI) close(ctx context.Context) {
shutdownSrv(ctx, web.logger, web.httpServer)
if web.auth != nil {
web.auth.Close(ctx)
web.auth.close(ctx)
}
web.logger.InfoContext(ctx, "stopped http server")
@ -318,7 +321,7 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
web.httpsServer.server = &http.Server{
Addr: addr,
Handler: globalContext.auth.middleware().Wrap(hdlr),
Handler: web.auth.middleware().Wrap(hdlr),
TLSConfig: &tls.Config{
Certificates: []tls.Certificate{web.httpsServer.cert},
RootCAs: web.tlsManager.rootCerts,
@ -359,7 +362,7 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
CipherSuites: web.tlsManager.customCipherIDs,
MinVersion: tls.VersionTLS12,
},
Handler: globalContext.auth.middleware().Wrap(withMiddlewares(globalContext.mux, limitRequestBody)),
Handler: web.auth.middleware().Wrap(withMiddlewares(globalContext.mux, limitRequestBody)),
}
web.logger.DebugContext(ctx, "starting http/3 server")