diff --git a/Server/API/AgentUpdateController.cs b/Server/API/AgentUpdateController.cs index dcd9f26a..909ca4fa 100644 --- a/Server/API/AgentUpdateController.cs +++ b/Server/API/AgentUpdateController.cs @@ -148,7 +148,7 @@ namespace Remotely.Server.API if (_appConfig.BannedDevices.Contains(deviceIp)) { - _dataService.WriteEvent($"Device IP ({deviceIp}) is banned. Sending uninstall command.", null); + _logger.LogInformation("Device IP ({deviceIp}) is banned. Sending uninstall command.", deviceIp); var bannedDevices = _serviceSessionCache.GetAllDevices().Where(x => x.PublicIP == deviceIp); diff --git a/Server/API/RemoteControlController.cs b/Server/API/RemoteControlController.cs index 1f82b5de..69e0e901 100644 --- a/Server/API/RemoteControlController.cs +++ b/Server/API/RemoteControlController.cs @@ -15,6 +15,8 @@ using Immense.RemoteControl.Server.Services; using Remotely.Server.Services.RcImplementations; using Immense.RemoteControl.Server.Abstractions; using Immense.RemoteControl.Shared.Helpers; +using Microsoft.Build.Framework; +using Microsoft.Extensions.Logging; // For more information on enabling Web API for empty projects, visit https://go.microsoft.com/fwlink/?LinkID=397860 @@ -32,6 +34,7 @@ namespace Remotely.Server.API private readonly IHubEventHandler _hubEvents; private readonly IDataService _dataService; private readonly SignInManager _signInManager; + private readonly ILogger _logger; public RemoteControlController( SignInManager signInManager, @@ -41,7 +44,8 @@ namespace Remotely.Server.API IServiceHubSessionCache serviceSessionCache, IOtpProvider otpProvider, IHubEventHandler hubEvents, - IApplicationConfig appConfig) + IApplicationConfig appConfig, + ILogger logger) { _dataService = dataService; _serviceHub = serviceHub; @@ -51,6 +55,7 @@ namespace Remotely.Server.API _otpProvider = otpProvider; _hubEvents = hubEvents; _signInManager = signInManager; + _logger = logger; } [HttpGet("{deviceID}")] @@ -75,20 +80,20 @@ namespace Remotely.Server.API if (result.Succeeded && _dataService.DoesUserHaveAccessToDevice(rcRequest.DeviceID, _dataService.GetUserByNameWithOrg(rcRequest.Email))) { - _dataService.WriteEvent($"API login successful for {rcRequest.Email}.", orgId); + _logger.LogInformation("API login successful for {rcRequestEmail}.", rcRequest.Email); return await InitiateRemoteControl(rcRequest.DeviceID, orgId); } else if (result.IsLockedOut) { - _dataService.WriteEvent($"API login unsuccessful due to lockout for {rcRequest.Email}.", orgId); + _logger.LogInformation("API login successful for {rcRequestEmail}.", rcRequest.Email); return Unauthorized("Account is locked."); } else if (result.RequiresTwoFactor) { - _dataService.WriteEvent($"API login unsuccessful due to 2FA for {rcRequest.Email}.", orgId); + _logger.LogInformation("API login successful for {rcRequestEmail}.", rcRequest.Email); return Unauthorized("Account requires two-factor authentication."); } - _dataService.WriteEvent($"API login unsuccessful due to bad attempt for {rcRequest.Email}.", orgId); + _logger.LogInformation("API login unsuccessful due to bad attempt for {rcRequestEmail}.", rcRequest.Email); return BadRequest(); } diff --git a/Server/API/ServerLogsController.cs b/Server/API/ServerLogsController.cs index aec96757..3c1a6852 100644 --- a/Server/API/ServerLogsController.cs +++ b/Server/API/ServerLogsController.cs @@ -1,9 +1,12 @@ using Microsoft.AspNetCore.Mvc; using Remotely.Server.Auth; -using Remotely.Server.Services; using System.Text; using System.Text.Json; using System; +using Microsoft.Extensions.Logging; +using Remotely.Server.Services; +using System.IO; +using System.Threading.Tasks; namespace Remotely.Server.API { @@ -11,23 +14,34 @@ namespace Remotely.Server.API [ApiController] public class ServerLogsController : ControllerBase { - private readonly IDataService _dataService; - private readonly JsonSerializerOptions _jsonOptions = new JsonSerializerOptions() { WriteIndented = true }; + private readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; + private readonly ILogsManager _logsManager; + private readonly ILogger _logger; - public ServerLogsController(IDataService dataService) + public ServerLogsController( + ILogsManager logsManager, + ILogger logger) { - _dataService = dataService; + _logsManager = logsManager; + _logger = logger; } [ServiceFilter(typeof(ApiAuthorizationFilter))] [HttpGet("Download")] - public ActionResult Download() + public async Task Download() { - Request.Headers.TryGetValue("OrganizationID", out var orgId); + _logger.LogInformation( + "Downloading server logs. Remote IP: {ip}", + HttpContext.Connection.RemoteIpAddress); - var logs = _dataService.GetAllEventLogs(User.Identity?.Name, orgId); - var fileBytes = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(logs, _jsonOptions)); - return File(fileBytes, "application/octet-stream", "ServerLogs.json"); + var zipFile = await _logsManager.ZipAllLogs(); + Response.OnCompleted(() => + { + Directory.Delete(zipFile.DirectoryName, true); + return Task.CompletedTask; + }); + + return File(zipFile.OpenRead(), "application/octet-stream", zipFile.Name); } } } diff --git a/Server/Areas/Identity/Pages/Account/ForgotPassword.cshtml.cs b/Server/Areas/Identity/Pages/Account/ForgotPassword.cshtml.cs index 48eb78c9..56ecc341 100644 --- a/Server/Areas/Identity/Pages/Account/ForgotPassword.cshtml.cs +++ b/Server/Areas/Identity/Pages/Account/ForgotPassword.cshtml.cs @@ -12,6 +12,7 @@ using Microsoft.AspNetCore.Mvc.RazorPages; using Microsoft.AspNetCore.WebUtilities; using Remotely.Shared.Models; using Remotely.Server.Services; +using Microsoft.Extensions.Logging; namespace Remotely.Server.Areas.Identity.Pages.Account { @@ -21,14 +22,18 @@ namespace Remotely.Server.Areas.Identity.Pages.Account private readonly UserManager _userManager; private readonly IEmailSenderEx _emailSender; private readonly IDataService _dataService; + private readonly ILogger _logger; - public ForgotPasswordModel(UserManager userManager, + public ForgotPasswordModel( + UserManager userManager, IEmailSenderEx emailSender, - IDataService dataService) + IDataService dataService, + ILogger logger) { _userManager = userManager; _emailSender = emailSender; _dataService = dataService; + _logger = logger; } [BindProperty] @@ -62,7 +67,8 @@ namespace Remotely.Server.Areas.Identity.Pages.Account values: new { area = "Identity", code }, protocol: Request.Scheme); - _dataService.WriteEvent($"Sending password reset for user {user.UserName}. Reset URL: {callbackUrl}", user.OrganizationID); + _logger.LogInformation( + "Sending password reset for user {username}. Reset URL: {callbackUrl}", user.UserName, callbackUrl); var emailResult = await _emailSender.SendEmailAsync( Input.Email, diff --git a/Server/Hubs/AgentHub.cs b/Server/Hubs/AgentHub.cs index 5c9a1b3a..0658a714 100644 --- a/Server/Hubs/AgentHub.cs +++ b/Server/Hubs/AgentHub.cs @@ -23,6 +23,7 @@ namespace Remotely.Server.Hubs private readonly ICircuitManager _circuitManager; private readonly IDataService _dataService; private readonly IExpiringTokenService _expiringTokenService; + private readonly ILogger _logger; private readonly IServiceHubSessionCache _serviceSessionCache; private readonly IHubContext _viewerHubContext; @@ -31,7 +32,8 @@ namespace Remotely.Server.Hubs IServiceHubSessionCache serviceSessionCache, IHubContext viewerHubContext, ICircuitManager circuitManager, - IExpiringTokenService expiringTokenService) + IExpiringTokenService expiringTokenService, + ILogger logger) { _dataService = dataService; _serviceSessionCache = serviceSessionCache; @@ -39,6 +41,7 @@ namespace Remotely.Server.Hubs _appConfig = appConfig; _circuitManager = circuitManager; _expiringTokenService = expiringTokenService; + _logger = logger; } // TODO: Replace with new invoke capability in .NET 7 in ScriptingController. @@ -133,7 +136,7 @@ namespace Remotely.Server.Hubs } catch (Exception ex) { - _dataService.WriteEvent(ex, device?.OrganizationID); + _logger.LogError(ex, "Error while setting device to online status."); } Context.Abort(); @@ -286,7 +289,7 @@ namespace Remotely.Server.Hubs if (_appConfig.BannedDevices.Any(x => !string.IsNullOrWhiteSpace(x) && x.Equals(device, StringComparison.OrdinalIgnoreCase))) { - _dataService.WriteEvent($"Device ID/name/IP ({device}) is banned. Sending uninstall command.", null); + _logger.LogWarning("Device ID/name/IP ({device}) is banned. Sending uninstall command.", device); _ = Clients.Caller.SendAsync("UninstallAgent"); return true; diff --git a/Server/Hubs/CircuitConnection.cs b/Server/Hubs/CircuitConnection.cs index 40f063f4..307a472e 100644 --- a/Server/Hubs/CircuitConnection.cs +++ b/Server/Hubs/CircuitConnection.cs @@ -222,7 +222,10 @@ namespace Remotely.Server.Hubs if (!_dataService.DoesUserHaveAccessToDevice(deviceId, User)) { var device = _dataService.GetDevice(targetDevice.ID); - _dataService.WriteEvent($"Remote control attempted by unauthorized user. Device ID: {deviceId}. User Name: {User.UserName}.", EventType.Warning, device?.OrganizationID); + _logger.LogWarning( + "Remote control attempted by unauthorized user. Device ID: {deviceId}. User Name: {userName}.", + deviceId, + User.UserName); return Result.Fail("Unauthorized."); } @@ -414,13 +417,11 @@ namespace Remotely.Server.Hubs public Task UploadFiles(List fileIDs, string transferID, string[] deviceIDs) { - _dataService.WriteEvent(new EventLog() - { - EventType = EventType.Info, - Message = $"File transfer started by {User.UserName}. File transfer IDs: {string.Join(", ", fileIDs)}.", - TimeStamp = Time.Now, - OrganizationID = User.OrganizationID - }); + _logger.LogInformation( + "File transfer started by {userName}. File transfer IDs: {fileIds}.", + User.UserName, + string.Join(", ", fileIDs)); + deviceIDs = _dataService.FilterDeviceIDsByUserPermission(deviceIDs, User); var connections = GetActiveConnectionsForUserOrg(deviceIDs); foreach (var connection in connections) diff --git a/Server/Pages/ServerLogs.razor b/Server/Pages/ServerLogs.razor index cd6b4187..9cdb91db 100644 --- a/Server/Pages/ServerLogs.razor +++ b/Server/Pages/ServerLogs.razor @@ -5,6 +5,7 @@ @inject IDataService DataService @inject IToastService ToastService @inject IJsInterop JsInterop +@inject ILogsManager LogsManager

Server Logs

@@ -109,11 +110,12 @@ else { get { - return DataService.GetEventLogs(User.UserName, - _fromDate, - _toDate, - _eventType, - _messageFilter); + return Enumerable.Empty(); + //return DataService.GetEventLogs(User.UserName, + // _fromDate, + // _toDate, + // _eventType, + // _messageFilter); } } @@ -122,7 +124,7 @@ else var result = await JsInterop.Confirm("Are you sure you want to delete all logs?"); if (result) { - await DataService.ClearLogs(User.UserName); + await LogsManager.DeleteLogs(); ToastService.ShowToast("Logs deleted."); } } diff --git a/Server/Program.cs b/Server/Program.cs index 5305c55b..70b221eb 100644 --- a/Server/Program.cs +++ b/Server/Program.cs @@ -205,6 +205,7 @@ services.AddScoped(); services.AddScoped(); services.AddSingleton(); services.AddSingleton(); +services.AddSingleton(LogsManager.Default); services.AddRemoteControlServer(config => { @@ -329,8 +330,7 @@ void ConfigureSerilog(WebApplicationBuilder webAppBuilder) dataRetentionDays = retentionSetting; } - var logPath = Directory.Exists("/remotely-data") ? "/remotely-data/logs" : "logs"; - Directory.CreateDirectory(logPath); + var logPath = LogsManager.Default.GetLogsDirectory(); void ApplySharedLoggerConfig(LoggerConfiguration loggerConfiguration) { diff --git a/Server/Services/DbLogger.cs b/Server/Services/DbLogger.cs deleted file mode 100644 index 5bc91a66..00000000 --- a/Server/Services/DbLogger.cs +++ /dev/null @@ -1,83 +0,0 @@ -using Microsoft.AspNetCore.Hosting; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Threading; - -namespace Remotely.Server.Services -{ - public class DbLogger : ILogger - { - private readonly string _categoryName; - private readonly IWebHostEnvironment _hostEnvironment; - private readonly IServiceProvider _serviceProvider; - - protected static ConcurrentStack ScopeStack { get; } = new ConcurrentStack(); - - public DbLogger(string categoryName, IWebHostEnvironment hostEnvironment, IServiceProvider serviceProvider) - { - _categoryName = categoryName; - _hostEnvironment = hostEnvironment; - _serviceProvider = serviceProvider; - } - - public IDisposable BeginScope(TState state) - { - ScopeStack.Push(state.ToString()); - return new NoopDisposable(); - } - - public bool IsEnabled(LogLevel logLevel) - { - switch (logLevel) - { - case LogLevel.Trace: - break; - case LogLevel.Debug: - case LogLevel.Information: - if (_hostEnvironment.IsDevelopment()) - { - return true; - } - break; - case LogLevel.Warning: - case LogLevel.Error: - case LogLevel.Critical: - return true; - case LogLevel.None: - break; - default: - break; - } - return false; - } - - public void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) - { - using var scope = _serviceProvider.CreateScope(); - var dataService = scope.ServiceProvider.GetRequiredService(); - - var scopeStack = ScopeStack.Any() ? - new string[] { ScopeStack.FirstOrDefault(), ScopeStack.LastOrDefault() } : - Array.Empty(); - - dataService.WriteLog(logLevel, _categoryName, eventId, state.ToString(), exception, scopeStack); - } - - - private class NoopDisposable : IDisposable - { - public void Dispose() - { - while (!ScopeStack.TryPop(out _)) - { - Thread.Sleep(100); - } - } - } - } -} diff --git a/Server/Services/DbLoggerProvider.cs b/Server/Services/DbLoggerProvider.cs deleted file mode 100644 index 739fa324..00000000 --- a/Server/Services/DbLoggerProvider.cs +++ /dev/null @@ -1,29 +0,0 @@ -using Microsoft.AspNetCore.Hosting; -using Microsoft.Extensions.Logging; -using System; - -namespace Remotely.Server.Services -{ - public class DbLoggerProvider : ILoggerProvider - { - private readonly IWebHostEnvironment _hostEnvironment; - private readonly IServiceProvider _serviceProvider; - - public DbLoggerProvider(IWebHostEnvironment hostEnvironment, IServiceProvider serviceProvider) - { - _hostEnvironment = hostEnvironment; - _serviceProvider = serviceProvider; - } - - - public ILogger CreateLogger(string categoryName) - { - return new DbLogger(categoryName, _hostEnvironment, _serviceProvider); - } - - public void Dispose() - { - GC.SuppressFinalize(this); - } - } -} diff --git a/Server/Services/EmailSender.cs b/Server/Services/EmailSender.cs index ed55d924..4417f061 100644 --- a/Server/Services/EmailSender.cs +++ b/Server/Services/EmailSender.cs @@ -1,6 +1,8 @@ using MailKit.Net.Smtp; using MailKit.Security; using Microsoft.AspNetCore.Identity.UI.Services; +using Microsoft.Build.Framework; +using Microsoft.Extensions.Logging; using MimeKit; using MimeKit.Text; using System; @@ -15,67 +17,6 @@ namespace Remotely.Server.Services Task SendEmailAsync(string email, string subject, string htmlMessage, string organizationID = null); } - public class EmailSenderEx : IEmailSenderEx - { - public EmailSenderEx(IApplicationConfig appConfig, IDataService dataService) - { - AppConfig = appConfig; - DataService = dataService; - } - - private IApplicationConfig AppConfig { get; } - private IDataService DataService { get; } - - public async Task SendEmailAsync(string toEmail, string replyTo, string subject, string htmlMessage, string organizationID = null) - { - try - { - var message = new MimeMessage(); - message.From.Add(new MailboxAddress(AppConfig.SmtpDisplayName, AppConfig.SmtpEmail)); - message.To.Add(MailboxAddress.Parse(toEmail)); - message.ReplyTo.Add(MailboxAddress.Parse(replyTo)); - message.Subject = subject; - message.Body = new TextPart(TextFormat.Html) - { - Text = htmlMessage - }; - - using var client = new SmtpClient(); - - if (!string.IsNullOrWhiteSpace(AppConfig.SmtpLocalDomain)) - { - client.LocalDomain = AppConfig.SmtpLocalDomain; - } - - client.CheckCertificateRevocation = AppConfig.SmtpCheckCertificateRevocation; - - await client.ConnectAsync(AppConfig.SmtpHost, AppConfig.SmtpPort); - - if (!string.IsNullOrWhiteSpace(AppConfig.SmtpUserName) && - !string.IsNullOrWhiteSpace(AppConfig.SmtpPassword)) - { - await client.AuthenticateAsync(AppConfig.SmtpUserName, AppConfig.SmtpPassword); - } - - await client.SendAsync(message); - await client.DisconnectAsync(true); - - DataService.WriteEvent($"Email successfully sent to {toEmail}. Subject: \"{subject}\".", organizationID); - - return true; - } - catch (Exception ex) - { - DataService.WriteEvent(ex, organizationID); - return false; - } - } - - public Task SendEmailAsync(string email, string subject, string htmlMessage, string organizationID = null) - { - return SendEmailAsync(email, AppConfig.SmtpEmail, subject, htmlMessage, organizationID); - } - } public class EmailSender : IEmailSender { public EmailSender(IEmailSenderEx emailSenderEx) @@ -91,4 +32,66 @@ namespace Remotely.Server.Services } } + public class EmailSenderEx : IEmailSenderEx + { + private readonly IApplicationConfig _appConfig; + private readonly ILogger _logger; + + public EmailSenderEx( + IApplicationConfig appConfig, + ILogger logger) + { + _appConfig = appConfig; + _logger = logger; + } + public async Task SendEmailAsync(string toEmail, string replyTo, string subject, string htmlMessage, string organizationID = null) + { + try + { + var message = new MimeMessage(); + message.From.Add(new MailboxAddress(_appConfig.SmtpDisplayName, _appConfig.SmtpEmail)); + message.To.Add(MailboxAddress.Parse(toEmail)); + message.ReplyTo.Add(MailboxAddress.Parse(replyTo)); + message.Subject = subject; + message.Body = new TextPart(TextFormat.Html) + { + Text = htmlMessage + }; + + using var client = new SmtpClient(); + + if (!string.IsNullOrWhiteSpace(_appConfig.SmtpLocalDomain)) + { + client.LocalDomain = _appConfig.SmtpLocalDomain; + } + + client.CheckCertificateRevocation = _appConfig.SmtpCheckCertificateRevocation; + + await client.ConnectAsync(_appConfig.SmtpHost, _appConfig.SmtpPort); + + if (!string.IsNullOrWhiteSpace(_appConfig.SmtpUserName) && + !string.IsNullOrWhiteSpace(_appConfig.SmtpPassword)) + { + await client.AuthenticateAsync(_appConfig.SmtpUserName, _appConfig.SmtpPassword); + } + + await client.SendAsync(message); + await client.DisconnectAsync(true); + + _logger.LogInformation("Email successfully sent to {toEmail}. Subject: \"{subject}\".", toEmail, subject); + + return true; + } + catch (Exception ex) + { + _logger.LogError(ex, "Error while sending email."); + return false; + } + } + + public Task SendEmailAsync(string email, string subject, string htmlMessage, string organizationID = null) + { + return SendEmailAsync(email, _appConfig.SmtpEmail, subject, htmlMessage, organizationID); + } + } } diff --git a/Server/Services/LogsManager.cs b/Server/Services/LogsManager.cs new file mode 100644 index 00000000..2fd9854e --- /dev/null +++ b/Server/Services/LogsManager.cs @@ -0,0 +1,80 @@ +using Remotely.Shared.Extensions; +using Serilog; +using System; +using System.IO; +using System.IO.Compression; +using System.Linq; +using System.Threading.Tasks; + +namespace Remotely.Server.Services +{ + public interface ILogsManager + { + string GetLogsDirectory(); + Task ZipAllLogs(); + Task DeleteLogs(); + } + + public class LogsManager : ILogsManager + { + public static LogsManager Default { get; } = new(); + + public string GetLogsDirectory() + { + var logsDir = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, "logs"); + if (Directory.Exists("/remotely-data")) + { + logsDir = "/remotely-data/logs"; + } + return Directory.CreateDirectory(logsDir).FullName; + } + + public async Task ZipAllLogs() + { + var logsDir = GetLogsDirectory(); + var baseDir = AppDomain.CurrentDomain.BaseDirectory; + var tempDir = Directory.CreateDirectory(Path.Combine(baseDir, "temp", Guid.NewGuid().ToString())).FullName; + var zipFilePath = Path.Combine( + tempDir, + $"Remotely_Logs-{DateTimeOffset.Now:yyyy-MM-dd-HH-mm-ss}.zip"); + + using var zipArchive = ZipFile.Open(zipFilePath, ZipArchiveMode.Update); + + var files = Directory.GetFiles(logsDir); + + foreach (var file in files) + { + var entry = zipArchive.CreateEntry(Path.GetFileName(file)); + using var entryStream = entry.Open(); + using var fs = File.Open(file, FileMode.Open, FileAccess.Read, FileShare.ReadWrite); + await fs.CopyToAsync(entryStream); + } + + return new FileInfo(zipFilePath); + } + + public async Task DeleteLogs() + { + var logsDir = GetLogsDirectory(); + + var files = Directory.GetFiles(logsDir); + + if (!files.Any()) + { + return; + } + + await foreach (var file in files.ToAsyncEnumerable()) + { + try + { + File.Delete(file); + } + catch (Exception ex) + { + Console.WriteLine("Failed to delete log file: {filename}. Message: {exMessage}", file, ex.Message); + } + } + } + } +} diff --git a/Shared/Extensions/IEnumerableExtensions.cs b/Shared/Extensions/IEnumerableExtensions.cs new file mode 100644 index 00000000..83605cd2 --- /dev/null +++ b/Shared/Extensions/IEnumerableExtensions.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Remotely.Shared.Extensions +{ + public static class IEnumerableExtensions + { + public static async IAsyncEnumerable ToAsyncEnumerable(this IEnumerable source) + { + foreach (var item in source) + { + yield return item; + await Task.Yield(); + } + } + } +} diff --git a/Shared/Primitives/CallbackDisposable.cs b/Shared/Primitives/CallbackDisposable.cs new file mode 100644 index 00000000..3e286e6d --- /dev/null +++ b/Shared/Primitives/CallbackDisposable.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Remotely.Shared.Primitives; + + +/// +/// An implementation of that lets you provide a +/// callback, which will be invoked when the object is disposed. +/// +public sealed class CallbackDisposable : IDisposable +{ + private readonly Action _callback; + private readonly Action _exceptionHandler; + + /// + /// Create anew instance where exceptions will be caught and suppressed. + /// + /// + public CallbackDisposable(Action callback) + : this(callback, (_) => { }) + { + } + + /// + /// Create a new instance where exceptions will be caught and passed to the supplied handler. + /// + /// + public CallbackDisposable( + Action callback, + Action exceptionHandler) + { + _callback = callback; + _exceptionHandler = exceptionHandler; + } + + + public void Dispose() + { + try + { + _callback.Invoke(); + } + catch (Exception ex) + { + _exceptionHandler.Invoke(ex); + } + } +} diff --git a/Shared/Primitives/CallbackDisposableAsync.cs b/Shared/Primitives/CallbackDisposableAsync.cs new file mode 100644 index 00000000..fe01ca8f --- /dev/null +++ b/Shared/Primitives/CallbackDisposableAsync.cs @@ -0,0 +1,52 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Remotely.Shared.Primitives; + + +/// +/// An implementation of that lets you provide a +/// callback, which will be invoked when the object is disposed. +/// +public sealed class CallbackDisposableAsync : IAsyncDisposable +{ + private readonly Func _callback; + private readonly Func _exceptionHandler; + + /// + /// Create anew instance where exceptions will be caught and suppressed. + /// + /// + public CallbackDisposableAsync(Func callback) + : this(callback, (_) => ValueTask.CompletedTask) + { + } + + /// + /// Create a new instance where exceptions will be caught and passed to the supplied handler. + /// + /// + public CallbackDisposableAsync( + Func callback, + Func exceptionHandler) + { + _callback = callback; + _exceptionHandler = exceptionHandler; + } + + + public ValueTask DisposeAsync() + { + try + { + return _callback.Invoke(); + } + catch (Exception ex) + { + return _exceptionHandler.Invoke(ex); + } + } +}