diff --git a/App/App.xaml.cs b/App/App.xaml.cs index 2cdee97..ba6fa67 100644 --- a/App/App.xaml.cs +++ b/App/App.xaml.cs @@ -41,6 +41,7 @@ public partial class App : Application #endif private readonly ILogger<App> _logger; + private readonly IUriHandler _uriHandler; public App() { @@ -72,6 +73,8 @@ public App() .Bind(builder.Configuration.GetSection(MutagenControllerConfigSection)); services.AddSingleton<ISyncSessionController, MutagenController>(); services.AddSingleton<IUserNotifier, UserNotifier>(); + services.AddSingleton<IRdpConnector, RdpConnector>(); + services.AddSingleton<IUriHandler, UriHandler>(); // SignInWindow views and view models services.AddTransient<SignInViewModel>(); @@ -98,6 +101,7 @@ public App() _services = services.BuildServiceProvider(); _logger = (ILogger<App>)_services.GetService(typeof(ILogger<App>))!; + _uriHandler = (IUriHandler)_services.GetService(typeof(IUriHandler))!; InitializeComponent(); } @@ -190,7 +194,19 @@ public void OnActivated(object? sender, AppActivationArguments args) _logger.LogWarning("URI activation with null data"); return; } - HandleURIActivation(protoArgs.Uri); + + // don't need to wait for it to complete. + _uriHandler.HandleUri(protoArgs.Uri).ContinueWith(t => + { + if (t.Exception != null) + { + // don't log query params, as they contain secrets. + _logger.LogError(t.Exception, + "unhandled exception while processing URI coder://{authority}{path}", + protoArgs.Uri.Authority, protoArgs.Uri.AbsolutePath); + } + }); + break; case ExtendedActivationKind.AppNotification: @@ -204,12 +220,6 @@ public void OnActivated(object? sender, AppActivationArguments args) } } - public void HandleURIActivation(Uri uri) - { - // don't log the query string as that's where we include some sensitive information like passwords - _logger.LogInformation("handling URI activation for {path}", uri.AbsolutePath); - } - public void HandleNotification(AppNotificationManager? sender, AppNotificationActivatedEventArgs args) { // right now, we don't do anything other than log diff --git a/App/Services/CredentialManager.cs b/App/Services/CredentialManager.cs index a2f6567..280169c 100644 --- a/App/Services/CredentialManager.cs +++ b/App/Services/CredentialManager.cs @@ -307,7 +307,7 @@ public WindowsCredentialBackend(string credentialsTargetName) public Task<RawCredentials?> ReadCredentials(CancellationToken ct = default) { - var raw = NativeApi.ReadCredentials(_credentialsTargetName); + var raw = Wincred.ReadCredentials(_credentialsTargetName); if (raw == null) return Task.FromResult<RawCredentials?>(null); RawCredentials? credentials; @@ -326,115 +326,179 @@ public WindowsCredentialBackend(string credentialsTargetName) public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default) { var raw = JsonSerializer.Serialize(credentials, RawCredentialsJsonContext.Default.RawCredentials); - NativeApi.WriteCredentials(_credentialsTargetName, raw); + Wincred.WriteCredentials(_credentialsTargetName, raw); return Task.CompletedTask; } public Task DeleteCredentials(CancellationToken ct = default) { - NativeApi.DeleteCredentials(_credentialsTargetName); + Wincred.DeleteCredentials(_credentialsTargetName); return Task.CompletedTask; } - private static class NativeApi +} + +/// <summary> +/// Wincred provides relatively low level wrapped calls to the Wincred.h native API. +/// </summary> +internal static class Wincred +{ + private const int CredentialTypeGeneric = 1; + private const int CredentialTypeDomainPassword = 2; + private const int PersistenceTypeLocalComputer = 2; + private const int ErrorNotFound = 1168; + private const int CredMaxCredentialBlobSize = 5 * 512; + private const string PackageNTLM = "NTLM"; + + public static string? ReadCredentials(string targetName) { - private const int CredentialTypeGeneric = 1; - private const int PersistenceTypeLocalComputer = 2; - private const int ErrorNotFound = 1168; - private const int CredMaxCredentialBlobSize = 5 * 512; + if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr)) + { + var error = Marshal.GetLastWin32Error(); + if (error == ErrorNotFound) return null; + throw new InvalidOperationException($"Failed to read credentials (Error {error})"); + } - public static string? ReadCredentials(string targetName) + try { - if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr)) - { - var error = Marshal.GetLastWin32Error(); - if (error == ErrorNotFound) return null; - throw new InvalidOperationException($"Failed to read credentials (Error {error})"); - } + var cred = Marshal.PtrToStructure<CREDENTIALW>(credentialPtr); + return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char)); + } + finally + { + CredFree(credentialPtr); + } + } - try - { - var cred = Marshal.PtrToStructure<CREDENTIAL>(credentialPtr); - return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char)); - } - finally + public static void WriteCredentials(string targetName, string secret) + { + var byteCount = Encoding.Unicode.GetByteCount(secret); + if (byteCount > CredMaxCredentialBlobSize) + throw new ArgumentOutOfRangeException(nameof(secret), + $"The secret is greater than {CredMaxCredentialBlobSize} bytes"); + + var credentialBlob = Marshal.StringToHGlobalUni(secret); + var cred = new CREDENTIALW + { + Type = CredentialTypeGeneric, + TargetName = targetName, + CredentialBlobSize = byteCount, + CredentialBlob = credentialBlob, + Persist = PersistenceTypeLocalComputer, + }; + try + { + if (!CredWriteW(ref cred, 0)) { - CredFree(credentialPtr); + var error = Marshal.GetLastWin32Error(); + throw new InvalidOperationException($"Failed to write credentials (Error {error})"); } } - - public static void WriteCredentials(string targetName, string secret) + finally { - var byteCount = Encoding.Unicode.GetByteCount(secret); - if (byteCount > CredMaxCredentialBlobSize) - throw new ArgumentOutOfRangeException(nameof(secret), - $"The secret is greater than {CredMaxCredentialBlobSize} bytes"); + Marshal.FreeHGlobal(credentialBlob); + } + } - var credentialBlob = Marshal.StringToHGlobalUni(secret); - var cred = new CREDENTIAL - { - Type = CredentialTypeGeneric, - TargetName = targetName, - CredentialBlobSize = byteCount, - CredentialBlob = credentialBlob, - Persist = PersistenceTypeLocalComputer, - }; - try - { - if (!CredWriteW(ref cred, 0)) - { - var error = Marshal.GetLastWin32Error(); - throw new InvalidOperationException($"Failed to write credentials (Error {error})"); - } - } - finally - { - Marshal.FreeHGlobal(credentialBlob); - } + public static void DeleteCredentials(string targetName) + { + if (!CredDeleteW(targetName, CredentialTypeGeneric, 0)) + { + var error = Marshal.GetLastWin32Error(); + if (error == ErrorNotFound) return; + throw new InvalidOperationException($"Failed to delete credentials (Error {error})"); } + } + + public static void WriteDomainCredentials(string domainName, string serverName, string username, string password) + { + var targetName = $"{domainName}/{serverName}"; + var targetInfo = new CREDENTIAL_TARGET_INFORMATIONW + { + TargetName = targetName, + DnsServerName = serverName, + DnsDomainName = domainName, + PackageName = PackageNTLM, + }; + var byteCount = Encoding.Unicode.GetByteCount(password); + if (byteCount > CredMaxCredentialBlobSize) + throw new ArgumentOutOfRangeException(nameof(password), + $"The secret is greater than {CredMaxCredentialBlobSize} bytes"); - public static void DeleteCredentials(string targetName) + var credentialBlob = Marshal.StringToHGlobalUni(password); + var cred = new CREDENTIALW { - if (!CredDeleteW(targetName, CredentialTypeGeneric, 0)) + Type = CredentialTypeDomainPassword, + TargetName = targetName, + CredentialBlobSize = byteCount, + CredentialBlob = credentialBlob, + Persist = PersistenceTypeLocalComputer, + UserName = username, + }; + try + { + if (!CredWriteDomainCredentialsW(ref targetInfo, ref cred, 0)) { var error = Marshal.GetLastWin32Error(); - if (error == ErrorNotFound) return; - throw new InvalidOperationException($"Failed to delete credentials (Error {error})"); + throw new InvalidOperationException($"Failed to write credentials (Error {error})"); } } + finally + { + Marshal.FreeHGlobal(credentialBlob); + } + } - [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] - private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr); + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr); - [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] - private static extern bool CredWriteW([In] ref CREDENTIAL userCredential, [In] uint flags); + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredWriteW([In] ref CREDENTIALW userCredential, [In] uint flags); - [DllImport("Advapi32.dll", SetLastError = true)] - private static extern void CredFree([In] IntPtr cred); + [DllImport("Advapi32.dll", SetLastError = true)] + private static extern void CredFree([In] IntPtr cred); - [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] - private static extern bool CredDeleteW(string target, int type, int flags); + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredDeleteW(string target, int type, int flags); - [StructLayout(LayoutKind.Sequential)] - private struct CREDENTIAL - { - public int Flags; - public int Type; + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredWriteDomainCredentialsW([In] ref CREDENTIAL_TARGET_INFORMATIONW target, [In] ref CREDENTIALW userCredential, [In] uint flags); - [MarshalAs(UnmanagedType.LPWStr)] public string TargetName; + [StructLayout(LayoutKind.Sequential)] + private struct CREDENTIALW + { + public int Flags; + public int Type; - [MarshalAs(UnmanagedType.LPWStr)] public string Comment; + [MarshalAs(UnmanagedType.LPWStr)] public string TargetName; - public long LastWritten; - public int CredentialBlobSize; - public IntPtr CredentialBlob; - public int Persist; - public int AttributeCount; - public IntPtr Attributes; + [MarshalAs(UnmanagedType.LPWStr)] public string Comment; - [MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias; + public long LastWritten; + public int CredentialBlobSize; + public IntPtr CredentialBlob; + public int Persist; + public int AttributeCount; + public IntPtr Attributes; - [MarshalAs(UnmanagedType.LPWStr)] public string UserName; - } + [MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias; + + [MarshalAs(UnmanagedType.LPWStr)] public string UserName; + } + + [StructLayout(LayoutKind.Sequential)] + private struct CREDENTIAL_TARGET_INFORMATIONW + { + [MarshalAs(UnmanagedType.LPWStr)] public string TargetName; + [MarshalAs(UnmanagedType.LPWStr)] public string NetbiosServerName; + [MarshalAs(UnmanagedType.LPWStr)] public string DnsServerName; + [MarshalAs(UnmanagedType.LPWStr)] public string NetbiosDomainName; + [MarshalAs(UnmanagedType.LPWStr)] public string DnsDomainName; + [MarshalAs(UnmanagedType.LPWStr)] public string DnsTreeName; + [MarshalAs(UnmanagedType.LPWStr)] public string PackageName; + + public uint Flags; + public uint CredTypeCount; + public IntPtr CredTypes; } } diff --git a/App/Services/RdpConnector.cs b/App/Services/RdpConnector.cs new file mode 100644 index 0000000..a48d0ac --- /dev/null +++ b/App/Services/RdpConnector.cs @@ -0,0 +1,76 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.App.Services; + +public struct RdpCredentials(string username, string password) +{ + public readonly string Username = username; + public readonly string Password = password; +} + +public interface IRdpConnector +{ + public const int DefaultPort = 3389; + + public void WriteCredentials(string fqdn, RdpCredentials credentials); + + public Task Connect(string fqdn, int port = DefaultPort, CancellationToken ct = default); +} + +public class RdpConnector(ILogger<RdpConnector> logger) : IRdpConnector +{ + // Remote Desktop always uses TERMSRV as the domain; RDP is a part of Windows "Terminal Services". + private const string RdpDomain = "TERMSRV"; + + public void WriteCredentials(string fqdn, RdpCredentials credentials) + { + // writing credentials is idempotent for the same domain and server name. + Wincred.WriteDomainCredentials(RdpDomain, fqdn, credentials.Username, credentials.Password); + logger.LogDebug("wrote domain credential for {serverName} with username {username}", fqdn, + credentials.Username); + return; + } + + public Task Connect(string fqdn, int port = IRdpConnector.DefaultPort, CancellationToken ct = default) + { + // use mstsc to launch the RDP connection + var mstscProc = new Process(); + mstscProc.StartInfo.FileName = "mstsc"; + var args = $"/v {fqdn}"; + if (port != IRdpConnector.DefaultPort) + { + args = $"/v {fqdn}:{port}"; + } + + mstscProc.StartInfo.Arguments = args; + mstscProc.StartInfo.CreateNoWindow = true; + mstscProc.StartInfo.UseShellExecute = false; + try + { + if (!mstscProc.Start()) + throw new InvalidOperationException("Failed to start mstsc, Start returned false"); + } + catch (Exception e) + { + logger.LogWarning(e, "mstsc failed to start"); + + try + { + mstscProc.Kill(); + } + catch + { + // ignored, the process likely doesn't exist + } + + mstscProc.Dispose(); + throw; + } + + return mstscProc.WaitForExitAsync(ct); + } +} diff --git a/App/Services/UriHandler.cs b/App/Services/UriHandler.cs new file mode 100644 index 0000000..b0b0a9a --- /dev/null +++ b/App/Services/UriHandler.cs @@ -0,0 +1,152 @@ +using System; +using System.Collections.Specialized; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using System.Web; +using Coder.Desktop.App.Models; +using Coder.Desktop.Vpn.Proto; +using Microsoft.Extensions.Logging; + + +namespace Coder.Desktop.App.Services; + +public interface IUriHandler +{ + public Task HandleUri(Uri uri, CancellationToken ct = default); +} + +public class UriHandler( + ILogger<UriHandler> logger, + IRpcController rpcController, + IUserNotifier userNotifier, + IRdpConnector rdpConnector) : IUriHandler +{ + private const string OpenWorkspacePrefix = "/v0/open/ws/"; + + internal class UriException : Exception + { + internal readonly string Title; + internal readonly string Detail; + + internal UriException(string title, string detail) : base($"{title}: {detail}") + { + Title = title; + Detail = detail; + } + } + + public async Task HandleUri(Uri uri, CancellationToken ct = default) + { + try + { + await HandleUriThrowingErrors(uri, ct); + } + catch (UriException e) + { + await userNotifier.ShowErrorNotification(e.Title, e.Detail, ct); + } + } + + private async Task HandleUriThrowingErrors(Uri uri, CancellationToken ct = default) + { + if (uri.AbsolutePath.StartsWith(OpenWorkspacePrefix)) + { + await HandleOpenWorkspaceApp(uri, ct); + return; + } + + logger.LogWarning("unhandled URI path {path}", uri.AbsolutePath); + throw new UriException("URI handling error", + $"URI with path '{uri.AbsolutePath}' is unsupported or malformed"); + } + + public async Task HandleOpenWorkspaceApp(Uri uri, CancellationToken ct = default) + { + const string errTitle = "Open Workspace Application Error"; + var subpath = uri.AbsolutePath[OpenWorkspacePrefix.Length..]; + var components = subpath.Split("/"); + if (components.Length != 4 || components[1] != "agent") + { + logger.LogWarning("unsupported open workspace app format in URI {path}", uri.AbsolutePath); + throw new UriException(errTitle, $"Failed to open '{uri.AbsolutePath}' because the format is unsupported."); + } + + var workspaceName = components[0]; + var agentName = components[2]; + var appName = components[3]; + + var state = rpcController.GetState(); + if (state.VpnLifecycle != VpnLifecycle.Started) + { + logger.LogDebug("got URI to open workspace '{workspace}', but Coder Connect is not started", workspaceName); + throw new UriException(errTitle, + $"Failed to open application on '{workspaceName}' because Coder Connect is not started."); + } + + var workspace = state.Workspaces.FirstOrDefault(w => w.Name == workspaceName); + if (workspace == null) + { + logger.LogDebug("got URI to open workspace '{workspace}', but the workspace doesn't exist", workspaceName); + throw new UriException(errTitle, + $"Failed to open application on workspace '{workspaceName}' because it doesn't exist"); + } + + var agent = state.Agents.FirstOrDefault(a => a.WorkspaceId == workspace.Id && a.Name == agentName); + if (agent == null) + { + logger.LogDebug( + "got URI to open workspace/agent '{workspaceName}/{agentName}', but the agent doesn't exist", + workspaceName, agentName); + // If the workspace isn't running, that is almost certainly why we can't find the agent, so report that + // to the user. + if (workspace.Status != Workspace.Types.Status.Running) + { + throw new UriException(errTitle, + $"Failed to open application on workspace '{workspaceName}', because the workspace is not running."); + } + + throw new UriException(errTitle, + $"Failed to open application on workspace '{workspaceName}', because agent '{agentName}' doesn't exist."); + } + + if (appName != "rdp") + { + logger.LogWarning("unsupported agent application type {app}", appName); + throw new UriException(errTitle, + $"Failed to open agent in URI '{uri.AbsolutePath}' because application '{appName}' is unsupported"); + } + + await OpenRDP(agent.Fqdn.First(), uri.Query, ct); + } + + public async Task OpenRDP(string domainName, string queryString, CancellationToken ct = default) + { + const string errTitle = "Workspace Remote Desktop Error"; + NameValueCollection query; + try + { + query = HttpUtility.ParseQueryString(queryString); + } + catch (Exception ex) + { + // unfortunately, we can't safely write they query string to logs because it might contain + // sensitive info like a password. This is also why we don't log the exception directly + var trace = new System.Diagnostics.StackTrace(ex, false); + logger.LogWarning("failed to parse open RDP query string: {classMethod}", + trace?.GetFrame(0)?.GetMethod()?.ReflectedType?.FullName); + throw new UriException(errTitle, + "Failed to open remote desktop on a workspace because the URI was malformed"); + } + + var username = query.Get("username"); + var password = query.Get("password"); + if (!string.IsNullOrEmpty(username)) + { + password ??= string.Empty; + rdpConnector.WriteCredentials(domainName, new RdpCredentials(username, password)); + } + + await rdpConnector.Connect(domainName, ct: ct); + } +} diff --git a/App/Services/UserNotifier.cs b/App/Services/UserNotifier.cs index 9cdf6c1..9150f47 100644 --- a/App/Services/UserNotifier.cs +++ b/App/Services/UserNotifier.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.Windows.AppNotifications; using Microsoft.Windows.AppNotifications.Builder; @@ -7,7 +8,7 @@ namespace Coder.Desktop.App.Services; public interface IUserNotifier : IAsyncDisposable { - public Task ShowErrorNotification(string title, string message); + public Task ShowErrorNotification(string title, string message, CancellationToken ct = default); } public class UserNotifier : IUserNotifier @@ -19,7 +20,7 @@ public ValueTask DisposeAsync() return ValueTask.CompletedTask; } - public Task ShowErrorNotification(string title, string message) + public Task ShowErrorNotification(string title, string message, CancellationToken ct = default) { var builder = new AppNotificationBuilder().AddText(title).AddText(message); _notificationManager.Show(builder.BuildNotification()); diff --git a/Tests.App/Services/RdpConnectorTest.cs b/Tests.App/Services/RdpConnectorTest.cs new file mode 100644 index 0000000..b4a870e --- /dev/null +++ b/Tests.App/Services/RdpConnectorTest.cs @@ -0,0 +1,27 @@ +using Coder.Desktop.App.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Serilog; + +namespace Coder.Desktop.Tests.App.Services; + +[TestFixture] +public class RdpConnectorTest +{ + [Test(Description = "Spawns RDP for real")] + [Ignore("Comment out to run manually")] + [CancelAfter(30_000)] + public async Task ConnectToRdp(CancellationToken ct) + { + var builder = Host.CreateApplicationBuilder(); + builder.Services.AddSerilog(); + builder.Services.AddSingleton<IRdpConnector, RdpConnector>(); + var services = builder.Services.BuildServiceProvider(); + + var rdpConnector = (RdpConnector)services.GetService<IRdpConnector>()!; + var creds = new RdpCredentials("Administrator", "coderRDP!"); + var workspace = "myworkspace.coder"; + rdpConnector.WriteCredentials(workspace, creds); + await rdpConnector.Connect(workspace, ct: ct); + } +} diff --git a/Tests.App/Services/UriHandlerTest.cs b/Tests.App/Services/UriHandlerTest.cs new file mode 100644 index 0000000..65c886c --- /dev/null +++ b/Tests.App/Services/UriHandlerTest.cs @@ -0,0 +1,178 @@ +using Coder.Desktop.App.Models; +using Coder.Desktop.App.Services; +using Coder.Desktop.Vpn.Proto; +using Google.Protobuf; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Moq; +using Serilog; + +namespace Coder.Desktop.Tests.App.Services; + +[TestFixture] +public class UriHandlerTest +{ + [SetUp] + public void SetupMocksAndUriHandler() + { + Serilog.Log.Logger = new LoggerConfiguration().MinimumLevel.Debug().WriteTo.NUnitOutput().CreateLogger(); + var builder = Host.CreateApplicationBuilder(); + builder.Services.AddSerilog(); + var logger = (ILogger<UriHandler>)builder.Build().Services.GetService(typeof(ILogger<UriHandler>))!; + + _mUserNotifier = new Mock<IUserNotifier>(MockBehavior.Strict); + _mRdpConnector = new Mock<IRdpConnector>(MockBehavior.Strict); + _mRpcController = new Mock<IRpcController>(MockBehavior.Strict); + + uriHandler = new UriHandler(logger, _mRpcController.Object, _mUserNotifier.Object, _mRdpConnector.Object); + } + + private Mock<IUserNotifier> _mUserNotifier; + private Mock<IRdpConnector> _mRdpConnector; + private Mock<IRpcController> _mRpcController; + private UriHandler uriHandler; // Unit under test. + + [SetUp] + public void AgentAndWorkspaceFixtures() + { + agent11 = new Agent(); + agent11.Fqdn.Add("workspace1.coder"); + agent11.Id = ByteString.CopyFrom(0x1, 0x1); + agent11.WorkspaceId = ByteString.CopyFrom(0x1, 0x0); + agent11.Name = "agent11"; + + workspace1 = new Workspace + { + Id = ByteString.CopyFrom(0x1, 0x0), + Name = "workspace1", + Status = Workspace.Types.Status.Running, + }; + + modelWithWorkspace1 = new RpcModel + { + VpnLifecycle = VpnLifecycle.Started, + Workspaces = [workspace1], + Agents = [agent11], + }; + } + + private Agent agent11; + private Workspace workspace1; + private RpcModel modelWithWorkspace1; + + [Test(Description = "Open RDP with username & password")] + [CancelAfter(30_000)] + public async Task Mainline(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/rdp?username=testy&password=sesame"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + var expectedCred = new RdpCredentials("testy", "sesame"); + _ = _mRdpConnector.Setup(m => m.WriteCredentials(agent11.Fqdn[0], expectedCred)); + _ = _mRdpConnector.Setup(m => m.Connect(agent11.Fqdn[0], IRdpConnector.DefaultPort, ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Open RDP with no credentials")] + [CancelAfter(30_000)] + public async Task NoCredentials(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _ = _mRdpConnector.Setup(m => m.Connect(agent11.Fqdn[0], IRdpConnector.DefaultPort, ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown app slug")] + [CancelAfter(30_000)] + public async Task UnknownApp(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/someapp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny<string>(), It.IsRegex("someapp"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown agent name")] + [CancelAfter(30_000)] + public async Task UnknownAgent(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/wrongagent/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny<string>(), It.IsRegex("wrongagent"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown workspace name")] + [CancelAfter(30_000)] + public async Task UnknownWorkspace(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/wrongworkspace/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny<string>(), It.IsRegex("wrongworkspace"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Malformed Query String")] + [CancelAfter(30_000)] + public async Task MalformedQuery(CancellationToken ct) + { + // there might be some query string that gets the parser to throw an exception, but I could not find one. + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/rdp?%&##"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + // treated the same as if we just didn't include credentials + _ = _mRdpConnector.Setup(m => m.Connect(agent11.Fqdn[0], IRdpConnector.DefaultPort, ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "VPN not started")] + [CancelAfter(30_000)] + public async Task VPNNotStarted(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/wrongworkspace/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(new RpcModel + { + VpnLifecycle = VpnLifecycle.Starting, + }); + // Coder Connect is the user facing name, so make sure the error mentions it. + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny<string>(), It.IsRegex("Coder Connect"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Wrong number of components")] + [CancelAfter(30_000)] + public async Task UnknownNumComponents(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/wrongworkspace/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny<string>(), It.IsAny<string>(), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown prefix")] + [CancelAfter(30_000)] + public async Task UnknownPrefix(CancellationToken ct) + { + var input = new Uri("coder:/v300/open/ws/workspace1/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny<string>(), It.IsAny<string>(), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } +} diff --git a/Tests.App/Tests.App.csproj b/Tests.App/Tests.App.csproj index cc01512..e20eba1 100644 --- a/Tests.App/Tests.App.csproj +++ b/Tests.App/Tests.App.csproj @@ -26,6 +26,8 @@ <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> </PackageReference> <PackageReference Include="NUnit3TestAdapter" Version="4.6.0" /> + <PackageReference Include="Serilog.Extensions.Hosting" Version="9.0.0" /> + <PackageReference Include="Serilog.Sinks.NUnit" Version="1.0.3" /> </ItemGroup> <ItemGroup>