using Remotely.Server.Models;
using Remotely.Shared.Services;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Text.Json;
namespace Remotely.Server.Services;
///
/// A cache containing all active remote control sessions.
///
public interface IRemoteControlSessionCache
{
IEnumerable Sessions { get; }
RemoteControlSession AddOrUpdate(string sessionId, RemoteControlSession session);
RemoteControlSession AddOrUpdate(
string sessionId,
RemoteControlSession session,
Func updateFactory);
RemoteControlSession GetOrAdd(string sessionId, Func valueFactory);
Task RemoveExpiredSessions();
bool TryAdd(string sessionId, RemoteControlSession session);
bool TryGetValue(string sessionId, [NotNullWhen(true)] out RemoteControlSession? session);
bool TryRemove(string sessionId, [NotNullWhen(true)] out RemoteControlSession? session);
}
internal class RemoteControlSessionCache : IRemoteControlSessionCache
{
private readonly ConcurrentDictionary _sessions = new();
// ConcurrentDictionary's AddOrUpdate and GetOrAdd are not atomic operations,
// so we need to use an outer lock.
private readonly object _sessionsLock = new();
private readonly ILogger _logger;
private readonly ISystemTime _systemTime;
public RemoteControlSessionCache(
ISystemTime systemTime,
ILogger logger)
{
_systemTime = systemTime;
_logger = logger;
}
public IEnumerable Sessions => _sessions.Values;
public RemoteControlSession AddOrUpdate(string sessionId, RemoteControlSession session)
{
lock (_sessionsLock)
{
return AddOrUpdate(sessionId, session, (k, v) =>
{
v.Dispose();
return session;
});
}
}
public RemoteControlSession AddOrUpdate(
string sessionId,
RemoteControlSession session,
Func updateFactory)
{
lock (_sessionsLock)
{
if (_sessions.ContainsKey(sessionId))
{
var newValue = updateFactory(sessionId, _sessions[sessionId]);
_sessions[sessionId] = newValue;
return newValue;
}
_sessions[sessionId] = session;
return session;
}
}
public RemoteControlSession GetOrAdd(string sessionId, Func valueFactory)
{
lock (_sessionsLock)
{
return _sessions.GetOrAdd(sessionId, (key) =>
{
return valueFactory(key);
});
}
}
public Task RemoveExpiredSessions()
{
lock (_sessionsLock)
{
foreach (var session in _sessions)
{
if (session.Value.Mode is RemoteControlMode.Unattended or RemoteControlMode.Unknown &&
session.Value.ViewerList.Count == 0 &&
session.Value.Created < _systemTime.Now.AddMinutes(-1))
{
_logger.LogWarning("Removing expired session: {session}", JsonSerializer.Serialize(session.Value));
if (_sessions.TryRemove(session.Key, out var expiredSession))
{
expiredSession.Dispose();
}
}
}
}
return Task.CompletedTask;
}
public bool TryAdd(string sessionId, RemoteControlSession session)
{
lock (_sessionsLock)
{
return _sessions.TryAdd(sessionId, session);
}
}
public bool TryGetValue(string sessionId, [NotNullWhen(true)] out RemoteControlSession? session)
{
lock (_sessionsLock)
{
return _sessions.TryGetValue(sessionId, out session);
}
}
public bool TryRemove(string sessionId, [NotNullWhen(true)] out RemoteControlSession? session)
{
lock (_sessionsLock)
{
if (_sessions.TryRemove(sessionId, out session))
{
try
{
session.Dispose();
}
catch (Exception ex)
{
_logger.LogError(ex, "Error disposing RemoteControlSession ID {id}.", sessionId);
}
return true;
}
}
return false;
}
}