From 5378529b9869f3bd747343ec083bf610f441e20a Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 1 May 2025 16:50:07 +0400 Subject: [PATCH 1/3] feat: add support for RDP URIs --- App/App.xaml.cs | 23 ++- App/Services/CredentialManager.cs | 218 ++++++++++++++++--------- App/Services/RdpConnector.cs | 81 +++++++++ App/Services/UriHandler.cs | 150 +++++++++++++++++ App/Services/UserNotifier.cs | 5 +- Tests.App/Services/RdpConnectorTest.cs | 27 +++ Tests.App/Services/UriHandlerTest.cs | 184 +++++++++++++++++++++ Tests.App/Tests.App.csproj | 2 + 8 files changed, 604 insertions(+), 86 deletions(-) create mode 100644 App/Services/RdpConnector.cs create mode 100644 App/Services/UriHandler.cs create mode 100644 Tests.App/Services/RdpConnectorTest.cs create mode 100644 Tests.App/Services/UriHandlerTest.cs diff --git a/App/App.xaml.cs b/App/App.xaml.cs index 2cdee97..a984778 100644 --- a/App/App.xaml.cs +++ b/App/App.xaml.cs @@ -41,6 +41,7 @@ public partial class App : Application #endif private readonly ILogger _logger; + private readonly IUriHandler _uriHandler; public App() { @@ -72,6 +73,8 @@ public App() .Bind(builder.Configuration.GetSection(MutagenControllerConfigSection)); services.AddSingleton(); services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(); // SignInWindow views and view models services.AddTransient(); @@ -98,6 +101,7 @@ public App() _services = services.BuildServiceProvider(); _logger = (ILogger)_services.GetService(typeof(ILogger))!; + _uriHandler = (IUriHandler)_services.GetService(typeof(IUriHandler))!; InitializeComponent(); } @@ -190,7 +194,18 @@ public void OnActivated(object? sender, AppActivationArguments args) _logger.LogWarning("URI activation with null data"); return; } - HandleURIActivation(protoArgs.Uri); + + try + { + // don't need to wait for it to complete. + _ = _uriHandler.HandleUri(protoArgs.Uri); + } + catch (System.Exception e) + { + _logger.LogError(e, "unhandled exception while processing URI coder://{authority}{path}", + protoArgs.Uri.Authority, protoArgs.Uri.AbsolutePath); + } + break; case ExtendedActivationKind.AppNotification: @@ -204,12 +219,6 @@ public void OnActivated(object? sender, AppActivationArguments args) } } - public void HandleURIActivation(Uri uri) - { - // don't log the query string as that's where we include some sensitive information like passwords - _logger.LogInformation("handling URI activation for {path}", uri.AbsolutePath); - } - public void HandleNotification(AppNotificationManager? sender, AppNotificationActivatedEventArgs args) { // right now, we don't do anything other than log diff --git a/App/Services/CredentialManager.cs b/App/Services/CredentialManager.cs index a2f6567..280169c 100644 --- a/App/Services/CredentialManager.cs +++ b/App/Services/CredentialManager.cs @@ -307,7 +307,7 @@ public WindowsCredentialBackend(string credentialsTargetName) public Task ReadCredentials(CancellationToken ct = default) { - var raw = NativeApi.ReadCredentials(_credentialsTargetName); + var raw = Wincred.ReadCredentials(_credentialsTargetName); if (raw == null) return Task.FromResult(null); RawCredentials? credentials; @@ -326,115 +326,179 @@ public WindowsCredentialBackend(string credentialsTargetName) public Task WriteCredentials(RawCredentials credentials, CancellationToken ct = default) { var raw = JsonSerializer.Serialize(credentials, RawCredentialsJsonContext.Default.RawCredentials); - NativeApi.WriteCredentials(_credentialsTargetName, raw); + Wincred.WriteCredentials(_credentialsTargetName, raw); return Task.CompletedTask; } public Task DeleteCredentials(CancellationToken ct = default) { - NativeApi.DeleteCredentials(_credentialsTargetName); + Wincred.DeleteCredentials(_credentialsTargetName); return Task.CompletedTask; } - private static class NativeApi +} + +/// +/// Wincred provides relatively low level wrapped calls to the Wincred.h native API. +/// +internal static class Wincred +{ + private const int CredentialTypeGeneric = 1; + private const int CredentialTypeDomainPassword = 2; + private const int PersistenceTypeLocalComputer = 2; + private const int ErrorNotFound = 1168; + private const int CredMaxCredentialBlobSize = 5 * 512; + private const string PackageNTLM = "NTLM"; + + public static string? ReadCredentials(string targetName) { - private const int CredentialTypeGeneric = 1; - private const int PersistenceTypeLocalComputer = 2; - private const int ErrorNotFound = 1168; - private const int CredMaxCredentialBlobSize = 5 * 512; + if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr)) + { + var error = Marshal.GetLastWin32Error(); + if (error == ErrorNotFound) return null; + throw new InvalidOperationException($"Failed to read credentials (Error {error})"); + } - public static string? ReadCredentials(string targetName) + try { - if (!CredReadW(targetName, CredentialTypeGeneric, 0, out var credentialPtr)) - { - var error = Marshal.GetLastWin32Error(); - if (error == ErrorNotFound) return null; - throw new InvalidOperationException($"Failed to read credentials (Error {error})"); - } + var cred = Marshal.PtrToStructure(credentialPtr); + return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char)); + } + finally + { + CredFree(credentialPtr); + } + } - try - { - var cred = Marshal.PtrToStructure(credentialPtr); - return Marshal.PtrToStringUni(cred.CredentialBlob, cred.CredentialBlobSize / sizeof(char)); - } - finally + public static void WriteCredentials(string targetName, string secret) + { + var byteCount = Encoding.Unicode.GetByteCount(secret); + if (byteCount > CredMaxCredentialBlobSize) + throw new ArgumentOutOfRangeException(nameof(secret), + $"The secret is greater than {CredMaxCredentialBlobSize} bytes"); + + var credentialBlob = Marshal.StringToHGlobalUni(secret); + var cred = new CREDENTIALW + { + Type = CredentialTypeGeneric, + TargetName = targetName, + CredentialBlobSize = byteCount, + CredentialBlob = credentialBlob, + Persist = PersistenceTypeLocalComputer, + }; + try + { + if (!CredWriteW(ref cred, 0)) { - CredFree(credentialPtr); + var error = Marshal.GetLastWin32Error(); + throw new InvalidOperationException($"Failed to write credentials (Error {error})"); } } - - public static void WriteCredentials(string targetName, string secret) + finally { - var byteCount = Encoding.Unicode.GetByteCount(secret); - if (byteCount > CredMaxCredentialBlobSize) - throw new ArgumentOutOfRangeException(nameof(secret), - $"The secret is greater than {CredMaxCredentialBlobSize} bytes"); + Marshal.FreeHGlobal(credentialBlob); + } + } - var credentialBlob = Marshal.StringToHGlobalUni(secret); - var cred = new CREDENTIAL - { - Type = CredentialTypeGeneric, - TargetName = targetName, - CredentialBlobSize = byteCount, - CredentialBlob = credentialBlob, - Persist = PersistenceTypeLocalComputer, - }; - try - { - if (!CredWriteW(ref cred, 0)) - { - var error = Marshal.GetLastWin32Error(); - throw new InvalidOperationException($"Failed to write credentials (Error {error})"); - } - } - finally - { - Marshal.FreeHGlobal(credentialBlob); - } + public static void DeleteCredentials(string targetName) + { + if (!CredDeleteW(targetName, CredentialTypeGeneric, 0)) + { + var error = Marshal.GetLastWin32Error(); + if (error == ErrorNotFound) return; + throw new InvalidOperationException($"Failed to delete credentials (Error {error})"); } + } + + public static void WriteDomainCredentials(string domainName, string serverName, string username, string password) + { + var targetName = $"{domainName}/{serverName}"; + var targetInfo = new CREDENTIAL_TARGET_INFORMATIONW + { + TargetName = targetName, + DnsServerName = serverName, + DnsDomainName = domainName, + PackageName = PackageNTLM, + }; + var byteCount = Encoding.Unicode.GetByteCount(password); + if (byteCount > CredMaxCredentialBlobSize) + throw new ArgumentOutOfRangeException(nameof(password), + $"The secret is greater than {CredMaxCredentialBlobSize} bytes"); - public static void DeleteCredentials(string targetName) + var credentialBlob = Marshal.StringToHGlobalUni(password); + var cred = new CREDENTIALW { - if (!CredDeleteW(targetName, CredentialTypeGeneric, 0)) + Type = CredentialTypeDomainPassword, + TargetName = targetName, + CredentialBlobSize = byteCount, + CredentialBlob = credentialBlob, + Persist = PersistenceTypeLocalComputer, + UserName = username, + }; + try + { + if (!CredWriteDomainCredentialsW(ref targetInfo, ref cred, 0)) { var error = Marshal.GetLastWin32Error(); - if (error == ErrorNotFound) return; - throw new InvalidOperationException($"Failed to delete credentials (Error {error})"); + throw new InvalidOperationException($"Failed to write credentials (Error {error})"); } } + finally + { + Marshal.FreeHGlobal(credentialBlob); + } + } - [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] - private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr); + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredReadW(string target, int type, int reservedFlag, out IntPtr credentialPtr); - [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] - private static extern bool CredWriteW([In] ref CREDENTIAL userCredential, [In] uint flags); + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredWriteW([In] ref CREDENTIALW userCredential, [In] uint flags); - [DllImport("Advapi32.dll", SetLastError = true)] - private static extern void CredFree([In] IntPtr cred); + [DllImport("Advapi32.dll", SetLastError = true)] + private static extern void CredFree([In] IntPtr cred); - [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] - private static extern bool CredDeleteW(string target, int type, int flags); + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredDeleteW(string target, int type, int flags); - [StructLayout(LayoutKind.Sequential)] - private struct CREDENTIAL - { - public int Flags; - public int Type; + [DllImport("Advapi32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + private static extern bool CredWriteDomainCredentialsW([In] ref CREDENTIAL_TARGET_INFORMATIONW target, [In] ref CREDENTIALW userCredential, [In] uint flags); - [MarshalAs(UnmanagedType.LPWStr)] public string TargetName; + [StructLayout(LayoutKind.Sequential)] + private struct CREDENTIALW + { + public int Flags; + public int Type; - [MarshalAs(UnmanagedType.LPWStr)] public string Comment; + [MarshalAs(UnmanagedType.LPWStr)] public string TargetName; - public long LastWritten; - public int CredentialBlobSize; - public IntPtr CredentialBlob; - public int Persist; - public int AttributeCount; - public IntPtr Attributes; + [MarshalAs(UnmanagedType.LPWStr)] public string Comment; - [MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias; + public long LastWritten; + public int CredentialBlobSize; + public IntPtr CredentialBlob; + public int Persist; + public int AttributeCount; + public IntPtr Attributes; - [MarshalAs(UnmanagedType.LPWStr)] public string UserName; - } + [MarshalAs(UnmanagedType.LPWStr)] public string TargetAlias; + + [MarshalAs(UnmanagedType.LPWStr)] public string UserName; + } + + [StructLayout(LayoutKind.Sequential)] + private struct CREDENTIAL_TARGET_INFORMATIONW + { + [MarshalAs(UnmanagedType.LPWStr)] public string TargetName; + [MarshalAs(UnmanagedType.LPWStr)] public string NetbiosServerName; + [MarshalAs(UnmanagedType.LPWStr)] public string DnsServerName; + [MarshalAs(UnmanagedType.LPWStr)] public string NetbiosDomainName; + [MarshalAs(UnmanagedType.LPWStr)] public string DnsDomainName; + [MarshalAs(UnmanagedType.LPWStr)] public string DnsTreeName; + [MarshalAs(UnmanagedType.LPWStr)] public string PackageName; + + public uint Flags; + public uint CredTypeCount; + public IntPtr CredTypes; } } diff --git a/App/Services/RdpConnector.cs b/App/Services/RdpConnector.cs new file mode 100644 index 0000000..1692186 --- /dev/null +++ b/App/Services/RdpConnector.cs @@ -0,0 +1,81 @@ +using System; +using System.Diagnostics; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Coder.Desktop.App.Services; + +public struct RdpCredentials(string username, string password) +{ + public readonly string Username = username; + public readonly string Password = password; +} + +public interface IRdpConnector : IAsyncDisposable +{ + public const int DefaultPort = 3389; + + public Task WriteCredentials(string fqdn, RdpCredentials credentials, CancellationToken ct = default); + + public Task Connect(string fqdn, int port = DefaultPort, CancellationToken ct = default); +} + +public class RdpConnector(ILogger logger) : IRdpConnector +{ + // Remote Desktop always uses TERMSRV as the domain; RDP is a part of Windows "Terminal Services". + private const string RdpDomain = "TERMSRV"; + + public Task WriteCredentials(string fqdn, RdpCredentials credentials, CancellationToken ct = default) + { + // writing credentials is idempotent for the same domain and server name. + Wincred.WriteDomainCredentials(RdpDomain, fqdn, credentials.Username, credentials.Password); + logger.LogDebug("wrote domain credential for {serverName} with username {username}", fqdn, + credentials.Username); + return Task.CompletedTask; + } + + public Task Connect(string fqdn, int port = IRdpConnector.DefaultPort, CancellationToken ct = default) + { + // use mstsc to launch the RDP connection + var mstscProc = new Process(); + mstscProc.StartInfo.FileName = "mstsc"; + var args = $"/v {fqdn}"; + if (port != IRdpConnector.DefaultPort) + { + args = $"/v {fqdn}:{port}"; + } + + mstscProc.StartInfo.Arguments = args; + mstscProc.StartInfo.CreateNoWindow = true; + mstscProc.StartInfo.UseShellExecute = false; + try + { + if (!mstscProc.Start()) + throw new InvalidOperationException("Failed to start mstsc, Start returned false"); + } + catch (Exception e) + { + logger.LogWarning(e, "mstsc failed to start"); + + try + { + mstscProc.Kill(); + } + catch + { + // ignored, the process likely doesn't exist + } + + mstscProc.Dispose(); + throw; + } + + return mstscProc.WaitForExitAsync(ct); + } + + public ValueTask DisposeAsync() + { + return ValueTask.CompletedTask; + } +} diff --git a/App/Services/UriHandler.cs b/App/Services/UriHandler.cs new file mode 100644 index 0000000..eabcdff --- /dev/null +++ b/App/Services/UriHandler.cs @@ -0,0 +1,150 @@ +using System; +using System.Collections.Specialized; +using System.Linq; +using System.Threading; +using System.Threading.Tasks; +using System.Web; +using Coder.Desktop.App.Models; +using Coder.Desktop.Vpn.Proto; +using Microsoft.Extensions.Logging; + + +namespace Coder.Desktop.App.Services; + +public interface IUriHandler : IAsyncDisposable +{ + public Task HandleUri(Uri uri, CancellationToken ct = default); +} + +public class UriHandler( + ILogger logger, + IRpcController rpcController, + IUserNotifier userNotifier, + IRdpConnector rdpConnector) : IUriHandler +{ + private const string OpenWorkspacePrefix = "/v0/open/ws/"; + + internal class UriException(string title, string detail) : Exception + { + internal readonly string Title = title; + internal readonly string Detail = detail; + } + + public async Task HandleUri(Uri uri, CancellationToken ct = default) + { + try + { + await HandleUriThrowingErrors(uri, ct); + } + catch (UriException e) + { + await userNotifier.ShowErrorNotification(e.Title, e.Detail, ct); + } + } + + private async Task HandleUriThrowingErrors(Uri uri, CancellationToken ct = default) + { + if (uri.AbsolutePath.StartsWith(OpenWorkspacePrefix)) + { + await HandleOpenWorkspaceApp(uri, ct); + return; + } + + logger.LogWarning("unhandled URI path {path}", uri.AbsolutePath); + throw new UriException("URI handling error", + $"URI with path {uri.AbsolutePath} is unsupported or malformed"); + } + + public async Task HandleOpenWorkspaceApp(Uri uri, CancellationToken ct = default) + { + const string errTitle = "Open Workspace Application Error"; + var subpath = uri.AbsolutePath[OpenWorkspacePrefix.Length..]; + var components = subpath.Split("/"); + if (components.Length != 4 || components[1] != "agent") + { + logger.LogWarning("unsupported open workspace app format in URI {path}", uri.AbsolutePath); + throw new UriException(errTitle, $"Failed to open {uri.AbsolutePath} because the format is unsupported."); + } + + var workspaceName = components[0]; + var agentName = components[2]; + var appName = components[3]; + + var state = rpcController.GetState(); + if (state.VpnLifecycle != VpnLifecycle.Started) + { + logger.LogDebug("got URI to open workspace {workspace}, but Coder Connect is not started", workspaceName); + throw new UriException(errTitle, + $"Failed to open application on {workspaceName} because Coder Connect is not started."); + } + + Workspace workspace; + try + { + workspace = state.Workspaces.Single(w => w.Name == workspaceName); + } + catch (InvalidOperationException) // Single() throws this when nothing matches. + { + logger.LogDebug("got URI to open workspace {workspace}, but the workspace doesn't exist", workspaceName); + throw new UriException(errTitle, + $"Failed to open application on workspace {workspaceName} because it doesn't exist"); + } + + Agent agent; + try + { + agent = state.Agents.Single(a => a.WorkspaceId == workspace.Id && a.Name == agentName); + } + catch (InvalidOperationException) // Single() throws this when nothing matches. + { + logger.LogDebug("got URI to open workspace/agent {workspaceName}/{agentName}, but the agent doesn't exist", + workspaceName, agentName); + throw new UriException(errTitle, + $"Failed to open application on workspace {workspaceName}, agent {agentName} because it doesn't exist."); + } + + if (appName != "rdp") + { + logger.LogWarning("unsupported agent application type {app}", appName); + throw new UriException(errTitle, + $"Failed to open agent in URI {uri.AbsolutePath} because application {appName} is unsupported"); + } + + await OpenRDP(agent.Fqdn.First(), uri.Query, ct); + } + + public async Task OpenRDP(string domainName, string queryString, CancellationToken ct = default) + { + const string errTitle = "Workspace Remote Desktop Error"; + NameValueCollection query; + try + { + query = HttpUtility.ParseQueryString(queryString); + } + catch (Exception ex) + { + // unfortunately, we can't safely write they query string to logs because it might contain + // sensitive info like a password. This is also why we don't log the exception directly + var trace = new System.Diagnostics.StackTrace(ex, false); + logger.LogWarning("failed to parse open RDP query string: {classMethod}", + trace?.GetFrame(0)?.GetMethod()?.ReflectedType?.FullName); + throw new UriException(errTitle, + "Failed to open remote desktop on a workspace because the URI was malformed"); + } + + var username = query.Get("username"); + var password = query.Get("password"); + if (!string.IsNullOrEmpty(username)) + { + password ??= string.Empty; + await rdpConnector.WriteCredentials(domainName, new RdpCredentials(username, password), ct); + } + + await rdpConnector.Connect(domainName, ct: ct); + } + + public ValueTask DisposeAsync() + { + return ValueTask.CompletedTask; + } +} diff --git a/App/Services/UserNotifier.cs b/App/Services/UserNotifier.cs index 9cdf6c1..9150f47 100644 --- a/App/Services/UserNotifier.cs +++ b/App/Services/UserNotifier.cs @@ -1,4 +1,5 @@ using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.Windows.AppNotifications; using Microsoft.Windows.AppNotifications.Builder; @@ -7,7 +8,7 @@ namespace Coder.Desktop.App.Services; public interface IUserNotifier : IAsyncDisposable { - public Task ShowErrorNotification(string title, string message); + public Task ShowErrorNotification(string title, string message, CancellationToken ct = default); } public class UserNotifier : IUserNotifier @@ -19,7 +20,7 @@ public ValueTask DisposeAsync() return ValueTask.CompletedTask; } - public Task ShowErrorNotification(string title, string message) + public Task ShowErrorNotification(string title, string message, CancellationToken ct = default) { var builder = new AppNotificationBuilder().AddText(title).AddText(message); _notificationManager.Show(builder.BuildNotification()); diff --git a/Tests.App/Services/RdpConnectorTest.cs b/Tests.App/Services/RdpConnectorTest.cs new file mode 100644 index 0000000..87bc59d --- /dev/null +++ b/Tests.App/Services/RdpConnectorTest.cs @@ -0,0 +1,27 @@ +using Coder.Desktop.App.Services; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; +using Serilog; + +namespace Coder.Desktop.Tests.App.Services; + +[TestFixture] +public class RdpConnectorTest +{ + [Test(Description = "Spawns RDP for real")] + [Ignore("Comment out to run manually")] + [CancelAfter(30_000)] + public async Task ConnectToRdp() + { + var builder = Host.CreateApplicationBuilder(); + builder.Services.AddSerilog(); + builder.Services.AddSingleton(); + var services = builder.Services.BuildServiceProvider(); + + var rdpConnector = (RdpConnector)services.GetService()!; + var creds = new RdpCredentials("Administrator", "coderRDP!"); + var workspace = "myworkspace.coder"; + await rdpConnector.WriteCredentials(workspace, creds); + await rdpConnector.Connect(workspace); + } +} diff --git a/Tests.App/Services/UriHandlerTest.cs b/Tests.App/Services/UriHandlerTest.cs new file mode 100644 index 0000000..9b24f02 --- /dev/null +++ b/Tests.App/Services/UriHandlerTest.cs @@ -0,0 +1,184 @@ +using Coder.Desktop.App.Models; +using Coder.Desktop.App.Services; +using Coder.Desktop.Vpn.Proto; +using Google.Protobuf; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; +using Moq; +using Serilog; + +namespace Coder.Desktop.Tests.App.Services; + +[TestFixture] +public class UriHandlerTest +{ + [SetUp] + public void SetupMocksAndUriHandler() + { + Serilog.Log.Logger = new LoggerConfiguration().MinimumLevel.Debug().WriteTo.NUnitOutput().CreateLogger(); + var builder = Host.CreateApplicationBuilder(); + builder.Services.AddSerilog(); + var logger = (ILogger)builder.Build().Services.GetService(typeof(ILogger))!; + + _mUserNotifier = new Mock(MockBehavior.Strict); + _mRdpConnector = new Mock(MockBehavior.Strict); + _mRpcController = new Mock(MockBehavior.Strict); + + uriHandler = new UriHandler(logger, _mRpcController.Object, _mUserNotifier.Object, _mRdpConnector.Object); + } + + [TearDown] + public async Task CleanupUriHandler() + { + await uriHandler.DisposeAsync(); + } + + private Mock _mUserNotifier; + private Mock _mRdpConnector; + private Mock _mRpcController; + private UriHandler uriHandler; // Unit under test. + + [SetUp] + public void AgentAndWorkspaceFixtures() + { + agent11 = new Agent(); + agent11.Fqdn.Add("workspace1.coder"); + agent11.Id = ByteString.CopyFrom(0x1, 0x1); + agent11.WorkspaceId = ByteString.CopyFrom(0x1, 0x0); + agent11.Name = "agent11"; + + workspace1 = new Workspace + { + Id = ByteString.CopyFrom(0x1, 0x0), + Name = "workspace1", + }; + + modelWithWorkspace1 = new RpcModel + { + VpnLifecycle = VpnLifecycle.Started, + Workspaces = [workspace1], + Agents = [agent11], + }; + } + + private Agent agent11; + private Workspace workspace1; + private RpcModel modelWithWorkspace1; + + [Test(Description = "Open RDP with username & password")] + [CancelAfter(30_000)] + public async Task Mainline(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/rdp?username=testy&password=sesame"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + var expectedCred = new RdpCredentials("testy", "sesame"); + _ = _mRdpConnector.Setup(m => m.WriteCredentials(agent11.Fqdn[0], expectedCred, ct)) + .Returns(Task.CompletedTask); + _ = _mRdpConnector.Setup(m => m.Connect(agent11.Fqdn[0], IRdpConnector.DefaultPort, ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Open RDP with no credentials")] + [CancelAfter(30_000)] + public async Task NoCredentials(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _ = _mRdpConnector.Setup(m => m.Connect(agent11.Fqdn[0], IRdpConnector.DefaultPort, ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown app slug")] + [CancelAfter(30_000)] + public async Task UnknownApp(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/someapp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny(), It.IsRegex("someapp"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown agent name")] + [CancelAfter(30_000)] + public async Task UnknownAgent(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/workspace1/agent/wrongagent/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny(), It.IsRegex("wrongagent"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown workspace name")] + [CancelAfter(30_000)] + public async Task UnknownWorkspace(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/wrongworkspace/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny(), It.IsRegex("wrongworkspace"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Malformed Query String")] + [CancelAfter(30_000)] + public async Task MalformedQuery(CancellationToken ct) + { + // there might be some query string that gets the parser to throw an exception, but I could not find one. + var input = new Uri("coder:/v0/open/ws/workspace1/agent/agent11/rdp?%&##"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + // treated the same as if we just didn't include credentials + _ = _mRdpConnector.Setup(m => m.Connect(agent11.Fqdn[0], IRdpConnector.DefaultPort, ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "VPN not started")] + [CancelAfter(30_000)] + public async Task VPNNotStarted(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/wrongworkspace/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(new RpcModel + { + VpnLifecycle = VpnLifecycle.Starting, + }); + // Coder Connect is the user facing name, so make sure the error mentions it. + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny(), It.IsRegex("Coder Connect"), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Wrong number of components")] + [CancelAfter(30_000)] + public async Task UnknownNumComponents(CancellationToken ct) + { + var input = new Uri("coder:/v0/open/ws/wrongworkspace/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny(), It.IsAny(), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } + + [Test(Description = "Unknown prefix")] + [CancelAfter(30_000)] + public async Task UnknownPrefix(CancellationToken ct) + { + var input = new Uri("coder:/v300/open/ws/workspace1/agent/agent11/rdp"); + + _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); + _mUserNotifier.Setup(m => m.ShowErrorNotification(It.IsAny(), It.IsAny(), ct)) + .Returns(Task.CompletedTask); + await uriHandler.HandleUri(input, ct); + } +} diff --git a/Tests.App/Tests.App.csproj b/Tests.App/Tests.App.csproj index cc01512..e20eba1 100644 --- a/Tests.App/Tests.App.csproj +++ b/Tests.App/Tests.App.csproj @@ -26,6 +26,8 @@ runtime; build; native; contentfiles; analyzers; buildtransitive + + From 8c5a07fa4b082ba4fa4a1bd5d5634421bc787cc0 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 8 May 2025 11:09:00 +0400 Subject: [PATCH 2/3] apply review suggestions --- App/App.xaml.cs | 16 +++---- App/Services/RdpConnector.cs | 13 ++---- App/Services/UriHandler.cs | 64 +++++++++++++------------- Tests.App/Services/RdpConnectorTest.cs | 6 +-- Tests.App/Services/UriHandlerTest.cs | 10 +--- 5 files changed, 48 insertions(+), 61 deletions(-) diff --git a/App/App.xaml.cs b/App/App.xaml.cs index a984778..7cd252e 100644 --- a/App/App.xaml.cs +++ b/App/App.xaml.cs @@ -195,16 +195,16 @@ public void OnActivated(object? sender, AppActivationArguments args) return; } - try - { // don't need to wait for it to complete. - _ = _uriHandler.HandleUri(protoArgs.Uri); - } - catch (System.Exception e) + _uriHandler.HandleUri(protoArgs.Uri).ContinueWith(t => { - _logger.LogError(e, "unhandled exception while processing URI coder://{authority}{path}", - protoArgs.Uri.Authority, protoArgs.Uri.AbsolutePath); - } + if (t.Exception != null) + { + _logger.LogError(t.Exception, + "unhandled exception while processing URI coder://{authority}{path}", + protoArgs.Uri.Authority, protoArgs.Uri.AbsolutePath); + } + }); break; diff --git a/App/Services/RdpConnector.cs b/App/Services/RdpConnector.cs index 1692186..a48d0ac 100644 --- a/App/Services/RdpConnector.cs +++ b/App/Services/RdpConnector.cs @@ -12,11 +12,11 @@ public struct RdpCredentials(string username, string password) public readonly string Password = password; } -public interface IRdpConnector : IAsyncDisposable +public interface IRdpConnector { public const int DefaultPort = 3389; - public Task WriteCredentials(string fqdn, RdpCredentials credentials, CancellationToken ct = default); + public void WriteCredentials(string fqdn, RdpCredentials credentials); public Task Connect(string fqdn, int port = DefaultPort, CancellationToken ct = default); } @@ -26,13 +26,13 @@ public class RdpConnector(ILogger logger) : IRdpConnector // Remote Desktop always uses TERMSRV as the domain; RDP is a part of Windows "Terminal Services". private const string RdpDomain = "TERMSRV"; - public Task WriteCredentials(string fqdn, RdpCredentials credentials, CancellationToken ct = default) + public void WriteCredentials(string fqdn, RdpCredentials credentials) { // writing credentials is idempotent for the same domain and server name. Wincred.WriteDomainCredentials(RdpDomain, fqdn, credentials.Username, credentials.Password); logger.LogDebug("wrote domain credential for {serverName} with username {username}", fqdn, credentials.Username); - return Task.CompletedTask; + return; } public Task Connect(string fqdn, int port = IRdpConnector.DefaultPort, CancellationToken ct = default) @@ -73,9 +73,4 @@ public Task Connect(string fqdn, int port = IRdpConnector.DefaultPort, Cancellat return mstscProc.WaitForExitAsync(ct); } - - public ValueTask DisposeAsync() - { - return ValueTask.CompletedTask; - } } diff --git a/App/Services/UriHandler.cs b/App/Services/UriHandler.cs index eabcdff..bea1c63 100644 --- a/App/Services/UriHandler.cs +++ b/App/Services/UriHandler.cs @@ -11,7 +11,7 @@ namespace Coder.Desktop.App.Services; -public interface IUriHandler : IAsyncDisposable +public interface IUriHandler { public Task HandleUri(Uri uri, CancellationToken ct = default); } @@ -24,10 +24,16 @@ public class UriHandler( { private const string OpenWorkspacePrefix = "/v0/open/ws/"; - internal class UriException(string title, string detail) : Exception + internal class UriException : Exception { - internal readonly string Title = title; - internal readonly string Detail = detail; + internal readonly string Title; + internal readonly string Detail; + + internal UriException(string title, string detail) : base($"{title}: {detail}") + { + Title = title; + Detail = detail; + } } public async Task HandleUri(Uri uri, CancellationToken ct = default) @@ -52,7 +58,7 @@ private async Task HandleUriThrowingErrors(Uri uri, CancellationToken ct = defau logger.LogWarning("unhandled URI path {path}", uri.AbsolutePath); throw new UriException("URI handling error", - $"URI with path {uri.AbsolutePath} is unsupported or malformed"); + $"URI with path '{uri.AbsolutePath}' is unsupported or malformed"); } public async Task HandleOpenWorkspaceApp(Uri uri, CancellationToken ct = default) @@ -63,7 +69,7 @@ public async Task HandleOpenWorkspaceApp(Uri uri, CancellationToken ct = default if (components.Length != 4 || components[1] != "agent") { logger.LogWarning("unsupported open workspace app format in URI {path}", uri.AbsolutePath); - throw new UriException(errTitle, $"Failed to open {uri.AbsolutePath} because the format is unsupported."); + throw new UriException(errTitle, $"Failed to open '{uri.AbsolutePath}' because the format is unsupported."); } var workspaceName = components[0]; @@ -73,41 +79,38 @@ public async Task HandleOpenWorkspaceApp(Uri uri, CancellationToken ct = default var state = rpcController.GetState(); if (state.VpnLifecycle != VpnLifecycle.Started) { - logger.LogDebug("got URI to open workspace {workspace}, but Coder Connect is not started", workspaceName); + logger.LogDebug("got URI to open workspace '{workspace}', but Coder Connect is not started", workspaceName); throw new UriException(errTitle, - $"Failed to open application on {workspaceName} because Coder Connect is not started."); + $"Failed to open application on '{workspaceName}' because Coder Connect is not started."); } - Workspace workspace; - try - { - workspace = state.Workspaces.Single(w => w.Name == workspaceName); - } - catch (InvalidOperationException) // Single() throws this when nothing matches. - { - logger.LogDebug("got URI to open workspace {workspace}, but the workspace doesn't exist", workspaceName); + var workspace = state.Workspaces.FirstOrDefault(w => w.Name == workspaceName); + if (workspace == null) { + logger.LogDebug("got URI to open workspace '{workspace}', but the workspace doesn't exist", workspaceName); throw new UriException(errTitle, - $"Failed to open application on workspace {workspaceName} because it doesn't exist"); + $"Failed to open application on workspace '{workspaceName}' because it doesn't exist"); } - Agent agent; - try - { - agent = state.Agents.Single(a => a.WorkspaceId == workspace.Id && a.Name == agentName); - } - catch (InvalidOperationException) // Single() throws this when nothing matches. - { - logger.LogDebug("got URI to open workspace/agent {workspaceName}/{agentName}, but the agent doesn't exist", + var agent = state.Agents.FirstOrDefault(a => a.WorkspaceId == workspace.Id && a.Name == agentName); + if (agent == null) { + logger.LogDebug("got URI to open workspace/agent '{workspaceName}/{agentName}', but the agent doesn't exist", workspaceName, agentName); + // If the workspace isn't running, that is almost certainly why we can't find the agent, so report that + // to the user. + if (workspace.Status != Workspace.Types.Status.Running) + { + throw new UriException(errTitle, + $"Failed to open application on workspace '{workspaceName}', because the workspace is not running."); + } throw new UriException(errTitle, - $"Failed to open application on workspace {workspaceName}, agent {agentName} because it doesn't exist."); + $"Failed to open application on workspace '{workspaceName}', because agent '{agentName}' doesn't exist."); } if (appName != "rdp") { logger.LogWarning("unsupported agent application type {app}", appName); throw new UriException(errTitle, - $"Failed to open agent in URI {uri.AbsolutePath} because application {appName} is unsupported"); + $"Failed to open agent in URI '{uri.AbsolutePath}' because application '{appName}' is unsupported"); } await OpenRDP(agent.Fqdn.First(), uri.Query, ct); @@ -137,14 +140,9 @@ public async Task OpenRDP(string domainName, string queryString, CancellationTok if (!string.IsNullOrEmpty(username)) { password ??= string.Empty; - await rdpConnector.WriteCredentials(domainName, new RdpCredentials(username, password), ct); + rdpConnector.WriteCredentials(domainName, new RdpCredentials(username, password)); } await rdpConnector.Connect(domainName, ct: ct); } - - public ValueTask DisposeAsync() - { - return ValueTask.CompletedTask; - } } diff --git a/Tests.App/Services/RdpConnectorTest.cs b/Tests.App/Services/RdpConnectorTest.cs index 87bc59d..b4a870e 100644 --- a/Tests.App/Services/RdpConnectorTest.cs +++ b/Tests.App/Services/RdpConnectorTest.cs @@ -11,7 +11,7 @@ public class RdpConnectorTest [Test(Description = "Spawns RDP for real")] [Ignore("Comment out to run manually")] [CancelAfter(30_000)] - public async Task ConnectToRdp() + public async Task ConnectToRdp(CancellationToken ct) { var builder = Host.CreateApplicationBuilder(); builder.Services.AddSerilog(); @@ -21,7 +21,7 @@ public async Task ConnectToRdp() var rdpConnector = (RdpConnector)services.GetService()!; var creds = new RdpCredentials("Administrator", "coderRDP!"); var workspace = "myworkspace.coder"; - await rdpConnector.WriteCredentials(workspace, creds); - await rdpConnector.Connect(workspace); + rdpConnector.WriteCredentials(workspace, creds); + await rdpConnector.Connect(workspace, ct: ct); } } diff --git a/Tests.App/Services/UriHandlerTest.cs b/Tests.App/Services/UriHandlerTest.cs index 9b24f02..65c886c 100644 --- a/Tests.App/Services/UriHandlerTest.cs +++ b/Tests.App/Services/UriHandlerTest.cs @@ -27,12 +27,6 @@ public void SetupMocksAndUriHandler() uriHandler = new UriHandler(logger, _mRpcController.Object, _mUserNotifier.Object, _mRdpConnector.Object); } - [TearDown] - public async Task CleanupUriHandler() - { - await uriHandler.DisposeAsync(); - } - private Mock _mUserNotifier; private Mock _mRdpConnector; private Mock _mRpcController; @@ -51,6 +45,7 @@ public void AgentAndWorkspaceFixtures() { Id = ByteString.CopyFrom(0x1, 0x0), Name = "workspace1", + Status = Workspace.Types.Status.Running, }; modelWithWorkspace1 = new RpcModel @@ -73,8 +68,7 @@ public async Task Mainline(CancellationToken ct) _mRpcController.Setup(m => m.GetState()).Returns(modelWithWorkspace1); var expectedCred = new RdpCredentials("testy", "sesame"); - _ = _mRdpConnector.Setup(m => m.WriteCredentials(agent11.Fqdn[0], expectedCred, ct)) - .Returns(Task.CompletedTask); + _ = _mRdpConnector.Setup(m => m.WriteCredentials(agent11.Fqdn[0], expectedCred)); _ = _mRdpConnector.Setup(m => m.Connect(agent11.Fqdn[0], IRdpConnector.DefaultPort, ct)) .Returns(Task.CompletedTask); await uriHandler.HandleUri(input, ct); From fe07b332ac8ce325e3eefac8f7c52c76a15d0068 Mon Sep 17 00:00:00 2001 From: Spike Curtis Date: Thu, 8 May 2025 13:47:37 +0400 Subject: [PATCH 3/3] more review suggestions, fmt --- App/App.xaml.cs | 3 ++- App/Services/UriHandler.cs | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/App/App.xaml.cs b/App/App.xaml.cs index 7cd252e..ba6fa67 100644 --- a/App/App.xaml.cs +++ b/App/App.xaml.cs @@ -195,11 +195,12 @@ public void OnActivated(object? sender, AppActivationArguments args) return; } - // don't need to wait for it to complete. + // don't need to wait for it to complete. _uriHandler.HandleUri(protoArgs.Uri).ContinueWith(t => { if (t.Exception != null) { + // don't log query params, as they contain secrets. _logger.LogError(t.Exception, "unhandled exception while processing URI coder://{authority}{path}", protoArgs.Uri.Authority, protoArgs.Uri.AbsolutePath); diff --git a/App/Services/UriHandler.cs b/App/Services/UriHandler.cs index bea1c63..b0b0a9a 100644 --- a/App/Services/UriHandler.cs +++ b/App/Services/UriHandler.cs @@ -85,15 +85,18 @@ public async Task HandleOpenWorkspaceApp(Uri uri, CancellationToken ct = default } var workspace = state.Workspaces.FirstOrDefault(w => w.Name == workspaceName); - if (workspace == null) { + if (workspace == null) + { logger.LogDebug("got URI to open workspace '{workspace}', but the workspace doesn't exist", workspaceName); throw new UriException(errTitle, $"Failed to open application on workspace '{workspaceName}' because it doesn't exist"); } var agent = state.Agents.FirstOrDefault(a => a.WorkspaceId == workspace.Id && a.Name == agentName); - if (agent == null) { - logger.LogDebug("got URI to open workspace/agent '{workspaceName}/{agentName}', but the agent doesn't exist", + if (agent == null) + { + logger.LogDebug( + "got URI to open workspace/agent '{workspaceName}/{agentName}', but the agent doesn't exist", workspaceName, agentName); // If the workspace isn't running, that is almost certainly why we can't find the agent, so report that // to the user. @@ -102,6 +105,7 @@ public async Task HandleOpenWorkspaceApp(Uri uri, CancellationToken ct = default throw new UriException(errTitle, $"Failed to open application on workspace '{workspaceName}', because the workspace is not running."); } + throw new UriException(errTitle, $"Failed to open application on workspace '{workspaceName}', because agent '{agentName}' doesn't exist."); }