diff --git a/internal/aghnet/dhcp.go b/internal/aghnet/dhcp.go index 327a5656..0c2e4d46 100644 --- a/internal/aghnet/dhcp.go +++ b/internal/aghnet/dhcp.go @@ -1,6 +1,16 @@ package aghnet -// CheckOtherDHCP tries to discover another DHCP server in the network. -func CheckOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) { - return checkOtherDHCP(ifaceName) +import ( + "context" + "log/slog" +) + +// CheckOtherDHCP tries to discover another DHCP server in the network. l must +// not be nil. +func CheckOtherDHCP( + ctx context.Context, + l *slog.Logger, + ifaceName string, +) (ok4, ok6 bool, err4, err6 error) { + return checkOtherDHCP(ctx, l, ifaceName) } diff --git a/internal/aghnet/dhcp_unix.go b/internal/aghnet/dhcp_unix.go index b75f40c4..06d421e0 100644 --- a/internal/aghnet/dhcp_unix.go +++ b/internal/aghnet/dhcp_unix.go @@ -4,14 +4,16 @@ package aghnet import ( "bytes" + "context" "fmt" + "log/slog" "net" "net/netip" "os" "time" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv6" @@ -23,7 +25,11 @@ import ( // response. const defaultDiscoverTime = 3 * time.Second -func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) { +func checkOtherDHCP( + ctx context.Context, + l *slog.Logger, + ifaceName string, +) (ok4, ok6 bool, err4, err6 error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { err = fmt.Errorf("couldn't find interface by name %s: %w", ifaceName, err) @@ -32,8 +38,8 @@ func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) { return false, false, err4, err6 } - ok4, err4 = checkOtherDHCPv4(iface) - ok6, err6 = checkOtherDHCPv6(iface) + ok4, err4 = checkOtherDHCPv4(ctx, l, iface) + ok6, err6 = checkOtherDHCPv6(ctx, l, iface) return ok4, ok6, err4, err6 } @@ -68,8 +74,13 @@ func ifaceIPv4Subnet(iface *net.Interface) (subnet netip.Prefix, err error) { } // checkOtherDHCPv4 sends a DHCP request to the specified network interface, and -// waits for a response for a period defined by defaultDiscoverTime. -func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) { +// waits for a response for a period defined by defaultDiscoverTime. l must not +// be nil. +func checkOtherDHCPv4( + ctx context.Context, + l *slog.Logger, + iface *net.Interface, +) (ok bool, err error) { var subnet netip.Prefix if subnet, err = ifaceIPv4Subnet(iface); err != nil { return false, err @@ -87,10 +98,16 @@ func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) { return false, fmt.Errorf("couldn't get hostname: %w", err) } - return discover4(iface, dstAddr, hostname) + return discover4(ctx, l, iface, dstAddr, hostname) } -func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok bool, err error) { +func discover4( + ctx context.Context, + l *slog.Logger, + iface *net.Interface, + dstAddr *net.UDPAddr, + hostname string, +) (ok bool, err error) { var req *dhcpv4.DHCPv4 if req, err = dhcpv4.NewDiscovery(iface.HardwareAddr); err != nil { return false, fmt.Errorf("dhcpv4.NewDiscovery: %w", err) @@ -125,10 +142,10 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok } var next bool - ok, next, err = tryConn4(req, c, iface) + ok, next, err = tryConn4(ctx, l, req, c, iface) if next { if err != nil { - log.Debug("dhcpv4: trying a connection: %s", err) + l.DebugContext(ctx, "dhcpv4: trying a connection", slogutil.KeyError, err) } continue @@ -144,16 +161,22 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok // TODO(a.garipov): Refactor further. Inspect error handling, remove parameter // next, address the TODO, merge with tryConn6, etc. -func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, next bool, err error) { +func tryConn4( + ctx context.Context, + l *slog.Logger, + req *dhcpv4.DHCPv4, + c net.PacketConn, + iface *net.Interface, +) (ok, next bool, err error) { // TODO: replicate dhclient's behavior of retrying several times with // progressively longer timeouts. - log.Tracef("dhcpv4: waiting %v for an answer", defaultDiscoverTime) + l.Log(ctx, slogutil.LevelTrace, "dhcpv4: waiting for an answer", "timeout", defaultDiscoverTime) b := make([]byte, 1500) n, _, err := c.ReadFrom(b) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { - log.Debug("dhcpv4: didn't receive dhcp response") + l.DebugContext(ctx, "dhcpv4: didn't receive dhcp response") return false, false, nil } @@ -161,16 +184,16 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n return false, false, fmt.Errorf("receiving packet: %w", err) } - log.Tracef("dhcpv4: received packet, %d bytes", n) + l.Log(ctx, slogutil.LevelTrace, "dhcpv4: received packet", "size", n) response, err := dhcpv4.FromBytes(b[:n]) if err != nil { - log.Debug("dhcpv4: encoding: %s", err) + l.DebugContext(ctx, "dhcpv4: encoding", slogutil.KeyError, err) return false, true, err } - log.Debug("dhcpv4: received message from server: %s", response.Summary()) + l.DebugContext(ctx, "dhcpv4: received message from server", "summary", response.Summary()) switch { case @@ -179,19 +202,24 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n !bytes.Equal(response.ClientHWAddr, iface.HardwareAddr), response.TransactionID != req.TransactionID, !response.Options.Has(dhcpv4.OptionDHCPMessageType): - log.Debug("dhcpv4: received response doesn't match the request") + l.DebugContext(ctx, "dhcpv4: received response does not match the request") return false, true, nil default: - log.Tracef("dhcpv4: the packet is from an active dhcp server") + l.Log(ctx, slogutil.LevelTrace, "dhcpv4: the packet is from an active dhcp server") return true, false, nil } } // checkOtherDHCPv6 sends a DHCP request to the specified network interface, and -// waits for a response for a period defined by defaultDiscoverTime. -func checkOtherDHCPv6(iface *net.Interface) (ok bool, err error) { +// waits for a response for a period defined by defaultDiscoverTime. l must not +// be nil. +func checkOtherDHCPv6( + ctx context.Context, + l *slog.Logger, + iface *net.Interface, +) (ok bool, err error) { ifaceIPNet, err := IfaceIPAddrs(iface, IPVersion6) if err != nil { return false, fmt.Errorf("getting ipv6 addrs for iface %s: %w", iface.Name, err) @@ -218,16 +246,22 @@ func checkOtherDHCPv6(iface *net.Interface) (ok bool, err error) { return false, fmt.Errorf("dhcpv6: Couldn't resolve UDP address %s: %w", dst, err) } - return discover6(iface, udpAddr, dstAddr) + return discover6(ctx, l, iface, udpAddr, dstAddr) } -func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, err error) { +func discover6( + ctx context.Context, + l *slog.Logger, + iface *net.Interface, + udpAddr *net.UDPAddr, + dstAddr *net.UDPAddr, +) (ok bool, err error) { req, err := dhcpv6.NewSolicit(iface.HardwareAddr) if err != nil { return false, fmt.Errorf("dhcpv6: dhcpv6.NewSolicit: %w", err) } - log.Debug("DHCPv6: Listening to udp6 %+v", udpAddr) + l.DebugContext(ctx, "dhcpv6: listening to udp6", "addr", udpAddr) c, err := nclient6.NewIPv6UDPConn(iface.Name, dhcpv6.DefaultClientPort) if err != nil { return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err) @@ -241,10 +275,10 @@ func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, er for { var next bool - ok, next, err = tryConn6(req, c) + ok, next, err = tryConn6(ctx, l, req, c) if next { if err != nil { - log.Debug("dhcpv6: trying a connection: %s", err) + l.DebugContext(ctx, "dhcpv6: trying a connection", slogutil.KeyError, err) } continue @@ -259,10 +293,15 @@ func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, er } // TODO(a.garipov): See the comment on tryConn4. Sigh… -func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error) { +func tryConn6( + ctx context.Context, + l *slog.Logger, + req *dhcpv6.Message, + c net.PacketConn, +) (ok, next bool, err error) { // TODO: replicate dhclient's behavior of retrying several times with // progressively longer timeouts. - log.Tracef("dhcpv6: waiting %v for an answer", defaultDiscoverTime) + l.Log(ctx, slogutil.LevelTrace, "dhcpv6: waiting for an answer", "timeout", defaultDiscoverTime) b := make([]byte, 4096) err = c.SetDeadline(time.Now().Add(defaultDiscoverTime)) @@ -273,7 +312,7 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error) n, _, err := c.ReadFrom(b) if err != nil { if errors.Is(err, os.ErrDeadlineExceeded) { - log.Debug("dhcpv6: didn't receive dhcp response") + l.DebugContext(ctx, "dhcpv6: didn't receive dhcp response") return false, false, nil } @@ -281,21 +320,21 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error) return false, false, fmt.Errorf("receiving packet: %w", err) } - log.Tracef("dhcpv6: received packet, %d bytes", n) + l.Log(ctx, slogutil.LevelTrace, "dhcpv6: received packet", "size", n) response, err := dhcpv6.FromBytes(b[:n]) if err != nil { - log.Debug("dhcpv6: encoding: %s", err) + l.DebugContext(ctx, "dhcpv6: encoding", slogutil.KeyError, err) return false, true, err } - log.Debug("dhcpv6: received message from server: %s", response.Summary()) + l.DebugContext(ctx, "dhcpv6: received message from server", "summary", response.Summary()) cid := req.Options.ClientID() msg, err := response.GetInnerMessage() if err != nil { - log.Debug("dhcpv6: resp.GetInnerMessage(): %s", err) + l.DebugContext(ctx, "getting inner message", slogutil.KeyError, err) return false, true, err } @@ -306,12 +345,12 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error) rcid != nil && cid.Equal(rcid)) { - log.Debug("dhcpv6: received message from server doesn't match our request") + l.DebugContext(ctx, "dhcpv6: received message from server does not match our request") return false, true, nil } - log.Tracef("dhcpv6: the packet is from an active dhcp server") + l.Log(ctx, slogutil.LevelTrace, "dhcpv6: the packet is from an active dhcp server") return true, false, nil } diff --git a/internal/aghnet/dhcp_windows.go b/internal/aghnet/dhcp_windows.go index f8f6dbd2..e10c8154 100644 --- a/internal/aghnet/dhcp_windows.go +++ b/internal/aghnet/dhcp_windows.go @@ -2,9 +2,18 @@ package aghnet -import "github.com/AdguardTeam/AdGuardHome/internal/aghos" +import ( + "context" + "log/slog" -func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) { + "github.com/AdguardTeam/AdGuardHome/internal/aghos" +) + +func checkOtherDHCP( + _ context.Context, + _ *slog.Logger, + ifaceName string, +) (ok4, ok6 bool, err4, err6 error) { return false, false, aghos.Unsupported("CheckIfOtherDHCPServersPresentV4"), diff --git a/internal/aghnet/hostscontainer.go b/internal/aghnet/hostscontainer.go index 14e2c5d4..b013b423 100644 --- a/internal/aghnet/hostscontainer.go +++ b/internal/aghnet/hostscontainer.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "io/fs" + "log/slog" "net/netip" "path" "sync/atomic" @@ -12,17 +13,21 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" ) -// hostsContainerPrefix is a prefix for logging and wrapping errors in -// HostsContainer's methods. +// hostsContainerPrefix is a prefix for wrapping errors in HostsContainer's +// methods. const hostsContainerPrefix = "hosts container" // HostsContainer stores the relevant hosts database provided by the OS and // processes both A/AAAA and PTR DNS requests for those. type HostsContainer struct { - // done is the channel to sign closing the container. + // logger is used for logging the operation of the hosts container. It must + // not be nil. + logger *slog.Logger + + // done is the channel to signal closing the container. done chan struct{} // updates is the channel for receiving updated hosts. @@ -31,10 +36,12 @@ type HostsContainer struct { // current is the last set of hosts parsed. current atomic.Pointer[hostsfile.DefaultStorage] - // fsys is the working file system to read hosts files from. + // fsys is the working file system to read hosts files from. It must not be + // nil. fsys fs.FS - // watcher tracks the changes in specified files and directories. + // watcher tracks the changes in specified files and directories. It must + // not be nil. watcher aghos.FSWatcher // patterns stores specified paths in the fs.Glob-compatible form. @@ -45,11 +52,12 @@ type HostsContainer struct { // the HostsContainer. const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided" -// NewHostsContainer creates a container of hosts, that watches the paths with -// w. listID is used as an identifier of the underlying rules list. paths -// shouldn't be empty and each of paths should locate either a file or a -// directory in fsys. fsys and w must be non-nil. +// NewHostsContainer creates a container of hosts that watches the paths with w. +// paths shouldn't be empty, and each path should refer either a file or a +// directory in fsys. l, fsys, and w must be non-nil. func NewHostsContainer( + ctx context.Context, + l *slog.Logger, fsys fs.FS, w aghos.FSWatcher, paths ...string, @@ -69,6 +77,7 @@ func NewHostsContainer( } hc = &HostsContainer{ + logger: l, done: make(chan struct{}, 1), updates: make(chan *hostsfile.DefaultStorage, 1), fsys: fsys, @@ -76,10 +85,10 @@ func NewHostsContainer( patterns: patterns, } - log.Debug("%s: starting", hostsContainerPrefix) + l.DebugContext(ctx, "starting") // Load initially. - if err = hc.refresh(); err != nil { + if err = hc.refresh(ctx); err != nil { return nil, err } @@ -89,11 +98,11 @@ func NewHostsContainer( return nil, fmt.Errorf("adding path: %w", err) } - log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPrefix, p) + l.DebugContext(ctx, "expected path does not exist", "path", p) } } - go hc.handleEvents() + go hc.handleEvents(ctx) return hc, nil } @@ -101,10 +110,10 @@ func NewHostsContainer( // Close implements the [io.Closer] interface for *HostsContainer. It closes // both itself and its [aghos.FSWatcher]. Close must only be called once. func (hc *HostsContainer) Close() (err error) { - log.Debug("%s: closing", hostsContainerPrefix) - // TODO(s.chzhen): Pass context. ctx := context.TODO() + hc.logger.DebugContext(ctx, "closing") + err = errors.Annotate(hc.watcher.Shutdown(ctx), "closing fs watcher: %w") // Go on and close the container either way. @@ -159,8 +168,8 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error) // handleEvents concurrently handles the file system events. It closes the // update channel of HostsContainer when finishes. It is intended to be used as // a goroutine. -func (hc *HostsContainer) handleEvents() { - defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix)) +func (hc *HostsContainer) handleEvents(ctx context.Context) { + defer slogutil.RecoverAndLog(ctx, hc.logger) defer close(hc.updates) @@ -170,13 +179,13 @@ func (hc *HostsContainer) handleEvents() { select { case _, ok = <-eventsCh: if !ok { - log.Debug("%s: watcher closed the events channel", hostsContainerPrefix) + hc.logger.DebugContext(ctx, "watcher closed the events channel") continue } - if err := hc.refresh(); err != nil { - log.Error("%s: warning: refreshing: %s", hostsContainerPrefix, err) + if err := hc.refresh(ctx); err != nil { + hc.logger.ErrorContext(ctx, "refreshing", slogutil.KeyError, err) } case _, ok = <-hc.done: // Go on. @@ -185,8 +194,8 @@ func (hc *HostsContainer) handleEvents() { } // sendUpd tries to send the parsed data to the ch. -func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) { - log.Debug("%s: sending upd", hostsContainerPrefix) +func (hc *HostsContainer) sendUpd(ctx context.Context, recs *hostsfile.DefaultStorage) { + hc.logger.DebugContext(ctx, "sending update") ch := hc.updates select { @@ -194,11 +203,11 @@ func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) { // Updates are delivered. Go on. case <-ch: ch <- recs - log.Debug("%s: replaced the last update", hostsContainerPrefix) + hc.logger.DebugContext(ctx, "replaced the last update") case ch <- recs: // The previous update was just read and the next one pushed. Go on. default: - log.Error("%s: the updates channel is broken", hostsContainerPrefix) + hc.logger.ErrorContext(ctx, "updates channel is broken") } } @@ -206,8 +215,8 @@ func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) { // needed. // // TODO(e.burkov): Accept a parameter to specify the files to refresh. -func (hc *HostsContainer) refresh() (err error) { - log.Debug("%s: refreshing", hostsContainerPrefix) +func (hc *HostsContainer) refresh(ctx context.Context) (err error) { + hc.logger.DebugContext(ctx, "refreshing") // The error is always nil here since no readers passed. strg, _ := hostsfile.NewDefaultStorage() @@ -223,7 +232,7 @@ func (hc *HostsContainer) refresh() (err error) { // TODO(e.burkov): Serialize updates using [time.Time]. if !hc.current.Load().Equal(strg) { hc.current.Store(strg) - hc.sendUpd(strg) + hc.sendUpd(ctx, strg) } return nil diff --git a/internal/aghnet/hostscontainer_test.go b/internal/aghnet/hostscontainer_test.go index 47562f2e..09813200 100644 --- a/internal/aghnet/hostscontainer_test.go +++ b/internal/aghnet/hostscontainer_test.go @@ -13,12 +13,19 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +// testTimeout is a common timeout for tests. +const testTimeout = 1 * time.Second + +// testLogger is a logger used in tests. +var testLogger = slogutil.NewDiscardLogger() + func TestNewHostsContainer(t *testing.T) { const dirname = "dir" const filename = "file1" @@ -67,7 +74,8 @@ func TestNewHostsContainer(t *testing.T) { return eventsCh } - hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{ OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) }, @@ -96,7 +104,8 @@ func TestNewHostsContainer(t *testing.T) { t.Run("nil_fs", func(t *testing.T) { require.Panics(t, func() { - _, _ = aghnet.NewHostsContainer(nil, &aghtest.FSWatcher{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + _, _ = aghnet.NewHostsContainer(ctx, testLogger, nil, &aghtest.FSWatcher{ OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) }, @@ -110,7 +119,8 @@ func TestNewHostsContainer(t *testing.T) { t.Run("nil_watcher", func(t *testing.T) { require.Panics(t, func() { - _, _ = aghnet.NewHostsContainer(testFS, nil, p) + ctx := testutil.ContextWithTimeout(t, testTimeout) + _, _ = aghnet.NewHostsContainer(ctx, testLogger, testFS, nil, p) }) }) @@ -124,7 +134,8 @@ func TestNewHostsContainer(t *testing.T) { OnShutdown: func(_ context.Context) (err error) { return nil }, } - hc, err := aghnet.NewHostsContainer(testFS, errWatcher, p) + ctx := testutil.ContextWithTimeout(t, testTimeout) + hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, errWatcher, p) require.ErrorIs(t, err, errOnAdd) assert.Nil(t, hc) @@ -173,7 +184,8 @@ func TestHostsContainer_refresh(t *testing.T) { OnShutdown: func(_ context.Context) (err error) { return nil }, } - hc, err := aghnet.NewHostsContainer(testFS, w, "dir") + ctx := testutil.ContextWithTimeout(t, testTimeout) + hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, w, "dir") require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) diff --git a/internal/aghnet/interfaces.go b/internal/aghnet/interfaces.go index a667a1f3..b6b925da 100644 --- a/internal/aghnet/interfaces.go +++ b/internal/aghnet/interfaces.go @@ -1,11 +1,11 @@ package aghnet import ( + "context" "fmt" + "log/slog" "net" "time" - - "github.com/AdguardTeam/golibs/log" ) // IPVersion is a alias for int for documentation purposes. Use it when the @@ -66,7 +66,7 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) { // IfaceDNSIPAddrs returns IP addresses of the interface suitable to send to // clients as DNS addresses. If err is nil, addrs contains either no addresses -// or at least two. +// or at least two. l must not be nil. // // It makes up to maxAttempts attempts to get the addresses if there are none, // each time using the provided backoff. Sometimes an interface needs a few @@ -74,6 +74,8 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) { // // See https://github.com/AdguardTeam/AdGuardHome/issues/2304. func IfaceDNSIPAddrs( + ctx context.Context, + l *slog.Logger, iface NetIface, ipv IPVersion, maxAttempts int, @@ -90,7 +92,7 @@ func IfaceDNSIPAddrs( break } - log.Debug("dhcpv%d: attempt %d: no ip addresses", ipv, n) + l.DebugContext(ctx, "no ip addresses", "attempt", n, "ipv", ipv) time.Sleep(backoff) } @@ -102,7 +104,7 @@ func IfaceDNSIPAddrs( // Don't return errors in case the users want to try and enable the DHCP // server later. t := time.Duration(n) * backoff - log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t) + l.ErrorContext(ctx, "no ip addresses for iface", "attempts", n, "duration", t, "ipv", ipv) return nil, nil case 1: @@ -111,13 +113,13 @@ func IfaceDNSIPAddrs( // address. // // See https://github.com/AdguardTeam/AdGuardHome/issues/1708. - log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv) + l.DebugContext(ctx, "setting secondary dns ip to itself", "ipv", ipv) addrs = append(addrs, addrs[0]) default: // Go on. } - log.Debug("dhcpv%d: got addresses %s after %d attempts", ipv, addrs, n) + l.DebugContext(ctx, "got addresses", "addrs", addrs, "attempts", n, "ipv", ipv) return addrs, nil } diff --git a/internal/aghnet/interfaces_test.go b/internal/aghnet/interfaces_test.go index 83bb81d5..7778103a 100644 --- a/internal/aghnet/interfaces_test.go +++ b/internal/aghnet/interfaces_test.go @@ -220,7 +220,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - got, err := aghnet.IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0) + ctx := testutil.ContextWithTimeout(t, testTimeout) + got, err := aghnet.IfaceDNSIPAddrs(ctx, testLogger, tc.iface, tc.ipv, 2, 0) require.ErrorIs(t, err, tc.wantErr) assert.Equal(t, tc.want, got) diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 5bfd0c89..81c51d5f 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -1,6 +1,4 @@ // Package aghnet contains networking utilities. -// -// TODO(s.chzhen): Use slog. package aghnet import ( @@ -9,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net" "net/netip" "net/url" @@ -17,7 +16,7 @@ import ( "github.com/AdguardTeam/dnsproxy/upstream" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/osutil" "github.com/AdguardTeam/golibs/osutil/executil" ) @@ -51,23 +50,25 @@ func IfaceHasStaticIP( return ifaceHasStaticIP(ctx, cmdCons, ifaceName) } -// IfaceSetStaticIP sets a static IP address for network interface. cmdCons -// must not be nil. +// IfaceSetStaticIP sets a static IP address for network interface. l and +// cmdCons must not be nil. func IfaceSetStaticIP( ctx context.Context, + l *slog.Logger, cmdCons executil.CommandConstructor, ifaceName string, ) (err error) { - return ifaceSetStaticIP(ctx, cmdCons, ifaceName) + return ifaceSetStaticIP(ctx, l, cmdCons, ifaceName) } -// GatewayIP returns the gateway IP address for the interface. cmdCons must not -// be nil. +// GatewayIP returns the gateway IP address for the interface. l and cmdCons +// must not be nil. // // TODO(e.burkov): Investigate if the gateway address may be fetched in another // way since not every machine has the software installed. func GatewayIP( ctx context.Context, + l *slog.Logger, cmdCons executil.CommandConstructor, ifaceName string, ) (ip netip.Addr) { @@ -83,9 +84,9 @@ func GatewayIP( ) if err != nil { if code, ok := executil.ExitCodeFromError(err); ok { - log.Debug("fetching gateway ip: unexpected exit code: %d", code) + l.DebugContext(ctx, "fetching gateway ip: unexpected exit code", "code", code) } else { - log.Debug("%s", err) + l.DebugContext(ctx, "fetching gateway ip", slogutil.KeyError, err) } return netip.Addr{} @@ -104,9 +105,9 @@ func GatewayIP( } // CanBindPrivilegedPorts checks if current process can bind to privileged -// ports. -func CanBindPrivilegedPorts() (can bool, err error) { - return canBindPrivilegedPorts() +// ports. l must not be nil. +func CanBindPrivilegedPorts(ctx context.Context, l *slog.Logger) (can bool, err error) { + return canBindPrivilegedPorts(ctx, l) } // NetInterface represents an entry of network interfaces map. @@ -237,13 +238,13 @@ func InterfaceByIP(ip netip.Addr) (ifaceName string) { } // GetSubnet returns the subnet corresponding to the interface of zero prefix if -// the search fails. +// the search fails. l must not be nil. // // TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb. -func GetSubnet(ifaceName string) (p netip.Prefix) { +func GetSubnet(ctx context.Context, l *slog.Logger, ifaceName string) (p netip.Prefix) { netIfaces, err := GetValidNetInterfacesForWeb() if err != nil { - log.Error("Could not get network interfaces info: %v", err) + l.ErrorContext(ctx, "could not get network interfaces info", slogutil.KeyError, err) return p } diff --git a/internal/aghnet/net_bsd.go b/internal/aghnet/net_bsd.go index 94a27a6d..cf9401c5 100644 --- a/internal/aghnet/net_bsd.go +++ b/internal/aghnet/net_bsd.go @@ -2,8 +2,13 @@ package aghnet -import "github.com/AdguardTeam/AdGuardHome/internal/aghos" +import ( + "context" + "log/slog" -func canBindPrivilegedPorts() (can bool, err error) { + "github.com/AdguardTeam/AdGuardHome/internal/aghos" +) + +func canBindPrivilegedPorts(_ context.Context, _ *slog.Logger) (can bool, err error) { return aghos.HaveAdminRights() } diff --git a/internal/aghnet/net_darwin.go b/internal/aghnet/net_darwin.go index de1968ca..82b382b9 100644 --- a/internal/aghnet/net_darwin.go +++ b/internal/aghnet/net_darwin.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "io" + "log/slog" "regexp" "github.com/AdguardTeam/AdGuardHome/internal/aghos" @@ -119,6 +120,7 @@ func getHardwarePortInfo( // ifaceSetStaticIP sets a static IP on ifaceName. cmdCons must not be nil. func ifaceSetStaticIP( ctx context.Context, + _ *slog.Logger, cmdCons executil.CommandConstructor, ifaceName string, ) (err error) { diff --git a/internal/aghnet/net_darwin_internal_test.go b/internal/aghnet/net_darwin_internal_test.go index a8e19cdc..863dde47 100644 --- a/internal/aghnet/net_darwin_internal_test.go +++ b/internal/aghnet/net_darwin_internal_test.go @@ -249,7 +249,7 @@ func TestIfaceSetStaticIP(t *testing.T) { substRootDirFS(t, tc.fsys) ctx := testutil.ContextWithTimeout(t, testTimeout) - err := IfaceSetStaticIP(ctx, tc.cmdCons, "en0") + err := IfaceSetStaticIP(ctx, testLogger, tc.cmdCons, "en0") testutil.AssertErrorMsg(t, tc.wantErrMsg, err) }) } diff --git a/internal/aghnet/net_freebsd.go b/internal/aghnet/net_freebsd.go index d646b80f..a7ae478a 100644 --- a/internal/aghnet/net_freebsd.go +++ b/internal/aghnet/net_freebsd.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "log/slog" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghos" @@ -58,6 +59,11 @@ func (n interfaceName) rcConfStaticConfig(r io.Reader) (_ []string, cont bool, e return nil, true, s.Err() } -func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) { +func ifaceSetStaticIP( + _ context.Context, + _ *slog.Logger, + _ executil.CommandConstructor, + _ string, +) (err error) { return aghos.Unsupported("setting static ip") } diff --git a/internal/aghnet/net_internal_test.go b/internal/aghnet/net_internal_test.go index 7df6c1cd..de086456 100644 --- a/internal/aghnet/net_internal_test.go +++ b/internal/aghnet/net_internal_test.go @@ -11,6 +11,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/osutil/executil" "github.com/AdguardTeam/golibs/testutil" @@ -24,6 +25,9 @@ const testTimeout = 1 * time.Second // testCmdCons is the common command constructor for tests. var testCmdCons = executil.EmptyCommandConstructor{} +// testLogger is a logger used in tests. +var testLogger = slogutil.NewDiscardLogger() + // substRootDirFS replaces the aghos.RootDirFS function used throughout the // package with fsys for tests ran under t. func substRootDirFS(tb testing.TB, fsys fs.FS) { @@ -87,7 +91,7 @@ func TestGatewayIP(t *testing.T) { t.Parallel() ctx := testutil.ContextWithTimeout(t, testTimeout) - assert.Equal(t, tc.want, GatewayIP(ctx, tc.cmdCons, ifaceName)) + assert.Equal(t, tc.want, GatewayIP(ctx, testLogger, tc.cmdCons, ifaceName)) }) } } diff --git a/internal/aghnet/net_linux.go b/internal/aghnet/net_linux.go index e979d403..5d047ff2 100644 --- a/internal/aghnet/net_linux.go +++ b/internal/aghnet/net_linux.go @@ -7,13 +7,14 @@ import ( "context" "fmt" "io" + "log/slog" "net/netip" "os" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/osutil/executil" "github.com/AdguardTeam/golibs/stringutil" "github.com/google/renameio/v2/maybe" @@ -23,7 +24,7 @@ import ( // dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem. const dhcpcdConf = "etc/dhcpcd.conf" -func canBindPrivilegedPorts() (can bool, err error) { +func canBindPrivilegedPorts(ctx context.Context, l *slog.Logger) (can bool, err error) { res, err := unix.PrctlRetInt( unix.PR_CAP_AMBIENT, unix.PR_CAP_AMBIENT_IS_SET, @@ -35,7 +36,11 @@ func canBindPrivilegedPorts() (can bool, err error) { if errors.Is(err, unix.EINVAL) { // Older versions of Linux kernel do not support this. Print a // warning and check admin rights. - log.Info("warning: cannot check capability cap_net_bind_service: %s", err) + l.WarnContext( + ctx, + "cannot check capability cap_net_bind_service", + slogutil.KeyError, err, + ) } else { return false, err } @@ -154,13 +159,14 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) { } // ifaceSetStaticIP configures the system to retain its current IP on the -// interface through dhcpcd.conf. cmdCons must not be nil. +// interface through dhcpcd.conf. l and cmdCons must not be nil. func ifaceSetStaticIP( ctx context.Context, + l *slog.Logger, cmdCons executil.CommandConstructor, ifaceName string, ) (err error) { - ipNet := GetSubnet(ifaceName) + ipNet := GetSubnet(ctx, l, ifaceName) if !ipNet.Addr().IsValid() { return errors.Error("can't get IP address") } @@ -170,7 +176,7 @@ func ifaceSetStaticIP( return err } - gatewayIP := GatewayIP(ctx, cmdCons, ifaceName) + gatewayIP := GatewayIP(ctx, l, cmdCons, ifaceName) add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP) body = append(body, []byte(add)...) diff --git a/internal/aghnet/net_openbsd.go b/internal/aghnet/net_openbsd.go index 3e405089..76df65c0 100644 --- a/internal/aghnet/net_openbsd.go +++ b/internal/aghnet/net_openbsd.go @@ -7,6 +7,7 @@ import ( "context" "fmt" "io" + "log/slog" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghos" @@ -45,6 +46,11 @@ func hostnameIfStaticConfig(r io.Reader) (_ []string, ok bool, err error) { return nil, true, s.Err() } -func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) { +func ifaceSetStaticIP( + _ context.Context, + _ *slog.Logger, + _ executil.CommandConstructor, + _ string, +) (err error) { return aghos.Unsupported("setting static ip") } diff --git a/internal/aghnet/net_windows.go b/internal/aghnet/net_windows.go index 164bb6fc..152dedb0 100644 --- a/internal/aghnet/net_windows.go +++ b/internal/aghnet/net_windows.go @@ -5,6 +5,7 @@ package aghnet import ( "context" "io" + "log/slog" "syscall" "time" @@ -14,7 +15,7 @@ import ( "golang.org/x/sys/windows" ) -func canBindPrivilegedPorts() (can bool, err error) { +func canBindPrivilegedPorts(_ context.Context, _ *slog.Logger) (can bool, err error) { return true, nil } @@ -26,7 +27,12 @@ func ifaceHasStaticIP( return false, aghos.Unsupported("checking static ip") } -func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) { +func ifaceSetStaticIP( + _ context.Context, + _ *slog.Logger, + _ executil.CommandConstructor, + _ string, +) (err error) { return aghos.Unsupported("setting static ip") } diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index f102cd6d..2033e2fe 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -576,8 +576,10 @@ func TestClientsAddExisting(t *testing.T) { // First, init a DHCP server with a single static lease. config := &dhcpd.ServerConfig{ - Enabled: true, - DataDir: t.TempDir(), + BaseLogger: testLogger, + Logger: testLogger, + Enabled: true, + DataDir: t.TempDir(), Conf4: dhcpd.V4ServerConf{ Enabled: true, GatewayIP: netip.MustParseAddr("1.2.3.1"), @@ -587,7 +589,8 @@ func TestClientsAddExisting(t *testing.T) { }, } - dhcpServer, err := dhcpd.Create(config) + ctx = testutil.ContextWithTimeout(t, testTimeout) + dhcpServer, err := dhcpd.Create(ctx, config) require.NoError(t, err) storage, err := client.NewStorage(ctx, &client.StorageConfig{ diff --git a/internal/dhcpd/config.go b/internal/dhcpd/config.go index af3930b9..1916dd7f 100644 --- a/internal/dhcpd/config.go +++ b/internal/dhcpd/config.go @@ -1,7 +1,9 @@ package dhcpd import ( + "context" "fmt" + "log/slog" "net" "net/netip" "time" @@ -17,6 +19,14 @@ import ( // ServerConfig is the configuration for the DHCP server. The order of YAML // fields is important, since the YAML configuration file follows it. type ServerConfig struct { + // BaseLogger is used for creating loggers for dhcpv4 and dhcpv6. It must + // not be nil. + BaseLogger *slog.Logger `yaml:"-"` + + // Logger is used for logging the operation of the DHCP server. It must not + // be nil. + Logger *slog.Logger `yaml:"-"` + // CommandConstructor is used to run external commands. It must not be nil. CommandConstructor executil.CommandConstructor `yaml:"-"` @@ -85,7 +95,7 @@ type DHCPServer interface { WriteDiskConfig6(c *V6ServerConf) // Start - start server - Start() (err error) + Start(ctx context.Context) (err error) // Stop - stop server Stop() (err error) getLeasesRef() []*dhcpsvc.Lease @@ -93,6 +103,10 @@ type DHCPServer interface { // V4ServerConf - server configuration type V4ServerConf struct { + // Logger is used for logging the operation of the DHCPv4 server. It must + // not be nil. + Logger *slog.Logger `yaml:"-" json:"-"` + Enabled bool `yaml:"-" json:"-"` InterfaceName string `yaml:"-" json:"-"` @@ -232,6 +246,10 @@ func (c *V4ServerConf) Validate() (err error) { // V6ServerConf - server configuration type V6ServerConf struct { + // Logger is used for logging the operation of the DHCPv6 server. It must + // not be nil. + Logger *slog.Logger `yaml:"-" json:"-"` + Enabled bool `yaml:"-" json:"-"` InterfaceName string `yaml:"-" json:"-"` diff --git a/internal/dhcpd/dhcpd.go b/internal/dhcpd/dhcpd.go index 20bd6cfb..ad59d814 100644 --- a/internal/dhcpd/dhcpd.go +++ b/internal/dhcpd/dhcpd.go @@ -2,6 +2,7 @@ package dhcpd import ( + "context" "fmt" "net" "net/netip" @@ -9,7 +10,7 @@ import ( "time" "github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc" - "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/timeutil" ) @@ -53,7 +54,7 @@ const ( // Interface is the DHCP server that deals with both IP address families. type Interface interface { - Start() (err error) + Start(ctx context.Context) (err error) Stop() (err error) // Enabled returns true if the DHCP server is running. @@ -104,9 +105,11 @@ var _ Interface = (*server)(nil) // Create initializes and returns the DHCP server handling both address // families. It also registers the corresponding HTTP API endpoints. -func Create(conf *ServerConfig) (s *server, err error) { +func Create(ctx context.Context, conf *ServerConfig) (s *server, err error) { s = &server{ conf: &ServerConfig{ + BaseLogger: conf.BaseLogger, + Logger: conf.Logger, CommandConstructor: conf.CommandConstructor, ConfModifier: conf.ConfModifier, @@ -125,7 +128,7 @@ func Create(conf *ServerConfig) (s *server, err error) { // [aghhttp.RegisterFunc]. s.registerHandlers() - v4Enabled, v6Enabled, err := s.setServers(conf) + v4Enabled, v6Enabled, err := s.setServers(ctx, conf) if err != nil { // Don't wrap the error, because it's informative enough as is. return nil, err @@ -158,8 +161,12 @@ func Create(conf *ServerConfig) (s *server, err error) { // setServers updates DHCPv4 and DHCPv6 servers created from the provided // configuration conf. It returns the status of both the DHCPv4 and the DHCPv6 // servers, which is always false for corresponding server on any error. -func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err error) { +func (s *server) setServers( + ctx context.Context, + conf *ServerConfig, +) (v4Enabled, v6Enabled bool, err error) { v4conf := conf.Conf4 + v4conf.Logger = s.conf.BaseLogger.With(slogutil.KeyPrefix, "dhcpv4_server") v4conf.InterfaceName = s.conf.InterfaceName v4conf.notify = s.onNotify v4conf.Enabled = s.conf.Enabled && v4conf.RangeStart.IsValid() @@ -170,10 +177,11 @@ func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err return false, false, fmt.Errorf("creating dhcpv4 srv: %w", err) } - log.Debug("dhcpd: warning: creating dhcpv4 srv: %s", err) + s.conf.Logger.WarnContext(ctx, "creating dhcpv4 server", slogutil.KeyError, err) } v6conf := conf.Conf6 + v6conf.Logger = s.conf.BaseLogger.With(slogutil.KeyPrefix, "dhcpv6_server") v6conf.InterfaceName = s.conf.InterfaceName v6conf.notify = s.onNotify v6conf.Enabled = s.conf.Enabled && len(v6conf.RangeStart) != 0 @@ -213,7 +221,9 @@ func (s *server) onNotify(flags uint32) { if flags == LeaseChangedDBStore { err := s.dbStore() if err != nil { - log.Error("updating db: %s", err) + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + s.conf.Logger.ErrorContext(ctx, "updating db", slogutil.KeyError, err) } return @@ -239,13 +249,13 @@ func (s *server) WriteDiskConfig(c *ServerConfig) { } // Start will listen on port 67 and serve DHCP requests. -func (s *server) Start() (err error) { - err = s.srv4.Start() +func (s *server) Start(ctx context.Context) (err error) { + err = s.srv4.Start(ctx) if err != nil { return err } - err = s.srv6.Start() + err = s.srv6.Start(ctx) if err != nil { return err } diff --git a/internal/dhcpd/dhcpd_internal_test.go b/internal/dhcpd/dhcpd_internal_test.go new file mode 100644 index 00000000..877902db --- /dev/null +++ b/internal/dhcpd/dhcpd_internal_test.go @@ -0,0 +1,13 @@ +package dhcpd + +import ( + "time" + + "github.com/AdguardTeam/golibs/logutil/slogutil" +) + +// testTimeout is a common timeout for tests. +const testTimeout = 1 * time.Second + +// testLogger is a logger used in tests. +var testLogger = slogutil.NewDiscardLogger() diff --git a/internal/dhcpd/http_unix.go b/internal/dhcpd/http_unix.go index 9df36ea9..6967c32b 100644 --- a/internal/dhcpd/http_unix.go +++ b/internal/dhcpd/http_unix.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net" "net/http" "net/netip" @@ -205,7 +206,7 @@ func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, er } if !hasStaticIP { - err = aghnet.IfaceSetStaticIP(ctx, cmdCons, ifaceName) + err = aghnet.IfaceSetStaticIP(ctx, s.conf.Logger, cmdCons, ifaceName) if err != nil { err = fmt.Errorf("setting static ip: %w", err) @@ -213,7 +214,7 @@ func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, er } } - err = s.Start() + err = s.Start(ctx) if err != nil { return http.StatusBadRequest, fmt.Errorf("starting dhcp server: %w", err) } @@ -390,6 +391,7 @@ type netInterfaceJSON struct { // handleDHCPInterfaces is the handler for the GET /control/dhcp/interfaces // HTTP API. func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() resp := map[string]*netInterfaceJSON{} ifaces, err := net.Interfaces() @@ -410,7 +412,7 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { continue } - jsonIface, iErr := newNetInterfaceJSON(r.Context(), iface, s.conf.CommandConstructor) + jsonIface, iErr := newNetInterfaceJSON(ctx, s.conf.Logger, iface, s.conf.CommandConstructor) if iErr != nil { aghhttp.Error(r, w, http.StatusInternalServerError, "%s", iErr) @@ -426,9 +428,10 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) { } // newNetInterfaceJSON creates a JSON object from a [net.Interface] iface. -// cmdCons must not be nil. +// l and cmdCons must not be nil. func newNetInterfaceJSON( ctx context.Context, + l *slog.Logger, iface net.Interface, cmdCons executil.CommandConstructor, ) (out *netInterfaceJSON, err error) { @@ -483,7 +486,7 @@ func newNetInterfaceJSON( return nil, nil } - out.GatewayIP = aghnet.GatewayIP(ctx, cmdCons, iface.Name) + out.GatewayIP = aghnet.GatewayIP(ctx, l, cmdCons, iface.Name) return out, nil } @@ -533,6 +536,8 @@ type findActiveServerReq struct { // 2. check if a static IP is configured for the network interface; // 3. responds with the results. func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if aghhttp.WriteTextPlainDeprecated(w, r) { return } @@ -569,24 +574,28 @@ func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque } cmdCons := s.conf.CommandConstructor - if isStaticIP, serr := aghnet.IfaceHasStaticIP(r.Context(), cmdCons, ifaceName); serr != nil { + if isStaticIP, serr := aghnet.IfaceHasStaticIP(ctx, cmdCons, ifaceName); serr != nil { result.V4.StaticIP.Static = "error" result.V4.StaticIP.Error = serr.Error() } else if !isStaticIP { result.V4.StaticIP.Static = "no" // TODO(e.burkov): The returned IP should only be of version 4. - result.V4.StaticIP.IP = aghnet.GetSubnet(ifaceName).String() + result.V4.StaticIP.IP = aghnet.GetSubnet(ctx, s.conf.Logger, ifaceName).String() } - setOtherDHCPResult(ifaceName, result) + s.setOtherDHCPResult(ctx, ifaceName, result) aghhttp.WriteJSONResponseOK(w, r, result) } // setOtherDHCPResult sets the results of the check for another DHCP server in -// result. -func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) { - found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ifaceName) +// result. result must not be nil. +func (s *server) setOtherDHCPResult( + ctx context.Context, + ifaceName string, + result *dhcpSearchResult, +) { + found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ctx, s.conf.Logger, ifaceName) if err4 != nil { result.V4.OtherServer.Found = "error" result.V4.OtherServer.Error = err4.Error() diff --git a/internal/dhcpd/http_unix_internal_test.go b/internal/dhcpd/http_unix_internal_test.go index 7f4c4390..1760dd18 100644 --- a/internal/dhcpd/http_unix_internal_test.go +++ b/internal/dhcpd/http_unix_internal_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/agh" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -84,7 +85,10 @@ func TestServer_handleDHCPStatus(t *testing.T) { Hostname: staticName, } - s, err := Create(&ServerConfig{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + s, err := Create(ctx, &ServerConfig{ + BaseLogger: testLogger, + Logger: testLogger, Enabled: true, Conf4: *defaultV4ServerConf(), DataDir: t.TempDir(), @@ -178,7 +182,10 @@ func TestServer_HandleUpdateStaticLease(t *testing.T) { }, } - s, err := Create(&ServerConfig{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + s, err := Create(ctx, &ServerConfig{ + BaseLogger: testLogger, + Logger: testLogger, Enabled: true, Conf4: *defaultV4ServerConf(), Conf6: V6ServerConf{}, @@ -266,7 +273,10 @@ func TestServer_HandleUpdateStaticLease_validation(t *testing.T) { Hostname: anotherV4Name, }} - s, err := Create(&ServerConfig{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + s, err := Create(ctx, &ServerConfig{ + BaseLogger: testLogger, + Logger: testLogger, Enabled: true, Conf4: *defaultV4ServerConf(), Conf6: V6ServerConf{}, diff --git a/internal/dhcpd/v46_windows.go b/internal/dhcpd/v46_windows.go index 241429c6..3fc462b0 100644 --- a/internal/dhcpd/v46_windows.go +++ b/internal/dhcpd/v46_windows.go @@ -5,6 +5,7 @@ package dhcpd // 'u-root/u-root' package, a dependency of 'insomniacslk/dhcp' package, doesn't build on Windows import ( + "context" "net" "net/netip" @@ -25,7 +26,7 @@ func (winServer) UpdateStaticLease(_ *dhcpsvc.Lease) (err error) { return func (winServer) FindMACbyIP(_ netip.Addr) (mac net.HardwareAddr) { return nil } func (winServer) WriteDiskConfig4(_ *V4ServerConf) {} func (winServer) WriteDiskConfig6(_ *V6ServerConf) {} -func (winServer) Start() (err error) { return nil } +func (winServer) Start(ctx context.Context) (err error) { return nil } func (winServer) Stop() (err error) { return nil } func (winServer) HostByIP(_ netip.Addr) (host string) { return "" } func (winServer) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} } diff --git a/internal/dhcpd/v4_unix.go b/internal/dhcpd/v4_unix.go index 37da9f71..0350d7f3 100644 --- a/internal/dhcpd/v4_unix.go +++ b/internal/dhcpd/v4_unix.go @@ -4,6 +4,7 @@ package dhcpd import ( "bytes" + "context" "fmt" "net" "net/netip" @@ -1299,7 +1300,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4 } // Start starts the IPv4 DHCP server. -func (s *v4Server) Start() (err error) { +func (s *v4Server) Start(ctx context.Context) (err error) { defer func() { err = errors.Annotate(err, "dhcpv4: %w") }() if !s.enabled() { @@ -1315,6 +1316,8 @@ func (s *v4Server) Start() (err error) { log.Debug("dhcpv4: starting...") dnsIPAddrs, err := aghnet.IfaceDNSIPAddrs( + ctx, + s.conf.Logger, iface, aghnet.IPVersion4, defaultMaxAttempts, diff --git a/internal/dhcpd/v6_unix.go b/internal/dhcpd/v6_unix.go index 6d832cfc..c7e956f9 100644 --- a/internal/dhcpd/v6_unix.go +++ b/internal/dhcpd/v6_unix.go @@ -4,6 +4,7 @@ package dhcpd import ( "bytes" + "context" "fmt" "net" "net/netip" @@ -657,8 +658,13 @@ func (s *v6Server) packetHandler(conn net.PacketConn, peer net.Addr, req dhcpv6. // configureDNSIPAddrs updates v6Server configuration with the slice of DNS IP // addresses of provided interface iface. Initializes RA module. -func (s *v6Server) configureDNSIPAddrs(iface *net.Interface) (ok bool, err error) { +func (s *v6Server) configureDNSIPAddrs( + ctx context.Context, + iface *net.Interface, +) (ok bool, err error) { dnsIPAddrs, err := aghnet.IfaceDNSIPAddrs( + ctx, + s.conf.Logger, iface, aghnet.IPVersion6, defaultMaxAttempts, @@ -700,7 +706,7 @@ func (s *v6Server) initRA(iface *net.Interface) (err error) { } // Start starts the IPv6 DHCP server. -func (s *v6Server) Start() (err error) { +func (s *v6Server) Start(ctx context.Context) (err error) { defer func() { err = errors.Annotate(err, "dhcpv6: %w") }() if !s.conf.Enabled { @@ -715,7 +721,7 @@ func (s *v6Server) Start() (err error) { log.Debug("dhcpv6: starting...") - ok, err := s.configureDNSIPAddrs(iface) + ok, err := s.configureDNSIPAddrs(ctx, iface) if err != nil { // Don't wrap the error, because it's informative enough as is. return err diff --git a/internal/dnsforward/dnsforward_internal_test.go b/internal/dnsforward/dnsforward_internal_test.go index afcd3c71..04711dea 100644 --- a/internal/dnsforward/dnsforward_internal_test.go +++ b/internal/dnsforward/dnsforward_internal_test.go @@ -1465,7 +1465,8 @@ func TestPTRResponseFromHosts(t *testing.T) { } var eventsCalledCounter uint32 - hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{ + ctx := testutil.ContextWithTimeout(t, testTimeout) + hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{ OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) }, OnEvents: func() (e <-chan struct{}) { assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1)) diff --git a/internal/dnsforward/http_internal_test.go b/internal/dnsforward/http_internal_test.go index 2da75e87..2cc99afe 100644 --- a/internal/dnsforward/http_internal_test.go +++ b/internal/dnsforward/http_internal_test.go @@ -362,7 +362,10 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) { Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()), }).String() + ctx := testutil.ContextWithTimeout(t, testTimeout) hc, err := aghnet.NewHostsContainer( + ctx, + testLogger, fstest.MapFS{ hostsFileName: &fstest.MapFile{ Data: []byte(hostsListener.Addr().String() + " " + upstreamHost), diff --git a/internal/filtering/dnsrewrite_test.go b/internal/filtering/dnsrewrite_test.go index 58353a43..b95d8dc4 100644 --- a/internal/filtering/dnsrewrite_test.go +++ b/internal/filtering/dnsrewrite_test.go @@ -6,7 +6,6 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -53,7 +52,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) { ` conf := &filtering.Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, SafeBrowsingCacheSize: 10000, ParentalCacheSize: 10000, SafeSearchCacheSize: 1000, diff --git a/internal/filtering/filter_internal_test.go b/internal/filtering/filter_internal_test.go index 144ec061..6dbad685 100644 --- a/internal/filtering/filter_internal_test.go +++ b/internal/filtering/filter_internal_test.go @@ -8,18 +8,13 @@ import ( "os" "path/filepath" "testing" - "time" - "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// testTimeout is the common timeout for tests. -const testTimeout = 5 * time.Second - // serveHTTPLocally starts a new HTTP server, that handles its index with h. It // also gracefully closes the listener when the test under t finishes. func serveHTTPLocally(tb testing.TB, h http.Handler) (urlStr string) { @@ -86,7 +81,7 @@ func newDNSFilter(tb testing.TB) (d *DNSFilter) { tb.Helper() dnsFilter, err := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, DataDir: tb.TempDir(), HTTPClient: &http.Client{ Timeout: testTimeout, diff --git a/internal/filtering/filtering_internal_test.go b/internal/filtering/filtering_internal_test.go index d5b14aff..b3a8287c 100644 --- a/internal/filtering/filtering_internal_test.go +++ b/internal/filtering/filtering_internal_test.go @@ -6,6 +6,7 @@ import ( "fmt" "net/netip" "testing" + "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix" @@ -18,6 +19,9 @@ import ( "github.com/stretchr/testify/require" ) +// testTimeout is a common timeout for tests. +const testTimeout = 1 * time.Second + const ( sbBlocked = "wmconvirus.narod.ru" pcBlocked = "pornhub.com" diff --git a/internal/filtering/filtering_test.go b/internal/filtering/filtering_test.go new file mode 100644 index 00000000..2fb09c78 --- /dev/null +++ b/internal/filtering/filtering_test.go @@ -0,0 +1,13 @@ +package filtering_test + +import ( + "time" + + "github.com/AdguardTeam/golibs/logutil/slogutil" +) + +// testTimeout is a common timeout for tests. +const testTimeout = 1 * time.Second + +// testLogger is a logger used in tests. +var testLogger = slogutil.NewDiscardLogger() diff --git a/internal/filtering/hashprefix/hashprefix_internal_test.go b/internal/filtering/hashprefix/hashprefix_internal_test.go index f1ab3b86..a4a72635 100644 --- a/internal/filtering/hashprefix/hashprefix_internal_test.go +++ b/internal/filtering/hashprefix/hashprefix_internal_test.go @@ -22,6 +22,9 @@ const ( cacheSize = 10000 ) +// testLogger is a logger used in tests. +var testLogger = slogutil.NewDiscardLogger() + func TestChcker_getQuestion(t *testing.T) { const suf = "sb.dns.adguard.com." @@ -45,7 +48,7 @@ func TestChcker_getQuestion(t *testing.T) { assert.False(t, slices.Contains(hashes, hash)) c := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, TXTSuffix: suf, }) @@ -100,7 +103,7 @@ func TestChecker_storeInCache(t *testing.T) { const testTimeout = 1 * time.Second c := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, CacheTime: cacheTime, }) @@ -158,7 +161,7 @@ func TestChecker_storeInCache(t *testing.T) { assert.True(t, ok) c = New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, CacheTime: cacheTime, }) @@ -195,7 +198,7 @@ func TestChecker_Check(t *testing.T) { for _, tc := range testCases { c := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, CacheTime: cacheTime, CacheSize: cacheSize, }) diff --git a/internal/filtering/hosts_test.go b/internal/filtering/hosts_test.go index 9385314a..01ef9b1e 100644 --- a/internal/filtering/hosts_test.go +++ b/internal/filtering/hosts_test.go @@ -11,7 +11,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" - "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/urlfilter/rules" "github.com/miekg/dns" @@ -49,12 +48,13 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) { OnAdd: func(name string) (err error) { return nil }, OnShutdown: func(_ context.Context) (err error) { return nil }, } - hc, err := aghnet.NewHostsContainer(files, watcher, "hosts") + ctx := testutil.ContextWithTimeout(t, testTimeout) + hc, err := aghnet.NewHostsContainer(ctx, testLogger, files, watcher, "hosts") require.NoError(t, err) testutil.CleanupAndRequireSuccess(t, hc.Close) conf := &filtering.Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, EtcHosts: hc, } f, err := filtering.New(conf, nil) diff --git a/internal/filtering/http_internal_test.go b/internal/filtering/http_internal_test.go index 326d91ff..504bb8ca 100644 --- a/internal/filtering/http_internal_test.go +++ b/internal/filtering/http_internal_test.go @@ -14,7 +14,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/schedule" - "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -110,7 +109,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) { confModifiedCalled = true } d, err := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, FilteringEnabled: true, Filters: tc.initial, HTTPClient: &http.Client{ @@ -195,7 +194,7 @@ func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) { } d, err := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, ConfModifier: confModifier, DataDir: filtersDir, HTTPRegister: func(_, url string, handler http.HandlerFunc) { @@ -282,7 +281,7 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) { } d, err := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, ConfModifier: confModifier, DataDir: filtersDir, HTTPRegister: func(_, url string, handler http.HandlerFunc) { @@ -384,7 +383,7 @@ func TestDNSFilter_HandleCheckHost(t *testing.T) { } dnsFilter, err := New(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, BlockedServices: &BlockedServices{ Schedule: schedule.EmptyWeekly(), }, diff --git a/internal/filtering/idgenerator_internal_test.go b/internal/filtering/idgenerator_internal_test.go index 195e9976..9a91c0e4 100644 --- a/internal/filtering/idgenerator_internal_test.go +++ b/internal/filtering/idgenerator_internal_test.go @@ -5,7 +5,6 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist" - "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/stretchr/testify/assert" ) @@ -65,7 +64,7 @@ func TestIDGenerator_Fix(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - g := newIDGenerator(1, slogutil.NewDiscardLogger()) + g := newIDGenerator(1, testLogger) g.fix(tc.in) assertUniqueIDs(t, tc.in) diff --git a/internal/filtering/rewrite/storage_internal_test.go b/internal/filtering/rewrite/storage_internal_test.go index 10df670c..84ca18bb 100644 --- a/internal/filtering/rewrite/storage_internal_test.go +++ b/internal/filtering/rewrite/storage_internal_test.go @@ -13,6 +13,9 @@ import ( "github.com/stretchr/testify/require" ) +// testLogger is a logger used in tests. +var testLogger = slogutil.NewDiscardLogger() + func TestNewDefaultStorage(t *testing.T) { items := []*Item{{ Domain: "example.com", @@ -20,7 +23,7 @@ func TestNewDefaultStorage(t *testing.T) { }} s, err := NewDefaultStorage(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, Rewrites: items, ListID: -1, }) @@ -33,7 +36,7 @@ func TestDefaultStorage_CRUD(t *testing.T) { var items []*Item s, err := NewDefaultStorage(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, Rewrites: items, ListID: -1, }) @@ -122,7 +125,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) { }} s, err := NewDefaultStorage(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, Rewrites: items, ListID: -1, }) @@ -298,7 +301,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) { }} s, err := NewDefaultStorage(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, Rewrites: items, ListID: -1, }) @@ -370,7 +373,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) { }} s, err := NewDefaultStorage(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, Rewrites: items, ListID: -1, }) @@ -438,7 +441,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) { }} s, err := NewDefaultStorage(&Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, Rewrites: items, ListID: -1, }) diff --git a/internal/filtering/rewritehttp_test.go b/internal/filtering/rewritehttp_test.go index 4b57c9db..c26c2b43 100644 --- a/internal/filtering/rewritehttp_test.go +++ b/internal/filtering/rewritehttp_test.go @@ -8,11 +8,9 @@ import ( "net/http" "net/http/httptest" "testing" - "time" "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/filtering" - "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -30,9 +28,6 @@ type rewriteUpdateJSON struct { } const ( - // testTimeout is the common timeout for tests. - testTimeout = 100 * time.Millisecond - listURL = "/control/rewrite/list" addURL = "/control/rewrite/add" deleteURL = "/control/rewrite/delete" @@ -159,7 +154,7 @@ func TestDNSFilter_handleRewriteHTTP(t *testing.T) { } d, err := filtering.New(&filtering.Config{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, ConfModifier: confModifier, HTTPRegister: func(_, url string, handler http.HandlerFunc) { handlers[url] = handler diff --git a/internal/filtering/safesearch/safesearch_test.go b/internal/filtering/safesearch/safesearch_test.go index 77eca2f1..aa1ad928 100644 --- a/internal/filtering/safesearch/safesearch_test.go +++ b/internal/filtering/safesearch/safesearch_test.go @@ -41,6 +41,9 @@ var testConf = filtering.SafeSearchConfig{ YouTube: true, } +// testLogger is a logger used in tests. +var testLogger = slogutil.NewDiscardLogger() + // yandexIP is the expected IP address of Yandex safe search results. Keep in // sync with the rules data. var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56}) @@ -49,7 +52,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) { conf := testConf ctx := testutil.ContextWithTimeout(t, testTimeout) ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, ServicesConfig: conf, CacheSize: testCacheSize, CacheTTL: testCacheTTL, @@ -111,7 +114,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) { func TestDefault_CheckHost_google(t *testing.T) { ctx := testutil.ContextWithTimeout(t, testTimeout) ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, ServicesConfig: testConf, CacheSize: testCacheSize, CacheTTL: testCacheTTL, @@ -163,7 +166,7 @@ func (r *testResolver) LookupIP( func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) { ctx := testutil.ContextWithTimeout(t, testTimeout) ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, ServicesConfig: testConf, CacheSize: testCacheSize, CacheTTL: testCacheTTL, @@ -186,7 +189,7 @@ func TestDefault_Update(t *testing.T) { conf := testConf ctx := testutil.ContextWithTimeout(t, testTimeout) ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{ - Logger: slogutil.NewDiscardLogger(), + Logger: testLogger, ServicesConfig: conf, CacheSize: testCacheSize, CacheTTL: testCacheTTL, diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 77fb3ef4..e9719eba 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -196,7 +196,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque if err != nil { resp.DNS.Status = err.Error() } else if !req.DNS.IP.IsUnspecified() { - resp.StaticIP = handleStaticIP(ctx, req.DNS.IP, req.SetStaticIP, web.cmdCons) + resp.StaticIP = handleStaticIP(ctx, web.logger, req.DNS.IP, req.SetStaticIP, web.cmdCons) } aghhttp.WriteJSONResponseOK(w, r, resp) @@ -206,6 +206,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque // owns IP. cmdCons must not be nil. func handleStaticIP( ctx context.Context, + l *slog.Logger, ip netip.Addr, set bool, cmdCons executil.CommandConstructor, @@ -222,7 +223,7 @@ func handleStaticIP( if set { // Try to set a static IP for the specified interface. - err := aghnet.IfaceSetStaticIP(ctx, cmdCons, interfaceName) + err := aghnet.IfaceSetStaticIP(ctx, l, cmdCons, interfaceName) if err != nil { ipResp.Static = "error" ipResp.Error = err.Error() @@ -244,7 +245,7 @@ func handleStaticIP( if isStaticIP { ipResp.Static = "yes" } - ipResp.IP = aghnet.GetSubnet(interfaceName).String() + ipResp.IP = aghnet.GetSubnet(ctx, l, interfaceName).String() return ipResp } diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index e791633d..8f835f9d 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -32,6 +32,8 @@ type temporaryError interface { // // TODO(a.garipov): Find out if this API used with a GET method by anyone. func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + resp := &versionResponse{} if web.conf.disableUpdate { resp.Disabled = true @@ -54,7 +56,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) { } } - err = web.requestVersionInfo(r.Context(), resp, req.Recheck) + err = web.requestVersionInfo(ctx, resp, req.Recheck) if err != nil { // Don't wrap the error, because it's informative enough as is. aghhttp.Error(r, w, http.StatusBadGateway, "%s", err) @@ -62,7 +64,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) { return } - err = resp.setAllowedToAutoUpdate(web.tlsManager) + err = resp.setAllowedToAutoUpdate(ctx, web.logger, web.tlsManager) if err != nil { // Don't wrap the error, because it's informative enough as is. aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err) @@ -164,8 +166,13 @@ type versionResponse struct { } // setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually -// allowed to perform an automatic update by the OS. tlsMgr must not be nil. -func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error) { +// allowed to perform an automatic update by the OS. l and tlsMgr must not be +// nil. +func (vr *versionResponse) setAllowedToAutoUpdate( + ctx context.Context, + l *slog.Logger, + tlsMgr *tlsManager, +) (err error) { if vr.CanAutoUpdate != aghalg.NBTrue { return nil } @@ -174,7 +181,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error if tlsConfUsesPrivilegedPorts(tlsMgr.config()) || config.HTTPConfig.Address.Port() < 1024 || config.DNS.Port < 1024 { - canUpdate, err = aghnet.CanBindPrivilegedPorts() + canUpdate, err = aghnet.CanBindPrivilegedPorts(ctx, l) if err != nil { return fmt.Errorf("checking ability to bind privileged ports: %w", err) } @@ -192,7 +199,7 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) { } // finishUpdate completes an update procedure. It is intended to be used as a -// goroutine. +// goroutine. l and cmdCons must not be nil. func finishUpdate( ctx context.Context, l *slog.Logger, diff --git a/internal/home/home.go b/internal/home/home.go index b40d09b8..29e2f29f 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -175,7 +175,7 @@ func setupContext(ctx context.Context, baseLogger *slog.Logger, opts options) (e if globalContext.firstRun { log.Info("This is the first time AdGuard Home is launched") - checkNetworkPermissions() + checkNetworkPermissions(ctx, baseLogger) return nil } @@ -265,7 +265,14 @@ func setupHostsContainer(ctx context.Context, baseLogger *slog.Logger) (err erro return fmt.Errorf("getting default system hosts paths: %w", err) } - globalContext.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...) + l := baseLogger.With(slogutil.KeyPrefix, "hosts") + globalContext.etcHosts, err = aghnet.NewHostsContainer( + ctx, + l, + osutil.RootDirFS(), + hostsWatcher, + paths..., + ) if err != nil { closeErr := hostsWatcher.Shutdown(ctx) if errors.Is(err, aghnet.ErrNoHostsPaths) { @@ -308,9 +315,11 @@ func initContextClients( config.DHCP.DataDir = globalContext.getDataDir() config.DHCP.HTTPRegister = httpRegister config.DHCP.CommandConstructor = executil.SystemCommandConstructor{} + config.DHCP.BaseLogger = logger + config.DHCP.Logger = logger.With(slogutil.KeyPrefix, "dhcp_server") config.DHCP.ConfModifier = confModifier - globalContext.dhcpServer, err = dhcpd.Create(config.DHCP) + globalContext.dhcpServer, err = dhcpd.Create(ctx, config.DHCP) if globalContext.dhcpServer == nil || err != nil { // TODO(a.garipov): There are a lot of places in the code right // now which assume that the DHCP server can be nil despite this @@ -805,7 +814,7 @@ func run( }() if globalContext.dhcpServer != nil { - err = globalContext.dhcpServer.Start() + err = globalContext.dhcpServer.Start(ctx) if err != nil { log.Error("starting dhcp server: %s", err) } @@ -940,11 +949,11 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) { } // checkNetworkPermissions checks if the current user permissions are enough to -// use the required networking functionality. -func checkNetworkPermissions() { +// use the required networking functionality. l must not be nil. +func checkNetworkPermissions(ctx context.Context, l *slog.Logger) { log.Info("Checking if AdGuard Home has necessary permissions") - if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil { + if ok, err := aghnet.CanBindPrivilegedPorts(ctx, l); !ok || err != nil { log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.") }