Refactor API controllers.

This commit is contained in:
Jared Goodwin 2023-07-25 11:34:35 -07:00
parent 8b341562bc
commit 28079f887a
27 changed files with 262 additions and 258 deletions

View File

@ -131,7 +131,6 @@ public class UpdaterLinux : IUpdater
_logger.LogInformation("Service Updater: Downloading install package.");
var downloadId = Guid.NewGuid().ToString();
var zipPath = Path.Combine(Path.GetTempPath(), "RemotelyUpdate.zip");
var installerPath = Path.Combine(Path.GetTempPath(), "RemotelyUpdate.sh");
@ -156,12 +155,9 @@ public class UpdaterLinux : IUpdater
installerPath);
await _updateDownloader.DownloadFile(
$"{serverUrl}/API/AgentUpdate/DownloadPackage/linux/{downloadId}",
$"{serverUrl}/API/AgentUpdate/DownloadPackage/linux",
zipPath);
using var httpClient = _httpClientFactory.CreateClient();
using var response = httpClient.GetAsync($"{serverUrl}/api/AgentUpdate/ClearDownload/{downloadId}");
_logger.LogInformation("Launching installer to perform update.");
Process.Start("sudo", $"chmod +x {installerPath}").WaitForExit();

View File

@ -132,7 +132,6 @@ public class UpdaterMac : IUpdater
_logger.LogInformation("Service Updater: Downloading install package.");
var downloadId = Guid.NewGuid().ToString();
var zipPath = Path.Combine(Path.GetTempPath(), "RemotelyUpdate.zip");
var installerPath = Path.Combine(Path.GetTempPath(), "RemotelyUpdate.sh");
@ -142,12 +141,9 @@ public class UpdaterMac : IUpdater
installerPath);
await _updateDownloader.DownloadFile(
$"{serverUrl}/API/AgentUpdate/DownloadPackage/macos-{_achitecture}/{downloadId}",
$"{serverUrl}/API/AgentUpdate/DownloadPackage/macos-{_achitecture}",
zipPath);
using var httpClient = _httpClientFactory.CreateClient();
using var response = httpClient.GetAsync($"{serverUrl}/api/AgentUpdate/ClearDownload/{downloadId}");
_logger.LogInformation("Launching installer to perform update.");
Process.Start("sudo", $"chmod +x {installerPath}").WaitForExit();

View File

@ -102,7 +102,7 @@ public class UpdaterWin : IUpdater
await InstallLatestVersion();
}
catch (WebException ex) when ((ex.Response as HttpWebResponse).StatusCode == HttpStatusCode.NotModified)
catch (WebException ex) when (ex.Response is HttpWebResponse response && response.StatusCode == HttpStatusCode.NotModified)
{
_logger.LogInformation("Service Updater: Version is current.");
return;
@ -128,7 +128,6 @@ public class UpdaterWin : IUpdater
_logger.LogInformation("Service Updater: Downloading install package.");
var downloadId = Guid.NewGuid().ToString();
var zipPath = Path.Combine(Path.GetTempPath(), "RemotelyUpdate.zip");
var installerPath = Path.Combine(Path.GetTempPath(), "Remotely_Installer.exe");
@ -139,12 +138,9 @@ public class UpdaterWin : IUpdater
installerPath);
await _updateDownloader.DownloadFile(
$"{serverUrl}/api/AgentUpdate/DownloadPackage/win-{platform}/{downloadId}",
$"{serverUrl}/api/AgentUpdate/DownloadPackage/win-{platform}",
zipPath);
using var httpClient = _httpClientFactory.CreateClient();
using var response = httpClient.GetAsync($"{serverUrl}/api/AgentUpdate/ClearDownload/{downloadId}");
foreach (var proc in Process.GetProcessesByName("Remotely_Installer"))
{
proc.Kill();
@ -170,7 +166,7 @@ public class UpdaterWin : IUpdater
}
}
private async void UpdateTimer_Elapsed(object sender, System.Timers.ElapsedEventArgs e)
private async void UpdateTimer_Elapsed(object? sender, System.Timers.ElapsedEventArgs e)
{
await CheckForUpdates();
}

View File

@ -1,19 +1,15 @@
using Immense.RemoteControl.Server.Abstractions;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.RateLimiting;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.Caching.Memory;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
using Remotely.Server.Hubs;
using Remotely.Server.RateLimiting;
using Remotely.Server.Services;
using Remotely.Shared.Enums;
using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
namespace Remotely.Server.API;
@ -22,10 +18,6 @@ namespace Remotely.Server.API;
[ApiController]
public class AgentUpdateController : ControllerBase
{
private static readonly MemoryCache _downloadingAgents = new(new MemoryCacheOptions()
{ ExpirationScanFrequency = TimeSpan.FromSeconds(10) });
private readonly IHubContext<AgentHub> _agentHubContext;
private readonly ILogger<AgentUpdateController> _logger;
private readonly IApplicationConfig _appConfig;
@ -45,61 +37,20 @@ public class AgentUpdateController : ControllerBase
_logger = logger;
}
[HttpGet("[action]/{downloadId}")]
public ActionResult ClearDownload(string downloadId)
{
_logger.LogDebug("Clearing download ID {downloadId}.", downloadId);
_downloadingAgents.Remove(downloadId);
return Ok();
}
[HttpGet("[action]/{platform}/{downloadId}")]
public async Task<ActionResult> DownloadPackage(string platform, string downloadId)
[HttpGet("[action]/{platform}")]
[EnableRateLimiting(PolicyNames.AgentUpdateDownloads)]
public async Task<ActionResult> DownloadPackage(string platform)
{
try
{
var remoteIp = Request?.HttpContext?.Connection?.RemoteIpAddress.ToString();
var remoteIp = $"{Request?.HttpContext?.Connection?.RemoteIpAddress}";
if (await CheckForDeviceBan(remoteIp))
{
return BadRequest();
}
var startWait = DateTimeOffset.Now;
while (_downloadingAgents.Count >= _appConfig.MaxConcurrentUpdates)
{
await Task.Delay(new Random().Next(100, 10000));
// A get operation is necessary to evaluate item eviction.
_downloadingAgents.TryGetValue(string.Empty, out _);
}
var entryExpirationTime = TimeSpan.FromMinutes(3);
var tokenExpirationTime = entryExpirationTime.Add(TimeSpan.FromSeconds(15));
var expirationToken = new CancellationChangeToken(
new CancellationTokenSource(tokenExpirationTime).Token);
var cacheOptions = new MemoryCacheEntryOptions()
.SetAbsoluteExpiration(entryExpirationTime)
.AddExpirationToken(expirationToken);
_downloadingAgents.Set(downloadId, string.Empty, cacheOptions);
var waitTime = DateTimeOffset.Now - startWait;
_logger.LogDebug(
"Download started after wait time of {waitTime}. " +
"ID: {downloadId}. " +
"IP: {remoteIp}. " +
"Current Downloads: {_downloadingAgentsCount}. Max Allowed: {_appConfigMaxConcurrentUpdates}",
waitTime,
downloadId,
remoteIp,
_downloadingAgents.Count,
_appConfig.MaxConcurrentUpdates);
string filePath;
switch (platform.ToLower())
@ -133,7 +84,6 @@ public class AgentUpdateController : ControllerBase
}
catch (Exception ex)
{
_downloadingAgents.Remove(downloadId);
_logger.LogError(ex, "Error while downloading package.");
return StatusCode((int)HttpStatusCode.InternalServerError);
}

View File

@ -1,5 +1,8 @@
using Microsoft.AspNetCore.Http;
using Immense.RemoteControl.Shared.Extensions;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Build.Framework;
using Microsoft.Extensions.Logging;
using Remotely.Server.Services;
using Remotely.Shared.Models;
using System;
@ -14,27 +17,49 @@ namespace Remotely.Server.API;
public class BrandingController : ControllerBase
{
private readonly IDataService _dataService;
private readonly ILogger<BrandingController> _logger;
public BrandingController(IDataService dataService)
public BrandingController(
IDataService dataService,
ILogger<BrandingController> logger)
{
_dataService = dataService;
_logger = logger;
}
[HttpGet("{organizationId}")]
public async Task<BrandingInfo> Get(string organizationId)
public async Task<ActionResult<BrandingInfo>> Get(string organizationId)
{
return await _dataService.GetBrandingInfo(organizationId);
var result = await _dataService.GetBrandingInfo(organizationId);
_logger.LogResult(result);
if (!result.IsSuccess)
{
return NotFound();
}
return result.Value;
}
[HttpGet]
public async Task<BrandingInfo> GetDefault()
{
var defaultOrg = await _dataService.GetDefaultOrganization();
if (defaultOrg is null)
var orgResult = await _dataService.GetDefaultOrganization();
_logger.LogResult(orgResult);
if (!orgResult.IsSuccess)
{
return new BrandingInfo();
return new();
}
return await _dataService.GetBrandingInfo(defaultOrg.ID);
var brandingResult = await _dataService.GetBrandingInfo(orgResult.Value.ID);
_logger.LogResult(brandingResult);
if (!orgResult.IsSuccess ||
brandingResult.Value is null)
{
return new();
}
return brandingResult.Value;
}
}

View File

@ -1,22 +1,16 @@
using MailKit.Search;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Build.Framework;
using Microsoft.Extensions.Logging;
using Remotely.Server.Auth;
using Remotely.Server.Extensions;
using Remotely.Server.Services;
using Remotely.Shared;
using Remotely.Shared.Models;
using Remotely.Shared.Services;
using Remotely.Shared.Utilities;
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
@ -147,7 +141,7 @@ public class ClientDownloadsController : ControllerBase
return File(fileBytes, "application/octet-stream", fileName);
}
private async Task<IActionResult> GetDesktopFile(string filePath, string organizationId = null)
private async Task<IActionResult> GetDesktopFile(string filePath, string? organizationId = null)
{
LogRequest(nameof(GetDesktopFile));
@ -158,7 +152,7 @@ public class ClientDownloadsController : ControllerBase
if (!result.IsSuccess)
{
throw result.Exception;
throw result.Exception ?? new Exception(result.Reason);
}
return File(result.Value, "application/octet-stream", Path.GetFileName(filePath));
@ -187,7 +181,7 @@ public class ClientDownloadsController : ControllerBase
if (!result.IsSuccess)
{
throw result.Exception;
throw result.Exception ?? new Exception(result.Reason);
}
return File(result.Value, "application/octet-stream", "Remotely_Installer.exe");

View File

@ -1,5 +1,7 @@
using Microsoft.AspNetCore.Http.Extensions;
using Immense.RemoteControl.Shared.Extensions;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Logging;
using Remotely.Server.Auth;
using Remotely.Server.Extensions;
using Remotely.Server.Services;
@ -8,20 +10,22 @@ using System;
using System.Collections.Generic;
using System.Threading.Tasks;
// For more information on enabling Web API for empty projects, visit https://go.microsoft.com/fwlink/?LinkID=397860
namespace Remotely.Server.API;
[ApiController]
[Route("api/[controller]")]
public class DevicesController : ControllerBase
{
private readonly IDataService _dataService;
private readonly ILogger<DevicesController> _logger;
public DevicesController(IDataService dataService)
public DevicesController(
IDataService dataService,
ILogger<DevicesController> logger)
{
DataService = dataService;
_dataService = dataService;
_logger = logger;
}
private IDataService DataService { get; set; }
[HttpGet]
@ -33,81 +37,103 @@ public class DevicesController : ControllerBase
return Array.Empty<Device>();
}
if (User.Identity?.IsAuthenticated == true &&
!string.IsNullOrWhiteSpace(User.Identity.Name))
if (User.Identity?.IsAuthenticated == true)
{
return DataService.GetDevicesForUser(User.Identity.Name);
return _dataService.GetDevicesForUser($"{User.Identity.Name}");
}
// Authorized with API key. Return all.
return DataService.GetAllDevices(orgId);
return _dataService.GetAllDevices(orgId);
}
[ServiceFilter(typeof(ApiAuthorizationFilter))]
[HttpGet("{id}")]
public ActionResult<Device> Get(string id)
public async Task<ActionResult<Device>> Get(string id)
{
if (!Request.Headers.TryGetOrganizationId(out var orgId))
{
return Unauthorized();
}
var device = DataService.GetDevice(orgId, id);
if (User.Identity?.IsAuthenticated == true &&
!string.IsNullOrWhiteSpace(User.Identity.Name) &&
!DataService.DoesUserHaveAccessToDevice(id, DataService.GetUserByNameWithOrg(User.Identity.Name)))
if (User.Identity?.IsAuthenticated == true)
{
return Unauthorized();
var userResult = await _dataService.GetUserByName($"{User.Identity.Name}");
_logger.LogResult(userResult);
if (!userResult.IsSuccess)
{
return Unauthorized();
}
if (!_dataService.DoesUserHaveAccessToDevice(id, userResult.Value))
{
return Unauthorized();
}
}
return device;
var deviceResult = await _dataService.GetDevice(orgId, id);
_logger.LogResult(deviceResult);
if (!deviceResult.IsSuccess)
{
return NotFound();
}
return deviceResult.Value;
}
[HttpPut]
[ServiceFilter(typeof(ApiAuthorizationFilter))]
public async Task<IActionResult> Update(
[FromBody] DeviceSetupOptions deviceOptions,
[FromHeader] string organizationId)
public async Task<IActionResult> Update([FromBody] DeviceSetupOptions deviceOptions)
{
if (string.IsNullOrWhiteSpace(deviceOptions?.DeviceID) ||
string.IsNullOrWhiteSpace(organizationId))
{
return BadRequest("DeviceOptions and OrganizationId are required.");
}
if (string.IsNullOrWhiteSpace(User.Identity?.Name))
if (!Request.Headers.TryGetOrganizationId(out var orgId))
{
return Unauthorized();
}
var user = DataService.GetUserByNameWithOrg(User.Identity.Name);
if (user is null)
if (string.IsNullOrWhiteSpace(deviceOptions?.DeviceID))
{
return Unauthorized();
return BadRequest("DeviceId is required.");
}
if (User.Identity?.IsAuthenticated == true &&
!DataService.DoesUserHaveAccessToDevice(deviceOptions.DeviceID, user))
if (User.Identity?.IsAuthenticated == true)
{
return Unauthorized();
var userResult = await _dataService.GetUserByName($"{User.Identity.Name}");
_logger.LogResult(userResult);
if (!userResult.IsSuccess)
{
return Unauthorized();
}
if (!_dataService.DoesUserHaveAccessToDevice(deviceOptions.DeviceID, userResult.Value))
{
return Unauthorized();
}
}
var device = await DataService.UpdateDevice(deviceOptions, organizationId);
if (device is null)
var deviceResult = await _dataService.UpdateDevice(deviceOptions, orgId);
_logger.LogResult(deviceResult);
if (!deviceResult.IsSuccess)
{
return BadRequest();
}
return Created(Request.GetDisplayUrl(), device);
return Created(Request.GetDisplayUrl(), deviceResult.Value);
}
[HttpPost]
public async Task<IActionResult> Create([FromBody] DeviceSetupOptions deviceOptions)
{
var device = await DataService.CreateDevice(deviceOptions);
if (device is null)
var result = await _dataService.CreateDevice(deviceOptions);
_logger.LogResult(result);
if (!result.IsSuccess)
{
return BadRequest("Device already exists. Use Put with authorization to update the device.");
}
return Created(Request.GetDisplayUrl(), device);
return Created(Request.GetDisplayUrl(), result.Value);
}
}

View File

@ -1,7 +1,9 @@
using Microsoft.AspNetCore.Mvc;
using Remotely.Server.Auth;
using Remotely.Server.Extensions;
using Remotely.Server.Services;
using Remotely.Shared;
using Remotely.Shared.Models;
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
@ -21,14 +23,18 @@ public class FileSharingController : ControllerBase
[HttpGet("{id}")]
[ServiceFilter(typeof(ExpiringTokenFilter))]
public ActionResult Get(string id)
public async Task<IActionResult> Get(string id)
{
var sharedFile = _dataService.GetSharedFiled(id);
if (sharedFile != null)
var sharedFileResult = await _dataService.GetSharedFiled(id);
if (!sharedFileResult.IsSuccess)
{
return File(sharedFile.FileContents, sharedFile.ContentType, sharedFile.FileName);
return NotFound();
}
return NotFound();
var sharedFile = sharedFileResult.Value;
var contentType = sharedFile.ContentType ?? "application/octet-stream";
return File(sharedFile.FileContents, contentType, sharedFile.FileName);
}
[HttpPost]
@ -36,18 +42,23 @@ public class FileSharingController : ControllerBase
[RequestSizeLimit(AppConstants.MaxUploadFileSize)]
public async Task<IEnumerable<string>> Post()
{
if (Request?.Form?.Files?.Count !> 0)
if (Request.Form.Files.Count !> 0)
{
return Array.Empty<string>();
}
var fileIDs = new List<string>();
var fileIds = new List<string>();
if (!Request.Headers.TryGetOrganizationId(out var orgId))
{
orgId = string.Empty;
}
foreach (var file in Request.Form.Files)
{
var orgID = User.Identity.IsAuthenticated ? _dataService.GetUserByNameWithOrg(User.Identity.Name).OrganizationID : null;
var id = await _dataService.AddSharedFile(file, orgID);
fileIDs.Add(id);
var id = await _dataService.AddSharedFile(file, orgId);
fileIds.Add(id);
}
return fileIDs;
return fileIds;
}
}

View File

@ -26,7 +26,7 @@ public class HealthCheckController : ControllerBase
[HttpGet]
public async Task<IActionResult> Get()
{
var orgCount = await _dataService.GetOrganizationCountAsync();
return Ok($"Organization Count: {orgCount}");
_ = await _dataService.GetOrganizationCountAsync();
return NoContent();
}
}

View File

@ -49,11 +49,15 @@ public class LoginController : ControllerBase
[HttpGet("Logout")]
public async Task<IActionResult> Logout()
{
string orgId = null;
if (HttpContext?.User?.Identity?.IsAuthenticated == true)
{
orgId = _dataService.GetUserByNameWithOrg(HttpContext.User.Identity.Name)?.OrganizationID;
var userResult = await _dataService.GetUserByName($"{HttpContext.User.Identity.Name}");
if (!userResult.IsSuccess)
{
return NotFound();
}
var activeSessions = _remoteControlSessionCache
.Sessions
.Where(x => x.RequesterUserName == HttpContext.User.Identity.Name);
@ -77,8 +81,6 @@ public class LoginController : ControllerBase
return NotFound();
}
var orgId = _dataService.GetUserByNameWithOrg(login.Email)?.OrganizationID;
var result = await _signInManager.PasswordSignInAsync(login.Email, login.Password, false, true);
if (result.Succeeded)
{

View File

@ -54,14 +54,14 @@ public class OrganizationManagementController : ControllerBase
if (User.Identity?.IsAuthenticated == true)
{
var userResult = await _dataService.GetUserByNameWithOrg($"{User.Identity.Name}");
var userResult = await _dataService.GetUserByName($"{User.Identity.Name}");
if (userResult.IsSuccess && userResult.Value.Id == userId)
{
return BadRequest("You can't remove administrator rights from yourself.");
}
}
_dataService.ChangeUserIsAdmin(orgId, userId, isAdmin);
await _dataService.ChangeUserIsAdmin(orgId, userId, isAdmin);
return NoContent();
}
@ -96,7 +96,7 @@ public class OrganizationManagementController : ControllerBase
if (User.Identity?.IsAuthenticated == true)
{
var userResult = await _dataService.GetUserByNameWithOrg($"{User.Identity.Name}");
var userResult = await _dataService.GetUserByName($"{User.Identity.Name}");
if (userResult.IsSuccess && userResult.Value.Id == userId)
{
return BadRequest("You can't delete yourself here. You must go to the Personal Data page to delete your own account.");

View File

@ -1,21 +1,17 @@
using Microsoft.AspNetCore.Identity;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.SignalR;
using Remotely.Server.Attributes;
using Remotely.Server.Hubs;
using Remotely.Server.Models;
using Remotely.Server.Services;
using Remotely.Shared.Utilities;
using Remotely.Shared.Models;
using System;
using System.Linq;
using System.Threading.Tasks;
using Remotely.Server.Auth;
using Immense.RemoteControl.Server.Services;
using Remotely.Server.Services.RcImplementations;
using Immense.RemoteControl.Server.Abstractions;
using Immense.RemoteControl.Shared.Helpers;
using Microsoft.Build.Framework;
using Microsoft.Extensions.Logging;
using Remotely.Server.Extensions;
@ -79,7 +75,7 @@ public class RemoteControlController : ControllerBase
return NotFound();
}
var userResult = await _dataService.GetUserByNameWithOrg(rcRequest.Email);
var userResult = await _dataService.GetUserByName(rcRequest.Email);
if (!userResult.IsSuccess)
{
return NotFound();
@ -123,7 +119,7 @@ public class RemoteControlController : ControllerBase
if (User.Identity?.IsAuthenticated == true)
{
var userResult = await _dataService.GetUserByNameWithOrg($"{User.Identity.Name}");
var userResult = await _dataService.GetUserByName($"{User.Identity.Name}");
if (!userResult.IsSuccess)
{

View File

@ -23,8 +23,14 @@ public class SavedScriptsController : ControllerBase
[ServiceFilter(typeof(ExpiringTokenFilter))]
[HttpGet("{scriptId}")]
public async Task<SavedScript> GetScript(Guid scriptId)
public async Task<ActionResult<SavedScript>> GetScript(Guid scriptId)
{
return await _dataService.GetSavedScript(scriptId);
var result = await _dataService.GetSavedScript(scriptId);
if (!result.IsSuccess)
{
return NotFound();
}
return result.Value;
}
}

View File

@ -8,8 +8,6 @@ using System;
using System.Text;
using System.Threading.Tasks;
// For more information on enabling Web API for empty projects, visit https://go.microsoft.com/fwlink/?LinkID=397860
namespace Remotely.Server.API;
[Route("api/[controller]")]
@ -25,8 +23,6 @@ public class ScriptResultsController : ControllerBase
_emailSender = emailSenderEx;
}
// GET: api/<controller>
[HttpGet]
[ServiceFilter(typeof(ApiAuthorizationFilter))]
public ActionResult DownloadAll()

View File

@ -6,15 +6,13 @@ using Remotely.Server.Services;
using Remotely.Shared.Utilities;
using Remotely.Shared.Models;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Remotely.Shared.Enums;
using Remotely.Server.Auth;
using Immense.RemoteControl.Server.Abstractions;
using Immense.RemoteControl.Shared.Helpers;
using Remotely.Shared;
using Remotely.Server.Extensions;
namespace Remotely.Server.API;
@ -47,6 +45,11 @@ public class ScriptingController : ControllerBase
[HttpPost("[action]/{mode}/{deviceID}")]
public async Task<ActionResult<ScriptResult>> ExecuteCommand(string mode, string deviceID)
{
if (!Request.Headers.TryGetOrganizationId(out var orgId))
{
return Unauthorized();
}
if (!Enum.TryParse<ScriptingShell>(mode, true, out var shell))
{
return BadRequest("Unable to parse shell type. Use either PSCore, WinPS, Bash, or CMD.");
@ -59,20 +62,22 @@ public class ScriptingController : ControllerBase
}
var userID = string.Empty;
if (Request.HttpContext.User.Identity.IsAuthenticated)
if (Request.HttpContext.User.Identity?.IsAuthenticated == true)
{
var username = Request.HttpContext.User.Identity.Name;
var user = await _userManager.FindByNameAsync(username);
userID = user.Id;
if (!_dataService.DoesUserHaveAccessToDevice(deviceID, user))
var userResult = await _dataService.GetUserByName($"{username}");
if (!userResult.IsSuccess)
{
return Unauthorized();
}
if (!_dataService.DoesUserHaveAccessToDevice(deviceID, userResult.Value))
{
return Unauthorized();
}
}
Request.Headers.TryGetValue("OrganizationID", out var orgID);
if (!_serviceSessionCache.TryGetByDeviceId(deviceID, out var device))
{
return NotFound();
@ -83,7 +88,7 @@ public class ScriptingController : ControllerBase
return NotFound();
}
if (device.OrganizationID != orgID)
if (device.OrganizationID != orgId)
{
return Unauthorized();
}
@ -99,9 +104,14 @@ public class ScriptingController : ControllerBase
{
return NotFound();
}
AgentHub.ApiScriptResults.TryGetValue(requestID, out var commandID);
AgentHub.ApiScriptResults.TryGetValue(requestID, out var commandId);
AgentHub.ApiScriptResults.Remove(requestID);
var result = _dataService.GetScriptResult(commandID.ToString(), orgID);
return result;
var scriptResult = await _dataService.GetScriptResult($"{commandId}", orgId);
if (!scriptResult.IsSuccess)
{
return NotFound();
}
return scriptResult.Value;
}
}

View File

@ -34,9 +34,15 @@ public class ServerLogsController : ControllerBase
HttpContext.Connection.RemoteIpAddress);
var zipFile = await _logsManager.ZipAllLogs();
Response.OnCompleted(() =>
{
Directory.Delete(zipFile.DirectoryName, true);
if (zipFile.Directory is null)
{
return Task.CompletedTask;
}
zipFile.Directory.Delete(true);
return Task.CompletedTask;
});

View File

@ -1,31 +0,0 @@
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.Extensions.Caching.Memory;
using System;
using System.Net;
namespace Remotely.Server.Attributes;
[AttributeUsage(AttributeTargets.Method)]
public class ActionRateLimiterAttribute : ActionFilterAttribute
{
public string Action { get; set; }
public int TimeoutInSeconds { get; set; } = 5;
private static MemoryCache RequestCache { get; } = new MemoryCache(new MemoryCacheOptions());
public override void OnActionExecuting(ActionExecutingContext context)
{
var ip = context.HttpContext.Request.HttpContext.Connection.RemoteIpAddress;
var key = $"Action-{ip}";
if (!RequestCache.TryGetValue(key, out _))
{
RequestCache.Set(key, true, TimeSpan.FromSeconds(TimeoutInSeconds));
}
else
{
context.HttpContext.Response.StatusCode = (int)HttpStatusCode.TooManyRequests;
}
base.OnActionExecuting(context);
}
}

View File

@ -41,9 +41,19 @@ public class ExpiringTokenFilter : ActionFilterAttribute, IAsyncAuthorizationFil
private async Task Authorize(AuthorizationFilterContext context)
{
var http = context.HttpContext;
http.Request.Headers["OrganizationID"] = string.Empty;
if (http.User.Identity?.IsAuthenticated == true)
{
var userResult = await _dataService.GetUserByName($"{http.User.Identity.Name}");
if (!userResult.IsSuccess)
{
http.Response.StatusCode = (int)HttpStatusCode.Forbidden;
context.Result = new UnauthorizedResult();
return;
}
http.Request.Headers["OrganizationID"] = userResult.Value.OrganizationID;
return;
}

View File

@ -66,7 +66,8 @@ public class AppDb : IdentityDbContext
.WithOne(x => x.Organization);
builder.Entity<Organization>()
.HasMany(x => x.SharedFiles)
.WithOne(x => x.Organization);
.WithOne(x => x.Organization)
.IsRequired(false);
builder.Entity<Organization>()
.HasMany(x => x.ApiTokens)
.WithOne(x => x.Organization);

View File

@ -186,7 +186,7 @@
if (_isAuthenticated)
{
var currentUser = await DataService.GetUserAsync(authState.User.Identity.Name);
var currentUser = await DataService.GetUserByName(authState.User.Identity.Name);
_organizationId = currentUser.OrganizationID;
}
else

View File

@ -279,7 +279,7 @@ public partial class ManageOrganization : AuthComponentBase
var result = await DataService.CreateUser(_inviteEmail, _inviteAsAdmin, User.OrganizationID);
if (result)
{
var user = await DataService.GetUserAsync(_inviteEmail);
var user = await DataService.GetUserByName(_inviteEmail);
await UserManager.ConfirmEmailAsync(user, await UserManager.GenerateEmailConfirmationTokenAsync(user));
@ -337,7 +337,7 @@ public partial class ManageOrganization : AuthComponentBase
}
}
private void SetUserIsAdmin(ChangeEventArgs args, RemotelyUser orgUser)
private async Task SetUserIsAdmin(ChangeEventArgs args, RemotelyUser orgUser)
{
if (!User.IsAdministrator)
{
@ -349,7 +349,7 @@ public partial class ManageOrganization : AuthComponentBase
return;
}
DataService.ChangeUserIsAdmin(User.OrganizationID, orgUser.Id, isAdmin);
await DataService.ChangeUserIsAdmin(User.OrganizationID, orgUser.Id, isAdmin);
ToastService.ShowToast("Administrator value set.");
}

View File

@ -37,6 +37,8 @@ using System;
using Immense.RemoteControl.Server.Services;
using Serilog;
using Nihs.SimpleMessenger;
using Microsoft.AspNetCore.RateLimiting;
using RatePolicyNames = Remotely.Server.RateLimiting.PolicyNames;
var builder = WebApplication.CreateBuilder(args);
var configuration = builder.Configuration;
@ -53,7 +55,12 @@ if (OperatingSystem.IsWindows() &&
builder.Logging.AddEventLog();
}
var dbProvider = configuration["ApplicationOptions:DBProvider"].ToLower();
var dbProvider = configuration["ApplicationOptions:DBProvider"]?.ToLower();
if (string.IsNullOrWhiteSpace(dbProvider))
{
throw new InvalidOperationException("DBProvider is missing from appsettings.json.");
}
if (dbProvider == "sqlite")
{
services.AddDbContext<AppDb, SqliteDbContext>(options =>
@ -191,7 +198,21 @@ services.AddSwaggerGen(c =>
c.SwaggerDoc("v1", new OpenApiInfo { Title = "Remotely API", Version = "v1" });
});
services.AddRateLimiter(options =>
{
options.AddConcurrencyLimiter(RatePolicyNames.AgentUpdateDownloads, clOptions =>
{
clOptions.QueueLimit = int.MaxValue;
var concurrentPermits = configuration.GetSection("ApplicationOptions:MaxConcurrentUpdates").Get<int>();
if (concurrentPermits <= 0)
{
concurrentPermits = 10;
}
clOptions.PermitLimit = concurrentPermits;
});
});
services.AddHttpClient();
services.AddLogging();
services.AddScoped<IEmailSenderEx, EmailSenderEx>();
@ -233,6 +254,9 @@ services.AddRemoteControlServer(config =>
services.AddSingleton<IAgentHubSessionCache, AgentHubSessionCache>();
var app = builder.Build();
app.UseRateLimiter();
var appConfig = app.Services.GetRequiredService<IApplicationConfig>();
if (appConfig.UseHttpLogging)
@ -250,11 +274,11 @@ if (app.Environment.IsDevelopment())
else
{
app.UseExceptionHandler("/Error");
if (bool.Parse(app.Configuration["ApplicationOptions:UseHsts"]))
if (bool.TryParse(app.Configuration["ApplicationOptions:UseHsts"], out var hsts) && hsts)
{
app.UseHsts();
}
if (bool.Parse(app.Configuration["ApplicationOptions:RedirectToHttps"]))
if (bool.TryParse(app.Configuration["ApplicationOptions:RedirectToHttps"], out var redirect) && redirect)
{
app.UseHttpsRedirection();
}

View File

@ -0,0 +1,6 @@
namespace Remotely.Server.RateLimiting;
public static class PolicyNames
{
public const string AgentUpdateDownloads = nameof(AgentUpdateDownloads);
}

View File

@ -42,7 +42,7 @@ public class AuthService : IAuthService
if (principal?.User?.Identity?.IsAuthenticated == true)
{
return await _dataService.GetUserAsync($"{principal.User.Identity.Name}");
return await _dataService.GetUserByName($"{principal.User.Identity.Name}");
}
return Result.Fail<RemotelyUser>("Not authenticated.");

View File

@ -50,7 +50,7 @@ public interface IDataService
bool AddUserToDeviceGroup(string orgId, string groupId, string userName, out string resultMessage);
void ChangeUserIsAdmin(string organizationId, string targetUserId, bool isAdmin);
Task ChangeUserIsAdmin(string organizationId, string targetUserId, bool isAdmin);
Task CleanupOldRecords();
@ -174,11 +174,11 @@ public interface IDataService
int GetTotalDevices();
Task<Result<RemotelyUser>> GetUserAsync(string username);
Task<Result<RemotelyUser>> GetUserById(string userId);
Task<Result<RemotelyUser>> GetUserByNameWithOrg(string userName);
Task<Result<RemotelyUser>> GetUserByName(
string userName,
Action<IQueryable<RemotelyUser>>? includesBuilder = null);
Task<Result<RemotelyUserOptions>> GetUserOptions(string userName);
@ -562,7 +562,7 @@ public class DataService : IDataService
progressCallback.Invoke(1, file.Name);
return await AddSharedFileInternal(file.Name, fileContents, file.ContentType, organizationId);
return await AddSharedFileImpl(file.Name, fileContents, file.ContentType, organizationId);
}
public async Task<string> AddSharedFile(IFormFile file, string organizationId)
@ -571,7 +571,7 @@ public class DataService : IDataService
using var stream = file.OpenReadStream();
await stream.ReadAsync(fileContents.AsMemory(0, (int)file.Length));
return await AddSharedFileInternal(file.Name, fileContents, file.ContentType, organizationId);
return await AddSharedFileImpl(file.Name, fileContents, file.ContentType, organizationId);
}
public bool AddUserToDeviceGroup(string orgId, string groupId, string userName, out string resultMessage)
@ -622,11 +622,11 @@ public class DataService : IDataService
return true;
}
public void ChangeUserIsAdmin(string organizationId, string targetUserId, bool isAdmin)
public async Task ChangeUserIsAdmin(string organizationId, string targetUserId, bool isAdmin)
{
using var dbContext = _appDbFactory.GetContext();
var targetUser = dbContext.Users.FirstOrDefault(x =>
var targetUser = await dbContext.Users.FirstOrDefaultAsync(x =>
x.OrganizationID == organizationId &&
x.Id == targetUserId);
@ -795,7 +795,7 @@ public class DataService : IDataService
if (!string.IsNullOrWhiteSpace(userName))
{
var userResult = await GetUserByNameWithOrg(userName);
var userResult = await GetUserByName(userName);
if (userResult.IsSuccess)
{
@ -1277,10 +1277,7 @@ public class DataService : IDataService
using var dbContext = _appDbFactory.GetContext();
var query = dbContext.Devices.AsQueryable();
if (includesBuilder is not null)
{
includesBuilder(query);
}
includesBuilder?.Invoke(query);
var device = await query.FirstOrDefaultAsync(x => x.ID == deviceId);
if (device is null)
@ -1683,23 +1680,6 @@ public class DataService : IDataService
return dbContext.Devices.Count();
}
public async Task<Result<RemotelyUser>> GetUserAsync(string username)
{
if (string.IsNullOrWhiteSpace(username))
{
return Result.Fail<RemotelyUser>("Username cannot be empty.");
}
using var dbContext = _appDbFactory.GetContext();
var user = await dbContext.Users.FirstOrDefaultAsync(x => x.UserName == username);
if (user is null)
{
return Result.Fail<RemotelyUser>("User not found.");
}
return Result.Ok(user);
}
public async Task<Result<RemotelyUser>> GetUserById(string userId)
{
if (string.IsNullOrWhiteSpace(userId))
@ -1717,7 +1697,9 @@ public class DataService : IDataService
return Result.Ok(user);
}
public async Task<Result<RemotelyUser>> GetUserByNameWithOrg(string userName)
public async Task<Result<RemotelyUser>> GetUserByName(
string userName,
Action<IQueryable<RemotelyUser>>? includesBuilder = null)
{
if (string.IsNullOrWhiteSpace(userName))
{
@ -1726,9 +1708,11 @@ public class DataService : IDataService
using var dbContext = _appDbFactory.GetContext();
var user = await dbContext.Users
.Include(x => x.Organization)
.FirstOrDefaultAsync(x => x.UserName!.ToLower().Trim() == userName.ToLower().Trim());
var query = dbContext.Users.AsQueryable();
includesBuilder?.Invoke(query);
var user = await query.FirstOrDefaultAsync(x =>
x.UserName!.ToLower().Trim() == userName.ToLower().Trim());
if (user is null)
{
@ -1966,7 +1950,7 @@ public class DataService : IDataService
return false;
}
var userResult = await GetUserByNameWithOrg(email);
var userResult = await GetUserByName(email);
if (!userResult.IsSuccess)
{
@ -2152,7 +2136,7 @@ public class DataService : IDataService
return isValid;
}
private async Task<string> AddSharedFileInternal(
private async Task<string> AddSharedFileImpl(
string fileName,
byte[] fileContents,
string contentType,

View File

@ -14,7 +14,7 @@ public class EmbeddedServerData
{
[SerializationConstructor]
[JsonConstructor]
public EmbeddedServerData(Uri serverUrl, string organizationId)
public EmbeddedServerData(Uri serverUrl, string? organizationId)
{
ServerUrl = serverUrl;
OrganizationId = organizationId ?? string.Empty;

View File

@ -14,5 +14,5 @@ public class SharedFile
public byte[] FileContents { get; set; } = Array.Empty<byte>();
public DateTimeOffset Timestamp { get; set; } = DateTimeOffset.Now;
public Organization? Organization { get; set; }
public string OrganizationID { get; set; } = null!;
public string? OrganizationID { get; set; }
}