all: imp code

This commit is contained in:
Stanislav Chzhen 2025-07-30 20:18:43 +03:00
parent 67c5608b4b
commit 96d21efc98
26 changed files with 259 additions and 254 deletions

View File

@ -15,8 +15,8 @@ type ConfigModifier interface {
// nothing.
type EmptyConfigModifier struct{}
// Apply implements the [ConfigModifier] for EmptyConfigModifier.
func (em EmptyConfigModifier) Apply(ctx context.Context) {}
// type check
var _ ConfigModifier = EmptyConfigModifier{}
// Apply implements the [ConfigModifier] for EmptyConfigModifier.
func (em EmptyConfigModifier) Apply(ctx context.Context) {}

View File

@ -9,6 +9,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -132,7 +133,7 @@ func TestServer_HandleBefore_tls(t *testing.T) {
s.conf.DisallowedClients = tc.disallowedClients
s.conf.BlockedHosts = tc.blockedHosts
err := s.Prepare(&s.conf)
err := s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
require.NoError(t, err)
startDeferStop(t, s)

View File

@ -310,7 +310,7 @@ const (
)
// newProxyConfig creates and validates configuration for the main proxy.
func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
func (s *Server) newProxyConfig(ctx context.Context) (conf *proxy.Config, err error) {
srvConf := s.conf
trustedPrefixes := netutil.UnembedPrefixes(srvConf.TrustedProxies)
@ -358,12 +358,12 @@ func (s *Server) newProxyConfig() (conf *proxy.Config, err error) {
return nil, fmt.Errorf("bogus_nxdomain: %w", err)
}
err = s.prepareTLS(conf)
err = s.prepareTLS(ctx, conf)
if err != nil {
return nil, fmt.Errorf("validating tls: %w", err)
}
err = s.preparePlain(conf)
err = s.preparePlain(ctx, conf)
if err != nil {
return nil, fmt.Errorf("validating plain: %w", err)
}
@ -447,7 +447,7 @@ func (s *Server) initDefaultSettings() {
// prepareIpsetListSettings reads and prepares the ipset configuration either
// from a file or from the data in the configuration file.
func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
func (s *Server) prepareIpsetListSettings(ctx context.Context) (ipsets []string, err error) {
fn := s.conf.IpsetListFileName
if fn == "" {
return s.conf.IpsetList, nil
@ -462,13 +462,7 @@ func (s *Server) prepareIpsetListSettings() (ipsets []string, err error) {
ipsets = stringutil.SplitTrimmed(string(data), "\n")
ipsets = slices.DeleteFunc(ipsets, aghnet.IsCommentOrEmpty)
// TODO(s.chzhen): Pass context.
s.logger.DebugContext(
context.TODO(),
"using ipset rules from file",
"num", len(ipsets),
"file", fn,
)
s.logger.DebugContext(ctx, "using ipset rules from file", "num", len(ipsets), "file", fn)
return ipsets, nil
}
@ -638,7 +632,7 @@ func (s *Server) prepareDNSCrypt(proxyConf *proxy.Config) {
}
// prepareTLS sets up the TLS configuration for the DNS proxy.
func (s *Server) prepareTLS(proxyConf *proxy.Config) (err error) {
func (s *Server) prepareTLS(ctx context.Context, proxyConf *proxy.Config) (err error) {
s.prepareDNSCrypt(proxyConf)
if s.conf.TLSConf.Cert == nil {
@ -659,8 +653,6 @@ func (s *Server) prepareTLS(proxyConf *proxy.Config) (err error) {
s.hasIPAddrs = aghtls.CertificateHasIP(cert)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
if s.conf.TLSConf.StrictSNICheck {
if len(cert.DNSNames) != 0 {
s.dnsNames = cert.DNSNames
@ -741,7 +733,7 @@ func (s *Server) onGetCertificate(ch *tls.ClientHelloInfo) (*tls.Certificate, er
// preparePlain prepares the plain-DNS configuration for the DNS proxy.
// preparePlain assumes that prepareTLS has already been called.
func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) {
func (s *Server) preparePlain(ctx context.Context, proxyConf *proxy.Config) (err error) {
if s.conf.ServePlainDNS {
proxyConf.UDPListenAddr = s.conf.UDPListenAddrs
proxyConf.TCPListenAddr = s.conf.TCPListenAddrs
@ -759,15 +751,16 @@ func (s *Server) preparePlain(proxyConf *proxy.Config) (err error) {
return errors.Error("disabling plain dns requires at least one encrypted protocol")
}
// TODO(s.chzhen): Pass context.
s.logger.WarnContext(context.TODO(), "plain dns is disabled")
s.logger.WarnContext(ctx, "plain dns is disabled")
return nil
}
// UpdatedProtectionStatus updates protection state, if the protection was
// disabled temporarily. Returns the updated state of protection.
func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Time) {
func (s *Server) UpdatedProtectionStatus(
ctx context.Context,
) (enabled bool, disabledUntil *time.Time) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()
@ -787,7 +780,7 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/5661.
if s.protectionUpdateInProgress.CompareAndSwap(false, true) {
go s.enableProtectionAfterPause()
go s.enableProtectionAfterPause(ctx)
}
return true, nil
@ -795,15 +788,12 @@ func (s *Server) UpdatedProtectionStatus() (enabled bool, disabledUntil *time.Ti
// enableProtectionAfterPause sets the protection configuration to enabled
// values. It is intended to be used as a goroutine.
func (s *Server) enableProtectionAfterPause() {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) enableProtectionAfterPause(ctx context.Context) {
defer slogutil.RecoverAndLog(ctx, s.logger)
defer s.protectionUpdateInProgress.Store(false)
// TODO(s.chzhen): Pass context.
defer s.conf.ConfModifier.Apply(context.TODO())
defer s.conf.ConfModifier.Apply(ctx)
s.serverLock.Lock()
defer s.serverLock.Unlock()

View File

@ -291,7 +291,7 @@ func NewServer(p DNSCreateParams) (s *Server, err error) {
// its workers finished. But it would require the upstream.Upstream to have the
// Close method to prevent from hanging while waiting for unresponsive server to
// respond.
func (s *Server) Close() {
func (s *Server) Close(ctx context.Context) {
s.serverLock.Lock()
defer s.serverLock.Unlock()
@ -301,8 +301,7 @@ func (s *Server) Close() {
s.dnsProxy = nil
if err := s.ipset.close(); err != nil {
// TODO(s.chzhen): Pass context.
s.logger.ErrorContext(context.TODO(), "closing ipset", slogutil.KeyError, err)
s.logger.ErrorContext(ctx, "closing ipset", slogutil.KeyError, err)
}
}
@ -467,18 +466,17 @@ func hostFromPTR(resp *dns.Msg) (host string, ttl time.Duration, err error) {
}
// Start starts the DNS server. It must only be called after [Server.Prepare].
func (s *Server) Start() error {
func (s *Server) Start(ctx context.Context) error {
s.serverLock.Lock()
defer s.serverLock.Unlock()
return s.startLocked()
return s.startLocked(ctx)
}
// startLocked starts the DNS server without locking. s.serverLock is expected
// to be locked.
func (s *Server) startLocked() error {
// TODO(e.burkov): Use context properly.
err := s.dnsProxy.Start(context.Background())
func (s *Server) startLocked(ctx context.Context) error {
err := s.dnsProxy.Start(ctx)
if err == nil {
s.isRunning = true
}
@ -488,7 +486,7 @@ func (s *Server) startLocked() error {
// Prepare initializes parameters of s using data from conf. conf must not be
// nil.
func (s *Server) Prepare(conf *ServerConfig) (err error) {
func (s *Server) Prepare(ctx context.Context, conf *ServerConfig) (err error) {
s.conf = *conf
// dnsFilter can be nil during application update.
@ -502,13 +500,13 @@ func (s *Server) Prepare(conf *ServerConfig) (err error) {
s.initDefaultSettings()
err = s.prepareInternalDNS()
err = s.prepareInternalDNS(ctx)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err
}
proxyConfig, err := s.newProxyConfig()
proxyConfig, err := s.newProxyConfig(ctx)
if err != nil {
return fmt.Errorf("preparing proxy: %w", err)
}
@ -639,8 +637,8 @@ func (s *Server) prepareLocalResolvers() (uc *proxy.UpstreamConfig, err error) {
// prepareInternalDNS initializes the internal state of s before initializing
// the primary DNS proxy instance. It assumes s.serverLock is locked or the
// Server not running.
func (s *Server) prepareInternalDNS() (err error) {
ipsetList, err := s.prepareIpsetListSettings()
func (s *Server) prepareInternalDNS(ctx context.Context) (err error) {
ipsetList, err := s.prepareIpsetListSettings(ctx)
if err != nil {
return fmt.Errorf("preparing ipset settings: %w", err)
}
@ -785,32 +783,26 @@ func (s *Server) prepareInternalProxy() (err error) {
}
// Stop stops the DNS server.
func (s *Server) Stop() error {
func (s *Server) Stop(ctx context.Context) error {
s.serverLock.Lock()
defer s.serverLock.Unlock()
s.stopLocked()
s.stopLocked(ctx)
return nil
}
// stopLocked stops the DNS server without locking. s.serverLock is expected to
// be locked.
func (s *Server) stopLocked() {
func (s *Server) stopLocked(ctx context.Context) {
// TODO(e.burkov, a.garipov): Return critical errors, not just log them.
// This will require filtering all the non-critical errors in
// [upstream.Upstream] implementations.
if s.dnsProxy != nil {
// TODO(e.burkov): Use context properly.
err := s.dnsProxy.Shutdown(context.Background())
err := s.dnsProxy.Shutdown(ctx)
if err != nil {
// TODO(s.chzhen): Pass context.
s.logger.ErrorContext(
context.TODO(),
"closing primary resolvers",
slogutil.KeyError, err,
)
s.logger.ErrorContext(ctx, "closing primary resolvers", slogutil.KeyError, err)
}
}
@ -859,16 +851,14 @@ func (s *Server) proxy() (p *proxy.Proxy) {
// Reconfigure applies the new configuration to the DNS server.
//
// TODO(a.garipov): This whole piece of API is weird and needs to be remade.
func (s *Server) Reconfigure(conf *ServerConfig) error {
func (s *Server) Reconfigure(ctx context.Context, conf *ServerConfig) error {
s.serverLock.Lock()
defer s.serverLock.Unlock()
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
s.logger.InfoContext(ctx, "starting reconfiguring server")
defer s.logger.InfoContext(ctx, "finished reconfiguring server")
s.stopLocked()
s.stopLocked(ctx)
// It seems that net.Listener.Close() doesn't close file descriptors right away.
// We wait for some time and hope that this fd will be closed.
@ -887,12 +877,12 @@ func (s *Server) Reconfigure(conf *ServerConfig) error {
// TODO(e.burkov): It seems an error here brings the server down, which is
// not reliable enough.
err := s.Prepare(conf)
err := s.Prepare(ctx, conf)
if err != nil {
return fmt.Errorf("could not reconfigure the server: %w", err)
}
err = s.startLocked()
err = s.startLocked(ctx)
if err != nil {
return fmt.Errorf("could not reconfigure the server: %w", err)
}

View File

@ -106,9 +106,11 @@ func (c *clientsContainer) ClearUpstreamCache() {
func startDeferStop(t *testing.T, s *Server) {
t.Helper()
err := s.Start()
err := s.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, s.Stop)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.Stop(testutil.ContextWithTimeout(t, testTimeout))
})
}
// applyEmptyClientFiltering is a helper function for tests with
@ -169,7 +171,7 @@ func createTestServer(
})
require.NoError(t, err)
err = s.Prepare(&forwardConf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf)
require.NoError(t, err)
return s
@ -244,7 +246,7 @@ func createTestTLS(t *testing.T, tlsConf *TLSConfig) (s *Server, certPem []byte)
ServePlainDNS: true,
})
err = s.Prepare(&s.conf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
require.NoErrorf(t, err, "failed to prepare server: %s", err)
return s, certPem
@ -420,7 +422,7 @@ func TestServer_timeout(t *testing.T) {
})
require.NoError(t, err)
err = s.Prepare(srvConf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), srvConf)
require.NoError(t, err)
assert.Equal(t, testTimeout, s.conf.UpstreamTimeout)
@ -439,7 +441,7 @@ func TestServer_timeout(t *testing.T) {
Enabled: false,
}
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
err = s.Prepare(&s.conf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
require.NoError(t, err)
assert.Equal(t, DefaultTimeout, s.conf.UpstreamTimeout)
@ -466,7 +468,7 @@ func TestServer_Prepare_fallbacks(t *testing.T) {
})
require.NoError(t, err)
err = s.Prepare(srvConf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), srvConf)
require.NoError(t, err)
require.NotNil(t, s.dnsProxy.Fallbacks)
@ -1104,7 +1106,7 @@ func TestBlockedCustomIP(t *testing.T) {
}
// Invalid BlockingIPv4.
err = s.Prepare(conf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), conf)
assert.Error(t, err)
s.dnsFilter.SetBlockingMode(
@ -1112,7 +1114,7 @@ func TestBlockedCustomIP(t *testing.T) {
netip.AddrFrom4([4]byte{0, 0, 0, 1}),
netip.MustParseAddr("::1"))
err = s.Prepare(conf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), conf)
require.NoError(t, err)
f.SetEnabled(true)
@ -1264,7 +1266,7 @@ func TestRewrite(t *testing.T) {
})
require.NoError(t, err)
assert.NoError(t, s.Prepare(&ServerConfig{
assert.NoError(t, s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &ServerConfig{
UDPListenAddrs: []*net.UDPAddr{{}},
TCPListenAddrs: []*net.TCPAddr{{}},
TLSConf: &TLSConfig{},
@ -1334,7 +1336,7 @@ func TestRewrite(t *testing.T) {
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf := s.getDNSConfig(testutil.ContextWithTimeout(t, testTimeout))
conf.ProtectionEnabled = &val
s.setConfig(conf)
@ -1408,12 +1410,12 @@ func TestPTRResponseFromDHCPLeases(t *testing.T) {
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
require.NoError(t, err)
err = s.Start()
err = s.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
t.Cleanup(s.Close)
t.Cleanup(func() { s.Close(testutil.ContextWithTimeout(t, testTimeout)) })
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
req := createTestMessageWithType("34.12.168.192.in-addr.arpa.", dns.TypePTR)
@ -1498,12 +1500,12 @@ func TestPTRResponseFromHosts(t *testing.T) {
s.conf.Config.ClientsContainer = EmptyClientsContainer{}
s.conf.Config.UpstreamMode = UpstreamModeLoadBalance
err = s.Prepare(&s.conf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &s.conf)
require.NoError(t, err)
err = s.Start()
err = s.Start(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
t.Cleanup(s.Close)
t.Cleanup(func() { s.Close(testutil.ContextWithTimeout(t, testTimeout)) })
subTestFunc := func(t *testing.T) {
addr := s.dnsProxy.Addr(proxy.ProtoUDP)
@ -1524,7 +1526,7 @@ func TestPTRResponseFromHosts(t *testing.T) {
for _, protect := range []bool{true, false} {
val := protect
conf := s.getDNSConfig()
conf := s.getDNSConfig(testutil.ContextWithTimeout(t, testTimeout))
conf.ProtectionEnabled = &val
s.setConfig(conf)

View File

@ -15,6 +15,7 @@ import (
// filterDNSRewriteResponse handles a single DNS rewrite response entry. It
// returns the properly constructed answer resource record.
func (s *Server) filterDNSRewriteResponse(
ctx context.Context,
req *dns.Msg,
rr rules.RRType,
v rules.RRValue,
@ -27,12 +28,11 @@ func (s *Server) filterDNSRewriteResponse(
case dns.TypeMX:
return s.ansFromDNSRewriteMX(v, rr, req)
case dns.TypeHTTPS, dns.TypeSVCB:
return s.ansFromDNSRewriteSVCB(v, rr, req)
return s.ansFromDNSRewriteSVCB(ctx, v, rr, req)
case dns.TypeSRV:
return s.ansFromDNSRewriteSRV(v, rr, req)
default:
// TODO(s.chzhen): Pass context.
s.logger.DebugContext(context.TODO(), "unsupported dns rr type, skipping", "res_record", rr)
s.logger.DebugContext(ctx, "unsupported dns rr type, skipping", "res_record", rr)
return nil, nil
}
@ -98,6 +98,7 @@ func (s *Server) ansFromDNSRewriteMX(
// ansFromDNSRewriteSVCB creates a new answer resource record from the
// SVCB/HTTPS dnsrewrite rule data.
func (s *Server) ansFromDNSRewriteSVCB(
ctx context.Context,
v rules.RRValue,
rr rules.RRType,
req *dns.Msg,
@ -112,10 +113,10 @@ func (s *Server) ansFromDNSRewriteSVCB(
}
if rr == dns.TypeHTTPS {
return s.genAnswerHTTPS(req, svcb), nil
return s.genAnswerHTTPS(ctx, req, svcb), nil
}
return s.genAnswerSVCB(req, svcb), nil
return s.genAnswerSVCB(ctx, req, svcb), nil
}
// ansFromDNSRewriteSRV creates a new answer resource record from the SRV
@ -140,6 +141,7 @@ func (s *Server) ansFromDNSRewriteSRV(
// filterDNSRewrite handles dnsrewrite filters. It constructs a DNS response
// and sets it into pctx.Res. All parameters must not be nil.
func (s *Server) filterDNSRewrite(
ctx context.Context,
req *dns.Msg,
res *filtering.Result,
pctx *proxy.DNSContext,
@ -165,7 +167,7 @@ func (s *Server) filterDNSRewrite(
values := dnsrr.Response[qtype]
for i, v := range values {
var ans dns.RR
ans, err = s.filterDNSRewriteResponse(req, qtype, v)
ans, err = s.filterDNSRewriteResponse(ctx, req, qtype, v)
if err != nil {
return fmt.Errorf("dns rewrite response for %s[%d]: %w", dns.Type(qtype), i, err)
}

View File

@ -7,6 +7,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@ -71,7 +72,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeNameError, 0, nil)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeNameError, d.Res.Rcode)
@ -82,7 +83,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, 0, nil)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -94,7 +95,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypeA, ip4)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -108,7 +109,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypeAAAA, ip6)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -122,7 +123,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypePTR, domain)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -136,7 +137,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypeTXT, domain)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -150,7 +151,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypeMX, mxVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -168,7 +169,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypeSVCB, svcbVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -198,7 +199,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypeHTTPS, svcbVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)
@ -228,7 +229,7 @@ func TestServer_FilterDNSRewrite(t *testing.T) {
res := makeRes(dns.RcodeSuccess, dns.TypeSRV, srvVal)
d := &proxy.DNSContext{}
err := srv.filterDNSRewrite(req, res, d)
err := srv.filterDNSRewrite(testutil.ContextWithTimeout(t, testTimeout), req, res, d)
require.NoError(t, err)
assert.Equal(t, dns.RcodeSuccess, d.Res.Rcode)

View File

@ -24,7 +24,10 @@ func (s *Server) clientRequestFilteringSettings(dctx *dnsContext) (setts *filter
// filterDNSRequest applies the dnsFilter and sets dctx.proxyCtx.Res if the
// request was filtered.
func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err error) {
func (s *Server) filterDNSRequest(
ctx context.Context,
dctx *dnsContext,
) (res *filtering.Result, err error) {
pctx := dctx.proxyCtx
req := pctx.Req
q := req.Question[0]
@ -44,18 +47,12 @@ func (s *Server) filterDNSRequest(dctx *dnsContext) (res *filtering.Result, err
dctx.origQuestion = q
req.Question[0].Name = dns.Fqdn(res.CanonName)
case res.IsFiltered:
// TODO(s.chzhen): Pass context.
s.logger.DebugContext(
context.TODO(),
"host is filtered",
"host", host,
"reason", res.Reason,
)
pctx.Res = s.genDNSFilterMessage(pctx, res)
s.logger.DebugContext(ctx, "host is filtered", "host", host, "reason", res.Reason)
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
case res.Reason.In(filtering.Rewritten, filtering.FilteredSafeSearch):
pctx.Res = s.getCNAMEWithIPs(req, res.IPList, res.CanonName)
pctx.Res = s.getCNAMEWithIPs(ctx, req, res.IPList, res.CanonName)
case res.Reason.In(filtering.RewrittenRule, filtering.RewrittenAutoHosts):
if err = s.filterDNSRewrite(req, res, pctx); err != nil {
if err = s.filterDNSRewrite(ctx, req, res, pctx); err != nil {
return nil, err
}
}
@ -96,15 +93,12 @@ func (s *Server) checkHostRules(
// dctx.proxyCtx.Res. It sets dctx.result and dctx.origResp if at least one of
// canonical names, IP addresses, or HTTPS RR hints in it matches the filtering
// rules, as well as sets dctx.proxyCtx.Res to the filtered response.
func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
func (s *Server) filterDNSResponse(ctx context.Context, dctx *dnsContext) (err error) {
setts := dctx.setts
if !setts.FilteringEnabled {
return nil
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
var res *filtering.Result
pctx := dctx.proxyCtx
for i, a := range pctx.Res.Answer {
@ -145,7 +139,7 @@ func (s *Server) filterDNSResponse(dctx *dnsContext) (err error) {
} else if res != nil && res.IsFiltered {
dctx.result = res
dctx.origResp = pctx.Res
pctx.Res = s.genDNSFilterMessage(pctx, res)
pctx.Res = s.genDNSFilterMessage(ctx, pctx, res)
s.logger.DebugContext(
ctx,

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -66,7 +67,7 @@ func TestHandleDNSRequest_handleDNSRequest(t *testing.T) {
})
require.NoError(t, err)
err = s.Prepare(&forwardConf)
err = s.Prepare(testutil.ContextWithTimeout(t, testTimeout), &forwardConf)
require.NoError(t, err)
s.conf.UpstreamConfig.Upstreams = []upstream.Upstream{
@ -347,7 +348,7 @@ func TestHandleDNSRequest_filterDNSResponse(t *testing.T) {
},
}
fltErr := s.filterDNSResponse(dctx)
fltErr := s.filterDNSResponse(testutil.ContextWithTimeout(t, testTimeout), dctx)
require.NoError(t, fltErr)
res := dctx.result

View File

@ -138,8 +138,8 @@ const (
jsonUpstreamModeFastestAddr jsonUpstreamMode = "fastest_addr"
)
func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus()
func (s *Server) getDNSConfig(ctx context.Context) (c *jsonDNSConfig) {
protectionEnabled, protectionDisabledUntil := s.UpdatedProtectionStatus(ctx)
s.serverLock.RLock()
defer s.serverLock.RUnlock()
@ -184,8 +184,7 @@ func (s *Server) getDNSConfig() (c *jsonDNSConfig) {
defPTRUps, err := s.defaultLocalPTRUpstreams()
if err != nil {
// TODO(s.chzhen): Pass context.
s.logger.ErrorContext(context.TODO(), "getting local ptr upstreams", slogutil.KeyError, err)
s.logger.ErrorContext(ctx, "getting local ptr upstreams", slogutil.KeyError, err)
}
return &jsonDNSConfig{
@ -241,7 +240,7 @@ func (s *Server) defaultLocalPTRUpstreams() (ups []string, err error) {
// handleGetConfig handles requests to the GET /control/dns_info endpoint.
func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
resp := s.getDNSConfig()
resp := s.getDNSConfig(r.Context())
aghhttp.WriteJSONResponseOK(w, r, resp)
}
@ -487,6 +486,8 @@ func checkInclusion(ptr *int, minN, maxN int) (err error) {
// handleSetConfig handles requests to the POST /control/dns_config endpoint.
func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
req := &jsonDNSConfig{}
err := json.NewDecoder(r.Body).Decode(req)
if err != nil {
@ -512,10 +513,10 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
}
restart := s.setConfig(req)
s.conf.ConfModifier.Apply(r.Context())
s.conf.ConfModifier.Apply(ctx)
if restart {
err = s.Reconfigure(nil)
err = s.Reconfigure(ctx, nil)
if err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
}

View File

@ -93,8 +93,10 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
s := createTestServer(t, filterConf, forwardConf)
s.sysResolvers = &emptySysResolvers{}
require.NoError(t, s.Start())
testutil.CleanupAndRequireSuccess(t, s.Stop)
require.NoError(t, s.Start(testutil.ContextWithTimeout(t, testTimeout)))
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.Stop(testutil.ContextWithTimeout(t, testTimeout))
})
defaultConf := s.conf
@ -137,7 +139,7 @@ func TestDNSForwardHTTP_handleGetConfig(t *testing.T) {
t.Cleanup(w.Body.Reset)
s.conf = tc.conf()
s.handleGetConfig(w, nil)
s.handleGetConfig(w, httptest.NewRequest(http.MethodGet, "/", nil))
cType := w.Header().Get(httphdr.ContentType)
assert.Equal(t, aghhttp.HdrValApplicationJSON, cType)
@ -178,9 +180,11 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
defaultConf := s.conf
err := s.Start()
err := s.Start(testutil.ContextWithTimeout(t, testTimeout))
assert.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, s.Stop)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return s.Stop(testutil.ContextWithTimeout(t, testTimeout))
})
w := httptest.NewRecorder()
@ -297,7 +301,7 @@ func TestDNSForwardHTTP_handleSetConfig(t *testing.T) {
assert.Equal(t, tc.wantSet, strings.TrimSuffix(w.Body.String(), "\n"))
w.Body.Reset()
s.handleGetConfig(w, nil)
s.handleGetConfig(w, httptest.NewRequest(http.MethodGet, "/", nil))
assert.JSONEq(t, string(caseData.Want), w.Body.String())
w.Body.Reset()
})

View File

@ -121,9 +121,7 @@ func ipsFromAnswer(ans []dns.RR) (ip4s, ip6s []net.IP) {
}
// process adds the resolved IP addresses to the domain's ipsets, if any.
func (h *ipsetHandler) process(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Use passed context.
ctx := context.TODO()
func (h *ipsetHandler) process(ctx context.Context, dctx *dnsContext) (rc resultCode) {
h.logger.DebugContext(ctx, "started processing")
defer h.logger.DebugContext(ctx, "finished processing")

View File

@ -6,6 +6,7 @@ import (
"testing"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)
@ -62,7 +63,7 @@ func TestIpsetCtx_process(t *testing.T) {
ictx := &ipsetHandler{
logger: testLogger,
}
rc := ictx.process(dctx)
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
assert.Equal(t, resultCodeSuccess, rc)
err := ictx.close()
@ -85,7 +86,7 @@ func TestIpsetCtx_process(t *testing.T) {
logger: testLogger,
}
rc := ictx.process(dctx)
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
assert.Equal(t, resultCodeSuccess, rc)
assert.Equal(t, []net.IP{ip4}, m.ip4s)
assert.Empty(t, m.ip6s)
@ -110,7 +111,7 @@ func TestIpsetCtx_process(t *testing.T) {
logger: testLogger,
}
rc := ictx.process(dctx)
rc := ictx.process(testutil.ContextWithTimeout(t, testTimeout), dctx)
assert.Equal(t, resultCodeSuccess, rc)
assert.Empty(t, m.ip4s)
assert.Equal(t, []net.IP{ip6}, m.ip6s)

View File

@ -48,6 +48,7 @@ func ipsFromRules(resRules []*filtering.ResultRule) (ips []netip.Addr) {
// genDNSFilterMessage generates a filtered response to req for the filtering
// result res.
func (s *Server) genDNSFilterMessage(
ctx context.Context,
dctx *proxy.DNSContext,
res *filtering.Result,
) (resp *dns.Msg) {
@ -64,22 +65,27 @@ func (s *Server) genDNSFilterMessage(
switch res.Reason {
case filtering.FilteredSafeBrowsing:
return s.genBlockedHost(req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
return s.genBlockedHost(ctx, req, s.dnsFilter.SafeBrowsingBlockHost(), dctx)
case filtering.FilteredParental:
return s.genBlockedHost(req, s.dnsFilter.ParentalBlockHost(), dctx)
return s.genBlockedHost(ctx, req, s.dnsFilter.ParentalBlockHost(), dctx)
case filtering.FilteredSafeSearch:
// If Safe Search generated the necessary IP addresses, use them.
// Otherwise, if there were no errors, there are no addresses for the
// requested IP version, so produce a NODATA response.
return s.getCNAMEWithIPs(req, ipsFromRules(res.Rules), res.CanonName)
return s.getCNAMEWithIPs(ctx, req, ipsFromRules(res.Rules), res.CanonName)
default:
return s.genForBlockingMode(req, ipsFromRules(res.Rules))
return s.genForBlockingMode(ctx, req, ipsFromRules(res.Rules))
}
}
// getCNAMEWithIPs generates a filtered response to req for with CNAME record
// and provided ips.
func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (resp *dns.Msg) {
func (s *Server) getCNAMEWithIPs(
ctx context.Context,
req *dns.Msg,
ips []netip.Addr,
cname string,
) (resp *dns.Msg) {
resp = s.replyCompressed(req)
originalName := req.Question[0].Name
@ -95,7 +101,7 @@ func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (
switch req.Question[0].Qtype {
case dns.TypeA:
ans = append(ans, s.genAnswersWithIPv4s(req, ips)...)
ans = append(ans, s.genAnswersWithIPv4s(ctx, req, ips)...)
case dns.TypeAAAA:
for _, ip := range ips {
if ip.Is6() {
@ -113,25 +119,28 @@ func (s *Server) getCNAMEWithIPs(req *dns.Msg, ips []netip.Addr, cname string) (
// genForBlockingMode generates a filtered response to req based on the server's
// blocking mode.
func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
func (s *Server) genForBlockingMode(
ctx context.Context,
req *dns.Msg,
ips []netip.Addr,
) (resp *dns.Msg) {
switch mode, bIPv4, bIPv6 := s.dnsFilter.BlockingMode(); mode {
case filtering.BlockingModeCustomIP:
return s.makeResponseCustomIP(req, bIPv4, bIPv6)
return s.makeResponseCustomIP(ctx, req, bIPv4, bIPv6)
case filtering.BlockingModeDefault:
if len(ips) > 0 {
return s.genResponseWithIPs(req, ips)
return s.genResponseWithIPs(ctx, req, ips)
}
return s.makeResponseNullIP(req)
return s.makeResponseNullIP(ctx, req)
case filtering.BlockingModeNullIP:
return s.makeResponseNullIP(req)
return s.makeResponseNullIP(ctx, req)
case filtering.BlockingModeNXDOMAIN:
return s.NewMsgNXDOMAIN(req)
case filtering.BlockingModeREFUSED:
return s.makeResponseREFUSED(req)
default:
// TODO(s.chzhen): Pass context.
s.logger.ErrorContext(context.TODO(), "invalid blocking mode", "mode", mode)
s.logger.ErrorContext(ctx, "invalid blocking mode", "mode", mode)
return s.replyCompressed(req)
}
@ -140,6 +149,7 @@ func (s *Server) genForBlockingMode(req *dns.Msg, ips []netip.Addr) (resp *dns.M
// makeResponseCustomIP generates a DNS response message for Custom IP blocking
// mode with the provided IP addresses and an appropriate resource record type.
func (s *Server) makeResponseCustomIP(
ctx context.Context,
req *dns.Msg,
bIPv4 netip.Addr,
bIPv6 netip.Addr,
@ -152,10 +162,8 @@ func (s *Server) makeResponseCustomIP(
default:
// Generally shouldn't happen, since the types are checked in
// genDNSFilterMessage.
//
// TODO(s.chzhen): Pass context.
s.logger.ErrorContext(
context.TODO(),
ctx,
"invalid message type for custom IP blocking mode",
"dns_type", dns.Type(qt),
)
@ -242,11 +250,15 @@ func (s *Server) genAnswerTXT(req *dns.Msg, strs []string) (ans *dns.TXT) {
// addresses and an appropriate resource record type. If any of the IPs cannot
// be converted to the correct protocol, genResponseWithIPs returns an empty
// response.
func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.Msg) {
func (s *Server) genResponseWithIPs(
ctx context.Context,
req *dns.Msg,
ips []netip.Addr,
) (resp *dns.Msg) {
var ans []dns.RR
switch req.Question[0].Qtype {
case dns.TypeA:
ans = s.genAnswersWithIPv4s(req, ips)
ans = s.genAnswersWithIPv4s(ctx, req, ips)
case dns.TypeAAAA:
for _, ip := range ips {
if ip.Is6() {
@ -266,11 +278,14 @@ func (s *Server) genResponseWithIPs(req *dns.Msg, ips []netip.Addr) (resp *dns.M
// genAnswersWithIPv4s generates DNS A answers provided IPv4 addresses. If any
// of the IPs isn't an IPv4 address, genAnswersWithIPv4s logs a warning and
// returns nil,
func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.RR) {
func (s *Server) genAnswersWithIPv4s(
ctx context.Context,
req *dns.Msg,
ips []netip.Addr,
) (ans []dns.RR) {
for _, ip := range ips {
if !ip.Is4() {
// TODO(s.chzhen): Pass context.
s.logger.WarnContext(context.TODO(), "ip is not an ipv4 address", "ip", ip)
s.logger.WarnContext(ctx, "ip is not an ipv4 address", "ip", ip)
return nil
}
@ -283,16 +298,16 @@ func (s *Server) genAnswersWithIPv4s(req *dns.Msg, ips []netip.Addr) (ans []dns.
// makeResponseNullIP creates a response with 0.0.0.0 for A requests, :: for
// AAAA requests, and an empty response for other types.
func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
func (s *Server) makeResponseNullIP(ctx context.Context, req *dns.Msg) (resp *dns.Msg) {
// Respond with the corresponding zero IP type as opposed to simply
// using one or the other in both cases, because the IPv4 zero IP is
// converted to a IPV6-mapped IPv4 address, while the IPv6 zero IP is
// converted into an empty slice instead of the zero IPv4.
switch req.Question[0].Qtype {
case dns.TypeA:
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv4Unspecified()})
resp = s.genResponseWithIPs(ctx, req, []netip.Addr{netip.IPv4Unspecified()})
case dns.TypeAAAA:
resp = s.genResponseWithIPs(req, []netip.Addr{netip.IPv6Unspecified()})
resp = s.genResponseWithIPs(ctx, req, []netip.Addr{netip.IPv6Unspecified()})
default:
resp = s.replyCompressed(req)
}
@ -300,10 +315,12 @@ func (s *Server) makeResponseNullIP(req *dns.Msg) (resp *dns.Msg) {
return resp
}
func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSContext) *dns.Msg {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) genBlockedHost(
ctx context.Context,
request *dns.Msg,
newAddr string,
d *proxy.DNSContext,
) (msg *dns.Msg) {
if newAddr == "" {
s.logger.InfoContext(ctx, "block host not specified")
@ -312,7 +329,7 @@ func (s *Server) genBlockedHost(request *dns.Msg, newAddr string, d *proxy.DNSCo
ip, err := netip.ParseAddr(newAddr)
if err == nil {
return s.genResponseWithIPs(request, []netip.Addr{ip})
return s.genResponseWithIPs(ctx, request, []netip.Addr{ip})
}
// look up the hostname, TODO: cache

View File

@ -82,13 +82,16 @@ const ddrHostFQDN = "_dns.resolver.arpa."
// handleDNSRequest filters the incoming DNS requests and writes them to the query log
func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
dctx := &dnsContext{
proxyCtx: pctx,
result: &filtering.Result{},
startTime: time.Now(),
}
type modProcessFunc func(ctx *dnsContext) (rc resultCode)
type modProcessFunc func(ctx context.Context, dctx *dnsContext) (rc resultCode)
// Since (*dnsforward.Server).handleDNSRequest(...) is used as
// proxy.(Config).RequestHandler, there is no need for additional index
@ -107,7 +110,7 @@ func (s *Server) handleDNSRequest(_ *proxy.Proxy, pctx *proxy.DNSContext) error
s.processQueryLogsAndStats,
}
for _, process := range mods {
r := process(dctx)
r := process(ctx, dctx)
switch r {
case resultCodeSuccess:
// continue: call the next filter
@ -148,14 +151,12 @@ const healthcheckFQDN = "healthcheck.adguardhome.test."
// needed and enriches dctx with some client-specific information.
//
// TODO(e.burkov): Decompose into less general processors.
func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processInitial(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing initial")
defer s.logger.DebugContext(ctx, "finished processing initial")
pctx := dctx.proxyCtx
s.processClientIP(pctx.Addr.Addr())
s.processClientIP(ctx, pctx.Addr.Addr())
q := pctx.Req.Question[0]
qt := q.Qtype
@ -185,16 +186,14 @@ func (s *Server) processInitial(dctx *dnsContext) (rc resultCode) {
dctx.clientID = string(s.clientIDCache.Get(key[:]))
// Get the client-specific filtering settings.
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus()
dctx.protectionEnabled, _ = s.UpdatedProtectionStatus(ctx)
dctx.setts = s.clientRequestFilteringSettings(dctx)
return resultCodeSuccess
}
// processClientIP sends the client IP address to s.addrProc, if needed.
func (s *Server) processClientIP(addr netip.Addr) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processClientIP(ctx context.Context, addr netip.Addr) {
if !addr.IsValid() {
s.logger.WarnContext(ctx, "bad client address", "addr", addr)
@ -214,9 +213,7 @@ func (s *Server) processClientIP(addr netip.Addr) {
// current user configuration.
//
// See https://www.ietf.org/archive/id/draft-ietf-add-ddr-10.html.
func (s *Server) processDDRQuery(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processDDRQuery(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing ddr")
defer s.logger.DebugContext(ctx, "finished processing ddr")
@ -315,9 +312,7 @@ func (s *Server) makeDDRResponse(req *dns.Msg) (resp *dns.Msg) {
// the request is for AAAA.
//
// TODO(a.garipov): Adapt to AAAA as well.
func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processDHCPHosts(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing dhcp hosts")
defer s.logger.DebugContext(ctx, "finished processing dhcp hosts")
@ -383,9 +378,7 @@ func (s *Server) processDHCPHosts(dctx *dnsContext) (rc resultCode) {
// processDHCPAddrs responds to PTR requests if the target IP is leased by the
// DHCP server.
func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processDHCPAddrs(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing dhcp addrs")
defer s.logger.DebugContext(ctx, "finished processing dhcp addrs")
@ -430,9 +423,10 @@ func (s *Server) processDHCPAddrs(dctx *dnsContext) (rc resultCode) {
}
// Apply filtering logic
func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processFilteringBeforeRequest(
ctx context.Context,
dctx *dnsContext,
) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing filtering before request")
defer s.logger.DebugContext(ctx, "finished processing filtering before request")
@ -454,7 +448,7 @@ func (s *Server) processFilteringBeforeRequest(dctx *dnsContext) (rc resultCode)
defer s.serverLock.RUnlock()
var err error
if dctx.result, err = s.filterDNSRequest(dctx); err != nil {
if dctx.result, err = s.filterDNSRequest(ctx, dctx); err != nil {
dctx.err = err
return resultCodeError
@ -473,9 +467,7 @@ func ipStringFromAddr(addr net.Addr) (ipStr string) {
}
// processUpstream passes request to upstream servers and handles the response.
func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processUpstream(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing upstream")
defer s.logger.DebugContext(ctx, "finished processing upstream")
@ -502,7 +494,7 @@ func (s *Server) processUpstream(dctx *dnsContext) (rc resultCode) {
return resultCodeFinish
}
s.setCustomUpstream(pctx, dctx.clientID)
s.setCustomUpstream(ctx, pctx, dctx.clientID)
reqWantsDNSSEC := s.setReqAD(req)
@ -592,7 +584,7 @@ func (s *Server) dhcpHostFromRequest(q *dns.Question) (reqHost string) {
}
// setCustomUpstream sets custom upstream settings in pctx, if necessary.
func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
func (s *Server) setCustomUpstream(ctx context.Context, pctx *proxy.DNSContext, clientID string) {
if !pctx.Addr.IsValid() || s.conf.ClientsContainer == nil {
return
}
@ -600,9 +592,8 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
cliAddr := pctx.Addr.Addr()
upsConf := s.conf.ClientsContainer.CustomUpstreamConfig(clientID, cliAddr)
if upsConf != nil {
// TODO(s.chzhen): Pass context.
s.logger.DebugContext(
context.TODO(),
ctx,
"using custom upstreams for client with",
"ip", cliAddr,
"client_id", clientID,
@ -613,9 +604,7 @@ func (s *Server) setCustomUpstream(pctx *proxy.DNSContext, clientID string) {
}
// Apply filtering logic after we have received response from upstream servers
func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processFilteringAfterResponse(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing filtering after response")
defer s.logger.DebugContext(ctx, "finished processing filtering after response")
@ -642,13 +631,13 @@ func (s *Server) processFilteringAfterResponse(dctx *dnsContext) (rc resultCode)
return resultCodeSuccess
default:
return s.filterAfterResponse(dctx)
return s.filterAfterResponse(ctx, dctx)
}
}
// filterAfterResponse returns the result of filtering the response that wasn't
// explicitly allowed or rewritten.
func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
func (s *Server) filterAfterResponse(ctx context.Context, dctx *dnsContext) (res resultCode) {
// Check the response only if it's from an upstream. Don't check the
// response if the protection is disabled since dnsrewrite rules aren't
// applied to it anyway.
@ -656,7 +645,7 @@ func (s *Server) filterAfterResponse(dctx *dnsContext) (res resultCode) {
return resultCodeSuccess
}
err := s.filterDNSResponse(dctx)
err := s.filterDNSResponse(ctx, dctx)
if err != nil {
dctx.err = err

View File

@ -105,7 +105,7 @@ func TestServer_ProcessInitial(t *testing.T) {
},
}
gotRC := s.processInitial(dctx)
gotRC := s.processInitial(testutil.ContextWithTimeout(t, testTimeout), dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, testClientAddrPort.Addr(), gotAddr)
@ -208,8 +208,8 @@ func TestServer_ProcessFilteringAfterResponse(t *testing.T) {
Addr: testClientAddrPort,
},
}
gotRC := s.processFilteringAfterResponse(dctx)
ctx := testutil.ContextWithTimeout(t, testTimeout)
gotRC := s.processFilteringAfterResponse(ctx, dctx)
assert.Equal(t, tc.wantRC, gotRC)
assert.Equal(t, newResp(dns.RcodeSuccess, tc.req, tc.wantRespAns), dctx.proxyCtx.Res)
})
@ -353,7 +353,7 @@ func TestServer_ProcessDDRQuery(t *testing.T) {
},
}
res := s.processDDRQuery(dctx)
res := s.processDDRQuery(testutil.ContextWithTimeout(t, testTimeout), dctx)
require.Equal(t, tc.wantRes, res)
if tc.wantRes != resultCodeFinish {
@ -461,7 +461,7 @@ func TestServer_ProcessDHCPHosts_localRestriction(t *testing.T) {
},
}
res := s.processDHCPHosts(dctx)
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
pctx := dctx.proxyCtx
if !tc.isLocalCli {
@ -606,7 +606,7 @@ func TestServer_ProcessDHCPHosts(t *testing.T) {
}
t.Run(tc.name, func(t *testing.T) {
res := s.processDHCPHosts(dctx)
res := s.processDHCPHosts(testutil.ContextWithTimeout(t, testTimeout), dctx)
pctx := dctx.proxyCtx
assert.Equal(t, tc.wantRes, res)
require.NoError(t, dctx.err)
@ -814,9 +814,9 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
ServePlainDNS: true,
},
)
ctx := testutil.ContextWithTimeout(t, testTimeout)
pctx := newPrxCtx()
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeSuccess, rc)
require.NotEmpty(t, pctx.Res.Answer)
ptr := testutil.RequireTypeAssert[*dns.PTR](t, pctx.Res.Answer[0])
@ -846,7 +846,8 @@ func TestServer_ProcessUpstream_localPTR(t *testing.T) {
)
pctx := newPrxCtx()
rc := s.processUpstream(&dnsContext{proxyCtx: pctx})
ctx := testutil.ContextWithTimeout(t, testTimeout)
rc := s.processUpstream(ctx, &dnsContext{proxyCtx: pctx})
require.Equal(t, resultCodeError, rc)
require.Empty(t, pctx.Res.Answer)
})

View File

@ -14,9 +14,7 @@ import (
)
// Write Stats data and logs
func (s *Server) processQueryLogsAndStats(dctx *dnsContext) (rc resultCode) {
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
func (s *Server) processQueryLogsAndStats(ctx context.Context, dctx *dnsContext) (rc resultCode) {
s.logger.DebugContext(ctx, "started processing querylog and stats")
defer s.logger.DebugContext(ctx, "finished processing querylog and stats")

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/stats"
"github.com/AdguardTeam/dnsproxy/proxy"
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/testutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -230,7 +231,7 @@ func TestServer_ProcessQueryLogsAndStats(t *testing.T) {
clientID: tc.clientID,
}
code := srv.processQueryLogsAndStats(dctx)
code := srv.processQueryLogsAndStats(testutil.ContextWithTimeout(t, testTimeout), dctx)
assert.Equal(t, tc.wantCode, code)
assert.Equal(t, tc.wantLogProto, ql.lastParams.ClientProto)
assert.Equal(t, tc.wantStatClient, st.lastEntry.Client)

View File

@ -15,9 +15,9 @@ import (
//
// See the comment on genAnswerSVCB for a list of current restrictions on
// parameter values.
func (s *Server) genAnswerHTTPS(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) {
func (s *Server) genAnswerHTTPS(ctx context.Context, req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.HTTPS) {
ans = &dns.HTTPS{
SVCB: *s.genAnswerSVCB(req, svcb),
SVCB: *s.genAnswerSVCB(ctx, req, svcb),
}
ans.Hdr.Rrtype = dns.TypeHTTPS
@ -164,7 +164,11 @@ var svcbKeyHandlers = map[string]svcbKeyHandler{
// ipv4hint="127.0.0.1,127.0.0.2" // Unsupported.
//
// TODO(a.garipov): Support all of these.
func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB) {
func (s *Server) genAnswerSVCB(
ctx context.Context,
req *dns.Msg,
svcb *rules.DNSSVCB,
) (ans *dns.SVCB) {
ans = &dns.SVCB{
Hdr: s.hdr(req, dns.TypeSVCB),
Priority: svcb.Priority,
@ -178,8 +182,7 @@ func (s *Server) genAnswerSVCB(req *dns.Msg, svcb *rules.DNSSVCB) (ans *dns.SVCB
for k, valStr := range svcb.Params {
handler, ok := svcbKeyHandlers[k]
if !ok {
// TODO(s.chzhen): Pass context.
s.logger.DebugContext(context.TODO(), "unknown svcb/https key, ignoring", "key", k)
s.logger.DebugContext(ctx, "unknown svcb/https key, ignoring", "key", k)
continue
}

View File

@ -5,6 +5,7 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@ -152,14 +153,14 @@ func TestGenAnswerHTTPS_andSVCB(t *testing.T) {
want := &dns.HTTPS{SVCB: *tc.want}
want.Hdr.Rrtype = dns.TypeHTTPS
got := s.genAnswerHTTPS(req, tc.svcb)
got := s.genAnswerHTTPS(testutil.ContextWithTimeout(t, testTimeout), req, tc.svcb)
assert.Equal(t, want, got)
})
})
t.Run("svcb", func(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
got := s.genAnswerSVCB(req, tc.svcb)
got := s.genAnswerSVCB(testutil.ContextWithTimeout(t, testTimeout), req, tc.svcb)
assert.Equal(t, tc.want, got)
})
})

View File

@ -115,6 +115,8 @@ type statusResponse struct {
}
func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
dnsAddrs, err := collectDNSAddresses(web.tlsManager)
if err != nil {
// Don't add a lot of formatting, since the error is already
@ -125,14 +127,14 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
}
var (
fltConf *dnsforward.Config
protectionDisabledUntil *time.Time
protectionEnabled bool
fltConf *dnsforward.Config
protDisabledUntil *time.Time
protEnabled bool
)
if globalContext.dnsServer != nil {
fltConf = &dnsforward.Config{}
globalContext.dnsServer.WriteDiskConfig(fltConf)
protectionEnabled, protectionDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus()
protEnabled, protDisabledUntil = globalContext.dnsServer.UpdatedProtectionStatus(ctx)
}
var resp statusResponse
@ -141,11 +143,11 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
defer config.RUnlock()
var protectionDisabledDuration int64
if protectionDisabledUntil != nil {
if protDisabledUntil != nil {
// Make sure that we don't send negative numbers to the frontend,
// since enough time might have passed to make the difference less
// than zero.
protectionDisabledDuration = max(0, time.Until(*protectionDisabledUntil).Milliseconds())
protectionDisabledDuration = max(0, time.Until(*protDisabledUntil).Milliseconds())
}
resp = statusResponse{
@ -155,7 +157,7 @@ func (web *webAPI) handleStatus(w http.ResponseWriter, r *http.Request) {
DNSPort: config.DNS.Port,
HTTPPort: config.HTTPConfig.Address.Port(),
ProtectionDisabledDuration: protectionDisabledDuration,
ProtectionEnabled: protectionEnabled,
ProtectionEnabled: protEnabled,
IsRunning: isRunning(),
}
}()

View File

@ -544,7 +544,7 @@ func startMods(
return err
}
err = initDNS(baseLogger, tlsMgr, confModifier, statsDir, querylogDir)
err = initDNS(ctx, baseLogger, tlsMgr, confModifier, statsDir, querylogDir)
if err != nil {
return err
}
@ -553,7 +553,7 @@ func startMods(
err = startDNSServer()
if err != nil {
closeDNSServer()
closeDNSServer(ctx)
return err
}

View File

@ -44,6 +44,7 @@ const (
// [config] and [globalContext] are initialized. baseLogger, tlsMgr and
// confModfier must not be nil.
func initDNS(
ctx context.Context,
baseLogger *slog.Logger,
tlsMgr *tlsManager,
confModifier agh.ConfigModifier,
@ -105,6 +106,7 @@ func initDNS(
}
return initDNSServer(
ctx,
globalContext.filters,
globalContext.stats,
globalContext.queryLog,
@ -124,6 +126,7 @@ func initDNS(
//
// TODO(e.burkov): Use [dnsforward.DNSCreateParams] as a parameter.
func initDNSServer(
ctx context.Context,
filters *filtering.DNSFilter,
sts stats.Interface,
qlog querylog.QueryLog,
@ -147,7 +150,7 @@ func initDNSServer(
})
defer func() {
if err != nil {
closeDNSServer()
closeDNSServer(ctx)
}
}()
if err != nil {
@ -171,12 +174,12 @@ func initDNSServer(
// Try to prepare the server with disabled private RDNS resolution if it
// failed to prepare as is. See TODO on [dnsforward.PrivateRDNSError].
err = globalContext.dnsServer.Prepare(dnsConf)
err = globalContext.dnsServer.Prepare(ctx, dnsConf)
if privRDNSErr := (&dnsforward.PrivateRDNSError{}); errors.As(err, &privRDNSErr) {
log.Info("WARNING: %s; trying to disable private RDNS resolution", err)
dnsConf.UsePrivateRDNS = false
err = globalContext.dnsServer.Prepare(dnsConf)
err = globalContext.dnsServer.Prepare(ctx, dnsConf)
}
if err != nil {
@ -450,7 +453,7 @@ func startDNSServer() error {
return fmt.Errorf("starting clients container: %w", err)
}
err = globalContext.dnsServer.Start()
err = globalContext.dnsServer.Start(ctx)
if err != nil {
return fmt.Errorf("starting dns server: %w", err)
}
@ -466,30 +469,30 @@ func startDNSServer() error {
return nil
}
func stopDNSServer() (err error) {
func stopDNSServer(ctx context.Context) (err error) {
if !isRunning() {
return nil
}
err = globalContext.dnsServer.Stop()
err = globalContext.dnsServer.Stop(ctx)
if err != nil {
return fmt.Errorf("stopping forwarding dns server: %w", err)
}
err = globalContext.clients.close(context.TODO())
err = globalContext.clients.close(ctx)
if err != nil {
return fmt.Errorf("closing clients container: %w", err)
}
closeDNSServer()
closeDNSServer(ctx)
return nil
}
func closeDNSServer() {
func closeDNSServer(ctx context.Context) {
// DNS forward module must be closed BEFORE stats or queryLog because it depends on them
if globalContext.dnsServer != nil {
globalContext.dnsServer.Close()
globalContext.dnsServer.Close(ctx)
globalContext.dnsServer = nil
}
@ -505,8 +508,7 @@ func closeDNSServer() {
}
if globalContext.queryLog != nil {
// TODO(s.chzhen): Pass context.
err := globalContext.queryLog.Shutdown(context.TODO())
err := globalContext.queryLog.Shutdown(ctx)
if err != nil {
log.Error("closing query log: %s", err)
}

View File

@ -731,7 +731,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
fatalOnError(err)
if !globalContext.firstRun {
err = initDNS(slogLogger, tlsMgr, confModifier, statsDir, querylogDir)
err = initDNS(ctx, slogLogger, tlsMgr, confModifier, statsDir, querylogDir)
fatalOnError(err)
tlsMgr.start(ctx)
@ -739,7 +739,7 @@ func run(opts options, clientBuildFS fs.FS, done chan struct{}, sigHdlr *signalH
go func() {
startErr := startDNSServer()
if startErr != nil {
closeDNSServer()
closeDNSServer(ctx)
fatalOnError(startErr)
}
}()
@ -975,7 +975,7 @@ func cleanup(ctx context.Context) {
globalContext.web = nil
}
err := stopDNSServer()
err := stopDNSServer(ctx)
if err != nil {
log.Error("stopping dns server: %s", err)
}
@ -1133,7 +1133,7 @@ func cmdlineUpdate(
//
// TODO(e.burkov): We could probably initialize the internal resolver
// separately.
err := initDNSServer(nil, nil, nil, nil, nil, nil, tlsMgr, l, agh.EmptyConfigModifier{})
err := initDNSServer(ctx, nil, nil, nil, nil, nil, nil, tlsMgr, l, agh.EmptyConfigModifier{})
fatalOnError(err)
l.InfoContext(ctx, "performing update via cli")

View File

@ -232,7 +232,7 @@ func (m *tlsManager) reload(ctx context.Context) {
m.certLastMod = fi.ModTime().UTC()
err = m.reconfigureDNSServer()
err = m.reconfigureDNSServer(ctx)
if err != nil {
m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err)
}
@ -245,7 +245,7 @@ func (m *tlsManager) reload(ctx context.Context) {
// reconfigureDNSServer updates the DNS server configuration using the stored
// TLS settings. m.mu is expected to be locked.
func (m *tlsManager) reconfigureDNSServer() (err error) {
func (m *tlsManager) reconfigureDNSServer(ctx context.Context) (err error) {
newConf, err := newServerConfig(
&config.DNS,
config.Clients.Sources,
@ -259,7 +259,7 @@ func (m *tlsManager) reconfigureDNSServer() (err error) {
return fmt.Errorf("generating forwarding dns server config: %w", err)
}
err = globalContext.dnsServer.Reconfigure(newConf)
err = globalContext.dnsServer.Reconfigure(ctx, newConf)
if err != nil {
return fmt.Errorf("starting forwarding dns server: %w", err)
}
@ -558,7 +558,7 @@ func (m *tlsManager) handleTLSConfigure(w http.ResponseWriter, r *http.Request)
}()
}
err = m.reconfigureDNSServer()
err = m.reconfigureDNSServer(ctx)
if err != nil {
m.logger.ErrorContext(ctx, "reconfiguring dns server", slogutil.KeyError, err)

View File

@ -273,7 +273,9 @@ func TestTLSManager_Reload(t *testing.T) {
// The [tlsManager.reload] method will start the DNS server and it should be
// stopped after the test ends.
testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout))
})
conf = m.config()
assertCertSerialNumber(t, conf, snAfter)
@ -477,15 +479,17 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
})
require.NoError(t, err)
err = globalContext.dnsServer.Prepare(&dnsforward.ServerConfig{
TLSConf: &dnsforward.TLSConfig{},
Config: dnsforward.Config{
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false},
ClientsContainer: dnsforward.EmptyClientsContainer{},
},
ServePlainDNS: true,
})
err = globalContext.dnsServer.Prepare(
testutil.ContextWithTimeout(t, testTimeout),
&dnsforward.ServerConfig{
TLSConf: &dnsforward.TLSConfig{},
Config: dnsforward.Config{
UpstreamMode: dnsforward.UpstreamModeLoadBalance,
EDNSClientSubnet: &dnsforward.EDNSClientSubnet{Enabled: false},
ClientsContainer: dnsforward.EmptyClientsContainer{},
},
ServePlainDNS: true,
})
require.NoError(t, err)
globalContext.clients.storage, err = client.NewStorage(ctx, &client.StorageConfig{
@ -552,7 +556,9 @@ func TestTLSManager_HandleTLSConfigure(t *testing.T) {
// The [tlsManager.handleTLSConfigure] method will start the DNS server and
// it should be stopped after the test ends.
testutil.CleanupAndRequireSuccess(t, globalContext.dnsServer.Stop)
testutil.CleanupAndRequireSuccess(t, func() (err error) {
return globalContext.dnsServer.Stop(testutil.ContextWithTimeout(t, testTimeout))
})
res := &tlsConfig{
tlsConfigStatus: &tlsConfigStatus{},