diff --git a/src/client/sse.ts b/src/client/sse.ts index 5e9f0cf0..c152aa99 100644 --- a/src/client/sse.ts +++ b/src/client/sse.ts @@ -62,9 +62,33 @@ export class SSEClientTransport implements Transport { private _requestInit?: RequestInit; private _authProvider?: OAuthClientProvider; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + protected _onclose?: Transport['onclose']; + protected _onerror?: Transport['onerror']; + protected _onmessage?: (message: JSONRPCMessage) => void; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: SSEClientTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: SSEClientTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: SSEClientTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } constructor( url: URL, diff --git a/src/client/stdio.ts b/src/client/stdio.ts index 9e35293d..4bb74f11 100644 --- a/src/client/stdio.ts +++ b/src/client/stdio.ts @@ -95,9 +95,34 @@ export class StdioClientTransport implements Transport { private _serverParams: StdioServerParameters; private _stderrStream: PassThrough | null = null; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + protected _onclose?: Transport['onclose']; + protected _onerror?: Transport['onerror']; + protected _onmessage?: (message: JSONRPCMessage) => void; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: StdioClientTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: StdioClientTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: StdioClientTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } + constructor(server: StdioServerParameters) { this._serverParams = server; diff --git a/src/client/streamableHttp.ts b/src/client/streamableHttp.ts index 3462b2ab..f93503d2 100644 --- a/src/client/streamableHttp.ts +++ b/src/client/streamableHttp.ts @@ -124,9 +124,33 @@ export class StreamableHTTPClientTransport implements Transport { private _sessionId?: string; private _reconnectionOptions: StreamableHTTPReconnectionOptions; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + protected _onclose?: Transport['onclose']; + protected _onerror?: Transport['onerror']; + protected _onmessage?: (message: JSONRPCMessage) => void; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: StreamableHTTPClientTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: StreamableHTTPClientTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: StreamableHTTPClientTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } constructor( url: URL, diff --git a/src/client/websocket.ts b/src/client/websocket.ts index 3ca76082..2d5dbbcd 100644 --- a/src/client/websocket.ts +++ b/src/client/websocket.ts @@ -10,9 +10,33 @@ export class WebSocketClientTransport implements Transport { private _socket?: WebSocket; private _url: URL; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + protected _onclose?: Transport['onclose']; + protected _onerror?: Transport['onerror']; + protected _onmessage?: (message: JSONRPCMessage) => void; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: WebSocketClientTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: WebSocketClientTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: WebSocketClientTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } constructor(url: URL) { this._url = url; diff --git a/src/inMemory.ts b/src/inMemory.ts index 5dd6e81e..32d865a7 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -14,10 +14,43 @@ export class InMemoryTransport implements Transport { private _otherTransport?: InMemoryTransport; private _messageQueue: QueuedMessage[] = []; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; - sessionId?: string; + protected _onclose?: Transport['onclose']; + protected _onerror?: Transport['onerror']; + protected _onmessage?: Transport['onmessage']; + protected _sessionId?: Transport['sessionId']; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: InMemoryTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: InMemoryTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: InMemoryTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } + + set sessionId(sessionId: InMemoryTransport['_sessionId']) { + this._sessionId = sessionId; + } + + get sessionId() { + return this._sessionId; + } + /** * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. diff --git a/src/server/sse.ts b/src/server/sse.ts index 03f6fefc..055c9e7b 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -18,9 +18,33 @@ export class SSEServerTransport implements Transport { private _sseResponse?: ServerResponse; private _sessionId: string; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage, extra?: { authInfo?: AuthInfo }) => void; + protected _onclose?: Transport['onclose']; + protected _onerror?: Transport['onerror']; + protected _onmessage?: Transport['onmessage'];; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: SSEServerTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: SSEServerTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: SSEServerTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } /** * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. diff --git a/src/server/stdio.ts b/src/server/stdio.ts index 30c80012..6339272f 100644 --- a/src/server/stdio.ts +++ b/src/server/stdio.ts @@ -18,18 +18,42 @@ export class StdioServerTransport implements Transport { private _stdout: Writable = process.stdout, ) {} - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + protected _onclose?: Transport['onclose']; + protected _onerror: NonNullable = () => {}; + protected _onmessage?: (message: JSONRPCMessage) => void; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: StdioServerTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: StdioServerTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: StdioServerTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } // Arrow functions to bind `this` properly, while maintaining function identity. _ondata = (chunk: Buffer) => { this._readBuffer.append(chunk); this.processReadBuffer(); }; - _onerror = (error: Error) => { - this.onerror?.(error); - }; + // _onerror = (error: Error) => { + // this.onerror?.(error); + // }; /** * Starts listening for messages on stdin. diff --git a/src/server/streamableHttp.ts b/src/server/streamableHttp.ts index 7112c52a..62d41fb7 100644 --- a/src/server/streamableHttp.ts +++ b/src/server/streamableHttp.ts @@ -109,10 +109,43 @@ export class StreamableHTTPServerTransport implements Transport { private _eventStore?: EventStore; private _onsessioninitialized?: (sessionId: string) => void; - sessionId?: string | undefined; - onclose?: () => void; - onerror?: (error: Error) => void; - onmessage?: (message: JSONRPCMessage) => void; + protected _onclose?: Transport['onclose']; + protected _onerror?: Transport['onerror']; + protected _onmessage?: (message: JSONRPCMessage) => void; + protected _sessionId?: Transport['sessionId']; + + get onmessage() { + return this._onmessage; + } + + set onmessage(onmessage: StreamableHTTPServerTransport['_onmessage']) { + this._onmessage = onmessage; + } + + set onerror(onerror: StreamableHTTPServerTransport['_onerror']) { + this._onerror = onerror; + } + + get onerror() { + return this._onerror; + } + + set onclose(onclose: StreamableHTTPServerTransport['_onclose']) { + this._onclose = onclose; + } + + get onclose() { + return this._onclose; + } + + set sessionId(sessionId: StreamableHTTPServerTransport['_sessionId']) { + this._sessionId = sessionId; + } + + get sessionId() { + return this._sessionId; + } + constructor(options: StreamableHTTPServerTransportOptions) { this.sessionIdGenerator = options.sessionIdGenerator; diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index fb5ecd13..2a79b223 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -2,6 +2,7 @@ import { ZodType, z } from "zod"; import { ClientCapabilities, ErrorCode, + JSONRPCMessage, McpError, Notification, Request, @@ -10,6 +11,7 @@ import { } from "../types.js"; import { Protocol, mergeCapabilities } from "./protocol.js"; import { Transport } from "./transport.js"; +import { StreamableHTTPClientTransport } from "../client/streamableHttp.js"; // Mock Transport class class MockTransport implements Transport { @@ -24,6 +26,42 @@ class MockTransport implements Transport { async send(_message: unknown): Promise {} } +// Mock Extends Transport +class MockExtendsTransport extends StreamableHTTPClientTransport { + messageHandler?: ( + ...args: Parameters< + NonNullable + > + ) => void; + + constructor( + url: URL, + opts?: ConstructorParameters["1"] & { + messageHandler?: ( + ...args: Parameters< + NonNullable + > + ) => void; + } + ) { + super(url, opts); + this.messageHandler = opts?.messageHandler; + } + get onmessage() { + return this._onmessage; + } + set onmessage(onmessage) { + this._onmessage = ( + ...args: Parameters< + NonNullable + > + ) => { + this.messageHandler?.(...args); + onmessage?.(...args); + }; + } +} + describe("protocol tests", () => { let protocol: Protocol; let transport: MockTransport; @@ -83,9 +121,9 @@ describe("protocol tests", () => { resetTimeoutOnProgress: false, onprogress: onProgressMock, }); - + jest.advanceTimersByTime(800); - + if (transport.onmessage) { transport.onmessage({ jsonrpc: "2.0", @@ -98,14 +136,14 @@ describe("protocol tests", () => { }); } await Promise.resolve(); - + expect(onProgressMock).toHaveBeenCalledWith({ progress: 50, total: 100, }); - + jest.advanceTimersByTime(201); - + await expect(requestPromise).rejects.toThrow("Request timed out"); }); @@ -194,7 +232,9 @@ describe("protocol tests", () => { }, }); } - await expect(requestPromise).rejects.toThrow("Maximum total timeout exceeded"); + await expect(requestPromise).rejects.toThrow( + "Maximum total timeout exceeded" + ); expect(onProgressMock).toHaveBeenCalledTimes(1); }); @@ -255,6 +295,26 @@ describe("protocol tests", () => { await Promise.resolve(); await expect(requestPromise).resolves.toEqual({ result: "success" }); }); + + test("should extends transport success", async () => { + let handlerCount = 0; + const testMessageHandler = () => { + handlerCount++; + } + const transport = new MockExtendsTransport(new URL('http://localhost:3000/'), { messageHandler: testMessageHandler }); + const onmessage = () => {}; + transport.onmessage = onmessage; + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 75, + total: 100, + }, + }); + expect(handlerCount === 1); + }); }); });