diff --git a/internal/agh/agh.go b/internal/agh/agh.go index 8fbd740f..e77c91bb 100644 --- a/internal/agh/agh.go +++ b/internal/agh/agh.go @@ -3,8 +3,18 @@ package agh import ( "context" + "fmt" + "strings" + + "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" + "github.com/AdguardTeam/golibs/testutil/fakeos/fakeexec" ) +// DefaultOutputLimit is the default limit of bytes for commands' standard +// output and standard error. +const DefaultOutputLimit = 512 + // ConfigModifier defines an interface for updating the global configuration. type ConfigModifier interface { // Apply applies changes to the global configuration. @@ -20,3 +30,142 @@ var _ ConfigModifier = EmptyConfigModifier{} // Apply implements the [ConfigModifier] for EmptyConfigModifier. func (em EmptyConfigModifier) Apply(ctx context.Context) {} + +// exitErr implements [executil.ExitCodeError] for tests to simulate non-zero +// process exit codes. +// +// TODO(s.chzhen): Consider constructing an [exec.ExitError] instead. +type exitErr struct { + code osutil.ExitCode +} + +// newExitErr returns a properly initialized exitErr with the provided code. +func newExitErr(code osutil.ExitCode) (err exitErr) { + return exitErr{code: code} +} + +// type check +var _ executil.ExitCodeError = exitErr{} + +// Error implements the [executil.ExitCodeError] for exitErr. +func (e exitErr) Error() (s string) { + return fmt.Sprintf("exit code %d", e.code) +} + +// ExitCode implements the [executil.ExitCodeError] for exitErr. +func (e exitErr) ExitCode() (code osutil.ExitCode) { + return e.code +} + +// ExternalCommand is a fake command used by [NewMultipleCommandConstructor]. +type ExternalCommand struct { + // Err is the error returned, if non-nil. + Err error + + // Cmd contains the command path and arguments. + Cmd string + + // Out is written to stdout if non-empty. + Out string + + // Code is returned as the exit code if non-zero. + Code osutil.ExitCode +} + +// keyCommand builds a key for a command lookup. +func keyCommand(path string, args []string) (k string) { + if len(args) == 0 { + return path + } + + return path + " " + strings.Join(args, " ") +} + +// parseCommand splits a command string into the executable path and args. +func parseCommand(s string) (path string, args []string) { + f := strings.Fields(s) + if len(f) == 0 { + return "", nil + } + + return f[0], f[1:] +} + +// NewMultipleCommandConstructor is a helper function that returns a mock +// [executil.CommandConstructor] for tests that supports multiple commands. +// +// TODO(s.chzhen): Move to aghtest once the import cycle is resolved, since it +// will be called from the aghnet package, which imports the whois package, +// which in turn imports aghnet. +func NewMultipleCommandConstructor(cmds ...ExternalCommand) (cs executil.CommandConstructor) { + table := make(map[string]ExternalCommand, len(cmds)) + for _, ec := range cmds { + p, a := parseCommand(ec.Cmd) + table[keyCommand(p, a)] = ec + } + + onNew := func(_ context.Context, conf *executil.CommandConfig) (c executil.Command, err error) { + ec := table[keyCommand(conf.Path, conf.Args)] + + cmd := fakeexec.NewCommand() + cmd.OnStart = func(_ context.Context) (err error) { + if ec.Out != "" { + _, _ = conf.Stdout.Write([]byte(ec.Out)) + } + + return nil + } + + cmd.OnWait = func(_ context.Context) (err error) { + if ec.Err != nil { + return ec.Err + } + + if ec.Code != 0 { + return newExitErr(ec.Code) + } + + return nil + } + + return cmd, nil + } + + return &fakeexec.CommandConstructor{OnNew: onNew} +} + +// NewCommandConstructor is a helper function that returns a mock +// [executil.CommandConstructor] for tests. +func NewCommandConstructor( + _ string, + code osutil.ExitCode, + stdout string, + cmdErr error, +) (cs executil.CommandConstructor) { + onNew := func(_ context.Context, conf *executil.CommandConfig) (c executil.Command, err error) { + cmd := fakeexec.NewCommand() + cmd.OnStart = func(_ context.Context) (err error) { + if conf.Stdout != nil { + _, _ = conf.Stdout.Write([]byte(stdout)) + } + + return nil + } + + cmd.OnWait = func(_ context.Context) (err error) { + if cmdErr != nil { + return cmdErr + } + + if code != 0 { + return newExitErr(code) + } + + return nil + } + + return cmd, nil + } + + return &fakeexec.CommandConstructor{OnNew: onNew} +} diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index 71e81b02..9eb32604 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -18,6 +18,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) // DialContextFunc is the semantic alias for dialing functions, such as @@ -27,7 +28,16 @@ type DialContextFunc = func(ctx context.Context, network, addr string) (conn net // Variables and functions to substitute in tests. var ( // aghosRunCommand is the function to run shell commands. - aghosRunCommand = aghos.RunCommand + // + // TODO(s.chzhen): Use [aghos.RunCommand] directly. + aghosRunCommand = (func() func(string, ...string) (int, []byte, error) { + ctx := context.TODO() + cmdCons := executil.SystemCommandConstructor{} + + return func(command string, arguments ...string) (int, []byte, error) { + return aghos.RunCommand(ctx, cmdCons, command, arguments...) + } + })() // netInterfaces is the function to get the available network interfaces. netInterfaceAddrs = net.InterfaceAddrs diff --git a/internal/aghos/os.go b/internal/aghos/os.go index 029df885..0eb5971a 100644 --- a/internal/aghos/os.go +++ b/internal/aghos/os.go @@ -5,13 +5,13 @@ package aghos import ( "bufio" + "bytes" "context" "fmt" "io" "io/fs" "log/slog" "os" - "os/exec" "path" "runtime" "slices" @@ -19,6 +19,9 @@ import ( "strings" "github.com/AdguardTeam/golibs/errors" + "github.com/AdguardTeam/golibs/ioutil" + "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) // Default file, binary, and directory permissions. @@ -50,23 +53,53 @@ func HaveAdminRights() (bool, error) { const MaxCmdOutputSize = 64 * 1024 // RunCommand runs shell command. -func RunCommand(command string, arguments ...string) (code int, output []byte, err error) { - cmd := exec.Command(command, arguments...) - out, err := cmd.Output() +// +// TODO(s.chzhen): Consider removing this after addressing the current behavior +// where a non-zero exit code is returned together with a nil error. +func RunCommand( + ctx context.Context, + cmdCons executil.CommandConstructor, + command string, + arguments ...string, +) (code int, output []byte, err error) { + stdoutBuf := bytes.Buffer{} + stderrBuf := bytes.Buffer{} - out = out[:min(len(out), MaxCmdOutputSize)] + err = executil.Run( + ctx, + cmdCons, + &executil.CommandConfig{ + Path: command, + Args: arguments, + Stdout: ioutil.NewTruncatedWriter(&stdoutBuf, MaxCmdOutputSize), + Stderr: &stderrBuf, + }, + ) - if err != nil { - if eerr := new(exec.ExitError); errors.As(err, &eerr) { - return eerr.ExitCode(), eerr.Stderr, nil - } - - return 1, nil, fmt.Errorf("command %q failed: %w: %s", command, err, out) + if err == nil { + return osutil.ExitCodeSuccess, stdoutBuf.Bytes(), nil } - return cmd.ProcessState.ExitCode(), out, nil + code, ok := executil.ExitCodeFromError(err) + if ok { + // Mirror the old behavior and return a nil-error on non-zero code + // status. + return code, stderrBuf.Bytes(), nil + } + + code = osutil.ExitCodeFailure + + return code, nil, fmt.Errorf("command %q failed: %w: %s", command, err, &stdoutBuf) } +// psArgs holds the default ps arguments to avoid per-call slice allocations. +// +// Don't use -C flag here since it's a feature of linux's ps +// implementation. Use POSIX-compatible flags instead. +// +// See https://github.com/AdguardTeam/AdGuardHome/issues/3457. +var psArgs = []string{"-A", "-o", "pid=", "-o", "comm="} + // PIDByCommand searches for process named command and returns its PID ignoring // the PIDs from except. If no processes found, the error returned. l must not // be nil. @@ -76,31 +109,35 @@ func PIDByCommand( command string, except ...int, ) (pid int, err error) { - // Don't use -C flag here since it's a feature of linux's ps - // implementation. Use POSIX-compatible flags instead. + const psCmd = "ps" + + l.DebugContext(ctx, "executing", "cmd", psCmd, "args", psArgs) + + stdoutBuf := bytes.Buffer{} + + // TODO(s.chzhen): Catch stderr. // - // See https://github.com/AdguardTeam/AdGuardHome/issues/3457. - cmd := exec.Command("ps", "-A", "-o", "pid=", "-o", "comm=") - - var stdout io.ReadCloser - if stdout, err = cmd.StdoutPipe(); err != nil { - return 0, fmt.Errorf("getting the command's stdout pipe: %w", err) - } - - if err = cmd.Start(); err != nil { - return 0, fmt.Errorf("start command executing: %w", err) - } + // TODO(s.chzhen): Consider streaming the output if needed. Using + // [io.Pipe] here is unnecessary; it complicates lifecycle management + // because the output must be read concurrently, and the PipeWriter must be + // explicitly closed to signal EOF. Since this command's output is small, a + // bytes.Buffer via executil.Run is sufficient. + runErr := executil.Run( + ctx, + executil.SystemCommandConstructor{}, + &executil.CommandConfig{ + Path: psCmd, + Args: psArgs, + Stdout: &stdoutBuf, + }, + ) var instNum int - pid, instNum, err = parsePSOutput(stdout, command, except) + pid, instNum, err = parsePSOutput(&stdoutBuf, command, except) if err != nil { return 0, err } - if err = cmd.Wait(); err != nil { - return 0, fmt.Errorf("executing the command: %w", err) - } - switch instNum { case 0: // TODO(e.burkov): Use constant error. @@ -111,8 +148,12 @@ func PIDByCommand( l.WarnContext(ctx, "instances found", "num", instNum, "command", command) } - if code := cmd.ProcessState.ExitCode(); code != 0 { - return 0, fmt.Errorf("ps finished with code %d", code) + if runErr != nil { + if code, ok := executil.ExitCodeFromError(runErr); ok { + return 0, fmt.Errorf("ps finished with code %d", code) + } + + return 0, fmt.Errorf("executing the command: %w", runErr) } return pid, nil diff --git a/internal/arpdb/arpdb.go b/internal/arpdb/arpdb.go index 13cabd0c..a2a08238 100644 --- a/internal/arpdb/arpdb.go +++ b/internal/arpdb/arpdb.go @@ -4,6 +4,7 @@ package arpdb import ( "bufio" "bytes" + "context" "fmt" "log/slog" "net" @@ -11,18 +12,16 @@ import ( "slices" "sync" - "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" + "github.com/AdguardTeam/golibs/service" ) // Variables and functions to substitute in tests. var ( - // aghosRunCommand is the function to run shell commands. - aghosRunCommand = aghos.RunCommand - // rootDirFS is the filesystem pointing to the root directory. rootDirFS = osutil.RootDirFS() ) @@ -30,17 +29,17 @@ var ( // Interface stores and refreshes the network neighborhood reported by ARP // (Address Resolution Protocol). type Interface interface { - // Refresh updates the stored data. It must be safe for concurrent use. - Refresh() (err error) + // Refresher updates the stored data. It must be safe for concurrent use. + service.Refresher - // Neighbors returnes the last set of data reported by ARP. Both the method + // Neighbors returns the last set of data reported by ARP. Both the method // and it's result must be safe for concurrent use. Neighbors() (ns []Neighbor) } // New returns the [Interface] properly initialized for the OS. func New(logger *slog.Logger) (arp Interface) { - return newARPDB(logger) + return newARPDB(logger, executil.SystemCommandConstructor{}) } // Empty is the [Interface] implementation that does nothing. @@ -51,7 +50,7 @@ var _ Interface = Empty{} // Refresh implements the [Interface] interface for EmptyARPContainer. It does // nothing and always returns nil error. -func (Empty) Refresh() (err error) { return nil } +func (Empty) Refresh(_ context.Context) (err error) { return nil } // Neighbors implements the [Interface] interface for EmptyARPContainer. It // always returns nil. @@ -164,28 +163,40 @@ type parseNeighsFunc func(logger *slog.Logger, sc *bufio.Scanner, lenHint int) ( // cmdARPDB is the implementation of the [Interface] that uses command line to // retrieve data. type cmdARPDB struct { - logger *slog.Logger - parse parseNeighsFunc - ns *neighs - cmd string - args []string + logger *slog.Logger + cmdCons executil.CommandConstructor + parse parseNeighsFunc + ns *neighs + cmd string + args []string } // type check var _ Interface = (*cmdARPDB)(nil) // Refresh implements the [Interface] interface for *cmdARPDB. -func (arp *cmdARPDB) Refresh() (err error) { +func (arp *cmdARPDB) Refresh(ctx context.Context) (err error) { defer func() { err = errors.Annotate(err, "cmd arpdb: %w") }() - code, out, err := aghosRunCommand(arp.cmd, arp.args...) + var stdout bytes.Buffer + err = executil.Run( + ctx, + arp.cmdCons, + &executil.CommandConfig{ + Path: arp.cmd, + Args: arp.args, + Stdout: &stdout, + }, + ) if err != nil { + if code, ok := executil.ExitCodeFromError(err); ok { + return fmt.Errorf("running command: unexpected exit code %d", code) + } + return fmt.Errorf("running command: %w", err) - } else if code != 0 { - return fmt.Errorf("running command: unexpected exit code %d", code) } - sc := bufio.NewScanner(bytes.NewReader(out)) + sc := bufio.NewScanner(&stdout) ns := arp.parse(arp.logger, sc, arp.ns.len()) if err = sc.Err(); err != nil { // TODO(e.burkov): This error seems unreachable. Investigate. @@ -226,11 +237,11 @@ func newARPDBs(arps ...Interface) (arp *arpdbs) { var _ Interface = (*arpdbs)(nil) // Refresh implements the [Interface] interface for *arpdbs. -func (arp *arpdbs) Refresh() (err error) { +func (arp *arpdbs) Refresh(ctx context.Context) (err error) { var errs []error for _, a := range arp.arps { - err = a.Refresh() + err = a.Refresh(ctx) if err != nil { errs = append(errs, err) diff --git a/internal/arpdb/arpdb_bsd.go b/internal/arpdb/arpdb_bsd.go index e7357603..922b8c99 100644 --- a/internal/arpdb/arpdb_bsd.go +++ b/internal/arpdb/arpdb_bsd.go @@ -9,12 +9,14 @@ import ( "sync" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) -func newARPDB(logger *slog.Logger) (arp *cmdARPDB) { +func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) { return &cmdARPDB{ - logger: logger, - parse: parseArpA, + logger: logger, + cmdCons: cmdCons, + parse: parseArpA, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), diff --git a/internal/arpdb/arpdb_internal_test.go b/internal/arpdb/arpdb_internal_test.go index 39dc666e..e9f307ae 100644 --- a/internal/arpdb/arpdb_internal_test.go +++ b/internal/arpdb/arpdb_internal_test.go @@ -1,15 +1,16 @@ package arpdb import ( - "fmt" + "context" "io/fs" "net" "net/netip" "os" - "strings" "sync" "testing" + "time" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/testutil" @@ -17,49 +18,15 @@ import ( "github.com/stretchr/testify/require" ) +// testTimeout is a common timeout for tests. +const testTimeout = 1 * time.Second + // testdata is the filesystem containing data for testing the package. var testdata fs.FS = os.DirFS("./testdata") // RunCmdFunc is the signature of aghos.RunCommand function. type RunCmdFunc func(cmd string, args ...string) (code int, out []byte, err error) -// substShell replaces the the aghos.RunCommand function used throughout the -// package with rc for tests ran under tb. -func substShell(tb testing.TB, rc RunCmdFunc) { - tb.Helper() - - prev := aghosRunCommand - tb.Cleanup(func() { aghosRunCommand = prev }) - aghosRunCommand = rc -} - -// mapShell is a substitution of aghos.RunCommand that maps the command to it's -// execution result. It's only needed to simplify testing. -// -// TODO(e.burkov): Perhaps put all the shell interactions behind an interface. -type mapShell map[string]struct { - err error - out string - code int -} - -// theOnlyCmd returns mapShell that only handles a single command and arguments -// combination from cmd. -func theOnlyCmd(cmd string, code int, out string, err error) (s mapShell) { - return mapShell{cmd: {code: code, out: out, err: err}} -} - -// RunCmd is a RunCmdFunc handled by s. -func (s mapShell) RunCmd(cmd string, args ...string) (code int, out []byte, err error) { - key := strings.Join(append([]string{cmd}, args...), " ") - ret, ok := s[key] - if !ok { - return 0, nil, fmt.Errorf("unexpected shell command %q", key) - } - - return ret.code, []byte(ret.out), ret.err -} - func Test_New(t *testing.T) { var a Interface require.NotPanics(t, func() { a = New(slogutil.NewDiscardLogger()) }) @@ -71,7 +38,7 @@ func Test_New(t *testing.T) { // TestARPDB is the mock implementation of [Interface] to use in tests. type TestARPDB struct { - OnRefresh func() (err error) + OnRefresh func(ctx context.Context) (err error) OnNeighbors func() (ns []Neighbor) } @@ -79,8 +46,8 @@ type TestARPDB struct { var _ Interface = (*TestARPDB)(nil) // Refresh implements the [Interface] interface for *TestARPDB. -func (arp *TestARPDB) Refresh() (err error) { - return arp.OnRefresh() +func (arp *TestARPDB) Refresh(ctx context.Context) (err error) { + return arp.OnRefresh(ctx) } // Neighbors implements the [Interface] interface for *TestARPDB. @@ -98,13 +65,17 @@ func Test_NewARPDBs(t *testing.T) { } succDB := &TestARPDB{ - OnRefresh: func() (err error) { succRefrCount++; return nil }, + OnRefresh: func(_ context.Context) (err error) { succRefrCount++; return nil }, OnNeighbors: func() (ns []Neighbor) { return []Neighbor{{Name: "abc", IP: knownIP, MAC: knownMAC}} }, } failDB := &TestARPDB{ - OnRefresh: func() (err error) { failRefrCount++; return errors.Error("refresh failed") }, + OnRefresh: func(_ context.Context) (err error) { + failRefrCount++ + + return errors.Error("refresh failed") + }, OnNeighbors: func() (ns []Neighbor) { return nil }, } @@ -112,7 +83,7 @@ func Test_NewARPDBs(t *testing.T) { t.Cleanup(clnp) a := newARPDBs(succDB, failDB) - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Equal(t, 1, succRefrCount) @@ -124,7 +95,7 @@ func Test_NewARPDBs(t *testing.T) { t.Cleanup(clnp) a := newARPDBs(failDB, succDB) - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Equal(t, 1, succRefrCount) @@ -138,7 +109,7 @@ func Test_NewARPDBs(t *testing.T) { wantMsg := "each arpdb failed: refresh failed\nrefresh failed" a := newARPDBs(failDB, failDB) - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.Error(t, err) testutil.AssertErrorMsg(t, wantMsg, err) @@ -152,7 +123,7 @@ func Test_NewARPDBs(t *testing.T) { shouldFail := false unstableDB := &TestARPDB{ - OnRefresh: func() (err error) { + OnRefresh: func(_ context.Context) (err error) { if shouldFail { err = errors.Error("unstable failed") } @@ -171,21 +142,21 @@ func Test_NewARPDBs(t *testing.T) { a := newARPDBs(unstableDB, succDB) // Unstable ARPDB should refresh successfully. - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Zero(t, succRefrCount) assert.NotEmpty(t, a.Neighbors()) // Unstable ARPDB should fail and the succDB should be used. - err = a.Refresh() + err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Equal(t, 1, succRefrCount) assert.NotEmpty(t, a.Neighbors()) // Unstable ARPDB should refresh successfully again. - err = a.Refresh() + err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Equal(t, 1, succRefrCount) @@ -194,7 +165,7 @@ func Test_NewARPDBs(t *testing.T) { t.Run("empty", func(t *testing.T) { a := newARPDBs() - require.NoError(t, a.Refresh()) + require.NoError(t, a.Refresh(testutil.ContextWithTimeout(t, testTimeout))) assert.Empty(t, a.Neighbors()) }) @@ -212,36 +183,36 @@ func TestCmdARPDB_arpa(t *testing.T) { } t.Run("arp_a", func(t *testing.T) { - sh := theOnlyCmd("cmd", 0, arpAOutput, nil) - substShell(t, sh.RunCmd) + a.cmdCons = agh.NewCommandConstructor("cmd", 0, arpAOutput, nil) - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Equal(t, wantNeighs, a.Neighbors()) }) t.Run("runcmd_error", func(t *testing.T) { - sh := theOnlyCmd("cmd", 0, "", errors.Error("can't run")) - substShell(t, sh.RunCmd) + a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", errors.Error("can't run")) - err := a.Refresh() - testutil.AssertErrorMsg(t, "cmd arpdb: running command: can't run", err) + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) + testutil.AssertErrorMsg(t, "cmd arpdb: running command: running: can't run", err) }) t.Run("bad_code", func(t *testing.T) { - sh := theOnlyCmd("cmd", 1, "", nil) - substShell(t, sh.RunCmd) + a.cmdCons = agh.NewCommandConstructor("cmd", 1, "", nil) - err := a.Refresh() - testutil.AssertErrorMsg(t, "cmd arpdb: running command: unexpected exit code 1", err) + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) + testutil.AssertErrorMsg( + t, + "cmd arpdb: running command: unexpected exit code 1", + err, + ) }) t.Run("empty", func(t *testing.T) { - sh := theOnlyCmd("cmd", 0, "", nil) - substShell(t, sh.RunCmd) + a.cmdCons = agh.NewCommandConstructor("cmd", 0, "", nil) - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Empty(t, a.Neighbors()) @@ -254,7 +225,7 @@ func TestEmptyARPDB(t *testing.T) { t.Run("refresh", func(t *testing.T) { var err error require.NotPanics(t, func() { - err = a.Refresh() + err = a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) }) assert.NoError(t, err) diff --git a/internal/arpdb/arpdb_linux.go b/internal/arpdb/arpdb_linux.go index 364d2ce9..a455411d 100644 --- a/internal/arpdb/arpdb_linux.go +++ b/internal/arpdb/arpdb_linux.go @@ -4,6 +4,7 @@ package arpdb import ( "bufio" + "context" "fmt" "io/fs" "log/slog" @@ -14,10 +15,11 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/AdguardTeam/golibs/stringutil" ) -func newARPDB(logger *slog.Logger) (arp *arpdbs) { +func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *arpdbs) { // Use the common storage among the implementations. ns := &neighs{ mu: &sync.RWMutex{}, @@ -40,10 +42,11 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) { }, // Then, try "arp -a -n". &cmdARPDB{ - logger: logger, - parse: parseF, - ns: ns, - cmd: "arp", + logger: logger, + cmdCons: cmdCons, + parse: parseF, + ns: ns, + cmd: "arp", // Use -n flag to avoid resolving the hostnames of the neighbors. // By default ARP attempts to resolve the hostnames via DNS. See // man 8 arp. @@ -53,11 +56,12 @@ func newARPDB(logger *slog.Logger) (arp *arpdbs) { }, // Finally, try "ip neigh". &cmdARPDB{ - logger: logger, - parse: parseIPNeigh, - ns: ns, - cmd: "ip", - args: []string{"neigh"}, + logger: logger, + cmdCons: cmdCons, + parse: parseIPNeigh, + ns: ns, + cmd: "ip", + args: []string{"neigh"}, }, ) } @@ -73,7 +77,7 @@ type fsysARPDB struct { var _ Interface = (*fsysARPDB)(nil) // Refresh implements the [Interface] interface for *fsysARPDB. -func (arp *fsysARPDB) Refresh() (err error) { +func (arp *fsysARPDB) Refresh(_ context.Context) (err error) { var f fs.File f, err = arp.fsys.Open(arp.filename) if err != nil { diff --git a/internal/arpdb/arpdb_linux_internal_test.go b/internal/arpdb/arpdb_linux_internal_test.go index cb0cb1d3..4d9c0083 100644 --- a/internal/arpdb/arpdb_linux_internal_test.go +++ b/internal/arpdb/arpdb_linux_internal_test.go @@ -9,7 +9,9 @@ import ( "testing" "testing/fstest" + "github.com/AdguardTeam/AdGuardHome/internal/agh" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -54,7 +56,7 @@ func TestFSysARPDB(t *testing.T) { filename: "proc_net_arp", } - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) ns := a.Neighbors() @@ -62,25 +64,20 @@ func TestFSysARPDB(t *testing.T) { } func TestCmdARPDB_linux(t *testing.T) { - sh := mapShell{ - "arp -a": {err: nil, out: arpAOutputWrt, code: 0}, - "ip neigh": {err: nil, out: ipNeighOutput, code: 0}, - } - substShell(t, sh.RunCmd) - t.Run("wrt", func(t *testing.T) { a := &cmdARPDB{ - logger: slogutil.NewDiscardLogger(), - parse: parseArpAWrt, - cmd: "arp", - args: []string{"-a"}, + logger: slogutil.NewDiscardLogger(), + cmdCons: agh.NewCommandConstructor("arp -a", 0, arpAOutputWrt, nil), + parse: parseArpAWrt, + cmd: "arp", + args: []string{"-a"}, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), }, } - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Equal(t, wantNeighs, a.Neighbors()) @@ -88,16 +85,17 @@ func TestCmdARPDB_linux(t *testing.T) { t.Run("ip_neigh", func(t *testing.T) { a := &cmdARPDB{ - logger: slogutil.NewDiscardLogger(), - parse: parseIPNeigh, - cmd: "ip", - args: []string{"neigh"}, + logger: slogutil.NewDiscardLogger(), + cmdCons: agh.NewCommandConstructor("ip neigh", 0, ipNeighOutput, nil), + parse: parseIPNeigh, + cmd: "ip", + args: []string{"neigh"}, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), }, } - err := a.Refresh() + err := a.Refresh(testutil.ContextWithTimeout(t, testTimeout)) require.NoError(t, err) assert.Equal(t, wantNeighs, a.Neighbors()) diff --git a/internal/arpdb/arpdb_openbsd.go b/internal/arpdb/arpdb_openbsd.go index 8d6ee657..40b45271 100644 --- a/internal/arpdb/arpdb_openbsd.go +++ b/internal/arpdb/arpdb_openbsd.go @@ -9,12 +9,14 @@ import ( "sync" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) -func newARPDB(logger *slog.Logger) (arp *cmdARPDB) { +func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) { return &cmdARPDB{ - logger: logger, - parse: parseArpA, + logger: logger, + cmdCons: cmdCons, + parse: parseArpA, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), diff --git a/internal/arpdb/arpdb_windows.go b/internal/arpdb/arpdb_windows.go index 878d1d31..f226d65b 100644 --- a/internal/arpdb/arpdb_windows.go +++ b/internal/arpdb/arpdb_windows.go @@ -9,12 +9,14 @@ import ( "sync" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) -func newARPDB(logger *slog.Logger) (arp *cmdARPDB) { +func newARPDB(logger *slog.Logger, cmdCons executil.CommandConstructor) (arp *cmdARPDB) { return &cmdARPDB{ - logger: logger, - parse: parseArpA, + logger: logger, + cmdCons: cmdCons, + parse: parseArpA, ns: &neighs{ mu: &sync.RWMutex{}, ns: make([]Neighbor, 0), diff --git a/internal/client/storage.go b/internal/client/storage.go index 4ec327ad..42953d90 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -249,7 +249,7 @@ func (s *Storage) addFromSystemARP(ctx context.Context) { s.mu.Lock() defer s.mu.Unlock() - if err := s.arpDB.Refresh(); err != nil { + if err := s.arpDB.Refresh(ctx); err != nil { s.arpDB = arpdb.Empty{} s.logger.ErrorContext(ctx, "refreshing arp container", slogutil.KeyError, err) diff --git a/internal/client/storage_test.go b/internal/client/storage_test.go index 658163b8..38bc4f56 100644 --- a/internal/client/storage_test.go +++ b/internal/client/storage_test.go @@ -1,6 +1,7 @@ package client_test import ( + "context" "net" "net/netip" "runtime" @@ -18,6 +19,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/hostsfile" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/service" "github.com/AdguardTeam/golibs/testutil" "github.com/AdguardTeam/golibs/testutil/faketime" "github.com/AdguardTeam/golibs/timeutil" @@ -63,17 +65,17 @@ func (c *testHostsContainer) Upd() (updates <-chan *hostsfile.DefaultStorage) { // Interface stores and refreshes the network neighborhood reported by ARP // (Address Resolution Protocol). type Interface interface { - // Refresh updates the stored data. It must be safe for concurrent use. - Refresh() (err error) + // Refresher updates the stored data. It must be safe for concurrent use. + service.Refresher - // Neighbors returnes the last set of data reported by ARP. Both the method + // Neighbors returns the last set of data reported by ARP. Both the method // and it's result must be safe for concurrent use. Neighbors() (ns []arpdb.Neighbor) } // testARPDB is a mock implementation of the [arpdb.Interface]. type testARPDB struct { - onRefresh func() (err error) + onRefresh func(ctx context.Context) (err error) onNeighbors func() (ns []arpdb.Neighbor) } @@ -81,8 +83,8 @@ type testARPDB struct { var _ arpdb.Interface = (*testARPDB)(nil) // Refresh implements the [arpdb.Interface] interface for *testARP. -func (c *testARPDB) Refresh() (err error) { - return c.onRefresh() +func (c *testARPDB) Refresh(ctx context.Context) (err error) { + return c.onRefresh(ctx) } // Neighbors implements the [arpdb.Interface] interface for *testARP. @@ -218,7 +220,7 @@ func TestStorage_Add_arp(t *testing.T) { ) a := &testARPDB{ - onRefresh: func() (err error) { return nil }, + onRefresh: func(_ context.Context) (err error) { return nil }, onNeighbors: func() (ns []arpdb.Neighbor) { mu.Lock() defer mu.Unlock() @@ -392,7 +394,7 @@ func TestClientsDHCP(t *testing.T) { arpCh := make(chan []arpdb.Neighbor, 1) arpDB := &testARPDB{ - onRefresh: func() (err error) { return nil }, + onRefresh: func(_ context.Context) (err error) { return nil }, onNeighbors: func() (ns []arpdb.Neighbor) { select { case ns = <-arpCh: diff --git a/internal/home/controlinstall.go b/internal/home/controlinstall.go index 8ea1cfa2..f214158a 100644 --- a/internal/home/controlinstall.go +++ b/internal/home/controlinstall.go @@ -9,7 +9,6 @@ import ( "net/http" "net/netip" "os" - "os/exec" "path/filepath" "runtime" "time" @@ -20,8 +19,10 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/version" + "github.com/AdguardTeam/golibs/container" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/quic-go/quic-go/http3" ) @@ -129,6 +130,7 @@ func (req *checkConfReq) validateDNS( ctx context.Context, l *slog.Logger, tcpPorts aghalg.UniqChecker[tcpPort], + cmdCons executil.CommandConstructor, ) (canAutofix bool, err error) { defer func() { err = errors.Annotate(err, "validating ports: %w") }() @@ -160,7 +162,7 @@ func (req *checkConfReq) validateDNS( // Try to fix automatically. canAutofix = checkDNSStubListener(ctx, l) if canAutofix && req.DNS.Autofix { - if derr := disableDNSStubListener(ctx, l); derr != nil { + if derr := disableDNSStubListener(ctx, l, cmdCons); derr != nil { l.ErrorContext(ctx, "disabling DNSStubListener", slogutil.KeyError, err) } @@ -188,7 +190,8 @@ func (web *webAPI) handleInstallCheckConfig(w http.ResponseWriter, r *http.Reque resp.Web.Status = err.Error() } - if resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts); err != nil { + resp.DNS.CanAutofix, err = req.validateDNS(r.Context(), web.logger, tcpPorts, web.cmdCons) + if err != nil { resp.DNS.Status = err.Error() } else if !req.DNS.IP.IsUnspecified() { resp.StaticIP = handleStaticIP(req.DNS.IP, req.SetStaticIP) @@ -243,34 +246,29 @@ func checkDNSStubListener(ctx context.Context, l *slog.Logger) (ok bool) { return false } - cmd := exec.Command("systemctl", "is-enabled", "systemd-resolved") - l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args) - _, err := cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - l.InfoContext( + cmds := container.KeyValues[string, []string]{{ + Key: "systemctl", + Value: []string{"is-enabled", "systemd-resolved"}, + }, { + Key: "grep", + Value: []string{"-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf"}, + }} + + for _, cmd := range cmds { + l.DebugContext(ctx, "executing", "cmd", cmd.Key, "args", cmd.Value) + + err := executil.RunWithPeek( ctx, - "execution failed", - "cmd", cmd.Path, - "code", cmd.ProcessState.ExitCode(), - slogutil.KeyError, err, + executil.SystemCommandConstructor{}, + agh.DefaultOutputLimit, + cmd.Key, + cmd.Value..., ) + if err != nil { + l.InfoContext(ctx, "execution failed", "cmd", cmd.Key, slogutil.KeyError, err) - return false - } - - cmd = exec.Command("grep", "-E", "#?DNSStubListener=yes", "/etc/systemd/resolved.conf") - l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args) - _, err = cmd.Output() - if err != nil || cmd.ProcessState.ExitCode() != 0 { - l.InfoContext( - ctx, - "execution failed", - "cmd", cmd.Path, - "code", cmd.ProcessState.ExitCode(), - slogutil.KeyError, err, - ) - - return false + return false + } } return true @@ -287,7 +285,11 @@ const resolvConfPath = "/etc/resolv.conf" // disableDNSStubListener deactivates DNSStubListerner and returns an error, if // any. -func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) { +func disableDNSStubListener( + ctx context.Context, + l *slog.Logger, + cmdCons executil.CommandConstructor, +) (err error) { dir := filepath.Dir(resolvedConfPath) err = os.MkdirAll(dir, 0o755) if err != nil { @@ -306,15 +308,21 @@ func disableDNSStubListener(ctx context.Context, l *slog.Logger) (err error) { return fmt.Errorf("os.Symlink: %s: %w", resolvConfPath, err) } - cmd := exec.Command("systemctl", "reload-or-restart", "systemd-resolved") - l.DebugContext(ctx, "executing", "cmd", cmd.Path, "args", cmd.Args) - _, err = cmd.Output() + const systemctlCmd = "systemctl" + + systemctlArgs := []string{"reload-or-restart", "systemd-resolved"} + + l.DebugContext(ctx, "executing", "cmd", systemctlCmd, "args", systemctlArgs) + + err = executil.RunWithPeek( + ctx, + cmdCons, + agh.DefaultOutputLimit, + systemctlCmd, + systemctlArgs..., + ) if err != nil { - return err - } - if cmd.ProcessState.ExitCode() != 0 { - return fmt.Errorf("process %s exited with an error: %d", - cmd.Path, cmd.ProcessState.ExitCode()) + return fmt.Errorf("executing cmd: %w", err) } return nil diff --git a/internal/home/controlupdate.go b/internal/home/controlupdate.go index 4d390ec9..e791633d 100644 --- a/internal/home/controlupdate.go +++ b/internal/home/controlupdate.go @@ -7,7 +7,6 @@ import ( "log/slog" "net/http" "os" - "os/exec" "runtime" "syscall" "time" @@ -19,6 +18,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) // temporaryError is the interface for temporary errors from the Go standard @@ -148,7 +148,13 @@ func (web *webAPI) handleUpdate(w http.ResponseWriter, r *http.Request) { // The background context is used because the underlying functions wrap it // with timeout and shut down the server, which handles current request. It // also should be done in a separate goroutine for the same reason. - go finishUpdate(context.Background(), web.logger, execPath, web.conf.runningAsService) + go finishUpdate( + context.Background(), + web.logger, + web.cmdCons, + execPath, + web.conf.runningAsService, + ) } // versionResponse is the response for /control/version.json endpoint. @@ -187,7 +193,13 @@ func tlsConfUsesPrivilegedPorts(c *tlsConfigSettings) (ok bool) { // finishUpdate completes an update procedure. It is intended to be used as a // goroutine. -func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningAsService bool) { +func finishUpdate( + ctx context.Context, + l *slog.Logger, + cmdCons executil.CommandConstructor, + execPath string, + runningAsService bool, +) { defer slogutil.RecoverAndExit(ctx, l, osutil.ExitCodeFailure) l.InfoContext(ctx, "stopping all tasks") @@ -203,8 +215,16 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA // instance, because Windows doesn't allow it. // // TODO(a.garipov): Recheck the claim above. - cmd := exec.Command("cmd", "/c", "net stop AdGuardHome & net start AdGuardHome") - err = cmd.Start() + var cmd executil.Command + cmd, err = cmdCons.New(ctx, &executil.CommandConfig{ + Path: "cmd", + Args: []string{"/c", "net stop AdGuardHome & net start AdGuardHome"}, + }) + if err != nil { + panic(fmt.Errorf("constructing cmd: %w", err)) + } + + err = cmd.Start(ctx) if err != nil { panic(fmt.Errorf("restarting service: %w", err)) } @@ -212,12 +232,21 @@ func finishUpdate(ctx context.Context, l *slog.Logger, execPath string, runningA os.Exit(osutil.ExitCodeSuccess) } - cmd := exec.Command(execPath, os.Args[1:]...) l.InfoContext(ctx, "restarting", "exec_path", execPath, "args", os.Args[1:]) - cmd.Stdin = os.Stdin - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - err = cmd.Start() + + var cmd executil.Command + cmd, err = cmdCons.New(ctx, &executil.CommandConfig{ + Path: execPath, + Args: os.Args[1:], + Stdin: os.Stdin, + Stdout: os.Stdout, + Stderr: os.Stderr, + }) + if err != nil { + panic(fmt.Errorf("constructing cmd: %w", err)) + } + + err = cmd.Start(ctx) if err != nil { panic(fmt.Errorf("restarting: %w", err)) } diff --git a/internal/home/home.go b/internal/home/home.go index ac180ff7..f147403b 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -43,6 +43,7 @@ import ( "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) // Global context @@ -589,12 +590,13 @@ func initWeb( disableUpdate := !isUpdateEnabled(ctx, baseLogger, &opts, isCustomUpdURL) webConf := &webConfig{ - updater: upd, - logger: logger, - baseLogger: baseLogger, - confModifier: confModifier, - tlsManager: tlsMgr, - auth: auth, + CommandConstructor: executil.SystemCommandConstructor{}, + updater: upd, + logger: logger, + baseLogger: baseLogger, + confModifier: confModifier, + tlsManager: tlsMgr, + auth: auth, clientFS: clientFS, @@ -821,18 +823,19 @@ func newUpdater( l.DebugContext(ctx, "creating updater", "config_path", confPath) return updater.NewUpdater(&updater.Config{ - Client: conf.Filtering.HTTPClient, - Logger: l, - Version: version.Version(), - Channel: version.Channel(), - GOARCH: runtime.GOARCH, - GOOS: runtime.GOOS, - GOARM: version.GOARM(), - GOMIPS: version.GOMIPS(), - WorkDir: workDir, - ConfName: confPath, - ExecPath: execPath, - VersionCheckURL: versionURL, + Client: conf.Filtering.HTTPClient, + Logger: l, + CommandConstructor: executil.SystemCommandConstructor{}, + Version: version.Version(), + Channel: version.Channel(), + GOARCH: runtime.GOARCH, + GOOS: runtime.GOOS, + GOARM: version.GOARM(), + GOMIPS: version.GOMIPS(), + WorkDir: workDir, + ConfName: confPath, + ExecPath: execPath, + VersionCheckURL: versionURL, }), isCustomURL } diff --git a/internal/home/service.go b/internal/home/service.go index af3d4a70..e63b015f 100644 --- a/internal/home/service.go +++ b/internal/home/service.go @@ -18,6 +18,7 @@ import ( "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/kardianos/service" ) @@ -76,11 +77,11 @@ func (p *program) Stop(_ service.Service) (err error) { // // On OpenWrt, the service utility may not exist. We use our service script // directly in this case. -func svcStatus(s service.Service) (status service.Status, err error) { +func svcStatus(ctx context.Context, s service.Service) (status service.Status, err error) { status, err = s.Status() if err != nil && service.Platform() == "unix-systemv" { var code int - code, err = runInitdCommand("status") + code, err = runInitdCommand(ctx, "status") if err != nil || code != 0 { return service.StatusStopped, nil } @@ -105,7 +106,7 @@ func svcAction(ctx context.Context, l *slog.Logger, s service.Service, action st err = service.Control(s, action) if err != nil && service.Platform() == "unix-systemv" && (action == "start" || action == "stop" || action == "restart") { - _, err = runInitdCommand(action) + _, err = runInitdCommand(ctx, action) } return err @@ -326,7 +327,7 @@ func handleServiceStatusCommand( l *slog.Logger, s service.Service, ) { - status, errSt := svcStatus(s) + status, errSt := svcStatus(ctx, s) if errSt != nil { l.ErrorContext(ctx, "failed to get service status", slogutil.KeyError, errSt) os.Exit(osutil.ExitCodeFailure) @@ -356,7 +357,7 @@ func handleServiceInstallCommand(ctx context.Context, l *slog.Logger, s service. // On OpenWrt it is important to run enable after the service // installation. Otherwise, the service won't start on the system // startup. - _, err = runInitdCommand("enable") + _, err = runInitdCommand(ctx, "enable") if err != nil { l.ErrorContext(ctx, "running init enable", slogutil.KeyError, err) os.Exit(osutil.ExitCodeFailure) @@ -386,7 +387,7 @@ func handleServiceUninstallCommand(ctx context.Context, l *slog.Logger, s servic if aghos.IsOpenWrt() { // On OpenWrt it is important to run disable command first // as it will remove the symlink - _, err := runInitdCommand("disable") + _, err := runInitdCommand(ctx, "disable") if err != nil { l.ErrorContext(ctx, "running init disable", slogutil.KeyError, err) os.Exit(osutil.ExitCodeFailure) @@ -458,10 +459,11 @@ func configureService(c *service.Config) { // runInitdCommand runs init.d service command // returns command code or error if any -func runInitdCommand(action string) (int, error) { +func runInitdCommand(ctx context.Context, action string) (int, error) { confPath := "/etc/init.d/" + serviceName // Pass the script and action as a single string argument. - code, _, err := aghos.RunCommand("sh", "-c", confPath+" "+action) + cmdCons := executil.SystemCommandConstructor{} + code, _, err := aghos.RunCommand(ctx, cmdCons, "sh", "-c", confPath+" "+action) return code, err } diff --git a/internal/home/service_linux.go b/internal/home/service_linux.go index 7f5a0c53..a201877b 100644 --- a/internal/home/service_linux.go +++ b/internal/home/service_linux.go @@ -4,12 +4,14 @@ package home import ( "bufio" + "bytes" + "context" "fmt" "io" - "os/exec" "strings" "github.com/AdguardTeam/AdGuardHome/internal/aghos" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/kardianos/service" ) @@ -58,6 +60,7 @@ func (sys *sysvSystem) New(i service.Interface, c *service.Config) (s service.Se } return &sysvService{ + cmdCons: executil.SystemCommandConstructor{}, Service: s, name: c.Name, }, nil @@ -66,6 +69,9 @@ func (sys *sysvSystem) New(i service.Interface, c *service.Config) (s service.Se // sysvService is a wrapper for a SysV [service.Service] that supplements the // installation and uninstallation. type sysvService struct { + // cmdCons is used to run external commands. It must not be nil. + cmdCons executil.CommandConstructor + // Service must have an unexported type *service.sysv. service.Service @@ -85,7 +91,8 @@ func (svc *sysvService) Install() (err error) { return err } - _, _, err = aghos.RunCommand("update-rc.d", svc.name, "defaults") + // TODO(s.chzhen): Pass context. + _, _, err = aghos.RunCommand(context.TODO(), svc.cmdCons, "update-rc.d", svc.name, "defaults") // Don't wrap an error since it's informative enough as is. return err @@ -100,7 +107,8 @@ func (svc *sysvService) Uninstall() (err error) { return err } - _, _, err = aghos.RunCommand("update-rc.d", svc.name, "remove") + // TODO(s.chzhen): Pass context. + _, _, err = aghos.RunCommand(context.TODO(), svc.cmdCons, "update-rc.d", svc.name, "remove") // Don't wrap an error since it's informative enough as is. return err @@ -127,6 +135,7 @@ func (sys *systemdSystem) New(i service.Interface, c *service.Config) (s service } return &systemdService{ + cmdCons: executil.SystemCommandConstructor{}, Service: s, unitName: fmt.Sprintf("%s.service", c.Name), }, nil @@ -138,6 +147,9 @@ var _ service.Service = (*systemdService)(nil) // systemdService is a wrapper for a systemd [service.Service] that enriches the // service status information. type systemdService struct { + // cmdCons is used to run external commands. It must not be nil. + cmdCons executil.CommandConstructor + // Service is expected to have an unexported type *service.systemd. service.Service @@ -150,26 +162,37 @@ var _ service.Service = (*systemdService)(nil) // Status implements the [service.Service] interface for *systemdService. func (s *systemdService) Status() (status service.Status, err error) { - cmd := exec.Command("systemctl", "show", s.unitName) - stdout, err := cmd.StdoutPipe() - if err != nil { - return service.StatusUnknown, fmt.Errorf("connecting to command stdout: %w", err) - } + const systemctlCmd = "systemctl" - if err = cmd.Start(); err != nil { - return service.StatusUnknown, fmt.Errorf("start command executing: %w", err) - } + var ( + systemctlArgs = []string{"show", s.unitName} + systemctlStdout bytes.Buffer + ) - status, err = parseSystemctlShow(stdout) - if err != nil { - return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err) - } - - err = cmd.Wait() + // TODO(s.chzhen): Consider streaming the output if needed. Using + // [io.Pipe] here is unnecessary; it complicates lifecycle management + // because the output must be read concurrently, and the PipeWriter must be + // explicitly closed to signal EOF. Since this command's output is small, a + // bytes.Buffer via executil.Run is sufficient. + err = executil.Run( + // TODO(s.chzhen): Pass context. + context.TODO(), + s.cmdCons, + &executil.CommandConfig{ + Path: systemctlCmd, + Args: systemctlArgs, + Stdout: &systemctlStdout, + }, + ) if err != nil { return service.StatusUnknown, fmt.Errorf("executing command: %w", err) } + status, err = parseSystemctlShow(&systemctlStdout) + if err != nil { + return service.StatusUnknown, fmt.Errorf("parsing command output: %w", err) + } + return status, nil } diff --git a/internal/home/service_openbsd.go b/internal/home/service_openbsd.go index b4f29c7a..6d5fecc7 100644 --- a/internal/home/service_openbsd.go +++ b/internal/home/service_openbsd.go @@ -4,6 +4,7 @@ package home import ( "cmp" + "context" "fmt" "os" "os/signal" @@ -15,6 +16,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghos" "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/log" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/kardianos/service" ) @@ -57,16 +59,18 @@ func (openbsdSystem) Interactive() (ok bool) { // New implements service.System interface for openbsdSystem. func (openbsdSystem) New(i service.Interface, c *service.Config) (s service.Service, err error) { return &openbsdRunComService{ - i: i, - cfg: c, + cmdCons: executil.SystemCommandConstructor{}, + i: i, + cfg: c, }, nil } // openbsdRunComService is the RunCom-based service.Service to be used on the // OpenBSD. type openbsdRunComService struct { - i service.Interface - cfg *service.Config + cmdCons executil.CommandConstructor + i service.Interface + cfg *service.Config } // Platform implements service.Service interface for *openbsdRunComService. @@ -210,8 +214,10 @@ func (s *openbsdRunComService) configureSysStartup(enable bool) (err error) { cmd = "disable" } + // TODO(s.chzhen): Pass context. + ctx := context.TODO() var code int - code, _, err = aghos.RunCommand("rcctl", cmd, s.cfg.Name) + code, _, err = aghos.RunCommand(ctx, s.cmdCons, "rcctl", cmd, s.cfg.Name) if err != nil { return err } else if code != 0 { @@ -312,11 +318,14 @@ func (s *openbsdRunComService) runCom(cmd string) (out string, err error) { return "", err } + // TODO(s.chzhen): Pass context. + ctx := context.TODO() + // TODO(e.burkov): It's possible that os.ErrNotExist is caused by // something different than the service script's non-existence. Keep it // in mind, when replace the aghos.RunCommand. var outData []byte - _, outData, err = aghos.RunCommand(scriptPath, cmd) + _, outData, err = aghos.RunCommand(ctx, s.cmdCons, scriptPath, cmd) if errors.Is(err, os.ErrNotExist) { return "", service.ErrNotInstalled } diff --git a/internal/home/web.go b/internal/home/web.go index 40f59230..276449b3 100644 --- a/internal/home/web.go +++ b/internal/home/web.go @@ -20,6 +20,7 @@ import ( "github.com/AdguardTeam/golibs/netutil/httputil" "github.com/AdguardTeam/golibs/netutil/urlutil" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/NYTimes/gziphandler" "github.com/quic-go/quic-go/http3" "golang.org/x/net/http2" @@ -39,6 +40,9 @@ const ( ) type webConfig struct { + // CommandConstructor is used to run external commands. It must not be nil. + CommandConstructor executil.CommandConstructor + updater *updater.Updater // logger is a slog logger used in webAPI. It must not be nil. @@ -111,6 +115,9 @@ type webAPI struct { // confModifier is used to update the global configuration. confModifier agh.ConfigModifier + // cmdCons is used to run external commands. + cmdCons executil.CommandConstructor + // TODO(a.garipov): Refactor all these servers. httpServer *http.Server @@ -143,6 +150,7 @@ func newWebAPI(ctx context.Context, conf *webConfig) (w *webAPI) { w = &webAPI{ conf: conf, confModifier: conf.confModifier, + cmdCons: conf.CommandConstructor, logger: conf.logger, baseLogger: conf.baseLogger, tlsManager: conf.tlsManager, diff --git a/internal/updater/check_test.go b/internal/updater/check_test.go index 4da1c876..a94ed5a2 100644 --- a/internal/updater/check_test.go +++ b/internal/updater/check_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/version" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -58,13 +59,14 @@ func TestUpdater_VersionInfo(t *testing.T) { fakeURL := srvURL.JoinPath("adguardhome", version.ChannelBeta, "version.json") u := updater.NewUpdater(&updater.Config{ - Client: srv.Client(), - Logger: testLogger, - Version: "v0.103.0-beta.1", - Channel: version.ChannelBeta, - GOARCH: "arm", - GOOS: "linux", - VersionCheckURL: fakeURL, + Client: srv.Client(), + Logger: testLogger, + CommandConstructor: executil.EmptyCommandConstructor{}, + Version: "v0.103.0-beta.1", + Channel: version.ChannelBeta, + GOARCH: "arm", + GOOS: "linux", + VersionCheckURL: fakeURL, }) ctx := testutil.ContextWithTimeout(t, testTimeout) @@ -132,15 +134,16 @@ func TestUpdater_VersionInfo_others(t *testing.T) { for _, tc := range testCases { u := updater.NewUpdater(&updater.Config{ - Client: fakeClient, - Logger: testLogger, - Version: "v0.103.0-beta.1", - Channel: version.ChannelBeta, - GOOS: "linux", - GOARCH: tc.arch, - GOARM: tc.arm, - GOMIPS: tc.mips, - VersionCheckURL: fakeURL, + Client: fakeClient, + Logger: testLogger, + CommandConstructor: executil.EmptyCommandConstructor{}, + Version: "v0.103.0-beta.1", + Channel: version.ChannelBeta, + GOOS: "linux", + GOARCH: tc.arch, + GOARM: tc.arm, + GOMIPS: tc.mips, + VersionCheckURL: fakeURL, }) ctx := testutil.ContextWithTimeout(t, testTimeout) diff --git a/internal/updater/updater.go b/internal/updater/updater.go index 22c44e31..d62b7867 100644 --- a/internal/updater/updater.go +++ b/internal/updater/updater.go @@ -4,6 +4,7 @@ package updater import ( "archive/tar" "archive/zip" + "bytes" "compress/gzip" "context" "fmt" @@ -13,7 +14,6 @@ import ( "net/http" "net/url" "os" - "os/exec" "path" "path/filepath" "strings" @@ -26,6 +26,7 @@ import ( "github.com/AdguardTeam/golibs/ioutil" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/netutil/urlutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) // Updater is the AdGuard Home updater. @@ -33,6 +34,8 @@ type Updater struct { client *http.Client logger *slog.Logger + cmdCons executil.CommandConstructor + version string channel string goarch string @@ -88,6 +91,9 @@ type Config struct { // be nil, see [DefaultVersionURL]. VersionCheckURL *url.URL + // CommandConstructor is used to run external commands. It must not be nil. + CommandConstructor executil.CommandConstructor + // Version is the current AdGuard Home version. It must not be empty. Version string @@ -129,6 +135,8 @@ func NewUpdater(conf *Config) *Updater { client: conf.Client, logger: conf.Logger, + cmdCons: conf.CommandConstructor, + version: conf.Version, channel: conf.Channel, goarch: conf.GOARCH, @@ -286,11 +294,27 @@ func (u *Updater) check(ctx context.Context) (err error) { "%s" + "end of the output" - cmd := exec.Command(u.updateExeName, "--check-config") - out, err := cmd.CombinedOutput() - code := cmd.ProcessState.ExitCode() - if err != nil || code != 0 { - return fmt.Errorf(format, err, code, out) + var ( + args = []string{"--check-config"} + buf bytes.Buffer + ) + + u.logger.DebugContext(ctx, "executing", "cmd", u.updateExeName, "args", args) + + err = executil.Run( + ctx, + u.cmdCons, + &executil.CommandConfig{ + Path: u.updateExeName, + Args: args, + Stdout: &buf, + Stderr: &buf, + }, + ) + if err != nil { + code, _ := executil.ExitCodeFromError(err) + + return fmt.Errorf(format, err, code, buf.Bytes()) } return nil diff --git a/internal/updater/updater_internal_test.go b/internal/updater/updater_internal_test.go index 3d96a8ff..195d86b2 100644 --- a/internal/updater/updater_internal_test.go +++ b/internal/updater/updater_internal_test.go @@ -10,6 +10,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/aghtest" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -58,13 +59,14 @@ func TestUpdater_internal(t *testing.T) { fakeURL = fakeURL.JoinPath(tc.archiveName) u := NewUpdater(&Config{ - Client: fakeClient, - Logger: slogutil.NewDiscardLogger(), - GOOS: tc.os, - Version: "v0.103.0", - ExecPath: exePath, - WorkDir: wd, - ConfName: yamlPath, + Client: fakeClient, + Logger: slogutil.NewDiscardLogger(), + CommandConstructor: executil.EmptyCommandConstructor{}, + GOOS: tc.os, + Version: "v0.103.0", + ExecPath: exePath, + WorkDir: wd, + ConfName: yamlPath, // TODO(e.burkov): Rewrite the test to use a fake version check // URL with a fake URLs for the package files. VersionCheckURL: &url.URL{}, diff --git a/internal/updater/updater_test.go b/internal/updater/updater_test.go index dfef0b10..f52fad8d 100644 --- a/internal/updater/updater_test.go +++ b/internal/updater/updater_test.go @@ -15,6 +15,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/updater" "github.com/AdguardTeam/AdGuardHome/internal/version" "github.com/AdguardTeam/golibs/logutil/slogutil" + "github.com/AdguardTeam/golibs/osutil/executil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -76,15 +77,16 @@ func TestUpdater_Update(t *testing.T) { require.NoError(t, err) u := updater.NewUpdater(&updater.Config{ - Client: srv.Client(), - Logger: testLogger, - GOARCH: "amd64", - GOOS: "linux", - Version: "v0.103.0", - ConfName: yamlPath, - WorkDir: wd, - ExecPath: exePath, - VersionCheckURL: versionCheckURL, + Client: srv.Client(), + Logger: testLogger, + CommandConstructor: executil.EmptyCommandConstructor{}, + GOARCH: "amd64", + GOOS: "linux", + Version: "v0.103.0", + ConfName: yamlPath, + WorkDir: wd, + ExecPath: exePath, + VersionCheckURL: versionCheckURL, }) ctx := testutil.ContextWithTimeout(t, testTimeout) diff --git a/scripts/translations/main.go b/scripts/translations/main.go index e8dc0473..f21f5af0 100644 --- a/scripts/translations/main.go +++ b/scripts/translations/main.go @@ -13,7 +13,6 @@ import ( "maps" "net/url" "os" - "os/exec" "path/filepath" "slices" "strings" @@ -23,6 +22,7 @@ import ( "github.com/AdguardTeam/golibs/errors" "github.com/AdguardTeam/golibs/logutil/slogutil" "github.com/AdguardTeam/golibs/osutil" + "github.com/AdguardTeam/golibs/osutil/executil" ) const ( @@ -96,7 +96,7 @@ func main() { errors.Check(cli.upload()) case "auto-add": - err := autoAdd(conf.LocalizableFiles[0]) + err := autoAdd(ctx, l, conf.LocalizableFiles[0]) errors.Check(err) default: usage("unknown command") @@ -395,10 +395,12 @@ func findUnused(fileNames []string, loc locales) (err error) { // autoAdd adds locales with additions to the git and restores locales with // deletions. -func autoAdd(basePath string) (err error) { +func autoAdd(ctx context.Context, l *slog.Logger, basePath string) (err error) { defer func() { err = errors.Annotate(err, "auto add: %w") }() - adds, dels, err := changedLocales() + cmdCons := executil.SystemCommandConstructor{} + + adds, dels, err := changedLocales(ctx, l, cmdCons) if err != nil { // Don't wrap the error since it's informative enough as is. return err @@ -408,13 +410,13 @@ func autoAdd(basePath string) (err error) { return errors.Error("base locale contains deletions") } - err = handleAdds(adds) + err = handleAdds(ctx, l, cmdCons, adds) if err != nil { // Don't wrap the error since it's informative enough as is. return nil } - err = handleDels(dels) + err = handleDels(ctx, l, cmdCons, dels) if err != nil { // Don't wrap the error since it's informative enough as is. return nil @@ -423,14 +425,24 @@ func autoAdd(basePath string) (err error) { return nil } +// gitCmd is the shell command for Git. +const gitCmd = "git" + // handleAdds adds locales with additions to the git. -func handleAdds(locales []string) (err error) { +func handleAdds( + ctx context.Context, + l *slog.Logger, + cmdCons executil.CommandConstructor, + locales []string, +) (err error) { if len(locales) == 0 { return nil } - args := append([]string{"add"}, locales...) - code, out, err := aghos.RunCommand("git", args...) + gitArgs := append([]string{"add"}, locales...) + l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs) + + code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...) if err != nil || code != 0 { return fmt.Errorf("git add exited with code %d output %q: %w", code, out, err) @@ -440,13 +452,20 @@ func handleAdds(locales []string) (err error) { } // handleDels restores locales with deletions. -func handleDels(locales []string) (err error) { +func handleDels( + ctx context.Context, + l *slog.Logger, + cmdCons executil.CommandConstructor, + locales []string, +) (err error) { if len(locales) == 0 { return nil } - args := append([]string{"restore"}, locales...) - code, out, err := aghos.RunCommand("git", args...) + gitArgs := append([]string{"restore"}, locales...) + l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs) + + code, out, err := aghos.RunCommand(ctx, cmdCons, gitCmd, gitArgs...) if err != nil || code != 0 { return fmt.Errorf("git restore exited with code %d output %q: %w", code, out, err) @@ -458,22 +477,32 @@ func handleDels(locales []string) (err error) { // changedLocales returns cleaned paths of locales with changes or error. adds // is the list of locales with only additions. dels is the list of locales // with only deletions. -func changedLocales() (adds, dels []string, err error) { +func changedLocales( + ctx context.Context, + l *slog.Logger, + cmdCons executil.CommandConstructor, +) (adds, dels []string, err error) { defer func() { err = errors.Annotate(err, "getting changes: %w") }() - cmd := exec.Command("git", "diff", "--numstat", localesDir) + gitArgs := []string{"diff", "--numstat", localesDir} + l.DebugContext(ctx, "executing", "cmd", gitCmd, "args", gitArgs) - stdout, err := cmd.StdoutPipe() + // TODO(s.chzhen): Consider streaming the output if needed. Using + // [io.Pipe] here is unnecessary; it complicates lifecycle management + // because the output must be read concurrently, and the PipeWriter must be + // explicitly closed to signal EOF. Since this command's output is small, a + // bytes.Buffer via executil.Run is sufficient. + var out bytes.Buffer + err = executil.Run(ctx, cmdCons, &executil.CommandConfig{ + Path: gitCmd, + Args: gitArgs, + Stdout: &out, + }) if err != nil { - return nil, nil, fmt.Errorf("piping: %w", err) + return nil, nil, fmt.Errorf("executing cmd: %w", err) } - err = cmd.Start() - if err != nil { - return nil, nil, fmt.Errorf("starting: %w", err) - } - - scanner := bufio.NewScanner(stdout) + scanner := bufio.NewScanner(&out) for scanner.Scan() { line := scanner.Text() @@ -497,10 +526,5 @@ func changedLocales() (adds, dels []string, err error) { return nil, nil, fmt.Errorf("scanning: %w", err) } - err = cmd.Wait() - if err != nil { - return nil, nil, fmt.Errorf("waiting: %w", err) - } - return adds, dels, nil }