From de9fc42ea26016fa458ba4e40fef71a79b6e0a2e Mon Sep 17 00:00:00 2001 From: Jared Goodwin Date: Wed, 5 Feb 2020 08:24:59 -0800 Subject: [PATCH] Added ScriptingController logic. --- Agent/Services/CommandExecutor.cs | 11 ++-- Agent/Services/DeviceSocket.cs | 6 +- Server/API/ScriptingController.cs | 92 +++++++++++++++++++++++------- Server/Services/DataService.cs | 6 ++ Server/Services/DeviceSocketHub.cs | 12 ++-- Shared/Helpers/TaskHelper.cs | 21 +++++++ 6 files changed, 112 insertions(+), 36 deletions(-) create mode 100644 Shared/Helpers/TaskHelper.cs diff --git a/Agent/Services/CommandExecutor.cs b/Agent/Services/CommandExecutor.cs index d743c527..0b905f88 100644 --- a/Agent/Services/CommandExecutor.cs +++ b/Agent/Services/CommandExecutor.cs @@ -99,35 +99,35 @@ namespace Remotely.Agent.Services } } - public async Task ExecuteCommandFromApi(string mode, string requestID, string command, string commandID, string senderConnectionID, HubConnection hubConnection) + public async Task ExecuteCommandFromApi(string mode, string requestID, string command, string commandID, string senderUserName, HubConnection hubConnection) { try { switch (mode.ToLower()) { case "pscore": - var psCoreResult = PSCore.GetCurrent(senderConnectionID).WriteInput(command, commandID); + var psCoreResult = PSCore.GetCurrent(senderUserName).WriteInput(command, commandID); await SendResultsViaAjax("PSCore", psCoreResult); break; case "winps": if (OSUtils.IsWindows) { - var result = WindowsPS.GetCurrent(senderConnectionID).WriteInput(command, commandID); + var result = WindowsPS.GetCurrent(senderUserName).WriteInput(command, commandID); await SendResultsViaAjax("WinPS", result); } break; case "cmd": if (OSUtils.IsWindows) { - var result = CMD.GetCurrent(senderConnectionID).WriteInput(command, commandID); + var result = CMD.GetCurrent(senderUserName).WriteInput(command, commandID); await SendResultsViaAjax("CMD", result); } break; case "bash": if (OSUtils.IsLinux) { - var result = Bash.GetCurrent(senderConnectionID).WriteInput(command, commandID); + var result = Bash.GetCurrent(senderUserName).WriteInput(command, commandID); await SendResultsViaAjax("Bash", result); } break; @@ -140,7 +140,6 @@ namespace Remotely.Agent.Services catch (Exception ex) { Logger.Write(ex); - await hubConnection.InvokeAsync("DisplayMessage", "There was an error executing the command. It has been logged on the client device.", "Error executing command.", senderConnectionID); } } private async Task SendResultsViaAjax(string resultType, object result) diff --git a/Agent/Services/DeviceSocket.cs b/Agent/Services/DeviceSocket.cs index 8ef5955d..d63c1771 100644 --- a/Agent/Services/DeviceSocket.cs +++ b/Agent/Services/DeviceSocket.cs @@ -114,16 +114,16 @@ namespace Remotely.Agent.Services await CommandExecutor.ExecuteCommand(mode, command, commandID, senderConnectionID, HubConnection); })); - HubConnection.On("ExecuteCommandFromApi", (async (string mode, string requesterID, string command, string commandID, string senderConnectionID) => + HubConnection.On("ExecuteCommandFromApi", (async (string mode, string requestID, string command, string commandID, string senderUserName) => { if (!IsServerVerified) { - Logger.Write($"Command attempted before server was verified. Mode: {mode}. Command: {command}. Sender: {senderConnectionID}"); + Logger.Write($"Command attempted before server was verified. Mode: {mode}. Command: {command}. Sender: {senderUserName}"); Uninstaller.UninstallAgent(); return; } - await CommandExecutor.ExecuteCommandFromApi(mode, requesterID, command, commandID, senderConnectionID, HubConnection); + await CommandExecutor.ExecuteCommandFromApi(mode, requestID, command, commandID, senderUserName, HubConnection); })); HubConnection.On("TransferFiles", async (string transferID, List fileIDs, string requesterID) => { diff --git a/Server/API/ScriptingController.cs b/Server/API/ScriptingController.cs index be9410fa..83d5b443 100644 --- a/Server/API/ScriptingController.cs +++ b/Server/API/ScriptingController.cs @@ -1,36 +1,84 @@ -using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Identity; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.SignalR; +using Remotely.Server.Services; +using Remotely.Shared.Models; using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Remotely.Shared.Helpers; +using Microsoft.AspNetCore.Http; +using System.IO; namespace Remotely.Server.API { [ApiController] + [Route("api/[controller]")] public class ScriptingController : ControllerBase { - //public Task ExecuteCommand(string mode, string command, string[] deviceIDs) - //{ - // deviceIDs = DataService.FilterDeviceIDsByUserPermission(deviceIDs, RemotelyUser); - // var connections = GetActiveClientConnections(deviceIDs); + public ScriptingController(DataService dataService, + UserManager userManager, + IHubContext deviceHub) + { + DataService = dataService; + UserManager = userManager; + DeviceHub = deviceHub; + } - // var commandContext = new CommandContext() - // { - // CommandMode = mode, - // CommandText = command, - // SenderConnectionID = Context.ConnectionId, - // SenderUserID = Context.UserIdentifier, - // TargetDeviceIDs = connections.Select(x => x.Value.ID).ToArray(), - // OrganizationID = RemotelyUser.OrganizationID - // }; - // DataService.AddOrUpdateCommandContext(commandContext); - // Clients.Caller.SendAsync("CommandContextCreated", commandContext); - // foreach (var connection in connections) - // { - // DeviceHub.Clients.Client(connection.Key).SendAsync("ExecuteCommand", mode, command, commandContext.ID, Context.ConnectionId); - // } + private DataService DataService { get; } + private IHubContext DeviceHub { get; } + private UserManager UserManager { get; } - // return Task.CompletedTask; - //} + [Authorize] + [HttpPost("[action]/{mode}/{deviceID}")] + public async Task> ExecuteCommand(string mode, string deviceID) + { + var command = string.Empty; + using (var sr = new StreamReader(Request.Body)) + { + command = await sr.ReadToEndAsync(); + } + var username = Request.HttpContext.User.Identity.Name; + var user = await UserManager.FindByNameAsync(username); + if (!DataService.DoesUserHaveAccessToDevice(deviceID, user)) + { + return Unauthorized(); + } + + + KeyValuePair connection = DeviceSocketHub.ServiceConnections.FirstOrDefault(x => + x.Value.OrganizationID == user.OrganizationID && + x.Value.ID == deviceID); + + if (string.IsNullOrWhiteSpace(connection.Key)) + { + return NotFound(); + } + + var commandContext = new CommandContext() + { + CommandMode = "PSCore", + CommandText = command, + SenderConnectionID = string.Empty, + SenderUserID = user.Id, + TargetDeviceIDs = new string[] { deviceID }, + OrganizationID = user.OrganizationID + }; + DataService.AddOrUpdateCommandContext(commandContext); + var requestID = Guid.NewGuid().ToString(); + await DeviceHub.Clients.Client(connection.Key).SendAsync("ExecuteCommandFromApi", mode, requestID, command, commandContext.ID, username); + var success = await TaskHelper.DelayUntil(() => DeviceSocketHub.ApiScriptResults.TryGetValue(requestID, out _), TimeSpan.FromSeconds(30)); + if (!success) + { + return commandContext; + } + DeviceSocketHub.ApiScriptResults.TryGetValue(requestID, out var commandID); + DeviceSocketHub.ApiScriptResults.Remove(requestID); + DataService.DetachEntity(commandContext); + var result = DataService.GetCommandContext(commandID.ToString(), username); + return result; + } } } diff --git a/Server/Services/DataService.cs b/Server/Services/DataService.cs index f0171fdf..f1c8aae5 100644 --- a/Server/Services/DataService.cs +++ b/Server/Services/DataService.cs @@ -227,6 +227,12 @@ namespace Remotely.Server.Services RemotelyContext.SaveChanges(); } + public void DetachEntity(object entity) + { + RemotelyContext.Entry(entity).State = EntityState.Detached; + } + + public void DeviceDisconnected(string deviceID) { var device = RemotelyContext.Devices.Find(deviceID); diff --git a/Server/Services/DeviceSocketHub.cs b/Server/Services/DeviceSocketHub.cs index 1a777f33..5bb98609 100644 --- a/Server/Services/DeviceSocketHub.cs +++ b/Server/Services/DeviceSocketHub.cs @@ -9,13 +9,15 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Internal; namespace Remotely.Server.Services { public class DeviceSocketHub : Hub { - public DeviceSocketHub(DataService dataService, - IHubContext browserHub, + public DeviceSocketHub(DataService dataService, + IHubContext browserHub, IHubContext rcBrowserHub) { DataService = dataService; @@ -23,8 +25,8 @@ namespace Remotely.Server.Services RCBrowserHub = rcBrowserHub; } - public static ConcurrentDictionary ServiceConnections { get; } = new ConcurrentDictionary(); - public static ConcurrentDictionary ApiScriptResults { get; } = new ConcurrentDictionary(); + public static ConcurrentDictionary ServiceConnections { get; } = new ConcurrentDictionary(); + public static IMemoryCache ApiScriptResults { get; } = new MemoryCache(new MemoryCacheOptions()); public IHubContext RCBrowserHub { get; } private IHubContext BrowserHub { get; } private DataService DataService { get; } @@ -68,7 +70,7 @@ namespace Remotely.Server.Services public void CommandResultViaApi(string commandID, string requestID) { - ApiScriptResults.AddOrUpdate(requestID, commandID, (k, v) => commandID); + ApiScriptResults.Set(requestID, commandID, DateTimeOffset.Now.AddHours(1)); } public Task DeviceCameOnline(Device device) diff --git a/Shared/Helpers/TaskHelper.cs b/Shared/Helpers/TaskHelper.cs new file mode 100644 index 00000000..e32db1a8 --- /dev/null +++ b/Shared/Helpers/TaskHelper.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading.Tasks; + +namespace Remotely.Shared.Helpers +{ + public static class TaskHelper + { + public static async Task DelayUntil(Func condition, TimeSpan timeout, int pollingMs = 10) + { + var sw = Stopwatch.StartNew(); + while (!condition() && sw.Elapsed < timeout) + { + await Task.Delay(pollingMs); + } + return condition(); + } + } +}