Initial version with improved server concurrency and better client retry

handling

align client service param names with master
This commit is contained in:
morestatic 2023-02-15 14:28:26 +07:00
parent 0e8ea01a65
commit 5bfab9fac7
60 changed files with 2155 additions and 1201 deletions

View File

@ -12,7 +12,7 @@ build-debug:
$(foreach BINARY,$(BINARIES),go build -race -gcflags "all=-N -l" -o $(BINARY) -v ./cmd/$(BINARY)/...;)
test:
go test -v ./...
go test -race -v ./...
bind-data:
cd db/migration/jobs/sql/ && go-bindata -o ../bindata.go -pkg jobs ./...

View File

@ -34,7 +34,16 @@ import (
"github.com/cloudradar-monitoring/rport/share/models"
)
const ConnectionTimeout = 10 * time.Second
const DialTimeout = 5 * 60 * time.Second
const AuthTimeout = 30 * time.Second
const MinConnectionBackoffWaitTime = 5 * time.Second
const MaxConnectionBackoffWaitTime = 10 * 60 * time.Second
const ServerReconnectRequestBackoffTime = 3 * 60 * time.Second
const InitialConnectionRequestSendDelayJitterMilliseconds = 10000
const SendRequestTimeout = 30 * time.Second
const MinSendRequestRetryWaitTime = 1 * time.Second
const BackoffOnServerTimeoutMaxDuration = 1 * time.Second
const MaxKeepAliveJitterMilliseconds = 5000
// Client represents a client instance
type Client struct {
@ -66,13 +75,14 @@ func NewClient(config *ClientConfigHolder, filesAPI files.FileAPI) (*Client, err
if err != nil {
return nil, fmt.Errorf("failed to create initial session id: %s", err)
}
cmdExec := system.NewCmdExecutor(logger.NewLogger("cmd executor", config.Logging.LogOutput, config.Logging.LogLevel))
logger := logger.NewLogger("client", config.Logging.LogOutput, config.Logging.LogLevel)
watchdog, err := NewWatchdog(config.Connection.WatchdogIntegration, config.Client.DataDir, logger)
if err != nil {
return nil, fmt.Errorf("failed to create watchdog: %s", err)
}
logger.Infof("Client started with sessionID %s", sessionID)
systemInfo := system.NewSystemInfo(cmdExec)
client := &Client{
SessionID: sessionID,
@ -93,9 +103,10 @@ func NewClient(config *ClientConfigHolder, filesAPI files.FileAPI) (*Client, err
Auth: []ssh.AuthMethod{ssh.Password(config.Client.AuthPass)},
ClientVersion: "SSH-" + chshare.ProtocolVersion + "-client",
HostKeyCallback: client.verifyServer,
Timeout: 30 * time.Second,
Timeout: AuthTimeout,
}
logger.Infof("New client instance with sessionID %s", sessionID)
return client, nil
}
@ -127,8 +138,9 @@ func (c *Client) Start(ctx context.Context) error {
c.Infof("Keepalive job (client to server ping) started with interval %s", c.configHolder.Connection.KeepAlive)
go c.keepAliveLoop()
}
//connection loop
go c.connectionLoop(ctx)
go c.connectionLoop(ctx, true)
c.updates.Start(ctx)
@ -137,19 +149,28 @@ func (c *Client) Start(ctx context.Context) error {
func (c *Client) keepAliveLoop() {
for c.running {
time.Sleep(c.configHolder.Connection.KeepAlive)
time.Sleep(c.configHolder.Connection.KeepAlive + (time.Duration(rand.Intn(MaxKeepAliveJitterMilliseconds)))*time.Millisecond)
c.mu.RLock()
conn := c.sshConn
c.mu.RUnlock()
if conn != nil {
ok, _, rtt, err := comm.PingConnectionWithTimeout(conn, c.configHolder.Connection.KeepAliveTimeout)
if err != nil || !ok {
res, err := comm.WithRetry(func() (res *sendResponse, err error) {
ok, _, rtt, err := comm.PingConnectionWithTimeout(conn, c.configHolder.Connection.KeepAliveTimeout, c.Logger)
return &sendResponse{
replyOk: ok,
rtt: rtt,
respBytes: nil,
}, err
}, canRetryFn, MinSendRequestRetryWaitTime, "ping", c.Logger)
if err != nil || !res.replyOk {
c.Errorf("Failed to send keepalive (client to server ping): %s", err)
c.sshConn.Close()
} else {
msg := fmt.Sprintf("ping to %s succeeded within %s", conn.RemoteAddr(), rtt)
msg := fmt.Sprintf("ping to %s succeeded within %s", conn.RemoteAddr(), res.rtt)
c.Debugf(msg)
c.watchdog.Ping(WatchdogStateConnected, msg)
}
@ -157,32 +178,28 @@ func (c *Client) keepAliveLoop() {
}
}
func (c *Client) connectionLoop(ctx context.Context) {
func (c *Client) connectionLoop(ctx context.Context, withInitialSendRequestDelay bool) {
//connection loop!
var connerr error
switchbackChan := make(chan *sshClientConn, 1)
b := &backoff.Backoff{Max: c.configHolder.Connection.MaxRetryInterval}
backoff := &backoff.Backoff{
Min: MinConnectionBackoffWaitTime + time.Duration(rand.Intn(60)),
Max: MaxConnectionBackoffWaitTime,
Jitter: true,
}
for c.running {
if connerr != nil {
attempt := int(b.Attempt())
var d = b.Duration()
c.showConnectionError(connerr, attempt)
if c.configHolder.Connection.MaxRetryCount >= 0 && attempt >= c.configHolder.Connection.MaxRetryCount {
break // Stop trying to connect if the user has set a max retry limit
stopRetrying := c.handleConnectionError(backoff, connerr)
if stopRetrying {
break
}
if _, ok := connerr.(comm.TimeoutError); ok {
// Timeout means the server is available. No need to wait up to 5 min to try again.
rand.Seed(time.Now().UnixNano())
d = time.Duration(rand.Intn(20)) * time.Second
b.Reset()
}
msg := fmt.Sprintf("Retrying in %s...", d)
c.Infof(msg)
c.watchdog.Ping(WatchdogStateReconnecting, msg)
connerr = nil
chshare.SleepSignal(d)
}
c.Logger.Debugf("conn loop attempt = %d", int(backoff.Attempt()))
// make the connection attempt
var sshConn *sshClientConn
var isPrimary bool
select {
@ -204,46 +221,36 @@ func (c *Client) connectionLoop(ctx context.Context) {
switchbackCtx, cancelSwitchback := context.WithCancel(ctx)
if !isPrimary {
go func() {
for {
switchbackTimer := time.NewTimer(c.configHolder.Client.ServerSwitchbackInterval)
select {
case <-switchbackCtx.Done():
switchbackTimer.Stop()
return
case <-switchbackTimer.C:
switchbackConn, err := c.connect(c.configHolder.Client.Server)
if err != nil {
c.Errorf("Switchback failed: %v", err.Error())
continue
}
c.Infof("Connected to main server, switching back.")
switchbackChan <- switchbackConn
sshConn.Connection.Close()
return
}
}
}()
go c.handleServerSwitchBack(switchbackCtx, switchbackChan, sshConn)
}
err := c.sendConnectionRequest(ctx, sshConn.Connection)
if withInitialSendRequestDelay {
delay := time.Duration(rand.Intn(InitialConnectionRequestSendDelayJitterMilliseconds)) * time.Millisecond
c.Logger.Debugf("waiting for %d milliseconds before sending connection request", delay/time.Millisecond)
time.Sleep(delay)
}
err := c.sendConnectionRequest(ctx, sshConn.Connection, MinSendRequestRetryWaitTime)
if err != nil {
// Connection request has failed, we try again
cancelSwitchback()
connerr = err
continue
}
// Connection request has succeeded
b.Reset()
// Connection request has succeeded
backoff.Reset()
// Hand over the open SSH connection to the client
c.mu.Lock()
c.sshConn = sshConn.Connection // Hand over the open SSH connection to the client
c.sshConn = sshConn.Connection
c.mu.Unlock()
c.updates.SetConn(sshConn.Connection)
c.monitor.SetConn(sshConn.Connection)
err = sshConn.Connection.Wait() // Block aka wait until the connection is closed
// now wait with the client handling SSH Requests and Channel Connections
err = sshConn.Connection.Wait()
c.mu.Lock()
//disconnected
@ -262,9 +269,57 @@ func (c *Client) connectionLoop(ctx context.Context) {
c.Infof("Disconnected\n")
}
close(c.runningc)
}
func (c *Client) handleConnectionError(backoff *backoff.Backoff, connerr error) (stopRetrying bool) {
attempt := int(backoff.Attempt())
c.showConnectionError(connerr, attempt)
// check if the user has set a max retry limit
if c.configHolder.Connection.MaxRetryCount >= 0 && attempt >= c.configHolder.Connection.MaxRetryCount {
return true // if so, stop trying
}
var d = backoff.Duration()
if _, ok := connerr.(comm.TimeoutError); ok {
// Timeout means the server isn't offline, so reset the backoff and use an initial short retry duration
backoff.Reset()
rand.Seed(time.Now().UnixNano())
d = time.Duration(rand.Intn(int(backoff.Attempt()))) * BackoffOnServerTimeoutMaxDuration
}
msg := fmt.Sprintf("Retrying in %s...", d)
c.Infof(msg)
// TODO: (rs): what is this watchdog ping?
c.watchdog.Ping(WatchdogStateReconnecting, msg)
chshare.SleepSignal(d)
return false
}
func (c *Client) handleServerSwitchBack(switchbackCtx context.Context, switchbackChan chan *sshClientConn, sshConn *sshClientConn) {
for {
switchbackTimer := time.NewTimer(c.configHolder.Client.ServerSwitchbackInterval)
select {
case <-switchbackCtx.Done():
switchbackTimer.Stop()
return
case <-switchbackTimer.C:
switchbackConn, err := c.connect(c.configHolder.Client.Server)
if err != nil {
c.Errorf("Switchback failed: %v", err.Error())
continue
}
c.Infof("Connected to main server, switching back.")
switchbackChan <- switchbackConn
sshConn.Connection.Close()
return
}
}
}
type sshClientConn struct {
Connection ssh.Conn
Channels <-chan ssh.NewChannel
@ -290,54 +345,24 @@ func (c *Client) connect(server string) (*sshClientConn, error) {
}
c.Infof("Trying to connect to %s%s ...\n", server, via)
netDialer := &net.Dialer{}
d := websocket.Dialer{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: 45 * time.Second,
Subprotocols: []string{chshare.ProtocolVersion},
NetDialContext: netDialer.DialContext,
d, netDialer, err := c.setupDialer()
if err != nil {
return nil, err
}
if c.configHolder.Client.BindInterface != "" {
laddr, err := c.localAddrForInterface(c.configHolder.Client.BindInterface)
//optionally proxy
if c.configHolder.Client.ProxyURL != nil {
err := c.addDialerProxySupport(d, netDialer)
if err != nil {
return nil, err
}
netDialer.LocalAddr = laddr
}
//optionally proxy
if c.configHolder.Client.ProxyURL != nil {
if strings.HasPrefix(c.configHolder.Client.ProxyURL.Scheme, "socks") {
// SOCKS5 proxy
if c.configHolder.Client.ProxyURL.Scheme != "socks" && c.configHolder.Client.ProxyURL.Scheme != "socks5h" {
return nil, fmt.Errorf(
"unsupported socks proxy type: %s:// (only socks5h:// or socks:// is supported)",
c.configHolder.Client.ProxyURL.Scheme)
}
var auth *proxy.Auth
if c.configHolder.Client.ProxyURL.User != nil {
pass, _ := c.configHolder.Client.ProxyURL.User.Password()
auth = &proxy.Auth{
User: c.configHolder.Client.ProxyURL.User.Username(),
Password: pass,
}
}
socksDialer, err := proxy.SOCKS5("tcp", c.configHolder.Client.ProxyURL.Host, auth, netDialer)
if err != nil {
return nil, err
}
d.NetDialContext = socksDialer.(proxy.ContextDialer).DialContext
} else {
// CONNECT proxy
d.Proxy = func(*http.Request) (*url.URL, error) {
return c.configHolder.Client.ProxyURL, nil
}
}
}
wsConn, _, err := d.Dial(server, c.configHolder.Connection.HTTPHeaders)
if err != nil {
return nil, ConnectionErrorHints(server, c.Logger, err)
}
conn := chshare.NewWebSocketConn(wsConn)
// perform SSH handshake on net.Conn
c.Debugf("Handshaking...")
@ -349,6 +374,7 @@ func (c *Client) connect(server string) (*sshClientConn, error) {
}
return nil, err
}
return &sshClientConn{
Connection: sshConn,
Requests: reqs,
@ -356,7 +382,69 @@ func (c *Client) connect(server string) (*sshClientConn, error) {
}, nil
}
func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn) error {
func (c *Client) setupDialer() (d *websocket.Dialer, netDialer *net.Dialer, err error) {
netDialer = &net.Dialer{}
d = &websocket.Dialer{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
HandshakeTimeout: DialTimeout,
Subprotocols: []string{chshare.ProtocolVersion},
NetDialContext: netDialer.DialContext,
}
if c.configHolder.Client.BindInterface != "" {
laddr, err := c.localAddrForInterface(c.configHolder.Client.BindInterface)
if err != nil {
return nil, nil, err
}
netDialer.LocalAddr = laddr
}
return d, netDialer, err
}
func (c *Client) addDialerProxySupport(d *websocket.Dialer, netDialer *net.Dialer) (err error) {
if strings.HasPrefix(c.configHolder.Client.ProxyURL.Scheme, "socks") {
// SOCKS5 proxy
if c.configHolder.Client.ProxyURL.Scheme != "socks" && c.configHolder.Client.ProxyURL.Scheme != "socks5h" {
return fmt.Errorf(
"unsupported socks proxy type: %s:// (only socks5h:// or socks:// is supported)",
c.configHolder.Client.ProxyURL.Scheme)
}
var auth *proxy.Auth
if c.configHolder.Client.ProxyURL.User != nil {
pass, _ := c.configHolder.Client.ProxyURL.User.Password()
auth = &proxy.Auth{
User: c.configHolder.Client.ProxyURL.User.Username(),
Password: pass,
}
}
socksDialer, err := proxy.SOCKS5("tcp", c.configHolder.Client.ProxyURL.Host, auth, netDialer)
if err != nil {
return err
}
d.NetDialContext = socksDialer.(proxy.ContextDialer).DialContext
} else {
// CONNECT proxy
d.Proxy = func(*http.Request) (*url.URL, error) {
return c.configHolder.Client.ProxyURL, nil
}
}
return nil
}
type sendResponse struct {
replyOk bool
respBytes []byte
rtt time.Duration
}
func canRetryFn(err error) (can bool) {
// if a timeout err, retry on the existing connection
return strings.Contains(err.Error(), "timeout")
}
func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn, minRetryWaitDuration time.Duration) error {
connReq, err := c.connectionRequest(ctx)
if err != nil {
return err
@ -366,19 +454,38 @@ func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn) er
if err != nil {
return fmt.Errorf("could not encode connection request: %v", err)
}
c.Infof("Sending connection request.")
c.Debugf("Sending connection request with client details %s", string(req))
t0 := time.Now()
replyOk, respBytes, err := comm.SendRequestWithTimeout(sshConn, "new_connection", true, req, ConnectionTimeout)
res, err := comm.WithRetry(func() (res *sendResponse, err error) {
replyOk, respBytes, err := comm.SendRequestWithTimeout(sshConn, "new_connection", true, req, SendRequestTimeout, c.Logger)
return &sendResponse{
replyOk: replyOk,
respBytes: respBytes,
}, err
}, canRetryFn, minRetryWaitDuration, "Connection Request", c.Logger)
if err != nil {
if err2 := sshConn.Close(); err2 != nil {
c.Errorf("Failed to close connection: %s", err2)
c.Errorf("connection request err = %v", err)
if closeErr := sshConn.Close(); closeErr != nil {
c.Errorf("Failed to close connection: %s", closeErr)
}
reconnect := strings.Contains(err.Error(), "reconnect")
if reconnect {
reconnectDelay := ServerReconnectRequestBackoffTime + (time.Duration(rand.Intn(30)) * time.Second)
c.Debugf("waiting %d seconds before reconnect", reconnectDelay/time.Second)
// this probably means the server is too busy for us. wait quite a while
// before returning to the conn loop.
time.Sleep(reconnectDelay)
}
return err
}
c.Debugf("Connection request has been answered successfully within %s.", time.Since(t0))
if !replyOk {
msg := string(respBytes)
c.Debugf("Connection request has been answered within %s.", time.Since(t0))
if !res.replyOk {
msg := string(res.respBytes)
// if replied with client credentials already used - retry
if strings.Contains(msg, "client is already connected:") {
@ -390,14 +497,17 @@ func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn) er
return errors.New(msg)
}
var remotes []*models.Remote
err = json.Unmarshal(respBytes, &remotes)
err = json.Unmarshal(res.respBytes, &remotes)
if err != nil {
return fmt.Errorf("can't decode reply payload: %s", err)
}
msg := fmt.Sprintf("Connected to %s within %s", sshConn.RemoteAddr().String(), time.Since(t0))
c.watchdog.Ping(WatchdogStateConnected, msg)
c.Infof(msg)
for _, r := range remotes {
c.Infof("New tunnel: %s", r.String())
@ -453,12 +563,12 @@ func (c *Client) handleSSHRequests(ctx context.Context, sshConn *sshClientConn)
sshConn.Connection,
system.SysUserProvider{},
)
resp, err = uploadManager.HandleUploadRequest(r.Payload)
case comm.RequestTypeCheckTunnelAllowed:
resp, err = c.checkTunnelAllowed(r.Payload)
case comm.RequestTypePing:
_ = r.Reply(true, nil)
continue
default:
c.Debugf("Unknown request: %q", r.Type)
comm.ReplyError(c.Logger, r, errors.New("unknown request"))

View File

@ -428,11 +428,11 @@ func (m *mockServer) IsConnected() bool {
}
func (m *mockServer) WaitForStatus(isConnected bool) error {
for i := 0; i < 1500; i++ {
for i := 0; i < 60; i++ {
if m.IsConnected() == isConnected {
return nil
}
time.Sleep(2 * time.Millisecond)
time.Sleep(200 * time.Millisecond)
}
return fmt.Errorf("timeout waiting for isConnected=%v", isConnected)
}
@ -492,7 +492,7 @@ func TestConnectionLoop(t *testing.T) {
c, err := NewClient(&config, test.NewFileAPIMock())
require.NoError(t, err)
go c.connectionLoop(context.Background())
go c.connectionLoop(context.Background(), false)
// connects to main server successfully
assert.NoError(t, mainServer.WaitForStatus(true))

View File

@ -27,7 +27,7 @@ func ConnectionErrorHints(server string, logger *logger.Logger, err error) error
if allHeaders != "" {
logger.Debugf("headers collected while detecting proxy: %s", allHeaders)
}
return fmt.Errorf("%s - Check your client credentials AND check for tranparent proxies", err)
return fmt.Errorf("%s - Server maybe busy. Also check your client credentials AND check for tranparent proxies", err)
default:
return err
}

View File

@ -26,21 +26,22 @@ import (
)
const (
DefaultKeepDisconnectedClients = time.Hour
DefaultPurgeDisconnectedClientsInterval = 1 * time.Minute
DefaultCheckClientsConnectionInterval = 5 * time.Minute
DefaultCheckClientsConnectionTimeout = 30 * time.Second
DefaultMaxRequestBytes = 10 * 1024 // 10 KB
DefaultMaxRequestBytesClient = 512 * 1024 // 512KB
DefaultMaxFilePushBytes = int64(10 << 20) // 10M
DefaultCheckPortTimeout = 2 * time.Second
DefaultUsedPorts = "20000-30000"
DefaultExcludedPorts = "1-1024"
DefaultServerAddress = "0.0.0.0:8080"
DefaultLogLevel = "info"
DefaultRunRemoteCmdTimeoutSec = 60
DefaultMonitoringDataStorageDuration = "7d"
DefaultPairingURL = "https://pairing.rport.io"
DefaultMaxConcurrentSSHConnectionHandshakes = 4
DefaultKeepDisconnectedClients = time.Hour
DefaultPurgeDisconnectedClientsInterval = 1 * time.Minute
DefaultCheckClientsConnectionInterval = 5 * time.Minute
DefaultCheckClientsConnectionTimeout = 30 * time.Second
DefaultMaxRequestBytes = 10 * 1024 // 10 KB
DefaultMaxRequestBytesClient = 512 * 1024 // 512KB
DefaultMaxFilePushBytes = int64(10 << 20) // 10M
DefaultCheckPortTimeout = 2 * time.Second
DefaultUsedPorts = "20000-30000"
DefaultExcludedPorts = "1-1024"
DefaultServerAddress = "0.0.0.0:8080"
DefaultLogLevel = "info"
DefaultRunRemoteCmdTimeoutSec = 60
DefaultMonitoringDataStorageDuration = "7d"
DefaultPairingURL = "https://pairing.rport.io"
)
var serverHelp = `
@ -314,6 +315,7 @@ func init() {
viperCfg.SetDefault("server.data_dir", chserver.DefaultDataDirectory)
viperCfg.SetDefault("server.sqlite_wal", true)
viperCfg.SetDefault("server.keep_disconnected_clients", DefaultKeepDisconnectedClients)
viperCfg.SetDefault("server.max_concurrent_ssh_handshakes", DefaultMaxConcurrentSSHConnectionHandshakes)
viperCfg.SetDefault("server.purge_disconnected_clients_interval", DefaultPurgeDisconnectedClientsInterval)
viperCfg.SetDefault("server.check_clients_connection_interval", DefaultCheckClientsConnectionInterval)
viperCfg.SetDefault("server.check_clients_connection_timeout", DefaultCheckClientsConnectionTimeout)

View File

@ -19,10 +19,12 @@ const (
WALEnabled = "_journal_mode=WAL"
defaultDelayBetweenAttempts = 200 * time.Millisecond
DefaultMaxAttempts = 5
DefaultMaxOpenConnections = 1
)
type DataSourceOptions struct {
WALEnabled bool
WALEnabled bool
MaxOpenConnections int
}
// New returns a new sqlite DB instance with migrated DB scheme to the latest version.
@ -42,7 +44,12 @@ func New(dataSourceName string, assetNames []string, asset func(name string) ([]
}
}
db.SetMaxOpenConns(1)
maxConns := dataSourceOptions.MaxOpenConnections
if maxConns == 0 {
maxConns = DefaultMaxOpenConnections
}
db.SetMaxOpenConns(maxConns)
s := bindata.Resource(assetNames,
func(name string) ([]byte, error) {
@ -70,27 +77,29 @@ func New(dataSourceName string, assetNames []string, asset func(name string) ([]
return db, nil
}
// TODO: (rs): we've moved to use single db connections. with potentially slower access to the sqlite
// volumes it seems there's too much concurrent contention for the dbs, so there's less need for this fn.
// not removing yet but check again in approx 6 months (from dec 22) and remove if no longer required.
func WithRetryWhenBusy[R any](retryAble func() (result R, err error), label string, l *logger.Logger) (result R, err error) {
func WithRetryWhenBusy[R any](retryAbleFn func() (result R, err error), label string, l *logger.Logger) (result R, err error) {
for r := 0; r < DefaultMaxAttempts; r++ {
result, err = retryAble()
if err != nil {
attempt := r + 1
if attempt > 1 && err != nil {
sqlErr, ok := err.(sql.Error)
if ok && sqlErr.Code == sql.ErrBusy {
l.Debugf("%s: attempt %d: source err = %+v\n", label, r+1, err)
l.Debugf("%s: attempt %d: source err = %+v\n", label, attempt, err)
jitter := time.Duration((rand.Intn(100))) * time.Millisecond
time.Sleep(defaultDelayBetweenAttempts + jitter)
continue
} else {
// a different error from database busy, so fail immediately
l.Debugf("%s: attempt %d: non-retryable err = %+v\n", label, attempt, err)
return result, err
}
// non retryable err
return result, sqlErr
}
// success
return result, nil
// make an attempt to complete the retryable fn
result, err = retryAbleFn()
// if no error then return immediately if with success result
if err == nil {
return result, nil
}
}
l.Debugf("%s: failed after max attempts: err = %+v\n", label, err)
l.Errorf("%s: failed after max attempts: err = %+v\n", label, err)
return result, err
}

View File

@ -124,6 +124,12 @@
## This is a performance enhancement. Do not turn off, unless you have good reasons.
#sqlite_wal = true
## Limits the number of ssh handshakes that the server will handle concurrently. Too many in progress SSH handshakes
## together will slow down the server's ability to perform other work. This can particularly impact server startup
## when many clients connect at similar times. A very slow server can also result in strange client reconnect issues.
## Default is 4.
#max_concurrent_ssh_handshakes = 4
## An optional parameter to define whether disconnected clients get purged from the database.
## By default, disconnected clients are NOT purged.
#purge_disconnected_clients = false

View File

@ -377,8 +377,8 @@ func TestHandleDeleteClientAuth(t *testing.T) {
mockConn := &mockConnection{}
initState := []*clientsauth.ClientAuth{cl1, cl2, cl3}
c1 := clients.New(t).ClientAuthID(cl1.ID).Connection(mockConn).Build()
c2 := clients.New(t).ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Build()
c1 := clients.New(t).ClientAuthID(cl1.ID).Connection(mockConn).Logger(testLog).Build()
c2 := clients.New(t).ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
testCases := []struct {
descr string // Test Case Description
@ -557,8 +557,7 @@ func TestHandleDeleteClientAuth(t *testing.T) {
require.NoError(err)
assert.ElementsMatch(tc.wantClientsAuth, clients, "clients auth not as expected")
assert.Equal(tc.wantClosedConn, mockConn.closed)
allClients, err := al.clientService.GetAll()
require.NoError(err)
allClients := al.clientService.GetAll()
assert.ElementsMatch(tc.wantClients, allClients)
})
}

View File

@ -6,7 +6,6 @@ import (
"net"
"net/http"
"strconv"
"time"
"github.com/gorilla/mux"
"golang.org/x/crypto/ssh"
@ -19,133 +18,11 @@ import (
"github.com/cloudradar-monitoring/rport/server/ports"
"github.com/cloudradar-monitoring/rport/server/routes"
"github.com/cloudradar-monitoring/rport/server/validation"
"github.com/cloudradar-monitoring/rport/share/clientconfig"
"github.com/cloudradar-monitoring/rport/share/comm"
"github.com/cloudradar-monitoring/rport/share/models"
"github.com/cloudradar-monitoring/rport/share/query"
)
type ClientPayload struct {
ID *string `json:"id,omitempty"`
Name *string `json:"name,omitempty"`
Address *string `json:"address,omitempty"`
Hostname *string `json:"hostname,omitempty"`
OS *string `json:"os,omitempty"`
OSFullName *string `json:"os_full_name,omitempty"`
OSVersion *string `json:"os_version,omitempty"`
OSArch *string `json:"os_arch,omitempty"`
OSFamily *string `json:"os_family,omitempty"`
OSKernel *string `json:"os_kernel,omitempty"`
OSVirtualizationSystem *string `json:"os_virtualization_system,omitempty"`
OSVirtualizationRole *string `json:"os_virtualization_role,omitempty"`
NumCPUs *int `json:"num_cpus,omitempty"`
CPUFamily *string `json:"cpu_family,omitempty"`
CPUModel *string `json:"cpu_model,omitempty"`
CPUModelName *string `json:"cpu_model_name,omitempty"`
CPUVendor *string `json:"cpu_vendor,omitempty"`
MemoryTotal *uint64 `json:"mem_total,omitempty"`
Timezone *string `json:"timezone,omitempty"`
ClientAuthID *string `json:"client_auth_id,omitempty"`
Version *string `json:"version,omitempty"`
DisconnectedAt **time.Time `json:"disconnected_at,omitempty"`
LastHeartbeatAt **time.Time `json:"last_heartbeat_at,omitempty"`
ConnectionState *string `json:"connection_state,omitempty"`
IPv4 *[]string `json:"ipv4,omitempty"`
IPv6 *[]string `json:"ipv6,omitempty"`
Tags *[]string `json:"tags,omitempty"`
AllowedUserGroups *[]string `json:"allowed_user_groups,omitempty"`
Tunnels *[]*clienttunnel.Tunnel `json:"tunnels,omitempty"`
UpdatesStatus **models.UpdatesStatus `json:"updates_status,omitempty"`
ClientConfiguration **clientconfig.Config `json:"client_configuration,omitempty"`
Groups *[]string `json:"groups,omitempty"`
}
func convertToClientsPayload(clients []*clients.CalculatedClient, fields []query.FieldsOption) []ClientPayload {
r := make([]ClientPayload, 0, len(clients))
for _, cur := range clients {
r = append(r, convertToClientPayload(cur, fields))
}
return r
}
func convertToClientPayload(client *clients.CalculatedClient, fields []query.FieldsOption) ClientPayload { //nolint:gocyclo
requestedFields := query.RequestedFields(fields, "clients")
p := ClientPayload{}
for field := range clients.OptionsSupportedFields["clients"] {
if len(fields) > 0 && !requestedFields[field] {
continue
}
switch field {
case "id":
p.ID = &client.ID
case "name":
p.Name = &client.Name
case "os":
p.OS = &client.OS
case "os_arch":
p.OSArch = &client.OSArch
case "os_family":
p.OSFamily = &client.OSFamily
case "os_kernel":
p.OSKernel = &client.OSKernel
case "hostname":
p.Hostname = &client.Hostname
case "ipv4":
p.IPv4 = &client.IPv4
case "ipv6":
p.IPv6 = &client.IPv6
case "tags":
p.Tags = &client.Tags
case "version":
p.Version = &client.Version
case "address":
p.Address = &client.Address
case "tunnels":
p.Tunnels = &client.Tunnels
case "disconnected_at":
p.DisconnectedAt = &client.DisconnectedAt
case "last_heartbeat_at":
p.LastHeartbeatAt = &client.LastHeartbeatAt
case "connection_state":
connectionState := string(client.ConnectionState)
p.ConnectionState = &connectionState
case "client_auth_id":
p.ClientAuthID = &client.ClientAuthID
case "os_full_name":
p.OSFullName = &client.OSFullName
case "os_version":
p.OSVersion = &client.OSVersion
case "os_virtualization_system":
p.OSVirtualizationSystem = &client.OSVirtualizationSystem
case "os_virtualization_role":
p.OSVirtualizationRole = &client.OSVirtualizationRole
case "cpu_family":
p.CPUFamily = &client.CPUFamily
case "cpu_model":
p.CPUModel = &client.CPUModel
case "cpu_model_name":
p.CPUModelName = &client.CPUModelName
case "cpu_vendor":
p.CPUVendor = &client.CPUVendor
case "timezone":
p.Timezone = &client.Timezone
case "num_cpus":
p.NumCPUs = &client.NumCPUs
case "mem_total":
p.MemoryTotal = &client.MemoryTotal
case "allowed_user_groups":
p.AllowedUserGroups = &client.AllowedUserGroups
case "updates_status":
p.UpdatesStatus = &client.UpdatesStatus
case "client_configuration":
p.ClientConfiguration = &client.ClientConfiguration
case "groups":
p.Groups = &client.Groups
}
}
return p
}
func getCorrespondingSortFunc(sorts []query.SortOption) (sortFunc func(a []*clients.CalculatedClient, desc bool), desc bool, err error) {
if len(sorts) < 1 {
return clients.SortByID, false, nil
@ -200,7 +77,7 @@ func (al *APIListener) handleGetClient(w http.ResponseWriter, req *http.Request)
return
}
clientPayload := convertToClientPayload(client.ToCalculated(groups), options.Fields)
clientPayload := clients.ConvertToClientPayload(client.ToCalculated(groups), options.Fields)
al.writeJSONResponse(w, http.StatusOK, api.NewSuccessPayload(clientPayload))
}
@ -291,19 +168,20 @@ func (al *APIListener) handleGetClients(w http.ResponseWriter, req *http.Request
return
}
cls, err := al.clientService.GetFilteredUserClients(curUser, options.Filters, groups)
filteredClients, err := al.clientService.GetFilteredUserClients(curUser, options.Filters, groups)
if err != nil {
al.jsonError(w, err)
return
}
sortFunc(cls, desc)
sortFunc(filteredClients, desc)
totalCount := len(cls)
totalCount := len(filteredClients)
start, end := options.Pagination.GetStartEnd(totalCount)
cls = cls[start:end]
filteredClients = filteredClients[start:end]
clientsPayload := clients.ConvertToClientsPayload(filteredClients, options.Fields)
clientsPayload := convertToClientsPayload(cls, options.Fields)
al.writeJSONResponse(w, http.StatusOK, &api.SuccessPayload{
Data: clientsPayload,
Meta: api.NewMeta(totalCount),
@ -345,7 +223,7 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
}
if client.IsPaused() {
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to start tunnel for client with id %s due to client being paused (reason = %s)", clientID, client.PausedReason))
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to start tunnel for client with id %s due to client being paused (reason = %s)", clientID, client.GetPausedReason()))
return
}
@ -411,7 +289,7 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
remote.ACL = &aclStr
}
allowed, err := clienttunnel.IsAllowed(remote.Remote(), client.Connection)
allowed, err := clienttunnel.IsAllowed(remote.Remote(), client.GetConnection(), al.Log())
if err != nil {
al.jsonError(w, err)
return
@ -426,7 +304,7 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
return
}
for _, t := range client.Tunnels {
for _, t := range client.GetTunnels() {
if t.Remote.Remote() == remote.Remote() && t.Remote.IsProtocol(remote.Protocol) && t.EqualACL(remote.ACL) {
al.jsonErrorResponseWithErrCode(w, http.StatusBadRequest, ErrCodeTunnelToPortExist, fmt.Sprintf("Tunnel to port %s already exists.", remote.RemotePort))
return
@ -434,17 +312,13 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
}
if checkPortStr := req.URL.Query().Get("check_port"); checkPortStr != "0" && remote.IsProtocol(models.ProtocolTCP) {
err = al.checkRemotePort(*remote, client.Connection)
err = al.checkRemotePort(*remote, client.GetConnection())
if err != nil {
al.jsonError(w, err)
return
}
}
// make next steps thread-safe
client.Lock()
defer client.Unlock()
if remote.IsLocalSpecified() {
err = al.checkLocalPort(remote.LocalPort, remote.Protocol)
if err != nil {
@ -588,7 +462,7 @@ func (al *APIListener) checkRemotePort(remote models.Remote, conn ssh.Conn) (err
Timeout: al.config.Server.CheckPortTimeout,
}
resp := &comm.CheckPortResponse{}
err = comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckPort, req, resp)
err = comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckPort, req, resp, al.Log())
if err != nil {
if _, ok := err.(*comm.ClientError); ok {
err = apierrors.NewAPIError(http.StatusConflict, "", "", err)
@ -646,10 +520,6 @@ func (al *APIListener) handleDeleteClientTunnel(w http.ResponseWriter, req *http
return
}
// make next steps thread-safe
client.Lock()
defer client.Unlock()
tunnel := al.clientService.FindTunnel(client, tunnelID)
if tunnel == nil {
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, "tunnel not found")

View File

@ -37,7 +37,7 @@ func (mockClientGroupProvider) GetAll(ctx context.Context) ([]*cgroups.ClientGro
}
func TestHandleGetClient(t *testing.T) {
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
al := APIListener{
insecureForTests: true,
Server: &Server{
@ -175,8 +175,12 @@ func TestHandleGetClients(t *testing.T) {
Username: "admin",
Groups: []string{users.Administrators},
}
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
c2 := clients.New(t).ID("client-2").ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Build()
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
c1.Logger = testLog
c2 := clients.New(t).ID("client-2").ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c2.Logger = testLog
al := APIListener{
insecureForTests: true,
Server: &Server{
@ -521,8 +525,8 @@ func TestHandlePutTunnelWithName(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
c1.Connection = connMock
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
c1.SetConnection(connMock)
c1.Logger = testLog
mockClientService := &SimpleMockClientService{
@ -724,8 +728,8 @@ func TestHandlePutTunnelUsingCaddyProxies(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
c1.Connection = connMock
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
c1.SetConnection(connMock)
c1.Logger = testLog
mockClientService := &SimpleMockClientService{

View File

@ -366,7 +366,7 @@ func (al *APIListener) handleExecuteCommand(ctx context.Context, w http.Response
}
if client.IsPaused() {
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to execute command/script for client with id %s due to client being paused (reason = %s)", client.ID, client.PausedReason))
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to execute command/script for client with id %s due to client being paused (reason = %s)", client.GetID(), client.GetPausedReason()))
return nil
}
@ -382,7 +382,7 @@ func (al *APIListener) handleExecuteCommand(ctx context.Context, w http.Response
JID: jid,
FinishedAt: nil,
ClientID: executeInput.ClientID,
ClientName: client.Name,
ClientName: client.GetName(),
Command: executeInput.Command,
Interpreter: executeInput.Interpreter,
CreatedBy: api.GetUser(ctx, al.Logger),
@ -393,7 +393,7 @@ func (al *APIListener) handleExecuteCommand(ctx context.Context, w http.Response
IsScript: executeInput.IsScript,
}
sshResp := &comm.RunCmdResponse{}
err = comm.SendRequestAndGetResponse(client.Connection, comm.RequestTypeRunCmd, curJob, sshResp)
err = comm.SendRequestAndGetResponse(client.GetConnection(), comm.RequestTypeRunCmd, curJob, sshResp, al.Log())
if err != nil {
if _, ok := err.(*comm.ClientError); ok {
al.jsonErrorResponseWithTitle(w, http.StatusConflict, err.Error())

View File

@ -107,8 +107,8 @@ func TestHandlePostCommand(t *testing.T) {
require.NoError(t, err)
connMock.ReturnResponsePayload = sshRespBytes
c1 := clients.New(t).Connection(connMock).Build()
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Build()
c1 := clients.New(t).Connection(connMock).Logger(testLog).Build()
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
testCases := []struct {
name string
@ -132,7 +132,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "valid cmd",
requestBody: validReqBody,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusOK,
wantTimeout: gotCmdTimeoutSec,
@ -140,7 +140,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "valid cmd with interpreter",
requestBody: `{"command": "` + gotCmd + `","interpreter": "powershell"}`,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusOK,
wantTimeout: defaultTimeout,
@ -149,7 +149,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "invalid interpreter",
requestBody: `{"command": "` + gotCmd + `","interpreter": "unsupported"}`,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusBadRequest,
wantErrTitle: "Invalid interpreter.",
@ -158,7 +158,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "valid cmd with no timeout",
requestBody: `{"command": "/bin/date;foo;whoami"}`,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantTimeout: defaultTimeout,
wantStatusCode: http.StatusOK,
@ -166,7 +166,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "valid cmd with 0 timeout",
requestBody: `{"command": "/bin/date;foo;whoami", "timeout_sec": 0}`,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantTimeout: defaultTimeout,
wantStatusCode: http.StatusOK,
@ -174,7 +174,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "empty cmd",
requestBody: `{"command": "", "timeout_sec": 30}`,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusBadRequest,
wantErrTitle: "Command cannot be empty.",
@ -182,7 +182,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "no cmd",
requestBody: `{"timeout_sec": 30}`,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusBadRequest,
wantErrTitle: "Command cannot be empty.",
@ -190,7 +190,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "empty body",
requestBody: "",
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusBadRequest,
wantErrTitle: "Missing body with json data.",
@ -198,7 +198,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "invalid request body",
requestBody: "sdfn fasld fasdf sdlf jd",
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusBadRequest,
wantErrTitle: "Invalid JSON data.",
@ -207,7 +207,7 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "invalid request body: unknown param",
requestBody: `{"command": "/bin/date;foo;whoami", "timeout": 30}`,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusBadRequest,
wantErrTitle: "Invalid JSON data.",
@ -216,24 +216,24 @@ func TestHandlePostCommand(t *testing.T) {
{
name: "no active client",
requestBody: validReqBody,
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{},
wantStatusCode: http.StatusNotFound,
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c1.ID),
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c1.GetID()),
},
{
name: "disconnected client",
requestBody: validReqBody,
cid: c2.ID,
cid: c2.GetID(),
clients: []*clients.Client{c1, c2},
wantStatusCode: http.StatusNotFound,
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c2.ID),
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c2.GetID()),
},
{
name: "error on save job",
requestBody: validReqBody,
jpReturnSaveErr: errors.New("save fake error"),
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusInternalServerError,
wantErrTitle: "Failed to persist a new job.",
@ -243,7 +243,7 @@ func TestHandlePostCommand(t *testing.T) {
name: "error on send request",
requestBody: validReqBody,
connReturnErr: errors.New("send fake error"),
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusInternalServerError,
wantErrTitle: "Failed to execute remote command.",
@ -253,7 +253,7 @@ func TestHandlePostCommand(t *testing.T) {
name: "invalid ssh response format",
requestBody: validReqBody,
connReturnResp: []byte("invalid ssh response data"),
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusConflict,
wantErrTitle: "invalid client response format: failed to decode response into *comm.RunCmdResponse: invalid character 'i' looking for beginning of value",
@ -263,7 +263,7 @@ func TestHandlePostCommand(t *testing.T) {
requestBody: validReqBody,
connReturnNotOk: true,
connReturnResp: []byte("fake failure msg"),
cid: c1.ID,
cid: c1.GetID(),
clients: []*clients.Client{c1},
wantStatusCode: http.StatusConflict,
wantErrTitle: "client error: fake failure msg",
@ -552,9 +552,9 @@ func TestHandlePostMultiClientCommand(t *testing.T) {
require.NoError(t, err)
connMock2.ReturnResponsePayload = sshRespBytes2
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c1.Logger = testLog
c2.Logger = testLog
@ -565,7 +565,7 @@ func TestHandlePostMultiClientCommand(t *testing.T) {
gotCmdTimeoutSec := 30
validReqBody := `{"command": "` + gotCmd +
`","timeout_sec": ` + strconv.Itoa(gotCmdTimeoutSec) +
`,"client_ids": ["` + c1.ID + `", "` + c2.ID + `"]` +
`,"client_ids": ["` + c1.GetID() + `", "` + c2.GetID() + `"]` +
`,"abort_on_error": false` +
`,"execute_concurrently": false` +
`}`
@ -762,11 +762,11 @@ func TestHandlePostMultiClientCommandWithPausedClient(t *testing.T) {
require.NoError(t, err)
connMock2.ReturnResponsePayload = sshRespBytes2
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c1.Logger = testLog
c1.SetPaused(true, clients.PausedDueToMaxClientsExceeded)
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c2.Logger = testLog
defaultTimeout := 60
@ -775,14 +775,14 @@ func TestHandlePostMultiClientCommandWithPausedClient(t *testing.T) {
c1ValidReqBody := `{"command": "` + gotCmd +
`","timeout_sec": ` + strconv.Itoa(gotCmdTimeoutSec) +
`,"client_ids": ["` + c1.ID + `"]` +
`,"client_ids": ["` + c1.GetID() + `"]` +
`,"abort_on_error": false` +
`,"execute_concurrently": false` +
`}`
c2ValidReqBody := `{"command": "` + gotCmd +
`","timeout_sec": ` + strconv.Itoa(gotCmdTimeoutSec) +
`,"client_ids": ["` + c2.ID + `"]` +
`,"client_ids": ["` + c2.GetID() + `"]` +
`,"abort_on_error": false` +
`,"execute_concurrently": false` +
`}`
@ -973,10 +973,10 @@ func TestHandlePostMultiClientCommandWithGroupIDs(t *testing.T) {
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
@ -990,9 +990,9 @@ func TestHandlePostMultiClientCommandWithGroupIDs(t *testing.T) {
Version: &cgroups.ParamValues{"0.1.1*"},
})
c1.AllowedUserGroups = []string{"group-1"}
c2.AllowedUserGroups = []string{"group-1"}
c4.AllowedUserGroups = []string{"group-2"}
c1.SetAllowedUserGroups([]string{"group-1"})
c2.SetAllowedUserGroups([]string{"group-1"})
c4.SetAllowedUserGroups([]string{"group-2"})
al := makeAPIListener(curUser,
clients.NewClientRepository([]*clients.Client{c1, c2, c3, c4}, &hour, testLog),
@ -1187,15 +1187,15 @@ func TestHandlePostMultiClientCommandWithTags(t *testing.T) {
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
c1.Tags = []string{"linux"}
c2.Tags = []string{"windows"}
c3.Tags = []string{"mac"}
c4.Tags = []string{"linux", "windows"}
c1.SetTags([]string{"linux"})
c2.SetTags([]string{"windows"})
c3.SetTags([]string{"mac"})
c4.SetTags([]string{"linux", "windows"})
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
@ -1209,9 +1209,9 @@ func TestHandlePostMultiClientCommandWithTags(t *testing.T) {
Version: &cgroups.ParamValues{"0.1.1*"},
})
c1.AllowedUserGroups = []string{"group-1"}
c2.AllowedUserGroups = []string{"group-1"}
c4.AllowedUserGroups = []string{"group-2"}
c1.SetAllowedUserGroups([]string{"group-1"})
c2.SetAllowedUserGroups([]string{"group-1"})
c4.SetAllowedUserGroups([]string{"group-2"})
clientList := []*clients.Client{c1, c2, c4}
@ -1425,15 +1425,15 @@ func TestHandlePostMultiClientWSCommandWithTags(t *testing.T) {
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
c1.Tags = []string{"linux"}
c2.Tags = []string{"windows"}
c3.Tags = []string{"mac"}
c4.Tags = []string{"linux", "windows"}
c1.SetTags([]string{"linux"})
c2.SetTags([]string{"windows"})
c3.SetTags([]string{"mac"})
c4.SetTags([]string{"linux", "windows"})
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
@ -1447,9 +1447,9 @@ func TestHandlePostMultiClientWSCommandWithTags(t *testing.T) {
Version: &cgroups.ParamValues{"0.1.1*"},
})
c1.AllowedUserGroups = []string{"group-1"}
c2.AllowedUserGroups = []string{"group-1"}
c4.AllowedUserGroups = []string{"group-2"}
c1.SetAllowedUserGroups([]string{"group-1"})
c2.SetAllowedUserGroups([]string{"group-1"})
c4.SetAllowedUserGroups([]string{"group-2"})
clientList := []*clients.Client{c1, c2, c4}
@ -1659,15 +1659,15 @@ func TestHandlePostMultiClientScriptWithTags(t *testing.T) {
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
c1.Tags = []string{"linux"}
c2.Tags = []string{"windows"}
c3.Tags = []string{"mac"}
c4.Tags = []string{"linux", "windows"}
c1.SetTags([]string{"linux"})
c2.SetTags([]string{"windows"})
c3.SetTags([]string{"mac"})
c4.SetTags([]string{"linux", "windows"})
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
@ -1681,9 +1681,9 @@ func TestHandlePostMultiClientScriptWithTags(t *testing.T) {
Version: &cgroups.ParamValues{"0.1.1*"},
})
c1.AllowedUserGroups = []string{"group-1"}
c2.AllowedUserGroups = []string{"group-1"}
c4.AllowedUserGroups = []string{"group-2"}
c1.SetAllowedUserGroups([]string{"group-1"})
c2.SetAllowedUserGroups([]string{"group-1"})
c4.SetAllowedUserGroups([]string{"group-2"})
clientList := []*clients.Client{c1, c2, c4}
@ -1899,15 +1899,15 @@ func TestHandlePostMultiClientWSScriptWithTags(t *testing.T) {
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
c1.Tags = []string{"linux"}
c2.Tags = []string{"windows"}
c3.Tags = []string{"mac"}
c4.Tags = []string{"linux", "windows"}
c1.SetTags([]string{"linux"})
c2.SetTags([]string{"windows"})
c3.SetTags([]string{"mac"})
c4.SetTags([]string{"linux", "windows"})
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
@ -1921,9 +1921,9 @@ func TestHandlePostMultiClientWSScriptWithTags(t *testing.T) {
Version: &cgroups.ParamValues{"0.1.1*"},
})
c1.AllowedUserGroups = []string{"group-1"}
c2.AllowedUserGroups = []string{"group-1"}
c4.AllowedUserGroups = []string{"group-2"}
c1.SetAllowedUserGroups([]string{"group-1"})
c2.SetAllowedUserGroups([]string{"group-1"})
c4.SetAllowedUserGroups([]string{"group-2"})
clientList := []*clients.Client{c1, c2, c4}

View File

@ -32,7 +32,7 @@ func (al *APIListener) handleRefreshUpdatesStatus(w http.ResponseWriter, req *ht
return
}
err = comm.SendRequestAndGetResponse(client.Connection, comm.RequestTypeRefreshUpdatesStatus, nil, nil)
err = comm.SendRequestAndGetResponse(client.GetConnection(), comm.RequestTypeRefreshUpdatesStatus, nil, nil, al.Log())
if err != nil {
al.jsonErrorResponse(w, http.StatusInternalServerError, err)
return
@ -77,8 +77,10 @@ func (al *APIListener) handleGetClientGraphMetrics(w http.ResponseWriter, req *h
queryOptions := query.NewOptions(req, monitoring.ClientGraphMetricsSortDefault, monitoring.ClientGraphMetricsFilterDefault, monitoring.ClientGraphMetricsFieldsDefault)
requestInfo := query.ParseRequestInfo(req)
netLan := client.ClientConfiguration.Monitoring.LanCard != nil
netWan := client.ClientConfiguration.Monitoring.WanCard != nil
monitoringConfig := client.GetMonitoringConfig()
netLan := monitoringConfig.LanCard != nil
netWan := monitoringConfig.WanCard != nil
payload, err := al.monitoringService.ListClientGraphMetrics(req.Context(), clientID, queryOptions, requestInfo, netLan, netWan)
if err != nil {
@ -110,7 +112,9 @@ func (al *APIListener) handleGetClientGraphMetricsGraph(w http.ResponseWriter, r
return
}
payload, err := al.monitoringService.ListClientGraph(req.Context(), clientID, queryOptions, graph, client.ClientConfiguration.Monitoring.LanCard, client.ClientConfiguration.Monitoring.WanCard)
monitoringConfig := client.GetMonitoringConfig()
payload, err := al.monitoringService.ListClientGraph(req.Context(), clientID, queryOptions, graph, monitoringConfig.LanCard, monitoringConfig.WanCard)
if err != nil {
if err == sql.ErrNoRows {
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("graph-metrics for client with id %q not found", clientID))

View File

@ -17,8 +17,8 @@ import (
)
func TestHandleRefreshUpdatesStatus(t *testing.T) {
c1 := clients.New(t).Build()
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Build()
c1 := clients.New(t).Logger(testLog).Build()
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
testCases := []struct {
Name string
@ -29,13 +29,13 @@ func TestHandleRefreshUpdatesStatus(t *testing.T) {
}{
{
Name: "Connected client",
ClientID: c1.ID,
ClientID: c1.GetID(),
ExpectedStatus: http.StatusNoContent,
ExpectedRequestName: comm.RequestTypeRefreshUpdatesStatus,
},
{
Name: "Disconnected client",
ClientID: c2.ID,
ClientID: c2.GetID(),
ExpectedStatus: http.StatusNotFound,
},
{
@ -45,7 +45,7 @@ func TestHandleRefreshUpdatesStatus(t *testing.T) {
},
{
Name: "SSH error",
ClientID: c1.ID,
ClientID: c1.GetID(),
SSHError: true,
ExpectedRequestName: comm.RequestTypeRefreshUpdatesStatus,
ExpectedStatus: http.StatusInternalServerError,
@ -57,7 +57,7 @@ func TestHandleRefreshUpdatesStatus(t *testing.T) {
connMock := test.NewConnMock()
// by default set to return success
connMock.ReturnOk = !tc.SSHError
c1.Connection = connMock
c1.SetConnection(connMock)
clientService := clients.NewClientService(nil, nil, clients.NewClientRepository([]*clients.Client{c1, c2}, &hour, testLog), testLog)
al := APIListener{
insecureForTests: true,

View File

@ -153,15 +153,15 @@ func TestHandlePostScheduleMultiClientJobWithTags(t *testing.T) {
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
c1.Tags = []string{"linux"}
c2.Tags = []string{"windows"}
c3.Tags = []string{"mac"}
c4.Tags = []string{"linux", "windows"}
c1.SetTags([]string{"linux"})
c2.SetTags([]string{"windows"})
c3.SetTags([]string{"mac"})
c4.SetTags([]string{"linux", "windows"})
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
@ -175,9 +175,9 @@ func TestHandlePostScheduleMultiClientJobWithTags(t *testing.T) {
Version: &cgroups.ParamValues{"0.1.1*"},
})
c1.AllowedUserGroups = []string{"group-1"}
c2.AllowedUserGroups = []string{"group-1"}
c4.AllowedUserGroups = []string{"group-2"}
c1.SetAllowedUserGroups([]string{"group-1"})
c2.SetAllowedUserGroups([]string{"group-1"})
c4.SetAllowedUserGroups([]string{"group-2"})
clientList := []*clients.Client{c1, c2, c4}
@ -382,15 +382,15 @@ func TestHandlePostUpdateScheduleMultiClientJobWithTags(t *testing.T) {
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
c1.Tags = []string{"linux"}
c2.Tags = []string{"windows"}
c3.Tags = []string{"mac"}
c4.Tags = []string{"linux", "windows"}
c1.SetTags([]string{"linux"})
c2.SetTags([]string{"windows"})
c3.SetTags([]string{"mac"})
c4.SetTags([]string{"linux", "windows"})
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
@ -404,9 +404,9 @@ func TestHandlePostUpdateScheduleMultiClientJobWithTags(t *testing.T) {
Version: &cgroups.ParamValues{"0.1.1*"},
})
c1.AllowedUserGroups = []string{"group-1"}
c2.AllowedUserGroups = []string{"group-1"}
c4.AllowedUserGroups = []string{"group-2"}
c1.SetAllowedUserGroups([]string{"group-1"})
c2.SetAllowedUserGroups([]string{"group-1"})
c4.SetAllowedUserGroups([]string{"group-2"})
clientList := []*clients.Client{c1, c2, c4}

View File

@ -8,11 +8,7 @@ import (
)
func (al *APIListener) handleGetStatus(w http.ResponseWriter, req *http.Request) {
countActive, err := al.clientService.CountActive()
if err != nil {
al.jsonErrorResponse(w, http.StatusInternalServerError, err)
return
}
countActive := al.clientService.CountActive()
countDisconnected, err := al.clientService.CountDisconnected()
if err != nil {

View File

@ -28,7 +28,7 @@ func (al *APIListener) handleGetStoredTunnels(w http.ResponseWriter, req *http.R
}
options := query.GetListOptions(req)
result, err := al.storedTunnels.List(ctx, options, client.ID)
result, err := al.storedTunnels.List(ctx, options, client.GetID())
if err != nil {
al.jsonError(w, err)
return
@ -59,7 +59,7 @@ func (al *APIListener) handlePostStoredTunnels(w http.ResponseWriter, req *http.
return
}
result, err := al.storedTunnels.Create(ctx, client.ID, storedTunnel)
result, err := al.storedTunnels.Create(ctx, client.GetID(), storedTunnel)
if err != nil {
al.jsonError(w, err)
return
@ -84,7 +84,7 @@ func (al *APIListener) handleDeleteStoredTunnel(w http.ResponseWriter, req *http
return
}
err = al.storedTunnels.Delete(ctx, client.ID, tunnelID)
err = al.storedTunnels.Delete(ctx, client.GetID(), tunnelID)
if err != nil {
al.jsonError(w, err)
return
@ -117,7 +117,7 @@ func (al *APIListener) handlePutStoredTunnel(w http.ResponseWriter, req *http.Re
}
storedTunnel.ID = tunnelID
result, err := al.storedTunnels.Update(ctx, client.ID, storedTunnel)
result, err := al.storedTunnels.Update(ctx, client.GetID(), storedTunnel)
if err != nil {
al.jsonError(w, err)
return

View File

@ -44,12 +44,13 @@ func (al *APIListener) handleGetTunnels(w http.ResponseWriter, req *http.Request
tunnels := make([]TunnelPayload, 0)
for _, c := range clients {
if c.DisconnectedAt != nil {
clientID := c.GetID()
if !c.IsConnected() {
continue
}
for _, t := range c.Tunnels {
tunnels = append(tunnels, convertToTunnelPayload(t, c.ID))
for _, t := range c.GetTunnels() {
tunnels = append(tunnels, convertToTunnelPayload(t, clientID))
}
}

View File

@ -84,8 +84,9 @@ func (al *APIListener) getOrderedClients(
// append group clients
for _, groupClient := range groupClients {
if !usedClientIDs[groupClient.ID] {
usedClientIDs[groupClient.ID] = true
groupClientID := groupClient.GetID()
if !usedClientIDs[groupClientID] {
usedClientIDs[groupClientID] = true
orderedClients = append(orderedClients, groupClient)
}
}

View File

@ -3,6 +3,7 @@ package chserver
import (
"context"
"errors"
"fmt"
"time"
"github.com/gorilla/websocket"
@ -110,6 +111,13 @@ func (al *APIListener) handleCommandsExecutionWS(
}
for _, client := range inboundMsg.OrderedClients {
if client.IsPaused() {
msg := fmt.Sprintf("failed to execute command/script for client with id %s", client.GetID())
err := fmt.Errorf("client is paused (reason = %s)", client.GetPausedReason())
uiConnTS.WriteError(msg, err)
continue
}
curJID, err := generateNewJobID()
if err != nil {
uiConnTS.WriteError("Could not generate job id.", err)
@ -168,7 +176,14 @@ func (al *APIListener) handleCommandsExecutionWS(
} else {
client := inboundMsg.OrderedClients[0]
al.createAndRunJob( //nolint:errcheck // error is logged, nothing to act on here
if client.IsPaused() {
msg := fmt.Sprintf("failed to execute command/script for client with id %s", client.GetID())
err := fmt.Errorf("client is paused (reason = %s)", client.GetPausedReason())
uiConnTS.WriteError(msg, err)
return
}
al.createAndRunJob(
uiConnTS,
nil,
jid,

View File

@ -49,8 +49,8 @@ func (al *APIListener) createAndRunJob(
curJob := models.Job{
JID: jid,
StartedAt: time.Now(),
ClientID: client.ID,
ClientName: client.Name,
ClientID: client.GetID(),
ClientName: client.GetName(),
Command: cmd,
Cwd: cwd,
IsSudo: isSudo,
@ -69,13 +69,14 @@ func (al *APIListener) createAndRunJob(
var err error
if !client.IsPaused() {
if client.Connection != nil {
err = comm.SendRequestAndGetResponse(client.Connection, comm.RequestTypeRunCmd, curJob, sshResp)
err = comm.SendRequestAndGetResponse(client.GetConnection(), comm.RequestTypeRunCmd, curJob, sshResp, al.Log())
} else {
err = ErrClientNotConnected
}
} else {
err = fmt.Errorf("client is paused (reason = %s)", client.PausedReason)
}
if err != nil {
al.Errorf("%s, Error on execute remote command: %v", logPrefix, err)

View File

@ -10,6 +10,7 @@ import (
"os"
"path"
"strings"
"sync"
"time"
"golang.org/x/crypto/bcrypt"
@ -72,6 +73,15 @@ type APIListener struct {
tokenManager *authorization.Manager
commandManager *command.Manager
storedTunnels *storedtunnels.Manager
mu sync.RWMutex
}
func (al *APIListener) Log() (l *logger.Logger) {
al.mu.RLock()
defer al.mu.RUnlock()
return al.Logger
}
type UserService interface {

View File

@ -81,8 +81,8 @@ func (e *Entry) WithClient(c *clients.Client) *Entry {
return e
}
e.ClientID = c.ID
e.ClientHostName = c.Hostname
e.ClientID = c.GetID()
e.ClientHostName = c.GetHostname()
return e
}
@ -99,7 +99,7 @@ func (e *Entry) WithClientID(cid string) *Entry {
return e
}
if client != nil {
e.ClientHostName = client.Hostname
e.ClientHostName = client.GetHostname()
}
return e

View File

@ -84,11 +84,12 @@ func TestWithResponse(t *testing.T) {
}
func TestWithClient(t *testing.T) {
e := emptyEntry().WithClient(&clients.Client{
ID: "11236310-6cad-408e-b372-a0f04d68d2df",
Address: "127.0.0.1",
Hostname: "hostname",
})
c1 := clients.Client{}
c1.SetID("11236310-6cad-408e-b372-a0f04d68d2df")
c1.SetAddress("127.0.0.1")
c1.SetHostname("hostname")
e := emptyEntry().WithClient(&c1)
assert.Equal(t, "11236310-6cad-408e-b372-a0f04d68d2df", e.ClientID)
assert.Equal(t, "hostname", e.ClientHostName)
@ -136,18 +137,17 @@ func TestSaveForMultipleClients(t *testing.T) {
auditLog := enabledAuditLog()
auditLog.provider = mockProvider
auditLog.Entry("", "").SaveForMultipleClients([]*clients.Client{
{
ID: "c1",
Address: "c1.com",
Hostname: "hostname1",
},
{
ID: "c2",
Address: "c2.com",
Hostname: "hostname2",
},
})
c1 := clients.Client{}
c1.SetID("c1")
c1.SetAddress("c1.com")
c1.SetHostname("hostname1")
c2 := clients.Client{}
c2.SetID("c2")
c2.SetAddress("c2.com")
c2.SetHostname("hostname2")
auditLog.Entry("", "").SaveForMultipleClients([]*clients.Client{&c1, &c2})
assert.Len(t, mockProvider.entries, 2)
assert.Equal(t, "c1", mockProvider.entries[0].ClientID)
@ -172,12 +172,13 @@ type mockClientGetter struct {
}
func (mockClientGetter) GetByID(id string) (*clients.Client, error) {
c1 := clients.Client{}
c1.SetID("11236310-6cad-408e-b372-a0f04d68d2df")
c1.SetAddress("127.0.0.1")
c1.SetHostname("hostname")
if id == "11236310-6cad-408e-b372-a0f04d68d2df" {
return &clients.Client{
ID: "11236310-6cad-408e-b372-a0f04d68d2df",
Address: "127.0.0.1",
Hostname: "hostname",
}, nil
return &c1, nil
}
return nil, nil
}

View File

@ -1,8 +1,12 @@
//go:build linux
// +build linux
package caddy_test
import (
"context"
"net/http"
"os"
"testing"
"time"
@ -86,6 +90,9 @@ func setupNewCaddyServer(ctx context.Context, t *testing.T) (cs *caddy.Server) {
require.NoError(t, err)
caddy.HostDomainSocket = bc.GlobalSettings.AdminSocket
// ensure the no existing admin socket file
os.Remove(caddy.HostDomainSocket)
cs = caddy.NewCaddyServer(cfg, testLog)
err = cs.Start(ctx)
require.NoError(t, err)

View File

@ -15,6 +15,7 @@ import (
"path"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"time"
@ -116,39 +117,40 @@ type LogConfig struct {
}
type ServerConfig struct {
ListenAddress string `mapstructure:"address"`
URL []string `mapstructure:"url"`
PairingURL string `mapstructure:"pairing_url"`
TunnelHost string `mapstructure:"tunnel_host"`
KeySeed string `mapstructure:"key_seed"`
Auth string `mapstructure:"auth"`
AuthFile string `mapstructure:"auth_file"`
AuthTable string `mapstructure:"auth_table"`
Proxy string `mapstructure:"proxy"`
UsedPortsRaw []string `mapstructure:"used_ports"`
ExcludedPortsRaw []string `mapstructure:"excluded_ports"`
DataDir string `mapstructure:"data_dir"`
SqliteWAL bool `mapstructure:"sqlite_wal"`
PurgeDisconnectedClients bool `mapstructure:"purge_disconnected_clients"`
CleanupLostClients bool `mapstructure:"cleanup_lost_clients" replaced_by:"PurgeDisconnectedClients"`
KeepLostClients time.Duration `mapstructure:"keep_lost_clients" replaced_by:"KeepDisconnectedClients"`
KeepDisconnectedClients time.Duration `mapstructure:"keep_disconnected_clients"`
CleanupClientsInterval time.Duration `mapstructure:"cleanup_clients_interval" replaced_by:"PurgeDisconnectedClientsInterval"`
PurgeDisconnectedClientsInterval time.Duration `mapstructure:"purge_disconnected_clients_interval"`
CheckClientsConnectionInterval time.Duration `mapstructure:"check_clients_connection_interval"`
CheckClientsConnectionTimeout time.Duration `mapstructure:"check_clients_connection_timeout"`
MaxRequestBytesClient int64 `mapstructure:"max_request_bytes_client"`
CheckPortTimeout time.Duration `mapstructure:"check_port_timeout"`
RunRemoteCmdTimeoutSec int `mapstructure:"run_remote_cmd_timeout_sec"`
AuthWrite bool `mapstructure:"auth_write"`
AuthMultiuseCreds bool `mapstructure:"auth_multiuse_creds"`
EquateClientauthidClientid bool `mapstructure:"equate_clientauthid_clientid"`
AllowRoot bool `mapstructure:"allow_root"`
ClientLoginWait float32 `mapstructure:"client_login_wait"`
MaxFailedLogin int `mapstructure:"max_failed_login"`
BanTime int `mapstructure:"ban_time"`
InternalTunnelProxyConfig clienttunnel.InternalTunnelProxyConfig `mapstructure:",squash"`
JobsMaxResults int `mapstructure:"jobs_max_results"`
ListenAddress string `mapstructure:"address"`
URL []string `mapstructure:"url"`
PairingURL string `mapstructure:"pairing_url"`
TunnelHost string `mapstructure:"tunnel_host"`
KeySeed string `mapstructure:"key_seed"`
Auth string `mapstructure:"auth"`
AuthFile string `mapstructure:"auth_file"`
AuthTable string `mapstructure:"auth_table"`
Proxy string `mapstructure:"proxy"`
UsedPortsRaw []string `mapstructure:"used_ports"`
ExcludedPortsRaw []string `mapstructure:"excluded_ports"`
DataDir string `mapstructure:"data_dir"`
SqliteWAL bool `mapstructure:"sqlite_wal"`
MaxConcurrentSSHConnectionHandshakes int `mapstructure:"max_concurrent_ssh_handshakes"`
PurgeDisconnectedClients bool `mapstructure:"purge_disconnected_clients"`
CleanupLostClients bool `mapstructure:"cleanup_lost_clients" replaced_by:"PurgeDisconnectedClients"`
KeepLostClients time.Duration `mapstructure:"keep_lost_clients" replaced_by:"KeepDisconnectedClients"`
KeepDisconnectedClients time.Duration `mapstructure:"keep_disconnected_clients"`
CleanupClientsInterval time.Duration `mapstructure:"cleanup_clients_interval" replaced_by:"PurgeDisconnectedClientsInterval"`
PurgeDisconnectedClientsInterval time.Duration `mapstructure:"purge_disconnected_clients_interval"`
CheckClientsConnectionInterval time.Duration `mapstructure:"check_clients_connection_interval"`
CheckClientsConnectionTimeout time.Duration `mapstructure:"check_clients_connection_timeout"`
MaxRequestBytesClient int64 `mapstructure:"max_request_bytes_client"`
CheckPortTimeout time.Duration `mapstructure:"check_port_timeout"`
RunRemoteCmdTimeoutSec int `mapstructure:"run_remote_cmd_timeout_sec"`
AuthWrite bool `mapstructure:"auth_write"`
AuthMultiuseCreds bool `mapstructure:"auth_multiuse_creds"`
EquateClientauthidClientid bool `mapstructure:"equate_clientauthid_clientid"`
AllowRoot bool `mapstructure:"allow_root"`
ClientLoginWait float32 `mapstructure:"client_login_wait"`
MaxFailedLogin int `mapstructure:"max_failed_login"`
BanTime int `mapstructure:"ban_time"`
InternalTunnelProxyConfig clienttunnel.InternalTunnelProxyConfig `mapstructure:",squash"`
JobsMaxResults int `mapstructure:"jobs_max_results"`
// DEPRECATED, only here for backwards compatibility
MaxRequestBytes int64 `mapstructure:"max_request_bytes"`
@ -379,6 +381,11 @@ func (c *Config) ParseAndValidate(mLog *logger.MemLogger) error {
return err
}
maxProcs := runtime.GOMAXPROCS(0)
if c.Server.MaxConcurrentSSHConnectionHandshakes > (maxProcs * 2) {
mLog.Infof("warning: allowing too many concurrent ssh handhakes ('max_concurrent_ssh_handshakes') will slow down the server significantly. Please use a value less than or equal to the MAX_PROCS (%d)", maxProcs)
}
if c.Server.CheckClientsConnectionInterval < CheckClientsConnectionIntervalMinimum {
c.Server.CheckClientsConnectionInterval = CheckClientsConnectionIntervalMinimum
mLog.Errorf("'check_clients_status_interval' too fast. Using the minimum possible of %s", CheckClientsConnectionIntervalMinimum)

View File

@ -609,7 +609,7 @@ func TestParseAndValidatePorts(t *testing.T) {
UsedPortsRaw: []string{"45-50"},
ExcludedPortsRaw: []string{"1-10", "44", "51", "80-90"},
},
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{45, 46, 47, 48, 49, 50}),
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{45, 46, 47, 48, 49, 50}),
},
{
Name: "used ports and excluded ports",
@ -617,7 +617,7 @@ func TestParseAndValidatePorts(t *testing.T) {
UsedPortsRaw: []string{"100-200", "205", "250-300", "305", "400-500"},
ExcludedPortsRaw: []string{"80-110", "114-116", "118", "120-198", "200", "240-310", "305", "401-499"},
},
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{111, 112, 113, 117, 119, 199, 205, 400, 500}),
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{111, 112, 113, 117, 119, 199, 205, 400, 500}),
},
{
Name: "excluded ports empty",
@ -625,7 +625,7 @@ func TestParseAndValidatePorts(t *testing.T) {
UsedPortsRaw: []string{"45-46"},
ExcludedPortsRaw: []string{},
},
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{45, 46}),
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{45, 46}),
},
{
Name: "one allowed port",
@ -633,7 +633,7 @@ func TestParseAndValidatePorts(t *testing.T) {
UsedPortsRaw: []string{"20000"},
ExcludedPortsRaw: []string{},
},
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{20000}),
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{20000}),
},
{
Name: "both empty",

View File

@ -12,6 +12,7 @@ import (
"net/http/httputil"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"
@ -32,9 +33,17 @@ import (
"github.com/cloudradar-monitoring/rport/share/security"
)
const ConnectionRequestTimeOut = 5 * 60 * time.Second
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
}
type ClientListener struct {
*logger.Logger
*Server
logger *logger.Logger
server *Server
connStats chshare.ConnStats
httpServer *chshare.HTTPServer
@ -45,33 +54,46 @@ type ClientListener struct {
bannedIPs *security.MaxBadAttemptsBanList
clientIndexAutoIncrement int32
// semaphore used to limit concurrent pending SSH connections
inprogressSSHHandshakes chan struct{}
mu sync.RWMutex
}
const SSHTimeOut = 90 * time.Second
func (cl *ClientListener) Log() (l *logger.Logger) {
cl.mu.RLock()
defer cl.mu.RUnlock()
return cl.logger
}
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool { return true },
func (cl *ClientListener) GetClientService() (cs clients.ClientService) {
cl.mu.RLock()
defer cl.mu.RUnlock()
return cl.server.clientService
}
func NewClientListener(server *Server, privateKey ssh.Signer) (*ClientListener, error) {
config := server.config
// semaphore to limit number of active pending SSH connections
inprogressSSHHandshakes := make(chan struct{}, config.Server.MaxConcurrentSSHConnectionHandshakes)
clog := logger.NewLogger("client-listener", config.Logging.LogOutput, config.Logging.LogLevel)
cl := &ClientListener{
Server: server,
httpServer: chshare.NewHTTPServer(int(config.Server.MaxRequestBytesClient), clog),
Logger: clog,
requestLogOptions: config.InitRequestLogOptions(),
bannedClientAuths: security.NewBanList(time.Duration(config.Server.ClientLoginWait) * time.Second),
server: server,
httpServer: chshare.NewHTTPServer(int(config.Server.MaxRequestBytesClient), clog),
requestLogOptions: config.InitRequestLogOptions(),
bannedClientAuths: security.NewBanList(time.Duration(config.Server.ClientLoginWait) * time.Second),
inprogressSSHHandshakes: inprogressSSHHandshakes,
logger: clog,
}
if config.Server.MaxFailedLogin > 0 && config.Server.BanTime > 0 {
cl.bannedIPs = security.NewMaxBadAttemptsBanList(
config.Server.MaxFailedLogin,
time.Duration(config.Server.BanTime)*time.Second,
cl.Logger,
cl.logger,
)
}
@ -80,7 +102,9 @@ func NewClientListener(server *Server, privateKey ssh.Signer) (*ClientListener,
ServerVersion: "SSH-" + chshare.ProtocolVersion + "-server",
PasswordCallback: cl.authUser,
}
cl.sshConfig.AddHostKey(privateKey)
//setup reverse proxy
if config.Server.Proxy != "" {
u, err := url.Parse(config.Server.Proxy)
@ -107,15 +131,15 @@ func (cl *ClientListener) authUser(c ssh.ConnMetadata, password []byte) (*ssh.Pe
clientAuthID := c.User()
if cl.bannedClientAuths.IsBanned(clientAuthID) {
cl.Infof("Failed login attempt for client auth id %q, forcing to wait for %vs (%s)",
cl.Log().Infof("Failed login attempt for client auth id %q, forcing to wait for %vs (%s)",
clientAuthID,
cl.config.Server.ClientLoginWait,
cl.server.config.Server.ClientLoginWait,
cl.getIP(c.RemoteAddr()),
)
return nil, ErrTooManyRequests
}
clientAuth, err := cl.clientAuthProvider.Get(clientAuthID)
clientAuth, err := cl.server.clientAuthProvider.Get(clientAuthID)
if err != nil {
return nil, err
}
@ -123,7 +147,7 @@ func (cl *ClientListener) authUser(c ssh.ConnMetadata, password []byte) (*ssh.Pe
ip := cl.getIP(c.RemoteAddr())
// constant time compare is used for security reasons
if clientAuth == nil || subtle.ConstantTimeCompare([]byte(clientAuth.Password), password) != 1 {
cl.Debugf("Login failed for client auth id: %s", clientAuthID)
cl.Log().Debugf("Login failed for client auth id: %s", clientAuthID)
cl.bannedClientAuths.Add(clientAuthID)
if cl.bannedIPs != nil {
cl.bannedIPs.AddBadAttempt(ip)
@ -141,20 +165,21 @@ func (cl *ClientListener) getIP(addr net.Addr) string {
addrStr := addr.String()
host, _, err := net.SplitHostPort(addrStr)
if err != nil {
cl.Errorf("failed to split host port for %q: %v", addr, err)
cl.Log().Errorf("failed to split host port for %q: %v", addr, err)
return addrStr
}
return host
}
func (cl *ClientListener) Start(ctx context.Context, listenAddr string) error {
cl.Debugf("Client listener starting...")
clLogger := cl.Log()
clLogger.Debugf("Client listener starting...")
if cl.reverseProxy != nil {
cl.Infof("Reverse proxy enabled")
clLogger.Infof("Reverse proxy enabled")
}
cl.Infof("Listening on %s...", listenAddr)
clLogger.Infof("Listening on %s...", listenAddr)
h := http.Handler(middleware.MaxBytes(http.HandlerFunc(cl.handleClient), cl.config.Server.MaxRequestBytesClient))
h := http.Handler(middleware.MaxBytes(http.HandlerFunc(cl.handleClient), cl.server.config.Server.MaxRequestBytesClient))
if cl.bannedIPs != nil {
h = security.RejectBannedIPs(cl.bannedIPs)(h)
}
@ -174,7 +199,7 @@ func (cl *ClientListener) Close() error {
}
func (cl *ClientListener) handleClient(w http.ResponseWriter, r *http.Request) {
cl.Debugf("Incoming client connection...")
cl.Log().Debugf("Incoming client connection...")
//websockets upgrade AND has rport prefix
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
protocol := r.Header.Get("Sec-WebSocket-Protocol")
@ -184,7 +209,7 @@ func (cl *ClientListener) handleClient(w http.ResponseWriter, r *http.Request) {
return
}
//print into server logs and silently fall-through
cl.Infof("ignored client connection using protocol '%s', expected '%s'",
cl.Log().Infof("ignored client connection using protocol '%s', expected '%s'",
protocol, chshare.ProtocolVersion)
}
//proxy target was provided
@ -201,55 +226,90 @@ func (cl *ClientListener) nextClientIndex() int32 {
return atomic.AddInt32(&cl.clientIndexAutoIncrement, 1)
}
// handleWebsocket is responsible for handling the websocket connection
func (cl *ClientListener) handleWebsocket(w http.ResponseWriter, req *http.Request) {
ts := time.Now()
clog := cl.Fork("client#%d", cl.nextClientIndex())
func (cl *ClientListener) acceptSSHConnection(w http.ResponseWriter, req *http.Request) (sshConn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, clog *logger.Logger, err error) {
// add to pending connections. will block if the chan is full
cl.inprogressSSHHandshakes <- struct{}{}
clog = cl.Log().Fork("client#%d", cl.nextClientIndex())
clog.Debugf("Handling inbound web socket connection...")
ts := time.Now()
wsConn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
clog.Debugf("Failed to upgrade (%s)", err)
return
<-cl.inprogressSSHHandshakes
return nil, nil, nil, nil, err
}
conn := chshare.NewWebSocketConn(wsConn)
// perform SSH handshake on net.Conn
clog.Debugf("SSH Handshaking...")
sshConn, chans, reqs, err := ssh.NewServerConn(conn, cl.sshConfig)
sshConn, chans, reqs, err = ssh.NewServerConn(conn, cl.sshConfig)
if err != nil {
cl.Debugf("Failed to handshake (%s) from %s", err, conn.RemoteAddr().String())
return
if strings.Contains(err.Error(), "unexpected EOF") {
clog.Debugf("Failed to handshake (client closed connection? - %s) from %s", err, conn.RemoteAddr().String())
} else {
clog.Debugf("Failed to handshake (%s) from %s", err, conn.RemoteAddr().String())
}
<-cl.inprogressSSHHandshakes
return nil, nil, nil, nil, err
}
clog.Debugf("Handshake finished after %s", time.Since(ts))
clog.Debugf("SSH Handshake finished after %s", time.Since(ts))
//verify configuration
clog.Debugf("Verifying configuration...")
//wait for request, with timeout
var r *ssh.Request
// on handshake finished, remove from pending connections, which will allow another connection to take place
<-cl.inprogressSSHHandshakes
return sshConn, chans, reqs, clog, err
}
func (cl *ClientListener) receiveClientConnectionRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request, clog *logger.Logger) (connRequest *chshare.ConnectionRequest, r *ssh.Request, err error) {
pendingRequestTimer := time.NewTimer(ConnectionRequestTimeOut)
select {
case r = <-reqs:
case <-time.After(SSHTimeOut):
clog.Debugf("SSH connection timeout exceeded %s sec", SSHTimeOut.Seconds())
pendingRequestTimer.Stop()
case <-time.After(ConnectionRequestTimeOut):
clog.Debugf("SSH connection request timeout exceeded %s sec", ConnectionRequestTimeOut.Seconds())
err = sshConn.Close()
if err != nil {
clog.Debugf("error on SSH connection close: %s", err)
}
return
}
failed := func(err error) {
clog.Debugf("Failed: %s", err)
cl.replyConnectionError(r, err)
return nil, nil, err
}
if r.Type != "new_connection" {
failed(errors.New("expecting connection request"))
return
return nil, nil, errors.New("expecting connection request")
}
if len(r.Payload) > int(cl.config.Server.MaxRequestBytesClient) {
failed(fmt.Errorf("request data exceeds the limit of %d bytes, actual size: %d", cl.config.Server.MaxRequestBytesClient, len(r.Payload)))
return
if len(r.Payload) > int(cl.server.config.Server.MaxRequestBytesClient) {
return nil, nil, fmt.Errorf("request data exceeds the limit of %d bytes, actual size: %d", cl.server.config.Server.MaxRequestBytesClient, len(r.Payload))
}
connRequest, err := chshare.DecodeConnectionRequest(r.Payload)
connRequest, err = chshare.DecodeConnectionRequest(r.Payload)
if err != nil {
failed(fmt.Errorf("invalid connection request: %s", err))
return nil, nil, fmt.Errorf("invalid connection request: %s", err)
}
return connRequest, r, nil
}
// handleWebsocket is responsible for handling the websocket connection
func (cl *ClientListener) handleWebsocket(w http.ResponseWriter, req *http.Request) {
// keep the time from the initial client connection attempt
ts1 := time.Now()
sshConn, chans, reqs, clog, err := cl.acceptSSHConnection(w, req)
if err != nil {
return
}
// verify configuration
clog.Debugf("Verifying configuration...")
// first request to be received must be a connection request
connRequest, r, err := cl.receiveClientConnectionRequest(sshConn, reqs, clog)
if err != nil {
cl.replyConnectionError(r, err)
return
}
@ -260,40 +320,46 @@ func (cl *ClientListener) handleWebsocket(w http.ResponseWriter, req *http.Reque
// get the current client auth id
clientAuthID := sshConn.User()
// client id
cid, err := cl.getCID(connRequest.ID, cl.config, clientAuthID)
clientID, err := cl.getClientID(connRequest.ID, cl.server.config, clientAuthID)
if err != nil {
failed(fmt.Errorf("could not get cid: %s", err))
cl.replyConnectionError(r, fmt.Errorf("could not get clientID: %s", err))
return
}
// TODO: (rs): shouldn't this ctx be based on the server start ctx?
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ts = time.Now()
client, err := cl.clientService.StartClient(ctx, clientAuthID, cid, sshConn, cl.config.Server.AuthMultiuseCreds, connRequest, clog)
client, err := cl.GetClientService().StartClient(ctx, clientAuthID, clientID, sshConn, cl.server.config.Server.AuthMultiuseCreds, connRequest, clog)
if err != nil {
failed(err)
cl.replyConnectionError(r, err)
return
}
clog.Debugf("Client service started for %s (%s) within %s", client.ID, client.Name, time.Since(ts))
clog.Debugf("Client service started for %s (%s) within %s", client.GetID(), client.GetName(), time.Since(ts1))
ts2 := time.Now()
cl.replyConnectionSuccess(r, connRequest.Remotes)
cl.sendCapabilities(sshConn)
// Now the client is fully connected and ready to create tunnels and execute command and scripts
clientBanner := client.Banner()
clog.Debugf("opened %s within %s", clientBanner, time.Since(ts))
go cl.handleSSHRequests(clog, cid, reqs)
clog.Debugf("opened %s within %s", clientBanner, time.Since(ts2))
// now run handler for other client requests and connections
// TODO: (rs): shouldn't these also use the server start ctx
go cl.handleSSHRequests(clog, clientID, reqs)
go cl.handleSSHChannels(clog, chans)
// wait until we're disconnected from the client
if err = sshConn.Wait(); err != nil {
clog.Debugf("sshConn.Wait() error: %s", err)
}
clog.Debugf("close %s", clientBanner)
err = cl.clientService.Terminate(client)
err = cl.GetClientService().Terminate(client)
if err != nil {
cl.Errorf("could not terminate client: %s", err)
cl.Log().Errorf("could not terminate client: %s", err)
}
}
@ -311,7 +377,7 @@ func checkVersions(log *logger.Logger, clientVersion string) {
log.Infof("Client version (%s) differs from server version (%s)", v, chshare.BuildVersion)
}
func (cl *ClientListener) getCID(reqID string, config *chconfig.Config, clientAuthID string) (string, error) {
func (cl *ClientListener) getClientID(reqID string, config *chconfig.Config, clientAuthID string) (string, error) {
if reqID != "" {
return reqID, nil
}
@ -327,7 +393,7 @@ func (cl *ClientListener) getCID(reqID string, config *chconfig.Config, clientAu
func (cl *ClientListener) replyConnectionSuccess(r *ssh.Request, remotes []*models.Remote) {
replyPayload, err := json.Marshal(remotes)
if err != nil {
cl.Errorf("can't encode success reply payload")
cl.Log().Errorf("can't encode success reply payload")
cl.replyConnectionError(r, err)
return
}
@ -336,24 +402,68 @@ func (cl *ClientListener) replyConnectionSuccess(r *ssh.Request, remotes []*mode
}
func (cl *ClientListener) replyConnectionError(r *ssh.Request, err error) {
if r == nil {
cl.Log().Errorf("failed to send connection reply with error due to nil request")
return
}
if err == nil {
cl.Log().Debugf("sending connection reply with nil error: %s", r.Type)
_ = r.Reply(false, nil)
return
}
_ = r.Reply(false, []byte(err.Error()))
}
func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID string, reqs <-chan *ssh.Request) {
clientService := cl.GetClientService()
for r := range reqs {
if len(r.Payload) > int(cl.config.Server.MaxRequestBytesClient) {
clientLog.Errorf("%s:request data exceeds the limit of %d bytes, actual size: %d", comm.RequestTypeSaveMeasurement, cl.config.Server.MaxRequestBytesClient, len(r.Payload))
if len(r.Payload) > int(cl.server.config.Server.MaxRequestBytesClient) {
clientLog.Errorf("%s:request data exceeds the limit of %d bytes, actual size: %d", comm.RequestTypeSaveMeasurement, cl.server.config.Server.MaxRequestBytesClient, len(r.Payload))
continue
}
// clientLog.Debugf("received request: %s from %s", r.Type, clientID)
// TODO: (rs): these case handlers should be refactored into individual handling fns
switch r.Type {
// we shouldn't be receiving this. it means the client didn't receive the server's reply
// to the previous connection request, so ask the client to reconnect.
case "new_connection":
clientLog.Debugf("received connection request on existing connection. asking the client to reconnect.")
// IMPORTANT: the client is checking for the word "reconnect" in reply errors
cl.replyConnectionError(r, errors.New("unexpected connection request. please reconnect"))
client, err := clientService.GetRepo().GetActiveByID(clientID)
if err != nil {
clientLog.Debugf("unable to get client: %v", err)
continue
}
if client == nil {
clientLog.Debugf("client not found: %v", err)
continue
}
clientLog.Debugf("terminating client due for reconnect")
err = clientService.Terminate(client)
if err != nil {
clientLog.Debugf("failed to terminate client due for reconnect: %v", err)
}
continue
case comm.RequestTypePing:
// clientLog.Debugf("ping received from: %s", clientID)
// ts := time.Now()
_ = r.Reply(true, nil)
err := cl.clientService.SetLastHeartbeat(clientID, time.Now())
err := clientService.SetLastHeartbeat(clientID, time.Now())
if err != nil {
clientLog.Errorf("Failed to save heartbeat: %s", err)
continue
}
// clientLog.Debugf("ping for: %s done in %s", clientID, time.Since(ts))
case comm.RequestTypeCmdResult:
clientLog.Debugf("saving command result from: %s", clientID)
job, err := cl.saveCmdResult(r.Payload)
if err != nil {
clientLog.Errorf("Failed to save cmd result: %s", err)
@ -363,9 +473,9 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
var auditLogEntry *auditlog.Entry
if job.IsScript {
auditLogEntry = cl.auditLog.Entry(auditlog.ApplicationClientScript, auditlog.ActionExecuteDone)
auditLogEntry = cl.server.auditLog.Entry(auditlog.ApplicationClientScript, auditlog.ActionExecuteDone)
} else {
auditLogEntry = cl.auditLog.Entry(auditlog.ApplicationClientCommand, auditlog.ActionExecuteDone)
auditLogEntry = cl.server.auditLog.Entry(auditlog.ApplicationClientCommand, auditlog.ActionExecuteDone)
}
if job.MultiJobID != nil {
auditLogEntry.WithID(*job.MultiJobID)
@ -378,7 +488,7 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
Save()
if job.MultiJobID != nil {
done := cl.jobsDoneChannel.Get(*job.MultiJobID)
done := cl.server.jobsDoneChannel.Get(*job.MultiJobID)
if done != nil {
// to avoid blocking the exec - send job result in a new goroutine
go func(done2 chan *models.Job, job2 *models.Job) {
@ -386,21 +496,24 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
}(done, job)
}
}
case comm.RequestTypeUpdatesStatus:
clientLog.Debugf("setting updates status from: %s", clientID)
updatesStatus := &models.UpdatesStatus{}
err := json.Unmarshal(r.Payload, updatesStatus)
if err != nil {
clientLog.Errorf("Failed to unmarshal updates status: %s", err)
continue
}
err = cl.clientService.SetUpdatesStatus(clientID, updatesStatus)
err = clientService.SetUpdatesStatus(clientID, updatesStatus)
if err != nil {
clientLog.Errorf("Failed to save updates status: %s", err)
continue
}
case comm.RequestTypeSaveMeasurement:
// if server monitoring is disabled then do not save measurements even if received
if !cl.Server.config.Monitoring.Enabled {
if !cl.server.config.Monitoring.Enabled {
clientLog.Errorf("Received measurement when monitoring disabled. Measurement not saved.")
continue
}
@ -412,7 +525,7 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
continue
}
measurement.ClientID = clientID
err = cl.monitoringService.SaveMeasurement(context.Background(), measurement)
err = cl.server.monitoringService.SaveMeasurement(context.Background(), measurement)
if err != nil {
clientLog.Errorf("Failed to save measurement for client %s: %s", clientID, err)
continue
@ -437,18 +550,18 @@ func (cl *ClientListener) saveCmdResult(respBytes []byte) (*models.Job, error) {
} else {
wsJID = resp.JID
}
ws := cl.Server.uiJobWebSockets.Get(wsJID)
ws := cl.server.uiJobWebSockets.Get(wsJID)
if ws != nil {
err := ws.WriteMessage(websocket.TextMessage, respBytes)
if err != nil {
cl.Errorf("%s, failed to write message to UI Web Socket: %v", resp.LogPrefix(), err)
cl.Log().Errorf("%s, failed to write message to UI Web Socket: %v", resp.LogPrefix(), err)
// proceed further
}
} else {
cl.Debugf("%s, WS conn not found when saving command result. No active listeners connected", resp.LogPrefix())
cl.Log().Debugf("%s, WS conn not found when saving command result. No active listeners connected", resp.LogPrefix())
}
err = cl.jobProvider.SaveJob(&resp)
err = cl.server.jobProvider.SaveJob(&resp)
if err != nil {
return nil, fmt.Errorf("failed to save job result: %s", err)
}
@ -458,6 +571,7 @@ func (cl *ClientListener) saveCmdResult(respBytes []byte) (*models.Job, error) {
func (cl *ClientListener) handleSSHChannels(clientLog *logger.Logger, chans <-chan ssh.NewChannel) {
for ch := range chans {
ch := ch
extraData := string(ch.ExtraData())
stream, reqs, err := ch.Accept()
if err != nil {
@ -510,7 +624,7 @@ func (cl *ClientListener) handleOutputChannel(typ string, jobData []byte, client
wsJID = job.JID
}
ws := cl.Server.uiJobWebSockets.Get(wsJID)
ws := cl.server.uiJobWebSockets.Get(wsJID)
ocd := outputChannelData{
JID: job.JID,
@ -591,13 +705,13 @@ func (cl *ClientListener) handleSessionChannel(stream ssh.Channel, clientLog *lo
}
func (cl *ClientListener) sendCapabilities(conn *ssh.ServerConn) {
payload, err := json.Marshal(cl.Server.capabilities)
payload, err := json.Marshal(cl.server.capabilities)
if err != nil {
cl.Errorf("can't encode capabilities payload")
cl.Log().Errorf("can't encode capabilities payload")
return
}
if _, _, err = conn.SendRequest(comm.RequestTypePutCapabilities, false, payload); err != nil {
cl.Errorf("can't send capabilities: %v", err)
cl.Log().Errorf("can't send capabilities: %v", err)
}
}

View File

@ -18,10 +18,10 @@ import (
func TestHandleOutputChannel(t *testing.T) {
log := logger.NewLogger("client-listener-test", logger.LogOutput{File: os.Stdout}, logger.LogLevelDebug)
cl := &ClientListener{Server: &Server{uiJobWebSockets: ws.NewWebSocketCache()}}
cl := &ClientListener{server: &Server{uiJobWebSockets: ws.NewWebSocketCache()}}
mockConn := &connMock{}
ws := ws.NewConcurrentWebSocket(mockConn, log)
cl.Server.uiJobWebSockets.Set("test-jid", ws)
cl.server.uiJobWebSockets.Set("test-jid", ws)
testCases := []struct {
Name string

View File

@ -9,10 +9,12 @@ import (
"github.com/cloudradar-monitoring/rport/share/logger"
)
const DefaultMaxWorkers = 100
type ClientsStatusCheckTask struct {
log *logger.Logger
cr *clients.ClientRepository
th time.Duration // Threshold after which a client to server ping is considered outdated.
clientRepo *clients.ClientRepository
threshold time.Duration // Threshold after which a client to server ping is considered outdated.
pingTimeout time.Duration // Don't wait longer than pingTimeout for a response
}
@ -20,47 +22,52 @@ type ClientsStatusCheckTask struct {
func NewClientsStatusCheckTask(log *logger.Logger, cr *clients.ClientRepository, th time.Duration, pingTimeout time.Duration) *ClientsStatusCheckTask {
return &ClientsStatusCheckTask{
log: log.Fork("clients-status-check"),
cr: cr,
th: th,
clientRepo: cr,
threshold: th,
pingTimeout: pingTimeout,
}
}
func (t *ClientsStatusCheckTask) Run(ctx context.Context) error {
t.log.Debugf("running")
timerStart := time.Now()
var dueClients []*clients.Client
var confirmedClients = 0
var now = time.Now()
for _, c := range t.cr.GetAllActive() {
// Shorten the threshold aka make heartbeat older than it is because the ping response is stored after this check.
// Clients would get checked only every second time otherwise.
if c.LastHeartbeatAt != nil && now.Sub(*c.LastHeartbeatAt) < t.th-10*time.Second {
// Skip all clients having sent a heartbeat from client to server recently
confirmedClients++
continue
}
dueClients = append(dueClients, c)
}
dueClients := t.getDueClients()
if len(dueClients) == 0 {
// Nothing to do
t.log.Debugf("ended after %s, no clients to ping", time.Since(timerStart))
return nil
}
maxWorkers := 100
// make sure no more workers than clients and limit to max workers
maxWorkers := DefaultMaxWorkers
if maxWorkers > len(dueClients) {
maxWorkers = len(dueClients)
}
// make a channel that will receive all the clients to ping
clientsToPing := make(chan *clients.Client, len(dueClients))
results := make(chan bool, len(dueClients))
// create workers to ping clients
for w := 1; w <= maxWorkers; w++ {
go t.PingClients(clientsToPing, results)
go t.PingClients(w, clientsToPing, results)
}
// send the clients to ping to the workers
for _, dueClient := range dueClients {
clientsToPing <- dueClient
}
// we're done queuing clients for processing, so close the channel
close(clientsToPing)
// gather the results of pinged clients
var dead = 0
var alive = 0
// TODO: (rs): note this is fragile. any mismatch between actual and expected results will cause
// the task to block and essential hang. also there's no ctx checking.
for a := 0; a < len(dueClients); a++ {
if <-results {
alive++
@ -68,30 +75,73 @@ func (t *ClientsStatusCheckTask) Run(ctx context.Context) error {
dead++
}
}
t.log.Debugf("ended after %s, skipped: %d, pinged: %d, alive: %d, dead: %d", time.Since(timerStart), confirmedClients, len(dueClients), alive, dead)
return nil
}
func (t *ClientsStatusCheckTask) PingClients(clientsToPing <-chan *clients.Client, results chan<- bool) {
func (t *ClientsStatusCheckTask) getDueClients() (dueClients []*clients.Client) {
var confirmedClients = 0
var now = time.Now()
activeClients, _ := t.clientRepo.GetAllActiveClients()
for _, c := range activeClients {
// Shorten the threshold aka make heartbeat older than it is because the ping response is stored after this check.
// Clients would get checked only every second time otherwise.
if c.HasLastHeartbeatAt() {
lastHeartbeatAt := c.GetLastHeartbeatAtValue()
if now.Sub(lastHeartbeatAt) < t.threshold-(10*time.Second) {
// Skip all clients having sent a heartbeat from client to server recently
// t.log.Debugf("skipping client: %s, %s, %s", c.GetID(), lastHeartbeatAt, now.Sub(lastHeartbeatAt) < t.threshold-(10*time.Second))
confirmedClients++
continue
}
}
dueClients = append(dueClients, c)
}
return dueClients
}
func (t *ClientsStatusCheckTask) PingClients(workerNum int, clientsToPing <-chan *clients.Client, results chan<- bool) {
// while there are clients to ping
for cl := range clientsToPing {
ok, response, rtt, err := comm.PingConnectionWithTimeout(cl.Connection, t.pingTimeout)
clientName := cl.GetName()
clientID := cl.GetID()
ok, response, rtt, err := comm.PingConnectionWithTimeout(cl.GetConnection(), t.pingTimeout, cl.Log())
//t.log.Debugf("ok=%s, error=%s, response=%s", ok, err, response)
//Old clients cannot respond properly to a ping request yet
if !ok && err == nil && string(response) == "unknown request" {
t.log.Debugf("ping to %s [%s] succeeded in %s. client < 0.8.2", cl.Name, cl.ID, rtt)
// Old clients cannot respond properly to a ping request yet
if !ok && err == nil && t.isLegacyClientResponse(response) {
t.log.Debugf("ping to %s [%s] succeeded in %s. client < 0.8.2", clientName, clientID, rtt)
cl.SetHeartbeatNow()
results <- true
continue
}
// client versions from 0.9.2 to 0.9.6 can return "null" as a ping response. this is due to a bug
// in the client ping handling that cause 2 replies to be sent by the client. this breaks stuff.
// for the server, assume the null reply is a successful ping. unfortunately, the extra reply
// confuses the next send by the client which means it won't get a reply from the server and
// will ultimately disconnect and reconnect. the work around is to make sure that the client
// pings the server faster than the server pings the client. as the server has a recent heartbeat
// already (from the client) it won't ping the client again, meaning that the client won't get a
// chance to double reply to the server and cause ssh protocol confusion.
if ok && err == nil && string(response) == "null" {
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2 *", clientName, clientID, rtt)
cl.SetHeartbeatNow()
results <- true
continue
}
// Only an empty response confirms the ping
if ok && err == nil && len(response) == 0 {
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2", cl.Name, cl.ID, rtt)
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2", clientName, clientID, rtt)
cl.SetHeartbeatNow()
results <- true
continue
}
// None of the above. Ping must have failed or timed out.
t.log.Infof("ping to %s [%s] failed: %s", cl.Name, cl.ID, err)
t.log.Infof("ping to %s [%s] failed: %s", clientName, clientID, err)
cl.SetDisconnectedNow()
@ -99,3 +149,7 @@ func (t *ClientsStatusCheckTask) PingClients(clientsToPing <-chan *clients.Clien
results <- false
}
}
func (t *ClientsStatusCheckTask) isLegacyClientResponse(response []byte) (isLegacy bool) {
return string(response) == "unknown request"
}

View File

@ -65,58 +65,58 @@ func TestClientsStatusDeterminationTask(t *testing.T) {
timeout = 1 * time.Millisecond
)
now := time.Now()
cr := clients.NewClientRepository([]*clients.Client{
{
ID: "1",
ClientAuthID: "1",
Connection: connSuccess,
},
{
ID: "2",
ClientAuthID: "2",
Connection: connSuccess,
LastHeartbeatAt: &now,
},
{
ID: "3",
ClientAuthID: "3",
Connection: connFailure,
},
{
ID: "4",
ClientAuthID: "4",
Connection: connTimeout,
},
}, nil, myTestLog)
c1 := clients.Client{}
c1.SetID("1")
c1.SetClientAuthID("1")
c1.SetConnection(connSuccess)
c2 := clients.Client{}
c2.SetID("2")
c2.SetClientAuthID("2")
c2.SetConnection(connSuccess)
c2.SetLastHeartbeatAt(&now)
c3 := clients.Client{}
c3.SetID("3")
c3.SetClientAuthID("3")
c3.SetConnection(connFailure)
c4 := clients.Client{}
c4.SetID("4")
c4.SetClientAuthID("4")
c4.SetConnection(connTimeout)
cr := clients.NewClientRepository([]*clients.Client{&c1, &c2, &c3, &c4}, nil, myTestLog)
task := NewClientsStatusCheckTask(myTestLog, cr, 120*time.Second, timeout)
// Check the last heartbeat of c1 has changed due to the ping sent
err = task.Run(context.Background())
assert.NoError(t, err)
c1, err := cr.GetByID("1")
tcl1, err := cr.GetByID("1")
assert.NoError(t, err)
assert.IsType(t, &time.Time{}, c1.LastHeartbeatAt)
t.Logf("c1: LastHeartbeatAt: %s", c1.LastHeartbeatAt)
assert.IsType(t, &time.Time{}, tcl1.GetLastHeartbeatAt())
t.Logf("tcl1: LastHeartbeatAt: %s", tcl1.GetLastHeartbeatAt())
// Check the last heartbeat of c2 has not changed because the task must skip this client
c2, err := cr.GetByID("2")
tcl2, err := cr.GetByID("2")
assert.NoError(t, err)
assert.Equal(t, &now, c2.LastHeartbeatAt, "LastHeartbeatAt of c2 must not change")
t.Logf("c2: LastHeartbeatAt: %s", c2.LastHeartbeatAt)
assert.Equal(t, &now, tcl2.GetLastHeartbeatAt(), "LastHeartbeatAt of tcl2 must not change")
t.Logf("tcl2: LastHeartbeatAt: %s", tcl2.GetLastHeartbeatAt())
// Check the status of c3 changed to disconnected
c3, err := cr.GetByID("3")
tcl3, err := cr.GetByID("3")
assert.NoError(t, err)
assert.NotNil(t, c3.DisconnectedAt)
assert.Equal(t, "disconnected", string(c3.CalculateConnectionState()))
t.Logf("c3: DisconnectedAt: %s", c3.DisconnectedAt)
assert.NotNil(t, tcl3.GetDisconnectedAt())
assert.Equal(t, "disconnected", string(tcl3.CalculateConnectionState()))
t.Logf("tcl3: GetDisconnectedAt(): %s", tcl3.GetDisconnectedAt())
// Check the status of c4 changed to disconnected caused by a timeout
c4, err := cr.GetByID("4")
tcl4, err := cr.GetByID("4")
assert.NoError(t, err)
assert.NotNil(t, c4.DisconnectedAt)
assert.Equal(t, "disconnected", string(c4.CalculateConnectionState()))
t.Logf("c4: DisconnectedAt: %s", c4.DisconnectedAt)
assert.NotNil(t, tcl4.GetDisconnectedAt())
assert.Equal(t, "disconnected", string(tcl4.CalculateConnectionState()))
t.Logf("tcl4: DisconnectedAt: %s", tcl4.GetDisconnectedAt())
log, err := os.ReadFile(logfile)
assert.NoError(t, err, "error reading log file")
assert.Contains(t, string(log), fmt.Sprintf("ping to [4] failed: conn.SendRequest(ping), timeout %s exceeded", timeout))

View File

@ -12,17 +12,20 @@ import (
func TestCleanup(t *testing.T) {
// given
ctx := context.Background()
c1 := New(t).Build() // active
c2 := New(t).DisconnectedDuration(5 * time.Minute).Build() // disconnected
c3 := New(t).DisconnectedDuration(time.Hour + time.Minute).Build() // obsolete
c1 := New(t).ID("client-1").Logger(testLog).Build() // active
c2 := New(t).ID("client-2").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build() // disconnected
c3 := New(t).ID("client-3").DisconnectedDuration(time.Hour + time.Minute).Logger(testLog).Build() // obsolete
clients := []*Client{c1, c2, c3}
p := NewFakeClientProvider(t, &hour, c1, c2, c3)
defer p.Close()
clientsRepo := NewClientRepositoryWithDB(clients, &hour, p, testLog)
require.Len(t, clientsRepo.clients, 3)
gotObsolete, err := p.get(ctx, c3.ID)
require.Len(t, clientsRepo.clientState, 3)
gotObsolete, err := p.get(ctx, c3.GetID())
require.NoError(t, err)
c3.Logger = nil
// patch in the logger for the client
gotObsolete.Logger = testLog
require.EqualValues(t, c3, gotObsolete)
task := NewCleanupTask(testLog, clientsRepo)
@ -31,11 +34,16 @@ func TestCleanup(t *testing.T) {
// then
assert.NoError(t, err)
assert.ElementsMatch(t, getValues(clientsRepo.clients), []*Client{c1, c2})
assert.ElementsMatch(t, getValues(clientsRepo.clientState), []*Client{c1, c2})
gotClients, err := p.GetAll(ctx)
assert.NoError(t, err)
// patch in the logger for the clients
gotClients[0].Logger = testLog
gotClients[1].Logger = testLog
assert.ElementsMatch(t, []*Client{c1, c2}, gotClients)
gotObsolete, err = p.get(ctx, c3.ID)
gotObsolete, err = p.get(ctx, c3.GetID())
require.NoError(t, err)
require.Nil(t, gotObsolete)
}
@ -43,14 +51,14 @@ func TestCleanup(t *testing.T) {
func TestCleanupDisabled(t *testing.T) {
// given
ctx := context.Background()
c1 := New(t).Build() // active
c2 := New(t).DisconnectedDuration(5 * time.Minute).Build() // disconnected
c3 := New(t).DisconnectedDuration(365*24*time.Hour + time.Minute).Build() // disconnected longer
c1 := New(t).Logger(testLog).Build() // active
c2 := New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build() // disconnected
c3 := New(t).DisconnectedDuration(365*24*time.Hour + time.Minute).Logger(testLog).Build() // disconnected longer
clients := []*Client{c1, c2, c3}
p := NewFakeClientProvider(t, nil, c1, c2, c3)
defer p.Close()
clientsRepo := NewClientRepositoryWithDB(clients, nil, p, testLog)
require.Len(t, clientsRepo.clients, 3)
require.Len(t, clientsRepo.clientState, 3)
task := NewCleanupTask(testLog, clientsRepo)
@ -59,7 +67,7 @@ func TestCleanupDisabled(t *testing.T) {
// then
assert.NoError(t, err)
assert.ElementsMatch(t, getValues(clientsRepo.clients), []*Client{c1, c2, c3})
assert.ElementsMatch(t, getValues(clientsRepo.clientState), []*Client{c1, c2, c3})
}
func getValues(clients map[string]*Client) []*Client {

View File

@ -12,6 +12,7 @@ import (
"github.com/cloudradar-monitoring/rport/server/api/users"
"github.com/cloudradar-monitoring/rport/server/cgroups"
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
chshare "github.com/cloudradar-monitoring/rport/share"
"github.com/cloudradar-monitoring/rport/share/clientconfig"
"github.com/cloudradar-monitoring/rport/share/logger"
"github.com/cloudradar-monitoring/rport/share/models"
@ -58,6 +59,7 @@ type Client struct {
Version string `json:"version"`
Address string `json:"address"`
Tunnels []*clienttunnel.Tunnel `json:"tunnels"`
// DisconnectedAt is a time when a client was disconnected. If nil - it's connected.
DisconnectedAt *time.Time `json:"disconnected_at"`
LastHeartbeatAt *time.Time `json:"last_heartbeat_at"`
@ -67,11 +69,13 @@ type Client struct {
ClientConfiguration *clientconfig.Config `json:"client_configuration"`
Connection ssh.Conn `json:"-"`
Logger *logger.Logger `json:"-"`
Context context.Context `json:"-"`
Paused bool `json:"-"`
PausedReason string `json:"-"`
lock sync.Mutex
Logger *logger.Logger `json:"-"`
flock sync.RWMutex
}
// CalculatedClient contains additional fields and is calculated on each request
@ -81,47 +85,295 @@ type CalculatedClient struct {
ConnectionState ConnectionState `json:"connection_state"`
}
func NewCalculatedClient(c *Client, groups []string, connectionState ConnectionState) (cc *CalculatedClient) {
cc = &CalculatedClient{}
cc.Client = c
cc.Groups = groups
cc.ConnectionState = connectionState
return cc
}
func (cc *CalculatedClient) GetConnectionState() (cs ConnectionState) {
return cc.ConnectionState
}
func (c *Client) GetID() (id string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.ID
}
func (c *Client) GetName() (name string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.Name
}
func (c *Client) GetSessionID() (sessionID string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.SessionID
}
func (c *Client) GetOS() (os string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.OS
}
func (c *Client) GetHostname() (hostname string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.Hostname
}
func (c *Client) GetTags() (tags []string) {
c.flock.RLock()
defer c.flock.RUnlock()
if c.Tags == nil {
return nil
}
// make sure not to return reference to underlying array
tags = make([]string, len(c.Tags))
copy(tags, c.Tags)
return tags
}
func (c *Client) GetAllowedUserGroups() (groups []string) {
c.flock.RLock()
defer c.flock.RUnlock()
if c.AllowedUserGroups == nil {
return nil
}
// make sure not to return reference to underlying array
groups = make([]string, len(c.AllowedUserGroups))
copy(groups, c.AllowedUserGroups)
return groups
}
func (c *Client) GetVersion() (version string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.Version
}
func (c *Client) GetDisconnectedAt() (at *time.Time) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.DisconnectedAt
}
func (c *Client) GetDisconnectedAtValue() (at time.Time) {
c.flock.RLock()
if c.DisconnectedAt != nil {
at = *c.DisconnectedAt
}
c.flock.RUnlock()
return at
}
func (c *Client) HasLastHeartbeatAt() (has bool) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.LastHeartbeatAt != nil
}
func (c *Client) GetLastHeartbeatAt() (at *time.Time) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.LastHeartbeatAt
}
func (c *Client) GetLastHeartbeatAtValue() (at time.Time) {
c.flock.RLock()
if c.LastHeartbeatAt != nil {
at = *c.LastHeartbeatAt
}
c.flock.RUnlock()
return at
}
func (c *Client) GetConnection() (conn ssh.Conn) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.Connection
}
func (c *Client) GetPausedReason() (reason string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.PausedReason
}
func (c *Client) GetContext() (ctx context.Context) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.Context
}
func (c *Client) GetClientAuthID() (authID string) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.ClientAuthID
}
func (c *Client) GetTunnels() (tunnels []*clienttunnel.Tunnel) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.Tunnels
}
func (c *Client) Log() (l *logger.Logger) {
c.flock.RLock()
defer c.flock.RUnlock()
return c.Logger
}
func (c *Client) IsPaused() (paused bool) {
paused = c.Paused
return paused
c.flock.RLock()
defer c.flock.RUnlock()
return c.Paused
}
func (c *Client) GetMonitoringConfig() (monitoringConfig *clientconfig.MonitoringConfig) {
c.flock.RLock()
defer c.flock.RUnlock()
if c.ClientConfiguration == nil {
return nil
}
return &c.ClientConfiguration.Monitoring
}
func (c *Client) GetFileReceptionConfig() (fileReceptionConfig *clientconfig.FileReceptionConfig) {
c.flock.RLock()
defer c.flock.RUnlock()
if c.ClientConfiguration == nil {
return nil
}
return &c.ClientConfiguration.FileReceptionConfig
}
// test only
func (c *Client) SetID(id string) {
c.flock.Lock()
c.ID = id
c.flock.Unlock()
}
// test only
func (c *Client) SetAddress(address string) {
c.flock.Lock()
c.Address = address
c.flock.Unlock()
}
// test only
func (c *Client) SetHostname(hostname string) {
c.flock.Lock()
c.Hostname = hostname
c.flock.Unlock()
}
// test only
func (c *Client) SetClientAuthID(authID string) {
c.flock.Lock()
c.ClientAuthID = authID
c.flock.Unlock()
}
// test only
func (c *Client) SetTags(tags []string) {
if c.Tags == nil {
return
}
c.flock.Lock()
// make sure not to just copy the tag reference
c.Tags = make([]string, len(c.Tags))
copy(c.Tags, tags)
c.flock.Unlock()
}
// test only
func (c *Client) SetConnection(conn ssh.Conn) {
c.flock.Lock()
c.Connection = conn
c.flock.Unlock()
}
func (c *Client) SetTunnels(tunnels []*clienttunnel.Tunnel) {
c.flock.Lock()
c.Tunnels = tunnels
c.flock.Unlock()
}
func (c *Client) SetAllowedUserGroups(groups []string) {
c.flock.Lock()
c.AllowedUserGroups = groups
c.flock.Unlock()
}
func (c *Client) SetUpdatesStatus(status *models.UpdatesStatus) {
c.flock.Lock()
c.UpdatesStatus = status
c.flock.Unlock()
}
func (c *Client) SetDisconnectedAt(at *time.Time) {
l := c.Log()
if l != nil && at != nil {
l.Debugf("%s: set to disconnected at %s", c.GetID(), at)
}
c.flock.Lock()
c.DisconnectedAt = at
c.flock.Unlock()
}
func (c *Client) SetLastHeartbeatAt(at *time.Time) {
c.SetDisconnectedAt(nil)
c.flock.Lock()
c.LastHeartbeatAt = at
c.flock.Unlock()
}
const PausedDueToMaxClientsExceeded = "unlicensed"
func (c *Client) SetPaused(paused bool, reason string) {
c.flock.Lock()
c.Paused = paused
c.PausedReason = reason
c.flock.Unlock()
if paused {
c.Logger.Infof("client %s is paused (reason = %s)", c.ID, c.PausedReason)
c.Log().Infof("client %s is paused (reason = %s)", c.GetID(), reason)
}
}
func (c *Client) IsConnected() bool {
return c.GetDisconnectedAt() == nil
}
func (c *Client) SetConnected() {
c.Logger.Debugf("%s: set to connected at %s", c.ID, time.Now())
c.DisconnectedAt = nil
}
func (c *Client) SetDisconnected(at *time.Time) {
if c.Logger != nil {
c.Logger.Debugf("%s: set to disconnected at %s", c.ID, at)
}
c.DisconnectedAt = at
c.Log().Debugf("%s: set to connected at %s", c.GetID(), time.Now())
c.SetDisconnectedAt(nil)
}
func (c *Client) SetDisconnectedNow() {
now := time.Now()
c.DisconnectedAt = &now
c.SetDisconnectedAt(&now)
}
func (c *Client) SetHeartbeatNow() {
now := time.Now()
c.DisconnectedAt = nil
c.LastHeartbeatAt = &now
}
func (c *Client) SetHeartbeat(at *time.Time) {
c.DisconnectedAt = nil
c.LastHeartbeatAt = at
c.SetLastHeartbeatAt(&now)
c.SetDisconnectedAt(nil)
}
func (c *Client) ToCalculated(allGroups []*cgroups.ClientGroup) *CalculatedClient {
@ -131,26 +383,16 @@ func (c *Client) ToCalculated(allGroups []*cgroups.ClientGroup) *CalculatedClien
clientGroups = append(clientGroups, group.ID)
}
}
return &CalculatedClient{
Client: c,
Groups: clientGroups,
ConnectionState: c.CalculateConnectionState(),
}
return NewCalculatedClient(c, clientGroups, c.CalculateConnectionState())
}
// Obsolete returns true if a given client was disconnected longer than a given duration.
// If a given duration is nil - returns false (never obsolete).
func (c *Client) Obsolete(duration *time.Duration) bool {
return duration != nil && c.DisconnectedAt != nil &&
c.DisconnectedAt.Add(*duration).Before(now())
}
func (c *Client) Lock() {
c.lock.Lock()
}
func (c *Client) Unlock() {
c.lock.Unlock()
disconnectedAt := c.GetDisconnectedAt()
return duration != nil && disconnectedAt != nil &&
disconnectedAt.Add(*duration).Before(now())
}
func (c *Client) NewTunnelID() (tunnelID string) {
@ -163,31 +405,37 @@ func (c *Client) generateNewTunnelID() int64 {
}
func (c *Client) RemoveTunnelByID(tunnelID string) {
result := make([]*clienttunnel.Tunnel, 0)
for _, curr := range c.Tunnels {
if curr.ID != tunnelID {
result = append(result, curr)
updatedTunnelList := make([]*clienttunnel.Tunnel, 0)
// TODO: (rs): not thread-safe
for _, tunnel := range c.GetTunnels() {
if tunnel.ID != tunnelID {
updatedTunnelList = append(updatedTunnelList, tunnel)
}
}
c.Tunnels = result
c.SetTunnels(updatedTunnelList)
}
func (c *Client) Banner() string {
banner := c.ID
if c.Name != "" {
banner += " (" + c.Name + ")"
clientID := c.GetID()
clientName := c.GetName()
tags := c.GetTags()
banner := clientID
if clientName != "" {
banner += " (" + clientName + ")"
}
if len(c.Tags) != 0 {
for _, t := range c.Tags {
if len(tags) != 0 {
for _, t := range tags {
banner += " #" + t
}
}
return banner
}
func (c *Client) Close() error {
// The tunnels are closed automatically when ssh connection is closed.
return c.Connection.Close()
return c.GetConnection().Close()
}
func (c *Client) BelongsToOneOf(groups []*cgroups.ClientGroup) bool {
@ -204,6 +452,10 @@ func (c *Client) BelongsTo(group *cgroups.ClientGroup) bool {
if p.HasNoParams() {
return false
}
c.flock.RLock()
defer c.flock.RUnlock()
if !p.ClientID.MatchesOneOf(c.ID) {
return false
}
@ -247,7 +499,7 @@ func (c *Client) BelongsTo(group *cgroups.ClientGroup) bool {
}
func (c *Client) CalculateConnectionState() ConnectionState {
if c.DisconnectedAt == nil {
if c.IsConnected() {
return Connected
}
return Disconnected
@ -259,7 +511,7 @@ func (c *Client) HasAccessViaUserGroups(userGroups []string) bool {
if curUserGroup == users.Administrators {
return true
}
for _, allowedGroup := range c.AllowedUserGroups {
for _, allowedGroup := range c.GetAllowedUserGroups() {
if allowedGroup == curUserGroup {
return true
}
@ -284,3 +536,48 @@ func (c *Client) UserGroupHasAccessViaClientGroup(userGroups []string, allClient
func NewClientID() (string, error) {
return random.UUID4()
}
func NewClientFromConnRequest(ctx context.Context, existingClient *Client, clientAuthID string, clientID string, req *chshare.ConnectionRequest, clientHost string, sshConn ssh.Conn, clog *logger.Logger) (client *Client) {
if existingClient == nil {
client = &Client{
ID: clientID,
}
} else {
client = existingClient
}
client.flock.Lock()
client.Name = req.Name
client.SessionID = req.SessionID
client.OS = req.OS
client.OSArch = req.OSArch
client.OSFamily = req.OSFamily
client.OSKernel = req.OSKernel
client.OSFullName = req.OSFullName
client.OSVersion = req.OSVersion
client.OSVirtualizationSystem = req.OSVirtualizationSystem
client.OSVirtualizationRole = req.OSVirtualizationRole
client.Hostname = req.Hostname
client.CPUFamily = req.CPUFamily
client.CPUModel = req.CPUModel
client.CPUModelName = req.CPUModelName
client.CPUVendor = req.CPUVendor
client.NumCPUs = req.NumCPUs
client.MemoryTotal = req.MemoryTotal
client.Timezone = req.Timezone
client.IPv4 = req.IPv4
client.IPv6 = req.IPv6
client.Tags = req.Tags
client.Version = req.Version
client.ClientConfiguration = req.ClientConfiguration
client.Address = clientHost
client.Tunnels = make([]*clienttunnel.Tunnel, 0)
client.DisconnectedAt = nil
client.ClientAuthID = clientAuthID
client.Connection = sshConn
client.Context = ctx
client.Logger = clog
client.flock.Unlock()
return client
}

View File

@ -14,6 +14,7 @@ import (
"time"
"github.com/cloudradar-monitoring/rport/share/clientconfig"
"github.com/cloudradar-monitoring/rport/share/logger"
"golang.org/x/crypto/ssh"
@ -39,6 +40,7 @@ type ClientBuilder struct {
disconnectedAt *time.Time
allowedUserGroups []string
conn ssh.Conn
logger *logger.Logger
cfg *clientconfig.Config
}
@ -58,6 +60,11 @@ func (b ClientBuilder) ID(id string) ClientBuilder {
return b
}
func (b ClientBuilder) Logger(l *logger.Logger) ClientBuilder {
b.logger = l
return b
}
func (b ClientBuilder) ClientAuthID(clientAuthID string) ClientBuilder {
b.clientAuthID = clientAuthID
return b
@ -139,6 +146,7 @@ func (b ClientBuilder) Build() *Client {
Connection: b.conn,
ClientConfiguration: b.cfg,
Logger: b.logger,
}
return cl

View File

@ -31,15 +31,15 @@ import (
type ClientService interface {
SetPlusLicenseInfoCap(licensecap licensecap.CapabilityEx)
Count() (int, error)
CountActive() (int, error)
Count() int
CountActive() int
CountDisconnected() (int, error)
GetByID(id string) (*Client, error)
GetActiveByID(id string) (*Client, error)
GetByGroups(groups []*cgroups.ClientGroup) ([]*Client, error)
GetClientsByTag(tags []string, operator string, allowDisconnected bool) (clients []*Client, err error)
GetAllByClientID(clientID string) []*Client
GetAll() ([]*Client, error)
GetAll() []*Client
GetUserClients(groups []*cgroups.ClientGroup, user User) ([]*Client, error)
GetFilteredUserClients(user User, filterOptions []query.FilterOption, groups []*cgroups.ClientGroup) ([]*CalculatedClient, error)
@ -81,7 +81,13 @@ type ClientServiceProvider struct {
licensecap licensecap.CapabilityEx
mu sync.Mutex
mu sync.RWMutex
}
func (s *ClientServiceProvider) log() (l *logger.Logger) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.logger
}
var OptionsSupportedFilters = map[string]bool{
@ -216,9 +222,9 @@ func (s *ClientServiceProvider) UpdateClientStatus() {
return
}
s.logger.Debugf("updating client status")
s.log().Debugf("updating client status")
clientList := s.repo.GetAllActive()
clientList, _ := s.repo.GetAllActiveClients()
for i, client := range clientList {
if i < s.GetMaxClients() {
@ -227,14 +233,13 @@ func (s *ClientServiceProvider) UpdateClientStatus() {
client.SetPaused(true, PausedDueToMaxClientsExceeded)
}
}
}
func (s *ClientServiceProvider) Count() (int, error) {
func (s *ClientServiceProvider) Count() int {
return s.repo.Count()
}
func (s *ClientServiceProvider) CountActive() (int, error) {
func (s *ClientServiceProvider) CountActive() int {
return s.repo.CountActive()
}
@ -255,13 +260,10 @@ func (s *ClientServiceProvider) GetByGroups(groups []*cgroups.ClientGroup) ([]*C
return nil, nil
}
all, err := s.repo.GetAll()
if err != nil {
return nil, err
}
allClients := s.repo.GetAllClients()
var res []*Client
for _, cur := range all {
for _, cur := range allClients {
if cur.BelongsToOneOf(groups) {
res = append(res, cur)
}
@ -274,11 +276,12 @@ func (s *ClientServiceProvider) GetClientsByTag(tags []string, operator string,
}
func (s *ClientServiceProvider) PopulateGroupsWithUserClients(groups []*cgroups.ClientGroup, user User) {
all, _ := s.repo.GetUserClients(user, groups)
for _, curClient := range all {
availableClients, _ := s.repo.GetUserClients(user, groups)
for _, client := range availableClients {
clientID := client.GetID()
for _, curGroup := range groups {
if curClient.BelongsTo(curGroup) {
curGroup.ClientIDs = append(curGroup.ClientIDs, curClient.ID)
if client.BelongsTo(curGroup) {
curGroup.ClientIDs = append(curGroup.ClientIDs, clientID)
}
}
}
@ -291,8 +294,8 @@ func (s *ClientServiceProvider) GetAllByClientID(clientID string) []*Client {
return s.repo.GetAllByClientAuthID(clientID)
}
func (s *ClientServiceProvider) GetAll() ([]*Client, error) {
return s.repo.GetAll()
func (s *ClientServiceProvider) GetAll() []*Client {
return s.repo.GetAllClients()
}
func (s *ClientServiceProvider) GetUserClients(groups []*cgroups.ClientGroup, user User) ([]*Client, error) {
@ -308,40 +311,49 @@ func (s *ClientServiceProvider) StartClient(
req *chshare.ConnectionRequest, clog *logger.Logger,
) (*Client, error) {
clog.Debugf("Starting client session: %s", clientID)
repo := s.GetRepo()
s.mu.Lock()
defer s.mu.Unlock()
clientAddr := sshConn.RemoteAddr().String()
clientHost, _, err := net.SplitHostPort(clientAddr)
if err != nil {
return nil, fmt.Errorf("failed to get host for address %q: %v", clientAddr, err)
}
// if client id is in use, deny connection
client, err := s.repo.GetByID(clientID)
client, err := repo.GetByID(clientID)
if err != nil {
return nil, fmt.Errorf("failed to get client by id %q", clientID)
}
// if found existing client
if client != nil {
clog.Debugf("found existing client %s", clientID)
var sessionReUsed = false
if req.SessionID != "" && req.SessionID == client.SessionID {
if req.SessionID != "" && req.SessionID == client.GetSessionID() {
// Stored previous session id and the session id of the connection attempt are equal
sessionReUsed = true
clog.Debugf("resuming existing session %s for client %s [%s]", req.SessionID, client.Name, clientID)
}
if client.DisconnectedAt == nil && !sessionReUsed {
return nil, fmt.Errorf("client is already connected: %s [%s]", client.Name, clientID)
clog.Debugf("resuming existing session %s for client %s [%s]", req.SessionID, client.GetName(), clientID)
}
oldTunnels := GetTunnelsToReestablish(getRemotes(client.Tunnels), req.Remotes)
if client.IsConnected() && !sessionReUsed {
clog.Debugf("client is already connected: %s", clientID)
return nil, fmt.Errorf("client is already connected: %s [%s]", client.GetName(), clientID)
}
oldTunnels := getTunnelsToReestablish(getRemotes(client.GetTunnels()), req.Remotes)
clientVersion, err := version.NewVersion(req.Version)
if err != nil {
return nil, fmt.Errorf("failed to determine client version: %v", err)
}
requiredVersion, _ := version.NewVersion("0.6.4")
if clientVersion.GreaterThanOrEqual(requiredVersion) {
oldTunnels, err = ExcludeNotAllowedTunnels(clog, oldTunnels, sshConn)
oldTunnels, err = s.excludeNotAllowedTunnels(clog, oldTunnels, sshConn)
if err != nil {
return nil, fmt.Errorf("failed to filter tunnels: %v", err)
}
} else {
clog.Infof("client %s (%s) version %s does not support 'tunnel_allowed' policies. Consider upgrading.", client.ID, client.Name, client.Version)
clog.Infof("client %s (%s) version %s does not support 'tunnel_allowed' policies. Consider upgrading.", client.GetID(), client.GetName(), client.GetVersion())
}
clog.Infof("tunnels to create %d: %v", len(req.Remotes), req.Remotes)
@ -353,68 +365,33 @@ func (s *ClientServiceProvider) StartClient(
// check if client auth ID is already used by another client
if !authMultiuseCreds && s.isClientAuthIDInUse(clientAuthID, clientID) {
clog.Debugf("client auth ID is already in use: %s: %q: ", clientID, clientAuthID)
return nil, fmt.Errorf("client auth ID is already in use: %q", clientAuthID)
}
clientAddr := sshConn.RemoteAddr().String()
clientHost, _, err := net.SplitHostPort(clientAddr)
if err != nil {
return nil, fmt.Errorf("failed to get host for address %q: %v", clientAddr, err)
}
if client == nil {
client = &Client{
ID: clientID,
}
}
client.Name = req.Name
client.SessionID = req.SessionID
client.OS = req.OS
client.OSArch = req.OSArch
client.OSFamily = req.OSFamily
client.OSKernel = req.OSKernel
client.OSFullName = req.OSFullName
client.OSVersion = req.OSVersion
client.OSVirtualizationSystem = req.OSVirtualizationSystem
client.OSVirtualizationRole = req.OSVirtualizationRole
client.Hostname = req.Hostname
client.CPUFamily = req.CPUFamily
client.CPUModel = req.CPUModel
client.CPUModelName = req.CPUModelName
client.CPUVendor = req.CPUVendor
client.NumCPUs = req.NumCPUs
client.MemoryTotal = req.MemoryTotal
client.Timezone = req.Timezone
client.IPv4 = req.IPv4
client.IPv6 = req.IPv6
client.Tags = req.Tags
client.Version = req.Version
client.ClientConfiguration = req.ClientConfiguration
client.Address = clientHost
client.Tunnels = make([]*clienttunnel.Tunnel, 0)
client.DisconnectedAt = nil
client.ClientAuthID = clientAuthID
client.Connection = sshConn
client.Context = ctx
client.Logger = clog
client = NewClientFromConnRequest(ctx, client, clientAuthID, clientID, req, clientHost, sshConn, clog)
client.SetConnected()
s.UpdateClientStatus()
if !client.IsPaused() {
_, err = s.startClientTunnels(client, req.Remotes)
_, err = s.startClientTunnels(client, req.Remotes, clog)
if err != nil {
return nil, err
}
}
err = s.repo.Save(client)
err = repo.Save(client)
if err != nil {
return nil, err
}
// TODO: (rs): should we keep this?
_, totalClients := repo.GetAllActiveClients()
s.log().Debugf("total clients = %d (last: %s)", totalClients, client.GetName())
return client, nil
}
@ -426,8 +403,8 @@ func getRemotes(tunnels []*clienttunnel.Tunnel) []*models.Remote {
return r
}
// GetTunnelsToReestablish returns old tunnels that should be re-establish taking into account new tunnels.
func GetTunnelsToReestablish(old, new []*models.Remote) []*models.Remote {
// getTunnelsToReestablish returns old tunnels that should be re-establish taking into account new tunnels.
func getTunnelsToReestablish(old, new []*models.Remote) []*models.Remote {
if len(new) > len(old) {
return nil
}
@ -484,11 +461,9 @@ loop2:
// StartClientTunnels returns a new tunnel for each requested remote or nil if error occurred
func (s *ClientServiceProvider) StartClientTunnels(client *Client, remotes []*models.Remote) ([]*clienttunnel.Tunnel, error) {
s.logger.Debugf("starting client tunnels: %s", client.ID)
s.logger.Debugf("starting client tunnels: %s", client.GetID())
s.mu.Lock()
defer s.mu.Unlock()
newTunnels, err := s.startClientTunnels(client, remotes)
newTunnels, err := s.startClientTunnels(client, remotes, s.log())
if err != nil {
return nil, err
}
@ -501,7 +476,7 @@ func (s *ClientServiceProvider) StartClientTunnels(client *Client, remotes []*mo
return newTunnels, err
}
func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*models.Remote) ([]*clienttunnel.Tunnel, error) {
func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*models.Remote, clog *logger.Logger) ([]*clienttunnel.Tunnel, error) {
err := s.portDistributor.Refresh()
if err != nil {
return nil, err
@ -510,7 +485,7 @@ func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*mo
tunnels := make([]*clienttunnel.Tunnel, 0, len(remotes))
for _, remote := range remotes {
if !remote.IsLocalSpecified() {
s.logger.Debugf("no local specified")
clog.Debugf("no local specified")
port, err := s.portDistributor.GetRandomPort(remote.Protocol)
if err != nil {
return nil, err
@ -518,15 +493,15 @@ func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*mo
remote.LocalPort = strconv.Itoa(port)
remote.LocalHost = models.ZeroHost
remote.LocalPortRandom = true
s.logger.Debugf("using random port %s", remote.LocalPort)
clog.Debugf("using random port %s", remote.LocalPort)
} else {
s.logger.Debugf("checking local port %s", remote.LocalPort)
clog.Debugf("checking local port %s", remote.LocalPort)
if err := s.checkLocalPort(remote.Protocol, remote.LocalPort); err != nil {
return nil, err
}
}
s.logger.Debugf("initiating tunnel %+v", remote)
clog.Debugf("initiating tunnel %+v", remote)
var acl *clienttunnel.TunnelACL
if remote.ACL != nil {
@ -537,9 +512,10 @@ func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*mo
}
}
s.logger.Debugf("starting tunnel: %s", remote)
clog.Debugf("starting tunnel: %s", remote)
t, err := s.StartTunnel(client, remote, acl)
if err != nil {
clog.Debugf("failed starting tunnel: %s: %v", remote, err)
return nil, apiErrors.APIError{
HTTPStatus: http.StatusConflict,
Err: fmt.Errorf("unable to start tunnel: %s", err),
@ -569,18 +545,16 @@ func (s *ClientServiceProvider) checkLocalPort(protocol, port string) error {
}
func (s *ClientServiceProvider) Terminate(client *Client) error {
s.logger.Infof("terminating client: %s", client.ID)
s.mu.Lock()
defer s.mu.Unlock()
if s.repo.KeepDisconnectedClients != nil && *s.repo.KeepDisconnectedClients == 0 {
s.log().Infof("terminating client: %s: %s", client.GetID(), client.GetName())
keepDisconnectedClientsDuration := s.repo.GetKeepDisconnectedClients()
if keepDisconnectedClientsDuration != nil && *keepDisconnectedClientsDuration == 0 {
return s.repo.Delete(client)
}
client.SetDisconnectedNow()
// Do not save if client doesn't exist in repo - it was force deleted
existing, err := s.repo.GetByID(client.ID)
existing, err := s.repo.GetByID(client.GetID())
if err != nil {
return err
}
@ -596,11 +570,9 @@ func (s *ClientServiceProvider) Terminate(client *Client) error {
// ForceDelete deletes client from repo regardless off KeepDisconnectedClients setting,
// if client is active it will be closed
func (s *ClientServiceProvider) ForceDelete(client *Client) error {
s.logger.Debugf("force deleting client: %s", client.ID)
s.logger.Debugf("force deleting client: %s", client.GetID())
s.mu.Lock()
defer s.mu.Unlock()
if client.DisconnectedAt == nil {
if client.IsConnected() {
if err := client.Close(); err != nil {
return err
}
@ -611,12 +583,12 @@ func (s *ClientServiceProvider) ForceDelete(client *Client) error {
func (s *ClientServiceProvider) DeleteOffline(clientID string) error {
s.logger.Debugf("deleting offline client: %s", clientID)
existing, err := s.getExistingByID(clientID)
existing, err := s.getExistingClientByID(clientID)
if err != nil {
return err
}
if existing.DisconnectedAt == nil {
if existing.IsConnected() {
return apiErrors.APIError{
Message: "Client is active, should be disconnected",
HTTPStatus: http.StatusBadRequest,
@ -628,8 +600,8 @@ func (s *ClientServiceProvider) DeleteOffline(clientID string) error {
// isClientAuthIDInUse returns true when the client with different id exists for the client auth
func (s *ClientServiceProvider) isClientAuthIDInUse(clientAuthID, clientID string) bool {
for _, s := range s.repo.GetAllByClientAuthID(clientAuthID) {
if s.ID != clientID {
for _, client := range s.repo.GetAllByClientAuthID(clientAuthID) {
if client.GetID() != clientID {
return true
}
}
@ -637,40 +609,40 @@ func (s *ClientServiceProvider) isClientAuthIDInUse(clientAuthID, clientID strin
}
func (s *ClientServiceProvider) SetACL(clientID string, allowedUserGroups []string) error {
existing, err := s.getExistingByID(clientID)
client, err := s.getExistingClientByID(clientID)
if err != nil {
return err
}
existing.AllowedUserGroups = allowedUserGroups
client.SetAllowedUserGroups(allowedUserGroups)
return s.repo.Save(existing)
return s.repo.Save(client)
}
func (s *ClientServiceProvider) SetUpdatesStatus(clientID string, updatesStatus *models.UpdatesStatus) error {
existing, err := s.getExistingByID(clientID)
client, err := s.getExistingClientByID(clientID)
if err != nil {
return err
}
existing.UpdatesStatus = updatesStatus
client.SetUpdatesStatus(updatesStatus)
return s.repo.Save(existing)
return s.repo.Save(client)
}
func (s *ClientServiceProvider) SetLastHeartbeat(clientID string, heartbeat time.Time) error {
existing, err := s.getExistingByID(clientID)
existing, err := s.getExistingClientByID(clientID)
if err != nil {
return err
}
existing.SetHeartbeat(&heartbeat)
existing.SetLastHeartbeatAt(&heartbeat)
return nil
}
// CheckClientAccess returns nil if a given user has an access to a given client.
// Otherwise, APIError with 403 is returned.
func (s *ClientServiceProvider) CheckClientAccess(clientID string, user User, groups []*cgroups.ClientGroup) error {
existing, err := s.getExistingByID(clientID)
existing, err := s.getExistingClientByID(clientID)
if err != nil {
return err
}
@ -687,11 +659,13 @@ func (s *ClientServiceProvider) CheckClientsAccess(clients []*Client, user User,
var clientsWithNoAccess []string
userGroups := user.GetGroups()
for _, curClient := range clients {
if curClient.HasAccessViaUserGroups(userGroups) || curClient.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
for _, client := range clients {
clientID := client.GetID()
if client.HasAccessViaUserGroups(userGroups) || client.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
continue
}
clientsWithNoAccess = append(clientsWithNoAccess, curClient.ID)
clientsWithNoAccess = append(clientsWithNoAccess, clientID)
}
if len(clientsWithNoAccess) > 0 {
@ -704,8 +678,8 @@ func (s *ClientServiceProvider) CheckClientsAccess(clients []*Client, user User,
return nil
}
// getExistingByID returns non-nil client by id. If not found or failed to get a client - an error is returned.
func (s *ClientServiceProvider) getExistingByID(clientID string) (*Client, error) {
// getExistingClientByID returns non-nil client by id. If not found or failed to get a client - an error is returned.
func (s *ClientServiceProvider) getExistingClientByID(clientID string) (*Client, error) {
if clientID == "" {
return nil, apiErrors.APIError{
Message: "Client id is empty",
@ -729,13 +703,15 @@ func (s *ClientServiceProvider) getExistingByID(clientID string) (*Client, error
}
func (s *ClientServiceProvider) GetRepo() *ClientRepository {
s.mu.RLock()
defer s.mu.RUnlock()
return s.repo
}
func ExcludeNotAllowedTunnels(clog *logger.Logger, tunnels []*models.Remote, conn ssh.Conn) ([]*models.Remote, error) {
func (s *ClientServiceProvider) excludeNotAllowedTunnels(clog *logger.Logger, tunnels []*models.Remote, conn ssh.Conn) ([]*models.Remote, error) {
filtered := make([]*models.Remote, 0, len(tunnels))
for _, t := range tunnels {
allowed, err := clienttunnel.IsAllowed(t.Remote(), conn)
allowed, err := clienttunnel.IsAllowed(t.Remote(), conn, s.log())
if err != nil {
if strings.Contains(err.Error(), "unknown request") {
return tunnels, nil
@ -751,19 +727,21 @@ func ExcludeNotAllowedTunnels(clog *logger.Logger, tunnels []*models.Remote, con
return filtered, nil
}
// TODO: (rs): can this move to the tunnel package?
func (s *ClientServiceProvider) FindTunnelByRemote(c *Client, r *models.Remote) *clienttunnel.Tunnel {
for _, curr := range c.Tunnels {
if curr.Equals(r) {
return curr
for _, tunnel := range c.GetTunnels() {
if tunnel.Equals(r) {
return tunnel
}
}
return nil
}
// TODO: (rs): can this move to the tunnel package?
func (s *ClientServiceProvider) FindTunnel(c *Client, id string) *clienttunnel.Tunnel {
for _, curr := range c.Tunnels {
if curr.ID == id {
return curr
for _, tunnel := range c.GetTunnels() {
if tunnel.ID == id {
return tunnel
}
}
return nil
@ -783,9 +761,9 @@ func (s *ClientServiceProvider) StartTunnel(
return tunnel, nil
}
s.logger.Debugf("starting tunnel: %s", remote)
s.log().Debugf("starting tunnel: %s", remote)
ctx := client.Context
ctx := client.GetContext()
if remote.AutoClose > 0 {
// no need to cancel the ctx since it will be canceled by parent ctx or after given timeout
ctx, _ = context.WithTimeout(ctx, remote.AutoClose) // nolint: govet
@ -800,9 +778,9 @@ func (s *ClientServiceProvider) StartTunnel(
if remote.HasSubdomainTunnel() {
err = s.startCaddyDownstreamProxy(ctx, client, remote, tunnel)
if err != nil {
tunnelStopErr := tunnel.InternalTunnelProxy.Stop(client.Context)
tunnelStopErr := tunnel.InternalTunnelProxy.Stop(client.GetContext())
if tunnelStopErr != nil {
client.Logger.Infof("unable to stop internal tunnel proxy after failing to create caddy downstream proxy: %s", tunnelStopErr)
client.Log().Infof("unable to stop internal tunnel proxy after failing to create caddy downstream proxy: %s", tunnelStopErr)
}
return nil, err
}
@ -824,7 +802,10 @@ func (s *ClientServiceProvider) StartTunnel(
go s.terminateTunnelOnIdleTimeout(ctx, tunnel, client)
}
client.Tunnels = append(client.Tunnels, tunnel)
existingTunnels := client.GetTunnels()
existingTunnels = append(existingTunnels, tunnel)
client.SetTunnels(existingTunnels)
return tunnel, nil
}
@ -834,9 +815,11 @@ func (s *ClientServiceProvider) startCaddyDownstreamProxy(
remote *models.Remote,
tunnel *clienttunnel.Tunnel,
) (err error) {
client.Logger.Infof("starting downstream caddy proxy at %s", remote.TunnelURL)
client.Logger.Debugf("tunnel = %#v", tunnel)
client.Logger.Debugf("remote = %#v", remote)
clientLogger := client.Log()
clientLogger.Infof("starting downstream caddy proxy at %s", remote.TunnelURL)
clientLogger.Debugf("tunnel = %#v", tunnel)
clientLogger.Debugf("remote = %#v", remote)
subdomain, basedomain, err := remote.GetTunnelDomains()
if err != nil {
@ -851,7 +834,7 @@ func (s *ClientServiceProvider) startCaddyDownstreamProxy(
DownstreamProxyBaseDomain: basedomain,
}
client.Logger.Debugf("requesting new caddy route = %+v", nrr)
clientLogger.Debugf("requesting new caddy route = %+v", nrr)
res, err := s.caddyAPI.AddRoute(ctx, nrr)
if err != nil {
@ -862,14 +845,14 @@ func (s *ClientServiceProvider) startCaddyDownstreamProxy(
return fmt.Errorf("failed to create downstream caddy proxy: status_code: %d", res.StatusCode)
}
client.Logger.Infof("started downstream caddy proxy at %s to %s:%s", remote.TunnelURL, tunnel.LocalHost, tunnel.LocalPort)
clientLogger.Infof("started downstream caddy proxy at %s to %s:%s", remote.TunnelURL, tunnel.LocalHost, tunnel.LocalPort)
return nil
}
func (s *ClientServiceProvider) startRegularTunnel(ctx context.Context, client *Client, remote *models.Remote, acl *clienttunnel.TunnelACL) (*clienttunnel.Tunnel, error) {
tunnelID := client.NewTunnelID()
tunnel, err := clienttunnel.NewTunnel(client.Logger, client.Connection, tunnelID, *remote, acl)
tunnel, err := clienttunnel.NewTunnel(client.Log(), client.GetConnection(), tunnelID, *remote, acl)
if err != nil {
return nil, err
}
@ -888,12 +871,14 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
remote *models.Remote,
acl *clienttunnel.TunnelACL,
) (*clienttunnel.Tunnel, error) {
var proxyACL *clienttunnel.TunnelACL
proxyHost := ""
proxyPort := ""
var proxyACL *clienttunnel.TunnelACL
clientLogger := client.Log()
// assuming that we still want to log activity in the client log
client.Logger.Debugf("client %s will use tunnel proxy", client.ID)
client.Logger.Debugf("client %s will use tunnel proxy", client.GetID())
// get values for tunnel proxy local host addr from original remote
proxyHost = remote.LocalHost
@ -913,7 +898,7 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
tunnelID := client.NewTunnelID()
// original tunnel will use the reconfigured original remote
t, err := clienttunnel.NewTunnel(client.Logger, client.Connection, tunnelID, *remote, acl)
t, err := clienttunnel.NewTunnel(client.Logger, client.GetConnection(), tunnelID, *remote, acl)
if err != nil {
return nil, err
}
@ -925,8 +910,8 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
}
// create new proxy tunnel listening at the original tunnel local host addr
tProxy := clienttunnel.NewInternalTunnelProxy(t, client.Logger, s.tunnelProxyConfig, proxyHost, proxyPort, proxyACL)
client.Logger.Debugf("client %s starting tunnel proxy", client.ID)
tProxy := clienttunnel.NewInternalTunnelProxy(t, client.Log(), s.tunnelProxyConfig, proxyHost, proxyPort, proxyACL)
clientLogger.Debugf("client %s starting tunnel proxy", client.GetID())
if err := tProxy.Start(ctx); err != nil {
client.Logger.Debugf("tunnel proxy could not be started, tunnel must be terminated: %v", err)
if tErr := t.Terminate(true); tErr != nil {
@ -941,8 +926,8 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
t.Remote.LocalHost = t.InternalTunnelProxy.Host
t.Remote.LocalPort = t.InternalTunnelProxy.Port
client.Logger.Debugf("client %s started tunnel with proxy: %#v", client.ID, t)
client.Logger.Debugf("internal tunnel proxy: %#v", t.InternalTunnelProxy)
clientLogger.Debugf("client %s started tunnel with proxy: %#v", client.GetID(), t)
clientLogger.Debugf("internal tunnel proxy: %#v", t.InternalTunnelProxy)
return t, nil
}
@ -966,7 +951,7 @@ func (s *ClientServiceProvider) terminateTunnelOnIdleTimeout(ctx context.Context
case <-timer.C:
sinceLastActive := time.Since(t.LastActive())
if sinceLastActive > idleTimeout {
c.Logger.Infof("Terminating... inactivity period is reached: %d minute(s)", t.IdleTimeoutMinutes)
c.Log().Infof("Terminating... inactivity period is reached: %d minute(s)", t.IdleTimeoutMinutes)
_ = t.Terminate(true)
s.cleanupAfterAutoClose(c, t)
return
@ -977,15 +962,14 @@ func (s *ClientServiceProvider) terminateTunnelOnIdleTimeout(ctx context.Context
}
func (s *ClientServiceProvider) cleanupAfterAutoClose(c *Client, t *clienttunnel.Tunnel) {
c.Lock()
defer c.Unlock()
clientLogger := c.Log()
c.Logger.Infof("Auto closing tunnel %s ...", t.ID)
clientLogger.Infof("Auto closing tunnel %s ...", t.ID)
//stop tunnel proxy
if t.InternalTunnelProxy != nil {
if err := t.InternalTunnelProxy.Stop(c.Context); err != nil {
c.Logger.Errorf("error while stopping tunnel proxy: %v", err)
if err := t.InternalTunnelProxy.Stop(c.GetContext()); err != nil {
clientLogger.Errorf("error while stopping tunnel proxy: %v", err)
}
if t.Remote.HasSubdomainTunnel() {
_ = s.removeCaddyDownstreamProxy(c, t)
@ -996,14 +980,16 @@ func (s *ClientServiceProvider) cleanupAfterAutoClose(c *Client, t *clienttunnel
err := s.repo.Save(c)
if err != nil {
c.Logger.Errorf("unable to save client after auto close cleanup: %v", err)
clientLogger.Errorf("unable to save client after auto close cleanup: %v", err)
}
c.Logger.Debugf("auto closed tunnel with id=%s removed", t.ID)
clientLogger.Debugf("auto closed tunnel with id=%s removed", t.ID)
}
func (s *ClientServiceProvider) TerminateTunnel(c *Client, t *clienttunnel.Tunnel, force bool) error {
c.Logger.Infof("Terminating tunnel %s (force: %v) ...", t.ID, force)
clientLogger := c.Log()
clientLogger.Infof("Terminating tunnel %s (force: %v) ...", t.ID, force)
err := t.Terminate(force)
if err != nil {
@ -1011,8 +997,8 @@ func (s *ClientServiceProvider) TerminateTunnel(c *Client, t *clienttunnel.Tunne
}
if t.InternalTunnelProxy != nil {
if err := t.InternalTunnelProxy.Stop(c.Context); err != nil {
c.Logger.Errorf("error while stopping tunnel proxy: %v", err)
if err := t.InternalTunnelProxy.Stop(c.GetContext()); err != nil {
clientLogger.Errorf("error while stopping tunnel proxy: %v", err)
}
if t.Remote.HasSubdomainTunnel() {
_ = s.removeCaddyDownstreamProxy(c, t)
@ -1026,10 +1012,10 @@ func (s *ClientServiceProvider) TerminateTunnel(c *Client, t *clienttunnel.Tunne
err = s.repo.Save(c)
if err != nil {
c.Logger.Errorf("unable to save client after auto close cleanup: %v", err)
clientLogger.Errorf("unable to save client after auto close cleanup: %v", err)
}
c.Logger.Debugf("terminated tunnel with id=%s removed", t.ID)
clientLogger.Debugf("terminated tunnel with id=%s removed", t.ID)
return nil
}
@ -1059,14 +1045,16 @@ func (s *ClientServiceProvider) SetTunnelACL(c *Client, t *clienttunnel.Tunnel,
}
func (s *ClientServiceProvider) removeCaddyDownstreamProxy(c *Client, t *clienttunnel.Tunnel) (err error) {
c.Logger.Infof("removing downstream caddy proxy at %s", t.Remote.TunnelURL)
clientLogger := c.Log()
clientLogger.Infof("removing downstream caddy proxy at %s", t.Remote.TunnelURL)
subdomain, _, err := t.Remote.GetTunnelDomains()
if err != nil {
return err
}
res, err := s.caddyAPI.DeleteRoute(c.Context, subdomain)
res, err := s.caddyAPI.DeleteRoute(c.GetContext(), subdomain)
if err != nil {
return err
}
@ -1075,6 +1063,6 @@ func (s *ClientServiceProvider) removeCaddyDownstreamProxy(c *Client, t *clientt
return fmt.Errorf("failed to delete downstream caddy proxy: status_code: %d", res.StatusCode)
}
c.Logger.Infof("removed downstream caddy proxy at %s", t.Remote.TunnelURL)
clientLogger.Infof("removed downstream caddy proxy at %s", t.Remote.TunnelURL)
return nil
}

View File

@ -4,23 +4,30 @@ import (
"context"
"errors"
"fmt"
"math/rand"
"net"
"net/http"
"os"
"path"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/cloudradar-monitoring/rport/server/caddy"
"github.com/cloudradar-monitoring/rport/server/cgroups"
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
"github.com/cloudradar-monitoring/rport/server/clientsauth"
mapset "github.com/deckarep/golang-set"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
errors2 "github.com/cloudradar-monitoring/rport/server/api/errors"
clientsmigration "github.com/cloudradar-monitoring/rport/db/migration/clients"
"github.com/cloudradar-monitoring/rport/db/sqlite"
"github.com/cloudradar-monitoring/rport/server/caddy"
"github.com/cloudradar-monitoring/rport/server/cgroups"
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
"github.com/cloudradar-monitoring/rport/server/clientsauth"
apiErrors "github.com/cloudradar-monitoring/rport/server/api/errors"
"github.com/cloudradar-monitoring/rport/server/api/users"
"github.com/cloudradar-monitoring/rport/server/ports"
chshare "github.com/cloudradar-monitoring/rport/share"
@ -93,7 +100,8 @@ func TestStartClient(t *testing.T) {
ID: "test-client",
ClientAuthID: "test-client-auth",
}}, nil, testLog),
portDistributor: ports.NewPortDistributor(mapset.NewThreadUnsafeSet()),
portDistributor: ports.NewPortDistributor(mapset.NewSet()),
logger: testLog,
}
_, err := cs.StartClient(
context.Background(), tc.ClientAuthID, tc.ClientID, connMock, tc.AuthMultiuseCreds,
@ -103,6 +111,90 @@ func TestStartClient(t *testing.T) {
}
}
// this is a fairly crude concurrency test for start client. currently excluded from the regular test runs as
// it consumes a moderate amount of memory and takes some time to run.
// go test -count=1 -race -v github.com/cloudradar-monitoring/rport/server/clients -run TestStartClientConcurrency
func TestStartClientConcurrency(t *testing.T) {
t.Skip()
// runtime.GOMAXPROCS(runtime.NumCPU() - 1)
runtime.GOMAXPROCS(8)
sourceOptions := sqlite.DataSourceOptions{
MaxOpenConnections: 250,
WALEnabled: false,
}
clientDB, err := sqlite.New(
path.Join("./", "clients.db"),
clientsmigration.AssetNames(),
clientsmigration.Asset,
sourceOptions,
)
require.NoError(t, err)
defer os.Remove("./clients.db")
defer os.Remove("./clients.db-shm")
pd := ports.NewPortDistributor(mapset.NewSet())
totalClients := 1500
clients := []*Client{{}}
for i := 0; i < totalClients; i++ {
client := Client{
Name: "test-name-" + strconv.Itoa(i),
ID: "test-id-" + strconv.Itoa(i),
ClientAuthID: "test-client-auth-" + strconv.Itoa(i),
Version: "0.6.4",
}
clients = append(clients, &client)
}
mockConns := []*test.ConnMock{}
for i := 0; i < totalClients; i++ {
mockConn := test.NewConnMock()
mockConn.ReturnRemoteAddr = &net.TCPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 2000}
mockConns = append(mockConns, mockConn)
}
repo, err := InitClientRepository(context.Background(), clientDB, nil, testLog)
require.NoError(t, err)
cs := &ClientServiceProvider{
repo: repo,
portDistributor: pd,
logger: testLog,
}
wg := sync.WaitGroup{}
for i := 0; i < totalClients; i++ {
wg.Add(1)
go func(i int) {
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
client := clients[i]
_, err := cs.StartClient(
context.Background(),
client.GetClientAuthID(),
client.GetID(),
mockConns[i],
false,
&chshare.ConnectionRequest{
Name: client.GetName(),
},
testLog,
)
assert.NoError(t, err)
wg.Done()
}(i)
}
wg.Wait()
}
func TestStartClientDisconnected(t *testing.T) {
connMock := test.NewConnMock()
connMock.ReturnRemoteAddr = &net.TCPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 2345}
@ -115,25 +207,27 @@ func TestStartClientDisconnected(t *testing.T) {
AllowedUserGroups: []string{"test-group"},
UpdatesStatus: &models.UpdatesStatus{UpdatesAvailable: 13},
}}, nil, testLog),
portDistributor: ports.NewPortDistributor(mapset.NewThreadUnsafeSet()),
portDistributor: ports.NewPortDistributor(mapset.NewSet()),
logger: testLog,
}
client, err := cs.StartClient(
context.Background(), "test-client-auth", "disconnected-client", connMock, false,
&chshare.ConnectionRequest{Name: "new-connection", Version: "0.7.0"}, testLog)
assert.NoError(t, err)
assert.Nil(t, client.DisconnectedAt)
assert.Equal(t, "disconnected-client", client.ID)
assert.Equal(t, "new-connection", client.Name)
assert.Equal(t, []string{"test-group"}, client.AllowedUserGroups)
assert.Equal(t, "disconnected-client", client.GetID())
assert.Equal(t, "new-connection", client.GetName())
assert.Equal(t, []string{"test-group"}, client.GetAllowedUserGroups())
assert.Equal(t, 13, client.UpdatesStatus.UpdatesAvailable)
}
func TestDeleteOfflineClient(t *testing.T) {
c1Active := New(t).Build()
c2Active := New(t).Build()
c3Offline := New(t).DisconnectedDuration(5 * time.Minute).Build()
c4Offline := New(t).DisconnectedDuration(time.Minute).Build()
c1Active := New(t).Logger(testLog).Build()
c2Active := New(t).Logger(testLog).Build()
c3Offline := New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c4Offline := New(t).DisconnectedDuration(time.Minute).Logger(testLog).Build()
testCases := []struct {
name string
@ -142,13 +236,13 @@ func TestDeleteOfflineClient(t *testing.T) {
}{
{
name: "delete offline client",
clientID: c3Offline.ID,
clientID: c3Offline.GetID(),
wantError: nil,
},
{
name: "delete active client",
clientID: c1Active.ID,
wantError: errors2.APIError{
clientID: c1Active.GetID(),
wantError: apiErrors.APIError{
Message: "Client is active, should be disconnected",
HTTPStatus: http.StatusBadRequest,
},
@ -156,7 +250,7 @@ func TestDeleteOfflineClient(t *testing.T) {
{
name: "delete unknown client",
clientID: "unknown-id",
wantError: errors2.APIError{
wantError: apiErrors.APIError{
Message: fmt.Sprintf("Client with id=%q not found.", "unknown-id"),
HTTPStatus: http.StatusNotFound,
},
@ -164,7 +258,7 @@ func TestDeleteOfflineClient(t *testing.T) {
{
name: "empty client ID",
clientID: "",
wantError: errors2.APIError{
wantError: apiErrors.APIError{
Message: "Client id is empty",
HTTPStatus: http.StatusBadRequest,
},
@ -175,8 +269,7 @@ func TestDeleteOfflineClient(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
// given
clientService := NewClientService(nil, nil, NewClientRepository([]*Client{c1Active, c2Active, c3Offline, c4Offline}, &hour, testLog), testLog)
before, err := clientService.Count()
require.NoError(t, err)
before := clientService.Count()
require.Equal(t, 4, before)
// when
@ -190,8 +283,7 @@ func TestDeleteOfflineClient(t *testing.T) {
} else {
wantAfter = before - 1
}
gotAfter, err := clientService.Count()
require.NoError(t, err)
gotAfter := clientService.Count()
assert.Equal(t, wantAfter, gotAfter)
})
}
@ -200,9 +292,9 @@ func TestDeleteOfflineClient(t *testing.T) {
func TestCheckLocalPort(t *testing.T) {
srv := ClientServiceProvider{
portDistributor: ports.NewPortDistributorForTests(
mapset.NewThreadUnsafeSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4}),
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4, 5}),
mapset.NewSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
mapset.NewSetFromSlice([]interface{}{2, 3, 4}),
mapset.NewSetFromSlice([]interface{}{2, 3, 4, 5}),
),
}
@ -223,7 +315,7 @@ func TestCheckLocalPort(t *testing.T) {
{
name: "invalid port",
port: invalidPort,
wantError: errors2.APIError{
wantError: apiErrors.APIError{
Message: "Invalid local port: 24563a.",
Err: invalidPortParseErr,
HTTPStatus: http.StatusBadRequest,
@ -232,7 +324,7 @@ func TestCheckLocalPort(t *testing.T) {
{
name: "not allowed port",
port: "6",
wantError: errors2.APIError{
wantError: apiErrors.APIError{
Message: "Local port 6 is not among allowed ports.",
HTTPStatus: http.StatusBadRequest,
},
@ -240,7 +332,7 @@ func TestCheckLocalPort(t *testing.T) {
{
name: "busy port tcp",
port: "5",
wantError: errors2.APIError{
wantError: apiErrors.APIError{
Message: "Local port 5 already in use.",
HTTPStatus: http.StatusConflict,
},
@ -255,7 +347,7 @@ func TestCheckLocalPort(t *testing.T) {
name: "tcp+udp port busy",
port: "5",
protocol: models.ProtocolTCPUDP,
wantError: errors2.APIError{
wantError: apiErrors.APIError{
Message: "Local port 5 already in use.",
HTTPStatus: http.StatusConflict,
},
@ -283,13 +375,13 @@ func TestCheckLocalPort(t *testing.T) {
}
func TestCheckClientsAccess(t *testing.T) {
c1 := New(t).Build() // no groups
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Build() // admin
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Build() // admin + group1
c4 := New(t).AllowedUserGroups([]string{"group1"}).Build() // group1
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Build() // group1 + group2
c6 := New(t).AllowedUserGroups([]string{"group3"}).Build() // group3
c7 := New(t).Build()
c1 := New(t).Logger(testLog).Build() // no groups
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Logger(testLog).Build() // admin
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Logger(testLog).Build() // admin + group1
c4 := New(t).AllowedUserGroups([]string{"group1"}).Logger(testLog).Build() // group1
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Logger(testLog).Build() // group1 + group2
c6 := New(t).AllowedUserGroups([]string{"group3"}).Logger(testLog).Build() // group3
c7 := New(t).Logger(testLog).Build()
allClients := []*Client{c1, c2, c3, c4, c5, c6}
clientGroups := []*cgroups.ClientGroup{
@ -297,7 +389,7 @@ func TestCheckClientsAccess(t *testing.T) {
ID: "1",
AllowedUserGroups: []string{"group4"},
Params: &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{cgroups.Param(c7.ID)},
ClientID: &cgroups.ParamValues{cgroups.Param(c7.GetID())},
},
},
}
@ -312,7 +404,7 @@ func TestCheckClientsAccess(t *testing.T) {
name: "user with no groups has no access",
clients: allClients,
user: &users.User{Groups: nil},
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c3.ID, c4.ID, c5.ID, c6.ID},
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c3.GetID(), c4.GetID(), c5.GetID(), c6.GetID()},
},
{
name: "admin user has access to all",
@ -330,25 +422,25 @@ func TestCheckClientsAccess(t *testing.T) {
name: "non-admin user with no access to clients with no groups and with admin group",
clients: allClients,
user: &users.User{Groups: []string{"group1", "group2", "group3"}},
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID},
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID()},
},
{
name: "non-admin user with access to one client",
clients: allClients,
user: &users.User{Groups: []string{"group3"}},
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c3.ID, c4.ID, c5.ID},
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c3.GetID(), c4.GetID(), c5.GetID()},
},
{
name: "non-admin user with access to few clients",
clients: allClients,
user: &users.User{Groups: []string{"group1"}},
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c6.ID},
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c6.GetID()},
},
{
name: "non-admin user that has unknown group",
clients: allClients,
user: &users.User{Groups: []string{"group4"}},
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c3.ID, c4.ID, c5.ID, c6.ID},
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c3.GetID(), c4.GetID(), c5.GetID(), c6.GetID()},
},
{
name: "non-admin user given access via client groups",
@ -368,7 +460,7 @@ func TestCheckClientsAccess(t *testing.T) {
// then
if len(tc.wantClientIDsWithNoAccess) > 0 {
wantErr := errors2.APIError{
wantErr := apiErrors.APIError{
Message: fmt.Sprintf("Access denied to client(s) with ID(s): %v", strings.Join(tc.wantClientIDsWithNoAccess, ", ")),
HTTPStatus: http.StatusForbidden,
}
@ -782,7 +874,7 @@ func TestGetTunnelsToReestablish(t *testing.T) {
}
// when
gotRes := GetTunnelsToReestablish(old, new)
gotRes := getTunnelsToReestablish(old, new)
var gotResStr []string
for _, r := range gotRes {
@ -843,7 +935,7 @@ func TestShouldStartTunnelsWithSubdomains(t *testing.T) {
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
c1 := New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
c1 := New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
c1.Connection = connMock
c1.Logger = testLog
c1.Context = context.Background()
@ -855,9 +947,9 @@ func TestShouldStartTunnelsWithSubdomains(t *testing.T) {
}
pd := ports.NewPortDistributorForTests(
mapset.NewThreadUnsafeSetFromSlice([]interface{}{4000, 4001, 4002, 4003, 4004, 4005}),
mapset.NewThreadUnsafeSetFromSlice([]interface{}{4000, 4002, 4003, 4004, 4005}),
mapset.NewThreadUnsafeSetFromSlice([]interface{}{5002, 5003, 5004, 5005}),
mapset.NewSetFromSlice([]interface{}{4000, 4001, 4002, 4003, 4004, 4005}),
mapset.NewSetFromSlice([]interface{}{4000, 4002, 4003, 4004, 4005}),
mapset.NewSetFromSlice([]interface{}{5002, 5003, 5004, 5005}),
)
mockCaddyAPI := &MockCaddyAPI{}

View File

@ -5,11 +5,23 @@ import (
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ssh"
"github.com/cloudradar-monitoring/rport/server/api/users"
"github.com/cloudradar-monitoring/rport/server/cgroups"
)
func NewTestClient(id string, address string, hostname string, clientAuthID string, connection ssh.Conn) (c *Client) {
c = &Client{
ID: id,
Address: address,
Hostname: hostname,
ClientAuthID: clientAuthID,
Connection: connection,
}
return c
}
func TestClientBelongsToGroup(t *testing.T) {
c1 := &Client{
ID: "test-client-id-1",

View File

@ -6,14 +6,15 @@ import (
"golang.org/x/crypto/ssh"
"github.com/cloudradar-monitoring/rport/share/comm"
"github.com/cloudradar-monitoring/rport/share/logger"
)
func IsAllowed(remote string, conn ssh.Conn) (bool, error) {
func IsAllowed(remote string, conn ssh.Conn, l *logger.Logger) (bool, error) {
req := &comm.CheckTunnelAllowedRequest{
Remote: remote,
}
resp := &comm.CheckTunnelAllowedResponse{}
err := comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckTunnelAllowed, req, resp)
err := comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckTunnelAllowed, req, resp, l)
if err != nil {
if strings.Contains(err.Error(), "unknown request") {
return true, nil

View File

@ -15,13 +15,16 @@ import (
)
type ClientRepository struct {
// in-memory cache
clients map[string]*Client
mu sync.RWMutex
KeepDisconnectedClients *time.Duration
// storage
store ClientStore
// in-memory state
clientState map[string]*Client
// db based store
clientStore ClientStore
keepDisconnectedClients *time.Duration
logger *logger.Logger
mu sync.RWMutex
}
type User interface {
@ -40,14 +43,15 @@ func NewClientRepository(initClients []*Client, keepDisconnectedClients *time.Du
func NewClientRepositoryWithDB(initialClients []*Client, keepDisconnectedClients *time.Duration, store ClientStore, logger *logger.Logger) *ClientRepository {
clients := make(map[string]*Client)
for i := range initialClients {
clients[initialClients[i].ID] = initialClients[i]
newClientID := initialClients[i].GetID()
clients[newClientID] = initialClients[i]
}
return &ClientRepository{
clients: clients,
KeepDisconnectedClients: keepDisconnectedClients,
store: store,
clientState: clients,
clientStore: store,
logger: logger,
keepDisconnectedClients: keepDisconnectedClients,
}
}
@ -58,7 +62,7 @@ func InitClientRepository(
logger *logger.Logger,
) (*ClientRepository, error) {
provider := newSqliteProvider(db, keepDisconnectedClients)
initialClients, err := LoadInitialClients(ctx, provider, logger.Fork("client-loader"))
initialClients, err := LoadInitialClients(ctx, provider, logger)
if err != nil {
return nil, err
}
@ -66,52 +70,54 @@ func InitClientRepository(
return NewClientRepositoryWithDB(initialClients, keepDisconnectedClients, provider, logger), nil
}
func (s *ClientRepository) Save(client *Client) error {
func (r *ClientRepository) Save(client *Client) error {
ts := time.Now()
if s.store != nil {
err := s.store.Save(context.Background(), client)
store := r.getStore()
if store != nil {
err := store.Save(context.Background(), client)
if err != nil {
return fmt.Errorf("failed to save a client: %w", err)
return fmt.Errorf("failed to save client: %w", err)
}
}
s.mu.Lock()
defer s.mu.Unlock()
s.clients[client.ID] = client
s.logger.Debugf(
r.updateClient(client)
r.log().Debugf(
"saved client: %s status=%s, within %s",
client.ID,
FormatConnectionState(client.DisconnectedAt),
client.GetID(),
FormatConnectionState(client),
time.Since(ts),
)
return nil
}
func (s *ClientRepository) Delete(client *Client) error {
s.logger.Debugf("deleting client: %s status=%s", client.ID, FormatConnectionState(client.DisconnectedAt))
func (r *ClientRepository) Delete(client *Client) error {
clientID := client.GetID()
if s.store != nil {
err := s.store.Delete(context.Background(), client.ID)
r.log().Debugf("deleting client: %s status=%s", clientID, FormatConnectionState(client))
store := r.getStore()
if store != nil {
err := store.Delete(context.Background(), clientID, client.Log())
if err != nil {
return fmt.Errorf("failed to delete a client: %w", err)
}
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.clients, client.ID)
r.removeClient(clientID)
return nil
}
func (s *ClientRepository) GetClientsByTag(tags []string, operator string, allowDisconnected bool) (matchingClients []*Client, err error) {
func (r *ClientRepository) GetClientsByTag(tags []string, operator string, allowDisconnected bool) (matchingClients []*Client, err error) {
var availableClients []*Client
if allowDisconnected {
availableClients, err = s.GetAll()
if err != nil {
return nil, err
}
availableClients = r.GetAllClients()
} else {
availableClients = s.GetAllActive()
availableClients, _ = r.GetAllActiveClients()
}
if strings.EqualFold(operator, "AND") {
matchingClients = findMatchingANDClients(availableClients, tags)
@ -122,10 +128,13 @@ func (s *ClientRepository) GetClientsByTag(tags []string, operator string, allow
return matchingClients, nil
}
// this fn doesn't lock the availableClients. please make sure not to use the main clients array.
// the various GetXXXClient fns will return new client arrays. please use those fns to get a
// clients array copy for this fn to operate on.
func findMatchingANDClients(availableClients []*Client, tags []string) (matchingClients []*Client) {
matchingClients = make([]*Client, 0, 64)
for _, cl := range availableClients {
clientTags := cl.Tags
clientTags := cl.GetTags()
foundAllTags := true
for _, tag := range tags {
@ -144,14 +153,18 @@ func findMatchingANDClients(availableClients []*Client, tags []string) (matching
if foundAllTags {
matchingClients = append(matchingClients, cl)
}
}
return matchingClients
}
// this fn doesn't lock the availableClients. please make sure not to use the main clients array.
// the various GetXXXClient fns will return new client arrays. please use those fns to get a
// clients array copy for this fn to operate on.
func findMatchingORClients(availableClients []*Client, tags []string) (matchingClients []*Client) {
matchingClients = make([]*Client, 0, 64)
for _, cl := range availableClients {
clientTags := cl.Tags
clientTags := cl.GetTags()
nextClientForOR:
for _, clTag := range clientTags {
for _, tag := range tags {
@ -166,57 +179,56 @@ func findMatchingORClients(availableClients []*Client, tags []string) (matchingC
}
// DeleteObsolete deletes obsolete disconnected clients and returns them.
func (s *ClientRepository) DeleteObsolete() ([]*Client, error) {
s.logger.Debugf("deleting obsolete clients")
if s.store != nil {
err := s.store.DeleteObsolete(context.Background())
func (r *ClientRepository) DeleteObsolete() ([]*Client, error) {
r.log().Debugf("deleting obsolete clients")
store := r.getStore()
if store != nil {
err := store.DeleteObsolete(context.Background(), r.log())
if err != nil {
return nil, fmt.Errorf("failed to delete obsolete clients: %w", err)
}
}
s.mu.Lock()
defer s.mu.Unlock()
var deleted []*Client
for _, client := range s.clients {
if client.Obsolete(s.KeepDisconnectedClients) {
s.logger.Debugf("deleting obsolete client: %s status=%s", client.ID, FormatConnectionState(client.DisconnectedAt))
r.mu.RLock()
for _, client := range r.getClients() {
r.mu.RUnlock()
clientID := client.GetID()
if client.Obsolete(r.GetKeepDisconnectedClients()) {
r.log().Debugf("deleting obsolete client: %s status=%s", clientID, FormatConnectionState(client))
r.removeClient(clientID)
delete(s.clients, client.ID)
deleted = append(deleted, client)
}
r.mu.RLock()
}
r.mu.RUnlock()
return deleted, nil
}
// Count returns a number of non-obsolete active and disconnected clients.
func (s *ClientRepository) Count() (int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
clients, err := s.getNonObsolete()
return len(clients), err
func (r *ClientRepository) Count() int {
_, count := r.getNonObsoleteClients()
return count
}
// CountActive returns a number of active clients.
func (s *ClientRepository) CountActive() (int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
return len(s.GetAllActive()), nil
func (r *ClientRepository) CountActive() (count int) {
_, count = r.GetAllActiveClients()
return count
}
// CountDisconnected returns a number of disconnected clients.
func (s *ClientRepository) CountDisconnected() (int, error) {
s.mu.RLock()
defer s.mu.RUnlock()
all, err := s.getNonObsolete()
if err != nil {
return 0, err
}
func (r *ClientRepository) CountDisconnected() (int, error) {
availableClients, _ := r.getNonObsoleteClients()
var n int
for _, cur := range all {
if cur.DisconnectedAt != nil {
// uses copy of clients array returned by getNonObsoleteClients
for _, cur := range availableClients {
if cur.GetDisconnectedAt() != nil {
n++
}
}
@ -224,114 +236,182 @@ func (s *ClientRepository) CountDisconnected() (int, error) {
}
// GetByID returns non-obsolete active or disconnected client by a given id.
func (s *ClientRepository) GetByID(id string) (*Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client := s.clients[id]
if client != nil && client.Obsolete(s.KeepDisconnectedClients) {
func (r *ClientRepository) GetByID(id string) (*Client, error) {
client := r.getClient(id)
if client != nil && client.Obsolete(r.GetKeepDisconnectedClients()) {
return nil, nil
}
return client, nil
}
// GetActiveByID returns an active client by a given id.
func (s *ClientRepository) GetActiveByID(id string) (*Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
client := s.clients[id]
if client != nil && client.DisconnectedAt != nil {
func (r *ClientRepository) GetActiveByID(id string) (*Client, error) {
client := r.getClient(id)
if client != nil && client.GetDisconnectedAt() != nil {
return nil, nil
}
return client, nil
}
// GetAllByClientAuthID @todo: make it consistent with others whether to return an error. In general it's just a cache, so should not return an err.
func (s *ClientRepository) GetAllByClientAuthID(clientAuthID string) []*Client {
all, _ := s.GetAll()
var res []*Client
for _, v := range all {
if v.ClientAuthID == clientAuthID {
res = append(res, v)
func (r *ClientRepository) GetAllByClientAuthID(clientAuthID string) []*Client {
availableClients := r.GetAllClients()
var matchingClients []*Client
// uses copy of clients array returned by GetAllClients
for _, c := range availableClients {
if c.GetClientAuthID() == clientAuthID {
matchingClients = append(matchingClients, c)
}
}
return res
return matchingClients
}
// GetAll returns all non-obsolete active and disconnected client clients.
func (s *ClientRepository) GetAll() ([]*Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.getNonObsolete()
func (r *ClientRepository) GetAllClients() []*Client {
availableClients, _ := r.getNonObsoleteClients()
return availableClients
}
// GetUserClients returns all non-obsolete active and disconnected clients that current user has access to
func (s *ClientRepository) GetUserClients(user User, groups []*cgroups.ClientGroup) ([]*Client, error) {
s.mu.RLock()
defer s.mu.RUnlock()
return s.getNonObsoleteByUser(user, groups)
func (r *ClientRepository) GetUserClients(user User, groups []*cgroups.ClientGroup) ([]*Client, error) {
return r.getNonObsoleteClientsByUser(user, groups)
}
// GetFilteredUserClients returns all non-obsolete active and disconnected clients that current user has access to, filtered by parameters
func (s *ClientRepository) GetFilteredUserClients(user User, filterOptions []query.FilterOption, groups []*cgroups.ClientGroup) ([]*CalculatedClient, error) {
s.mu.RLock()
defer s.mu.RUnlock()
clients, err := s.getNonObsoleteByUser(user, groups)
func (r *ClientRepository) GetFilteredUserClients(user User, filterOptions []query.FilterOption, groups []*cgroups.ClientGroup) ([]*CalculatedClient, error) {
clients, err := r.getNonObsoleteClientsByUser(user, groups)
if err != nil {
return nil, err
}
result := make([]*CalculatedClient, 0, len(clients))
matchingClients := make([]*CalculatedClient, 0, len(clients))
// uses copy of clients array returned by getNonObsoleteClientsByUser
for _, client := range clients {
calculatedClient := client.ToCalculated(groups)
// we need to lock because MatchesFilters receives an interface and not a client,
// therefore we lose our ability to lock.
calculatedClient.flock.RLock()
matches, err := query.MatchesFilters(calculatedClient, filterOptions)
calculatedClient.flock.RUnlock()
if err != nil {
return result, err
return matchingClients, err
}
if matches {
result = append(result, calculatedClient)
matchingClients = append(matchingClients, calculatedClient)
}
}
return result, nil
return matchingClients, nil
}
func (s *ClientRepository) GetAllActive() []*Client {
s.mu.RLock()
defer s.mu.RUnlock()
var result []*Client
for _, client := range s.clients {
if client.DisconnectedAt == nil {
result = append(result, client)
// GetAllActiveClients returns a new client array that can be used without locks (assuming not shared)
func (r *ClientRepository) GetAllActiveClients() (matchingClients []*Client, count int) {
count = 0
clients := r.getClients()
r.mu.RLock()
for _, client := range clients {
r.mu.RUnlock()
if client.GetDisconnectedAt() == nil {
matchingClients = append(matchingClients, client)
count++
}
r.mu.RLock()
}
return result
r.mu.RUnlock()
return matchingClients, count
}
func (s *ClientRepository) getNonObsolete() ([]*Client, error) {
result := make([]*Client, 0, len(s.clients))
for _, client := range s.clients {
if !client.Obsolete(s.KeepDisconnectedClients) {
result = append(result, client)
// getNonObsoleteClients returns a new client array that can be used without locks (assuming not shared)
func (r *ClientRepository) getNonObsoleteClients() (matchingClients []*Client, count int) {
count = 0
clients := r.getClients()
r.mu.RLock()
matchingClients = make([]*Client, 0, len(clients))
for _, client := range clients {
r.mu.RUnlock()
if !client.Obsolete(r.GetKeepDisconnectedClients()) {
matchingClients = append(matchingClients, client)
count++
}
r.mu.RLock()
}
return result, nil
r.mu.RUnlock()
return matchingClients, count
}
// getNonObsoleteByUser return connected clients the user has access to either by user group or by client group
func (s *ClientRepository) getNonObsoleteByUser(user User, clientGroups []*cgroups.ClientGroup) ([]*Client, error) {
// getNonObsoleteByUser return connected clients the user has access to either by user group or by client group.
// returns a new client array that can be used without locks (assuming not shared)
func (r *ClientRepository) getNonObsoleteClientsByUser(user User, clientGroups []*cgroups.ClientGroup) ([]*Client, error) {
clients := r.getClients()
userGroups := user.GetGroups()
result := make([]*Client, 0, len(s.clients))
for _, client := range s.clients {
if client.Obsolete(s.KeepDisconnectedClients) {
continue
}
if user.IsAdmin() || client.HasAccessViaUserGroups(userGroups) || client.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
result = append(result, client)
continue
r.mu.RLock()
matchingClients := make([]*Client, 0, len(clients))
for _, client := range clients {
r.mu.RUnlock()
if !client.Obsolete(r.GetKeepDisconnectedClients()) {
if user.IsAdmin() || client.HasAccessViaUserGroups(userGroups) || client.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
matchingClients = append(matchingClients, client)
}
}
r.mu.RLock()
}
return result, nil
r.mu.RUnlock()
return matchingClients, nil
}
func (r *ClientRepository) getStore() (store ClientStore) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.clientStore
}
func (r *ClientRepository) log() (l *logger.Logger) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.logger
}
// getClients returns the primary map of clients. accessing the return map must use locks.
func (r *ClientRepository) getClients() (clients map[string]*Client) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.clientState
}
func (r *ClientRepository) getClient(clientID string) (client *Client) {
r.mu.RLock()
client = r.clientState[clientID]
r.mu.RUnlock()
return client
}
func (r *ClientRepository) updateClient(client *Client) {
clientID := client.GetID()
r.mu.Lock()
r.clientState[clientID] = client
r.mu.Unlock()
}
func (r *ClientRepository) removeClient(clientID string) {
r.mu.Lock()
delete(r.clientState, clientID)
r.mu.Unlock()
}
func (r *ClientRepository) GetKeepDisconnectedClients() (keep *time.Duration) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.keepDisconnectedClients
}

View File

@ -40,29 +40,26 @@ func TestCRWithExpiration(t *testing.T) {
assert.NoError(repo.Save(c3))
assert.NoError(repo.Save(c4))
gotCount, err := repo.Count()
assert.NoError(err)
gotCount := repo.Count()
assert.Equal(3, gotCount)
gotCountActive, err := repo.CountActive()
assert.NoError(err)
gotCountActive := repo.CountActive()
assert.Equal(1, gotCountActive)
gotCountDisconnected, err := repo.CountDisconnected()
assert.NoError(err)
assert.Equal(2, gotCountDisconnected)
gotClients, err := repo.GetAll()
assert.NoError(err)
gotClients := repo.GetAllClients()
assert.ElementsMatch([]*Client{c1, c2, c3}, gotClients)
// active
gotClient, err := repo.GetActiveByID(c1.ID)
gotClient, err := repo.GetActiveByID(c1.GetID())
assert.NoError(err)
assert.Equal(c1, gotClient)
// disconnected
gotClient, err = repo.GetActiveByID(c2.ID)
gotClient, err = repo.GetActiveByID(c2.GetID())
assert.NoError(err)
assert.Nil(gotClient)
@ -70,12 +67,11 @@ func TestCRWithExpiration(t *testing.T) {
assert.NoError(err)
require.Len(t, deleted, 1)
assert.Equal(c4, deleted[0])
gotClients, err = repo.GetAll()
assert.NoError(err)
gotClients = repo.GetAllClients()
assert.ElementsMatch([]*Client{c1, c2, c3}, gotClients)
assert.NoError(repo.Delete(c3))
gotClients, err = repo.GetAll()
gotClients = repo.GetAllClients()
assert.NoError(err)
assert.ElementsMatch([]*Client{c1, c2}, gotClients)
}
@ -90,29 +86,26 @@ func TestCRWithNoExpiration(t *testing.T) {
assert := assert.New(t)
assert.NoError(repo.Save(c4Active))
gotCount, err := repo.Count()
assert.NoError(err)
gotCount := repo.Count()
assert.Equal(4, gotCount)
gotCountActive, err := repo.CountActive()
assert.NoError(err)
gotCountActive := repo.CountActive()
assert.Equal(2, gotCountActive)
gotCountDisconnected, err := repo.CountDisconnected()
assert.NoError(err)
assert.Equal(2, gotCountDisconnected)
gotClients, err := repo.GetAll()
assert.NoError(err)
gotClients := repo.GetAllClients()
assert.ElementsMatch([]*Client{c1, c2, c3, c4Active}, gotClients)
// active
gotClient, err := repo.GetActiveByID(c1.ID)
gotClient, err := repo.GetActiveByID(c1.GetID())
assert.NoError(err)
assert.Equal(c1, gotClient)
// disconnected
gotClient, err = repo.GetActiveByID(c2.ID)
gotClient, err = repo.GetActiveByID(c2.GetID())
assert.NoError(err)
assert.Nil(gotClient)
@ -121,8 +114,7 @@ func TestCRWithNoExpiration(t *testing.T) {
assert.Len(deleted, 0)
assert.NoError(repo.Delete(c4Active))
gotClients, err = repo.GetAll()
assert.NoError(err)
gotClients = repo.GetAllClients()
assert.ElementsMatch([]*Client{c1, c2, c3}, gotClients)
}
@ -521,7 +513,7 @@ func TestCRWithFilter(t *testing.T) {
actualClientIDs := make([]string, 0, len(actualClients))
for _, actualClient := range actualClients {
actualClientIDs = append(actualClientIDs, actualClient.ID)
actualClientIDs = append(actualClientIDs, actualClient.GetID())
}
assert.ElementsMatch(t, tc.expectedClientIDs, actualClientIDs)
@ -543,15 +535,15 @@ func TestCRWithUnsupportedFilter(t *testing.T) {
}
func TestGetUserClients(t *testing.T) {
c1 := New(t).Build() // no groups
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Build() // admin
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Build() // admin + group1
c4 := New(t).AllowedUserGroups([]string{"group1"}).Build() // group1
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Build() // group1 + group2
c6 := New(t).AllowedUserGroups([]string{"group2"}).Build() // group2
c7 := New(t).AllowedUserGroups([]string{"group3"}).Build() // group3
c8 := New(t).AllowedUserGroups([]string{"group2", "group3"}).Build() // group2 + group3
c9 := New(t).Build()
c1 := New(t).Logger(testLog).Build() // no groups
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Logger(testLog).Build() // admin
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Logger(testLog).Build() // admin + group1
c4 := New(t).AllowedUserGroups([]string{"group1"}).Logger(testLog).Build() // group1
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Logger(testLog).Build() // group1 + group2
c6 := New(t).AllowedUserGroups([]string{"group2"}).Logger(testLog).Build() // group2
c7 := New(t).AllowedUserGroups([]string{"group3"}).Logger(testLog).Build() // group3
c8 := New(t).AllowedUserGroups([]string{"group2", "group3"}).Logger(testLog).Build() // group2 + group3
c9 := New(t).Logger(testLog).Build()
allClients := []*Client{c1, c2, c3, c4, c5, c6, c7, c8, c9}
clientGroups := []*cgroups.ClientGroup{
@ -559,7 +551,7 @@ func TestGetUserClients(t *testing.T) {
ID: "1",
AllowedUserGroups: []string{"group6"},
Params: &cgroups.ClientParams{
ClientID: &cgroups.ParamValues{cgroups.Param(c9.ID)},
ClientID: &cgroups.ParamValues{cgroups.Param(c9.GetID())},
},
},
}
@ -675,7 +667,7 @@ func TestGetClientByTag(t *testing.T) {
}
for idx, cl := range matchingClients {
assert.Equal(t, tc.expectedClientIDs[idx], cl.ID)
assert.Equal(t, tc.expectedClientIDs[idx], cl.GetID())
}
})
}

View File

@ -204,6 +204,8 @@ func shallowCopy(c *Client) *Client {
Address: c.Address,
Tunnels: append([]*clienttunnel.Tunnel{}, c.Tunnels...),
DisconnectedAt: c.DisconnectedAt,
LastHeartbeatAt: c.LastHeartbeatAt,
ClientAuthID: c.ClientAuthID,
Logger: c.Logger,
}
}

View File

@ -2,12 +2,11 @@ package clients
import (
"fmt"
"time"
)
func FormatConnectionState(disconnectedAt *time.Time) string {
if disconnectedAt != nil {
return fmt.Sprintf("disconnected since %s", disconnectedAt)
func FormatConnectionState(client *Client) string {
if !client.IsConnected() {
return fmt.Sprintf("disconnected since %s", client.GetDisconnectedAtValue())
}
return "connected"
}

View File

@ -9,24 +9,25 @@ import (
// LoadInitialClients returns an initial Client Repository state populated with clients from the internal storage.
func LoadInitialClients(ctx context.Context, p ClientStore, logger *logger.Logger) ([]*Client, error) {
if logger != nil {
logger.Debugf("loading existing clients")
}
logger.Debugf("loading existing clients")
all, err := p.GetAll(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get clients: %v", err)
}
if logger != nil {
logger.Debugf("loaded %d clients", len(all))
}
logger.Debugf("loaded %d clients", len(all))
// mark previously connected clients as disconnected with current time
now := now()
for _, cur := range all {
if cur.DisconnectedAt == nil {
cur.SetDisconnected(&now)
err := p.Save(ctx, cur)
// setup a logger for the clients
clientLogger := logger.Fork("client")
for _, client := range all {
client.Logger = clientLogger
if client.IsConnected() {
client.SetDisconnectedAt(&now)
err := p.Save(ctx, client)
if err != nil {
return nil, fmt.Errorf("failed to save client: %v", err)
}

View File

@ -10,11 +10,12 @@ import (
func TestGetInitState(t *testing.T) {
ctx := context.Background()
c1 := New(t).Build()
c1 := New(t).ID("client-1").Logger(testLog).Build()
wantC1 := shallowCopy(c1)
wantC1.DisconnectedAt = &nowMock
c2 := New(t).DisconnectedDuration(5 * time.Minute).Build()
c3 := New(t).DisconnectedDuration(2 * time.Hour).Build()
wantC1.SetDisconnectedAt(&nowMock)
c2 := New(t).ID("client-2").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
c3 := New(t).ID("client-3").DisconnectedDuration(2 * time.Hour).Logger(testLog).Build()
testCases := []struct {
name string
@ -44,16 +45,18 @@ func TestGetInitState(t *testing.T) {
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// given
p := NewFakeClientProvider(t, &tc.expiration, tc.dbClients...)
defer p.Close()
// when
gotClients, gotErr := LoadInitialClients(ctx, p, nil)
// then
gotClients, gotErr := LoadInitialClients(ctx, p, testLog)
assert.NoError(t, gotErr)
assert.Len(t, gotClients, len(tc.wantRes))
// patch the client logger for the ElementsMatch check
for _, c := range gotClients {
c.Logger = testLog
}
assert.ElementsMatch(t, gotClients, tc.wantRes)
})
}

139
server/clients/payloads.go Normal file
View File

@ -0,0 +1,139 @@
package clients
import (
"time"
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
"github.com/cloudradar-monitoring/rport/share/clientconfig"
"github.com/cloudradar-monitoring/rport/share/models"
"github.com/cloudradar-monitoring/rport/share/query"
)
type ClientPayload struct {
ID *string `json:"id,omitempty"`
Name *string `json:"name,omitempty"`
Address *string `json:"address,omitempty"`
Hostname *string `json:"hostname,omitempty"`
OS *string `json:"os,omitempty"`
OSFullName *string `json:"os_full_name,omitempty"`
OSVersion *string `json:"os_version,omitempty"`
OSArch *string `json:"os_arch,omitempty"`
OSFamily *string `json:"os_family,omitempty"`
OSKernel *string `json:"os_kernel,omitempty"`
OSVirtualizationSystem *string `json:"os_virtualization_system,omitempty"`
OSVirtualizationRole *string `json:"os_virtualization_role,omitempty"`
NumCPUs *int `json:"num_cpus,omitempty"`
CPUFamily *string `json:"cpu_family,omitempty"`
CPUModel *string `json:"cpu_model,omitempty"`
CPUModelName *string `json:"cpu_model_name,omitempty"`
CPUVendor *string `json:"cpu_vendor,omitempty"`
MemoryTotal *uint64 `json:"mem_total,omitempty"`
Timezone *string `json:"timezone,omitempty"`
ClientAuthID *string `json:"client_auth_id,omitempty"`
Version *string `json:"version,omitempty"`
DisconnectedAt **time.Time `json:"disconnected_at,omitempty"`
LastHeartbeatAt **time.Time `json:"last_heartbeat_at,omitempty"`
ConnectionState *string `json:"connection_state,omitempty"`
IPv4 *[]string `json:"ipv4,omitempty"`
IPv6 *[]string `json:"ipv6,omitempty"`
Tags *[]string `json:"tags,omitempty"`
AllowedUserGroups *[]string `json:"allowed_user_groups,omitempty"`
Tunnels *[]*clienttunnel.Tunnel `json:"tunnels,omitempty"`
UpdatesStatus **models.UpdatesStatus `json:"updates_status,omitempty"`
ClientConfiguration **clientconfig.Config `json:"client_configuration,omitempty"`
Groups *[]string `json:"groups,omitempty"`
}
func ConvertToClientsPayload(clientsList []*CalculatedClient, fields []query.FieldsOption) []ClientPayload {
r := make([]ClientPayload, 0, len(clientsList))
for _, cur := range clientsList {
r = append(r, ConvertToClientPayload(cur, fields))
}
return r
}
func ConvertToClientPayload(client *CalculatedClient, fields []query.FieldsOption) ClientPayload { //nolint:gocyclo
requestedFields := query.RequestedFields(fields, "clients")
p := ClientPayload{}
for field := range OptionsSupportedFields["clients"] {
if len(fields) > 0 && !requestedFields[field] {
continue
}
client.flock.RLock()
defer client.flock.RUnlock()
switch field {
case "id":
id := client.ID
p.ID = &id
case "name":
name := client.Name
p.Name = &name
case "os":
p.OS = &client.OS
case "os_arch":
p.OSArch = &client.OSArch
case "os_family":
p.OSFamily = &client.OSFamily
case "os_kernel":
p.OSKernel = &client.OSKernel
case "hostname":
p.Hostname = &client.Hostname
case "ipv4":
p.IPv4 = &client.IPv4
case "ipv6":
p.IPv6 = &client.IPv6
case "tags":
p.Tags = &client.Tags
case "version":
p.Version = &client.Version
case "address":
p.Address = &client.Address
case "tunnels":
p.Tunnels = &client.Tunnels
case "disconnected_at":
disconnectedAt := client.DisconnectedAt
p.DisconnectedAt = &disconnectedAt
case "last_heartbeat_at":
lastHeartbeatAt := client.LastHeartbeatAt
p.LastHeartbeatAt = &lastHeartbeatAt
case "client_auth_id":
p.ClientAuthID = &client.ClientAuthID
case "os_full_name":
p.OSFullName = &client.OSFullName
case "os_version":
p.OSVersion = &client.OSVersion
case "os_virtualization_system":
p.OSVirtualizationSystem = &client.OSVirtualizationSystem
case "os_virtualization_role":
p.OSVirtualizationRole = &client.OSVirtualizationRole
case "cpu_family":
p.CPUFamily = &client.CPUFamily
case "cpu_model":
p.CPUModel = &client.CPUModel
case "cpu_model_name":
p.CPUModelName = &client.CPUModelName
case "cpu_vendor":
p.CPUVendor = &client.CPUVendor
case "timezone":
p.Timezone = &client.Timezone
case "num_cpus":
p.NumCPUs = &client.NumCPUs
case "mem_total":
p.MemoryTotal = &client.MemoryTotal
case "allowed_user_groups":
p.AllowedUserGroups = &client.AllowedUserGroups
case "updates_status":
p.UpdatesStatus = &client.UpdatesStatus
case "client_configuration":
p.ClientConfiguration = &client.ClientConfiguration
case "groups":
p.Groups = &client.Groups
case "connection_state":
connectionState := string(client.GetConnectionState())
p.ConnectionState = &connectionState
}
}
return p
}

View File

@ -7,7 +7,7 @@ import (
func SortByID(a []*CalculatedClient, desc bool) {
sort.Slice(a, func(i, j int) bool {
less := strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
less := strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
if desc {
return !less
}
@ -17,9 +17,9 @@ func SortByID(a []*CalculatedClient, desc bool) {
func SortByName(a []*CalculatedClient, desc bool) {
sort.Slice(a, func(i, j int) bool {
aiName := strings.ToLower(a[i].Name)
ajName := strings.ToLower(a[j].Name)
less := aiName < ajName || aiName == ajName && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
aiName := strings.ToLower(a[i].GetName())
ajName := strings.ToLower(a[j].GetName())
less := aiName < ajName || aiName == ajName && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
if desc {
return !less
}
@ -29,9 +29,9 @@ func SortByName(a []*CalculatedClient, desc bool) {
func SortByOS(a []*CalculatedClient, desc bool) {
sort.Slice(a, func(i, j int) bool {
aiOS := strings.ToLower(a[i].OS)
ajOS := strings.ToLower(a[j].OS)
less := aiOS < ajOS || aiOS == ajOS && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
aiOS := strings.ToLower(a[i].GetOS())
ajOS := strings.ToLower(a[j].GetOS())
less := aiOS < ajOS || aiOS == ajOS && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
if desc {
return !less
}
@ -41,9 +41,9 @@ func SortByOS(a []*CalculatedClient, desc bool) {
func SortByHostname(a []*CalculatedClient, desc bool) {
sort.Slice(a, func(i, j int) bool {
aiHostname := strings.ToLower(a[i].Hostname)
ajHostname := strings.ToLower(a[j].Hostname)
less := aiHostname < ajHostname || aiHostname == ajHostname && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
aiHostname := strings.ToLower(a[i].GetHostname())
ajHostname := strings.ToLower(a[j].GetHostname())
less := aiHostname < ajHostname || aiHostname == ajHostname && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
if desc {
return !less
}
@ -53,9 +53,9 @@ func SortByHostname(a []*CalculatedClient, desc bool) {
func SortByVersion(a []*CalculatedClient, desc bool) {
sort.Slice(a, func(i, j int) bool {
aiVersion := strings.ToLower(a[i].Version)
ajVersion := strings.ToLower(a[j].Version)
less := aiVersion < ajVersion || aiVersion == ajVersion && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
aiVersion := strings.ToLower(a[i].GetVersion())
ajVersion := strings.ToLower(a[j].GetVersion())
less := aiVersion < ajVersion || aiVersion == ajVersion && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
if desc {
return !less
}

View File

@ -11,16 +11,18 @@ import (
"github.com/jmoiron/sqlx"
"github.com/cloudradar-monitoring/rport/db/sqlite"
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
chshare "github.com/cloudradar-monitoring/rport/share/clientconfig"
"github.com/cloudradar-monitoring/rport/share/logger"
"github.com/cloudradar-monitoring/rport/share/models"
)
type ClientStore interface {
GetAll(ctx context.Context) ([]*Client, error)
Save(ctx context.Context, client *Client) error
DeleteObsolete(ctx context.Context) error
Delete(ctx context.Context, id string) error
DeleteObsolete(ctx context.Context, l *logger.Logger) error
Delete(ctx context.Context, id string, l *logger.Logger) error
Close() error
}
@ -48,6 +50,7 @@ func (p *SqliteProvider) GetAll(ctx context.Context) ([]*Client, error) {
return convertClientList(res), nil
}
// test only
func (p *SqliteProvider) get(ctx context.Context, id string) (*Client, error) {
res := &clientSqlite{}
err := p.db.GetContext(ctx, res, "SELECT * FROM clients WHERE id = ?", id)
@ -61,29 +64,54 @@ func (p *SqliteProvider) get(ctx context.Context, id string) (*Client, error) {
}
func (p *SqliteProvider) Save(ctx context.Context, client *Client) error {
_, err := p.db.NamedExecContext(
ctx,
"INSERT OR REPLACE INTO clients (id, client_auth_id, disconnected_at, details) VALUES (:id, :client_auth_id, :disconnected_at, :details)",
convertToSqlite(client),
)
_, err := sqlite.WithRetryWhenBusy(func() (result sql.Result, err error) {
clientForSQL := convertToSqlite(client)
_, err = p.db.NamedExecContext(
ctx,
"INSERT OR REPLACE INTO clients (id, client_auth_id, disconnected_at, details) VALUES (:id, :client_auth_id, :disconnected_at, :details)",
clientForSQL,
)
return nil, err
}, "save", client.Log())
return err
}
func (p *SqliteProvider) DeleteObsolete(ctx context.Context) error {
_, err := p.db.ExecContext(
ctx,
"DELETE FROM clients WHERE disconnected_at IS NOT NULL AND DATETIME(disconnected_at) < DATETIME(?) AND ?",
p.keepDisconnectedClientsStart(),
p.keepDisconnectedClients != nil,
)
func (p *SqliteProvider) DeleteObsolete(ctx context.Context, l *logger.Logger) error {
_, err := sqlite.WithRetryWhenBusy(func() (result sql.Result, err error) {
_, err = p.db.ExecContext(
ctx,
"DELETE FROM clients WHERE disconnected_at IS NOT NULL AND DATETIME(disconnected_at) < DATETIME(?) AND ?",
p.keepDisconnectedClientsStart(),
p.keepDisconnectedClients != nil,
)
return nil, err
}, "delete obsolete", l)
return err
}
func (p *SqliteProvider) Delete(ctx context.Context, id string) error {
_, err := p.db.ExecContext(ctx, "DELETE FROM clients WHERE id = ?", id)
func (p *SqliteProvider) Delete(ctx context.Context, id string, l *logger.Logger) error {
_, err := sqlite.WithRetryWhenBusy(func() (result sql.Result, err error) {
_, err = p.db.ExecContext(ctx, "DELETE FROM clients WHERE id = ?", id)
return nil, err
}, "delete", l)
return err
}
func (p *SqliteProvider) Close() error {
return p.db.Close()
}
func (p *SqliteProvider) keepDisconnectedClientsStart() time.Time {
t := now()
if p.keepDisconnectedClients != nil {
@ -92,45 +120,49 @@ func (p *SqliteProvider) keepDisconnectedClientsStart() time.Time {
return t
}
func convertToSqlite(v *Client) *clientSqlite {
if v == nil {
func convertToSqlite(c *Client) (res *clientSqlite) {
if c == nil {
return nil
}
res := &clientSqlite{
ID: v.ID,
ClientAuthID: v.ClientAuthID,
c.flock.RLock()
res = &clientSqlite{
ID: c.ID,
ClientAuthID: c.ClientAuthID,
Details: &clientDetails{
Name: v.Name,
OS: v.OS,
OSArch: v.OSArch,
OSFamily: v.OSFamily,
OSKernel: v.OSKernel,
Hostname: v.Hostname,
Version: v.Version,
Address: v.Address,
OSFullName: v.OSFullName,
OSVersion: v.OSVersion,
OSVirtualizationSystem: v.OSVirtualizationSystem,
OSVirtualizationRole: v.OSVirtualizationRole,
CPUFamily: v.CPUFamily,
CPUModel: v.CPUModel,
CPUModelName: v.CPUModelName,
CPUVendor: v.CPUVendor,
NumCPUs: v.NumCPUs,
MemoryTotal: v.MemoryTotal,
Timezone: v.Timezone,
IPv4: v.IPv4,
IPv6: v.IPv6,
Tags: v.Tags,
Tunnels: v.Tunnels,
AllowedUserGroups: v.AllowedUserGroups,
UpdatesStatus: v.UpdatesStatus,
ClientConfig: v.ClientConfiguration,
Name: c.Name,
OS: c.OS,
OSArch: c.OSArch,
OSFamily: c.OSFamily,
OSKernel: c.OSKernel,
Hostname: c.Hostname,
Version: c.Version,
Address: c.Address,
OSFullName: c.OSFullName,
OSVersion: c.OSVersion,
OSVirtualizationSystem: c.OSVirtualizationSystem,
OSVirtualizationRole: c.OSVirtualizationRole,
CPUFamily: c.CPUFamily,
CPUModel: c.CPUModel,
CPUModelName: c.CPUModelName,
CPUVendor: c.CPUVendor,
NumCPUs: c.NumCPUs,
MemoryTotal: c.MemoryTotal,
Timezone: c.Timezone,
IPv4: c.IPv4,
IPv6: c.IPv6,
Tags: c.Tags,
Tunnels: c.Tunnels,
AllowedUserGroups: c.AllowedUserGroups,
UpdatesStatus: c.UpdatesStatus,
ClientConfig: c.ClientConfiguration,
},
}
if v.DisconnectedAt != nil {
res.DisconnectedAt = sql.NullTime{Time: *v.DisconnectedAt, Valid: true}
c.flock.RUnlock()
if !c.IsConnected() {
res.DisconnectedAt = sql.NullTime{Time: c.GetDisconnectedAtValue(), Valid: true}
}
return res
}
@ -196,9 +228,9 @@ func (d *clientDetails) Value() (driver.Value, error) {
return string(b), nil
}
func (s *clientSqlite) convert() *Client {
func (s *clientSqlite) convert() (res *Client) {
d := s.Details
res := &Client{
res = &Client{
ID: s.ID,
ClientAuthID: s.ClientAuthID,
Name: d.Name,
@ -229,15 +261,11 @@ func (s *clientSqlite) convert() *Client {
ClientConfiguration: d.ClientConfig,
}
if s.DisconnectedAt.Valid {
res.DisconnectedAt = &s.DisconnectedAt.Time
res.SetDisconnectedAt(&s.DisconnectedAt.Time)
}
return res
}
func (p *SqliteProvider) Close() error {
return p.db.Close()
}
func convertClientList(list []*clientSqlite) []*Client {
res := make([]*Client, 0, len(list))
for _, cur := range list {

View File

@ -39,12 +39,12 @@ func TestClientsSqliteProvider(t *testing.T) {
assert.ElementsMatch(t, []*Client{c1, c2, c3, c4, c5}, gotAll)
// verify delete obsolete clients
gotObsolete, err := p.get(ctx, c5.ID)
gotObsolete, err := p.get(ctx, c5.GetID())
require.NoError(t, err)
require.EqualValues(t, c5, gotObsolete)
require.NoError(t, p.DeleteObsolete(ctx))
gotObsolete, err = p.get(ctx, c5.ID)
require.NoError(t, p.DeleteObsolete(ctx, testLog))
gotObsolete, err = p.get(ctx, c5.GetID())
require.NoError(t, err)
require.Nil(t, gotObsolete)
@ -61,7 +61,7 @@ func TestClientsSqliteProvider(t *testing.T) {
d := time.Date(2020, 11, 5, 12, 11, 20, 0, time.UTC)
c1.DisconnectedAt = &d
require.NoError(t, p.Save(ctx, c1))
gotUpdated, err := p.get(ctx, c1.ID)
gotUpdated, err := p.get(ctx, c1.GetID())
require.NoError(t, err)
require.EqualValues(t, c1, gotUpdated)
gotAll, err = p.GetAll(ctx)

View File

@ -92,17 +92,19 @@ func (c *FileProvider) Get(id string) (*ClientAuth, error) {
return nil, nil
}
func (c *FileProvider) Add(client *ClientAuth) (bool, error) {
func (c *FileProvider) Add(clientAuth *ClientAuth) (bool, error) {
idPswdPairs, err := c.load()
if err != nil {
return false, fmt.Errorf("failed to decode rport clients auth file: %v", err)
}
if _, ok := idPswdPairs[client.ID]; ok {
clientID := clientAuth.ID
if _, ok := idPswdPairs[clientID]; ok {
return false, nil
}
idPswdPairs[client.ID] = client.Password
idPswdPairs[clientID] = clientAuth.Password
if err := c.save(idPswdPairs); err != nil {
return false, fmt.Errorf("failed to encode rport clients auth file: %v", err)

View File

@ -2,6 +2,7 @@ package ports
import (
"fmt"
"sync"
mapset "github.com/deckarep/golang-set"
"github.com/shirou/gopsutil/v3/net"
@ -13,6 +14,7 @@ type PortDistributor struct {
allowedPorts mapset.Set
portsPools map[string]mapset.Set
mu sync.RWMutex
}
func NewPortDistributor(allowedPorts mapset.Set) *PortDistributor {
@ -22,6 +24,19 @@ func NewPortDistributor(allowedPorts mapset.Set) *PortDistributor {
}
}
func (d *PortDistributor) getPoolFromMap(protocol string) (pool mapset.Set) {
d.mu.RLock()
pool = d.portsPools[protocol]
d.mu.RUnlock()
return pool
}
func (d *PortDistributor) setPool(protocol string, pool mapset.Set) {
d.mu.Lock()
d.portsPools[protocol] = pool
d.mu.Unlock()
}
// NewPortDistributorForTests is used only for unit-testing.
func NewPortDistributorForTests(allowedPorts, tcpPortsPool, udpPortsPool mapset.Set) *PortDistributor {
return &PortDistributor{
@ -39,7 +54,8 @@ func (d *PortDistributor) GetRandomPort(protocol string) (int, error) {
subProtocols = []string{models.ProtocolTCP, models.ProtocolUDP}
}
for _, p := range subProtocols {
if d.portsPools[p] == nil {
pool := d.getPoolFromMap(p)
if pool == nil {
err := d.refresh(p)
if err != nil {
return 0, err
@ -54,7 +70,8 @@ func (d *PortDistributor) GetRandomPort(protocol string) (int, error) {
// Make sure port is removed from all pools for tcp+udp protocol
for _, p := range subProtocols {
d.portsPools[p].Remove(port)
pool := d.getPoolFromMap(p)
pool.Remove(port)
}
return port.(int), nil
@ -69,7 +86,7 @@ func (d *PortDistributor) IsPortBusy(protocol string, port int) bool {
}
func (d *PortDistributor) getPool(protocol string) mapset.Set {
pool := d.portsPools[protocol]
pool := d.getPoolFromMap(protocol)
if protocol == models.ProtocolTCPUDP {
pool = d.portsPools[models.ProtocolTCP].Intersect(d.portsPools[models.ProtocolUDP])
}
@ -94,12 +111,14 @@ func (d *PortDistributor) refresh(protocol string) error {
return err
}
d.portsPools[protocol] = d.allowedPorts.Difference(busyPorts)
pool := d.allowedPorts.Difference(busyPorts)
d.setPool(protocol, pool)
return nil
}
func ListBusyPorts(protocol string) (mapset.Set, error) {
result := mapset.NewThreadUnsafeSet()
result := mapset.NewSet()
connections, err := net.Connections(protocol)
if err != nil {
return nil, err

View File

@ -15,9 +15,9 @@ func TestPortDistributor(t *testing.T) {
for _, protocol := range []string{models.ProtocolTCP, models.ProtocolUDP, models.ProtocolTCPUDP} {
t.Run(protocol, func(t *testing.T) {
pd := NewPortDistributorForTests(
mapset.NewThreadUnsafeSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4, 5}),
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4, 5}),
mapset.NewSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
mapset.NewSetFromSlice([]interface{}{2, 3, 4, 5}),
mapset.NewSetFromSlice([]interface{}{2, 3, 4, 5}),
)
assert.Equal(t, true, pd.IsPortBusy(protocol, 1))

View File

@ -10,7 +10,7 @@ import (
)
func TryParsePortRanges(portRanges []string) (mapset.Set, error) {
result := mapset.NewThreadUnsafeSet()
result := mapset.NewSet()
for _, rangeStr := range portRanges {
rangeParts := strings.Split(rangeStr, "-")
if len(rangeParts) == 1 {
@ -59,7 +59,7 @@ func tryParsePortNumberRange(rangeStart, rangeEnd string) (mapset.Set, error) {
}
func setFromRange(start, end int) mapset.Set {
s := mapset.NewThreadUnsafeSet()
s := mapset.NewSet()
for i := 0; i <= end-start; i++ {
s.Add(start + i)
}

View File

@ -13,16 +13,20 @@ type Task interface {
// Run runs the given task periodically with a given interval between executions.
func Run(ctx context.Context, log *logger.Logger, task Task, interval time.Duration) {
log.Debugf("started")
tick := time.NewTicker(interval)
defer tick.Stop()
for {
select {
case <-tick.C:
log.Debugf("running")
if err := task.Run(ctx); err != nil {
log.Errorf("Task %T finished with an error: %v.", task, err)
log.Errorf("finished with an error: %v.", err)
}
log.Debugf("finished")
case <-ctx.Done():
log.Debugf("%T: context canceled", task)
tick.Stop()
log.Debugf("context canceled", task)
log.Debugf("stopped")
return
}
}

View File

@ -49,6 +49,8 @@ const (
cleanupAPISessionsInterval = time.Hour
cleanupJobsInterval = time.Hour
LogNumGoRoutinesInterval = time.Minute * 2
DefaultMaxClientDBConnections = 50
)
// Server represents a rport service
@ -168,11 +170,18 @@ func NewServer(ctx context.Context, config *chconfig.Config, opts *ServerOpts) (
// even if monitoring disabled, always create the monitoring service to support queries of past data etc
s.monitoringService = monitoring.NewService(monitoringProvider)
sourceOptions := config.Server.GetSQLiteDataSourceOptions()
// particularly the client.db needs performant db access, so allow multi-threaded access
// and use the RetryWhenBusy fn to ensure writes succeed if we get a busy error due to
// concurrent thread access.
sourceOptions.MaxOpenConnections = DefaultMaxClientDBConnections
s.clientDB, err = sqlite.New(
path.Join(config.Server.DataDir, "clients.db"),
clientsmigration.AssetNames(),
clientsmigration.Asset,
config.Server.GetSQLiteDataSourceOptions(),
sourceOptions,
)
if err != nil {
return nil, fmt.Errorf("failed to create clients DB instance: %v", err)
@ -265,8 +274,8 @@ func NewServer(ctx context.Context, config *chconfig.Config, opts *ServerOpts) (
func (s *Server) HandlePlusLicenseInfoAvailable() {
s.Logger.Debugf("received license info from rport-plus")
if s.clientListener != nil && s.clientListener.clientService != nil {
s.clientListener.clientService.UpdateClientStatus()
if s.clientListener != nil && s.clientListener.server.clientService != nil {
s.clientListener.server.clientService.UpdateClientStatus()
}
}
@ -309,19 +318,20 @@ func (s *Server) Run(ctx context.Context) error {
// TODO(m-terel): add graceful shutdown of background task
if s.config.Server.PurgeDisconnectedClients {
s.Infof("Period to keep disconnected clients is set to %v", s.config.Server.KeepDisconnectedClients)
go scheduler.Run(ctx, s.Logger, clients.NewCleanupTask(s.Logger, s.clientListener.clientService.GetRepo()), s.config.Server.PurgeDisconnectedClientsInterval)
go scheduler.Run(ctx, s.Logger, clients.NewCleanupTask(s.Logger, s.clientListener.server.clientService.GetRepo()), s.config.Server.PurgeDisconnectedClientsInterval)
s.Infof("Task to purge disconnected clients will run with interval %v", s.config.Server.PurgeDisconnectedClientsInterval)
} else {
s.Debugf("Task to purge disconnected clients disabled")
}
//Run a task to Check the client connections status by sending and receiving pings
go scheduler.Run(ctx, s.Logger, NewClientsStatusCheckTask(
clientsStatusCheckTask := NewClientsStatusCheckTask(
s.Logger,
s.clientListener.clientService.GetRepo(),
s.clientListener.server.clientService.GetRepo(),
s.config.Server.CheckClientsConnectionInterval,
s.config.Server.CheckClientsConnectionTimeout,
), s.config.Server.CheckClientsConnectionInterval)
)
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", clientsStatusCheckTask)), clientsStatusCheckTask, s.config.Server.CheckClientsConnectionInterval)
s.Infof("Task to check the clients connection status will run with interval %v", s.config.Server.CheckClientsConnectionInterval)
if s.config.Monitoring.Enabled {
@ -334,16 +344,19 @@ func (s *Server) Run(ctx context.Context) error {
cleaningPeriod = s.config.Monitoring.GetDataStorageDuration()
}
go scheduler.Run(ctx, s.Logger, monitoring.NewCleanupTask(s.Logger, s.monitoringService, cleaningPeriod), cleanupMeasurementsInterval)
monitoringCleanupTask := monitoring.NewCleanupTask(s.Logger, s.monitoringService, cleaningPeriod)
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", monitoringCleanupTask)), monitoringCleanupTask, cleanupMeasurementsInterval)
s.Infof("Task to cleanup measurements will run with interval %v", cleanupMeasurementsInterval)
} else {
s.Infof("Measurement disabled")
}
go scheduler.Run(ctx, s.Logger, session.NewCleanupTask(s.apiListener.apiSessions), cleanupAPISessionsInterval)
sessionsCleanupTask := session.NewCleanupTask(s.apiListener.apiSessions)
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", sessionsCleanupTask)), sessionsCleanupTask, cleanupAPISessionsInterval)
s.Infof("Task to cleanup expired api sessions will run with interval %v", cleanupAPISessionsInterval)
go scheduler.Run(ctx, s.Logger, jobs.NewCleanupTask(s.jobProvider, s.config.Server.JobsMaxResults), cleanupJobsInterval)
jobsCleanupTask := jobs.NewCleanupTask(s.jobProvider, s.config.Server.JobsMaxResults)
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", jobsCleanupTask)), jobsCleanupTask, cleanupJobsInterval)
s.Infof("Task to cleanup jobs will run with interval %v", cleanupJobsInterval)
// Only on debug mode, log the number of running go routines

View File

@ -217,8 +217,9 @@ func (al *APIListener) sendFileToClients(uploadRequest *UploadRequest) {
func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRequest *UploadRequest) {
for res := range resChan {
clientID := res.client.GetID()
output := &UploadOutput{
ClientID: res.client.ID,
ClientID: clientID,
UploadResponse: res.resp,
}
if res.err != nil {
@ -238,7 +239,7 @@ func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRe
errTxt,
uploadRequest.ID,
uploadRequest.DestinationPath,
res.client.ID,
clientID,
)
al.auditLog.Entry(auditlog.ApplicationUploads, auditlog.ActionFailed).
WithRequest(uploadRequest.UploadedFile).
@ -251,7 +252,7 @@ func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRe
"upload success, file id: %s, file path: %s, client %s",
uploadRequest.ID,
uploadRequest.DestinationPath,
res.client.ID,
clientID,
)
al.auditLog.Entry(auditlog.ApplicationUploads, auditlog.ActionSuccess).
WithRequest(uploadRequest.UploadedFile).
@ -268,7 +269,8 @@ func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRe
func (al *APIListener) sendFileToClient(wg *sync.WaitGroup, file *models.UploadedFile, cl *clients.Client, resChan chan *uploadResult) {
defer wg.Done()
if cl.ClientConfiguration != nil && !cl.ClientConfiguration.FileReceptionConfig.Enabled {
fileReceptionConfig := cl.GetFileReceptionConfig()
if fileReceptionConfig != nil && !fileReceptionConfig.Enabled {
resChan <- &uploadResult{
err: errors3.ErrUploadsDisabled,
client: cl,
@ -277,7 +279,7 @@ func (al *APIListener) sendFileToClient(wg *sync.WaitGroup, file *models.Uploade
return
}
resp := &models.UploadResponse{}
err := comm.SendRequestAndGetResponse(cl.Connection, comm.RequestTypeUpload, file, resp)
err := comm.SendRequestAndGetResponse(cl.GetConnection(), comm.RequestTypeUpload, file, resp, al.Log())
resChan <- &uploadResult{
err: err,

View File

@ -85,7 +85,7 @@ func TestHandleFileUploads(t *testing.T) {
useFsCallback: true,
fileName: "file.txt",
fileContent: "some content",
cl: clients.New(t).ID("22114341234").Build(),
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
formParts: map[string][]string{
"client_id": {
"22114341234",
@ -137,7 +137,7 @@ func TestHandleFileUploads(t *testing.T) {
useFsCallback: true,
fileName: "file.txt",
fileContent: "some content",
cl: clients.New(t).ID("22114341234").Build(),
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
clientTags: []string{"linux"},
formParts: map[string][]string{
"tags": {
@ -189,7 +189,7 @@ func TestHandleFileUploads(t *testing.T) {
useFsCallback: true,
fileName: "file.txt",
fileContent: "some content",
cl: clients.New(t).ID("22114341234").Build(),
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
formParts: map[string][]string{
"client_id": {
"22114341234",
@ -220,7 +220,7 @@ func TestHandleFileUploads(t *testing.T) {
useFsCallback: true,
fileName: "file.txt",
fileContent: "some content",
cl: clients.New(t).ID("22114341234").Build(),
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
formParts: map[string][]string{
"dest": {
"/destination/myfile.txt",
@ -242,7 +242,7 @@ func TestHandleFileUploads(t *testing.T) {
useFsCallback: true,
fileName: "file.txt",
fileContent: "some content",
cl: clients.New(t).ID("22114341234").Build(),
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
formParts: map[string][]string{
"tags": {
`{
@ -270,7 +270,7 @@ func TestHandleFileUploads(t *testing.T) {
useFsCallback: true,
fileName: "file.txt",
fileContent: "some content",
cl: clients.New(t).ID("22114341234").Build(),
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
formParts: map[string][]string{
"client_id": {
"22114341234",
@ -296,7 +296,7 @@ func TestHandleFileUploads(t *testing.T) {
useFsCallback: true,
fileName: "file.txt",
fileContent: "some content",
cl: clients.New(t).ID("22114341234").Build(),
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
formParts: map[string][]string{
"client_id": {
"22114341234",
@ -321,7 +321,7 @@ func TestHandleFileUploads(t *testing.T) {
cl := tc.cl
if tc.clientTags != nil {
cl.Tags = tc.clientTags
cl.SetTags(tc.clientTags)
}
connMock := test.NewConnMock()
@ -331,7 +331,7 @@ func TestHandleFileUploads(t *testing.T) {
done := make(chan bool)
connMock.DoneChannel = done
cl.Connection = connMock
cl.SetConnection(connMock)
fileAPIMock := test.NewFileAPIMock()
if tc.useFsCallback {

View File

@ -4,10 +4,12 @@ import (
"time"
"golang.org/x/crypto/ssh"
"github.com/cloudradar-monitoring/rport/share/logger"
)
func PingConnectionWithTimeout(conn ssh.Conn, timeout time.Duration) (ok bool, response []byte, rtt time.Duration, err error) {
func PingConnectionWithTimeout(conn ssh.Conn, timeout time.Duration, l *logger.Logger) (ok bool, response []byte, rtt time.Duration, err error) {
timerStart := time.Now()
ok, response, err = SendRequestWithTimeout(conn, RequestTypePing, true, nil, timeout)
ok, response, err = SendRequestWithTimeout(conn, RequestTypePing, true, nil, timeout, l)
return ok, response, time.Since(timerStart), err
}

View File

@ -4,7 +4,9 @@ package comm
import (
"context"
"encoding/json"
"errors"
"fmt"
"math/rand"
"time"
"golang.org/x/crypto/ssh"
@ -46,7 +48,7 @@ func ReplySuccessJSON(log *logger.Logger, req *ssh.Request, resp interface{}) {
// SendRequestAndGetResponse sends a given request, parses a returned response and stores a success result in a given destination value.
// Returns an error on a failure response or if an error happen. Error will be ClientError type if the error is a client error.
// Both request and response are expected to be JSON.
func SendRequestAndGetResponse(conn ssh.Conn, reqType string, req, successRespDest interface{}) error {
func SendRequestAndGetResponse(conn ssh.Conn, reqType string, req, successRespDest interface{}, l *logger.Logger) error {
reqBytes, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to encode request %T: %v", req, err)
@ -63,6 +65,8 @@ func SendRequestAndGetResponse(conn ssh.Conn, reqType string, req, successRespDe
if successRespDest != nil {
if err := json.Unmarshal(respBytes, successRespDest); err != nil {
l.Debugf("failed to unmarshal: respBytes: %s", string(respBytes))
l.Debugf("%#v", successRespDest)
return NewClientError(fmt.Errorf("invalid client response format: failed to decode response into %T: %v", successRespDest, err))
}
}
@ -85,30 +89,77 @@ func (e *ClientError) Error() string {
return e.err.Error()
}
func SendRequestWithTimeout(conn ssh.Conn, name string, wantReplay bool, payload []byte, timeout time.Duration) (bool, []byte, error) {
type requestResponse struct {
ok bool
response []byte
err error
}
func SendRequestWithTimeout(conn ssh.Conn, name string, wantReply bool, payload []byte, timeout time.Duration, l *logger.Logger) (bool, []byte, error) {
var (
ok bool
response []byte
err error
)
if conn == nil {
return false, nil, errors.New("cannot send request when conn is nil")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan bool, 1)
ch := make(chan requestResponse, 1)
go func() {
ok, response, err = conn.SendRequest(name, wantReplay, payload)
ok, response, err = conn.SendRequest(name, wantReply, payload)
select {
default:
ch <- true
ch <- requestResponse{
ok: ok,
response: response,
err: err,
}
case <-ctx.Done():
l.Debugf("send canceled")
return
}
}()
reqTimeout := time.NewTimer(timeout)
defer reqTimeout.Stop()
select {
case <-ch:
return ok, response, err
case res := <-ch:
reqTimeout.Stop()
return res.ok, res.response, res.err
case <-reqTimeout.C:
return false, nil, TimeoutError{fmt.Errorf("conn.SendRequest(%s), timeout %s exceeded", name, timeout)}
}
}
const DefaultMaxRetryAttempts = 3
func WithRetry[R any](retryAbleFn func() (result R, err error), canRetry func(err error) (shouldRetry bool), minRetryWaitDuration time.Duration, label string, l *logger.Logger) (result R, err error) {
for r := 0; r < DefaultMaxRetryAttempts; r++ {
attempt := r + 1
// l.Debugf("%s: attempt %d", label, attempt)
if r > 0 {
// backoff with some jitter
delay := (minRetryWaitDuration * time.Duration(r*r)) + time.Duration(rand.Intn(1000))*time.Millisecond
l.Debugf("%s: attempt %d failed. will sleep for: %d seconds", label, attempt, delay/time.Second)
time.Sleep(delay)
}
result, err = retryAbleFn()
if err != nil {
l.Debugf("%s: attempt %d err = %+v\n", label, attempt, err)
if !canRetry(err) {
// non retryable err
l.Debugf("%s: attempt %d non-retryable err %v", label, attempt, err)
return result, err
}
continue
}
// success
return result, nil
}
return result, err
}

View File

@ -5,7 +5,7 @@ import (
)
func SetFromRange(start, end int) mapset.Set {
s := mapset.NewThreadUnsafeSet()
s := mapset.NewSet()
for i := 0; i <= end-start; i++ {
s.Add(start + i)
}