mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-10-26 11:27:18 +00:00
Pull request 2496: AGDNS-3224-aghhttp-register-slog
Squashed commit of the following: commit 9324a0066202f1677bfd033d40d3a82fa9756ed9 Merge:8a1b5cad4f9da40e39Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Oct 23 17:48:01 2025 +0300 Merge branch 'master' into AGDNS-3224-aghhttp-register-slog commit8a1b5cad4cAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Oct 21 15:51:48 2025 +0300 filtering: imp code commitfe569166efMerge:9a101a2f59be4ca90eAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Oct 21 15:45:42 2025 +0300 Merge branch 'master' into AGDNS-3224-aghhttp-register-slog commit9a101a2f5fAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Oct 15 18:52:22 2025 +0300 home: imp code commit727e1663baAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Oct 15 10:19:56 2025 +0300 all: imp code commit113a9017dfAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Oct 13 23:10:06 2025 +0300 home: fix typo commit6588dd2dadAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Oct 13 22:46:28 2025 +0300 all: imp naming commit44278505a9Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Oct 10 16:20:17 2025 +0300 home: fix typo commit7b4b57628bAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Oct 10 15:58:07 2025 +0300 all: web mw commit93168142cbAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Oct 8 22:20:07 2025 +0300 all: aghhttp slog commit9155edef67Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Oct 8 15:38:01 2025 +0300 aghhttp: registrar commita356473855Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Oct 7 15:32:30 2025 +0300 all: http registrar
This commit is contained in:
parent
f9da40e393
commit
5c9fef62f1
@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
@ -27,15 +26,6 @@ func OK(ctx context.Context, l *slog.Logger, w http.ResponseWriter) {
|
||||
}
|
||||
}
|
||||
|
||||
// Error writes formatted message to w and also logs it.
|
||||
//
|
||||
// TODO(s.chzhen): Remove it.
|
||||
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
|
||||
text := fmt.Sprintf(format, args...)
|
||||
log.Error("%s %s %s: %s", r.Method, r.Host, r.URL, text)
|
||||
http.Error(w, text, code)
|
||||
}
|
||||
|
||||
// ErrorAndLog writes a formatted HTTP error response and logs it at
|
||||
// [slog.LevelError] level. l, r, and w must not be nil.
|
||||
func ErrorAndLog(
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
package aghhttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
)
|
||||
|
||||
// JSON Utilities
|
||||
@ -88,8 +90,15 @@ func (t *JSONTime) UnmarshalJSON(b []byte) (err error) {
|
||||
|
||||
// WriteJSONResponse writes headers with the code, encodes resp into w, and logs
|
||||
// any errors it encounters. r is used to get additional information from the
|
||||
// request.
|
||||
func WriteJSONResponse(w http.ResponseWriter, r *http.Request, code int, resp any) {
|
||||
// request. l, w, and r must not be nil.
|
||||
func WriteJSONResponse(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
code int,
|
||||
resp any,
|
||||
) {
|
||||
h := w.Header()
|
||||
h.Set(httphdr.ContentType, HdrValApplicationJSON)
|
||||
h.Set(httphdr.Server, UserAgent())
|
||||
@ -98,15 +107,27 @@ func WriteJSONResponse(w http.ResponseWriter, r *http.Request, code int, resp an
|
||||
|
||||
err := json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
log.Error("aghhttp: writing json resp to %s %s: %s", r.Method, r.URL.Path, err)
|
||||
l.ErrorContext(
|
||||
ctx,
|
||||
"writing json response",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteJSONResponseOK writes headers with the code 200 OK, encodes v into w,
|
||||
// and logs any errors it encounters. r is used to get additional information
|
||||
// from the request.
|
||||
func WriteJSONResponseOK(w http.ResponseWriter, r *http.Request, v any) {
|
||||
WriteJSONResponse(w, r, http.StatusOK, v)
|
||||
// from the request. l, w, and r must not be nil.
|
||||
func WriteJSONResponseOK(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
v any,
|
||||
) {
|
||||
WriteJSONResponse(ctx, l, w, r, http.StatusOK, v)
|
||||
}
|
||||
|
||||
// ErrorCode is the error code as used by the HTTP API. See the ErrorCode
|
||||
@ -131,11 +152,23 @@ type HTTPAPIErrorResp struct {
|
||||
|
||||
// WriteJSONResponseError encodes err as a JSON error into w, and logs any
|
||||
// errors it encounters. r is used to get additional information from the
|
||||
// request.
|
||||
func WriteJSONResponseError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
log.Error("aghhttp: writing json error to %s %s: %s", r.Method, r.URL.Path, err)
|
||||
// request. l, w, and r must not be nil.
|
||||
func WriteJSONResponseError(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
err error,
|
||||
) {
|
||||
l.ErrorContext(
|
||||
ctx,
|
||||
"writing json error",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
WriteJSONResponse(w, r, http.StatusUnprocessableEntity, &HTTPAPIErrorResp{
|
||||
WriteJSONResponse(ctx, l, w, r, http.StatusUnprocessableEntity, &HTTPAPIErrorResp{
|
||||
Code: ErrorCodeTMP000,
|
||||
Msg: err.Error(),
|
||||
})
|
||||
|
||||
49
internal/aghhttp/registrar.go
Normal file
49
internal/aghhttp/registrar.go
Normal file
@ -0,0 +1,49 @@
|
||||
package aghhttp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Registrar registers an HTTP handler for a method and path.
|
||||
//
|
||||
// TODO(s.chzhen): Implement [httputil.Router].
|
||||
type Registrar interface {
|
||||
Register(method, path string, h http.HandlerFunc)
|
||||
}
|
||||
|
||||
// EmptyRegistrar is an implementation of [Registrar] that does nothing.
|
||||
type EmptyRegistrar struct{}
|
||||
|
||||
// type check
|
||||
var _ Registrar = EmptyRegistrar{}
|
||||
|
||||
// Register implements the [Registrar] interface.
|
||||
func (EmptyRegistrar) Register(_, _ string, _ http.HandlerFunc) {}
|
||||
|
||||
// WrapFunc is a wrapper function that builds an HTTP handler for a route.
|
||||
type WrapFunc func(method string, h http.HandlerFunc) (wrapped http.Handler)
|
||||
|
||||
// DefaultRegistrar is an implementation of [Registrar] that registers handlers
|
||||
// after applying a user-provided wrapper function.
|
||||
type DefaultRegistrar struct {
|
||||
mux *http.ServeMux
|
||||
wrapFn WrapFunc
|
||||
}
|
||||
|
||||
// NewDefaultRegistrar returns a new properly initialized *DefaultRegistrar.
|
||||
// mux and wrap must not be nil.
|
||||
func NewDefaultRegistrar(mux *http.ServeMux, wrap WrapFunc) (r *DefaultRegistrar) {
|
||||
return &DefaultRegistrar{
|
||||
mux: mux,
|
||||
wrapFn: wrap,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Registrar = (*DefaultRegistrar)(nil)
|
||||
|
||||
// Register implements the [Registrar] interface.
|
||||
func (r *DefaultRegistrar) Register(method, path string, h http.HandlerFunc) {
|
||||
wrapped := r.wrapFn(method, h)
|
||||
r.mux.Handle(path, wrapped)
|
||||
}
|
||||
@ -2,10 +2,12 @@ package aghtest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
@ -158,7 +160,20 @@ type ConfigModifier struct {
|
||||
// type check
|
||||
var _ agh.ConfigModifier = (*ConfigModifier)(nil)
|
||||
|
||||
// Apply implements the [ConfigModifier] interface for *ConfigModifier.
|
||||
// Apply implements the [agh.ConfigModifier] interface for *ConfigModifier.
|
||||
func (m *ConfigModifier) Apply(ctx context.Context) {
|
||||
m.OnApply(ctx)
|
||||
}
|
||||
|
||||
// Registrar is a fake [aghhttp.Registrar] implementation for tests.
|
||||
type Registrar struct {
|
||||
OnRegister func(method, path string, h http.HandlerFunc)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ aghhttp.Registrar = (*Registrar)(nil)
|
||||
|
||||
// Register implements the [aghhttp.Registrar] interface for *Registrar.
|
||||
func (m *Registrar) Register(method, path string, h http.HandlerFunc) {
|
||||
m.OnRegister(method, path, h)
|
||||
}
|
||||
|
||||
@ -31,7 +31,7 @@ type ServerConfig struct {
|
||||
ConfModifier agh.ConfigModifier `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
HTTPReg aghhttp.Registrar `yaml:"-"`
|
||||
|
||||
Enabled bool `yaml:"enabled"`
|
||||
InterfaceName string `yaml:"interface_name"`
|
||||
|
||||
@ -112,7 +112,7 @@ func Create(ctx context.Context, conf *ServerConfig) (s *server, err error) {
|
||||
CommandConstructor: conf.CommandConstructor,
|
||||
ConfModifier: conf.ConfModifier,
|
||||
|
||||
HTTPRegister: conf.HTTPRegister,
|
||||
HTTPReg: conf.HTTPReg,
|
||||
|
||||
Enabled: conf.Enabled,
|
||||
InterfaceName: conf.InterfaceName,
|
||||
|
||||
@ -171,7 +171,7 @@ func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
|
||||
status.Leases = leasesToDynamic(leases[dynamicIdx:])
|
||||
status.StaticLeases = leasesToStatic(leases[:dynamicIdx])
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, status)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), s.conf.Logger, w, r, status)
|
||||
}
|
||||
|
||||
func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, err error) {
|
||||
@ -451,7 +451,7 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
}
|
||||
|
||||
// newNetInterfaceJSON creates a JSON object from a [net.Interface] iface.
|
||||
@ -613,7 +613,7 @@ func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
|
||||
|
||||
s.setOtherDHCPResult(ctx, ifaceName, result)
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, result)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, result)
|
||||
}
|
||||
|
||||
// setOtherDHCPResult sets the results of the check for another DHCP server in
|
||||
@ -741,7 +741,7 @@ func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
s.conf = &ServerConfig{
|
||||
ConfModifier: s.conf.ConfModifier,
|
||||
|
||||
HTTPRegister: s.conf.HTTPRegister,
|
||||
HTTPReg: s.conf.HTTPReg,
|
||||
|
||||
LocalDomainName: s.conf.LocalDomainName,
|
||||
|
||||
@ -778,17 +778,17 @@ func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *server) registerHandlers() {
|
||||
if s.conf.HTTPRegister == nil {
|
||||
if s.conf.HTTPReg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.handleDHCPUpdateStaticLease)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
|
||||
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
|
||||
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/update_static_lease", s.handleDHCPUpdateStaticLease)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset", s.handleReset)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
|
||||
}
|
||||
|
||||
@ -24,7 +24,9 @@ type jsonError struct {
|
||||
// TODO(a.garipov): Either take the logger from the server after we've
|
||||
// refactored logging or make this not a method of *Server.
|
||||
func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponse(w, r, http.StatusNotImplemented, &jsonError{
|
||||
ctx := r.Context()
|
||||
|
||||
aghhttp.WriteJSONResponse(ctx, s.conf.Logger, w, r, http.StatusNotImplemented, &jsonError{
|
||||
Message: aghos.Unsupported("dhcp").Error(),
|
||||
})
|
||||
}
|
||||
@ -37,13 +39,13 @@ func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
|
||||
// interconnected parts--such as HTTP handlers and frontend--to make that work
|
||||
// properly.
|
||||
func (s *server) registerHandlers() {
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/status", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/update_static_lease", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
|
||||
}
|
||||
|
||||
@ -186,7 +186,7 @@ func (s *Server) accessListJSON() (j accessListJSON) {
|
||||
|
||||
// handleAccessList handles requests to the GET /control/access/list endpoint.
|
||||
func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, s.accessListJSON())
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), s.logger, w, r, s.accessListJSON())
|
||||
}
|
||||
|
||||
// validateAccessSet checks the internal accessListJSON lists. To search for
|
||||
|
||||
@ -270,7 +270,7 @@ type ServerConfig struct {
|
||||
ConfModifier agh.ConfigModifier
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
HTTPReg aghhttp.Registrar
|
||||
|
||||
// LocalPTRResolvers is a slice of addresses to be used as upstreams for
|
||||
// resolving PTR queries for local addresses.
|
||||
|
||||
@ -245,8 +245,10 @@ func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
|
||||
|
||||
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
resp := s.getDNSConfig(r.Context())
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
ctx := r.Context()
|
||||
|
||||
resp := s.getDNSConfig(ctx)
|
||||
aghhttp.WriteJSONResponseOK(ctx, s.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// checkBlockingMode returns an error if blocking mode is invalid.
|
||||
@ -739,7 +741,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
|
||||
cv.check()
|
||||
cv.close()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, cv.status())
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, cv.status())
|
||||
}
|
||||
|
||||
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
|
||||
@ -797,7 +799,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
s.conf.ConfModifier.Apply(ctx)
|
||||
|
||||
aghhttp.OK(ctx, s.logger, w)
|
||||
aghhttp.OK(ctx, l, w)
|
||||
}
|
||||
|
||||
// handleDoH is the DNS-over-HTTPs handler.
|
||||
@ -837,19 +839,19 @@ func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (s *Server) registerHandlers() {
|
||||
if webRegistered || s.conf.HTTPRegister == nil {
|
||||
if webRegistered || s.conf.HTTPReg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/dns_info", s.handleGetConfig)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/dns_config", s.handleSetConfig)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/protection", s.handleSetProtection)
|
||||
s.conf.HTTPReg.Register(http.MethodGet, "/control/dns_info", s.handleGetConfig)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/dns_config", s.handleSetConfig)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/protection", s.handleSetProtection)
|
||||
|
||||
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)
|
||||
s.conf.HTTPReg.Register(http.MethodGet, "/control/access/list", s.handleAccessList)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/access/set", s.handleAccessSet)
|
||||
|
||||
s.conf.HTTPRegister(http.MethodPost, "/control/cache_clear", s.handleCacheClear)
|
||||
s.conf.HTTPReg.Register(http.MethodPost, "/control/cache_clear", s.handleCacheClear)
|
||||
|
||||
// Register both versions, with and without the trailing slash, to
|
||||
// prevent a 301 Moved Permanently redirect when clients request the
|
||||
@ -858,8 +860,8 @@ func (s *Server) registerHandlers() {
|
||||
// See go doc net/http.ServeMux.
|
||||
//
|
||||
// See also https://github.com/AdguardTeam/AdGuardHome/issues/2628.
|
||||
s.conf.HTTPRegister("", "/dns-query", s.handleDoH)
|
||||
s.conf.HTTPRegister("", "/dns-query/", s.handleDoH)
|
||||
s.conf.HTTPReg.Register("", "/dns-query", s.handleDoH)
|
||||
s.conf.HTTPReg.Register("", "/dns-query/", s.handleDoH)
|
||||
|
||||
webRegistered = true
|
||||
}
|
||||
|
||||
@ -127,11 +127,11 @@ func (d *DNSFilter) ApplyBlockedServicesList(setts *Settings, list []string) {
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleBlockedServicesIDs(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, serviceIDs)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, serviceIDs)
|
||||
}
|
||||
|
||||
func (d *DNSFilter) handleBlockedServicesAll(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, struct {
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, struct {
|
||||
BlockedServices []blockedService `json:"blocked_services"`
|
||||
ServiceGroups []serviceGroup `json:"groups"`
|
||||
}{
|
||||
@ -153,7 +153,7 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
|
||||
list = d.conf.BlockedServices.IDs
|
||||
}()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, list)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, list)
|
||||
}
|
||||
|
||||
// handleBlockedServicesSet is the handler for the POST
|
||||
@ -193,7 +193,7 @@ func (d *DNSFilter) handleBlockedServicesGet(w http.ResponseWriter, r *http.Requ
|
||||
bsvc = d.conf.BlockedServices.Clone()
|
||||
}()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, bsvc)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, bsvc)
|
||||
}
|
||||
|
||||
// handleBlockedServicesUpdate is the handler for the PUT
|
||||
|
||||
@ -109,8 +109,8 @@ type Config struct {
|
||||
// nil.
|
||||
ConfModifier agh.ConfigModifier `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
// HTTPReg registers HTTP handlers. It must not be nil.
|
||||
HTTPReg aghhttp.Registrar `yaml:"-"`
|
||||
|
||||
// HTTPClient is the client to use for updating the remote filters.
|
||||
HTTPClient *http.Client `yaml:"-"`
|
||||
|
||||
@ -370,11 +370,12 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
var err error
|
||||
|
||||
ctx := r.Context()
|
||||
l := d.logger
|
||||
|
||||
req := Req{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, d.logger, r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "json decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -387,7 +388,7 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
if !ok {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
d.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
@ -397,7 +398,7 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
}
|
||||
|
||||
type filterJSON struct {
|
||||
@ -454,7 +455,7 @@ func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request
|
||||
resp.UserRules = d.conf.UserRules
|
||||
d.conf.filtersMu.RUnlock()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// Set filtering configuration
|
||||
@ -521,13 +522,14 @@ type checkHostResp struct {
|
||||
// API.
|
||||
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := d.logger
|
||||
|
||||
query := r.URL.Query()
|
||||
host := query.Get("name")
|
||||
if host == "" {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
d.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
@ -543,7 +545,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
d.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusUnprocessableEntity,
|
||||
@ -572,7 +574,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
d.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
@ -605,7 +607,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
}
|
||||
|
||||
// stringToDNSType is a helper function that converts a string to DNS type. If
|
||||
@ -680,7 +682,7 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
|
||||
Enabled: protectedBool(d.confMu, &d.conf.SafeBrowsingEnabled),
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// handleParentalEnable is the handler for the POST /control/parental/enable
|
||||
@ -706,15 +708,12 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
|
||||
Enabled: protectedBool(d.confMu, &d.conf.ParentalEnabled),
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// RegisterFilteringHandlers - register handlers
|
||||
func (d *DNSFilter) RegisterFilteringHandlers() {
|
||||
registerHTTP := d.conf.HTTPRegister
|
||||
if registerHTTP == nil {
|
||||
return
|
||||
}
|
||||
registerHTTP := d.conf.HTTPReg.Register
|
||||
|
||||
registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
|
||||
registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -116,6 +117,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
ConfModifier: confModifier,
|
||||
HTTPReg: aghhttp.EmptyRegistrar{},
|
||||
DataDir: filtersDir,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
@ -197,8 +199,10 @@ func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) {
|
||||
Logger: testLogger,
|
||||
ConfModifier: confModifier,
|
||||
DataDir: filtersDir,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
HTTPReg: &aghtest.Registrar{
|
||||
OnRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
},
|
||||
SafeBrowsingEnabled: tc.enabled,
|
||||
}, nil)
|
||||
@ -284,8 +288,10 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) {
|
||||
Logger: testLogger,
|
||||
ConfModifier: confModifier,
|
||||
DataDir: filtersDir,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
HTTPReg: &aghtest.Registrar{
|
||||
OnRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
},
|
||||
ParentalEnabled: tc.enabled,
|
||||
}, nil)
|
||||
|
||||
@ -45,7 +45,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, arr)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, arr)
|
||||
}
|
||||
|
||||
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
|
||||
@ -229,20 +229,22 @@ func (d *DNSFilter) handleRewriteSettings(w http.ResponseWriter, r *http.Request
|
||||
Enabled: protectedBool(d.confMu, &d.conf.RewritesEnabled),
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// handleRewriteSettingsUpdate is the handler for the PUT
|
||||
// /control/rewrite/settings/update HTTP API.
|
||||
func (d *DNSFilter) handleRewriteSettingsUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
req := &rewriteSettings{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
aghhttp.ErrorAndLog(ctx, d.logger, r, w, http.StatusBadRequest, "json.Decode: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
setProtectedBool(d.confMu, &d.conf.RewritesEnabled, req.Enabled)
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
d.conf.ConfModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
@ -261,8 +261,10 @@ func TestDNSFilter_HandleRewriteHTTP(t *testing.T) {
|
||||
d, err := filtering.New(&filtering.Config{
|
||||
Logger: testLogger,
|
||||
ConfModifier: confModifier,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
HTTPReg: &aghtest.Registrar{
|
||||
OnRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
},
|
||||
Rewrites: rewriteEntriesToLegacyRewrites(testRewrites),
|
||||
}, nil)
|
||||
@ -362,8 +364,10 @@ func TestDNSFilter_HandleRewriteSettings(t *testing.T) {
|
||||
d, err := filtering.New(&filtering.Config{
|
||||
Logger: testLogger,
|
||||
ConfModifier: confModifier,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
HTTPReg: &aghtest.Registrar{
|
||||
OnRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
},
|
||||
RewritesEnabled: false,
|
||||
}, nil)
|
||||
|
||||
@ -36,7 +36,7 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
|
||||
resp = d.conf.SafeSearchConf
|
||||
}()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// handleSafeSearchSettings is the handler for PUT /control/safesearch/settings
|
||||
|
||||
@ -267,10 +267,10 @@ func (web *webAPI) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
// registerAuthHandlers registers authentication handlers.
|
||||
func (web *webAPI) registerAuthHandlers() {
|
||||
web.conf.mux.Handle(
|
||||
"/control/login",
|
||||
web.postInstallHandler(ensure(http.MethodPost, web.handleLogin)),
|
||||
http.MethodPost+" "+"/control/login",
|
||||
web.postInstallHandler(http.HandlerFunc(web.handleLogin)),
|
||||
)
|
||||
httpRegister(http.MethodGet, "/control/logout", web.handleLogout)
|
||||
web.httpReg.Register(http.MethodGet, "/control/logout", web.handleLogout)
|
||||
}
|
||||
|
||||
// isPublicResource returns true if p is a path to a public resource.
|
||||
|
||||
@ -321,8 +321,9 @@ func authRequest(path string, c *http.Cookie, user, pass string) (r *http.Reques
|
||||
func TestAuth_ServeHTTP_firstRun(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
mw := &webMw{}
|
||||
mux := http.NewServeMux()
|
||||
globalContext.mux = mux
|
||||
httpReg := aghhttp.NewDefaultRegistrar(mux, mw.wrap)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
web, err := initWeb(
|
||||
@ -335,12 +336,14 @@ func TestAuth_ServeHTTP_firstRun(t *testing.T) {
|
||||
nil,
|
||||
mux,
|
||||
agh.EmptyConfigModifier{},
|
||||
httpReg,
|
||||
false,
|
||||
true,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
mw.set(web)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -484,8 +487,9 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
|
||||
|
||||
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
|
||||
mw := &webMw{}
|
||||
baseMux := http.NewServeMux()
|
||||
globalContext.mux = baseMux
|
||||
httpReg := aghhttp.NewDefaultRegistrar(baseMux, mw.wrap)
|
||||
|
||||
tlsMgr, err := newTLSManager(testutil.ContextWithTimeout(t, testTimeout), &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
@ -504,12 +508,14 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
|
||||
auth,
|
||||
baseMux,
|
||||
agh.EmptyConfigModifier{},
|
||||
httpReg,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
mw.set(web)
|
||||
|
||||
mux := auth.middleware().Wrap(baseMux)
|
||||
|
||||
@ -645,8 +651,9 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
|
||||
|
||||
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
|
||||
mw := &webMw{}
|
||||
baseMux := http.NewServeMux()
|
||||
globalContext.mux = baseMux
|
||||
httpReg := aghhttp.NewDefaultRegistrar(baseMux, mw.wrap)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
web, err := initWeb(
|
||||
@ -659,12 +666,14 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
|
||||
auth,
|
||||
baseMux,
|
||||
agh.EmptyConfigModifier{},
|
||||
httpReg,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
mw.set(web)
|
||||
|
||||
mux := auth.middleware().Wrap(baseMux)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
@ -44,6 +45,9 @@ type clientsContainer struct {
|
||||
// nil.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// httpReg registers HTTP handlers. It must not be nil.
|
||||
httpReg aghhttp.Registrar
|
||||
|
||||
// lock protects all fields.
|
||||
//
|
||||
// TODO(a.garipov): Use a pointer and describe which fields are protected in
|
||||
@ -66,9 +70,12 @@ type BlockedClientChecker interface {
|
||||
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
|
||||
}
|
||||
|
||||
// Init initializes clients container
|
||||
// dhcpServer: optional
|
||||
// Note: this function must be called only once
|
||||
// Init initializes the clients container. All arguments must not be nil except
|
||||
// for objects.
|
||||
//
|
||||
// NOTE: This function must be called only once.
|
||||
//
|
||||
// TODO(s.chzhen): Use a configuration structure.
|
||||
func (clients *clientsContainer) Init(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
@ -79,6 +86,7 @@ func (clients *clientsContainer) Init(
|
||||
filteringConf *filtering.Config,
|
||||
sigHdlr *signalHandler,
|
||||
confModifier agh.ConfigModifier,
|
||||
httpReg aghhttp.Registrar,
|
||||
) (err error) {
|
||||
// TODO(s.chzhen): Refactor it.
|
||||
if clients.storage != nil {
|
||||
@ -90,6 +98,7 @@ func (clients *clientsContainer) Init(
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
clients.confModifier = confModifier
|
||||
clients.httpReg = httpReg
|
||||
|
||||
confClients := make([]*client.Persistent, 0, len(objects))
|
||||
for i, o := range objects {
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -30,6 +31,7 @@ func newClientsContainer(tb testing.TB) (c *clientsContainer) {
|
||||
},
|
||||
newSignalHandler(testLogger, nil, nil),
|
||||
agh.EmptyConfigModifier{},
|
||||
aghhttp.EmptyRegistrar{},
|
||||
)
|
||||
|
||||
require.NoError(tb, err)
|
||||
|
||||
@ -94,6 +94,7 @@ func whoisOrEmpty(r *client.Runtime) (info *whois.Info) {
|
||||
|
||||
// handleGetClients is the handler for GET /control/clients HTTP API.
|
||||
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
data := clientListJSON{}
|
||||
|
||||
clients.lock.Lock()
|
||||
@ -106,7 +107,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
return true
|
||||
})
|
||||
|
||||
clients.storage.UpdateDHCP(r.Context())
|
||||
clients.storage.UpdateDHCP(ctx)
|
||||
|
||||
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
|
||||
src, host := rc.Info()
|
||||
@ -124,7 +125,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
|
||||
|
||||
data.Tags = clients.storage.AllowedTags()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(ctx, clients.logger, w, r, data)
|
||||
}
|
||||
|
||||
// initPrev initializes the persistent client with the default or previous
|
||||
@ -454,6 +455,9 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
//
|
||||
// Deprecated: Remove it when migration to the new API is over.
|
||||
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := clients.logger
|
||||
|
||||
q := r.URL.Query()
|
||||
data := make([]map[string]*clientJSON, 0, len(q))
|
||||
params := &client.FindParams{}
|
||||
@ -467,12 +471,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
|
||||
err = params.Set(idStr)
|
||||
if err != nil {
|
||||
clients.logger.DebugContext(
|
||||
r.Context(),
|
||||
"finding client",
|
||||
"id", idStr,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
l.DebugContext(ctx, "finding client", "id", idStr, slogutil.KeyError, err)
|
||||
|
||||
continue
|
||||
}
|
||||
@ -482,7 +481,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
|
||||
})
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, data)
|
||||
}
|
||||
|
||||
// findClient returns available information about a client by params from the
|
||||
@ -530,13 +529,14 @@ type searchClientJSON struct {
|
||||
// API.
|
||||
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := clients.logger
|
||||
|
||||
q := searchQueryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&q)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
clients.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
@ -554,12 +554,7 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
|
||||
idStr := c.ID
|
||||
err = params.Set(idStr)
|
||||
if err != nil {
|
||||
clients.logger.DebugContext(
|
||||
ctx,
|
||||
"searching client",
|
||||
"id", idStr,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
l.DebugContext(ctx, "searching client", "id", idStr, slogutil.KeyError, err)
|
||||
|
||||
continue
|
||||
}
|
||||
@ -569,7 +564,7 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
|
||||
})
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, data)
|
||||
}
|
||||
|
||||
// findRuntime looks up the IP in runtime and temporary storages, like
|
||||
@ -613,14 +608,14 @@ func (clients *clientsContainer) findRuntime(
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterClientsHandlers registers HTTP handlers
|
||||
// registerWebHandlers registers HTTP handlers.
|
||||
func (clients *clientsContainer) registerWebHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/clients", clients.handleGetClients)
|
||||
httpRegister(http.MethodPost, "/control/clients/add", clients.handleAddClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
|
||||
httpRegister(http.MethodPost, "/control/clients/search", clients.handleSearchClient)
|
||||
clients.httpReg.Register(http.MethodGet, "/control/clients", clients.handleGetClients)
|
||||
clients.httpReg.Register(http.MethodPost, "/control/clients/add", clients.handleAddClient)
|
||||
clients.httpReg.Register(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
|
||||
clients.httpReg.Register(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
|
||||
clients.httpReg.Register(http.MethodPost, "/control/clients/search", clients.handleSearchClient)
|
||||
|
||||
// Deprecated handler.
|
||||
httpRegister(http.MethodGet, "/control/clients/find", clients.handleFindClient)
|
||||
clients.httpReg.Register(http.MethodGet, "/control/clients/find", clients.handleFindClient)
|
||||
}
|
||||
|
||||
@ -116,12 +116,13 @@ type statusResponse struct {
|
||||
|
||||
func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := web.logger
|
||||
|
||||
dnsAddrs, err := collectDNSAddresses(web.tlsManager)
|
||||
if err != nil {
|
||||
// Don't add a lot of formatting, since the error is already
|
||||
// wrapped by collectDNSAddresses.
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -167,7 +168,7 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
resp.IsDHCPAvailable = globalContext.dhcpServer != nil
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
}
|
||||
|
||||
// registerControlHandlers sets up HTTP handlers for various control endpoints.
|
||||
@ -178,13 +179,21 @@ func (web *webAPI) registerControlHandlers() {
|
||||
"/control/version.json",
|
||||
web.postInstallHandler(http.HandlerFunc(web.handleVersionJSON)),
|
||||
)
|
||||
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
|
||||
web.httpReg.Register(http.MethodPost, "/control/update", web.handleUpdate)
|
||||
|
||||
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
|
||||
httpRegister(http.MethodPost, "/control/i18n/change_language", web.handleI18nChangeLanguage)
|
||||
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
|
||||
httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile)
|
||||
httpRegister(http.MethodPut, "/control/profile/update", web.handlePutProfile)
|
||||
web.httpReg.Register(http.MethodGet, "/control/status", web.handleStatus)
|
||||
web.httpReg.Register(
|
||||
http.MethodPost,
|
||||
"/control/i18n/change_language",
|
||||
web.handleI18nChangeLanguage,
|
||||
)
|
||||
web.httpReg.Register(
|
||||
http.MethodGet,
|
||||
"/control/i18n/current_language",
|
||||
web.handleI18nCurrentLanguage,
|
||||
)
|
||||
web.httpReg.Register(http.MethodGet, "/control/profile", web.handleGetProfile)
|
||||
web.httpReg.Register(http.MethodPut, "/control/profile/update", web.handlePutProfile)
|
||||
|
||||
// No authentication is required for DoH/DoT configuration endpoints.
|
||||
mux.Handle(
|
||||
@ -199,38 +208,71 @@ func (web *webAPI) registerControlHandlers() {
|
||||
web.registerAuthHandlers()
|
||||
}
|
||||
|
||||
// httpRegister registers an HTTP handler.
|
||||
//
|
||||
// TODO(s.chzhen): Do not use [globalContext.mux].
|
||||
func httpRegister(method, url string, handler http.HandlerFunc) {
|
||||
if method == "" {
|
||||
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
|
||||
globalContext.mux.Handle(url, postInstallHandler(handler))
|
||||
return
|
||||
}
|
||||
// webMw provides middleware for route handlers. The set method must be called
|
||||
// to initialize the middleware.
|
||||
type webMw struct {
|
||||
// postInstallMw is middleware that verifies that AdGuard Home is not
|
||||
// running for the first time.
|
||||
postInstallMw func(h http.Handler) (wrapped http.Handler)
|
||||
|
||||
globalContext.mux.Handle(
|
||||
url,
|
||||
postInstallHandler(gziphandler.GzipHandler(ensure(method, handler))),
|
||||
)
|
||||
// ensureMw is like postInstallMw, but also applies gzip and enforces the
|
||||
// HTTP method.
|
||||
ensureMw aghhttp.WrapFunc
|
||||
}
|
||||
|
||||
// ensure returns a wrapped handler that makes sure that the request has the
|
||||
// correct method as well as additional method and header checks.
|
||||
func ensure(
|
||||
// set sets the middleware functions used to build handler chains.
|
||||
func (mw *webMw) set(web *webAPI) {
|
||||
mw.postInstallMw = web.postInstallHandler
|
||||
|
||||
mw.ensureMw = func(method string, h http.HandlerFunc) (wrapped http.Handler) {
|
||||
return web.postInstallHandler(gziphandler.GzipHandler(web.ensure(method, h)))
|
||||
}
|
||||
}
|
||||
|
||||
// wrap returns a wrapped HTTP handler for the given route.
|
||||
//
|
||||
// TODO(s.chzhen): Implement [httputil.Middleware].
|
||||
func (mw *webMw) wrap(method string, h http.HandlerFunc) (wrapped http.Handler) {
|
||||
f := func(w http.ResponseWriter, r *http.Request) {
|
||||
var handler http.Handler
|
||||
if method == "" {
|
||||
// The "/dns-query" handler doesn't require authentication or gzip,
|
||||
// and it isn't restricted to a single HTTP method.
|
||||
handler = mw.postInstallMw(h)
|
||||
} else {
|
||||
handler = mw.ensureMw(method, h)
|
||||
}
|
||||
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(f)
|
||||
}
|
||||
|
||||
// ensure returns a wrapped handler that verifies the request method. It also
|
||||
// performs additional method and header checks.
|
||||
func (web *webAPI) ensure(
|
||||
method string,
|
||||
handler func(http.ResponseWriter, *http.Request),
|
||||
) (wrapped http.HandlerFunc) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
m := r.Method
|
||||
if m != method {
|
||||
aghhttp.Error(r, w, http.StatusMethodNotAllowed, "only method %s is allowed", method)
|
||||
aghhttp.ErrorAndLog(
|
||||
r.Context(),
|
||||
web.logger,
|
||||
r,
|
||||
w,
|
||||
http.StatusMethodNotAllowed,
|
||||
"only method %s is allowed",
|
||||
method,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if modifiesData(m) {
|
||||
if !ensureContentType(w, r) {
|
||||
if !web.ensureContentType(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -250,9 +292,11 @@ func modifiesData(m string) (ok bool) {
|
||||
// ensureContentType makes sure that the content type of a data-modifying
|
||||
// request is set correctly. If it is not, ensureContentType writes a response
|
||||
// to w, and ok is false.
|
||||
func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
func (web *webAPI) ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
const statusUnsup = http.StatusUnsupportedMediaType
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
cType := r.Header.Get(httphdr.ContentType)
|
||||
if r.ContentLength == 0 {
|
||||
if cType == "" {
|
||||
@ -262,7 +306,15 @@ func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
// Assume that browsers always send a content type when submitting HTML
|
||||
// forms and require no content type for requests with no body to make
|
||||
// sure that the request comes from JavaScript.
|
||||
aghhttp.Error(r, w, statusUnsup, "empty body with content-type %q not allowed", cType)
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
web.logger,
|
||||
r,
|
||||
w,
|
||||
statusUnsup,
|
||||
"empty body with content-type %q not allowed",
|
||||
cType,
|
||||
)
|
||||
|
||||
return false
|
||||
|
||||
@ -273,7 +325,15 @@ func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||
return true
|
||||
}
|
||||
|
||||
aghhttp.Error(r, w, statusUnsup, "only content-type %s is allowed", wantCType)
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
web.logger,
|
||||
r,
|
||||
w,
|
||||
statusUnsup,
|
||||
"only content-type %s is allowed",
|
||||
wantCType,
|
||||
)
|
||||
|
||||
return false
|
||||
}
|
||||
@ -297,15 +357,16 @@ func (web *webAPI) preInstallHandler(handler http.Handler) (wrapped http.Handler
|
||||
// handleHTTPSRedirect redirects the request to HTTPS, if needed, and adds some
|
||||
// HTTPS-related headers. If proceed is true, the middleware must continue
|
||||
// handling the request.
|
||||
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
web := globalContext.web
|
||||
func (web *webAPI) handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if web.httpsServer.server == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
host, err := netutil.SplitHost(r.Host)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "bad host: %s", err)
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "bad host: %s", err)
|
||||
|
||||
return false
|
||||
}
|
||||
@ -381,36 +442,6 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
|
||||
}
|
||||
}
|
||||
|
||||
// postInstallHandler lets the handler to run only if firstRun is false.
|
||||
// Otherwise, it redirects to /install.html. It also enforces HTTPS if it is
|
||||
// enabled and configured and sets appropriate access control headers.
|
||||
//
|
||||
// TODO(s.chzhen): Replace with [web.postInstall] after fixing its usage in
|
||||
// [httpRegister], which is called by [dhcpd.Create] before [web] is
|
||||
// initialized.
|
||||
func postInstallHandler(handler http.Handler) (wrapped http.Handler) {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if globalContext.web == nil {
|
||||
aghhttp.Error(r, w, http.StatusTooEarly, "it is not initialized yet")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
path := r.URL.Path
|
||||
if globalContext.web.conf.firstRun &&
|
||||
!strings.HasPrefix(path, "/install.") &&
|
||||
!strings.HasPrefix(path, "/assets/") {
|
||||
http.Redirect(w, r, "install.html", http.StatusFound)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if handleHTTPSRedirect(w, r) {
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// postInstallHandler lets the handler to run only if firstRun is false.
|
||||
// Otherwise, it redirects to /install.html. It also enforces HTTPS if it is
|
||||
// enabled and configured and sets appropriate access control headers.
|
||||
@ -425,7 +456,7 @@ func (web *webAPI) postInstallHandler(handler http.Handler) (wrapped http.Handle
|
||||
return
|
||||
}
|
||||
|
||||
if handleHTTPSRedirect(w, r) {
|
||||
if web.handleHTTPSRedirect(w, r) {
|
||||
handler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
@ -44,6 +44,7 @@ type getAddrsResponse struct {
|
||||
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
|
||||
func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := web.logger
|
||||
|
||||
data := getAddrsResponse{
|
||||
Version: version.Version(),
|
||||
@ -56,7 +57,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
web.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
@ -72,7 +73,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
|
||||
data.Interfaces[iface.Name] = iface
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, data)
|
||||
}
|
||||
|
||||
type checkConfReqEnt struct {
|
||||
@ -186,20 +187,13 @@ func (req *checkConfReq) validateDNS(
|
||||
// handleInstallCheckConfig handles the /check_config endpoint.
|
||||
func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := web.logger
|
||||
|
||||
req := &checkConfReq{}
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
web.logger,
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"decoding the request: %s",
|
||||
err,
|
||||
)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "decoding the request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -210,14 +204,14 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
|
||||
resp.Web.Status = err.Error()
|
||||
}
|
||||
|
||||
resp.DNS.CanAutofix, err = req.validateDNS(ctx, web.logger, tcpPorts, web.cmdCons)
|
||||
resp.DNS.CanAutofix, err = req.validateDNS(ctx, l, tcpPorts, web.cmdCons)
|
||||
if err != nil {
|
||||
resp.DNS.Status = err.Error()
|
||||
} else if !req.DNS.IP.IsUnspecified() {
|
||||
resp.StaticIP = handleStaticIP(ctx, web.logger, req.DNS.IP, req.SetStaticIP, web.cmdCons)
|
||||
resp.StaticIP = handleStaticIP(ctx, l, req.DNS.IP, req.SetStaticIP, web.cmdCons)
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
}
|
||||
|
||||
// handleStaticIP checks and optionally sets a static IP on the interface that
|
||||
@ -430,10 +424,11 @@ const PasswordMinRunes = 8
|
||||
// Apply new configuration, start DNS server, restart Web server
|
||||
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := web.logger
|
||||
|
||||
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -441,7 +436,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
if utf8.RuneCountInString(req.Password) < PasswordMinRunes {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
web.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusUnprocessableEntity,
|
||||
@ -454,14 +449,14 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
|
||||
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -478,6 +473,8 @@ func (web *webAPI) finalizeInstall(
|
||||
req *applyConfigReq,
|
||||
restartHTTP bool,
|
||||
) {
|
||||
l := web.logger
|
||||
|
||||
var err error
|
||||
curConfig := &configuration{}
|
||||
copyInstallSettings(curConfig, config)
|
||||
@ -501,7 +498,7 @@ func (web *webAPI) finalizeInstall(
|
||||
}
|
||||
err = web.auth.addUser(ctx, u, req.Password)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusUnprocessableEntity, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusUnprocessableEntity, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -510,9 +507,9 @@ func (web *webAPI) finalizeInstall(
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = startMods(ctx, web.baseLogger, web.tlsManager, web.confModifier)
|
||||
err = startMods(ctx, web.baseLogger, web.tlsManager, web.confModifier, web.httpReg)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -521,7 +518,7 @@ func (web *webAPI) finalizeInstall(
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
web.logger,
|
||||
l,
|
||||
r,
|
||||
w,
|
||||
http.StatusInternalServerError,
|
||||
@ -537,12 +534,12 @@ func (web *webAPI) finalizeInstall(
|
||||
|
||||
web.registerControlHandlers()
|
||||
|
||||
aghhttp.OK(ctx, web.logger, w)
|
||||
aghhttp.OK(ctx, l, w)
|
||||
|
||||
rc := http.NewResponseController(w)
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
web.logger.WarnContext(ctx, "flushing response", slogutil.KeyError, err)
|
||||
l.WarnContext(ctx, "flushing response", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
if !restartHTTP {
|
||||
@ -554,10 +551,10 @@ func (web *webAPI) finalizeInstall(
|
||||
// and will be blocked by it's own caller.
|
||||
go func(timeout time.Duration) {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), timeout)
|
||||
defer slogutil.RecoverAndLog(shutdownCtx, web.logger)
|
||||
defer slogutil.RecoverAndLog(shutdownCtx, l)
|
||||
defer cancel()
|
||||
|
||||
shutdownSrv(shutdownCtx, web.logger, web.httpServer)
|
||||
shutdownSrv(shutdownCtx, l, web.httpServer)
|
||||
}(shutdownTimeout)
|
||||
}
|
||||
|
||||
@ -593,19 +590,20 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
}
|
||||
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
// baseLogger and tlsMgr must not be nil.
|
||||
// baseLogger, tlsMgr, confModifier, and httpReg must not be nil.
|
||||
func startMods(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
confModifier agh.ConfigModifier,
|
||||
httpReg aghhttp.Registrar,
|
||||
) (err error) {
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = initDNS(ctx, baseLogger, tlsMgr, confModifier, statsDir, querylogDir)
|
||||
err = initDNS(ctx, baseLogger, tlsMgr, confModifier, httpReg, statsDir, querylogDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -627,15 +625,15 @@ func (web *webAPI) registerInstallHandlers() {
|
||||
mux := web.conf.mux
|
||||
|
||||
mux.Handle(
|
||||
"/control/install/get_addresses",
|
||||
web.preInstallHandler(ensure(http.MethodGet, web.handleInstallGetAddresses)),
|
||||
http.MethodGet+" "+"/control/install/get_addresses",
|
||||
web.preInstallHandler(http.HandlerFunc(web.handleInstallGetAddresses)),
|
||||
)
|
||||
mux.Handle(
|
||||
"/control/install/check_config",
|
||||
web.preInstallHandler(ensure(http.MethodPost, web.handleInstallCheckConfig)),
|
||||
web.preInstallHandler(web.ensure(http.MethodPost, web.handleInstallCheckConfig)),
|
||||
)
|
||||
mux.Handle(
|
||||
"/control/install/configure",
|
||||
web.preInstallHandler(ensure(http.MethodPost, web.handleInstallConfigure)),
|
||||
web.preInstallHandler(web.ensure(http.MethodPost, web.handleInstallConfigure)),
|
||||
)
|
||||
}
|
||||
|
||||
@ -33,11 +33,12 @@ type temporaryError interface {
|
||||
// TODO(a.garipov): Find out if this API used with a GET method by anyone.
|
||||
func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := web.logger
|
||||
|
||||
resp := &versionResponse{}
|
||||
if web.conf.disableUpdate {
|
||||
resp.Disabled = true
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
|
||||
return
|
||||
}
|
||||
@ -50,15 +51,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
if r.ContentLength != 0 {
|
||||
err = json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
aghhttp.ErrorAndLog(
|
||||
ctx,
|
||||
web.logger,
|
||||
r,
|
||||
w,
|
||||
http.StatusBadRequest,
|
||||
"parsing request: %s",
|
||||
err,
|
||||
)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "parsing request: %s", err)
|
||||
|
||||
return
|
||||
}
|
||||
@ -67,20 +60,20 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
|
||||
err = web.requestVersionInfo(ctx, resp, req.Recheck)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadGateway, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadGateway, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = resp.setAllowedToAutoUpdate(ctx, web.logger, web.tlsManager)
|
||||
err = resp.setAllowedToAutoUpdate(ctx, l, web.tlsManager)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
}
|
||||
|
||||
// requestVersionInfo sets the VersionInfo field of resp if it can reach the
|
||||
|
||||
@ -41,13 +41,14 @@ const (
|
||||
|
||||
// initDNS updates all the fields of the [globalContext] needed to initialize
|
||||
// the DNS server and initializes it at last. It also must not be called unless
|
||||
// [config] and [globalContext] are initialized. baseLogger, tlsMgr and
|
||||
// confModfier must not be nil.
|
||||
// [config] and [globalContext] are initialized. baseLogger, tlsMgr,
|
||||
// confModifier, and httpReg must not be nil.
|
||||
func initDNS(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
confModifier agh.ConfigModifier,
|
||||
httpReg aghhttp.Registrar,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
) (err error) {
|
||||
@ -58,7 +59,7 @@ func initDNS(
|
||||
Filename: filepath.Join(statsDir, "stats.db"),
|
||||
Limit: time.Duration(config.Stats.Interval),
|
||||
ConfigModifier: confModifier,
|
||||
HTTPRegister: httpRegister,
|
||||
HTTPReg: httpReg,
|
||||
Enabled: config.Stats.Enabled,
|
||||
ShouldCountClient: globalContext.clients.shouldCountClient,
|
||||
}
|
||||
@ -78,7 +79,7 @@ func initDNS(
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"),
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModifier: confModifier,
|
||||
HTTPRegister: httpRegister,
|
||||
HTTPReg: httpReg,
|
||||
FindClient: globalContext.clients.findMultiple,
|
||||
BaseDir: querylogDir,
|
||||
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
|
||||
@ -112,7 +113,7 @@ func initDNS(
|
||||
globalContext.queryLog,
|
||||
globalContext.dhcpServer,
|
||||
anonymizer,
|
||||
httpRegister,
|
||||
httpReg,
|
||||
tlsMgr,
|
||||
baseLogger,
|
||||
confModifier,
|
||||
@ -132,7 +133,7 @@ func initDNSServer(
|
||||
qlog querylog.QueryLog,
|
||||
dhcpSrv dnsforward.DHCP,
|
||||
anonymizer *aghnet.IPMut,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
httpReg aghhttp.Registrar,
|
||||
tlsMgr *tlsManager,
|
||||
l *slog.Logger,
|
||||
confModifier agh.ConfigModifier,
|
||||
@ -235,13 +236,13 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
|
||||
}
|
||||
|
||||
// newServerConfig converts values from the configuration file into the internal
|
||||
// DNS server configuration. All arguments must not be nil, except for httpReg.
|
||||
// DNS server configuration. All arguments must not be nil.
|
||||
func newServerConfig(
|
||||
dnsConf *dnsConfig,
|
||||
clientSrcConf *clientSourcesConfig,
|
||||
tlsConf *tlsConfigSettings,
|
||||
tlsMgr *tlsManager,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
httpReg aghhttp.Registrar,
|
||||
clientsContainer dnsforward.ClientsContainer,
|
||||
confModifier agh.ConfigModifier,
|
||||
) (newConf *dnsforward.ServerConfig, err error) {
|
||||
@ -264,7 +265,7 @@ func newServerConfig(
|
||||
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
||||
TLSv12Roots: tlsMgr.rootCerts,
|
||||
ConfModifier: confModifier,
|
||||
HTTPRegister: httpReg,
|
||||
HTTPReg: httpReg,
|
||||
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
||||
UseDNS64: dnsConf.UseDNS64,
|
||||
DNS64Prefixes: dnsConf.DNS64Prefixes,
|
||||
|
||||
@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghslog"
|
||||
@ -65,9 +66,6 @@ type homeContext struct {
|
||||
// configuration files, for example /etc/hosts.
|
||||
etcHosts *aghnet.HostsContainer
|
||||
|
||||
// mux is our custom http.ServeMux.
|
||||
mux *http.ServeMux
|
||||
|
||||
// Runtime properties
|
||||
// --
|
||||
|
||||
@ -150,8 +148,6 @@ func setupContext(
|
||||
opts options,
|
||||
isFirstRun bool,
|
||||
) (err error) {
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
if !opts.noEtcHosts {
|
||||
err = setupHostsContainer(ctx, baseLogger)
|
||||
if err != nil {
|
||||
@ -302,11 +298,12 @@ func initContextClients(
|
||||
logger *slog.Logger,
|
||||
sigHdlr *signalHandler,
|
||||
confModifier agh.ConfigModifier,
|
||||
httpReg aghhttp.Registrar,
|
||||
) (err error) {
|
||||
//lint:ignore SA1019 Migration is not over.
|
||||
config.DHCP.WorkDir = globalContext.workDir
|
||||
config.DHCP.DataDir = globalContext.getDataDir()
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.HTTPReg = httpReg
|
||||
config.DHCP.CommandConstructor = executil.SystemCommandConstructor{}
|
||||
config.DHCP.Logger = logger.With(slogutil.KeyPrefix, "dhcpd")
|
||||
config.DHCP.ConfModifier = confModifier
|
||||
@ -335,6 +332,7 @@ func initContextClients(
|
||||
config.Filtering,
|
||||
sigHdlr,
|
||||
confModifier,
|
||||
httpReg,
|
||||
)
|
||||
}
|
||||
|
||||
@ -386,6 +384,7 @@ func setupDNSFilteringConf(
|
||||
conf *filtering.Config,
|
||||
tlsMgr *tlsManager,
|
||||
confModifier agh.ConfigModifier,
|
||||
httpReg aghhttp.Registrar,
|
||||
) (err error) {
|
||||
const (
|
||||
dnsTimeout = 3 * time.Second
|
||||
@ -408,7 +407,7 @@ func setupDNSFilteringConf(
|
||||
}
|
||||
|
||||
conf.ConfModifier = confModifier
|
||||
conf.HTTPRegister = httpRegister
|
||||
conf.HTTPReg = httpReg
|
||||
conf.DataDir = globalContext.getDataDir()
|
||||
conf.Filters = slices.Clone(config.Filters)
|
||||
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
|
||||
@ -563,8 +562,7 @@ func isUpdateEnabled(
|
||||
}
|
||||
}
|
||||
|
||||
// initWeb initializes the web module. upd, baseLogger, tlsMgr, auth, and mux
|
||||
// must not be nil.
|
||||
// initWeb initializes the web module. All arguments must not be nil.
|
||||
func initWeb(
|
||||
ctx context.Context,
|
||||
opts options,
|
||||
@ -575,6 +573,7 @@ func initWeb(
|
||||
auth *auth,
|
||||
mux *http.ServeMux,
|
||||
confModifier agh.ConfigModifier,
|
||||
httpReg aghhttp.Registrar,
|
||||
isCustomUpdURL bool,
|
||||
isFirstRun bool,
|
||||
) (web *webAPI, err error) {
|
||||
@ -602,6 +601,7 @@ func initWeb(
|
||||
logger: logger,
|
||||
baseLogger: baseLogger,
|
||||
confModifier: confModifier,
|
||||
httpReg: httpReg,
|
||||
tlsManager: tlsMgr,
|
||||
auth: auth,
|
||||
mux: mux,
|
||||
@ -718,7 +718,11 @@ func run(
|
||||
slogLogger.With(slogutil.KeyPrefix, "config_modifier"),
|
||||
)
|
||||
|
||||
err = initContextClients(ctx, slogLogger, sigHdlr, confModifier)
|
||||
mw := &webMw{}
|
||||
mux := http.NewServeMux()
|
||||
httpReg := aghhttp.NewDefaultRegistrar(mux, mw.wrap)
|
||||
|
||||
err = initContextClients(ctx, slogLogger, sigHdlr, confModifier, httpReg)
|
||||
fatalOnError(err)
|
||||
|
||||
tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager")
|
||||
@ -726,6 +730,7 @@ func run(
|
||||
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: tlsMgrLogger,
|
||||
confModifier: confModifier,
|
||||
httpReg: httpReg,
|
||||
tlsSettings: config.TLS,
|
||||
servePlainDNS: config.DNS.ServePlainDNS,
|
||||
})
|
||||
@ -736,7 +741,14 @@ func run(
|
||||
|
||||
confModifier.setTLSManager(tlsMgr)
|
||||
|
||||
err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr, confModifier)
|
||||
err = setupDNSFilteringConf(
|
||||
ctx,
|
||||
slogLogger,
|
||||
config.Filtering,
|
||||
tlsMgr,
|
||||
confModifier,
|
||||
httpReg,
|
||||
)
|
||||
fatalOnError(err)
|
||||
|
||||
err = setupOpts(opts)
|
||||
@ -788,13 +800,16 @@ func run(
|
||||
slogLogger,
|
||||
tlsMgr,
|
||||
auth,
|
||||
globalContext.mux,
|
||||
mux,
|
||||
confModifier,
|
||||
httpReg,
|
||||
isCustomURL,
|
||||
isFirstRun,
|
||||
)
|
||||
fatalOnError(err)
|
||||
|
||||
mw.set(web)
|
||||
|
||||
globalContext.web = web
|
||||
|
||||
tlsMgr.setWebAPI(web)
|
||||
@ -804,7 +819,7 @@ func run(
|
||||
fatalOnError(err)
|
||||
|
||||
if !isFirstRun {
|
||||
err = initDNS(ctx, slogLogger, tlsMgr, confModifier, statsDir, querylogDir)
|
||||
err = initDNS(ctx, slogLogger, tlsMgr, confModifier, httpReg, statsDir, querylogDir)
|
||||
fatalOnError(err)
|
||||
|
||||
tlsMgr.start(ctx)
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
|
||||
@ -54,15 +53,24 @@ type languageJSON struct {
|
||||
Language string `json:"language"`
|
||||
}
|
||||
|
||||
// handleI18nCurrentLanguage is the handler for the GET
|
||||
// /control/i18n/current_language HTTP API.
|
||||
//
|
||||
// TODO(d.kolyshev): Deprecated, remove it later.
|
||||
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("home: language is %s", config.Language)
|
||||
func (web *webAPI) handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := web.logger
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, &languageJSON{
|
||||
l.InfoContext(ctx, "current language", "lang", config.Language)
|
||||
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, &languageJSON{
|
||||
Language: config.Language,
|
||||
})
|
||||
}
|
||||
|
||||
// handleI18nChangeLanguage is the handler for the POST
|
||||
// /control/i18n/change_language HTTP API.
|
||||
//
|
||||
// TODO(d.kolyshev): Deprecated, remove it later.
|
||||
func (web *webAPI) handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
@ -46,10 +46,12 @@ type profileJSON struct {
|
||||
|
||||
// handleGetProfile is the handler for GET /control/profile endpoint.
|
||||
func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
var name string
|
||||
|
||||
if !web.auth.isGLiNet && !web.auth.isUserless {
|
||||
u, ok := webUserFromContext(r.Context())
|
||||
u, ok := webUserFromContext(ctx)
|
||||
if !ok {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
|
||||
@ -71,7 +73,7 @@ func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, web.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// handlePutProfile is the handler for PUT /control/profile/update endpoint.
|
||||
|
||||
@ -72,7 +72,6 @@ func TestWeb_HandleGetProfile(t *testing.T) {
|
||||
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
|
||||
baseMux := http.NewServeMux()
|
||||
globalContext.mux = baseMux
|
||||
|
||||
tlsMgr, err := newTLSManager(testutil.ContextWithTimeout(t, testTimeout), &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
@ -90,6 +89,7 @@ func TestWeb_HandleGetProfile(t *testing.T) {
|
||||
auth,
|
||||
baseMux,
|
||||
agh.EmptyConfigModifier{},
|
||||
aghhttp.EmptyRegistrar{},
|
||||
false,
|
||||
false,
|
||||
)
|
||||
@ -132,8 +132,9 @@ func TestWeb_HandleGetProfile(t *testing.T) {
|
||||
func TestWeb_HandlePutProfile(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
mw := &webMw{}
|
||||
mux := http.NewServeMux()
|
||||
globalContext.mux = mux
|
||||
httpReg := aghhttp.NewDefaultRegistrar(mux, mw.wrap)
|
||||
|
||||
isConfigChanged := false
|
||||
confModifier := &aghtest.ConfigModifier{
|
||||
@ -151,12 +152,14 @@ func TestWeb_HandlePutProfile(t *testing.T) {
|
||||
nil,
|
||||
mux,
|
||||
confModifier,
|
||||
httpReg,
|
||||
false,
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
mw.set(web)
|
||||
|
||||
var (
|
||||
dataValid = errors.Must(json.Marshal(&profileJSON{
|
||||
|
||||
@ -60,7 +60,11 @@ type tlsManager struct {
|
||||
// confModifier is used to update the global configuration.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// customCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
// httpReg registers HTTP handlers. It must not be nil.
|
||||
httpReg aghhttp.Registrar
|
||||
|
||||
// customCipherIDs are the IDs of the cipher suites that AdGuard Home must
|
||||
// use.
|
||||
customCipherIDs []uint16
|
||||
|
||||
// servePlainDNS defines if plain DNS is allowed for incoming requests.
|
||||
@ -77,6 +81,8 @@ type tlsManagerConfig struct {
|
||||
// nil.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
httpReg aghhttp.Registrar
|
||||
|
||||
// tlsSettings contains the TLS configuration settings.
|
||||
tlsSettings tlsConfigSettings
|
||||
|
||||
@ -94,6 +100,7 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager,
|
||||
logger: conf.logger,
|
||||
mu: &sync.Mutex{},
|
||||
confModifier: conf.confModifier,
|
||||
httpReg: conf.httpReg,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: &conf.tlsSettings,
|
||||
servePlainDNS: conf.servePlainDNS,
|
||||
@ -251,7 +258,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
|
||||
config.Clients.Sources,
|
||||
m.conf,
|
||||
m,
|
||||
httpRegister,
|
||||
m.httpReg,
|
||||
globalContext.clients.storage,
|
||||
m.confModifier,
|
||||
)
|
||||
@ -426,7 +433,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
servePlainDNS = m.servePlainDNS
|
||||
}()
|
||||
|
||||
data := tlsConfig{
|
||||
data := &tlsConfig{
|
||||
tlsConfigSettingsExt: tlsConfigSettingsExt{
|
||||
tlsConfigSettings: *tlsConf,
|
||||
ServePlainDNS: aghalg.BoolToNullBool(servePlainDNS),
|
||||
@ -434,7 +441,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, data)
|
||||
m.marshalTLS(r.Context(), w, r, data)
|
||||
}
|
||||
|
||||
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
|
||||
@ -471,12 +478,12 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
|
||||
// status.WarningValidation.
|
||||
status := &tlsConfigStatus{}
|
||||
_ = m.loadTLSConfig(ctx, &setts.tlsConfigSettings, status)
|
||||
resp := tlsConfig{
|
||||
resp := &tlsConfig{
|
||||
tlsConfigSettingsExt: setts,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
m.marshalTLS(ctx, w, r, resp)
|
||||
}
|
||||
|
||||
// setConfig updates manager TLS configuration with the given one. m.mu is
|
||||
@ -548,12 +555,12 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
status := &tlsConfigStatus{}
|
||||
err = m.loadTLSConfig(ctx, &req.tlsConfigSettings, status)
|
||||
if err != nil {
|
||||
resp := tlsConfig{
|
||||
resp := &tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
m.marshalTLS(ctx, w, r, resp)
|
||||
|
||||
return
|
||||
}
|
||||
@ -579,12 +586,12 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
resp := tlsConfig{
|
||||
resp := &tlsConfig{
|
||||
tlsConfigSettingsExt: req,
|
||||
tlsConfigStatus: m.status,
|
||||
}
|
||||
|
||||
marshalTLS(w, r, resp)
|
||||
m.marshalTLS(ctx, w, r, resp)
|
||||
rc := http.NewResponseController(w)
|
||||
err = rc.Flush()
|
||||
if err != nil {
|
||||
@ -1027,7 +1034,14 @@ func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
|
||||
// marshalTLS encodes sensitive fields and writes data as JSON. All arguments
|
||||
// must not be nil.
|
||||
func (m *tlsManager) marshalTLS(
|
||||
ctx context.Context,
|
||||
w http.ResponseWriter,
|
||||
r *http.Request,
|
||||
data *tlsConfig,
|
||||
) {
|
||||
if data.CertificateChain != "" {
|
||||
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
|
||||
data.CertificateChain = encoded
|
||||
@ -1038,12 +1052,12 @@ func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
|
||||
data.PrivateKey = ""
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, data)
|
||||
aghhttp.WriteJSONResponseOK(ctx, m.logger, w, r, *data)
|
||||
}
|
||||
|
||||
// registerWebHandlers registers HTTP handlers for TLS configuration.
|
||||
func (m *tlsManager) registerWebHandlers() {
|
||||
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
||||
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
||||
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
||||
m.httpReg.Register(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
|
||||
m.httpReg.Register(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
|
||||
m.httpReg.Register(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -153,7 +154,6 @@ func storeGlobals(tb testing.TB) {
|
||||
prefGLFilePrefix := glFilePrefix
|
||||
storage := globalContext.clients.storage
|
||||
dnsServer := globalContext.dnsServer
|
||||
mux := globalContext.mux
|
||||
web := globalContext.web
|
||||
|
||||
tb.Cleanup(func() {
|
||||
@ -161,7 +161,6 @@ func storeGlobals(tb testing.TB) {
|
||||
glFilePrefix = prefGLFilePrefix
|
||||
globalContext.clients.storage = storage
|
||||
globalContext.dnsServer = dnsServer
|
||||
globalContext.mux = mux
|
||||
globalContext.web = web
|
||||
})
|
||||
}
|
||||
@ -307,6 +306,7 @@ func initEmptyWeb(tb testing.TB) (web *webAPI) {
|
||||
nil,
|
||||
http.NewServeMux(),
|
||||
agh.EmptyConfigModifier{},
|
||||
aghhttp.EmptyRegistrar{},
|
||||
false,
|
||||
false,
|
||||
)
|
||||
@ -337,8 +337,6 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
const (
|
||||
snBefore int64 = 1
|
||||
snAfter int64 = 2
|
||||
@ -419,8 +417,6 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
||||
func TestValidateTLSSettings(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
var (
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
err error
|
||||
@ -516,8 +512,6 @@ func TestValidateTLSSettings(t *testing.T) {
|
||||
func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
||||
storeGlobals(t)
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
var (
|
||||
ctx = testutil.ContextWithTimeout(t, testTimeout)
|
||||
err error
|
||||
@ -598,8 +592,6 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")}
|
||||
config.DNS.Port = 0
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
@ -55,6 +56,9 @@ type webConfig struct {
|
||||
// confModifier is used to update the global configuration.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// httpReg registers HTTP handlers. It must not be nil.
|
||||
httpReg aghhttp.Registrar
|
||||
|
||||
// tlsManager contains the current configuration and state of TLS
|
||||
// encryption. It must not be nil.
|
||||
tlsManager *tlsManager
|
||||
@ -125,6 +129,9 @@ type webAPI struct {
|
||||
// cmdCons is used to run external commands.
|
||||
cmdCons executil.CommandConstructor
|
||||
|
||||
// httpReg registers HTTP handlers.
|
||||
httpReg aghhttp.Registrar
|
||||
|
||||
// TODO(a.garipov): Refactor all these servers.
|
||||
httpServer *http.Server
|
||||
|
||||
@ -157,6 +164,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
w = &webAPI{
|
||||
conf: conf,
|
||||
confModifier: conf.confModifier,
|
||||
httpReg: conf.httpReg,
|
||||
cmdCons: conf.CommandConstructor,
|
||||
logger: conf.logger,
|
||||
baseLogger: conf.baseLogger,
|
||||
|
||||
@ -60,11 +60,13 @@ type HTTPAPIDNSSettings struct {
|
||||
// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP
|
||||
// API.
|
||||
func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
l := svc.logger
|
||||
req := &ReqPatchSettingsDNS{}
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
|
||||
aghhttp.WriteJSONResponseError(ctx, l, w, r, fmt.Errorf("decoding: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
@ -93,10 +95,9 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
|
||||
req.RefuseAny.Set(&newConf.RefuseAny)
|
||||
req.UseDNS64.Set(&newConf.UseDNS64)
|
||||
|
||||
ctx := r.Context()
|
||||
err = svc.confMgr.UpdateDNS(ctx, newConf)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("updating: %w", err))
|
||||
aghhttp.WriteJSONResponseError(ctx, l, w, r, fmt.Errorf("updating: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
@ -104,12 +105,12 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
|
||||
newSvc := svc.confMgr.DNS()
|
||||
err = newSvc.Start(ctx)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))
|
||||
aghhttp.WriteJSONResponseError(ctx, l, w, r, fmt.Errorf("starting new service: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIDNSSettings{
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, &HTTPAPIDNSSettings{
|
||||
UpstreamMode: newConf.UpstreamMode,
|
||||
Addresses: newConf.Addresses,
|
||||
BootstrapServers: newConf.BootstrapServers,
|
||||
|
||||
@ -43,11 +43,13 @@ type HTTPAPIHTTPSettings struct {
|
||||
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
|
||||
// HTTP API.
|
||||
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
req := &ReqPatchSettingsHTTP{}
|
||||
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
|
||||
aghhttp.WriteJSONResponseError(ctx, svc.logger, w, r, fmt.Errorf("decoding: %w", err))
|
||||
|
||||
return
|
||||
}
|
||||
@ -61,7 +63,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
|
||||
req.Timeout.Set((*aghhttp.JSONDuration)(&newConf.Timeout))
|
||||
req.ForceHTTPS.Set(&newConf.ForceHTTPS)
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIHTTPSettings{
|
||||
aghhttp.WriteJSONResponseOK(ctx, svc.logger, w, r, &HTTPAPIHTTPSettings{
|
||||
Addresses: newConf.Addresses,
|
||||
SecureAddresses: newConf.SecureAddresses,
|
||||
Timeout: aghhttp.JSONDuration(newConf.Timeout),
|
||||
@ -71,7 +73,6 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
|
||||
cancelUpd := func() {}
|
||||
updCtx := context.Background()
|
||||
|
||||
ctx := r.Context()
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
updCtx, cancelUpd = context.WithDeadline(updCtx, deadline)
|
||||
}
|
||||
|
||||
@ -27,7 +27,7 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request)
|
||||
httpConf := webSvc.Config()
|
||||
|
||||
// TODO(a.garipov): Add all currently supported parameters.
|
||||
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SettingsAll{
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), svc.logger, w, r, &RespGetV1SettingsAll{
|
||||
DNS: &HTTPAPIDNSSettings{
|
||||
UpstreamMode: dnsConf.UpstreamMode,
|
||||
Addresses: dnsConf.Addresses,
|
||||
|
||||
@ -24,7 +24,7 @@ type RespGetV1SystemInfo struct {
|
||||
// handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP
|
||||
// API.
|
||||
func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) {
|
||||
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SystemInfo{
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), svc.logger, w, r, &RespGetV1SystemInfo{
|
||||
Arch: runtime.GOARCH,
|
||||
Channel: version.Channel(),
|
||||
OS: runtime.GOOS,
|
||||
|
||||
@ -59,18 +59,18 @@ type getConfigResp struct {
|
||||
|
||||
// Register web handlers
|
||||
func (l *queryLog) initWeb() {
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog", l.handleQueryLog)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog/config", l.handleGetQueryLogConfig)
|
||||
l.conf.HTTPRegister(
|
||||
l.conf.HTTPReg.Register(http.MethodGet, "/control/querylog", l.handleQueryLog)
|
||||
l.conf.HTTPReg.Register(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
|
||||
l.conf.HTTPReg.Register(http.MethodGet, "/control/querylog/config", l.handleGetQueryLogConfig)
|
||||
l.conf.HTTPReg.Register(
|
||||
http.MethodPut,
|
||||
"/control/querylog/config/update",
|
||||
l.handlePutQueryLogConfig,
|
||||
)
|
||||
|
||||
// Deprecated handlers.
|
||||
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
|
||||
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
|
||||
l.conf.HTTPReg.Register(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
|
||||
l.conf.HTTPReg.Register(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
|
||||
}
|
||||
|
||||
// handleQueryLog is the handler for the GET /control/querylog HTTP API.
|
||||
@ -94,7 +94,7 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
resp := l.entriesToJSON(ctx, entries, oldest, l.anonymizer.Load())
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
|
||||
@ -119,7 +119,7 @@ func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
|
||||
ivl = timeutil.Day * 90
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, configJSON{
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), l.logger, w, r, configJSON{
|
||||
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
|
||||
Interval: ivl.Hours() / 24,
|
||||
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
|
||||
@ -142,7 +142,7 @@ func (l *queryLog) handleGetQueryLogConfig(w http.ResponseWriter, r *http.Reques
|
||||
}
|
||||
}()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), l.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// AnonymizeIP masks ip to anonymize the client if the ip is a valid one.
|
||||
|
||||
@ -87,7 +87,7 @@ var _ QueryLog = (*queryLog)(nil)
|
||||
|
||||
// Start implements the [QueryLog] interface for *queryLog.
|
||||
func (l *queryLog) Start(ctx context.Context) (err error) {
|
||||
if l.conf.HTTPRegister != nil {
|
||||
if l.conf.HTTPReg != nil {
|
||||
l.initWeb()
|
||||
}
|
||||
|
||||
|
||||
@ -53,7 +53,7 @@ type Config struct {
|
||||
ConfigModifier agh.ConfigModifier
|
||||
|
||||
// HTTPRegister registers an HTTP handler.
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
HTTPReg aghhttp.Registrar
|
||||
|
||||
// FindClient returns client information by their IDs.
|
||||
FindClient func(ids []string) (c *Client, err error)
|
||||
|
||||
@ -51,6 +51,7 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
ctx := r.Context()
|
||||
l := s.logger
|
||||
|
||||
var (
|
||||
resp *StatsResp
|
||||
@ -63,18 +64,18 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
|
||||
resp, ok = s.getData(uint32(s.limit.Hours()))
|
||||
}()
|
||||
|
||||
s.logger.DebugContext(ctx, "prepared data", "elapsed", time.Since(start))
|
||||
l.DebugContext(ctx, "prepared data", "elapsed", time.Since(start))
|
||||
|
||||
if !ok {
|
||||
// Don't bring the message to the lower case since it's a part of UI
|
||||
// text for the moment.
|
||||
const msg = "Couldn't get statistics data"
|
||||
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusInternalServerError, msg)
|
||||
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, msg)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
|
||||
}
|
||||
|
||||
// configResp is the response to the GET /control/stats_info.
|
||||
@ -123,7 +124,7 @@ func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
|
||||
resp.IntervalDays = 0
|
||||
}
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), s.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// handleGetStatsConfig is the handler for the GET /control/stats/config HTTP
|
||||
@ -141,7 +142,7 @@ func (s *StatsCtx) handleGetStatsConfig(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
}()
|
||||
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
aghhttp.WriteJSONResponseOK(r.Context(), s.logger, w, r, resp)
|
||||
}
|
||||
|
||||
// handleStatsConfig is the handler for the POST /control/stats_config HTTP API.
|
||||
@ -244,16 +245,12 @@ func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// initWeb registers the handlers for web endpoints of statistics module.
|
||||
func (s *StatsCtx) initWeb() {
|
||||
if s.httpRegister == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.httpRegister(http.MethodGet, "/control/stats", s.handleStats)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
|
||||
s.httpRegister(http.MethodGet, "/control/stats/config", s.handleGetStatsConfig)
|
||||
s.httpRegister(http.MethodPut, "/control/stats/config/update", s.handlePutStatsConfig)
|
||||
s.httpReg.Register(http.MethodGet, "/control/stats", s.handleStats)
|
||||
s.httpReg.Register(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
|
||||
s.httpReg.Register(http.MethodGet, "/control/stats/config", s.handleGetStatsConfig)
|
||||
s.httpReg.Register(http.MethodPut, "/control/stats/config/update", s.handlePutStatsConfig)
|
||||
|
||||
// Deprecated handlers.
|
||||
s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
|
||||
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
|
||||
s.httpReg.Register(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
|
||||
s.httpReg.Register(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@ -30,6 +31,7 @@ func TestHandleStatsConfig(t *testing.T) {
|
||||
UnitID: func() (id uint32) { return 0 },
|
||||
ConfigModifier: agh.EmptyConfigModifier{},
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
HTTPReg: aghhttp.EmptyRegistrar{},
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: time.Hour * 24,
|
||||
Enabled: true,
|
||||
|
||||
@ -65,7 +65,7 @@ type Config struct {
|
||||
|
||||
// HTTPRegister is the function that registers handlers for the stats
|
||||
// endpoints.
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
HTTPReg aghhttp.Registrar
|
||||
|
||||
// Ignored contains the list of host names, which should not be counted,
|
||||
// and matches them.
|
||||
@ -121,8 +121,8 @@ type StatsCtx struct {
|
||||
// unit. It's here for only testing purposes.
|
||||
unitIDGen UnitIDGenFunc
|
||||
|
||||
// httpRegister is used to set HTTP handlers.
|
||||
httpRegister aghhttp.RegisterFunc
|
||||
// httpReg registers HTTP handlers. It must not be nil.
|
||||
httpReg aghhttp.Registrar
|
||||
|
||||
// configModifier is used to update the global configuration.
|
||||
configModifier agh.ConfigModifier
|
||||
@ -164,7 +164,7 @@ func New(conf Config) (s *StatsCtx, err error) {
|
||||
s = &StatsCtx{
|
||||
logger: conf.Logger,
|
||||
currMu: &sync.RWMutex{},
|
||||
httpRegister: conf.HTTPRegister,
|
||||
httpReg: conf.HTTPReg,
|
||||
configModifier: conf.ConfigModifier,
|
||||
filename: conf.Filename,
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@ -21,6 +22,7 @@ func TestStats_races(t *testing.T) {
|
||||
conf := Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
HTTPReg: aghhttp.EmptyRegistrar{},
|
||||
UnitID: idGen,
|
||||
Filename: filepath.Join(t.TempDir(), "./stats.db"),
|
||||
Limit: timeutil.Day,
|
||||
|
||||
@ -11,7 +11,9 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
@ -59,8 +61,10 @@ func TestStats(t *testing.T) {
|
||||
Limit: timeutil.Day,
|
||||
Enabled: true,
|
||||
UnitID: constUnitID,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
HTTPReg: &aghtest.Registrar{
|
||||
OnRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@ -180,7 +184,11 @@ func TestLargeNumbers(t *testing.T) {
|
||||
Limit: timeutil.Day,
|
||||
Enabled: true,
|
||||
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
|
||||
HTTPReg: &aghtest.Registrar{
|
||||
OnRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s, err := stats.New(conf)
|
||||
@ -234,6 +242,7 @@ func TestShouldCount(t *testing.T) {
|
||||
ShouldCountClient: func(ids []string) (a bool) {
|
||||
return ids[0] != "no_count"
|
||||
},
|
||||
HTTPReg: aghhttp.EmptyRegistrar{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user