Merge pull request #418 from immense/app-db-factory

Create AppDbFactory.
This commit is contained in:
dkattan 2021-12-09 18:12:52 -06:00 committed by GitHub
commit d506c8da34
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 164 additions and 134 deletions

View File

@ -0,0 +1,35 @@
using Microsoft.Extensions.Configuration;
using Remotely.Server.Services;
using System;
namespace Remotely.Server.Data
{
public interface IAppDbFactory
{
AppDb GetContext();
}
public class AppDbFactory : IAppDbFactory
{
private readonly IApplicationConfig _appConfig;
private readonly IConfiguration _configuration;
public AppDbFactory(IApplicationConfig appConfig, IConfiguration configuration)
{
_appConfig = appConfig;
_configuration = configuration;
}
public AppDb GetContext()
{
return _appConfig.DBProvider.ToLower() switch
{
"sqlite" => new SqliteDbContext(_configuration),
"sqlserver" => new SqlServerDbContext(_configuration),
"postgresql" => new PostgreSqlDbContext(_configuration),
"inmemory" => new TestingDbContext(),
_ => throw new ArgumentException("Unknown DB provider."),
};
}
}
}

View File

@ -240,22 +240,22 @@ namespace Remotely.Server.Services
public class DataService : IDataService
{
private readonly IApplicationConfig _appConfig;
private readonly IConfiguration _configuration;
private readonly IHostEnvironment _hostEnvironment;
private readonly IAppDbFactory _appDbFactory;
public DataService(
IConfiguration configuration,
IApplicationConfig appConfig,
IHostEnvironment hostEnvironment)
IHostEnvironment hostEnvironment,
IAppDbFactory appDbFactory)
{
_configuration = configuration;
_appConfig = appConfig;
_hostEnvironment = hostEnvironment;
_appDbFactory = appDbFactory;
}
public async Task AddAlert(string deviceId, string organizationID, string alertMessage, string details = null)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var users = dbContext.Users
.Include(x => x.Alerts)
@ -289,7 +289,7 @@ namespace Remotely.Server.Services
public bool AddDeviceGroup(string orgID, DeviceGroup deviceGroup, out string deviceGroupID, out string errorMessage)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
deviceGroupID = null;
errorMessage = null;
@ -318,7 +318,7 @@ namespace Remotely.Server.Services
public InviteLink AddInvite(string orgID, InviteViewModel invite)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var organization = dbContext.Organizations
.Include(x => x.InviteLinks)
@ -340,7 +340,7 @@ namespace Remotely.Server.Services
public bool AddOrUpdateDevice(Device device, out Device updatedDevice)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var existingDevice = dbContext.Devices.Find(device.ID);
if (existingDevice != null)
@ -395,7 +395,7 @@ namespace Remotely.Server.Services
public async Task AddOrUpdateSavedScript(SavedScript script, string userId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.SavedScripts.Update(script);
script.CreatorId = userId;
@ -406,7 +406,7 @@ namespace Remotely.Server.Services
public void AddOrUpdateScriptResult(ScriptResult result)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var device = dbContext.Devices.Find(result.DeviceID);
@ -433,7 +433,7 @@ namespace Remotely.Server.Services
public async Task AddOrUpdateScriptSchedule(ScriptSchedule schedule)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var existingSchedule = await dbContext.ScriptSchedules
.Include(x => x.Creator)
@ -476,7 +476,7 @@ namespace Remotely.Server.Services
public async Task AddScriptResultToScriptRun(string scriptResultId, int scriptRunId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var run = await dbContext.ScriptRuns
.Include(x => x.Results)
@ -505,7 +505,7 @@ namespace Remotely.Server.Services
public async Task AddScriptRun(ScriptRun scriptRun)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.Attach(scriptRun);
dbContext.ScriptRuns.Add(scriptRun);
@ -539,7 +539,7 @@ namespace Remotely.Server.Services
public bool AddUserToDeviceGroup(string orgID, string groupID, string userName, out string resultMessage)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
resultMessage = string.Empty;
@ -587,7 +587,7 @@ namespace Remotely.Server.Services
public void ChangeUserIsAdmin(string organizationID, string targetUserID, bool isAdmin)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var targetUser = dbContext.Users.FirstOrDefault(x =>
x.OrganizationID == organizationID &&
@ -602,7 +602,7 @@ namespace Remotely.Server.Services
public void CleanupOldRecords()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
if (_appConfig.DataRetentionInDays > -1)
{
@ -644,7 +644,7 @@ namespace Remotely.Server.Services
public async Task ClearLogs(string currentUserName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var currentUser = await dbContext.Users.FirstOrDefaultAsync(x => x.UserName == currentUserName);
if (currentUser is null)
@ -674,7 +674,7 @@ namespace Remotely.Server.Services
public async Task<ApiToken> CreateApiToken(string userName, string tokenName, string secretHash)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var user = dbContext.Users.FirstOrDefault(x => x.UserName == userName);
@ -691,7 +691,7 @@ namespace Remotely.Server.Services
public async Task<Device> CreateDevice(DeviceSetupOptions options)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
try
{
@ -737,7 +737,7 @@ namespace Remotely.Server.Services
public async Task<bool> CreateUser(string userEmail, bool isAdmin, string organizationID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
try
{
@ -766,7 +766,7 @@ namespace Remotely.Server.Services
public async Task DeleteAlert(Alert alert)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.Alerts.Remove(alert);
await dbContext.SaveChangesAsync();
@ -774,7 +774,7 @@ namespace Remotely.Server.Services
public async Task DeleteAllAlerts(string orgID, string userName = null)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var alerts = dbContext.Alerts.Where(x => x.OrganizationID == orgID);
@ -791,7 +791,7 @@ namespace Remotely.Server.Services
public async Task DeleteApiToken(string userName, string tokenId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var user = dbContext.Users.FirstOrDefault(x => x.UserName == userName);
var token = dbContext.ApiTokens.FirstOrDefault(x =>
@ -804,7 +804,7 @@ namespace Remotely.Server.Services
public void DeleteDeviceGroup(string orgID, string deviceGroupID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var deviceGroup = dbContext.DeviceGroups
.Include(x => x.Devices)
@ -834,7 +834,7 @@ namespace Remotely.Server.Services
public void DeleteInvite(string orgID, string inviteID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var invite = dbContext.InviteLinks.FirstOrDefault(x =>
x.OrganizationID == orgID &&
@ -852,7 +852,7 @@ namespace Remotely.Server.Services
public async Task DeleteSavedScript(Guid scriptId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var script = dbContext.SavedScripts.Find(scriptId);
if (script is not null)
@ -864,7 +864,7 @@ namespace Remotely.Server.Services
public async Task DeleteScriptSchedule(int scriptScheduleId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var schedule = dbContext.ScriptSchedules.Find(scriptScheduleId);
if (schedule is not null)
@ -876,7 +876,7 @@ namespace Remotely.Server.Services
public async Task DeleteUser(string orgID, string targetUserID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var target = dbContext.Users
.Include(x => x.DeviceGroups)
@ -923,14 +923,14 @@ namespace Remotely.Server.Services
public void DetachEntity(object entity)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.Entry(entity).State = EntityState.Detached;
}
public void DeviceDisconnected(string deviceID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var device = dbContext.Devices.Find(deviceID);
if (device != null)
@ -943,7 +943,7 @@ namespace Remotely.Server.Services
public bool DoesUserExist(string userName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
if (string.IsNullOrWhiteSpace(userName))
{
@ -959,7 +959,7 @@ namespace Remotely.Server.Services
return false;
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices
.Include(x => x.DeviceGroup)
@ -975,7 +975,7 @@ namespace Remotely.Server.Services
public bool DoesUserHaveAccessToDevice(string deviceID, string remotelyUserID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var remotelyUser = dbContext.Users.Find(remotelyUserID);
@ -984,7 +984,7 @@ namespace Remotely.Server.Services
public string[] FilterDeviceIDsByUserPermission(string[] deviceIDs, RemotelyUser remotelyUser)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices
.Include(x => x.DeviceGroup)
@ -1003,14 +1003,14 @@ namespace Remotely.Server.Services
public string[] FilterUsersByDevicePermission(IEnumerable<string> userIDs, string deviceID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return FilterUsersByDevicePermissionInternal(dbContext, userIDs, deviceID);
}
public async Task<Alert> GetAlert(string alertID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.Alerts
.Include(x => x.Device)
@ -1020,7 +1020,7 @@ namespace Remotely.Server.Services
public Alert[] GetAlerts(string userID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Alerts
.Include(x => x.Device)
@ -1032,7 +1032,7 @@ namespace Remotely.Server.Services
public ApiToken[] GetAllApiTokens(string userID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var user = dbContext.Users.FirstOrDefault(x => x.Id == userID);
@ -1044,7 +1044,7 @@ namespace Remotely.Server.Services
public ScriptResult[] GetAllCommandResults(string orgID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.ScriptResults
.Where(x => x.OrganizationID == orgID)
@ -1054,7 +1054,7 @@ namespace Remotely.Server.Services
public ScriptResult[] GetAllCommandResultsForUser(string orgId, string userName, string deviceId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.ScriptResults
.Where(x => x.OrganizationID == orgId &&
@ -1066,14 +1066,14 @@ namespace Remotely.Server.Services
public Device[] GetAllDevices(string orgID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices.Where(x => x.OrganizationID == orgID).ToArray();
}
public EventLog[] GetAllEventLogs(string orgID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.EventLogs
.Where(x => x.OrganizationID == orgID)
@ -1083,7 +1083,7 @@ namespace Remotely.Server.Services
public InviteLink[] GetAllInviteLinks(string organizationId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.InviteLinks
.Where(x => x.OrganizationID == organizationId)
@ -1092,7 +1092,7 @@ namespace Remotely.Server.Services
public ScriptResult[] GetAllScriptResults(string orgId, string deviceId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.ScriptResults
.Where(x => x.OrganizationID == orgId && x.DeviceID == deviceId)
@ -1102,7 +1102,7 @@ namespace Remotely.Server.Services
public ScriptResult[] GetAllScriptResultsForUser(string orgId, string userName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.ScriptResults
.Where(x => x.OrganizationID == orgId && x.SenderUserName == userName)
@ -1112,7 +1112,7 @@ namespace Remotely.Server.Services
public RemotelyUser[] GetAllUsersForServer()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Users.ToArray();
}
@ -1124,7 +1124,7 @@ namespace Remotely.Server.Services
return Array.Empty<RemotelyUser>();
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var organization = await dbContext.Organizations
.Include(x => x.RemotelyUsers)
@ -1140,7 +1140,7 @@ namespace Remotely.Server.Services
return null;
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.ApiTokens.FirstOrDefault(x => x.ID == keyId);
}
@ -1152,7 +1152,7 @@ namespace Remotely.Server.Services
return null;
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var organization = await dbContext.Organizations
.Include(x => x.BrandingInfo)
@ -1173,14 +1173,14 @@ namespace Remotely.Server.Services
public async Task<Organization> GetDefaultOrganization()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.Organizations.FirstOrDefaultAsync(x => x.IsDefaultOrganization);
}
public async Task<string> GetDefaultRelayCode()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var relayCode = await dbContext.Organizations
.Where(x => x.IsDefaultOrganization)
@ -1192,7 +1192,7 @@ namespace Remotely.Server.Services
public Device GetDevice(string orgID, string deviceID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices.FirstOrDefault(x =>
x.OrganizationID == orgID &&
@ -1201,21 +1201,21 @@ namespace Remotely.Server.Services
public Device GetDevice(string deviceID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices.FirstOrDefault(x => x.ID == deviceID);
}
public int GetDeviceCount()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices.Count();
}
public int GetDeviceCount(RemotelyUser user)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices
.Include(x => x.DeviceGroup)
@ -1232,13 +1232,13 @@ namespace Remotely.Server.Services
public async Task<DeviceGroup> GetDeviceGroup(string deviceGroupID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.DeviceGroups.FindAsync(deviceGroupID);
}
public DeviceGroup[] GetDeviceGroups(string username)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var user = dbContext.Users.FirstOrDefault(x => x.UserName == username);
@ -1274,7 +1274,7 @@ namespace Remotely.Server.Services
public DeviceGroup[] GetDeviceGroupsForOrganization(string organizationId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.DeviceGroups
.Include(x => x.Users)
@ -1286,7 +1286,7 @@ namespace Remotely.Server.Services
public List<Device> GetDevices(IEnumerable<string> deviceIds)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices
.Where(x => deviceIds.Contains(x.ID))
@ -1295,7 +1295,7 @@ namespace Remotely.Server.Services
public Device[] GetDevicesForUser(string userName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
if (string.IsNullOrWhiteSpace(userName))
{
@ -1329,7 +1329,7 @@ namespace Remotely.Server.Services
public EventLog[] GetEventLogs(string userName, DateTimeOffset from, DateTimeOffset to, EventType? type, string message)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var user = dbContext.Users
.FirstOrDefault(x => x.UserName == userName);
@ -1368,14 +1368,14 @@ namespace Remotely.Server.Services
public Organization GetOrganizationById(string organizationID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Organizations.Find(organizationID);
}
public async Task<Organization> GetOrganizationByRelayCode(string relayCode)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
if (string.IsNullOrWhiteSpace(relayCode))
{
@ -1387,7 +1387,7 @@ namespace Remotely.Server.Services
public async Task<Organization> GetOrganizationByUserName(string userName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var user = await dbContext
.Users
@ -1399,21 +1399,21 @@ namespace Remotely.Server.Services
public int GetOrganizationCount()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Organizations.Count();
}
public string GetOrganizationNameById(string organizationID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Organizations.FirstOrDefault(x => x.ID == organizationID)?.OrganizationName;
}
public string GetOrganizationNameByUserName(string userName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Users
.Include(x => x.Organization)
@ -1424,7 +1424,7 @@ namespace Remotely.Server.Services
public async Task<List<ScriptRun>> GetPendingScriptRuns(string deviceId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var pendingRuns = new List<ScriptRun>();
@ -1460,7 +1460,7 @@ namespace Remotely.Server.Services
public async Task<List<SavedScript>> GetQuickScripts(string userId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.SavedScripts
.Where(x => x.CreatorId == userId && x.IsQuickScript)
@ -1469,7 +1469,7 @@ namespace Remotely.Server.Services
public async Task<SavedScript> GetSavedScript(string userId, Guid scriptId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.SavedScripts
.FirstOrDefaultAsync(x =>
@ -1479,13 +1479,13 @@ namespace Remotely.Server.Services
public async Task<SavedScript> GetSavedScript(Guid scriptId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.SavedScripts.FirstOrDefaultAsync(x => x.Id == scriptId);
}
public async Task<List<SavedScript>> GetSavedScriptsWithoutContent(string userId, string organizationId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var query = dbContext.SavedScripts
.Include(x => x.Creator)
@ -1507,7 +1507,7 @@ namespace Remotely.Server.Services
public ScriptResult GetScriptResult(string commandResultID, string orgID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.ScriptResults
.FirstOrDefault(x =>
@ -1517,14 +1517,14 @@ namespace Remotely.Server.Services
public ScriptResult GetScriptResult(string commandResultID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.ScriptResults.Find(commandResultID);
}
public async Task<List<ScriptSchedule>> GetScriptSchedules(string organizationId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.ScriptSchedules
.Include(x => x.Creator)
.Include(x => x.Devices)
@ -1535,7 +1535,7 @@ namespace Remotely.Server.Services
public async Task<List<ScriptSchedule>> GetScriptSchedulesDue()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var now = Time.Now;
@ -1549,7 +1549,7 @@ namespace Remotely.Server.Services
public List<string> GetServerAdmins()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Users
.Where(x => x.IsServerAdmin)
@ -1559,14 +1559,14 @@ namespace Remotely.Server.Services
public SharedFile GetSharedFiled(string fileID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.SharedFiles.Find(fileID);
}
public int GetTotalDevices()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Devices.Count();
}
@ -1577,7 +1577,7 @@ namespace Remotely.Server.Services
{
return null;
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return await dbContext.Users.FirstOrDefaultAsync(x => x.UserName == username);
}
@ -1588,7 +1588,7 @@ namespace Remotely.Server.Services
{
return null;
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Users.FirstOrDefault(x => x.Id == userID);
}
@ -1600,7 +1600,7 @@ namespace Remotely.Server.Services
return null;
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Users
.Include(x => x.Organization)
@ -1609,7 +1609,7 @@ namespace Remotely.Server.Services
public RemotelyUserOptions GetUserOptions(string userName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
return dbContext.Users
.FirstOrDefault(x => x.UserName == userName)
@ -1618,7 +1618,7 @@ namespace Remotely.Server.Services
public bool JoinViaInvitation(string userName, string inviteID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var invite = dbContext.InviteLinks.FirstOrDefault(x =>
x.InvitedUser.ToLower() == userName.ToLower() &&
@ -1648,7 +1648,7 @@ namespace Remotely.Server.Services
public void PopulateRelayCodes()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
foreach (var organization in dbContext.Organizations)
{
@ -1666,7 +1666,7 @@ namespace Remotely.Server.Services
public void RemoveDevices(string[] deviceIDs)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var devices = dbContext.Devices
.Include(x => x.ScriptResults)
@ -1683,7 +1683,7 @@ namespace Remotely.Server.Services
public async Task<bool> RemoveUserFromDeviceGroup(string orgID, string groupID, string userID)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var deviceGroup = dbContext.DeviceGroups
.Include(x => x.Users)
@ -1707,7 +1707,7 @@ namespace Remotely.Server.Services
public async Task RenameApiToken(string userName, string tokenId, string tokenName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var user = dbContext.Users.FirstOrDefault(x => x.UserName == userName);
var token = dbContext.ApiTokens.FirstOrDefault(x =>
@ -1720,7 +1720,7 @@ namespace Remotely.Server.Services
public async Task ResetBranding(string organizationId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var organization = await dbContext.Organizations
.Include(x => x.BrandingInfo)
@ -1738,7 +1738,7 @@ namespace Remotely.Server.Services
public void SetAllDevicesNotOnline()
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.Devices.ForEachAsync(x =>
{
@ -1749,7 +1749,7 @@ namespace Remotely.Server.Services
public async Task SetDisplayName(RemotelyUser user, string displayName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.Attach(user);
user.UserOptions.DisplayName = displayName;
@ -1758,7 +1758,7 @@ namespace Remotely.Server.Services
public async Task SetIsDefaultOrganization(string orgID, bool isDefault)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var organization = await dbContext.Organizations.FindAsync(orgID);
if (organization is null)
@ -1777,7 +1777,7 @@ namespace Remotely.Server.Services
public async Task SetIsServerAdmin(string targetUserId, bool isServerAdmin, string callerUserId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var caller = await dbContext.Users.FindAsync(callerUserId);
if (caller?.IsServerAdmin != true)
@ -1804,7 +1804,7 @@ namespace Remotely.Server.Services
public void SetServerVerificationToken(string deviceID, string verificationToken)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var device = dbContext.Devices.Find(deviceID);
if (device != null)
@ -1827,7 +1827,7 @@ namespace Remotely.Server.Services
{
return false;
}
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
if (user.TempPassword == password)
{
@ -1847,7 +1847,7 @@ namespace Remotely.Server.Services
ColorPickerModel titleBackground,
ColorPickerModel titleButtonForeground)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var organization = await dbContext.Organizations
.Include(x => x.BrandingInfo)
@ -1887,7 +1887,7 @@ namespace Remotely.Server.Services
public void UpdateDevice(string deviceID, string tag, string alias, string deviceGroupID, string notes, WebRtcSetting webRtcSetting)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var device = dbContext.Devices
.Include(x => x.DeviceGroup)
@ -1917,7 +1917,7 @@ namespace Remotely.Server.Services
public async Task<Device> UpdateDevice(DeviceSetupOptions deviceOptions, string organizationId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var device = dbContext.Devices.Find(deviceOptions.DeviceID);
if (device == null || device.OrganizationID != organizationId)
@ -1937,7 +1937,7 @@ namespace Remotely.Server.Services
public void UpdateOrganizationName(string orgID, string organizationName)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.Organizations
.FirstOrDefault(x => x.ID == orgID)
@ -1947,7 +1947,7 @@ namespace Remotely.Server.Services
public void UpdateTags(string deviceID, string tags)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var device = dbContext.Devices.Find(deviceID);
if (device == null)
@ -1961,7 +1961,7 @@ namespace Remotely.Server.Services
public void UpdateUserOptions(string userName, RemotelyUserOptions options)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.Users.FirstOrDefault(x => x.UserName == userName).UserOptions = options;
dbContext.SaveChanges();
@ -1969,7 +1969,7 @@ namespace Remotely.Server.Services
public bool ValidateApiKey(string keyId, string apiSecret, string requestPath, string remoteIP)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var hasher = new PasswordHasher<RemotelyUser>();
var token = dbContext.ApiTokens.FirstOrDefault(x => x.ID == keyId);
@ -1992,7 +1992,7 @@ namespace Remotely.Server.Services
{
try
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.EventLogs.Add(eventLog);
dbContext.SaveChanges();
@ -2004,7 +2004,7 @@ namespace Remotely.Server.Services
{
try
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.EventLogs.Add(new EventLog()
{
@ -2029,7 +2029,7 @@ namespace Remotely.Server.Services
{
try
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
dbContext.EventLogs.Add(new EventLog()
{
@ -2054,7 +2054,7 @@ namespace Remotely.Server.Services
try
{
// TODO: Refactor EventLog to resemble these params. Replace WriteEvent with ILogger<T>.
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
EventType eventType = EventType.Debug;
switch (logLevel)
@ -2101,7 +2101,7 @@ namespace Remotely.Server.Services
string contentType,
string organizationId)
{
using var dbContext = GetDbContext();
using var dbContext = _appDbFactory.GetContext();
var expirationDate = DateTimeOffset.Now.AddDays(-_appConfig.DataRetentionInDays);
var expiredFiles = dbContext.SharedFiles.Where(x => x.Timestamp < expirationDate);
@ -2155,24 +2155,5 @@ namespace Remotely.Server.Services
.Select(x => x.Id)
.ToArray();
}
private AppDb GetDbContext()
{
switch (_appConfig.DBProvider.ToLower())
{
case "sqlite":
return new SqliteDbContext(_configuration);
case "sqlserver":
return new SqlServerDbContext(_configuration);
case "postgresql":
return new PostgreSqlDbContext(_configuration);
case "inmemory":
return new TestingDbContext();
default:
throw new ArgumentException("Unknown DB provider.");
}
}
}
}

View File

@ -53,7 +53,6 @@ namespace Remotely.Server
{
options.UseSqlite(Configuration.GetConnectionString("SQLite"));
});
}
else if (dbProvider == "sqlserver")
{
@ -158,6 +157,7 @@ namespace Remotely.Server
services.AddLogging();
services.AddScoped<IEmailSenderEx, EmailSenderEx>();
services.AddScoped<IEmailSender, EmailSender>();
services.AddScoped<IAppDbFactory, AppDbFactory>();
services.AddTransient<IDataService, DataService>();
services.AddScoped<IApplicationConfig, ApplicationConfig>();
services.AddScoped<ApiAuthorizationFilter>();

View File

@ -1,5 +1,6 @@
using Microsoft.Extensions.DependencyInjection;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
using Remotely.Agent.Interfaces;
using Remotely.Server.Services;
using Remotely.Shared.Models;
@ -15,8 +16,9 @@ namespace Remotely.Tests
public class DataServiceTests
{
private IDataService _dataService;
private IDeviceInformationService _deviceInfo;
private Mock<IDeviceInformationService> _deviceInfo;
private TestData _testData;
private string _newDeviceID = "NewDeviceName";
[TestMethod]
public async Task AddAlert()
@ -31,17 +33,16 @@ namespace Remotely.Tests
[TestMethod]
public async Task AddOrUpdateDevice()
{
var newDeviceID = "NewDeviceName";
var storedDevice = _dataService.GetDevice(newDeviceID);
var storedDevice = _dataService.GetDevice(_newDeviceID);
Assert.IsNull(storedDevice);
var newDevice = await _deviceInfo.CreateDevice(newDeviceID, _testData.OrganizationID);
var newDevice = await _deviceInfo.Object.CreateDevice(_newDeviceID, _testData.OrganizationID);
Assert.IsTrue(_dataService.AddOrUpdateDevice(newDevice, out _));
storedDevice = _dataService.GetDevice(newDeviceID);
storedDevice = _dataService.GetDevice(_newDeviceID);
Assert.AreEqual(newDeviceID, storedDevice.ID);
Assert.AreEqual(_newDeviceID, storedDevice.ID);
Assert.AreEqual(Environment.MachineName, storedDevice.DeviceName);
Assert.AreEqual(Environment.Is64BitOperatingSystem, storedDevice.Is64Bit);
}
@ -171,7 +172,19 @@ namespace Remotely.Tests
{
_testData = new TestData();
_dataService = IoCActivator.ServiceProvider.GetRequiredService<IDataService>();
_deviceInfo = IoCActivator.ServiceProvider.GetRequiredService<IDeviceInformationService>();
var newDevice = new Device()
{
ID = _newDeviceID,
DeviceName = Environment.MachineName,
Is64Bit = Environment.Is64BitOperatingSystem,
OrganizationID = _testData.OrganizationID
};
_deviceInfo = new Mock<IDeviceInformationService>();
_deviceInfo
.Setup(x => x.CreateDevice(_newDeviceID, _testData.OrganizationID))
.Returns(Task.FromResult(newDevice));
}
[TestMethod]

View File

@ -64,6 +64,7 @@ namespace Remotely.Tests
.AddDefaultUI()
.AddDefaultTokenProviders();
services.AddTransient<IAppDbFactory, AppDbFactory>();
services.AddTransient<IDataService, DataService>();
services.AddTransient<IApplicationConfig, ApplicationConfig>();
services.AddTransient<IEmailSenderEx, EmailSenderEx>();