mirror of
https://github.com/openrport/openrport.git
synced 2025-10-26 11:27:11 +00:00
Initial version with improved server concurrency and better client retry
handling align client service param names with master
This commit is contained in:
parent
0e8ea01a65
commit
5bfab9fac7
2
Makefile
2
Makefile
@ -12,7 +12,7 @@ build-debug:
|
||||
$(foreach BINARY,$(BINARIES),go build -race -gcflags "all=-N -l" -o $(BINARY) -v ./cmd/$(BINARY)/...;)
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
go test -race -v ./...
|
||||
|
||||
bind-data:
|
||||
cd db/migration/jobs/sql/ && go-bindata -o ../bindata.go -pkg jobs ./...
|
||||
|
||||
306
client/client.go
306
client/client.go
@ -34,7 +34,16 @@ import (
|
||||
"github.com/cloudradar-monitoring/rport/share/models"
|
||||
)
|
||||
|
||||
const ConnectionTimeout = 10 * time.Second
|
||||
const DialTimeout = 5 * 60 * time.Second
|
||||
const AuthTimeout = 30 * time.Second
|
||||
const MinConnectionBackoffWaitTime = 5 * time.Second
|
||||
const MaxConnectionBackoffWaitTime = 10 * 60 * time.Second
|
||||
const ServerReconnectRequestBackoffTime = 3 * 60 * time.Second
|
||||
const InitialConnectionRequestSendDelayJitterMilliseconds = 10000
|
||||
const SendRequestTimeout = 30 * time.Second
|
||||
const MinSendRequestRetryWaitTime = 1 * time.Second
|
||||
const BackoffOnServerTimeoutMaxDuration = 1 * time.Second
|
||||
const MaxKeepAliveJitterMilliseconds = 5000
|
||||
|
||||
// Client represents a client instance
|
||||
type Client struct {
|
||||
@ -66,13 +75,14 @@ func NewClient(config *ClientConfigHolder, filesAPI files.FileAPI) (*Client, err
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create initial session id: %s", err)
|
||||
}
|
||||
|
||||
cmdExec := system.NewCmdExecutor(logger.NewLogger("cmd executor", config.Logging.LogOutput, config.Logging.LogLevel))
|
||||
logger := logger.NewLogger("client", config.Logging.LogOutput, config.Logging.LogLevel)
|
||||
watchdog, err := NewWatchdog(config.Connection.WatchdogIntegration, config.Client.DataDir, logger)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create watchdog: %s", err)
|
||||
}
|
||||
logger.Infof("Client started with sessionID %s", sessionID)
|
||||
|
||||
systemInfo := system.NewSystemInfo(cmdExec)
|
||||
client := &Client{
|
||||
SessionID: sessionID,
|
||||
@ -93,9 +103,10 @@ func NewClient(config *ClientConfigHolder, filesAPI files.FileAPI) (*Client, err
|
||||
Auth: []ssh.AuthMethod{ssh.Password(config.Client.AuthPass)},
|
||||
ClientVersion: "SSH-" + chshare.ProtocolVersion + "-client",
|
||||
HostKeyCallback: client.verifyServer,
|
||||
Timeout: 30 * time.Second,
|
||||
Timeout: AuthTimeout,
|
||||
}
|
||||
|
||||
logger.Infof("New client instance with sessionID %s", sessionID)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@ -127,8 +138,9 @@ func (c *Client) Start(ctx context.Context) error {
|
||||
c.Infof("Keepalive job (client to server ping) started with interval %s", c.configHolder.Connection.KeepAlive)
|
||||
go c.keepAliveLoop()
|
||||
}
|
||||
|
||||
//connection loop
|
||||
go c.connectionLoop(ctx)
|
||||
go c.connectionLoop(ctx, true)
|
||||
|
||||
c.updates.Start(ctx)
|
||||
|
||||
@ -137,19 +149,28 @@ func (c *Client) Start(ctx context.Context) error {
|
||||
|
||||
func (c *Client) keepAliveLoop() {
|
||||
for c.running {
|
||||
time.Sleep(c.configHolder.Connection.KeepAlive)
|
||||
time.Sleep(c.configHolder.Connection.KeepAlive + (time.Duration(rand.Intn(MaxKeepAliveJitterMilliseconds)))*time.Millisecond)
|
||||
|
||||
c.mu.RLock()
|
||||
conn := c.sshConn
|
||||
c.mu.RUnlock()
|
||||
|
||||
if conn != nil {
|
||||
ok, _, rtt, err := comm.PingConnectionWithTimeout(conn, c.configHolder.Connection.KeepAliveTimeout)
|
||||
if err != nil || !ok {
|
||||
|
||||
res, err := comm.WithRetry(func() (res *sendResponse, err error) {
|
||||
ok, _, rtt, err := comm.PingConnectionWithTimeout(conn, c.configHolder.Connection.KeepAliveTimeout, c.Logger)
|
||||
return &sendResponse{
|
||||
replyOk: ok,
|
||||
rtt: rtt,
|
||||
respBytes: nil,
|
||||
}, err
|
||||
}, canRetryFn, MinSendRequestRetryWaitTime, "ping", c.Logger)
|
||||
|
||||
if err != nil || !res.replyOk {
|
||||
c.Errorf("Failed to send keepalive (client to server ping): %s", err)
|
||||
c.sshConn.Close()
|
||||
} else {
|
||||
msg := fmt.Sprintf("ping to %s succeeded within %s", conn.RemoteAddr(), rtt)
|
||||
msg := fmt.Sprintf("ping to %s succeeded within %s", conn.RemoteAddr(), res.rtt)
|
||||
c.Debugf(msg)
|
||||
c.watchdog.Ping(WatchdogStateConnected, msg)
|
||||
}
|
||||
@ -157,32 +178,28 @@ func (c *Client) keepAliveLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) connectionLoop(ctx context.Context) {
|
||||
func (c *Client) connectionLoop(ctx context.Context, withInitialSendRequestDelay bool) {
|
||||
//connection loop!
|
||||
var connerr error
|
||||
switchbackChan := make(chan *sshClientConn, 1)
|
||||
b := &backoff.Backoff{Max: c.configHolder.Connection.MaxRetryInterval}
|
||||
backoff := &backoff.Backoff{
|
||||
Min: MinConnectionBackoffWaitTime + time.Duration(rand.Intn(60)),
|
||||
Max: MaxConnectionBackoffWaitTime,
|
||||
Jitter: true,
|
||||
}
|
||||
|
||||
for c.running {
|
||||
if connerr != nil {
|
||||
attempt := int(b.Attempt())
|
||||
var d = b.Duration()
|
||||
c.showConnectionError(connerr, attempt)
|
||||
if c.configHolder.Connection.MaxRetryCount >= 0 && attempt >= c.configHolder.Connection.MaxRetryCount {
|
||||
break // Stop trying to connect if the user has set a max retry limit
|
||||
stopRetrying := c.handleConnectionError(backoff, connerr)
|
||||
if stopRetrying {
|
||||
break
|
||||
}
|
||||
if _, ok := connerr.(comm.TimeoutError); ok {
|
||||
// Timeout means the server is available. No need to wait up to 5 min to try again.
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
d = time.Duration(rand.Intn(20)) * time.Second
|
||||
b.Reset()
|
||||
}
|
||||
msg := fmt.Sprintf("Retrying in %s...", d)
|
||||
c.Infof(msg)
|
||||
c.watchdog.Ping(WatchdogStateReconnecting, msg)
|
||||
connerr = nil
|
||||
chshare.SleepSignal(d)
|
||||
}
|
||||
|
||||
c.Logger.Debugf("conn loop attempt = %d", int(backoff.Attempt()))
|
||||
|
||||
// make the connection attempt
|
||||
var sshConn *sshClientConn
|
||||
var isPrimary bool
|
||||
select {
|
||||
@ -204,46 +221,36 @@ func (c *Client) connectionLoop(ctx context.Context) {
|
||||
|
||||
switchbackCtx, cancelSwitchback := context.WithCancel(ctx)
|
||||
if !isPrimary {
|
||||
go func() {
|
||||
for {
|
||||
switchbackTimer := time.NewTimer(c.configHolder.Client.ServerSwitchbackInterval)
|
||||
select {
|
||||
case <-switchbackCtx.Done():
|
||||
switchbackTimer.Stop()
|
||||
return
|
||||
case <-switchbackTimer.C:
|
||||
switchbackConn, err := c.connect(c.configHolder.Client.Server)
|
||||
if err != nil {
|
||||
c.Errorf("Switchback failed: %v", err.Error())
|
||||
continue
|
||||
}
|
||||
c.Infof("Connected to main server, switching back.")
|
||||
switchbackChan <- switchbackConn
|
||||
sshConn.Connection.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go c.handleServerSwitchBack(switchbackCtx, switchbackChan, sshConn)
|
||||
}
|
||||
|
||||
err := c.sendConnectionRequest(ctx, sshConn.Connection)
|
||||
if withInitialSendRequestDelay {
|
||||
delay := time.Duration(rand.Intn(InitialConnectionRequestSendDelayJitterMilliseconds)) * time.Millisecond
|
||||
c.Logger.Debugf("waiting for %d milliseconds before sending connection request", delay/time.Millisecond)
|
||||
time.Sleep(delay)
|
||||
}
|
||||
|
||||
err := c.sendConnectionRequest(ctx, sshConn.Connection, MinSendRequestRetryWaitTime)
|
||||
if err != nil {
|
||||
// Connection request has failed, we try again
|
||||
cancelSwitchback()
|
||||
connerr = err
|
||||
continue
|
||||
}
|
||||
// Connection request has succeeded
|
||||
b.Reset()
|
||||
|
||||
// Connection request has succeeded
|
||||
backoff.Reset()
|
||||
|
||||
// Hand over the open SSH connection to the client
|
||||
c.mu.Lock()
|
||||
c.sshConn = sshConn.Connection // Hand over the open SSH connection to the client
|
||||
c.sshConn = sshConn.Connection
|
||||
c.mu.Unlock()
|
||||
|
||||
c.updates.SetConn(sshConn.Connection)
|
||||
c.monitor.SetConn(sshConn.Connection)
|
||||
|
||||
err = sshConn.Connection.Wait() // Block aka wait until the connection is closed
|
||||
// now wait with the client handling SSH Requests and Channel Connections
|
||||
err = sshConn.Connection.Wait()
|
||||
|
||||
c.mu.Lock()
|
||||
//disconnected
|
||||
@ -262,9 +269,57 @@ func (c *Client) connectionLoop(ctx context.Context) {
|
||||
|
||||
c.Infof("Disconnected\n")
|
||||
}
|
||||
|
||||
close(c.runningc)
|
||||
}
|
||||
|
||||
func (c *Client) handleConnectionError(backoff *backoff.Backoff, connerr error) (stopRetrying bool) {
|
||||
attempt := int(backoff.Attempt())
|
||||
|
||||
c.showConnectionError(connerr, attempt)
|
||||
|
||||
// check if the user has set a max retry limit
|
||||
if c.configHolder.Connection.MaxRetryCount >= 0 && attempt >= c.configHolder.Connection.MaxRetryCount {
|
||||
return true // if so, stop trying
|
||||
}
|
||||
|
||||
var d = backoff.Duration()
|
||||
if _, ok := connerr.(comm.TimeoutError); ok {
|
||||
// Timeout means the server isn't offline, so reset the backoff and use an initial short retry duration
|
||||
backoff.Reset()
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
d = time.Duration(rand.Intn(int(backoff.Attempt()))) * BackoffOnServerTimeoutMaxDuration
|
||||
}
|
||||
msg := fmt.Sprintf("Retrying in %s...", d)
|
||||
c.Infof(msg)
|
||||
// TODO: (rs): what is this watchdog ping?
|
||||
c.watchdog.Ping(WatchdogStateReconnecting, msg)
|
||||
chshare.SleepSignal(d)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *Client) handleServerSwitchBack(switchbackCtx context.Context, switchbackChan chan *sshClientConn, sshConn *sshClientConn) {
|
||||
for {
|
||||
switchbackTimer := time.NewTimer(c.configHolder.Client.ServerSwitchbackInterval)
|
||||
select {
|
||||
case <-switchbackCtx.Done():
|
||||
switchbackTimer.Stop()
|
||||
return
|
||||
case <-switchbackTimer.C:
|
||||
switchbackConn, err := c.connect(c.configHolder.Client.Server)
|
||||
if err != nil {
|
||||
c.Errorf("Switchback failed: %v", err.Error())
|
||||
continue
|
||||
}
|
||||
c.Infof("Connected to main server, switching back.")
|
||||
switchbackChan <- switchbackConn
|
||||
sshConn.Connection.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sshClientConn struct {
|
||||
Connection ssh.Conn
|
||||
Channels <-chan ssh.NewChannel
|
||||
@ -290,54 +345,24 @@ func (c *Client) connect(server string) (*sshClientConn, error) {
|
||||
}
|
||||
c.Infof("Trying to connect to %s%s ...\n", server, via)
|
||||
|
||||
netDialer := &net.Dialer{}
|
||||
d := websocket.Dialer{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
Subprotocols: []string{chshare.ProtocolVersion},
|
||||
NetDialContext: netDialer.DialContext,
|
||||
d, netDialer, err := c.setupDialer()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if c.configHolder.Client.BindInterface != "" {
|
||||
laddr, err := c.localAddrForInterface(c.configHolder.Client.BindInterface)
|
||||
|
||||
//optionally proxy
|
||||
if c.configHolder.Client.ProxyURL != nil {
|
||||
err := c.addDialerProxySupport(d, netDialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netDialer.LocalAddr = laddr
|
||||
}
|
||||
//optionally proxy
|
||||
if c.configHolder.Client.ProxyURL != nil {
|
||||
if strings.HasPrefix(c.configHolder.Client.ProxyURL.Scheme, "socks") {
|
||||
// SOCKS5 proxy
|
||||
if c.configHolder.Client.ProxyURL.Scheme != "socks" && c.configHolder.Client.ProxyURL.Scheme != "socks5h" {
|
||||
return nil, fmt.Errorf(
|
||||
"unsupported socks proxy type: %s:// (only socks5h:// or socks:// is supported)",
|
||||
c.configHolder.Client.ProxyURL.Scheme)
|
||||
}
|
||||
var auth *proxy.Auth
|
||||
if c.configHolder.Client.ProxyURL.User != nil {
|
||||
pass, _ := c.configHolder.Client.ProxyURL.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: c.configHolder.Client.ProxyURL.User.Username(),
|
||||
Password: pass,
|
||||
}
|
||||
}
|
||||
socksDialer, err := proxy.SOCKS5("tcp", c.configHolder.Client.ProxyURL.Host, auth, netDialer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.NetDialContext = socksDialer.(proxy.ContextDialer).DialContext
|
||||
} else {
|
||||
// CONNECT proxy
|
||||
d.Proxy = func(*http.Request) (*url.URL, error) {
|
||||
return c.configHolder.Client.ProxyURL, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wsConn, _, err := d.Dial(server, c.configHolder.Connection.HTTPHeaders)
|
||||
if err != nil {
|
||||
return nil, ConnectionErrorHints(server, c.Logger, err)
|
||||
}
|
||||
|
||||
conn := chshare.NewWebSocketConn(wsConn)
|
||||
// perform SSH handshake on net.Conn
|
||||
c.Debugf("Handshaking...")
|
||||
@ -349,6 +374,7 @@ func (c *Client) connect(server string) (*sshClientConn, error) {
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &sshClientConn{
|
||||
Connection: sshConn,
|
||||
Requests: reqs,
|
||||
@ -356,7 +382,69 @@ func (c *Client) connect(server string) (*sshClientConn, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn) error {
|
||||
func (c *Client) setupDialer() (d *websocket.Dialer, netDialer *net.Dialer, err error) {
|
||||
netDialer = &net.Dialer{}
|
||||
d = &websocket.Dialer{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
HandshakeTimeout: DialTimeout,
|
||||
Subprotocols: []string{chshare.ProtocolVersion},
|
||||
NetDialContext: netDialer.DialContext,
|
||||
}
|
||||
if c.configHolder.Client.BindInterface != "" {
|
||||
laddr, err := c.localAddrForInterface(c.configHolder.Client.BindInterface)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
netDialer.LocalAddr = laddr
|
||||
}
|
||||
|
||||
return d, netDialer, err
|
||||
}
|
||||
|
||||
func (c *Client) addDialerProxySupport(d *websocket.Dialer, netDialer *net.Dialer) (err error) {
|
||||
if strings.HasPrefix(c.configHolder.Client.ProxyURL.Scheme, "socks") {
|
||||
// SOCKS5 proxy
|
||||
if c.configHolder.Client.ProxyURL.Scheme != "socks" && c.configHolder.Client.ProxyURL.Scheme != "socks5h" {
|
||||
return fmt.Errorf(
|
||||
"unsupported socks proxy type: %s:// (only socks5h:// or socks:// is supported)",
|
||||
c.configHolder.Client.ProxyURL.Scheme)
|
||||
}
|
||||
var auth *proxy.Auth
|
||||
if c.configHolder.Client.ProxyURL.User != nil {
|
||||
pass, _ := c.configHolder.Client.ProxyURL.User.Password()
|
||||
auth = &proxy.Auth{
|
||||
User: c.configHolder.Client.ProxyURL.User.Username(),
|
||||
Password: pass,
|
||||
}
|
||||
}
|
||||
socksDialer, err := proxy.SOCKS5("tcp", c.configHolder.Client.ProxyURL.Host, auth, netDialer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.NetDialContext = socksDialer.(proxy.ContextDialer).DialContext
|
||||
} else {
|
||||
// CONNECT proxy
|
||||
d.Proxy = func(*http.Request) (*url.URL, error) {
|
||||
return c.configHolder.Client.ProxyURL, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type sendResponse struct {
|
||||
replyOk bool
|
||||
respBytes []byte
|
||||
rtt time.Duration
|
||||
}
|
||||
|
||||
func canRetryFn(err error) (can bool) {
|
||||
// if a timeout err, retry on the existing connection
|
||||
return strings.Contains(err.Error(), "timeout")
|
||||
}
|
||||
|
||||
func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn, minRetryWaitDuration time.Duration) error {
|
||||
connReq, err := c.connectionRequest(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -366,19 +454,38 @@ func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn) er
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not encode connection request: %v", err)
|
||||
}
|
||||
|
||||
c.Infof("Sending connection request.")
|
||||
c.Debugf("Sending connection request with client details %s", string(req))
|
||||
t0 := time.Now()
|
||||
replyOk, respBytes, err := comm.SendRequestWithTimeout(sshConn, "new_connection", true, req, ConnectionTimeout)
|
||||
|
||||
res, err := comm.WithRetry(func() (res *sendResponse, err error) {
|
||||
replyOk, respBytes, err := comm.SendRequestWithTimeout(sshConn, "new_connection", true, req, SendRequestTimeout, c.Logger)
|
||||
return &sendResponse{
|
||||
replyOk: replyOk,
|
||||
respBytes: respBytes,
|
||||
}, err
|
||||
}, canRetryFn, minRetryWaitDuration, "Connection Request", c.Logger)
|
||||
|
||||
if err != nil {
|
||||
if err2 := sshConn.Close(); err2 != nil {
|
||||
c.Errorf("Failed to close connection: %s", err2)
|
||||
c.Errorf("connection request err = %v", err)
|
||||
if closeErr := sshConn.Close(); closeErr != nil {
|
||||
c.Errorf("Failed to close connection: %s", closeErr)
|
||||
}
|
||||
reconnect := strings.Contains(err.Error(), "reconnect")
|
||||
if reconnect {
|
||||
reconnectDelay := ServerReconnectRequestBackoffTime + (time.Duration(rand.Intn(30)) * time.Second)
|
||||
c.Debugf("waiting %d seconds before reconnect", reconnectDelay/time.Second)
|
||||
// this probably means the server is too busy for us. wait quite a while
|
||||
// before returning to the conn loop.
|
||||
time.Sleep(reconnectDelay)
|
||||
}
|
||||
return err
|
||||
}
|
||||
c.Debugf("Connection request has been answered successfully within %s.", time.Since(t0))
|
||||
if !replyOk {
|
||||
msg := string(respBytes)
|
||||
|
||||
c.Debugf("Connection request has been answered within %s.", time.Since(t0))
|
||||
if !res.replyOk {
|
||||
msg := string(res.respBytes)
|
||||
|
||||
// if replied with client credentials already used - retry
|
||||
if strings.Contains(msg, "client is already connected:") {
|
||||
@ -390,14 +497,17 @@ func (c *Client) sendConnectionRequest(ctx context.Context, sshConn ssh.Conn) er
|
||||
|
||||
return errors.New(msg)
|
||||
}
|
||||
|
||||
var remotes []*models.Remote
|
||||
err = json.Unmarshal(respBytes, &remotes)
|
||||
err = json.Unmarshal(res.respBytes, &remotes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("can't decode reply payload: %s", err)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf("Connected to %s within %s", sshConn.RemoteAddr().String(), time.Since(t0))
|
||||
c.watchdog.Ping(WatchdogStateConnected, msg)
|
||||
c.Infof(msg)
|
||||
|
||||
for _, r := range remotes {
|
||||
c.Infof("New tunnel: %s", r.String())
|
||||
|
||||
@ -453,12 +563,12 @@ func (c *Client) handleSSHRequests(ctx context.Context, sshConn *sshClientConn)
|
||||
sshConn.Connection,
|
||||
system.SysUserProvider{},
|
||||
)
|
||||
|
||||
resp, err = uploadManager.HandleUploadRequest(r.Payload)
|
||||
case comm.RequestTypeCheckTunnelAllowed:
|
||||
resp, err = c.checkTunnelAllowed(r.Payload)
|
||||
case comm.RequestTypePing:
|
||||
_ = r.Reply(true, nil)
|
||||
continue
|
||||
default:
|
||||
c.Debugf("Unknown request: %q", r.Type)
|
||||
comm.ReplyError(c.Logger, r, errors.New("unknown request"))
|
||||
|
||||
@ -428,11 +428,11 @@ func (m *mockServer) IsConnected() bool {
|
||||
}
|
||||
|
||||
func (m *mockServer) WaitForStatus(isConnected bool) error {
|
||||
for i := 0; i < 1500; i++ {
|
||||
for i := 0; i < 60; i++ {
|
||||
if m.IsConnected() == isConnected {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("timeout waiting for isConnected=%v", isConnected)
|
||||
}
|
||||
@ -492,7 +492,7 @@ func TestConnectionLoop(t *testing.T) {
|
||||
c, err := NewClient(&config, test.NewFileAPIMock())
|
||||
require.NoError(t, err)
|
||||
|
||||
go c.connectionLoop(context.Background())
|
||||
go c.connectionLoop(context.Background(), false)
|
||||
|
||||
// connects to main server successfully
|
||||
assert.NoError(t, mainServer.WaitForStatus(true))
|
||||
|
||||
@ -27,7 +27,7 @@ func ConnectionErrorHints(server string, logger *logger.Logger, err error) error
|
||||
if allHeaders != "" {
|
||||
logger.Debugf("headers collected while detecting proxy: %s", allHeaders)
|
||||
}
|
||||
return fmt.Errorf("%s - Check your client credentials AND check for tranparent proxies", err)
|
||||
return fmt.Errorf("%s - Server maybe busy. Also check your client credentials AND check for tranparent proxies", err)
|
||||
default:
|
||||
return err
|
||||
}
|
||||
@ -26,21 +26,22 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultKeepDisconnectedClients = time.Hour
|
||||
DefaultPurgeDisconnectedClientsInterval = 1 * time.Minute
|
||||
DefaultCheckClientsConnectionInterval = 5 * time.Minute
|
||||
DefaultCheckClientsConnectionTimeout = 30 * time.Second
|
||||
DefaultMaxRequestBytes = 10 * 1024 // 10 KB
|
||||
DefaultMaxRequestBytesClient = 512 * 1024 // 512KB
|
||||
DefaultMaxFilePushBytes = int64(10 << 20) // 10M
|
||||
DefaultCheckPortTimeout = 2 * time.Second
|
||||
DefaultUsedPorts = "20000-30000"
|
||||
DefaultExcludedPorts = "1-1024"
|
||||
DefaultServerAddress = "0.0.0.0:8080"
|
||||
DefaultLogLevel = "info"
|
||||
DefaultRunRemoteCmdTimeoutSec = 60
|
||||
DefaultMonitoringDataStorageDuration = "7d"
|
||||
DefaultPairingURL = "https://pairing.rport.io"
|
||||
DefaultMaxConcurrentSSHConnectionHandshakes = 4
|
||||
DefaultKeepDisconnectedClients = time.Hour
|
||||
DefaultPurgeDisconnectedClientsInterval = 1 * time.Minute
|
||||
DefaultCheckClientsConnectionInterval = 5 * time.Minute
|
||||
DefaultCheckClientsConnectionTimeout = 30 * time.Second
|
||||
DefaultMaxRequestBytes = 10 * 1024 // 10 KB
|
||||
DefaultMaxRequestBytesClient = 512 * 1024 // 512KB
|
||||
DefaultMaxFilePushBytes = int64(10 << 20) // 10M
|
||||
DefaultCheckPortTimeout = 2 * time.Second
|
||||
DefaultUsedPorts = "20000-30000"
|
||||
DefaultExcludedPorts = "1-1024"
|
||||
DefaultServerAddress = "0.0.0.0:8080"
|
||||
DefaultLogLevel = "info"
|
||||
DefaultRunRemoteCmdTimeoutSec = 60
|
||||
DefaultMonitoringDataStorageDuration = "7d"
|
||||
DefaultPairingURL = "https://pairing.rport.io"
|
||||
)
|
||||
|
||||
var serverHelp = `
|
||||
@ -314,6 +315,7 @@ func init() {
|
||||
viperCfg.SetDefault("server.data_dir", chserver.DefaultDataDirectory)
|
||||
viperCfg.SetDefault("server.sqlite_wal", true)
|
||||
viperCfg.SetDefault("server.keep_disconnected_clients", DefaultKeepDisconnectedClients)
|
||||
viperCfg.SetDefault("server.max_concurrent_ssh_handshakes", DefaultMaxConcurrentSSHConnectionHandshakes)
|
||||
viperCfg.SetDefault("server.purge_disconnected_clients_interval", DefaultPurgeDisconnectedClientsInterval)
|
||||
viperCfg.SetDefault("server.check_clients_connection_interval", DefaultCheckClientsConnectionInterval)
|
||||
viperCfg.SetDefault("server.check_clients_connection_timeout", DefaultCheckClientsConnectionTimeout)
|
||||
|
||||
@ -19,10 +19,12 @@ const (
|
||||
WALEnabled = "_journal_mode=WAL"
|
||||
defaultDelayBetweenAttempts = 200 * time.Millisecond
|
||||
DefaultMaxAttempts = 5
|
||||
DefaultMaxOpenConnections = 1
|
||||
)
|
||||
|
||||
type DataSourceOptions struct {
|
||||
WALEnabled bool
|
||||
WALEnabled bool
|
||||
MaxOpenConnections int
|
||||
}
|
||||
|
||||
// New returns a new sqlite DB instance with migrated DB scheme to the latest version.
|
||||
@ -42,7 +44,12 @@ func New(dataSourceName string, assetNames []string, asset func(name string) ([]
|
||||
}
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(1)
|
||||
maxConns := dataSourceOptions.MaxOpenConnections
|
||||
if maxConns == 0 {
|
||||
maxConns = DefaultMaxOpenConnections
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(maxConns)
|
||||
|
||||
s := bindata.Resource(assetNames,
|
||||
func(name string) ([]byte, error) {
|
||||
@ -70,27 +77,29 @@ func New(dataSourceName string, assetNames []string, asset func(name string) ([]
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// TODO: (rs): we've moved to use single db connections. with potentially slower access to the sqlite
|
||||
// volumes it seems there's too much concurrent contention for the dbs, so there's less need for this fn.
|
||||
// not removing yet but check again in approx 6 months (from dec 22) and remove if no longer required.
|
||||
func WithRetryWhenBusy[R any](retryAble func() (result R, err error), label string, l *logger.Logger) (result R, err error) {
|
||||
func WithRetryWhenBusy[R any](retryAbleFn func() (result R, err error), label string, l *logger.Logger) (result R, err error) {
|
||||
for r := 0; r < DefaultMaxAttempts; r++ {
|
||||
result, err = retryAble()
|
||||
if err != nil {
|
||||
attempt := r + 1
|
||||
if attempt > 1 && err != nil {
|
||||
sqlErr, ok := err.(sql.Error)
|
||||
if ok && sqlErr.Code == sql.ErrBusy {
|
||||
l.Debugf("%s: attempt %d: source err = %+v\n", label, r+1, err)
|
||||
l.Debugf("%s: attempt %d: source err = %+v\n", label, attempt, err)
|
||||
jitter := time.Duration((rand.Intn(100))) * time.Millisecond
|
||||
time.Sleep(defaultDelayBetweenAttempts + jitter)
|
||||
continue
|
||||
} else {
|
||||
// a different error from database busy, so fail immediately
|
||||
l.Debugf("%s: attempt %d: non-retryable err = %+v\n", label, attempt, err)
|
||||
return result, err
|
||||
}
|
||||
// non retryable err
|
||||
return result, sqlErr
|
||||
}
|
||||
// success
|
||||
return result, nil
|
||||
// make an attempt to complete the retryable fn
|
||||
result, err = retryAbleFn()
|
||||
// if no error then return immediately if with success result
|
||||
if err == nil {
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
l.Debugf("%s: failed after max attempts: err = %+v\n", label, err)
|
||||
l.Errorf("%s: failed after max attempts: err = %+v\n", label, err)
|
||||
return result, err
|
||||
}
|
||||
|
||||
@ -124,6 +124,12 @@
|
||||
## This is a performance enhancement. Do not turn off, unless you have good reasons.
|
||||
#sqlite_wal = true
|
||||
|
||||
## Limits the number of ssh handshakes that the server will handle concurrently. Too many in progress SSH handshakes
|
||||
## together will slow down the server's ability to perform other work. This can particularly impact server startup
|
||||
## when many clients connect at similar times. A very slow server can also result in strange client reconnect issues.
|
||||
## Default is 4.
|
||||
#max_concurrent_ssh_handshakes = 4
|
||||
|
||||
## An optional parameter to define whether disconnected clients get purged from the database.
|
||||
## By default, disconnected clients are NOT purged.
|
||||
#purge_disconnected_clients = false
|
||||
|
||||
@ -377,8 +377,8 @@ func TestHandleDeleteClientAuth(t *testing.T) {
|
||||
mockConn := &mockConnection{}
|
||||
initState := []*clientsauth.ClientAuth{cl1, cl2, cl3}
|
||||
|
||||
c1 := clients.New(t).ClientAuthID(cl1.ID).Connection(mockConn).Build()
|
||||
c2 := clients.New(t).ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Build()
|
||||
c1 := clients.New(t).ClientAuthID(cl1.ID).Connection(mockConn).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
|
||||
testCases := []struct {
|
||||
descr string // Test Case Description
|
||||
@ -557,8 +557,7 @@ func TestHandleDeleteClientAuth(t *testing.T) {
|
||||
require.NoError(err)
|
||||
assert.ElementsMatch(tc.wantClientsAuth, clients, "clients auth not as expected")
|
||||
assert.Equal(tc.wantClosedConn, mockConn.closed)
|
||||
allClients, err := al.clientService.GetAll()
|
||||
require.NoError(err)
|
||||
allClients := al.clientService.GetAll()
|
||||
assert.ElementsMatch(tc.wantClients, allClients)
|
||||
})
|
||||
}
|
||||
|
||||
@ -6,7 +6,6 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -19,133 +18,11 @@ import (
|
||||
"github.com/cloudradar-monitoring/rport/server/ports"
|
||||
"github.com/cloudradar-monitoring/rport/server/routes"
|
||||
"github.com/cloudradar-monitoring/rport/server/validation"
|
||||
"github.com/cloudradar-monitoring/rport/share/clientconfig"
|
||||
"github.com/cloudradar-monitoring/rport/share/comm"
|
||||
"github.com/cloudradar-monitoring/rport/share/models"
|
||||
"github.com/cloudradar-monitoring/rport/share/query"
|
||||
)
|
||||
|
||||
type ClientPayload struct {
|
||||
ID *string `json:"id,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Address *string `json:"address,omitempty"`
|
||||
Hostname *string `json:"hostname,omitempty"`
|
||||
OS *string `json:"os,omitempty"`
|
||||
OSFullName *string `json:"os_full_name,omitempty"`
|
||||
OSVersion *string `json:"os_version,omitempty"`
|
||||
OSArch *string `json:"os_arch,omitempty"`
|
||||
OSFamily *string `json:"os_family,omitempty"`
|
||||
OSKernel *string `json:"os_kernel,omitempty"`
|
||||
OSVirtualizationSystem *string `json:"os_virtualization_system,omitempty"`
|
||||
OSVirtualizationRole *string `json:"os_virtualization_role,omitempty"`
|
||||
NumCPUs *int `json:"num_cpus,omitempty"`
|
||||
CPUFamily *string `json:"cpu_family,omitempty"`
|
||||
CPUModel *string `json:"cpu_model,omitempty"`
|
||||
CPUModelName *string `json:"cpu_model_name,omitempty"`
|
||||
CPUVendor *string `json:"cpu_vendor,omitempty"`
|
||||
MemoryTotal *uint64 `json:"mem_total,omitempty"`
|
||||
Timezone *string `json:"timezone,omitempty"`
|
||||
ClientAuthID *string `json:"client_auth_id,omitempty"`
|
||||
Version *string `json:"version,omitempty"`
|
||||
DisconnectedAt **time.Time `json:"disconnected_at,omitempty"`
|
||||
LastHeartbeatAt **time.Time `json:"last_heartbeat_at,omitempty"`
|
||||
ConnectionState *string `json:"connection_state,omitempty"`
|
||||
IPv4 *[]string `json:"ipv4,omitempty"`
|
||||
IPv6 *[]string `json:"ipv6,omitempty"`
|
||||
Tags *[]string `json:"tags,omitempty"`
|
||||
AllowedUserGroups *[]string `json:"allowed_user_groups,omitempty"`
|
||||
Tunnels *[]*clienttunnel.Tunnel `json:"tunnels,omitempty"`
|
||||
UpdatesStatus **models.UpdatesStatus `json:"updates_status,omitempty"`
|
||||
ClientConfiguration **clientconfig.Config `json:"client_configuration,omitempty"`
|
||||
Groups *[]string `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func convertToClientsPayload(clients []*clients.CalculatedClient, fields []query.FieldsOption) []ClientPayload {
|
||||
r := make([]ClientPayload, 0, len(clients))
|
||||
for _, cur := range clients {
|
||||
r = append(r, convertToClientPayload(cur, fields))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func convertToClientPayload(client *clients.CalculatedClient, fields []query.FieldsOption) ClientPayload { //nolint:gocyclo
|
||||
requestedFields := query.RequestedFields(fields, "clients")
|
||||
p := ClientPayload{}
|
||||
for field := range clients.OptionsSupportedFields["clients"] {
|
||||
if len(fields) > 0 && !requestedFields[field] {
|
||||
continue
|
||||
}
|
||||
switch field {
|
||||
case "id":
|
||||
p.ID = &client.ID
|
||||
case "name":
|
||||
p.Name = &client.Name
|
||||
case "os":
|
||||
p.OS = &client.OS
|
||||
case "os_arch":
|
||||
p.OSArch = &client.OSArch
|
||||
case "os_family":
|
||||
p.OSFamily = &client.OSFamily
|
||||
case "os_kernel":
|
||||
p.OSKernel = &client.OSKernel
|
||||
case "hostname":
|
||||
p.Hostname = &client.Hostname
|
||||
case "ipv4":
|
||||
p.IPv4 = &client.IPv4
|
||||
case "ipv6":
|
||||
p.IPv6 = &client.IPv6
|
||||
case "tags":
|
||||
p.Tags = &client.Tags
|
||||
case "version":
|
||||
p.Version = &client.Version
|
||||
case "address":
|
||||
p.Address = &client.Address
|
||||
case "tunnels":
|
||||
p.Tunnels = &client.Tunnels
|
||||
case "disconnected_at":
|
||||
p.DisconnectedAt = &client.DisconnectedAt
|
||||
case "last_heartbeat_at":
|
||||
p.LastHeartbeatAt = &client.LastHeartbeatAt
|
||||
case "connection_state":
|
||||
connectionState := string(client.ConnectionState)
|
||||
p.ConnectionState = &connectionState
|
||||
case "client_auth_id":
|
||||
p.ClientAuthID = &client.ClientAuthID
|
||||
case "os_full_name":
|
||||
p.OSFullName = &client.OSFullName
|
||||
case "os_version":
|
||||
p.OSVersion = &client.OSVersion
|
||||
case "os_virtualization_system":
|
||||
p.OSVirtualizationSystem = &client.OSVirtualizationSystem
|
||||
case "os_virtualization_role":
|
||||
p.OSVirtualizationRole = &client.OSVirtualizationRole
|
||||
case "cpu_family":
|
||||
p.CPUFamily = &client.CPUFamily
|
||||
case "cpu_model":
|
||||
p.CPUModel = &client.CPUModel
|
||||
case "cpu_model_name":
|
||||
p.CPUModelName = &client.CPUModelName
|
||||
case "cpu_vendor":
|
||||
p.CPUVendor = &client.CPUVendor
|
||||
case "timezone":
|
||||
p.Timezone = &client.Timezone
|
||||
case "num_cpus":
|
||||
p.NumCPUs = &client.NumCPUs
|
||||
case "mem_total":
|
||||
p.MemoryTotal = &client.MemoryTotal
|
||||
case "allowed_user_groups":
|
||||
p.AllowedUserGroups = &client.AllowedUserGroups
|
||||
case "updates_status":
|
||||
p.UpdatesStatus = &client.UpdatesStatus
|
||||
case "client_configuration":
|
||||
p.ClientConfiguration = &client.ClientConfiguration
|
||||
case "groups":
|
||||
p.Groups = &client.Groups
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func getCorrespondingSortFunc(sorts []query.SortOption) (sortFunc func(a []*clients.CalculatedClient, desc bool), desc bool, err error) {
|
||||
if len(sorts) < 1 {
|
||||
return clients.SortByID, false, nil
|
||||
@ -200,7 +77,7 @@ func (al *APIListener) handleGetClient(w http.ResponseWriter, req *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
clientPayload := convertToClientPayload(client.ToCalculated(groups), options.Fields)
|
||||
clientPayload := clients.ConvertToClientPayload(client.ToCalculated(groups), options.Fields)
|
||||
al.writeJSONResponse(w, http.StatusOK, api.NewSuccessPayload(clientPayload))
|
||||
}
|
||||
|
||||
@ -291,19 +168,20 @@ func (al *APIListener) handleGetClients(w http.ResponseWriter, req *http.Request
|
||||
return
|
||||
}
|
||||
|
||||
cls, err := al.clientService.GetFilteredUserClients(curUser, options.Filters, groups)
|
||||
filteredClients, err := al.clientService.GetFilteredUserClients(curUser, options.Filters, groups)
|
||||
if err != nil {
|
||||
al.jsonError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
sortFunc(cls, desc)
|
||||
sortFunc(filteredClients, desc)
|
||||
|
||||
totalCount := len(cls)
|
||||
totalCount := len(filteredClients)
|
||||
start, end := options.Pagination.GetStartEnd(totalCount)
|
||||
cls = cls[start:end]
|
||||
filteredClients = filteredClients[start:end]
|
||||
|
||||
clientsPayload := clients.ConvertToClientsPayload(filteredClients, options.Fields)
|
||||
|
||||
clientsPayload := convertToClientsPayload(cls, options.Fields)
|
||||
al.writeJSONResponse(w, http.StatusOK, &api.SuccessPayload{
|
||||
Data: clientsPayload,
|
||||
Meta: api.NewMeta(totalCount),
|
||||
@ -345,7 +223,7 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
|
||||
}
|
||||
|
||||
if client.IsPaused() {
|
||||
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to start tunnel for client with id %s due to client being paused (reason = %s)", clientID, client.PausedReason))
|
||||
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to start tunnel for client with id %s due to client being paused (reason = %s)", clientID, client.GetPausedReason()))
|
||||
return
|
||||
}
|
||||
|
||||
@ -411,7 +289,7 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
|
||||
remote.ACL = &aclStr
|
||||
}
|
||||
|
||||
allowed, err := clienttunnel.IsAllowed(remote.Remote(), client.Connection)
|
||||
allowed, err := clienttunnel.IsAllowed(remote.Remote(), client.GetConnection(), al.Log())
|
||||
if err != nil {
|
||||
al.jsonError(w, err)
|
||||
return
|
||||
@ -426,7 +304,7 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
for _, t := range client.Tunnels {
|
||||
for _, t := range client.GetTunnels() {
|
||||
if t.Remote.Remote() == remote.Remote() && t.Remote.IsProtocol(remote.Protocol) && t.EqualACL(remote.ACL) {
|
||||
al.jsonErrorResponseWithErrCode(w, http.StatusBadRequest, ErrCodeTunnelToPortExist, fmt.Sprintf("Tunnel to port %s already exists.", remote.RemotePort))
|
||||
return
|
||||
@ -434,17 +312,13 @@ func (al *APIListener) handlePutClientTunnel(w http.ResponseWriter, req *http.Re
|
||||
}
|
||||
|
||||
if checkPortStr := req.URL.Query().Get("check_port"); checkPortStr != "0" && remote.IsProtocol(models.ProtocolTCP) {
|
||||
err = al.checkRemotePort(*remote, client.Connection)
|
||||
err = al.checkRemotePort(*remote, client.GetConnection())
|
||||
if err != nil {
|
||||
al.jsonError(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// make next steps thread-safe
|
||||
client.Lock()
|
||||
defer client.Unlock()
|
||||
|
||||
if remote.IsLocalSpecified() {
|
||||
err = al.checkLocalPort(remote.LocalPort, remote.Protocol)
|
||||
if err != nil {
|
||||
@ -588,7 +462,7 @@ func (al *APIListener) checkRemotePort(remote models.Remote, conn ssh.Conn) (err
|
||||
Timeout: al.config.Server.CheckPortTimeout,
|
||||
}
|
||||
resp := &comm.CheckPortResponse{}
|
||||
err = comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckPort, req, resp)
|
||||
err = comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckPort, req, resp, al.Log())
|
||||
if err != nil {
|
||||
if _, ok := err.(*comm.ClientError); ok {
|
||||
err = apierrors.NewAPIError(http.StatusConflict, "", "", err)
|
||||
@ -646,10 +520,6 @@ func (al *APIListener) handleDeleteClientTunnel(w http.ResponseWriter, req *http
|
||||
return
|
||||
}
|
||||
|
||||
// make next steps thread-safe
|
||||
client.Lock()
|
||||
defer client.Unlock()
|
||||
|
||||
tunnel := al.clientService.FindTunnel(client, tunnelID)
|
||||
if tunnel == nil {
|
||||
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, "tunnel not found")
|
||||
|
||||
@ -37,7 +37,7 @@ func (mockClientGroupProvider) GetAll(ctx context.Context) ([]*cgroups.ClientGro
|
||||
}
|
||||
|
||||
func TestHandleGetClient(t *testing.T) {
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
|
||||
al := APIListener{
|
||||
insecureForTests: true,
|
||||
Server: &Server{
|
||||
@ -175,8 +175,12 @@ func TestHandleGetClients(t *testing.T) {
|
||||
Username: "admin",
|
||||
Groups: []string{users.Administrators},
|
||||
}
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
|
||||
c2 := clients.New(t).ID("client-2").ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Build()
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
|
||||
c1.Logger = testLog
|
||||
|
||||
c2 := clients.New(t).ID("client-2").ClientAuthID(cl1.ID).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c2.Logger = testLog
|
||||
|
||||
al := APIListener{
|
||||
insecureForTests: true,
|
||||
Server: &Server{
|
||||
@ -521,8 +525,8 @@ func TestHandlePutTunnelWithName(t *testing.T) {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
|
||||
c1.Connection = connMock
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
|
||||
c1.SetConnection(connMock)
|
||||
c1.Logger = testLog
|
||||
|
||||
mockClientService := &SimpleMockClientService{
|
||||
@ -724,8 +728,8 @@ func TestHandlePutTunnelUsingCaddyProxies(t *testing.T) {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
|
||||
c1.Connection = connMock
|
||||
c1 := clients.New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
|
||||
c1.SetConnection(connMock)
|
||||
c1.Logger = testLog
|
||||
|
||||
mockClientService := &SimpleMockClientService{
|
||||
|
||||
@ -366,7 +366,7 @@ func (al *APIListener) handleExecuteCommand(ctx context.Context, w http.Response
|
||||
}
|
||||
|
||||
if client.IsPaused() {
|
||||
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to execute command/script for client with id %s due to client being paused (reason = %s)", client.ID, client.PausedReason))
|
||||
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("failed to execute command/script for client with id %s due to client being paused (reason = %s)", client.GetID(), client.GetPausedReason()))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -382,7 +382,7 @@ func (al *APIListener) handleExecuteCommand(ctx context.Context, w http.Response
|
||||
JID: jid,
|
||||
FinishedAt: nil,
|
||||
ClientID: executeInput.ClientID,
|
||||
ClientName: client.Name,
|
||||
ClientName: client.GetName(),
|
||||
Command: executeInput.Command,
|
||||
Interpreter: executeInput.Interpreter,
|
||||
CreatedBy: api.GetUser(ctx, al.Logger),
|
||||
@ -393,7 +393,7 @@ func (al *APIListener) handleExecuteCommand(ctx context.Context, w http.Response
|
||||
IsScript: executeInput.IsScript,
|
||||
}
|
||||
sshResp := &comm.RunCmdResponse{}
|
||||
err = comm.SendRequestAndGetResponse(client.Connection, comm.RequestTypeRunCmd, curJob, sshResp)
|
||||
err = comm.SendRequestAndGetResponse(client.GetConnection(), comm.RequestTypeRunCmd, curJob, sshResp, al.Log())
|
||||
if err != nil {
|
||||
if _, ok := err.(*comm.ClientError); ok {
|
||||
al.jsonErrorResponseWithTitle(w, http.StatusConflict, err.Error())
|
||||
|
||||
@ -107,8 +107,8 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
connMock.ReturnResponsePayload = sshRespBytes
|
||||
|
||||
c1 := clients.New(t).Connection(connMock).Build()
|
||||
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Build()
|
||||
c1 := clients.New(t).Connection(connMock).Logger(testLog).Build()
|
||||
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -132,7 +132,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "valid cmd",
|
||||
requestBody: validReqBody,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantTimeout: gotCmdTimeoutSec,
|
||||
@ -140,7 +140,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "valid cmd with interpreter",
|
||||
requestBody: `{"command": "` + gotCmd + `","interpreter": "powershell"}`,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusOK,
|
||||
wantTimeout: defaultTimeout,
|
||||
@ -149,7 +149,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "invalid interpreter",
|
||||
requestBody: `{"command": "` + gotCmd + `","interpreter": "unsupported"}`,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErrTitle: "Invalid interpreter.",
|
||||
@ -158,7 +158,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "valid cmd with no timeout",
|
||||
requestBody: `{"command": "/bin/date;foo;whoami"}`,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantTimeout: defaultTimeout,
|
||||
wantStatusCode: http.StatusOK,
|
||||
@ -166,7 +166,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "valid cmd with 0 timeout",
|
||||
requestBody: `{"command": "/bin/date;foo;whoami", "timeout_sec": 0}`,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantTimeout: defaultTimeout,
|
||||
wantStatusCode: http.StatusOK,
|
||||
@ -174,7 +174,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "empty cmd",
|
||||
requestBody: `{"command": "", "timeout_sec": 30}`,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErrTitle: "Command cannot be empty.",
|
||||
@ -182,7 +182,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "no cmd",
|
||||
requestBody: `{"timeout_sec": 30}`,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErrTitle: "Command cannot be empty.",
|
||||
@ -190,7 +190,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "empty body",
|
||||
requestBody: "",
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErrTitle: "Missing body with json data.",
|
||||
@ -198,7 +198,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "invalid request body",
|
||||
requestBody: "sdfn fasld fasdf sdlf jd",
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErrTitle: "Invalid JSON data.",
|
||||
@ -207,7 +207,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "invalid request body: unknown param",
|
||||
requestBody: `{"command": "/bin/date;foo;whoami", "timeout": 30}`,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErrTitle: "Invalid JSON data.",
|
||||
@ -216,24 +216,24 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
{
|
||||
name: "no active client",
|
||||
requestBody: validReqBody,
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{},
|
||||
wantStatusCode: http.StatusNotFound,
|
||||
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c1.ID),
|
||||
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c1.GetID()),
|
||||
},
|
||||
{
|
||||
name: "disconnected client",
|
||||
requestBody: validReqBody,
|
||||
cid: c2.ID,
|
||||
cid: c2.GetID(),
|
||||
clients: []*clients.Client{c1, c2},
|
||||
wantStatusCode: http.StatusNotFound,
|
||||
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c2.ID),
|
||||
wantErrTitle: fmt.Sprintf("Active client with id=%q not found.", c2.GetID()),
|
||||
},
|
||||
{
|
||||
name: "error on save job",
|
||||
requestBody: validReqBody,
|
||||
jpReturnSaveErr: errors.New("save fake error"),
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantErrTitle: "Failed to persist a new job.",
|
||||
@ -243,7 +243,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
name: "error on send request",
|
||||
requestBody: validReqBody,
|
||||
connReturnErr: errors.New("send fake error"),
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantErrTitle: "Failed to execute remote command.",
|
||||
@ -253,7 +253,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
name: "invalid ssh response format",
|
||||
requestBody: validReqBody,
|
||||
connReturnResp: []byte("invalid ssh response data"),
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusConflict,
|
||||
wantErrTitle: "invalid client response format: failed to decode response into *comm.RunCmdResponse: invalid character 'i' looking for beginning of value",
|
||||
@ -263,7 +263,7 @@ func TestHandlePostCommand(t *testing.T) {
|
||||
requestBody: validReqBody,
|
||||
connReturnNotOk: true,
|
||||
connReturnResp: []byte("fake failure msg"),
|
||||
cid: c1.ID,
|
||||
cid: c1.GetID(),
|
||||
clients: []*clients.Client{c1},
|
||||
wantStatusCode: http.StatusConflict,
|
||||
wantErrTitle: "client error: fake failure msg",
|
||||
@ -552,9 +552,9 @@ func TestHandlePostMultiClientCommand(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
connMock2.ReturnResponsePayload = sshRespBytes2
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
|
||||
c1.Logger = testLog
|
||||
c2.Logger = testLog
|
||||
@ -565,7 +565,7 @@ func TestHandlePostMultiClientCommand(t *testing.T) {
|
||||
gotCmdTimeoutSec := 30
|
||||
validReqBody := `{"command": "` + gotCmd +
|
||||
`","timeout_sec": ` + strconv.Itoa(gotCmdTimeoutSec) +
|
||||
`,"client_ids": ["` + c1.ID + `", "` + c2.ID + `"]` +
|
||||
`,"client_ids": ["` + c1.GetID() + `", "` + c2.GetID() + `"]` +
|
||||
`,"abort_on_error": false` +
|
||||
`,"execute_concurrently": false` +
|
||||
`}`
|
||||
@ -762,11 +762,11 @@ func TestHandlePostMultiClientCommandWithPausedClient(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
connMock2.ReturnResponsePayload = sshRespBytes2
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c1.Logger = testLog
|
||||
c1.SetPaused(true, clients.PausedDueToMaxClientsExceeded)
|
||||
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c2.Logger = testLog
|
||||
|
||||
defaultTimeout := 60
|
||||
@ -775,14 +775,14 @@ func TestHandlePostMultiClientCommandWithPausedClient(t *testing.T) {
|
||||
|
||||
c1ValidReqBody := `{"command": "` + gotCmd +
|
||||
`","timeout_sec": ` + strconv.Itoa(gotCmdTimeoutSec) +
|
||||
`,"client_ids": ["` + c1.ID + `"]` +
|
||||
`,"client_ids": ["` + c1.GetID() + `"]` +
|
||||
`,"abort_on_error": false` +
|
||||
`,"execute_concurrently": false` +
|
||||
`}`
|
||||
|
||||
c2ValidReqBody := `{"command": "` + gotCmd +
|
||||
`","timeout_sec": ` + strconv.Itoa(gotCmdTimeoutSec) +
|
||||
`,"client_ids": ["` + c2.ID + `"]` +
|
||||
`,"client_ids": ["` + c2.GetID() + `"]` +
|
||||
`,"abort_on_error": false` +
|
||||
`,"execute_concurrently": false` +
|
||||
`}`
|
||||
@ -973,10 +973,10 @@ func TestHandlePostMultiClientCommandWithGroupIDs(t *testing.T) {
|
||||
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
|
||||
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
|
||||
|
||||
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
|
||||
@ -990,9 +990,9 @@ func TestHandlePostMultiClientCommandWithGroupIDs(t *testing.T) {
|
||||
Version: &cgroups.ParamValues{"0.1.1*"},
|
||||
})
|
||||
|
||||
c1.AllowedUserGroups = []string{"group-1"}
|
||||
c2.AllowedUserGroups = []string{"group-1"}
|
||||
c4.AllowedUserGroups = []string{"group-2"}
|
||||
c1.SetAllowedUserGroups([]string{"group-1"})
|
||||
c2.SetAllowedUserGroups([]string{"group-1"})
|
||||
c4.SetAllowedUserGroups([]string{"group-2"})
|
||||
|
||||
al := makeAPIListener(curUser,
|
||||
clients.NewClientRepository([]*clients.Client{c1, c2, c3, c4}, &hour, testLog),
|
||||
@ -1187,15 +1187,15 @@ func TestHandlePostMultiClientCommandWithTags(t *testing.T) {
|
||||
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
|
||||
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
|
||||
|
||||
c1.Tags = []string{"linux"}
|
||||
c2.Tags = []string{"windows"}
|
||||
c3.Tags = []string{"mac"}
|
||||
c4.Tags = []string{"linux", "windows"}
|
||||
c1.SetTags([]string{"linux"})
|
||||
c2.SetTags([]string{"windows"})
|
||||
c3.SetTags([]string{"mac"})
|
||||
c4.SetTags([]string{"linux", "windows"})
|
||||
|
||||
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
|
||||
@ -1209,9 +1209,9 @@ func TestHandlePostMultiClientCommandWithTags(t *testing.T) {
|
||||
Version: &cgroups.ParamValues{"0.1.1*"},
|
||||
})
|
||||
|
||||
c1.AllowedUserGroups = []string{"group-1"}
|
||||
c2.AllowedUserGroups = []string{"group-1"}
|
||||
c4.AllowedUserGroups = []string{"group-2"}
|
||||
c1.SetAllowedUserGroups([]string{"group-1"})
|
||||
c2.SetAllowedUserGroups([]string{"group-1"})
|
||||
c4.SetAllowedUserGroups([]string{"group-2"})
|
||||
|
||||
clientList := []*clients.Client{c1, c2, c4}
|
||||
|
||||
@ -1425,15 +1425,15 @@ func TestHandlePostMultiClientWSCommandWithTags(t *testing.T) {
|
||||
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
|
||||
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
|
||||
|
||||
c1.Tags = []string{"linux"}
|
||||
c2.Tags = []string{"windows"}
|
||||
c3.Tags = []string{"mac"}
|
||||
c4.Tags = []string{"linux", "windows"}
|
||||
c1.SetTags([]string{"linux"})
|
||||
c2.SetTags([]string{"windows"})
|
||||
c3.SetTags([]string{"mac"})
|
||||
c4.SetTags([]string{"linux", "windows"})
|
||||
|
||||
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
|
||||
@ -1447,9 +1447,9 @@ func TestHandlePostMultiClientWSCommandWithTags(t *testing.T) {
|
||||
Version: &cgroups.ParamValues{"0.1.1*"},
|
||||
})
|
||||
|
||||
c1.AllowedUserGroups = []string{"group-1"}
|
||||
c2.AllowedUserGroups = []string{"group-1"}
|
||||
c4.AllowedUserGroups = []string{"group-2"}
|
||||
c1.SetAllowedUserGroups([]string{"group-1"})
|
||||
c2.SetAllowedUserGroups([]string{"group-1"})
|
||||
c4.SetAllowedUserGroups([]string{"group-2"})
|
||||
|
||||
clientList := []*clients.Client{c1, c2, c4}
|
||||
|
||||
@ -1659,15 +1659,15 @@ func TestHandlePostMultiClientScriptWithTags(t *testing.T) {
|
||||
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
|
||||
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
|
||||
|
||||
c1.Tags = []string{"linux"}
|
||||
c2.Tags = []string{"windows"}
|
||||
c3.Tags = []string{"mac"}
|
||||
c4.Tags = []string{"linux", "windows"}
|
||||
c1.SetTags([]string{"linux"})
|
||||
c2.SetTags([]string{"windows"})
|
||||
c3.SetTags([]string{"mac"})
|
||||
c4.SetTags([]string{"linux", "windows"})
|
||||
|
||||
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
|
||||
@ -1681,9 +1681,9 @@ func TestHandlePostMultiClientScriptWithTags(t *testing.T) {
|
||||
Version: &cgroups.ParamValues{"0.1.1*"},
|
||||
})
|
||||
|
||||
c1.AllowedUserGroups = []string{"group-1"}
|
||||
c2.AllowedUserGroups = []string{"group-1"}
|
||||
c4.AllowedUserGroups = []string{"group-2"}
|
||||
c1.SetAllowedUserGroups([]string{"group-1"})
|
||||
c2.SetAllowedUserGroups([]string{"group-1"})
|
||||
c4.SetAllowedUserGroups([]string{"group-2"})
|
||||
|
||||
clientList := []*clients.Client{c1, c2, c4}
|
||||
|
||||
@ -1899,15 +1899,15 @@ func TestHandlePostMultiClientWSScriptWithTags(t *testing.T) {
|
||||
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
|
||||
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
|
||||
|
||||
c1.Tags = []string{"linux"}
|
||||
c2.Tags = []string{"windows"}
|
||||
c3.Tags = []string{"mac"}
|
||||
c4.Tags = []string{"linux", "windows"}
|
||||
c1.SetTags([]string{"linux"})
|
||||
c2.SetTags([]string{"windows"})
|
||||
c3.SetTags([]string{"mac"})
|
||||
c4.SetTags([]string{"linux", "windows"})
|
||||
|
||||
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
|
||||
@ -1921,9 +1921,9 @@ func TestHandlePostMultiClientWSScriptWithTags(t *testing.T) {
|
||||
Version: &cgroups.ParamValues{"0.1.1*"},
|
||||
})
|
||||
|
||||
c1.AllowedUserGroups = []string{"group-1"}
|
||||
c2.AllowedUserGroups = []string{"group-1"}
|
||||
c4.AllowedUserGroups = []string{"group-2"}
|
||||
c1.SetAllowedUserGroups([]string{"group-1"})
|
||||
c2.SetAllowedUserGroups([]string{"group-1"})
|
||||
c4.SetAllowedUserGroups([]string{"group-2"})
|
||||
|
||||
clientList := []*clients.Client{c1, c2, c4}
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ func (al *APIListener) handleRefreshUpdatesStatus(w http.ResponseWriter, req *ht
|
||||
return
|
||||
}
|
||||
|
||||
err = comm.SendRequestAndGetResponse(client.Connection, comm.RequestTypeRefreshUpdatesStatus, nil, nil)
|
||||
err = comm.SendRequestAndGetResponse(client.GetConnection(), comm.RequestTypeRefreshUpdatesStatus, nil, nil, al.Log())
|
||||
if err != nil {
|
||||
al.jsonErrorResponse(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
@ -77,8 +77,10 @@ func (al *APIListener) handleGetClientGraphMetrics(w http.ResponseWriter, req *h
|
||||
|
||||
queryOptions := query.NewOptions(req, monitoring.ClientGraphMetricsSortDefault, monitoring.ClientGraphMetricsFilterDefault, monitoring.ClientGraphMetricsFieldsDefault)
|
||||
requestInfo := query.ParseRequestInfo(req)
|
||||
netLan := client.ClientConfiguration.Monitoring.LanCard != nil
|
||||
netWan := client.ClientConfiguration.Monitoring.WanCard != nil
|
||||
|
||||
monitoringConfig := client.GetMonitoringConfig()
|
||||
netLan := monitoringConfig.LanCard != nil
|
||||
netWan := monitoringConfig.WanCard != nil
|
||||
|
||||
payload, err := al.monitoringService.ListClientGraphMetrics(req.Context(), clientID, queryOptions, requestInfo, netLan, netWan)
|
||||
if err != nil {
|
||||
@ -110,7 +112,9 @@ func (al *APIListener) handleGetClientGraphMetricsGraph(w http.ResponseWriter, r
|
||||
return
|
||||
}
|
||||
|
||||
payload, err := al.monitoringService.ListClientGraph(req.Context(), clientID, queryOptions, graph, client.ClientConfiguration.Monitoring.LanCard, client.ClientConfiguration.Monitoring.WanCard)
|
||||
monitoringConfig := client.GetMonitoringConfig()
|
||||
|
||||
payload, err := al.monitoringService.ListClientGraph(req.Context(), clientID, queryOptions, graph, monitoringConfig.LanCard, monitoringConfig.WanCard)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
al.jsonErrorResponseWithTitle(w, http.StatusNotFound, fmt.Sprintf("graph-metrics for client with id %q not found", clientID))
|
||||
|
||||
@ -17,8 +17,8 @@ import (
|
||||
)
|
||||
|
||||
func TestHandleRefreshUpdatesStatus(t *testing.T) {
|
||||
c1 := clients.New(t).Build()
|
||||
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Build()
|
||||
c1 := clients.New(t).Logger(testLog).Build()
|
||||
c2 := clients.New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
@ -29,13 +29,13 @@ func TestHandleRefreshUpdatesStatus(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
Name: "Connected client",
|
||||
ClientID: c1.ID,
|
||||
ClientID: c1.GetID(),
|
||||
ExpectedStatus: http.StatusNoContent,
|
||||
ExpectedRequestName: comm.RequestTypeRefreshUpdatesStatus,
|
||||
},
|
||||
{
|
||||
Name: "Disconnected client",
|
||||
ClientID: c2.ID,
|
||||
ClientID: c2.GetID(),
|
||||
ExpectedStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
@ -45,7 +45,7 @@ func TestHandleRefreshUpdatesStatus(t *testing.T) {
|
||||
},
|
||||
{
|
||||
Name: "SSH error",
|
||||
ClientID: c1.ID,
|
||||
ClientID: c1.GetID(),
|
||||
SSHError: true,
|
||||
ExpectedRequestName: comm.RequestTypeRefreshUpdatesStatus,
|
||||
ExpectedStatus: http.StatusInternalServerError,
|
||||
@ -57,7 +57,7 @@ func TestHandleRefreshUpdatesStatus(t *testing.T) {
|
||||
connMock := test.NewConnMock()
|
||||
// by default set to return success
|
||||
connMock.ReturnOk = !tc.SSHError
|
||||
c1.Connection = connMock
|
||||
c1.SetConnection(connMock)
|
||||
clientService := clients.NewClientService(nil, nil, clients.NewClientRepository([]*clients.Client{c1, c2}, &hour, testLog), testLog)
|
||||
al := APIListener{
|
||||
insecureForTests: true,
|
||||
|
||||
@ -153,15 +153,15 @@ func TestHandlePostScheduleMultiClientJobWithTags(t *testing.T) {
|
||||
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
|
||||
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
|
||||
|
||||
c1.Tags = []string{"linux"}
|
||||
c2.Tags = []string{"windows"}
|
||||
c3.Tags = []string{"mac"}
|
||||
c4.Tags = []string{"linux", "windows"}
|
||||
c1.SetTags([]string{"linux"})
|
||||
c2.SetTags([]string{"windows"})
|
||||
c3.SetTags([]string{"mac"})
|
||||
c4.SetTags([]string{"linux", "windows"})
|
||||
|
||||
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
|
||||
@ -175,9 +175,9 @@ func TestHandlePostScheduleMultiClientJobWithTags(t *testing.T) {
|
||||
Version: &cgroups.ParamValues{"0.1.1*"},
|
||||
})
|
||||
|
||||
c1.AllowedUserGroups = []string{"group-1"}
|
||||
c2.AllowedUserGroups = []string{"group-1"}
|
||||
c4.AllowedUserGroups = []string{"group-2"}
|
||||
c1.SetAllowedUserGroups([]string{"group-1"})
|
||||
c2.SetAllowedUserGroups([]string{"group-1"})
|
||||
c4.SetAllowedUserGroups([]string{"group-2"})
|
||||
|
||||
clientList := []*clients.Client{c1, c2, c4}
|
||||
|
||||
@ -382,15 +382,15 @@ func TestHandlePostUpdateScheduleMultiClientJobWithTags(t *testing.T) {
|
||||
connMock2 := makeConnMock(t, 2, time.Date(2020, 10, 10, 10, 10, 2, 0, time.UTC))
|
||||
connMock4 := makeConnMock(t, 4, time.Date(2020, 10, 10, 10, 10, 4, 0, time.UTC))
|
||||
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Build()
|
||||
c1 := clients.New(t).ID("client-1").Connection(connMock1).Logger(testLog).Build()
|
||||
c2 := clients.New(t).ID("client-2").Connection(connMock2).Logger(testLog).Build()
|
||||
c3 := clients.New(t).ID("client-3").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4 := clients.New(t).ID("client-4").Connection(connMock4).Logger(testLog).Build()
|
||||
|
||||
c1.Tags = []string{"linux"}
|
||||
c2.Tags = []string{"windows"}
|
||||
c3.Tags = []string{"mac"}
|
||||
c4.Tags = []string{"linux", "windows"}
|
||||
c1.SetTags([]string{"linux"})
|
||||
c2.SetTags([]string{"windows"})
|
||||
c3.SetTags([]string{"mac"})
|
||||
c4.SetTags([]string{"linux", "windows"})
|
||||
|
||||
g1 := makeClientGroup("group-1", &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{"client-1", "client-2"},
|
||||
@ -404,9 +404,9 @@ func TestHandlePostUpdateScheduleMultiClientJobWithTags(t *testing.T) {
|
||||
Version: &cgroups.ParamValues{"0.1.1*"},
|
||||
})
|
||||
|
||||
c1.AllowedUserGroups = []string{"group-1"}
|
||||
c2.AllowedUserGroups = []string{"group-1"}
|
||||
c4.AllowedUserGroups = []string{"group-2"}
|
||||
c1.SetAllowedUserGroups([]string{"group-1"})
|
||||
c2.SetAllowedUserGroups([]string{"group-1"})
|
||||
c4.SetAllowedUserGroups([]string{"group-2"})
|
||||
|
||||
clientList := []*clients.Client{c1, c2, c4}
|
||||
|
||||
|
||||
@ -8,11 +8,7 @@ import (
|
||||
)
|
||||
|
||||
func (al *APIListener) handleGetStatus(w http.ResponseWriter, req *http.Request) {
|
||||
countActive, err := al.clientService.CountActive()
|
||||
if err != nil {
|
||||
al.jsonErrorResponse(w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
countActive := al.clientService.CountActive()
|
||||
|
||||
countDisconnected, err := al.clientService.CountDisconnected()
|
||||
if err != nil {
|
||||
|
||||
@ -28,7 +28,7 @@ func (al *APIListener) handleGetStoredTunnels(w http.ResponseWriter, req *http.R
|
||||
}
|
||||
|
||||
options := query.GetListOptions(req)
|
||||
result, err := al.storedTunnels.List(ctx, options, client.ID)
|
||||
result, err := al.storedTunnels.List(ctx, options, client.GetID())
|
||||
if err != nil {
|
||||
al.jsonError(w, err)
|
||||
return
|
||||
@ -59,7 +59,7 @@ func (al *APIListener) handlePostStoredTunnels(w http.ResponseWriter, req *http.
|
||||
return
|
||||
}
|
||||
|
||||
result, err := al.storedTunnels.Create(ctx, client.ID, storedTunnel)
|
||||
result, err := al.storedTunnels.Create(ctx, client.GetID(), storedTunnel)
|
||||
if err != nil {
|
||||
al.jsonError(w, err)
|
||||
return
|
||||
@ -84,7 +84,7 @@ func (al *APIListener) handleDeleteStoredTunnel(w http.ResponseWriter, req *http
|
||||
return
|
||||
}
|
||||
|
||||
err = al.storedTunnels.Delete(ctx, client.ID, tunnelID)
|
||||
err = al.storedTunnels.Delete(ctx, client.GetID(), tunnelID)
|
||||
if err != nil {
|
||||
al.jsonError(w, err)
|
||||
return
|
||||
@ -117,7 +117,7 @@ func (al *APIListener) handlePutStoredTunnel(w http.ResponseWriter, req *http.Re
|
||||
}
|
||||
storedTunnel.ID = tunnelID
|
||||
|
||||
result, err := al.storedTunnels.Update(ctx, client.ID, storedTunnel)
|
||||
result, err := al.storedTunnels.Update(ctx, client.GetID(), storedTunnel)
|
||||
if err != nil {
|
||||
al.jsonError(w, err)
|
||||
return
|
||||
|
||||
@ -44,12 +44,13 @@ func (al *APIListener) handleGetTunnels(w http.ResponseWriter, req *http.Request
|
||||
|
||||
tunnels := make([]TunnelPayload, 0)
|
||||
for _, c := range clients {
|
||||
if c.DisconnectedAt != nil {
|
||||
clientID := c.GetID()
|
||||
if !c.IsConnected() {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, t := range c.Tunnels {
|
||||
tunnels = append(tunnels, convertToTunnelPayload(t, c.ID))
|
||||
for _, t := range c.GetTunnels() {
|
||||
tunnels = append(tunnels, convertToTunnelPayload(t, clientID))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -84,8 +84,9 @@ func (al *APIListener) getOrderedClients(
|
||||
|
||||
// append group clients
|
||||
for _, groupClient := range groupClients {
|
||||
if !usedClientIDs[groupClient.ID] {
|
||||
usedClientIDs[groupClient.ID] = true
|
||||
groupClientID := groupClient.GetID()
|
||||
if !usedClientIDs[groupClientID] {
|
||||
usedClientIDs[groupClientID] = true
|
||||
orderedClients = append(orderedClients, groupClient)
|
||||
}
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ package chserver
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
@ -110,6 +111,13 @@ func (al *APIListener) handleCommandsExecutionWS(
|
||||
}
|
||||
|
||||
for _, client := range inboundMsg.OrderedClients {
|
||||
if client.IsPaused() {
|
||||
msg := fmt.Sprintf("failed to execute command/script for client with id %s", client.GetID())
|
||||
err := fmt.Errorf("client is paused (reason = %s)", client.GetPausedReason())
|
||||
uiConnTS.WriteError(msg, err)
|
||||
continue
|
||||
}
|
||||
|
||||
curJID, err := generateNewJobID()
|
||||
if err != nil {
|
||||
uiConnTS.WriteError("Could not generate job id.", err)
|
||||
@ -168,7 +176,14 @@ func (al *APIListener) handleCommandsExecutionWS(
|
||||
} else {
|
||||
client := inboundMsg.OrderedClients[0]
|
||||
|
||||
al.createAndRunJob( //nolint:errcheck // error is logged, nothing to act on here
|
||||
if client.IsPaused() {
|
||||
msg := fmt.Sprintf("failed to execute command/script for client with id %s", client.GetID())
|
||||
err := fmt.Errorf("client is paused (reason = %s)", client.GetPausedReason())
|
||||
uiConnTS.WriteError(msg, err)
|
||||
return
|
||||
}
|
||||
|
||||
al.createAndRunJob(
|
||||
uiConnTS,
|
||||
nil,
|
||||
jid,
|
||||
|
||||
@ -49,8 +49,8 @@ func (al *APIListener) createAndRunJob(
|
||||
curJob := models.Job{
|
||||
JID: jid,
|
||||
StartedAt: time.Now(),
|
||||
ClientID: client.ID,
|
||||
ClientName: client.Name,
|
||||
ClientID: client.GetID(),
|
||||
ClientName: client.GetName(),
|
||||
Command: cmd,
|
||||
Cwd: cwd,
|
||||
IsSudo: isSudo,
|
||||
@ -69,13 +69,14 @@ func (al *APIListener) createAndRunJob(
|
||||
var err error
|
||||
if !client.IsPaused() {
|
||||
if client.Connection != nil {
|
||||
err = comm.SendRequestAndGetResponse(client.Connection, comm.RequestTypeRunCmd, curJob, sshResp)
|
||||
err = comm.SendRequestAndGetResponse(client.GetConnection(), comm.RequestTypeRunCmd, curJob, sshResp, al.Log())
|
||||
} else {
|
||||
err = ErrClientNotConnected
|
||||
}
|
||||
} else {
|
||||
err = fmt.Errorf("client is paused (reason = %s)", client.PausedReason)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
al.Errorf("%s, Error on execute remote command: %v", logPrefix, err)
|
||||
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
@ -72,6 +73,15 @@ type APIListener struct {
|
||||
tokenManager *authorization.Manager
|
||||
commandManager *command.Manager
|
||||
storedTunnels *storedtunnels.Manager
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (al *APIListener) Log() (l *logger.Logger) {
|
||||
al.mu.RLock()
|
||||
defer al.mu.RUnlock()
|
||||
|
||||
return al.Logger
|
||||
}
|
||||
|
||||
type UserService interface {
|
||||
|
||||
@ -81,8 +81,8 @@ func (e *Entry) WithClient(c *clients.Client) *Entry {
|
||||
return e
|
||||
}
|
||||
|
||||
e.ClientID = c.ID
|
||||
e.ClientHostName = c.Hostname
|
||||
e.ClientID = c.GetID()
|
||||
e.ClientHostName = c.GetHostname()
|
||||
return e
|
||||
}
|
||||
|
||||
@ -99,7 +99,7 @@ func (e *Entry) WithClientID(cid string) *Entry {
|
||||
return e
|
||||
}
|
||||
if client != nil {
|
||||
e.ClientHostName = client.Hostname
|
||||
e.ClientHostName = client.GetHostname()
|
||||
}
|
||||
|
||||
return e
|
||||
|
||||
@ -84,11 +84,12 @@ func TestWithResponse(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWithClient(t *testing.T) {
|
||||
e := emptyEntry().WithClient(&clients.Client{
|
||||
ID: "11236310-6cad-408e-b372-a0f04d68d2df",
|
||||
Address: "127.0.0.1",
|
||||
Hostname: "hostname",
|
||||
})
|
||||
c1 := clients.Client{}
|
||||
c1.SetID("11236310-6cad-408e-b372-a0f04d68d2df")
|
||||
c1.SetAddress("127.0.0.1")
|
||||
c1.SetHostname("hostname")
|
||||
|
||||
e := emptyEntry().WithClient(&c1)
|
||||
|
||||
assert.Equal(t, "11236310-6cad-408e-b372-a0f04d68d2df", e.ClientID)
|
||||
assert.Equal(t, "hostname", e.ClientHostName)
|
||||
@ -136,18 +137,17 @@ func TestSaveForMultipleClients(t *testing.T) {
|
||||
auditLog := enabledAuditLog()
|
||||
auditLog.provider = mockProvider
|
||||
|
||||
auditLog.Entry("", "").SaveForMultipleClients([]*clients.Client{
|
||||
{
|
||||
ID: "c1",
|
||||
Address: "c1.com",
|
||||
Hostname: "hostname1",
|
||||
},
|
||||
{
|
||||
ID: "c2",
|
||||
Address: "c2.com",
|
||||
Hostname: "hostname2",
|
||||
},
|
||||
})
|
||||
c1 := clients.Client{}
|
||||
c1.SetID("c1")
|
||||
c1.SetAddress("c1.com")
|
||||
c1.SetHostname("hostname1")
|
||||
|
||||
c2 := clients.Client{}
|
||||
c2.SetID("c2")
|
||||
c2.SetAddress("c2.com")
|
||||
c2.SetHostname("hostname2")
|
||||
|
||||
auditLog.Entry("", "").SaveForMultipleClients([]*clients.Client{&c1, &c2})
|
||||
|
||||
assert.Len(t, mockProvider.entries, 2)
|
||||
assert.Equal(t, "c1", mockProvider.entries[0].ClientID)
|
||||
@ -172,12 +172,13 @@ type mockClientGetter struct {
|
||||
}
|
||||
|
||||
func (mockClientGetter) GetByID(id string) (*clients.Client, error) {
|
||||
c1 := clients.Client{}
|
||||
c1.SetID("11236310-6cad-408e-b372-a0f04d68d2df")
|
||||
c1.SetAddress("127.0.0.1")
|
||||
c1.SetHostname("hostname")
|
||||
|
||||
if id == "11236310-6cad-408e-b372-a0f04d68d2df" {
|
||||
return &clients.Client{
|
||||
ID: "11236310-6cad-408e-b372-a0f04d68d2df",
|
||||
Address: "127.0.0.1",
|
||||
Hostname: "hostname",
|
||||
}, nil
|
||||
return &c1, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package caddy_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -86,6 +90,9 @@ func setupNewCaddyServer(ctx context.Context, t *testing.T) (cs *caddy.Server) {
|
||||
require.NoError(t, err)
|
||||
caddy.HostDomainSocket = bc.GlobalSettings.AdminSocket
|
||||
|
||||
// ensure the no existing admin socket file
|
||||
os.Remove(caddy.HostDomainSocket)
|
||||
|
||||
cs = caddy.NewCaddyServer(cfg, testLog)
|
||||
err = cs.Start(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"path"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -116,39 +117,40 @@ type LogConfig struct {
|
||||
}
|
||||
|
||||
type ServerConfig struct {
|
||||
ListenAddress string `mapstructure:"address"`
|
||||
URL []string `mapstructure:"url"`
|
||||
PairingURL string `mapstructure:"pairing_url"`
|
||||
TunnelHost string `mapstructure:"tunnel_host"`
|
||||
KeySeed string `mapstructure:"key_seed"`
|
||||
Auth string `mapstructure:"auth"`
|
||||
AuthFile string `mapstructure:"auth_file"`
|
||||
AuthTable string `mapstructure:"auth_table"`
|
||||
Proxy string `mapstructure:"proxy"`
|
||||
UsedPortsRaw []string `mapstructure:"used_ports"`
|
||||
ExcludedPortsRaw []string `mapstructure:"excluded_ports"`
|
||||
DataDir string `mapstructure:"data_dir"`
|
||||
SqliteWAL bool `mapstructure:"sqlite_wal"`
|
||||
PurgeDisconnectedClients bool `mapstructure:"purge_disconnected_clients"`
|
||||
CleanupLostClients bool `mapstructure:"cleanup_lost_clients" replaced_by:"PurgeDisconnectedClients"`
|
||||
KeepLostClients time.Duration `mapstructure:"keep_lost_clients" replaced_by:"KeepDisconnectedClients"`
|
||||
KeepDisconnectedClients time.Duration `mapstructure:"keep_disconnected_clients"`
|
||||
CleanupClientsInterval time.Duration `mapstructure:"cleanup_clients_interval" replaced_by:"PurgeDisconnectedClientsInterval"`
|
||||
PurgeDisconnectedClientsInterval time.Duration `mapstructure:"purge_disconnected_clients_interval"`
|
||||
CheckClientsConnectionInterval time.Duration `mapstructure:"check_clients_connection_interval"`
|
||||
CheckClientsConnectionTimeout time.Duration `mapstructure:"check_clients_connection_timeout"`
|
||||
MaxRequestBytesClient int64 `mapstructure:"max_request_bytes_client"`
|
||||
CheckPortTimeout time.Duration `mapstructure:"check_port_timeout"`
|
||||
RunRemoteCmdTimeoutSec int `mapstructure:"run_remote_cmd_timeout_sec"`
|
||||
AuthWrite bool `mapstructure:"auth_write"`
|
||||
AuthMultiuseCreds bool `mapstructure:"auth_multiuse_creds"`
|
||||
EquateClientauthidClientid bool `mapstructure:"equate_clientauthid_clientid"`
|
||||
AllowRoot bool `mapstructure:"allow_root"`
|
||||
ClientLoginWait float32 `mapstructure:"client_login_wait"`
|
||||
MaxFailedLogin int `mapstructure:"max_failed_login"`
|
||||
BanTime int `mapstructure:"ban_time"`
|
||||
InternalTunnelProxyConfig clienttunnel.InternalTunnelProxyConfig `mapstructure:",squash"`
|
||||
JobsMaxResults int `mapstructure:"jobs_max_results"`
|
||||
ListenAddress string `mapstructure:"address"`
|
||||
URL []string `mapstructure:"url"`
|
||||
PairingURL string `mapstructure:"pairing_url"`
|
||||
TunnelHost string `mapstructure:"tunnel_host"`
|
||||
KeySeed string `mapstructure:"key_seed"`
|
||||
Auth string `mapstructure:"auth"`
|
||||
AuthFile string `mapstructure:"auth_file"`
|
||||
AuthTable string `mapstructure:"auth_table"`
|
||||
Proxy string `mapstructure:"proxy"`
|
||||
UsedPortsRaw []string `mapstructure:"used_ports"`
|
||||
ExcludedPortsRaw []string `mapstructure:"excluded_ports"`
|
||||
DataDir string `mapstructure:"data_dir"`
|
||||
SqliteWAL bool `mapstructure:"sqlite_wal"`
|
||||
MaxConcurrentSSHConnectionHandshakes int `mapstructure:"max_concurrent_ssh_handshakes"`
|
||||
PurgeDisconnectedClients bool `mapstructure:"purge_disconnected_clients"`
|
||||
CleanupLostClients bool `mapstructure:"cleanup_lost_clients" replaced_by:"PurgeDisconnectedClients"`
|
||||
KeepLostClients time.Duration `mapstructure:"keep_lost_clients" replaced_by:"KeepDisconnectedClients"`
|
||||
KeepDisconnectedClients time.Duration `mapstructure:"keep_disconnected_clients"`
|
||||
CleanupClientsInterval time.Duration `mapstructure:"cleanup_clients_interval" replaced_by:"PurgeDisconnectedClientsInterval"`
|
||||
PurgeDisconnectedClientsInterval time.Duration `mapstructure:"purge_disconnected_clients_interval"`
|
||||
CheckClientsConnectionInterval time.Duration `mapstructure:"check_clients_connection_interval"`
|
||||
CheckClientsConnectionTimeout time.Duration `mapstructure:"check_clients_connection_timeout"`
|
||||
MaxRequestBytesClient int64 `mapstructure:"max_request_bytes_client"`
|
||||
CheckPortTimeout time.Duration `mapstructure:"check_port_timeout"`
|
||||
RunRemoteCmdTimeoutSec int `mapstructure:"run_remote_cmd_timeout_sec"`
|
||||
AuthWrite bool `mapstructure:"auth_write"`
|
||||
AuthMultiuseCreds bool `mapstructure:"auth_multiuse_creds"`
|
||||
EquateClientauthidClientid bool `mapstructure:"equate_clientauthid_clientid"`
|
||||
AllowRoot bool `mapstructure:"allow_root"`
|
||||
ClientLoginWait float32 `mapstructure:"client_login_wait"`
|
||||
MaxFailedLogin int `mapstructure:"max_failed_login"`
|
||||
BanTime int `mapstructure:"ban_time"`
|
||||
InternalTunnelProxyConfig clienttunnel.InternalTunnelProxyConfig `mapstructure:",squash"`
|
||||
JobsMaxResults int `mapstructure:"jobs_max_results"`
|
||||
|
||||
// DEPRECATED, only here for backwards compatibility
|
||||
MaxRequestBytes int64 `mapstructure:"max_request_bytes"`
|
||||
@ -379,6 +381,11 @@ func (c *Config) ParseAndValidate(mLog *logger.MemLogger) error {
|
||||
return err
|
||||
}
|
||||
|
||||
maxProcs := runtime.GOMAXPROCS(0)
|
||||
if c.Server.MaxConcurrentSSHConnectionHandshakes > (maxProcs * 2) {
|
||||
mLog.Infof("warning: allowing too many concurrent ssh handhakes ('max_concurrent_ssh_handshakes') will slow down the server significantly. Please use a value less than or equal to the MAX_PROCS (%d)", maxProcs)
|
||||
}
|
||||
|
||||
if c.Server.CheckClientsConnectionInterval < CheckClientsConnectionIntervalMinimum {
|
||||
c.Server.CheckClientsConnectionInterval = CheckClientsConnectionIntervalMinimum
|
||||
mLog.Errorf("'check_clients_status_interval' too fast. Using the minimum possible of %s", CheckClientsConnectionIntervalMinimum)
|
||||
|
||||
@ -609,7 +609,7 @@ func TestParseAndValidatePorts(t *testing.T) {
|
||||
UsedPortsRaw: []string{"45-50"},
|
||||
ExcludedPortsRaw: []string{"1-10", "44", "51", "80-90"},
|
||||
},
|
||||
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{45, 46, 47, 48, 49, 50}),
|
||||
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{45, 46, 47, 48, 49, 50}),
|
||||
},
|
||||
{
|
||||
Name: "used ports and excluded ports",
|
||||
@ -617,7 +617,7 @@ func TestParseAndValidatePorts(t *testing.T) {
|
||||
UsedPortsRaw: []string{"100-200", "205", "250-300", "305", "400-500"},
|
||||
ExcludedPortsRaw: []string{"80-110", "114-116", "118", "120-198", "200", "240-310", "305", "401-499"},
|
||||
},
|
||||
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{111, 112, 113, 117, 119, 199, 205, 400, 500}),
|
||||
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{111, 112, 113, 117, 119, 199, 205, 400, 500}),
|
||||
},
|
||||
{
|
||||
Name: "excluded ports empty",
|
||||
@ -625,7 +625,7 @@ func TestParseAndValidatePorts(t *testing.T) {
|
||||
UsedPortsRaw: []string{"45-46"},
|
||||
ExcludedPortsRaw: []string{},
|
||||
},
|
||||
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{45, 46}),
|
||||
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{45, 46}),
|
||||
},
|
||||
{
|
||||
Name: "one allowed port",
|
||||
@ -633,7 +633,7 @@ func TestParseAndValidatePorts(t *testing.T) {
|
||||
UsedPortsRaw: []string{"20000"},
|
||||
ExcludedPortsRaw: []string{},
|
||||
},
|
||||
ExpectedAllowedPorts: mapset.NewThreadUnsafeSetFromSlice([]interface{}{20000}),
|
||||
ExpectedAllowedPorts: mapset.NewSetFromSlice([]interface{}{20000}),
|
||||
},
|
||||
{
|
||||
Name: "both empty",
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -32,9 +33,17 @@ import (
|
||||
"github.com/cloudradar-monitoring/rport/share/security"
|
||||
)
|
||||
|
||||
const ConnectionRequestTimeOut = 5 * 60 * time.Second
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
}
|
||||
|
||||
type ClientListener struct {
|
||||
*logger.Logger
|
||||
*Server
|
||||
logger *logger.Logger
|
||||
server *Server
|
||||
|
||||
connStats chshare.ConnStats
|
||||
httpServer *chshare.HTTPServer
|
||||
@ -45,33 +54,46 @@ type ClientListener struct {
|
||||
bannedIPs *security.MaxBadAttemptsBanList
|
||||
|
||||
clientIndexAutoIncrement int32
|
||||
|
||||
// semaphore used to limit concurrent pending SSH connections
|
||||
inprogressSSHHandshakes chan struct{}
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
const SSHTimeOut = 90 * time.Second
|
||||
func (cl *ClientListener) Log() (l *logger.Logger) {
|
||||
cl.mu.RLock()
|
||||
defer cl.mu.RUnlock()
|
||||
return cl.logger
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool { return true },
|
||||
func (cl *ClientListener) GetClientService() (cs clients.ClientService) {
|
||||
cl.mu.RLock()
|
||||
defer cl.mu.RUnlock()
|
||||
return cl.server.clientService
|
||||
}
|
||||
|
||||
func NewClientListener(server *Server, privateKey ssh.Signer) (*ClientListener, error) {
|
||||
config := server.config
|
||||
|
||||
// semaphore to limit number of active pending SSH connections
|
||||
inprogressSSHHandshakes := make(chan struct{}, config.Server.MaxConcurrentSSHConnectionHandshakes)
|
||||
|
||||
clog := logger.NewLogger("client-listener", config.Logging.LogOutput, config.Logging.LogLevel)
|
||||
cl := &ClientListener{
|
||||
Server: server,
|
||||
httpServer: chshare.NewHTTPServer(int(config.Server.MaxRequestBytesClient), clog),
|
||||
Logger: clog,
|
||||
requestLogOptions: config.InitRequestLogOptions(),
|
||||
bannedClientAuths: security.NewBanList(time.Duration(config.Server.ClientLoginWait) * time.Second),
|
||||
server: server,
|
||||
httpServer: chshare.NewHTTPServer(int(config.Server.MaxRequestBytesClient), clog),
|
||||
requestLogOptions: config.InitRequestLogOptions(),
|
||||
bannedClientAuths: security.NewBanList(time.Duration(config.Server.ClientLoginWait) * time.Second),
|
||||
inprogressSSHHandshakes: inprogressSSHHandshakes,
|
||||
logger: clog,
|
||||
}
|
||||
|
||||
if config.Server.MaxFailedLogin > 0 && config.Server.BanTime > 0 {
|
||||
cl.bannedIPs = security.NewMaxBadAttemptsBanList(
|
||||
config.Server.MaxFailedLogin,
|
||||
time.Duration(config.Server.BanTime)*time.Second,
|
||||
cl.Logger,
|
||||
cl.logger,
|
||||
)
|
||||
}
|
||||
|
||||
@ -80,7 +102,9 @@ func NewClientListener(server *Server, privateKey ssh.Signer) (*ClientListener,
|
||||
ServerVersion: "SSH-" + chshare.ProtocolVersion + "-server",
|
||||
PasswordCallback: cl.authUser,
|
||||
}
|
||||
|
||||
cl.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
//setup reverse proxy
|
||||
if config.Server.Proxy != "" {
|
||||
u, err := url.Parse(config.Server.Proxy)
|
||||
@ -107,15 +131,15 @@ func (cl *ClientListener) authUser(c ssh.ConnMetadata, password []byte) (*ssh.Pe
|
||||
clientAuthID := c.User()
|
||||
|
||||
if cl.bannedClientAuths.IsBanned(clientAuthID) {
|
||||
cl.Infof("Failed login attempt for client auth id %q, forcing to wait for %vs (%s)",
|
||||
cl.Log().Infof("Failed login attempt for client auth id %q, forcing to wait for %vs (%s)",
|
||||
clientAuthID,
|
||||
cl.config.Server.ClientLoginWait,
|
||||
cl.server.config.Server.ClientLoginWait,
|
||||
cl.getIP(c.RemoteAddr()),
|
||||
)
|
||||
return nil, ErrTooManyRequests
|
||||
}
|
||||
|
||||
clientAuth, err := cl.clientAuthProvider.Get(clientAuthID)
|
||||
clientAuth, err := cl.server.clientAuthProvider.Get(clientAuthID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -123,7 +147,7 @@ func (cl *ClientListener) authUser(c ssh.ConnMetadata, password []byte) (*ssh.Pe
|
||||
ip := cl.getIP(c.RemoteAddr())
|
||||
// constant time compare is used for security reasons
|
||||
if clientAuth == nil || subtle.ConstantTimeCompare([]byte(clientAuth.Password), password) != 1 {
|
||||
cl.Debugf("Login failed for client auth id: %s", clientAuthID)
|
||||
cl.Log().Debugf("Login failed for client auth id: %s", clientAuthID)
|
||||
cl.bannedClientAuths.Add(clientAuthID)
|
||||
if cl.bannedIPs != nil {
|
||||
cl.bannedIPs.AddBadAttempt(ip)
|
||||
@ -141,20 +165,21 @@ func (cl *ClientListener) getIP(addr net.Addr) string {
|
||||
addrStr := addr.String()
|
||||
host, _, err := net.SplitHostPort(addrStr)
|
||||
if err != nil {
|
||||
cl.Errorf("failed to split host port for %q: %v", addr, err)
|
||||
cl.Log().Errorf("failed to split host port for %q: %v", addr, err)
|
||||
return addrStr
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
func (cl *ClientListener) Start(ctx context.Context, listenAddr string) error {
|
||||
cl.Debugf("Client listener starting...")
|
||||
clLogger := cl.Log()
|
||||
clLogger.Debugf("Client listener starting...")
|
||||
if cl.reverseProxy != nil {
|
||||
cl.Infof("Reverse proxy enabled")
|
||||
clLogger.Infof("Reverse proxy enabled")
|
||||
}
|
||||
cl.Infof("Listening on %s...", listenAddr)
|
||||
clLogger.Infof("Listening on %s...", listenAddr)
|
||||
|
||||
h := http.Handler(middleware.MaxBytes(http.HandlerFunc(cl.handleClient), cl.config.Server.MaxRequestBytesClient))
|
||||
h := http.Handler(middleware.MaxBytes(http.HandlerFunc(cl.handleClient), cl.server.config.Server.MaxRequestBytesClient))
|
||||
if cl.bannedIPs != nil {
|
||||
h = security.RejectBannedIPs(cl.bannedIPs)(h)
|
||||
}
|
||||
@ -174,7 +199,7 @@ func (cl *ClientListener) Close() error {
|
||||
}
|
||||
|
||||
func (cl *ClientListener) handleClient(w http.ResponseWriter, r *http.Request) {
|
||||
cl.Debugf("Incoming client connection...")
|
||||
cl.Log().Debugf("Incoming client connection...")
|
||||
//websockets upgrade AND has rport prefix
|
||||
upgrade := strings.ToLower(r.Header.Get("Upgrade"))
|
||||
protocol := r.Header.Get("Sec-WebSocket-Protocol")
|
||||
@ -184,7 +209,7 @@ func (cl *ClientListener) handleClient(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
//print into server logs and silently fall-through
|
||||
cl.Infof("ignored client connection using protocol '%s', expected '%s'",
|
||||
cl.Log().Infof("ignored client connection using protocol '%s', expected '%s'",
|
||||
protocol, chshare.ProtocolVersion)
|
||||
}
|
||||
//proxy target was provided
|
||||
@ -201,55 +226,90 @@ func (cl *ClientListener) nextClientIndex() int32 {
|
||||
return atomic.AddInt32(&cl.clientIndexAutoIncrement, 1)
|
||||
}
|
||||
|
||||
// handleWebsocket is responsible for handling the websocket connection
|
||||
func (cl *ClientListener) handleWebsocket(w http.ResponseWriter, req *http.Request) {
|
||||
ts := time.Now()
|
||||
clog := cl.Fork("client#%d", cl.nextClientIndex())
|
||||
func (cl *ClientListener) acceptSSHConnection(w http.ResponseWriter, req *http.Request) (sshConn *ssh.ServerConn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, clog *logger.Logger, err error) {
|
||||
// add to pending connections. will block if the chan is full
|
||||
cl.inprogressSSHHandshakes <- struct{}{}
|
||||
|
||||
clog = cl.Log().Fork("client#%d", cl.nextClientIndex())
|
||||
|
||||
clog.Debugf("Handling inbound web socket connection...")
|
||||
ts := time.Now()
|
||||
|
||||
wsConn, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
clog.Debugf("Failed to upgrade (%s)", err)
|
||||
return
|
||||
<-cl.inprogressSSHHandshakes
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
conn := chshare.NewWebSocketConn(wsConn)
|
||||
// perform SSH handshake on net.Conn
|
||||
clog.Debugf("SSH Handshaking...")
|
||||
sshConn, chans, reqs, err := ssh.NewServerConn(conn, cl.sshConfig)
|
||||
sshConn, chans, reqs, err = ssh.NewServerConn(conn, cl.sshConfig)
|
||||
if err != nil {
|
||||
cl.Debugf("Failed to handshake (%s) from %s", err, conn.RemoteAddr().String())
|
||||
return
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
clog.Debugf("Failed to handshake (client closed connection? - %s) from %s", err, conn.RemoteAddr().String())
|
||||
} else {
|
||||
clog.Debugf("Failed to handshake (%s) from %s", err, conn.RemoteAddr().String())
|
||||
}
|
||||
<-cl.inprogressSSHHandshakes
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
clog.Debugf("Handshake finished after %s", time.Since(ts))
|
||||
clog.Debugf("SSH Handshake finished after %s", time.Since(ts))
|
||||
|
||||
//verify configuration
|
||||
clog.Debugf("Verifying configuration...")
|
||||
//wait for request, with timeout
|
||||
var r *ssh.Request
|
||||
// on handshake finished, remove from pending connections, which will allow another connection to take place
|
||||
<-cl.inprogressSSHHandshakes
|
||||
|
||||
return sshConn, chans, reqs, clog, err
|
||||
}
|
||||
|
||||
func (cl *ClientListener) receiveClientConnectionRequest(sshConn *ssh.ServerConn, reqs <-chan *ssh.Request, clog *logger.Logger) (connRequest *chshare.ConnectionRequest, r *ssh.Request, err error) {
|
||||
pendingRequestTimer := time.NewTimer(ConnectionRequestTimeOut)
|
||||
select {
|
||||
case r = <-reqs:
|
||||
case <-time.After(SSHTimeOut):
|
||||
clog.Debugf("SSH connection timeout exceeded %s sec", SSHTimeOut.Seconds())
|
||||
pendingRequestTimer.Stop()
|
||||
|
||||
case <-time.After(ConnectionRequestTimeOut):
|
||||
clog.Debugf("SSH connection request timeout exceeded %s sec", ConnectionRequestTimeOut.Seconds())
|
||||
err = sshConn.Close()
|
||||
if err != nil {
|
||||
clog.Debugf("error on SSH connection close: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
failed := func(err error) {
|
||||
clog.Debugf("Failed: %s", err)
|
||||
cl.replyConnectionError(r, err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if r.Type != "new_connection" {
|
||||
failed(errors.New("expecting connection request"))
|
||||
return
|
||||
return nil, nil, errors.New("expecting connection request")
|
||||
}
|
||||
if len(r.Payload) > int(cl.config.Server.MaxRequestBytesClient) {
|
||||
failed(fmt.Errorf("request data exceeds the limit of %d bytes, actual size: %d", cl.config.Server.MaxRequestBytesClient, len(r.Payload)))
|
||||
return
|
||||
|
||||
if len(r.Payload) > int(cl.server.config.Server.MaxRequestBytesClient) {
|
||||
return nil, nil, fmt.Errorf("request data exceeds the limit of %d bytes, actual size: %d", cl.server.config.Server.MaxRequestBytesClient, len(r.Payload))
|
||||
}
|
||||
connRequest, err := chshare.DecodeConnectionRequest(r.Payload)
|
||||
|
||||
connRequest, err = chshare.DecodeConnectionRequest(r.Payload)
|
||||
if err != nil {
|
||||
failed(fmt.Errorf("invalid connection request: %s", err))
|
||||
return nil, nil, fmt.Errorf("invalid connection request: %s", err)
|
||||
}
|
||||
|
||||
return connRequest, r, nil
|
||||
}
|
||||
|
||||
// handleWebsocket is responsible for handling the websocket connection
|
||||
func (cl *ClientListener) handleWebsocket(w http.ResponseWriter, req *http.Request) {
|
||||
// keep the time from the initial client connection attempt
|
||||
ts1 := time.Now()
|
||||
|
||||
sshConn, chans, reqs, clog, err := cl.acceptSSHConnection(w, req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// verify configuration
|
||||
clog.Debugf("Verifying configuration...")
|
||||
|
||||
// first request to be received must be a connection request
|
||||
connRequest, r, err := cl.receiveClientConnectionRequest(sshConn, reqs, clog)
|
||||
if err != nil {
|
||||
cl.replyConnectionError(r, err)
|
||||
return
|
||||
}
|
||||
|
||||
@ -260,40 +320,46 @@ func (cl *ClientListener) handleWebsocket(w http.ResponseWriter, req *http.Reque
|
||||
// get the current client auth id
|
||||
clientAuthID := sshConn.User()
|
||||
|
||||
// client id
|
||||
cid, err := cl.getCID(connRequest.ID, cl.config, clientAuthID)
|
||||
clientID, err := cl.getClientID(connRequest.ID, cl.server.config, clientAuthID)
|
||||
if err != nil {
|
||||
failed(fmt.Errorf("could not get cid: %s", err))
|
||||
cl.replyConnectionError(r, fmt.Errorf("could not get clientID: %s", err))
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: (rs): shouldn't this ctx be based on the server start ctx?
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ts = time.Now()
|
||||
client, err := cl.clientService.StartClient(ctx, clientAuthID, cid, sshConn, cl.config.Server.AuthMultiuseCreds, connRequest, clog)
|
||||
client, err := cl.GetClientService().StartClient(ctx, clientAuthID, clientID, sshConn, cl.server.config.Server.AuthMultiuseCreds, connRequest, clog)
|
||||
if err != nil {
|
||||
failed(err)
|
||||
cl.replyConnectionError(r, err)
|
||||
return
|
||||
}
|
||||
clog.Debugf("Client service started for %s (%s) within %s", client.ID, client.Name, time.Since(ts))
|
||||
clog.Debugf("Client service started for %s (%s) within %s", client.GetID(), client.GetName(), time.Since(ts1))
|
||||
|
||||
ts2 := time.Now()
|
||||
|
||||
cl.replyConnectionSuccess(r, connRequest.Remotes)
|
||||
cl.sendCapabilities(sshConn)
|
||||
// Now the client is fully connected and ready to create tunnels and execute command and scripts
|
||||
|
||||
clientBanner := client.Banner()
|
||||
clog.Debugf("opened %s within %s", clientBanner, time.Since(ts))
|
||||
go cl.handleSSHRequests(clog, cid, reqs)
|
||||
clog.Debugf("opened %s within %s", clientBanner, time.Since(ts2))
|
||||
|
||||
// now run handler for other client requests and connections
|
||||
// TODO: (rs): shouldn't these also use the server start ctx
|
||||
go cl.handleSSHRequests(clog, clientID, reqs)
|
||||
go cl.handleSSHChannels(clog, chans)
|
||||
|
||||
// wait until we're disconnected from the client
|
||||
if err = sshConn.Wait(); err != nil {
|
||||
clog.Debugf("sshConn.Wait() error: %s", err)
|
||||
}
|
||||
clog.Debugf("close %s", clientBanner)
|
||||
|
||||
err = cl.clientService.Terminate(client)
|
||||
err = cl.GetClientService().Terminate(client)
|
||||
if err != nil {
|
||||
cl.Errorf("could not terminate client: %s", err)
|
||||
cl.Log().Errorf("could not terminate client: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -311,7 +377,7 @@ func checkVersions(log *logger.Logger, clientVersion string) {
|
||||
log.Infof("Client version (%s) differs from server version (%s)", v, chshare.BuildVersion)
|
||||
}
|
||||
|
||||
func (cl *ClientListener) getCID(reqID string, config *chconfig.Config, clientAuthID string) (string, error) {
|
||||
func (cl *ClientListener) getClientID(reqID string, config *chconfig.Config, clientAuthID string) (string, error) {
|
||||
if reqID != "" {
|
||||
return reqID, nil
|
||||
}
|
||||
@ -327,7 +393,7 @@ func (cl *ClientListener) getCID(reqID string, config *chconfig.Config, clientAu
|
||||
func (cl *ClientListener) replyConnectionSuccess(r *ssh.Request, remotes []*models.Remote) {
|
||||
replyPayload, err := json.Marshal(remotes)
|
||||
if err != nil {
|
||||
cl.Errorf("can't encode success reply payload")
|
||||
cl.Log().Errorf("can't encode success reply payload")
|
||||
cl.replyConnectionError(r, err)
|
||||
return
|
||||
}
|
||||
@ -336,24 +402,68 @@ func (cl *ClientListener) replyConnectionSuccess(r *ssh.Request, remotes []*mode
|
||||
}
|
||||
|
||||
func (cl *ClientListener) replyConnectionError(r *ssh.Request, err error) {
|
||||
if r == nil {
|
||||
cl.Log().Errorf("failed to send connection reply with error due to nil request")
|
||||
return
|
||||
}
|
||||
if err == nil {
|
||||
cl.Log().Debugf("sending connection reply with nil error: %s", r.Type)
|
||||
_ = r.Reply(false, nil)
|
||||
return
|
||||
}
|
||||
_ = r.Reply(false, []byte(err.Error()))
|
||||
}
|
||||
|
||||
func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID string, reqs <-chan *ssh.Request) {
|
||||
clientService := cl.GetClientService()
|
||||
|
||||
for r := range reqs {
|
||||
if len(r.Payload) > int(cl.config.Server.MaxRequestBytesClient) {
|
||||
clientLog.Errorf("%s:request data exceeds the limit of %d bytes, actual size: %d", comm.RequestTypeSaveMeasurement, cl.config.Server.MaxRequestBytesClient, len(r.Payload))
|
||||
if len(r.Payload) > int(cl.server.config.Server.MaxRequestBytesClient) {
|
||||
clientLog.Errorf("%s:request data exceeds the limit of %d bytes, actual size: %d", comm.RequestTypeSaveMeasurement, cl.server.config.Server.MaxRequestBytesClient, len(r.Payload))
|
||||
continue
|
||||
}
|
||||
|
||||
// clientLog.Debugf("received request: %s from %s", r.Type, clientID)
|
||||
|
||||
// TODO: (rs): these case handlers should be refactored into individual handling fns
|
||||
switch r.Type {
|
||||
|
||||
// we shouldn't be receiving this. it means the client didn't receive the server's reply
|
||||
// to the previous connection request, so ask the client to reconnect.
|
||||
case "new_connection":
|
||||
clientLog.Debugf("received connection request on existing connection. asking the client to reconnect.")
|
||||
// IMPORTANT: the client is checking for the word "reconnect" in reply errors
|
||||
cl.replyConnectionError(r, errors.New("unexpected connection request. please reconnect"))
|
||||
client, err := clientService.GetRepo().GetActiveByID(clientID)
|
||||
if err != nil {
|
||||
clientLog.Debugf("unable to get client: %v", err)
|
||||
continue
|
||||
}
|
||||
if client == nil {
|
||||
clientLog.Debugf("client not found: %v", err)
|
||||
continue
|
||||
}
|
||||
clientLog.Debugf("terminating client due for reconnect")
|
||||
err = clientService.Terminate(client)
|
||||
if err != nil {
|
||||
clientLog.Debugf("failed to terminate client due for reconnect: %v", err)
|
||||
}
|
||||
continue
|
||||
|
||||
case comm.RequestTypePing:
|
||||
// clientLog.Debugf("ping received from: %s", clientID)
|
||||
// ts := time.Now()
|
||||
_ = r.Reply(true, nil)
|
||||
err := cl.clientService.SetLastHeartbeat(clientID, time.Now())
|
||||
err := clientService.SetLastHeartbeat(clientID, time.Now())
|
||||
if err != nil {
|
||||
clientLog.Errorf("Failed to save heartbeat: %s", err)
|
||||
continue
|
||||
}
|
||||
// clientLog.Debugf("ping for: %s done in %s", clientID, time.Since(ts))
|
||||
|
||||
case comm.RequestTypeCmdResult:
|
||||
clientLog.Debugf("saving command result from: %s", clientID)
|
||||
|
||||
job, err := cl.saveCmdResult(r.Payload)
|
||||
if err != nil {
|
||||
clientLog.Errorf("Failed to save cmd result: %s", err)
|
||||
@ -363,9 +473,9 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
|
||||
|
||||
var auditLogEntry *auditlog.Entry
|
||||
if job.IsScript {
|
||||
auditLogEntry = cl.auditLog.Entry(auditlog.ApplicationClientScript, auditlog.ActionExecuteDone)
|
||||
auditLogEntry = cl.server.auditLog.Entry(auditlog.ApplicationClientScript, auditlog.ActionExecuteDone)
|
||||
} else {
|
||||
auditLogEntry = cl.auditLog.Entry(auditlog.ApplicationClientCommand, auditlog.ActionExecuteDone)
|
||||
auditLogEntry = cl.server.auditLog.Entry(auditlog.ApplicationClientCommand, auditlog.ActionExecuteDone)
|
||||
}
|
||||
if job.MultiJobID != nil {
|
||||
auditLogEntry.WithID(*job.MultiJobID)
|
||||
@ -378,7 +488,7 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
|
||||
Save()
|
||||
|
||||
if job.MultiJobID != nil {
|
||||
done := cl.jobsDoneChannel.Get(*job.MultiJobID)
|
||||
done := cl.server.jobsDoneChannel.Get(*job.MultiJobID)
|
||||
if done != nil {
|
||||
// to avoid blocking the exec - send job result in a new goroutine
|
||||
go func(done2 chan *models.Job, job2 *models.Job) {
|
||||
@ -386,21 +496,24 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
|
||||
}(done, job)
|
||||
}
|
||||
}
|
||||
|
||||
case comm.RequestTypeUpdatesStatus:
|
||||
clientLog.Debugf("setting updates status from: %s", clientID)
|
||||
updatesStatus := &models.UpdatesStatus{}
|
||||
err := json.Unmarshal(r.Payload, updatesStatus)
|
||||
if err != nil {
|
||||
clientLog.Errorf("Failed to unmarshal updates status: %s", err)
|
||||
continue
|
||||
}
|
||||
err = cl.clientService.SetUpdatesStatus(clientID, updatesStatus)
|
||||
err = clientService.SetUpdatesStatus(clientID, updatesStatus)
|
||||
if err != nil {
|
||||
clientLog.Errorf("Failed to save updates status: %s", err)
|
||||
continue
|
||||
}
|
||||
|
||||
case comm.RequestTypeSaveMeasurement:
|
||||
// if server monitoring is disabled then do not save measurements even if received
|
||||
if !cl.Server.config.Monitoring.Enabled {
|
||||
if !cl.server.config.Monitoring.Enabled {
|
||||
clientLog.Errorf("Received measurement when monitoring disabled. Measurement not saved.")
|
||||
continue
|
||||
}
|
||||
@ -412,7 +525,7 @@ func (cl *ClientListener) handleSSHRequests(clientLog *logger.Logger, clientID s
|
||||
continue
|
||||
}
|
||||
measurement.ClientID = clientID
|
||||
err = cl.monitoringService.SaveMeasurement(context.Background(), measurement)
|
||||
err = cl.server.monitoringService.SaveMeasurement(context.Background(), measurement)
|
||||
if err != nil {
|
||||
clientLog.Errorf("Failed to save measurement for client %s: %s", clientID, err)
|
||||
continue
|
||||
@ -437,18 +550,18 @@ func (cl *ClientListener) saveCmdResult(respBytes []byte) (*models.Job, error) {
|
||||
} else {
|
||||
wsJID = resp.JID
|
||||
}
|
||||
ws := cl.Server.uiJobWebSockets.Get(wsJID)
|
||||
ws := cl.server.uiJobWebSockets.Get(wsJID)
|
||||
if ws != nil {
|
||||
err := ws.WriteMessage(websocket.TextMessage, respBytes)
|
||||
if err != nil {
|
||||
cl.Errorf("%s, failed to write message to UI Web Socket: %v", resp.LogPrefix(), err)
|
||||
cl.Log().Errorf("%s, failed to write message to UI Web Socket: %v", resp.LogPrefix(), err)
|
||||
// proceed further
|
||||
}
|
||||
} else {
|
||||
cl.Debugf("%s, WS conn not found when saving command result. No active listeners connected", resp.LogPrefix())
|
||||
cl.Log().Debugf("%s, WS conn not found when saving command result. No active listeners connected", resp.LogPrefix())
|
||||
}
|
||||
|
||||
err = cl.jobProvider.SaveJob(&resp)
|
||||
err = cl.server.jobProvider.SaveJob(&resp)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save job result: %s", err)
|
||||
}
|
||||
@ -458,6 +571,7 @@ func (cl *ClientListener) saveCmdResult(respBytes []byte) (*models.Job, error) {
|
||||
|
||||
func (cl *ClientListener) handleSSHChannels(clientLog *logger.Logger, chans <-chan ssh.NewChannel) {
|
||||
for ch := range chans {
|
||||
ch := ch
|
||||
extraData := string(ch.ExtraData())
|
||||
stream, reqs, err := ch.Accept()
|
||||
if err != nil {
|
||||
@ -510,7 +624,7 @@ func (cl *ClientListener) handleOutputChannel(typ string, jobData []byte, client
|
||||
wsJID = job.JID
|
||||
}
|
||||
|
||||
ws := cl.Server.uiJobWebSockets.Get(wsJID)
|
||||
ws := cl.server.uiJobWebSockets.Get(wsJID)
|
||||
|
||||
ocd := outputChannelData{
|
||||
JID: job.JID,
|
||||
@ -591,13 +705,13 @@ func (cl *ClientListener) handleSessionChannel(stream ssh.Channel, clientLog *lo
|
||||
}
|
||||
|
||||
func (cl *ClientListener) sendCapabilities(conn *ssh.ServerConn) {
|
||||
payload, err := json.Marshal(cl.Server.capabilities)
|
||||
payload, err := json.Marshal(cl.server.capabilities)
|
||||
if err != nil {
|
||||
cl.Errorf("can't encode capabilities payload")
|
||||
cl.Log().Errorf("can't encode capabilities payload")
|
||||
return
|
||||
}
|
||||
|
||||
if _, _, err = conn.SendRequest(comm.RequestTypePutCapabilities, false, payload); err != nil {
|
||||
cl.Errorf("can't send capabilities: %v", err)
|
||||
cl.Log().Errorf("can't send capabilities: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -18,10 +18,10 @@ import (
|
||||
|
||||
func TestHandleOutputChannel(t *testing.T) {
|
||||
log := logger.NewLogger("client-listener-test", logger.LogOutput{File: os.Stdout}, logger.LogLevelDebug)
|
||||
cl := &ClientListener{Server: &Server{uiJobWebSockets: ws.NewWebSocketCache()}}
|
||||
cl := &ClientListener{server: &Server{uiJobWebSockets: ws.NewWebSocketCache()}}
|
||||
mockConn := &connMock{}
|
||||
ws := ws.NewConcurrentWebSocket(mockConn, log)
|
||||
cl.Server.uiJobWebSockets.Set("test-jid", ws)
|
||||
cl.server.uiJobWebSockets.Set("test-jid", ws)
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
|
||||
@ -9,10 +9,12 @@ import (
|
||||
"github.com/cloudradar-monitoring/rport/share/logger"
|
||||
)
|
||||
|
||||
const DefaultMaxWorkers = 100
|
||||
|
||||
type ClientsStatusCheckTask struct {
|
||||
log *logger.Logger
|
||||
cr *clients.ClientRepository
|
||||
th time.Duration // Threshold after which a client to server ping is considered outdated.
|
||||
clientRepo *clients.ClientRepository
|
||||
threshold time.Duration // Threshold after which a client to server ping is considered outdated.
|
||||
pingTimeout time.Duration // Don't wait longer than pingTimeout for a response
|
||||
}
|
||||
|
||||
@ -20,47 +22,52 @@ type ClientsStatusCheckTask struct {
|
||||
func NewClientsStatusCheckTask(log *logger.Logger, cr *clients.ClientRepository, th time.Duration, pingTimeout time.Duration) *ClientsStatusCheckTask {
|
||||
return &ClientsStatusCheckTask{
|
||||
log: log.Fork("clients-status-check"),
|
||||
cr: cr,
|
||||
th: th,
|
||||
clientRepo: cr,
|
||||
threshold: th,
|
||||
pingTimeout: pingTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ClientsStatusCheckTask) Run(ctx context.Context) error {
|
||||
t.log.Debugf("running")
|
||||
timerStart := time.Now()
|
||||
var dueClients []*clients.Client
|
||||
var confirmedClients = 0
|
||||
var now = time.Now()
|
||||
for _, c := range t.cr.GetAllActive() {
|
||||
// Shorten the threshold aka make heartbeat older than it is because the ping response is stored after this check.
|
||||
// Clients would get checked only every second time otherwise.
|
||||
if c.LastHeartbeatAt != nil && now.Sub(*c.LastHeartbeatAt) < t.th-10*time.Second {
|
||||
// Skip all clients having sent a heartbeat from client to server recently
|
||||
confirmedClients++
|
||||
continue
|
||||
}
|
||||
dueClients = append(dueClients, c)
|
||||
}
|
||||
|
||||
dueClients := t.getDueClients()
|
||||
if len(dueClients) == 0 {
|
||||
// Nothing to do
|
||||
t.log.Debugf("ended after %s, no clients to ping", time.Since(timerStart))
|
||||
return nil
|
||||
}
|
||||
maxWorkers := 100
|
||||
|
||||
// make sure no more workers than clients and limit to max workers
|
||||
maxWorkers := DefaultMaxWorkers
|
||||
if maxWorkers > len(dueClients) {
|
||||
maxWorkers = len(dueClients)
|
||||
}
|
||||
|
||||
// make a channel that will receive all the clients to ping
|
||||
clientsToPing := make(chan *clients.Client, len(dueClients))
|
||||
results := make(chan bool, len(dueClients))
|
||||
|
||||
// create workers to ping clients
|
||||
for w := 1; w <= maxWorkers; w++ {
|
||||
go t.PingClients(clientsToPing, results)
|
||||
go t.PingClients(w, clientsToPing, results)
|
||||
}
|
||||
|
||||
// send the clients to ping to the workers
|
||||
for _, dueClient := range dueClients {
|
||||
clientsToPing <- dueClient
|
||||
}
|
||||
|
||||
// we're done queuing clients for processing, so close the channel
|
||||
close(clientsToPing)
|
||||
|
||||
// gather the results of pinged clients
|
||||
var dead = 0
|
||||
var alive = 0
|
||||
// TODO: (rs): note this is fragile. any mismatch between actual and expected results will cause
|
||||
// the task to block and essential hang. also there's no ctx checking.
|
||||
for a := 0; a < len(dueClients); a++ {
|
||||
if <-results {
|
||||
alive++
|
||||
@ -68,30 +75,73 @@ func (t *ClientsStatusCheckTask) Run(ctx context.Context) error {
|
||||
dead++
|
||||
}
|
||||
}
|
||||
|
||||
t.log.Debugf("ended after %s, skipped: %d, pinged: %d, alive: %d, dead: %d", time.Since(timerStart), confirmedClients, len(dueClients), alive, dead)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *ClientsStatusCheckTask) PingClients(clientsToPing <-chan *clients.Client, results chan<- bool) {
|
||||
func (t *ClientsStatusCheckTask) getDueClients() (dueClients []*clients.Client) {
|
||||
var confirmedClients = 0
|
||||
var now = time.Now()
|
||||
activeClients, _ := t.clientRepo.GetAllActiveClients()
|
||||
for _, c := range activeClients {
|
||||
// Shorten the threshold aka make heartbeat older than it is because the ping response is stored after this check.
|
||||
// Clients would get checked only every second time otherwise.
|
||||
if c.HasLastHeartbeatAt() {
|
||||
lastHeartbeatAt := c.GetLastHeartbeatAtValue()
|
||||
if now.Sub(lastHeartbeatAt) < t.threshold-(10*time.Second) {
|
||||
// Skip all clients having sent a heartbeat from client to server recently
|
||||
// t.log.Debugf("skipping client: %s, %s, %s", c.GetID(), lastHeartbeatAt, now.Sub(lastHeartbeatAt) < t.threshold-(10*time.Second))
|
||||
confirmedClients++
|
||||
continue
|
||||
}
|
||||
}
|
||||
dueClients = append(dueClients, c)
|
||||
}
|
||||
return dueClients
|
||||
}
|
||||
|
||||
func (t *ClientsStatusCheckTask) PingClients(workerNum int, clientsToPing <-chan *clients.Client, results chan<- bool) {
|
||||
// while there are clients to ping
|
||||
for cl := range clientsToPing {
|
||||
ok, response, rtt, err := comm.PingConnectionWithTimeout(cl.Connection, t.pingTimeout)
|
||||
clientName := cl.GetName()
|
||||
clientID := cl.GetID()
|
||||
ok, response, rtt, err := comm.PingConnectionWithTimeout(cl.GetConnection(), t.pingTimeout, cl.Log())
|
||||
//t.log.Debugf("ok=%s, error=%s, response=%s", ok, err, response)
|
||||
//Old clients cannot respond properly to a ping request yet
|
||||
if !ok && err == nil && string(response) == "unknown request" {
|
||||
t.log.Debugf("ping to %s [%s] succeeded in %s. client < 0.8.2", cl.Name, cl.ID, rtt)
|
||||
|
||||
// Old clients cannot respond properly to a ping request yet
|
||||
if !ok && err == nil && t.isLegacyClientResponse(response) {
|
||||
t.log.Debugf("ping to %s [%s] succeeded in %s. client < 0.8.2", clientName, clientID, rtt)
|
||||
cl.SetHeartbeatNow()
|
||||
results <- true
|
||||
continue
|
||||
}
|
||||
|
||||
// client versions from 0.9.2 to 0.9.6 can return "null" as a ping response. this is due to a bug
|
||||
// in the client ping handling that cause 2 replies to be sent by the client. this breaks stuff.
|
||||
// for the server, assume the null reply is a successful ping. unfortunately, the extra reply
|
||||
// confuses the next send by the client which means it won't get a reply from the server and
|
||||
// will ultimately disconnect and reconnect. the work around is to make sure that the client
|
||||
// pings the server faster than the server pings the client. as the server has a recent heartbeat
|
||||
// already (from the client) it won't ping the client again, meaning that the client won't get a
|
||||
// chance to double reply to the server and cause ssh protocol confusion.
|
||||
if ok && err == nil && string(response) == "null" {
|
||||
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2 *", clientName, clientID, rtt)
|
||||
cl.SetHeartbeatNow()
|
||||
results <- true
|
||||
continue
|
||||
}
|
||||
|
||||
// Only an empty response confirms the ping
|
||||
if ok && err == nil && len(response) == 0 {
|
||||
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2", cl.Name, cl.ID, rtt)
|
||||
t.log.Debugf("ping to %s [%s] succeeded in %s. client >= 0.8.2", clientName, clientID, rtt)
|
||||
cl.SetHeartbeatNow()
|
||||
results <- true
|
||||
continue
|
||||
}
|
||||
|
||||
// None of the above. Ping must have failed or timed out.
|
||||
t.log.Infof("ping to %s [%s] failed: %s", cl.Name, cl.ID, err)
|
||||
t.log.Infof("ping to %s [%s] failed: %s", clientName, clientID, err)
|
||||
|
||||
cl.SetDisconnectedNow()
|
||||
|
||||
@ -99,3 +149,7 @@ func (t *ClientsStatusCheckTask) PingClients(clientsToPing <-chan *clients.Clien
|
||||
results <- false
|
||||
}
|
||||
}
|
||||
|
||||
func (t *ClientsStatusCheckTask) isLegacyClientResponse(response []byte) (isLegacy bool) {
|
||||
return string(response) == "unknown request"
|
||||
}
|
||||
|
||||
@ -65,58 +65,58 @@ func TestClientsStatusDeterminationTask(t *testing.T) {
|
||||
timeout = 1 * time.Millisecond
|
||||
)
|
||||
now := time.Now()
|
||||
cr := clients.NewClientRepository([]*clients.Client{
|
||||
{
|
||||
ID: "1",
|
||||
ClientAuthID: "1",
|
||||
Connection: connSuccess,
|
||||
},
|
||||
{
|
||||
ID: "2",
|
||||
ClientAuthID: "2",
|
||||
Connection: connSuccess,
|
||||
LastHeartbeatAt: &now,
|
||||
},
|
||||
{
|
||||
ID: "3",
|
||||
ClientAuthID: "3",
|
||||
Connection: connFailure,
|
||||
},
|
||||
{
|
||||
ID: "4",
|
||||
ClientAuthID: "4",
|
||||
Connection: connTimeout,
|
||||
},
|
||||
}, nil, myTestLog)
|
||||
|
||||
c1 := clients.Client{}
|
||||
c1.SetID("1")
|
||||
c1.SetClientAuthID("1")
|
||||
c1.SetConnection(connSuccess)
|
||||
|
||||
c2 := clients.Client{}
|
||||
c2.SetID("2")
|
||||
c2.SetClientAuthID("2")
|
||||
c2.SetConnection(connSuccess)
|
||||
c2.SetLastHeartbeatAt(&now)
|
||||
|
||||
c3 := clients.Client{}
|
||||
c3.SetID("3")
|
||||
c3.SetClientAuthID("3")
|
||||
c3.SetConnection(connFailure)
|
||||
|
||||
c4 := clients.Client{}
|
||||
c4.SetID("4")
|
||||
c4.SetClientAuthID("4")
|
||||
c4.SetConnection(connTimeout)
|
||||
|
||||
cr := clients.NewClientRepository([]*clients.Client{&c1, &c2, &c3, &c4}, nil, myTestLog)
|
||||
task := NewClientsStatusCheckTask(myTestLog, cr, 120*time.Second, timeout)
|
||||
|
||||
// Check the last heartbeat of c1 has changed due to the ping sent
|
||||
err = task.Run(context.Background())
|
||||
assert.NoError(t, err)
|
||||
c1, err := cr.GetByID("1")
|
||||
tcl1, err := cr.GetByID("1")
|
||||
assert.NoError(t, err)
|
||||
assert.IsType(t, &time.Time{}, c1.LastHeartbeatAt)
|
||||
t.Logf("c1: LastHeartbeatAt: %s", c1.LastHeartbeatAt)
|
||||
assert.IsType(t, &time.Time{}, tcl1.GetLastHeartbeatAt())
|
||||
t.Logf("tcl1: LastHeartbeatAt: %s", tcl1.GetLastHeartbeatAt())
|
||||
|
||||
// Check the last heartbeat of c2 has not changed because the task must skip this client
|
||||
c2, err := cr.GetByID("2")
|
||||
tcl2, err := cr.GetByID("2")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &now, c2.LastHeartbeatAt, "LastHeartbeatAt of c2 must not change")
|
||||
t.Logf("c2: LastHeartbeatAt: %s", c2.LastHeartbeatAt)
|
||||
assert.Equal(t, &now, tcl2.GetLastHeartbeatAt(), "LastHeartbeatAt of tcl2 must not change")
|
||||
t.Logf("tcl2: LastHeartbeatAt: %s", tcl2.GetLastHeartbeatAt())
|
||||
|
||||
// Check the status of c3 changed to disconnected
|
||||
c3, err := cr.GetByID("3")
|
||||
tcl3, err := cr.GetByID("3")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, c3.DisconnectedAt)
|
||||
assert.Equal(t, "disconnected", string(c3.CalculateConnectionState()))
|
||||
t.Logf("c3: DisconnectedAt: %s", c3.DisconnectedAt)
|
||||
assert.NotNil(t, tcl3.GetDisconnectedAt())
|
||||
assert.Equal(t, "disconnected", string(tcl3.CalculateConnectionState()))
|
||||
t.Logf("tcl3: GetDisconnectedAt(): %s", tcl3.GetDisconnectedAt())
|
||||
|
||||
// Check the status of c4 changed to disconnected caused by a timeout
|
||||
c4, err := cr.GetByID("4")
|
||||
tcl4, err := cr.GetByID("4")
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, c4.DisconnectedAt)
|
||||
assert.Equal(t, "disconnected", string(c4.CalculateConnectionState()))
|
||||
t.Logf("c4: DisconnectedAt: %s", c4.DisconnectedAt)
|
||||
assert.NotNil(t, tcl4.GetDisconnectedAt())
|
||||
assert.Equal(t, "disconnected", string(tcl4.CalculateConnectionState()))
|
||||
t.Logf("tcl4: DisconnectedAt: %s", tcl4.GetDisconnectedAt())
|
||||
log, err := os.ReadFile(logfile)
|
||||
assert.NoError(t, err, "error reading log file")
|
||||
assert.Contains(t, string(log), fmt.Sprintf("ping to [4] failed: conn.SendRequest(ping), timeout %s exceeded", timeout))
|
||||
|
||||
@ -12,17 +12,20 @@ import (
|
||||
func TestCleanup(t *testing.T) {
|
||||
// given
|
||||
ctx := context.Background()
|
||||
c1 := New(t).Build() // active
|
||||
c2 := New(t).DisconnectedDuration(5 * time.Minute).Build() // disconnected
|
||||
c3 := New(t).DisconnectedDuration(time.Hour + time.Minute).Build() // obsolete
|
||||
c1 := New(t).ID("client-1").Logger(testLog).Build() // active
|
||||
c2 := New(t).ID("client-2").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build() // disconnected
|
||||
c3 := New(t).ID("client-3").DisconnectedDuration(time.Hour + time.Minute).Logger(testLog).Build() // obsolete
|
||||
clients := []*Client{c1, c2, c3}
|
||||
p := NewFakeClientProvider(t, &hour, c1, c2, c3)
|
||||
defer p.Close()
|
||||
clientsRepo := NewClientRepositoryWithDB(clients, &hour, p, testLog)
|
||||
require.Len(t, clientsRepo.clients, 3)
|
||||
gotObsolete, err := p.get(ctx, c3.ID)
|
||||
require.Len(t, clientsRepo.clientState, 3)
|
||||
gotObsolete, err := p.get(ctx, c3.GetID())
|
||||
require.NoError(t, err)
|
||||
c3.Logger = nil
|
||||
|
||||
// patch in the logger for the client
|
||||
gotObsolete.Logger = testLog
|
||||
|
||||
require.EqualValues(t, c3, gotObsolete)
|
||||
task := NewCleanupTask(testLog, clientsRepo)
|
||||
|
||||
@ -31,11 +34,16 @@ func TestCleanup(t *testing.T) {
|
||||
|
||||
// then
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, getValues(clientsRepo.clients), []*Client{c1, c2})
|
||||
assert.ElementsMatch(t, getValues(clientsRepo.clientState), []*Client{c1, c2})
|
||||
gotClients, err := p.GetAll(ctx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// patch in the logger for the clients
|
||||
gotClients[0].Logger = testLog
|
||||
gotClients[1].Logger = testLog
|
||||
|
||||
assert.ElementsMatch(t, []*Client{c1, c2}, gotClients)
|
||||
gotObsolete, err = p.get(ctx, c3.ID)
|
||||
gotObsolete, err = p.get(ctx, c3.GetID())
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, gotObsolete)
|
||||
}
|
||||
@ -43,14 +51,14 @@ func TestCleanup(t *testing.T) {
|
||||
func TestCleanupDisabled(t *testing.T) {
|
||||
// given
|
||||
ctx := context.Background()
|
||||
c1 := New(t).Build() // active
|
||||
c2 := New(t).DisconnectedDuration(5 * time.Minute).Build() // disconnected
|
||||
c3 := New(t).DisconnectedDuration(365*24*time.Hour + time.Minute).Build() // disconnected longer
|
||||
c1 := New(t).Logger(testLog).Build() // active
|
||||
c2 := New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build() // disconnected
|
||||
c3 := New(t).DisconnectedDuration(365*24*time.Hour + time.Minute).Logger(testLog).Build() // disconnected longer
|
||||
clients := []*Client{c1, c2, c3}
|
||||
p := NewFakeClientProvider(t, nil, c1, c2, c3)
|
||||
defer p.Close()
|
||||
clientsRepo := NewClientRepositoryWithDB(clients, nil, p, testLog)
|
||||
require.Len(t, clientsRepo.clients, 3)
|
||||
require.Len(t, clientsRepo.clientState, 3)
|
||||
|
||||
task := NewCleanupTask(testLog, clientsRepo)
|
||||
|
||||
@ -59,7 +67,7 @@ func TestCleanupDisabled(t *testing.T) {
|
||||
|
||||
// then
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, getValues(clientsRepo.clients), []*Client{c1, c2, c3})
|
||||
assert.ElementsMatch(t, getValues(clientsRepo.clientState), []*Client{c1, c2, c3})
|
||||
}
|
||||
|
||||
func getValues(clients map[string]*Client) []*Client {
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"github.com/cloudradar-monitoring/rport/server/api/users"
|
||||
"github.com/cloudradar-monitoring/rport/server/cgroups"
|
||||
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
|
||||
chshare "github.com/cloudradar-monitoring/rport/share"
|
||||
"github.com/cloudradar-monitoring/rport/share/clientconfig"
|
||||
"github.com/cloudradar-monitoring/rport/share/logger"
|
||||
"github.com/cloudradar-monitoring/rport/share/models"
|
||||
@ -58,6 +59,7 @@ type Client struct {
|
||||
Version string `json:"version"`
|
||||
Address string `json:"address"`
|
||||
Tunnels []*clienttunnel.Tunnel `json:"tunnels"`
|
||||
|
||||
// DisconnectedAt is a time when a client was disconnected. If nil - it's connected.
|
||||
DisconnectedAt *time.Time `json:"disconnected_at"`
|
||||
LastHeartbeatAt *time.Time `json:"last_heartbeat_at"`
|
||||
@ -67,11 +69,13 @@ type Client struct {
|
||||
ClientConfiguration *clientconfig.Config `json:"client_configuration"`
|
||||
|
||||
Connection ssh.Conn `json:"-"`
|
||||
Logger *logger.Logger `json:"-"`
|
||||
Context context.Context `json:"-"`
|
||||
Paused bool `json:"-"`
|
||||
PausedReason string `json:"-"`
|
||||
lock sync.Mutex
|
||||
|
||||
Logger *logger.Logger `json:"-"`
|
||||
|
||||
flock sync.RWMutex
|
||||
}
|
||||
|
||||
// CalculatedClient contains additional fields and is calculated on each request
|
||||
@ -81,47 +85,295 @@ type CalculatedClient struct {
|
||||
ConnectionState ConnectionState `json:"connection_state"`
|
||||
}
|
||||
|
||||
func NewCalculatedClient(c *Client, groups []string, connectionState ConnectionState) (cc *CalculatedClient) {
|
||||
cc = &CalculatedClient{}
|
||||
cc.Client = c
|
||||
cc.Groups = groups
|
||||
cc.ConnectionState = connectionState
|
||||
return cc
|
||||
}
|
||||
|
||||
func (cc *CalculatedClient) GetConnectionState() (cs ConnectionState) {
|
||||
return cc.ConnectionState
|
||||
}
|
||||
|
||||
func (c *Client) GetID() (id string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.ID
|
||||
}
|
||||
|
||||
func (c *Client) GetName() (name string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Name
|
||||
}
|
||||
|
||||
func (c *Client) GetSessionID() (sessionID string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.SessionID
|
||||
}
|
||||
|
||||
func (c *Client) GetOS() (os string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.OS
|
||||
}
|
||||
|
||||
func (c *Client) GetHostname() (hostname string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Hostname
|
||||
}
|
||||
|
||||
func (c *Client) GetTags() (tags []string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
|
||||
if c.Tags == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// make sure not to return reference to underlying array
|
||||
tags = make([]string, len(c.Tags))
|
||||
copy(tags, c.Tags)
|
||||
return tags
|
||||
}
|
||||
|
||||
func (c *Client) GetAllowedUserGroups() (groups []string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
|
||||
if c.AllowedUserGroups == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// make sure not to return reference to underlying array
|
||||
groups = make([]string, len(c.AllowedUserGroups))
|
||||
copy(groups, c.AllowedUserGroups)
|
||||
return groups
|
||||
}
|
||||
|
||||
func (c *Client) GetVersion() (version string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Version
|
||||
}
|
||||
|
||||
func (c *Client) GetDisconnectedAt() (at *time.Time) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.DisconnectedAt
|
||||
}
|
||||
|
||||
func (c *Client) GetDisconnectedAtValue() (at time.Time) {
|
||||
c.flock.RLock()
|
||||
if c.DisconnectedAt != nil {
|
||||
at = *c.DisconnectedAt
|
||||
}
|
||||
c.flock.RUnlock()
|
||||
return at
|
||||
}
|
||||
|
||||
func (c *Client) HasLastHeartbeatAt() (has bool) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.LastHeartbeatAt != nil
|
||||
}
|
||||
|
||||
func (c *Client) GetLastHeartbeatAt() (at *time.Time) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.LastHeartbeatAt
|
||||
}
|
||||
|
||||
func (c *Client) GetLastHeartbeatAtValue() (at time.Time) {
|
||||
c.flock.RLock()
|
||||
if c.LastHeartbeatAt != nil {
|
||||
at = *c.LastHeartbeatAt
|
||||
}
|
||||
c.flock.RUnlock()
|
||||
return at
|
||||
}
|
||||
|
||||
func (c *Client) GetConnection() (conn ssh.Conn) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Connection
|
||||
}
|
||||
|
||||
func (c *Client) GetPausedReason() (reason string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.PausedReason
|
||||
}
|
||||
|
||||
func (c *Client) GetContext() (ctx context.Context) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Context
|
||||
}
|
||||
|
||||
func (c *Client) GetClientAuthID() (authID string) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.ClientAuthID
|
||||
}
|
||||
|
||||
func (c *Client) GetTunnels() (tunnels []*clienttunnel.Tunnel) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Tunnels
|
||||
}
|
||||
|
||||
func (c *Client) Log() (l *logger.Logger) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Logger
|
||||
}
|
||||
|
||||
func (c *Client) IsPaused() (paused bool) {
|
||||
paused = c.Paused
|
||||
return paused
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
return c.Paused
|
||||
}
|
||||
|
||||
func (c *Client) GetMonitoringConfig() (monitoringConfig *clientconfig.MonitoringConfig) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
|
||||
if c.ClientConfiguration == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &c.ClientConfiguration.Monitoring
|
||||
}
|
||||
|
||||
func (c *Client) GetFileReceptionConfig() (fileReceptionConfig *clientconfig.FileReceptionConfig) {
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
|
||||
if c.ClientConfiguration == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &c.ClientConfiguration.FileReceptionConfig
|
||||
}
|
||||
|
||||
// test only
|
||||
func (c *Client) SetID(id string) {
|
||||
c.flock.Lock()
|
||||
c.ID = id
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
// test only
|
||||
func (c *Client) SetAddress(address string) {
|
||||
c.flock.Lock()
|
||||
c.Address = address
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
// test only
|
||||
func (c *Client) SetHostname(hostname string) {
|
||||
c.flock.Lock()
|
||||
c.Hostname = hostname
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
// test only
|
||||
func (c *Client) SetClientAuthID(authID string) {
|
||||
c.flock.Lock()
|
||||
c.ClientAuthID = authID
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
// test only
|
||||
func (c *Client) SetTags(tags []string) {
|
||||
if c.Tags == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.flock.Lock()
|
||||
// make sure not to just copy the tag reference
|
||||
c.Tags = make([]string, len(c.Tags))
|
||||
copy(c.Tags, tags)
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
// test only
|
||||
func (c *Client) SetConnection(conn ssh.Conn) {
|
||||
c.flock.Lock()
|
||||
c.Connection = conn
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) SetTunnels(tunnels []*clienttunnel.Tunnel) {
|
||||
c.flock.Lock()
|
||||
c.Tunnels = tunnels
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) SetAllowedUserGroups(groups []string) {
|
||||
c.flock.Lock()
|
||||
c.AllowedUserGroups = groups
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) SetUpdatesStatus(status *models.UpdatesStatus) {
|
||||
c.flock.Lock()
|
||||
c.UpdatesStatus = status
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) SetDisconnectedAt(at *time.Time) {
|
||||
l := c.Log()
|
||||
if l != nil && at != nil {
|
||||
l.Debugf("%s: set to disconnected at %s", c.GetID(), at)
|
||||
}
|
||||
c.flock.Lock()
|
||||
c.DisconnectedAt = at
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
func (c *Client) SetLastHeartbeatAt(at *time.Time) {
|
||||
c.SetDisconnectedAt(nil)
|
||||
c.flock.Lock()
|
||||
c.LastHeartbeatAt = at
|
||||
c.flock.Unlock()
|
||||
}
|
||||
|
||||
const PausedDueToMaxClientsExceeded = "unlicensed"
|
||||
|
||||
func (c *Client) SetPaused(paused bool, reason string) {
|
||||
c.flock.Lock()
|
||||
c.Paused = paused
|
||||
c.PausedReason = reason
|
||||
c.flock.Unlock()
|
||||
if paused {
|
||||
c.Logger.Infof("client %s is paused (reason = %s)", c.ID, c.PausedReason)
|
||||
c.Log().Infof("client %s is paused (reason = %s)", c.GetID(), reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) IsConnected() bool {
|
||||
return c.GetDisconnectedAt() == nil
|
||||
}
|
||||
|
||||
func (c *Client) SetConnected() {
|
||||
c.Logger.Debugf("%s: set to connected at %s", c.ID, time.Now())
|
||||
c.DisconnectedAt = nil
|
||||
}
|
||||
|
||||
func (c *Client) SetDisconnected(at *time.Time) {
|
||||
if c.Logger != nil {
|
||||
c.Logger.Debugf("%s: set to disconnected at %s", c.ID, at)
|
||||
}
|
||||
c.DisconnectedAt = at
|
||||
c.Log().Debugf("%s: set to connected at %s", c.GetID(), time.Now())
|
||||
c.SetDisconnectedAt(nil)
|
||||
}
|
||||
|
||||
func (c *Client) SetDisconnectedNow() {
|
||||
now := time.Now()
|
||||
c.DisconnectedAt = &now
|
||||
c.SetDisconnectedAt(&now)
|
||||
}
|
||||
|
||||
func (c *Client) SetHeartbeatNow() {
|
||||
now := time.Now()
|
||||
c.DisconnectedAt = nil
|
||||
c.LastHeartbeatAt = &now
|
||||
}
|
||||
|
||||
func (c *Client) SetHeartbeat(at *time.Time) {
|
||||
c.DisconnectedAt = nil
|
||||
c.LastHeartbeatAt = at
|
||||
c.SetLastHeartbeatAt(&now)
|
||||
c.SetDisconnectedAt(nil)
|
||||
}
|
||||
|
||||
func (c *Client) ToCalculated(allGroups []*cgroups.ClientGroup) *CalculatedClient {
|
||||
@ -131,26 +383,16 @@ func (c *Client) ToCalculated(allGroups []*cgroups.ClientGroup) *CalculatedClien
|
||||
clientGroups = append(clientGroups, group.ID)
|
||||
}
|
||||
}
|
||||
return &CalculatedClient{
|
||||
Client: c,
|
||||
Groups: clientGroups,
|
||||
ConnectionState: c.CalculateConnectionState(),
|
||||
}
|
||||
|
||||
return NewCalculatedClient(c, clientGroups, c.CalculateConnectionState())
|
||||
}
|
||||
|
||||
// Obsolete returns true if a given client was disconnected longer than a given duration.
|
||||
// If a given duration is nil - returns false (never obsolete).
|
||||
func (c *Client) Obsolete(duration *time.Duration) bool {
|
||||
return duration != nil && c.DisconnectedAt != nil &&
|
||||
c.DisconnectedAt.Add(*duration).Before(now())
|
||||
}
|
||||
|
||||
func (c *Client) Lock() {
|
||||
c.lock.Lock()
|
||||
}
|
||||
|
||||
func (c *Client) Unlock() {
|
||||
c.lock.Unlock()
|
||||
disconnectedAt := c.GetDisconnectedAt()
|
||||
return duration != nil && disconnectedAt != nil &&
|
||||
disconnectedAt.Add(*duration).Before(now())
|
||||
}
|
||||
|
||||
func (c *Client) NewTunnelID() (tunnelID string) {
|
||||
@ -163,31 +405,37 @@ func (c *Client) generateNewTunnelID() int64 {
|
||||
}
|
||||
|
||||
func (c *Client) RemoveTunnelByID(tunnelID string) {
|
||||
result := make([]*clienttunnel.Tunnel, 0)
|
||||
for _, curr := range c.Tunnels {
|
||||
if curr.ID != tunnelID {
|
||||
result = append(result, curr)
|
||||
updatedTunnelList := make([]*clienttunnel.Tunnel, 0)
|
||||
// TODO: (rs): not thread-safe
|
||||
for _, tunnel := range c.GetTunnels() {
|
||||
if tunnel.ID != tunnelID {
|
||||
updatedTunnelList = append(updatedTunnelList, tunnel)
|
||||
}
|
||||
}
|
||||
c.Tunnels = result
|
||||
c.SetTunnels(updatedTunnelList)
|
||||
}
|
||||
|
||||
func (c *Client) Banner() string {
|
||||
banner := c.ID
|
||||
if c.Name != "" {
|
||||
banner += " (" + c.Name + ")"
|
||||
clientID := c.GetID()
|
||||
clientName := c.GetName()
|
||||
tags := c.GetTags()
|
||||
|
||||
banner := clientID
|
||||
if clientName != "" {
|
||||
banner += " (" + clientName + ")"
|
||||
}
|
||||
if len(c.Tags) != 0 {
|
||||
for _, t := range c.Tags {
|
||||
if len(tags) != 0 {
|
||||
for _, t := range tags {
|
||||
banner += " #" + t
|
||||
}
|
||||
}
|
||||
|
||||
return banner
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
// The tunnels are closed automatically when ssh connection is closed.
|
||||
return c.Connection.Close()
|
||||
return c.GetConnection().Close()
|
||||
}
|
||||
|
||||
func (c *Client) BelongsToOneOf(groups []*cgroups.ClientGroup) bool {
|
||||
@ -204,6 +452,10 @@ func (c *Client) BelongsTo(group *cgroups.ClientGroup) bool {
|
||||
if p.HasNoParams() {
|
||||
return false
|
||||
}
|
||||
|
||||
c.flock.RLock()
|
||||
defer c.flock.RUnlock()
|
||||
|
||||
if !p.ClientID.MatchesOneOf(c.ID) {
|
||||
return false
|
||||
}
|
||||
@ -247,7 +499,7 @@ func (c *Client) BelongsTo(group *cgroups.ClientGroup) bool {
|
||||
}
|
||||
|
||||
func (c *Client) CalculateConnectionState() ConnectionState {
|
||||
if c.DisconnectedAt == nil {
|
||||
if c.IsConnected() {
|
||||
return Connected
|
||||
}
|
||||
return Disconnected
|
||||
@ -259,7 +511,7 @@ func (c *Client) HasAccessViaUserGroups(userGroups []string) bool {
|
||||
if curUserGroup == users.Administrators {
|
||||
return true
|
||||
}
|
||||
for _, allowedGroup := range c.AllowedUserGroups {
|
||||
for _, allowedGroup := range c.GetAllowedUserGroups() {
|
||||
if allowedGroup == curUserGroup {
|
||||
return true
|
||||
}
|
||||
@ -284,3 +536,48 @@ func (c *Client) UserGroupHasAccessViaClientGroup(userGroups []string, allClient
|
||||
func NewClientID() (string, error) {
|
||||
return random.UUID4()
|
||||
}
|
||||
|
||||
func NewClientFromConnRequest(ctx context.Context, existingClient *Client, clientAuthID string, clientID string, req *chshare.ConnectionRequest, clientHost string, sshConn ssh.Conn, clog *logger.Logger) (client *Client) {
|
||||
if existingClient == nil {
|
||||
client = &Client{
|
||||
ID: clientID,
|
||||
}
|
||||
} else {
|
||||
client = existingClient
|
||||
}
|
||||
|
||||
client.flock.Lock()
|
||||
client.Name = req.Name
|
||||
client.SessionID = req.SessionID
|
||||
client.OS = req.OS
|
||||
client.OSArch = req.OSArch
|
||||
client.OSFamily = req.OSFamily
|
||||
client.OSKernel = req.OSKernel
|
||||
client.OSFullName = req.OSFullName
|
||||
client.OSVersion = req.OSVersion
|
||||
client.OSVirtualizationSystem = req.OSVirtualizationSystem
|
||||
client.OSVirtualizationRole = req.OSVirtualizationRole
|
||||
client.Hostname = req.Hostname
|
||||
client.CPUFamily = req.CPUFamily
|
||||
client.CPUModel = req.CPUModel
|
||||
client.CPUModelName = req.CPUModelName
|
||||
client.CPUVendor = req.CPUVendor
|
||||
client.NumCPUs = req.NumCPUs
|
||||
client.MemoryTotal = req.MemoryTotal
|
||||
client.Timezone = req.Timezone
|
||||
client.IPv4 = req.IPv4
|
||||
client.IPv6 = req.IPv6
|
||||
client.Tags = req.Tags
|
||||
client.Version = req.Version
|
||||
client.ClientConfiguration = req.ClientConfiguration
|
||||
client.Address = clientHost
|
||||
client.Tunnels = make([]*clienttunnel.Tunnel, 0)
|
||||
client.DisconnectedAt = nil
|
||||
client.ClientAuthID = clientAuthID
|
||||
client.Connection = sshConn
|
||||
client.Context = ctx
|
||||
client.Logger = clog
|
||||
client.flock.Unlock()
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
@ -14,6 +14,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cloudradar-monitoring/rport/share/clientconfig"
|
||||
"github.com/cloudradar-monitoring/rport/share/logger"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
@ -39,6 +40,7 @@ type ClientBuilder struct {
|
||||
disconnectedAt *time.Time
|
||||
allowedUserGroups []string
|
||||
conn ssh.Conn
|
||||
logger *logger.Logger
|
||||
cfg *clientconfig.Config
|
||||
}
|
||||
|
||||
@ -58,6 +60,11 @@ func (b ClientBuilder) ID(id string) ClientBuilder {
|
||||
return b
|
||||
}
|
||||
|
||||
func (b ClientBuilder) Logger(l *logger.Logger) ClientBuilder {
|
||||
b.logger = l
|
||||
return b
|
||||
}
|
||||
|
||||
func (b ClientBuilder) ClientAuthID(clientAuthID string) ClientBuilder {
|
||||
b.clientAuthID = clientAuthID
|
||||
return b
|
||||
@ -139,6 +146,7 @@ func (b ClientBuilder) Build() *Client {
|
||||
|
||||
Connection: b.conn,
|
||||
ClientConfiguration: b.cfg,
|
||||
Logger: b.logger,
|
||||
}
|
||||
|
||||
return cl
|
||||
|
||||
@ -31,15 +31,15 @@ import (
|
||||
type ClientService interface {
|
||||
SetPlusLicenseInfoCap(licensecap licensecap.CapabilityEx)
|
||||
|
||||
Count() (int, error)
|
||||
CountActive() (int, error)
|
||||
Count() int
|
||||
CountActive() int
|
||||
CountDisconnected() (int, error)
|
||||
GetByID(id string) (*Client, error)
|
||||
GetActiveByID(id string) (*Client, error)
|
||||
GetByGroups(groups []*cgroups.ClientGroup) ([]*Client, error)
|
||||
GetClientsByTag(tags []string, operator string, allowDisconnected bool) (clients []*Client, err error)
|
||||
GetAllByClientID(clientID string) []*Client
|
||||
GetAll() ([]*Client, error)
|
||||
GetAll() []*Client
|
||||
GetUserClients(groups []*cgroups.ClientGroup, user User) ([]*Client, error)
|
||||
GetFilteredUserClients(user User, filterOptions []query.FilterOption, groups []*cgroups.ClientGroup) ([]*CalculatedClient, error)
|
||||
|
||||
@ -81,7 +81,13 @@ type ClientServiceProvider struct {
|
||||
|
||||
licensecap licensecap.CapabilityEx
|
||||
|
||||
mu sync.Mutex
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) log() (l *logger.Logger) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.logger
|
||||
}
|
||||
|
||||
var OptionsSupportedFilters = map[string]bool{
|
||||
@ -216,9 +222,9 @@ func (s *ClientServiceProvider) UpdateClientStatus() {
|
||||
return
|
||||
}
|
||||
|
||||
s.logger.Debugf("updating client status")
|
||||
s.log().Debugf("updating client status")
|
||||
|
||||
clientList := s.repo.GetAllActive()
|
||||
clientList, _ := s.repo.GetAllActiveClients()
|
||||
|
||||
for i, client := range clientList {
|
||||
if i < s.GetMaxClients() {
|
||||
@ -227,14 +233,13 @@ func (s *ClientServiceProvider) UpdateClientStatus() {
|
||||
client.SetPaused(true, PausedDueToMaxClientsExceeded)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) Count() (int, error) {
|
||||
func (s *ClientServiceProvider) Count() int {
|
||||
return s.repo.Count()
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) CountActive() (int, error) {
|
||||
func (s *ClientServiceProvider) CountActive() int {
|
||||
return s.repo.CountActive()
|
||||
}
|
||||
|
||||
@ -255,13 +260,10 @@ func (s *ClientServiceProvider) GetByGroups(groups []*cgroups.ClientGroup) ([]*C
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
all, err := s.repo.GetAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allClients := s.repo.GetAllClients()
|
||||
|
||||
var res []*Client
|
||||
for _, cur := range all {
|
||||
for _, cur := range allClients {
|
||||
if cur.BelongsToOneOf(groups) {
|
||||
res = append(res, cur)
|
||||
}
|
||||
@ -274,11 +276,12 @@ func (s *ClientServiceProvider) GetClientsByTag(tags []string, operator string,
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) PopulateGroupsWithUserClients(groups []*cgroups.ClientGroup, user User) {
|
||||
all, _ := s.repo.GetUserClients(user, groups)
|
||||
for _, curClient := range all {
|
||||
availableClients, _ := s.repo.GetUserClients(user, groups)
|
||||
for _, client := range availableClients {
|
||||
clientID := client.GetID()
|
||||
for _, curGroup := range groups {
|
||||
if curClient.BelongsTo(curGroup) {
|
||||
curGroup.ClientIDs = append(curGroup.ClientIDs, curClient.ID)
|
||||
if client.BelongsTo(curGroup) {
|
||||
curGroup.ClientIDs = append(curGroup.ClientIDs, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -291,8 +294,8 @@ func (s *ClientServiceProvider) GetAllByClientID(clientID string) []*Client {
|
||||
return s.repo.GetAllByClientAuthID(clientID)
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) GetAll() ([]*Client, error) {
|
||||
return s.repo.GetAll()
|
||||
func (s *ClientServiceProvider) GetAll() []*Client {
|
||||
return s.repo.GetAllClients()
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) GetUserClients(groups []*cgroups.ClientGroup, user User) ([]*Client, error) {
|
||||
@ -308,40 +311,49 @@ func (s *ClientServiceProvider) StartClient(
|
||||
req *chshare.ConnectionRequest, clog *logger.Logger,
|
||||
) (*Client, error) {
|
||||
clog.Debugf("Starting client session: %s", clientID)
|
||||
repo := s.GetRepo()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
clientAddr := sshConn.RemoteAddr().String()
|
||||
clientHost, _, err := net.SplitHostPort(clientAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get host for address %q: %v", clientAddr, err)
|
||||
}
|
||||
|
||||
// if client id is in use, deny connection
|
||||
client, err := s.repo.GetByID(clientID)
|
||||
client, err := repo.GetByID(clientID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get client by id %q", clientID)
|
||||
}
|
||||
|
||||
// if found existing client
|
||||
if client != nil {
|
||||
clog.Debugf("found existing client %s", clientID)
|
||||
var sessionReUsed = false
|
||||
if req.SessionID != "" && req.SessionID == client.SessionID {
|
||||
if req.SessionID != "" && req.SessionID == client.GetSessionID() {
|
||||
// Stored previous session id and the session id of the connection attempt are equal
|
||||
sessionReUsed = true
|
||||
clog.Debugf("resuming existing session %s for client %s [%s]", req.SessionID, client.Name, clientID)
|
||||
}
|
||||
if client.DisconnectedAt == nil && !sessionReUsed {
|
||||
return nil, fmt.Errorf("client is already connected: %s [%s]", client.Name, clientID)
|
||||
clog.Debugf("resuming existing session %s for client %s [%s]", req.SessionID, client.GetName(), clientID)
|
||||
}
|
||||
|
||||
oldTunnels := GetTunnelsToReestablish(getRemotes(client.Tunnels), req.Remotes)
|
||||
if client.IsConnected() && !sessionReUsed {
|
||||
clog.Debugf("client is already connected: %s", clientID)
|
||||
return nil, fmt.Errorf("client is already connected: %s [%s]", client.GetName(), clientID)
|
||||
}
|
||||
|
||||
oldTunnels := getTunnelsToReestablish(getRemotes(client.GetTunnels()), req.Remotes)
|
||||
|
||||
clientVersion, err := version.NewVersion(req.Version)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to determine client version: %v", err)
|
||||
}
|
||||
requiredVersion, _ := version.NewVersion("0.6.4")
|
||||
if clientVersion.GreaterThanOrEqual(requiredVersion) {
|
||||
oldTunnels, err = ExcludeNotAllowedTunnels(clog, oldTunnels, sshConn)
|
||||
oldTunnels, err = s.excludeNotAllowedTunnels(clog, oldTunnels, sshConn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to filter tunnels: %v", err)
|
||||
}
|
||||
} else {
|
||||
clog.Infof("client %s (%s) version %s does not support 'tunnel_allowed' policies. Consider upgrading.", client.ID, client.Name, client.Version)
|
||||
clog.Infof("client %s (%s) version %s does not support 'tunnel_allowed' policies. Consider upgrading.", client.GetID(), client.GetName(), client.GetVersion())
|
||||
}
|
||||
|
||||
clog.Infof("tunnels to create %d: %v", len(req.Remotes), req.Remotes)
|
||||
@ -353,68 +365,33 @@ func (s *ClientServiceProvider) StartClient(
|
||||
|
||||
// check if client auth ID is already used by another client
|
||||
if !authMultiuseCreds && s.isClientAuthIDInUse(clientAuthID, clientID) {
|
||||
clog.Debugf("client auth ID is already in use: %s: %q: ", clientID, clientAuthID)
|
||||
return nil, fmt.Errorf("client auth ID is already in use: %q", clientAuthID)
|
||||
}
|
||||
|
||||
clientAddr := sshConn.RemoteAddr().String()
|
||||
clientHost, _, err := net.SplitHostPort(clientAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get host for address %q: %v", clientAddr, err)
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
client = &Client{
|
||||
ID: clientID,
|
||||
}
|
||||
}
|
||||
|
||||
client.Name = req.Name
|
||||
client.SessionID = req.SessionID
|
||||
client.OS = req.OS
|
||||
client.OSArch = req.OSArch
|
||||
client.OSFamily = req.OSFamily
|
||||
client.OSKernel = req.OSKernel
|
||||
client.OSFullName = req.OSFullName
|
||||
client.OSVersion = req.OSVersion
|
||||
client.OSVirtualizationSystem = req.OSVirtualizationSystem
|
||||
client.OSVirtualizationRole = req.OSVirtualizationRole
|
||||
client.Hostname = req.Hostname
|
||||
client.CPUFamily = req.CPUFamily
|
||||
client.CPUModel = req.CPUModel
|
||||
client.CPUModelName = req.CPUModelName
|
||||
client.CPUVendor = req.CPUVendor
|
||||
client.NumCPUs = req.NumCPUs
|
||||
client.MemoryTotal = req.MemoryTotal
|
||||
client.Timezone = req.Timezone
|
||||
client.IPv4 = req.IPv4
|
||||
client.IPv6 = req.IPv6
|
||||
client.Tags = req.Tags
|
||||
client.Version = req.Version
|
||||
client.ClientConfiguration = req.ClientConfiguration
|
||||
client.Address = clientHost
|
||||
client.Tunnels = make([]*clienttunnel.Tunnel, 0)
|
||||
client.DisconnectedAt = nil
|
||||
client.ClientAuthID = clientAuthID
|
||||
client.Connection = sshConn
|
||||
client.Context = ctx
|
||||
client.Logger = clog
|
||||
client = NewClientFromConnRequest(ctx, client, clientAuthID, clientID, req, clientHost, sshConn, clog)
|
||||
|
||||
client.SetConnected()
|
||||
|
||||
s.UpdateClientStatus()
|
||||
|
||||
if !client.IsPaused() {
|
||||
_, err = s.startClientTunnels(client, req.Remotes)
|
||||
_, err = s.startClientTunnels(client, req.Remotes, clog)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = s.repo.Save(client)
|
||||
err = repo.Save(client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: (rs): should we keep this?
|
||||
_, totalClients := repo.GetAllActiveClients()
|
||||
s.log().Debugf("total clients = %d (last: %s)", totalClients, client.GetName())
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
@ -426,8 +403,8 @@ func getRemotes(tunnels []*clienttunnel.Tunnel) []*models.Remote {
|
||||
return r
|
||||
}
|
||||
|
||||
// GetTunnelsToReestablish returns old tunnels that should be re-establish taking into account new tunnels.
|
||||
func GetTunnelsToReestablish(old, new []*models.Remote) []*models.Remote {
|
||||
// getTunnelsToReestablish returns old tunnels that should be re-establish taking into account new tunnels.
|
||||
func getTunnelsToReestablish(old, new []*models.Remote) []*models.Remote {
|
||||
if len(new) > len(old) {
|
||||
return nil
|
||||
}
|
||||
@ -484,11 +461,9 @@ loop2:
|
||||
|
||||
// StartClientTunnels returns a new tunnel for each requested remote or nil if error occurred
|
||||
func (s *ClientServiceProvider) StartClientTunnels(client *Client, remotes []*models.Remote) ([]*clienttunnel.Tunnel, error) {
|
||||
s.logger.Debugf("starting client tunnels: %s", client.ID)
|
||||
s.logger.Debugf("starting client tunnels: %s", client.GetID())
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
newTunnels, err := s.startClientTunnels(client, remotes)
|
||||
newTunnels, err := s.startClientTunnels(client, remotes, s.log())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -501,7 +476,7 @@ func (s *ClientServiceProvider) StartClientTunnels(client *Client, remotes []*mo
|
||||
return newTunnels, err
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*models.Remote) ([]*clienttunnel.Tunnel, error) {
|
||||
func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*models.Remote, clog *logger.Logger) ([]*clienttunnel.Tunnel, error) {
|
||||
err := s.portDistributor.Refresh()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -510,7 +485,7 @@ func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*mo
|
||||
tunnels := make([]*clienttunnel.Tunnel, 0, len(remotes))
|
||||
for _, remote := range remotes {
|
||||
if !remote.IsLocalSpecified() {
|
||||
s.logger.Debugf("no local specified")
|
||||
clog.Debugf("no local specified")
|
||||
port, err := s.portDistributor.GetRandomPort(remote.Protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -518,15 +493,15 @@ func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*mo
|
||||
remote.LocalPort = strconv.Itoa(port)
|
||||
remote.LocalHost = models.ZeroHost
|
||||
remote.LocalPortRandom = true
|
||||
s.logger.Debugf("using random port %s", remote.LocalPort)
|
||||
clog.Debugf("using random port %s", remote.LocalPort)
|
||||
} else {
|
||||
s.logger.Debugf("checking local port %s", remote.LocalPort)
|
||||
clog.Debugf("checking local port %s", remote.LocalPort)
|
||||
if err := s.checkLocalPort(remote.Protocol, remote.LocalPort); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Debugf("initiating tunnel %+v", remote)
|
||||
clog.Debugf("initiating tunnel %+v", remote)
|
||||
|
||||
var acl *clienttunnel.TunnelACL
|
||||
if remote.ACL != nil {
|
||||
@ -537,9 +512,10 @@ func (s *ClientServiceProvider) startClientTunnels(client *Client, remotes []*mo
|
||||
}
|
||||
}
|
||||
|
||||
s.logger.Debugf("starting tunnel: %s", remote)
|
||||
clog.Debugf("starting tunnel: %s", remote)
|
||||
t, err := s.StartTunnel(client, remote, acl)
|
||||
if err != nil {
|
||||
clog.Debugf("failed starting tunnel: %s: %v", remote, err)
|
||||
return nil, apiErrors.APIError{
|
||||
HTTPStatus: http.StatusConflict,
|
||||
Err: fmt.Errorf("unable to start tunnel: %s", err),
|
||||
@ -569,18 +545,16 @@ func (s *ClientServiceProvider) checkLocalPort(protocol, port string) error {
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) Terminate(client *Client) error {
|
||||
s.logger.Infof("terminating client: %s", client.ID)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.repo.KeepDisconnectedClients != nil && *s.repo.KeepDisconnectedClients == 0 {
|
||||
s.log().Infof("terminating client: %s: %s", client.GetID(), client.GetName())
|
||||
keepDisconnectedClientsDuration := s.repo.GetKeepDisconnectedClients()
|
||||
if keepDisconnectedClientsDuration != nil && *keepDisconnectedClientsDuration == 0 {
|
||||
return s.repo.Delete(client)
|
||||
}
|
||||
|
||||
client.SetDisconnectedNow()
|
||||
|
||||
// Do not save if client doesn't exist in repo - it was force deleted
|
||||
existing, err := s.repo.GetByID(client.ID)
|
||||
existing, err := s.repo.GetByID(client.GetID())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -596,11 +570,9 @@ func (s *ClientServiceProvider) Terminate(client *Client) error {
|
||||
// ForceDelete deletes client from repo regardless off KeepDisconnectedClients setting,
|
||||
// if client is active it will be closed
|
||||
func (s *ClientServiceProvider) ForceDelete(client *Client) error {
|
||||
s.logger.Debugf("force deleting client: %s", client.ID)
|
||||
s.logger.Debugf("force deleting client: %s", client.GetID())
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if client.DisconnectedAt == nil {
|
||||
if client.IsConnected() {
|
||||
if err := client.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -611,12 +583,12 @@ func (s *ClientServiceProvider) ForceDelete(client *Client) error {
|
||||
func (s *ClientServiceProvider) DeleteOffline(clientID string) error {
|
||||
s.logger.Debugf("deleting offline client: %s", clientID)
|
||||
|
||||
existing, err := s.getExistingByID(clientID)
|
||||
existing, err := s.getExistingClientByID(clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if existing.DisconnectedAt == nil {
|
||||
if existing.IsConnected() {
|
||||
return apiErrors.APIError{
|
||||
Message: "Client is active, should be disconnected",
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
@ -628,8 +600,8 @@ func (s *ClientServiceProvider) DeleteOffline(clientID string) error {
|
||||
|
||||
// isClientAuthIDInUse returns true when the client with different id exists for the client auth
|
||||
func (s *ClientServiceProvider) isClientAuthIDInUse(clientAuthID, clientID string) bool {
|
||||
for _, s := range s.repo.GetAllByClientAuthID(clientAuthID) {
|
||||
if s.ID != clientID {
|
||||
for _, client := range s.repo.GetAllByClientAuthID(clientAuthID) {
|
||||
if client.GetID() != clientID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -637,40 +609,40 @@ func (s *ClientServiceProvider) isClientAuthIDInUse(clientAuthID, clientID strin
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) SetACL(clientID string, allowedUserGroups []string) error {
|
||||
existing, err := s.getExistingByID(clientID)
|
||||
client, err := s.getExistingClientByID(clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existing.AllowedUserGroups = allowedUserGroups
|
||||
client.SetAllowedUserGroups(allowedUserGroups)
|
||||
|
||||
return s.repo.Save(existing)
|
||||
return s.repo.Save(client)
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) SetUpdatesStatus(clientID string, updatesStatus *models.UpdatesStatus) error {
|
||||
existing, err := s.getExistingByID(clientID)
|
||||
client, err := s.getExistingClientByID(clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
existing.UpdatesStatus = updatesStatus
|
||||
client.SetUpdatesStatus(updatesStatus)
|
||||
|
||||
return s.repo.Save(existing)
|
||||
return s.repo.Save(client)
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) SetLastHeartbeat(clientID string, heartbeat time.Time) error {
|
||||
existing, err := s.getExistingByID(clientID)
|
||||
existing, err := s.getExistingClientByID(clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
existing.SetHeartbeat(&heartbeat)
|
||||
existing.SetLastHeartbeatAt(&heartbeat)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckClientAccess returns nil if a given user has an access to a given client.
|
||||
// Otherwise, APIError with 403 is returned.
|
||||
func (s *ClientServiceProvider) CheckClientAccess(clientID string, user User, groups []*cgroups.ClientGroup) error {
|
||||
existing, err := s.getExistingByID(clientID)
|
||||
existing, err := s.getExistingClientByID(clientID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -687,11 +659,13 @@ func (s *ClientServiceProvider) CheckClientsAccess(clients []*Client, user User,
|
||||
|
||||
var clientsWithNoAccess []string
|
||||
userGroups := user.GetGroups()
|
||||
for _, curClient := range clients {
|
||||
if curClient.HasAccessViaUserGroups(userGroups) || curClient.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
|
||||
for _, client := range clients {
|
||||
clientID := client.GetID()
|
||||
if client.HasAccessViaUserGroups(userGroups) || client.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
|
||||
continue
|
||||
}
|
||||
clientsWithNoAccess = append(clientsWithNoAccess, curClient.ID)
|
||||
|
||||
clientsWithNoAccess = append(clientsWithNoAccess, clientID)
|
||||
}
|
||||
|
||||
if len(clientsWithNoAccess) > 0 {
|
||||
@ -704,8 +678,8 @@ func (s *ClientServiceProvider) CheckClientsAccess(clients []*Client, user User,
|
||||
return nil
|
||||
}
|
||||
|
||||
// getExistingByID returns non-nil client by id. If not found or failed to get a client - an error is returned.
|
||||
func (s *ClientServiceProvider) getExistingByID(clientID string) (*Client, error) {
|
||||
// getExistingClientByID returns non-nil client by id. If not found or failed to get a client - an error is returned.
|
||||
func (s *ClientServiceProvider) getExistingClientByID(clientID string) (*Client, error) {
|
||||
if clientID == "" {
|
||||
return nil, apiErrors.APIError{
|
||||
Message: "Client id is empty",
|
||||
@ -729,13 +703,15 @@ func (s *ClientServiceProvider) getExistingByID(clientID string) (*Client, error
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) GetRepo() *ClientRepository {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.repo
|
||||
}
|
||||
|
||||
func ExcludeNotAllowedTunnels(clog *logger.Logger, tunnels []*models.Remote, conn ssh.Conn) ([]*models.Remote, error) {
|
||||
func (s *ClientServiceProvider) excludeNotAllowedTunnels(clog *logger.Logger, tunnels []*models.Remote, conn ssh.Conn) ([]*models.Remote, error) {
|
||||
filtered := make([]*models.Remote, 0, len(tunnels))
|
||||
for _, t := range tunnels {
|
||||
allowed, err := clienttunnel.IsAllowed(t.Remote(), conn)
|
||||
allowed, err := clienttunnel.IsAllowed(t.Remote(), conn, s.log())
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "unknown request") {
|
||||
return tunnels, nil
|
||||
@ -751,19 +727,21 @@ func ExcludeNotAllowedTunnels(clog *logger.Logger, tunnels []*models.Remote, con
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
// TODO: (rs): can this move to the tunnel package?
|
||||
func (s *ClientServiceProvider) FindTunnelByRemote(c *Client, r *models.Remote) *clienttunnel.Tunnel {
|
||||
for _, curr := range c.Tunnels {
|
||||
if curr.Equals(r) {
|
||||
return curr
|
||||
for _, tunnel := range c.GetTunnels() {
|
||||
if tunnel.Equals(r) {
|
||||
return tunnel
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: (rs): can this move to the tunnel package?
|
||||
func (s *ClientServiceProvider) FindTunnel(c *Client, id string) *clienttunnel.Tunnel {
|
||||
for _, curr := range c.Tunnels {
|
||||
if curr.ID == id {
|
||||
return curr
|
||||
for _, tunnel := range c.GetTunnels() {
|
||||
if tunnel.ID == id {
|
||||
return tunnel
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -783,9 +761,9 @@ func (s *ClientServiceProvider) StartTunnel(
|
||||
return tunnel, nil
|
||||
}
|
||||
|
||||
s.logger.Debugf("starting tunnel: %s", remote)
|
||||
s.log().Debugf("starting tunnel: %s", remote)
|
||||
|
||||
ctx := client.Context
|
||||
ctx := client.GetContext()
|
||||
if remote.AutoClose > 0 {
|
||||
// no need to cancel the ctx since it will be canceled by parent ctx or after given timeout
|
||||
ctx, _ = context.WithTimeout(ctx, remote.AutoClose) // nolint: govet
|
||||
@ -800,9 +778,9 @@ func (s *ClientServiceProvider) StartTunnel(
|
||||
if remote.HasSubdomainTunnel() {
|
||||
err = s.startCaddyDownstreamProxy(ctx, client, remote, tunnel)
|
||||
if err != nil {
|
||||
tunnelStopErr := tunnel.InternalTunnelProxy.Stop(client.Context)
|
||||
tunnelStopErr := tunnel.InternalTunnelProxy.Stop(client.GetContext())
|
||||
if tunnelStopErr != nil {
|
||||
client.Logger.Infof("unable to stop internal tunnel proxy after failing to create caddy downstream proxy: %s", tunnelStopErr)
|
||||
client.Log().Infof("unable to stop internal tunnel proxy after failing to create caddy downstream proxy: %s", tunnelStopErr)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
@ -824,7 +802,10 @@ func (s *ClientServiceProvider) StartTunnel(
|
||||
go s.terminateTunnelOnIdleTimeout(ctx, tunnel, client)
|
||||
}
|
||||
|
||||
client.Tunnels = append(client.Tunnels, tunnel)
|
||||
existingTunnels := client.GetTunnels()
|
||||
existingTunnels = append(existingTunnels, tunnel)
|
||||
client.SetTunnels(existingTunnels)
|
||||
|
||||
return tunnel, nil
|
||||
}
|
||||
|
||||
@ -834,9 +815,11 @@ func (s *ClientServiceProvider) startCaddyDownstreamProxy(
|
||||
remote *models.Remote,
|
||||
tunnel *clienttunnel.Tunnel,
|
||||
) (err error) {
|
||||
client.Logger.Infof("starting downstream caddy proxy at %s", remote.TunnelURL)
|
||||
client.Logger.Debugf("tunnel = %#v", tunnel)
|
||||
client.Logger.Debugf("remote = %#v", remote)
|
||||
clientLogger := client.Log()
|
||||
|
||||
clientLogger.Infof("starting downstream caddy proxy at %s", remote.TunnelURL)
|
||||
clientLogger.Debugf("tunnel = %#v", tunnel)
|
||||
clientLogger.Debugf("remote = %#v", remote)
|
||||
|
||||
subdomain, basedomain, err := remote.GetTunnelDomains()
|
||||
if err != nil {
|
||||
@ -851,7 +834,7 @@ func (s *ClientServiceProvider) startCaddyDownstreamProxy(
|
||||
DownstreamProxyBaseDomain: basedomain,
|
||||
}
|
||||
|
||||
client.Logger.Debugf("requesting new caddy route = %+v", nrr)
|
||||
clientLogger.Debugf("requesting new caddy route = %+v", nrr)
|
||||
|
||||
res, err := s.caddyAPI.AddRoute(ctx, nrr)
|
||||
if err != nil {
|
||||
@ -862,14 +845,14 @@ func (s *ClientServiceProvider) startCaddyDownstreamProxy(
|
||||
return fmt.Errorf("failed to create downstream caddy proxy: status_code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
client.Logger.Infof("started downstream caddy proxy at %s to %s:%s", remote.TunnelURL, tunnel.LocalHost, tunnel.LocalPort)
|
||||
clientLogger.Infof("started downstream caddy proxy at %s to %s:%s", remote.TunnelURL, tunnel.LocalHost, tunnel.LocalPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) startRegularTunnel(ctx context.Context, client *Client, remote *models.Remote, acl *clienttunnel.TunnelACL) (*clienttunnel.Tunnel, error) {
|
||||
tunnelID := client.NewTunnelID()
|
||||
|
||||
tunnel, err := clienttunnel.NewTunnel(client.Logger, client.Connection, tunnelID, *remote, acl)
|
||||
tunnel, err := clienttunnel.NewTunnel(client.Log(), client.GetConnection(), tunnelID, *remote, acl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -888,12 +871,14 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
|
||||
remote *models.Remote,
|
||||
acl *clienttunnel.TunnelACL,
|
||||
) (*clienttunnel.Tunnel, error) {
|
||||
var proxyACL *clienttunnel.TunnelACL
|
||||
proxyHost := ""
|
||||
proxyPort := ""
|
||||
var proxyACL *clienttunnel.TunnelACL
|
||||
|
||||
clientLogger := client.Log()
|
||||
|
||||
// assuming that we still want to log activity in the client log
|
||||
client.Logger.Debugf("client %s will use tunnel proxy", client.ID)
|
||||
client.Logger.Debugf("client %s will use tunnel proxy", client.GetID())
|
||||
|
||||
// get values for tunnel proxy local host addr from original remote
|
||||
proxyHost = remote.LocalHost
|
||||
@ -913,7 +898,7 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
|
||||
tunnelID := client.NewTunnelID()
|
||||
|
||||
// original tunnel will use the reconfigured original remote
|
||||
t, err := clienttunnel.NewTunnel(client.Logger, client.Connection, tunnelID, *remote, acl)
|
||||
t, err := clienttunnel.NewTunnel(client.Logger, client.GetConnection(), tunnelID, *remote, acl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -925,8 +910,8 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
|
||||
}
|
||||
|
||||
// create new proxy tunnel listening at the original tunnel local host addr
|
||||
tProxy := clienttunnel.NewInternalTunnelProxy(t, client.Logger, s.tunnelProxyConfig, proxyHost, proxyPort, proxyACL)
|
||||
client.Logger.Debugf("client %s starting tunnel proxy", client.ID)
|
||||
tProxy := clienttunnel.NewInternalTunnelProxy(t, client.Log(), s.tunnelProxyConfig, proxyHost, proxyPort, proxyACL)
|
||||
clientLogger.Debugf("client %s starting tunnel proxy", client.GetID())
|
||||
if err := tProxy.Start(ctx); err != nil {
|
||||
client.Logger.Debugf("tunnel proxy could not be started, tunnel must be terminated: %v", err)
|
||||
if tErr := t.Terminate(true); tErr != nil {
|
||||
@ -941,8 +926,8 @@ func (s *ClientServiceProvider) startTunnelWithProxy(
|
||||
t.Remote.LocalHost = t.InternalTunnelProxy.Host
|
||||
t.Remote.LocalPort = t.InternalTunnelProxy.Port
|
||||
|
||||
client.Logger.Debugf("client %s started tunnel with proxy: %#v", client.ID, t)
|
||||
client.Logger.Debugf("internal tunnel proxy: %#v", t.InternalTunnelProxy)
|
||||
clientLogger.Debugf("client %s started tunnel with proxy: %#v", client.GetID(), t)
|
||||
clientLogger.Debugf("internal tunnel proxy: %#v", t.InternalTunnelProxy)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
@ -966,7 +951,7 @@ func (s *ClientServiceProvider) terminateTunnelOnIdleTimeout(ctx context.Context
|
||||
case <-timer.C:
|
||||
sinceLastActive := time.Since(t.LastActive())
|
||||
if sinceLastActive > idleTimeout {
|
||||
c.Logger.Infof("Terminating... inactivity period is reached: %d minute(s)", t.IdleTimeoutMinutes)
|
||||
c.Log().Infof("Terminating... inactivity period is reached: %d minute(s)", t.IdleTimeoutMinutes)
|
||||
_ = t.Terminate(true)
|
||||
s.cleanupAfterAutoClose(c, t)
|
||||
return
|
||||
@ -977,15 +962,14 @@ func (s *ClientServiceProvider) terminateTunnelOnIdleTimeout(ctx context.Context
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) cleanupAfterAutoClose(c *Client, t *clienttunnel.Tunnel) {
|
||||
c.Lock()
|
||||
defer c.Unlock()
|
||||
clientLogger := c.Log()
|
||||
|
||||
c.Logger.Infof("Auto closing tunnel %s ...", t.ID)
|
||||
clientLogger.Infof("Auto closing tunnel %s ...", t.ID)
|
||||
|
||||
//stop tunnel proxy
|
||||
if t.InternalTunnelProxy != nil {
|
||||
if err := t.InternalTunnelProxy.Stop(c.Context); err != nil {
|
||||
c.Logger.Errorf("error while stopping tunnel proxy: %v", err)
|
||||
if err := t.InternalTunnelProxy.Stop(c.GetContext()); err != nil {
|
||||
clientLogger.Errorf("error while stopping tunnel proxy: %v", err)
|
||||
}
|
||||
if t.Remote.HasSubdomainTunnel() {
|
||||
_ = s.removeCaddyDownstreamProxy(c, t)
|
||||
@ -996,14 +980,16 @@ func (s *ClientServiceProvider) cleanupAfterAutoClose(c *Client, t *clienttunnel
|
||||
|
||||
err := s.repo.Save(c)
|
||||
if err != nil {
|
||||
c.Logger.Errorf("unable to save client after auto close cleanup: %v", err)
|
||||
clientLogger.Errorf("unable to save client after auto close cleanup: %v", err)
|
||||
}
|
||||
|
||||
c.Logger.Debugf("auto closed tunnel with id=%s removed", t.ID)
|
||||
clientLogger.Debugf("auto closed tunnel with id=%s removed", t.ID)
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) TerminateTunnel(c *Client, t *clienttunnel.Tunnel, force bool) error {
|
||||
c.Logger.Infof("Terminating tunnel %s (force: %v) ...", t.ID, force)
|
||||
clientLogger := c.Log()
|
||||
|
||||
clientLogger.Infof("Terminating tunnel %s (force: %v) ...", t.ID, force)
|
||||
|
||||
err := t.Terminate(force)
|
||||
if err != nil {
|
||||
@ -1011,8 +997,8 @@ func (s *ClientServiceProvider) TerminateTunnel(c *Client, t *clienttunnel.Tunne
|
||||
}
|
||||
|
||||
if t.InternalTunnelProxy != nil {
|
||||
if err := t.InternalTunnelProxy.Stop(c.Context); err != nil {
|
||||
c.Logger.Errorf("error while stopping tunnel proxy: %v", err)
|
||||
if err := t.InternalTunnelProxy.Stop(c.GetContext()); err != nil {
|
||||
clientLogger.Errorf("error while stopping tunnel proxy: %v", err)
|
||||
}
|
||||
if t.Remote.HasSubdomainTunnel() {
|
||||
_ = s.removeCaddyDownstreamProxy(c, t)
|
||||
@ -1026,10 +1012,10 @@ func (s *ClientServiceProvider) TerminateTunnel(c *Client, t *clienttunnel.Tunne
|
||||
|
||||
err = s.repo.Save(c)
|
||||
if err != nil {
|
||||
c.Logger.Errorf("unable to save client after auto close cleanup: %v", err)
|
||||
clientLogger.Errorf("unable to save client after auto close cleanup: %v", err)
|
||||
}
|
||||
|
||||
c.Logger.Debugf("terminated tunnel with id=%s removed", t.ID)
|
||||
clientLogger.Debugf("terminated tunnel with id=%s removed", t.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1059,14 +1045,16 @@ func (s *ClientServiceProvider) SetTunnelACL(c *Client, t *clienttunnel.Tunnel,
|
||||
}
|
||||
|
||||
func (s *ClientServiceProvider) removeCaddyDownstreamProxy(c *Client, t *clienttunnel.Tunnel) (err error) {
|
||||
c.Logger.Infof("removing downstream caddy proxy at %s", t.Remote.TunnelURL)
|
||||
clientLogger := c.Log()
|
||||
|
||||
clientLogger.Infof("removing downstream caddy proxy at %s", t.Remote.TunnelURL)
|
||||
|
||||
subdomain, _, err := t.Remote.GetTunnelDomains()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res, err := s.caddyAPI.DeleteRoute(c.Context, subdomain)
|
||||
res, err := s.caddyAPI.DeleteRoute(c.GetContext(), subdomain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -1075,6 +1063,6 @@ func (s *ClientServiceProvider) removeCaddyDownstreamProxy(c *Client, t *clientt
|
||||
return fmt.Errorf("failed to delete downstream caddy proxy: status_code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
c.Logger.Infof("removed downstream caddy proxy at %s", t.Remote.TunnelURL)
|
||||
clientLogger.Infof("removed downstream caddy proxy at %s", t.Remote.TunnelURL)
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -4,23 +4,30 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/cloudradar-monitoring/rport/server/caddy"
|
||||
"github.com/cloudradar-monitoring/rport/server/cgroups"
|
||||
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
|
||||
"github.com/cloudradar-monitoring/rport/server/clientsauth"
|
||||
|
||||
mapset "github.com/deckarep/golang-set"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
errors2 "github.com/cloudradar-monitoring/rport/server/api/errors"
|
||||
clientsmigration "github.com/cloudradar-monitoring/rport/db/migration/clients"
|
||||
"github.com/cloudradar-monitoring/rport/db/sqlite"
|
||||
"github.com/cloudradar-monitoring/rport/server/caddy"
|
||||
"github.com/cloudradar-monitoring/rport/server/cgroups"
|
||||
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
|
||||
"github.com/cloudradar-monitoring/rport/server/clientsauth"
|
||||
|
||||
apiErrors "github.com/cloudradar-monitoring/rport/server/api/errors"
|
||||
"github.com/cloudradar-monitoring/rport/server/api/users"
|
||||
"github.com/cloudradar-monitoring/rport/server/ports"
|
||||
chshare "github.com/cloudradar-monitoring/rport/share"
|
||||
@ -93,7 +100,8 @@ func TestStartClient(t *testing.T) {
|
||||
ID: "test-client",
|
||||
ClientAuthID: "test-client-auth",
|
||||
}}, nil, testLog),
|
||||
portDistributor: ports.NewPortDistributor(mapset.NewThreadUnsafeSet()),
|
||||
portDistributor: ports.NewPortDistributor(mapset.NewSet()),
|
||||
logger: testLog,
|
||||
}
|
||||
_, err := cs.StartClient(
|
||||
context.Background(), tc.ClientAuthID, tc.ClientID, connMock, tc.AuthMultiuseCreds,
|
||||
@ -103,6 +111,90 @@ func TestStartClient(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// this is a fairly crude concurrency test for start client. currently excluded from the regular test runs as
|
||||
// it consumes a moderate amount of memory and takes some time to run.
|
||||
// go test -count=1 -race -v github.com/cloudradar-monitoring/rport/server/clients -run TestStartClientConcurrency
|
||||
func TestStartClientConcurrency(t *testing.T) {
|
||||
t.Skip()
|
||||
|
||||
// runtime.GOMAXPROCS(runtime.NumCPU() - 1)
|
||||
runtime.GOMAXPROCS(8)
|
||||
|
||||
sourceOptions := sqlite.DataSourceOptions{
|
||||
MaxOpenConnections: 250,
|
||||
WALEnabled: false,
|
||||
}
|
||||
|
||||
clientDB, err := sqlite.New(
|
||||
path.Join("./", "clients.db"),
|
||||
clientsmigration.AssetNames(),
|
||||
clientsmigration.Asset,
|
||||
sourceOptions,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer os.Remove("./clients.db")
|
||||
defer os.Remove("./clients.db-shm")
|
||||
|
||||
pd := ports.NewPortDistributor(mapset.NewSet())
|
||||
|
||||
totalClients := 1500
|
||||
|
||||
clients := []*Client{{}}
|
||||
for i := 0; i < totalClients; i++ {
|
||||
client := Client{
|
||||
Name: "test-name-" + strconv.Itoa(i),
|
||||
ID: "test-id-" + strconv.Itoa(i),
|
||||
ClientAuthID: "test-client-auth-" + strconv.Itoa(i),
|
||||
Version: "0.6.4",
|
||||
}
|
||||
clients = append(clients, &client)
|
||||
}
|
||||
|
||||
mockConns := []*test.ConnMock{}
|
||||
for i := 0; i < totalClients; i++ {
|
||||
mockConn := test.NewConnMock()
|
||||
mockConn.ReturnRemoteAddr = &net.TCPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 2000}
|
||||
mockConns = append(mockConns, mockConn)
|
||||
}
|
||||
|
||||
repo, err := InitClientRepository(context.Background(), clientDB, nil, testLog)
|
||||
require.NoError(t, err)
|
||||
|
||||
cs := &ClientServiceProvider{
|
||||
repo: repo,
|
||||
portDistributor: pd,
|
||||
logger: testLog,
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < totalClients; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
time.Sleep(time.Duration(rand.Intn(10)) * time.Millisecond)
|
||||
client := clients[i]
|
||||
|
||||
_, err := cs.StartClient(
|
||||
context.Background(),
|
||||
client.GetClientAuthID(),
|
||||
client.GetID(),
|
||||
mockConns[i],
|
||||
false,
|
||||
&chshare.ConnectionRequest{
|
||||
Name: client.GetName(),
|
||||
},
|
||||
testLog,
|
||||
)
|
||||
|
||||
assert.NoError(t, err)
|
||||
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestStartClientDisconnected(t *testing.T) {
|
||||
connMock := test.NewConnMock()
|
||||
connMock.ReturnRemoteAddr = &net.TCPAddr{IP: net.IPv4(192, 0, 2, 1), Port: 2345}
|
||||
@ -115,25 +207,27 @@ func TestStartClientDisconnected(t *testing.T) {
|
||||
AllowedUserGroups: []string{"test-group"},
|
||||
UpdatesStatus: &models.UpdatesStatus{UpdatesAvailable: 13},
|
||||
}}, nil, testLog),
|
||||
portDistributor: ports.NewPortDistributor(mapset.NewThreadUnsafeSet()),
|
||||
portDistributor: ports.NewPortDistributor(mapset.NewSet()),
|
||||
logger: testLog,
|
||||
}
|
||||
|
||||
client, err := cs.StartClient(
|
||||
context.Background(), "test-client-auth", "disconnected-client", connMock, false,
|
||||
&chshare.ConnectionRequest{Name: "new-connection", Version: "0.7.0"}, testLog)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Nil(t, client.DisconnectedAt)
|
||||
assert.Equal(t, "disconnected-client", client.ID)
|
||||
assert.Equal(t, "new-connection", client.Name)
|
||||
assert.Equal(t, []string{"test-group"}, client.AllowedUserGroups)
|
||||
assert.Equal(t, "disconnected-client", client.GetID())
|
||||
assert.Equal(t, "new-connection", client.GetName())
|
||||
assert.Equal(t, []string{"test-group"}, client.GetAllowedUserGroups())
|
||||
assert.Equal(t, 13, client.UpdatesStatus.UpdatesAvailable)
|
||||
}
|
||||
|
||||
func TestDeleteOfflineClient(t *testing.T) {
|
||||
c1Active := New(t).Build()
|
||||
c2Active := New(t).Build()
|
||||
c3Offline := New(t).DisconnectedDuration(5 * time.Minute).Build()
|
||||
c4Offline := New(t).DisconnectedDuration(time.Minute).Build()
|
||||
c1Active := New(t).Logger(testLog).Build()
|
||||
c2Active := New(t).Logger(testLog).Build()
|
||||
c3Offline := New(t).DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c4Offline := New(t).DisconnectedDuration(time.Minute).Logger(testLog).Build()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -142,13 +236,13 @@ func TestDeleteOfflineClient(t *testing.T) {
|
||||
}{
|
||||
{
|
||||
name: "delete offline client",
|
||||
clientID: c3Offline.ID,
|
||||
clientID: c3Offline.GetID(),
|
||||
wantError: nil,
|
||||
},
|
||||
{
|
||||
name: "delete active client",
|
||||
clientID: c1Active.ID,
|
||||
wantError: errors2.APIError{
|
||||
clientID: c1Active.GetID(),
|
||||
wantError: apiErrors.APIError{
|
||||
Message: "Client is active, should be disconnected",
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
@ -156,7 +250,7 @@ func TestDeleteOfflineClient(t *testing.T) {
|
||||
{
|
||||
name: "delete unknown client",
|
||||
clientID: "unknown-id",
|
||||
wantError: errors2.APIError{
|
||||
wantError: apiErrors.APIError{
|
||||
Message: fmt.Sprintf("Client with id=%q not found.", "unknown-id"),
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
},
|
||||
@ -164,7 +258,7 @@ func TestDeleteOfflineClient(t *testing.T) {
|
||||
{
|
||||
name: "empty client ID",
|
||||
clientID: "",
|
||||
wantError: errors2.APIError{
|
||||
wantError: apiErrors.APIError{
|
||||
Message: "Client id is empty",
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
@ -175,8 +269,7 @@ func TestDeleteOfflineClient(t *testing.T) {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// given
|
||||
clientService := NewClientService(nil, nil, NewClientRepository([]*Client{c1Active, c2Active, c3Offline, c4Offline}, &hour, testLog), testLog)
|
||||
before, err := clientService.Count()
|
||||
require.NoError(t, err)
|
||||
before := clientService.Count()
|
||||
require.Equal(t, 4, before)
|
||||
|
||||
// when
|
||||
@ -190,8 +283,7 @@ func TestDeleteOfflineClient(t *testing.T) {
|
||||
} else {
|
||||
wantAfter = before - 1
|
||||
}
|
||||
gotAfter, err := clientService.Count()
|
||||
require.NoError(t, err)
|
||||
gotAfter := clientService.Count()
|
||||
assert.Equal(t, wantAfter, gotAfter)
|
||||
})
|
||||
}
|
||||
@ -200,9 +292,9 @@ func TestDeleteOfflineClient(t *testing.T) {
|
||||
func TestCheckLocalPort(t *testing.T) {
|
||||
srv := ClientServiceProvider{
|
||||
portDistributor: ports.NewPortDistributorForTests(
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4}),
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4, 5}),
|
||||
mapset.NewSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
|
||||
mapset.NewSetFromSlice([]interface{}{2, 3, 4}),
|
||||
mapset.NewSetFromSlice([]interface{}{2, 3, 4, 5}),
|
||||
),
|
||||
}
|
||||
|
||||
@ -223,7 +315,7 @@ func TestCheckLocalPort(t *testing.T) {
|
||||
{
|
||||
name: "invalid port",
|
||||
port: invalidPort,
|
||||
wantError: errors2.APIError{
|
||||
wantError: apiErrors.APIError{
|
||||
Message: "Invalid local port: 24563a.",
|
||||
Err: invalidPortParseErr,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
@ -232,7 +324,7 @@ func TestCheckLocalPort(t *testing.T) {
|
||||
{
|
||||
name: "not allowed port",
|
||||
port: "6",
|
||||
wantError: errors2.APIError{
|
||||
wantError: apiErrors.APIError{
|
||||
Message: "Local port 6 is not among allowed ports.",
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
},
|
||||
@ -240,7 +332,7 @@ func TestCheckLocalPort(t *testing.T) {
|
||||
{
|
||||
name: "busy port tcp",
|
||||
port: "5",
|
||||
wantError: errors2.APIError{
|
||||
wantError: apiErrors.APIError{
|
||||
Message: "Local port 5 already in use.",
|
||||
HTTPStatus: http.StatusConflict,
|
||||
},
|
||||
@ -255,7 +347,7 @@ func TestCheckLocalPort(t *testing.T) {
|
||||
name: "tcp+udp port busy",
|
||||
port: "5",
|
||||
protocol: models.ProtocolTCPUDP,
|
||||
wantError: errors2.APIError{
|
||||
wantError: apiErrors.APIError{
|
||||
Message: "Local port 5 already in use.",
|
||||
HTTPStatus: http.StatusConflict,
|
||||
},
|
||||
@ -283,13 +375,13 @@ func TestCheckLocalPort(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCheckClientsAccess(t *testing.T) {
|
||||
c1 := New(t).Build() // no groups
|
||||
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Build() // admin
|
||||
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Build() // admin + group1
|
||||
c4 := New(t).AllowedUserGroups([]string{"group1"}).Build() // group1
|
||||
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Build() // group1 + group2
|
||||
c6 := New(t).AllowedUserGroups([]string{"group3"}).Build() // group3
|
||||
c7 := New(t).Build()
|
||||
c1 := New(t).Logger(testLog).Build() // no groups
|
||||
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Logger(testLog).Build() // admin
|
||||
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Logger(testLog).Build() // admin + group1
|
||||
c4 := New(t).AllowedUserGroups([]string{"group1"}).Logger(testLog).Build() // group1
|
||||
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Logger(testLog).Build() // group1 + group2
|
||||
c6 := New(t).AllowedUserGroups([]string{"group3"}).Logger(testLog).Build() // group3
|
||||
c7 := New(t).Logger(testLog).Build()
|
||||
|
||||
allClients := []*Client{c1, c2, c3, c4, c5, c6}
|
||||
clientGroups := []*cgroups.ClientGroup{
|
||||
@ -297,7 +389,7 @@ func TestCheckClientsAccess(t *testing.T) {
|
||||
ID: "1",
|
||||
AllowedUserGroups: []string{"group4"},
|
||||
Params: &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{cgroups.Param(c7.ID)},
|
||||
ClientID: &cgroups.ParamValues{cgroups.Param(c7.GetID())},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -312,7 +404,7 @@ func TestCheckClientsAccess(t *testing.T) {
|
||||
name: "user with no groups has no access",
|
||||
clients: allClients,
|
||||
user: &users.User{Groups: nil},
|
||||
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c3.ID, c4.ID, c5.ID, c6.ID},
|
||||
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c3.GetID(), c4.GetID(), c5.GetID(), c6.GetID()},
|
||||
},
|
||||
{
|
||||
name: "admin user has access to all",
|
||||
@ -330,25 +422,25 @@ func TestCheckClientsAccess(t *testing.T) {
|
||||
name: "non-admin user with no access to clients with no groups and with admin group",
|
||||
clients: allClients,
|
||||
user: &users.User{Groups: []string{"group1", "group2", "group3"}},
|
||||
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID},
|
||||
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID()},
|
||||
},
|
||||
{
|
||||
name: "non-admin user with access to one client",
|
||||
clients: allClients,
|
||||
user: &users.User{Groups: []string{"group3"}},
|
||||
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c3.ID, c4.ID, c5.ID},
|
||||
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c3.GetID(), c4.GetID(), c5.GetID()},
|
||||
},
|
||||
{
|
||||
name: "non-admin user with access to few clients",
|
||||
clients: allClients,
|
||||
user: &users.User{Groups: []string{"group1"}},
|
||||
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c6.ID},
|
||||
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c6.GetID()},
|
||||
},
|
||||
{
|
||||
name: "non-admin user that has unknown group",
|
||||
clients: allClients,
|
||||
user: &users.User{Groups: []string{"group4"}},
|
||||
wantClientIDsWithNoAccess: []string{c1.ID, c2.ID, c3.ID, c4.ID, c5.ID, c6.ID},
|
||||
wantClientIDsWithNoAccess: []string{c1.GetID(), c2.GetID(), c3.GetID(), c4.GetID(), c5.GetID(), c6.GetID()},
|
||||
},
|
||||
{
|
||||
name: "non-admin user given access via client groups",
|
||||
@ -368,7 +460,7 @@ func TestCheckClientsAccess(t *testing.T) {
|
||||
|
||||
// then
|
||||
if len(tc.wantClientIDsWithNoAccess) > 0 {
|
||||
wantErr := errors2.APIError{
|
||||
wantErr := apiErrors.APIError{
|
||||
Message: fmt.Sprintf("Access denied to client(s) with ID(s): %v", strings.Join(tc.wantClientIDsWithNoAccess, ", ")),
|
||||
HTTPStatus: http.StatusForbidden,
|
||||
}
|
||||
@ -782,7 +874,7 @@ func TestGetTunnelsToReestablish(t *testing.T) {
|
||||
}
|
||||
|
||||
// when
|
||||
gotRes := GetTunnelsToReestablish(old, new)
|
||||
gotRes := getTunnelsToReestablish(old, new)
|
||||
|
||||
var gotResStr []string
|
||||
for _, r := range gotRes {
|
||||
@ -843,7 +935,7 @@ func TestShouldStartTunnelsWithSubdomains(t *testing.T) {
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
c1 := New(t).ID("client-1").ClientAuthID(cl1.ID).Build()
|
||||
c1 := New(t).ID("client-1").ClientAuthID(cl1.ID).Logger(testLog).Build()
|
||||
c1.Connection = connMock
|
||||
c1.Logger = testLog
|
||||
c1.Context = context.Background()
|
||||
@ -855,9 +947,9 @@ func TestShouldStartTunnelsWithSubdomains(t *testing.T) {
|
||||
}
|
||||
|
||||
pd := ports.NewPortDistributorForTests(
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{4000, 4001, 4002, 4003, 4004, 4005}),
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{4000, 4002, 4003, 4004, 4005}),
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{5002, 5003, 5004, 5005}),
|
||||
mapset.NewSetFromSlice([]interface{}{4000, 4001, 4002, 4003, 4004, 4005}),
|
||||
mapset.NewSetFromSlice([]interface{}{4000, 4002, 4003, 4004, 4005}),
|
||||
mapset.NewSetFromSlice([]interface{}{5002, 5003, 5004, 5005}),
|
||||
)
|
||||
|
||||
mockCaddyAPI := &MockCaddyAPI{}
|
||||
|
||||
@ -5,11 +5,23 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/cloudradar-monitoring/rport/server/api/users"
|
||||
"github.com/cloudradar-monitoring/rport/server/cgroups"
|
||||
)
|
||||
|
||||
func NewTestClient(id string, address string, hostname string, clientAuthID string, connection ssh.Conn) (c *Client) {
|
||||
c = &Client{
|
||||
ID: id,
|
||||
Address: address,
|
||||
Hostname: hostname,
|
||||
ClientAuthID: clientAuthID,
|
||||
Connection: connection,
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func TestClientBelongsToGroup(t *testing.T) {
|
||||
c1 := &Client{
|
||||
ID: "test-client-id-1",
|
||||
|
||||
@ -6,14 +6,15 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/cloudradar-monitoring/rport/share/comm"
|
||||
"github.com/cloudradar-monitoring/rport/share/logger"
|
||||
)
|
||||
|
||||
func IsAllowed(remote string, conn ssh.Conn) (bool, error) {
|
||||
func IsAllowed(remote string, conn ssh.Conn, l *logger.Logger) (bool, error) {
|
||||
req := &comm.CheckTunnelAllowedRequest{
|
||||
Remote: remote,
|
||||
}
|
||||
resp := &comm.CheckTunnelAllowedResponse{}
|
||||
err := comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckTunnelAllowed, req, resp)
|
||||
err := comm.SendRequestAndGetResponse(conn, comm.RequestTypeCheckTunnelAllowed, req, resp, l)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "unknown request") {
|
||||
return true, nil
|
||||
|
||||
@ -15,13 +15,16 @@ import (
|
||||
)
|
||||
|
||||
type ClientRepository struct {
|
||||
// in-memory cache
|
||||
clients map[string]*Client
|
||||
mu sync.RWMutex
|
||||
KeepDisconnectedClients *time.Duration
|
||||
// storage
|
||||
store ClientStore
|
||||
// in-memory state
|
||||
clientState map[string]*Client
|
||||
// db based store
|
||||
clientStore ClientStore
|
||||
|
||||
keepDisconnectedClients *time.Duration
|
||||
|
||||
logger *logger.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
type User interface {
|
||||
@ -40,14 +43,15 @@ func NewClientRepository(initClients []*Client, keepDisconnectedClients *time.Du
|
||||
func NewClientRepositoryWithDB(initialClients []*Client, keepDisconnectedClients *time.Duration, store ClientStore, logger *logger.Logger) *ClientRepository {
|
||||
clients := make(map[string]*Client)
|
||||
for i := range initialClients {
|
||||
clients[initialClients[i].ID] = initialClients[i]
|
||||
newClientID := initialClients[i].GetID()
|
||||
clients[newClientID] = initialClients[i]
|
||||
}
|
||||
|
||||
return &ClientRepository{
|
||||
clients: clients,
|
||||
KeepDisconnectedClients: keepDisconnectedClients,
|
||||
store: store,
|
||||
clientState: clients,
|
||||
clientStore: store,
|
||||
logger: logger,
|
||||
keepDisconnectedClients: keepDisconnectedClients,
|
||||
}
|
||||
}
|
||||
|
||||
@ -58,7 +62,7 @@ func InitClientRepository(
|
||||
logger *logger.Logger,
|
||||
) (*ClientRepository, error) {
|
||||
provider := newSqliteProvider(db, keepDisconnectedClients)
|
||||
initialClients, err := LoadInitialClients(ctx, provider, logger.Fork("client-loader"))
|
||||
initialClients, err := LoadInitialClients(ctx, provider, logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -66,52 +70,54 @@ func InitClientRepository(
|
||||
return NewClientRepositoryWithDB(initialClients, keepDisconnectedClients, provider, logger), nil
|
||||
}
|
||||
|
||||
func (s *ClientRepository) Save(client *Client) error {
|
||||
func (r *ClientRepository) Save(client *Client) error {
|
||||
ts := time.Now()
|
||||
if s.store != nil {
|
||||
err := s.store.Save(context.Background(), client)
|
||||
|
||||
store := r.getStore()
|
||||
|
||||
if store != nil {
|
||||
err := store.Save(context.Background(), client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save a client: %w", err)
|
||||
return fmt.Errorf("failed to save client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.clients[client.ID] = client
|
||||
s.logger.Debugf(
|
||||
r.updateClient(client)
|
||||
|
||||
r.log().Debugf(
|
||||
"saved client: %s status=%s, within %s",
|
||||
client.ID,
|
||||
FormatConnectionState(client.DisconnectedAt),
|
||||
client.GetID(),
|
||||
FormatConnectionState(client),
|
||||
time.Since(ts),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ClientRepository) Delete(client *Client) error {
|
||||
s.logger.Debugf("deleting client: %s status=%s", client.ID, FormatConnectionState(client.DisconnectedAt))
|
||||
func (r *ClientRepository) Delete(client *Client) error {
|
||||
clientID := client.GetID()
|
||||
|
||||
if s.store != nil {
|
||||
err := s.store.Delete(context.Background(), client.ID)
|
||||
r.log().Debugf("deleting client: %s status=%s", clientID, FormatConnectionState(client))
|
||||
|
||||
store := r.getStore()
|
||||
|
||||
if store != nil {
|
||||
err := store.Delete(context.Background(), clientID, client.Log())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete a client: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.clients, client.ID)
|
||||
r.removeClient(clientID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ClientRepository) GetClientsByTag(tags []string, operator string, allowDisconnected bool) (matchingClients []*Client, err error) {
|
||||
func (r *ClientRepository) GetClientsByTag(tags []string, operator string, allowDisconnected bool) (matchingClients []*Client, err error) {
|
||||
var availableClients []*Client
|
||||
if allowDisconnected {
|
||||
availableClients, err = s.GetAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
availableClients = r.GetAllClients()
|
||||
} else {
|
||||
availableClients = s.GetAllActive()
|
||||
availableClients, _ = r.GetAllActiveClients()
|
||||
}
|
||||
if strings.EqualFold(operator, "AND") {
|
||||
matchingClients = findMatchingANDClients(availableClients, tags)
|
||||
@ -122,10 +128,13 @@ func (s *ClientRepository) GetClientsByTag(tags []string, operator string, allow
|
||||
return matchingClients, nil
|
||||
}
|
||||
|
||||
// this fn doesn't lock the availableClients. please make sure not to use the main clients array.
|
||||
// the various GetXXXClient fns will return new client arrays. please use those fns to get a
|
||||
// clients array copy for this fn to operate on.
|
||||
func findMatchingANDClients(availableClients []*Client, tags []string) (matchingClients []*Client) {
|
||||
matchingClients = make([]*Client, 0, 64)
|
||||
for _, cl := range availableClients {
|
||||
clientTags := cl.Tags
|
||||
clientTags := cl.GetTags()
|
||||
|
||||
foundAllTags := true
|
||||
for _, tag := range tags {
|
||||
@ -144,14 +153,18 @@ func findMatchingANDClients(availableClients []*Client, tags []string) (matching
|
||||
if foundAllTags {
|
||||
matchingClients = append(matchingClients, cl)
|
||||
}
|
||||
|
||||
}
|
||||
return matchingClients
|
||||
}
|
||||
|
||||
// this fn doesn't lock the availableClients. please make sure not to use the main clients array.
|
||||
// the various GetXXXClient fns will return new client arrays. please use those fns to get a
|
||||
// clients array copy for this fn to operate on.
|
||||
func findMatchingORClients(availableClients []*Client, tags []string) (matchingClients []*Client) {
|
||||
matchingClients = make([]*Client, 0, 64)
|
||||
for _, cl := range availableClients {
|
||||
clientTags := cl.Tags
|
||||
clientTags := cl.GetTags()
|
||||
nextClientForOR:
|
||||
for _, clTag := range clientTags {
|
||||
for _, tag := range tags {
|
||||
@ -166,57 +179,56 @@ func findMatchingORClients(availableClients []*Client, tags []string) (matchingC
|
||||
}
|
||||
|
||||
// DeleteObsolete deletes obsolete disconnected clients and returns them.
|
||||
func (s *ClientRepository) DeleteObsolete() ([]*Client, error) {
|
||||
s.logger.Debugf("deleting obsolete clients")
|
||||
if s.store != nil {
|
||||
err := s.store.DeleteObsolete(context.Background())
|
||||
func (r *ClientRepository) DeleteObsolete() ([]*Client, error) {
|
||||
r.log().Debugf("deleting obsolete clients")
|
||||
store := r.getStore()
|
||||
|
||||
if store != nil {
|
||||
err := store.DeleteObsolete(context.Background(), r.log())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to delete obsolete clients: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var deleted []*Client
|
||||
for _, client := range s.clients {
|
||||
if client.Obsolete(s.KeepDisconnectedClients) {
|
||||
s.logger.Debugf("deleting obsolete client: %s status=%s", client.ID, FormatConnectionState(client.DisconnectedAt))
|
||||
r.mu.RLock()
|
||||
for _, client := range r.getClients() {
|
||||
r.mu.RUnlock()
|
||||
clientID := client.GetID()
|
||||
|
||||
if client.Obsolete(r.GetKeepDisconnectedClients()) {
|
||||
r.log().Debugf("deleting obsolete client: %s status=%s", clientID, FormatConnectionState(client))
|
||||
|
||||
r.removeClient(clientID)
|
||||
|
||||
delete(s.clients, client.ID)
|
||||
deleted = append(deleted, client)
|
||||
}
|
||||
r.mu.RLock()
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
// Count returns a number of non-obsolete active and disconnected clients.
|
||||
func (s *ClientRepository) Count() (int, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
clients, err := s.getNonObsolete()
|
||||
return len(clients), err
|
||||
func (r *ClientRepository) Count() int {
|
||||
_, count := r.getNonObsoleteClients()
|
||||
return count
|
||||
}
|
||||
|
||||
// CountActive returns a number of active clients.
|
||||
func (s *ClientRepository) CountActive() (int, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.GetAllActive()), nil
|
||||
func (r *ClientRepository) CountActive() (count int) {
|
||||
_, count = r.GetAllActiveClients()
|
||||
return count
|
||||
}
|
||||
|
||||
// CountDisconnected returns a number of disconnected clients.
|
||||
func (s *ClientRepository) CountDisconnected() (int, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
all, err := s.getNonObsolete()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
func (r *ClientRepository) CountDisconnected() (int, error) {
|
||||
availableClients, _ := r.getNonObsoleteClients()
|
||||
|
||||
var n int
|
||||
for _, cur := range all {
|
||||
if cur.DisconnectedAt != nil {
|
||||
// uses copy of clients array returned by getNonObsoleteClients
|
||||
for _, cur := range availableClients {
|
||||
if cur.GetDisconnectedAt() != nil {
|
||||
n++
|
||||
}
|
||||
}
|
||||
@ -224,114 +236,182 @@ func (s *ClientRepository) CountDisconnected() (int, error) {
|
||||
}
|
||||
|
||||
// GetByID returns non-obsolete active or disconnected client by a given id.
|
||||
func (s *ClientRepository) GetByID(id string) (*Client, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
client := s.clients[id]
|
||||
if client != nil && client.Obsolete(s.KeepDisconnectedClients) {
|
||||
func (r *ClientRepository) GetByID(id string) (*Client, error) {
|
||||
client := r.getClient(id)
|
||||
|
||||
if client != nil && client.Obsolete(r.GetKeepDisconnectedClients()) {
|
||||
return nil, nil
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetActiveByID returns an active client by a given id.
|
||||
func (s *ClientRepository) GetActiveByID(id string) (*Client, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
client := s.clients[id]
|
||||
if client != nil && client.DisconnectedAt != nil {
|
||||
func (r *ClientRepository) GetActiveByID(id string) (*Client, error) {
|
||||
client := r.getClient(id)
|
||||
|
||||
if client != nil && client.GetDisconnectedAt() != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// GetAllByClientAuthID @todo: make it consistent with others whether to return an error. In general it's just a cache, so should not return an err.
|
||||
func (s *ClientRepository) GetAllByClientAuthID(clientAuthID string) []*Client {
|
||||
all, _ := s.GetAll()
|
||||
var res []*Client
|
||||
for _, v := range all {
|
||||
if v.ClientAuthID == clientAuthID {
|
||||
res = append(res, v)
|
||||
func (r *ClientRepository) GetAllByClientAuthID(clientAuthID string) []*Client {
|
||||
availableClients := r.GetAllClients()
|
||||
var matchingClients []*Client
|
||||
// uses copy of clients array returned by GetAllClients
|
||||
for _, c := range availableClients {
|
||||
if c.GetClientAuthID() == clientAuthID {
|
||||
matchingClients = append(matchingClients, c)
|
||||
}
|
||||
}
|
||||
return res
|
||||
return matchingClients
|
||||
}
|
||||
|
||||
// GetAll returns all non-obsolete active and disconnected client clients.
|
||||
func (s *ClientRepository) GetAll() ([]*Client, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.getNonObsolete()
|
||||
func (r *ClientRepository) GetAllClients() []*Client {
|
||||
availableClients, _ := r.getNonObsoleteClients()
|
||||
return availableClients
|
||||
}
|
||||
|
||||
// GetUserClients returns all non-obsolete active and disconnected clients that current user has access to
|
||||
func (s *ClientRepository) GetUserClients(user User, groups []*cgroups.ClientGroup) ([]*Client, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.getNonObsoleteByUser(user, groups)
|
||||
func (r *ClientRepository) GetUserClients(user User, groups []*cgroups.ClientGroup) ([]*Client, error) {
|
||||
return r.getNonObsoleteClientsByUser(user, groups)
|
||||
}
|
||||
|
||||
// GetFilteredUserClients returns all non-obsolete active and disconnected clients that current user has access to, filtered by parameters
|
||||
func (s *ClientRepository) GetFilteredUserClients(user User, filterOptions []query.FilterOption, groups []*cgroups.ClientGroup) ([]*CalculatedClient, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
clients, err := s.getNonObsoleteByUser(user, groups)
|
||||
func (r *ClientRepository) GetFilteredUserClients(user User, filterOptions []query.FilterOption, groups []*cgroups.ClientGroup) ([]*CalculatedClient, error) {
|
||||
clients, err := r.getNonObsoleteClientsByUser(user, groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*CalculatedClient, 0, len(clients))
|
||||
matchingClients := make([]*CalculatedClient, 0, len(clients))
|
||||
|
||||
// uses copy of clients array returned by getNonObsoleteClientsByUser
|
||||
for _, client := range clients {
|
||||
calculatedClient := client.ToCalculated(groups)
|
||||
|
||||
// we need to lock because MatchesFilters receives an interface and not a client,
|
||||
// therefore we lose our ability to lock.
|
||||
calculatedClient.flock.RLock()
|
||||
matches, err := query.MatchesFilters(calculatedClient, filterOptions)
|
||||
calculatedClient.flock.RUnlock()
|
||||
|
||||
if err != nil {
|
||||
return result, err
|
||||
return matchingClients, err
|
||||
}
|
||||
|
||||
if matches {
|
||||
result = append(result, calculatedClient)
|
||||
matchingClients = append(matchingClients, calculatedClient)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return result, nil
|
||||
return matchingClients, nil
|
||||
}
|
||||
|
||||
func (s *ClientRepository) GetAllActive() []*Client {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
var result []*Client
|
||||
for _, client := range s.clients {
|
||||
if client.DisconnectedAt == nil {
|
||||
result = append(result, client)
|
||||
// GetAllActiveClients returns a new client array that can be used without locks (assuming not shared)
|
||||
func (r *ClientRepository) GetAllActiveClients() (matchingClients []*Client, count int) {
|
||||
count = 0
|
||||
clients := r.getClients()
|
||||
|
||||
r.mu.RLock()
|
||||
for _, client := range clients {
|
||||
r.mu.RUnlock()
|
||||
if client.GetDisconnectedAt() == nil {
|
||||
matchingClients = append(matchingClients, client)
|
||||
count++
|
||||
}
|
||||
r.mu.RLock()
|
||||
}
|
||||
return result
|
||||
r.mu.RUnlock()
|
||||
|
||||
return matchingClients, count
|
||||
}
|
||||
|
||||
func (s *ClientRepository) getNonObsolete() ([]*Client, error) {
|
||||
result := make([]*Client, 0, len(s.clients))
|
||||
for _, client := range s.clients {
|
||||
if !client.Obsolete(s.KeepDisconnectedClients) {
|
||||
result = append(result, client)
|
||||
// getNonObsoleteClients returns a new client array that can be used without locks (assuming not shared)
|
||||
func (r *ClientRepository) getNonObsoleteClients() (matchingClients []*Client, count int) {
|
||||
count = 0
|
||||
clients := r.getClients()
|
||||
|
||||
r.mu.RLock()
|
||||
matchingClients = make([]*Client, 0, len(clients))
|
||||
for _, client := range clients {
|
||||
r.mu.RUnlock()
|
||||
if !client.Obsolete(r.GetKeepDisconnectedClients()) {
|
||||
matchingClients = append(matchingClients, client)
|
||||
count++
|
||||
}
|
||||
r.mu.RLock()
|
||||
}
|
||||
return result, nil
|
||||
r.mu.RUnlock()
|
||||
return matchingClients, count
|
||||
}
|
||||
|
||||
// getNonObsoleteByUser return connected clients the user has access to either by user group or by client group
|
||||
func (s *ClientRepository) getNonObsoleteByUser(user User, clientGroups []*cgroups.ClientGroup) ([]*Client, error) {
|
||||
// getNonObsoleteByUser return connected clients the user has access to either by user group or by client group.
|
||||
// returns a new client array that can be used without locks (assuming not shared)
|
||||
func (r *ClientRepository) getNonObsoleteClientsByUser(user User, clientGroups []*cgroups.ClientGroup) ([]*Client, error) {
|
||||
clients := r.getClients()
|
||||
userGroups := user.GetGroups()
|
||||
result := make([]*Client, 0, len(s.clients))
|
||||
for _, client := range s.clients {
|
||||
if client.Obsolete(s.KeepDisconnectedClients) {
|
||||
continue
|
||||
}
|
||||
if user.IsAdmin() || client.HasAccessViaUserGroups(userGroups) || client.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
|
||||
result = append(result, client)
|
||||
continue
|
||||
|
||||
r.mu.RLock()
|
||||
matchingClients := make([]*Client, 0, len(clients))
|
||||
for _, client := range clients {
|
||||
r.mu.RUnlock()
|
||||
if !client.Obsolete(r.GetKeepDisconnectedClients()) {
|
||||
if user.IsAdmin() || client.HasAccessViaUserGroups(userGroups) || client.UserGroupHasAccessViaClientGroup(userGroups, clientGroups) {
|
||||
matchingClients = append(matchingClients, client)
|
||||
}
|
||||
}
|
||||
r.mu.RLock()
|
||||
}
|
||||
return result, nil
|
||||
r.mu.RUnlock()
|
||||
return matchingClients, nil
|
||||
}
|
||||
|
||||
func (r *ClientRepository) getStore() (store ClientStore) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.clientStore
|
||||
}
|
||||
|
||||
func (r *ClientRepository) log() (l *logger.Logger) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.logger
|
||||
}
|
||||
|
||||
// getClients returns the primary map of clients. accessing the return map must use locks.
|
||||
func (r *ClientRepository) getClients() (clients map[string]*Client) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.clientState
|
||||
}
|
||||
|
||||
func (r *ClientRepository) getClient(clientID string) (client *Client) {
|
||||
r.mu.RLock()
|
||||
client = r.clientState[clientID]
|
||||
r.mu.RUnlock()
|
||||
return client
|
||||
}
|
||||
|
||||
func (r *ClientRepository) updateClient(client *Client) {
|
||||
clientID := client.GetID()
|
||||
|
||||
r.mu.Lock()
|
||||
r.clientState[clientID] = client
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *ClientRepository) removeClient(clientID string) {
|
||||
r.mu.Lock()
|
||||
delete(r.clientState, clientID)
|
||||
r.mu.Unlock()
|
||||
}
|
||||
|
||||
func (r *ClientRepository) GetKeepDisconnectedClients() (keep *time.Duration) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.keepDisconnectedClients
|
||||
}
|
||||
|
||||
@ -40,29 +40,26 @@ func TestCRWithExpiration(t *testing.T) {
|
||||
assert.NoError(repo.Save(c3))
|
||||
assert.NoError(repo.Save(c4))
|
||||
|
||||
gotCount, err := repo.Count()
|
||||
assert.NoError(err)
|
||||
gotCount := repo.Count()
|
||||
assert.Equal(3, gotCount)
|
||||
|
||||
gotCountActive, err := repo.CountActive()
|
||||
assert.NoError(err)
|
||||
gotCountActive := repo.CountActive()
|
||||
assert.Equal(1, gotCountActive)
|
||||
|
||||
gotCountDisconnected, err := repo.CountDisconnected()
|
||||
assert.NoError(err)
|
||||
assert.Equal(2, gotCountDisconnected)
|
||||
|
||||
gotClients, err := repo.GetAll()
|
||||
assert.NoError(err)
|
||||
gotClients := repo.GetAllClients()
|
||||
assert.ElementsMatch([]*Client{c1, c2, c3}, gotClients)
|
||||
|
||||
// active
|
||||
gotClient, err := repo.GetActiveByID(c1.ID)
|
||||
gotClient, err := repo.GetActiveByID(c1.GetID())
|
||||
assert.NoError(err)
|
||||
assert.Equal(c1, gotClient)
|
||||
|
||||
// disconnected
|
||||
gotClient, err = repo.GetActiveByID(c2.ID)
|
||||
gotClient, err = repo.GetActiveByID(c2.GetID())
|
||||
assert.NoError(err)
|
||||
assert.Nil(gotClient)
|
||||
|
||||
@ -70,12 +67,11 @@ func TestCRWithExpiration(t *testing.T) {
|
||||
assert.NoError(err)
|
||||
require.Len(t, deleted, 1)
|
||||
assert.Equal(c4, deleted[0])
|
||||
gotClients, err = repo.GetAll()
|
||||
assert.NoError(err)
|
||||
gotClients = repo.GetAllClients()
|
||||
assert.ElementsMatch([]*Client{c1, c2, c3}, gotClients)
|
||||
|
||||
assert.NoError(repo.Delete(c3))
|
||||
gotClients, err = repo.GetAll()
|
||||
gotClients = repo.GetAllClients()
|
||||
assert.NoError(err)
|
||||
assert.ElementsMatch([]*Client{c1, c2}, gotClients)
|
||||
}
|
||||
@ -90,29 +86,26 @@ func TestCRWithNoExpiration(t *testing.T) {
|
||||
assert := assert.New(t)
|
||||
assert.NoError(repo.Save(c4Active))
|
||||
|
||||
gotCount, err := repo.Count()
|
||||
assert.NoError(err)
|
||||
gotCount := repo.Count()
|
||||
assert.Equal(4, gotCount)
|
||||
|
||||
gotCountActive, err := repo.CountActive()
|
||||
assert.NoError(err)
|
||||
gotCountActive := repo.CountActive()
|
||||
assert.Equal(2, gotCountActive)
|
||||
|
||||
gotCountDisconnected, err := repo.CountDisconnected()
|
||||
assert.NoError(err)
|
||||
assert.Equal(2, gotCountDisconnected)
|
||||
|
||||
gotClients, err := repo.GetAll()
|
||||
assert.NoError(err)
|
||||
gotClients := repo.GetAllClients()
|
||||
assert.ElementsMatch([]*Client{c1, c2, c3, c4Active}, gotClients)
|
||||
|
||||
// active
|
||||
gotClient, err := repo.GetActiveByID(c1.ID)
|
||||
gotClient, err := repo.GetActiveByID(c1.GetID())
|
||||
assert.NoError(err)
|
||||
assert.Equal(c1, gotClient)
|
||||
|
||||
// disconnected
|
||||
gotClient, err = repo.GetActiveByID(c2.ID)
|
||||
gotClient, err = repo.GetActiveByID(c2.GetID())
|
||||
assert.NoError(err)
|
||||
assert.Nil(gotClient)
|
||||
|
||||
@ -121,8 +114,7 @@ func TestCRWithNoExpiration(t *testing.T) {
|
||||
assert.Len(deleted, 0)
|
||||
|
||||
assert.NoError(repo.Delete(c4Active))
|
||||
gotClients, err = repo.GetAll()
|
||||
assert.NoError(err)
|
||||
gotClients = repo.GetAllClients()
|
||||
assert.ElementsMatch([]*Client{c1, c2, c3}, gotClients)
|
||||
}
|
||||
|
||||
@ -521,7 +513,7 @@ func TestCRWithFilter(t *testing.T) {
|
||||
actualClientIDs := make([]string, 0, len(actualClients))
|
||||
|
||||
for _, actualClient := range actualClients {
|
||||
actualClientIDs = append(actualClientIDs, actualClient.ID)
|
||||
actualClientIDs = append(actualClientIDs, actualClient.GetID())
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, tc.expectedClientIDs, actualClientIDs)
|
||||
@ -543,15 +535,15 @@ func TestCRWithUnsupportedFilter(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetUserClients(t *testing.T) {
|
||||
c1 := New(t).Build() // no groups
|
||||
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Build() // admin
|
||||
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Build() // admin + group1
|
||||
c4 := New(t).AllowedUserGroups([]string{"group1"}).Build() // group1
|
||||
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Build() // group1 + group2
|
||||
c6 := New(t).AllowedUserGroups([]string{"group2"}).Build() // group2
|
||||
c7 := New(t).AllowedUserGroups([]string{"group3"}).Build() // group3
|
||||
c8 := New(t).AllowedUserGroups([]string{"group2", "group3"}).Build() // group2 + group3
|
||||
c9 := New(t).Build()
|
||||
c1 := New(t).Logger(testLog).Build() // no groups
|
||||
c2 := New(t).AllowedUserGroups([]string{users.Administrators}).Logger(testLog).Build() // admin
|
||||
c3 := New(t).AllowedUserGroups([]string{users.Administrators, "group1"}).Logger(testLog).Build() // admin + group1
|
||||
c4 := New(t).AllowedUserGroups([]string{"group1"}).Logger(testLog).Build() // group1
|
||||
c5 := New(t).AllowedUserGroups([]string{"group1", "group2"}).Logger(testLog).Build() // group1 + group2
|
||||
c6 := New(t).AllowedUserGroups([]string{"group2"}).Logger(testLog).Build() // group2
|
||||
c7 := New(t).AllowedUserGroups([]string{"group3"}).Logger(testLog).Build() // group3
|
||||
c8 := New(t).AllowedUserGroups([]string{"group2", "group3"}).Logger(testLog).Build() // group2 + group3
|
||||
c9 := New(t).Logger(testLog).Build()
|
||||
allClients := []*Client{c1, c2, c3, c4, c5, c6, c7, c8, c9}
|
||||
|
||||
clientGroups := []*cgroups.ClientGroup{
|
||||
@ -559,7 +551,7 @@ func TestGetUserClients(t *testing.T) {
|
||||
ID: "1",
|
||||
AllowedUserGroups: []string{"group6"},
|
||||
Params: &cgroups.ClientParams{
|
||||
ClientID: &cgroups.ParamValues{cgroups.Param(c9.ID)},
|
||||
ClientID: &cgroups.ParamValues{cgroups.Param(c9.GetID())},
|
||||
},
|
||||
},
|
||||
}
|
||||
@ -675,7 +667,7 @@ func TestGetClientByTag(t *testing.T) {
|
||||
}
|
||||
|
||||
for idx, cl := range matchingClients {
|
||||
assert.Equal(t, tc.expectedClientIDs[idx], cl.ID)
|
||||
assert.Equal(t, tc.expectedClientIDs[idx], cl.GetID())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -204,6 +204,8 @@ func shallowCopy(c *Client) *Client {
|
||||
Address: c.Address,
|
||||
Tunnels: append([]*clienttunnel.Tunnel{}, c.Tunnels...),
|
||||
DisconnectedAt: c.DisconnectedAt,
|
||||
LastHeartbeatAt: c.LastHeartbeatAt,
|
||||
ClientAuthID: c.ClientAuthID,
|
||||
Logger: c.Logger,
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,12 +2,11 @@ package clients
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func FormatConnectionState(disconnectedAt *time.Time) string {
|
||||
if disconnectedAt != nil {
|
||||
return fmt.Sprintf("disconnected since %s", disconnectedAt)
|
||||
func FormatConnectionState(client *Client) string {
|
||||
if !client.IsConnected() {
|
||||
return fmt.Sprintf("disconnected since %s", client.GetDisconnectedAtValue())
|
||||
}
|
||||
return "connected"
|
||||
}
|
||||
|
||||
@ -9,24 +9,25 @@ import (
|
||||
|
||||
// LoadInitialClients returns an initial Client Repository state populated with clients from the internal storage.
|
||||
func LoadInitialClients(ctx context.Context, p ClientStore, logger *logger.Logger) ([]*Client, error) {
|
||||
if logger != nil {
|
||||
logger.Debugf("loading existing clients")
|
||||
}
|
||||
logger.Debugf("loading existing clients")
|
||||
all, err := p.GetAll(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get clients: %v", err)
|
||||
}
|
||||
|
||||
if logger != nil {
|
||||
logger.Debugf("loaded %d clients", len(all))
|
||||
}
|
||||
logger.Debugf("loaded %d clients", len(all))
|
||||
|
||||
// mark previously connected clients as disconnected with current time
|
||||
now := now()
|
||||
for _, cur := range all {
|
||||
if cur.DisconnectedAt == nil {
|
||||
cur.SetDisconnected(&now)
|
||||
err := p.Save(ctx, cur)
|
||||
|
||||
// setup a logger for the clients
|
||||
clientLogger := logger.Fork("client")
|
||||
|
||||
for _, client := range all {
|
||||
client.Logger = clientLogger
|
||||
if client.IsConnected() {
|
||||
client.SetDisconnectedAt(&now)
|
||||
err := p.Save(ctx, client)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save client: %v", err)
|
||||
}
|
||||
|
||||
@ -10,11 +10,12 @@ import (
|
||||
|
||||
func TestGetInitState(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
c1 := New(t).Build()
|
||||
c1 := New(t).ID("client-1").Logger(testLog).Build()
|
||||
wantC1 := shallowCopy(c1)
|
||||
wantC1.DisconnectedAt = &nowMock
|
||||
c2 := New(t).DisconnectedDuration(5 * time.Minute).Build()
|
||||
c3 := New(t).DisconnectedDuration(2 * time.Hour).Build()
|
||||
|
||||
wantC1.SetDisconnectedAt(&nowMock)
|
||||
c2 := New(t).ID("client-2").DisconnectedDuration(5 * time.Minute).Logger(testLog).Build()
|
||||
c3 := New(t).ID("client-3").DisconnectedDuration(2 * time.Hour).Logger(testLog).Build()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -44,16 +45,18 @@ func TestGetInitState(t *testing.T) {
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// given
|
||||
p := NewFakeClientProvider(t, &tc.expiration, tc.dbClients...)
|
||||
defer p.Close()
|
||||
|
||||
// when
|
||||
gotClients, gotErr := LoadInitialClients(ctx, p, nil)
|
||||
|
||||
// then
|
||||
gotClients, gotErr := LoadInitialClients(ctx, p, testLog)
|
||||
assert.NoError(t, gotErr)
|
||||
assert.Len(t, gotClients, len(tc.wantRes))
|
||||
|
||||
// patch the client logger for the ElementsMatch check
|
||||
for _, c := range gotClients {
|
||||
c.Logger = testLog
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, gotClients, tc.wantRes)
|
||||
})
|
||||
}
|
||||
|
||||
139
server/clients/payloads.go
Normal file
139
server/clients/payloads.go
Normal file
@ -0,0 +1,139 @@
|
||||
package clients
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
|
||||
"github.com/cloudradar-monitoring/rport/share/clientconfig"
|
||||
"github.com/cloudradar-monitoring/rport/share/models"
|
||||
"github.com/cloudradar-monitoring/rport/share/query"
|
||||
)
|
||||
|
||||
type ClientPayload struct {
|
||||
ID *string `json:"id,omitempty"`
|
||||
Name *string `json:"name,omitempty"`
|
||||
Address *string `json:"address,omitempty"`
|
||||
Hostname *string `json:"hostname,omitempty"`
|
||||
OS *string `json:"os,omitempty"`
|
||||
OSFullName *string `json:"os_full_name,omitempty"`
|
||||
OSVersion *string `json:"os_version,omitempty"`
|
||||
OSArch *string `json:"os_arch,omitempty"`
|
||||
OSFamily *string `json:"os_family,omitempty"`
|
||||
OSKernel *string `json:"os_kernel,omitempty"`
|
||||
OSVirtualizationSystem *string `json:"os_virtualization_system,omitempty"`
|
||||
OSVirtualizationRole *string `json:"os_virtualization_role,omitempty"`
|
||||
NumCPUs *int `json:"num_cpus,omitempty"`
|
||||
CPUFamily *string `json:"cpu_family,omitempty"`
|
||||
CPUModel *string `json:"cpu_model,omitempty"`
|
||||
CPUModelName *string `json:"cpu_model_name,omitempty"`
|
||||
CPUVendor *string `json:"cpu_vendor,omitempty"`
|
||||
MemoryTotal *uint64 `json:"mem_total,omitempty"`
|
||||
Timezone *string `json:"timezone,omitempty"`
|
||||
ClientAuthID *string `json:"client_auth_id,omitempty"`
|
||||
Version *string `json:"version,omitempty"`
|
||||
DisconnectedAt **time.Time `json:"disconnected_at,omitempty"`
|
||||
LastHeartbeatAt **time.Time `json:"last_heartbeat_at,omitempty"`
|
||||
ConnectionState *string `json:"connection_state,omitempty"`
|
||||
IPv4 *[]string `json:"ipv4,omitempty"`
|
||||
IPv6 *[]string `json:"ipv6,omitempty"`
|
||||
Tags *[]string `json:"tags,omitempty"`
|
||||
AllowedUserGroups *[]string `json:"allowed_user_groups,omitempty"`
|
||||
Tunnels *[]*clienttunnel.Tunnel `json:"tunnels,omitempty"`
|
||||
UpdatesStatus **models.UpdatesStatus `json:"updates_status,omitempty"`
|
||||
ClientConfiguration **clientconfig.Config `json:"client_configuration,omitempty"`
|
||||
Groups *[]string `json:"groups,omitempty"`
|
||||
}
|
||||
|
||||
func ConvertToClientsPayload(clientsList []*CalculatedClient, fields []query.FieldsOption) []ClientPayload {
|
||||
r := make([]ClientPayload, 0, len(clientsList))
|
||||
for _, cur := range clientsList {
|
||||
r = append(r, ConvertToClientPayload(cur, fields))
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func ConvertToClientPayload(client *CalculatedClient, fields []query.FieldsOption) ClientPayload { //nolint:gocyclo
|
||||
requestedFields := query.RequestedFields(fields, "clients")
|
||||
p := ClientPayload{}
|
||||
for field := range OptionsSupportedFields["clients"] {
|
||||
if len(fields) > 0 && !requestedFields[field] {
|
||||
continue
|
||||
}
|
||||
|
||||
client.flock.RLock()
|
||||
defer client.flock.RUnlock()
|
||||
|
||||
switch field {
|
||||
case "id":
|
||||
id := client.ID
|
||||
p.ID = &id
|
||||
case "name":
|
||||
name := client.Name
|
||||
p.Name = &name
|
||||
case "os":
|
||||
p.OS = &client.OS
|
||||
case "os_arch":
|
||||
p.OSArch = &client.OSArch
|
||||
case "os_family":
|
||||
p.OSFamily = &client.OSFamily
|
||||
case "os_kernel":
|
||||
p.OSKernel = &client.OSKernel
|
||||
case "hostname":
|
||||
p.Hostname = &client.Hostname
|
||||
case "ipv4":
|
||||
p.IPv4 = &client.IPv4
|
||||
case "ipv6":
|
||||
p.IPv6 = &client.IPv6
|
||||
case "tags":
|
||||
p.Tags = &client.Tags
|
||||
case "version":
|
||||
p.Version = &client.Version
|
||||
case "address":
|
||||
p.Address = &client.Address
|
||||
case "tunnels":
|
||||
p.Tunnels = &client.Tunnels
|
||||
case "disconnected_at":
|
||||
disconnectedAt := client.DisconnectedAt
|
||||
p.DisconnectedAt = &disconnectedAt
|
||||
case "last_heartbeat_at":
|
||||
lastHeartbeatAt := client.LastHeartbeatAt
|
||||
p.LastHeartbeatAt = &lastHeartbeatAt
|
||||
case "client_auth_id":
|
||||
p.ClientAuthID = &client.ClientAuthID
|
||||
case "os_full_name":
|
||||
p.OSFullName = &client.OSFullName
|
||||
case "os_version":
|
||||
p.OSVersion = &client.OSVersion
|
||||
case "os_virtualization_system":
|
||||
p.OSVirtualizationSystem = &client.OSVirtualizationSystem
|
||||
case "os_virtualization_role":
|
||||
p.OSVirtualizationRole = &client.OSVirtualizationRole
|
||||
case "cpu_family":
|
||||
p.CPUFamily = &client.CPUFamily
|
||||
case "cpu_model":
|
||||
p.CPUModel = &client.CPUModel
|
||||
case "cpu_model_name":
|
||||
p.CPUModelName = &client.CPUModelName
|
||||
case "cpu_vendor":
|
||||
p.CPUVendor = &client.CPUVendor
|
||||
case "timezone":
|
||||
p.Timezone = &client.Timezone
|
||||
case "num_cpus":
|
||||
p.NumCPUs = &client.NumCPUs
|
||||
case "mem_total":
|
||||
p.MemoryTotal = &client.MemoryTotal
|
||||
case "allowed_user_groups":
|
||||
p.AllowedUserGroups = &client.AllowedUserGroups
|
||||
case "updates_status":
|
||||
p.UpdatesStatus = &client.UpdatesStatus
|
||||
case "client_configuration":
|
||||
p.ClientConfiguration = &client.ClientConfiguration
|
||||
case "groups":
|
||||
p.Groups = &client.Groups
|
||||
case "connection_state":
|
||||
connectionState := string(client.GetConnectionState())
|
||||
p.ConnectionState = &connectionState
|
||||
}
|
||||
}
|
||||
return p
|
||||
}
|
||||
@ -7,7 +7,7 @@ import (
|
||||
|
||||
func SortByID(a []*CalculatedClient, desc bool) {
|
||||
sort.Slice(a, func(i, j int) bool {
|
||||
less := strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
|
||||
less := strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
|
||||
if desc {
|
||||
return !less
|
||||
}
|
||||
@ -17,9 +17,9 @@ func SortByID(a []*CalculatedClient, desc bool) {
|
||||
|
||||
func SortByName(a []*CalculatedClient, desc bool) {
|
||||
sort.Slice(a, func(i, j int) bool {
|
||||
aiName := strings.ToLower(a[i].Name)
|
||||
ajName := strings.ToLower(a[j].Name)
|
||||
less := aiName < ajName || aiName == ajName && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
|
||||
aiName := strings.ToLower(a[i].GetName())
|
||||
ajName := strings.ToLower(a[j].GetName())
|
||||
less := aiName < ajName || aiName == ajName && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
|
||||
if desc {
|
||||
return !less
|
||||
}
|
||||
@ -29,9 +29,9 @@ func SortByName(a []*CalculatedClient, desc bool) {
|
||||
|
||||
func SortByOS(a []*CalculatedClient, desc bool) {
|
||||
sort.Slice(a, func(i, j int) bool {
|
||||
aiOS := strings.ToLower(a[i].OS)
|
||||
ajOS := strings.ToLower(a[j].OS)
|
||||
less := aiOS < ajOS || aiOS == ajOS && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
|
||||
aiOS := strings.ToLower(a[i].GetOS())
|
||||
ajOS := strings.ToLower(a[j].GetOS())
|
||||
less := aiOS < ajOS || aiOS == ajOS && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
|
||||
if desc {
|
||||
return !less
|
||||
}
|
||||
@ -41,9 +41,9 @@ func SortByOS(a []*CalculatedClient, desc bool) {
|
||||
|
||||
func SortByHostname(a []*CalculatedClient, desc bool) {
|
||||
sort.Slice(a, func(i, j int) bool {
|
||||
aiHostname := strings.ToLower(a[i].Hostname)
|
||||
ajHostname := strings.ToLower(a[j].Hostname)
|
||||
less := aiHostname < ajHostname || aiHostname == ajHostname && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
|
||||
aiHostname := strings.ToLower(a[i].GetHostname())
|
||||
ajHostname := strings.ToLower(a[j].GetHostname())
|
||||
less := aiHostname < ajHostname || aiHostname == ajHostname && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
|
||||
if desc {
|
||||
return !less
|
||||
}
|
||||
@ -53,9 +53,9 @@ func SortByHostname(a []*CalculatedClient, desc bool) {
|
||||
|
||||
func SortByVersion(a []*CalculatedClient, desc bool) {
|
||||
sort.Slice(a, func(i, j int) bool {
|
||||
aiVersion := strings.ToLower(a[i].Version)
|
||||
ajVersion := strings.ToLower(a[j].Version)
|
||||
less := aiVersion < ajVersion || aiVersion == ajVersion && strings.ToLower(a[i].ID) < strings.ToLower(a[j].ID)
|
||||
aiVersion := strings.ToLower(a[i].GetVersion())
|
||||
ajVersion := strings.ToLower(a[j].GetVersion())
|
||||
less := aiVersion < ajVersion || aiVersion == ajVersion && strings.ToLower(a[i].GetID()) < strings.ToLower(a[j].GetID())
|
||||
if desc {
|
||||
return !less
|
||||
}
|
||||
|
||||
@ -11,16 +11,18 @@ import (
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
|
||||
"github.com/cloudradar-monitoring/rport/db/sqlite"
|
||||
"github.com/cloudradar-monitoring/rport/server/clients/clienttunnel"
|
||||
chshare "github.com/cloudradar-monitoring/rport/share/clientconfig"
|
||||
"github.com/cloudradar-monitoring/rport/share/logger"
|
||||
"github.com/cloudradar-monitoring/rport/share/models"
|
||||
)
|
||||
|
||||
type ClientStore interface {
|
||||
GetAll(ctx context.Context) ([]*Client, error)
|
||||
Save(ctx context.Context, client *Client) error
|
||||
DeleteObsolete(ctx context.Context) error
|
||||
Delete(ctx context.Context, id string) error
|
||||
DeleteObsolete(ctx context.Context, l *logger.Logger) error
|
||||
Delete(ctx context.Context, id string, l *logger.Logger) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
@ -48,6 +50,7 @@ func (p *SqliteProvider) GetAll(ctx context.Context) ([]*Client, error) {
|
||||
return convertClientList(res), nil
|
||||
}
|
||||
|
||||
// test only
|
||||
func (p *SqliteProvider) get(ctx context.Context, id string) (*Client, error) {
|
||||
res := &clientSqlite{}
|
||||
err := p.db.GetContext(ctx, res, "SELECT * FROM clients WHERE id = ?", id)
|
||||
@ -61,29 +64,54 @@ func (p *SqliteProvider) get(ctx context.Context, id string) (*Client, error) {
|
||||
}
|
||||
|
||||
func (p *SqliteProvider) Save(ctx context.Context, client *Client) error {
|
||||
_, err := p.db.NamedExecContext(
|
||||
ctx,
|
||||
"INSERT OR REPLACE INTO clients (id, client_auth_id, disconnected_at, details) VALUES (:id, :client_auth_id, :disconnected_at, :details)",
|
||||
convertToSqlite(client),
|
||||
)
|
||||
|
||||
_, err := sqlite.WithRetryWhenBusy(func() (result sql.Result, err error) {
|
||||
|
||||
clientForSQL := convertToSqlite(client)
|
||||
|
||||
_, err = p.db.NamedExecContext(
|
||||
ctx,
|
||||
"INSERT OR REPLACE INTO clients (id, client_auth_id, disconnected_at, details) VALUES (:id, :client_auth_id, :disconnected_at, :details)",
|
||||
clientForSQL,
|
||||
)
|
||||
|
||||
return nil, err
|
||||
}, "save", client.Log())
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *SqliteProvider) DeleteObsolete(ctx context.Context) error {
|
||||
_, err := p.db.ExecContext(
|
||||
ctx,
|
||||
"DELETE FROM clients WHERE disconnected_at IS NOT NULL AND DATETIME(disconnected_at) < DATETIME(?) AND ?",
|
||||
p.keepDisconnectedClientsStart(),
|
||||
p.keepDisconnectedClients != nil,
|
||||
)
|
||||
func (p *SqliteProvider) DeleteObsolete(ctx context.Context, l *logger.Logger) error {
|
||||
_, err := sqlite.WithRetryWhenBusy(func() (result sql.Result, err error) {
|
||||
|
||||
_, err = p.db.ExecContext(
|
||||
ctx,
|
||||
"DELETE FROM clients WHERE disconnected_at IS NOT NULL AND DATETIME(disconnected_at) < DATETIME(?) AND ?",
|
||||
p.keepDisconnectedClientsStart(),
|
||||
p.keepDisconnectedClients != nil,
|
||||
)
|
||||
|
||||
return nil, err
|
||||
}, "delete obsolete", l)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *SqliteProvider) Delete(ctx context.Context, id string) error {
|
||||
_, err := p.db.ExecContext(ctx, "DELETE FROM clients WHERE id = ?", id)
|
||||
func (p *SqliteProvider) Delete(ctx context.Context, id string, l *logger.Logger) error {
|
||||
_, err := sqlite.WithRetryWhenBusy(func() (result sql.Result, err error) {
|
||||
|
||||
_, err = p.db.ExecContext(ctx, "DELETE FROM clients WHERE id = ?", id)
|
||||
|
||||
return nil, err
|
||||
}, "delete", l)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *SqliteProvider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
func (p *SqliteProvider) keepDisconnectedClientsStart() time.Time {
|
||||
t := now()
|
||||
if p.keepDisconnectedClients != nil {
|
||||
@ -92,45 +120,49 @@ func (p *SqliteProvider) keepDisconnectedClientsStart() time.Time {
|
||||
return t
|
||||
}
|
||||
|
||||
func convertToSqlite(v *Client) *clientSqlite {
|
||||
if v == nil {
|
||||
func convertToSqlite(c *Client) (res *clientSqlite) {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
res := &clientSqlite{
|
||||
ID: v.ID,
|
||||
ClientAuthID: v.ClientAuthID,
|
||||
|
||||
c.flock.RLock()
|
||||
res = &clientSqlite{
|
||||
ID: c.ID,
|
||||
ClientAuthID: c.ClientAuthID,
|
||||
Details: &clientDetails{
|
||||
Name: v.Name,
|
||||
OS: v.OS,
|
||||
OSArch: v.OSArch,
|
||||
OSFamily: v.OSFamily,
|
||||
OSKernel: v.OSKernel,
|
||||
Hostname: v.Hostname,
|
||||
Version: v.Version,
|
||||
Address: v.Address,
|
||||
OSFullName: v.OSFullName,
|
||||
OSVersion: v.OSVersion,
|
||||
OSVirtualizationSystem: v.OSVirtualizationSystem,
|
||||
OSVirtualizationRole: v.OSVirtualizationRole,
|
||||
CPUFamily: v.CPUFamily,
|
||||
CPUModel: v.CPUModel,
|
||||
CPUModelName: v.CPUModelName,
|
||||
CPUVendor: v.CPUVendor,
|
||||
NumCPUs: v.NumCPUs,
|
||||
MemoryTotal: v.MemoryTotal,
|
||||
Timezone: v.Timezone,
|
||||
IPv4: v.IPv4,
|
||||
IPv6: v.IPv6,
|
||||
Tags: v.Tags,
|
||||
Tunnels: v.Tunnels,
|
||||
AllowedUserGroups: v.AllowedUserGroups,
|
||||
UpdatesStatus: v.UpdatesStatus,
|
||||
ClientConfig: v.ClientConfiguration,
|
||||
Name: c.Name,
|
||||
OS: c.OS,
|
||||
OSArch: c.OSArch,
|
||||
OSFamily: c.OSFamily,
|
||||
OSKernel: c.OSKernel,
|
||||
Hostname: c.Hostname,
|
||||
Version: c.Version,
|
||||
Address: c.Address,
|
||||
OSFullName: c.OSFullName,
|
||||
OSVersion: c.OSVersion,
|
||||
OSVirtualizationSystem: c.OSVirtualizationSystem,
|
||||
OSVirtualizationRole: c.OSVirtualizationRole,
|
||||
CPUFamily: c.CPUFamily,
|
||||
CPUModel: c.CPUModel,
|
||||
CPUModelName: c.CPUModelName,
|
||||
CPUVendor: c.CPUVendor,
|
||||
NumCPUs: c.NumCPUs,
|
||||
MemoryTotal: c.MemoryTotal,
|
||||
Timezone: c.Timezone,
|
||||
IPv4: c.IPv4,
|
||||
IPv6: c.IPv6,
|
||||
Tags: c.Tags,
|
||||
Tunnels: c.Tunnels,
|
||||
AllowedUserGroups: c.AllowedUserGroups,
|
||||
UpdatesStatus: c.UpdatesStatus,
|
||||
ClientConfig: c.ClientConfiguration,
|
||||
},
|
||||
}
|
||||
if v.DisconnectedAt != nil {
|
||||
res.DisconnectedAt = sql.NullTime{Time: *v.DisconnectedAt, Valid: true}
|
||||
c.flock.RUnlock()
|
||||
if !c.IsConnected() {
|
||||
res.DisconnectedAt = sql.NullTime{Time: c.GetDisconnectedAtValue(), Valid: true}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
@ -196,9 +228,9 @@ func (d *clientDetails) Value() (driver.Value, error) {
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
func (s *clientSqlite) convert() *Client {
|
||||
func (s *clientSqlite) convert() (res *Client) {
|
||||
d := s.Details
|
||||
res := &Client{
|
||||
res = &Client{
|
||||
ID: s.ID,
|
||||
ClientAuthID: s.ClientAuthID,
|
||||
Name: d.Name,
|
||||
@ -229,15 +261,11 @@ func (s *clientSqlite) convert() *Client {
|
||||
ClientConfiguration: d.ClientConfig,
|
||||
}
|
||||
if s.DisconnectedAt.Valid {
|
||||
res.DisconnectedAt = &s.DisconnectedAt.Time
|
||||
res.SetDisconnectedAt(&s.DisconnectedAt.Time)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (p *SqliteProvider) Close() error {
|
||||
return p.db.Close()
|
||||
}
|
||||
|
||||
func convertClientList(list []*clientSqlite) []*Client {
|
||||
res := make([]*Client, 0, len(list))
|
||||
for _, cur := range list {
|
||||
|
||||
@ -39,12 +39,12 @@ func TestClientsSqliteProvider(t *testing.T) {
|
||||
assert.ElementsMatch(t, []*Client{c1, c2, c3, c4, c5}, gotAll)
|
||||
|
||||
// verify delete obsolete clients
|
||||
gotObsolete, err := p.get(ctx, c5.ID)
|
||||
gotObsolete, err := p.get(ctx, c5.GetID())
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, c5, gotObsolete)
|
||||
|
||||
require.NoError(t, p.DeleteObsolete(ctx))
|
||||
gotObsolete, err = p.get(ctx, c5.ID)
|
||||
require.NoError(t, p.DeleteObsolete(ctx, testLog))
|
||||
gotObsolete, err = p.get(ctx, c5.GetID())
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, gotObsolete)
|
||||
|
||||
@ -61,7 +61,7 @@ func TestClientsSqliteProvider(t *testing.T) {
|
||||
d := time.Date(2020, 11, 5, 12, 11, 20, 0, time.UTC)
|
||||
c1.DisconnectedAt = &d
|
||||
require.NoError(t, p.Save(ctx, c1))
|
||||
gotUpdated, err := p.get(ctx, c1.ID)
|
||||
gotUpdated, err := p.get(ctx, c1.GetID())
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, c1, gotUpdated)
|
||||
gotAll, err = p.GetAll(ctx)
|
||||
|
||||
@ -92,17 +92,19 @@ func (c *FileProvider) Get(id string) (*ClientAuth, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *FileProvider) Add(client *ClientAuth) (bool, error) {
|
||||
func (c *FileProvider) Add(clientAuth *ClientAuth) (bool, error) {
|
||||
idPswdPairs, err := c.load()
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to decode rport clients auth file: %v", err)
|
||||
}
|
||||
|
||||
if _, ok := idPswdPairs[client.ID]; ok {
|
||||
clientID := clientAuth.ID
|
||||
|
||||
if _, ok := idPswdPairs[clientID]; ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
idPswdPairs[client.ID] = client.Password
|
||||
idPswdPairs[clientID] = clientAuth.Password
|
||||
|
||||
if err := c.save(idPswdPairs); err != nil {
|
||||
return false, fmt.Errorf("failed to encode rport clients auth file: %v", err)
|
||||
|
||||
@ -2,6 +2,7 @@ package ports
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
mapset "github.com/deckarep/golang-set"
|
||||
"github.com/shirou/gopsutil/v3/net"
|
||||
@ -13,6 +14,7 @@ type PortDistributor struct {
|
||||
allowedPorts mapset.Set
|
||||
|
||||
portsPools map[string]mapset.Set
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewPortDistributor(allowedPorts mapset.Set) *PortDistributor {
|
||||
@ -22,6 +24,19 @@ func NewPortDistributor(allowedPorts mapset.Set) *PortDistributor {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *PortDistributor) getPoolFromMap(protocol string) (pool mapset.Set) {
|
||||
d.mu.RLock()
|
||||
pool = d.portsPools[protocol]
|
||||
d.mu.RUnlock()
|
||||
return pool
|
||||
}
|
||||
|
||||
func (d *PortDistributor) setPool(protocol string, pool mapset.Set) {
|
||||
d.mu.Lock()
|
||||
d.portsPools[protocol] = pool
|
||||
d.mu.Unlock()
|
||||
}
|
||||
|
||||
// NewPortDistributorForTests is used only for unit-testing.
|
||||
func NewPortDistributorForTests(allowedPorts, tcpPortsPool, udpPortsPool mapset.Set) *PortDistributor {
|
||||
return &PortDistributor{
|
||||
@ -39,7 +54,8 @@ func (d *PortDistributor) GetRandomPort(protocol string) (int, error) {
|
||||
subProtocols = []string{models.ProtocolTCP, models.ProtocolUDP}
|
||||
}
|
||||
for _, p := range subProtocols {
|
||||
if d.portsPools[p] == nil {
|
||||
pool := d.getPoolFromMap(p)
|
||||
if pool == nil {
|
||||
err := d.refresh(p)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@ -54,7 +70,8 @@ func (d *PortDistributor) GetRandomPort(protocol string) (int, error) {
|
||||
|
||||
// Make sure port is removed from all pools for tcp+udp protocol
|
||||
for _, p := range subProtocols {
|
||||
d.portsPools[p].Remove(port)
|
||||
pool := d.getPoolFromMap(p)
|
||||
pool.Remove(port)
|
||||
}
|
||||
|
||||
return port.(int), nil
|
||||
@ -69,7 +86,7 @@ func (d *PortDistributor) IsPortBusy(protocol string, port int) bool {
|
||||
}
|
||||
|
||||
func (d *PortDistributor) getPool(protocol string) mapset.Set {
|
||||
pool := d.portsPools[protocol]
|
||||
pool := d.getPoolFromMap(protocol)
|
||||
if protocol == models.ProtocolTCPUDP {
|
||||
pool = d.portsPools[models.ProtocolTCP].Intersect(d.portsPools[models.ProtocolUDP])
|
||||
}
|
||||
@ -94,12 +111,14 @@ func (d *PortDistributor) refresh(protocol string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
d.portsPools[protocol] = d.allowedPorts.Difference(busyPorts)
|
||||
pool := d.allowedPorts.Difference(busyPorts)
|
||||
d.setPool(protocol, pool)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ListBusyPorts(protocol string) (mapset.Set, error) {
|
||||
result := mapset.NewThreadUnsafeSet()
|
||||
result := mapset.NewSet()
|
||||
connections, err := net.Connections(protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -15,9 +15,9 @@ func TestPortDistributor(t *testing.T) {
|
||||
for _, protocol := range []string{models.ProtocolTCP, models.ProtocolUDP, models.ProtocolTCPUDP} {
|
||||
t.Run(protocol, func(t *testing.T) {
|
||||
pd := NewPortDistributorForTests(
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4, 5}),
|
||||
mapset.NewThreadUnsafeSetFromSlice([]interface{}{2, 3, 4, 5}),
|
||||
mapset.NewSetFromSlice([]interface{}{1, 2, 3, 4, 5}),
|
||||
mapset.NewSetFromSlice([]interface{}{2, 3, 4, 5}),
|
||||
mapset.NewSetFromSlice([]interface{}{2, 3, 4, 5}),
|
||||
)
|
||||
|
||||
assert.Equal(t, true, pd.IsPortBusy(protocol, 1))
|
||||
|
||||
@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func TryParsePortRanges(portRanges []string) (mapset.Set, error) {
|
||||
result := mapset.NewThreadUnsafeSet()
|
||||
result := mapset.NewSet()
|
||||
for _, rangeStr := range portRanges {
|
||||
rangeParts := strings.Split(rangeStr, "-")
|
||||
if len(rangeParts) == 1 {
|
||||
@ -59,7 +59,7 @@ func tryParsePortNumberRange(rangeStart, rangeEnd string) (mapset.Set, error) {
|
||||
}
|
||||
|
||||
func setFromRange(start, end int) mapset.Set {
|
||||
s := mapset.NewThreadUnsafeSet()
|
||||
s := mapset.NewSet()
|
||||
for i := 0; i <= end-start; i++ {
|
||||
s.Add(start + i)
|
||||
}
|
||||
|
||||
@ -13,16 +13,20 @@ type Task interface {
|
||||
|
||||
// Run runs the given task periodically with a given interval between executions.
|
||||
func Run(ctx context.Context, log *logger.Logger, task Task, interval time.Duration) {
|
||||
log.Debugf("started")
|
||||
tick := time.NewTicker(interval)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
log.Debugf("running")
|
||||
if err := task.Run(ctx); err != nil {
|
||||
log.Errorf("Task %T finished with an error: %v.", task, err)
|
||||
log.Errorf("finished with an error: %v.", err)
|
||||
}
|
||||
log.Debugf("finished")
|
||||
case <-ctx.Done():
|
||||
log.Debugf("%T: context canceled", task)
|
||||
tick.Stop()
|
||||
log.Debugf("context canceled", task)
|
||||
log.Debugf("stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@ -49,6 +49,8 @@ const (
|
||||
cleanupAPISessionsInterval = time.Hour
|
||||
cleanupJobsInterval = time.Hour
|
||||
LogNumGoRoutinesInterval = time.Minute * 2
|
||||
|
||||
DefaultMaxClientDBConnections = 50
|
||||
)
|
||||
|
||||
// Server represents a rport service
|
||||
@ -168,11 +170,18 @@ func NewServer(ctx context.Context, config *chconfig.Config, opts *ServerOpts) (
|
||||
// even if monitoring disabled, always create the monitoring service to support queries of past data etc
|
||||
s.monitoringService = monitoring.NewService(monitoringProvider)
|
||||
|
||||
sourceOptions := config.Server.GetSQLiteDataSourceOptions()
|
||||
|
||||
// particularly the client.db needs performant db access, so allow multi-threaded access
|
||||
// and use the RetryWhenBusy fn to ensure writes succeed if we get a busy error due to
|
||||
// concurrent thread access.
|
||||
sourceOptions.MaxOpenConnections = DefaultMaxClientDBConnections
|
||||
|
||||
s.clientDB, err = sqlite.New(
|
||||
path.Join(config.Server.DataDir, "clients.db"),
|
||||
clientsmigration.AssetNames(),
|
||||
clientsmigration.Asset,
|
||||
config.Server.GetSQLiteDataSourceOptions(),
|
||||
sourceOptions,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create clients DB instance: %v", err)
|
||||
@ -265,8 +274,8 @@ func NewServer(ctx context.Context, config *chconfig.Config, opts *ServerOpts) (
|
||||
func (s *Server) HandlePlusLicenseInfoAvailable() {
|
||||
s.Logger.Debugf("received license info from rport-plus")
|
||||
|
||||
if s.clientListener != nil && s.clientListener.clientService != nil {
|
||||
s.clientListener.clientService.UpdateClientStatus()
|
||||
if s.clientListener != nil && s.clientListener.server.clientService != nil {
|
||||
s.clientListener.server.clientService.UpdateClientStatus()
|
||||
}
|
||||
}
|
||||
|
||||
@ -309,19 +318,20 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
// TODO(m-terel): add graceful shutdown of background task
|
||||
if s.config.Server.PurgeDisconnectedClients {
|
||||
s.Infof("Period to keep disconnected clients is set to %v", s.config.Server.KeepDisconnectedClients)
|
||||
go scheduler.Run(ctx, s.Logger, clients.NewCleanupTask(s.Logger, s.clientListener.clientService.GetRepo()), s.config.Server.PurgeDisconnectedClientsInterval)
|
||||
go scheduler.Run(ctx, s.Logger, clients.NewCleanupTask(s.Logger, s.clientListener.server.clientService.GetRepo()), s.config.Server.PurgeDisconnectedClientsInterval)
|
||||
s.Infof("Task to purge disconnected clients will run with interval %v", s.config.Server.PurgeDisconnectedClientsInterval)
|
||||
} else {
|
||||
s.Debugf("Task to purge disconnected clients disabled")
|
||||
}
|
||||
|
||||
//Run a task to Check the client connections status by sending and receiving pings
|
||||
go scheduler.Run(ctx, s.Logger, NewClientsStatusCheckTask(
|
||||
clientsStatusCheckTask := NewClientsStatusCheckTask(
|
||||
s.Logger,
|
||||
s.clientListener.clientService.GetRepo(),
|
||||
s.clientListener.server.clientService.GetRepo(),
|
||||
s.config.Server.CheckClientsConnectionInterval,
|
||||
s.config.Server.CheckClientsConnectionTimeout,
|
||||
), s.config.Server.CheckClientsConnectionInterval)
|
||||
)
|
||||
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", clientsStatusCheckTask)), clientsStatusCheckTask, s.config.Server.CheckClientsConnectionInterval)
|
||||
s.Infof("Task to check the clients connection status will run with interval %v", s.config.Server.CheckClientsConnectionInterval)
|
||||
|
||||
if s.config.Monitoring.Enabled {
|
||||
@ -334,16 +344,19 @@ func (s *Server) Run(ctx context.Context) error {
|
||||
cleaningPeriod = s.config.Monitoring.GetDataStorageDuration()
|
||||
}
|
||||
|
||||
go scheduler.Run(ctx, s.Logger, monitoring.NewCleanupTask(s.Logger, s.monitoringService, cleaningPeriod), cleanupMeasurementsInterval)
|
||||
monitoringCleanupTask := monitoring.NewCleanupTask(s.Logger, s.monitoringService, cleaningPeriod)
|
||||
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", monitoringCleanupTask)), monitoringCleanupTask, cleanupMeasurementsInterval)
|
||||
s.Infof("Task to cleanup measurements will run with interval %v", cleanupMeasurementsInterval)
|
||||
} else {
|
||||
s.Infof("Measurement disabled")
|
||||
}
|
||||
|
||||
go scheduler.Run(ctx, s.Logger, session.NewCleanupTask(s.apiListener.apiSessions), cleanupAPISessionsInterval)
|
||||
sessionsCleanupTask := session.NewCleanupTask(s.apiListener.apiSessions)
|
||||
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", sessionsCleanupTask)), sessionsCleanupTask, cleanupAPISessionsInterval)
|
||||
s.Infof("Task to cleanup expired api sessions will run with interval %v", cleanupAPISessionsInterval)
|
||||
|
||||
go scheduler.Run(ctx, s.Logger, jobs.NewCleanupTask(s.jobProvider, s.config.Server.JobsMaxResults), cleanupJobsInterval)
|
||||
jobsCleanupTask := jobs.NewCleanupTask(s.jobProvider, s.config.Server.JobsMaxResults)
|
||||
go scheduler.Run(ctx, s.Logger.Fork(fmt.Sprintf("task %T", jobsCleanupTask)), jobsCleanupTask, cleanupJobsInterval)
|
||||
s.Infof("Task to cleanup jobs will run with interval %v", cleanupJobsInterval)
|
||||
|
||||
// Only on debug mode, log the number of running go routines
|
||||
|
||||
@ -217,8 +217,9 @@ func (al *APIListener) sendFileToClients(uploadRequest *UploadRequest) {
|
||||
|
||||
func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRequest *UploadRequest) {
|
||||
for res := range resChan {
|
||||
clientID := res.client.GetID()
|
||||
output := &UploadOutput{
|
||||
ClientID: res.client.ID,
|
||||
ClientID: clientID,
|
||||
UploadResponse: res.resp,
|
||||
}
|
||||
if res.err != nil {
|
||||
@ -238,7 +239,7 @@ func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRe
|
||||
errTxt,
|
||||
uploadRequest.ID,
|
||||
uploadRequest.DestinationPath,
|
||||
res.client.ID,
|
||||
clientID,
|
||||
)
|
||||
al.auditLog.Entry(auditlog.ApplicationUploads, auditlog.ActionFailed).
|
||||
WithRequest(uploadRequest.UploadedFile).
|
||||
@ -251,7 +252,7 @@ func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRe
|
||||
"upload success, file id: %s, file path: %s, client %s",
|
||||
uploadRequest.ID,
|
||||
uploadRequest.DestinationPath,
|
||||
res.client.ID,
|
||||
clientID,
|
||||
)
|
||||
al.auditLog.Entry(auditlog.ApplicationUploads, auditlog.ActionSuccess).
|
||||
WithRequest(uploadRequest.UploadedFile).
|
||||
@ -268,7 +269,8 @@ func (al *APIListener) consumeUploadResults(resChan chan *uploadResult, uploadRe
|
||||
func (al *APIListener) sendFileToClient(wg *sync.WaitGroup, file *models.UploadedFile, cl *clients.Client, resChan chan *uploadResult) {
|
||||
defer wg.Done()
|
||||
|
||||
if cl.ClientConfiguration != nil && !cl.ClientConfiguration.FileReceptionConfig.Enabled {
|
||||
fileReceptionConfig := cl.GetFileReceptionConfig()
|
||||
if fileReceptionConfig != nil && !fileReceptionConfig.Enabled {
|
||||
resChan <- &uploadResult{
|
||||
err: errors3.ErrUploadsDisabled,
|
||||
client: cl,
|
||||
@ -277,7 +279,7 @@ func (al *APIListener) sendFileToClient(wg *sync.WaitGroup, file *models.Uploade
|
||||
return
|
||||
}
|
||||
resp := &models.UploadResponse{}
|
||||
err := comm.SendRequestAndGetResponse(cl.Connection, comm.RequestTypeUpload, file, resp)
|
||||
err := comm.SendRequestAndGetResponse(cl.GetConnection(), comm.RequestTypeUpload, file, resp, al.Log())
|
||||
|
||||
resChan <- &uploadResult{
|
||||
err: err,
|
||||
|
||||
@ -85,7 +85,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
useFsCallback: true,
|
||||
fileName: "file.txt",
|
||||
fileContent: "some content",
|
||||
cl: clients.New(t).ID("22114341234").Build(),
|
||||
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
|
||||
formParts: map[string][]string{
|
||||
"client_id": {
|
||||
"22114341234",
|
||||
@ -137,7 +137,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
useFsCallback: true,
|
||||
fileName: "file.txt",
|
||||
fileContent: "some content",
|
||||
cl: clients.New(t).ID("22114341234").Build(),
|
||||
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
|
||||
clientTags: []string{"linux"},
|
||||
formParts: map[string][]string{
|
||||
"tags": {
|
||||
@ -189,7 +189,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
useFsCallback: true,
|
||||
fileName: "file.txt",
|
||||
fileContent: "some content",
|
||||
cl: clients.New(t).ID("22114341234").Build(),
|
||||
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
|
||||
formParts: map[string][]string{
|
||||
"client_id": {
|
||||
"22114341234",
|
||||
@ -220,7 +220,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
useFsCallback: true,
|
||||
fileName: "file.txt",
|
||||
fileContent: "some content",
|
||||
cl: clients.New(t).ID("22114341234").Build(),
|
||||
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
|
||||
formParts: map[string][]string{
|
||||
"dest": {
|
||||
"/destination/myfile.txt",
|
||||
@ -242,7 +242,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
useFsCallback: true,
|
||||
fileName: "file.txt",
|
||||
fileContent: "some content",
|
||||
cl: clients.New(t).ID("22114341234").Build(),
|
||||
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
|
||||
formParts: map[string][]string{
|
||||
"tags": {
|
||||
`{
|
||||
@ -270,7 +270,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
useFsCallback: true,
|
||||
fileName: "file.txt",
|
||||
fileContent: "some content",
|
||||
cl: clients.New(t).ID("22114341234").Build(),
|
||||
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
|
||||
formParts: map[string][]string{
|
||||
"client_id": {
|
||||
"22114341234",
|
||||
@ -296,7 +296,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
useFsCallback: true,
|
||||
fileName: "file.txt",
|
||||
fileContent: "some content",
|
||||
cl: clients.New(t).ID("22114341234").Build(),
|
||||
cl: clients.New(t).ID("22114341234").Logger(testLog).Build(),
|
||||
formParts: map[string][]string{
|
||||
"client_id": {
|
||||
"22114341234",
|
||||
@ -321,7 +321,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
cl := tc.cl
|
||||
|
||||
if tc.clientTags != nil {
|
||||
cl.Tags = tc.clientTags
|
||||
cl.SetTags(tc.clientTags)
|
||||
}
|
||||
|
||||
connMock := test.NewConnMock()
|
||||
@ -331,7 +331,7 @@ func TestHandleFileUploads(t *testing.T) {
|
||||
done := make(chan bool)
|
||||
connMock.DoneChannel = done
|
||||
|
||||
cl.Connection = connMock
|
||||
cl.SetConnection(connMock)
|
||||
|
||||
fileAPIMock := test.NewFileAPIMock()
|
||||
if tc.useFsCallback {
|
||||
|
||||
@ -4,10 +4,12 @@ import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/cloudradar-monitoring/rport/share/logger"
|
||||
)
|
||||
|
||||
func PingConnectionWithTimeout(conn ssh.Conn, timeout time.Duration) (ok bool, response []byte, rtt time.Duration, err error) {
|
||||
func PingConnectionWithTimeout(conn ssh.Conn, timeout time.Duration, l *logger.Logger) (ok bool, response []byte, rtt time.Duration, err error) {
|
||||
timerStart := time.Now()
|
||||
ok, response, err = SendRequestWithTimeout(conn, RequestTypePing, true, nil, timeout)
|
||||
ok, response, err = SendRequestWithTimeout(conn, RequestTypePing, true, nil, timeout, l)
|
||||
return ok, response, time.Since(timerStart), err
|
||||
}
|
||||
|
||||
@ -4,7 +4,9 @@ package comm
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
@ -46,7 +48,7 @@ func ReplySuccessJSON(log *logger.Logger, req *ssh.Request, resp interface{}) {
|
||||
// SendRequestAndGetResponse sends a given request, parses a returned response and stores a success result in a given destination value.
|
||||
// Returns an error on a failure response or if an error happen. Error will be ClientError type if the error is a client error.
|
||||
// Both request and response are expected to be JSON.
|
||||
func SendRequestAndGetResponse(conn ssh.Conn, reqType string, req, successRespDest interface{}) error {
|
||||
func SendRequestAndGetResponse(conn ssh.Conn, reqType string, req, successRespDest interface{}, l *logger.Logger) error {
|
||||
reqBytes, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode request %T: %v", req, err)
|
||||
@ -63,6 +65,8 @@ func SendRequestAndGetResponse(conn ssh.Conn, reqType string, req, successRespDe
|
||||
|
||||
if successRespDest != nil {
|
||||
if err := json.Unmarshal(respBytes, successRespDest); err != nil {
|
||||
l.Debugf("failed to unmarshal: respBytes: %s", string(respBytes))
|
||||
l.Debugf("%#v", successRespDest)
|
||||
return NewClientError(fmt.Errorf("invalid client response format: failed to decode response into %T: %v", successRespDest, err))
|
||||
}
|
||||
}
|
||||
@ -85,30 +89,77 @@ func (e *ClientError) Error() string {
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func SendRequestWithTimeout(conn ssh.Conn, name string, wantReplay bool, payload []byte, timeout time.Duration) (bool, []byte, error) {
|
||||
type requestResponse struct {
|
||||
ok bool
|
||||
response []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func SendRequestWithTimeout(conn ssh.Conn, name string, wantReply bool, payload []byte, timeout time.Duration, l *logger.Logger) (bool, []byte, error) {
|
||||
var (
|
||||
ok bool
|
||||
response []byte
|
||||
err error
|
||||
)
|
||||
|
||||
if conn == nil {
|
||||
return false, nil, errors.New("cannot send request when conn is nil")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ch := make(chan bool, 1)
|
||||
ch := make(chan requestResponse, 1)
|
||||
|
||||
go func() {
|
||||
ok, response, err = conn.SendRequest(name, wantReplay, payload)
|
||||
ok, response, err = conn.SendRequest(name, wantReply, payload)
|
||||
select {
|
||||
default:
|
||||
ch <- true
|
||||
ch <- requestResponse{
|
||||
ok: ok,
|
||||
response: response,
|
||||
err: err,
|
||||
}
|
||||
case <-ctx.Done():
|
||||
l.Debugf("send canceled")
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
reqTimeout := time.NewTimer(timeout)
|
||||
defer reqTimeout.Stop()
|
||||
select {
|
||||
case <-ch:
|
||||
return ok, response, err
|
||||
case res := <-ch:
|
||||
reqTimeout.Stop()
|
||||
return res.ok, res.response, res.err
|
||||
case <-reqTimeout.C:
|
||||
return false, nil, TimeoutError{fmt.Errorf("conn.SendRequest(%s), timeout %s exceeded", name, timeout)}
|
||||
}
|
||||
}
|
||||
|
||||
const DefaultMaxRetryAttempts = 3
|
||||
|
||||
func WithRetry[R any](retryAbleFn func() (result R, err error), canRetry func(err error) (shouldRetry bool), minRetryWaitDuration time.Duration, label string, l *logger.Logger) (result R, err error) {
|
||||
for r := 0; r < DefaultMaxRetryAttempts; r++ {
|
||||
attempt := r + 1
|
||||
// l.Debugf("%s: attempt %d", label, attempt)
|
||||
if r > 0 {
|
||||
// backoff with some jitter
|
||||
delay := (minRetryWaitDuration * time.Duration(r*r)) + time.Duration(rand.Intn(1000))*time.Millisecond
|
||||
l.Debugf("%s: attempt %d failed. will sleep for: %d seconds", label, attempt, delay/time.Second)
|
||||
time.Sleep(delay)
|
||||
}
|
||||
result, err = retryAbleFn()
|
||||
if err != nil {
|
||||
l.Debugf("%s: attempt %d err = %+v\n", label, attempt, err)
|
||||
if !canRetry(err) {
|
||||
// non retryable err
|
||||
l.Debugf("%s: attempt %d non-retryable err %v", label, attempt, err)
|
||||
return result, err
|
||||
}
|
||||
continue
|
||||
}
|
||||
// success
|
||||
return result, nil
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
@ -5,7 +5,7 @@ import (
|
||||
)
|
||||
|
||||
func SetFromRange(start, end int) mapset.Set {
|
||||
s := mapset.NewThreadUnsafeSet()
|
||||
s := mapset.NewSet()
|
||||
for i := 0; i <= end-start; i++ {
|
||||
s.Add(start + i)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user