Pull request 2496: AGDNS-3224-aghhttp-register-slog

Squashed commit of the following:

commit 9324a0066202f1677bfd033d40d3a82fa9756ed9
Merge: 8a1b5cad4 f9da40e39
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Oct 23 17:48:01 2025 +0300

    Merge branch 'master' into AGDNS-3224-aghhttp-register-slog

commit 8a1b5cad4c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Oct 21 15:51:48 2025 +0300

    filtering: imp code

commit fe569166ef
Merge: 9a101a2f5 9be4ca90e
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Oct 21 15:45:42 2025 +0300

    Merge branch 'master' into AGDNS-3224-aghhttp-register-slog

commit 9a101a2f5f
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Oct 15 18:52:22 2025 +0300

    home: imp code

commit 727e1663ba
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Oct 15 10:19:56 2025 +0300

    all: imp code

commit 113a9017df
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Oct 13 23:10:06 2025 +0300

    home: fix typo

commit 6588dd2dad
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Mon Oct 13 22:46:28 2025 +0300

    all: imp naming

commit 44278505a9
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Oct 10 16:20:17 2025 +0300

    home: fix typo

commit 7b4b57628b
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Fri Oct 10 15:58:07 2025 +0300

    all: web mw

commit 93168142cb
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Oct 8 22:20:07 2025 +0300

    all: aghhttp slog

commit 9155edef67
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Oct 8 15:38:01 2025 +0300

    aghhttp: registrar

commit a356473855
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Oct 7 15:32:30 2025 +0300

    all: http registrar
This commit is contained in:
Stanislav Chzhen 2025-10-23 18:05:39 +03:00
parent f9da40e393
commit 5c9fef62f1
46 changed files with 531 additions and 337 deletions

View File

@ -10,7 +10,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
@ -27,15 +26,6 @@ func OK(ctx context.Context, l *slog.Logger, w http.ResponseWriter) {
}
}
// Error writes formatted message to w and also logs it.
//
// TODO(s.chzhen): Remove it.
func Error(r *http.Request, w http.ResponseWriter, code int, format string, args ...any) {
text := fmt.Sprintf(format, args...)
log.Error("%s %s %s: %s", r.Method, r.Host, r.URL, text)
http.Error(w, text, code)
}
// ErrorAndLog writes a formatted HTTP error response and logs it at
// [slog.LevelError] level. l, r, and w must not be nil.
func ErrorAndLog(

View File

@ -1,14 +1,16 @@
package aghhttp
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strconv"
"time"
"github.com/AdguardTeam/golibs/httphdr"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// JSON Utilities
@ -88,8 +90,15 @@ func (t *JSONTime) UnmarshalJSON(b []byte) (err error) {
// WriteJSONResponse writes headers with the code, encodes resp into w, and logs
// any errors it encounters. r is used to get additional information from the
// request.
func WriteJSONResponse(w http.ResponseWriter, r *http.Request, code int, resp any) {
// request. l, w, and r must not be nil.
func WriteJSONResponse(
ctx context.Context,
l *slog.Logger,
w http.ResponseWriter,
r *http.Request,
code int,
resp any,
) {
h := w.Header()
h.Set(httphdr.ContentType, HdrValApplicationJSON)
h.Set(httphdr.Server, UserAgent())
@ -98,15 +107,27 @@ func WriteJSONResponse(w http.ResponseWriter, r *http.Request, code int, resp an
err := json.NewEncoder(w).Encode(resp)
if err != nil {
log.Error("aghhttp: writing json resp to %s %s: %s", r.Method, r.URL.Path, err)
l.ErrorContext(
ctx,
"writing json response",
"method", r.Method,
"path", r.URL.Path,
slogutil.KeyError, err,
)
}
}
// WriteJSONResponseOK writes headers with the code 200 OK, encodes v into w,
// and logs any errors it encounters. r is used to get additional information
// from the request.
func WriteJSONResponseOK(w http.ResponseWriter, r *http.Request, v any) {
WriteJSONResponse(w, r, http.StatusOK, v)
// from the request. l, w, and r must not be nil.
func WriteJSONResponseOK(
ctx context.Context,
l *slog.Logger,
w http.ResponseWriter,
r *http.Request,
v any,
) {
WriteJSONResponse(ctx, l, w, r, http.StatusOK, v)
}
// ErrorCode is the error code as used by the HTTP API. See the ErrorCode
@ -131,11 +152,23 @@ type HTTPAPIErrorResp struct {
// WriteJSONResponseError encodes err as a JSON error into w, and logs any
// errors it encounters. r is used to get additional information from the
// request.
func WriteJSONResponseError(w http.ResponseWriter, r *http.Request, err error) {
log.Error("aghhttp: writing json error to %s %s: %s", r.Method, r.URL.Path, err)
// request. l, w, and r must not be nil.
func WriteJSONResponseError(
ctx context.Context,
l *slog.Logger,
w http.ResponseWriter,
r *http.Request,
err error,
) {
l.ErrorContext(
ctx,
"writing json error",
"method", r.Method,
"path", r.URL.Path,
slogutil.KeyError, err,
)
WriteJSONResponse(w, r, http.StatusUnprocessableEntity, &HTTPAPIErrorResp{
WriteJSONResponse(ctx, l, w, r, http.StatusUnprocessableEntity, &HTTPAPIErrorResp{
Code: ErrorCodeTMP000,
Msg: err.Error(),
})

View File

@ -0,0 +1,49 @@
package aghhttp
import (
"net/http"
)
// Registrar registers an HTTP handler for a method and path.
//
// TODO(s.chzhen): Implement [httputil.Router].
type Registrar interface {
Register(method, path string, h http.HandlerFunc)
}
// EmptyRegistrar is an implementation of [Registrar] that does nothing.
type EmptyRegistrar struct{}
// type check
var _ Registrar = EmptyRegistrar{}
// Register implements the [Registrar] interface.
func (EmptyRegistrar) Register(_, _ string, _ http.HandlerFunc) {}
// WrapFunc is a wrapper function that builds an HTTP handler for a route.
type WrapFunc func(method string, h http.HandlerFunc) (wrapped http.Handler)
// DefaultRegistrar is an implementation of [Registrar] that registers handlers
// after applying a user-provided wrapper function.
type DefaultRegistrar struct {
mux *http.ServeMux
wrapFn WrapFunc
}
// NewDefaultRegistrar returns a new properly initialized *DefaultRegistrar.
// mux and wrap must not be nil.
func NewDefaultRegistrar(mux *http.ServeMux, wrap WrapFunc) (r *DefaultRegistrar) {
return &DefaultRegistrar{
mux: mux,
wrapFn: wrap,
}
}
// type check
var _ Registrar = (*DefaultRegistrar)(nil)
// Register implements the [Registrar] interface.
func (r *DefaultRegistrar) Register(method, path string, h http.HandlerFunc) {
wrapped := r.wrapFn(method, h)
r.mux.Handle(path, wrapped)
}

View File

@ -2,10 +2,12 @@ package aghtest
import (
"context"
"net/http"
"net/netip"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
@ -158,7 +160,20 @@ type ConfigModifier struct {
// type check
var _ agh.ConfigModifier = (*ConfigModifier)(nil)
// Apply implements the [ConfigModifier] interface for *ConfigModifier.
// Apply implements the [agh.ConfigModifier] interface for *ConfigModifier.
func (m *ConfigModifier) Apply(ctx context.Context) {
m.OnApply(ctx)
}
// Registrar is a fake [aghhttp.Registrar] implementation for tests.
type Registrar struct {
OnRegister func(method, path string, h http.HandlerFunc)
}
// type check
var _ aghhttp.Registrar = (*Registrar)(nil)
// Register implements the [aghhttp.Registrar] interface for *Registrar.
func (m *Registrar) Register(method, path string, h http.HandlerFunc) {
m.OnRegister(method, path, h)
}

View File

@ -31,7 +31,7 @@ type ServerConfig struct {
ConfModifier agh.ConfigModifier `yaml:"-"`
// Register an HTTP handler
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
HTTPReg aghhttp.Registrar `yaml:"-"`
Enabled bool `yaml:"enabled"`
InterfaceName string `yaml:"interface_name"`

View File

@ -112,7 +112,7 @@ func Create(ctx context.Context, conf *ServerConfig) (s *server, err error) {
CommandConstructor: conf.CommandConstructor,
ConfModifier: conf.ConfModifier,
HTTPRegister: conf.HTTPRegister,
HTTPReg: conf.HTTPReg,
Enabled: conf.Enabled,
InterfaceName: conf.InterfaceName,

View File

@ -171,7 +171,7 @@ func (s *server) handleDHCPStatus(w http.ResponseWriter, r *http.Request) {
status.Leases = leasesToDynamic(leases[dynamicIdx:])
status.StaticLeases = leasesToStatic(leases[:dynamicIdx])
aghhttp.WriteJSONResponseOK(w, r, status)
aghhttp.WriteJSONResponseOK(r.Context(), s.conf.Logger, w, r, status)
}
func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, err error) {
@ -451,7 +451,7 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
}
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
}
// newNetInterfaceJSON creates a JSON object from a [net.Interface] iface.
@ -613,7 +613,7 @@ func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
s.setOtherDHCPResult(ctx, ifaceName, result)
aghhttp.WriteJSONResponseOK(w, r, result)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, result)
}
// setOtherDHCPResult sets the results of the check for another DHCP server in
@ -741,7 +741,7 @@ func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
s.conf = &ServerConfig{
ConfModifier: s.conf.ConfModifier,
HTTPRegister: s.conf.HTTPRegister,
HTTPReg: s.conf.HTTPReg,
LocalDomainName: s.conf.LocalDomainName,
@ -778,17 +778,17 @@ func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
}
func (s *server) registerHandlers() {
if s.conf.HTTPRegister == nil {
if s.conf.HTTPReg == nil {
return
}
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.handleDHCPUpdateStaticLease)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.handleReset)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/status", s.handleDHCPStatus)
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/interfaces", s.handleDHCPInterfaces)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/set_config", s.handleDHCPSetConfig)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/find_active_dhcp", s.handleDHCPFindActiveServer)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/add_static_lease", s.handleDHCPAddStaticLease)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/remove_static_lease", s.handleDHCPRemoveStaticLease)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/update_static_lease", s.handleDHCPUpdateStaticLease)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset", s.handleReset)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset_leases", s.handleResetLeases)
}

View File

@ -24,7 +24,9 @@ type jsonError struct {
// TODO(a.garipov): Either take the logger from the server after we've
// refactored logging or make this not a method of *Server.
func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponse(w, r, http.StatusNotImplemented, &jsonError{
ctx := r.Context()
aghhttp.WriteJSONResponse(ctx, s.conf.Logger, w, r, http.StatusNotImplemented, &jsonError{
Message: aghos.Unsupported("dhcp").Error(),
})
}
@ -37,13 +39,13 @@ func (s *server) notImplemented(w http.ResponseWriter, r *http.Request) {
// interconnected parts--such as HTTP handlers and frontend--to make that work
// properly.
func (s *server) registerHandlers() {
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/status", s.notImplemented)
s.conf.HTTPRegister(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/update_static_lease", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
s.conf.HTTPRegister(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/status", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodGet, "/control/dhcp/interfaces", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/set_config", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/find_active_dhcp", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/add_static_lease", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/remove_static_lease", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/update_static_lease", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset", s.notImplemented)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dhcp/reset_leases", s.notImplemented)
}

View File

@ -186,7 +186,7 @@ func (s *Server) accessListJSON() (j accessListJSON) {
// handleAccessList handles requests to the GET /control/access/list endpoint.
func (s *Server) handleAccessList(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, s.accessListJSON())
aghhttp.WriteJSONResponseOK(r.Context(), s.logger, w, r, s.accessListJSON())
}
// validateAccessSet checks the internal accessListJSON lists. To search for

View File

@ -270,7 +270,7 @@ type ServerConfig struct {
ConfModifier agh.ConfigModifier
// Register an HTTP handler
HTTPRegister aghhttp.RegisterFunc
HTTPReg aghhttp.Registrar
// LocalPTRResolvers is a slice of addresses to be used as upstreams for
// resolving PTR queries for local addresses.

View File

@ -245,8 +245,10 @@ func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
resp := s.getDNSConfig(r.Context())
aghhttp.WriteJSONResponseOK(w, r, resp)
ctx := r.Context()
resp := s.getDNSConfig(ctx)
aghhttp.WriteJSONResponseOK(ctx, s.logger, w, r, resp)
}
// checkBlockingMode returns an error if blocking mode is invalid.
@ -739,7 +741,7 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
cv.check()
cv.close()
aghhttp.WriteJSONResponseOK(w, r, cv.status())
aghhttp.WriteJSONResponseOK(ctx, l, w, r, cv.status())
}
// handleCacheClear is the handler for the POST /control/cache_clear HTTP API.
@ -797,7 +799,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
s.conf.ConfModifier.Apply(ctx)
aghhttp.OK(ctx, s.logger, w)
aghhttp.OK(ctx, l, w)
}
// handleDoH is the DNS-over-HTTPs handler.
@ -837,19 +839,19 @@ func (s *Server) handleDoH(w http.ResponseWriter, r *http.Request) {
}
func (s *Server) registerHandlers() {
if webRegistered || s.conf.HTTPRegister == nil {
if webRegistered || s.conf.HTTPReg == nil {
return
}
s.conf.HTTPRegister(http.MethodGet, "/control/dns_info", s.handleGetConfig)
s.conf.HTTPRegister(http.MethodPost, "/control/dns_config", s.handleSetConfig)
s.conf.HTTPRegister(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
s.conf.HTTPRegister(http.MethodPost, "/control/protection", s.handleSetProtection)
s.conf.HTTPReg.Register(http.MethodGet, "/control/dns_info", s.handleGetConfig)
s.conf.HTTPReg.Register(http.MethodPost, "/control/dns_config", s.handleSetConfig)
s.conf.HTTPReg.Register(http.MethodPost, "/control/test_upstream_dns", s.handleTestUpstreamDNS)
s.conf.HTTPReg.Register(http.MethodPost, "/control/protection", s.handleSetProtection)
s.conf.HTTPRegister(http.MethodGet, "/control/access/list", s.handleAccessList)
s.conf.HTTPRegister(http.MethodPost, "/control/access/set", s.handleAccessSet)
s.conf.HTTPReg.Register(http.MethodGet, "/control/access/list", s.handleAccessList)
s.conf.HTTPReg.Register(http.MethodPost, "/control/access/set", s.handleAccessSet)
s.conf.HTTPRegister(http.MethodPost, "/control/cache_clear", s.handleCacheClear)
s.conf.HTTPReg.Register(http.MethodPost, "/control/cache_clear", s.handleCacheClear)
// Register both versions, with and without the trailing slash, to
// prevent a 301 Moved Permanently redirect when clients request the
@ -858,8 +860,8 @@ func (s *Server) registerHandlers() {
// See go doc net/http.ServeMux.
//
// See also https://github.com/AdguardTeam/AdGuardHome/issues/2628.
s.conf.HTTPRegister("", "/dns-query", s.handleDoH)
s.conf.HTTPRegister("", "/dns-query/", s.handleDoH)
s.conf.HTTPReg.Register("", "/dns-query", s.handleDoH)
s.conf.HTTPReg.Register("", "/dns-query/", s.handleDoH)
webRegistered = true
}

View File

@ -127,11 +127,11 @@ func (d *DNSFilter) ApplyBlockedServicesList(setts *Settings, list []string) {
}
func (d *DNSFilter) handleBlockedServicesIDs(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, serviceIDs)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, serviceIDs)
}
func (d *DNSFilter) handleBlockedServicesAll(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, struct {
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, struct {
BlockedServices []blockedService `json:"blocked_services"`
ServiceGroups []serviceGroup `json:"groups"`
}{
@ -153,7 +153,7 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
list = d.conf.BlockedServices.IDs
}()
aghhttp.WriteJSONResponseOK(w, r, list)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, list)
}
// handleBlockedServicesSet is the handler for the POST
@ -193,7 +193,7 @@ func (d *DNSFilter) handleBlockedServicesGet(w http.ResponseWriter, r *http.Requ
bsvc = d.conf.BlockedServices.Clone()
}()
aghhttp.WriteJSONResponseOK(w, r, bsvc)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, bsvc)
}
// handleBlockedServicesUpdate is the handler for the PUT

View File

@ -109,8 +109,8 @@ type Config struct {
// nil.
ConfModifier agh.ConfigModifier `yaml:"-"`
// Register an HTTP handler
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
// HTTPReg registers HTTP handlers. It must not be nil.
HTTPReg aghhttp.Registrar `yaml:"-"`
// HTTPClient is the client to use for updating the remote filters.
HTTPClient *http.Client `yaml:"-"`

View File

@ -370,11 +370,12 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
var err error
ctx := r.Context()
l := d.logger
req := Req{}
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.ErrorAndLog(ctx, d.logger, r, w, http.StatusBadRequest, "json decode: %s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "json decode: %s", err)
return
}
@ -387,7 +388,7 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
if !ok {
aghhttp.ErrorAndLog(
ctx,
d.logger,
l,
r,
w,
http.StatusInternalServerError,
@ -397,7 +398,7 @@ func (d *DNSFilter) handleFilteringRefresh(w http.ResponseWriter, r *http.Reques
return
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
}
type filterJSON struct {
@ -454,7 +455,7 @@ func (d *DNSFilter) handleFilteringStatus(w http.ResponseWriter, r *http.Request
resp.UserRules = d.conf.UserRules
d.conf.filtersMu.RUnlock()
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
}
// Set filtering configuration
@ -521,13 +522,14 @@ type checkHostResp struct {
// API.
func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := d.logger
query := r.URL.Query()
host := query.Get("name")
if host == "" {
aghhttp.ErrorAndLog(
ctx,
d.logger,
l,
r,
w,
http.StatusBadRequest,
@ -543,7 +545,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
if err != nil {
aghhttp.ErrorAndLog(
ctx,
d.logger,
l,
r,
w,
http.StatusUnprocessableEntity,
@ -572,7 +574,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
if err != nil {
aghhttp.ErrorAndLog(
ctx,
d.logger,
l,
r,
w,
http.StatusInternalServerError,
@ -605,7 +607,7 @@ func (d *DNSFilter) handleCheckHost(w http.ResponseWriter, r *http.Request) {
}
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
}
// stringToDNSType is a helper function that converts a string to DNS type. If
@ -680,7 +682,7 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
Enabled: protectedBool(d.confMu, &d.conf.SafeBrowsingEnabled),
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
}
// handleParentalEnable is the handler for the POST /control/parental/enable
@ -706,15 +708,12 @@ func (d *DNSFilter) handleParentalStatus(w http.ResponseWriter, r *http.Request)
Enabled: protectedBool(d.confMu, &d.conf.ParentalEnabled),
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
}
// RegisterFilteringHandlers - register handlers
func (d *DNSFilter) RegisterFilteringHandlers() {
registerHTTP := d.conf.HTTPRegister
if registerHTTP == nil {
return
}
registerHTTP := d.conf.HTTPReg.Register
registerHTTP(http.MethodPost, "/control/safebrowsing/enable", d.handleSafeBrowsingEnable)
registerHTTP(http.MethodPost, "/control/safebrowsing/disable", d.handleSafeBrowsingDisable)

View File

@ -12,6 +12,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/testutil"
@ -116,6 +117,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
Timeout: 5 * time.Second,
},
ConfModifier: confModifier,
HTTPReg: aghhttp.EmptyRegistrar{},
DataDir: filtersDir,
}, nil)
require.NoError(t, err)
@ -197,8 +199,10 @@ func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) {
Logger: testLogger,
ConfModifier: confModifier,
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
HTTPReg: &aghtest.Registrar{
OnRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
},
SafeBrowsingEnabled: tc.enabled,
}, nil)
@ -284,8 +288,10 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) {
Logger: testLogger,
ConfModifier: confModifier,
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
HTTPReg: &aghtest.Registrar{
OnRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
},
ParentalEnabled: tc.enabled,
}, nil)

View File

@ -45,7 +45,7 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
}
}()
aghhttp.WriteJSONResponseOK(w, r, arr)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, arr)
}
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
@ -229,20 +229,22 @@ func (d *DNSFilter) handleRewriteSettings(w http.ResponseWriter, r *http.Request
Enabled: protectedBool(d.confMu, &d.conf.RewritesEnabled),
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
}
// handleRewriteSettingsUpdate is the handler for the PUT
// /control/rewrite/settings/update HTTP API.
func (d *DNSFilter) handleRewriteSettingsUpdate(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req := &rewriteSettings{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json.Decode: %s", err)
aghhttp.ErrorAndLog(ctx, d.logger, r, w, http.StatusBadRequest, "json.Decode: %s", err)
return
}
setProtectedBool(d.confMu, &d.conf.RewritesEnabled, req.Enabled)
d.conf.ConfModifier.Apply(r.Context())
d.conf.ConfModifier.Apply(ctx)
}

View File

@ -261,8 +261,10 @@ func TestDNSFilter_HandleRewriteHTTP(t *testing.T) {
d, err := filtering.New(&filtering.Config{
Logger: testLogger,
ConfModifier: confModifier,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
HTTPReg: &aghtest.Registrar{
OnRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
},
Rewrites: rewriteEntriesToLegacyRewrites(testRewrites),
}, nil)
@ -362,8 +364,10 @@ func TestDNSFilter_HandleRewriteSettings(t *testing.T) {
d, err := filtering.New(&filtering.Config{
Logger: testLogger,
ConfModifier: confModifier,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
HTTPReg: &aghtest.Registrar{
OnRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
},
RewritesEnabled: false,
}, nil)

View File

@ -36,7 +36,7 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
resp = d.conf.SafeSearchConf
}()
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), d.logger, w, r, resp)
}
// handleSafeSearchSettings is the handler for PUT /control/safesearch/settings

View File

@ -267,10 +267,10 @@ func (web *webAPI) handleLogout(w http.ResponseWriter, r *http.Request) {
// registerAuthHandlers registers authentication handlers.
func (web *webAPI) registerAuthHandlers() {
web.conf.mux.Handle(
"/control/login",
web.postInstallHandler(ensure(http.MethodPost, web.handleLogin)),
http.MethodPost+" "+"/control/login",
web.postInstallHandler(http.HandlerFunc(web.handleLogin)),
)
httpRegister(http.MethodGet, "/control/logout", web.handleLogout)
web.httpReg.Register(http.MethodGet, "/control/logout", web.handleLogout)
}
// isPublicResource returns true if p is a path to a public resource.

View File

@ -321,8 +321,9 @@ func authRequest(path string, c *http.Cookie, user, pass string) (r *http.Reques
func TestAuth_ServeHTTP_firstRun(t *testing.T) {
storeGlobals(t)
mw := &webMw{}
mux := http.NewServeMux()
globalContext.mux = mux
httpReg := aghhttp.NewDefaultRegistrar(mux, mw.wrap)
ctx := testutil.ContextWithTimeout(t, testTimeout)
web, err := initWeb(
@ -335,12 +336,14 @@ func TestAuth_ServeHTTP_firstRun(t *testing.T) {
nil,
mux,
agh.EmptyConfigModifier{},
httpReg,
false,
true,
)
require.NoError(t, err)
globalContext.web = web
mw.set(web)
testCases := []struct {
name string
@ -484,8 +487,9 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
mw := &webMw{}
baseMux := http.NewServeMux()
globalContext.mux = baseMux
httpReg := aghhttp.NewDefaultRegistrar(baseMux, mw.wrap)
tlsMgr, err := newTLSManager(testutil.ContextWithTimeout(t, testTimeout), &tlsManagerConfig{
logger: testLogger,
@ -504,12 +508,14 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
auth,
baseMux,
agh.EmptyConfigModifier{},
httpReg,
false,
false,
)
require.NoError(t, err)
globalContext.web = web
mw.set(web)
mux := auth.middleware().Wrap(baseMux)
@ -645,8 +651,9 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
mw := &webMw{}
baseMux := http.NewServeMux()
globalContext.mux = baseMux
httpReg := aghhttp.NewDefaultRegistrar(baseMux, mw.wrap)
ctx := testutil.ContextWithTimeout(t, testTimeout)
web, err := initWeb(
@ -659,12 +666,14 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
auth,
baseMux,
agh.EmptyConfigModifier{},
httpReg,
false,
false,
)
require.NoError(t, err)
globalContext.web = web
mw.set(web)
mux := auth.middleware().Wrap(baseMux)

View File

@ -10,6 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
"github.com/AdguardTeam/AdGuardHome/internal/client"
@ -44,6 +45,9 @@ type clientsContainer struct {
// nil.
confModifier agh.ConfigModifier
// httpReg registers HTTP handlers. It must not be nil.
httpReg aghhttp.Registrar
// lock protects all fields.
//
// TODO(a.garipov): Use a pointer and describe which fields are protected in
@ -66,9 +70,12 @@ type BlockedClientChecker interface {
IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string)
}
// Init initializes clients container
// dhcpServer: optional
// Note: this function must be called only once
// Init initializes the clients container. All arguments must not be nil except
// for objects.
//
// NOTE: This function must be called only once.
//
// TODO(s.chzhen): Use a configuration structure.
func (clients *clientsContainer) Init(
ctx context.Context,
baseLogger *slog.Logger,
@ -79,6 +86,7 @@ func (clients *clientsContainer) Init(
filteringConf *filtering.Config,
sigHdlr *signalHandler,
confModifier agh.ConfigModifier,
httpReg aghhttp.Registrar,
) (err error) {
// TODO(s.chzhen): Refactor it.
if clients.storage != nil {
@ -90,6 +98,7 @@ func (clients *clientsContainer) Init(
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
clients.confModifier = confModifier
clients.httpReg = httpReg
confClients := make([]*client.Persistent, 0, len(objects))
for i, o := range objects {

View File

@ -4,6 +4,7 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
@ -30,6 +31,7 @@ func newClientsContainer(tb testing.TB) (c *clientsContainer) {
},
newSignalHandler(testLogger, nil, nil),
agh.EmptyConfigModifier{},
aghhttp.EmptyRegistrar{},
)
require.NoError(tb, err)

View File

@ -94,6 +94,7 @@ func whoisOrEmpty(r *client.Runtime) (info *whois.Info) {
// handleGetClients is the handler for GET /control/clients HTTP API.
func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
data := clientListJSON{}
clients.lock.Lock()
@ -106,7 +107,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
return true
})
clients.storage.UpdateDHCP(r.Context())
clients.storage.UpdateDHCP(ctx)
clients.storage.RangeRuntime(func(rc *client.Runtime) (cont bool) {
src, host := rc.Info()
@ -124,7 +125,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http
data.Tags = clients.storage.AllowedTags()
aghhttp.WriteJSONResponseOK(w, r, data)
aghhttp.WriteJSONResponseOK(ctx, clients.logger, w, r, data)
}
// initPrev initializes the persistent client with the default or previous
@ -454,6 +455,9 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
//
// Deprecated: Remove it when migration to the new API is over.
func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := clients.logger
q := r.URL.Query()
data := make([]map[string]*clientJSON, 0, len(q))
params := &client.FindParams{}
@ -467,12 +471,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
err = params.Set(idStr)
if err != nil {
clients.logger.DebugContext(
r.Context(),
"finding client",
"id", idStr,
slogutil.KeyError, err,
)
l.DebugContext(ctx, "finding client", "id", idStr, slogutil.KeyError, err)
continue
}
@ -482,7 +481,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http
})
}
aghhttp.WriteJSONResponseOK(w, r, data)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, data)
}
// findClient returns available information about a client by params from the
@ -530,13 +529,14 @@ type searchClientJSON struct {
// API.
func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := clients.logger
q := searchQueryJSON{}
err := json.NewDecoder(r.Body).Decode(&q)
if err != nil {
aghhttp.ErrorAndLog(
ctx,
clients.logger,
l,
r,
w,
http.StatusBadRequest,
@ -554,12 +554,7 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
idStr := c.ID
err = params.Set(idStr)
if err != nil {
clients.logger.DebugContext(
ctx,
"searching client",
"id", idStr,
slogutil.KeyError, err,
)
l.DebugContext(ctx, "searching client", "id", idStr, slogutil.KeyError, err)
continue
}
@ -569,7 +564,7 @@ func (clients *clientsContainer) handleSearchClient(w http.ResponseWriter, r *ht
})
}
aghhttp.WriteJSONResponseOK(w, r, data)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, data)
}
// findRuntime looks up the IP in runtime and temporary storages, like
@ -613,14 +608,14 @@ func (clients *clientsContainer) findRuntime(
}
}
// RegisterClientsHandlers registers HTTP handlers
// registerWebHandlers registers HTTP handlers.
func (clients *clientsContainer) registerWebHandlers() {
httpRegister(http.MethodGet, "/control/clients", clients.handleGetClients)
httpRegister(http.MethodPost, "/control/clients/add", clients.handleAddClient)
httpRegister(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
httpRegister(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
httpRegister(http.MethodPost, "/control/clients/search", clients.handleSearchClient)
clients.httpReg.Register(http.MethodGet, "/control/clients", clients.handleGetClients)
clients.httpReg.Register(http.MethodPost, "/control/clients/add", clients.handleAddClient)
clients.httpReg.Register(http.MethodPost, "/control/clients/delete", clients.handleDelClient)
clients.httpReg.Register(http.MethodPost, "/control/clients/update", clients.handleUpdateClient)
clients.httpReg.Register(http.MethodPost, "/control/clients/search", clients.handleSearchClient)
// Deprecated handler.
httpRegister(http.MethodGet, "/control/clients/find", clients.handleFindClient)
clients.httpReg.Register(http.MethodGet, "/control/clients/find", clients.handleFindClient)
}

View File

@ -116,12 +116,13 @@ type statusResponse struct {
func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := web.logger
dnsAddrs, err := collectDNSAddresses(web.tlsManager)
if err != nil {
// Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses.
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", err)
return
}
@ -167,7 +168,7 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
resp.IsDHCPAvailable = globalContext.dhcpServer != nil
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
}
// registerControlHandlers sets up HTTP handlers for various control endpoints.
@ -178,13 +179,21 @@ func (web *webAPI) registerControlHandlers() {
"/control/version.json",
web.postInstallHandler(http.HandlerFunc(web.handleVersionJSON)),
)
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
web.httpReg.Register(http.MethodPost, "/control/update", web.handleUpdate)
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
httpRegister(http.MethodPost, "/control/i18n/change_language", web.handleI18nChangeLanguage)
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile)
httpRegister(http.MethodPut, "/control/profile/update", web.handlePutProfile)
web.httpReg.Register(http.MethodGet, "/control/status", web.handleStatus)
web.httpReg.Register(
http.MethodPost,
"/control/i18n/change_language",
web.handleI18nChangeLanguage,
)
web.httpReg.Register(
http.MethodGet,
"/control/i18n/current_language",
web.handleI18nCurrentLanguage,
)
web.httpReg.Register(http.MethodGet, "/control/profile", web.handleGetProfile)
web.httpReg.Register(http.MethodPut, "/control/profile/update", web.handlePutProfile)
// No authentication is required for DoH/DoT configuration endpoints.
mux.Handle(
@ -199,38 +208,71 @@ func (web *webAPI) registerControlHandlers() {
web.registerAuthHandlers()
}
// httpRegister registers an HTTP handler.
//
// TODO(s.chzhen): Do not use [globalContext.mux].
func httpRegister(method, url string, handler http.HandlerFunc) {
if method == "" {
// "/dns-query" handler doesn't need auth, gzip and isn't restricted by 1 HTTP method
globalContext.mux.Handle(url, postInstallHandler(handler))
return
}
// webMw provides middleware for route handlers. The set method must be called
// to initialize the middleware.
type webMw struct {
// postInstallMw is middleware that verifies that AdGuard Home is not
// running for the first time.
postInstallMw func(h http.Handler) (wrapped http.Handler)
globalContext.mux.Handle(
url,
postInstallHandler(gziphandler.GzipHandler(ensure(method, handler))),
)
// ensureMw is like postInstallMw, but also applies gzip and enforces the
// HTTP method.
ensureMw aghhttp.WrapFunc
}
// ensure returns a wrapped handler that makes sure that the request has the
// correct method as well as additional method and header checks.
func ensure(
// set sets the middleware functions used to build handler chains.
func (mw *webMw) set(web *webAPI) {
mw.postInstallMw = web.postInstallHandler
mw.ensureMw = func(method string, h http.HandlerFunc) (wrapped http.Handler) {
return web.postInstallHandler(gziphandler.GzipHandler(web.ensure(method, h)))
}
}
// wrap returns a wrapped HTTP handler for the given route.
//
// TODO(s.chzhen): Implement [httputil.Middleware].
func (mw *webMw) wrap(method string, h http.HandlerFunc) (wrapped http.Handler) {
f := func(w http.ResponseWriter, r *http.Request) {
var handler http.Handler
if method == "" {
// The "/dns-query" handler doesn't require authentication or gzip,
// and it isn't restricted to a single HTTP method.
handler = mw.postInstallMw(h)
} else {
handler = mw.ensureMw(method, h)
}
handler.ServeHTTP(w, r)
}
return http.HandlerFunc(f)
}
// ensure returns a wrapped handler that verifies the request method. It also
// performs additional method and header checks.
func (web *webAPI) ensure(
method string,
handler func(http.ResponseWriter, *http.Request),
) (wrapped http.HandlerFunc) {
return func(w http.ResponseWriter, r *http.Request) {
m := r.Method
if m != method {
aghhttp.Error(r, w, http.StatusMethodNotAllowed, "only method %s is allowed", method)
aghhttp.ErrorAndLog(
r.Context(),
web.logger,
r,
w,
http.StatusMethodNotAllowed,
"only method %s is allowed",
method,
)
return
}
if modifiesData(m) {
if !ensureContentType(w, r) {
if !web.ensureContentType(w, r) {
return
}
@ -250,9 +292,11 @@ func modifiesData(m string) (ok bool) {
// ensureContentType makes sure that the content type of a data-modifying
// request is set correctly. If it is not, ensureContentType writes a response
// to w, and ok is false.
func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
func (web *webAPI) ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
const statusUnsup = http.StatusUnsupportedMediaType
ctx := r.Context()
cType := r.Header.Get(httphdr.ContentType)
if r.ContentLength == 0 {
if cType == "" {
@ -262,7 +306,15 @@ func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
// Assume that browsers always send a content type when submitting HTML
// forms and require no content type for requests with no body to make
// sure that the request comes from JavaScript.
aghhttp.Error(r, w, statusUnsup, "empty body with content-type %q not allowed", cType)
aghhttp.ErrorAndLog(
ctx,
web.logger,
r,
w,
statusUnsup,
"empty body with content-type %q not allowed",
cType,
)
return false
@ -273,7 +325,15 @@ func ensureContentType(w http.ResponseWriter, r *http.Request) (ok bool) {
return true
}
aghhttp.Error(r, w, statusUnsup, "only content-type %s is allowed", wantCType)
aghhttp.ErrorAndLog(
ctx,
web.logger,
r,
w,
statusUnsup,
"only content-type %s is allowed",
wantCType,
)
return false
}
@ -297,15 +357,16 @@ func (web *webAPI) preInstallHandler(handler http.Handler) (wrapped http.Handler
// handleHTTPSRedirect redirects the request to HTTPS, if needed, and adds some
// HTTPS-related headers. If proceed is true, the middleware must continue
// handling the request.
func handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
web := globalContext.web
func (web *webAPI) handleHTTPSRedirect(w http.ResponseWriter, r *http.Request) (proceed bool) {
if web.httpsServer.server == nil {
return true
}
ctx := r.Context()
host, err := netutil.SplitHost(r.Host)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "bad host: %s", err)
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "bad host: %s", err)
return false
}
@ -381,36 +442,6 @@ func httpsURL(u *url.URL, host string, portHTTPS uint16) (redirectURL *url.URL)
}
}
// postInstallHandler lets the handler to run only if firstRun is false.
// Otherwise, it redirects to /install.html. It also enforces HTTPS if it is
// enabled and configured and sets appropriate access control headers.
//
// TODO(s.chzhen): Replace with [web.postInstall] after fixing its usage in
// [httpRegister], which is called by [dhcpd.Create] before [web] is
// initialized.
func postInstallHandler(handler http.Handler) (wrapped http.Handler) {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if globalContext.web == nil {
aghhttp.Error(r, w, http.StatusTooEarly, "it is not initialized yet")
return
}
path := r.URL.Path
if globalContext.web.conf.firstRun &&
!strings.HasPrefix(path, "/install.") &&
!strings.HasPrefix(path, "/assets/") {
http.Redirect(w, r, "install.html", http.StatusFound)
return
}
if handleHTTPSRedirect(w, r) {
handler.ServeHTTP(w, r)
}
})
}
// postInstallHandler lets the handler to run only if firstRun is false.
// Otherwise, it redirects to /install.html. It also enforces HTTPS if it is
// enabled and configured and sets appropriate access control headers.
@ -425,7 +456,7 @@ func (web *webAPI) postInstallHandler(handler http.Handler) (wrapped http.Handle
return
}
if handleHTTPSRedirect(w, r) {
if web.handleHTTPSRedirect(w, r) {
handler.ServeHTTP(w, r)
}
})

View File

@ -44,6 +44,7 @@ type getAddrsResponse struct {
// handleInstallGetAddresses is the handler for /install/get_addresses endpoint.
func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := web.logger
data := getAddrsResponse{
Version: version.Version(),
@ -56,7 +57,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
if err != nil {
aghhttp.ErrorAndLog(
ctx,
web.logger,
l,
r,
w,
http.StatusInternalServerError,
@ -72,7 +73,7 @@ func (web *webAPI) handleInstallGetAddresses(w http.ResponseWriter, r *http.Requ
data.Interfaces[iface.Name] = iface
}
aghhttp.WriteJSONResponseOK(w, r, data)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, data)
}
type checkConfReqEnt struct {
@ -186,20 +187,13 @@ func (req *checkConfReq) validateDNS(
// handleInstallCheckConfig handles the /check_config endpoint.
func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := web.logger
req := &checkConfReq{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.ErrorAndLog(
ctx,
web.logger,
r,
w,
http.StatusBadRequest,
"decoding the request: %s",
err,
)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "decoding the request: %s", err)
return
}
@ -210,14 +204,14 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
resp.Web.Status = err.Error()
}
resp.DNS.CanAutofix, err = req.validateDNS(ctx, web.logger, tcpPorts, web.cmdCons)
resp.DNS.CanAutofix, err = req.validateDNS(ctx, l, tcpPorts, web.cmdCons)
if err != nil {
resp.DNS.Status = err.Error()
} else if !req.DNS.IP.IsUnspecified() {
resp.StaticIP = handleStaticIP(ctx, web.logger, req.DNS.IP, req.SetStaticIP, web.cmdCons)
resp.StaticIP = handleStaticIP(ctx, l, req.DNS.IP, req.SetStaticIP, web.cmdCons)
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
}
// handleStaticIP checks and optionally sets a static IP on the interface that
@ -430,10 +424,11 @@ const PasswordMinRunes = 8
// Apply new configuration, start DNS server, restart Web server
func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := web.logger
req, restartHTTP, err := decodeApplyConfigReq(r.Body)
if err != nil {
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
@ -441,7 +436,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
if utf8.RuneCountInString(req.Password) < PasswordMinRunes {
aghhttp.ErrorAndLog(
ctx,
web.logger,
l,
r,
w,
http.StatusUnprocessableEntity,
@ -454,14 +449,14 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
err = aghnet.CheckPort("udp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
if err != nil {
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
err = aghnet.CheckPort("tcp", netip.AddrPortFrom(req.DNS.IP, req.DNS.Port))
if err != nil {
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadRequest, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "%s", err)
return
}
@ -478,6 +473,8 @@ func (web *webAPI) finalizeInstall(
req *applyConfigReq,
restartHTTP bool,
) {
l := web.logger
var err error
curConfig := &configuration{}
copyInstallSettings(curConfig, config)
@ -501,7 +498,7 @@ func (web *webAPI) finalizeInstall(
}
err = web.auth.addUser(ctx, u, req.Password)
if err != nil {
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusUnprocessableEntity, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusUnprocessableEntity, "%s", err)
return
}
@ -510,9 +507,9 @@ func (web *webAPI) finalizeInstall(
// moment we'll allow setting up TLS in the initial configuration or the
// configuration itself will use HTTPS protocol, because the underlying
// functions potentially restart the HTTPS server.
err = startMods(ctx, web.baseLogger, web.tlsManager, web.confModifier)
err = startMods(ctx, web.baseLogger, web.tlsManager, web.confModifier, web.httpReg)
if err != nil {
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", err)
return
}
@ -521,7 +518,7 @@ func (web *webAPI) finalizeInstall(
if err != nil {
aghhttp.ErrorAndLog(
ctx,
web.logger,
l,
r,
w,
http.StatusInternalServerError,
@ -537,12 +534,12 @@ func (web *webAPI) finalizeInstall(
web.registerControlHandlers()
aghhttp.OK(ctx, web.logger, w)
aghhttp.OK(ctx, l, w)
rc := http.NewResponseController(w)
err = rc.Flush()
if err != nil {
web.logger.WarnContext(ctx, "flushing response", slogutil.KeyError, err)
l.WarnContext(ctx, "flushing response", slogutil.KeyError, err)
}
if !restartHTTP {
@ -554,10 +551,10 @@ func (web *webAPI) finalizeInstall(
// and will be blocked by it's own caller.
go func(timeout time.Duration) {
shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), timeout)
defer slogutil.RecoverAndLog(shutdownCtx, web.logger)
defer slogutil.RecoverAndLog(shutdownCtx, l)
defer cancel()
shutdownSrv(shutdownCtx, web.logger, web.httpServer)
shutdownSrv(shutdownCtx, l, web.httpServer)
}(shutdownTimeout)
}
@ -593,19 +590,20 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
}
// startMods initializes and starts the DNS server after installation.
// baseLogger and tlsMgr must not be nil.
// baseLogger, tlsMgr, confModifier, and httpReg must not be nil.
func startMods(
ctx context.Context,
baseLogger *slog.Logger,
tlsMgr *tlsManager,
confModifier agh.ConfigModifier,
httpReg aghhttp.Registrar,
) (err error) {
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
if err != nil {
return err
}
err = initDNS(ctx, baseLogger, tlsMgr, confModifier, statsDir, querylogDir)
err = initDNS(ctx, baseLogger, tlsMgr, confModifier, httpReg, statsDir, querylogDir)
if err != nil {
return err
}
@ -627,15 +625,15 @@ func (web *webAPI) registerInstallHandlers() {
mux := web.conf.mux
mux.Handle(
"/control/install/get_addresses",
web.preInstallHandler(ensure(http.MethodGet, web.handleInstallGetAddresses)),
http.MethodGet+" "+"/control/install/get_addresses",
web.preInstallHandler(http.HandlerFunc(web.handleInstallGetAddresses)),
)
mux.Handle(
"/control/install/check_config",
web.preInstallHandler(ensure(http.MethodPost, web.handleInstallCheckConfig)),
web.preInstallHandler(web.ensure(http.MethodPost, web.handleInstallCheckConfig)),
)
mux.Handle(
"/control/install/configure",
web.preInstallHandler(ensure(http.MethodPost, web.handleInstallConfigure)),
web.preInstallHandler(web.ensure(http.MethodPost, web.handleInstallConfigure)),
)
}

View File

@ -33,11 +33,12 @@ type temporaryError interface {
// TODO(a.garipov): Find out if this API used with a GET method by anyone.
func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := web.logger
resp := &versionResponse{}
if web.conf.disableUpdate {
resp.Disabled = true
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
return
}
@ -50,15 +51,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
if r.ContentLength != 0 {
err = json.NewDecoder(r.Body).Decode(req)
if err != nil {
aghhttp.ErrorAndLog(
ctx,
web.logger,
r,
w,
http.StatusBadRequest,
"parsing request: %s",
err,
)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadRequest, "parsing request: %s", err)
return
}
@ -67,20 +60,20 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
err = web.requestVersionInfo(ctx, resp, req.Recheck)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusBadGateway, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusBadGateway, "%s", err)
return
}
err = resp.setAllowedToAutoUpdate(ctx, web.logger, web.tlsManager)
err = resp.setAllowedToAutoUpdate(ctx, l, web.tlsManager)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.ErrorAndLog(ctx, web.logger, r, w, http.StatusInternalServerError, "%s", err)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, "%s", err)
return
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
}
// requestVersionInfo sets the VersionInfo field of resp if it can reach the

View File

@ -41,13 +41,14 @@ const (
// initDNS updates all the fields of the [globalContext] needed to initialize
// the DNS server and initializes it at last. It also must not be called unless
// [config] and [globalContext] are initialized. baseLogger, tlsMgr and
// confModfier must not be nil.
// [config] and [globalContext] are initialized. baseLogger, tlsMgr,
// confModifier, and httpReg must not be nil.
func initDNS(
ctx context.Context,
baseLogger *slog.Logger,
tlsMgr *tlsManager,
confModifier agh.ConfigModifier,
httpReg aghhttp.Registrar,
statsDir string,
querylogDir string,
) (err error) {
@ -58,7 +59,7 @@ func initDNS(
Filename: filepath.Join(statsDir, "stats.db"),
Limit: time.Duration(config.Stats.Interval),
ConfigModifier: confModifier,
HTTPRegister: httpRegister,
HTTPReg: httpReg,
Enabled: config.Stats.Enabled,
ShouldCountClient: globalContext.clients.shouldCountClient,
}
@ -78,7 +79,7 @@ func initDNS(
Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"),
Anonymizer: anonymizer,
ConfigModifier: confModifier,
HTTPRegister: httpRegister,
HTTPReg: httpReg,
FindClient: globalContext.clients.findMultiple,
BaseDir: querylogDir,
AnonymizeClientIP: config.DNS.AnonymizeClientIP,
@ -112,7 +113,7 @@ func initDNS(
globalContext.queryLog,
globalContext.dhcpServer,
anonymizer,
httpRegister,
httpReg,
tlsMgr,
baseLogger,
confModifier,
@ -132,7 +133,7 @@ func initDNSServer(
qlog querylog.QueryLog,
dhcpSrv dnsforward.DHCP,
anonymizer *aghnet.IPMut,
httpReg aghhttp.RegisterFunc,
httpReg aghhttp.Registrar,
tlsMgr *tlsManager,
l *slog.Logger,
confModifier agh.ConfigModifier,
@ -235,13 +236,13 @@ func ipsToUDPAddrs(ips []netip.Addr, port uint16) (udpAddrs []*net.UDPAddr) {
}
// newServerConfig converts values from the configuration file into the internal
// DNS server configuration. All arguments must not be nil, except for httpReg.
// DNS server configuration. All arguments must not be nil.
func newServerConfig(
dnsConf *dnsConfig,
clientSrcConf *clientSourcesConfig,
tlsConf *tlsConfigSettings,
tlsMgr *tlsManager,
httpReg aghhttp.RegisterFunc,
httpReg aghhttp.Registrar,
clientsContainer dnsforward.ClientsContainer,
confModifier agh.ConfigModifier,
) (newConf *dnsforward.ServerConfig, err error) {
@ -264,7 +265,7 @@ func newServerConfig(
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
TLSv12Roots: tlsMgr.rootCerts,
ConfModifier: confModifier,
HTTPRegister: httpReg,
HTTPReg: httpReg,
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
UseDNS64: dnsConf.UseDNS64,
DNS64Prefixes: dnsConf.DNS64Prefixes,

View File

@ -21,6 +21,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/AdGuardHome/internal/aghslog"
@ -65,9 +66,6 @@ type homeContext struct {
// configuration files, for example /etc/hosts.
etcHosts *aghnet.HostsContainer
// mux is our custom http.ServeMux.
mux *http.ServeMux
// Runtime properties
// --
@ -150,8 +148,6 @@ func setupContext(
opts options,
isFirstRun bool,
) (err error) {
globalContext.mux = http.NewServeMux()
if !opts.noEtcHosts {
err = setupHostsContainer(ctx, baseLogger)
if err != nil {
@ -302,11 +298,12 @@ func initContextClients(
logger *slog.Logger,
sigHdlr *signalHandler,
confModifier agh.ConfigModifier,
httpReg aghhttp.Registrar,
) (err error) {
//lint:ignore SA1019 Migration is not over.
config.DHCP.WorkDir = globalContext.workDir
config.DHCP.DataDir = globalContext.getDataDir()
config.DHCP.HTTPRegister = httpRegister
config.DHCP.HTTPReg = httpReg
config.DHCP.CommandConstructor = executil.SystemCommandConstructor{}
config.DHCP.Logger = logger.With(slogutil.KeyPrefix, "dhcpd")
config.DHCP.ConfModifier = confModifier
@ -335,6 +332,7 @@ func initContextClients(
config.Filtering,
sigHdlr,
confModifier,
httpReg,
)
}
@ -386,6 +384,7 @@ func setupDNSFilteringConf(
conf *filtering.Config,
tlsMgr *tlsManager,
confModifier agh.ConfigModifier,
httpReg aghhttp.Registrar,
) (err error) {
const (
dnsTimeout = 3 * time.Second
@ -408,7 +407,7 @@ func setupDNSFilteringConf(
}
conf.ConfModifier = confModifier
conf.HTTPRegister = httpRegister
conf.HTTPReg = httpReg
conf.DataDir = globalContext.getDataDir()
conf.Filters = slices.Clone(config.Filters)
conf.WhitelistFilters = slices.Clone(config.WhitelistFilters)
@ -563,8 +562,7 @@ func isUpdateEnabled(
}
}
// initWeb initializes the web module. upd, baseLogger, tlsMgr, auth, and mux
// must not be nil.
// initWeb initializes the web module. All arguments must not be nil.
func initWeb(
ctx context.Context,
opts options,
@ -575,6 +573,7 @@ func initWeb(
auth *auth,
mux *http.ServeMux,
confModifier agh.ConfigModifier,
httpReg aghhttp.Registrar,
isCustomUpdURL bool,
isFirstRun bool,
) (web *webAPI, err error) {
@ -602,6 +601,7 @@ func initWeb(
logger: logger,
baseLogger: baseLogger,
confModifier: confModifier,
httpReg: httpReg,
tlsManager: tlsMgr,
auth: auth,
mux: mux,
@ -718,7 +718,11 @@ func run(
slogLogger.With(slogutil.KeyPrefix, "config_modifier"),
)
err = initContextClients(ctx, slogLogger, sigHdlr, confModifier)
mw := &webMw{}
mux := http.NewServeMux()
httpReg := aghhttp.NewDefaultRegistrar(mux, mw.wrap)
err = initContextClients(ctx, slogLogger, sigHdlr, confModifier, httpReg)
fatalOnError(err)
tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager")
@ -726,6 +730,7 @@ func run(
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
logger: tlsMgrLogger,
confModifier: confModifier,
httpReg: httpReg,
tlsSettings: config.TLS,
servePlainDNS: config.DNS.ServePlainDNS,
})
@ -736,7 +741,14 @@ func run(
confModifier.setTLSManager(tlsMgr)
err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr, confModifier)
err = setupDNSFilteringConf(
ctx,
slogLogger,
config.Filtering,
tlsMgr,
confModifier,
httpReg,
)
fatalOnError(err)
err = setupOpts(opts)
@ -788,13 +800,16 @@ func run(
slogLogger,
tlsMgr,
auth,
globalContext.mux,
mux,
confModifier,
httpReg,
isCustomURL,
isFirstRun,
)
fatalOnError(err)
mw.set(web)
globalContext.web = web
tlsMgr.setWebAPI(web)
@ -804,7 +819,7 @@ func run(
fatalOnError(err)
if !isFirstRun {
err = initDNS(ctx, slogLogger, tlsMgr, confModifier, statsDir, querylogDir)
err = initDNS(ctx, slogLogger, tlsMgr, confModifier, httpReg, statsDir, querylogDir)
fatalOnError(err)
tlsMgr.start(ctx)

View File

@ -6,7 +6,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/log"
)
// TODO(a.garipov): Get rid of a global or generate from .twosky.json.
@ -54,15 +53,24 @@ type languageJSON struct {
Language string `json:"language"`
}
// handleI18nCurrentLanguage is the handler for the GET
// /control/i18n/current_language HTTP API.
//
// TODO(d.kolyshev): Deprecated, remove it later.
func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
log.Printf("home: language is %s", config.Language)
func (web *webAPI) handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := web.logger
aghhttp.WriteJSONResponseOK(w, r, &languageJSON{
l.InfoContext(ctx, "current language", "lang", config.Language)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, &languageJSON{
Language: config.Language,
})
}
// handleI18nChangeLanguage is the handler for the POST
// /control/i18n/change_language HTTP API.
//
// TODO(d.kolyshev): Deprecated, remove it later.
func (web *webAPI) handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()

View File

@ -46,10 +46,12 @@ type profileJSON struct {
// handleGetProfile is the handler for GET /control/profile endpoint.
func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var name string
if !web.auth.isGLiNet && !web.auth.isUserless {
u, ok := webUserFromContext(r.Context())
u, ok := webUserFromContext(ctx)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
@ -71,7 +73,7 @@ func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) {
}
}()
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, web.logger, w, r, resp)
}
// handlePutProfile is the handler for PUT /control/profile/update endpoint.

View File

@ -72,7 +72,6 @@ func TestWeb_HandleGetProfile(t *testing.T) {
t.Cleanup(func() { auth.close(testutil.ContextWithTimeout(t, testTimeout)) })
baseMux := http.NewServeMux()
globalContext.mux = baseMux
tlsMgr, err := newTLSManager(testutil.ContextWithTimeout(t, testTimeout), &tlsManagerConfig{
logger: testLogger,
@ -90,6 +89,7 @@ func TestWeb_HandleGetProfile(t *testing.T) {
auth,
baseMux,
agh.EmptyConfigModifier{},
aghhttp.EmptyRegistrar{},
false,
false,
)
@ -132,8 +132,9 @@ func TestWeb_HandleGetProfile(t *testing.T) {
func TestWeb_HandlePutProfile(t *testing.T) {
storeGlobals(t)
mw := &webMw{}
mux := http.NewServeMux()
globalContext.mux = mux
httpReg := aghhttp.NewDefaultRegistrar(mux, mw.wrap)
isConfigChanged := false
confModifier := &aghtest.ConfigModifier{
@ -151,12 +152,14 @@ func TestWeb_HandlePutProfile(t *testing.T) {
nil,
mux,
confModifier,
httpReg,
false,
false,
)
require.NoError(t, err)
globalContext.web = web
mw.set(web)
var (
dataValid = errors.Must(json.Marshal(&profileJSON{

View File

@ -60,7 +60,11 @@ type tlsManager struct {
// confModifier is used to update the global configuration.
confModifier agh.ConfigModifier
// customCipherIDs are the ID of the cipher suites that AdGuard Home must use.
// httpReg registers HTTP handlers. It must not be nil.
httpReg aghhttp.Registrar
// customCipherIDs are the IDs of the cipher suites that AdGuard Home must
// use.
customCipherIDs []uint16
// servePlainDNS defines if plain DNS is allowed for incoming requests.
@ -77,6 +81,8 @@ type tlsManagerConfig struct {
// nil.
confModifier agh.ConfigModifier
httpReg aghhttp.Registrar
// tlsSettings contains the TLS configuration settings.
tlsSettings tlsConfigSettings
@ -94,6 +100,7 @@ func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager,
logger: conf.logger,
mu: &sync.Mutex{},
confModifier: conf.confModifier,
httpReg: conf.httpReg,
status: &tlsConfigStatus{},
conf: &conf.tlsSettings,
servePlainDNS: conf.servePlainDNS,
@ -251,7 +258,7 @@ func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
config.Clients.Sources,
m.conf,
m,
httpRegister,
m.httpReg,
globalContext.clients.storage,
m.confModifier,
)
@ -426,7 +433,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
servePlainDNS = m.servePlainDNS
}()
data := tlsConfig{
data := &tlsConfig{
tlsConfigSettingsExt: tlsConfigSettingsExt{
tlsConfigSettings: *tlsConf,
ServePlainDNS: aghalg.BoolToNullBool(servePlainDNS),
@ -434,7 +441,7 @@ func (m *tlsManager) handleTLSStatus(w http.ResponseWriter, r *http.Request) {
tlsConfigStatus: m.status,
}
marshalTLS(w, r, data)
m.marshalTLS(r.Context(), w, r, data)
}
// handleTLSValidate is the handler for the POST /control/tls/validate HTTP API.
@ -471,12 +478,12 @@ func (m *tlsManager) handleTLSValidate(w http.ResponseWriter, r *http.Request) {
// status.WarningValidation.
status := &tlsConfigStatus{}
_ = m.loadTLSConfig(ctx, &setts.tlsConfigSettings, status)
resp := tlsConfig{
resp := &tlsConfig{
tlsConfigSettingsExt: setts,
tlsConfigStatus: status,
}
marshalTLS(w, r, resp)
m.marshalTLS(ctx, w, r, resp)
}
// setConfig updates manager TLS configuration with the given one. m.mu is
@ -548,12 +555,12 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
status := &tlsConfigStatus{}
err = m.loadTLSConfig(ctx, &req.tlsConfigSettings, status)
if err != nil {
resp := tlsConfig{
resp := &tlsConfig{
tlsConfigSettingsExt: req,
tlsConfigStatus: status,
}
marshalTLS(w, r, resp)
m.marshalTLS(ctx, w, r, resp)
return
}
@ -579,12 +586,12 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
return
}
resp := tlsConfig{
resp := &tlsConfig{
tlsConfigSettingsExt: req,
tlsConfigStatus: m.status,
}
marshalTLS(w, r, resp)
m.marshalTLS(ctx, w, r, resp)
rc := http.NewResponseController(w)
err = rc.Flush()
if err != nil {
@ -1027,7 +1034,14 @@ func unmarshalTLS(r *http.Request) (tlsConfigSettingsExt, error) {
return data, nil
}
func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
// marshalTLS encodes sensitive fields and writes data as JSON. All arguments
// must not be nil.
func (m *tlsManager) marshalTLS(
ctx context.Context,
w http.ResponseWriter,
r *http.Request,
data *tlsConfig,
) {
if data.CertificateChain != "" {
encoded := base64.StdEncoding.EncodeToString([]byte(data.CertificateChain))
data.CertificateChain = encoded
@ -1038,12 +1052,12 @@ func marshalTLS(w http.ResponseWriter, r *http.Request, data tlsConfig) {
data.PrivateKey = ""
}
aghhttp.WriteJSONResponseOK(w, r, data)
aghhttp.WriteJSONResponseOK(ctx, m.logger, w, r, *data)
}
// registerWebHandlers registers HTTP handlers for TLS configuration.
func (m *tlsManager) registerWebHandlers() {
httpRegister(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
httpRegister(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
httpRegister(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
m.httpReg.Register(http.MethodGet, "/control/tls/status", m.handleTLSStatus)
m.httpReg.Register(http.MethodPost, "/control/tls/configure", m.handleTLSConfigure)
m.httpReg.Register(http.MethodPost, "/control/tls/validate", m.handleTLSValidate)
}

View File

@ -22,6 +22,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/client"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/golibs/testutil"
@ -153,7 +154,6 @@ func storeGlobals(tb testing.TB) {
prefGLFilePrefix := glFilePrefix
storage := globalContext.clients.storage
dnsServer := globalContext.dnsServer
mux := globalContext.mux
web := globalContext.web
tb.Cleanup(func() {
@ -161,7 +161,6 @@ func storeGlobals(tb testing.TB) {
glFilePrefix = prefGLFilePrefix
globalContext.clients.storage = storage
globalContext.dnsServer = dnsServer
globalContext.mux = mux
globalContext.web = web
})
}
@ -307,6 +306,7 @@ func initEmptyWeb(tb testing.TB) (web *webAPI) {
nil,
http.NewServeMux(),
agh.EmptyConfigModifier{},
aghhttp.EmptyRegistrar{},
false,
false,
)
@ -337,8 +337,6 @@ func TestTLSManager_Reload(t *testing.T) {
})
require.NoError(t, err)
globalContext.mux = http.NewServeMux()
const (
snBefore int64 = 1
snAfter int64 = 2
@ -419,8 +417,6 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
func TestValidateTLSSettings(t *testing.T) {
storeGlobals(t)
globalContext.mux = http.NewServeMux()
var (
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
@ -516,8 +512,6 @@ func TestValidateTLSSettings(t *testing.T) {
func TestTLSManager_HandleTLSValidate(t *testing.T) {
storeGlobals(t)
globalContext.mux = http.NewServeMux()
var (
ctx = testutil.ContextWithTimeout(t, testTimeout)
err error
@ -598,8 +592,6 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
})
require.NoError(t, err)
globalContext.mux = http.NewServeMux()
config.DNS.BindHosts = []netip.Addr{netip.MustParseAddr("127.0.0.1")}
config.DNS.Port = 0

View File

@ -13,6 +13,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
@ -55,6 +56,9 @@ type webConfig struct {
// confModifier is used to update the global configuration.
confModifier agh.ConfigModifier
// httpReg registers HTTP handlers. It must not be nil.
httpReg aghhttp.Registrar
// tlsManager contains the current configuration and state of TLS
// encryption. It must not be nil.
tlsManager *tlsManager
@ -125,6 +129,9 @@ type webAPI struct {
// cmdCons is used to run external commands.
cmdCons executil.CommandConstructor
// httpReg registers HTTP handlers.
httpReg aghhttp.Registrar
// TODO(a.garipov): Refactor all these servers.
httpServer *http.Server
@ -157,6 +164,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
w = &webAPI{
conf: conf,
confModifier: conf.confModifier,
httpReg: conf.httpReg,
cmdCons: conf.CommandConstructor,
logger: conf.logger,
baseLogger: conf.baseLogger,

View File

@ -60,11 +60,13 @@ type HTTPAPIDNSSettings struct {
// handlePatchSettingsDNS is the handler for the PATCH /api/v1/settings/dns HTTP
// API.
func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
l := svc.logger
req := &ReqPatchSettingsDNS{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
aghhttp.WriteJSONResponseError(ctx, l, w, r, fmt.Errorf("decoding: %w", err))
return
}
@ -93,10 +95,9 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
req.RefuseAny.Set(&newConf.RefuseAny)
req.UseDNS64.Set(&newConf.UseDNS64)
ctx := r.Context()
err = svc.confMgr.UpdateDNS(ctx, newConf)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("updating: %w", err))
aghhttp.WriteJSONResponseError(ctx, l, w, r, fmt.Errorf("updating: %w", err))
return
}
@ -104,12 +105,12 @@ func (svc *Service) handlePatchSettingsDNS(w http.ResponseWriter, r *http.Reques
newSvc := svc.confMgr.DNS()
err = newSvc.Start(ctx)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("starting new service: %w", err))
aghhttp.WriteJSONResponseError(ctx, l, w, r, fmt.Errorf("starting new service: %w", err))
return
}
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIDNSSettings{
aghhttp.WriteJSONResponseOK(ctx, l, w, r, &HTTPAPIDNSSettings{
UpstreamMode: newConf.UpstreamMode,
Addresses: newConf.Addresses,
BootstrapServers: newConf.BootstrapServers,

View File

@ -43,11 +43,13 @@ type HTTPAPIHTTPSettings struct {
// handlePatchSettingsHTTP is the handler for the PATCH /api/v1/settings/http
// HTTP API.
func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req := &ReqPatchSettingsHTTP{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.WriteJSONResponseError(w, r, fmt.Errorf("decoding: %w", err))
aghhttp.WriteJSONResponseError(ctx, svc.logger, w, r, fmt.Errorf("decoding: %w", err))
return
}
@ -61,7 +63,7 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
req.Timeout.Set((*aghhttp.JSONDuration)(&newConf.Timeout))
req.ForceHTTPS.Set(&newConf.ForceHTTPS)
aghhttp.WriteJSONResponseOK(w, r, &HTTPAPIHTTPSettings{
aghhttp.WriteJSONResponseOK(ctx, svc.logger, w, r, &HTTPAPIHTTPSettings{
Addresses: newConf.Addresses,
SecureAddresses: newConf.SecureAddresses,
Timeout: aghhttp.JSONDuration(newConf.Timeout),
@ -71,7 +73,6 @@ func (svc *Service) handlePatchSettingsHTTP(w http.ResponseWriter, r *http.Reque
cancelUpd := func() {}
updCtx := context.Background()
ctx := r.Context()
if deadline, ok := ctx.Deadline(); ok {
updCtx, cancelUpd = context.WithDeadline(updCtx, deadline)
}

View File

@ -27,7 +27,7 @@ func (svc *Service) handleGetSettingsAll(w http.ResponseWriter, r *http.Request)
httpConf := webSvc.Config()
// TODO(a.garipov): Add all currently supported parameters.
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SettingsAll{
aghhttp.WriteJSONResponseOK(r.Context(), svc.logger, w, r, &RespGetV1SettingsAll{
DNS: &HTTPAPIDNSSettings{
UpstreamMode: dnsConf.UpstreamMode,
Addresses: dnsConf.Addresses,

View File

@ -24,7 +24,7 @@ type RespGetV1SystemInfo struct {
// handleGetV1SystemInfo is the handler for the GET /api/v1/system/info HTTP
// API.
func (svc *Service) handleGetV1SystemInfo(w http.ResponseWriter, r *http.Request) {
aghhttp.WriteJSONResponseOK(w, r, &RespGetV1SystemInfo{
aghhttp.WriteJSONResponseOK(r.Context(), svc.logger, w, r, &RespGetV1SystemInfo{
Arch: runtime.GOARCH,
Channel: version.Channel(),
OS: runtime.GOOS,

View File

@ -59,18 +59,18 @@ type getConfigResp struct {
// Register web handlers
func (l *queryLog) initWeb() {
l.conf.HTTPRegister(http.MethodGet, "/control/querylog", l.handleQueryLog)
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPRegister(http.MethodGet, "/control/querylog/config", l.handleGetQueryLogConfig)
l.conf.HTTPRegister(
l.conf.HTTPReg.Register(http.MethodGet, "/control/querylog", l.handleQueryLog)
l.conf.HTTPReg.Register(http.MethodPost, "/control/querylog_clear", l.handleQueryLogClear)
l.conf.HTTPReg.Register(http.MethodGet, "/control/querylog/config", l.handleGetQueryLogConfig)
l.conf.HTTPReg.Register(
http.MethodPut,
"/control/querylog/config/update",
l.handlePutQueryLogConfig,
)
// Deprecated handlers.
l.conf.HTTPRegister(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPRegister(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
l.conf.HTTPReg.Register(http.MethodGet, "/control/querylog_info", l.handleQueryLogInfo)
l.conf.HTTPReg.Register(http.MethodPost, "/control/querylog_config", l.handleQueryLogConfig)
}
// handleQueryLog is the handler for the GET /control/querylog HTTP API.
@ -94,7 +94,7 @@ func (l *queryLog) handleQueryLog(w http.ResponseWriter, r *http.Request) {
resp := l.entriesToJSON(ctx, entries, oldest, l.anonymizer.Load())
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l.logger, w, r, resp)
}
// handleQueryLogClear is the handler for the POST /control/querylog/clear HTTP
@ -119,7 +119,7 @@ func (l *queryLog) handleQueryLogInfo(w http.ResponseWriter, r *http.Request) {
ivl = timeutil.Day * 90
}
aghhttp.WriteJSONResponseOK(w, r, configJSON{
aghhttp.WriteJSONResponseOK(r.Context(), l.logger, w, r, configJSON{
Enabled: aghalg.BoolToNullBool(l.conf.Enabled),
Interval: ivl.Hours() / 24,
AnonymizeClientIP: aghalg.BoolToNullBool(l.conf.AnonymizeClientIP),
@ -142,7 +142,7 @@ func (l *queryLog) handleGetQueryLogConfig(w http.ResponseWriter, r *http.Reques
}
}()
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), l.logger, w, r, resp)
}
// AnonymizeIP masks ip to anonymize the client if the ip is a valid one.

View File

@ -87,7 +87,7 @@ var _ QueryLog = (*queryLog)(nil)
// Start implements the [QueryLog] interface for *queryLog.
func (l *queryLog) Start(ctx context.Context) (err error) {
if l.conf.HTTPRegister != nil {
if l.conf.HTTPReg != nil {
l.initWeb()
}

View File

@ -53,7 +53,7 @@ type Config struct {
ConfigModifier agh.ConfigModifier
// HTTPRegister registers an HTTP handler.
HTTPRegister aghhttp.RegisterFunc
HTTPReg aghhttp.Registrar
// FindClient returns client information by their IDs.
FindClient func(ids []string) (c *Client, err error)

View File

@ -51,6 +51,7 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
start := time.Now()
ctx := r.Context()
l := s.logger
var (
resp *StatsResp
@ -63,18 +64,18 @@ func (s *StatsCtx) handleStats(w http.ResponseWriter, r *http.Request) {
resp, ok = s.getData(uint32(s.limit.Hours()))
}()
s.logger.DebugContext(ctx, "prepared data", "elapsed", time.Since(start))
l.DebugContext(ctx, "prepared data", "elapsed", time.Since(start))
if !ok {
// Don't bring the message to the lower case since it's a part of UI
// text for the moment.
const msg = "Couldn't get statistics data"
aghhttp.ErrorAndLog(ctx, s.logger, r, w, http.StatusInternalServerError, msg)
aghhttp.ErrorAndLog(ctx, l, r, w, http.StatusInternalServerError, msg)
return
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(ctx, l, w, r, resp)
}
// configResp is the response to the GET /control/stats_info.
@ -123,7 +124,7 @@ func (s *StatsCtx) handleStatsInfo(w http.ResponseWriter, r *http.Request) {
resp.IntervalDays = 0
}
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), s.logger, w, r, resp)
}
// handleGetStatsConfig is the handler for the GET /control/stats/config HTTP
@ -141,7 +142,7 @@ func (s *StatsCtx) handleGetStatsConfig(w http.ResponseWriter, r *http.Request)
}
}()
aghhttp.WriteJSONResponseOK(w, r, resp)
aghhttp.WriteJSONResponseOK(r.Context(), s.logger, w, r, resp)
}
// handleStatsConfig is the handler for the POST /control/stats_config HTTP API.
@ -244,16 +245,12 @@ func (s *StatsCtx) handleStatsReset(w http.ResponseWriter, r *http.Request) {
// initWeb registers the handlers for web endpoints of statistics module.
func (s *StatsCtx) initWeb() {
if s.httpRegister == nil {
return
}
s.httpRegister(http.MethodGet, "/control/stats", s.handleStats)
s.httpRegister(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
s.httpRegister(http.MethodGet, "/control/stats/config", s.handleGetStatsConfig)
s.httpRegister(http.MethodPut, "/control/stats/config/update", s.handlePutStatsConfig)
s.httpReg.Register(http.MethodGet, "/control/stats", s.handleStats)
s.httpReg.Register(http.MethodPost, "/control/stats_reset", s.handleStatsReset)
s.httpReg.Register(http.MethodGet, "/control/stats/config", s.handleGetStatsConfig)
s.httpReg.Register(http.MethodPut, "/control/stats/config/update", s.handlePutStatsConfig)
// Deprecated handlers.
s.httpRegister(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
s.httpRegister(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
s.httpReg.Register(http.MethodGet, "/control/stats_info", s.handleStatsInfo)
s.httpReg.Register(http.MethodPost, "/control/stats_config", s.handleStatsConfig)
}

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
@ -30,6 +31,7 @@ func TestHandleStatsConfig(t *testing.T) {
UnitID: func() (id uint32) { return 0 },
ConfigModifier: agh.EmptyConfigModifier{},
ShouldCountClient: func([]string) bool { return true },
HTTPReg: aghhttp.EmptyRegistrar{},
Filename: filepath.Join(t.TempDir(), "stats.db"),
Limit: time.Hour * 24,
Enabled: true,

View File

@ -65,7 +65,7 @@ type Config struct {
// HTTPRegister is the function that registers handlers for the stats
// endpoints.
HTTPRegister aghhttp.RegisterFunc
HTTPReg aghhttp.Registrar
// Ignored contains the list of host names, which should not be counted,
// and matches them.
@ -121,8 +121,8 @@ type StatsCtx struct {
// unit. It's here for only testing purposes.
unitIDGen UnitIDGenFunc
// httpRegister is used to set HTTP handlers.
httpRegister aghhttp.RegisterFunc
// httpReg registers HTTP handlers. It must not be nil.
httpReg aghhttp.Registrar
// configModifier is used to update the global configuration.
configModifier agh.ConfigModifier
@ -164,7 +164,7 @@ func New(conf Config) (s *StatsCtx, err error) {
s = &StatsCtx{
logger: conf.Logger,
currMu: &sync.RWMutex{},
httpRegister: conf.HTTPRegister,
httpReg: conf.HTTPReg,
configModifier: conf.ConfigModifier,
filename: conf.Filename,

View File

@ -8,6 +8,7 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/timeutil"
@ -21,6 +22,7 @@ func TestStats_races(t *testing.T) {
conf := Config{
Logger: slogutil.NewDiscardLogger(),
ShouldCountClient: func([]string) bool { return true },
HTTPReg: aghhttp.EmptyRegistrar{},
UnitID: idGen,
Filename: filepath.Join(t.TempDir(), "./stats.db"),
Limit: timeutil.Day,

View File

@ -11,7 +11,9 @@ import (
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/logutil/slogutil"
@ -59,8 +61,10 @@ func TestStats(t *testing.T) {
Limit: timeutil.Day,
Enabled: true,
UnitID: constUnitID,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
HTTPReg: &aghtest.Registrar{
OnRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
},
}
@ -180,7 +184,11 @@ func TestLargeNumbers(t *testing.T) {
Limit: timeutil.Day,
Enabled: true,
UnitID: func() (id uint32) { return atomic.LoadUint32(&curHour) },
HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler },
HTTPReg: &aghtest.Registrar{
OnRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler
},
},
}
s, err := stats.New(conf)
@ -234,6 +242,7 @@ func TestShouldCount(t *testing.T) {
ShouldCountClient: func(ids []string) (a bool) {
return ids[0] != "no_count"
},
HTTPReg: aghhttp.EmptyRegistrar{},
})
require.NoError(t, err)