mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-10-26 11:27:18 +00:00
home: imp code
This commit is contained in:
parent
b90031495a
commit
7c2e5d41f0
@ -179,8 +179,8 @@ func (a *auth) addUser(ctx context.Context, u *webUser, password string) (err er
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the authentication database.
|
||||
func (a *auth) Close(ctx context.Context) {
|
||||
// close closes the authentication database.
|
||||
func (a *auth) close(ctx context.Context) {
|
||||
err := a.sessions.Close()
|
||||
if err != nil {
|
||||
a.logger.ErrorContext(ctx, "closing session storage", slogutil.KeyError, err)
|
||||
|
||||
@ -319,6 +319,9 @@ type authMiddlewareDefaultConfig struct {
|
||||
rateLimiter loginRaateLimiter
|
||||
|
||||
// trustedProxies is a set of subnets considered as trusted.
|
||||
//
|
||||
// TODO(s.chzhen): Use it not only to pass it to the middleware but also to
|
||||
// log the work of the rate limiter.
|
||||
trustedProxies netutil.SubnetSet
|
||||
|
||||
// sessions contains web user sessions. It must not be nil.
|
||||
@ -332,9 +335,8 @@ type authMiddlewareDefaultConfig struct {
|
||||
// for a web client using an authentication cookie or basic auth credentials and
|
||||
// passes it with the context.
|
||||
type authMiddlewareDefault struct {
|
||||
logger *slog.Logger
|
||||
rateLimiter loginRaateLimiter
|
||||
// TODO(s.chzhen): !! Use it.
|
||||
logger *slog.Logger
|
||||
rateLimiter loginRaateLimiter
|
||||
trustedProxies netutil.SubnetSet
|
||||
sessions aghuser.SessionStorage
|
||||
users aghuser.DB
|
||||
@ -360,14 +362,13 @@ var _ httputil.Middleware = (*authMiddlewareDefault)(nil)
|
||||
func (mw *authMiddlewareDefault) Wrap(h http.Handler) (wrapped http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !mw.needsAuthentication(ctx, r) {
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
u, err := mw.userFromRequest(ctx, r)
|
||||
if err != nil {
|
||||
mw.logger.ErrorContext(ctx, "retrieving user from request", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
if u != nil {
|
||||
if path == "/login.html" {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
@ -380,31 +381,29 @@ func (mw *authMiddlewareDefault) Wrap(h http.Handler) (wrapped http.Handler) {
|
||||
return
|
||||
}
|
||||
|
||||
if path == "/" || path == "index.html" {
|
||||
http.Redirect(w, r, "login.html", http.StatusFound)
|
||||
if isPublicResource(path) {
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
mw.logger.ErrorContext(ctx, "retrieving user from request", slogutil.KeyError, err)
|
||||
if path == "/" || path == "/index.html" {
|
||||
http.Redirect(w, r, "login.html", http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
})
|
||||
}
|
||||
|
||||
// needsAuthentication returns true if the current request requires
|
||||
// authentication.
|
||||
func (mw *authMiddlewareDefault) needsAuthentication(
|
||||
// userFromRequest tries to retrieve a user based on the request. r must not be
|
||||
// nil.
|
||||
func (mw *authMiddlewareDefault) userFromRequest(
|
||||
ctx context.Context,
|
||||
r *http.Request,
|
||||
) (ok bool) {
|
||||
path := r.URL.Path
|
||||
|
||||
if isPublicResource(path) {
|
||||
return false
|
||||
}
|
||||
) (u *aghuser.User, err error) {
|
||||
defer func() { err = errors.Annotate(err, "getting user from request: %w") }()
|
||||
|
||||
users, err := mw.users.All(ctx)
|
||||
if err != nil {
|
||||
@ -413,47 +412,14 @@ func (mw *authMiddlewareDefault) needsAuthentication(
|
||||
}
|
||||
|
||||
if len(users) == 0 {
|
||||
return false
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// userFromRequest tries to retrieve a user based on the request.
|
||||
func (mw *authMiddlewareDefault) userFromRequest(
|
||||
ctx context.Context,
|
||||
r *http.Request,
|
||||
) (u *aghuser.User, err error) {
|
||||
defer func() { err = errors.Annotate(err, "getting user from request: %w") }()
|
||||
|
||||
cookie, err := r.Cookie(sessionCookieName)
|
||||
if err == nil {
|
||||
return mw.userFromCookie(ctx, cookie.Value)
|
||||
}
|
||||
|
||||
var remoteIP string
|
||||
// realIP cannot be used here without taking TrustedProxies into account due
|
||||
// to security issues.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
return nil, fmt.Errorf("getting remote address: %w", err)
|
||||
}
|
||||
|
||||
rateLimiter := mw.rateLimiter
|
||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||
return nil, fmt.Errorf("login attempt blocked for %s", left)
|
||||
}
|
||||
|
||||
rateLimiter.inc(remoteIP)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rateLimiter.remove(remoteIP)
|
||||
}()
|
||||
|
||||
return mw.userFromRequestBasicAuth(ctx, r)
|
||||
}
|
||||
|
||||
@ -494,8 +460,7 @@ func sessionTokenFromHex(val string) (token aghuser.SessionToken, err error) {
|
||||
|
||||
l := aghuser.SessionTokenLength
|
||||
|
||||
// TODO(s.chzhen): Use validate.Len.
|
||||
err = validate.InRange("token length", len(sess), l, l)
|
||||
err = validate.Equal("token length", l, len(sess))
|
||||
if err != nil {
|
||||
// Don't wrap the error because it's informative enough as is.
|
||||
return token, err
|
||||
@ -504,16 +469,40 @@ func sessionTokenFromHex(val string) (token aghuser.SessionToken, err error) {
|
||||
return aghuser.SessionToken(sess), nil
|
||||
}
|
||||
|
||||
// userFromRequestBasicAuth searches for a user using Basic Auth credentials.
|
||||
// userFromRequestBasicAuth searches for a user using Basic Auth credentials. r
|
||||
// must not be nil.
|
||||
func (mw *authMiddlewareDefault) userFromRequestBasicAuth(
|
||||
ctx context.Context,
|
||||
r *http.Request,
|
||||
) (user *aghuser.User, err error) {
|
||||
login, pass, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("credentials: %w", errors.ErrNoValue)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var remoteIP string
|
||||
// realIP cannot be used here without taking TrustedProxies into account due
|
||||
// to security issues.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/2799.
|
||||
if remoteIP, err = netutil.SplitHost(r.RemoteAddr); err != nil {
|
||||
return nil, fmt.Errorf("getting remote address: %w", err)
|
||||
}
|
||||
|
||||
rateLimiter := mw.rateLimiter
|
||||
if left := rateLimiter.check(remoteIP); left > 0 {
|
||||
return nil, fmt.Errorf("login attempt blocked for %s", left)
|
||||
}
|
||||
|
||||
rateLimiter.inc(remoteIP)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rateLimiter.remove(remoteIP)
|
||||
}()
|
||||
|
||||
user, _ = mw.users.ByLogin(ctx, aghuser.Login(login))
|
||||
if user == nil {
|
||||
return nil, errInvalidLogin
|
||||
|
||||
@ -165,41 +165,18 @@ func (h *testAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.user, _ = webUserFromContext(r.Context())
|
||||
}
|
||||
|
||||
func TestAuthMiddlewareDefault_firstRun(t *testing.T) {
|
||||
db := newTestUsersDB()
|
||||
db.onAll = func(_ context.Context) (users []*aghuser.User, err error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mw := newAuthMiddlewareDefault(&authMiddlewareDefaultConfig{
|
||||
logger: testLogger,
|
||||
rateLimiter: emptyRateLimiter{},
|
||||
sessions: &testSessionStorage{},
|
||||
users: db,
|
||||
})
|
||||
|
||||
h := &testAuthHandler{}
|
||||
wrapped := mw.Wrap(h)
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
wrapped.ServeHTTP(w, r)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
assert.True(t, h.called)
|
||||
}
|
||||
|
||||
func TestAuthMiddlewareDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
login aghuser.Login = "user_login"
|
||||
loginStr = "user_login"
|
||||
passwordStr = "user_password"
|
||||
|
||||
passwordRaw = "user_password"
|
||||
login = aghuser.Login(loginStr)
|
||||
)
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword(
|
||||
[]byte(passwordRaw),
|
||||
[]byte(passwordStr),
|
||||
bcrypt.DefaultCost,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@ -244,17 +221,8 @@ func TestAuthMiddlewareDefault(t *testing.T) {
|
||||
users: usersDB,
|
||||
})
|
||||
|
||||
reqCookie := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
reqCookie.AddCookie(&http.Cookie{Name: sessionCookieName, Value: tokenHex})
|
||||
|
||||
reqInvalidCookie := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
reqInvalidCookie.AddCookie(&http.Cookie{Name: sessionCookieName, Value: "invalid_cookie"})
|
||||
|
||||
reqBasicAuth := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
reqBasicAuth.SetBasicAuth(string(login), passwordRaw)
|
||||
|
||||
reqInvalidPassBasicAuth := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
reqInvalidPassBasicAuth.SetBasicAuth(string(login), "invalid_password")
|
||||
cookie := &http.Cookie{Name: sessionCookieName, Value: tokenHex}
|
||||
invalidCookie := &http.Cookie{Name: sessionCookieName, Value: "123"}
|
||||
|
||||
testCases := []struct {
|
||||
req *http.Request
|
||||
@ -264,25 +232,55 @@ func TestAuthMiddlewareDefault(t *testing.T) {
|
||||
}{{
|
||||
req: httptest.NewRequest(http.MethodGet, "/", nil),
|
||||
wantUser: nil,
|
||||
name: "no_auth_root",
|
||||
wantCode: http.StatusFound,
|
||||
}, {
|
||||
req: httptest.NewRequest(http.MethodGet, "/index.html", nil),
|
||||
wantUser: nil,
|
||||
name: "no_auth",
|
||||
wantCode: http.StatusFound,
|
||||
}, {
|
||||
req: reqCookie,
|
||||
req: authRequest("/", invalidCookie, "", ""),
|
||||
wantUser: nil,
|
||||
name: "invalid_auth",
|
||||
wantCode: http.StatusFound,
|
||||
}, {
|
||||
req: authRequest("/", cookie, "", ""),
|
||||
wantUser: user,
|
||||
name: "cookie",
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
req: reqBasicAuth,
|
||||
req: authRequest("/login.html", cookie, "", ""),
|
||||
wantUser: nil,
|
||||
name: "redirect",
|
||||
wantCode: http.StatusFound,
|
||||
}, {
|
||||
req: authRequest("/control/profile", cookie, "", ""),
|
||||
wantUser: user,
|
||||
name: "protected",
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
req: authRequest("/control/profile", invalidCookie, "", ""),
|
||||
wantUser: nil,
|
||||
name: "no_auth_protected",
|
||||
wantCode: http.StatusUnauthorized,
|
||||
}, {
|
||||
req: httptest.NewRequest(http.MethodGet, "/control/login", nil),
|
||||
wantUser: nil,
|
||||
name: "public",
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
req: authRequest("/", nil, loginStr, passwordStr),
|
||||
wantUser: user,
|
||||
name: "basic_auth",
|
||||
wantCode: http.StatusOK,
|
||||
}, {
|
||||
req: reqInvalidCookie,
|
||||
req: authRequest("/", invalidCookie, "", ""),
|
||||
wantUser: nil,
|
||||
name: "invalid_cookie",
|
||||
wantCode: http.StatusFound,
|
||||
}, {
|
||||
req: reqInvalidPassBasicAuth,
|
||||
req: authRequest("/", nil, "invalid", "creds"),
|
||||
wantUser: nil,
|
||||
name: "invalid_basic_auth",
|
||||
wantCode: http.StatusFound,
|
||||
@ -304,6 +302,22 @@ func TestAuthMiddlewareDefault(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// authRequest is a test helper function that returns a GET request configured
|
||||
// with the provided credentials and path.
|
||||
func authRequest(path string, c *http.Cookie, user, pass string) (r *http.Request) {
|
||||
r = httptest.NewRequest(http.MethodGet, path, nil)
|
||||
|
||||
if c != nil {
|
||||
r.AddCookie(c)
|
||||
}
|
||||
|
||||
if user != "" {
|
||||
r.SetBasicAuth(user, pass)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
func TestAuth_ServeHTTP_firstRun(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
@ -458,7 +472,7 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() { auth.Close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
@ -606,7 +620,7 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() { auth.Close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ func registerControlHandlers(web *webAPI) {
|
||||
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
|
||||
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
|
||||
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
|
||||
httpRegister(http.MethodGet, "/control/profile", handleGetProfile)
|
||||
httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile)
|
||||
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
|
||||
|
||||
// No auth is necessary for DoH/DoT configurations
|
||||
|
||||
@ -490,7 +490,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// and with its own context, because it waits until all requests are handled
|
||||
// and will be blocked by it's own caller.
|
||||
go func(timeout time.Duration) {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), timeout)
|
||||
defer slogutil.RecoverAndLog(shutdownCtx, web.logger)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@ -812,7 +812,7 @@ func initUsers(
|
||||
blockDur := time.Duration(config.AuthBlockMin) * time.Minute
|
||||
rateLimiter = newAuthRateLimiter(blockDur, config.AuthAttempts)
|
||||
} else {
|
||||
baseLogger.InfoContext(ctx, "authratelimiter is disabled")
|
||||
baseLogger.WarnContext(ctx, "authratelimiter is disabled")
|
||||
rateLimiter = emptyRateLimiter{}
|
||||
}
|
||||
|
||||
|
||||
@ -46,10 +46,16 @@ type profileJSON struct {
|
||||
}
|
||||
|
||||
// handleGetProfile is the handler for GET /control/profile endpoint.
|
||||
func handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
var name string
|
||||
u, ok := webUserFromContext(r.Context())
|
||||
if ok {
|
||||
if !web.auth.isGLiNet {
|
||||
u, ok := webUserFromContext(r.Context())
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
name = string(u.Login)
|
||||
}
|
||||
|
||||
|
||||
@ -221,7 +221,10 @@ func (web *webAPI) start(ctx context.Context) {
|
||||
errs := make(chan error, 2)
|
||||
|
||||
// Use an h2c handler to support unencrypted HTTP/2, e.g. for proxies.
|
||||
hdlr := h2c.NewHandler(withMiddlewares(globalContext.mux, limitRequestBody), &http2.Server{})
|
||||
hdlr := h2c.NewHandler(
|
||||
withMiddlewares(globalContext.mux, limitRequestBody),
|
||||
&http2.Server{},
|
||||
)
|
||||
|
||||
logger := web.baseLogger.With(loggerKeyServer, "plain")
|
||||
|
||||
@ -232,7 +235,7 @@ func (web *webAPI) start(ctx context.Context) {
|
||||
// Create a new instance, because the Web is not usable after Shutdown.
|
||||
web.httpServer = &http.Server{
|
||||
Addr: web.conf.BindAddr.String(),
|
||||
Handler: globalContext.auth.middleware().Wrap(hdlr),
|
||||
Handler: web.auth.middleware().Wrap(hdlr),
|
||||
ReadTimeout: web.conf.ReadTimeout,
|
||||
ReadHeaderTimeout: web.conf.ReadHeaderTimeout,
|
||||
WriteTimeout: web.conf.WriteTimeout,
|
||||
@ -274,7 +277,7 @@ func (web *webAPI) close(ctx context.Context) {
|
||||
shutdownSrv(ctx, web.logger, web.httpServer)
|
||||
|
||||
if web.auth != nil {
|
||||
web.auth.Close(ctx)
|
||||
web.auth.close(ctx)
|
||||
}
|
||||
|
||||
web.logger.InfoContext(ctx, "stopped http server")
|
||||
@ -318,7 +321,7 @@ func (web *webAPI) tlsServerLoop(ctx context.Context) {
|
||||
|
||||
web.httpsServer.server = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: globalContext.auth.middleware().Wrap(hdlr),
|
||||
Handler: web.auth.middleware().Wrap(hdlr),
|
||||
TLSConfig: &tls.Config{
|
||||
Certificates: []tls.Certificate{web.httpsServer.cert},
|
||||
RootCAs: web.tlsManager.rootCerts,
|
||||
@ -359,7 +362,7 @@ func (web *webAPI) mustStartHTTP3(ctx context.Context, address string) {
|
||||
CipherSuites: web.tlsManager.customCipherIDs,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
Handler: globalContext.auth.middleware().Wrap(withMiddlewares(globalContext.mux, limitRequestBody)),
|
||||
Handler: web.auth.middleware().Wrap(withMiddlewares(globalContext.mux, limitRequestBody)),
|
||||
}
|
||||
|
||||
web.logger.DebugContext(ctx, "starting http/3 server")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user