diff --git a/internal/agh/agh.go b/internal/agh/agh.go new file mode 100644 index 00000000..8fbd740f --- /dev/null +++ b/internal/agh/agh.go @@ -0,0 +1,22 @@ +// Package agh contains common entities and interfaces of AdGuard Home. +package agh + +import ( + "context" +) + +// ConfigModifier defines an interface for updating the global configuration. +type ConfigModifier interface { + // Apply applies changes to the global configuration. + Apply(ctx context.Context) +} + +// EmptyConfigModifier is an empty [ConfigModifier] implementation that does +// nothing. +type EmptyConfigModifier struct{} + +// type check +var _ ConfigModifier = EmptyConfigModifier{} + +// Apply implements the [ConfigModifier] for EmptyConfigModifier. +func (em EmptyConfigModifier) Apply(ctx context.Context) {} diff --git a/internal/aghtest/interface.go b/internal/aghtest/interface.go index 3b07257d..17385a15 100644 --- a/internal/aghtest/interface.go +++ b/internal/aghtest/interface.go @@ -6,8 +6,9 @@ import ( "net/netip" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghos" - "github.com/AdguardTeam/AdGuardHome/internal/next/agh" + nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh" "github.com/AdguardTeam/AdGuardHome/internal/rdns" "github.com/AdguardTeam/AdGuardHome/internal/whois" "github.com/AdguardTeam/dnsproxy/upstream" @@ -53,9 +54,10 @@ func (w *FSWatcher) Add(name string) (err error) { return w.OnAdd(name) } -// Package agh +// Package nextagh -// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests. +// ServiceWithConfig is a fake [nextagh.ServiceWithConfig] implementation for +// tests. type ServiceWithConfig[ConfigType any] struct { OnStart func(ctx context.Context) (err error) OnShutdown func(ctx context.Context) (err error) @@ -63,21 +65,21 @@ type ServiceWithConfig[ConfigType any] struct { } // type check -var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil) +var _ nextagh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil) -// Start implements the [agh.ServiceWithConfig] interface for +// Start implements the [nextagh.ServiceWithConfig] interface for // *ServiceWithConfig. func (s *ServiceWithConfig[_]) Start(ctx context.Context) (err error) { return s.OnStart(ctx) } -// Shutdown implements the [agh.ServiceWithConfig] interface for +// Shutdown implements the [nextagh.ServiceWithConfig] interface for // *ServiceWithConfig. func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) { return s.OnShutdown(ctx) } -// Config implements the [agh.ServiceWithConfig] interface for +// Config implements the [nextagh.ServiceWithConfig] interface for // *ServiceWithConfig. func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) { return s.OnConfig() @@ -178,3 +180,16 @@ func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) { func (u *UpstreamMock) Close() (err error) { return u.OnClose() } + +// ConfigModifier is a fake [agh.ConfigModifier] implementation for tests. +type ConfigModifier struct { + OnApply func(ctx context.Context) +} + +// type check +var _ agh.ConfigModifier = (*ConfigModifier)(nil) + +// Apply implements the [ConfigModifier] interface for *ConfigModifier. +func (m *ConfigModifier) Apply(ctx context.Context) { + m.OnApply(ctx) +} diff --git a/internal/dhcpd/config.go b/internal/dhcpd/config.go index d11d9342..4cef310d 100644 --- a/internal/dhcpd/config.go +++ b/internal/dhcpd/config.go @@ -6,6 +6,7 @@ import ( "net/netip" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" @@ -15,8 +16,9 @@ import ( // ServerConfig is the configuration for the DHCP server. The order of YAML // fields is important, since the YAML configuration file follows it. type ServerConfig struct { - // Called when the configuration is changed by HTTP request - ConfigModified func() `yaml:"-"` + // ConfModifier is used to update the global configuration. It must not be + // nil. + ConfModifier agh.ConfigModifier `yaml:"-"` // Register an HTTP handler HTTPRegister aghhttp.RegisterFunc `yaml:"-"` diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index edc3d3a4..5da089d3 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -107,7 +107,7 @@ var _ Interface = (*server)(nil) func Create(conf *ServerConfig) (s *server, err error) { s = &server{ conf: &ServerConfig{ - ConfigModified: conf.ConfigModified, + ConfModifier: conf.ConfModifier, HTTPRegister: conf.HTTPRegister, diff --git a/internal/dhcpd/http_unix.go b/internal/dhcpd/http_unix.go index db81bafc..6d6226bd 100644 --- a/internal/dhcpd/http_unix.go +++ b/internal/dhcpd/http_unix.go @@ -335,7 +335,7 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) { } s.setConfFromJSON(conf, srv4, srv6) - s.conf.ConfigModified() + s.conf.ConfModifier.Apply(r.Context()) err = s.dbLoad() if err != nil { @@ -679,7 +679,7 @@ func (s *server) handleReset(w http.ResponseWriter, r *http.Request) { } s.conf = &ServerConfig{ - ConfigModified: s.conf.ConfigModified, + ConfModifier: s.conf.ConfModifier, HTTPRegister: s.conf.HTTPRegister, @@ -702,7 +702,7 @@ func (s *server) handleReset(w http.ResponseWriter, r *http.Request) { } s.srv6, _ = v6Create(v6conf) - s.conf.ConfigModified() + s.conf.ConfModifier.Apply(r.Context()) } func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) { diff --git a/internal/dhcpd/http_unix_internal_test.go b/internal/dhcpd/http_unix_internal_test.go index 80d37050..23d11322 100644 --- a/internal/dhcpd/http_unix_internal_test.go +++ b/internal/dhcpd/http_unix_internal_test.go @@ -10,6 +10,7 @@ import ( "net/netip" "testing" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -84,10 +85,10 @@ func TestServer_handleDHCPStatus(t *testing.T) { } s, err := Create(&ServerConfig{ - Enabled: true, - Conf4: *defaultV4ServerConf(), - DataDir: t.TempDir(), - ConfigModified: func() {}, + Enabled: true, + Conf4: *defaultV4ServerConf(), + DataDir: t.TempDir(), + ConfModifier: agh.EmptyConfigModifier{}, }) require.NoError(t, err) @@ -178,11 +179,11 @@ func TestServer_HandleUpdateStaticLease(t *testing.T) { } s, err := Create(&ServerConfig{ - Enabled: true, - Conf4: *defaultV4ServerConf(), - Conf6: V6ServerConf{}, - DataDir: t.TempDir(), - ConfigModified: func() {}, + Enabled: true, + Conf4: *defaultV4ServerConf(), + Conf6: V6ServerConf{}, + DataDir: t.TempDir(), + ConfModifier: agh.EmptyConfigModifier{}, }) require.NoError(t, err) @@ -266,11 +267,11 @@ func TestServer_HandleUpdateStaticLease_validation(t *testing.T) { }} s, err := Create(&ServerConfig{ - Enabled: true, - Conf4: *defaultV4ServerConf(), - Conf6: V6ServerConf{}, - DataDir: t.TempDir(), - ConfigModified: func() {}, + Enabled: true, + Conf4: *defaultV4ServerConf(), + Conf6: V6ServerConf{}, + DataDir: t.TempDir(), + ConfModifier: agh.EmptyConfigModifier{}, }) require.NoError(t, err) diff --git a/internal/dnsforward/access.go b/internal/dnsforward/access.go index c5535d30..ef58dff9 100644 --- a/internal/dnsforward/access.go +++ b/internal/dnsforward/access.go @@ -12,7 +12,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/golibs/container" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/stringutil" "github.com/AdguardTeam/urlfilter" "github.com/AdguardTeam/urlfilter/filterlist" @@ -230,6 +229,8 @@ func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error // handleAccessSet handles requests to the POST /control/access/set endpoint. func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + list := &accessListJSON{} err := json.NewDecoder(r.Body).Decode(&list) if err != nil { @@ -253,14 +254,15 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) { return } - defer log.Debug( - "access: updated lists: %d, %d, %d", - len(list.AllowedClients), - len(list.DisallowedClients), - len(list.BlockedHosts), + defer s.logger.DebugContext( + ctx, + "updated access lists", + "allowed", len(list.AllowedClients), + "disallowed", len(list.DisallowedClients), + "blocked_hosts", len(list.BlockedHosts), ) - defer s.conf.ConfigModified() + defer s.conf.ConfModifier.Apply(ctx) s.serverLock.Lock() defer s.serverLock.Unlock() diff --git a/internal/dnsforward/beforerequest.go b/internal/dnsforward/beforerequest.go index 952ea253..1d0a138e 100644 --- a/internal/dnsforward/beforerequest.go +++ b/internal/dnsforward/beforerequest.go @@ -1,13 +1,13 @@ package dnsforward import ( + "context" "encoding/binary" "fmt" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) @@ -41,7 +41,13 @@ func (s *Server) HandleBefore( qt := q.Qtype host := aghnet.NormalizeDomain(q.Name) if s.access.isBlockedHost(host, qt) { - log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host) + // TODO(s.chzhen): Pass context. + s.logger.DebugContext( + context.TODO(), + "request is in access blocklist", + "dns_type", dns.Type(qt), + "host", host, + ) return s.preBlockedResponse(pctx) } diff --git a/internal/dnsforward/beforerequest_internal_test.go b/internal/dnsforward/beforerequest_internal_test.go index 35a1157b..d9b872b8 100644 --- a/internal/dnsforward/beforerequest_internal_test.go +++ b/internal/dnsforward/beforerequest_internal_test.go @@ -9,6 +9,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -132,7 +133,7 @@ func TestServer_HandleBefore_tls(t *testing.T) { s.conf.DisallowedClients = tc.disallowedClients s.conf.BlockedHosts = tc.blockedHosts - err := s.Prepare(&s.conf) + err := s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf) require.NoError(t, err) startDeferStop(t, s) diff --git a/internal/dnsforward/clientid_internal_test.go b/internal/dnsforward/clientid_internal_test.go index ec110f60..f095c448 100644 --- a/internal/dnsforward/clientid_internal_test.go +++ b/internal/dnsforward/clientid_internal_test.go @@ -201,6 +201,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) { srv := &Server{ conf: ServerConfig{TLSConf: tlsConf}, baseLogger: testLogger, + logger: testLogger, } var ( diff --git a/internal/dnsforward/config.go b/internal/dnsforward/config.go index a5d83c6c..123174ba 100644 --- a/internal/dnsforward/config.go +++ b/internal/dnsforward/config.go @@ -1,6 +1,7 @@ package dnsforward import ( + "context" "crypto/tls" "crypto/x509" "fmt" @@ -11,6 +12,7 @@ import ( "strings" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghslog" @@ -259,8 +261,9 @@ type ServerConfig struct { // TLSCiphers are the IDs of TLS cipher suites to use. TLSCiphers []uint16 - // Called when the configuration is changed by HTTP request - ConfigModified func() + // ConfModifier is used to update the global configuration. It must not be + // nil. + ConfModifier agh.ConfigModifier // Register an HTTP handler HTTPRegister aghhttp.RegisterFunc @@ -307,7 +310,7 @@ const ( ) // newProxyConfig creates and validates configuration for the main proxy. -func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { +func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err error) { srvConf := s.conf trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies) @@ -355,12 +358,12 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) { return nil, fmt.Errorf("bogus_nxdomain: %w", err) } - err = s.prepareTLS(conf) + err = s.prepareTLS(ctx, conf) if err != nil { return nil, fmt.Errorf("validating tls: %w", err) } - err = s.preparePlain(conf) + err = s.preparePlain(ctx, conf) if err != nil { return nil, fmt.Errorf("validating plain: %w", err) } @@ -444,7 +447,7 @@ func (s *Server) initDefaultSettings() { // prepareIpsetListSettings reads and prepares the ipset configuration either // from a file or from the data in the configuration file. -func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) { +func (s *Server) prepareIpsetListSettings(ctx context.Context) (ipsets []string, err error) { fn := s.conf.IpsetListFileName if fn == "" { return s.conf.IpsetList, nil @@ -459,7 +462,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) { ipsets = stringutil.SplitTrimmed(string(data), "\n") ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty) - log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn) + s.logger.DebugContext(ctx, "using ipset rules from file", "num", len(ipsets), "file", fn) return ipsets, nil } @@ -629,7 +632,7 @@ func (s *Server) prepareDNSCrypt(proxyConf *proxy.Config) { } // prepareTLS sets up the TLS configuration for the DNS proxy. -func (s *Server) prepareTLS(proxyConf *proxy.Config) (err error) { +func (s *Server) prepareTLS(ctx context.Context, proxyConf *proxy.Config) (err error) { s.prepareDNSCrypt(proxyConf) if s.conf.TLSConf.Cert == nil { @@ -653,11 +656,20 @@ func (s *Server) prepareTLS(proxyConf *proxy.Config) (err error) { if s.conf.TLSConf.StrictSNICheck { if len(cert.DNSNames) != 0 { s.dnsNames = cert.DNSNames - log.Debug("dns: using certificate's SAN as DNS names: %v", cert.DNSNames) + s.logger.DebugContext( + ctx, + "using certificate's SAN as DNS names", + "dns_names", cert.DNSNames, + ) slices.Sort(s.dnsNames) } else { s.dnsNames = []string{cert.Subject.CommonName} - log.Debug("dns: using certificate's CN as DNS name: %s", cert.Subject.CommonName) + s.logger.DebugContext( + ctx, + "using certificate's CN as DNS name", + "common_name", + cert.Subject.CommonName, + ) } } @@ -706,15 +718,22 @@ func anyNameMatches(dnsNames []string, sni string) (ok bool) { // If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake. func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) { if s.conf.TLSConf.StrictSNICheck && !anyNameMatches(s.dnsNames, ch.ServerName) { - log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName) + // TODO(s.chzhen): Pass context. + s.logger.WarnContext( + context.TODO(), + "unknown SNI in Client Hello", + "server_name", ch.ServerName, + ) + return nil, fmt.Errorf("invalid SNI") } + return s.conf.TLSConf.Cert, nil } // preparePlain prepares the plain-DNS configuration for the DNS proxy. // preparePlain assumes that prepareTLS has already been called. -func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) { +func (s *Server) preparePlain(ctx context.Context, proxyConf *proxy.Config) (err error) { if s.conf.ServePlainDNS { proxyConf.UDPListenAddr = s.conf.UDPListenAddrs proxyConf.TCPListenAddr = s.conf.TCPListenAddrs @@ -732,14 +751,16 @@ func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) { return errors.Error("disabling plain dns requires at least one encrypted protocol") } - log.Info("dnsforward: warning: plain dns is disabled") + s.logger.WarnContext(ctx, "plain dns is disabled") return nil } // UpdatedProtectionStatus updates protection state, if the protection was // disabled temporarily. Returns the updated state of protection. -func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) { +func (s *Server) UpdatedProtectionStatus( + ctx context.Context, +) (enabled bool, disabledUntil *time.Time) { s.serverLock.RLock() defer s.serverLock.RUnlock() @@ -759,7 +780,7 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti // // See https://github.com/AdguardTeam/AdGuardHome/issues/5661. if s.protectionUpdateInProgress.CompareAndSwap(false, true) { - go s.enableProtectionAfterPause() + go s.enableProtectionAfterPause(ctx) } return true, nil @@ -767,19 +788,19 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti // enableProtectionAfterPause sets the protection configuration to enabled // values. It is intended to be used as a goroutine. -func (s *Server) enableProtectionAfterPause() { - defer log.OnPanic("dns: enabling protection after pause") +func (s *Server) enableProtectionAfterPause(ctx context.Context) { + defer slogutil.RecoverAndLog(ctx, s.logger) defer s.protectionUpdateInProgress.Store(false) - defer s.conf.ConfigModified() + defer s.conf.ConfModifier.Apply(ctx) s.serverLock.Lock() defer s.serverLock.Unlock() s.dnsFilter.SetProtectionStatus(true, nil) - log.Info("dns: protection is restarted after pause") + s.logger.InfoContext(ctx, "protection is restarted after pause") } // validateCacheTTL returns an error if the configuration of the cache TTL diff --git a/internal/dnsforward/dialcontext.go b/internal/dnsforward/dialcontext.go index 0ed91fb8..669f6c58 100644 --- a/internal/dnsforward/dialcontext.go +++ b/internal/dnsforward/dialcontext.go @@ -9,7 +9,6 @@ import ( "time" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" ) @@ -17,7 +16,7 @@ import ( // addr should be a valid host:port address, where host could be a domain name // or an IP address. func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) { - log.Debug("dnsforward: dialing %q for network %q", addr, network) + s.logger.DebugContext(ctx, "dialing", "addr", addr, "network", network) host, portStr, err := net.SplitHostPort(addr) if err != nil { @@ -45,7 +44,7 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne return nil, fmt.Errorf("no addresses for host %q", host) } - log.Debug("dnsforward: resolved %q: %v", host, ips) + s.logger.DebugContext(ctx, "resolved", "host", host, "ips", ips) var dialErrs []error for _, ip := range ips { diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 05814288..9b2e91f6 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -145,6 +145,10 @@ type Server struct { // have a prefix and must not be nil. baseLogger *slog.Logger + // logger is used to log the operation of the DNS server. It is created + // during initialization in [NewServer]. + logger *slog.Logger + // dnsFilter is the DNS filter for filtering client's DNS requests and // responses. dnsFilter *filtering.DNSFilter @@ -254,6 +258,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { queryLog: p.QueryLog, privateNets: p.PrivateNets, baseLogger: p.Logger, + logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"), // TODO(e.burkov): Use some case-insensitive string comparison. localDomainSuffix: strings.ToLower(localDomainSuffix), etcHosts: etcHosts, @@ -286,7 +291,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) { // its workers finished. But it would require the upstream.Upstream to have the // Close method to prevent from hanging while waiting for unresponsive server to // respond. -func (s *Server) Close() { +func (s *Server) Close(ctx context.Context) { s.serverLock.Lock() defer s.serverLock.Unlock() @@ -296,7 +301,7 @@ func (s *Server) Close() { s.dnsProxy = nil if err := s.ipset.close(); err != nil { - log.Error("dnsforward: closing ipset: %s", err) + s.logger.ErrorContext(ctx, "closing ipset", slogutil.KeyError, err) } } @@ -461,18 +466,17 @@ func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) { } // Start starts the DNS server. It must only be called after [Server.Prepare]. -func (s *Server) Start() error { +func (s *Server) Start(ctx context.Context) error { s.serverLock.Lock() defer s.serverLock.Unlock() - return s.startLocked() + return s.startLocked(ctx) } // startLocked starts the DNS server without locking. s.serverLock is expected // to be locked. -func (s *Server) startLocked() error { - // TODO(e.burkov): Use context properly. - err := s.dnsProxy.Start(context.Background()) +func (s *Server) startLocked(ctx context.Context) error { + err := s.dnsProxy.Start(ctx) if err == nil { s.isRunning = true } @@ -482,7 +486,7 @@ func (s *Server) startLocked() error { // Prepare initializes parameters of s using data from conf. conf must not be // nil. -func (s *Server) Prepare(conf *ServerConfig) (err error) { +func (s *Server) Prepare(ctx context.Context, conf *ServerConfig) (err error) { s.conf = *conf // dnsFilter can be nil during application update. @@ -496,13 +500,13 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) { s.initDefaultSettings() - err = s.prepareInternalDNS() + err = s.prepareInternalDNS(ctx) if err != nil { // Don't wrap the error, because it's informative enough as is. return err } - proxyConfig, err := s.newProxyConfig() + proxyConfig, err := s.newProxyConfig(ctx) if err != nil { return fmt.Errorf("preparing proxy: %w", err) } @@ -633,8 +637,8 @@ func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) { // prepareInternalDNS initializes the internal state of s before initializing // the primary DNS proxy instance. It assumes s.serverLock is locked or the // Server not running. -func (s *Server) prepareInternalDNS() (err error) { - ipsetList, err := s.prepareIpsetListSettings() +func (s *Server) prepareInternalDNS(ctx context.Context) (err error) { + ipsetList, err := s.prepareIpsetListSettings(ctx) if err != nil { return fmt.Errorf("preparing ipset settings: %w", err) } @@ -779,27 +783,26 @@ func (s *Server) prepareInternalProxy() (err error) { } // Stop stops the DNS server. -func (s *Server) Stop() error { +func (s *Server) Stop(ctx context.Context) error { s.serverLock.Lock() defer s.serverLock.Unlock() - s.stopLocked() + s.stopLocked(ctx) return nil } // stopLocked stops the DNS server without locking. s.serverLock is expected to // be locked. -func (s *Server) stopLocked() { +func (s *Server) stopLocked(ctx context.Context) { // TODO(e.burkov, a.garipov): Return critical errors, not just log them. // This will require filtering all the non-critical errors in // [upstream.Upstream] implementations. if s.dnsProxy != nil { - // TODO(e.burkov): Use context properly. - err := s.dnsProxy.Shutdown(context.Background()) + err := s.dnsProxy.Shutdown(ctx) if err != nil { - log.Error("dnsforward: closing primary resolvers: %s", err) + s.logger.ErrorContext(ctx, "closing primary resolvers", slogutil.KeyError, err) } } @@ -848,14 +851,14 @@ func (s *Server) proxy() (p *proxy.Proxy) { // Reconfigure applies the new configuration to the DNS server. // // TODO(a.garipov): This whole piece of API is weird and needs to be remade. -func (s *Server) Reconfigure(conf *ServerConfig) error { +func (s *Server) Reconfigure(ctx context.Context, conf *ServerConfig) error { s.serverLock.Lock() defer s.serverLock.Unlock() - log.Info("dnsforward: starting reconfiguring server") - defer log.Info("dnsforward: finished reconfiguring server") + s.logger.InfoContext(ctx, "starting reconfiguring server") + defer s.logger.InfoContext(ctx, "finished reconfiguring server") - s.stopLocked() + s.stopLocked(ctx) // It seems that net.Listener.Close() doesn't close file descriptors right away. // We wait for some time and hope that this fd will be closed. @@ -864,7 +867,7 @@ func (s *Server) Reconfigure(conf *ServerConfig) error { if s.addrProc != nil { err := s.addrProc.Close() if err != nil { - log.Error("dnsforward: closing address processor: %s", err) + s.logger.ErrorContext(ctx, "closing address processor", slogutil.KeyError, err) } } @@ -874,12 +877,12 @@ func (s *Server) Reconfigure(conf *ServerConfig) error { // TODO(e.burkov): It seems an error here brings the server down, which is // not reliable enough. - err := s.Prepare(conf) + err := s.Prepare(ctx, conf) if err != nil { return fmt.Errorf("could not reconfigure the server: %w", err) } - err = s.startLocked() + err = s.startLocked(ctx) if err != nil { return fmt.Errorf("could not reconfigure the server: %w", err) } @@ -908,16 +911,29 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, allowlistMode := s.access.allowlistMode() blockedByClientID := s.access.isBlockedClientID(clientID) + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + // Allow if at least one of the checks allows in allowlist mode, but block // if at least one of the checks blocks in blocklist mode. if allowlistMode && blockedByIP && blockedByClientID { - log.Debug("dnsforward: client %v (id %q) is not in access allowlist", ip, clientID) + s.logger.DebugContext( + ctx, + "client is not in access allowlist", + "ip", ip, + "client_id", clientID, + ) // Return now without substituting the empty rule for the // clientID because the rule can't be empty here. return true, rule } else if !allowlistMode && (blockedByIP || blockedByClientID) { - log.Debug("dnsforward: client %v (id %q) is in access blocklist", ip, clientID) + s.logger.DebugContext( + ctx, + "client is in access blocklist", + "ip", ip, + "client_id", clientID, + ) blocked = true } diff --git a/internal/dnsforward/dnsforward_internal_test.go b/internal/dnsforward/dnsforward_internal_test.go index adb08537..e05fa7a0 100644 --- a/internal/dnsforward/dnsforward_internal_test.go +++ b/internal/dnsforward/dnsforward_internal_test.go @@ -22,6 +22,7 @@ import ( "testing/fstest" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/client" @@ -106,9 +107,11 @@ func (c *clientsContainer) ClearUpstreamCache() { func startDeferStop(t *testing.T, s *Server) { t.Helper() - err := s.Start() + err := s.Start(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, s.Stop) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return s.Stop(testutil.ContextWithTimeout(t, testTimeout)) + }) } // applyEmptyClientFiltering is a helper function for tests with @@ -169,7 +172,7 @@ func createTestServer( }) require.NoError(t, err) - err = s.Prepare(&forwardConf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf) require.NoError(t, err) return s @@ -244,7 +247,7 @@ func createTestTLS(t *testing.T, tlsConf *TLSConfig) (s *Server, certPem []byte) ServePlainDNS: true, }) - err = s.Prepare(&s.conf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf) require.NoErrorf(t, err, "failed to prepare server: %s", err) return s, certPem @@ -420,7 +423,7 @@ func TestServer_timeout(t *testing.T) { }) require.NoError(t, err) - err = s.Prepare(srvConf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), srvConf) require.NoError(t, err) assert.Equal(t, testTimeout, s.conf.UpstreamTimeout) @@ -439,7 +442,7 @@ func TestServer_timeout(t *testing.T) { Enabled: false, } s.conf.Config.ClientsContainer = EmptyClientsContainer{} - err = s.Prepare(&s.conf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf) require.NoError(t, err) assert.Equal(t, DefaultTimeout, s.conf.UpstreamTimeout) @@ -466,7 +469,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) { }) require.NoError(t, err) - err = s.Prepare(srvConf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), srvConf) require.NoError(t, err) require.NotNil(t, s.dnsProxy.Fallbacks) @@ -566,8 +569,8 @@ func TestServerRace(t *testing.T) { UpstreamMode: UpstreamModeLoadBalance, UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"}, }, - ConfigModified: func() {}, - ServePlainDNS: true, + ConfModifier: agh.EmptyConfigModifier{}, + ServePlainDNS: true, } s := createTestServer(t, filterConf, forwardConf) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()} @@ -1104,7 +1107,7 @@ func TestBlockedCustomIP(t *testing.T) { } // Invalid BlockingIPv4. - err = s.Prepare(conf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), conf) assert.Error(t, err) s.dnsFilter.SetBlockingMode( @@ -1112,7 +1115,7 @@ func TestBlockedCustomIP(t *testing.T) { netip.AddrFrom4([4]byte{0, 0, 0, 1}), netip.MustParseAddr("::1")) - err = s.Prepare(conf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), conf) require.NoError(t, err) f.SetEnabled(true) @@ -1264,7 +1267,7 @@ func TestRewrite(t *testing.T) { }) require.NoError(t, err) - assert.NoError(t, s.Prepare(&ServerConfig{ + assert.NoError(t, s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &ServerConfig{ UDPListenAddrs: []*net.UDPAddr{{}}, TCPListenAddrs: []*net.TCPAddr{{}}, TLSConf: &TLSConfig{}, @@ -1334,7 +1337,7 @@ func TestRewrite(t *testing.T) { for _, protect := range []bool{true, false} { val := protect - conf := s.getDNSConfig() + conf := s.getDNSConfig(testutil.ContextWithTimeout(t, testTimeout)) conf.ProtectionEnabled = &val s.setConfig(conf) @@ -1408,12 +1411,12 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) { s.conf.Config.ClientsContainer = EmptyClientsContainer{} s.conf.Config.UpstreamMode = UpstreamModeLoadBalance - err = s.Prepare(&s.conf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf) require.NoError(t, err) - err = s.Start() + err = s.Start(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) - t.Cleanup(s.Close) + t.Cleanup(func() { s.Close(testutil.ContextWithTimeout(t, testTimeout)) }) addr := s.dnsProxy.Addr(proxy.ProtoUDP) req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR) @@ -1498,12 +1501,12 @@ func TestPTRResponseFromHosts(t *testing.T) { s.conf.Config.ClientsContainer = EmptyClientsContainer{} s.conf.Config.UpstreamMode = UpstreamModeLoadBalance - err = s.Prepare(&s.conf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf) require.NoError(t, err) - err = s.Start() + err = s.Start(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) - t.Cleanup(s.Close) + t.Cleanup(func() { s.Close(testutil.ContextWithTimeout(t, testTimeout)) }) subTestFunc := func(t *testing.T) { addr := s.dnsProxy.Addr(proxy.ProtoUDP) @@ -1524,7 +1527,7 @@ func TestPTRResponseFromHosts(t *testing.T) { for _, protect := range []bool{true, false} { val := protect - conf := s.getDNSConfig() + conf := s.getDNSConfig(testutil.ContextWithTimeout(t, testTimeout)) conf.ProtectionEnabled = &val s.setConfig(conf) diff --git a/internal/dnsforward/dnsrewrite.go b/internal/dnsforward/dnsrewrite.go index 7d9fde72..00e4cd56 100644 --- a/internal/dnsforward/dnsrewrite.go +++ b/internal/dnsforward/dnsrewrite.go @@ -1,13 +1,13 @@ package dnsforward import ( + "context" "fmt" "net/netip" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) @@ -15,6 +15,7 @@ import ( // filterDNSRewriteResponse handles a single DNS rewrite response entry. It // returns the properly constructed answer resource record. func (s *Server) filterDNSRewriteResponse( + ctx context.Context, req *dns.Msg, rr rules.RRType, v rules.RRValue, @@ -27,11 +28,11 @@ func (s *Server) filterDNSRewriteResponse( case dns.TypeMX: return s.ansFromDNSRewriteMX(v, rr, req) case dns.TypeHTTPS, dns.TypeSVCB: - return s.ansFromDNSRewriteSVCB(v, rr, req) + return s.ansFromDNSRewriteSVCB(ctx, v, rr, req) case dns.TypeSRV: return s.ansFromDNSRewriteSRV(v, rr, req) default: - log.Debug("don't know how to handle dns rr type %d, skipping", rr) + s.logger.DebugContext(ctx, "unsupported dns rr type, skipping", "res_record", rr) return nil, nil } @@ -97,6 +98,7 @@ func (s *Server) ansFromDNSRewriteMX( // ansFromDNSRewriteSVCB creates a new answer resource record from the // SVCB/HTTPS dnsrewrite rule data. func (s *Server) ansFromDNSRewriteSVCB( + ctx context.Context, v rules.RRValue, rr rules.RRType, req *dns.Msg, @@ -111,10 +113,10 @@ func (s *Server) ansFromDNSRewriteSVCB( } if rr == dns.TypeHTTPS { - return s.genAnswerHTTPS(req, svcb), nil + return s.genAnswerHTTPS(ctx, req, svcb), nil } - return s.genAnswerSVCB(req, svcb), nil + return s.genAnswerSVCB(ctx, req, svcb), nil } // ansFromDNSRewriteSRV creates a new answer resource record from the SRV @@ -139,6 +141,7 @@ func (s *Server) ansFromDNSRewriteSRV( // filterDNSRewrite handles dnsrewrite filters. It constructs a DNS response // and sets it into pctx.Res. All parameters must not be nil. func (s *Server) filterDNSRewrite( + ctx context.Context, req *dns.Msg, res *filtering.Result, pctx *proxy.DNSContext, @@ -164,7 +167,7 @@ func (s *Server) filterDNSRewrite( values := dnsrr.Response[qtype] for i, v := range values { var ans dns.RR - ans, err = s.filterDNSRewriteResponse(req, qtype, v) + ans, err = s.filterDNSRewriteResponse(ctx, req, qtype, v) if err != nil { return fmt.Errorf("dns rewrite response for %s[%d]: %w", dns.Type(qtype), i, err) } diff --git a/internal/dnsforward/dnsrewrite_internal_test.go b/internal/dnsforward/dnsrewrite_internal_test.go index f30c661b..ed523fb5 100644 --- a/internal/dnsforward/dnsrewrite_internal_test.go +++ b/internal/dnsforward/dnsrewrite_internal_test.go @@ -7,6 +7,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -71,7 +72,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeNameError, 0, nil) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeNameError, d.Res.Rcode) @@ -82,7 +83,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, 0, nil) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -94,7 +95,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypeA, ip4) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -108,7 +109,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypeAAAA, ip6) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -122,7 +123,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypePTR, domain) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -136,7 +137,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypeTXT, domain) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -150,7 +151,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypeMX, mxVal) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -168,7 +169,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypeSVCB, svcbVal) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -198,7 +199,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypeHTTPS, svcbVal) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) @@ -228,7 +229,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) { res := makeRes(dns.RcodeSuccess, dns.TypeSRV, srvVal) d := &proxy.DNSContext{} - err := srv.filterDNSRewrite(req, res, d) + err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d) require.NoError(t, err) assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode) diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 6cfd7bea..ebb08f84 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -1,13 +1,13 @@ package dnsforward import ( + "context" "fmt" "net" "slices" "strings" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) @@ -24,7 +24,10 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter // filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the // request was filtered. -func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err error) { +func (s *Server) filterDNSRequest( + ctx context.Context, + dctx *dnsContext, +) (res *filtering.Result, err error) { pctx := dctx.proxyCtx req := pctx.Req q := req.Question[0] @@ -44,12 +47,12 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err dctx.origQuestion = q req.Question[0].Name = dns.Fqdn(res.CanonName) case res.IsFiltered: - log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason) - pctx.Res = s.genDNSFilterMessage(pctx, res) + s.logger.DebugContext(ctx, "host is filtered", "host", host, "reason", res.Reason) + pctx.Res = s.genDNSFilterMessage(ctx, pctx, res) case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch): - pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName) + pctx.Res = s.getCNAMEWithIPs(ctx, req, res.IPList, res.CanonName) case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts): - if err = s.filterDNSRewrite(req, res, pctx); err != nil { + if err = s.filterDNSRewrite(ctx, req, res, pctx); err != nil { return nil, err } } @@ -90,7 +93,7 @@ func (s *Server) checkHostRules( // dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of // canonical names, IP addresses, or HTTPS RR hints in it matches the filtering // rules, as well as sets dctx.proxyCtx.Res to the filtered response. -func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) { +func (s *Server) filterDNSResponse(ctx context.Context, dctx *dnsContext) (err error) { setts := dctx.setts if !setts.FilteringEnabled { return nil @@ -123,16 +126,27 @@ func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) { continue } - log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name) + s.logger.DebugContext( + ctx, + "checked", + "dns_type", dns.Type(rrtype), + "host", host, + "name", a.Header().Name, + ) if err != nil { return fmt.Errorf("filtering answer at index %d: %w", i, err) } else if res != nil && res.IsFiltered { dctx.result = res dctx.origResp = pctx.Res - pctx.Res = s.genDNSFilterMessage(pctx, res) + pctx.Res = s.genDNSFilterMessage(ctx, pctx, res) - log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host) + s.logger.DebugContext( + ctx, + "matched by response", + "name", pctx.Req.Question[0].Name, + "host", host, + ) break } diff --git a/internal/dnsforward/filter_internal_test.go b/internal/dnsforward/filter_internal_test.go index 8a07b4f1..cd6da731 100644 --- a/internal/dnsforward/filter_internal_test.go +++ b/internal/dnsforward/filter_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/netutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -66,7 +67,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) { }) require.NoError(t, err) - err = s.Prepare(&forwardConf) + err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf) require.NoError(t, err) s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{ @@ -347,7 +348,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) { }, } - fltErr := s.filterDNSResponse(dctx) + fltErr := s.filterDNSResponse(testutil.ContextWithTimeout(t, testTimeout), dctx) require.NoError(t, fltErr) res := dctx.result diff --git a/internal/dnsforward/http.go b/internal/dnsforward/http.go index 59f3fde8..1612081f 100644 --- a/internal/dnsforward/http.go +++ b/internal/dnsforward/http.go @@ -2,6 +2,7 @@ package dnsforward import ( "cmp" + "context" "encoding/json" "fmt" "io" @@ -17,7 +18,6 @@ import ( "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/stringutil" @@ -138,8 +138,8 @@ const ( jsonUpstreamModeFastestAddr jsonUpstreamMode = "fastest_addr" ) -func (s *Server) getDNSConfig() (c *jsonDNSConfig) { - protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus() +func (s *Server) getDNSConfig(ctx context.Context) (c *jsonDNSConfig) { + protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus(ctx) s.serverLock.RLock() defer s.serverLock.RUnlock() @@ -184,7 +184,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) { defPTRUps, err := s.defaultLocalPTRUpstreams() if err != nil { - log.Error("dnsforward: %s", err) + s.logger.ErrorContext(ctx, "getting local ptr upstreams", slogutil.KeyError, err) } return &jsonDNSConfig{ @@ -240,7 +240,7 @@ func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) { // handleGetConfig handles requests to the GET /control/dns_info endpoint. func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { - resp := s.getDNSConfig() + resp := s.getDNSConfig(r.Context()) aghhttp.WriteJSONResponseOK(w, r, resp) } @@ -486,6 +486,8 @@ func checkInclusion(ptr *int, minN, maxN int) (err error) { // handleSetConfig handles requests to the POST /control/dns_config endpoint. func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + req := &jsonDNSConfig{} err := json.NewDecoder(r.Body).Decode(req) if err != nil { @@ -511,10 +513,10 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) { } restart := s.setConfig(req) - s.conf.ConfigModified() + s.conf.ConfModifier.Apply(ctx) if restart { - err = s.Reconfigure(nil) + err = s.Reconfigure(ctx, nil) if err != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) } @@ -726,7 +728,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) { s.dnsFilter.SetProtectionStatus(protectionReq.Enabled, disabledUntil) }() - s.conf.ConfigModified() + s.conf.ConfModifier.Apply(r.Context()) aghhttp.OK(w) } diff --git a/internal/dnsforward/http_internal_test.go b/internal/dnsforward/http_internal_test.go index 2b076281..47272adb 100644 --- a/internal/dnsforward/http_internal_test.go +++ b/internal/dnsforward/http_internal_test.go @@ -17,6 +17,7 @@ import ( "testing/fstest" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" @@ -87,14 +88,16 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, ClientsContainer: EmptyClientsContainer{}, }, - ConfigModified: func() {}, - ServePlainDNS: true, + ConfModifier: agh.EmptyConfigModifier{}, + ServePlainDNS: true, } s := createTestServer(t, filterConf, forwardConf) s.sysResolvers = &emptySysResolvers{} - require.NoError(t, s.Start()) - testutil.CleanupAndRequireSuccess(t, s.Stop) + require.NoError(t, s.Start(testutil.ContextWithTimeout(t, testTimeout))) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return s.Stop(testutil.ContextWithTimeout(t, testTimeout)) + }) defaultConf := s.conf @@ -137,7 +140,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) { t.Cleanup(w.Body.Reset) s.conf = tc.conf() - s.handleGetConfig(w, nil) + s.handleGetConfig(w, httptest.NewRequest(http.MethodGet, "/", nil)) cType := w.Header().Get(httphdr.ContentType) assert.Equal(t, aghhttp.HdrValApplicationJSON, cType) @@ -170,17 +173,19 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { EDNSClientSubnet: &EDNSClientSubnet{Enabled: false}, ClientsContainer: EmptyClientsContainer{}, }, - ConfigModified: func() {}, - ServePlainDNS: true, + ConfModifier: agh.EmptyConfigModifier{}, + ServePlainDNS: true, } s := createTestServer(t, filterConf, forwardConf) s.sysResolvers = &emptySysResolvers{} defaultConf := s.conf - err := s.Start() + err := s.Start(testutil.ContextWithTimeout(t, testTimeout)) assert.NoError(t, err) - testutil.CleanupAndRequireSuccess(t, s.Stop) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return s.Stop(testutil.ContextWithTimeout(t, testTimeout)) + }) w := httptest.NewRecorder() @@ -297,7 +302,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) { assert.Equal(t, tc.wantSet, strings.TrimSuffix(w.Body.String(), "\n")) w.Body.Reset() - s.handleGetConfig(w, nil) + s.handleGetConfig(w, httptest.NewRequest(http.MethodGet, "/", nil)) assert.JSONEq(t, string(caseData.Want), w.Body.String()) w.Body.Reset() }) diff --git a/internal/dnsforward/ipset.go b/internal/dnsforward/ipset.go index 7347890a..4204fbb3 100644 --- a/internal/dnsforward/ipset.go +++ b/internal/dnsforward/ipset.go @@ -121,9 +121,7 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) { } // process adds the resolved IP addresses to the domain's ipsets, if any. -func (h *ipsetHandler) process(dctx *dnsContext) (rc resultCode) { - // TODO(s.chzhen): Use passed context. - ctx := context.TODO() +func (h *ipsetHandler) process(ctx context.Context, dctx *dnsContext) (rc resultCode) { h.logger.DebugContext(ctx, "started processing") defer h.logger.DebugContext(ctx, "finished processing") diff --git a/internal/dnsforward/ipset_internal_test.go b/internal/dnsforward/ipset_internal_test.go index 90c200d0..d7735b91 100644 --- a/internal/dnsforward/ipset_internal_test.go +++ b/internal/dnsforward/ipset_internal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/AdguardTeam/dnsproxy/proxy" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" ) @@ -62,7 +63,7 @@ func TestIpsetCtx_process(t *testing.T) { ictx := &ipsetHandler{ logger: testLogger, } - rc := ictx.process(dctx) + rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx) assert.Equal(t, resultCodeSuccess, rc) err := ictx.close() @@ -85,7 +86,7 @@ func TestIpsetCtx_process(t *testing.T) { logger: testLogger, } - rc := ictx.process(dctx) + rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx) assert.Equal(t, resultCodeSuccess, rc) assert.Equal(t, []net.IP{ip4}, m.ip4s) assert.Empty(t, m.ip6s) @@ -110,7 +111,7 @@ func TestIpsetCtx_process(t *testing.T) { logger: testLogger, } - rc := ictx.process(dctx) + rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx) assert.Equal(t, resultCodeSuccess, rc) assert.Empty(t, m.ip4s) assert.Equal(t, []net.IP{ip6}, m.ip6s) diff --git a/internal/dnsforward/msg.go b/internal/dnsforward/msg.go index e9f1f2d7..6fd8580b 100644 --- a/internal/dnsforward/msg.go +++ b/internal/dnsforward/msg.go @@ -1,12 +1,13 @@ package dnsforward import ( + "context" "net/netip" "slices" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" ) @@ -47,6 +48,7 @@ func ipsFromRules(resRules []*filtering.ResultRule) (ips []netip.Addr) { // genDNSFilterMessage generates a filtered response to req for the filtering // result res. func (s *Server) genDNSFilterMessage( + ctx context.Context, dctx *proxy.DNSContext, res *filtering.Result, ) (resp *dns.Msg) { @@ -63,22 +65,27 @@ func (s *Server) genDNSFilterMessage( switch res.Reason { case filtering.FilteredSafeBrowsing: - return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost(), dctx) + return s.genBlockedHost(ctx, req, s.dnsFilter.SafeBrowsingBlockHost(), dctx) case filtering.FilteredParental: - return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost(), dctx) + return s.genBlockedHost(ctx, req, s.dnsFilter.ParentalBlockHost(), dctx) case filtering.FilteredSafeSearch: // If Safe Search generated the necessary IP addresses, use them. // Otherwise, if there were no errors, there are no addresses for the // requested IP version, so produce a NODATA response. - return s.getCNAMEWithIPs(req, ipsFromRules(res.Rules), res.CanonName) + return s.getCNAMEWithIPs(ctx, req, ipsFromRules(res.Rules), res.CanonName) default: - return s.genForBlockingMode(req, ipsFromRules(res.Rules)) + return s.genForBlockingMode(ctx, req, ipsFromRules(res.Rules)) } } // getCNAMEWithIPs generates a filtered response to req for with CNAME record // and provided ips. -func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) { +func (s *Server) getCNAMEWithIPs( + ctx context.Context, + req *dns.Msg, + ips []netip.Addr, + cname string, +) (resp *dns.Msg) { resp = s.replyCompressed(req) originalName := req.Question[0].Name @@ -94,7 +101,7 @@ func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) ( switch req.Question[0].Qtype { case dns.TypeA: - ans = append(ans, s.genAnswersWithIPv4s(req, ips)...) + ans = append(ans, s.genAnswersWithIPv4s(ctx, req, ips)...) case dns.TypeAAAA: for _, ip := range ips { if ip.Is6() { @@ -112,24 +119,28 @@ func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) ( // genForBlockingMode generates a filtered response to req based on the server's // blocking mode. -func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) { +func (s *Server) genForBlockingMode( + ctx context.Context, + req *dns.Msg, + ips []netip.Addr, +) (resp *dns.Msg) { switch mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode(); mode { case filtering.BlockingModeCustomIP: - return s.makeResponseCustomIP(req, bIPv4, bIPv6) + return s.makeResponseCustomIP(ctx, req, bIPv4, bIPv6) case filtering.BlockingModeDefault: if len(ips) > 0 { - return s.genResponseWithIPs(req, ips) + return s.genResponseWithIPs(ctx, req, ips) } - return s.makeResponseNullIP(req) + return s.makeResponseNullIP(ctx, req) case filtering.BlockingModeNullIP: - return s.makeResponseNullIP(req) + return s.makeResponseNullIP(ctx, req) case filtering.BlockingModeNXDOMAIN: return s.NewMsgNXDOMAIN(req) case filtering.BlockingModeREFUSED: return s.makeResponseREFUSED(req) default: - log.Error("dnsforward: invalid blocking mode %q", mode) + s.logger.ErrorContext(ctx, "invalid blocking mode", "mode", mode) return s.replyCompressed(req) } @@ -138,6 +149,7 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M // makeResponseCustomIP generates a DNS response message for Custom IP blocking // mode with the provided IP addresses and an appropriate resource record type. func (s *Server) makeResponseCustomIP( + ctx context.Context, req *dns.Msg, bIPv4 netip.Addr, bIPv6 netip.Addr, @@ -150,7 +162,11 @@ func (s *Server) makeResponseCustomIP( default: // Generally shouldn't happen, since the types are checked in // genDNSFilterMessage. - log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt)) + s.logger.ErrorContext( + ctx, + "invalid message type for custom IP blocking mode", + "dns_type", dns.Type(qt), + ) return s.replyCompressed(req) } @@ -234,11 +250,15 @@ func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) { // addresses and an appropriate resource record type. If any of the IPs cannot // be converted to the correct protocol, genResponseWithIPs returns an empty // response. -func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) { +func (s *Server) genResponseWithIPs( + ctx context.Context, + req *dns.Msg, + ips []netip.Addr, +) (resp *dns.Msg) { var ans []dns.RR switch req.Question[0].Qtype { case dns.TypeA: - ans = s.genAnswersWithIPv4s(req, ips) + ans = s.genAnswersWithIPv4s(ctx, req, ips) case dns.TypeAAAA: for _, ip := range ips { if ip.Is6() { @@ -258,10 +278,14 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M // genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any // of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and // returns nil, -func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) { +func (s *Server) genAnswersWithIPv4s( + ctx context.Context, + req *dns.Msg, + ips []netip.Addr, +) (ans []dns.RR) { for _, ip := range ips { if !ip.Is4() { - log.Info("dnsforward: warning: ip %s is not ipv4 address", ip) + s.logger.WarnContext(ctx, "ip is not an ipv4 address", "ip", ip) return nil } @@ -274,16 +298,16 @@ func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns. // makeResponseNullIP creates a response with 0.0.0.0 for A requests, :: for // AAAA requests, and an empty response for other types. -func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) { +func (s *Server) makeResponseNullIP(ctx context.Context, req *dns.Msg) (resp *dns.Msg) { // Respond with the corresponding zero IP type as opposed to simply // using one or the other in both cases, because the IPv4 zero IP is // converted to a IPV6-mapped IPv4 address, while the IPv6 zero IP is // converted into an empty slice instead of the zero IPv4. switch req.Question[0].Qtype { case dns.TypeA: - resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv4Unspecified()}) + resp = s.genResponseWithIPs(ctx, req, []netip.Addr{netip.IPv4Unspecified()}) case dns.TypeAAAA: - resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()}) + resp = s.genResponseWithIPs(ctx, req, []netip.Addr{netip.IPv6Unspecified()}) default: resp = s.replyCompressed(req) } @@ -291,16 +315,21 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) { return resp } -func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg { +func (s *Server) genBlockedHost( + ctx context.Context, + request *dns.Msg, + newAddr string, + d *proxy.DNSContext, +) (msg *dns.Msg) { if newAddr == "" { - log.Info("dnsforward: block host is not specified") + s.logger.InfoContext(ctx, "block host not specified") return s.NewMsgSERVFAIL(request) } ip, err := netip.ParseAddr(newAddr) if err == nil { - return s.genResponseWithIPs(request, []netip.Addr{ip}) + return s.genResponseWithIPs(ctx, request, []netip.Addr{ip}) } // look up the hostname, TODO: cache @@ -316,14 +345,19 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo prx := s.proxy() if prx == nil { - log.Debug("dnsforward: %s", srvClosedErr) + s.logger.DebugContext(ctx, "getting current proxy", slogutil.KeyError, srvClosedErr) return s.NewMsgSERVFAIL(request) } err = prx.Resolve(newContext) if err != nil { - log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err) + s.logger.ErrorContext( + ctx, + "looking up replacement host", + "host", newAddr, + slogutil.KeyError, err, + ) return s.NewMsgSERVFAIL(request) } diff --git a/internal/dnsforward/process.go b/internal/dnsforward/process.go index 623edff0..d2d85295 100644 --- a/internal/dnsforward/process.go +++ b/internal/dnsforward/process.go @@ -10,7 +10,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" ) @@ -83,13 +82,16 @@ const ddrHostFQDN = "_dns.resolver.arpa." // handleDNSRequest filters the incoming DNS requests and writes them to the query log func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error { + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + dctx := &dnsContext{ proxyCtx: pctx, result: &filtering.Result{}, startTime: time.Now(), } - type modProcessFunc func(ctx *dnsContext) (rc resultCode) + type modProcessFunc func(ctx context.Context, dctx *dnsContext) (rc resultCode) // Since (*dnsforward.Server).handleDNSRequest(...) is used as // proxy.(Config).RequestHandler, there is no need for additional index @@ -108,7 +110,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error s.processQueryLogsAndStats, } for _, process := range mods { - r := process(dctx) + r := process(ctx, dctx) switch r { case resultCodeSuccess: // continue: call the next filter @@ -149,12 +151,12 @@ const healthcheckFQDN = "healthcheck.adguardhome.test." // needed and enriches dctx with some client-specific information. // // TODO(e.burkov): Decompose into less general processors. -func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing initial") - defer log.Debug("dnsforward: finished processing initial") +func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing initial") + defer s.logger.DebugContext(ctx, "finished processing initial") pctx := dctx.proxyCtx - s.processClientIP(pctx.Addr.Addr()) + s.processClientIP(ctx, pctx.Addr.Addr()) q := pctx.Req.Question[0] qt := q.Qtype @@ -184,16 +186,16 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) { dctx.clientID = string(s.clientIDCache.Get(key[:])) // Get the client-specific filtering settings. - dctx.protectionEnabled, _ = s.UpdatedProtectionStatus() + dctx.protectionEnabled, _ = s.UpdatedProtectionStatus(ctx) dctx.setts = s.clientRequestFilteringSettings(dctx) return resultCodeSuccess } // processClientIP sends the client IP address to s.addrProc, if needed. -func (s *Server) processClientIP(addr netip.Addr) { +func (s *Server) processClientIP(ctx context.Context, addr netip.Addr) { if !addr.IsValid() { - log.Info("dnsforward: warning: bad client addr %q", addr) + s.logger.WarnContext(ctx, "bad client address", "addr", addr) return } @@ -203,8 +205,7 @@ func (s *Server) processClientIP(addr netip.Addr) { s.serverLock.RLock() defer s.serverLock.RUnlock() - // TODO(s.chzhen): Pass context. - s.addrProc.Process(context.TODO(), addr) + s.addrProc.Process(ctx, addr) } // processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB @@ -212,9 +213,9 @@ func (s *Server) processClientIP(addr netip.Addr) { // current user configuration. // // See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html. -func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing ddr") - defer log.Debug("dnsforward: finished processing ddr") +func (s *Server) processDDRQuery(ctx context.Context, dctx *dnsContext) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing ddr") + defer s.logger.DebugContext(ctx, "finished processing ddr") if !s.conf.HandleDDR { return resultCodeSuccess @@ -311,9 +312,9 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) { // the request is for AAAA. // // TODO(a.garipov): Adapt to AAAA as well. -func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing dhcp hosts") - defer log.Debug("dnsforward: finished processing dhcp hosts") +func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing dhcp hosts") + defer s.logger.DebugContext(ctx, "finished processing dhcp hosts") pctx := dctx.proxyCtx req := pctx.Req @@ -325,7 +326,12 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { } if !pctx.IsPrivateClient { - log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost) + s.logger.DebugContext( + ctx, + "requests for dhcp host", + "addr", pctx.Addr, + "dhcp_host", dhcpHost, + ) pctx.Res = s.NewMsgNXDOMAIN(req) // Do not even put into query log. @@ -336,12 +342,12 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { if ip == (netip.Addr{}) { // Go on and process them with filters, including dnsrewrite ones, and // possibly route them to a domain-specific upstream. - log.Debug("dnsforward: no dhcp record for %q", dhcpHost) + s.logger.DebugContext(ctx, "no dhcp record", "dhcp_host", dhcpHost) return resultCodeSuccess } - log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip) + s.logger.DebugContext(ctx, "dhcp record for", "dhcp_host", dhcpHost, "ip", ip) resp := s.replyCompressed(req) switch q.Qtype { @@ -372,9 +378,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) { // processDHCPAddrs responds to PTR requests if the target IP is leased by the // DHCP server. -func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing dhcp addrs") - defer log.Debug("dnsforward: finished processing dhcp addrs") +func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing dhcp addrs") + defer s.logger.DebugContext(ctx, "finished processing dhcp addrs") pctx := dctx.proxyCtx if pctx.Res != nil { @@ -396,7 +402,7 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { return resultCodeSuccess } - log.Debug("dnsforward: dhcp client %s is %q", addr, host) + s.logger.DebugContext(ctx, "dhcp client", "addr", addr, "host", host) resp := s.replyCompressed(req) ptr := &dns.PTR{ @@ -417,9 +423,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) { } // Apply filtering logic -func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing filtering before req") - defer log.Debug("dnsforward: finished processing filtering before req") +func (s *Server) processFilteringBeforeRequest( + ctx context.Context, + dctx *dnsContext, +) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing filtering before request") + defer s.logger.DebugContext(ctx, "finished processing filtering before request") if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) { // There is no need to filter request for locally served ARPA hostname @@ -439,7 +448,7 @@ func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) defer s.serverLock.RUnlock() var err error - if dctx.result, err = s.filterDNSRequest(dctx); err != nil { + if dctx.result, err = s.filterDNSRequest(ctx, dctx); err != nil { dctx.err = err return resultCodeError @@ -458,9 +467,9 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) { } // processUpstream passes request to upstream servers and handles the response. -func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing upstream") - defer log.Debug("dnsforward: finished processing upstream") +func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing upstream") + defer s.logger.DebugContext(ctx, "finished processing upstream") pctx := dctx.proxyCtx req := pctx.Req @@ -475,13 +484,17 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) { // TODO(a.garipov): Route such queries to a custom upstream for the // local domain name if there is one. name := req.Question[0].Name - log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1]) + s.logger.DebugContext( + ctx, + "dhcp client hostname was not filtered", + "hostname", name[:len(name)-1], + ) pctx.Res = s.NewMsgNXDOMAIN(req) return resultCodeFinish } - s.setCustomUpstream(pctx, dctx.clientID) + s.setCustomUpstream(ctx, pctx, dctx.clientID) reqWantsDNSSEC := s.setReqAD(req) @@ -571,7 +584,7 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) { } // setCustomUpstream sets custom upstream settings in pctx, if necessary. -func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { +func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext, clientID string) { if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil { return } @@ -579,10 +592,11 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { cliAddr := pctx.Addr.Addr() upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr) if upsConf != nil { - log.Debug( - "dnsforward: using custom upstreams for client with ip %s and clientid %q", - cliAddr, - clientID, + s.logger.DebugContext( + ctx, + "using custom upstreams for client with", + "ip", cliAddr, + "client_id", clientID, ) pctx.CustomUpstreamConfig = upsConf @@ -590,9 +604,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) { } // Apply filtering logic after we have received response from upstream servers -func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing filtering after resp") - defer log.Debug("dnsforward: finished processing filtering after resp") +func (s *Server) processFilteringAfterResponse(ctx context.Context, dctx *dnsContext) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing filtering after response") + defer s.logger.DebugContext(ctx, "finished processing filtering after response") switch res := dctx.result; res.Reason { case filtering.NotFilteredAllowList: @@ -617,13 +631,13 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) return resultCodeSuccess default: - return s.filterAfterResponse(dctx) + return s.filterAfterResponse(ctx, dctx) } } // filterAfterResponse returns the result of filtering the response that wasn't // explicitly allowed or rewritten. -func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) { +func (s *Server) filterAfterResponse(ctx context.Context, dctx *dnsContext) (res resultCode) { // Check the response only if it's from an upstream. Don't check the // response if the protection is disabled since dnsrewrite rules aren't // applied to it anyway. @@ -631,7 +645,7 @@ func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) { return resultCodeSuccess } - err := s.filterDNSResponse(dctx) + err := s.filterDNSResponse(ctx, dctx) if err != nil { dctx.err = err diff --git a/internal/dnsforward/process_internal_test.go b/internal/dnsforward/process_internal_test.go index 8b335832..b5c54469 100644 --- a/internal/dnsforward/process_internal_test.go +++ b/internal/dnsforward/process_internal_test.go @@ -105,7 +105,7 @@ func TestServer_ProcessInitial(t *testing.T) { }, } - gotRC := s.processInitial(dctx) + gotRC := s.processInitial(testutil.ContextWithTimeout(t, testTimeout), dctx) assert.Equal(t, tc.wantRC, gotRC) assert.Equal(t, testClientAddrPort.Addr(), gotAddr) @@ -208,8 +208,8 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) { Addr: testClientAddrPort, }, } - - gotRC := s.processFilteringAfterResponse(dctx) + ctx := testutil.ContextWithTimeout(t, testTimeout) + gotRC := s.processFilteringAfterResponse(ctx, dctx) assert.Equal(t, tc.wantRC, gotRC) assert.Equal(t, newResp(dns.RcodeSuccess, tc.req, tc.wantRespAns), dctx.proxyCtx.Res) }) @@ -353,7 +353,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) { }, } - res := s.processDDRQuery(dctx) + res := s.processDDRQuery(testutil.ContextWithTimeout(t, testTimeout), dctx) require.Equal(t, tc.wantRes, res) if tc.wantRes != resultCodeFinish { @@ -440,6 +440,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { dhcpServer: dhcp, localDomainSuffix: localDomainSuffix, baseLogger: testLogger, + logger: testLogger, } req := &dns.Msg{ @@ -460,7 +461,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) { }, } - res := s.processDHCPHosts(dctx) + res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx) pctx := dctx.proxyCtx if !tc.isLocalCli { @@ -592,6 +593,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { dhcpServer: testDHCP, localDomainSuffix: tc.suffix, baseLogger: testLogger, + logger: testLogger, } req := (&dns.Msg{}).SetQuestion(dns.Fqdn(tc.host), tc.qtyp) @@ -604,7 +606,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - res := s.processDHCPHosts(dctx) + res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx) pctx := dctx.proxyCtx assert.Equal(t, tc.wantRes, res) require.NoError(t, dctx.err) @@ -812,9 +814,9 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) { ServePlainDNS: true, }, ) + ctx := testutil.ContextWithTimeout(t, testTimeout) pctx := newPrxCtx() - - rc := s.processUpstream(&dnsContext{proxyCtx: pctx}) + rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx}) require.Equal(t, resultCodeSuccess, rc) require.NotEmpty(t, pctx.Res.Answer) ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0]) @@ -844,7 +846,8 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) { ) pctx := newPrxCtx() - rc := s.processUpstream(&dnsContext{proxyCtx: pctx}) + ctx := testutil.ContextWithTimeout(t, testTimeout) + rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx}) require.Equal(t, resultCodeError, rc) require.Empty(t, pctx.Res.Answer) }) diff --git a/internal/dnsforward/stats.go b/internal/dnsforward/stats.go index 50818b40..90590455 100644 --- a/internal/dnsforward/stats.go +++ b/internal/dnsforward/stats.go @@ -1,6 +1,7 @@ package dnsforward import ( + "context" "net" "time" @@ -9,14 +10,13 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/querylog" "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" - "github.com/AdguardTeam/golibs/log" "github.com/miekg/dns" ) // Write Stats data and logs -func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) { - log.Debug("dnsforward: started processing querylog and stats") - defer log.Debug("dnsforward: finished processing querylog and stats") +func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext) (rc resultCode) { + s.logger.DebugContext(ctx, "started processing querylog and stats") + defer s.logger.DebugContext(ctx, "finished processing querylog and stats") pctx := dctx.proxyCtx q := pctx.Req.Question[0] @@ -27,7 +27,7 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) { s.anonymizer.Load()(ip) ipStr := net.IP(ip).String() - log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr) + s.logger.DebugContext(ctx, "client ip for stats and querylog", "ip", ipStr) ids := []string{ipStr} if dctx.clientID != "" { @@ -47,24 +47,26 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) { if s.shouldLog(host, qt, cl, ids) { s.logQuery(dctx, ip, processingTime) } else { - log.Debug( - "dnsforward: request %s %s %q from %s ignored; not adding to querylog", - dns.Class(cl), - dns.Type(qt), - host, - ipStr, + s.logger.DebugContext( + ctx, + "not adding to querylog", + "dns_class", dns.Class(cl), + "dns_type", dns.Type(qt), + "host", host, + "ip", ipStr, ) } if s.shouldCountStat(host, qt, cl, ids) { s.updateStats(dctx, ipStr, processingTime) } else { - log.Debug( - "dnsforward: request %s %s %q from %s ignored; not counting in stats", - dns.Class(cl), - dns.Type(qt), - host, - ipStr, + s.logger.DebugContext( + ctx, + "not counting in stats", + "dns_class", dns.Class(cl), + "dns_type", dns.Type(qt), + "host", host, + "ip", ipStr, ) } diff --git a/internal/dnsforward/stats_internal_test.go b/internal/dnsforward/stats_internal_test.go index 301f8c8b..3bce993a 100644 --- a/internal/dnsforward/stats_internal_test.go +++ b/internal/dnsforward/stats_internal_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/stats" "github.com/AdguardTeam/dnsproxy/proxy" "github.com/AdguardTeam/dnsproxy/upstream" + "github.com/AdguardTeam/golibs/testutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -203,6 +204,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) { st := &testStats{} srv := &Server{ baseLogger: testLogger, + logger: testLogger, queryLog: ql, stats: st, anonymizer: aghnet.NewIPMut(nil), @@ -229,7 +231,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) { clientID: tc.clientID, } - code := srv.processQueryLogsAndStats(dctx) + code := srv.processQueryLogsAndStats(testutil.ContextWithTimeout(t, testTimeout), dctx) assert.Equal(t, tc.wantCode, code) assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto) assert.Equal(t, tc.wantStatClient, st.lastEntry.Client) diff --git a/internal/dnsforward/svcbmsg.go b/internal/dnsforward/svcbmsg.go index 96983dee..07c171aa 100644 --- a/internal/dnsforward/svcbmsg.go +++ b/internal/dnsforward/svcbmsg.go @@ -1,6 +1,7 @@ package dnsforward import ( + "context" "encoding/base64" "net" "strconv" @@ -14,9 +15,9 @@ import ( // // See the comment on genAnswerSVCB for a list of current restrictions on // parameter values. -func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) { +func (s *Server) genAnswerHTTPS(ctx context.Context, req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) { ans = &dns.HTTPS{ - SVCB: *s.genAnswerSVCB(req, svcb), + SVCB: *s.genAnswerSVCB(ctx, req, svcb), } ans.Hdr.Rrtype = dns.TypeHTTPS @@ -163,7 +164,11 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{ // ipv4hint="127.0.0.1,127.0.0.2" // Unsupported. // // TODO(a.garipov): Support all of these. -func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB) { +func (s *Server) genAnswerSVCB( + ctx context.Context, + req *dns.Msg, + svcb *rules.DNSSVCB, +) (ans *dns.SVCB) { ans = &dns.SVCB{ Hdr: s.hdr(req, dns.TypeSVCB), Priority: svcb.Priority, @@ -177,7 +182,7 @@ func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB for k, valStr := range svcb.Params { handler, ok := svcbKeyHandlers[k] if !ok { - log.Debug("unknown svcb/https key %q, ignoring", k) + s.logger.DebugContext(ctx, "unknown svcb/https key, ignoring", "key", k) continue } diff --git a/internal/dnsforward/svcbmsg_internal_test.go b/internal/dnsforward/svcbmsg_internal_test.go index 7de96da8..9a2e6d93 100644 --- a/internal/dnsforward/svcbmsg_internal_test.go +++ b/internal/dnsforward/svcbmsg_internal_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/filtering" + "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -152,14 +153,14 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) { want := &dns.HTTPS{SVCB: *tc.want} want.Hdr.Rrtype = dns.TypeHTTPS - got := s.genAnswerHTTPS(req, tc.svcb) + got := s.genAnswerHTTPS(testutil.ContextWithTimeout(t, testTimeout), req, tc.svcb) assert.Equal(t, want, got) }) }) t.Run("svcb", func(t *testing.T) { t.Run(tc.name, func(t *testing.T) { - got := s.genAnswerSVCB(req, tc.svcb) + got := s.genAnswerSVCB(testutil.ContextWithTimeout(t, testTimeout), req, tc.svcb) assert.Equal(t, tc.want, got) }) }) diff --git a/internal/filtering/blocked.go b/internal/filtering/blocked.go index 8150f309..855db7d6 100644 --- a/internal/filtering/blocked.go +++ b/internal/filtering/blocked.go @@ -159,6 +159,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req // // Deprecated: Use handleBlockedServicesUpdate. func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + list := []string{} err := json.NewDecoder(r.Body).Decode(&list) if err != nil { @@ -172,10 +174,10 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ defer d.confMu.Unlock() d.conf.BlockedServices.IDs = list - d.logger.DebugContext(r.Context(), "updated blocked services list", "len", len(list)) + d.logger.DebugContext(ctx, "updated blocked services list", "len", len(list)) }() - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(ctx) } // handleBlockedServicesGet is the handler for the GET @@ -195,6 +197,8 @@ func (d *DNSFilter) handleBlockedServicesGet(w http.ResponseWriter, r *http.Requ // handleBlockedServicesUpdate is the handler for the PUT // /control/blocked_services/update HTTP API. func (d *DNSFilter) handleBlockedServicesUpdate(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + bsvc := &BlockedServices{} err := json.NewDecoder(r.Body).Decode(bsvc) if err != nil { @@ -221,7 +225,7 @@ func (d *DNSFilter) handleBlockedServicesUpdate(w http.ResponseWriter, r *http.R d.conf.BlockedServices = bsvc }() - d.logger.DebugContext(r.Context(), "updated blocked services schedule", "len", len(bsvc.IDs)) + d.logger.DebugContext(ctx, "updated blocked services schedule", "len", len(bsvc.IDs)) - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(ctx) } diff --git a/internal/filtering/filtering.go b/internal/filtering/filtering.go index 874283ec..ac4e2f94 100644 --- a/internal/filtering/filtering.go +++ b/internal/filtering/filtering.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" @@ -104,8 +105,9 @@ type Config struct { // TODO(e.burkov): Move it to dnsforward entirely. EtcHosts hostsfile.Storage `yaml:"-"` - // Called when the configuration is changed by HTTP request - ConfigModified func() `yaml:"-"` + // ConfModifier is used to update the global configuration. It must not be + // nil. + ConfModifier agh.ConfigModifier `yaml:"-"` // Register an HTTP handler HTTPRegister aghhttp.RegisterFunc `yaml:"-"` diff --git a/internal/filtering/http.go b/internal/filtering/http.go index dca039a0..55b8d4ac 100644 --- a/internal/filtering/http.go +++ b/internal/filtering/http.go @@ -133,7 +133,7 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request return } - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) d.EnableFilters(true) _, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount) @@ -202,7 +202,7 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ d.logger.InfoContext(ctx, "deleted filter", "id", deleted.ID) }() - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(ctx) d.EnableFilters(true) // NOTE: The old files "filter.txt.old" aren't deleted. It's not really @@ -264,7 +264,7 @@ func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request return } - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) if restart { d.EnableFilters(true) } @@ -289,7 +289,7 @@ func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque } d.conf.UserRules = req.Rules - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) d.EnableFilters(true) } @@ -403,7 +403,7 @@ func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request d.conf.FiltersUpdateIntervalHours = req.Interval }() - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) d.EnableFilters(true) } @@ -571,14 +571,14 @@ func protectedBool(mu *sync.RWMutex, ptr *bool) (val bool) { // /control/safebrowsing/enable HTTP API. func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) { setProtectedBool(d.confMu, &d.conf.SafeBrowsingEnabled, true) - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) } // handleSafeBrowsingDisable is the handler for the POST // /control/safebrowsing/disable HTTP API. func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) { setProtectedBool(d.confMu, &d.conf.SafeBrowsingEnabled, false) - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) } // handleSafeBrowsingStatus is the handler for the GET @@ -597,14 +597,14 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ // HTTP API. func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) { setProtectedBool(d.confMu, &d.conf.ParentalEnabled, true) - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) } // handleParentalDisable is the handler for the POST /control/parental/disable // HTTP API. func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) { setProtectedBool(d.confMu, &d.conf.ParentalEnabled, false) - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) } // handleParentalStatus is the handler for the GET /control/parental/status diff --git a/internal/filtering/http_internal_test.go b/internal/filtering/http_internal_test.go index 4d45e254..326d91ff 100644 --- a/internal/filtering/http_internal_test.go +++ b/internal/filtering/http_internal_test.go @@ -2,6 +2,7 @@ package filtering import ( "bytes" + "context" "encoding/json" "fmt" "net/http" @@ -11,6 +12,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/schedule" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" @@ -103,6 +105,10 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { confModifiedCalled := false + confModifier := &aghtest.ConfigModifier{} + confModifier.OnApply = func(_ context.Context) { + confModifiedCalled = true + } d, err := New(&Config{ Logger: slogutil.NewDiscardLogger(), FilteringEnabled: true, @@ -110,8 +116,8 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) { HTTPClient: &http.Client{ Timeout: 5 * time.Second, }, - ConfigModified: func() { confModifiedCalled = true }, - DataDir: filtersDir, + ConfModifier: confModifier, + DataDir: filtersDir, }, nil) require.NoError(t, err) t.Cleanup(d.Close) @@ -183,13 +189,15 @@ func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { handlers := make(map[string]http.Handler) + confModifier := &aghtest.ConfigModifier{} + confModifier.OnApply = func(_ context.Context) { + testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout) + } d, err := New(&Config{ - Logger: slogutil.NewDiscardLogger(), - ConfigModified: func() { - testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout) - }, - DataDir: filtersDir, + Logger: slogutil.NewDiscardLogger(), + ConfModifier: confModifier, + DataDir: filtersDir, HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler }, @@ -268,13 +276,15 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { handlers := make(map[string]http.Handler) + confModifier := &aghtest.ConfigModifier{} + confModifier.OnApply = func(_ context.Context) { + testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout) + } d, err := New(&Config{ - Logger: slogutil.NewDiscardLogger(), - ConfigModified: func() { - testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout) - }, - DataDir: filtersDir, + Logger: slogutil.NewDiscardLogger(), + ConfModifier: confModifier, + DataDir: filtersDir, HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler }, diff --git a/internal/filtering/rewritehttp.go b/internal/filtering/rewritehttp.go index d6415a05..d685b27d 100644 --- a/internal/filtering/rewritehttp.go +++ b/internal/filtering/rewritehttp.go @@ -36,6 +36,8 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) { // handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API. func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + rwJSON := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&rwJSON) if err != nil { @@ -49,7 +51,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { Answer: rwJSON.Answer, } - err = rw.normalize(r.Context(), d.logger) + err = rw.normalize(ctx, d.logger) if err != nil { // Shouldn't happen currently, since normalize only returns a non-nil // error when a rewrite is nil, but be change-proof. @@ -64,7 +66,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { d.conf.Rewrites = append(d.conf.Rewrites, rw) d.logger.DebugContext( - r.Context(), + ctx, "added rewrite element", "domain", rw.Domain, "answer", rw.Answer, @@ -72,12 +74,14 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) { ) }() - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(ctx) } // handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP // API. func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + jsent := rewriteEntryJSON{} err := json.NewDecoder(r.Body).Decode(&jsent) if err != nil { @@ -92,28 +96,27 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) } arr := []*LegacyRewrite{} - func() { - d.confMu.Lock() - defer d.confMu.Unlock() + defer d.conf.ConfModifier.Apply(ctx) - for _, ent := range d.conf.Rewrites { - if ent.equal(entDel) { - d.logger.DebugContext( - r.Context(), - "removed rewrite element", - "domain", ent.Domain, - "answer", ent.Answer, - ) - - continue - } + d.confMu.Lock() + defer d.confMu.Unlock() + for _, ent := range d.conf.Rewrites { + if !ent.equal(entDel) { arr = append(arr, ent) - } - d.conf.Rewrites = arr - }() - d.conf.ConfigModified() + continue + } + + d.logger.DebugContext( + ctx, + "removed rewrite element", + "domain", ent.Domain, + "answer", ent.Answer, + ) + } + + d.conf.Rewrites = arr } // rewriteUpdateJSON is a struct for JSON object with rewrite rule update info. @@ -125,6 +128,8 @@ type rewriteUpdateJSON struct { // handleRewriteUpdate is the handler for the PUT /control/rewrite/update HTTP // API. func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + updateJSON := rewriteUpdateJSON{} err := json.NewDecoder(r.Body).Decode(&updateJSON) if err != nil { @@ -143,7 +148,7 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) Answer: updateJSON.Update.Answer, } - err = rwAdd.normalize(r.Context(), d.logger) + err = rwAdd.normalize(ctx, d.logger) if err != nil { // Shouldn't happen currently, since normalize only returns a non-nil // error when a rewrite is nil, but be change-proof. @@ -155,7 +160,7 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) index := -1 defer func() { if index >= 0 { - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(ctx) } }() @@ -171,7 +176,6 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) d.conf.Rewrites = slices.Replace(d.conf.Rewrites, index, index+1, rwAdd) - ctx := r.Context() d.logger.DebugContext( ctx, "removed rewrite element", diff --git a/internal/filtering/rewritehttp_test.go b/internal/filtering/rewritehttp_test.go index b95435b8..4448a4c0 100644 --- a/internal/filtering/rewritehttp_test.go +++ b/internal/filtering/rewritehttp_test.go @@ -2,6 +2,7 @@ package filtering_test import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -9,6 +10,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" @@ -148,20 +150,17 @@ func TestDNSFilter_handleRewriteHTTP(t *testing.T) { }} for _, tc := range testCases { - onConfModified := func() { - if !tc.wantConfMod { - panic("config modified has been fired") - } - - testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout) - } - t.Run(tc.name, func(t *testing.T) { handlers := make(map[string]http.Handler) + confModifier := &aghtest.ConfigModifier{} + confModifier.OnApply = func(_ context.Context) { + require.Truef(t, tc.wantConfMod, "config modified has been fired") + testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout) + } d, err := filtering.New(&filtering.Config{ - Logger: slogutil.NewDiscardLogger(), - ConfigModified: onConfModified, + Logger: slogutil.NewDiscardLogger(), + ConfModifier: confModifier, HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler }, diff --git a/internal/filtering/safesearchhttp.go b/internal/filtering/safesearchhttp.go index 8790b297..b7a6a4f3 100644 --- a/internal/filtering/safesearchhttp.go +++ b/internal/filtering/safesearchhttp.go @@ -13,7 +13,7 @@ import ( // Deprecated: Use handleSafeSearchSettings. func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) { setProtectedBool(d.confMu, &d.conf.SafeSearchConf.Enabled, true) - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) } // handleSafeSearchDisable is the handler for POST /control/safesearch/disable @@ -22,7 +22,7 @@ func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Reques // Deprecated: Use handleSafeSearchSettings. func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) { setProtectedBool(d.confMu, &d.conf.SafeSearchConf.Enabled, false) - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(r.Context()) } // handleSafeSearchStatus is the handler for GET /control/safesearch/status @@ -42,6 +42,8 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques // handleSafeSearchSettings is the handler for PUT /control/safesearch/settings // HTTP API. func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + req := &SafeSearchConfig{} err := json.NewDecoder(r.Body).Decode(req) if err != nil { @@ -51,7 +53,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ } conf := *req - err = d.safeSearch.Update(r.Context(), conf) + err = d.safeSearch.Update(ctx, conf) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err) @@ -65,7 +67,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ d.conf.SafeSearchConf = conf }() - d.conf.ConfigModified() + d.conf.ConfModifier.Apply(ctx) aghhttp.OK(w) } diff --git a/internal/home/authhttp_internal_test.go b/internal/home/authhttp_internal_test.go index 93ffa8e8..8215bd28 100644 --- a/internal/home/authhttp_internal_test.go +++ b/internal/home/authhttp_internal_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghuser" "github.com/AdguardTeam/golibs/httphdr" @@ -327,7 +328,17 @@ func TestAuth_ServeHTTP_firstRun(t *testing.T) { globalContext.mux = mux ctx := testutil.ContextWithTimeout(t, testTimeout) - web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false) + web, err := initWeb( + ctx, + options{}, + nil, + nil, + testLogger, + nil, + nil, + agh.EmptyConfigModifier{}, + false, + ) require.NoError(t, err) globalContext.web = web @@ -477,13 +488,23 @@ func TestAuth_ServeHTTP_auth(t *testing.T) { globalContext.mux = http.NewServeMux() tlsMgr, err := newTLSManager(testutil.ContextWithTimeout(t, testTimeout), &tlsManagerConfig{ - logger: testLogger, - configModified: func() {}, + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, }) require.NoError(t, err) ctx := testutil.ContextWithTimeout(t, testTimeout) - web, err := initWeb(ctx, options{}, nil, nil, testLogger, tlsMgr, auth, false) + web, err := initWeb( + ctx, + options{}, + nil, + nil, + testLogger, + tlsMgr, + auth, + agh.EmptyConfigModifier{}, + false, + ) require.NoError(t, err) globalContext.web = web @@ -625,7 +646,16 @@ func TestAuth_ServeHTTP_logout(t *testing.T) { globalContext.mux = http.NewServeMux() ctx := testutil.ContextWithTimeout(t, testTimeout) - web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, auth, false) + web, err := initWeb(ctx, + options{}, + nil, + nil, + testLogger, + nil, + auth, + agh.EmptyConfigModifier{}, + false, + ) require.NoError(t, err) globalContext.web = web diff --git a/internal/home/clients.go b/internal/home/clients.go index e6459342..a6f37ba1 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/arpdb" "github.com/AdguardTeam/AdGuardHome/internal/client" @@ -39,6 +40,10 @@ type clientsContainer struct { // settings. clientChecker BlockedClientChecker + // confModifier is used to update the global configuration. It must not be + // nil. + confModifier agh.ConfigModifier + // lock protects all fields. // // TODO(a.garipov): Use a pointer and describe which fields are protected in @@ -52,11 +57,6 @@ type clientsContainer struct { // safeSearchCacheTTL is the TTL of the safe search cache to use for // persistent clients. safeSearchCacheTTL time.Duration - - // testing is a flag that disables some features for internal tests. - // - // TODO(a.garipov): Awful. Remove. - testing bool } // BlockedClientChecker checks if a client is blocked by the current access @@ -78,6 +78,7 @@ func (clients *clientsContainer) Init( arpDB arpdb.Interface, filteringConf *filtering.Config, sigHdlr *signalHandler, + confModifier agh.ConfigModifier, ) (err error) { // TODO(s.chzhen): Refactor it. if clients.storage != nil { @@ -88,6 +89,7 @@ func (clients *clientsContainer) Init( clients.logger = baseLogger.With(slogutil.KeyPrefix, "client_container") clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime) + clients.confModifier = confModifier confClients := make([]*client.Persistent, 0, len(objects)) for i, o := range objects { @@ -141,10 +143,6 @@ var webHandlersRegistered = false // Start starts the clients container. func (clients *clientsContainer) Start(ctx context.Context) (err error) { - if clients.testing { - return - } - if !webHandlersRegistered { webHandlersRegistered = true clients.registerWebHandlers() diff --git a/internal/home/clients_internal_test.go b/internal/home/clients_internal_test.go index 917b20e0..adf84870 100644 --- a/internal/home/clients_internal_test.go +++ b/internal/home/clients_internal_test.go @@ -3,6 +3,7 @@ package home import ( "testing" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/golibs/testutil" @@ -14,9 +15,7 @@ import ( func newClientsContainer(t *testing.T) (c *clientsContainer) { t.Helper() - c = &clientsContainer{ - testing: true, - } + c = &clientsContainer{} ctx := testutil.ContextWithTimeout(t, testTimeout) err := c.Init( @@ -30,6 +29,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) { Logger: testLogger, }, newSignalHandler(testLogger, nil, nil), + agh.EmptyConfigModifier{}, ) require.NoError(t, err) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index f8260dbd..f1691166 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -326,6 +326,8 @@ func clientToJSON(c *client.Persistent) (cj *clientJSON) { // handleAddClient is the handler for POST /control/clients/add HTTP API. func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + cj := clientJSON{} err := json.NewDecoder(r.Body).Decode(&cj) if err != nil { @@ -334,27 +336,27 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http. return } - c, err := clients.jsonToClient(r.Context(), cj, nil) + c, err := clients.jsonToClient(ctx, cj, nil) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - err = clients.storage.Add(r.Context(), c) + err = clients.storage.Add(ctx, c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - if !clients.testing { - onConfigModified() - } + clients.confModifier.Apply(ctx) } // handleDelClient is the handler for POST /control/clients/delete HTTP API. func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + cj := clientJSON{} err := json.NewDecoder(r.Body).Decode(&cj) if err != nil { @@ -369,15 +371,13 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http. return } - if !clients.storage.RemoveByName(r.Context(), cj.Name) { + if !clients.storage.RemoveByName(ctx, cj.Name) { aghhttp.Error(r, w, http.StatusBadRequest, "Client not found") return } - if !clients.testing { - onConfigModified() - } + clients.confModifier.Apply(ctx) } // updateJSON contains the name and data of the updated persistent client. @@ -390,6 +390,8 @@ type updateJSON struct { // // TODO(s.chzhen): Accept updated parameters instead of whole structure. func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + dj := updateJSON{} err := json.NewDecoder(r.Body).Decode(&dj) if err != nil { @@ -404,23 +406,21 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht return } - c, err := clients.jsonToClient(r.Context(), dj.Data, nil) + c, err := clients.jsonToClient(ctx, dj.Data, nil) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - err = clients.storage.Update(r.Context(), dj.Name, c) + err = clients.storage.Update(ctx, dj.Name, c) if err != nil { aghhttp.Error(r, w, http.StatusBadRequest, "%s", err) return } - if !clients.testing { - onConfigModified() - } + clients.confModifier.Apply(ctx) } // handleFindClient is the handler for GET /control/clients/find HTTP API. diff --git a/internal/home/config.go b/internal/home/config.go index ac6636d7..607f14ee 100644 --- a/internal/home/config.go +++ b/internal/home/config.go @@ -4,12 +4,14 @@ import ( "bytes" "context" "fmt" + "log/slog" "net/netip" "os" "path/filepath" "slices" "sync" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/AdGuardHome/internal/aghtls" @@ -23,6 +25,7 @@ import ( "github.com/AdguardTeam/dnsproxy/fastip" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/timeutil" "github.com/google/go-cmp/cmp" @@ -838,3 +841,47 @@ func validateTLSCipherIDs(cipherIDs []string) (err error) { return nil } + +// defaultConfigModifier is a default [agh.ConfigModifier] implementation. +type defaultConfigModifier struct { + auth *auth + config *configuration + logger *slog.Logger + tlsMgr *tlsManager +} + +// newDefaultConfigModifier returns the new properly initialized +// *defaultConfigModifier. All arguments must not be nil. +// +// TODO(s.chzhen): Consider using configuration struct. +func newDefaultConfigModifier( + conf *configuration, + l *slog.Logger, +) (cm *defaultConfigModifier) { + return &defaultConfigModifier{ + config: conf, + logger: l, + } +} + +// type check +var _ agh.ConfigModifier = (*defaultConfigModifier)(nil) + +// Apply implements the [agh.ConfigModifier] interface for +// *defaultConfigModifier. +func (cm *defaultConfigModifier) Apply(ctx context.Context) { + err := cm.config.write(cm.tlsMgr, cm.auth) + if err != nil { + cm.logger.ErrorContext(ctx, "writing config", slogutil.KeyError, err) + } +} + +// setAuth sets the auth parameters used by Apply. +func (cm *defaultConfigModifier) setAuth(a *auth) { + cm.auth = a +} + +// setTLSManager sets the TLS manager used by Apply. +func (cm *defaultConfigModifier) setTLSManager(m *tlsManager) { + cm.tlsMgr = m +} diff --git a/internal/home/control.go b/internal/home/control.go index 45168805..75c82e5d 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -115,6 +115,8 @@ type statusResponse struct { } func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + dnsAddrs, err := collectDNSAddresses(web.tlsManager) if err != nil { // Don't add a lot of formatting, since the error is already @@ -125,14 +127,14 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) { } var ( - fltConf *dnsforward.Config - protectionDisabledUntil *time.Time - protectionEnabled bool + fltConf *dnsforward.Config + protDisabledUntil *time.Time + protEnabled bool ) if globalContext.dnsServer != nil { fltConf = &dnsforward.Config{} globalContext.dnsServer.WriteDiskConfig(fltConf) - protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus() + protEnabled, protDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus(ctx) } var resp statusResponse @@ -141,11 +143,11 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) { defer config.RUnlock() var protectionDisabledDuration int64 - if protectionDisabledUntil != nil { + if protDisabledUntil != nil { // Make sure that we don't send negative numbers to the frontend, // since enough time might have passed to make the difference less // than zero. - protectionDisabledDuration = max(0, time.Until(*protectionDisabledUntil).Milliseconds()) + protectionDisabledDuration = max(0, time.Until(*protDisabledUntil).Milliseconds()) } resp = statusResponse{ @@ -155,7 +157,7 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) { DNSPort: config.DNS.Port, HTTPPort: config.HTTPConfig.Address.Port(), ProtectionDisabledDuration: protectionDisabledDuration, - ProtectionEnabled: protectionEnabled, + ProtectionEnabled: protEnabled, IsRunning: isRunning(), } }() @@ -175,10 +177,10 @@ func registerControlHandlers(web *webAPI) { httpRegister(http.MethodPost, "/control/update", web.handleUpdate) httpRegister(http.MethodGet, "/control/status", web.handleStatus) - httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage) + httpRegister(http.MethodPost, "/control/i18n/change_language", web.handleI18nChangeLanguage) httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage) httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile) - httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile) + httpRegister(http.MethodPut, "/control/profile/update", web.handlePutProfile) // No auth is necessary for DoH/DoT configurations globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH)) diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 50d36a91..8ea1cfa2 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -15,6 +15,7 @@ import ( "time" "unicode/utf8" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" @@ -455,7 +456,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request // moment we'll allow setting up TLS in the initial configuration or the // configuration itself will use HTTPS protocol, because the underlying // functions potentially restart the HTTPS server. - err = startMods(ctx, web.baseLogger, web.tlsManager) + err = startMods(ctx, web.baseLogger, web.tlsManager, web.confModifier) if err != nil { globalContext.firstRun = true copyInstallSettings(config, curConfig) @@ -532,13 +533,18 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e // startMods initializes and starts the DNS server after installation. // baseLogger and tlsMgr must not be nil. -func startMods(ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager) (err error) { +func startMods( + ctx context.Context, + baseLogger *slog.Logger, + tlsMgr *tlsManager, + confModifier agh.ConfigModifier, +) (err error) { statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config) if err != nil { return err } - err = initDNS(baseLogger, tlsMgr, statsDir, querylogDir) + err = initDNS(ctx, baseLogger, tlsMgr, confModifier, statsDir, querylogDir) if err != nil { return err } @@ -547,7 +553,7 @@ func startMods(ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager) err = startDNSServer() if err != nil { - closeDNSServer() + closeDNSServer(ctx) return err } diff --git a/internal/home/dns.go b/internal/home/dns.go index 79fcbb4e..07372efd 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -12,6 +12,7 @@ import ( "path/filepath" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" @@ -38,23 +39,15 @@ const ( defaultPortTLS uint16 = 853 ) -// Called by other modules when configuration is changed -// -// TODO(s.chzhen): Remove this after refactoring. -func onConfigModified() { - err := config.write(globalContext.tls, globalContext.auth) - if err != nil { - log.Error("writing config: %s", err) - } -} - // initDNS updates all the fields of the [globalContext] needed to initialize // the DNS server and initializes it at last. It also must not be called unless -// [config] and [globalContext] are initialized. baseLogger and tlsMgr must not -// be nil. +// [config] and [globalContext] are initialized. baseLogger, tlsMgr and +// confModfier must not be nil. func initDNS( + ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager, + confModifier agh.ConfigModifier, statsDir string, querylogDir string, ) (err error) { @@ -64,7 +57,7 @@ func initDNS( Logger: baseLogger.With(slogutil.KeyPrefix, "stats"), Filename: filepath.Join(statsDir, "stats.db"), Limit: time.Duration(config.Stats.Interval), - ConfigModified: onConfigModified, + ConfigModifier: confModifier, HTTPRegister: httpRegister, Enabled: config.Stats.Enabled, ShouldCountClient: globalContext.clients.shouldCountClient, @@ -84,7 +77,7 @@ func initDNS( conf := querylog.Config{ Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"), Anonymizer: anonymizer, - ConfigModified: onConfigModified, + ConfigModifier: confModifier, HTTPRegister: httpRegister, FindClient: globalContext.clients.findMultiple, BaseDir: querylogDir, @@ -113,6 +106,7 @@ func initDNS( } return initDNSServer( + ctx, globalContext.filters, globalContext.stats, globalContext.queryLog, @@ -121,6 +115,7 @@ func initDNS( httpRegister, tlsMgr, baseLogger, + confModifier, ) } @@ -131,6 +126,7 @@ func initDNS( // // TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter. func initDNSServer( + ctx context.Context, filters *filtering.DNSFilter, sts stats.Interface, qlog querylog.QueryLog, @@ -139,6 +135,7 @@ func initDNSServer( httpReg aghhttp.RegisterFunc, tlsMgr *tlsManager, l *slog.Logger, + confModifier agh.ConfigModifier, ) (err error) { globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{ Logger: l, @@ -153,7 +150,7 @@ func initDNSServer( }) defer func() { if err != nil { - closeDNSServer() + closeDNSServer(ctx) } }() if err != nil { @@ -169,6 +166,7 @@ func initDNSServer( tlsMgr, httpReg, globalContext.clients.storage, + confModifier, ) if err != nil { return fmt.Errorf("newServerConfig: %w", err) @@ -176,12 +174,12 @@ func initDNSServer( // Try to prepare the server with disabled private RDNS resolution if it // failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError]. - err = globalContext.dnsServer.Prepare(dnsConf) + err = globalContext.dnsServer.Prepare(ctx, dnsConf) if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) { log.Info("WARNING: %s; trying to disable private RDNS resolution", err) dnsConf.UsePrivateRDNS = false - err = globalContext.dnsServer.Prepare(dnsConf) + err = globalContext.dnsServer.Prepare(ctx, dnsConf) } if err != nil { @@ -245,6 +243,7 @@ func newServerConfig( tlsMgr *tlsManager, httpReg aghhttp.RegisterFunc, clientsContainer dnsforward.ClientsContainer, + confModifier agh.ConfigModifier, ) (newConf *dnsforward.ServerConfig, err error) { hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()}) @@ -264,7 +263,7 @@ func newServerConfig( TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH, UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout), TLSv12Roots: tlsMgr.rootCerts, - ConfigModified: onConfigModified, + ConfModifier: confModifier, HTTPRegister: httpReg, LocalPTRResolvers: dnsConf.PrivateRDNSResolvers, UseDNS64: dnsConf.UseDNS64, @@ -454,7 +453,7 @@ func startDNSServer() error { return fmt.Errorf("starting clients container: %w", err) } - err = globalContext.dnsServer.Start() + err = globalContext.dnsServer.Start(ctx) if err != nil { return fmt.Errorf("starting dns server: %w", err) } @@ -470,30 +469,30 @@ func startDNSServer() error { return nil } -func stopDNSServer() (err error) { +func stopDNSServer(ctx context.Context) (err error) { if !isRunning() { return nil } - err = globalContext.dnsServer.Stop() + err = globalContext.dnsServer.Stop(ctx) if err != nil { return fmt.Errorf("stopping forwarding dns server: %w", err) } - err = globalContext.clients.close(context.TODO()) + err = globalContext.clients.close(ctx) if err != nil { return fmt.Errorf("closing clients container: %w", err) } - closeDNSServer() + closeDNSServer(ctx) return nil } -func closeDNSServer() { +func closeDNSServer(ctx context.Context) { // DNS forward module must be closed BEFORE stats or queryLog because it depends on them if globalContext.dnsServer != nil { - globalContext.dnsServer.Close() + globalContext.dnsServer.Close(ctx) globalContext.dnsServer = nil } @@ -509,8 +508,7 @@ func closeDNSServer() { } if globalContext.queryLog != nil { - // TODO(s.chzhen): Pass context. - err := globalContext.queryLog.Shutdown(context.TODO()) + err := globalContext.queryLog.Shutdown(ctx) if err != nil { log.Error("closing query log: %s", err) } diff --git a/internal/home/home.go b/internal/home/home.go index 7acd29ad..9536219d 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -18,6 +18,7 @@ import ( "syscall" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghos" @@ -297,12 +298,13 @@ func initContextClients( ctx context.Context, logger *slog.Logger, sigHdlr *signalHandler, + confModifier agh.ConfigModifier, ) (err error) { //lint:ignore SA1019 Migration is not over. config.DHCP.WorkDir = globalContext.workDir config.DHCP.DataDir = globalContext.getDataDir() config.DHCP.HTTPRegister = httpRegister - config.DHCP.ConfigModified = onConfigModified + config.DHCP.ConfModifier = confModifier globalContext.dhcpServer, err = dhcpd.Create(config.DHCP) if globalContext.dhcpServer == nil || err != nil { @@ -327,6 +329,7 @@ func initContextClients( arpDB, config.Filtering, sigHdlr, + confModifier, ) } @@ -377,6 +380,7 @@ func setupDNSFilteringConf( baseLogger *slog.Logger, conf *filtering.Config, tlsMgr *tlsManager, + confModifier agh.ConfigModifier, ) (err error) { const ( dnsTimeout = 3 * time.Second @@ -398,7 +402,7 @@ func setupDNSFilteringConf( conf.EtcHosts = nil } - conf.ConfigModified = onConfigModified + conf.ConfModifier = confModifier conf.HTTPRegister = httpRegister conf.DataDir = globalContext.getDataDir() conf.Filters = slices.Clone(config.Filters) @@ -564,6 +568,7 @@ func initWeb( baseLogger *slog.Logger, tlsMgr *tlsManager, auth *auth, + confModifier agh.ConfigModifier, isCustomUpdURL bool, ) (web *webAPI, err error) { logger := baseLogger.With(slogutil.KeyPrefix, "webapi") @@ -583,11 +588,12 @@ func initWeb( disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL) webConf := &webConfig{ - updater: upd, - logger: logger, - baseLogger: baseLogger, - tlsManager: tlsMgr, - auth: auth, + updater: upd, + logger: logger, + baseLogger: baseLogger, + confModifier: confModifier, + tlsManager: tlsMgr, + auth: auth, clientFS: clientFS, @@ -661,25 +667,31 @@ func run( // data first, but also to avoid relying on automatic Go init() function. filtering.InitModule(ctx, slogLogger) - err = initContextClients(ctx, slogLogger, sigHdlr) + confModifier := newDefaultConfigModifier( + config, + slogLogger.With(slogutil.KeyPrefix, "config_modifier"), + ) + + err = initContextClients(ctx, slogLogger, sigHdlr, confModifier) fatalOnError(err) tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager") tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: tlsMgrLogger, - configModified: onConfigModified, - tlsSettings: config.TLS, - servePlainDNS: config.DNS.ServePlainDNS, + logger: tlsMgrLogger, + confModifier: confModifier, + tlsSettings: config.TLS, + servePlainDNS: config.DNS.ServePlainDNS, }) if err != nil { tlsMgrLogger.ErrorContext(ctx, "initializing", slogutil.KeyError, err) - onConfigModified() + confModifier.Apply(ctx) } globalContext.tls = tlsMgr + confModifier.setTLSManager(tlsMgr) - err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr) + err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr, confModifier) fatalOnError(err) err = setupOpts(opts) @@ -715,8 +727,19 @@ func run( fatalOnError(err) globalContext.auth = auth + confModifier.setAuth(auth) - web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, auth, isCustomURL) + web, err := initWeb( + ctx, + opts, + clientBuildFS, + upd, + slogLogger, + tlsMgr, + auth, + confModifier, + isCustomURL, + ) fatalOnError(err) globalContext.web = web @@ -728,7 +751,7 @@ func run( fatalOnError(err) if !globalContext.firstRun { - err = initDNS(slogLogger, tlsMgr, statsDir, querylogDir) + err = initDNS(ctx, slogLogger, tlsMgr, confModifier, statsDir, querylogDir) fatalOnError(err) tlsMgr.start(ctx) @@ -736,7 +759,7 @@ func run( go func() { startErr := startDNSServer() if startErr != nil { - closeDNSServer() + closeDNSServer(ctx) fatalOnError(startErr) } }() @@ -972,7 +995,7 @@ func cleanup(ctx context.Context) { globalContext.web = nil } - err := stopDNSServer() + err := stopDNSServer(ctx) if err != nil { log.Error("stopping dns server: %s", err) } @@ -1130,7 +1153,7 @@ func cmdlineUpdate( // // TODO(e.burkov): We could probably initialize the internal resolver // separately. - err := initDNSServer(nil, nil, nil, nil, nil, nil, tlsMgr, l) + err := initDNSServer(ctx, nil, nil, nil, nil, nil, nil, tlsMgr, l, agh.EmptyConfigModifier{}) fatalOnError(err) l.InfoContext(ctx, "performing update via cli") diff --git a/internal/home/i18n.go b/internal/home/i18n.go index d49ca2fa..4b616605 100644 --- a/internal/home/i18n.go +++ b/internal/home/i18n.go @@ -64,7 +64,9 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) { } // TODO(d.kolyshev): Deprecated, remove it later. -func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { +func (web *webAPI) handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if aghhttp.WriteTextPlainDeprecated(w, r) { return } @@ -89,9 +91,10 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) { defer config.Unlock() config.Language = lang - log.Printf("home: language is set to %s", lang) + web.logger.InfoContext(ctx, "language is updated", "lang", lang) }() - onConfigModified() + web.confModifier.Apply(ctx) + aghhttp.OK(w) } diff --git a/internal/home/profilehttp.go b/internal/home/profilehttp.go index 8c3d6ef0..41d2c18e 100644 --- a/internal/home/profilehttp.go +++ b/internal/home/profilehttp.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" - "github.com/AdguardTeam/golibs/log" ) // Theme is an enum of all allowed UI themes. @@ -75,7 +74,9 @@ func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) { } // handlePutProfile is the handler for PUT /control/profile/update endpoint. -func handlePutProfile(w http.ResponseWriter, r *http.Request) { +func (web *webAPI) handlePutProfile(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if aghhttp.WriteTextPlainDeprecated(w, r) { return } @@ -103,10 +104,10 @@ func handlePutProfile(w http.ResponseWriter, r *http.Request) { config.Language = lang config.Theme = theme - log.Printf("home: language is set to %s", lang) - log.Printf("home: theme is set to %s", theme) + web.logger.InfoContext(ctx, "profile updated", "lang", lang, "theme", theme) }() - onConfigModified() + web.confModifier.Apply(ctx) + aghhttp.OK(w) } diff --git a/internal/home/tls.go b/internal/home/tls.go index 25ce29da..7f1d6316 100644 --- a/internal/home/tls.go +++ b/internal/home/tls.go @@ -20,6 +20,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" @@ -56,9 +57,8 @@ type tlsManager struct { // conf contains the TLS configuration settings. It must not be nil. conf *tlsConfigSettings - // configModified is called when the TLS configuration is changed via an - // HTTP request. - configModified func() + // confModifier is used to update the global configuration. + confModifier agh.ConfigModifier // customCipherIDs are the ID of the cipher suites that AdGuard Home must use. customCipherIDs []uint16 @@ -73,9 +73,9 @@ type tlsManagerConfig struct { // be nil. logger *slog.Logger - // configModified is called when the TLS configuration is changed via an - // HTTP request. It must not be nil. - configModified func() + // confModifier is used to update the global configuration. It must not be + // nil. + confModifier agh.ConfigModifier // tlsSettings contains the TLS configuration settings. tlsSettings tlsConfigSettings @@ -91,12 +91,12 @@ type tlsManagerConfig struct { // [tlsManager.setWebAPI]. func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, err error) { m = &tlsManager{ - logger: conf.logger, - mu: &sync.Mutex{}, - configModified: conf.configModified, - status: &tlsConfigStatus{}, - conf: &conf.tlsSettings, - servePlainDNS: conf.servePlainDNS, + logger: conf.logger, + mu: &sync.Mutex{}, + confModifier: conf.confModifier, + status: &tlsConfigStatus{}, + conf: &conf.tlsSettings, + servePlainDNS: conf.servePlainDNS, } m.rootCerts = aghtls.SystemRootCAs(ctx, conf.logger) @@ -232,7 +232,7 @@ func (m *tlsManager) reload(ctx context.Context) { m.certLastMod = fi.ModTime().UTC() - err = m.reconfigureDNSServer() + err = m.reconfigureDNSServer(ctx) if err != nil { m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) } @@ -245,7 +245,7 @@ func (m *tlsManager) reload(ctx context.Context) { // reconfigureDNSServer updates the DNS server configuration using the stored // TLS settings. m.mu is expected to be locked. -func (m *tlsManager) reconfigureDNSServer() (err error) { +func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) { newConf, err := newServerConfig( &config.DNS, config.Clients.Sources, @@ -253,12 +253,13 @@ func (m *tlsManager) reconfigureDNSServer() (err error) { m, httpRegister, globalContext.clients.storage, + m.confModifier, ) if err != nil { return fmt.Errorf("generating forwarding dns server config: %w", err) } - err = globalContext.dnsServer.Reconfigure(newConf) + err = globalContext.dnsServer.Reconfigure(ctx, newConf) if err != nil { return fmt.Errorf("starting forwarding dns server: %w", err) } @@ -515,7 +516,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) var restartHTTPS bool defer func() { if restartHTTPS { - m.configModified() + m.confModifier.Apply(ctx) } }() @@ -557,7 +558,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request) }() } - err = m.reconfigureDNSServer() + err = m.reconfigureDNSServer(ctx) if err != nil { m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err) diff --git a/internal/home/tls_internal_test.go b/internal/home/tls_internal_test.go index ad05fd44..6fc13d06 100644 --- a/internal/home/tls_internal_test.go +++ b/internal/home/tls_internal_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/client" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" @@ -66,9 +67,9 @@ func TestValidateCertificates(t *testing.T) { ctx := testutil.ContextWithTimeout(t, testTimeout) m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - configModified: func() {}, - servePlainDNS: false, + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, + servePlainDNS: false, }) require.NoError(t, err) @@ -246,8 +247,8 @@ func TestTLSManager_Reload(t *testing.T) { writeCertAndKey(t, certDER, certPath, key, keyPath) m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - configModified: func() {}, + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, tlsSettings: tlsConfigSettings{ Enabled: true, CertificatePath: certPath, @@ -257,7 +258,7 @@ func TestTLSManager_Reload(t *testing.T) { }) require.NoError(t, err) - web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false) + web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false) require.NoError(t, err) m.setWebAPI(web) @@ -272,7 +273,9 @@ func TestTLSManager_Reload(t *testing.T) { // The [tlsManager.reload] method will start the DNS server and it should be // stopped after the test ends. - testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout)) + }) conf = m.config() assertCertSerialNumber(t, conf, snAfter) @@ -285,8 +288,8 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) { ) m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - configModified: func() {}, + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, tlsSettings: tlsConfigSettings{ Enabled: true, CertificateChain: string(testCertChainData), @@ -321,13 +324,13 @@ func TestValidateTLSSettings(t *testing.T) { ) m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - configModified: func() {}, - servePlainDNS: false, + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, + servePlainDNS: false, }) require.NoError(t, err) - web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false) + web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false) require.NoError(t, err) m.setWebAPI(web) @@ -420,8 +423,8 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) { ) m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - configModified: func() {}, + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, tlsSettings: tlsConfigSettings{ Enabled: true, CertificateChain: string(testCertChainData), @@ -431,7 +434,7 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) { }) require.NoError(t, err) - web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false) + web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false) require.NoError(t, err) m.setWebAPI(web) @@ -476,15 +479,17 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { }) require.NoError(t, err) - err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{ - TLSConf: &dnsforward.TLSConfig{}, - Config: dnsforward.Config{ - UpstreamMode: dnsforward.UpstreamModeLoadBalance, - EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false}, - ClientsContainer: dnsforward.EmptyClientsContainer{}, - }, - ServePlainDNS: true, - }) + err = globalContext.dnsServer.Prepare( + testutil.ContextWithTimeout(t, testTimeout), + &dnsforward.ServerConfig{ + TLSConf: &dnsforward.TLSConfig{}, + Config: dnsforward.Config{ + UpstreamMode: dnsforward.UpstreamModeLoadBalance, + EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false}, + ClientsContainer: dnsforward.EmptyClientsContainer{}, + }, + ServePlainDNS: true, + }) require.NoError(t, err) globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{ @@ -511,8 +516,8 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { // Initialize the TLS manager and assert its configuration. m, err := newTLSManager(ctx, &tlsManagerConfig{ - logger: testLogger, - configModified: func() {}, + logger: testLogger, + confModifier: agh.EmptyConfigModifier{}, tlsSettings: tlsConfigSettings{ Enabled: true, CertificatePath: certPath, @@ -522,7 +527,7 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { }) require.NoError(t, err) - web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false) + web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false) require.NoError(t, err) m.setWebAPI(web) @@ -551,7 +556,9 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) { // The [tlsManager.handleTLSConfigure] method will start the DNS server and // it should be stopped after the test ends. - testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop) + testutil.CleanupAndRequireSuccess(t, func() (err error) { + return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout)) + }) res := &tlsConfig{ tlsConfigStatus: &tlsConfigStatus{}, diff --git a/internal/home/web.go b/internal/home/web.go index 96d4852f..40f59230 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -12,6 +12,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" @@ -47,6 +48,9 @@ type webConfig struct { // nil. baseLogger *slog.Logger + // confModifier is used to update the global configuration. + confModifier agh.ConfigModifier + // tlsManager contains the current configuration and state of TLS // encryption. It must not be nil. tlsManager *tlsManager @@ -104,6 +108,9 @@ type httpsServer struct { type webAPI struct { conf *webConfig + // confModifier is used to update the global configuration. + confModifier agh.ConfigModifier + // TODO(a.garipov): Refactor all these servers. httpServer *http.Server @@ -134,11 +141,12 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) { conf.logger.InfoContext(ctx, "initializing") w = &webAPI{ - conf: conf, - logger: conf.logger, - baseLogger: conf.baseLogger, - tlsManager: conf.tlsManager, - auth: conf.auth, + conf: conf, + confModifier: conf.confModifier, + logger: conf.logger, + baseLogger: conf.baseLogger, + tlsManager: conf.tlsManager, + auth: conf.auth, } clientFS := http.FileServer(http.FS(conf.clientFS)) diff --git a/internal/querylog/http.go b/internal/querylog/http.go index fb878e04..727be326 100644 --- a/internal/querylog/http.go +++ b/internal/querylog/http.go @@ -186,7 +186,7 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request) return } - defer l.conf.ConfigModified() + defer l.conf.ConfigModifier.Apply(r.Context()) l.confMu.Lock() defer l.confMu.Unlock() @@ -250,7 +250,7 @@ func (l *queryLog) handlePutQueryLogConfig(w http.ResponseWriter, r *http.Reques return } - defer l.conf.ConfigModified() + defer l.conf.ConfigModifier.Apply(r.Context()) l.confMu.Lock() defer l.confMu.Unlock() diff --git a/internal/querylog/querylog.go b/internal/querylog/querylog.go index c7350f70..c4ed53cc 100644 --- a/internal/querylog/querylog.go +++ b/internal/querylog/querylog.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -47,9 +48,9 @@ type Config struct { // Anonymizer processes the IP addresses to anonymize those if needed. Anonymizer *aghnet.IPMut - // ConfigModified is called when the configuration is changed, for example - // by HTTP requests. - ConfigModified func() + // ConfigModifier is used to update the global configuration. It must not + // be nil. + ConfigModifier agh.ConfigModifier // HTTPRegister registers an HTTP handler. HTTPRegister aghhttp.RegisterFunc diff --git a/internal/stats/http.go b/internal/stats/http.go index c2ea01d0..67b78a97 100644 --- a/internal/stats/http.go +++ b/internal/stats/http.go @@ -166,7 +166,7 @@ func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) { limit := time.Duration(reqData.IntervalDays) * timeutil.Day - defer s.configModified() + defer s.configModifier.Apply(ctx) s.confMu.Lock() defer s.confMu.Unlock() @@ -216,7 +216,7 @@ func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request) return } - defer s.configModified() + defer s.configModifier.Apply(ctx) s.confMu.Lock() defer s.confMu.Unlock() diff --git a/internal/stats/http_internal_test.go b/internal/stats/http_internal_test.go index b53668d6..9a53c01a 100644 --- a/internal/stats/http_internal_test.go +++ b/internal/stats/http_internal_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" @@ -27,7 +28,7 @@ func TestHandleStatsConfig(t *testing.T) { conf := Config{ Logger: slogutil.NewDiscardLogger(), UnitID: func() (id uint32) { return 0 }, - ConfigModified: func() {}, + ConfigModifier: agh.EmptyConfigModifier{}, ShouldCountClient: func([]string) bool { return true }, Filename: filepath.Join(t.TempDir(), "stats.db"), Limit: time.Hour * 24, diff --git a/internal/stats/stats.go b/internal/stats/stats.go index 3b63df5b..a1506cf4 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -13,6 +13,7 @@ import ( "sync/atomic" "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/aghos" @@ -55,9 +56,9 @@ type Config struct { // nil, the default function is used, see newUnitID. UnitID UnitIDGenFunc - // ConfigModified will be called each time the configuration changed via web - // interface. - ConfigModified func() + // ConfigModifier is used to update the global configuration. It must not + // be nil. + ConfigModifier agh.ConfigModifier // ShouldCountClient returns client's ignore setting. ShouldCountClient func([]string) bool @@ -123,9 +124,8 @@ type StatsCtx struct { // httpRegister is used to set HTTP handlers. httpRegister aghhttp.RegisterFunc - // configModified is called whenever the configuration is modified via web - // interface. - configModified func() + // configModifier is used to update the global configuration. + configModifier agh.ConfigModifier // confMu protects ignored, limit, and enabled. confMu *sync.RWMutex @@ -165,7 +165,7 @@ func New(conf Config) (s *StatsCtx, err error) { logger: conf.Logger, currMu: &sync.RWMutex{}, httpRegister: conf.HTTPRegister, - configModified: conf.ConfigModified, + configModifier: conf.ConfigModifier, filename: conf.Filename, confMu: &sync.RWMutex{},