diff --git a/Coder-Desktop/Coder-Desktop/Preview Content/PreviewVPN.swift b/Coder-Desktop/Coder-Desktop/Preview Content/PreviewVPN.swift index 2c6e8d0..4d4e9f9 100644 --- a/Coder-Desktop/Coder-Desktop/Preview Content/PreviewVPN.swift +++ b/Coder-Desktop/Coder-Desktop/Preview Content/PreviewVPN.swift @@ -33,6 +33,8 @@ final class PreviewVPN: Coder_Desktop.VPNService { self.shouldFail = shouldFail } + @Published var progress: VPNProgress = .init(stage: .initial, downloadProgress: nil) + var startTask: Task? func start() async { if await startTask?.value != nil { diff --git a/Coder-Desktop/Coder-Desktop/VPN/VPNProgress.swift b/Coder-Desktop/Coder-Desktop/VPN/VPNProgress.swift new file mode 100644 index 0000000..1d2aa6b --- /dev/null +++ b/Coder-Desktop/Coder-Desktop/VPN/VPNProgress.swift @@ -0,0 +1,68 @@ +import SwiftUI +import VPNLib + +struct VPNProgress { + let stage: ProgressStage + let downloadProgress: DownloadProgress? +} + +struct VPNProgressView: View { + let state: VPNServiceState + let progress: VPNProgress + + var body: some View { + VStack { + CircularProgressView(value: value) + // We'll estimate that the last 25% takes 9 seconds + // so it doesn't appear stuck + .autoComplete(threshold: 0.75, duration: 9) + Text(progressMessage) + .multilineTextAlignment(.center) + } + .padding() + .progressViewStyle(.circular) + .foregroundStyle(.secondary) + } + + var progressMessage: String { + "\(progress.stage.description ?? defaultMessage)\(downloadProgressMessage)" + } + + var downloadProgressMessage: String { + progress.downloadProgress.flatMap { "\n\($0.description)" } ?? "" + } + + var defaultMessage: String { + state == .connecting ? "Starting Coder Connect..." : "Stopping Coder Connect..." + } + + var value: Float? { + guard state == .connecting else { + return nil + } + switch progress.stage { + case .initial: + return 0.10 + case .downloading: + guard let downloadProgress = progress.downloadProgress else { + // We can't make this illegal state unrepresentable because XPC + // doesn't support enums with associated values. + return 0.05 + } + // 40MB if the server doesn't give us the expected size + let totalBytes = downloadProgress.totalBytesToWrite ?? 40_000_000 + let downloadPercent = min(1.0, Float(downloadProgress.totalBytesWritten) / Float(totalBytes)) + return 0.10 + 0.6 * downloadPercent + case .validating: + return 0.71 + case .removingQuarantine: + return 0.72 + case .opening: + return 0.73 + case .settingUpTunnel: + return 0.74 + case .startingTunnel: + return 0.75 + } + } +} diff --git a/Coder-Desktop/Coder-Desktop/VPN/VPNService.swift b/Coder-Desktop/Coder-Desktop/VPN/VPNService.swift index c3c1773..224174a 100644 --- a/Coder-Desktop/Coder-Desktop/VPN/VPNService.swift +++ b/Coder-Desktop/Coder-Desktop/VPN/VPNService.swift @@ -7,6 +7,7 @@ import VPNLib protocol VPNService: ObservableObject { var state: VPNServiceState { get } var menuState: VPNMenuState { get } + var progress: VPNProgress { get } func start() async func stop() async func configureTunnelProviderProtocol(proto: NETunnelProviderProtocol?) @@ -55,7 +56,14 @@ final class CoderVPNService: NSObject, VPNService { var logger = Logger(subsystem: Bundle.main.bundleIdentifier!, category: "vpn") lazy var xpc: VPNXPCInterface = .init(vpn: self) - @Published var tunnelState: VPNServiceState = .disabled + @Published var tunnelState: VPNServiceState = .disabled { + didSet { + if tunnelState == .connecting { + progress = .init(stage: .initial, downloadProgress: nil) + } + } + } + @Published var sysExtnState: SystemExtensionState = .uninstalled @Published var neState: NetworkExtensionState = .unconfigured var state: VPNServiceState { @@ -72,6 +80,8 @@ final class CoderVPNService: NSObject, VPNService { return tunnelState } + @Published var progress: VPNProgress = .init(stage: .initial, downloadProgress: nil) + @Published var menuState: VPNMenuState = .init() // Whether the VPN should start as soon as possible @@ -155,6 +165,10 @@ final class CoderVPNService: NSObject, VPNService { } } + func onProgress(stage: ProgressStage, downloadProgress: DownloadProgress?) { + progress = .init(stage: stage, downloadProgress: downloadProgress) + } + func applyPeerUpdate(with update: Vpn_PeerUpdate) { // Delete agents update.deletedAgents.forEach { menuState.deleteAgent(withId: $0.id) } diff --git a/Coder-Desktop/Coder-Desktop/Views/CircularProgressView.swift b/Coder-Desktop/Coder-Desktop/Views/CircularProgressView.swift new file mode 100644 index 0000000..fc359e8 --- /dev/null +++ b/Coder-Desktop/Coder-Desktop/Views/CircularProgressView.swift @@ -0,0 +1,80 @@ +import SwiftUI + +struct CircularProgressView: View { + let value: Float? + + var strokeWidth: CGFloat = 4 + var diameter: CGFloat = 22 + var primaryColor: Color = .secondary + var backgroundColor: Color = .secondary.opacity(0.3) + + @State private var rotation = 0.0 + @State private var trimAmount: CGFloat = 0.15 + + var autoCompleteThreshold: Float? + var autoCompleteDuration: TimeInterval? + + var body: some View { + ZStack { + // Background circle + Circle() + .stroke(backgroundColor, style: StrokeStyle(lineWidth: strokeWidth, lineCap: .round)) + .frame(width: diameter, height: diameter) + Group { + if let value { + // Determinate gauge + Circle() + .trim(from: 0, to: CGFloat(displayValue(for: value))) + .stroke(primaryColor, style: StrokeStyle(lineWidth: strokeWidth, lineCap: .round)) + .frame(width: diameter, height: diameter) + .rotationEffect(.degrees(-90)) + .animation(autoCompleteAnimation(for: value), value: value) + } else { + // Indeterminate gauge + Circle() + .trim(from: 0, to: trimAmount) + .stroke(primaryColor, style: StrokeStyle(lineWidth: strokeWidth, lineCap: .round)) + .frame(width: diameter, height: diameter) + .rotationEffect(.degrees(rotation)) + } + } + } + .frame(width: diameter + strokeWidth * 2, height: diameter + strokeWidth * 2) + .onAppear { + if value == nil { + withAnimation(.linear(duration: 0.8).repeatForever(autoreverses: false)) { + rotation = 360 + } + } + } + } + + private func displayValue(for value: Float) -> Float { + if let threshold = autoCompleteThreshold, + value >= threshold, value < 1.0 + { + return 1.0 + } + return value + } + + private func autoCompleteAnimation(for value: Float) -> Animation? { + guard let threshold = autoCompleteThreshold, + let duration = autoCompleteDuration, + value >= threshold, value < 1.0 + else { + return .default + } + + return .easeOut(duration: duration) + } +} + +extension CircularProgressView { + func autoComplete(threshold: Float, duration: TimeInterval) -> CircularProgressView { + var view = self + view.autoCompleteThreshold = threshold + view.autoCompleteDuration = duration + return view + } +} diff --git a/Coder-Desktop/Coder-Desktop/Views/VPN/VPNState.swift b/Coder-Desktop/Coder-Desktop/Views/VPN/VPNState.swift index 2331902..e2aa1d8 100644 --- a/Coder-Desktop/Coder-Desktop/Views/VPN/VPNState.swift +++ b/Coder-Desktop/Coder-Desktop/Views/VPN/VPNState.swift @@ -28,9 +28,7 @@ struct VPNState: View { case (.connecting, _), (.disconnecting, _): HStack { Spacer() - ProgressView( - vpn.state == .connecting ? "Starting Coder Connect..." : "Stopping Coder Connect..." - ).padding() + VPNProgressView(state: vpn.state, progress: vpn.progress) Spacer() } case let (.failed(vpnErr), _): diff --git a/Coder-Desktop/Coder-Desktop/XPCInterface.swift b/Coder-Desktop/Coder-Desktop/XPCInterface.swift index e21be86..e6c78d6 100644 --- a/Coder-Desktop/Coder-Desktop/XPCInterface.swift +++ b/Coder-Desktop/Coder-Desktop/XPCInterface.swift @@ -71,6 +71,12 @@ import VPNLib } } + func onProgress(stage: ProgressStage, downloadProgress: DownloadProgress?) { + Task { @MainActor in + svc.onProgress(stage: stage, downloadProgress: downloadProgress) + } + } + // The NE has verified the dylib and knows better than Gatekeeper func removeQuarantine(path: String, reply: @escaping (Bool) -> Void) { let reply = CallbackWrapper(reply) diff --git a/Coder-Desktop/Coder-DesktopTests/Util.swift b/Coder-Desktop/Coder-DesktopTests/Util.swift index 6c7bc20..6075127 100644 --- a/Coder-Desktop/Coder-DesktopTests/Util.swift +++ b/Coder-Desktop/Coder-DesktopTests/Util.swift @@ -10,6 +10,7 @@ class MockVPNService: VPNService, ObservableObject { @Published var state: Coder_Desktop.VPNServiceState = .disabled @Published var baseAccessURL: URL = .init(string: "https://dev.coder.com")! @Published var menuState: VPNMenuState = .init() + @Published var progress: VPNProgress = .init(stage: .initial, downloadProgress: nil) var onStart: (() async -> Void)? var onStop: (() async -> Void)? diff --git a/Coder-Desktop/Coder-DesktopTests/VPNStateTests.swift b/Coder-Desktop/Coder-DesktopTests/VPNStateTests.swift index 92827cf..abad6ab 100644 --- a/Coder-Desktop/Coder-DesktopTests/VPNStateTests.swift +++ b/Coder-Desktop/Coder-DesktopTests/VPNStateTests.swift @@ -38,8 +38,7 @@ struct VPNStateTests { try await ViewHosting.host(view) { try await sut.inspection.inspect { view in - let progressView = try view.find(ViewType.ProgressView.self) - #expect(try progressView.labelView().text().string() == "Starting Coder Connect...") + _ = try view.find(text: "Starting Coder Connect...") } } } @@ -50,8 +49,7 @@ struct VPNStateTests { try await ViewHosting.host(view) { try await sut.inspection.inspect { view in - let progressView = try view.find(ViewType.ProgressView.self) - #expect(try progressView.labelView().text().string() == "Stopping Coder Connect...") + _ = try view.find(text: "Stopping Coder Connect...") } } } diff --git a/Coder-Desktop/VPN/Manager.swift b/Coder-Desktop/VPN/Manager.swift index bc441ac..4559e4d 100644 --- a/Coder-Desktop/VPN/Manager.swift +++ b/Coder-Desktop/VPN/Manager.swift @@ -35,10 +35,18 @@ actor Manager { // Timeout after 5 minutes, or if there's no data for 60 seconds sessionConfig.timeoutIntervalForRequest = 60 sessionConfig.timeoutIntervalForResource = 300 - try await download(src: dylibPath, dest: dest, urlSession: URLSession(configuration: sessionConfig)) + try await download( + src: dylibPath, + dest: dest, + urlSession: URLSession(configuration: sessionConfig) + ) { progress in + // TODO: Debounce, somehow + pushProgress(stage: .downloading, downloadProgress: progress) + } } catch { throw .download(error) } + pushProgress(stage: .validating) let client = Client(url: cfg.serverUrl) let buildInfo: BuildInfoResponse do { @@ -59,11 +67,13 @@ actor Manager { // so it's safe to execute. However, the SE must be sandboxed, so we defer to the app. try await removeQuarantine(dest) + pushProgress(stage: .opening) do { try tunnelHandle = TunnelHandle(dylibPath: dest) } catch { throw .tunnelSetup(error) } + pushProgress(stage: .settingUpTunnel) speaker = await Speaker( writeFD: tunnelHandle.writeHandle, readFD: tunnelHandle.readHandle @@ -158,6 +168,7 @@ actor Manager { } func startVPN() async throws(ManagerError) { + pushProgress(stage: .startingTunnel) logger.info("sending start rpc") guard let tunFd = ptp.tunnelFileDescriptor else { logger.error("no fd") @@ -234,6 +245,15 @@ actor Manager { } } +func pushProgress(stage: ProgressStage, downloadProgress: DownloadProgress? = nil) { + guard let conn = globalXPCListenerDelegate.conn else { + logger.warning("couldn't send progress message to app: no connection") + return + } + logger.debug("sending progress message to app") + conn.onProgress(stage: stage, downloadProgress: downloadProgress) +} + struct ManagerConfig { let apiToken: String let serverUrl: URL @@ -312,6 +332,7 @@ private func removeQuarantine(_ dest: URL) async throws(ManagerError) { let file = NSURL(fileURLWithPath: dest.path) try? file.getResourceValue(&flag, forKey: kCFURLQuarantinePropertiesKey as URLResourceKey) if flag != nil { + pushProgress(stage: .removingQuarantine) // Try the privileged helper first (it may not even be registered) if await globalHelperXPCSpeaker.tryRemoveQuarantine(path: dest.path) { // Success! diff --git a/Coder-Desktop/VPNLib/Download.swift b/Coder-Desktop/VPNLib/Download.swift index 559be37..d102209 100644 --- a/Coder-Desktop/VPNLib/Download.swift +++ b/Coder-Desktop/VPNLib/Download.swift @@ -125,47 +125,13 @@ public class SignatureValidator { } } -public func download(src: URL, dest: URL, urlSession: URLSession) async throws(DownloadError) { - var req = URLRequest(url: src) - if FileManager.default.fileExists(atPath: dest.path) { - if let existingFileData = try? Data(contentsOf: dest, options: .mappedIfSafe) { - req.setValue(etag(data: existingFileData), forHTTPHeaderField: "If-None-Match") - } - } - // TODO: Add Content-Length headers to coderd, add download progress delegate - let tempURL: URL - let response: URLResponse - do { - (tempURL, response) = try await urlSession.download(for: req) - } catch { - throw .networkError(error, url: src.absoluteString) - } - defer { - if FileManager.default.fileExists(atPath: tempURL.path) { - try? FileManager.default.removeItem(at: tempURL) - } - } - - guard let httpResponse = response as? HTTPURLResponse else { - throw .invalidResponse - } - guard httpResponse.statusCode != 304 else { - // We already have the latest dylib downloaded on disk - return - } - - guard httpResponse.statusCode == 200 else { - throw .unexpectedStatusCode(httpResponse.statusCode) - } - - do { - if FileManager.default.fileExists(atPath: dest.path) { - try FileManager.default.removeItem(at: dest) - } - try FileManager.default.moveItem(at: tempURL, to: dest) - } catch { - throw .fileOpError(error) - } +public func download( + src: URL, + dest: URL, + urlSession: URLSession, + progressUpdates: ((DownloadProgress) -> Void)? = nil +) async throws(DownloadError) { + try await DownloadManager().download(src: src, dest: dest, urlSession: urlSession, progressUpdates: progressUpdates) } func etag(data: Data) -> String { @@ -195,3 +161,124 @@ public enum DownloadError: Error { public var localizedDescription: String { description } } + +// The async `URLSession.download` api ignores the passed-in delegate, so we +// wrap the older delegate methods in an async adapter with a continuation. +private final class DownloadManager: NSObject, @unchecked Sendable { + private var continuation: CheckedContinuation! + private var progressHandler: ((DownloadProgress) -> Void)? + private var dest: URL! + + func download( + src: URL, + dest: URL, + urlSession: URLSession, + progressUpdates: ((DownloadProgress) -> Void)? + ) async throws(DownloadError) { + var req = URLRequest(url: src) + if FileManager.default.fileExists(atPath: dest.path) { + if let existingFileData = try? Data(contentsOf: dest, options: .mappedIfSafe) { + req.setValue(etag(data: existingFileData), forHTTPHeaderField: "If-None-Match") + } + } + + let downloadTask = urlSession.downloadTask(with: req) + progressHandler = progressUpdates + self.dest = dest + downloadTask.delegate = self + do { + try await withCheckedThrowingContinuation { continuation in + self.continuation = continuation + downloadTask.resume() + } + } catch let error as DownloadError { + throw error + } catch { + throw .networkError(error, url: src.absoluteString) + } + } +} + +extension DownloadManager: URLSessionDownloadDelegate { + // Progress + func urlSession( + _: URLSession, + downloadTask: URLSessionDownloadTask, + didWriteData _: Int64, + totalBytesWritten: Int64, + totalBytesExpectedToWrite _: Int64 + ) { + let maybeLength = (downloadTask.response as? HTTPURLResponse)? + .value(forHTTPHeaderField: "X-Original-Content-Length") + .flatMap(Int64.init) + progressHandler?(.init(totalBytesWritten: totalBytesWritten, totalBytesToWrite: maybeLength)) + } + + // Completion + func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { + guard let httpResponse = downloadTask.response as? HTTPURLResponse else { + continuation.resume(throwing: DownloadError.invalidResponse) + return + } + guard httpResponse.statusCode != 304 else { + // We already have the latest dylib downloaded in dest + continuation.resume() + return + } + + guard httpResponse.statusCode == 200 else { + continuation.resume(throwing: DownloadError.unexpectedStatusCode(httpResponse.statusCode)) + return + } + + do { + if FileManager.default.fileExists(atPath: dest.path) { + try FileManager.default.removeItem(at: dest) + } + try FileManager.default.moveItem(at: location, to: dest) + } catch { + continuation.resume(throwing: DownloadError.fileOpError(error)) + } + + continuation.resume() + } + + // Failure + func urlSession(_: URLSession, task _: URLSessionTask, didCompleteWithError error: Error?) { + if let error { + continuation.resume(throwing: error) + } + } +} + +@objc public final class DownloadProgress: NSObject, NSSecureCoding, @unchecked Sendable { + public static var supportsSecureCoding: Bool { true } + + public let totalBytesWritten: Int64 + public let totalBytesToWrite: Int64? + + public init(totalBytesWritten: Int64, totalBytesToWrite: Int64?) { + self.totalBytesWritten = totalBytesWritten + self.totalBytesToWrite = totalBytesToWrite + } + + public required convenience init?(coder: NSCoder) { + let written = coder.decodeInt64(forKey: "written") + let total = coder.containsValue(forKey: "total") ? coder.decodeInt64(forKey: "total") : nil + self.init(totalBytesWritten: written, totalBytesToWrite: total) + } + + public func encode(with coder: NSCoder) { + coder.encode(totalBytesWritten, forKey: "written") + if let total = totalBytesToWrite { + coder.encode(total, forKey: "total") + } + } + + override public var description: String { + let fmt = ByteCountFormatter() + let done = fmt.string(fromByteCount: totalBytesWritten) + let tot = totalBytesToWrite.map { fmt.string(fromByteCount: $0) } ?? "Unknown" + return "\(done) / \(tot)" + } +} diff --git a/Coder-Desktop/VPNLib/XPC.swift b/Coder-Desktop/VPNLib/XPC.swift index dc79651..9464ea2 100644 --- a/Coder-Desktop/VPNLib/XPC.swift +++ b/Coder-Desktop/VPNLib/XPC.swift @@ -10,5 +10,35 @@ import Foundation @objc public protocol VPNXPCClientCallbackProtocol { // data is a serialized `Vpn_PeerUpdate` func onPeerUpdate(_ data: Data) + func onProgress(stage: ProgressStage, downloadProgress: DownloadProgress?) func removeQuarantine(path: String, reply: @escaping (Bool) -> Void) } + +@objc public enum ProgressStage: Int, Sendable { + case initial + case downloading + case validating + case removingQuarantine + case opening + case settingUpTunnel + case startingTunnel + + public var description: String? { + switch self { + case .initial: + nil + case .downloading: + "Downloading library..." + case .validating: + "Validating library..." + case .removingQuarantine: + "Removing quarantine..." + case .opening: + "Opening library..." + case .settingUpTunnel: + "Setting up tunnel..." + case .startingTunnel: + nil + } + } +}