mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-10-26 11:27:18 +00:00
Pull request 2442: AGDNS-3061-config-modifier
Merge in DNS/adguard-home from AGDNS-3061-config-modifier to master Squashed commit of the following: commit a0068547bd0209d12e8dbf98ddd5e4ed2545cdd0 Merge:97b798af6451255675Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Aug 5 17:39:31 2025 +0300 Merge branch 'master' into AGDNS-3061-config-modifier commit97b798af6aMerge:96d21efc9b8043e4f0Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Jul 30 20:27:41 2025 +0300 Merge branch 'master' into AGDNS-3061-config-modifier commit96d21efc98Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Jul 30 20:18:43 2025 +0300 all: imp code commit67c5608b4bAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Jul 29 20:31:19 2025 +0300 all: imp code commit52f45a9f70Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Mon Jul 28 20:32:13 2025 +0300 all: use config modifier commitd389ffd286Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Jul 25 14:20:31 2025 +0300 bamboo-specs: fix ci commit3f303ac913Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Fri Jul 25 14:18:42 2025 +0300 home: config modifier
This commit is contained in:
parent
451255675e
commit
86de4e75f0
22
internal/agh/agh.go
Normal file
22
internal/agh/agh.go
Normal file
@ -0,0 +1,22 @@
|
||||
// Package agh contains common entities and interfaces of AdGuard Home.
|
||||
package agh
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// ConfigModifier defines an interface for updating the global configuration.
|
||||
type ConfigModifier interface {
|
||||
// Apply applies changes to the global configuration.
|
||||
Apply(ctx context.Context)
|
||||
}
|
||||
|
||||
// EmptyConfigModifier is an empty [ConfigModifier] implementation that does
|
||||
// nothing.
|
||||
type EmptyConfigModifier struct{}
|
||||
|
||||
// type check
|
||||
var _ ConfigModifier = EmptyConfigModifier{}
|
||||
|
||||
// Apply implements the [ConfigModifier] for EmptyConfigModifier.
|
||||
func (em EmptyConfigModifier) Apply(ctx context.Context) {}
|
||||
@ -6,8 +6,9 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
nextagh "github.com/AdguardTeam/AdGuardHome/internal/next/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/rdns"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/whois"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
@ -53,9 +54,10 @@ func (w *FSWatcher) Add(name string) (err error) {
|
||||
return w.OnAdd(name)
|
||||
}
|
||||
|
||||
// Package agh
|
||||
// Package nextagh
|
||||
|
||||
// ServiceWithConfig is a fake [agh.ServiceWithConfig] implementation for tests.
|
||||
// ServiceWithConfig is a fake [nextagh.ServiceWithConfig] implementation for
|
||||
// tests.
|
||||
type ServiceWithConfig[ConfigType any] struct {
|
||||
OnStart func(ctx context.Context) (err error)
|
||||
OnShutdown func(ctx context.Context) (err error)
|
||||
@ -63,21 +65,21 @@ type ServiceWithConfig[ConfigType any] struct {
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
|
||||
var _ nextagh.ServiceWithConfig[struct{}] = (*ServiceWithConfig[struct{}])(nil)
|
||||
|
||||
// Start implements the [agh.ServiceWithConfig] interface for
|
||||
// Start implements the [nextagh.ServiceWithConfig] interface for
|
||||
// *ServiceWithConfig.
|
||||
func (s *ServiceWithConfig[_]) Start(ctx context.Context) (err error) {
|
||||
return s.OnStart(ctx)
|
||||
}
|
||||
|
||||
// Shutdown implements the [agh.ServiceWithConfig] interface for
|
||||
// Shutdown implements the [nextagh.ServiceWithConfig] interface for
|
||||
// *ServiceWithConfig.
|
||||
func (s *ServiceWithConfig[_]) Shutdown(ctx context.Context) (err error) {
|
||||
return s.OnShutdown(ctx)
|
||||
}
|
||||
|
||||
// Config implements the [agh.ServiceWithConfig] interface for
|
||||
// Config implements the [nextagh.ServiceWithConfig] interface for
|
||||
// *ServiceWithConfig.
|
||||
func (s *ServiceWithConfig[ConfigType]) Config() (c ConfigType) {
|
||||
return s.OnConfig()
|
||||
@ -178,3 +180,16 @@ func (u *UpstreamMock) Exchange(req *dns.Msg) (resp *dns.Msg, err error) {
|
||||
func (u *UpstreamMock) Close() (err error) {
|
||||
return u.OnClose()
|
||||
}
|
||||
|
||||
// ConfigModifier is a fake [agh.ConfigModifier] implementation for tests.
|
||||
type ConfigModifier struct {
|
||||
OnApply func(ctx context.Context)
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.ConfigModifier = (*ConfigModifier)(nil)
|
||||
|
||||
// Apply implements the [ConfigModifier] interface for *ConfigModifier.
|
||||
func (m *ConfigModifier) Apply(ctx context.Context) {
|
||||
m.OnApply(ctx)
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
|
||||
@ -15,8 +16,9 @@ import (
|
||||
// ServerConfig is the configuration for the DHCP server. The order of YAML
|
||||
// fields is important, since the YAML configuration file follows it.
|
||||
type ServerConfig struct {
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
// ConfModifier is used to update the global configuration. It must not be
|
||||
// nil.
|
||||
ConfModifier agh.ConfigModifier `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
@ -107,7 +107,7 @@ var _ Interface = (*server)(nil)
|
||||
func Create(conf *ServerConfig) (s *server, err error) {
|
||||
s = &server{
|
||||
conf: &ServerConfig{
|
||||
ConfigModified: conf.ConfigModified,
|
||||
ConfModifier: conf.ConfModifier,
|
||||
|
||||
HTTPRegister: conf.HTTPRegister,
|
||||
|
||||
|
||||
@ -335,7 +335,7 @@ func (s *server) handleDHCPSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
s.setConfFromJSON(conf, srv4, srv6)
|
||||
s.conf.ConfigModified()
|
||||
s.conf.ConfModifier.Apply(r.Context())
|
||||
|
||||
err = s.dbLoad()
|
||||
if err != nil {
|
||||
@ -679,7 +679,7 @@ func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
s.conf = &ServerConfig{
|
||||
ConfigModified: s.conf.ConfigModified,
|
||||
ConfModifier: s.conf.ConfModifier,
|
||||
|
||||
HTTPRegister: s.conf.HTTPRegister,
|
||||
|
||||
@ -702,7 +702,7 @@ func (s *server) handleReset(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
s.srv6, _ = v6Create(v6conf)
|
||||
|
||||
s.conf.ConfigModified()
|
||||
s.conf.ConfModifier.Apply(r.Context())
|
||||
}
|
||||
|
||||
func (s *server) handleResetLeases(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -84,10 +85,10 @@ func TestServer_handleDHCPStatus(t *testing.T) {
|
||||
}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
DataDir: t.TempDir(),
|
||||
ConfigModified: func() {},
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
DataDir: t.TempDir(),
|
||||
ConfModifier: agh.EmptyConfigModifier{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -178,11 +179,11 @@ func TestServer_HandleUpdateStaticLease(t *testing.T) {
|
||||
}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
Conf6: V6ServerConf{},
|
||||
DataDir: t.TempDir(),
|
||||
ConfigModified: func() {},
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
Conf6: V6ServerConf{},
|
||||
DataDir: t.TempDir(),
|
||||
ConfModifier: agh.EmptyConfigModifier{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -266,11 +267,11 @@ func TestServer_HandleUpdateStaticLease_validation(t *testing.T) {
|
||||
}}
|
||||
|
||||
s, err := Create(&ServerConfig{
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
Conf6: V6ServerConf{},
|
||||
DataDir: t.TempDir(),
|
||||
ConfigModified: func() {},
|
||||
Enabled: true,
|
||||
Conf4: *defaultV4ServerConf(),
|
||||
Conf6: V6ServerConf{},
|
||||
DataDir: t.TempDir(),
|
||||
ConfModifier: agh.EmptyConfigModifier{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
"github.com/AdguardTeam/urlfilter"
|
||||
"github.com/AdguardTeam/urlfilter/filterlist"
|
||||
@ -230,6 +229,8 @@ func validateStrUniq(clients []string) (uc aghalg.UniqChecker[string], err error
|
||||
|
||||
// handleAccessSet handles requests to the POST /control/access/set endpoint.
|
||||
func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
list := &accessListJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
@ -253,14 +254,15 @@ func (s *Server) handleAccessSet(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
defer log.Debug(
|
||||
"access: updated lists: %d, %d, %d",
|
||||
len(list.AllowedClients),
|
||||
len(list.DisallowedClients),
|
||||
len(list.BlockedHosts),
|
||||
defer s.logger.DebugContext(
|
||||
ctx,
|
||||
"updated access lists",
|
||||
"allowed", len(list.AllowedClients),
|
||||
"disallowed", len(list.DisallowedClients),
|
||||
"blocked_hosts", len(list.BlockedHosts),
|
||||
)
|
||||
|
||||
defer s.conf.ConfigModified()
|
||||
defer s.conf.ConfModifier.Apply(ctx)
|
||||
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
@ -41,7 +41,13 @@ func (s *Server) HandleBefore(
|
||||
qt := q.Qtype
|
||||
host := aghnet.NormalizeDomain(q.Name)
|
||||
if s.access.isBlockedHost(host, qt) {
|
||||
log.Debug("access: request %s %s is in access blocklist", dns.Type(qt), host)
|
||||
// TODO(s.chzhen): Pass context.
|
||||
s.logger.DebugContext(
|
||||
context.TODO(),
|
||||
"request is in access blocklist",
|
||||
"dns_type", dns.Type(qt),
|
||||
"host", host,
|
||||
)
|
||||
|
||||
return s.preBlockedResponse(pctx)
|
||||
}
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -132,7 +133,7 @@ func TestServer_HandleBefore_tls(t *testing.T) {
|
||||
s.conf.DisallowedClients = tc.disallowedClients
|
||||
s.conf.BlockedHosts = tc.blockedHosts
|
||||
|
||||
err := s.Prepare(&s.conf)
|
||||
err := s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
startDeferStop(t, s)
|
||||
|
||||
@ -201,6 +201,7 @@ func TestServer_clientIDFromDNSContext(t *testing.T) {
|
||||
srv := &Server{
|
||||
conf: ServerConfig{TLSConf: tlsConf},
|
||||
baseLogger: testLogger,
|
||||
logger: testLogger,
|
||||
}
|
||||
|
||||
var (
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
@ -11,6 +12,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghslog"
|
||||
@ -259,8 +261,9 @@ type ServerConfig struct {
|
||||
// TLSCiphers are the IDs of TLS cipher suites to use.
|
||||
TLSCiphers []uint16
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func()
|
||||
// ConfModifier is used to update the global configuration. It must not be
|
||||
// nil.
|
||||
ConfModifier agh.ConfigModifier
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
@ -307,7 +310,7 @@ const (
|
||||
)
|
||||
|
||||
// newProxyConfig creates and validates configuration for the main proxy.
|
||||
func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err error) {
|
||||
srvConf := s.conf
|
||||
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
|
||||
|
||||
@ -355,12 +358,12 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
|
||||
return nil, fmt.Errorf("bogus_nxdomain: %w", err)
|
||||
}
|
||||
|
||||
err = s.prepareTLS(conf)
|
||||
err = s.prepareTLS(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating tls: %w", err)
|
||||
}
|
||||
|
||||
err = s.preparePlain(conf)
|
||||
err = s.preparePlain(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("validating plain: %w", err)
|
||||
}
|
||||
@ -444,7 +447,7 @@ func (s *Server) initDefaultSettings() {
|
||||
|
||||
// prepareIpsetListSettings reads and prepares the ipset configuration either
|
||||
// from a file or from the data in the configuration file.
|
||||
func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
|
||||
func (s *Server) prepareIpsetListSettings(ctx context.Context) (ipsets []string, err error) {
|
||||
fn := s.conf.IpsetListFileName
|
||||
if fn == "" {
|
||||
return s.conf.IpsetList, nil
|
||||
@ -459,7 +462,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
|
||||
ipsets = stringutil.SplitTrimmed(string(data), "\n")
|
||||
ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty)
|
||||
|
||||
log.Debug("dns: using %d ipset rules from file %q", len(ipsets), fn)
|
||||
s.logger.DebugContext(ctx, "using ipset rules from file", "num", len(ipsets), "file", fn)
|
||||
|
||||
return ipsets, nil
|
||||
}
|
||||
@ -629,7 +632,7 @@ func (s *Server) prepareDNSCrypt(proxyConf *proxy.Config) {
|
||||
}
|
||||
|
||||
// prepareTLS sets up the TLS configuration for the DNS proxy.
|
||||
func (s *Server) prepareTLS(proxyConf *proxy.Config) (err error) {
|
||||
func (s *Server) prepareTLS(ctx context.Context, proxyConf *proxy.Config) (err error) {
|
||||
s.prepareDNSCrypt(proxyConf)
|
||||
|
||||
if s.conf.TLSConf.Cert == nil {
|
||||
@ -653,11 +656,20 @@ func (s *Server) prepareTLS(proxyConf *proxy.Config) (err error) {
|
||||
if s.conf.TLSConf.StrictSNICheck {
|
||||
if len(cert.DNSNames) != 0 {
|
||||
s.dnsNames = cert.DNSNames
|
||||
log.Debug("dns: using certificate's SAN as DNS names: %v", cert.DNSNames)
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"using certificate's SAN as DNS names",
|
||||
"dns_names", cert.DNSNames,
|
||||
)
|
||||
slices.Sort(s.dnsNames)
|
||||
} else {
|
||||
s.dnsNames = []string{cert.Subject.CommonName}
|
||||
log.Debug("dns: using certificate's CN as DNS name: %s", cert.Subject.CommonName)
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"using certificate's CN as DNS name",
|
||||
"common_name",
|
||||
cert.Subject.CommonName,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -706,15 +718,22 @@ func anyNameMatches(dnsNames []string, sni string) (ok bool) {
|
||||
// If the server name (from SNI) supplied by client is incorrect - we terminate the ongoing TLS handshake.
|
||||
func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||
if s.conf.TLSConf.StrictSNICheck && !anyNameMatches(s.dnsNames, ch.ServerName) {
|
||||
log.Info("dns: tls: unknown SNI in Client Hello: %s", ch.ServerName)
|
||||
// TODO(s.chzhen): Pass context.
|
||||
s.logger.WarnContext(
|
||||
context.TODO(),
|
||||
"unknown SNI in Client Hello",
|
||||
"server_name", ch.ServerName,
|
||||
)
|
||||
|
||||
return nil, fmt.Errorf("invalid SNI")
|
||||
}
|
||||
|
||||
return s.conf.TLSConf.Cert, nil
|
||||
}
|
||||
|
||||
// preparePlain prepares the plain-DNS configuration for the DNS proxy.
|
||||
// preparePlain assumes that prepareTLS has already been called.
|
||||
func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) {
|
||||
func (s *Server) preparePlain(ctx context.Context, proxyConf *proxy.Config) (err error) {
|
||||
if s.conf.ServePlainDNS {
|
||||
proxyConf.UDPListenAddr = s.conf.UDPListenAddrs
|
||||
proxyConf.TCPListenAddr = s.conf.TCPListenAddrs
|
||||
@ -732,14 +751,16 @@ func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) {
|
||||
return errors.Error("disabling plain dns requires at least one encrypted protocol")
|
||||
}
|
||||
|
||||
log.Info("dnsforward: warning: plain dns is disabled")
|
||||
s.logger.WarnContext(ctx, "plain dns is disabled")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdatedProtectionStatus updates protection state, if the protection was
|
||||
// disabled temporarily. Returns the updated state of protection.
|
||||
func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {
|
||||
func (s *Server) UpdatedProtectionStatus(
|
||||
ctx context.Context,
|
||||
) (enabled bool, disabledUntil *time.Time) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
@ -759,7 +780,7 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/5661.
|
||||
if s.protectionUpdateInProgress.CompareAndSwap(false, true) {
|
||||
go s.enableProtectionAfterPause()
|
||||
go s.enableProtectionAfterPause(ctx)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
@ -767,19 +788,19 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti
|
||||
|
||||
// enableProtectionAfterPause sets the protection configuration to enabled
|
||||
// values. It is intended to be used as a goroutine.
|
||||
func (s *Server) enableProtectionAfterPause() {
|
||||
defer log.OnPanic("dns: enabling protection after pause")
|
||||
func (s *Server) enableProtectionAfterPause(ctx context.Context) {
|
||||
defer slogutil.RecoverAndLog(ctx, s.logger)
|
||||
|
||||
defer s.protectionUpdateInProgress.Store(false)
|
||||
|
||||
defer s.conf.ConfigModified()
|
||||
defer s.conf.ConfModifier.Apply(ctx)
|
||||
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.dnsFilter.SetProtectionStatus(true, nil)
|
||||
|
||||
log.Info("dns: protection is restarted after pause")
|
||||
s.logger.InfoContext(ctx, "protection is restarted after pause")
|
||||
}
|
||||
|
||||
// validateCacheTTL returns an error if the configuration of the cache TTL
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
)
|
||||
|
||||
@ -17,7 +16,7 @@ import (
|
||||
// addr should be a valid host:port address, where host could be a domain name
|
||||
// or an IP address.
|
||||
func (s *Server) DialContext(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
log.Debug("dnsforward: dialing %q for network %q", addr, network)
|
||||
s.logger.DebugContext(ctx, "dialing", "addr", addr, "network", network)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
@ -45,7 +44,7 @@ func (s *Server) DialContext(ctx context.Context, network, addr string) (conn ne
|
||||
return nil, fmt.Errorf("no addresses for host %q", host)
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: resolved %q: %v", host, ips)
|
||||
s.logger.DebugContext(ctx, "resolved", "host", host, "ips", ips)
|
||||
|
||||
var dialErrs []error
|
||||
for _, ip := range ips {
|
||||
|
||||
@ -145,6 +145,10 @@ type Server struct {
|
||||
// have a prefix and must not be nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// logger is used to log the operation of the DNS server. It is created
|
||||
// during initialization in [NewServer].
|
||||
logger *slog.Logger
|
||||
|
||||
// dnsFilter is the DNS filter for filtering client's DNS requests and
|
||||
// responses.
|
||||
dnsFilter *filtering.DNSFilter
|
||||
@ -254,6 +258,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
queryLog: p.QueryLog,
|
||||
privateNets: p.PrivateNets,
|
||||
baseLogger: p.Logger,
|
||||
logger: p.Logger.With(slogutil.KeyPrefix, "dnsforward"),
|
||||
// TODO(e.burkov): Use some case-insensitive string comparison.
|
||||
localDomainSuffix: strings.ToLower(localDomainSuffix),
|
||||
etcHosts: etcHosts,
|
||||
@ -286,7 +291,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
|
||||
// its workers finished. But it would require the upstream.Upstream to have the
|
||||
// Close method to prevent from hanging while waiting for unresponsive server to
|
||||
// respond.
|
||||
func (s *Server) Close() {
|
||||
func (s *Server) Close(ctx context.Context) {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
@ -296,7 +301,7 @@ func (s *Server) Close() {
|
||||
s.dnsProxy = nil
|
||||
|
||||
if err := s.ipset.close(); err != nil {
|
||||
log.Error("dnsforward: closing ipset: %s", err)
|
||||
s.logger.ErrorContext(ctx, "closing ipset", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -461,18 +466,17 @@ func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
|
||||
}
|
||||
|
||||
// Start starts the DNS server. It must only be called after [Server.Prepare].
|
||||
func (s *Server) Start() error {
|
||||
func (s *Server) Start(ctx context.Context) error {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
return s.startLocked()
|
||||
return s.startLocked(ctx)
|
||||
}
|
||||
|
||||
// startLocked starts the DNS server without locking. s.serverLock is expected
|
||||
// to be locked.
|
||||
func (s *Server) startLocked() error {
|
||||
// TODO(e.burkov): Use context properly.
|
||||
err := s.dnsProxy.Start(context.Background())
|
||||
func (s *Server) startLocked(ctx context.Context) error {
|
||||
err := s.dnsProxy.Start(ctx)
|
||||
if err == nil {
|
||||
s.isRunning = true
|
||||
}
|
||||
@ -482,7 +486,7 @@ func (s *Server) startLocked() error {
|
||||
|
||||
// Prepare initializes parameters of s using data from conf. conf must not be
|
||||
// nil.
|
||||
func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
func (s *Server) Prepare(ctx context.Context, conf *ServerConfig) (err error) {
|
||||
s.conf = *conf
|
||||
|
||||
// dnsFilter can be nil during application update.
|
||||
@ -496,13 +500,13 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
|
||||
|
||||
s.initDefaultSettings()
|
||||
|
||||
err = s.prepareInternalDNS()
|
||||
err = s.prepareInternalDNS(ctx)
|
||||
if err != nil {
|
||||
// Don't wrap the error, because it's informative enough as is.
|
||||
return err
|
||||
}
|
||||
|
||||
proxyConfig, err := s.newProxyConfig()
|
||||
proxyConfig, err := s.newProxyConfig(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing proxy: %w", err)
|
||||
}
|
||||
@ -633,8 +637,8 @@ func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
|
||||
// prepareInternalDNS initializes the internal state of s before initializing
|
||||
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
|
||||
// Server not running.
|
||||
func (s *Server) prepareInternalDNS() (err error) {
|
||||
ipsetList, err := s.prepareIpsetListSettings()
|
||||
func (s *Server) prepareInternalDNS(ctx context.Context) (err error) {
|
||||
ipsetList, err := s.prepareIpsetListSettings(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("preparing ipset settings: %w", err)
|
||||
}
|
||||
@ -779,27 +783,26 @@ func (s *Server) prepareInternalProxy() (err error) {
|
||||
}
|
||||
|
||||
// Stop stops the DNS server.
|
||||
func (s *Server) Stop() error {
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
s.stopLocked()
|
||||
s.stopLocked(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// stopLocked stops the DNS server without locking. s.serverLock is expected to
|
||||
// be locked.
|
||||
func (s *Server) stopLocked() {
|
||||
func (s *Server) stopLocked(ctx context.Context) {
|
||||
// TODO(e.burkov, a.garipov): Return critical errors, not just log them.
|
||||
// This will require filtering all the non-critical errors in
|
||||
// [upstream.Upstream] implementations.
|
||||
|
||||
if s.dnsProxy != nil {
|
||||
// TODO(e.burkov): Use context properly.
|
||||
err := s.dnsProxy.Shutdown(context.Background())
|
||||
err := s.dnsProxy.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Error("dnsforward: closing primary resolvers: %s", err)
|
||||
s.logger.ErrorContext(ctx, "closing primary resolvers", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -848,14 +851,14 @@ func (s *Server) proxy() (p *proxy.Proxy) {
|
||||
// Reconfigure applies the new configuration to the DNS server.
|
||||
//
|
||||
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
|
||||
func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
func (s *Server) Reconfigure(ctx context.Context, conf *ServerConfig) error {
|
||||
s.serverLock.Lock()
|
||||
defer s.serverLock.Unlock()
|
||||
|
||||
log.Info("dnsforward: starting reconfiguring server")
|
||||
defer log.Info("dnsforward: finished reconfiguring server")
|
||||
s.logger.InfoContext(ctx, "starting reconfiguring server")
|
||||
defer s.logger.InfoContext(ctx, "finished reconfiguring server")
|
||||
|
||||
s.stopLocked()
|
||||
s.stopLocked(ctx)
|
||||
|
||||
// It seems that net.Listener.Close() doesn't close file descriptors right away.
|
||||
// We wait for some time and hope that this fd will be closed.
|
||||
@ -864,7 +867,7 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
if s.addrProc != nil {
|
||||
err := s.addrProc.Close()
|
||||
if err != nil {
|
||||
log.Error("dnsforward: closing address processor: %s", err)
|
||||
s.logger.ErrorContext(ctx, "closing address processor", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -874,12 +877,12 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
|
||||
|
||||
// TODO(e.burkov): It seems an error here brings the server down, which is
|
||||
// not reliable enough.
|
||||
err := s.Prepare(conf)
|
||||
err := s.Prepare(ctx, conf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not reconfigure the server: %w", err)
|
||||
}
|
||||
|
||||
err = s.startLocked()
|
||||
err = s.startLocked(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not reconfigure the server: %w", err)
|
||||
}
|
||||
@ -908,16 +911,29 @@ func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool,
|
||||
allowlistMode := s.access.allowlistMode()
|
||||
blockedByClientID := s.access.isBlockedClientID(clientID)
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
// Allow if at least one of the checks allows in allowlist mode, but block
|
||||
// if at least one of the checks blocks in blocklist mode.
|
||||
if allowlistMode && blockedByIP && blockedByClientID {
|
||||
log.Debug("dnsforward: client %v (id %q) is not in access allowlist", ip, clientID)
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"client is not in access allowlist",
|
||||
"ip", ip,
|
||||
"client_id", clientID,
|
||||
)
|
||||
|
||||
// Return now without substituting the empty rule for the
|
||||
// clientID because the rule can't be empty here.
|
||||
return true, rule
|
||||
} else if !allowlistMode && (blockedByIP || blockedByClientID) {
|
||||
log.Debug("dnsforward: client %v (id %q) is in access blocklist", ip, clientID)
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"client is in access blocklist",
|
||||
"ip", ip,
|
||||
"client_id", clientID,
|
||||
)
|
||||
|
||||
blocked = true
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
@ -106,9 +107,11 @@ func (c *clientsContainer) ClearUpstreamCache() {
|
||||
func startDeferStop(t *testing.T, s *Server) {
|
||||
t.Helper()
|
||||
|
||||
err := s.Start()
|
||||
err := s.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.Stop(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
}
|
||||
|
||||
// applyEmptyClientFiltering is a helper function for tests with
|
||||
@ -169,7 +172,7 @@ func createTestServer(
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(&forwardConf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
return s
|
||||
@ -244,7 +247,7 @@ func createTestTLS(t *testing.T, tlsConf *TLSConfig) (s *Server, certPem []byte)
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
|
||||
err = s.Prepare(&s.conf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
|
||||
require.NoErrorf(t, err, "failed to prepare server: %s", err)
|
||||
|
||||
return s, certPem
|
||||
@ -420,7 +423,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(srvConf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), srvConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
|
||||
@ -439,7 +442,7 @@ func TestServer_timeout(t *testing.T) {
|
||||
Enabled: false,
|
||||
}
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
err = s.Prepare(&s.conf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, DefaultTimeout, s.conf.UpstreamTimeout)
|
||||
@ -466,7 +469,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(srvConf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), srvConf)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, s.dnsProxy.Fallbacks)
|
||||
|
||||
@ -566,8 +569,8 @@ func TestServerRace(t *testing.T) {
|
||||
UpstreamMode: UpstreamModeLoadBalance,
|
||||
UpstreamDNS: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
ConfModifier: agh.EmptyConfigModifier{},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{newGoogleUpstream()}
|
||||
@ -1104,7 +1107,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
}
|
||||
|
||||
// Invalid BlockingIPv4.
|
||||
err = s.Prepare(conf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), conf)
|
||||
assert.Error(t, err)
|
||||
|
||||
s.dnsFilter.SetBlockingMode(
|
||||
@ -1112,7 +1115,7 @@ func TestBlockedCustomIP(t *testing.T) {
|
||||
netip.AddrFrom4([4]byte{0, 0, 0, 1}),
|
||||
netip.MustParseAddr("::1"))
|
||||
|
||||
err = s.Prepare(conf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
f.SetEnabled(true)
|
||||
@ -1264,7 +1267,7 @@ func TestRewrite(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NoError(t, s.Prepare(&ServerConfig{
|
||||
assert.NoError(t, s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &ServerConfig{
|
||||
UDPListenAddrs: []*net.UDPAddr{{}},
|
||||
TCPListenAddrs: []*net.TCPAddr{{}},
|
||||
TLSConf: &TLSConfig{},
|
||||
@ -1334,7 +1337,7 @@ func TestRewrite(t *testing.T) {
|
||||
|
||||
for _, protect := range []bool{true, false} {
|
||||
val := protect
|
||||
conf := s.getDNSConfig()
|
||||
conf := s.getDNSConfig(testutil.ContextWithTimeout(t, testTimeout))
|
||||
conf.ProtectionEnabled = &val
|
||||
s.setConfig(conf)
|
||||
|
||||
@ -1408,12 +1411,12 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
|
||||
err = s.Prepare(&s.conf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
err = s.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(s.Close)
|
||||
t.Cleanup(func() { s.Close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
|
||||
@ -1498,12 +1501,12 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
|
||||
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
|
||||
|
||||
err = s.Prepare(&s.conf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
err = s.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(s.Close)
|
||||
t.Cleanup(func() { s.Close(testutil.ContextWithTimeout(t, testTimeout)) })
|
||||
|
||||
subTestFunc := func(t *testing.T) {
|
||||
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
|
||||
@ -1524,7 +1527,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
|
||||
|
||||
for _, protect := range []bool{true, false} {
|
||||
val := protect
|
||||
conf := s.getDNSConfig()
|
||||
conf := s.getDNSConfig(testutil.ContextWithTimeout(t, testTimeout))
|
||||
conf.ProtectionEnabled = &val
|
||||
s.setConfig(conf)
|
||||
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -15,6 +15,7 @@ import (
|
||||
// filterDNSRewriteResponse handles a single DNS rewrite response entry. It
|
||||
// returns the properly constructed answer resource record.
|
||||
func (s *Server) filterDNSRewriteResponse(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
rr rules.RRType,
|
||||
v rules.RRValue,
|
||||
@ -27,11 +28,11 @@ func (s *Server) filterDNSRewriteResponse(
|
||||
case dns.TypeMX:
|
||||
return s.ansFromDNSRewriteMX(v, rr, req)
|
||||
case dns.TypeHTTPS, dns.TypeSVCB:
|
||||
return s.ansFromDNSRewriteSVCB(v, rr, req)
|
||||
return s.ansFromDNSRewriteSVCB(ctx, v, rr, req)
|
||||
case dns.TypeSRV:
|
||||
return s.ansFromDNSRewriteSRV(v, rr, req)
|
||||
default:
|
||||
log.Debug("don't know how to handle dns rr type %d, skipping", rr)
|
||||
s.logger.DebugContext(ctx, "unsupported dns rr type, skipping", "res_record", rr)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@ -97,6 +98,7 @@ func (s *Server) ansFromDNSRewriteMX(
|
||||
// ansFromDNSRewriteSVCB creates a new answer resource record from the
|
||||
// SVCB/HTTPS dnsrewrite rule data.
|
||||
func (s *Server) ansFromDNSRewriteSVCB(
|
||||
ctx context.Context,
|
||||
v rules.RRValue,
|
||||
rr rules.RRType,
|
||||
req *dns.Msg,
|
||||
@ -111,10 +113,10 @@ func (s *Server) ansFromDNSRewriteSVCB(
|
||||
}
|
||||
|
||||
if rr == dns.TypeHTTPS {
|
||||
return s.genAnswerHTTPS(req, svcb), nil
|
||||
return s.genAnswerHTTPS(ctx, req, svcb), nil
|
||||
}
|
||||
|
||||
return s.genAnswerSVCB(req, svcb), nil
|
||||
return s.genAnswerSVCB(ctx, req, svcb), nil
|
||||
}
|
||||
|
||||
// ansFromDNSRewriteSRV creates a new answer resource record from the SRV
|
||||
@ -139,6 +141,7 @@ func (s *Server) ansFromDNSRewriteSRV(
|
||||
// filterDNSRewrite handles dnsrewrite filters. It constructs a DNS response
|
||||
// and sets it into pctx.Res. All parameters must not be nil.
|
||||
func (s *Server) filterDNSRewrite(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
res *filtering.Result,
|
||||
pctx *proxy.DNSContext,
|
||||
@ -164,7 +167,7 @@ func (s *Server) filterDNSRewrite(
|
||||
values := dnsrr.Response[qtype]
|
||||
for i, v := range values {
|
||||
var ans dns.RR
|
||||
ans, err = s.filterDNSRewriteResponse(req, qtype, v)
|
||||
ans, err = s.filterDNSRewriteResponse(ctx, req, qtype, v)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dns rewrite response for %s[%d]: %w", dns.Type(qtype), i, err)
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -71,7 +72,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeNameError, 0, nil)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
|
||||
@ -82,7 +83,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, 0, nil)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -94,7 +95,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeA, ip4)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -108,7 +109,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeAAAA, ip6)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -122,7 +123,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypePTR, domain)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -136,7 +137,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeTXT, domain)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -150,7 +151,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeMX, mxVal)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -168,7 +169,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeSVCB, svcbVal)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -198,7 +199,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeHTTPS, svcbVal)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
@ -228,7 +229,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
|
||||
res := makeRes(dns.RcodeSuccess, dns.TypeSRV, srvVal)
|
||||
d := &proxy.DNSContext{}
|
||||
|
||||
err := srv.filterDNSRewrite(req, res, d)
|
||||
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -24,7 +24,10 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
|
||||
|
||||
// filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the
|
||||
// request was filtered.
|
||||
func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err error) {
|
||||
func (s *Server) filterDNSRequest(
|
||||
ctx context.Context,
|
||||
dctx *dnsContext,
|
||||
) (res *filtering.Result, err error) {
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
q := req.Question[0]
|
||||
@ -44,12 +47,12 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
|
||||
dctx.origQuestion = q
|
||||
req.Question[0].Name = dns.Fqdn(res.CanonName)
|
||||
case res.IsFiltered:
|
||||
log.Debug("dnsforward: host %q is filtered, reason: %q", host, res.Reason)
|
||||
pctx.Res = s.genDNSFilterMessage(pctx, res)
|
||||
s.logger.DebugContext(ctx, "host is filtered", "host", host, "reason", res.Reason)
|
||||
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
|
||||
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
|
||||
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
|
||||
pctx.Res = s.getCNAMEWithIPs(ctx, req, res.IPList, res.CanonName)
|
||||
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
|
||||
if err = s.filterDNSRewrite(req, res, pctx); err != nil {
|
||||
if err = s.filterDNSRewrite(ctx, req, res, pctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -90,7 +93,7 @@ func (s *Server) checkHostRules(
|
||||
// dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
|
||||
// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
|
||||
// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
|
||||
func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
|
||||
func (s *Server) filterDNSResponse(ctx context.Context, dctx *dnsContext) (err error) {
|
||||
setts := dctx.setts
|
||||
if !setts.FilteringEnabled {
|
||||
return nil
|
||||
@ -123,16 +126,27 @@ func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: checked %s %s for %s", dns.Type(rrtype), host, a.Header().Name)
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"checked",
|
||||
"dns_type", dns.Type(rrtype),
|
||||
"host", host,
|
||||
"name", a.Header().Name,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("filtering answer at index %d: %w", i, err)
|
||||
} else if res != nil && res.IsFiltered {
|
||||
dctx.result = res
|
||||
dctx.origResp = pctx.Res
|
||||
pctx.Res = s.genDNSFilterMessage(pctx, res)
|
||||
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
|
||||
|
||||
log.Debug("dnsforward: matched %q by response: %q", pctx.Req.Question[0].Name, host)
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"matched by response",
|
||||
"name", pctx.Req.Question[0].Name,
|
||||
"host", host,
|
||||
)
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -66,7 +67,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Prepare(&forwardConf)
|
||||
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf)
|
||||
require.NoError(t, err)
|
||||
|
||||
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
|
||||
@ -347,7 +348,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
fltErr := s.filterDNSResponse(dctx)
|
||||
fltErr := s.filterDNSResponse(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
require.NoError(t, fltErr)
|
||||
|
||||
res := dctx.result
|
||||
|
||||
@ -2,6 +2,7 @@ package dnsforward
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -17,7 +18,6 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
@ -138,8 +138,8 @@ const (
|
||||
jsonUpstreamModeFastestAddr jsonUpstreamMode = "fastest_addr"
|
||||
)
|
||||
|
||||
func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus()
|
||||
func (s *Server) getDNSConfig(ctx context.Context) (c *jsonDNSConfig) {
|
||||
protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus(ctx)
|
||||
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
@ -184,7 +184,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
|
||||
|
||||
defPTRUps, err := s.defaultLocalPTRUpstreams()
|
||||
if err != nil {
|
||||
log.Error("dnsforward: %s", err)
|
||||
s.logger.ErrorContext(ctx, "getting local ptr upstreams", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
return &jsonDNSConfig{
|
||||
@ -240,7 +240,7 @@ func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
|
||||
|
||||
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
|
||||
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
resp := s.getDNSConfig()
|
||||
resp := s.getDNSConfig(r.Context())
|
||||
aghhttp.WriteJSONResponseOK(w, r, resp)
|
||||
}
|
||||
|
||||
@ -486,6 +486,8 @@ func checkInclusion(ptr *int, minN, maxN int) (err error) {
|
||||
|
||||
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
|
||||
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
req := &jsonDNSConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
@ -511,10 +513,10 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
restart := s.setConfig(req)
|
||||
s.conf.ConfigModified()
|
||||
s.conf.ConfModifier.Apply(ctx)
|
||||
|
||||
if restart {
|
||||
err = s.Reconfigure(nil)
|
||||
err = s.Reconfigure(ctx, nil)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
|
||||
}
|
||||
@ -726,7 +728,7 @@ func (s *Server) handleSetProtection(w http.ResponseWriter, r *http.Request) {
|
||||
s.dnsFilter.SetProtectionStatus(protectionReq.Enabled, disabledUntil)
|
||||
}()
|
||||
|
||||
s.conf.ConfigModified()
|
||||
s.conf.ConfModifier.Apply(r.Context())
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ import (
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
@ -87,14 +88,16 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
ConfModifier: agh.EmptyConfigModifier{},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
require.NoError(t, s.Start())
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
require.NoError(t, s.Start(testutil.ContextWithTimeout(t, testTimeout)))
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.Stop(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
defaultConf := s.conf
|
||||
|
||||
@ -137,7 +140,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
|
||||
t.Cleanup(w.Body.Reset)
|
||||
|
||||
s.conf = tc.conf()
|
||||
s.handleGetConfig(w, nil)
|
||||
s.handleGetConfig(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
|
||||
cType := w.Header().Get(httphdr.ContentType)
|
||||
assert.Equal(t, aghhttp.HdrValApplicationJSON, cType)
|
||||
@ -170,17 +173,19 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
EDNSClientSubnet: &EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: EmptyClientsContainer{},
|
||||
},
|
||||
ConfigModified: func() {},
|
||||
ServePlainDNS: true,
|
||||
ConfModifier: agh.EmptyConfigModifier{},
|
||||
ServePlainDNS: true,
|
||||
}
|
||||
s := createTestServer(t, filterConf, forwardConf)
|
||||
s.sysResolvers = &emptySysResolvers{}
|
||||
|
||||
defaultConf := s.conf
|
||||
|
||||
err := s.Start()
|
||||
err := s.Start(testutil.ContextWithTimeout(t, testTimeout))
|
||||
assert.NoError(t, err)
|
||||
testutil.CleanupAndRequireSuccess(t, s.Stop)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return s.Stop(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
@ -297,7 +302,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
|
||||
assert.Equal(t, tc.wantSet, strings.TrimSuffix(w.Body.String(), "\n"))
|
||||
w.Body.Reset()
|
||||
|
||||
s.handleGetConfig(w, nil)
|
||||
s.handleGetConfig(w, httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
assert.JSONEq(t, string(caseData.Want), w.Body.String())
|
||||
w.Body.Reset()
|
||||
})
|
||||
|
||||
@ -121,9 +121,7 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
|
||||
}
|
||||
|
||||
// process adds the resolved IP addresses to the domain's ipsets, if any.
|
||||
func (h *ipsetHandler) process(dctx *dnsContext) (rc resultCode) {
|
||||
// TODO(s.chzhen): Use passed context.
|
||||
ctx := context.TODO()
|
||||
func (h *ipsetHandler) process(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
h.logger.DebugContext(ctx, "started processing")
|
||||
defer h.logger.DebugContext(ctx, "finished processing")
|
||||
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
@ -62,7 +63,7 @@ func TestIpsetCtx_process(t *testing.T) {
|
||||
ictx := &ipsetHandler{
|
||||
logger: testLogger,
|
||||
}
|
||||
rc := ictx.process(dctx)
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
assert.Equal(t, resultCodeSuccess, rc)
|
||||
|
||||
err := ictx.close()
|
||||
@ -85,7 +86,7 @@ func TestIpsetCtx_process(t *testing.T) {
|
||||
logger: testLogger,
|
||||
}
|
||||
|
||||
rc := ictx.process(dctx)
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
assert.Equal(t, resultCodeSuccess, rc)
|
||||
assert.Equal(t, []net.IP{ip4}, m.ip4s)
|
||||
assert.Empty(t, m.ip6s)
|
||||
@ -110,7 +111,7 @@ func TestIpsetCtx_process(t *testing.T) {
|
||||
logger: testLogger,
|
||||
}
|
||||
|
||||
rc := ictx.process(dctx)
|
||||
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
assert.Equal(t, resultCodeSuccess, rc)
|
||||
assert.Empty(t, m.ip4s)
|
||||
assert.Equal(t, []net.IP{ip6}, m.ip6s)
|
||||
|
||||
@ -1,12 +1,13 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"slices"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -47,6 +48,7 @@ func ipsFromRules(resRules []*filtering.ResultRule) (ips []netip.Addr) {
|
||||
// genDNSFilterMessage generates a filtered response to req for the filtering
|
||||
// result res.
|
||||
func (s *Server) genDNSFilterMessage(
|
||||
ctx context.Context,
|
||||
dctx *proxy.DNSContext,
|
||||
res *filtering.Result,
|
||||
) (resp *dns.Msg) {
|
||||
@ -63,22 +65,27 @@ func (s *Server) genDNSFilterMessage(
|
||||
|
||||
switch res.Reason {
|
||||
case filtering.FilteredSafeBrowsing:
|
||||
return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
|
||||
return s.genBlockedHost(ctx, req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
|
||||
case filtering.FilteredParental:
|
||||
return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost(), dctx)
|
||||
return s.genBlockedHost(ctx, req, s.dnsFilter.ParentalBlockHost(), dctx)
|
||||
case filtering.FilteredSafeSearch:
|
||||
// If Safe Search generated the necessary IP addresses, use them.
|
||||
// Otherwise, if there were no errors, there are no addresses for the
|
||||
// requested IP version, so produce a NODATA response.
|
||||
return s.getCNAMEWithIPs(req, ipsFromRules(res.Rules), res.CanonName)
|
||||
return s.getCNAMEWithIPs(ctx, req, ipsFromRules(res.Rules), res.CanonName)
|
||||
default:
|
||||
return s.genForBlockingMode(req, ipsFromRules(res.Rules))
|
||||
return s.genForBlockingMode(ctx, req, ipsFromRules(res.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
// getCNAMEWithIPs generates a filtered response to req for with CNAME record
|
||||
// and provided ips.
|
||||
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
|
||||
func (s *Server) getCNAMEWithIPs(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
ips []netip.Addr,
|
||||
cname string,
|
||||
) (resp *dns.Msg) {
|
||||
resp = s.replyCompressed(req)
|
||||
|
||||
originalName := req.Question[0].Name
|
||||
@ -94,7 +101,7 @@ func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (
|
||||
|
||||
switch req.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
ans = append(ans, s.genAnswersWithIPv4s(req, ips)...)
|
||||
ans = append(ans, s.genAnswersWithIPv4s(ctx, req, ips)...)
|
||||
case dns.TypeAAAA:
|
||||
for _, ip := range ips {
|
||||
if ip.Is6() {
|
||||
@ -112,24 +119,28 @@ func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (
|
||||
|
||||
// genForBlockingMode generates a filtered response to req based on the server's
|
||||
// blocking mode.
|
||||
func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
|
||||
func (s *Server) genForBlockingMode(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
ips []netip.Addr,
|
||||
) (resp *dns.Msg) {
|
||||
switch mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode(); mode {
|
||||
case filtering.BlockingModeCustomIP:
|
||||
return s.makeResponseCustomIP(req, bIPv4, bIPv6)
|
||||
return s.makeResponseCustomIP(ctx, req, bIPv4, bIPv6)
|
||||
case filtering.BlockingModeDefault:
|
||||
if len(ips) > 0 {
|
||||
return s.genResponseWithIPs(req, ips)
|
||||
return s.genResponseWithIPs(ctx, req, ips)
|
||||
}
|
||||
|
||||
return s.makeResponseNullIP(req)
|
||||
return s.makeResponseNullIP(ctx, req)
|
||||
case filtering.BlockingModeNullIP:
|
||||
return s.makeResponseNullIP(req)
|
||||
return s.makeResponseNullIP(ctx, req)
|
||||
case filtering.BlockingModeNXDOMAIN:
|
||||
return s.NewMsgNXDOMAIN(req)
|
||||
case filtering.BlockingModeREFUSED:
|
||||
return s.makeResponseREFUSED(req)
|
||||
default:
|
||||
log.Error("dnsforward: invalid blocking mode %q", mode)
|
||||
s.logger.ErrorContext(ctx, "invalid blocking mode", "mode", mode)
|
||||
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
@ -138,6 +149,7 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
// makeResponseCustomIP generates a DNS response message for Custom IP blocking
|
||||
// mode with the provided IP addresses and an appropriate resource record type.
|
||||
func (s *Server) makeResponseCustomIP(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
bIPv4 netip.Addr,
|
||||
bIPv6 netip.Addr,
|
||||
@ -150,7 +162,11 @@ func (s *Server) makeResponseCustomIP(
|
||||
default:
|
||||
// Generally shouldn't happen, since the types are checked in
|
||||
// genDNSFilterMessage.
|
||||
log.Error("dnsforward: invalid msg type %s for custom IP blocking mode", dns.Type(qt))
|
||||
s.logger.ErrorContext(
|
||||
ctx,
|
||||
"invalid message type for custom IP blocking mode",
|
||||
"dns_type", dns.Type(qt),
|
||||
)
|
||||
|
||||
return s.replyCompressed(req)
|
||||
}
|
||||
@ -234,11 +250,15 @@ func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
|
||||
// addresses and an appropriate resource record type. If any of the IPs cannot
|
||||
// be converted to the correct protocol, genResponseWithIPs returns an empty
|
||||
// response.
|
||||
func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
|
||||
func (s *Server) genResponseWithIPs(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
ips []netip.Addr,
|
||||
) (resp *dns.Msg) {
|
||||
var ans []dns.RR
|
||||
switch req.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
ans = s.genAnswersWithIPv4s(req, ips)
|
||||
ans = s.genAnswersWithIPv4s(ctx, req, ips)
|
||||
case dns.TypeAAAA:
|
||||
for _, ip := range ips {
|
||||
if ip.Is6() {
|
||||
@ -258,10 +278,14 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
|
||||
// genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any
|
||||
// of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and
|
||||
// returns nil,
|
||||
func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) {
|
||||
func (s *Server) genAnswersWithIPv4s(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
ips []netip.Addr,
|
||||
) (ans []dns.RR) {
|
||||
for _, ip := range ips {
|
||||
if !ip.Is4() {
|
||||
log.Info("dnsforward: warning: ip %s is not ipv4 address", ip)
|
||||
s.logger.WarnContext(ctx, "ip is not an ipv4 address", "ip", ip)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -274,16 +298,16 @@ func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.
|
||||
|
||||
// makeResponseNullIP creates a response with 0.0.0.0 for A requests, :: for
|
||||
// AAAA requests, and an empty response for other types.
|
||||
func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
func (s *Server) makeResponseNullIP(ctx context.Context, req *dns.Msg) (resp *dns.Msg) {
|
||||
// Respond with the corresponding zero IP type as opposed to simply
|
||||
// using one or the other in both cases, because the IPv4 zero IP is
|
||||
// converted to a IPV6-mapped IPv4 address, while the IPv6 zero IP is
|
||||
// converted into an empty slice instead of the zero IPv4.
|
||||
switch req.Question[0].Qtype {
|
||||
case dns.TypeA:
|
||||
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv4Unspecified()})
|
||||
resp = s.genResponseWithIPs(ctx, req, []netip.Addr{netip.IPv4Unspecified()})
|
||||
case dns.TypeAAAA:
|
||||
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
|
||||
resp = s.genResponseWithIPs(ctx, req, []netip.Addr{netip.IPv6Unspecified()})
|
||||
default:
|
||||
resp = s.replyCompressed(req)
|
||||
}
|
||||
@ -291,16 +315,21 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
|
||||
return resp
|
||||
}
|
||||
|
||||
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
|
||||
func (s *Server) genBlockedHost(
|
||||
ctx context.Context,
|
||||
request *dns.Msg,
|
||||
newAddr string,
|
||||
d *proxy.DNSContext,
|
||||
) (msg *dns.Msg) {
|
||||
if newAddr == "" {
|
||||
log.Info("dnsforward: block host is not specified")
|
||||
s.logger.InfoContext(ctx, "block host not specified")
|
||||
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(newAddr)
|
||||
if err == nil {
|
||||
return s.genResponseWithIPs(request, []netip.Addr{ip})
|
||||
return s.genResponseWithIPs(ctx, request, []netip.Addr{ip})
|
||||
}
|
||||
|
||||
// look up the hostname, TODO: cache
|
||||
@ -316,14 +345,19 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
|
||||
|
||||
prx := s.proxy()
|
||||
if prx == nil {
|
||||
log.Debug("dnsforward: %s", srvClosedErr)
|
||||
s.logger.DebugContext(ctx, "getting current proxy", slogutil.KeyError, srvClosedErr)
|
||||
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
err = prx.Resolve(newContext)
|
||||
if err != nil {
|
||||
log.Info("dnsforward: looking up replacement host %q: %s", newAddr, err)
|
||||
s.logger.ErrorContext(
|
||||
ctx,
|
||||
"looking up replacement host",
|
||||
"host", newAddr,
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return s.NewMsgSERVFAIL(request)
|
||||
}
|
||||
|
||||
@ -10,7 +10,6 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
@ -83,13 +82,16 @@ const ddrHostFQDN = "_dns.resolver.arpa."
|
||||
|
||||
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
|
||||
func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
dctx := &dnsContext{
|
||||
proxyCtx: pctx,
|
||||
result: &filtering.Result{},
|
||||
startTime: time.Now(),
|
||||
}
|
||||
|
||||
type modProcessFunc func(ctx *dnsContext) (rc resultCode)
|
||||
type modProcessFunc func(ctx context.Context, dctx *dnsContext) (rc resultCode)
|
||||
|
||||
// Since (*dnsforward.Server).handleDNSRequest(...) is used as
|
||||
// proxy.(Config).RequestHandler, there is no need for additional index
|
||||
@ -108,7 +110,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
|
||||
s.processQueryLogsAndStats,
|
||||
}
|
||||
for _, process := range mods {
|
||||
r := process(dctx)
|
||||
r := process(ctx, dctx)
|
||||
switch r {
|
||||
case resultCodeSuccess:
|
||||
// continue: call the next filter
|
||||
@ -149,12 +151,12 @@ const healthcheckFQDN = "healthcheck.adguardhome.test."
|
||||
// needed and enriches dctx with some client-specific information.
|
||||
//
|
||||
// TODO(e.burkov): Decompose into less general processors.
|
||||
func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing initial")
|
||||
defer log.Debug("dnsforward: finished processing initial")
|
||||
func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing initial")
|
||||
defer s.logger.DebugContext(ctx, "finished processing initial")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
s.processClientIP(pctx.Addr.Addr())
|
||||
s.processClientIP(ctx, pctx.Addr.Addr())
|
||||
|
||||
q := pctx.Req.Question[0]
|
||||
qt := q.Qtype
|
||||
@ -184,16 +186,16 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
|
||||
dctx.clientID = string(s.clientIDCache.Get(key[:]))
|
||||
|
||||
// Get the client-specific filtering settings.
|
||||
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
|
||||
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus(ctx)
|
||||
dctx.setts = s.clientRequestFilteringSettings(dctx)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
// processClientIP sends the client IP address to s.addrProc, if needed.
|
||||
func (s *Server) processClientIP(addr netip.Addr) {
|
||||
func (s *Server) processClientIP(ctx context.Context, addr netip.Addr) {
|
||||
if !addr.IsValid() {
|
||||
log.Info("dnsforward: warning: bad client addr %q", addr)
|
||||
s.logger.WarnContext(ctx, "bad client address", "addr", addr)
|
||||
|
||||
return
|
||||
}
|
||||
@ -203,8 +205,7 @@ func (s *Server) processClientIP(addr netip.Addr) {
|
||||
s.serverLock.RLock()
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
s.addrProc.Process(context.TODO(), addr)
|
||||
s.addrProc.Process(ctx, addr)
|
||||
}
|
||||
|
||||
// processDDRQuery responds to Discovery of Designated Resolvers (DDR) SVCB
|
||||
@ -212,9 +213,9 @@ func (s *Server) processClientIP(addr netip.Addr) {
|
||||
// current user configuration.
|
||||
//
|
||||
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
|
||||
func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing ddr")
|
||||
defer log.Debug("dnsforward: finished processing ddr")
|
||||
func (s *Server) processDDRQuery(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing ddr")
|
||||
defer s.logger.DebugContext(ctx, "finished processing ddr")
|
||||
|
||||
if !s.conf.HandleDDR {
|
||||
return resultCodeSuccess
|
||||
@ -311,9 +312,9 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
|
||||
// the request is for AAAA.
|
||||
//
|
||||
// TODO(a.garipov): Adapt to AAAA as well.
|
||||
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing dhcp hosts")
|
||||
defer log.Debug("dnsforward: finished processing dhcp hosts")
|
||||
func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing dhcp hosts")
|
||||
defer s.logger.DebugContext(ctx, "finished processing dhcp hosts")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
@ -325,7 +326,12 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
if !pctx.IsPrivateClient {
|
||||
log.Debug("dnsforward: %q requests for dhcp host %q", pctx.Addr, dhcpHost)
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"requests for dhcp host",
|
||||
"addr", pctx.Addr,
|
||||
"dhcp_host", dhcpHost,
|
||||
)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(req)
|
||||
|
||||
// Do not even put into query log.
|
||||
@ -336,12 +342,12 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
if ip == (netip.Addr{}) {
|
||||
// Go on and process them with filters, including dnsrewrite ones, and
|
||||
// possibly route them to a domain-specific upstream.
|
||||
log.Debug("dnsforward: no dhcp record for %q", dhcpHost)
|
||||
s.logger.DebugContext(ctx, "no dhcp record", "dhcp_host", dhcpHost)
|
||||
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: dhcp record for %q is %s", dhcpHost, ip)
|
||||
s.logger.DebugContext(ctx, "dhcp record for", "dhcp_host", dhcpHost, "ip", ip)
|
||||
|
||||
resp := s.replyCompressed(req)
|
||||
switch q.Qtype {
|
||||
@ -372,9 +378,9 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
|
||||
|
||||
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
|
||||
// DHCP server.
|
||||
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing dhcp addrs")
|
||||
defer log.Debug("dnsforward: finished processing dhcp addrs")
|
||||
func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing dhcp addrs")
|
||||
defer s.logger.DebugContext(ctx, "finished processing dhcp addrs")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if pctx.Res != nil {
|
||||
@ -396,7 +402,7 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
log.Debug("dnsforward: dhcp client %s is %q", addr, host)
|
||||
s.logger.DebugContext(ctx, "dhcp client", "addr", addr, "host", host)
|
||||
|
||||
resp := s.replyCompressed(req)
|
||||
ptr := &dns.PTR{
|
||||
@ -417,9 +423,12 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
|
||||
}
|
||||
|
||||
// Apply filtering logic
|
||||
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering before req")
|
||||
defer log.Debug("dnsforward: finished processing filtering before req")
|
||||
func (s *Server) processFilteringBeforeRequest(
|
||||
ctx context.Context,
|
||||
dctx *dnsContext,
|
||||
) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing filtering before request")
|
||||
defer s.logger.DebugContext(ctx, "finished processing filtering before request")
|
||||
|
||||
if dctx.proxyCtx.RequestedPrivateRDNS != (netip.Prefix{}) {
|
||||
// There is no need to filter request for locally served ARPA hostname
|
||||
@ -439,7 +448,7 @@ func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode)
|
||||
defer s.serverLock.RUnlock()
|
||||
|
||||
var err error
|
||||
if dctx.result, err = s.filterDNSRequest(dctx); err != nil {
|
||||
if dctx.result, err = s.filterDNSRequest(ctx, dctx); err != nil {
|
||||
dctx.err = err
|
||||
|
||||
return resultCodeError
|
||||
@ -458,9 +467,9 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
|
||||
}
|
||||
|
||||
// processUpstream passes request to upstream servers and handles the response.
|
||||
func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing upstream")
|
||||
defer log.Debug("dnsforward: finished processing upstream")
|
||||
func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing upstream")
|
||||
defer s.logger.DebugContext(ctx, "finished processing upstream")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
req := pctx.Req
|
||||
@ -475,13 +484,17 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
|
||||
// TODO(a.garipov): Route such queries to a custom upstream for the
|
||||
// local domain name if there is one.
|
||||
name := req.Question[0].Name
|
||||
log.Debug("dnsforward: dhcp client hostname %q was not filtered", name[:len(name)-1])
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"dhcp client hostname was not filtered",
|
||||
"hostname", name[:len(name)-1],
|
||||
)
|
||||
pctx.Res = s.NewMsgNXDOMAIN(req)
|
||||
|
||||
return resultCodeFinish
|
||||
}
|
||||
|
||||
s.setCustomUpstream(pctx, dctx.clientID)
|
||||
s.setCustomUpstream(ctx, pctx, dctx.clientID)
|
||||
|
||||
reqWantsDNSSEC := s.setReqAD(req)
|
||||
|
||||
@ -571,7 +584,7 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
|
||||
}
|
||||
|
||||
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
|
||||
func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext, clientID string) {
|
||||
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
|
||||
return
|
||||
}
|
||||
@ -579,10 +592,11 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
cliAddr := pctx.Addr.Addr()
|
||||
upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr)
|
||||
if upsConf != nil {
|
||||
log.Debug(
|
||||
"dnsforward: using custom upstreams for client with ip %s and clientid %q",
|
||||
cliAddr,
|
||||
clientID,
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"using custom upstreams for client with",
|
||||
"ip", cliAddr,
|
||||
"client_id", clientID,
|
||||
)
|
||||
|
||||
pctx.CustomUpstreamConfig = upsConf
|
||||
@ -590,9 +604,9 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
|
||||
}
|
||||
|
||||
// Apply filtering logic after we have received response from upstream servers
|
||||
func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing filtering after resp")
|
||||
defer log.Debug("dnsforward: finished processing filtering after resp")
|
||||
func (s *Server) processFilteringAfterResponse(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing filtering after response")
|
||||
defer s.logger.DebugContext(ctx, "finished processing filtering after response")
|
||||
|
||||
switch res := dctx.result; res.Reason {
|
||||
case filtering.NotFilteredAllowList:
|
||||
@ -617,13 +631,13 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
|
||||
|
||||
return resultCodeSuccess
|
||||
default:
|
||||
return s.filterAfterResponse(dctx)
|
||||
return s.filterAfterResponse(ctx, dctx)
|
||||
}
|
||||
}
|
||||
|
||||
// filterAfterResponse returns the result of filtering the response that wasn't
|
||||
// explicitly allowed or rewritten.
|
||||
func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
|
||||
func (s *Server) filterAfterResponse(ctx context.Context, dctx *dnsContext) (res resultCode) {
|
||||
// Check the response only if it's from an upstream. Don't check the
|
||||
// response if the protection is disabled since dnsrewrite rules aren't
|
||||
// applied to it anyway.
|
||||
@ -631,7 +645,7 @@ func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
|
||||
return resultCodeSuccess
|
||||
}
|
||||
|
||||
err := s.filterDNSResponse(dctx)
|
||||
err := s.filterDNSResponse(ctx, dctx)
|
||||
if err != nil {
|
||||
dctx.err = err
|
||||
|
||||
|
||||
@ -105,7 +105,7 @@ func TestServer_ProcessInitial(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
gotRC := s.processInitial(dctx)
|
||||
gotRC := s.processInitial(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, testClientAddrPort.Addr(), gotAddr)
|
||||
|
||||
@ -208,8 +208,8 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
|
||||
Addr: testClientAddrPort,
|
||||
},
|
||||
}
|
||||
|
||||
gotRC := s.processFilteringAfterResponse(dctx)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
gotRC := s.processFilteringAfterResponse(ctx, dctx)
|
||||
assert.Equal(t, tc.wantRC, gotRC)
|
||||
assert.Equal(t, newResp(dns.RcodeSuccess, tc.req, tc.wantRespAns), dctx.proxyCtx.Res)
|
||||
})
|
||||
@ -353,7 +353,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
res := s.processDDRQuery(dctx)
|
||||
res := s.processDDRQuery(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
require.Equal(t, tc.wantRes, res)
|
||||
|
||||
if tc.wantRes != resultCodeFinish {
|
||||
@ -440,6 +440,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
dhcpServer: dhcp,
|
||||
localDomainSuffix: localDomainSuffix,
|
||||
baseLogger: testLogger,
|
||||
logger: testLogger,
|
||||
}
|
||||
|
||||
req := &dns.Msg{
|
||||
@ -460,7 +461,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
res := s.processDHCPHosts(dctx)
|
||||
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
if !tc.isLocalCli {
|
||||
@ -592,6 +593,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
dhcpServer: testDHCP,
|
||||
localDomainSuffix: tc.suffix,
|
||||
baseLogger: testLogger,
|
||||
logger: testLogger,
|
||||
}
|
||||
|
||||
req := (&dns.Msg{}).SetQuestion(dns.Fqdn(tc.host), tc.qtyp)
|
||||
@ -604,7 +606,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := s.processDHCPHosts(dctx)
|
||||
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
pctx := dctx.proxyCtx
|
||||
assert.Equal(t, tc.wantRes, res)
|
||||
require.NoError(t, dctx.err)
|
||||
@ -812,9 +814,9 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
ServePlainDNS: true,
|
||||
},
|
||||
)
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
pctx := newPrxCtx()
|
||||
|
||||
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
|
||||
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeSuccess, rc)
|
||||
require.NotEmpty(t, pctx.Res.Answer)
|
||||
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
|
||||
@ -844,7 +846,8 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
|
||||
)
|
||||
pctx := newPrxCtx()
|
||||
|
||||
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
|
||||
require.Equal(t, resultCodeError, rc)
|
||||
require.Empty(t, pctx.Res.Answer)
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
@ -9,14 +10,13 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Write Stats data and logs
|
||||
func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
log.Debug("dnsforward: started processing querylog and stats")
|
||||
defer log.Debug("dnsforward: finished processing querylog and stats")
|
||||
func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext) (rc resultCode) {
|
||||
s.logger.DebugContext(ctx, "started processing querylog and stats")
|
||||
defer s.logger.DebugContext(ctx, "finished processing querylog and stats")
|
||||
|
||||
pctx := dctx.proxyCtx
|
||||
q := pctx.Req.Question[0]
|
||||
@ -27,7 +27,7 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
s.anonymizer.Load()(ip)
|
||||
ipStr := net.IP(ip).String()
|
||||
|
||||
log.Debug("dnsforward: client ip for stats and querylog: %s", ipStr)
|
||||
s.logger.DebugContext(ctx, "client ip for stats and querylog", "ip", ipStr)
|
||||
|
||||
ids := []string{ipStr}
|
||||
if dctx.clientID != "" {
|
||||
@ -47,24 +47,26 @@ func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
|
||||
if s.shouldLog(host, qt, cl, ids) {
|
||||
s.logQuery(dctx, ip, processingTime)
|
||||
} else {
|
||||
log.Debug(
|
||||
"dnsforward: request %s %s %q from %s ignored; not adding to querylog",
|
||||
dns.Class(cl),
|
||||
dns.Type(qt),
|
||||
host,
|
||||
ipStr,
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"not adding to querylog",
|
||||
"dns_class", dns.Class(cl),
|
||||
"dns_type", dns.Type(qt),
|
||||
"host", host,
|
||||
"ip", ipStr,
|
||||
)
|
||||
}
|
||||
|
||||
if s.shouldCountStat(host, qt, cl, ids) {
|
||||
s.updateStats(dctx, ipStr, processingTime)
|
||||
} else {
|
||||
log.Debug(
|
||||
"dnsforward: request %s %s %q from %s ignored; not counting in stats",
|
||||
dns.Class(cl),
|
||||
dns.Type(qt),
|
||||
host,
|
||||
ipStr,
|
||||
s.logger.DebugContext(
|
||||
ctx,
|
||||
"not counting in stats",
|
||||
"dns_class", dns.Class(cl),
|
||||
"dns_type", dns.Type(qt),
|
||||
"host", host,
|
||||
"ip", ipStr,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/stats"
|
||||
"github.com/AdguardTeam/dnsproxy/proxy"
|
||||
"github.com/AdguardTeam/dnsproxy/upstream"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -203,6 +204,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
st := &testStats{}
|
||||
srv := &Server{
|
||||
baseLogger: testLogger,
|
||||
logger: testLogger,
|
||||
queryLog: ql,
|
||||
stats: st,
|
||||
anonymizer: aghnet.NewIPMut(nil),
|
||||
@ -229,7 +231,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
|
||||
clientID: tc.clientID,
|
||||
}
|
||||
|
||||
code := srv.processQueryLogsAndStats(dctx)
|
||||
code := srv.processQueryLogsAndStats(testutil.ContextWithTimeout(t, testTimeout), dctx)
|
||||
assert.Equal(t, tc.wantCode, code)
|
||||
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
|
||||
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package dnsforward
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"net"
|
||||
"strconv"
|
||||
@ -14,9 +15,9 @@ import (
|
||||
//
|
||||
// See the comment on genAnswerSVCB for a list of current restrictions on
|
||||
// parameter values.
|
||||
func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) {
|
||||
func (s *Server) genAnswerHTTPS(ctx context.Context, req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) {
|
||||
ans = &dns.HTTPS{
|
||||
SVCB: *s.genAnswerSVCB(req, svcb),
|
||||
SVCB: *s.genAnswerSVCB(ctx, req, svcb),
|
||||
}
|
||||
|
||||
ans.Hdr.Rrtype = dns.TypeHTTPS
|
||||
@ -163,7 +164,11 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
|
||||
// ipv4hint="127.0.0.1,127.0.0.2" // Unsupported.
|
||||
//
|
||||
// TODO(a.garipov): Support all of these.
|
||||
func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB) {
|
||||
func (s *Server) genAnswerSVCB(
|
||||
ctx context.Context,
|
||||
req *dns.Msg,
|
||||
svcb *rules.DNSSVCB,
|
||||
) (ans *dns.SVCB) {
|
||||
ans = &dns.SVCB{
|
||||
Hdr: s.hdr(req, dns.TypeSVCB),
|
||||
Priority: svcb.Priority,
|
||||
@ -177,7 +182,7 @@ func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB
|
||||
for k, valStr := range svcb.Params {
|
||||
handler, ok := svcbKeyHandlers[k]
|
||||
if !ok {
|
||||
log.Debug("unknown svcb/https key %q, ignoring", k)
|
||||
s.logger.DebugContext(ctx, "unknown svcb/https key, ignoring", "key", k)
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/urlfilter/rules"
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -152,14 +153,14 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
|
||||
want := &dns.HTTPS{SVCB: *tc.want}
|
||||
want.Hdr.Rrtype = dns.TypeHTTPS
|
||||
|
||||
got := s.genAnswerHTTPS(req, tc.svcb)
|
||||
got := s.genAnswerHTTPS(testutil.ContextWithTimeout(t, testTimeout), req, tc.svcb)
|
||||
assert.Equal(t, want, got)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("svcb", func(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := s.genAnswerSVCB(req, tc.svcb)
|
||||
got := s.genAnswerSVCB(testutil.ContextWithTimeout(t, testTimeout), req, tc.svcb)
|
||||
assert.Equal(t, tc.want, got)
|
||||
})
|
||||
})
|
||||
|
||||
@ -159,6 +159,8 @@ func (d *DNSFilter) handleBlockedServicesList(w http.ResponseWriter, r *http.Req
|
||||
//
|
||||
// Deprecated: Use handleBlockedServicesUpdate.
|
||||
func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
list := []string{}
|
||||
err := json.NewDecoder(r.Body).Decode(&list)
|
||||
if err != nil {
|
||||
@ -172,10 +174,10 @@ func (d *DNSFilter) handleBlockedServicesSet(w http.ResponseWriter, r *http.Requ
|
||||
defer d.confMu.Unlock()
|
||||
|
||||
d.conf.BlockedServices.IDs = list
|
||||
d.logger.DebugContext(r.Context(), "updated blocked services list", "len", len(list))
|
||||
d.logger.DebugContext(ctx, "updated blocked services list", "len", len(list))
|
||||
}()
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
// handleBlockedServicesGet is the handler for the GET
|
||||
@ -195,6 +197,8 @@ func (d *DNSFilter) handleBlockedServicesGet(w http.ResponseWriter, r *http.Requ
|
||||
// handleBlockedServicesUpdate is the handler for the PUT
|
||||
// /control/blocked_services/update HTTP API.
|
||||
func (d *DNSFilter) handleBlockedServicesUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
bsvc := &BlockedServices{}
|
||||
err := json.NewDecoder(r.Body).Decode(bsvc)
|
||||
if err != nil {
|
||||
@ -221,7 +225,7 @@ func (d *DNSFilter) handleBlockedServicesUpdate(w http.ResponseWriter, r *http.R
|
||||
d.conf.BlockedServices = bsvc
|
||||
}()
|
||||
|
||||
d.logger.DebugContext(r.Context(), "updated blocked services schedule", "len", len(bsvc.IDs))
|
||||
d.logger.DebugContext(ctx, "updated blocked services schedule", "len", len(bsvc.IDs))
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
|
||||
@ -104,8 +105,9 @@ type Config struct {
|
||||
// TODO(e.burkov): Move it to dnsforward entirely.
|
||||
EtcHosts hostsfile.Storage `yaml:"-"`
|
||||
|
||||
// Called when the configuration is changed by HTTP request
|
||||
ConfigModified func() `yaml:"-"`
|
||||
// ConfModifier is used to update the global configuration. It must not be
|
||||
// nil.
|
||||
ConfModifier agh.ConfigModifier `yaml:"-"`
|
||||
|
||||
// Register an HTTP handler
|
||||
HTTPRegister aghhttp.RegisterFunc `yaml:"-"`
|
||||
|
||||
@ -133,7 +133,7 @@ func (d *DNSFilter) handleFilteringAddURL(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
d.EnableFilters(true)
|
||||
|
||||
_, err = fmt.Fprintf(w, "OK %d rules\n", filt.RulesCount)
|
||||
@ -202,7 +202,7 @@ func (d *DNSFilter) handleFilteringRemoveURL(w http.ResponseWriter, r *http.Requ
|
||||
d.logger.InfoContext(ctx, "deleted filter", "id", deleted.ID)
|
||||
}()
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(ctx)
|
||||
d.EnableFilters(true)
|
||||
|
||||
// NOTE: The old files "filter.txt.old" aren't deleted. It's not really
|
||||
@ -264,7 +264,7 @@ func (d *DNSFilter) handleFilteringSetURL(w http.ResponseWriter, r *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
if restart {
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
@ -289,7 +289,7 @@ func (d *DNSFilter) handleFilteringSetRules(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
d.conf.UserRules = req.Rules
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
|
||||
@ -403,7 +403,7 @@ func (d *DNSFilter) handleFilteringConfig(w http.ResponseWriter, r *http.Request
|
||||
d.conf.FiltersUpdateIntervalHours = req.Interval
|
||||
}()
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
d.EnableFilters(true)
|
||||
}
|
||||
|
||||
@ -571,14 +571,14 @@ func protectedBool(mu *sync.RWMutex, ptr *bool) (val bool) {
|
||||
// /control/safebrowsing/enable HTTP API.
|
||||
func (d *DNSFilter) handleSafeBrowsingEnable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(d.confMu, &d.conf.SafeBrowsingEnabled, true)
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
}
|
||||
|
||||
// handleSafeBrowsingDisable is the handler for the POST
|
||||
// /control/safebrowsing/disable HTTP API.
|
||||
func (d *DNSFilter) handleSafeBrowsingDisable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(d.confMu, &d.conf.SafeBrowsingEnabled, false)
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
}
|
||||
|
||||
// handleSafeBrowsingStatus is the handler for the GET
|
||||
@ -597,14 +597,14 @@ func (d *DNSFilter) handleSafeBrowsingStatus(w http.ResponseWriter, r *http.Requ
|
||||
// HTTP API.
|
||||
func (d *DNSFilter) handleParentalEnable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(d.confMu, &d.conf.ParentalEnabled, true)
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
}
|
||||
|
||||
// handleParentalDisable is the handler for the POST /control/parental/disable
|
||||
// HTTP API.
|
||||
func (d *DNSFilter) handleParentalDisable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(d.confMu, &d.conf.ParentalEnabled, false)
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
}
|
||||
|
||||
// handleParentalStatus is the handler for the GET /control/parental/status
|
||||
|
||||
@ -2,6 +2,7 @@ package filtering
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -11,6 +12,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -103,6 +105,10 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
confModifiedCalled := false
|
||||
confModifier := &aghtest.ConfigModifier{}
|
||||
confModifier.OnApply = func(_ context.Context) {
|
||||
confModifiedCalled = true
|
||||
}
|
||||
d, err := New(&Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
FilteringEnabled: true,
|
||||
@ -110,8 +116,8 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
|
||||
HTTPClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
},
|
||||
ConfigModified: func() { confModifiedCalled = true },
|
||||
DataDir: filtersDir,
|
||||
ConfModifier: confModifier,
|
||||
DataDir: filtersDir,
|
||||
}, nil)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(d.Close)
|
||||
@ -183,13 +189,15 @@ func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handlers := make(map[string]http.Handler)
|
||||
confModifier := &aghtest.ConfigModifier{}
|
||||
confModifier.OnApply = func(_ context.Context) {
|
||||
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
||||
}
|
||||
|
||||
d, err := New(&Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ConfigModified: func() {
|
||||
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
||||
},
|
||||
DataDir: filtersDir,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ConfModifier: confModifier,
|
||||
DataDir: filtersDir,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
@ -268,13 +276,15 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) {
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handlers := make(map[string]http.Handler)
|
||||
confModifier := &aghtest.ConfigModifier{}
|
||||
confModifier.OnApply = func(_ context.Context) {
|
||||
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
||||
}
|
||||
|
||||
d, err := New(&Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ConfigModified: func() {
|
||||
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
||||
},
|
||||
DataDir: filtersDir,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ConfModifier: confModifier,
|
||||
DataDir: filtersDir,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
|
||||
@ -36,6 +36,8 @@ func (d *DNSFilter) handleRewriteList(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// handleRewriteAdd is the handler for the POST /control/rewrite/add HTTP API.
|
||||
func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
rwJSON := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&rwJSON)
|
||||
if err != nil {
|
||||
@ -49,7 +51,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
Answer: rwJSON.Answer,
|
||||
}
|
||||
|
||||
err = rw.normalize(r.Context(), d.logger)
|
||||
err = rw.normalize(ctx, d.logger)
|
||||
if err != nil {
|
||||
// Shouldn't happen currently, since normalize only returns a non-nil
|
||||
// error when a rewrite is nil, but be change-proof.
|
||||
@ -64,7 +66,7 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
d.conf.Rewrites = append(d.conf.Rewrites, rw)
|
||||
d.logger.DebugContext(
|
||||
r.Context(),
|
||||
ctx,
|
||||
"added rewrite element",
|
||||
"domain", rw.Domain,
|
||||
"answer", rw.Answer,
|
||||
@ -72,12 +74,14 @@ func (d *DNSFilter) handleRewriteAdd(w http.ResponseWriter, r *http.Request) {
|
||||
)
|
||||
}()
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
// handleRewriteDelete is the handler for the POST /control/rewrite/delete HTTP
|
||||
// API.
|
||||
func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
jsent := rewriteEntryJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&jsent)
|
||||
if err != nil {
|
||||
@ -92,28 +96,27 @@ func (d *DNSFilter) handleRewriteDelete(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
arr := []*LegacyRewrite{}
|
||||
|
||||
func() {
|
||||
d.confMu.Lock()
|
||||
defer d.confMu.Unlock()
|
||||
defer d.conf.ConfModifier.Apply(ctx)
|
||||
|
||||
for _, ent := range d.conf.Rewrites {
|
||||
if ent.equal(entDel) {
|
||||
d.logger.DebugContext(
|
||||
r.Context(),
|
||||
"removed rewrite element",
|
||||
"domain", ent.Domain,
|
||||
"answer", ent.Answer,
|
||||
)
|
||||
|
||||
continue
|
||||
}
|
||||
d.confMu.Lock()
|
||||
defer d.confMu.Unlock()
|
||||
|
||||
for _, ent := range d.conf.Rewrites {
|
||||
if !ent.equal(entDel) {
|
||||
arr = append(arr, ent)
|
||||
}
|
||||
d.conf.Rewrites = arr
|
||||
}()
|
||||
|
||||
d.conf.ConfigModified()
|
||||
continue
|
||||
}
|
||||
|
||||
d.logger.DebugContext(
|
||||
ctx,
|
||||
"removed rewrite element",
|
||||
"domain", ent.Domain,
|
||||
"answer", ent.Answer,
|
||||
)
|
||||
}
|
||||
|
||||
d.conf.Rewrites = arr
|
||||
}
|
||||
|
||||
// rewriteUpdateJSON is a struct for JSON object with rewrite rule update info.
|
||||
@ -125,6 +128,8 @@ type rewriteUpdateJSON struct {
|
||||
// handleRewriteUpdate is the handler for the PUT /control/rewrite/update HTTP
|
||||
// API.
|
||||
func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
updateJSON := rewriteUpdateJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&updateJSON)
|
||||
if err != nil {
|
||||
@ -143,7 +148,7 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request)
|
||||
Answer: updateJSON.Update.Answer,
|
||||
}
|
||||
|
||||
err = rwAdd.normalize(r.Context(), d.logger)
|
||||
err = rwAdd.normalize(ctx, d.logger)
|
||||
if err != nil {
|
||||
// Shouldn't happen currently, since normalize only returns a non-nil
|
||||
// error when a rewrite is nil, but be change-proof.
|
||||
@ -155,7 +160,7 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request)
|
||||
index := -1
|
||||
defer func() {
|
||||
if index >= 0 {
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -171,7 +176,6 @@ func (d *DNSFilter) handleRewriteUpdate(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
d.conf.Rewrites = slices.Replace(d.conf.Rewrites, index, index+1, rwAdd)
|
||||
|
||||
ctx := r.Context()
|
||||
d.logger.DebugContext(
|
||||
ctx,
|
||||
"removed rewrite element",
|
||||
|
||||
@ -2,6 +2,7 @@ package filtering_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
@ -9,6 +10,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -148,20 +150,17 @@ func TestDNSFilter_handleRewriteHTTP(t *testing.T) {
|
||||
}}
|
||||
|
||||
for _, tc := range testCases {
|
||||
onConfModified := func() {
|
||||
if !tc.wantConfMod {
|
||||
panic("config modified has been fired")
|
||||
}
|
||||
|
||||
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
||||
}
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
handlers := make(map[string]http.Handler)
|
||||
confModifier := &aghtest.ConfigModifier{}
|
||||
confModifier.OnApply = func(_ context.Context) {
|
||||
require.Truef(t, tc.wantConfMod, "config modified has been fired")
|
||||
testutil.RequireSend(testutil.PanicT{}, confModCh, struct{}{}, testTimeout)
|
||||
}
|
||||
|
||||
d, err := filtering.New(&filtering.Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ConfigModified: onConfModified,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
ConfModifier: confModifier,
|
||||
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
|
||||
handlers[url] = handler
|
||||
},
|
||||
|
||||
@ -13,7 +13,7 @@ import (
|
||||
// Deprecated: Use handleSafeSearchSettings.
|
||||
func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(d.confMu, &d.conf.SafeSearchConf.Enabled, true)
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
}
|
||||
|
||||
// handleSafeSearchDisable is the handler for POST /control/safesearch/disable
|
||||
@ -22,7 +22,7 @@ func (d *DNSFilter) handleSafeSearchEnable(w http.ResponseWriter, r *http.Reques
|
||||
// Deprecated: Use handleSafeSearchSettings.
|
||||
func (d *DNSFilter) handleSafeSearchDisable(w http.ResponseWriter, r *http.Request) {
|
||||
setProtectedBool(d.confMu, &d.conf.SafeSearchConf.Enabled, false)
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(r.Context())
|
||||
}
|
||||
|
||||
// handleSafeSearchStatus is the handler for GET /control/safesearch/status
|
||||
@ -42,6 +42,8 @@ func (d *DNSFilter) handleSafeSearchStatus(w http.ResponseWriter, r *http.Reques
|
||||
// handleSafeSearchSettings is the handler for PUT /control/safesearch/settings
|
||||
// HTTP API.
|
||||
func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
req := &SafeSearchConfig{}
|
||||
err := json.NewDecoder(r.Body).Decode(req)
|
||||
if err != nil {
|
||||
@ -51,7 +53,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
|
||||
}
|
||||
|
||||
conf := *req
|
||||
err = d.safeSearch.Update(r.Context(), conf)
|
||||
err = d.safeSearch.Update(ctx, conf)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "updating: %s", err)
|
||||
|
||||
@ -65,7 +67,7 @@ func (d *DNSFilter) handleSafeSearchSettings(w http.ResponseWriter, r *http.Requ
|
||||
d.conf.SafeSearchConf = conf
|
||||
}()
|
||||
|
||||
d.conf.ConfigModified()
|
||||
d.conf.ConfModifier.Apply(ctx)
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghuser"
|
||||
"github.com/AdguardTeam/golibs/httphdr"
|
||||
@ -327,7 +328,17 @@ func TestAuth_ServeHTTP_firstRun(t *testing.T) {
|
||||
globalContext.mux = mux
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false)
|
||||
web, err := initWeb(
|
||||
ctx,
|
||||
options{},
|
||||
nil,
|
||||
nil,
|
||||
testLogger,
|
||||
nil,
|
||||
nil,
|
||||
agh.EmptyConfigModifier{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
@ -477,13 +488,23 @@ func TestAuth_ServeHTTP_auth(t *testing.T) {
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
tlsMgr, err := newTLSManager(testutil.ContextWithTimeout(t, testTimeout), &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
configModified: func() {},
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, tlsMgr, auth, false)
|
||||
web, err := initWeb(
|
||||
ctx,
|
||||
options{},
|
||||
nil,
|
||||
nil,
|
||||
testLogger,
|
||||
tlsMgr,
|
||||
auth,
|
||||
agh.EmptyConfigModifier{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
@ -625,7 +646,16 @@ func TestAuth_ServeHTTP_logout(t *testing.T) {
|
||||
globalContext.mux = http.NewServeMux()
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, auth, false)
|
||||
web, err := initWeb(ctx,
|
||||
options{},
|
||||
nil,
|
||||
nil,
|
||||
testLogger,
|
||||
nil,
|
||||
auth,
|
||||
agh.EmptyConfigModifier{},
|
||||
false,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.web = web
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/arpdb"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
@ -39,6 +40,10 @@ type clientsContainer struct {
|
||||
// settings.
|
||||
clientChecker BlockedClientChecker
|
||||
|
||||
// confModifier is used to update the global configuration. It must not be
|
||||
// nil.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// lock protects all fields.
|
||||
//
|
||||
// TODO(a.garipov): Use a pointer and describe which fields are protected in
|
||||
@ -52,11 +57,6 @@ type clientsContainer struct {
|
||||
// safeSearchCacheTTL is the TTL of the safe search cache to use for
|
||||
// persistent clients.
|
||||
safeSearchCacheTTL time.Duration
|
||||
|
||||
// testing is a flag that disables some features for internal tests.
|
||||
//
|
||||
// TODO(a.garipov): Awful. Remove.
|
||||
testing bool
|
||||
}
|
||||
|
||||
// BlockedClientChecker checks if a client is blocked by the current access
|
||||
@ -78,6 +78,7 @@ func (clients *clientsContainer) Init(
|
||||
arpDB arpdb.Interface,
|
||||
filteringConf *filtering.Config,
|
||||
sigHdlr *signalHandler,
|
||||
confModifier agh.ConfigModifier,
|
||||
) (err error) {
|
||||
// TODO(s.chzhen): Refactor it.
|
||||
if clients.storage != nil {
|
||||
@ -88,6 +89,7 @@ func (clients *clientsContainer) Init(
|
||||
clients.logger = baseLogger.With(slogutil.KeyPrefix, "client_container")
|
||||
clients.safeSearchCacheSize = filteringConf.SafeSearchCacheSize
|
||||
clients.safeSearchCacheTTL = time.Minute * time.Duration(filteringConf.CacheTime)
|
||||
clients.confModifier = confModifier
|
||||
|
||||
confClients := make([]*client.Persistent, 0, len(objects))
|
||||
for i, o := range objects {
|
||||
@ -141,10 +143,6 @@ var webHandlersRegistered = false
|
||||
|
||||
// Start starts the clients container.
|
||||
func (clients *clientsContainer) Start(ctx context.Context) (err error) {
|
||||
if clients.testing {
|
||||
return
|
||||
}
|
||||
|
||||
if !webHandlersRegistered {
|
||||
webHandlersRegistered = true
|
||||
clients.registerWebHandlers()
|
||||
|
||||
@ -3,6 +3,7 @@ package home
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -14,9 +15,7 @@ import (
|
||||
func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
t.Helper()
|
||||
|
||||
c = &clientsContainer{
|
||||
testing: true,
|
||||
}
|
||||
c = &clientsContainer{}
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
err := c.Init(
|
||||
@ -30,6 +29,7 @@ func newClientsContainer(t *testing.T) (c *clientsContainer) {
|
||||
Logger: testLogger,
|
||||
},
|
||||
newSignalHandler(testLogger, nil, nil),
|
||||
agh.EmptyConfigModifier{},
|
||||
)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -326,6 +326,8 @@ func clientToJSON(c *client.Persistent) (cj *clientJSON) {
|
||||
|
||||
// handleAddClient is the handler for POST /control/clients/add HTTP API.
|
||||
func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
if err != nil {
|
||||
@ -334,27 +336,27 @@ func (clients *clientsContainer) handleAddClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
c, err := clients.jsonToClient(r.Context(), cj, nil)
|
||||
c, err := clients.jsonToClient(ctx, cj, nil)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = clients.storage.Add(r.Context(), c)
|
||||
err = clients.storage.Add(ctx, c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
clients.confModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
// handleDelClient is the handler for POST /control/clients/delete HTTP API.
|
||||
func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
cj := clientJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&cj)
|
||||
if err != nil {
|
||||
@ -369,15 +371,13 @@ func (clients *clientsContainer) handleDelClient(w http.ResponseWriter, r *http.
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.storage.RemoveByName(r.Context(), cj.Name) {
|
||||
if !clients.storage.RemoveByName(ctx, cj.Name) {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "Client not found")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
clients.confModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
// updateJSON contains the name and data of the updated persistent client.
|
||||
@ -390,6 +390,8 @@ type updateJSON struct {
|
||||
//
|
||||
// TODO(s.chzhen): Accept updated parameters instead of whole structure.
|
||||
func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
dj := updateJSON{}
|
||||
err := json.NewDecoder(r.Body).Decode(&dj)
|
||||
if err != nil {
|
||||
@ -404,23 +406,21 @@ func (clients *clientsContainer) handleUpdateClient(w http.ResponseWriter, r *ht
|
||||
return
|
||||
}
|
||||
|
||||
c, err := clients.jsonToClient(r.Context(), dj.Data, nil)
|
||||
c, err := clients.jsonToClient(ctx, dj.Data, nil)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
err = clients.storage.Update(r.Context(), dj.Name, c)
|
||||
err = clients.storage.Update(ctx, dj.Name, c)
|
||||
if err != nil {
|
||||
aghhttp.Error(r, w, http.StatusBadRequest, "%s", err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !clients.testing {
|
||||
onConfigModified()
|
||||
}
|
||||
clients.confModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
// handleFindClient is the handler for GET /control/clients/find HTTP API.
|
||||
|
||||
@ -4,12 +4,14 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtls"
|
||||
@ -23,6 +25,7 @@ import (
|
||||
"github.com/AdguardTeam/dnsproxy/fastip"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@ -838,3 +841,47 @@ func validateTLSCipherIDs(cipherIDs []string) (err error) {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// defaultConfigModifier is a default [agh.ConfigModifier] implementation.
|
||||
type defaultConfigModifier struct {
|
||||
auth *auth
|
||||
config *configuration
|
||||
logger *slog.Logger
|
||||
tlsMgr *tlsManager
|
||||
}
|
||||
|
||||
// newDefaultConfigModifier returns the new properly initialized
|
||||
// *defaultConfigModifier. All arguments must not be nil.
|
||||
//
|
||||
// TODO(s.chzhen): Consider using configuration struct.
|
||||
func newDefaultConfigModifier(
|
||||
conf *configuration,
|
||||
l *slog.Logger,
|
||||
) (cm *defaultConfigModifier) {
|
||||
return &defaultConfigModifier{
|
||||
config: conf,
|
||||
logger: l,
|
||||
}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ agh.ConfigModifier = (*defaultConfigModifier)(nil)
|
||||
|
||||
// Apply implements the [agh.ConfigModifier] interface for
|
||||
// *defaultConfigModifier.
|
||||
func (cm *defaultConfigModifier) Apply(ctx context.Context) {
|
||||
err := cm.config.write(cm.tlsMgr, cm.auth)
|
||||
if err != nil {
|
||||
cm.logger.ErrorContext(ctx, "writing config", slogutil.KeyError, err)
|
||||
}
|
||||
}
|
||||
|
||||
// setAuth sets the auth parameters used by Apply.
|
||||
func (cm *defaultConfigModifier) setAuth(a *auth) {
|
||||
cm.auth = a
|
||||
}
|
||||
|
||||
// setTLSManager sets the TLS manager used by Apply.
|
||||
func (cm *defaultConfigModifier) setTLSManager(m *tlsManager) {
|
||||
cm.tlsMgr = m
|
||||
}
|
||||
|
||||
@ -115,6 +115,8 @@ type statusResponse struct {
|
||||
}
|
||||
|
||||
func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
dnsAddrs, err := collectDNSAddresses(web.tlsManager)
|
||||
if err != nil {
|
||||
// Don't add a lot of formatting, since the error is already
|
||||
@ -125,14 +127,14 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var (
|
||||
fltConf *dnsforward.Config
|
||||
protectionDisabledUntil *time.Time
|
||||
protectionEnabled bool
|
||||
fltConf *dnsforward.Config
|
||||
protDisabledUntil *time.Time
|
||||
protEnabled bool
|
||||
)
|
||||
if globalContext.dnsServer != nil {
|
||||
fltConf = &dnsforward.Config{}
|
||||
globalContext.dnsServer.WriteDiskConfig(fltConf)
|
||||
protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus()
|
||||
protEnabled, protDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus(ctx)
|
||||
}
|
||||
|
||||
var resp statusResponse
|
||||
@ -141,11 +143,11 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
defer config.RUnlock()
|
||||
|
||||
var protectionDisabledDuration int64
|
||||
if protectionDisabledUntil != nil {
|
||||
if protDisabledUntil != nil {
|
||||
// Make sure that we don't send negative numbers to the frontend,
|
||||
// since enough time might have passed to make the difference less
|
||||
// than zero.
|
||||
protectionDisabledDuration = max(0, time.Until(*protectionDisabledUntil).Milliseconds())
|
||||
protectionDisabledDuration = max(0, time.Until(*protDisabledUntil).Milliseconds())
|
||||
}
|
||||
|
||||
resp = statusResponse{
|
||||
@ -155,7 +157,7 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
|
||||
DNSPort: config.DNS.Port,
|
||||
HTTPPort: config.HTTPConfig.Address.Port(),
|
||||
ProtectionDisabledDuration: protectionDisabledDuration,
|
||||
ProtectionEnabled: protectionEnabled,
|
||||
ProtectionEnabled: protEnabled,
|
||||
IsRunning: isRunning(),
|
||||
}
|
||||
}()
|
||||
@ -175,10 +177,10 @@ func registerControlHandlers(web *webAPI) {
|
||||
httpRegister(http.MethodPost, "/control/update", web.handleUpdate)
|
||||
|
||||
httpRegister(http.MethodGet, "/control/status", web.handleStatus)
|
||||
httpRegister(http.MethodPost, "/control/i18n/change_language", handleI18nChangeLanguage)
|
||||
httpRegister(http.MethodPost, "/control/i18n/change_language", web.handleI18nChangeLanguage)
|
||||
httpRegister(http.MethodGet, "/control/i18n/current_language", handleI18nCurrentLanguage)
|
||||
httpRegister(http.MethodGet, "/control/profile", web.handleGetProfile)
|
||||
httpRegister(http.MethodPut, "/control/profile/update", handlePutProfile)
|
||||
httpRegister(http.MethodPut, "/control/profile/update", web.handlePutProfile)
|
||||
|
||||
// No auth is necessary for DoH/DoT configurations
|
||||
globalContext.mux.HandleFunc("/apple/doh.mobileconfig", postInstall(handleMobileConfigDoH))
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@ -455,7 +456,7 @@ func (web *webAPI) handleInstallConfigure(w http.ResponseWriter, r *http.Request
|
||||
// moment we'll allow setting up TLS in the initial configuration or the
|
||||
// configuration itself will use HTTPS protocol, because the underlying
|
||||
// functions potentially restart the HTTPS server.
|
||||
err = startMods(ctx, web.baseLogger, web.tlsManager)
|
||||
err = startMods(ctx, web.baseLogger, web.tlsManager, web.confModifier)
|
||||
if err != nil {
|
||||
globalContext.firstRun = true
|
||||
copyInstallSettings(config, curConfig)
|
||||
@ -532,13 +533,18 @@ func decodeApplyConfigReq(r io.Reader) (req *applyConfigReq, restartHTTP bool, e
|
||||
|
||||
// startMods initializes and starts the DNS server after installation.
|
||||
// baseLogger and tlsMgr must not be nil.
|
||||
func startMods(ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager) (err error) {
|
||||
func startMods(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
confModifier agh.ConfigModifier,
|
||||
) (err error) {
|
||||
statsDir, querylogDir, err := checkStatsAndQuerylogDirs(&globalContext, config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = initDNS(baseLogger, tlsMgr, statsDir, querylogDir)
|
||||
err = initDNS(ctx, baseLogger, tlsMgr, confModifier, statsDir, querylogDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -547,7 +553,7 @@ func startMods(ctx context.Context, baseLogger *slog.Logger, tlsMgr *tlsManager)
|
||||
|
||||
err = startDNSServer()
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
closeDNSServer(ctx)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@ -38,23 +39,15 @@ const (
|
||||
defaultPortTLS uint16 = 853
|
||||
)
|
||||
|
||||
// Called by other modules when configuration is changed
|
||||
//
|
||||
// TODO(s.chzhen): Remove this after refactoring.
|
||||
func onConfigModified() {
|
||||
err := config.write(globalContext.tls, globalContext.auth)
|
||||
if err != nil {
|
||||
log.Error("writing config: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// initDNS updates all the fields of the [globalContext] needed to initialize
|
||||
// the DNS server and initializes it at last. It also must not be called unless
|
||||
// [config] and [globalContext] are initialized. baseLogger and tlsMgr must not
|
||||
// be nil.
|
||||
// [config] and [globalContext] are initialized. baseLogger, tlsMgr and
|
||||
// confModfier must not be nil.
|
||||
func initDNS(
|
||||
ctx context.Context,
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
confModifier agh.ConfigModifier,
|
||||
statsDir string,
|
||||
querylogDir string,
|
||||
) (err error) {
|
||||
@ -64,7 +57,7 @@ func initDNS(
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "stats"),
|
||||
Filename: filepath.Join(statsDir, "stats.db"),
|
||||
Limit: time.Duration(config.Stats.Interval),
|
||||
ConfigModified: onConfigModified,
|
||||
ConfigModifier: confModifier,
|
||||
HTTPRegister: httpRegister,
|
||||
Enabled: config.Stats.Enabled,
|
||||
ShouldCountClient: globalContext.clients.shouldCountClient,
|
||||
@ -84,7 +77,7 @@ func initDNS(
|
||||
conf := querylog.Config{
|
||||
Logger: baseLogger.With(slogutil.KeyPrefix, "querylog"),
|
||||
Anonymizer: anonymizer,
|
||||
ConfigModified: onConfigModified,
|
||||
ConfigModifier: confModifier,
|
||||
HTTPRegister: httpRegister,
|
||||
FindClient: globalContext.clients.findMultiple,
|
||||
BaseDir: querylogDir,
|
||||
@ -113,6 +106,7 @@ func initDNS(
|
||||
}
|
||||
|
||||
return initDNSServer(
|
||||
ctx,
|
||||
globalContext.filters,
|
||||
globalContext.stats,
|
||||
globalContext.queryLog,
|
||||
@ -121,6 +115,7 @@ func initDNS(
|
||||
httpRegister,
|
||||
tlsMgr,
|
||||
baseLogger,
|
||||
confModifier,
|
||||
)
|
||||
}
|
||||
|
||||
@ -131,6 +126,7 @@ func initDNS(
|
||||
//
|
||||
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
|
||||
func initDNSServer(
|
||||
ctx context.Context,
|
||||
filters *filtering.DNSFilter,
|
||||
sts stats.Interface,
|
||||
qlog querylog.QueryLog,
|
||||
@ -139,6 +135,7 @@ func initDNSServer(
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
tlsMgr *tlsManager,
|
||||
l *slog.Logger,
|
||||
confModifier agh.ConfigModifier,
|
||||
) (err error) {
|
||||
globalContext.dnsServer, err = dnsforward.NewServer(dnsforward.DNSCreateParams{
|
||||
Logger: l,
|
||||
@ -153,7 +150,7 @@ func initDNSServer(
|
||||
})
|
||||
defer func() {
|
||||
if err != nil {
|
||||
closeDNSServer()
|
||||
closeDNSServer(ctx)
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
@ -169,6 +166,7 @@ func initDNSServer(
|
||||
tlsMgr,
|
||||
httpReg,
|
||||
globalContext.clients.storage,
|
||||
confModifier,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("newServerConfig: %w", err)
|
||||
@ -176,12 +174,12 @@ func initDNSServer(
|
||||
|
||||
// Try to prepare the server with disabled private RDNS resolution if it
|
||||
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
|
||||
err = globalContext.dnsServer.Prepare(dnsConf)
|
||||
err = globalContext.dnsServer.Prepare(ctx, dnsConf)
|
||||
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
|
||||
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
|
||||
|
||||
dnsConf.UsePrivateRDNS = false
|
||||
err = globalContext.dnsServer.Prepare(dnsConf)
|
||||
err = globalContext.dnsServer.Prepare(ctx, dnsConf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -245,6 +243,7 @@ func newServerConfig(
|
||||
tlsMgr *tlsManager,
|
||||
httpReg aghhttp.RegisterFunc,
|
||||
clientsContainer dnsforward.ClientsContainer,
|
||||
confModifier agh.ConfigModifier,
|
||||
) (newConf *dnsforward.ServerConfig, err error) {
|
||||
hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()})
|
||||
|
||||
@ -264,7 +263,7 @@ func newServerConfig(
|
||||
TLSAllowUnencryptedDoH: tlsConf.AllowUnencryptedDoH,
|
||||
UpstreamTimeout: time.Duration(dnsConf.UpstreamTimeout),
|
||||
TLSv12Roots: tlsMgr.rootCerts,
|
||||
ConfigModified: onConfigModified,
|
||||
ConfModifier: confModifier,
|
||||
HTTPRegister: httpReg,
|
||||
LocalPTRResolvers: dnsConf.PrivateRDNSResolvers,
|
||||
UseDNS64: dnsConf.UseDNS64,
|
||||
@ -454,7 +453,7 @@ func startDNSServer() error {
|
||||
return fmt.Errorf("starting clients container: %w", err)
|
||||
}
|
||||
|
||||
err = globalContext.dnsServer.Start()
|
||||
err = globalContext.dnsServer.Start(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting dns server: %w", err)
|
||||
}
|
||||
@ -470,30 +469,30 @@ func startDNSServer() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func stopDNSServer() (err error) {
|
||||
func stopDNSServer(ctx context.Context) (err error) {
|
||||
if !isRunning() {
|
||||
return nil
|
||||
}
|
||||
|
||||
err = globalContext.dnsServer.Stop()
|
||||
err = globalContext.dnsServer.Stop(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("stopping forwarding dns server: %w", err)
|
||||
}
|
||||
|
||||
err = globalContext.clients.close(context.TODO())
|
||||
err = globalContext.clients.close(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing clients container: %w", err)
|
||||
}
|
||||
|
||||
closeDNSServer()
|
||||
closeDNSServer(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeDNSServer() {
|
||||
func closeDNSServer(ctx context.Context) {
|
||||
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
|
||||
if globalContext.dnsServer != nil {
|
||||
globalContext.dnsServer.Close()
|
||||
globalContext.dnsServer.Close(ctx)
|
||||
globalContext.dnsServer = nil
|
||||
}
|
||||
|
||||
@ -509,8 +508,7 @@ func closeDNSServer() {
|
||||
}
|
||||
|
||||
if globalContext.queryLog != nil {
|
||||
// TODO(s.chzhen): Pass context.
|
||||
err := globalContext.queryLog.Shutdown(context.TODO())
|
||||
err := globalContext.queryLog.Shutdown(ctx)
|
||||
if err != nil {
|
||||
log.Error("closing query log: %s", err)
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
@ -297,12 +298,13 @@ func initContextClients(
|
||||
ctx context.Context,
|
||||
logger *slog.Logger,
|
||||
sigHdlr *signalHandler,
|
||||
confModifier agh.ConfigModifier,
|
||||
) (err error) {
|
||||
//lint:ignore SA1019 Migration is not over.
|
||||
config.DHCP.WorkDir = globalContext.workDir
|
||||
config.DHCP.DataDir = globalContext.getDataDir()
|
||||
config.DHCP.HTTPRegister = httpRegister
|
||||
config.DHCP.ConfigModified = onConfigModified
|
||||
config.DHCP.ConfModifier = confModifier
|
||||
|
||||
globalContext.dhcpServer, err = dhcpd.Create(config.DHCP)
|
||||
if globalContext.dhcpServer == nil || err != nil {
|
||||
@ -327,6 +329,7 @@ func initContextClients(
|
||||
arpDB,
|
||||
config.Filtering,
|
||||
sigHdlr,
|
||||
confModifier,
|
||||
)
|
||||
}
|
||||
|
||||
@ -377,6 +380,7 @@ func setupDNSFilteringConf(
|
||||
baseLogger *slog.Logger,
|
||||
conf *filtering.Config,
|
||||
tlsMgr *tlsManager,
|
||||
confModifier agh.ConfigModifier,
|
||||
) (err error) {
|
||||
const (
|
||||
dnsTimeout = 3 * time.Second
|
||||
@ -398,7 +402,7 @@ func setupDNSFilteringConf(
|
||||
conf.EtcHosts = nil
|
||||
}
|
||||
|
||||
conf.ConfigModified = onConfigModified
|
||||
conf.ConfModifier = confModifier
|
||||
conf.HTTPRegister = httpRegister
|
||||
conf.DataDir = globalContext.getDataDir()
|
||||
conf.Filters = slices.Clone(config.Filters)
|
||||
@ -564,6 +568,7 @@ func initWeb(
|
||||
baseLogger *slog.Logger,
|
||||
tlsMgr *tlsManager,
|
||||
auth *auth,
|
||||
confModifier agh.ConfigModifier,
|
||||
isCustomUpdURL bool,
|
||||
) (web *webAPI, err error) {
|
||||
logger := baseLogger.With(slogutil.KeyPrefix, "webapi")
|
||||
@ -583,11 +588,12 @@ func initWeb(
|
||||
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL)
|
||||
|
||||
webConf := &webConfig{
|
||||
updater: upd,
|
||||
logger: logger,
|
||||
baseLogger: baseLogger,
|
||||
tlsManager: tlsMgr,
|
||||
auth: auth,
|
||||
updater: upd,
|
||||
logger: logger,
|
||||
baseLogger: baseLogger,
|
||||
confModifier: confModifier,
|
||||
tlsManager: tlsMgr,
|
||||
auth: auth,
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
@ -661,25 +667,31 @@ func run(
|
||||
// data first, but also to avoid relying on automatic Go init() function.
|
||||
filtering.InitModule(ctx, slogLogger)
|
||||
|
||||
err = initContextClients(ctx, slogLogger, sigHdlr)
|
||||
confModifier := newDefaultConfigModifier(
|
||||
config,
|
||||
slogLogger.With(slogutil.KeyPrefix, "config_modifier"),
|
||||
)
|
||||
|
||||
err = initContextClients(ctx, slogLogger, sigHdlr, confModifier)
|
||||
fatalOnError(err)
|
||||
|
||||
tlsMgrLogger := slogLogger.With(slogutil.KeyPrefix, "tls_manager")
|
||||
|
||||
tlsMgr, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: tlsMgrLogger,
|
||||
configModified: onConfigModified,
|
||||
tlsSettings: config.TLS,
|
||||
servePlainDNS: config.DNS.ServePlainDNS,
|
||||
logger: tlsMgrLogger,
|
||||
confModifier: confModifier,
|
||||
tlsSettings: config.TLS,
|
||||
servePlainDNS: config.DNS.ServePlainDNS,
|
||||
})
|
||||
if err != nil {
|
||||
tlsMgrLogger.ErrorContext(ctx, "initializing", slogutil.KeyError, err)
|
||||
onConfigModified()
|
||||
confModifier.Apply(ctx)
|
||||
}
|
||||
|
||||
globalContext.tls = tlsMgr
|
||||
confModifier.setTLSManager(tlsMgr)
|
||||
|
||||
err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr)
|
||||
err = setupDNSFilteringConf(ctx, slogLogger, config.Filtering, tlsMgr, confModifier)
|
||||
fatalOnError(err)
|
||||
|
||||
err = setupOpts(opts)
|
||||
@ -715,8 +727,19 @@ func run(
|
||||
fatalOnError(err)
|
||||
|
||||
globalContext.auth = auth
|
||||
confModifier.setAuth(auth)
|
||||
|
||||
web, err := initWeb(ctx, opts, clientBuildFS, upd, slogLogger, tlsMgr, auth, isCustomURL)
|
||||
web, err := initWeb(
|
||||
ctx,
|
||||
opts,
|
||||
clientBuildFS,
|
||||
upd,
|
||||
slogLogger,
|
||||
tlsMgr,
|
||||
auth,
|
||||
confModifier,
|
||||
isCustomURL,
|
||||
)
|
||||
fatalOnError(err)
|
||||
|
||||
globalContext.web = web
|
||||
@ -728,7 +751,7 @@ func run(
|
||||
fatalOnError(err)
|
||||
|
||||
if !globalContext.firstRun {
|
||||
err = initDNS(slogLogger, tlsMgr, statsDir, querylogDir)
|
||||
err = initDNS(ctx, slogLogger, tlsMgr, confModifier, statsDir, querylogDir)
|
||||
fatalOnError(err)
|
||||
|
||||
tlsMgr.start(ctx)
|
||||
@ -736,7 +759,7 @@ func run(
|
||||
go func() {
|
||||
startErr := startDNSServer()
|
||||
if startErr != nil {
|
||||
closeDNSServer()
|
||||
closeDNSServer(ctx)
|
||||
fatalOnError(startErr)
|
||||
}
|
||||
}()
|
||||
@ -972,7 +995,7 @@ func cleanup(ctx context.Context) {
|
||||
globalContext.web = nil
|
||||
}
|
||||
|
||||
err := stopDNSServer()
|
||||
err := stopDNSServer(ctx)
|
||||
if err != nil {
|
||||
log.Error("stopping dns server: %s", err)
|
||||
}
|
||||
@ -1130,7 +1153,7 @@ func cmdlineUpdate(
|
||||
//
|
||||
// TODO(e.burkov): We could probably initialize the internal resolver
|
||||
// separately.
|
||||
err := initDNSServer(nil, nil, nil, nil, nil, nil, tlsMgr, l)
|
||||
err := initDNSServer(ctx, nil, nil, nil, nil, nil, nil, tlsMgr, l, agh.EmptyConfigModifier{})
|
||||
fatalOnError(err)
|
||||
|
||||
l.InfoContext(ctx, "performing update via cli")
|
||||
|
||||
@ -64,7 +64,9 @@ func handleI18nCurrentLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// TODO(d.kolyshev): Deprecated, remove it later.
|
||||
func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *webAPI) handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
}
|
||||
@ -89,9 +91,10 @@ func handleI18nChangeLanguage(w http.ResponseWriter, r *http.Request) {
|
||||
defer config.Unlock()
|
||||
|
||||
config.Language = lang
|
||||
log.Printf("home: language is set to %s", lang)
|
||||
web.logger.InfoContext(ctx, "language is updated", "lang", lang)
|
||||
}()
|
||||
|
||||
onConfigModified()
|
||||
web.confModifier.Apply(ctx)
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
)
|
||||
|
||||
// Theme is an enum of all allowed UI themes.
|
||||
@ -75,7 +74,9 @@ func (web *webAPI) handleGetProfile(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// handlePutProfile is the handler for PUT /control/profile/update endpoint.
|
||||
func handlePutProfile(w http.ResponseWriter, r *http.Request) {
|
||||
func (web *webAPI) handlePutProfile(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if aghhttp.WriteTextPlainDeprecated(w, r) {
|
||||
return
|
||||
}
|
||||
@ -103,10 +104,10 @@ func handlePutProfile(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
config.Language = lang
|
||||
config.Theme = theme
|
||||
log.Printf("home: language is set to %s", lang)
|
||||
log.Printf("home: theme is set to %s", theme)
|
||||
web.logger.InfoContext(ctx, "profile updated", "lang", lang, "theme", theme)
|
||||
}()
|
||||
|
||||
onConfigModified()
|
||||
web.confModifier.Apply(ctx)
|
||||
|
||||
aghhttp.OK(w)
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
@ -56,9 +57,8 @@ type tlsManager struct {
|
||||
// conf contains the TLS configuration settings. It must not be nil.
|
||||
conf *tlsConfigSettings
|
||||
|
||||
// configModified is called when the TLS configuration is changed via an
|
||||
// HTTP request.
|
||||
configModified func()
|
||||
// confModifier is used to update the global configuration.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// customCipherIDs are the ID of the cipher suites that AdGuard Home must use.
|
||||
customCipherIDs []uint16
|
||||
@ -73,9 +73,9 @@ type tlsManagerConfig struct {
|
||||
// be nil.
|
||||
logger *slog.Logger
|
||||
|
||||
// configModified is called when the TLS configuration is changed via an
|
||||
// HTTP request. It must not be nil.
|
||||
configModified func()
|
||||
// confModifier is used to update the global configuration. It must not be
|
||||
// nil.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// tlsSettings contains the TLS configuration settings.
|
||||
tlsSettings tlsConfigSettings
|
||||
@ -91,12 +91,12 @@ type tlsManagerConfig struct {
|
||||
// [tlsManager.setWebAPI].
|
||||
func newTLSManager(ctx context.Context, conf *tlsManagerConfig) (m *tlsManager, err error) {
|
||||
m = &tlsManager{
|
||||
logger: conf.logger,
|
||||
mu: &sync.Mutex{},
|
||||
configModified: conf.configModified,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: &conf.tlsSettings,
|
||||
servePlainDNS: conf.servePlainDNS,
|
||||
logger: conf.logger,
|
||||
mu: &sync.Mutex{},
|
||||
confModifier: conf.confModifier,
|
||||
status: &tlsConfigStatus{},
|
||||
conf: &conf.tlsSettings,
|
||||
servePlainDNS: conf.servePlainDNS,
|
||||
}
|
||||
|
||||
m.rootCerts = aghtls.SystemRootCAs(ctx, conf.logger)
|
||||
@ -232,7 +232,7 @@ func (m *tlsManager) reload(ctx context.Context) {
|
||||
|
||||
m.certLastMod = fi.ModTime().UTC()
|
||||
|
||||
err = m.reconfigureDNSServer()
|
||||
err = m.reconfigureDNSServer(ctx)
|
||||
if err != nil {
|
||||
m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err)
|
||||
}
|
||||
@ -245,7 +245,7 @@ func (m *tlsManager) reload(ctx context.Context) {
|
||||
|
||||
// reconfigureDNSServer updates the DNS server configuration using the stored
|
||||
// TLS settings. m.mu is expected to be locked.
|
||||
func (m *tlsManager) reconfigureDNSServer() (err error) {
|
||||
func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
|
||||
newConf, err := newServerConfig(
|
||||
&config.DNS,
|
||||
config.Clients.Sources,
|
||||
@ -253,12 +253,13 @@ func (m *tlsManager) reconfigureDNSServer() (err error) {
|
||||
m,
|
||||
httpRegister,
|
||||
globalContext.clients.storage,
|
||||
m.confModifier,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("generating forwarding dns server config: %w", err)
|
||||
}
|
||||
|
||||
err = globalContext.dnsServer.Reconfigure(newConf)
|
||||
err = globalContext.dnsServer.Reconfigure(ctx, newConf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("starting forwarding dns server: %w", err)
|
||||
}
|
||||
@ -515,7 +516,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
var restartHTTPS bool
|
||||
defer func() {
|
||||
if restartHTTPS {
|
||||
m.configModified()
|
||||
m.confModifier.Apply(ctx)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -557,7 +558,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
|
||||
}()
|
||||
}
|
||||
|
||||
err = m.reconfigureDNSServer()
|
||||
err = m.reconfigureDNSServer(ctx)
|
||||
if err != nil {
|
||||
m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err)
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/client"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
|
||||
@ -66,9 +67,9 @@ func TestValidateCertificates(t *testing.T) {
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
configModified: func() {},
|
||||
servePlainDNS: false,
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -246,8 +247,8 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
writeCertAndKey(t, certDER, certPath, key, keyPath)
|
||||
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
configModified: func() {},
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
CertificatePath: certPath,
|
||||
@ -257,7 +258,7 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
@ -272,7 +273,9 @@ func TestTLSManager_Reload(t *testing.T) {
|
||||
|
||||
// The [tlsManager.reload] method will start the DNS server and it should be
|
||||
// stopped after the test ends.
|
||||
testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
conf = m.config()
|
||||
assertCertSerialNumber(t, conf, snAfter)
|
||||
@ -285,8 +288,8 @@ func TestTLSManager_HandleTLSStatus(t *testing.T) {
|
||||
)
|
||||
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
configModified: func() {},
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
CertificateChain: string(testCertChainData),
|
||||
@ -321,13 +324,13 @@ func TestValidateTLSSettings(t *testing.T) {
|
||||
)
|
||||
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
configModified: func() {},
|
||||
servePlainDNS: false,
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
servePlainDNS: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
@ -420,8 +423,8 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
||||
)
|
||||
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
configModified: func() {},
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
CertificateChain: string(testCertChainData),
|
||||
@ -431,7 +434,7 @@ func TestTLSManager_HandleTLSValidate(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
@ -476,15 +479,17 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{
|
||||
TLSConf: &dnsforward.TLSConfig{},
|
||||
Config: dnsforward.Config{
|
||||
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: dnsforward.EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
err = globalContext.dnsServer.Prepare(
|
||||
testutil.ContextWithTimeout(t, testTimeout),
|
||||
&dnsforward.ServerConfig{
|
||||
TLSConf: &dnsforward.TLSConfig{},
|
||||
Config: dnsforward.Config{
|
||||
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
|
||||
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false},
|
||||
ClientsContainer: dnsforward.EmptyClientsContainer{},
|
||||
},
|
||||
ServePlainDNS: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
|
||||
@ -511,8 +516,8 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
|
||||
// Initialize the TLS manager and assert its configuration.
|
||||
m, err := newTLSManager(ctx, &tlsManagerConfig{
|
||||
logger: testLogger,
|
||||
configModified: func() {},
|
||||
logger: testLogger,
|
||||
confModifier: agh.EmptyConfigModifier{},
|
||||
tlsSettings: tlsConfigSettings{
|
||||
Enabled: true,
|
||||
CertificatePath: certPath,
|
||||
@ -522,7 +527,7 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, false)
|
||||
web, err := initWeb(ctx, options{}, nil, nil, testLogger, nil, nil, agh.EmptyConfigModifier{}, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
m.setWebAPI(web)
|
||||
@ -551,7 +556,9 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
|
||||
|
||||
// The [tlsManager.handleTLSConfigure] method will start the DNS server and
|
||||
// it should be stopped after the test ends.
|
||||
testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop)
|
||||
testutil.CleanupAndRequireSuccess(t, func() (err error) {
|
||||
return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
res := &tlsConfig{
|
||||
tlsConfigStatus: &tlsConfigStatus{},
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
@ -47,6 +48,9 @@ type webConfig struct {
|
||||
// nil.
|
||||
baseLogger *slog.Logger
|
||||
|
||||
// confModifier is used to update the global configuration.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// tlsManager contains the current configuration and state of TLS
|
||||
// encryption. It must not be nil.
|
||||
tlsManager *tlsManager
|
||||
@ -104,6 +108,9 @@ type httpsServer struct {
|
||||
type webAPI struct {
|
||||
conf *webConfig
|
||||
|
||||
// confModifier is used to update the global configuration.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// TODO(a.garipov): Refactor all these servers.
|
||||
httpServer *http.Server
|
||||
|
||||
@ -134,11 +141,12 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
conf.logger.InfoContext(ctx, "initializing")
|
||||
|
||||
w = &webAPI{
|
||||
conf: conf,
|
||||
logger: conf.logger,
|
||||
baseLogger: conf.baseLogger,
|
||||
tlsManager: conf.tlsManager,
|
||||
auth: conf.auth,
|
||||
conf: conf,
|
||||
confModifier: conf.confModifier,
|
||||
logger: conf.logger,
|
||||
baseLogger: conf.baseLogger,
|
||||
tlsManager: conf.tlsManager,
|
||||
auth: conf.auth,
|
||||
}
|
||||
|
||||
clientFS := http.FileServer(http.FS(conf.clientFS))
|
||||
|
||||
@ -186,7 +186,7 @@ func (l *queryLog) handleQueryLogConfig(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
defer l.conf.ConfigModified()
|
||||
defer l.conf.ConfigModifier.Apply(r.Context())
|
||||
|
||||
l.confMu.Lock()
|
||||
defer l.confMu.Unlock()
|
||||
@ -250,7 +250,7 @@ func (l *queryLog) handlePutQueryLogConfig(w http.ResponseWriter, r *http.Reques
|
||||
return
|
||||
}
|
||||
|
||||
defer l.conf.ConfigModified()
|
||||
defer l.conf.ConfigModifier.Apply(r.Context())
|
||||
|
||||
l.confMu.Lock()
|
||||
defer l.confMu.Unlock()
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
|
||||
@ -47,9 +48,9 @@ type Config struct {
|
||||
// Anonymizer processes the IP addresses to anonymize those if needed.
|
||||
Anonymizer *aghnet.IPMut
|
||||
|
||||
// ConfigModified is called when the configuration is changed, for example
|
||||
// by HTTP requests.
|
||||
ConfigModified func()
|
||||
// ConfigModifier is used to update the global configuration. It must not
|
||||
// be nil.
|
||||
ConfigModifier agh.ConfigModifier
|
||||
|
||||
// HTTPRegister registers an HTTP handler.
|
||||
HTTPRegister aghhttp.RegisterFunc
|
||||
|
||||
@ -166,7 +166,7 @@ func (s *StatsCtx) handleStatsConfig(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
limit := time.Duration(reqData.IntervalDays) * timeutil.Day
|
||||
|
||||
defer s.configModified()
|
||||
defer s.configModifier.Apply(ctx)
|
||||
|
||||
s.confMu.Lock()
|
||||
defer s.confMu.Unlock()
|
||||
@ -216,7 +216,7 @@ func (s *StatsCtx) handlePutStatsConfig(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
defer s.configModified()
|
||||
defer s.configModifier.Apply(ctx)
|
||||
|
||||
s.confMu.Lock()
|
||||
defer s.confMu.Unlock()
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -27,7 +28,7 @@ func TestHandleStatsConfig(t *testing.T) {
|
||||
conf := Config{
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
UnitID: func() (id uint32) { return 0 },
|
||||
ConfigModified: func() {},
|
||||
ConfigModifier: agh.EmptyConfigModifier{},
|
||||
ShouldCountClient: func([]string) bool { return true },
|
||||
Filename: filepath.Join(t.TempDir(), "stats.db"),
|
||||
Limit: time.Hour * 24,
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
@ -55,9 +56,9 @@ type Config struct {
|
||||
// nil, the default function is used, see newUnitID.
|
||||
UnitID UnitIDGenFunc
|
||||
|
||||
// ConfigModified will be called each time the configuration changed via web
|
||||
// interface.
|
||||
ConfigModified func()
|
||||
// ConfigModifier is used to update the global configuration. It must not
|
||||
// be nil.
|
||||
ConfigModifier agh.ConfigModifier
|
||||
|
||||
// ShouldCountClient returns client's ignore setting.
|
||||
ShouldCountClient func([]string) bool
|
||||
@ -123,9 +124,8 @@ type StatsCtx struct {
|
||||
// httpRegister is used to set HTTP handlers.
|
||||
httpRegister aghhttp.RegisterFunc
|
||||
|
||||
// configModified is called whenever the configuration is modified via web
|
||||
// interface.
|
||||
configModified func()
|
||||
// configModifier is used to update the global configuration.
|
||||
configModifier agh.ConfigModifier
|
||||
|
||||
// confMu protects ignored, limit, and enabled.
|
||||
confMu *sync.RWMutex
|
||||
@ -165,7 +165,7 @@ func New(conf Config) (s *StatsCtx, err error) {
|
||||
logger: conf.Logger,
|
||||
currMu: &sync.RWMutex{},
|
||||
httpRegister: conf.HTTPRegister,
|
||||
configModified: conf.ConfigModified,
|
||||
configModifier: conf.ConfigModifier,
|
||||
filename: conf.Filename,
|
||||
|
||||
confMu: &sync.RWMutex{},
|
||||
|
||||
Loading…
Reference in New Issue
Block a user