diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 73103208..1b2f4d4b 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -79,6 +79,99 @@ describe("McpServer", () => { } ]) }); + + /*** + * Test: Progress Notification with Message Field + */ + test("should send progress notifications with message field", async () => { + const mcpServer = new McpServer( + { + name: "test server", + version: "1.0", + } + ); + + // Create a tool that sends progress updates + mcpServer.tool( + "long-operation", + "A long running operation with progress updates", + { + steps: z.number().min(1).describe("Number of steps to perform"), + }, + async ({ steps }, { sendNotification, _meta }) => { + const progressToken = _meta?.progressToken; + + if (progressToken) { + // Send progress notification for each step + for (let i = 1; i <= steps; i++) { + await sendNotification({ + method: "notifications/progress", + params: { + progressToken, + progress: i, + total: steps, + message: `Completed step ${i} of ${steps}`, + }, + }); + } + } + + return { content: [{ type: "text" as const, text: `Operation completed with ${steps} steps` }] }; + } + ); + + const progressUpdates: Array<{ progress: number, total?: number, message?: string }> = []; + + const client = new Client({ + name: "test client", + version: "1.0", + }); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + mcpServer.server.connect(serverTransport), + ]); + + // Call the tool with progress tracking + await client.request( + { + method: "tools/call", + params: { + name: "long-operation", + arguments: { steps: 3 }, + _meta: { + progressToken: "progress-test-1" + } + } + }, + CallToolResultSchema, + { + onprogress: (progress) => { + progressUpdates.push(progress); + } + } + ); + + // Verify progress notifications were received with message field + expect(progressUpdates).toHaveLength(3); + expect(progressUpdates[0]).toMatchObject({ + progress: 1, + total: 3, + message: "Completed step 1 of 3", + }); + expect(progressUpdates[1]).toMatchObject({ + progress: 2, + total: 3, + message: "Completed step 2 of 3", + }); + expect(progressUpdates[2]).toMatchObject({ + progress: 3, + total: 3, + message: "Completed step 3 of 3", + }); + }); }); describe("ResourceTemplate", () => { diff --git a/src/shared/protocol.test.ts b/src/shared/protocol.test.ts index fb5ecd13..e0141da1 100644 --- a/src/shared/protocol.test.ts +++ b/src/shared/protocol.test.ts @@ -255,6 +255,74 @@ describe("protocol tests", () => { await Promise.resolve(); await expect(requestPromise).resolves.toEqual({ result: "success" }); }); + + test("should handle progress notifications with message field", async () => { + await protocol.connect(transport); + const request = { method: "example", params: {} }; + const mockSchema: ZodType<{ result: string }> = z.object({ + result: z.string(), + }); + const onProgressMock = jest.fn(); + + const requestPromise = protocol.request(request, mockSchema, { + timeout: 1000, + onprogress: onProgressMock, + }); + + jest.advanceTimersByTime(200); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 25, + total: 100, + message: "Initializing process...", + }, + }); + } + await Promise.resolve(); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 25, + total: 100, + message: "Initializing process...", + }); + + jest.advanceTimersByTime(200); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + method: "notifications/progress", + params: { + progressToken: 0, + progress: 75, + total: 100, + message: "Processing data...", + }, + }); + } + await Promise.resolve(); + + expect(onProgressMock).toHaveBeenCalledWith({ + progress: 75, + total: 100, + message: "Processing data...", + }); + + if (transport.onmessage) { + transport.onmessage({ + jsonrpc: "2.0", + id: 0, + result: { result: "success" }, + }); + } + await Promise.resolve(); + await expect(requestPromise).resolves.toEqual({ result: "success" }); + }); }); }); diff --git a/src/types.ts b/src/types.ts index 2e31a6b5..bd299c8f 100644 --- a/src/types.ts +++ b/src/types.ts @@ -364,6 +364,10 @@ export const ProgressSchema = z * Total number of items to process (or total progress required), if known. */ total: z.optional(z.number()), + /** + * An optional message describing the current progress. + */ + message: z.optional(z.string()), }) .passthrough();