diff --git a/Agent/Services/CommandExecutor.cs b/Agent/Services/CommandExecutor.cs index 3ee1a52a..0b905f88 100644 --- a/Agent/Services/CommandExecutor.cs +++ b/Agent/Services/CommandExecutor.cs @@ -40,7 +40,6 @@ namespace Remotely.Agent.Services } break; } - case "winps": if (OSUtils.IsWindows) { @@ -99,6 +98,50 @@ namespace Remotely.Agent.Services await hubConnection.InvokeAsync("DisplayMessage", "There was an error executing the command. It has been logged on the client device.", "Error executing command.", senderConnectionID); } } + + public async Task ExecuteCommandFromApi(string mode, string requestID, string command, string commandID, string senderUserName, HubConnection hubConnection) + { + try + { + switch (mode.ToLower()) + { + case "pscore": + var psCoreResult = PSCore.GetCurrent(senderUserName).WriteInput(command, commandID); + await SendResultsViaAjax("PSCore", psCoreResult); + break; + + case "winps": + if (OSUtils.IsWindows) + { + var result = WindowsPS.GetCurrent(senderUserName).WriteInput(command, commandID); + await SendResultsViaAjax("WinPS", result); + } + break; + case "cmd": + if (OSUtils.IsWindows) + { + var result = CMD.GetCurrent(senderUserName).WriteInput(command, commandID); + await SendResultsViaAjax("CMD", result); + } + break; + case "bash": + if (OSUtils.IsLinux) + { + var result = Bash.GetCurrent(senderUserName).WriteInput(command, commandID); + await SendResultsViaAjax("Bash", result); + } + break; + default: + break; + } + + await hubConnection.InvokeAsync("CommandResultViaApi", commandID, requestID); + } + catch (Exception ex) + { + Logger.Write(ex); + } + } private async Task SendResultsViaAjax(string resultType, object result) { var targetURL = ConfigService.GetConnectionInfo().Host + $"/API/Commands/{resultType}"; diff --git a/Agent/Services/DeviceSocket.cs b/Agent/Services/DeviceSocket.cs index aa16824b..d63c1771 100644 --- a/Agent/Services/DeviceSocket.cs +++ b/Agent/Services/DeviceSocket.cs @@ -114,6 +114,17 @@ namespace Remotely.Agent.Services await CommandExecutor.ExecuteCommand(mode, command, commandID, senderConnectionID, HubConnection); })); + HubConnection.On("ExecuteCommandFromApi", (async (string mode, string requestID, string command, string commandID, string senderUserName) => + { + if (!IsServerVerified) + { + Logger.Write($"Command attempted before server was verified. Mode: {mode}. Command: {command}. Sender: {senderUserName}"); + Uninstaller.UninstallAgent(); + return; + } + + await CommandExecutor.ExecuteCommandFromApi(mode, requestID, command, commandID, senderUserName, HubConnection); + })); HubConnection.On("TransferFiles", async (string transferID, List fileIDs, string requesterID) => { Logger.Write($"File transfer started by {requesterID}."); diff --git a/ScreenCast.Core/Communication/CasterSocket.cs b/ScreenCast.Core/Communication/CasterSocket.cs index 7a20f508..70f849b9 100644 --- a/ScreenCast.Core/Communication/CasterSocket.cs +++ b/ScreenCast.Core/Communication/CasterSocket.cs @@ -211,6 +211,16 @@ namespace Remotely.ScreenCast.Core.Communication } }); + Connection.On("Disconnect", async (string reason) => + { + Logger.Write($"Disconnecting caster socket. Reason: {reason}"); + foreach (var viewer in conductor.Viewers.Values.ToList()) + { + await Connection.InvokeAsync("ViewerDisconnected", viewer.ViewerConnectionID); + viewer.DisconnectRequested = true; + } + }); + Connection.On("GetScreenCast", (string viewerID, string requesterName) => { try @@ -302,11 +312,11 @@ namespace Remotely.ScreenCast.Core.Communication Connection.On("ViewerDisconnected", async (string viewerID) => { + await Connection.InvokeAsync("ViewerDisconnected", viewerID); if (conductor.Viewers.TryGetValue(viewerID, out var viewer)) { viewer.DisconnectRequested = true; } - await Connection.InvokeAsync("ViewerDisconnected", viewerID); conductor.InvokeViewerRemoved(viewerID); }); diff --git a/Server/API/LoginController.cs b/Server/API/LoginController.cs index f64a124e..5d77a7db 100644 --- a/Server/API/LoginController.cs +++ b/Server/API/LoginController.cs @@ -10,6 +10,7 @@ using Remotely.Shared.Models; using Remotely.Server.Data; using Remotely.Server.Models; using Remotely.Server.Services; +using Microsoft.AspNetCore.SignalR; // For more information on enabling Web API for empty projects, visit https://go.microsoft.com/fwlink/?LinkID=397860 @@ -19,16 +20,24 @@ namespace Remotely.Server.API [ApiController] public class LoginController : ControllerBase { - public LoginController(SignInManager signInManager, DataService dataService, ApplicationConfig appConfig) + public LoginController(SignInManager signInManager, + DataService dataService, + ApplicationConfig appConfig, + IHubContext rcDeviceHub, + IHubContext rcBrowserHub) { SignInManager = signInManager; DataService = dataService; AppConfig = appConfig; + RCDeviceHub = rcDeviceHub; + RCBrowserHub = rcBrowserHub; } private SignInManager SignInManager { get; } private DataService DataService { get; } public ApplicationConfig AppConfig { get; } + private IHubContext RCDeviceHub { get; } + private IHubContext RCBrowserHub { get; } [HttpPost] public async Task Post([FromBody]ApiLogin login) @@ -68,6 +77,12 @@ namespace Remotely.Server.API if (HttpContext?.User?.Identity?.IsAuthenticated == true) { orgId = DataService.GetUserByName(HttpContext.User.Identity.Name)?.OrganizationID; + var activeSessions = RCDeviceSocketHub.SessionInfoList.Where(x => x.Value.RequesterUserName == HttpContext.User.Identity.Name); + foreach (var session in activeSessions.ToList()) + { + await RCDeviceHub.Clients.Client(session.Value.RCDeviceSocketID).SendAsync("Disconnect", "User logged out."); + await RCBrowserHub.Clients.Client(session.Value.RequesterSocketID).SendAsync("ConnectionFailed"); + } } await SignInManager.SignOutAsync(); DataService.WriteEvent($"API logout successful for {HttpContext?.User?.Identity?.Name}.", orgId); diff --git a/Server/API/RemoteControlController.cs b/Server/API/RemoteControlController.cs index 1cb93301..73e62e9b 100644 --- a/Server/API/RemoteControlController.cs +++ b/Server/API/RemoteControlController.cs @@ -94,19 +94,19 @@ namespace Remotely.Server.API var stopWatch = Stopwatch.StartNew(); - while (!RCDeviceSocketHub.SessionInfoList.Values.Any(x=>x.DeviceID == targetDevice.Value.ID && !existingSessions.Any(y=>y.Key != x.RCSocketID)) && stopWatch.Elapsed.TotalSeconds < 5) + while (!RCDeviceSocketHub.SessionInfoList.Values.Any(x=>x.DeviceID == targetDevice.Value.ID && !existingSessions.Any(y=>y.Key != x.RCDeviceSocketID)) && stopWatch.Elapsed.TotalSeconds < 5) { await Task.Delay(10); } - if (!RCDeviceSocketHub.SessionInfoList.Values.Any(x => x.DeviceID == targetDevice.Value.ID && !existingSessions.Any(y => y.Key != x.RCSocketID))) + if (!RCDeviceSocketHub.SessionInfoList.Values.Any(x => x.DeviceID == targetDevice.Value.ID && !existingSessions.Any(y => y.Key != x.RCDeviceSocketID))) { return StatusCode(408, "The remote control process failed to start in time on the remote device."); } else { - var rcSession = RCDeviceSocketHub.SessionInfoList.Values.FirstOrDefault(x=>x.DeviceID == targetDevice.Value.ID && !existingSessions.Any(y=>y.Key != x.RCSocketID)); - return Ok($"{HttpContext.Request.Scheme}://{Request.Host}/RemoteControl?clientID={rcSession.RCSocketID}&serviceID={targetDevice.Key}&fromApi=true"); + var rcSession = RCDeviceSocketHub.SessionInfoList.Values.FirstOrDefault(x=>x.DeviceID == targetDevice.Value.ID && !existingSessions.Any(y=>y.Key != x.RCDeviceSocketID)); + return Ok($"{HttpContext.Request.Scheme}://{Request.Host}/RemoteControl?clientID={rcSession.RCDeviceSocketID}&serviceID={targetDevice.Key}&fromApi=true"); } } else diff --git a/Server/API/ScriptingController.cs b/Server/API/ScriptingController.cs new file mode 100644 index 00000000..83d5b443 --- /dev/null +++ b/Server/API/ScriptingController.cs @@ -0,0 +1,84 @@ +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Identity; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.SignalR; +using Remotely.Server.Services; +using Remotely.Shared.Models; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Remotely.Shared.Helpers; +using Microsoft.AspNetCore.Http; +using System.IO; + +namespace Remotely.Server.API +{ + [ApiController] + [Route("api/[controller]")] + public class ScriptingController : ControllerBase + { + public ScriptingController(DataService dataService, + UserManager userManager, + IHubContext deviceHub) + { + DataService = dataService; + UserManager = userManager; + DeviceHub = deviceHub; + } + + private DataService DataService { get; } + private IHubContext DeviceHub { get; } + private UserManager UserManager { get; } + + [Authorize] + [HttpPost("[action]/{mode}/{deviceID}")] + public async Task> ExecuteCommand(string mode, string deviceID) + { + var command = string.Empty; + using (var sr = new StreamReader(Request.Body)) + { + command = await sr.ReadToEndAsync(); + } + var username = Request.HttpContext.User.Identity.Name; + var user = await UserManager.FindByNameAsync(username); + if (!DataService.DoesUserHaveAccessToDevice(deviceID, user)) + { + return Unauthorized(); + } + + + KeyValuePair connection = DeviceSocketHub.ServiceConnections.FirstOrDefault(x => + x.Value.OrganizationID == user.OrganizationID && + x.Value.ID == deviceID); + + if (string.IsNullOrWhiteSpace(connection.Key)) + { + return NotFound(); + } + + var commandContext = new CommandContext() + { + CommandMode = "PSCore", + CommandText = command, + SenderConnectionID = string.Empty, + SenderUserID = user.Id, + TargetDeviceIDs = new string[] { deviceID }, + OrganizationID = user.OrganizationID + }; + DataService.AddOrUpdateCommandContext(commandContext); + var requestID = Guid.NewGuid().ToString(); + await DeviceHub.Clients.Client(connection.Key).SendAsync("ExecuteCommandFromApi", mode, requestID, command, commandContext.ID, username); + var success = await TaskHelper.DelayUntil(() => DeviceSocketHub.ApiScriptResults.TryGetValue(requestID, out _), TimeSpan.FromSeconds(30)); + if (!success) + { + return commandContext; + } + DeviceSocketHub.ApiScriptResults.TryGetValue(requestID, out var commandID); + DeviceSocketHub.ApiScriptResults.Remove(requestID); + DataService.DetachEntity(commandContext); + var result = DataService.GetCommandContext(commandID.ToString(), username); + return result; + } + } +} diff --git a/Server/Areas/Identity/IdentityHostingStartup.cs b/Server/Areas/Identity/IdentityHostingStartup.cs new file mode 100644 index 00000000..4d28f1e7 --- /dev/null +++ b/Server/Areas/Identity/IdentityHostingStartup.cs @@ -0,0 +1,22 @@ +using System; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Identity; +using Microsoft.AspNetCore.Identity.UI; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Remotely.Server.Data; +using Remotely.Shared.Models; + +[assembly: HostingStartup(typeof(Remotely.Server.Areas.Identity.IdentityHostingStartup))] +namespace Remotely.Server.Areas.Identity +{ + public class IdentityHostingStartup : IHostingStartup + { + public void Configure(IWebHostBuilder builder) + { + builder.ConfigureServices((context, services) => { + }); + } + } +} \ No newline at end of file diff --git a/Server/Areas/Identity/Pages/Account/Logout.cshtml b/Server/Areas/Identity/Pages/Account/Logout.cshtml new file mode 100644 index 00000000..cb864ef2 --- /dev/null +++ b/Server/Areas/Identity/Pages/Account/Logout.cshtml @@ -0,0 +1,10 @@ +@page +@model LogoutModel +@{ + ViewData["Title"] = "Log out"; +} + +
+

@ViewData["Title"]

+

You have successfully logged out of the application.

+
\ No newline at end of file diff --git a/Server/Areas/Identity/Pages/Account/Logout.cshtml.cs b/Server/Areas/Identity/Pages/Account/Logout.cshtml.cs new file mode 100644 index 00000000..6437de81 --- /dev/null +++ b/Server/Areas/Identity/Pages/Account/Logout.cshtml.cs @@ -0,0 +1,63 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Identity; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.RazorPages; +using Microsoft.AspNetCore.SignalR; +using Microsoft.Extensions.Logging; +using Remotely.Server.Services; +using Remotely.Shared.Models; + +namespace Remotely.Server.Areas.Identity.Pages.Account +{ + [AllowAnonymous] + public class LogoutModel : PageModel + { + private readonly ILogger _logger; + private readonly SignInManager _signInManager; + public LogoutModel(SignInManager signInManager, + ILogger logger, + IHubContext rcDeviceHub, + IHubContext rcBrowserHub) + { + _signInManager = signInManager; + _logger = logger; + RCDeviceHub = rcDeviceHub; + RCBrowserHub = rcBrowserHub; + } + + private IHubContext RCDeviceHub { get; } + private IHubContext RCBrowserHub { get; } + + public void OnGet() + { + } + + public async Task OnPost(string returnUrl = null) + { + if (HttpContext.User.Identity.IsAuthenticated) + { + var activeSessions = RCDeviceSocketHub.SessionInfoList.Where(x => x.Value.RequesterUserName == HttpContext.User.Identity.Name); + foreach (var session in activeSessions.ToList()) + { + await RCDeviceHub.Clients.Client(session.Value.RCDeviceSocketID).SendAsync("Disconnect", "User logged out."); + await RCBrowserHub.Clients.Client(session.Value.RequesterSocketID).SendAsync("ConnectionFailed"); + } + } + + await _signInManager.SignOutAsync(); + _logger.LogInformation("User logged out."); + if (returnUrl != null) + { + return LocalRedirect(returnUrl); + } + else + { + return RedirectToPage(); + } + } + } +} diff --git a/Server/Models/RCSessionInfo.cs b/Server/Models/RCSessionInfo.cs index 826a4397..992217e5 100644 --- a/Server/Models/RCSessionInfo.cs +++ b/Server/Models/RCSessionInfo.cs @@ -13,8 +13,10 @@ namespace Remotely.Server.Models public string MachineName { get; set; } public RemoteControlMode Mode { get; set; } public string OrganizationID { get; set; } - public string RCSocketID { get; set; } + public string RCDeviceSocketID { get; set; } public string RequesterName { get; set; } + public string RequesterSocketID { get; set; } + public string RequesterUserName { get; set; } public string ServiceID { get; set; } public DateTime StartTime { get; set; } } diff --git a/Server/Services/BrowserSocketHub.cs b/Server/Services/BrowserSocketHub.cs index fb8f2bf1..90cdaec3 100644 --- a/Server/Services/BrowserSocketHub.cs +++ b/Server/Services/BrowserSocketHub.cs @@ -102,8 +102,7 @@ namespace Remotely.Server.Services return Task.CompletedTask; } - - public override async Task OnConnectedAsync() + public override async Task OnConnectedAsync() { RemotelyUser = DataService.GetUserByID(Context.UserIdentifier); if (await IsConnectionValid() == false) diff --git a/Server/Services/DataService.cs b/Server/Services/DataService.cs index f0171fdf..f1c8aae5 100644 --- a/Server/Services/DataService.cs +++ b/Server/Services/DataService.cs @@ -227,6 +227,12 @@ namespace Remotely.Server.Services RemotelyContext.SaveChanges(); } + public void DetachEntity(object entity) + { + RemotelyContext.Entry(entity).State = EntityState.Detached; + } + + public void DeviceDisconnected(string deviceID) { var device = RemotelyContext.Devices.Find(deviceID); diff --git a/Server/Services/DeviceSocketHub.cs b/Server/Services/DeviceSocketHub.cs index 648f380b..5bb98609 100644 --- a/Server/Services/DeviceSocketHub.cs +++ b/Server/Services/DeviceSocketHub.cs @@ -9,13 +9,15 @@ using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.Extensions.Caching.Memory; +using Microsoft.Extensions.Internal; namespace Remotely.Server.Services { public class DeviceSocketHub : Hub { - public DeviceSocketHub(DataService dataService, - IHubContext browserHub, + public DeviceSocketHub(DataService dataService, + IHubContext browserHub, IHubContext rcBrowserHub) { DataService = dataService; @@ -23,7 +25,8 @@ namespace Remotely.Server.Services RCBrowserHub = rcBrowserHub; } - public static ConcurrentDictionary ServiceConnections { get; } = new ConcurrentDictionary(); + public static ConcurrentDictionary ServiceConnections { get; } = new ConcurrentDictionary(); + public static IMemoryCache ApiScriptResults { get; } = new MemoryCache(new MemoryCacheOptions()); public IHubContext RCBrowserHub { get; } private IHubContext BrowserHub { get; } private DataService DataService { get; } @@ -44,6 +47,7 @@ namespace Remotely.Server.Services var commandContext = DataService.GetCommandContext(commandID); return BrowserHub.Clients.Client(commandContext.SenderConnectionID).SendAsync("BashResultViaAjax", commandID, Device.ID); } + public Task Chat(string message, string senderConnectionID) { return BrowserHub.Clients.Client(senderConnectionID).SendAsync("Chat", Device.DeviceName, message); @@ -55,7 +59,7 @@ namespace Remotely.Server.Services return BrowserHub.Clients.Client(commandContext.SenderConnectionID).SendAsync("CMDResultViaAjax", commandID, Device.ID); } - public Task CommandResult(GenericCommandResult result) + public Task CommandResult(GenericCommandResult result) { result.DeviceID = Device.ID; var commandContext = DataService.GetCommandContext(result.CommandContextID); @@ -64,6 +68,11 @@ namespace Remotely.Server.Services return BrowserHub.Clients.Client(commandContext.SenderConnectionID).SendAsync("CommandResult", result); } + public void CommandResultViaApi(string commandID, string requestID) + { + ApiScriptResults.Set(requestID, commandID, DateTimeOffset.Now.AddHours(1)); + } + public Task DeviceCameOnline(Device device) { try diff --git a/Server/Services/RCBrowserSocketHub.cs b/Server/Services/RCBrowserSocketHub.cs index 1a0496f4..e14fbe0e 100644 --- a/Server/Services/RCBrowserSocketHub.cs +++ b/Server/Services/RCBrowserSocketHub.cs @@ -166,7 +166,7 @@ namespace Remotely.Server.Services return Clients.Caller.SendAsync("SessionIDNotFound"); } - screenCasterID = RCDeviceSocketHub.SessionInfoList.First(x => x.Value.AttendedSessionID == screenCasterID).Value.RCSocketID; + screenCasterID = RCDeviceSocketHub.SessionInfoList.First(x => x.Value.AttendedSessionID == screenCasterID).Value.RCDeviceSocketID; } RCDeviceSocketHub.SessionInfoList.TryGetValue(screenCasterID, out var sessionInfo); @@ -188,6 +188,8 @@ namespace Remotely.Server.Services return Task.CompletedTask; } sessionInfo.OrganizationID = orgId; + sessionInfo.RequesterUserName = Context.User.Identity.Name; + sessionInfo.RequesterSocketID = Context.ConnectionId; } DataService.WriteEvent(new EventLog() diff --git a/Server/Services/RCDeviceSocketHub.cs b/Server/Services/RCDeviceSocketHub.cs index 527a452e..80629380 100644 --- a/Server/Services/RCDeviceSocketHub.cs +++ b/Server/Services/RCDeviceSocketHub.cs @@ -121,7 +121,7 @@ namespace Remotely.Server.Services { SessionInfo = new RCSessionInfo() { - RCSocketID = Context.ConnectionId, + RCDeviceSocketID = Context.ConnectionId, StartTime = DateTime.Now }; SessionInfoList.AddOrUpdate(Context.ConnectionId, SessionInfo, (id, si) => SessionInfo); diff --git a/Shared/Helpers/TaskHelper.cs b/Shared/Helpers/TaskHelper.cs new file mode 100644 index 00000000..e32db1a8 --- /dev/null +++ b/Shared/Helpers/TaskHelper.cs @@ -0,0 +1,21 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using System.Threading.Tasks; + +namespace Remotely.Shared.Helpers +{ + public static class TaskHelper + { + public static async Task DelayUntil(Func condition, TimeSpan timeout, int pollingMs = 10) + { + var sw = Stopwatch.StartNew(); + while (!condition() && sw.Elapsed < timeout) + { + await Task.Delay(pollingMs); + } + return condition(); + } + } +}