all: aghnet slog

This commit is contained in:
Stanislav Chzhen 2025-09-24 15:12:13 +03:00
parent 2a81beeb46
commit ea69b773d7
41 changed files with 423 additions and 198 deletions

View File

@ -1,6 +1,16 @@
package aghnet
// CheckOtherDHCP tries to discover another DHCP server in the network.
func CheckOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
return checkOtherDHCP(ifaceName)
import (
"context"
"log/slog"
)
// CheckOtherDHCP tries to discover another DHCP server in the network. l must
// not be nil.
func CheckOtherDHCP(
ctx context.Context,
l *slog.Logger,
ifaceName string,
) (ok4, ok6 bool, err4, err6 error) {
return checkOtherDHCP(ctx, l, ifaceName)
}

View File

@ -4,14 +4,16 @@ package aghnet
import (
"bytes"
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"time"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/insomniacslk/dhcp/dhcpv4"
"github.com/insomniacslk/dhcp/dhcpv6"
@ -23,7 +25,11 @@ import (
// response.
const defaultDiscoverTime = 3 * time.Second
func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
func checkOtherDHCP(
ctx context.Context,
l *slog.Logger,
ifaceName string,
) (ok4, ok6 bool, err4, err6 error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {
err = fmt.Errorf("couldn't find interface by name %s: %w", ifaceName, err)
@ -32,8 +38,8 @@ func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
return false, false, err4, err6
}
ok4, err4 = checkOtherDHCPv4(iface)
ok6, err6 = checkOtherDHCPv6(iface)
ok4, err4 = checkOtherDHCPv4(ctx, l, iface)
ok6, err6 = checkOtherDHCPv6(ctx, l, iface)
return ok4, ok6, err4, err6
}
@ -68,8 +74,13 @@ func ifaceIPv4Subnet(iface *net.Interface) (subnet netip.Prefix, err error) {
}
// checkOtherDHCPv4 sends a DHCP request to the specified network interface, and
// waits for a response for a period defined by defaultDiscoverTime.
func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) {
// waits for a response for a period defined by defaultDiscoverTime. l must not
// be nil.
func checkOtherDHCPv4(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
) (ok bool, err error) {
var subnet netip.Prefix
if subnet, err = ifaceIPv4Subnet(iface); err != nil {
return false, err
@ -87,10 +98,16 @@ func checkOtherDHCPv4(iface *net.Interface) (ok bool, err error) {
return false, fmt.Errorf("couldn't get hostname: %w", err)
}
return discover4(iface, dstAddr, hostname)
return discover4(ctx, l, iface, dstAddr, hostname)
}
func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok bool, err error) {
func discover4(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
dstAddr *net.UDPAddr,
hostname string,
) (ok bool, err error) {
var req *dhcpv4.DHCPv4
if req, err = dhcpv4.NewDiscovery(iface.HardwareAddr); err != nil {
return false, fmt.Errorf("dhcpv4.NewDiscovery: %w", err)
@ -125,10 +142,10 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok
}
var next bool
ok, next, err = tryConn4(req, c, iface)
ok, next, err = tryConn4(ctx, l, req, c, iface)
if next {
if err != nil {
log.Debug("dhcpv4: trying a connection: %s", err)
l.DebugContext(ctx, "dhcpv4: trying a connection", slogutil.KeyError, err)
}
continue
@ -144,16 +161,22 @@ func discover4(iface *net.Interface, dstAddr *net.UDPAddr, hostname string) (ok
// TODO(a.garipov): Refactor further. Inspect error handling, remove parameter
// next, address the TODO, merge with tryConn6, etc.
func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, next bool, err error) {
func tryConn4(
ctx context.Context,
l *slog.Logger,
req *dhcpv4.DHCPv4,
c net.PacketConn,
iface *net.Interface,
) (ok, next bool, err error) {
// TODO: replicate dhclient's behavior of retrying several times with
// progressively longer timeouts.
log.Tracef("dhcpv4: waiting %v for an answer", defaultDiscoverTime)
l.Log(ctx, slogutil.LevelTrace, "dhcpv4: waiting for an answer", "timeout", defaultDiscoverTime)
b := make([]byte, 1500)
n, _, err := c.ReadFrom(b)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
log.Debug("dhcpv4: didn't receive dhcp response")
l.DebugContext(ctx, "dhcpv4: didn't receive dhcp response")
return false, false, nil
}
@ -161,16 +184,16 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n
return false, false, fmt.Errorf("receiving packet: %w", err)
}
log.Tracef("dhcpv4: received packet, %d bytes", n)
l.Log(ctx, slogutil.LevelTrace, "dhcpv4: received packet", "size", n)
response, err := dhcpv4.FromBytes(b[:n])
if err != nil {
log.Debug("dhcpv4: encoding: %s", err)
l.DebugContext(ctx, "dhcpv4: encoding", slogutil.KeyError, err)
return false, true, err
}
log.Debug("dhcpv4: received message from server: %s", response.Summary())
l.DebugContext(ctx, "dhcpv4: received message from server", "summary", response.Summary())
switch {
case
@ -179,19 +202,24 @@ func tryConn4(req *dhcpv4.DHCPv4, c net.PacketConn, iface *net.Interface) (ok, n
!bytes.Equal(response.ClientHWAddr, iface.HardwareAddr),
response.TransactionID != req.TransactionID,
!response.Options.Has(dhcpv4.OptionDHCPMessageType):
log.Debug("dhcpv4: received response doesn't match the request")
l.DebugContext(ctx, "dhcpv4: received response does not match the request")
return false, true, nil
default:
log.Tracef("dhcpv4: the packet is from an active dhcp server")
l.Log(ctx, slogutil.LevelTrace, "dhcpv4: the packet is from an active dhcp server")
return true, false, nil
}
}
// checkOtherDHCPv6 sends a DHCP request to the specified network interface, and
// waits for a response for a period defined by defaultDiscoverTime.
func checkOtherDHCPv6(iface *net.Interface) (ok bool, err error) {
// waits for a response for a period defined by defaultDiscoverTime. l must not
// be nil.
func checkOtherDHCPv6(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
) (ok bool, err error) {
ifaceIPNet, err := IfaceIPAddrs(iface, IPVersion6)
if err != nil {
return false, fmt.Errorf("getting ipv6 addrs for iface %s: %w", iface.Name, err)
@ -218,16 +246,22 @@ func checkOtherDHCPv6(iface *net.Interface) (ok bool, err error) {
return false, fmt.Errorf("dhcpv6: Couldn't resolve UDP address %s: %w", dst, err)
}
return discover6(iface, udpAddr, dstAddr)
return discover6(ctx, l, iface, udpAddr, dstAddr)
}
func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, err error) {
func discover6(
ctx context.Context,
l *slog.Logger,
iface *net.Interface,
udpAddr *net.UDPAddr,
dstAddr *net.UDPAddr,
) (ok bool, err error) {
req, err := dhcpv6.NewSolicit(iface.HardwareAddr)
if err != nil {
return false, fmt.Errorf("dhcpv6: dhcpv6.NewSolicit: %w", err)
}
log.Debug("DHCPv6: Listening to udp6 %+v", udpAddr)
l.DebugContext(ctx, "dhcpv6: listening to udp6", "addr", udpAddr)
c, err := nclient6.NewIPv6UDPConn(iface.Name, dhcpv6.DefaultClientPort)
if err != nil {
return false, fmt.Errorf("dhcpv6: Couldn't listen on :546: %w", err)
@ -241,10 +275,10 @@ func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, er
for {
var next bool
ok, next, err = tryConn6(req, c)
ok, next, err = tryConn6(ctx, l, req, c)
if next {
if err != nil {
log.Debug("dhcpv6: trying a connection: %s", err)
l.DebugContext(ctx, "dhcpv6: trying a connection", slogutil.KeyError, err)
}
continue
@ -259,10 +293,15 @@ func discover6(iface *net.Interface, udpAddr, dstAddr *net.UDPAddr) (ok bool, er
}
// TODO(a.garipov): See the comment on tryConn4. Sigh…
func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error) {
func tryConn6(
ctx context.Context,
l *slog.Logger,
req *dhcpv6.Message,
c net.PacketConn,
) (ok, next bool, err error) {
// TODO: replicate dhclient's behavior of retrying several times with
// progressively longer timeouts.
log.Tracef("dhcpv6: waiting %v for an answer", defaultDiscoverTime)
l.Log(ctx, slogutil.LevelTrace, "dhcpv6: waiting for an answer", "timeout", defaultDiscoverTime)
b := make([]byte, 4096)
err = c.SetDeadline(time.Now().Add(defaultDiscoverTime))
@ -273,7 +312,7 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
n, _, err := c.ReadFrom(b)
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
log.Debug("dhcpv6: didn't receive dhcp response")
l.DebugContext(ctx, "dhcpv6: didn't receive dhcp response")
return false, false, nil
}
@ -281,21 +320,21 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
return false, false, fmt.Errorf("receiving packet: %w", err)
}
log.Tracef("dhcpv6: received packet, %d bytes", n)
l.Log(ctx, slogutil.LevelTrace, "dhcpv6: received packet", "size", n)
response, err := dhcpv6.FromBytes(b[:n])
if err != nil {
log.Debug("dhcpv6: encoding: %s", err)
l.DebugContext(ctx, "dhcpv6: encoding", slogutil.KeyError, err)
return false, true, err
}
log.Debug("dhcpv6: received message from server: %s", response.Summary())
l.DebugContext(ctx, "dhcpv6: received message from server", "summary", response.Summary())
cid := req.Options.ClientID()
msg, err := response.GetInnerMessage()
if err != nil {
log.Debug("dhcpv6: resp.GetInnerMessage(): %s", err)
l.DebugContext(ctx, "getting inner message", slogutil.KeyError, err)
return false, true, err
}
@ -306,12 +345,12 @@ func tryConn6(req *dhcpv6.Message, c net.PacketConn) (ok, next bool, err error)
rcid != nil &&
cid.Equal(rcid)) {
log.Debug("dhcpv6: received message from server doesn't match our request")
l.DebugContext(ctx, "dhcpv6: received message from server does not match our request")
return false, true, nil
}
log.Tracef("dhcpv6: the packet is from an active dhcp server")
l.Log(ctx, slogutil.LevelTrace, "dhcpv6: the packet is from an active dhcp server")
return true, false, nil
}

View File

@ -2,9 +2,18 @@
package aghnet
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
import (
"context"
"log/slog"
func checkOtherDHCP(ifaceName string) (ok4, ok6 bool, err4, err6 error) {
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func checkOtherDHCP(
_ context.Context,
_ *slog.Logger,
ifaceName string,
) (ok4, ok6 bool, err4, err6 error) {
return false,
false,
aghos.Unsupported("CheckIfOtherDHCPServersPresentV4"),

View File

@ -5,6 +5,7 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/netip"
"path"
"sync/atomic"
@ -12,17 +13,21 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// hostsContainerPrefix is a prefix for logging and wrapping errors in
// HostsContainer's methods.
// hostsContainerPrefix is a prefix for wrapping errors in HostsContainer's
// methods.
const hostsContainerPrefix = "hosts container"
// HostsContainer stores the relevant hosts database provided by the OS and
// processes both A/AAAA and PTR DNS requests for those.
type HostsContainer struct {
// done is the channel to sign closing the container.
// logger is used for logging the operation of the hosts container. It must
// not be nil.
logger *slog.Logger
// done is the channel to signal closing the container.
done chan struct{}
// updates is the channel for receiving updated hosts.
@ -31,10 +36,12 @@ type HostsContainer struct {
// current is the last set of hosts parsed.
current atomic.Pointer[hostsfile.DefaultStorage]
// fsys is the working file system to read hosts files from.
// fsys is the working file system to read hosts files from. It must not be
// nil.
fsys fs.FS
// watcher tracks the changes in specified files and directories.
// watcher tracks the changes in specified files and directories. It must
// not be nil.
watcher aghos.FSWatcher
// patterns stores specified paths in the fs.Glob-compatible form.
@ -45,11 +52,12 @@ type HostsContainer struct {
// the HostsContainer.
const ErrNoHostsPaths errors.Error = "no valid paths to hosts files provided"
// NewHostsContainer creates a container of hosts, that watches the paths with
// w. listID is used as an identifier of the underlying rules list. paths
// shouldn't be empty and each of paths should locate either a file or a
// directory in fsys. fsys and w must be non-nil.
// NewHostsContainer creates a container of hosts that watches the paths with w.
// paths shouldn't be empty, and each path should refer either a file or a
// directory in fsys. l, fsys, and w must be non-nil.
func NewHostsContainer(
ctx context.Context,
l *slog.Logger,
fsys fs.FS,
w aghos.FSWatcher,
paths ...string,
@ -69,6 +77,7 @@ func NewHostsContainer(
}
hc = &HostsContainer{
logger: l,
done: make(chan struct{}, 1),
updates: make(chan *hostsfile.DefaultStorage, 1),
fsys: fsys,
@ -76,10 +85,10 @@ func NewHostsContainer(
patterns: patterns,
}
log.Debug("%s: starting", hostsContainerPrefix)
l.DebugContext(ctx, "starting")
// Load initially.
if err = hc.refresh(); err != nil {
if err = hc.refresh(ctx); err != nil {
return nil, err
}
@ -89,11 +98,11 @@ func NewHostsContainer(
return nil, fmt.Errorf("adding path: %w", err)
}
log.Debug("%s: %s is expected to exist but doesn't", hostsContainerPrefix, p)
l.DebugContext(ctx, "expected path does not exist", "path", p)
}
}
go hc.handleEvents()
go hc.handleEvents(ctx)
return hc, nil
}
@ -101,10 +110,10 @@ func NewHostsContainer(
// Close implements the [io.Closer] interface for *HostsContainer. It closes
// both itself and its [aghos.FSWatcher]. Close must only be called once.
func (hc *HostsContainer) Close() (err error) {
log.Debug("%s: closing", hostsContainerPrefix)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
hc.logger.DebugContext(ctx, "closing")
err = errors.Annotate(hc.watcher.Shutdown(ctx), "closing fs watcher: %w")
// Go on and close the container either way.
@ -159,8 +168,8 @@ func pathsToPatterns(fsys fs.FS, paths []string) (patterns []string, err error)
// handleEvents concurrently handles the file system events. It closes the
// update channel of HostsContainer when finishes. It is intended to be used as
// a goroutine.
func (hc *HostsContainer) handleEvents() {
defer log.OnPanic(fmt.Sprintf("%s: handling events", hostsContainerPrefix))
func (hc *HostsContainer) handleEvents(ctx context.Context) {
defer slogutil.RecoverAndLog(ctx, hc.logger)
defer close(hc.updates)
@ -170,13 +179,13 @@ func (hc *HostsContainer) handleEvents() {
select {
case _, ok = <-eventsCh:
if !ok {
log.Debug("%s: watcher closed the events channel", hostsContainerPrefix)
hc.logger.DebugContext(ctx, "watcher closed the events channel")
continue
}
if err := hc.refresh(); err != nil {
log.Error("%s: warning: refreshing: %s", hostsContainerPrefix, err)
if err := hc.refresh(ctx); err != nil {
hc.logger.ErrorContext(ctx, "refreshing", slogutil.KeyError, err)
}
case _, ok = <-hc.done:
// Go on.
@ -185,8 +194,8 @@ func (hc *HostsContainer) handleEvents() {
}
// sendUpd tries to send the parsed data to the ch.
func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) {
log.Debug("%s: sending upd", hostsContainerPrefix)
func (hc *HostsContainer) sendUpd(ctx context.Context, recs *hostsfile.DefaultStorage) {
hc.logger.DebugContext(ctx, "sending update")
ch := hc.updates
select {
@ -194,11 +203,11 @@ func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) {
// Updates are delivered. Go on.
case <-ch:
ch <- recs
log.Debug("%s: replaced the last update", hostsContainerPrefix)
hc.logger.DebugContext(ctx, "replaced the last update")
case ch <- recs:
// The previous update was just read and the next one pushed. Go on.
default:
log.Error("%s: the updates channel is broken", hostsContainerPrefix)
hc.logger.ErrorContext(ctx, "updates channel is broken")
}
}
@ -206,8 +215,8 @@ func (hc *HostsContainer) sendUpd(recs *hostsfile.DefaultStorage) {
// needed.
//
// TODO(e.burkov): Accept a parameter to specify the files to refresh.
func (hc *HostsContainer) refresh() (err error) {
log.Debug("%s: refreshing", hostsContainerPrefix)
func (hc *HostsContainer) refresh(ctx context.Context) (err error) {
hc.logger.DebugContext(ctx, "refreshing")
// The error is always nil here since no readers passed.
strg, _ := hostsfile.NewDefaultStorage()
@ -223,7 +232,7 @@ func (hc *HostsContainer) refresh() (err error) {
// TODO(e.burkov): Serialize updates using [time.Time].
if !hc.current.Load().Equal(strg) {
hc.current.Store(strg)
hc.sendUpd(strg)
hc.sendUpd(ctx, strg)
}
return nil

View File

@ -13,12 +13,19 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
func TestNewHostsContainer(t *testing.T) {
const dirname = "dir"
const filename = "file1"
@ -67,7 +74,8 @@ func TestNewHostsContainer(t *testing.T) {
return eventsCh
}
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) {
panic(testutil.UnexpectedCall(ctx))
},
@ -96,7 +104,8 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_fs", func(t *testing.T) {
require.Panics(t, func() {
_, _ = aghnet.NewHostsContainer(nil, &aghtest.FSWatcher{
ctx := testutil.ContextWithTimeout(t, testTimeout)
_, _ = aghnet.NewHostsContainer(ctx, testLogger, nil, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) {
panic(testutil.UnexpectedCall(ctx))
},
@ -110,7 +119,8 @@ func TestNewHostsContainer(t *testing.T) {
t.Run("nil_watcher", func(t *testing.T) {
require.Panics(t, func() {
_, _ = aghnet.NewHostsContainer(testFS, nil, p)
ctx := testutil.ContextWithTimeout(t, testTimeout)
_, _ = aghnet.NewHostsContainer(ctx, testLogger, testFS, nil, p)
})
})
@ -124,7 +134,8 @@ func TestNewHostsContainer(t *testing.T) {
OnShutdown: func(_ context.Context) (err error) { return nil },
}
hc, err := aghnet.NewHostsContainer(testFS, errWatcher, p)
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, errWatcher, p)
require.ErrorIs(t, err, errOnAdd)
assert.Nil(t, hc)
@ -173,7 +184,8 @@ func TestHostsContainer_refresh(t *testing.T) {
OnShutdown: func(_ context.Context) (err error) { return nil },
}
hc, err := aghnet.NewHostsContainer(testFS, w, "dir")
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, w, "dir")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)

View File

@ -1,11 +1,11 @@
package aghnet
import (
"context"
"fmt"
"log/slog"
"net"
"time"
"github.com/AdguardTeam/golibs/log"
)
// IPVersion is a alias for int for documentation purposes. Use it when the
@ -66,7 +66,7 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
// IfaceDNSIPAddrs returns IP addresses of the interface suitable to send to
// clients as DNS addresses. If err is nil, addrs contains either no addresses
// or at least two.
// or at least two. l must not be nil.
//
// It makes up to maxAttempts attempts to get the addresses if there are none,
// each time using the provided backoff. Sometimes an interface needs a few
@ -74,6 +74,8 @@ func IfaceIPAddrs(iface NetIface, ipv IPVersion) (ips []net.IP, err error) {
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/2304.
func IfaceDNSIPAddrs(
ctx context.Context,
l *slog.Logger,
iface NetIface,
ipv IPVersion,
maxAttempts int,
@ -90,7 +92,7 @@ func IfaceDNSIPAddrs(
break
}
log.Debug("dhcpv%d: attempt %d: no ip addresses", ipv, n)
l.DebugContext(ctx, "no ip addresses", "attempt", n, "ipv", ipv)
time.Sleep(backoff)
}
@ -102,7 +104,7 @@ func IfaceDNSIPAddrs(
// Don't return errors in case the users want to try and enable the DHCP
// server later.
t := time.Duration(n) * backoff
log.Error("dhcpv%d: no ip for iface after %d attempts and %s", ipv, n, t)
l.ErrorContext(ctx, "no ip addresses for iface", "attempts", n, "duration", t, "ipv", ipv)
return nil, nil
case 1:
@ -111,13 +113,13 @@ func IfaceDNSIPAddrs(
// address.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/1708.
log.Debug("dhcpv%d: setting secondary dns ip to itself", ipv)
l.DebugContext(ctx, "setting secondary dns ip to itself", "ipv", ipv)
addrs = append(addrs, addrs[0])
default:
// Go on.
}
log.Debug("dhcpv%d: got addresses %s after %d attempts", ipv, addrs, n)
l.DebugContext(ctx, "got addresses", "addrs", addrs, "attempts", n, "ipv", ipv)
return addrs, nil
}

View File

@ -220,7 +220,8 @@ func TestIfaceDNSIPAddrs(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
got, err := aghnet.IfaceDNSIPAddrs(tc.iface, tc.ipv, 2, 0)
ctx := testutil.ContextWithTimeout(t, testTimeout)
got, err := aghnet.IfaceDNSIPAddrs(ctx, testLogger, tc.iface, tc.ipv, 2, 0)
require.ErrorIs(t, err, tc.wantErr)
assert.Equal(t, tc.want, got)

View File

@ -1,6 +1,4 @@
// Package aghnet contains networking utilities.
//
// TODO(s.chzhen): Use slog.
package aghnet
import (
@ -9,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/netip"
"net/url"
@ -17,7 +16,7 @@ import (
"github.com/AdguardTeam/dnsproxy/upstream"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
@ -51,23 +50,25 @@ func IfaceHasStaticIP(
return ifaceHasStaticIP(ctx, cmdCons, ifaceName)
}
// IfaceSetStaticIP sets a static IP address for network interface. cmdCons
// must not be nil.
// IfaceSetStaticIP sets a static IP address for network interface. l and
// cmdCons must not be nil.
func IfaceSetStaticIP(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (err error) {
return ifaceSetStaticIP(ctx, cmdCons, ifaceName)
return ifaceSetStaticIP(ctx, l, cmdCons, ifaceName)
}
// GatewayIP returns the gateway IP address for the interface. cmdCons must not
// be nil.
// GatewayIP returns the gateway IP address for the interface. l and cmdCons
// must not be nil.
//
// TODO(e.burkov): Investigate if the gateway address may be fetched in another
// way since not every machine has the software installed.
func GatewayIP(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (ip netip.Addr) {
@ -83,9 +84,9 @@ func GatewayIP(
)
if err != nil {
if code, ok := executil.ExitCodeFromError(err); ok {
log.Debug("fetching gateway ip: unexpected exit code: %d", code)
l.DebugContext(ctx, "fetching gateway ip: unexpected exit code", "code", code)
} else {
log.Debug("%s", err)
l.DebugContext(ctx, "fetching gateway ip", slogutil.KeyError, err)
}
return netip.Addr{}
@ -104,9 +105,9 @@ func GatewayIP(
}
// CanBindPrivilegedPorts checks if current process can bind to privileged
// ports.
func CanBindPrivilegedPorts() (can bool, err error) {
return canBindPrivilegedPorts()
// ports. l must not be nil.
func CanBindPrivilegedPorts(ctx context.Context, l *slog.Logger) (can bool, err error) {
return canBindPrivilegedPorts(ctx, l)
}
// NetInterface represents an entry of network interfaces map.
@ -237,13 +238,13 @@ func InterfaceByIP(ip netip.Addr) (ifaceName string) {
}
// GetSubnet returns the subnet corresponding to the interface of zero prefix if
// the search fails.
// the search fails. l must not be nil.
//
// TODO(e.burkov): See TODO on GetValidNetInterfacesForWeb.
func GetSubnet(ifaceName string) (p netip.Prefix) {
func GetSubnet(ctx context.Context, l *slog.Logger, ifaceName string) (p netip.Prefix) {
netIfaces, err := GetValidNetInterfacesForWeb()
if err != nil {
log.Error("Could not get network interfaces info: %v", err)
l.ErrorContext(ctx, "could not get network interfaces info", slogutil.KeyError, err)
return p
}

View File

@ -2,8 +2,13 @@
package aghnet
import "github.com/AdguardTeam/AdGuardHome/internal/aghos"
import (
"context"
"log/slog"
func canBindPrivilegedPorts() (can bool, err error) {
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
)
func canBindPrivilegedPorts(_ context.Context, _ *slog.Logger) (can bool, err error) {
return aghos.HaveAdminRights()
}

View File

@ -8,6 +8,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"regexp"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@ -119,6 +120,7 @@ func getHardwarePortInfo(
// ifaceSetStaticIP sets a static IP on ifaceName. cmdCons must not be nil.
func ifaceSetStaticIP(
ctx context.Context,
_ *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (err error) {

View File

@ -249,7 +249,7 @@ func TestIfaceSetStaticIP(t *testing.T) {
substRootDirFS(t, tc.fsys)
ctx := testutil.ContextWithTimeout(t, testTimeout)
err := IfaceSetStaticIP(ctx, tc.cmdCons, "en0")
err := IfaceSetStaticIP(ctx, testLogger, tc.cmdCons, "en0")
testutil.AssertErrorMsg(t, tc.wantErrMsg, err)
})
}

View File

@ -7,6 +7,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@ -58,6 +59,11 @@ func (n interfaceName) rcConfStaticConfig(r io.Reader) (_ []string, cont bool, e
return nil, true, s.Err()
}
func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) {
func ifaceSetStaticIP(
_ context.Context,
_ *slog.Logger,
_ executil.CommandConstructor,
_ string,
) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@ -11,6 +11,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
@ -24,6 +25,9 @@ const testTimeout = 1 * time.Second
// testCmdCons is the common command constructor for tests.
var testCmdCons = executil.EmptyCommandConstructor{}
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
// substRootDirFS replaces the aghos.RootDirFS function used throughout the
// package with fsys for tests ran under t.
func substRootDirFS(tb testing.TB, fsys fs.FS) {
@ -87,7 +91,7 @@ func TestGatewayIP(t *testing.T) {
t.Parallel()
ctx := testutil.ContextWithTimeout(t, testTimeout)
assert.Equal(t, tc.want, GatewayIP(ctx, tc.cmdCons, ifaceName))
assert.Equal(t, tc.want, GatewayIP(ctx, testLogger, tc.cmdCons, ifaceName))
})
}
}

View File

@ -7,13 +7,14 @@ import (
"context"
"fmt"
"io"
"log/slog"
"net/netip"
"os"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/stringutil"
"github.com/google/renameio/v2/maybe"
@ -23,7 +24,7 @@ import (
// dhcpсdConf is the name of /etc/dhcpcd.conf file in the root filesystem.
const dhcpcdConf = "etc/dhcpcd.conf"
func canBindPrivilegedPorts() (can bool, err error) {
func canBindPrivilegedPorts(ctx context.Context, l *slog.Logger) (can bool, err error) {
res, err := unix.PrctlRetInt(
unix.PR_CAP_AMBIENT,
unix.PR_CAP_AMBIENT_IS_SET,
@ -35,7 +36,11 @@ func canBindPrivilegedPorts() (can bool, err error) {
if errors.Is(err, unix.EINVAL) {
// Older versions of Linux kernel do not support this. Print a
// warning and check admin rights.
log.Info("warning: cannot check capability cap_net_bind_service: %s", err)
l.WarnContext(
ctx,
"cannot check capability cap_net_bind_service",
slogutil.KeyError, err,
)
} else {
return false, err
}
@ -154,13 +159,14 @@ func findIfaceLine(s *bufio.Scanner, name string) (ok bool) {
}
// ifaceSetStaticIP configures the system to retain its current IP on the
// interface through dhcpcd.conf. cmdCons must not be nil.
// interface through dhcpcd.conf. l and cmdCons must not be nil.
func ifaceSetStaticIP(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
ifaceName string,
) (err error) {
ipNet := GetSubnet(ifaceName)
ipNet := GetSubnet(ctx, l, ifaceName)
if !ipNet.Addr().IsValid() {
return errors.Error("can't get IP address")
}
@ -170,7 +176,7 @@ func ifaceSetStaticIP(
return err
}
gatewayIP := GatewayIP(ctx, cmdCons, ifaceName)
gatewayIP := GatewayIP(ctx, l, cmdCons, ifaceName)
add := dhcpcdConfIface(ifaceName, ipNet, gatewayIP)
body = append(body, []byte(add)...)

View File

@ -7,6 +7,7 @@ import (
"context"
"fmt"
"io"
"log/slog"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
@ -45,6 +46,11 @@ func hostnameIfStaticConfig(r io.Reader) (_ []string, ok bool, err error) {
return nil, true, s.Err()
}
func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) {
func ifaceSetStaticIP(
_ context.Context,
_ *slog.Logger,
_ executil.CommandConstructor,
_ string,
) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@ -5,6 +5,7 @@ package aghnet
import (
"context"
"io"
"log/slog"
"syscall"
"time"
@ -14,7 +15,7 @@ import (
"golang.org/x/sys/windows"
)
func canBindPrivilegedPorts() (can bool, err error) {
func canBindPrivilegedPorts(_ context.Context, _ *slog.Logger) (can bool, err error) {
return true, nil
}
@ -26,7 +27,12 @@ func ifaceHasStaticIP(
return false, aghos.Unsupported("checking static ip")
}
func ifaceSetStaticIP(_ context.Context, _ executil.CommandConstructor, _ string) (err error) {
func ifaceSetStaticIP(
_ context.Context,
_ *slog.Logger,
_ executil.CommandConstructor,
_ string,
) (err error) {
return aghos.Unsupported("setting static ip")
}

View File

@ -576,8 +576,10 @@ func TestClientsAddExisting(t *testing.T) {
// First, init a DHCP server with a single static lease.
config := &dhcpd.ServerConfig{
Enabled: true,
DataDir: t.TempDir(),
BaseLogger: testLogger,
Logger: testLogger,
Enabled: true,
DataDir: t.TempDir(),
Conf4: dhcpd.V4ServerConf{
Enabled: true,
GatewayIP: netip.MustParseAddr("1.2.3.1"),
@ -587,7 +589,8 @@ func TestClientsAddExisting(t *testing.T) {
},
}
dhcpServer, err := dhcpd.Create(config)
ctx = testutil.ContextWithTimeout(t, testTimeout)
dhcpServer, err := dhcpd.Create(ctx, config)
require.NoError(t, err)
storage, err := client.NewStorage(ctx, &client.StorageConfig{

View File

@ -1,7 +1,9 @@
package dhcpd
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"time"
@ -17,6 +19,14 @@ import (
// ServerConfig is the configuration for the DHCP server. The order of YAML
// fields is important, since the YAML configuration file follows it.
type ServerConfig struct {
// BaseLogger is used for creating loggers for dhcpv4 and dhcpv6. It must
// not be nil.
BaseLogger *slog.Logger `yaml:"-"`
// Logger is used for logging the operation of the DHCP server. It must not
// be nil.
Logger *slog.Logger `yaml:"-"`
// CommandConstructor is used to run external commands. It must not be nil.
CommandConstructor executil.CommandConstructor `yaml:"-"`
@ -85,7 +95,7 @@ type DHCPServer interface {
WriteDiskConfig6(c *V6ServerConf)
// Start - start server
Start() (err error)
Start(ctx context.Context) (err error)
// Stop - stop server
Stop() (err error)
getLeasesRef() []*dhcpsvc.Lease
@ -93,6 +103,10 @@ type DHCPServer interface {
// V4ServerConf - server configuration
type V4ServerConf struct {
// Logger is used for logging the operation of the DHCPv4 server. It must
// not be nil.
Logger *slog.Logger `yaml:"-" json:"-"`
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`
@ -232,6 +246,10 @@ func (c *V4ServerConf) Validate() (err error) {
// V6ServerConf - server configuration
type V6ServerConf struct {
// Logger is used for logging the operation of the DHCPv6 server. It must
// not be nil.
Logger *slog.Logger `yaml:"-" json:"-"`
Enabled bool `yaml:"-" json:"-"`
InterfaceName string `yaml:"-" json:"-"`

View File

@ -2,6 +2,7 @@
package dhcpd
import (
"context"
"fmt"
"net"
"net/netip"
@ -9,7 +10,7 @@ import (
"time"
"github.com/AdguardTeam/AdGuardHome/internal/dhcpsvc"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/timeutil"
)
@ -53,7 +54,7 @@ const (
// Interface is the DHCP server that deals with both IP address families.
type Interface interface {
Start() (err error)
Start(ctx context.Context) (err error)
Stop() (err error)
// Enabled returns true if the DHCP server is running.
@ -104,9 +105,11 @@ var _ Interface = (*server)(nil)
// Create initializes and returns the DHCP server handling both address
// families. It also registers the corresponding HTTP API endpoints.
func Create(conf *ServerConfig) (s *server, err error) {
func Create(ctx context.Context, conf *ServerConfig) (s *server, err error) {
s = &server{
conf: &ServerConfig{
BaseLogger: conf.BaseLogger,
Logger: conf.Logger,
CommandConstructor: conf.CommandConstructor,
ConfModifier: conf.ConfModifier,
@ -125,7 +128,7 @@ func Create(conf *ServerConfig) (s *server, err error) {
// [aghhttp.RegisterFunc].
s.registerHandlers()
v4Enabled, v6Enabled, err := s.setServers(conf)
v4Enabled, v6Enabled, err := s.setServers(ctx, conf)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return nil, err
@ -158,8 +161,12 @@ func Create(conf *ServerConfig) (s *server, err error) {
// setServers updates DHCPv4 and DHCPv6 servers created from the provided
// configuration conf. It returns the status of both the DHCPv4 and the DHCPv6
// servers, which is always false for corresponding server on any error.
func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err error) {
func (s *server) setServers(
ctx context.Context,
conf *ServerConfig,
) (v4Enabled, v6Enabled bool, err error) {
v4conf := conf.Conf4
v4conf.Logger = s.conf.BaseLogger.With(slogutil.KeyPrefix, "dhcpv4_server")
v4conf.InterfaceName = s.conf.InterfaceName
v4conf.notify = s.onNotify
v4conf.Enabled = s.conf.Enabled && v4conf.RangeStart.IsValid()
@ -170,10 +177,11 @@ func (s *server) setServers(conf *ServerConfig) (v4Enabled, v6Enabled bool, err
return false, false, fmt.Errorf("creating dhcpv4 srv: %w", err)
}
log.Debug("dhcpd: warning: creating dhcpv4 srv: %s", err)
s.conf.Logger.WarnContext(ctx, "creating dhcpv4 server", slogutil.KeyError, err)
}
v6conf := conf.Conf6
v6conf.Logger = s.conf.BaseLogger.With(slogutil.KeyPrefix, "dhcpv6_server")
v6conf.InterfaceName = s.conf.InterfaceName
v6conf.notify = s.onNotify
v6conf.Enabled = s.conf.Enabled && len(v6conf.RangeStart) != 0
@ -213,7 +221,9 @@ func (s *server) onNotify(flags uint32) {
if flags == LeaseChangedDBStore {
err := s.dbStore()
if err != nil {
log.Error("updating db: %s", err)
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
s.conf.Logger.ErrorContext(ctx, "updating db", slogutil.KeyError, err)
}
return
@ -239,13 +249,13 @@ func (s *server) WriteDiskConfig(c *ServerConfig) {
}
// Start will listen on port 67 and serve DHCP requests.
func (s *server) Start() (err error) {
err = s.srv4.Start()
func (s *server) Start(ctx context.Context) (err error) {
err = s.srv4.Start(ctx)
if err != nil {
return err
}
err = s.srv6.Start()
err = s.srv6.Start(ctx)
if err != nil {
return err
}

View File

@ -0,0 +1,13 @@
package dhcpd
import (
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()

View File

@ -7,6 +7,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/netip"
@ -205,7 +206,7 @@ func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, er
}
if !hasStaticIP {
err = aghnet.IfaceSetStaticIP(ctx, cmdCons, ifaceName)
err = aghnet.IfaceSetStaticIP(ctx, s.conf.Logger, cmdCons, ifaceName)
if err != nil {
err = fmt.Errorf("setting static ip: %w", err)
@ -213,7 +214,7 @@ func (s *server) enableDHCP(ctx context.Context, ifaceName string) (code int, er
}
}
err = s.Start()
err = s.Start(ctx)
if err != nil {
return http.StatusBadRequest, fmt.Errorf("starting dhcp server: %w", err)
}
@ -390,6 +391,7 @@ type netInterfaceJSON struct {
// handleDHCPInterfaces is the handler for the GET /control/dhcp/interfaces
// HTTP API.
func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resp := map[string]*netInterfaceJSON{}
ifaces, err := net.Interfaces()
@ -410,7 +412,7 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
continue
}
jsonIface, iErr := newNetInterfaceJSON(r.Context(), iface, s.conf.CommandConstructor)
jsonIface, iErr := newNetInterfaceJSON(ctx, s.conf.Logger, iface, s.conf.CommandConstructor)
if iErr != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", iErr)
@ -426,9 +428,10 @@ func (s *server) handleDHCPInterfaces(w http.ResponseWriter, r *http.Request) {
}
// newNetInterfaceJSON creates a JSON object from a [net.Interface] iface.
// cmdCons must not be nil.
// l and cmdCons must not be nil.
func newNetInterfaceJSON(
ctx context.Context,
l *slog.Logger,
iface net.Interface,
cmdCons executil.CommandConstructor,
) (out *netInterfaceJSON, err error) {
@ -483,7 +486,7 @@ func newNetInterfaceJSON(
return nil, nil
}
out.GatewayIP = aghnet.GatewayIP(ctx, cmdCons, iface.Name)
out.GatewayIP = aghnet.GatewayIP(ctx, l, cmdCons, iface.Name)
return out, nil
}
@ -533,6 +536,8 @@ type findActiveServerReq struct {
// 2. check if a static IP is configured for the network interface;
// 3. responds with the results.
func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if aghhttp.WriteTextPlainDeprecated(w, r) {
return
}
@ -569,24 +574,28 @@ func (s *server) handleDHCPFindActiveServer(w http.ResponseWriter, r *http.Reque
}
cmdCons := s.conf.CommandConstructor
if isStaticIP, serr := aghnet.IfaceHasStaticIP(r.Context(), cmdCons, ifaceName); serr != nil {
if isStaticIP, serr := aghnet.IfaceHasStaticIP(ctx, cmdCons, ifaceName); serr != nil {
result.V4.StaticIP.Static = "error"
result.V4.StaticIP.Error = serr.Error()
} else if !isStaticIP {
result.V4.StaticIP.Static = "no"
// TODO(e.burkov): The returned IP should only be of version 4.
result.V4.StaticIP.IP = aghnet.GetSubnet(ifaceName).String()
result.V4.StaticIP.IP = aghnet.GetSubnet(ctx, s.conf.Logger, ifaceName).String()
}
setOtherDHCPResult(ifaceName, result)
s.setOtherDHCPResult(ctx, ifaceName, result)
aghhttp.WriteJSONResponseOK(w, r, result)
}
// setOtherDHCPResult sets the results of the check for another DHCP server in
// result.
func setOtherDHCPResult(ifaceName string, result *dhcpSearchResult) {
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ifaceName)
// result. result must not be nil.
func (s *server) setOtherDHCPResult(
ctx context.Context,
ifaceName string,
result *dhcpSearchResult,
) {
found4, found6, err4, err6 := aghnet.CheckOtherDHCP(ctx, s.conf.Logger, ifaceName)
if err4 != nil {
result.V4.OtherServer.Found = "error"
result.V4.OtherServer.Error = err4.Error()

View File

@ -11,6 +11,7 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -84,7 +85,10 @@ func TestServer_handleDHCPStatus(t *testing.T) {
Hostname: staticName,
}
s, err := Create(&ServerConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
s, err := Create(ctx, &ServerConfig{
BaseLogger: testLogger,
Logger: testLogger,
Enabled: true,
Conf4: *defaultV4ServerConf(),
DataDir: t.TempDir(),
@ -178,7 +182,10 @@ func TestServer_HandleUpdateStaticLease(t *testing.T) {
},
}
s, err := Create(&ServerConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
s, err := Create(ctx, &ServerConfig{
BaseLogger: testLogger,
Logger: testLogger,
Enabled: true,
Conf4: *defaultV4ServerConf(),
Conf6: V6ServerConf{},
@ -266,7 +273,10 @@ func TestServer_HandleUpdateStaticLease_validation(t *testing.T) {
Hostname: anotherV4Name,
}}
s, err := Create(&ServerConfig{
ctx := testutil.ContextWithTimeout(t, testTimeout)
s, err := Create(ctx, &ServerConfig{
BaseLogger: testLogger,
Logger: testLogger,
Enabled: true,
Conf4: *defaultV4ServerConf(),
Conf6: V6ServerConf{},

View File

@ -5,6 +5,7 @@ package dhcpd
// 'u-root/u-root' package, a dependency of 'insomniacslk/dhcp' package, doesn't build on Windows
import (
"context"
"net"
"net/netip"
@ -25,7 +26,7 @@ func (winServer) UpdateStaticLease(_ *dhcpsvc.Lease) (err error) { return
func (winServer) FindMACbyIP(_ netip.Addr) (mac net.HardwareAddr) { return nil }
func (winServer) WriteDiskConfig4(_ *V4ServerConf) {}
func (winServer) WriteDiskConfig6(_ *V6ServerConf) {}
func (winServer) Start() (err error) { return nil }
func (winServer) Start(ctx context.Context) (err error) { return nil }
func (winServer) Stop() (err error) { return nil }
func (winServer) HostByIP(_ netip.Addr) (host string) { return "" }
func (winServer) IPByHost(_ string) (ip netip.Addr) { return netip.Addr{} }

View File

@ -4,6 +4,7 @@ package dhcpd
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
@ -1299,7 +1300,7 @@ func (s *v4Server) packetHandler(conn net.PacketConn, peer net.Addr, req *dhcpv4
}
// Start starts the IPv4 DHCP server.
func (s *v4Server) Start() (err error) {
func (s *v4Server) Start(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv4: %w") }()
if !s.enabled() {
@ -1315,6 +1316,8 @@ func (s *v4Server) Start() (err error) {
log.Debug("dhcpv4: starting...")
dnsIPAddrs, err := aghnet.IfaceDNSIPAddrs(
ctx,
s.conf.Logger,
iface,
aghnet.IPVersion4,
defaultMaxAttempts,

View File

@ -4,6 +4,7 @@ package dhcpd
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
@ -657,8 +658,13 @@ func (s *v6Server) packetHandler(conn net.PacketConn, peer net.Addr, req dhcpv6.
// configureDNSIPAddrs updates v6Server configuration with the slice of DNS IP
// addresses of provided interface iface. Initializes RA module.
func (s *v6Server) configureDNSIPAddrs(iface *net.Interface) (ok bool, err error) {
func (s *v6Server) configureDNSIPAddrs(
ctx context.Context,
iface *net.Interface,
) (ok bool, err error) {
dnsIPAddrs, err := aghnet.IfaceDNSIPAddrs(
ctx,
s.conf.Logger,
iface,
aghnet.IPVersion6,
defaultMaxAttempts,
@ -700,7 +706,7 @@ func (s *v6Server) initRA(iface *net.Interface) (err error) {
}
// Start starts the IPv6 DHCP server.
func (s *v6Server) Start() (err error) {
func (s *v6Server) Start(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "dhcpv6: %w") }()
if !s.conf.Enabled {
@ -715,7 +721,7 @@ func (s *v6Server) Start() (err error) {
log.Debug("dhcpv6: starting...")
ok, err := s.configureDNSIPAddrs(iface)
ok, err := s.configureDNSIPAddrs(ctx, iface)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
return err

View File

@ -1465,7 +1465,8 @@ func TestPTRResponseFromHosts(t *testing.T) {
}
var eventsCalledCounter uint32
hc, err := aghnet.NewHostsContainer(testFS, &aghtest.FSWatcher{
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, testFS, &aghtest.FSWatcher{
OnStart: func(ctx context.Context) (_ error) { panic(testutil.UnexpectedCall(ctx)) },
OnEvents: func() (e <-chan struct{}) {
assert.Equal(t, uint32(1), atomic.AddUint32(&eventsCalledCounter, 1))

View File

@ -362,7 +362,10 @@ func TestServer_HandleTestUpstreamDNS(t *testing.T) {
Host: netutil.JoinHostPort(upstreamHost, hostsListener.Port()),
}).String()
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(
ctx,
testLogger,
fstest.MapFS{
hostsFileName: &fstest.MapFile{
Data: []byte(hostsListener.Addr().String() + " " + upstreamHost),

View File

@ -6,7 +6,6 @@ import (
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@ -53,7 +52,7 @@ func TestDNSFilter_CheckHostRules_dnsrewrite(t *testing.T) {
`
conf := &filtering.Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
SafeBrowsingCacheSize: 10000,
ParentalCacheSize: 10000,
SafeSearchCacheSize: 1000,

View File

@ -8,18 +8,13 @@ import (
"os"
"path/filepath"
"testing"
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// testTimeout is the common timeout for tests.
const testTimeout = 5 * time.Second
// serveHTTPLocally starts a new HTTP server, that handles its index with h. It
// also gracefully closes the listener when the test under t finishes.
func serveHTTPLocally(tb testing.TB, h http.Handler) (urlStr string) {
@ -86,7 +81,7 @@ func newDNSFilter(tb testing.TB) (d *DNSFilter) {
tb.Helper()
dnsFilter, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
DataDir: tb.TempDir(),
HTTPClient: &http.Client{
Timeout: testTimeout,

View File

@ -6,6 +6,7 @@ import (
"fmt"
"net/netip"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/hashprefix"
@ -18,6 +19,9 @@ import (
"github.com/stretchr/testify/require"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
const (
sbBlocked = "wmconvirus.narod.ru"
pcBlocked = "pornhub.com"

View File

@ -0,0 +1,13 @@
package filtering_test
import (
"time"
"github.com/AdguardTeam/golibs/logutil/slogutil"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()

View File

@ -22,6 +22,9 @@ const (
cacheSize = 10000
)
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
func TestChcker_getQuestion(t *testing.T) {
const suf = "sb.dns.adguard.com."
@ -45,7 +48,7 @@ func TestChcker_getQuestion(t *testing.T) {
assert.False(t, slices.Contains(hashes, hash))
c := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
TXTSuffix: suf,
})
@ -100,7 +103,7 @@ func TestChecker_storeInCache(t *testing.T) {
const testTimeout = 1 * time.Second
c := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
CacheTime: cacheTime,
})
@ -158,7 +161,7 @@ func TestChecker_storeInCache(t *testing.T) {
assert.True(t, ok)
c = New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
CacheTime: cacheTime,
})
@ -195,7 +198,7 @@ func TestChecker_Check(t *testing.T) {
for _, tc := range testCases {
c := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
CacheTime: cacheTime,
CacheSize: cacheSize,
})

View File

@ -11,7 +11,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/urlfilter/rules"
"github.com/miekg/dns"
@ -49,12 +48,13 @@ func TestDNSFilter_CheckHost_hostsContainer(t *testing.T) {
OnAdd: func(name string) (err error) { return nil },
OnShutdown: func(_ context.Context) (err error) { return nil },
}
hc, err := aghnet.NewHostsContainer(files, watcher, "hosts")
ctx := testutil.ContextWithTimeout(t, testTimeout)
hc, err := aghnet.NewHostsContainer(ctx, testLogger, files, watcher, "hosts")
require.NoError(t, err)
testutil.CleanupAndRequireSuccess(t, hc.Close)
conf := &filtering.Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
EtcHosts: hc,
}
f, err := filtering.New(conf, nil)

View File

@ -14,7 +14,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/schedule"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -110,7 +109,7 @@ func TestDNSFilter_handleFilteringSetURL(t *testing.T) {
confModifiedCalled = true
}
d, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
FilteringEnabled: true,
Filters: tc.initial,
HTTPClient: &http.Client{
@ -195,7 +194,7 @@ func TestDNSFilter_handleSafeBrowsingStatus(t *testing.T) {
}
d, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ConfModifier: confModifier,
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
@ -282,7 +281,7 @@ func TestDNSFilter_handleParentalStatus(t *testing.T) {
}
d, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ConfModifier: confModifier,
DataDir: filtersDir,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
@ -384,7 +383,7 @@ func TestDNSFilter_HandleCheckHost(t *testing.T) {
}
dnsFilter, err := New(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
BlockedServices: &BlockedServices{
Schedule: schedule.EmptyWeekly(),
},

View File

@ -5,7 +5,6 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghalg"
"github.com/AdguardTeam/AdGuardHome/internal/filtering/rulelist"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/assert"
)
@ -65,7 +64,7 @@ func TestIDGenerator_Fix(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
g := newIDGenerator(1, slogutil.NewDiscardLogger())
g := newIDGenerator(1, testLogger)
g.fix(tc.in)
assertUniqueIDs(t, tc.in)

View File

@ -13,6 +13,9 @@ import (
"github.com/stretchr/testify/require"
)
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
func TestNewDefaultStorage(t *testing.T) {
items := []*Item{{
Domain: "example.com",
@ -20,7 +23,7 @@ func TestNewDefaultStorage(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: -1,
})
@ -33,7 +36,7 @@ func TestDefaultStorage_CRUD(t *testing.T) {
var items []*Item
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: -1,
})
@ -122,7 +125,7 @@ func TestDefaultStorage_MatchRequest(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: -1,
})
@ -298,7 +301,7 @@ func TestDefaultStorage_MatchRequest_Levels(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: -1,
})
@ -370,7 +373,7 @@ func TestDefaultStorage_MatchRequest_ExceptionCNAME(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: -1,
})
@ -438,7 +441,7 @@ func TestDefaultStorage_MatchRequest_ExceptionIP(t *testing.T) {
}}
s, err := NewDefaultStorage(&Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
Rewrites: items,
ListID: -1,
})

View File

@ -8,11 +8,9 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/filtering"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -30,9 +28,6 @@ type rewriteUpdateJSON struct {
}
const (
// testTimeout is the common timeout for tests.
testTimeout = 100 * time.Millisecond
listURL = "/control/rewrite/list"
addURL = "/control/rewrite/add"
deleteURL = "/control/rewrite/delete"
@ -159,7 +154,7 @@ func TestDNSFilter_handleRewriteHTTP(t *testing.T) {
}
d, err := filtering.New(&filtering.Config{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ConfModifier: confModifier,
HTTPRegister: func(_, url string, handler http.HandlerFunc) {
handlers[url] = handler

View File

@ -41,6 +41,9 @@ var testConf = filtering.SafeSearchConfig{
YouTube: true,
}
// testLogger is a logger used in tests.
var testLogger = slogutil.NewDiscardLogger()
// yandexIP is the expected IP address of Yandex safe search results. Keep in
// sync with the rules data.
var yandexIP = netip.AddrFrom4([4]byte{213, 180, 193, 56})
@ -49,7 +52,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
conf := testConf
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
@ -111,7 +114,7 @@ func TestDefault_CheckHost_yandex(t *testing.T) {
func TestDefault_CheckHost_google(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
@ -163,7 +166,7 @@ func (r *testResolver) LookupIP(
func TestDefault_CheckHost_duckduckgoAAAA(t *testing.T) {
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: testConf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,
@ -186,7 +189,7 @@ func TestDefault_Update(t *testing.T) {
conf := testConf
ctx := testutil.ContextWithTimeout(t, testTimeout)
ss, err := safesearch.NewDefault(ctx, &safesearch.DefaultConfig{
Logger: slogutil.NewDiscardLogger(),
Logger: testLogger,
ServicesConfig: conf,
CacheSize: testCacheSize,
CacheTTL: testCacheTTL,

View File

@ -196,7 +196,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
if err != nil {
resp.DNS.Status = err.Error()
} else if !req.DNS.IP.IsUnspecified() {
resp.StaticIP = handleStaticIP(ctx, req.DNS.IP, req.SetStaticIP, web.cmdCons)
resp.StaticIP = handleStaticIP(ctx, web.logger, req.DNS.IP, req.SetStaticIP, web.cmdCons)
}
aghhttp.WriteJSONResponseOK(w, r, resp)
@ -206,6 +206,7 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
// owns IP. cmdCons must not be nil.
func handleStaticIP(
ctx context.Context,
l *slog.Logger,
ip netip.Addr,
set bool,
cmdCons executil.CommandConstructor,
@ -222,7 +223,7 @@ func handleStaticIP(
if set {
// Try to set a static IP for the specified interface.
err := aghnet.IfaceSetStaticIP(ctx, cmdCons, interfaceName)
err := aghnet.IfaceSetStaticIP(ctx, l, cmdCons, interfaceName)
if err != nil {
ipResp.Static = "error"
ipResp.Error = err.Error()
@ -244,7 +245,7 @@ func handleStaticIP(
if isStaticIP {
ipResp.Static = "yes"
}
ipResp.IP = aghnet.GetSubnet(interfaceName).String()
ipResp.IP = aghnet.GetSubnet(ctx, l, interfaceName).String()
return ipResp
}

View File

@ -32,6 +32,8 @@ type temporaryError interface {
//
// TODO(a.garipov): Find out if this API used with a GET method by anyone.
func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resp := &versionResponse{}
if web.conf.disableUpdate {
resp.Disabled = true
@ -54,7 +56,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
}
}
err = web.requestVersionInfo(r.Context(), resp, req.Recheck)
err = web.requestVersionInfo(ctx, resp, req.Recheck)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusBadGateway, "%s", err)
@ -62,7 +64,7 @@ func (web *webAPI) handleVersionJSON(w http.ResponseWriter, r *http.Request) {
return
}
err = resp.setAllowedToAutoUpdate(web.tlsManager)
err = resp.setAllowedToAutoUpdate(ctx, web.logger, web.tlsManager)
if err != nil {
// Don't wrap the error, because it's informative enough as is.
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
@ -164,8 +166,13 @@ type versionResponse struct {
}
// setAllowedToAutoUpdate sets CanAutoUpdate to true if AdGuard Home is actually
// allowed to perform an automatic update by the OS. tlsMgr must not be nil.
func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error) {
// allowed to perform an automatic update by the OS. l and tlsMgr must not be
// nil.
func (vr *versionResponse) setAllowedToAutoUpdate(
ctx context.Context,
l *slog.Logger,
tlsMgr *tlsManager,
) (err error) {
if vr.CanAutoUpdate != aghalg.NBTrue {
return nil
}
@ -174,7 +181,7 @@ func (vr *versionResponse) setAllowedToAutoUpdate(tlsMgr *tlsManager) (err error
if tlsConfUsesPrivilegedPorts(tlsMgr.config()) ||
config.HTTPConfig.Address.Port() < 1024 ||
config.DNS.Port < 1024 {
canUpdate, err = aghnet.CanBindPrivilegedPorts()
canUpdate, err = aghnet.CanBindPrivilegedPorts(ctx, l)
if err != nil {
return fmt.Errorf("checking ability to bind privileged ports: %w", err)
}
@ -192,7 +199,7 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
}
// finishUpdate completes an update procedure. It is intended to be used as a
// goroutine.
// goroutine. l and cmdCons must not be nil.
func finishUpdate(
ctx context.Context,
l *slog.Logger,

View File

@ -175,7 +175,7 @@ func setupContext(ctx context.Context, baseLogger *slog.Logger, opts options) (e
if globalContext.firstRun {
log.Info("This is the first time AdGuard Home is launched")
checkNetworkPermissions()
checkNetworkPermissions(ctx, baseLogger)
return nil
}
@ -265,7 +265,14 @@ func setupHostsContainer(ctx context.Context, baseLogger *slog.Logger) (err erro
return fmt.Errorf("getting default system hosts paths: %w", err)
}
globalContext.etcHosts, err = aghnet.NewHostsContainer(osutil.RootDirFS(), hostsWatcher, paths...)
l := baseLogger.With(slogutil.KeyPrefix, "hosts")
globalContext.etcHosts, err = aghnet.NewHostsContainer(
ctx,
l,
osutil.RootDirFS(),
hostsWatcher,
paths...,
)
if err != nil {
closeErr := hostsWatcher.Shutdown(ctx)
if errors.Is(err, aghnet.ErrNoHostsPaths) {
@ -308,9 +315,11 @@ func initContextClients(
config.DHCP.DataDir = globalContext.getDataDir()
config.DHCP.HTTPRegister = httpRegister
config.DHCP.CommandConstructor = executil.SystemCommandConstructor{}
config.DHCP.BaseLogger = logger
config.DHCP.Logger = logger.With(slogutil.KeyPrefix, "dhcp_server")
config.DHCP.ConfModifier = confModifier
globalContext.dhcpServer, err = dhcpd.Create(config.DHCP)
globalContext.dhcpServer, err = dhcpd.Create(ctx, config.DHCP)
if globalContext.dhcpServer == nil || err != nil {
// TODO(a.garipov): There are a lot of places in the code right
// now which assume that the DHCP server can be nil despite this
@ -805,7 +814,7 @@ func run(
}()
if globalContext.dhcpServer != nil {
err = globalContext.dhcpServer.Start()
err = globalContext.dhcpServer.Start(ctx)
if err != nil {
log.Error("starting dhcp server: %s", err)
}
@ -940,11 +949,11 @@ func (c *configuration) anonymizer() (ipmut *aghnet.IPMut) {
}
// checkNetworkPermissions checks if the current user permissions are enough to
// use the required networking functionality.
func checkNetworkPermissions() {
// use the required networking functionality. l must not be nil.
func checkNetworkPermissions(ctx context.Context, l *slog.Logger) {
log.Info("Checking if AdGuard Home has necessary permissions")
if ok, err := aghnet.CanBindPrivilegedPorts(); !ok || err != nil {
if ok, err := aghnet.CanBindPrivilegedPorts(ctx, l); !ok || err != nil {
log.Fatal("This is the first launch of AdGuard Home. You must run it as Administrator.")
}