all: use executil

This commit is contained in:
Stanislav Chzhen 2025-08-20 15:55:19 +03:00
parent 5e2b4405fd
commit 4fc73dca76
21 changed files with 526 additions and 272 deletions

View File

@ -3,6 +3,12 @@ package agh
import (
"context"
"fmt"
"strings"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil/fakeos/fakeexec"
)
// ConfigModifier defines an interface for updating the global configuration.
@ -20,3 +26,124 @@ var _ ConfigModifier = EmptyConfigModifier{}
// Apply implements the [ConfigModifier] for EmptyConfigModifier.
func (em EmptyConfigModifier) Apply(ctx context.Context) {}
// TODO(s.chzhen): !! Is there another way?
//
// TODO(s.chzhen): !! Docs, naming.
//
// TODO(s.chzhen): Move to aghtest once the import cycle is resolved.
type exitErr struct {
code osutil.ExitCode
}
// type check
var _ executil.ExitCodeError = exitErr{}
func (e exitErr) Error() (s string) {
return fmt.Sprintf("exit code %d", e.code)
}
func (e exitErr) ExitCode() (code osutil.ExitCode) {
return e.code
}
type ExternalCommand struct {
Err error
Cmd string
Out string
Code int
}
func keyCommand(path string, args []string) (k string) {
return path + " " + strings.Join(args, " ")
}
func parseCommand(s string) (path string, args []string) {
f := strings.Fields(s)
if len(f) == 0 {
return "", nil
}
return f[0], f[1:]
}
// NewMultipleCommandConstructor is a helper function that returns a mock
// [executil.CommandConstructor] for tests.
func NewMultipleCommandConstructor(cmds ...ExternalCommand) (cs executil.CommandConstructor) {
table := make(map[string]ExternalCommand, len(cmds))
for _, ec := range cmds {
p, a := parseCommand(ec.Cmd)
table[keyCommand(p, a)] = ec
}
return &fakeexec.CommandConstructor{
OnNew: func(
_ context.Context,
conf *executil.CommandConfig,
) (c executil.Command, err error) {
ec := table[keyCommand(conf.Path, conf.Args)]
cmd := fakeexec.NewCommand()
cmd.OnStart = func(_ context.Context) (err error) {
if ec.Out != "" {
_, _ = conf.Stdout.Write([]byte(ec.Out))
}
return nil
}
cmd.OnWait = func(_ context.Context) (err error) {
if ec.Err != nil {
return ec.Err
}
if ec.Code != 0 {
return exitErr{code: ec.Code}
}
return nil
}
return cmd, nil
},
}
}
// NewCommandConstructor is a helper function that returns a mock
// [executil.CommandConstructor] for tests.
func NewCommandConstructor(
_ string,
code int,
stdout string,
cmdErr error,
) (cs executil.CommandConstructor) {
return &fakeexec.CommandConstructor{
OnNew: func(
_ context.Context,
conf *executil.CommandConfig,
) (c executil.Command, err error) {
cmd := fakeexec.NewCommand()
cmd.OnStart = func(_ context.Context) (err error) {
if conf.Stdout != nil {
_, _ = conf.Stdout.Write([]byte(stdout))
}
return nil
}
cmd.OnWait = func(_ context.Context) (err error) {
if cmdErr != nil {
return cmdErr
}
if code != 0 {
return exitErr{code: code}
}
return nil
}
return cmd, nil
},
}
}

View File

@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// DialContextFunc is the semantic alias for dialing functions, such as
@ -27,7 +28,16 @@ type DialContextFunc = func(ctx context.Context, network, addr string) (conn net
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.
aghosRunCommand = aghos.RunCommand
//
// TODO(s.chzhen): Use [aghos.RunCommand] directly.
aghosRunCommand = (func() func(string, ...string) (int, []byte, error) {
ctx := context.TODO()
cmdCons := executil.SystemCommandConstructor{}
return func(command string, arguments ...string) (int, []byte, error) {
return aghos.RunCommand(ctx, cmdCons, command, arguments...)
}
})()
// netInterfaces is the function to get the available network interfaces.
netInterfaceAddrs = net.InterfaceAddrs

View File

@ -5,13 +5,13 @@ package aghos
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"os/exec"
"path"
"runtime"
"slices"
@ -19,6 +19,9 @@ import (
"strings"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// Default file, binary, and directory permissions.
@ -50,23 +53,48 @@ func HaveAdminRights() (bool, error) {
const MaxCmdOutputSize = 64 * 1024
// RunCommand runs shell command.
func RunCommand(command string, arguments ...string) (code int, output []byte, err error) {
cmd := exec.Command(command, arguments...)
out, err := cmd.Output()
//
// TODO(s.chzhen): Consider removing this after addressing the current behavior
// where a non-zero exit code is returned together with a nil error.
func RunCommand(
ctx context.Context,
cmdCons executil.CommandConstructor,
command string,
arguments ...string,
) (code int, output []byte, err error) {
stdoutBuf := bytes.Buffer{}
stderrBuf := bytes.Buffer{}
out = out[:min(len(out), MaxCmdOutputSize)]
err = executil.Run(
ctx,
cmdCons,
&executil.CommandConfig{
Path: command,
Args: arguments,
Stdout: ioutil.NewTruncatedWriter(&stdoutBuf, MaxCmdOutputSize),
Stderr: &stderrBuf,
},
)
if err != nil {
if eerr := new(exec.ExitError); errors.As(err, &eerr) {
return eerr.ExitCode(), eerr.Stderr, nil
}
return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out)
if err == nil {
return osutil.ExitCodeSuccess, stdoutBuf.Bytes(), nil
}
return cmd.ProcessState.ExitCode(), out, nil
code, ok := executil.ExitCodeFromError(err)
if ok {
// Mirror the old behavior and return a nil-error on non-zero code
// status.
return code, stderrBuf.Bytes(), nil
}
return osutil.ExitCodeFailure,
nil,
fmt.Errorf("command %q failed: %w: %s", command, err, stdoutBuf.Bytes())
}
// psArgs holds the default ps arguments to avoid per-call slice allocations.
var psArgs = []string{"-A", "-o", "pid=", "-o", "comm="}
// PIDByCommand searches for process named command and returns its PID ignoring
// the PIDs from except. If no processes found, the error returned. l must not
// be nil.
@ -76,31 +104,34 @@ func PIDByCommand(
command string,
except ...int,
) (pid int, err error) {
const psCmd = "ps"
l.DebugContext(ctx, "executing", "cmd", psCmd, "args", psArgs)
// Don't use -C flag here since it's a feature of linux's ps
// implementation. Use POSIX-compatible flags instead.
//
// See https://github.com/AdguardTeam/AdGuardHome/issues/3457.
cmd := exec.Command("ps", "-A", "-o", "pid=", "-o", "comm=")
var stdout io.ReadCloser
if stdout, err = cmd.StdoutPipe(); err != nil {
return 0, fmt.Errorf("getting the command's stdout pipe: %w", err)
}
if err = cmd.Start(); err != nil {
return 0, fmt.Errorf("start command executing: %w", err)
stdoutBuf := bytes.Buffer{}
err = executil.Run(
ctx,
executil.SystemCommandConstructor{},
&executil.CommandConfig{
Path: psCmd,
Args: psArgs,
Stdout: &stdoutBuf,
},
)
if err != nil {
return 0, fmt.Errorf("executing the command: %w", err)
}
var instNum int
pid, instNum, err = parsePSOutput(stdout, command, except)
pid, instNum, err = parsePSOutput(&stdoutBuf, command, except)
if err != nil {
return 0, err
}
if err = cmd.Wait(); err != nil {
return 0, fmt.Errorf("executing the command: %w", err)
}
switch instNum {
case 0:
// TODO(e.burkov): Use constant error.
@ -111,10 +142,6 @@ func PIDByCommand(
l.WarnContext(ctx, "instances found", "num", instNum, "command", command)
}
if code := cmd.ProcessState.ExitCode(); code != 0 {
return 0, fmt.Errorf("ps finished with code %d", code)
}
return pid, nil
}

View File

@ -4,6 +4,7 @@ package arpdb
import (
"bufio"
"bytes"
"context"
"fmt"
"log/slog"
"net"
@ -11,18 +12,15 @@ import (
"slices"
"sync"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// Variables and functions to substitute in tests.
var (
// aghosRunCommand is the function to run shell commands.
aghosRunCommand = aghos.RunCommand
// rootDirFS is the filesystem pointing to the root directory.
rootDirFS = osutil.RootDirFS()
)
@ -40,7 +38,7 @@ type Interface interface {
// New returns the [Interface] properly initialized for the OS.
func New(logger *slog.Logger) (arp Interface) {
return newARPDB(logger)
return newARPDB(logger, executil.SystemCommandConstructor{})
}
// Empty is the [Interface] implementation that does nothing.
@ -164,11 +162,12 @@ type parseNeighsFunc func(logger *slog.Logger, sc *bufio.Scanner, lenHint int) (
// cmdARPDB is the implementation of the [Interface] that uses command line to
// retrieve data.
type cmdARPDB struct {
logger *slog.Logger
parse parseNeighsFunc
ns *neighs
cmd string
args []string
logger *slog.Logger
cmdCons executil.CommandConstructor
parse parseNeighsFunc
ns *neighs
cmd string
args []string
}
// type check
@ -178,14 +177,26 @@ var _ Interface = (*cmdARPDB)(nil)
func (arp *cmdARPDB) Refresh() (err error) {
defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }()
code, out, err := aghosRunCommand(arp.cmd, arp.args...)
var stdout bytes.Buffer
err = executil.Run(
// TODO(s.chzhen): Pass context.
context.TODO(),
arp.cmdCons,
&executil.CommandConfig{
Path: arp.cmd,
Args: arp.args,
Stdout: &stdout,
},
)
if err != nil {
if code, ok := executil.ExitCodeFromError(err); ok {
return fmt.Errorf("running command: unexpected exit code %d", code)
}
return fmt.Errorf("running command: %w", err)
} else if code != 0 {
return fmt.Errorf("running command: unexpected exit code %d", code)
}
sc := bufio.NewScanner(bytes.NewReader(out))
sc := bufio.NewScanner(&stdout)
ns := arp.parse(arp.logger, sc, arp.ns.len())
if err = sc.Err(); err != nil {
// TODO(e.burkov): This error seems unreachable. Investigate.

View File

@ -9,12 +9,14 @@ import (
"sync"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
return &cmdARPDB{
logger: logger,
parse: parseArpA,
logger: logger,
cmdCons: cmdCons,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),

View File

@ -1,15 +1,14 @@
package arpdb
import (
"fmt"
"io/fs"
"net"
"net/netip"
"os"
"strings"
"sync"
"testing"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/testutil"
@ -23,43 +22,6 @@ var testdata fs.FS = os.DirFS("./testdata")
// RunCmdFunc is the signature of aghos.RunCommand function.
type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error)
// substShell replaces the the aghos.RunCommand function used throughout the
// package with rc for tests ran under tb.
func substShell(tb testing.TB, rc RunCmdFunc) {
tb.Helper()
prev := aghosRunCommand
tb.Cleanup(func() { aghosRunCommand = prev })
aghosRunCommand = rc
}
// mapShell is a substitution of aghos.RunCommand that maps the command to it's
// execution result. It's only needed to simplify testing.
//
// TODO(e.burkov): Perhaps put all the shell interactions behind an interface.
type mapShell map[string]struct {
err error
out string
code int
}
// theOnlyCmd returns mapShell that only handles a single command and arguments
// combination from cmd.
func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) {
return mapShell{cmd: {code: code, out: out, err: err}}
}
// RunCmd is a RunCmdFunc handled by s.
func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) {
key := strings.Join(append([]string{cmd}, args...), " ")
ret, ok := s[key]
if !ok {
return 0, nil, fmt.Errorf("unexpected shell command %q", key)
}
return ret.code, []byte(ret.out), ret.err
}
func Test_New(t *testing.T) {
var a Interface
require.NotPanics(t, func() { a = New(slogutil.NewDiscardLogger()) })
@ -212,8 +174,7 @@ func TestCmdARPDB_arpa(t *testing.T) {
}
t.Run("arp_a", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, arpAOutput, nil)
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 0, arpAOutput, nil)
err := a.Refresh()
require.NoError(t, err)
@ -222,24 +183,25 @@ func TestCmdARPDB_arpa(t *testing.T) {
})
t.Run("runcmd_error", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run"))
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", errors.Error("can't run"))
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err)
testutil.AssertErrorMsg(t, "cmd arpdb: running command: running: can't run", err)
})
t.Run("bad_code", func(t *testing.T) {
sh := theOnlyCmd("cmd", 1, "", nil)
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 1, "", nil)
err := a.Refresh()
testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err)
testutil.AssertErrorMsg(
t,
"cmd arpdb: running command: unexpected exit code 1",
err,
)
})
t.Run("empty", func(t *testing.T) {
sh := theOnlyCmd("cmd", 0, "", nil)
substShell(t, sh.RunCmd)
a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", nil)
err := a.Refresh()
require.NoError(t, err)

View File

@ -14,10 +14,11 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/stringutil"
)
func newARPDB(logger *slog.Logger) (arp *arpdbs) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *arpdbs) {
// Use the common storage among the implementations.
ns := &neighs{
mu: &sync.RWMutex{},
@ -40,10 +41,11 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
},
// Then, try "arp -a -n".
&cmdARPDB{
logger: logger,
parse: parseF,
ns: ns,
cmd: "arp",
logger: logger,
cmdCons: cmdCons,
parse: parseF,
ns: ns,
cmd: "arp",
// Use -n flag to avoid resolving the hostnames of the neighbors.
// By default ARP attempts to resolve the hostnames via DNS. See
// man 8 arp.
@ -53,11 +55,12 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) {
},
// Finally, try "ip neigh".
&cmdARPDB{
logger: logger,
parse: parseIPNeigh,
ns: ns,
cmd: "ip",
args: []string{"neigh"},
logger: logger,
cmdCons: cmdCons,
parse: parseIPNeigh,
ns: ns,
cmd: "ip",
args: []string{"neigh"},
},
)
}

View File

@ -9,6 +9,7 @@ import (
"testing"
"testing/fstest"
"github.com/AdguardTeam/AdGuardHome/internal/agh"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -62,18 +63,13 @@ func TestFSysARPDB(t *testing.T) {
}
func TestCmdARPDB_linux(t *testing.T) {
sh := mapShell{
"arp -a": {err: nil, out: arpAOutputWrt, code: 0},
"ip neigh": {err: nil, out: ipNeighOutput, code: 0},
}
substShell(t, sh.RunCmd)
t.Run("wrt", func(t *testing.T) {
a := &cmdARPDB{
logger: slogutil.NewDiscardLogger(),
parse: parseArpAWrt,
cmd: "arp",
args: []string{"-a"},
logger: slogutil.NewDiscardLogger(),
cmdCons: agh.NewCommandConstructor("arp -a", 0, arpAOutputWrt, nil),
parse: parseArpAWrt,
cmd: "arp",
args: []string{"-a"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),
@ -88,10 +84,11 @@ func TestCmdARPDB_linux(t *testing.T) {
t.Run("ip_neigh", func(t *testing.T) {
a := &cmdARPDB{
logger: slogutil.NewDiscardLogger(),
parse: parseIPNeigh,
cmd: "ip",
args: []string{"neigh"},
logger: slogutil.NewDiscardLogger(),
cmdCons: agh.NewCommandConstructor("ip neigh", 0, ipNeighOutput, nil),
parse: parseIPNeigh,
cmd: "ip",
args: []string{"neigh"},
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),

View File

@ -9,12 +9,14 @@ import (
"sync"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
return &cmdARPDB{
logger: logger,
parse: parseArpA,
logger: logger,
cmdCons: cmdCons,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),

View File

@ -9,12 +9,14 @@ import (
"sync"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
func newARPDB(logger *slog.Logger) (arp *cmdARPDB) {
func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) {
return &cmdARPDB{
logger: logger,
parse: parseArpA,
logger: logger,
cmdCons: cmdCons,
parse: parseArpA,
ns: &neighs{
mu: &sync.RWMutex{},
ns: make([]Neighbor, 0),

View File

@ -1,6 +1,7 @@
package home
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -9,7 +10,6 @@ import (
"net/http"
"net/netip"
"os"
"os/exec"
"path/filepath"
"runtime"
"time"
@ -20,8 +20,10 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghhttp"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/container"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/quic-go/quic-go/http3"
)
@ -243,34 +245,30 @@ func checkDNSStubListener(ctx context.Context, l *slog.Logger) (ok bool) {
return false
}
cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved")
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err := cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
l.InfoContext(
cmds := container.KeyValues[string, []string]{{
Key: "systemctl",
Value: []string{"is-enabled", "systemd-resolved"},
}, {
Key: "grep",
Value: []string{"-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf"},
}}
for _, cmd := range cmds {
l.DebugContext(ctx, "executing", "cmd", cmd.Key, "args", cmd.Value)
err := executil.Run(
ctx,
"execution failed",
"cmd", cmd.Path,
"code", cmd.ProcessState.ExitCode(),
slogutil.KeyError, err,
executil.SystemCommandConstructor{},
&executil.CommandConfig{
Path: cmd.Key,
Args: cmd.Value,
},
)
if err != nil {
l.InfoContext(ctx, "execution failed", "cmd", cmd.Key, slogutil.KeyError, err)
return false
}
cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf")
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err = cmd.Output()
if err != nil || cmd.ProcessState.ExitCode() != 0 {
l.InfoContext(
ctx,
"execution failed",
"cmd", cmd.Path,
"code", cmd.ProcessState.ExitCode(),
slogutil.KeyError, err,
)
return false
return false
}
}
return true
@ -306,15 +304,30 @@ func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) {
return fmt.Errorf("os.Symlink: %s: %w", resolvConfPath, err)
}
cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved")
l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args)
_, err = cmd.Output()
const (
systemctlCmd = "systemctl"
)
var (
systemctlArgs = []string{"reload-or-restart", "systemd-resolved"}
systemctlStdout bytes.Buffer
systemctlStderr bytes.Buffer
)
l.DebugContext(ctx, "executing", "cmd", systemctlCmd, "args", systemctlArgs)
err = executil.Run(
ctx,
executil.SystemCommandConstructor{},
&executil.CommandConfig{
Path: systemctlCmd,
Args: systemctlArgs,
Stdout: &systemctlStdout,
Stderr: &systemctlStderr,
},
)
if err != nil {
return err
}
if cmd.ProcessState.ExitCode() != 0 {
return fmt.Errorf("process %s exited with an error: %d",
cmd.Path, cmd.ProcessState.ExitCode())
return fmt.Errorf("executing cmd: %w", err)
}
return nil

View File

@ -7,7 +7,6 @@ import (
"log/slog"
"net/http"
"os"
"os/exec"
"runtime"
"syscall"
"time"
@ -19,6 +18,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// temporaryError is the interface for temporary errors from the Go standard
@ -195,6 +195,8 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA
cleanup(ctx)
cleanupAlways()
cons := executil.SystemCommandConstructor{}
var err error
if runtime.GOOS == "windows" {
if runningAsService {
@ -203,8 +205,16 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA
// instance, because Windows doesn't allow it.
//
// TODO(a.garipov): Recheck the claim above.
cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome")
err = cmd.Start()
var cmd executil.Command
cmd, err = cons.New(ctx, &executil.CommandConfig{
Path: "cmd",
Args: []string{"/c", "net stop AdGuardHome & net start AdGuardHome"},
})
if err != nil {
panic(fmt.Errorf("constructing cmd: %w", err))
}
err = cmd.Start(ctx)
if err != nil {
panic(fmt.Errorf("restarting service: %w", err))
}
@ -212,12 +222,21 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA
os.Exit(osutil.ExitCodeSuccess)
}
cmd := exec.Command(execPath, os.Args[1:]...)
l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:])
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Start()
var cmd executil.Command
cmd, err = cons.New(ctx, &executil.CommandConfig{
Path: execPath,
Args: os.Args[1:],
Stdin: os.Stdin,
Stdout: os.Stdout,
Stderr: os.Stderr,
})
if err != nil {
panic(fmt.Errorf("constructing cmd: %w", err))
}
err = cmd.Start(ctx)
if err != nil {
panic(fmt.Errorf("restarting: %w", err))
}

View File

@ -43,6 +43,7 @@ import (
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// Global context
@ -820,18 +821,19 @@ func newUpdater(
l.DebugContext(ctx, "creating updater", "config_path", confPath)
return updater.NewUpdater(&updater.Config{
Client: conf.Filtering.HTTPClient,
Logger: l,
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: workDir,
ConfName: confPath,
ExecPath: execPath,
VersionCheckURL: versionURL,
Client: conf.Filtering.HTTPClient,
Logger: l,
CommandConstructor: executil.SystemCommandConstructor{},
Version: version.Version(),
Channel: version.Channel(),
GOARCH: runtime.GOARCH,
GOOS: runtime.GOOS,
GOARM: version.GOARM(),
GOMIPS: version.GOMIPS(),
WorkDir: workDir,
ConfName: confPath,
ExecPath: execPath,
VersionCheckURL: versionURL,
}), isCustomURL
}

View File

@ -18,6 +18,7 @@ import (
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/kardianos/service"
)
@ -76,11 +77,11 @@ func (p *program) Stop(_ service.Service) (err error) {
//
// On OpenWrt, the service utility may not exist. We use our service script
// directly in this case.
func svcStatus(s service.Service) (status service.Status, err error) {
func svcStatus(ctx context.Context, s service.Service) (status service.Status, err error) {
status, err = s.Status()
if err != nil && service.Platform() == "unix-systemv" {
var code int
code, err = runInitdCommand("status")
code, err = runInitdCommand(ctx, "status")
if err != nil || code != 0 {
return service.StatusStopped, nil
}
@ -105,7 +106,7 @@ func svcAction(ctx context.Context, l *slog.Logger, s service.Service, action st
err = service.Control(s, action)
if err != nil && service.Platform() == "unix-systemv" &&
(action == "start" || action == "stop" || action == "restart") {
_, err = runInitdCommand(action)
_, err = runInitdCommand(ctx, action)
}
return err
@ -326,7 +327,7 @@ func handleServiceStatusCommand(
l *slog.Logger,
s service.Service,
) {
status, errSt := svcStatus(s)
status, errSt := svcStatus(ctx, s)
if errSt != nil {
l.ErrorContext(ctx, "failed to get service status", slogutil.KeyError, errSt)
os.Exit(osutil.ExitCodeFailure)
@ -356,7 +357,7 @@ func handleServiceInstallCommand(ctx context.Context, l *slog.Logger, s service.
// On OpenWrt it is important to run enable after the service
// installation. Otherwise, the service won't start on the system
// startup.
_, err = runInitdCommand("enable")
_, err = runInitdCommand(ctx, "enable")
if err != nil {
l.ErrorContext(ctx, "running init enable", slogutil.KeyError, err)
os.Exit(osutil.ExitCodeFailure)
@ -386,7 +387,7 @@ func handleServiceUninstallCommand(ctx context.Context, l *slog.Logger, s servic
if aghos.IsOpenWrt() {
// On OpenWrt it is important to run disable command first
// as it will remove the symlink
_, err := runInitdCommand("disable")
_, err := runInitdCommand(ctx, "disable")
if err != nil {
l.ErrorContext(ctx, "running init disable", slogutil.KeyError, err)
os.Exit(osutil.ExitCodeFailure)
@ -458,10 +459,11 @@ func configureService(c *service.Config) {
// runInitdCommand runs init.d service command
// returns command code or error if any
func runInitdCommand(action string) (int, error) {
func runInitdCommand(ctx context.Context, action string) (int, error) {
confPath := "/etc/init.d/" + serviceName
// Pass the script and action as a single string argument.
code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action)
cmdCons := executil.SystemCommandConstructor{}
code, _, err := aghos.RunCommand(ctx, cmdCons, "sh", "-c", confPath+" "+action)
return code, err
}

View File

@ -4,12 +4,14 @@ package home
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"os/exec"
"strings"
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/kardianos/service"
)
@ -85,7 +87,9 @@ func (svc *sysvService) Install() (err error) {
return err
}
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "defaults")
cmdCons := executil.SystemCommandConstructor{}
// TODO(s.chzhen): Pass context.
_, _, err = aghos.RunCommand(context.TODO(), cmdCons, "update-rc.d", svc.name, "defaults")
// Don't wrap an error since it's informative enough as is.
return err
@ -100,7 +104,9 @@ func (svc *sysvService) Uninstall() (err error) {
return err
}
_, _, err = aghos.RunCommand("update-rc.d", svc.name, "remove")
cmdCons := executil.SystemCommandConstructor{}
// TODO(s.chzhen): Pass context.
_, _, err = aghos.RunCommand(context.TODO(), cmdCons, "update-rc.d", svc.name, "remove")
// Don't wrap an error since it's informative enough as is.
return err
@ -150,26 +156,36 @@ var _ service.Service = (*systemdService)(nil)
// Status implements the [service.Service] interface for *systemdService.
func (s *systemdService) Status() (status service.Status, err error) {
cmd := exec.Command("systemctl", "show", s.unitName)
stdout, err := cmd.StdoutPipe()
if err != nil {
return service.StatusUnknown, fmt.Errorf("connecting to command stdout: %w", err)
}
const (
systemctlCmd = "systemctl"
)
if err = cmd.Start(); err != nil {
return service.StatusUnknown, fmt.Errorf("start command executing: %w", err)
}
var (
systemctlArgs = []string{"show", s.unitName}
systemctlStdout bytes.Buffer
)
status, err = parseSystemctlShow(stdout)
if err != nil {
return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err)
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
err = cmd.Wait()
err = executil.Run(
ctx,
executil.SystemCommandConstructor{},
&executil.CommandConfig{
Path: systemctlCmd,
Args: systemctlArgs,
Stdout: &systemctlStdout,
},
)
if err != nil {
return service.StatusUnknown, fmt.Errorf("executing command: %w", err)
}
status, err = parseSystemctlShow(&systemctlStdout)
if err != nil {
return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err)
}
return status, nil
}

View File

@ -4,6 +4,7 @@ package home
import (
"cmp"
"context"
"fmt"
"os"
"os/signal"
@ -15,6 +16,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghos"
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/log"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/kardianos/service"
)
@ -210,8 +212,11 @@ func (s *openbsdRunComService) configureSysStartup(enable bool) (err error) {
cmd = "disable"
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
cmdCons := executil.SystemCommandConstructor{}
var code int
code, _, err = aghos.RunCommand("rcctl", cmd, s.cfg.Name)
code, _, err = aghos.RunCommand(ctx, cmdCons, "rcctl", cmd, s.cfg.Name)
if err != nil {
return err
} else if code != 0 {
@ -312,11 +317,15 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) {
return "", err
}
// TODO(s.chzhen): Pass context.
ctx := context.TODO()
cmdCons := executil.SystemCommandConstructor{}
// TODO(e.burkov): It's possible that os.ErrNotExist is caused by
// something different than the service script's non-existence. Keep it
// in mind, when replace the aghos.RunCommand.
var outData []byte
_, outData, err = aghos.RunCommand(scriptPath, cmd)
_, outData, err = aghos.RunCommand(ctx, cmdCons, scriptPath, cmd)
if errors.Is(err, os.ErrNotExist) {
return "", service.ErrNotInstalled
}

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -58,13 +59,14 @@ func TestUpdater_VersionInfo(t *testing.T) {
fakeURL := srvURL.JoinPath("adguardhome", version.ChannelBeta, "version.json")
u := updater.NewUpdater(&updater.Config{
Client: srv.Client(),
Logger: testLogger,
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "arm",
GOOS: "linux",
VersionCheckURL: fakeURL,
Client: srv.Client(),
Logger: testLogger,
CommandConstructor: executil.EmptyCommandConstructor{},
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOARCH: "arm",
GOOS: "linux",
VersionCheckURL: fakeURL,
})
ctx := testutil.ContextWithTimeout(t, testTimeout)
@ -132,15 +134,16 @@ func TestUpdater_VersionInfo_others(t *testing.T) {
for _, tc := range testCases {
u := updater.NewUpdater(&updater.Config{
Client: fakeClient,
Logger: testLogger,
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOOS: "linux",
GOARCH: tc.arch,
GOARM: tc.arm,
GOMIPS: tc.mips,
VersionCheckURL: fakeURL,
Client: fakeClient,
Logger: testLogger,
CommandConstructor: executil.EmptyCommandConstructor{},
Version: "v0.103.0-beta.1",
Channel: version.ChannelBeta,
GOOS: "linux",
GOARCH: tc.arch,
GOARM: tc.arm,
GOMIPS: tc.mips,
VersionCheckURL: fakeURL,
})
ctx := testutil.ContextWithTimeout(t, testTimeout)

View File

@ -4,6 +4,7 @@ package updater
import (
"archive/tar"
"archive/zip"
"bytes"
"compress/gzip"
"context"
"fmt"
@ -13,7 +14,6 @@ import (
"net/http"
"net/url"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
@ -26,6 +26,7 @@ import (
"github.com/AdguardTeam/golibs/ioutil"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/netutil/urlutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
// Updater is the AdGuard Home updater.
@ -33,6 +34,8 @@ type Updater struct {
client *http.Client
logger *slog.Logger
cmdCons executil.CommandConstructor
version string
channel string
goarch string
@ -88,6 +91,9 @@ type Config struct {
// be nil, see [DefaultVersionURL].
VersionCheckURL *url.URL
// CommandConstructor is used to run external commands. It must not be nil.
CommandConstructor executil.CommandConstructor
// Version is the current AdGuard Home version. It must not be empty.
Version string
@ -129,6 +135,8 @@ func NewUpdater(conf *Config) *Updater {
client: conf.Client,
logger: conf.Logger,
cmdCons: conf.CommandConstructor,
version: conf.Version,
channel: conf.Channel,
goarch: conf.GOARCH,
@ -286,11 +294,27 @@ func (u *Updater) check(ctx context.Context) (err error) {
"%s" +
"end of the output"
cmd := exec.Command(u.updateExeName, "--check-config")
out, err := cmd.CombinedOutput()
code := cmd.ProcessState.ExitCode()
if err != nil || code != 0 {
return fmt.Errorf(format, err, code, out)
var (
args = []string{"--check-config"}
buf bytes.Buffer
)
u.logger.DebugContext(ctx, "executing", "cmd", u.updateExeName, "args", args)
err = executil.Run(
ctx,
u.cmdCons,
&executil.CommandConfig{
Path: u.updateExeName,
Args: args,
Stdout: &buf,
Stderr: &buf,
},
)
if err != nil {
code, _ := executil.ExitCodeFromError(err)
return fmt.Errorf(format, err, code, buf.Bytes())
}
return nil

View File

@ -10,6 +10,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/aghtest"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -58,13 +59,14 @@ func TestUpdater_internal(t *testing.T) {
fakeURL = fakeURL.JoinPath(tc.archiveName)
u := NewUpdater(&Config{
Client: fakeClient,
Logger: slogutil.NewDiscardLogger(),
GOOS: tc.os,
Version: "v0.103.0",
ExecPath: exePath,
WorkDir: wd,
ConfName: yamlPath,
Client: fakeClient,
Logger: slogutil.NewDiscardLogger(),
CommandConstructor: executil.EmptyCommandConstructor{},
GOOS: tc.os,
Version: "v0.103.0",
ExecPath: exePath,
WorkDir: wd,
ConfName: yamlPath,
// TODO(e.burkov): Rewrite the test to use a fake version check
// URL with a fake URLs for the package files.
VersionCheckURL: &url.URL{},

View File

@ -15,6 +15,7 @@ import (
"github.com/AdguardTeam/AdGuardHome/internal/updater"
"github.com/AdguardTeam/AdGuardHome/internal/version"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil/executil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -76,15 +77,16 @@ func TestUpdater_Update(t *testing.T) {
require.NoError(t, err)
u := updater.NewUpdater(&updater.Config{
Client: srv.Client(),
Logger: testLogger,
GOARCH: "amd64",
GOOS: "linux",
Version: "v0.103.0",
ConfName: yamlPath,
WorkDir: wd,
ExecPath: exePath,
VersionCheckURL: versionCheckURL,
Client: srv.Client(),
Logger: testLogger,
CommandConstructor: executil.EmptyCommandConstructor{},
GOARCH: "amd64",
GOOS: "linux",
Version: "v0.103.0",
ConfName: yamlPath,
WorkDir: wd,
ExecPath: exePath,
VersionCheckURL: versionCheckURL,
})
ctx := testutil.ContextWithTimeout(t, testTimeout)

View File

@ -13,7 +13,6 @@ import (
"maps"
"net/url"
"os"
"os/exec"
"path/filepath"
"slices"
"strings"
@ -23,6 +22,7 @@ import (
"github.com/AdguardTeam/golibs/errors"
"github.com/AdguardTeam/golibs/logutil/slogutil"
"github.com/AdguardTeam/golibs/osutil"
"github.com/AdguardTeam/golibs/osutil/executil"
)
const (
@ -96,7 +96,7 @@ func main() {
errors.Check(cli.upload())
case "auto-add":
err := autoAdd(conf.LocalizableFiles[0])
err := autoAdd(ctx, l, conf.LocalizableFiles[0])
errors.Check(err)
default:
usage("unknown command")
@ -395,10 +395,12 @@ func findUnused(fileNames []string, loc locales) (err error) {
// autoAdd adds locales with additions to the git and restores locales with
// deletions.
func autoAdd(basePath string) (err error) {
func autoAdd(ctx context.Context, l *slog.Logger, basePath string) (err error) {
defer func() { err = errors.Annotate(err, "auto add: %w") }()
adds, dels, err := changedLocales()
cmdCons := executil.SystemCommandConstructor{}
adds, dels, err := changedLocales(ctx, l, cmdCons)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return err
@ -408,13 +410,13 @@ func autoAdd(basePath string) (err error) {
return errors.Error("base locale contains deletions")
}
err = handleAdds(adds)
err = handleAdds(ctx, l, cmdCons, adds)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil
}
err = handleDels(dels)
err = handleDels(ctx, l, cmdCons, dels)
if err != nil {
// Don't wrap the error since it's informative enough as is.
return nil
@ -423,14 +425,24 @@ func autoAdd(basePath string) (err error) {
return nil
}
// gitCmd is the shell command for Git.
const gitCmd = "git"
// handleAdds adds locales with additions to the git.
func handleAdds(locales []string) (err error) {
func handleAdds(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
locales []string,
) (err error) {
if len(locales) == 0 {
return nil
}
args := append([]string{"add"}, locales...)
code, out, err := aghos.RunCommand("git", args...)
gitArgs := append([]string{"add"}, locales...)
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...)
if err != nil || code != 0 {
return fmt.Errorf("git add exited with code %d output %q: %w", code, out, err)
@ -440,13 +452,20 @@ func handleAdds(locales []string) (err error) {
}
// handleDels restores locales with deletions.
func handleDels(locales []string) (err error) {
func handleDels(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
locales []string,
) (err error) {
if len(locales) == 0 {
return nil
}
args := append([]string{"restore"}, locales...)
code, out, err := aghos.RunCommand("git", args...)
gitArgs := append([]string{"restore"}, locales...)
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...)
if err != nil || code != 0 {
return fmt.Errorf("git restore exited with code %d output %q: %w", code, out, err)
@ -458,22 +477,27 @@ func handleDels(locales []string) (err error) {
// changedLocales returns cleaned paths of locales with changes or error. adds
// is the list of locales with only additions. dels is the list of locales
// with only deletions.
func changedLocales() (adds, dels []string, err error) {
func changedLocales(
ctx context.Context,
l *slog.Logger,
cmdCons executil.CommandConstructor,
) (adds, dels []string, err error) {
defer func() { err = errors.Annotate(err, "getting changes: %w") }()
cmd := exec.Command("git", "diff", "--numstat", localesDir)
gitArgs := []string{"diff", "--numstat", localesDir}
l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs)
stdout, err := cmd.StdoutPipe()
var out bytes.Buffer
err = executil.Run(ctx, cmdCons, &executil.CommandConfig{
Path: gitCmd,
Args: gitArgs,
Stdout: &out,
})
if err != nil {
return nil, nil, fmt.Errorf("piping: %w", err)
return nil, nil, fmt.Errorf("executing cmd: %w", err)
}
err = cmd.Start()
if err != nil {
return nil, nil, fmt.Errorf("starting: %w", err)
}
scanner := bufio.NewScanner(stdout)
scanner := bufio.NewScanner(&out)
for scanner.Scan() {
line := scanner.Text()
@ -497,10 +521,5 @@ func changedLocales() (adds, dels []string, err error) {
return nil, nil, fmt.Errorf("scanning: %w", err)
}
err = cmd.Wait()
if err != nil {
return nil, nil, fmt.Errorf("waiting: %w", err)
}
return adds, dels, nil
}