mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2025-10-26 11:27:18 +00:00
Pull request 2453: AGDNS-3086-arpdb-use-executil
Squashed commit of the following: commit 856439ef7c974d67b9afd19b4e4e9bed5b782de0 Merge:65d593d092e5005d7dAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Tue Sep 2 15:44:52 2025 +0300 Merge branch 'master' into AGDNS-3086-arpdb-use-executil commit65d593d091Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Aug 28 19:55:42 2025 +0300 all: imp docs commit3b569a94d7Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Aug 28 15:52:54 2025 +0300 agh: imp code commit803f79d144Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Aug 28 14:22:48 2025 +0300 all: imp code commit20a0702a0cAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Aug 27 14:57:35 2025 +0300 aghos: add todo commite400e1757dAuthor: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Aug 21 22:37:22 2025 +0300 all: imp docs commit971b5bc1b1Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Thu Aug 21 22:00:47 2025 +0300 all: imp code commit4fc73dca76Author: Stanislav Chzhen <s.chzhen@adguard.com> Date: Wed Aug 20 15:55:19 2025 +0300 all: use executil
This commit is contained in:
parent
2e5005d7df
commit
cd79a4ac72
@ -3,8 +3,18 @@ package agh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/AdguardTeam/golibs/testutil/fakeos/fakeexec"
|
||||
)
|
||||
|
||||
// DefaultOutputLimit is the default limit of bytes for commands' standard
|
||||
// output and standard error.
|
||||
const DefaultOutputLimit = 512
|
||||
|
||||
// ConfigModifier defines an interface for updating the global configuration.
|
||||
type ConfigModifier interface {
|
||||
// Apply applies changes to the global configuration.
|
||||
@ -20,3 +30,142 @@ var _ ConfigModifier = EmptyConfigModifier{}
|
||||
|
||||
// Apply implements the [ConfigModifier] for EmptyConfigModifier.
|
||||
func (em EmptyConfigModifier) Apply(ctx context.Context) {}
|
||||
|
||||
// exitErr implements [executil.ExitCodeError] for tests to simulate non-zero
|
||||
// process exit codes.
|
||||
//
|
||||
// TODO(s.chzhen): Consider constructing an [exec.ExitError] instead.
|
||||
type exitErr struct {
|
||||
code osutil.ExitCode
|
||||
}
|
||||
|
||||
// newExitErr returns a properly initialized exitErr with the provided code.
|
||||
func newExitErr(code osutil.ExitCode) (err exitErr) {
|
||||
return exitErr{code: code}
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ executil.ExitCodeError = exitErr{}
|
||||
|
||||
// Error implements the [executil.ExitCodeError] for exitErr.
|
||||
func (e exitErr) Error() (s string) {
|
||||
return fmt.Sprintf("exit code %d", e.code)
|
||||
}
|
||||
|
||||
// ExitCode implements the [executil.ExitCodeError] for exitErr.
|
||||
func (e exitErr) ExitCode() (code osutil.ExitCode) {
|
||||
return e.code
|
||||
}
|
||||
|
||||
// ExternalCommand is a fake command used by [NewMultipleCommandConstructor].
|
||||
type ExternalCommand struct {
|
||||
// Err is the error returned, if non-nil.
|
||||
Err error
|
||||
|
||||
// Cmd contains the command path and arguments.
|
||||
Cmd string
|
||||
|
||||
// Out is written to stdout if non-empty.
|
||||
Out string
|
||||
|
||||
// Code is returned as the exit code if non-zero.
|
||||
Code osutil.ExitCode
|
||||
}
|
||||
|
||||
// keyCommand builds a key for a command lookup.
|
||||
func keyCommand(path string, args []string) (k string) {
|
||||
if len(args) == 0 {
|
||||
return path
|
||||
}
|
||||
|
||||
return path + " " + strings.Join(args, " ")
|
||||
}
|
||||
|
||||
// parseCommand splits a command string into the executable path and args.
|
||||
func parseCommand(s string) (path string, args []string) {
|
||||
f := strings.Fields(s)
|
||||
if len(f) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return f[0], f[1:]
|
||||
}
|
||||
|
||||
// NewMultipleCommandConstructor is a helper function that returns a mock
|
||||
// [executil.CommandConstructor] for tests that supports multiple commands.
|
||||
//
|
||||
// TODO(s.chzhen): Move to aghtest once the import cycle is resolved, since it
|
||||
// will be called from the aghnet package, which imports the whois package,
|
||||
// which in turn imports aghnet.
|
||||
func NewMultipleCommandConstructor(cmds ...ExternalCommand) (cs executil.CommandConstructor) {
|
||||
table := make(map[string]ExternalCommand, len(cmds))
|
||||
for _, ec := range cmds {
|
||||
p, a := parseCommand(ec.Cmd)
|
||||
table[keyCommand(p, a)] = ec
|
||||
}
|
||||
|
||||
onNew := func(_ context.Context, conf *executil.CommandConfig) (c executil.Command, err error) {
|
||||
ec := table[keyCommand(conf.Path, conf.Args)]
|
||||
|
||||
cmd := fakeexec.NewCommand()
|
||||
cmd.OnStart = func(_ context.Context) (err error) {
|
||||
if ec.Out != "" {
|
||||
_, _ = conf.Stdout.Write([]byte(ec.Out))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.OnWait = func(_ context.Context) (err error) {
|
||||
if ec.Err != nil {
|
||||
return ec.Err
|
||||
}
|
||||
|
||||
if ec.Code != 0 {
|
||||
return newExitErr(ec.Code)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
return &fakeexec.CommandConstructor{OnNew: onNew}
|
||||
}
|
||||
|
||||
// NewCommandConstructor is a helper function that returns a mock
|
||||
// [executil.CommandConstructor] for tests.
|
||||
func NewCommandConstructor(
|
||||
_ string,
|
||||
code osutil.ExitCode,
|
||||
stdout string,
|
||||
cmdErr error,
|
||||
) (cs executil.CommandConstructor) {
|
||||
onNew := func(_ context.Context, conf *executil.CommandConfig) (c executil.Command, err error) {
|
||||
cmd := fakeexec.NewCommand()
|
||||
cmd.OnStart = func(_ context.Context) (err error) {
|
||||
if conf.Stdout != nil {
|
||||
_, _ = conf.Stdout.Write([]byte(stdout))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
cmd.OnWait = func(_ context.Context) (err error) {
|
||||
if cmdErr != nil {
|
||||
return cmdErr
|
||||
}
|
||||
|
||||
if code != 0 {
|
||||
return newExitErr(code)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
return &fakeexec.CommandConstructor{OnNew: onNew}
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
// DialContextFunc is the semantic alias for dialing functions, such as
|
||||
@ -27,7 +28,16 @@ type DialContextFunc = func(ctx context.Context, network, addr string) (conn net
|
||||
// Variables and functions to substitute in tests.
|
||||
var (
|
||||
// aghosRunCommand is the function to run shell commands.
|
||||
aghosRunCommand = aghos.RunCommand
|
||||
//
|
||||
// TODO(s.chzhen): Use [aghos.RunCommand] directly.
|
||||
aghosRunCommand = (func() func(string, ...string) (int, []byte, error) {
|
||||
ctx := context.TODO()
|
||||
cmdCons := executil.SystemCommandConstructor{}
|
||||
|
||||
return func(command string, arguments ...string) (int, []byte, error) {
|
||||
return aghos.RunCommand(ctx, cmdCons, command, arguments...)
|
||||
}
|
||||
})()
|
||||
|
||||
// netInterfaces is the function to get the available network interfaces.
|
||||
netInterfaceAddrs = net.InterfaceAddrs
|
||||
|
||||
@ -5,13 +5,13 @@ package aghos
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"runtime"
|
||||
"slices"
|
||||
@ -19,6 +19,9 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
// Default file, binary, and directory permissions.
|
||||
@ -50,23 +53,53 @@ func HaveAdminRights() (bool, error) {
|
||||
const MaxCmdOutputSize = 64 * 1024
|
||||
|
||||
// RunCommand runs shell command.
|
||||
func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
|
||||
cmd := exec.Command(command, arguments...)
|
||||
out, err := cmd.Output()
|
||||
//
|
||||
// TODO(s.chzhen): Consider removing this after addressing the current behavior
|
||||
// where a non-zero exit code is returned together with a nil error.
|
||||
func RunCommand(
|
||||
ctx context.Context,
|
||||
cmdCons executil.CommandConstructor,
|
||||
command string,
|
||||
arguments ...string,
|
||||
) (code int, output []byte, err error) {
|
||||
stdoutBuf := bytes.Buffer{}
|
||||
stderrBuf := bytes.Buffer{}
|
||||
|
||||
out = out[:min(len(out), MaxCmdOutputSize)]
|
||||
err = executil.Run(
|
||||
ctx,
|
||||
cmdCons,
|
||||
&executil.CommandConfig{
|
||||
Path: command,
|
||||
Args: arguments,
|
||||
Stdout: ioutil.NewTruncatedWriter(&stdoutBuf, MaxCmdOutputSize),
|
||||
Stderr: &stderrBuf,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
if eerr := new(exec.ExitError); errors.As(err, &eerr) {
|
||||
return eerr.ExitCode(), eerr.Stderr, nil
|
||||
}
|
||||
|
||||
return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out)
|
||||
if err == nil {
|
||||
return osutil.ExitCodeSuccess, stdoutBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
return cmd.ProcessState.ExitCode(), out, nil
|
||||
code, ok := executil.ExitCodeFromError(err)
|
||||
if ok {
|
||||
// Mirror the old behavior and return a nil-error on non-zero code
|
||||
// status.
|
||||
return code, stderrBuf.Bytes(), nil
|
||||
}
|
||||
|
||||
code = osutil.ExitCodeFailure
|
||||
|
||||
return code, nil, fmt.Errorf("command %q failed: %w: %s", command, err, &stdoutBuf)
|
||||
}
|
||||
|
||||
// psArgs holds the default ps arguments to avoid per-call slice allocations.
|
||||
//
|
||||
// Don't use -C flag here since it's a feature of linux's ps
|
||||
// implementation. Use POSIX-compatible flags instead.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3457.
|
||||
var psArgs = []string{"-A", "-o", "pid=", "-o", "comm="}
|
||||
|
||||
// PIDByCommand searches for process named command and returns its PID ignoring
|
||||
// the PIDs from except. If no processes found, the error returned. l must not
|
||||
// be nil.
|
||||
@ -76,31 +109,35 @@ func PIDByCommand(
|
||||
command string,
|
||||
except ...int,
|
||||
) (pid int, err error) {
|
||||
// Don't use -C flag here since it's a feature of linux's ps
|
||||
// implementation. Use POSIX-compatible flags instead.
|
||||
const psCmd = "ps"
|
||||
|
||||
l.DebugContext(ctx, "executing", "cmd", psCmd, "args", psArgs)
|
||||
|
||||
stdoutBuf := bytes.Buffer{}
|
||||
|
||||
// TODO(s.chzhen): Catch stderr.
|
||||
//
|
||||
// See https://github.com/AdguardTeam/AdGuardHome/issues/3457.
|
||||
cmd := exec.Command("ps", "-A", "-o", "pid=", "-o", "comm=")
|
||||
|
||||
var stdout io.ReadCloser
|
||||
if stdout, err = cmd.StdoutPipe(); err != nil {
|
||||
return 0, fmt.Errorf("getting the command's stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err = cmd.Start(); err != nil {
|
||||
return 0, fmt.Errorf("start command executing: %w", err)
|
||||
}
|
||||
// TODO(s.chzhen): Consider streaming the output if needed. Using
|
||||
// [io.Pipe] here is unnecessary; it complicates lifecycle management
|
||||
// because the output must be read concurrently, and the PipeWriter must be
|
||||
// explicitly closed to signal EOF. Since this command's output is small, a
|
||||
// bytes.Buffer via executil.Run is sufficient.
|
||||
runErr := executil.Run(
|
||||
ctx,
|
||||
executil.SystemCommandConstructor{},
|
||||
&executil.CommandConfig{
|
||||
Path: psCmd,
|
||||
Args: psArgs,
|
||||
Stdout: &stdoutBuf,
|
||||
},
|
||||
)
|
||||
|
||||
var instNum int
|
||||
pid, instNum, err = parsePSOutput(stdout, command, except)
|
||||
pid, instNum, err = parsePSOutput(&stdoutBuf, command, except)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err = cmd.Wait(); err != nil {
|
||||
return 0, fmt.Errorf("executing the command: %w", err)
|
||||
}
|
||||
|
||||
switch instNum {
|
||||
case 0:
|
||||
// TODO(e.burkov): Use constant error.
|
||||
@ -111,8 +148,12 @@ func PIDByCommand(
|
||||
l.WarnContext(ctx, "instances found", "num", instNum, "command", command)
|
||||
}
|
||||
|
||||
if code := cmd.ProcessState.ExitCode(); code != 0 {
|
||||
return 0, fmt.Errorf("ps finished with code %d", code)
|
||||
if runErr != nil {
|
||||
if code, ok := executil.ExitCodeFromError(runErr); ok {
|
||||
return 0, fmt.Errorf("ps finished with code %d", code)
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("executing the command: %w", runErr)
|
||||
}
|
||||
|
||||
return pid, nil
|
||||
|
||||
@ -4,6 +4,7 @@ package arpdb
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
@ -11,18 +12,16 @@ import (
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
)
|
||||
|
||||
// Variables and functions to substitute in tests.
|
||||
var (
|
||||
// aghosRunCommand is the function to run shell commands.
|
||||
aghosRunCommand = aghos.RunCommand
|
||||
|
||||
// rootDirFS is the filesystem pointing to the root directory.
|
||||
rootDirFS = osutil.RootDirFS()
|
||||
)
|
||||
@ -30,17 +29,17 @@ var (
|
||||
// Interface stores and refreshes the network neighborhood reported by ARP
|
||||
// (Address Resolution Protocol).
|
||||
type Interface interface {
|
||||
// Refresh updates the stored data. It must be safe for concurrent use.
|
||||
Refresh() (err error)
|
||||
// Refresher updates the stored data. It must be safe for concurrent use.
|
||||
service.Refresher
|
||||
|
||||
// Neighbors returnes the last set of data reported by ARP. Both the method
|
||||
// Neighbors returns the last set of data reported by ARP. Both the method
|
||||
// and it's result must be safe for concurrent use.
|
||||
Neighbors() (ns []Neighbor)
|
||||
}
|
||||
|
||||
// New returns the [Interface] properly initialized for the OS.
|
||||
func New(logger *slog.Logger) (arp Interface) {
|
||||
return newARPDB(logger)
|
||||
return newARPDB(logger, executil.SystemCommandConstructor{})
|
||||
}
|
||||
|
||||
// Empty is the [Interface] implementation that does nothing.
|
||||
@ -51,7 +50,7 @@ var _ Interface = Empty{}
|
||||
|
||||
// Refresh implements the [Interface] interface for EmptyARPContainer. It does
|
||||
// nothing and always returns nil error.
|
||||
func (Empty) Refresh() (err error) { return nil }
|
||||
func (Empty) Refresh(_ context.Context) (err error) { return nil }
|
||||
|
||||
// Neighbors implements the [Interface] interface for EmptyARPContainer. It
|
||||
// always returns nil.
|
||||
@ -164,28 +163,40 @@ type parseNeighsFunc func(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (
|
||||
// cmdARPDB is the implementation of the [Interface] that uses command line to
|
||||
// retrieve data.
|
||||
type cmdARPDB struct {
|
||||
logger *slog.Logger
|
||||
parse parseNeighsFunc
|
||||
ns *neighs
|
||||
cmd string
|
||||
args []string
|
||||
logger *slog.Logger
|
||||
cmdCons executil.CommandConstructor
|
||||
parse parseNeighsFunc
|
||||
ns *neighs
|
||||
cmd string
|
||||
args []string
|
||||
}
|
||||
|
||||
// type check
|
||||
var _ Interface = (*cmdARPDB)(nil)
|
||||
|
||||
// Refresh implements the [Interface] interface for *cmdARPDB.
|
||||
func (arp *cmdARPDB) Refresh() (err error) {
|
||||
func (arp *cmdARPDB) Refresh(ctx context.Context) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }()
|
||||
|
||||
code, out, err := aghosRunCommand(arp.cmd, arp.args...)
|
||||
var stdout bytes.Buffer
|
||||
err = executil.Run(
|
||||
ctx,
|
||||
arp.cmdCons,
|
||||
&executil.CommandConfig{
|
||||
Path: arp.cmd,
|
||||
Args: arp.args,
|
||||
Stdout: &stdout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
if code, ok := executil.ExitCodeFromError(err); ok {
|
||||
return fmt.Errorf("running command: unexpected exit code %d", code)
|
||||
}
|
||||
|
||||
return fmt.Errorf("running command: %w", err)
|
||||
} else if code != 0 {
|
||||
return fmt.Errorf("running command: unexpected exit code %d", code)
|
||||
}
|
||||
|
||||
sc := bufio.NewScanner(bytes.NewReader(out))
|
||||
sc := bufio.NewScanner(&stdout)
|
||||
ns := arp.parse(arp.logger, sc, arp.ns.len())
|
||||
if err = sc.Err(); err != nil {
|
||||
// TODO(e.burkov): This error seems unreachable. Investigate.
|
||||
@ -226,11 +237,11 @@ func newARPDBs(arps ...Interface) (arp *arpdbs) {
|
||||
var _ Interface = (*arpdbs)(nil)
|
||||
|
||||
// Refresh implements the [Interface] interface for *arpdbs.
|
||||
func (arp *arpdbs) Refresh() (err error) {
|
||||
func (arp *arpdbs) Refresh(ctx context.Context) (err error) {
|
||||
var errs []error
|
||||
|
||||
for _, a := range arp.arps {
|
||||
err = a.Refresh()
|
||||
err = a.Refresh(ctx)
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
||||
|
||||
@ -9,12 +9,14 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseArpA,
|
||||
logger: logger,
|
||||
cmdCons: cmdCons,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
package arpdb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
@ -17,49 +18,15 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// testTimeout is a common timeout for tests.
|
||||
const testTimeout = 1 * time.Second
|
||||
|
||||
// testdata is the filesystem containing data for testing the package.
|
||||
var testdata fs.FS = os.DirFS("./testdata")
|
||||
|
||||
// RunCmdFunc is the signature of aghos.RunCommand function.
|
||||
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
|
||||
|
||||
// substShell replaces the the aghos.RunCommand function used throughout the
|
||||
// package with rc for tests ran under tb.
|
||||
func substShell(tb testing.TB, rc RunCmdFunc) {
|
||||
tb.Helper()
|
||||
|
||||
prev := aghosRunCommand
|
||||
tb.Cleanup(func() { aghosRunCommand = prev })
|
||||
aghosRunCommand = rc
|
||||
}
|
||||
|
||||
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
|
||||
// execution result. It's only needed to simplify testing.
|
||||
//
|
||||
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
|
||||
type mapShell map[string]struct {
|
||||
err error
|
||||
out string
|
||||
code int
|
||||
}
|
||||
|
||||
// theOnlyCmd returns mapShell that only handles a single command and arguments
|
||||
// combination from cmd.
|
||||
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
|
||||
return mapShell{cmd: {code: code, out: out, err: err}}
|
||||
}
|
||||
|
||||
// RunCmd is a RunCmdFunc handled by s.
|
||||
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
|
||||
key := strings.Join(append([]string{cmd}, args...), " ")
|
||||
ret, ok := s[key]
|
||||
if !ok {
|
||||
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
|
||||
}
|
||||
|
||||
return ret.code, []byte(ret.out), ret.err
|
||||
}
|
||||
|
||||
func Test_New(t *testing.T) {
|
||||
var a Interface
|
||||
require.NotPanics(t, func() { a = New(slogutil.NewDiscardLogger()) })
|
||||
@ -71,7 +38,7 @@ func Test_New(t *testing.T) {
|
||||
|
||||
// TestARPDB is the mock implementation of [Interface] to use in tests.
|
||||
type TestARPDB struct {
|
||||
OnRefresh func() (err error)
|
||||
OnRefresh func(ctx context.Context) (err error)
|
||||
OnNeighbors func() (ns []Neighbor)
|
||||
}
|
||||
|
||||
@ -79,8 +46,8 @@ type TestARPDB struct {
|
||||
var _ Interface = (*TestARPDB)(nil)
|
||||
|
||||
// Refresh implements the [Interface] interface for *TestARPDB.
|
||||
func (arp *TestARPDB) Refresh() (err error) {
|
||||
return arp.OnRefresh()
|
||||
func (arp *TestARPDB) Refresh(ctx context.Context) (err error) {
|
||||
return arp.OnRefresh(ctx)
|
||||
}
|
||||
|
||||
// Neighbors implements the [Interface] interface for *TestARPDB.
|
||||
@ -98,13 +65,17 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
}
|
||||
|
||||
succDB := &TestARPDB{
|
||||
OnRefresh: func() (err error) { succRefrCount++; return nil },
|
||||
OnRefresh: func(_ context.Context) (err error) { succRefrCount++; return nil },
|
||||
OnNeighbors: func() (ns []Neighbor) {
|
||||
return []Neighbor{{Name: "abc", IP: knownIP, MAC: knownMAC}}
|
||||
},
|
||||
}
|
||||
failDB := &TestARPDB{
|
||||
OnRefresh: func() (err error) { failRefrCount++; return errors.Error("refresh failed") },
|
||||
OnRefresh: func(_ context.Context) (err error) {
|
||||
failRefrCount++
|
||||
|
||||
return errors.Error("refresh failed")
|
||||
},
|
||||
OnNeighbors: func() (ns []Neighbor) { return nil },
|
||||
}
|
||||
|
||||
@ -112,7 +83,7 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
t.Cleanup(clnp)
|
||||
|
||||
a := newARPDBs(succDB, failDB)
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, succRefrCount)
|
||||
@ -124,7 +95,7 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
t.Cleanup(clnp)
|
||||
|
||||
a := newARPDBs(failDB, succDB)
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, succRefrCount)
|
||||
@ -138,7 +109,7 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
wantMsg := "each arpdb failed: refresh failed\nrefresh failed"
|
||||
|
||||
a := newARPDBs(failDB, failDB)
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.Error(t, err)
|
||||
|
||||
testutil.AssertErrorMsg(t, wantMsg, err)
|
||||
@ -152,7 +123,7 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
|
||||
shouldFail := false
|
||||
unstableDB := &TestARPDB{
|
||||
OnRefresh: func() (err error) {
|
||||
OnRefresh: func(_ context.Context) (err error) {
|
||||
if shouldFail {
|
||||
err = errors.Error("unstable failed")
|
||||
}
|
||||
@ -171,21 +142,21 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
a := newARPDBs(unstableDB, succDB)
|
||||
|
||||
// Unstable ARPDB should refresh successfully.
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Zero(t, succRefrCount)
|
||||
assert.NotEmpty(t, a.Neighbors())
|
||||
|
||||
// Unstable ARPDB should fail and the succDB should be used.
|
||||
err = a.Refresh()
|
||||
err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, succRefrCount)
|
||||
assert.NotEmpty(t, a.Neighbors())
|
||||
|
||||
// Unstable ARPDB should refresh successfully again.
|
||||
err = a.Refresh()
|
||||
err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, 1, succRefrCount)
|
||||
@ -194,7 +165,7 @@ func Test_NewARPDBs(t *testing.T) {
|
||||
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
a := newARPDBs()
|
||||
require.NoError(t, a.Refresh())
|
||||
require.NoError(t, a.Refresh(testutil.ContextWithTimeout(t, testTimeout)))
|
||||
|
||||
assert.Empty(t, a.Neighbors())
|
||||
})
|
||||
@ -212,36 +183,36 @@ func TestCmdARPDB_arpa(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Run("arp_a", func(t *testing.T) {
|
||||
sh := theOnlyCmd("cmd", 0, arpAOutput, nil)
|
||||
substShell(t, sh.RunCmd)
|
||||
a.cmdCons = agh.NewCommandConstructor("cmd", 0, arpAOutput, nil)
|
||||
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantNeighs, a.Neighbors())
|
||||
})
|
||||
|
||||
t.Run("runcmd_error", func(t *testing.T) {
|
||||
sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run"))
|
||||
substShell(t, sh.RunCmd)
|
||||
a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", errors.Error("can't run"))
|
||||
|
||||
err := a.Refresh()
|
||||
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
testutil.AssertErrorMsg(t, "cmd arpdb: running command: running: can't run", err)
|
||||
})
|
||||
|
||||
t.Run("bad_code", func(t *testing.T) {
|
||||
sh := theOnlyCmd("cmd", 1, "", nil)
|
||||
substShell(t, sh.RunCmd)
|
||||
a.cmdCons = agh.NewCommandConstructor("cmd", 1, "", nil)
|
||||
|
||||
err := a.Refresh()
|
||||
testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err)
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
testutil.AssertErrorMsg(
|
||||
t,
|
||||
"cmd arpdb: running command: unexpected exit code 1",
|
||||
err,
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("empty", func(t *testing.T) {
|
||||
sh := theOnlyCmd("cmd", 0, "", nil)
|
||||
substShell(t, sh.RunCmd)
|
||||
a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", nil)
|
||||
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Empty(t, a.Neighbors())
|
||||
@ -254,7 +225,7 @@ func TestEmptyARPDB(t *testing.T) {
|
||||
t.Run("refresh", func(t *testing.T) {
|
||||
var err error
|
||||
require.NotPanics(t, func() {
|
||||
err = a.Refresh()
|
||||
err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
})
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
@ -4,6 +4,7 @@ package arpdb
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
@ -14,10 +15,11 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/AdguardTeam/golibs/stringutil"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *arpdbs) {
|
||||
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *arpdbs) {
|
||||
// Use the common storage among the implementations.
|
||||
ns := &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
@ -40,10 +42,11 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
|
||||
},
|
||||
// Then, try "arp -a -n".
|
||||
&cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseF,
|
||||
ns: ns,
|
||||
cmd: "arp",
|
||||
logger: logger,
|
||||
cmdCons: cmdCons,
|
||||
parse: parseF,
|
||||
ns: ns,
|
||||
cmd: "arp",
|
||||
// Use -n flag to avoid resolving the hostnames of the neighbors.
|
||||
// By default ARP attempts to resolve the hostnames via DNS. See
|
||||
// man 8 arp.
|
||||
@ -53,11 +56,12 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
|
||||
},
|
||||
// Finally, try "ip neigh".
|
||||
&cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseIPNeigh,
|
||||
ns: ns,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
logger: logger,
|
||||
cmdCons: cmdCons,
|
||||
parse: parseIPNeigh,
|
||||
ns: ns,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
},
|
||||
)
|
||||
}
|
||||
@ -73,7 +77,7 @@ type fsysARPDB struct {
|
||||
var _ Interface = (*fsysARPDB)(nil)
|
||||
|
||||
// Refresh implements the [Interface] interface for *fsysARPDB.
|
||||
func (arp *fsysARPDB) Refresh() (err error) {
|
||||
func (arp *fsysARPDB) Refresh(_ context.Context) (err error) {
|
||||
var f fs.File
|
||||
f, err = arp.fsys.Open(arp.filename)
|
||||
if err != nil {
|
||||
|
||||
@ -9,7 +9,9 @@ import (
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/agh"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -54,7 +56,7 @@ func TestFSysARPDB(t *testing.T) {
|
||||
filename: "proc_net_arp",
|
||||
}
|
||||
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
ns := a.Neighbors()
|
||||
@ -62,25 +64,20 @@ func TestFSysARPDB(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCmdARPDB_linux(t *testing.T) {
|
||||
sh := mapShell{
|
||||
"arp -a": {err: nil, out: arpAOutputWrt, code: 0},
|
||||
"ip neigh": {err: nil, out: ipNeighOutput, code: 0},
|
||||
}
|
||||
substShell(t, sh.RunCmd)
|
||||
|
||||
t.Run("wrt", func(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
parse: parseArpAWrt,
|
||||
cmd: "arp",
|
||||
args: []string{"-a"},
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
cmdCons: agh.NewCommandConstructor("arp -a", 0, arpAOutputWrt, nil),
|
||||
parse: parseArpAWrt,
|
||||
cmd: "arp",
|
||||
args: []string{"-a"},
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
}
|
||||
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantNeighs, a.Neighbors())
|
||||
@ -88,16 +85,17 @@ func TestCmdARPDB_linux(t *testing.T) {
|
||||
|
||||
t.Run("ip_neigh", func(t *testing.T) {
|
||||
a := &cmdARPDB{
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
parse: parseIPNeigh,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
logger: slogutil.NewDiscardLogger(),
|
||||
cmdCons: agh.NewCommandConstructor("ip neigh", 0, ipNeighOutput, nil),
|
||||
parse: parseIPNeigh,
|
||||
cmd: "ip",
|
||||
args: []string{"neigh"},
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
},
|
||||
}
|
||||
err := a.Refresh()
|
||||
err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, wantNeighs, a.Neighbors())
|
||||
|
||||
@ -9,12 +9,14 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseArpA,
|
||||
logger: logger,
|
||||
cmdCons: cmdCons,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
|
||||
@ -9,12 +9,14 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
|
||||
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
|
||||
return &cmdARPDB{
|
||||
logger: logger,
|
||||
parse: parseArpA,
|
||||
logger: logger,
|
||||
cmdCons: cmdCons,
|
||||
parse: parseArpA,
|
||||
ns: &neighs{
|
||||
mu: &sync.RWMutex{},
|
||||
ns: make([]Neighbor, 0),
|
||||
|
||||
@ -249,7 +249,7 @@ func (s *Storage) addFromSystemARP(ctx context.Context) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if err := s.arpDB.Refresh(); err != nil {
|
||||
if err := s.arpDB.Refresh(ctx); err != nil {
|
||||
s.arpDB = arpdb.Empty{}
|
||||
s.logger.ErrorContext(ctx, "refreshing arp container", slogutil.KeyError, err)
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package client_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@ -18,6 +19,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/hostsfile"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/service"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/AdguardTeam/golibs/testutil/faketime"
|
||||
"github.com/AdguardTeam/golibs/timeutil"
|
||||
@ -63,17 +65,17 @@ func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) {
|
||||
// Interface stores and refreshes the network neighborhood reported by ARP
|
||||
// (Address Resolution Protocol).
|
||||
type Interface interface {
|
||||
// Refresh updates the stored data. It must be safe for concurrent use.
|
||||
Refresh() (err error)
|
||||
// Refresher updates the stored data. It must be safe for concurrent use.
|
||||
service.Refresher
|
||||
|
||||
// Neighbors returnes the last set of data reported by ARP. Both the method
|
||||
// Neighbors returns the last set of data reported by ARP. Both the method
|
||||
// and it's result must be safe for concurrent use.
|
||||
Neighbors() (ns []arpdb.Neighbor)
|
||||
}
|
||||
|
||||
// testARPDB is a mock implementation of the [arpdb.Interface].
|
||||
type testARPDB struct {
|
||||
onRefresh func() (err error)
|
||||
onRefresh func(ctx context.Context) (err error)
|
||||
onNeighbors func() (ns []arpdb.Neighbor)
|
||||
}
|
||||
|
||||
@ -81,8 +83,8 @@ type testARPDB struct {
|
||||
var _ arpdb.Interface = (*testARPDB)(nil)
|
||||
|
||||
// Refresh implements the [arpdb.Interface] interface for *testARP.
|
||||
func (c *testARPDB) Refresh() (err error) {
|
||||
return c.onRefresh()
|
||||
func (c *testARPDB) Refresh(ctx context.Context) (err error) {
|
||||
return c.onRefresh(ctx)
|
||||
}
|
||||
|
||||
// Neighbors implements the [arpdb.Interface] interface for *testARP.
|
||||
@ -218,7 +220,7 @@ func TestStorage_Add_arp(t *testing.T) {
|
||||
)
|
||||
|
||||
a := &testARPDB{
|
||||
onRefresh: func() (err error) { return nil },
|
||||
onRefresh: func(_ context.Context) (err error) { return nil },
|
||||
onNeighbors: func() (ns []arpdb.Neighbor) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
@ -392,7 +394,7 @@ func TestClientsDHCP(t *testing.T) {
|
||||
|
||||
arpCh := make(chan []arpdb.Neighbor, 1)
|
||||
arpDB := &testARPDB{
|
||||
onRefresh: func() (err error) { return nil },
|
||||
onRefresh: func(_ context.Context) (err error) { return nil },
|
||||
onNeighbors: func() (ns []arpdb.Neighbor) {
|
||||
select {
|
||||
case ns = <-arpCh:
|
||||
|
||||
@ -9,7 +9,6 @@ import (
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
@ -20,8 +19,10 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/container"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
)
|
||||
|
||||
@ -129,6 +130,7 @@ func (req *checkConfReq) validateDNS(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
tcpPorts aghalg.UniqChecker[tcpPort],
|
||||
cmdCons executil.CommandConstructor,
|
||||
) (canAutofix bool, err error) {
|
||||
defer func() { err = errors.Annotate(err, "validating ports: %w") }()
|
||||
|
||||
@ -160,7 +162,7 @@ func (req *checkConfReq) validateDNS(
|
||||
// Try to fix automatically.
|
||||
canAutofix = checkDNSStubListener(ctx, l)
|
||||
if canAutofix && req.DNS.Autofix {
|
||||
if derr := disableDNSStubListener(ctx, l); derr != nil {
|
||||
if derr := disableDNSStubListener(ctx, l, cmdCons); derr != nil {
|
||||
l.ErrorContext(ctx, "disabling DNSStubListener", slogutil.KeyError, err)
|
||||
}
|
||||
|
||||
@ -188,7 +190,8 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque
|
||||
resp.Web.Status = err.Error()
|
||||
}
|
||||
|
||||
if resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts); err != nil {
|
||||
resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts, web.cmdCons)
|
||||
if err != nil {
|
||||
resp.DNS.Status = err.Error()
|
||||
} else if !req.DNS.IP.IsUnspecified() {
|
||||
resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP)
|
||||
@ -243,34 +246,29 @@ func checkDNSStubListener(ctx context.Context, l *slog.Logger) (ok bool) {
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
|
||||
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
|
||||
_, err := cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
l.InfoContext(
|
||||
cmds := container.KeyValues[string, []string]{{
|
||||
Key: "systemctl",
|
||||
Value: []string{"is-enabled", "systemd-resolved"},
|
||||
}, {
|
||||
Key: "grep",
|
||||
Value: []string{"-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf"},
|
||||
}}
|
||||
|
||||
for _, cmd := range cmds {
|
||||
l.DebugContext(ctx, "executing", "cmd", cmd.Key, "args", cmd.Value)
|
||||
|
||||
err := executil.RunWithPeek(
|
||||
ctx,
|
||||
"execution failed",
|
||||
"cmd", cmd.Path,
|
||||
"code", cmd.ProcessState.ExitCode(),
|
||||
slogutil.KeyError, err,
|
||||
executil.SystemCommandConstructor{},
|
||||
agh.DefaultOutputLimit,
|
||||
cmd.Key,
|
||||
cmd.Value...,
|
||||
)
|
||||
if err != nil {
|
||||
l.InfoContext(ctx, "execution failed", "cmd", cmd.Key, slogutil.KeyError, err)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
|
||||
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
if err != nil || cmd.ProcessState.ExitCode() != 0 {
|
||||
l.InfoContext(
|
||||
ctx,
|
||||
"execution failed",
|
||||
"cmd", cmd.Path,
|
||||
"code", cmd.ProcessState.ExitCode(),
|
||||
slogutil.KeyError, err,
|
||||
)
|
||||
|
||||
return false
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
@ -287,7 +285,11 @@ const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
// disableDNSStubListener deactivates DNSStubListerner and returns an error, if
|
||||
// any.
|
||||
func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) {
|
||||
func disableDNSStubListener(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
cmdCons executil.CommandConstructor,
|
||||
) (err error) {
|
||||
dir := filepath.Dir(resolvedConfPath)
|
||||
err = os.MkdirAll(dir, 0o755)
|
||||
if err != nil {
|
||||
@ -306,15 +308,21 @@ func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) {
|
||||
return fmt.Errorf("os.Symlink: %s: %w", resolvConfPath, err)
|
||||
}
|
||||
|
||||
cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
|
||||
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
|
||||
_, err = cmd.Output()
|
||||
const systemctlCmd = "systemctl"
|
||||
|
||||
systemctlArgs := []string{"reload-or-restart", "systemd-resolved"}
|
||||
|
||||
l.DebugContext(ctx, "executing", "cmd", systemctlCmd, "args", systemctlArgs)
|
||||
|
||||
err = executil.RunWithPeek(
|
||||
ctx,
|
||||
cmdCons,
|
||||
agh.DefaultOutputLimit,
|
||||
systemctlCmd,
|
||||
systemctlArgs...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmd.ProcessState.ExitCode() != 0 {
|
||||
return fmt.Errorf("process %s exited with an error: %d",
|
||||
cmd.Path, cmd.ProcessState.ExitCode())
|
||||
return fmt.Errorf("executing cmd: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"syscall"
|
||||
"time"
|
||||
@ -19,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
// temporaryError is the interface for temporary errors from the Go standard
|
||||
@ -148,7 +148,13 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) {
|
||||
// The background context is used because the underlying functions wrap it
|
||||
// with timeout and shut down the server, which handles current request. It
|
||||
// also should be done in a separate goroutine for the same reason.
|
||||
go finishUpdate(context.Background(), web.logger, execPath, web.conf.runningAsService)
|
||||
go finishUpdate(
|
||||
context.Background(),
|
||||
web.logger,
|
||||
web.cmdCons,
|
||||
execPath,
|
||||
web.conf.runningAsService,
|
||||
)
|
||||
}
|
||||
|
||||
// versionResponse is the response for /control/version.json endpoint.
|
||||
@ -187,7 +193,13 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) {
|
||||
|
||||
// finishUpdate completes an update procedure. It is intended to be used as a
|
||||
// goroutine.
|
||||
func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningAsService bool) {
|
||||
func finishUpdate(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
cmdCons executil.CommandConstructor,
|
||||
execPath string,
|
||||
runningAsService bool,
|
||||
) {
|
||||
defer slogutil.RecoverAndExit(ctx, l, osutil.ExitCodeFailure)
|
||||
|
||||
l.InfoContext(ctx, "stopping all tasks")
|
||||
@ -203,8 +215,16 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA
|
||||
// instance, because Windows doesn't allow it.
|
||||
//
|
||||
// TODO(a.garipov): Recheck the claim above.
|
||||
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
|
||||
err = cmd.Start()
|
||||
var cmd executil.Command
|
||||
cmd, err = cmdCons.New(ctx, &executil.CommandConfig{
|
||||
Path: "cmd",
|
||||
Args: []string{"/c", "net stop AdGuardHome & net start AdGuardHome"},
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("constructing cmd: %w", err))
|
||||
}
|
||||
|
||||
err = cmd.Start(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("restarting service: %w", err))
|
||||
}
|
||||
@ -212,12 +232,21 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA
|
||||
os.Exit(osutil.ExitCodeSuccess)
|
||||
}
|
||||
|
||||
cmd := exec.Command(execPath, os.Args[1:]...)
|
||||
l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:])
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err = cmd.Start()
|
||||
|
||||
var cmd executil.Command
|
||||
cmd, err = cmdCons.New(ctx, &executil.CommandConfig{
|
||||
Path: execPath,
|
||||
Args: os.Args[1:],
|
||||
Stdin: os.Stdin,
|
||||
Stdout: os.Stdout,
|
||||
Stderr: os.Stderr,
|
||||
})
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("constructing cmd: %w", err))
|
||||
}
|
||||
|
||||
err = cmd.Start(ctx)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("restarting: %w", err))
|
||||
}
|
||||
|
||||
@ -43,6 +43,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
// Global context
|
||||
@ -589,12 +590,13 @@ func initWeb(
|
||||
disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL)
|
||||
|
||||
webConf := &webConfig{
|
||||
updater: upd,
|
||||
logger: logger,
|
||||
baseLogger: baseLogger,
|
||||
confModifier: confModifier,
|
||||
tlsManager: tlsMgr,
|
||||
auth: auth,
|
||||
CommandConstructor: executil.SystemCommandConstructor{},
|
||||
updater: upd,
|
||||
logger: logger,
|
||||
baseLogger: baseLogger,
|
||||
confModifier: confModifier,
|
||||
tlsManager: tlsMgr,
|
||||
auth: auth,
|
||||
|
||||
clientFS: clientFS,
|
||||
|
||||
@ -821,18 +823,19 @@ func newUpdater(
|
||||
l.DebugContext(ctx, "creating updater", "config_path", confPath)
|
||||
|
||||
return updater.NewUpdater(&updater.Config{
|
||||
Client: conf.Filtering.HTTPClient,
|
||||
Logger: l,
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: workDir,
|
||||
ConfName: confPath,
|
||||
ExecPath: execPath,
|
||||
VersionCheckURL: versionURL,
|
||||
Client: conf.Filtering.HTTPClient,
|
||||
Logger: l,
|
||||
CommandConstructor: executil.SystemCommandConstructor{},
|
||||
Version: version.Version(),
|
||||
Channel: version.Channel(),
|
||||
GOARCH: runtime.GOARCH,
|
||||
GOOS: runtime.GOOS,
|
||||
GOARM: version.GOARM(),
|
||||
GOMIPS: version.GOMIPS(),
|
||||
WorkDir: workDir,
|
||||
ConfName: confPath,
|
||||
ExecPath: execPath,
|
||||
VersionCheckURL: versionURL,
|
||||
}), isCustomURL
|
||||
}
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
@ -76,11 +77,11 @@ func (p *program) Stop(_ service.Service) (err error) {
|
||||
//
|
||||
// On OpenWrt, the service utility may not exist. We use our service script
|
||||
// directly in this case.
|
||||
func svcStatus(s service.Service) (status service.Status, err error) {
|
||||
func svcStatus(ctx context.Context, s service.Service) (status service.Status, err error) {
|
||||
status, err = s.Status()
|
||||
if err != nil && service.Platform() == "unix-systemv" {
|
||||
var code int
|
||||
code, err = runInitdCommand("status")
|
||||
code, err = runInitdCommand(ctx, "status")
|
||||
if err != nil || code != 0 {
|
||||
return service.StatusStopped, nil
|
||||
}
|
||||
@ -105,7 +106,7 @@ func svcAction(ctx context.Context, l *slog.Logger, s service.Service, action st
|
||||
err = service.Control(s, action)
|
||||
if err != nil && service.Platform() == "unix-systemv" &&
|
||||
(action == "start" || action == "stop" || action == "restart") {
|
||||
_, err = runInitdCommand(action)
|
||||
_, err = runInitdCommand(ctx, action)
|
||||
}
|
||||
|
||||
return err
|
||||
@ -326,7 +327,7 @@ func handleServiceStatusCommand(
|
||||
l *slog.Logger,
|
||||
s service.Service,
|
||||
) {
|
||||
status, errSt := svcStatus(s)
|
||||
status, errSt := svcStatus(ctx, s)
|
||||
if errSt != nil {
|
||||
l.ErrorContext(ctx, "failed to get service status", slogutil.KeyError, errSt)
|
||||
os.Exit(osutil.ExitCodeFailure)
|
||||
@ -356,7 +357,7 @@ func handleServiceInstallCommand(ctx context.Context, l *slog.Logger, s service.
|
||||
// On OpenWrt it is important to run enable after the service
|
||||
// installation. Otherwise, the service won't start on the system
|
||||
// startup.
|
||||
_, err = runInitdCommand("enable")
|
||||
_, err = runInitdCommand(ctx, "enable")
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "running init enable", slogutil.KeyError, err)
|
||||
os.Exit(osutil.ExitCodeFailure)
|
||||
@ -386,7 +387,7 @@ func handleServiceUninstallCommand(ctx context.Context, l *slog.Logger, s servic
|
||||
if aghos.IsOpenWrt() {
|
||||
// On OpenWrt it is important to run disable command first
|
||||
// as it will remove the symlink
|
||||
_, err := runInitdCommand("disable")
|
||||
_, err := runInitdCommand(ctx, "disable")
|
||||
if err != nil {
|
||||
l.ErrorContext(ctx, "running init disable", slogutil.KeyError, err)
|
||||
os.Exit(osutil.ExitCodeFailure)
|
||||
@ -458,10 +459,11 @@ func configureService(c *service.Config) {
|
||||
|
||||
// runInitdCommand runs init.d service command
|
||||
// returns command code or error if any
|
||||
func runInitdCommand(action string) (int, error) {
|
||||
func runInitdCommand(ctx context.Context, action string) (int, error) {
|
||||
confPath := "/etc/init.d/" + serviceName
|
||||
// Pass the script and action as a single string argument.
|
||||
code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action)
|
||||
cmdCons := executil.SystemCommandConstructor{}
|
||||
code, _, err := aghos.RunCommand(ctx, cmdCons, "sh", "-c", confPath+" "+action)
|
||||
|
||||
return code, err
|
||||
}
|
||||
|
||||
@ -4,12 +4,14 @@ package home
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
@ -58,6 +60,7 @@ func (sys *sysvSystem) New(i service.Interface, c *service.Config) (s service.Se
|
||||
}
|
||||
|
||||
return &sysvService{
|
||||
cmdCons: executil.SystemCommandConstructor{},
|
||||
Service: s,
|
||||
name: c.Name,
|
||||
}, nil
|
||||
@ -66,6 +69,9 @@ func (sys *sysvSystem) New(i service.Interface, c *service.Config) (s service.Se
|
||||
// sysvService is a wrapper for a SysV [service.Service] that supplements the
|
||||
// installation and uninstallation.
|
||||
type sysvService struct {
|
||||
// cmdCons is used to run external commands. It must not be nil.
|
||||
cmdCons executil.CommandConstructor
|
||||
|
||||
// Service must have an unexported type *service.sysv.
|
||||
service.Service
|
||||
|
||||
@ -85,7 +91,8 @@ func (svc *sysvService) Install() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "defaults")
|
||||
// TODO(s.chzhen): Pass context.
|
||||
_, _, err = aghos.RunCommand(context.TODO(), svc.cmdCons, "update-rc.d", svc.name, "defaults")
|
||||
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
@ -100,7 +107,8 @@ func (svc *sysvService) Uninstall() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "remove")
|
||||
// TODO(s.chzhen): Pass context.
|
||||
_, _, err = aghos.RunCommand(context.TODO(), svc.cmdCons, "update-rc.d", svc.name, "remove")
|
||||
|
||||
// Don't wrap an error since it's informative enough as is.
|
||||
return err
|
||||
@ -127,6 +135,7 @@ func (sys *systemdSystem) New(i service.Interface, c *service.Config) (s service
|
||||
}
|
||||
|
||||
return &systemdService{
|
||||
cmdCons: executil.SystemCommandConstructor{},
|
||||
Service: s,
|
||||
unitName: fmt.Sprintf("%s.service", c.Name),
|
||||
}, nil
|
||||
@ -138,6 +147,9 @@ var _ service.Service = (*systemdService)(nil)
|
||||
// systemdService is a wrapper for a systemd [service.Service] that enriches the
|
||||
// service status information.
|
||||
type systemdService struct {
|
||||
// cmdCons is used to run external commands. It must not be nil.
|
||||
cmdCons executil.CommandConstructor
|
||||
|
||||
// Service is expected to have an unexported type *service.systemd.
|
||||
service.Service
|
||||
|
||||
@ -150,26 +162,37 @@ var _ service.Service = (*systemdService)(nil)
|
||||
|
||||
// Status implements the [service.Service] interface for *systemdService.
|
||||
func (s *systemdService) Status() (status service.Status, err error) {
|
||||
cmd := exec.Command("systemctl", "show", s.unitName)
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return service.StatusUnknown, fmt.Errorf("connecting to command stdout: %w", err)
|
||||
}
|
||||
const systemctlCmd = "systemctl"
|
||||
|
||||
if err = cmd.Start(); err != nil {
|
||||
return service.StatusUnknown, fmt.Errorf("start command executing: %w", err)
|
||||
}
|
||||
var (
|
||||
systemctlArgs = []string{"show", s.unitName}
|
||||
systemctlStdout bytes.Buffer
|
||||
)
|
||||
|
||||
status, err = parseSystemctlShow(stdout)
|
||||
if err != nil {
|
||||
return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err)
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
// TODO(s.chzhen): Consider streaming the output if needed. Using
|
||||
// [io.Pipe] here is unnecessary; it complicates lifecycle management
|
||||
// because the output must be read concurrently, and the PipeWriter must be
|
||||
// explicitly closed to signal EOF. Since this command's output is small, a
|
||||
// bytes.Buffer via executil.Run is sufficient.
|
||||
err = executil.Run(
|
||||
// TODO(s.chzhen): Pass context.
|
||||
context.TODO(),
|
||||
s.cmdCons,
|
||||
&executil.CommandConfig{
|
||||
Path: systemctlCmd,
|
||||
Args: systemctlArgs,
|
||||
Stdout: &systemctlStdout,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return service.StatusUnknown, fmt.Errorf("executing command: %w", err)
|
||||
}
|
||||
|
||||
status, err = parseSystemctlShow(&systemctlStdout)
|
||||
if err != nil {
|
||||
return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err)
|
||||
}
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
|
||||
@ -4,6 +4,7 @@ package home
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
@ -15,6 +16,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/log"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/kardianos/service"
|
||||
)
|
||||
|
||||
@ -57,16 +59,18 @@ func (openbsdSystem) Interactive() (ok bool) {
|
||||
// New implements service.System interface for openbsdSystem.
|
||||
func (openbsdSystem) New(i service.Interface, c *service.Config) (s service.Service, err error) {
|
||||
return &openbsdRunComService{
|
||||
i: i,
|
||||
cfg: c,
|
||||
cmdCons: executil.SystemCommandConstructor{},
|
||||
i: i,
|
||||
cfg: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// openbsdRunComService is the RunCom-based service.Service to be used on the
|
||||
// OpenBSD.
|
||||
type openbsdRunComService struct {
|
||||
i service.Interface
|
||||
cfg *service.Config
|
||||
cmdCons executil.CommandConstructor
|
||||
i service.Interface
|
||||
cfg *service.Config
|
||||
}
|
||||
|
||||
// Platform implements service.Service interface for *openbsdRunComService.
|
||||
@ -210,8 +214,10 @@ func (s *openbsdRunComService) configureSysStartup(enable bool) (err error) {
|
||||
cmd = "disable"
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
var code int
|
||||
code, _, err = aghos.RunCommand("rcctl", cmd, s.cfg.Name)
|
||||
code, _, err = aghos.RunCommand(ctx, s.cmdCons, "rcctl", cmd, s.cfg.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if code != 0 {
|
||||
@ -312,11 +318,14 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// TODO(s.chzhen): Pass context.
|
||||
ctx := context.TODO()
|
||||
|
||||
// TODO(e.burkov): It's possible that os.ErrNotExist is caused by
|
||||
// something different than the service script's non-existence. Keep it
|
||||
// in mind, when replace the aghos.RunCommand.
|
||||
var outData []byte
|
||||
_, outData, err = aghos.RunCommand(scriptPath, cmd)
|
||||
_, outData, err = aghos.RunCommand(ctx, s.cmdCons, scriptPath, cmd)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return "", service.ErrNotInstalled
|
||||
}
|
||||
|
||||
@ -20,6 +20,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/netutil/httputil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/NYTimes/gziphandler"
|
||||
"github.com/quic-go/quic-go/http3"
|
||||
"golang.org/x/net/http2"
|
||||
@ -39,6 +40,9 @@ const (
|
||||
)
|
||||
|
||||
type webConfig struct {
|
||||
// CommandConstructor is used to run external commands. It must not be nil.
|
||||
CommandConstructor executil.CommandConstructor
|
||||
|
||||
updater *updater.Updater
|
||||
|
||||
// logger is a slog logger used in webAPI. It must not be nil.
|
||||
@ -111,6 +115,9 @@ type webAPI struct {
|
||||
// confModifier is used to update the global configuration.
|
||||
confModifier agh.ConfigModifier
|
||||
|
||||
// cmdCons is used to run external commands.
|
||||
cmdCons executil.CommandConstructor
|
||||
|
||||
// TODO(a.garipov): Refactor all these servers.
|
||||
httpServer *http.Server
|
||||
|
||||
@ -143,6 +150,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) {
|
||||
w = &webAPI{
|
||||
conf: conf,
|
||||
confModifier: conf.confModifier,
|
||||
cmdCons: conf.CommandConstructor,
|
||||
logger: conf.logger,
|
||||
baseLogger: conf.baseLogger,
|
||||
tlsManager: conf.tlsManager,
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -58,13 +59,14 @@ func TestUpdater_VersionInfo(t *testing.T) {
|
||||
fakeURL := srvURL.JoinPath("adguardhome", version.ChannelBeta, "version.json")
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
Logger: testLogger,
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOARCH: "arm",
|
||||
GOOS: "linux",
|
||||
VersionCheckURL: fakeURL,
|
||||
Client: srv.Client(),
|
||||
Logger: testLogger,
|
||||
CommandConstructor: executil.EmptyCommandConstructor{},
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOARCH: "arm",
|
||||
GOOS: "linux",
|
||||
VersionCheckURL: fakeURL,
|
||||
})
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
@ -132,15 +134,16 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: fakeClient,
|
||||
Logger: testLogger,
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOOS: "linux",
|
||||
GOARCH: tc.arch,
|
||||
GOARM: tc.arm,
|
||||
GOMIPS: tc.mips,
|
||||
VersionCheckURL: fakeURL,
|
||||
Client: fakeClient,
|
||||
Logger: testLogger,
|
||||
CommandConstructor: executil.EmptyCommandConstructor{},
|
||||
Version: "v0.103.0-beta.1",
|
||||
Channel: version.ChannelBeta,
|
||||
GOOS: "linux",
|
||||
GOARCH: tc.arch,
|
||||
GOARM: tc.arm,
|
||||
GOMIPS: tc.mips,
|
||||
VersionCheckURL: fakeURL,
|
||||
})
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
@ -4,6 +4,7 @@ package updater
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"fmt"
|
||||
@ -13,7 +14,6 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/ioutil"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/netutil/urlutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
// Updater is the AdGuard Home updater.
|
||||
@ -33,6 +34,8 @@ type Updater struct {
|
||||
client *http.Client
|
||||
logger *slog.Logger
|
||||
|
||||
cmdCons executil.CommandConstructor
|
||||
|
||||
version string
|
||||
channel string
|
||||
goarch string
|
||||
@ -88,6 +91,9 @@ type Config struct {
|
||||
// be nil, see [DefaultVersionURL].
|
||||
VersionCheckURL *url.URL
|
||||
|
||||
// CommandConstructor is used to run external commands. It must not be nil.
|
||||
CommandConstructor executil.CommandConstructor
|
||||
|
||||
// Version is the current AdGuard Home version. It must not be empty.
|
||||
Version string
|
||||
|
||||
@ -129,6 +135,8 @@ func NewUpdater(conf *Config) *Updater {
|
||||
client: conf.Client,
|
||||
logger: conf.Logger,
|
||||
|
||||
cmdCons: conf.CommandConstructor,
|
||||
|
||||
version: conf.Version,
|
||||
channel: conf.Channel,
|
||||
goarch: conf.GOARCH,
|
||||
@ -286,11 +294,27 @@ func (u *Updater) check(ctx context.Context) (err error) {
|
||||
"%s" +
|
||||
"end of the output"
|
||||
|
||||
cmd := exec.Command(u.updateExeName, "--check-config")
|
||||
out, err := cmd.CombinedOutput()
|
||||
code := cmd.ProcessState.ExitCode()
|
||||
if err != nil || code != 0 {
|
||||
return fmt.Errorf(format, err, code, out)
|
||||
var (
|
||||
args = []string{"--check-config"}
|
||||
buf bytes.Buffer
|
||||
)
|
||||
|
||||
u.logger.DebugContext(ctx, "executing", "cmd", u.updateExeName, "args", args)
|
||||
|
||||
err = executil.Run(
|
||||
ctx,
|
||||
u.cmdCons,
|
||||
&executil.CommandConfig{
|
||||
Path: u.updateExeName,
|
||||
Args: args,
|
||||
Stdout: &buf,
|
||||
Stderr: &buf,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
code, _ := executil.ExitCodeFromError(err)
|
||||
|
||||
return fmt.Errorf(format, err, code, buf.Bytes())
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -58,13 +59,14 @@ func TestUpdater_internal(t *testing.T) {
|
||||
fakeURL = fakeURL.JoinPath(tc.archiveName)
|
||||
|
||||
u := NewUpdater(&Config{
|
||||
Client: fakeClient,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
GOOS: tc.os,
|
||||
Version: "v0.103.0",
|
||||
ExecPath: exePath,
|
||||
WorkDir: wd,
|
||||
ConfName: yamlPath,
|
||||
Client: fakeClient,
|
||||
Logger: slogutil.NewDiscardLogger(),
|
||||
CommandConstructor: executil.EmptyCommandConstructor{},
|
||||
GOOS: tc.os,
|
||||
Version: "v0.103.0",
|
||||
ExecPath: exePath,
|
||||
WorkDir: wd,
|
||||
ConfName: yamlPath,
|
||||
// TODO(e.burkov): Rewrite the test to use a fake version check
|
||||
// URL with a fake URLs for the package files.
|
||||
VersionCheckURL: &url.URL{},
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/updater"
|
||||
"github.com/AdguardTeam/AdGuardHome/internal/version"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
"github.com/AdguardTeam/golibs/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -76,15 +77,16 @@ func TestUpdater_Update(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
u := updater.NewUpdater(&updater.Config{
|
||||
Client: srv.Client(),
|
||||
Logger: testLogger,
|
||||
GOARCH: "amd64",
|
||||
GOOS: "linux",
|
||||
Version: "v0.103.0",
|
||||
ConfName: yamlPath,
|
||||
WorkDir: wd,
|
||||
ExecPath: exePath,
|
||||
VersionCheckURL: versionCheckURL,
|
||||
Client: srv.Client(),
|
||||
Logger: testLogger,
|
||||
CommandConstructor: executil.EmptyCommandConstructor{},
|
||||
GOARCH: "amd64",
|
||||
GOOS: "linux",
|
||||
Version: "v0.103.0",
|
||||
ConfName: yamlPath,
|
||||
WorkDir: wd,
|
||||
ExecPath: exePath,
|
||||
VersionCheckURL: versionCheckURL,
|
||||
})
|
||||
|
||||
ctx := testutil.ContextWithTimeout(t, testTimeout)
|
||||
|
||||
@ -13,7 +13,6 @@ import (
|
||||
"maps"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
@ -23,6 +22,7 @@ import (
|
||||
"github.com/AdguardTeam/golibs/errors"
|
||||
"github.com/AdguardTeam/golibs/logutil/slogutil"
|
||||
"github.com/AdguardTeam/golibs/osutil"
|
||||
"github.com/AdguardTeam/golibs/osutil/executil"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -96,7 +96,7 @@ func main() {
|
||||
|
||||
errors.Check(cli.upload())
|
||||
case "auto-add":
|
||||
err := autoAdd(conf.LocalizableFiles[0])
|
||||
err := autoAdd(ctx, l, conf.LocalizableFiles[0])
|
||||
errors.Check(err)
|
||||
default:
|
||||
usage("unknown command")
|
||||
@ -395,10 +395,12 @@ func findUnused(fileNames []string, loc locales) (err error) {
|
||||
|
||||
// autoAdd adds locales with additions to the git and restores locales with
|
||||
// deletions.
|
||||
func autoAdd(basePath string) (err error) {
|
||||
func autoAdd(ctx context.Context, l *slog.Logger, basePath string) (err error) {
|
||||
defer func() { err = errors.Annotate(err, "auto add: %w") }()
|
||||
|
||||
adds, dels, err := changedLocales()
|
||||
cmdCons := executil.SystemCommandConstructor{}
|
||||
|
||||
adds, dels, err := changedLocales(ctx, l, cmdCons)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return err
|
||||
@ -408,13 +410,13 @@ func autoAdd(basePath string) (err error) {
|
||||
return errors.Error("base locale contains deletions")
|
||||
}
|
||||
|
||||
err = handleAdds(adds)
|
||||
err = handleAdds(ctx, l, cmdCons, adds)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil
|
||||
}
|
||||
|
||||
err = handleDels(dels)
|
||||
err = handleDels(ctx, l, cmdCons, dels)
|
||||
if err != nil {
|
||||
// Don't wrap the error since it's informative enough as is.
|
||||
return nil
|
||||
@ -423,14 +425,24 @@ func autoAdd(basePath string) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// gitCmd is the shell command for Git.
|
||||
const gitCmd = "git"
|
||||
|
||||
// handleAdds adds locales with additions to the git.
|
||||
func handleAdds(locales []string) (err error) {
|
||||
func handleAdds(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
cmdCons executil.CommandConstructor,
|
||||
locales []string,
|
||||
) (err error) {
|
||||
if len(locales) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
args := append([]string{"add"}, locales...)
|
||||
code, out, err := aghos.RunCommand("git", args...)
|
||||
gitArgs := append([]string{"add"}, locales...)
|
||||
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
|
||||
|
||||
code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...)
|
||||
|
||||
if err != nil || code != 0 {
|
||||
return fmt.Errorf("git add exited with code %d output %q: %w", code, out, err)
|
||||
@ -440,13 +452,20 @@ func handleAdds(locales []string) (err error) {
|
||||
}
|
||||
|
||||
// handleDels restores locales with deletions.
|
||||
func handleDels(locales []string) (err error) {
|
||||
func handleDels(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
cmdCons executil.CommandConstructor,
|
||||
locales []string,
|
||||
) (err error) {
|
||||
if len(locales) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
args := append([]string{"restore"}, locales...)
|
||||
code, out, err := aghos.RunCommand("git", args...)
|
||||
gitArgs := append([]string{"restore"}, locales...)
|
||||
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
|
||||
|
||||
code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...)
|
||||
|
||||
if err != nil || code != 0 {
|
||||
return fmt.Errorf("git restore exited with code %d output %q: %w", code, out, err)
|
||||
@ -458,22 +477,32 @@ func handleDels(locales []string) (err error) {
|
||||
// changedLocales returns cleaned paths of locales with changes or error. adds
|
||||
// is the list of locales with only additions. dels is the list of locales
|
||||
// with only deletions.
|
||||
func changedLocales() (adds, dels []string, err error) {
|
||||
func changedLocales(
|
||||
ctx context.Context,
|
||||
l *slog.Logger,
|
||||
cmdCons executil.CommandConstructor,
|
||||
) (adds, dels []string, err error) {
|
||||
defer func() { err = errors.Annotate(err, "getting changes: %w") }()
|
||||
|
||||
cmd := exec.Command("git", "diff", "--numstat", localesDir)
|
||||
gitArgs := []string{"diff", "--numstat", localesDir}
|
||||
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
// TODO(s.chzhen): Consider streaming the output if needed. Using
|
||||
// [io.Pipe] here is unnecessary; it complicates lifecycle management
|
||||
// because the output must be read concurrently, and the PipeWriter must be
|
||||
// explicitly closed to signal EOF. Since this command's output is small, a
|
||||
// bytes.Buffer via executil.Run is sufficient.
|
||||
var out bytes.Buffer
|
||||
err = executil.Run(ctx, cmdCons, &executil.CommandConfig{
|
||||
Path: gitCmd,
|
||||
Args: gitArgs,
|
||||
Stdout: &out,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("piping: %w", err)
|
||||
return nil, nil, fmt.Errorf("executing cmd: %w", err)
|
||||
}
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("starting: %w", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stdout)
|
||||
scanner := bufio.NewScanner(&out)
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
@ -497,10 +526,5 @@ func changedLocales() (adds, dels []string, err error) {
|
||||
return nil, nil, fmt.Errorf("scanning: %w", err)
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("waiting: %w", err)
|
||||
}
|
||||
|
||||
return adds, dels, nil
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user