Pull request 2453: AGDNS-3086-arpdb-use-executil

Squashed commit of the following:

commit 856439ef7c974d67b9afd19b4e4e9bed5b782de0
Merge: 65d593d09 2e5005d7d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Tue Sep 2 15:44:52 2025 +0300

    Merge branch 'master' into AGDNS-3086-arpdb-use-executil

commit 65d593d091
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Aug 28 19:55:42 2025 +0300

    all: imp docs

commit 3b569a94d7
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Aug 28 15:52:54 2025 +0300

    agh: imp code

commit 803f79d144
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Aug 28 14:22:48 2025 +0300

    all: imp code

commit 20a0702a0c
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Aug 27 14:57:35 2025 +0300

    aghos: add todo

commit e400e1757d
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Aug 21 22:37:22 2025 +0300

    all: imp docs

commit 971b5bc1b1
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Thu Aug 21 22:00:47 2025 +0300

    all: imp code

commit 4fc73dca76
Author: Stanislav Chzhen <s.chzhen@adguard.com>
Date:   Wed Aug 20 15:55:19 2025 +0300

    all: use executil
This commit is contained in:
Stanislav Chzhen 2025-09-02 15:58:30 +03:00
parent 2e5005d7df
commit cd79a4ac72
24 changed files with 657 additions and 328 deletions

View File

@ -3,8 +3,18 @@ package agh
import (
"context"
"fmt"
"strings"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil/fakeos/fakeexec"
)
// DefaultOutputLimit is the default limit of bytes for commands' standard
// output and standard error.
const DefaultOutputLimit = 512
// ConfigModifier defines an interface for updating the global configuration.
type ConfigModifier interface {
// Apply applies changes to the global configuration.
@ -20,3 +30,142 @@ var _ ConfigModifier = EmptyConfigModifier{}
// Apply implements the [ConfigModifier] for EmptyConfigModifier.
func (em EmptyConfigModifier) Apply(ctx context.Context) {}
// exitErr implements [executil.ExitCodeError] for tests to simulate non-zero
// process exit codes.
//
// TODO(s.chzhen): Consider constructing an [exec.ExitError] instead.
type exitErr struct {
code osutil.ExitCode
}
// newExitErr returns a properly initialized exitErr with the provided code.
func newExitErr(code osutil.ExitCode) (err exitErr) {
return exitErr{code: code}
}
// type check
var _ executil.ExitCodeError = exitErr{}
// Error implements the [executil.ExitCodeError] for exitErr.
func (e exitErr) Error() (s string) {
return fmt.Sprintf("exit code %d", e.code)
}
// ExitCode implements the [executil.ExitCodeError] for exitErr.
func (e exitErr) ExitCode() (code osutil.ExitCode) {
return e.code
}
// ExternalCommand is a fake command used by [NewMultipleCommandConstructor].
type ExternalCommand struct {
// Err is the error returned, if non-nil.
Err error
// Cmd contains the command path and arguments.
Cmd string
// Out is written to stdout if non-empty.
Out string
// Code is returned as the exit code if non-zero.
Code osutil.ExitCode
}
// keyCommand builds a key for a command lookup.
func keyCommand(path string, args []string) (k string) {
if len(args) == 0 {
return path
}
return path + " " + strings.Join(args, " ")
}
// parseCommand splits a command string into the executable path and args.
func parseCommand(s string) (path string, args []string) {
f := strings.Fields(s)
if len(f) == 0 {
return "", nil
}
return f[0], f[1:]
}
// NewMultipleCommandConstructor is a helper function that returns a mock
// [executil.CommandConstructor] for tests that supports multiple commands.
//
// TODO(s.chzhen): Move to aghtest once the import cycle is resolved, since it
// will be called from the aghnet package, which imports the whois package,
// which in turn imports aghnet.
func NewMultipleCommandConstructor(cmds ...ExternalCommand) (cs executil.CommandConstructor) {
table := make(map[string]ExternalCommand, len(cmds))
for _, ec := range cmds {
p, a := parseCommand(ec.Cmd)
table[keyCommand(p, a)] = ec
}
onNew := func(_ context.Context, conf *executil.CommandConfig) (c executil.Command, err error) {
ec := table[keyCommand(conf.Path, conf.Args)]
cmd := fakeexec.NewCommand()
cmd.OnStart = func(_ context.Context) (err error) {
if ec.Out != "" {
_, _ = conf.Stdout.Write([]byte(ec.Out))
}
return nil
}
cmd.OnWait = func(_ context.Context) (err error) {
if ec.Err != nil {
return ec.Err
}
if ec.Code != 0 {
return newExitErr(ec.Code)
}
return nil
}
return cmd, nil
}
return &fakeexec.CommandConstructor{OnNew: onNew}
}
// NewCommandConstructor is a helper function that returns a mock
// [executil.CommandConstructor] for tests.
func NewCommandConstructor(
_ string,
code osutil.ExitCode,
stdout string,
cmdErr error,
) (cs executil.CommandConstructor) {
onNew := func(_ context.Context, conf *executil.CommandConfig) (c executil.Command, err error) {
cmd := fakeexec.NewCommand()
cmd.OnStart = func(_ context.Context) (err error) {
if conf.Stdout != nil {
_, _ = conf.Stdout.Write([]byte(stdout))
}
return nil
}
cmd.OnWait = func(_ context.Context) (err error) {
if cmdErr != nil {
return cmdErr
}
if code != 0 {
return newExitErr(code)
}
return nil
}
return cmd, nil
}
return &fakeexec.CommandConstructor{OnNew: onNew}
}

View File

@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// DialContextFunc is the semantic alias for dialing functions, such as
@ -27,7 +28,16 @@ type DialContextFunc = func(ctx context.Context, network, addr string) (conn net
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.
aghosRunCommand = aghos.RunCommand
//
// TODO(s.chzhen): Use [aghos.RunCommand] directly.
aghosRunCommand = (func() func(string, ...string) (int, []byte, error) {
ctx := context.TODO()
cmdCons := executil.SystemCommandConstructor{}
return func(command string, arguments ...string) (int, []byte, error) {
return aghos.RunCommand(ctx, cmdCons, command, arguments...)
}
})()
// netInterfaces is the function to get the available network interfaces.
netInterfaceAddrs = net.InterfaceAddrs

View File

@ -5,13 +5,13 @@ package aghos
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"os/exec"
"path"
"runtime"
"slices"
@ -19,6 +19,9 @@ import (
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// Default file, binary, and directory permissions.
@ -50,23 +53,53 @@ func HaveAdminRights() (bool, error) {
const MaxCmdOutputSize = 64 * 1024
// RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
cmd := exec.Command(command, arguments...)
out, err := cmd.Output()
//
// TODO(s.chzhen): Consider removing this after addressing the current behavior
// where a non-zero exit code is returned together with a nil error.
func RunCommand(
ctx context.Context,
cmdCons executil.CommandConstructor,
command string,
arguments ...string,
) (code int, output []byte, err error) {
stdoutBuf := bytes.Buffer{}
stderrBuf := bytes.Buffer{}
out = out[:min(len(out), MaxCmdOutputSize)]
err = executil.Run(
ctx,
cmdCons,
&executil.CommandConfig{
Path: command,
Args: arguments,
Stdout: ioutil.NewTruncatedWriter(&stdoutBuf, MaxCmdOutputSize),
Stderr: &stderrBuf,
},
)
if err != nil {
if eerr := new(exec.ExitError); errors.As(err, &eerr) {
return eerr.ExitCode(), eerr.Stderr, nil
}
return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out)
if err == nil {
return osutil.ExitCodeSuccess, stdoutBuf.Bytes(), nil
}
return cmd.ProcessState.ExitCode(), out, nil
code, ok := executil.ExitCodeFromError(err)
if ok {
// Mirror the old behavior and return a nil-error on non-zero code
// status.
return code, stderrBuf.Bytes(), nil
}
code = osutil.ExitCodeFailure
return code, nil, fmt.Errorf("command %q failed: %w: %s", command, err, &stdoutBuf)
}
// psArgs holds the default ps arguments to avoid per-call slice allocations.
//
// Don't use -C flag here since it's a feature of linux's ps
// implementation. Use POSIX-compatible flags instead.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3457.
var psArgs = []string{"-A", "-o", "pid=", "-o", "comm="}
// PIDByCommand searches for process named command and returns its PID ignoring
// the PIDs from except. If no processes found, the error returned. l must not
// be nil.
@ -76,31 +109,35 @@ func PIDByCommand(
command string,
except ...int,
) (pid int, err error) {
// Don't use -C flag here since it's a feature of linux's ps
// implementation. Use POSIX-compatible flags instead.
const psCmd = "ps"
l.DebugContext(ctx, "executing", "cmd", psCmd, "args", psArgs)
stdoutBuf := bytes.Buffer{}
// TODO(s.chzhen): Catch stderr.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3457.
cmd := exec.Command("ps", "-A", "-o", "pid=", "-o", "comm=")
var stdout io.ReadCloser
if stdout, err = cmd.StdoutPipe(); err != nil {
return 0, fmt.Errorf("getting the command's stdout pipe: %w", err)
}
if err = cmd.Start(); err != nil {
return 0, fmt.Errorf("start command executing: %w", err)
}
// TODO(s.chzhen): Consider streaming the output if needed. Using
// [io.Pipe] here is unnecessary; it complicates lifecycle management
// because the output must be read concurrently, and the PipeWriter must be
// explicitly closed to signal EOF. Since this command's output is small, a
// bytes.Buffer via executil.Run is sufficient.
runErr := executil.Run(
ctx,
executil.SystemCommandConstructor{},
&executil.CommandConfig{
Path: psCmd,
Args: psArgs,
Stdout: &stdoutBuf,
},
)
var instNum int
pid, instNum, err = parsePSOutput(stdout, command, except)
pid, instNum, err = parsePSOutput(&stdoutBuf, command, except)
if err != nil {
return 0, err
}
if err = cmd.Wait(); err != nil {
return 0, fmt.Errorf("executing the command: %w", err)
}
switch instNum {
case 0:
// TODO(e.burkov): Use constant error.
@ -111,8 +148,12 @@ func PIDByCommand(
l.WarnContext(ctx, "instances found", "num", instNum, "command", command)
}
if code := cmd.ProcessState.ExitCode(); code != 0 {
return 0, fmt.Errorf("ps finished with code %d", code)
if runErr != nil {
if code, ok := executil.ExitCodeFromError(runErr); ok {
return 0, fmt.Errorf("ps finished with code %d", code)
}
return 0, fmt.Errorf("executing the command: %w", runErr)
}
return pid, nil

View File

@ -4,6 +4,7 @@ package arpdb
import (
"bufio"
"bytes"
"context"
"fmt"
"log/slog"
"net"
@ -11,18 +12,16 @@ import (
"slices"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/service"
)
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.
aghosRunCommand = aghos.RunCommand
// rootDirFS is the filesystem pointing to the root directory.
rootDirFS = osutil.RootDirFS()
)
@ -30,17 +29,17 @@ var (
// Interface stores and refreshes the network neighborhood reported by ARP
// (Address Resolution Protocol).
type Interface interface {
// Refresh updates the stored data. It must be safe for concurrent use.
Refresh() (err error)
// Refresher updates the stored data. It must be safe for concurrent use.
service.Refresher
// Neighbors returnes the last set of data reported by ARP. Both the method
// Neighbors returns the last set of data reported by ARP. Both the method
// and it's result must be safe for concurrent use.
Neighbors() (ns []Neighbor)
}
// New returns the [Interface] properly initialized for the OS.
func New(logger *slog.Logger) (arp Interface) {
return newARPDB(logger)
return newARPDB(logger, executil.SystemCommandConstructor{})
}
// Empty is the [Interface] implementation that does nothing.
@ -51,7 +50,7 @@ var _ Interface = Empty{}
// Refresh implements the [Interface] interface for EmptyARPContainer. It does
// nothing and always returns nil error.
func (Empty) Refresh() (err error) { return nil }
func (Empty) Refresh(_ context.Context) (err error) { return nil }
// Neighbors implements the [Interface] interface for EmptyARPContainer. It
// always returns nil.
@ -164,28 +163,40 @@ type parseNeighsFunc func(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (
// cmdARPDB is the implementation of the [Interface] that uses command line to
// retrieve data.
type cmdARPDB struct {
logger *slog.Logger
parse parseNeighsFunc
ns *neighs
cmd string
args []string
logger *slog.Logger
cmdCons executil.CommandConstructor
parse parseNeighsFunc
ns *neighs
cmd string
args []string
}
// type check
var _ Interface = (*cmdARPDB)(nil)
// Refresh implements the [Interface] interface for *cmdARPDB.
func (arp *cmdARPDB) Refresh() (err error) {
func (arp *cmdARPDB) Refresh(ctx context.Context) (err error) {
defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }()
code, out, err := aghosRunCommand(arp.cmd, arp.args...)
var stdout bytes.Buffer
err = executil.Run(
ctx,
arp.cmdCons,
&executil.CommandConfig{
Path: arp.cmd,
Args: arp.args,
Stdout: &stdout,
},
)
if err != nil {
if code, ok := executil.ExitCodeFromError(err); ok {
return fmt.Errorf("running command: unexpected exit code %d", code)
}
return fmt.Errorf("running command: %w", err)
} else if code != 0 {
return fmt.Errorf("running command: unexpected exit code %d", code)
}
sc := bufio.NewScanner(bytes.NewReader(out))
sc := bufio.NewScanner(&stdout)
ns := arp.parse(arp.logger, sc, arp.ns.len())
if err = sc.Err(); err != nil {
// TODO(e.burkov): This error seems unreachable. Investigate.
@ -226,11 +237,11 @@ func newARPDBs(arps ...Interface) (arp *arpdbs) {
var _ Interface = (*arpdbs)(nil)
// Refresh implements the [Interface] interface for *arpdbs.
func (arp *arpdbs) Refresh() (err error) {
func (arp *arpdbs) Refresh(ctx context.Context) (err error) {
var errs []error
for _, a := range arp.arps {
err = a.Refresh()
err = a.Refresh(ctx)
if err != nil {
errs = append(errs, err)

View File

@ -9,12 +9,14 @@ import (
"sync"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
return &cmdARPDB{
logger: logger,
parse: parseArpA,
logger: logger,
cmdCons: cmdCons,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),

View File

@ -1,15 +1,16 @@
package arpdb
import (
"fmt"
"context"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"sync"
"testing"
"time"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
@ -17,49 +18,15 @@ import (
"github.com/stretchr/testify/require"
)
// testTimeout is a common timeout for tests.
const testTimeout = 1 * time.Second
// testdata is the filesystem containing data for testing the package.
var testdata fs.FS = os.DirFS("./testdata")
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under tb.
func substShell(tb testing.TB, rc RunCmdFunc) {
tb.Helper()
prev := aghosRunCommand
tb.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
func Test_New(t *testing.T) {
var a Interface
require.NotPanics(t, func() { a = New(slogutil.NewDiscardLogger()) })
@ -71,7 +38,7 @@ func Test_New(t *testing.T) {
// TestARPDB is the mock implementation of [Interface] to use in tests.
type TestARPDB struct {
OnRefresh func() (err error)
OnRefresh func(ctx context.Context) (err error)
OnNeighbors func() (ns []Neighbor)
}
@ -79,8 +46,8 @@ type TestARPDB struct {
var _ Interface = (*TestARPDB)(nil)
// Refresh implements the [Interface] interface for *TestARPDB.
func (arp *TestARPDB) Refresh() (err error) {
return arp.OnRefresh()
func (arp *TestARPDB) Refresh(ctx context.Context) (err error) {
return arp.OnRefresh(ctx)
}
// Neighbors implements the [Interface] interface for *TestARPDB.
@ -98,13 +65,17 @@ func Test_NewARPDBs(t *testing.T) {
}
succDB := &TestARPDB{
OnRefresh: func() (err error) { succRefrCount++; return nil },
OnRefresh: func(_ context.Context) (err error) { succRefrCount++; return nil },
OnNeighbors: func() (ns []Neighbor) {
return []Neighbor{{Name: "abc", IP: knownIP, MAC: knownMAC}}
},
}
failDB := &TestARPDB{
OnRefresh: func() (err error) { failRefrCount++; return errors.Error("refresh failed") },
OnRefresh: func(_ context.Context) (err error) {
failRefrCount++
return errors.Error("refresh failed")
},
OnNeighbors: func() (ns []Neighbor) { return nil },
}
@ -112,7 +83,7 @@ func Test_NewARPDBs(t *testing.T) {
t.Cleanup(clnp)
a := newARPDBs(succDB, failDB)
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
@ -124,7 +95,7 @@ func Test_NewARPDBs(t *testing.T) {
t.Cleanup(clnp)
a := newARPDBs(failDB, succDB)
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
@ -138,7 +109,7 @@ func Test_NewARPDBs(t *testing.T) {
wantMsg := "each arpdb failed: refresh failed\nrefresh failed"
a := newARPDBs(failDB, failDB)
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.Error(t, err)
testutil.AssertErrorMsg(t, wantMsg, err)
@ -152,7 +123,7 @@ func Test_NewARPDBs(t *testing.T) {
shouldFail := false
unstableDB := &TestARPDB{
OnRefresh: func() (err error) {
OnRefresh: func(_ context.Context) (err error) {
if shouldFail {
err = errors.Error("unstable failed")
}
@ -171,21 +142,21 @@ func Test_NewARPDBs(t *testing.T) {
a := newARPDBs(unstableDB, succDB)
// Unstable ARPDB should refresh successfully.
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Zero(t, succRefrCount)
assert.NotEmpty(t, a.Neighbors())
// Unstable ARPDB should fail and the succDB should be used.
err = a.Refresh()
err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
assert.NotEmpty(t, a.Neighbors())
// Unstable ARPDB should refresh successfully again.
err = a.Refresh()
err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Equal(t, 1, succRefrCount)
@ -194,7 +165,7 @@ func Test_NewARPDBs(t *testing.T) {
t.Run("empty", func(t *testing.T) {
a := newARPDBs()
require.NoError(t, a.Refresh())
require.NoError(t, a.Refresh(testutil.ContextWithTimeout(t, testTimeout)))
assert.Empty(t, a.Neighbors())
})
@ -212,36 +183,36 @@ func TestCmdARPDB_arpa(t *testing.T) {
}
t.Run("arp_a", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, arpAOutput, nil)
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 0, arpAOutput, nil)
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
})
t.Run("runcmd_error", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run"))
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", errors.Error("can't run"))
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
testutil.AssertErrorMsg(t, "cmd arpdb: running command: running: can't run", err)
})
t.Run("bad_code", func(t *testing.T) {
sh := theOnlyCmd("cmd", 1, "", nil)
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 1, "", nil)
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err)
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
testutil.AssertErrorMsg(
t,
"cmd arpdb: running command: unexpected exit code 1",
err,
)
})
t.Run("empty", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", nil)
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", nil)
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Empty(t, a.Neighbors())
@ -254,7 +225,7 @@ func TestEmptyARPDB(t *testing.T) {
t.Run("refresh", func(t *testing.T) {
var err error
require.NotPanics(t, func() {
err = a.Refresh()
err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
})
assert.NoError(t, err)

View File

@ -4,6 +4,7 @@ package arpdb
import (
"bufio"
"context"
"fmt"
"io/fs"
"log/slog"
@ -14,10 +15,11 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/stringutil"
)
func newARPDB(logger *slog.Logger) (arp *arpdbs) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *arpdbs) {
// Use the common storage among the implementations.
ns := &neighs{
mu: &sync.RWMutex{},
@ -40,10 +42,11 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
},
// Then, try "arp -a -n".
&cmdARPDB{
logger: logger,
parse: parseF,
ns: ns,
cmd: "arp",
logger: logger,
cmdCons: cmdCons,
parse: parseF,
ns: ns,
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors.
// By default ARP attempts to resolve the hostnames via DNS. See
// man 8 arp.
@ -53,11 +56,12 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
},
// Finally, try "ip neigh".
&cmdARPDB{
logger: logger,
parse: parseIPNeigh,
ns: ns,
cmd: "ip",
args: []string{"neigh"},
logger: logger,
cmdCons: cmdCons,
parse: parseIPNeigh,
ns: ns,
cmd: "ip",
args: []string{"neigh"},
},
)
}
@ -73,7 +77,7 @@ type fsysARPDB struct {
var _ Interface = (*fsysARPDB)(nil)
// Refresh implements the [Interface] interface for *fsysARPDB.
func (arp *fsysARPDB) Refresh() (err error) {
func (arp *fsysARPDB) Refresh(_ context.Context) (err error) {
var f fs.File
f, err = arp.fsys.Open(arp.filename)
if err != nil {

View File

@ -9,7 +9,9 @@ import (
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@ -54,7 +56,7 @@ func TestFSysARPDB(t *testing.T) {
filename: "proc_net_arp",
}
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
ns := a.Neighbors()
@ -62,25 +64,20 @@ func TestFSysARPDB(t *testing.T) {
}
func TestCmdARPDB_linux(t *testing.T) {
sh := mapShell{
"arp -a": {err: nil, out: arpAOutputWrt, code: 0},
"ip neigh": {err: nil, out: ipNeighOutput, code: 0},
}
substShell(t, sh.RunCmd)
t.Run("wrt", func(t *testing.T) {
a := &cmdARPDB{
logger: slogutil.NewDiscardLogger(),
parse: parseArpAWrt,
cmd: "arp",
args: []string{"-a"},
logger: slogutil.NewDiscardLogger(),
cmdCons: agh.NewCommandConstructor("arp -a", 0, arpAOutputWrt, nil),
parse: parseArpAWrt,
cmd: "arp",
args: []string{"-a"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())
@ -88,16 +85,17 @@ func TestCmdARPDB_linux(t *testing.T) {
t.Run("ip_neigh", func(t *testing.T) {
a := &cmdARPDB{
logger: slogutil.NewDiscardLogger(),
parse: parseIPNeigh,
cmd: "ip",
args: []string{"neigh"},
logger: slogutil.NewDiscardLogger(),
cmdCons: agh.NewCommandConstructor("ip neigh", 0, ipNeighOutput, nil),
parse: parseIPNeigh,
cmd: "ip",
args: []string{"neigh"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
},
}
err := a.Refresh()
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
require.NoError(t, err)
assert.Equal(t, wantNeighs, a.Neighbors())

View File

@ -9,12 +9,14 @@ import (
"sync"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
return &cmdARPDB{
logger: logger,
parse: parseArpA,
logger: logger,
cmdCons: cmdCons,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),

View File

@ -9,12 +9,14 @@ import (
"sync"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
return &cmdARPDB{
logger: logger,
parse: parseArpA,
logger: logger,
cmdCons: cmdCons,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),

View File

@ -249,7 +249,7 @@ func (s *Storage) addFromSystemARP(ctx context.Context) {
s.mu.Lock()
defer s.mu.Unlock()
if err := s.arpDB.Refresh(); err != nil {
if err := s.arpDB.Refresh(ctx); err != nil {
s.arpDB = arpdb.Empty{}
s.logger.ErrorContext(ctx, "refreshing arp container", slogutil.KeyError, err)

View File

@ -1,6 +1,7 @@
package client_test
import (
"context"
"net"
"net/netip"
"runtime"
@ -18,6 +19,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/hostsfile"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/service"
"github.com/AdguardTeam/golibs/testutil"
"github.com/AdguardTeam/golibs/testutil/faketime"
"github.com/AdguardTeam/golibs/timeutil"
@ -63,17 +65,17 @@ func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) {
// Interface stores and refreshes the network neighborhood reported by ARP
// (Address Resolution Protocol).
type Interface interface {
// Refresh updates the stored data. It must be safe for concurrent use.
Refresh() (err error)
// Refresher updates the stored data. It must be safe for concurrent use.
service.Refresher
// Neighbors returnes the last set of data reported by ARP. Both the method
// Neighbors returns the last set of data reported by ARP. Both the method
// and it's result must be safe for concurrent use.
Neighbors() (ns []arpdb.Neighbor)
}
// testARPDB is a mock implementation of the [arpdb.Interface].
type testARPDB struct {
onRefresh func() (err error)
onRefresh func(ctx context.Context) (err error)
onNeighbors func() (ns []arpdb.Neighbor)
}
@ -81,8 +83,8 @@ type testARPDB struct {
var _ arpdb.Interface = (*testARPDB)(nil)
// Refresh implements the [arpdb.Interface] interface for *testARP.
func (c *testARPDB) Refresh() (err error) {
return c.onRefresh()
func (c *testARPDB) Refresh(ctx context.Context) (err error) {
return c.onRefresh(ctx)
}
// Neighbors implements the [arpdb.Interface] interface for *testARP.
@ -218,7 +220,7 @@ func TestStorage_Add_arp(t *testing.T) {
)
a := &testARPDB{
onRefresh: func() (err error) { return nil },
onRefresh: func(_ context.Context) (err error) { return nil },
onNeighbors: func() (ns []arpdb.Neighbor) {
mu.Lock()
defer mu.Unlock()
@ -392,7 +394,7 @@ func TestClientsDHCP(t *testing.T) {
arpCh := make(chan []arpdb.Neighbor, 1)
arpDB := &testARPDB{
onRefresh: func() (err error) { return nil },
onRefresh: func(_ context.Context) (err error) { return nil },
onNeighbors: func() (ns []arpdb.Neighbor) {
select {
case ns = <-arpCh:

View File

@ -9,7 +9,6 @@ import (
"net/http"
"net/netip"
"os"
"os/exec"
"path/filepath"
"runtime"
"time"
@ -20,8 +19,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/quic-go/quic-go/http3"
)
@ -129,6 +130,7 @@ func (req *checkConfReq) validateDNS(
ctx context.Context,
l *slog.Logger,
tcpPorts aghalg.UniqChecker[tcpPort],
cmdCons executil.CommandConstructor,
) (canAutofix bool, err error) {
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
@ -160,7 +162,7 @@ func (req *checkConfReq) validateDNS(
// Try to fix automatically.
canAutofix = checkDNSStubListener(ctx, l)
if canAutofix && req.DNS.Autofix {
if derr := disableDNSStubListener(ctx, l); derr != nil {
if derr := disableDNSStubListener(ctx, l, cmdCons); derr != nil {
l.ErrorContext(ctx, "disabling DNSStubListener", slogutil.KeyError, err)
}
@ -188,7 +190,8 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
resp.Web.Status = err.Error()
}
if resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts); err != nil {
resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts, web.cmdCons)
if err != nil {
resp.DNS.Status = err.Error()
} else if !req.DNS.IP.IsUnspecified() {
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
@ -243,34 +246,29 @@ func checkDNSStubListener(ctx context.Context, l *slog.Logger) (ok bool) {
return false
}
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
l.InfoContext(
cmds := container.KeyValues[string, []string]{{
Key: "systemctl",
Value: []string{"is-enabled", "systemd-resolved"},
}, {
Key: "grep",
Value: []string{"-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf"},
}}
for _, cmd := range cmds {
l.DebugContext(ctx, "executing", "cmd", cmd.Key, "args", cmd.Value)
err := executil.RunWithPeek(
ctx,
"execution failed",
"cmd", cmd.Path,
"code", cmd.ProcessState.ExitCode(),
slogutil.KeyError, err,
executil.SystemCommandConstructor{},
agh.DefaultOutputLimit,
cmd.Key,
cmd.Value...,
)
if err != nil {
l.InfoContext(ctx, "execution failed", "cmd", cmd.Key, slogutil.KeyError, err)
return false
}
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err = cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
l.InfoContext(
ctx,
"execution failed",
"cmd", cmd.Path,
"code", cmd.ProcessState.ExitCode(),
slogutil.KeyError, err,
)
return false
return false
}
}
return true
@ -287,7 +285,11 @@ const resolvConfPath = "/etc/resolv.conf"
// disableDNSStubListener deactivates DNSStubListerner and returns an error, if
// any.
func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) {
func disableDNSStubListener(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
) (err error) {
dir := filepath.Dir(resolvedConfPath)
err = os.MkdirAll(dir, 0o755)
if err != nil {
@ -306,15 +308,21 @@ func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) {
return fmt.Errorf("os.Symlink: %s: %w", resolvConfPath, err)
}
cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err = cmd.Output()
const systemctlCmd = "systemctl"
systemctlArgs := []string{"reload-or-restart", "systemd-resolved"}
l.DebugContext(ctx, "executing", "cmd", systemctlCmd, "args", systemctlArgs)
err = executil.RunWithPeek(
ctx,
cmdCons,
agh.DefaultOutputLimit,
systemctlCmd,
systemctlArgs...,
)
if err != nil {
return err
}
if cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("process %s exited with an error: %d",
cmd.Path, cmd.ProcessState.ExitCode())
return fmt.Errorf("executing cmd: %w", err)
}
return nil

View File

@ -7,7 +7,6 @@ import (
"log/slog"
"net/http"
"os"
"os/exec"
"runtime"
"syscall"
"time"
@ -19,6 +18,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// temporaryError is the interface for temporary errors from the Go standard
@ -148,7 +148,13 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
// The background context is used because the underlying functions wrap it
// with timeout and shut down the server, which handles current request. It
// also should be done in a separate goroutine for the same reason.
go finishUpdate(context.Background(), web.logger, execPath, web.conf.runningAsService)
go finishUpdate(
context.Background(),
web.logger,
web.cmdCons,
execPath,
web.conf.runningAsService,
)
}
// versionResponse is the response for /control/version.json endpoint.
@ -187,7 +193,13 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
// finishUpdate completes an update procedure. It is intended to be used as a
// goroutine.
func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningAsService bool) {
func finishUpdate(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
execPath string,
runningAsService bool,
) {
defer slogutil.RecoverAndExit(ctx, l, osutil.ExitCodeFailure)
l.InfoContext(ctx, "stopping all tasks")
@ -203,8 +215,16 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA
// instance, because Windows doesn't allow it.
//
// TODO(a.garipov): Recheck the claim above.
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err = cmd.Start()
var cmd executil.Command
cmd, err = cmdCons.New(ctx, &executil.CommandConfig{
Path: "cmd",
Args: []string{"/c", "net stop AdGuardHome & net start AdGuardHome"},
})
if err != nil {
panic(fmt.Errorf("constructing cmd: %w", err))
}
err = cmd.Start(ctx)
if err != nil {
panic(fmt.Errorf("restarting service: %w", err))
}
@ -212,12 +232,21 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA
os.Exit(osutil.ExitCodeSuccess)
}
cmd := exec.Command(execPath, os.Args[1:]...)
l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:])
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Start()
var cmd executil.Command
cmd, err = cmdCons.New(ctx, &executil.CommandConfig{
Path: execPath,
Args: os.Args[1:],
Stdin: os.Stdin,
Stdout: os.Stdout,
Stderr: os.Stderr,
})
if err != nil {
panic(fmt.Errorf("constructing cmd: %w", err))
}
err = cmd.Start(ctx)
if err != nil {
panic(fmt.Errorf("restarting: %w", err))
}

View File

@ -43,6 +43,7 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// Global context
@ -589,12 +590,13 @@ func initWeb(
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL)
webConf := &webConfig{
updater: upd,
logger: logger,
baseLogger: baseLogger,
confModifier: confModifier,
tlsManager: tlsMgr,
auth: auth,
CommandConstructor: executil.SystemCommandConstructor{},
updater: upd,
logger: logger,
baseLogger: baseLogger,
confModifier: confModifier,
tlsManager: tlsMgr,
auth: auth,
clientFS: clientFS,
@ -821,18 +823,19 @@ func newUpdater(
l.DebugContext(ctx, "creating updater", "config_path", confPath)
return updater.NewUpdater(&updater.Config{
Client: conf.Filtering.HTTPClient,
Logger: l,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: workDir,
ConfName: confPath,
ExecPath: execPath,
VersionCheckURL: versionURL,
Client: conf.Filtering.HTTPClient,
Logger: l,
CommandConstructor: executil.SystemCommandConstructor{},
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: workDir,
ConfName: confPath,
ExecPath: execPath,
VersionCheckURL: versionURL,
}), isCustomURL
}

View File

@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/kardianos/service"
)
@ -76,11 +77,11 @@ func (p *program) Stop(_ service.Service) (err error) {
//
// On OpenWrt, the service utility may not exist. We use our service script
// directly in this case.
func svcStatus(s service.Service) (status service.Status, err error) {
func svcStatus(ctx context.Context, s service.Service) (status service.Status, err error) {
status, err = s.Status()
if err != nil && service.Platform() == "unix-systemv" {
var code int
code, err = runInitdCommand("status")
code, err = runInitdCommand(ctx, "status")
if err != nil || code != 0 {
return service.StatusStopped, nil
}
@ -105,7 +106,7 @@ func svcAction(ctx context.Context, l *slog.Logger, s service.Service, action st
err = service.Control(s, action)
if err != nil && service.Platform() == "unix-systemv" &&
(action == "start" || action == "stop" || action == "restart") {
_, err = runInitdCommand(action)
_, err = runInitdCommand(ctx, action)
}
return err
@ -326,7 +327,7 @@ func handleServiceStatusCommand(
l *slog.Logger,
s service.Service,
) {
status, errSt := svcStatus(s)
status, errSt := svcStatus(ctx, s)
if errSt != nil {
l.ErrorContext(ctx, "failed to get service status", slogutil.KeyError, errSt)
os.Exit(osutil.ExitCodeFailure)
@ -356,7 +357,7 @@ func handleServiceInstallCommand(ctx context.Context, l *slog.Logger, s service.
// On OpenWrt it is important to run enable after the service
// installation. Otherwise, the service won't start on the system
// startup.
_, err = runInitdCommand("enable")
_, err = runInitdCommand(ctx, "enable")
if err != nil {
l.ErrorContext(ctx, "running init enable", slogutil.KeyError, err)
os.Exit(osutil.ExitCodeFailure)
@ -386,7 +387,7 @@ func handleServiceUninstallCommand(ctx context.Context, l *slog.Logger, s servic
if aghos.IsOpenWrt() {
// On OpenWrt it is important to run disable command first
// as it will remove the symlink
_, err := runInitdCommand("disable")
_, err := runInitdCommand(ctx, "disable")
if err != nil {
l.ErrorContext(ctx, "running init disable", slogutil.KeyError, err)
os.Exit(osutil.ExitCodeFailure)
@ -458,10 +459,11 @@ func configureService(c *service.Config) {
// runInitdCommand runs init.d service command
// returns command code or error if any
func runInitdCommand(action string) (int, error) {
func runInitdCommand(ctx context.Context, action string) (int, error) {
confPath := "/etc/init.d/" + serviceName
// Pass the script and action as a single string argument.
code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action)
cmdCons := executil.SystemCommandConstructor{}
code, _, err := aghos.RunCommand(ctx, cmdCons, "sh", "-c", confPath+" "+action)
return code, err
}

View File

@ -4,12 +4,14 @@ package home
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"os/exec"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/kardianos/service"
)
@ -58,6 +60,7 @@ func (sys *sysvSystem) New(i service.Interface, c *service.Config) (s service.Se
}
return &sysvService{
cmdCons: executil.SystemCommandConstructor{},
Service: s,
name: c.Name,
}, nil
@ -66,6 +69,9 @@ func (sys *sysvSystem) New(i service.Interface, c *service.Config) (s service.Se
// sysvService is a wrapper for a SysV [service.Service] that supplements the
// installation and uninstallation.
type sysvService struct {
// cmdCons is used to run external commands. It must not be nil.
cmdCons executil.CommandConstructor
// Service must have an unexported type *service.sysv.
service.Service
@ -85,7 +91,8 @@ func (svc *sysvService) Install() (err error) {
return err
}
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "defaults")
// TODO(s.chzhen): Pass context.
_, _, err = aghos.RunCommand(context.TODO(), svc.cmdCons, "update-rc.d", svc.name, "defaults")
// Don't wrap an error since it's informative enough as is.
return err
@ -100,7 +107,8 @@ func (svc *sysvService) Uninstall() (err error) {
return err
}
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "remove")
// TODO(s.chzhen): Pass context.
_, _, err = aghos.RunCommand(context.TODO(), svc.cmdCons, "update-rc.d", svc.name, "remove")
// Don't wrap an error since it's informative enough as is.
return err
@ -127,6 +135,7 @@ func (sys *systemdSystem) New(i service.Interface, c *service.Config) (s service
}
return &systemdService{
cmdCons: executil.SystemCommandConstructor{},
Service: s,
unitName: fmt.Sprintf("%s.service", c.Name),
}, nil
@ -138,6 +147,9 @@ var _ service.Service = (*systemdService)(nil)
// systemdService is a wrapper for a systemd [service.Service] that enriches the
// service status information.
type systemdService struct {
// cmdCons is used to run external commands. It must not be nil.
cmdCons executil.CommandConstructor
// Service is expected to have an unexported type *service.systemd.
service.Service
@ -150,26 +162,37 @@ var _ service.Service = (*systemdService)(nil)
// Status implements the [service.Service] interface for *systemdService.
func (s *systemdService) Status() (status service.Status, err error) {
cmd := exec.Command("systemctl", "show", s.unitName)
stdout, err := cmd.StdoutPipe()
if err != nil {
return service.StatusUnknown, fmt.Errorf("connecting to command stdout: %w", err)
}
const systemctlCmd = "systemctl"
if err = cmd.Start(); err != nil {
return service.StatusUnknown, fmt.Errorf("start command executing: %w", err)
}
var (
systemctlArgs = []string{"show", s.unitName}
systemctlStdout bytes.Buffer
)
status, err = parseSystemctlShow(stdout)
if err != nil {
return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err)
}
err = cmd.Wait()
// TODO(s.chzhen): Consider streaming the output if needed. Using
// [io.Pipe] here is unnecessary; it complicates lifecycle management
// because the output must be read concurrently, and the PipeWriter must be
// explicitly closed to signal EOF. Since this command's output is small, a
// bytes.Buffer via executil.Run is sufficient.
err = executil.Run(
// TODO(s.chzhen): Pass context.
context.TODO(),
s.cmdCons,
&executil.CommandConfig{
Path: systemctlCmd,
Args: systemctlArgs,
Stdout: &systemctlStdout,
},
)
if err != nil {
return service.StatusUnknown, fmt.Errorf("executing command: %w", err)
}
status, err = parseSystemctlShow(&systemctlStdout)
if err != nil {
return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err)
}
return status, nil
}

View File

@ -4,6 +4,7 @@ package home
import (
"cmp"
"context"
"fmt"
"os"
"os/signal"
@ -15,6 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/kardianos/service"
)
@ -57,16 +59,18 @@ func (openbsdSystem) Interactive() (ok bool) {
// New implements service.System interface for openbsdSystem.
func (openbsdSystem) New(i service.Interface, c *service.Config) (s service.Service, err error) {
return &openbsdRunComService{
i: i,
cfg: c,
cmdCons: executil.SystemCommandConstructor{},
i: i,
cfg: c,
}, nil
}
// openbsdRunComService is the RunCom-based service.Service to be used on the
// OpenBSD.
type openbsdRunComService struct {
i service.Interface
cfg *service.Config
cmdCons executil.CommandConstructor
i service.Interface
cfg *service.Config
}
// Platform implements service.Service interface for *openbsdRunComService.
@ -210,8 +214,10 @@ func (s *openbsdRunComService) configureSysStartup(enable bool) (err error) {
cmd = "disable"
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
var code int
code, _, err = aghos.RunCommand("rcctl", cmd, s.cfg.Name)
code, _, err = aghos.RunCommand(ctx, s.cmdCons, "rcctl", cmd, s.cfg.Name)
if err != nil {
return err
} else if code != 0 {
@ -312,11 +318,14 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) {
return "", err
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
// TODO(e.burkov): It's possible that os.ErrNotExist is caused by
// something different than the service script's non-existence. Keep it
// in mind, when replace the aghos.RunCommand.
var outData []byte
_, outData, err = aghos.RunCommand(scriptPath, cmd)
_, outData, err = aghos.RunCommand(ctx, s.cmdCons, scriptPath, cmd)
if errors.Is(err, os.ErrNotExist) {
return "", service.ErrNotInstalled
}

View File

@ -20,6 +20,7 @@ import (
"github.com/AdguardTeam/golibs/netutil/httputil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/NYTimes/gziphandler"
"github.com/quic-go/quic-go/http3"
"golang.org/x/net/http2"
@ -39,6 +40,9 @@ const (
)
type webConfig struct {
// CommandConstructor is used to run external commands. It must not be nil.
CommandConstructor executil.CommandConstructor
updater *updater.Updater
// logger is a slog logger used in webAPI. It must not be nil.
@ -111,6 +115,9 @@ type webAPI struct {
// confModifier is used to update the global configuration.
confModifier agh.ConfigModifier
// cmdCons is used to run external commands.
cmdCons executil.CommandConstructor
// TODO(a.garipov): Refactor all these servers.
httpServer *http.Server
@ -143,6 +150,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
w = &webAPI{
conf: conf,
confModifier: conf.confModifier,
cmdCons: conf.CommandConstructor,
logger: conf.logger,
baseLogger: conf.baseLogger,
tlsManager: conf.tlsManager,

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -58,13 +59,14 @@ func TestUpdater_VersionInfo(t *testing.T) {
fakeURL := srvURL.JoinPath("adguardhome", version.ChannelBeta, "version.json")
u := updater.NewUpdater(&updater.Config{
Client: srv.Client(),
Logger: testLogger,
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "arm",
GOOS: "linux",
VersionCheckURL: fakeURL,
Client: srv.Client(),
Logger: testLogger,
CommandConstructor: executil.EmptyCommandConstructor{},
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "arm",
GOOS: "linux",
VersionCheckURL: fakeURL,
})
ctx := testutil.ContextWithTimeout(t, testTimeout)
@ -132,15 +134,16 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
for _, tc := range testCases {
u := updater.NewUpdater(&updater.Config{
Client: fakeClient,
Logger: testLogger,
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOOS: "linux",
GOARCH: tc.arch,
GOARM: tc.arm,
GOMIPS: tc.mips,
VersionCheckURL: fakeURL,
Client: fakeClient,
Logger: testLogger,
CommandConstructor: executil.EmptyCommandConstructor{},
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOOS: "linux",
GOARCH: tc.arch,
GOARM: tc.arm,
GOMIPS: tc.mips,
VersionCheckURL: fakeURL,
})
ctx := testutil.ContextWithTimeout(t, testTimeout)

View File

@ -4,6 +4,7 @@ package updater
import (
"archive/tar"
"archive/zip"
"bytes"
"compress/gzip"
"context"
"fmt"
@ -13,7 +14,6 @@ import (
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
@ -26,6 +26,7 @@ import (
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// Updater is the AdGuard Home updater.
@ -33,6 +34,8 @@ type Updater struct {
client *http.Client
logger *slog.Logger
cmdCons executil.CommandConstructor
version string
channel string
goarch string
@ -88,6 +91,9 @@ type Config struct {
// be nil, see [DefaultVersionURL].
VersionCheckURL *url.URL
// CommandConstructor is used to run external commands. It must not be nil.
CommandConstructor executil.CommandConstructor
// Version is the current AdGuard Home version. It must not be empty.
Version string
@ -129,6 +135,8 @@ func NewUpdater(conf *Config) *Updater {
client: conf.Client,
logger: conf.Logger,
cmdCons: conf.CommandConstructor,
version: conf.Version,
channel: conf.Channel,
goarch: conf.GOARCH,
@ -286,11 +294,27 @@ func (u *Updater) check(ctx context.Context) (err error) {
"%s" +
"end of the output"
cmd := exec.Command(u.updateExeName, "--check-config")
out, err := cmd.CombinedOutput()
code := cmd.ProcessState.ExitCode()
if err != nil || code != 0 {
return fmt.Errorf(format, err, code, out)
var (
args = []string{"--check-config"}
buf bytes.Buffer
)
u.logger.DebugContext(ctx, "executing", "cmd", u.updateExeName, "args", args)
err = executil.Run(
ctx,
u.cmdCons,
&executil.CommandConfig{
Path: u.updateExeName,
Args: args,
Stdout: &buf,
Stderr: &buf,
},
)
if err != nil {
code, _ := executil.ExitCodeFromError(err)
return fmt.Errorf(format, err, code, buf.Bytes())
}
return nil

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -58,13 +59,14 @@ func TestUpdater_internal(t *testing.T) {
fakeURL = fakeURL.JoinPath(tc.archiveName)
u := NewUpdater(&Config{
Client: fakeClient,
Logger: slogutil.NewDiscardLogger(),
GOOS: tc.os,
Version: "v0.103.0",
ExecPath: exePath,
WorkDir: wd,
ConfName: yamlPath,
Client: fakeClient,
Logger: slogutil.NewDiscardLogger(),
CommandConstructor: executil.EmptyCommandConstructor{},
GOOS: tc.os,
Version: "v0.103.0",
ExecPath: exePath,
WorkDir: wd,
ConfName: yamlPath,
// TODO(e.burkov): Rewrite the test to use a fake version check
// URL with a fake URLs for the package files.
VersionCheckURL: &url.URL{},

View File

@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -76,15 +77,16 @@ func TestUpdater_Update(t *testing.T) {
require.NoError(t, err)
u := updater.NewUpdater(&updater.Config{
Client: srv.Client(),
Logger: testLogger,
GOARCH: "amd64",
GOOS: "linux",
Version: "v0.103.0",
ConfName: yamlPath,
WorkDir: wd,
ExecPath: exePath,
VersionCheckURL: versionCheckURL,
Client: srv.Client(),
Logger: testLogger,
CommandConstructor: executil.EmptyCommandConstructor{},
GOARCH: "amd64",
GOOS: "linux",
Version: "v0.103.0",
ConfName: yamlPath,
WorkDir: wd,
ExecPath: exePath,
VersionCheckURL: versionCheckURL,
})
ctx := testutil.ContextWithTimeout(t, testTimeout)

View File

@ -13,7 +13,6 @@ import (
"maps"
"net/url"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
@ -23,6 +22,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
const (
@ -96,7 +96,7 @@ func main() {
errors.Check(cli.upload())
case "auto-add":
err := autoAdd(conf.LocalizableFiles[0])
err := autoAdd(ctx, l, conf.LocalizableFiles[0])
errors.Check(err)
default:
usage("unknown command")
@ -395,10 +395,12 @@ func findUnused(fileNames []string, loc locales) (err error) {
// autoAdd adds locales with additions to the git and restores locales with
// deletions.
func autoAdd(basePath string) (err error) {
func autoAdd(ctx context.Context, l *slog.Logger, basePath string) (err error) {
defer func() { err = errors.Annotate(err, "auto add: %w") }()
adds, dels, err := changedLocales()
cmdCons := executil.SystemCommandConstructor{}
adds, dels, err := changedLocales(ctx, l, cmdCons)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
@ -408,13 +410,13 @@ func autoAdd(basePath string) (err error) {
return errors.Error("base locale contains deletions")
}
err = handleAdds(adds)
err = handleAdds(ctx, l, cmdCons, adds)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil
}
err = handleDels(dels)
err = handleDels(ctx, l, cmdCons, dels)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil
@ -423,14 +425,24 @@ func autoAdd(basePath string) (err error) {
return nil
}
// gitCmd is the shell command for Git.
const gitCmd = "git"
// handleAdds adds locales with additions to the git.
func handleAdds(locales []string) (err error) {
func handleAdds(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
locales []string,
) (err error) {
if len(locales) == 0 {
return nil
}
args := append([]string{"add"}, locales...)
code, out, err := aghos.RunCommand("git", args...)
gitArgs := append([]string{"add"}, locales...)
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...)
if err != nil || code != 0 {
return fmt.Errorf("git add exited with code %d output %q: %w", code, out, err)
@ -440,13 +452,20 @@ func handleAdds(locales []string) (err error) {
}
// handleDels restores locales with deletions.
func handleDels(locales []string) (err error) {
func handleDels(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
locales []string,
) (err error) {
if len(locales) == 0 {
return nil
}
args := append([]string{"restore"}, locales...)
code, out, err := aghos.RunCommand("git", args...)
gitArgs := append([]string{"restore"}, locales...)
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...)
if err != nil || code != 0 {
return fmt.Errorf("git restore exited with code %d output %q: %w", code, out, err)
@ -458,22 +477,32 @@ func handleDels(locales []string) (err error) {
// changedLocales returns cleaned paths of locales with changes or error. adds
// is the list of locales with only additions. dels is the list of locales
// with only deletions.
func changedLocales() (adds, dels []string, err error) {
func changedLocales(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
) (adds, dels []string, err error) {
defer func() { err = errors.Annotate(err, "getting changes: %w") }()
cmd := exec.Command("git", "diff", "--numstat", localesDir)
gitArgs := []string{"diff", "--numstat", localesDir}
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
stdout, err := cmd.StdoutPipe()
// TODO(s.chzhen): Consider streaming the output if needed. Using
// [io.Pipe] here is unnecessary; it complicates lifecycle management
// because the output must be read concurrently, and the PipeWriter must be
// explicitly closed to signal EOF. Since this command's output is small, a
// bytes.Buffer via executil.Run is sufficient.
var out bytes.Buffer
err = executil.Run(ctx, cmdCons, &executil.CommandConfig{
Path: gitCmd,
Args: gitArgs,
Stdout: &out,
})
if err != nil {
return nil, nil, fmt.Errorf("piping: %w", err)
return nil, nil, fmt.Errorf("executing cmd: %w", err)
}
err = cmd.Start()
if err != nil {
return nil, nil, fmt.Errorf("starting: %w", err)
}
scanner := bufio.NewScanner(stdout)
scanner := bufio.NewScanner(&out)
for scanner.Scan() {
line := scanner.Text()
@ -497,10 +526,5 @@ func changedLocales() (adds, dels []string, err error) {
return nil, nil, fmt.Errorf("scanning: %w", err)
}
err = cmd.Wait()
if err != nil {
return nil, nil, fmt.Errorf("waiting: %w", err)
}
return adds, dels, nil
}