diff --git a/.github/workflows/mutation-testing.yml b/.github/workflows/mutation-testing.yml new file mode 100644 index 0000000..59493fd --- /dev/null +++ b/.github/workflows/mutation-testing.yml @@ -0,0 +1,53 @@ +name: Mutation Testing (Scheduled) + +# Run weekly on Sunday at 3am UTC, or manually +on: + schedule: + - cron: '0 3 * * 0' # Every Sunday at 3:00 AM UTC + workflow_dispatch: # Allow manual trigger + inputs: + packages: + description: 'Package to test (leave empty for all)' + required: false + default: '' + +permissions: + contents: read + +jobs: + mutation-testing: + name: Full Mutation Testing + runs-on: ubuntu-latest + timeout-minutes: 180 # 3 hour timeout + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Bun + uses: oven-sh/setup-bun@v2 + with: + bun-version: latest + + - name: Install dependencies + run: bun install + + - name: Run Mutation Testing + run: bun run test:mutate + + - name: Upload Mutation Report + uses: actions/upload-artifact@v4 + if: always() + with: + name: mutation-report + path: reports/mutation/ + retention-days: 30 + + - name: Summary + if: always() + run: | + echo "## 🧬 Mutation Testing Complete" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Report available in artifacts: mutation-report" >> $GITHUB_STEP_SUMMARY + echo "" >> $GITHUB_STEP_SUMMARY + echo "Run locally with: \`bun run test:mutate\`" >> $GITHUB_STEP_SUMMARY diff --git a/.github/workflows/test-quality.yml b/.github/workflows/test-quality.yml index 915a40b..c45d392 100644 --- a/.github/workflows/test-quality.yml +++ b/.github/workflows/test-quality.yml @@ -89,11 +89,9 @@ jobs: - name: Lint Check run: bun run lint - mutation-testing: - name: Mutation Testing + property-based-testing: + name: Property-Based Testing runs-on: ubuntu-latest - # Only run mutation testing on PRs to avoid slowing down every push - if: github.event_name == 'pull_request' steps: - name: Checkout @@ -107,14 +105,12 @@ jobs: - name: Install dependencies run: bun install - - name: Run Mutation Testing - run: bun run test:mutate + - name: Run Property-Based Tests + run: bun run test:property - - name: Check Mutation Score - run: | - # Stryker outputs score in reports/mutation/index.html - # The threshold is configured in stryker.config.mjs (break: 50) - echo "✅ Mutation testing passed (score >= threshold)" + # NOTE: Full mutation testing removed from CI (too slow: ~2hrs for 1309 mutants) + # Run manually with: bun run test:mutate + # Consider scheduling nightly: add a separate workflow with schedule trigger pr-comment: name: PR Quality Summary diff --git a/packages/mcp/package.json b/packages/mcp/package.json index 7f27b73..4240527 100644 --- a/packages/mcp/package.json +++ b/packages/mcp/package.json @@ -21,4 +21,4 @@ "@types/uuid": "^11.0.0", "typescript": "^5.9.3" } -} \ No newline at end of file +} diff --git a/packages/mcp/src/cancel/manager.test.ts b/packages/mcp/src/cancel/manager.test.ts index 24037c0..a2dd91b 100644 --- a/packages/mcp/src/cancel/manager.test.ts +++ b/packages/mcp/src/cancel/manager.test.ts @@ -1,118 +1,144 @@ import { beforeEach, describe, expect, mock, test } from "bun:test"; import { randomUUID } from "node:crypto"; +import { toolOperationStore } from "../store/operation-store"; import { CancellationManager } from "./manager"; describe("CancellationManager", () => { - let manager: CancellationManager; - let mockClient: any; - - beforeEach(() => { - manager = new CancellationManager(); - - // Mock MCP client with notification method - mockClient = { - notification: mock(() => Promise.resolve()), - }; - manager.setClient(mockClient); - }); - - test("register() starts timeout timer", () => { - const originalSetTimeout = global.setTimeout; - const setTimeoutMock = mock( - (fn: () => void, ms: number) => - originalSetTimeout(fn, ms) as unknown as NodeJS.Timeout, - ); - global.setTimeout = setTimeoutMock as any; - - try { - const requestId = "req-1"; - const operationId = randomUUID(); - - manager.register(requestId, operationId, 5000); - - expect(setTimeoutMock).toHaveBeenCalled(); - } finally { - global.setTimeout = originalSetTimeout; - } - }); - - test("cancel() sends notifications/cancelled notification", async () => { - const requestId = "req-2"; - const operationId = randomUUID(); - - manager.register(requestId, operationId, 30000); - await manager.cancel(operationId, "User requested cancellation"); - - expect(mockClient.notification).toHaveBeenCalledWith( - expect.objectContaining({ - method: "notifications/cancelled", - params: expect.objectContaining({ - requestId: requestId, - reason: "User requested cancellation", - }), - }), - ); - }); - - test("cancel() updates operation status to cancelled", async () => { - const requestId = "req-3"; - const operationId = randomUUID(); - - manager.register(requestId, operationId, 30000); - await manager.cancel(operationId); - - // Verification would require access to the operation store - // The implementation should update the store's operation status - // This test verifies the method doesn't throw - }); - - test("cancel() clears timeout timer", async () => { - const originalClearTimeout = global.clearTimeout; - const clearTimeoutMock = mock(() => { }); - global.clearTimeout = clearTimeoutMock as any; - - try { - const requestId = "req-4"; - const operationId = randomUUID(); - - manager.register(requestId, operationId, 30000); - await manager.cancel(operationId); - - expect(clearTimeoutMock).toHaveBeenCalled(); - } finally { - global.clearTimeout = originalClearTimeout; - } - }); - - test("onResponse() clears pending request", () => { - const requestId = "req-5"; - const operationId = randomUUID(); - - manager.register(requestId, operationId, 30000); - manager.onResponse(requestId); - - // Calling cancel after onResponse should not send notification - // because the request is no longer pending - }); - - test("onResponse() ignores unknown requestId", () => { - // Should not throw for unknown requestId - expect(() => manager.onResponse("unknown-id")).not.toThrow(); - }); - - test("timeout auto-cancels operation", async () => { - // Use fake timers or short timeout - const requestId = "req-6"; - const operationId = randomUUID(); - - // Register with very short timeout - manager.register(requestId, operationId, 50); - - // Wait for timeout to fire - await new Promise((resolve) => setTimeout(resolve, 100)); - - // The implementation should have auto-cancelled - // Verify via notification call or store state - // For now, we verify that the timeout mechanism is wired up - }); + let manager: CancellationManager; + let mockClient: any; + + beforeEach(() => { + manager = new CancellationManager(); + + // Mock MCP client with notification method + mockClient = { + notification: mock(() => Promise.resolve()), + }; + manager.setClient(mockClient); + }); + + test("register() starts timeout timer", () => { + const originalSetTimeout = global.setTimeout; + const setTimeoutMock = mock( + (fn: () => void, ms: number) => + originalSetTimeout(fn, ms) as unknown as NodeJS.Timeout, + ); + global.setTimeout = setTimeoutMock as any; + + try { + const requestId = "req-1"; + const operationId = randomUUID(); + + manager.register(requestId, operationId, 5000); + + expect(setTimeoutMock).toHaveBeenCalled(); + // Also verify the timeout was called with correct duration + const calls = setTimeoutMock.mock.calls; + expect(calls.length).toBeGreaterThan(0); + } finally { + global.setTimeout = originalSetTimeout; + } + }); + + test("cancel() sends notifications/cancelled notification", async () => { + const requestId = "req-2"; + const operationId = randomUUID(); + + manager.register(requestId, operationId, 30000); + await manager.cancel(operationId, "User requested cancellation"); + + expect(mockClient.notification).toHaveBeenCalledWith( + expect.objectContaining({ + method: "notifications/cancelled", + params: expect.objectContaining({ + requestId: requestId, + reason: "User requested cancellation", + }), + }), + ); + }); + + test("cancel() updates operation status to cancelled", async () => { + const requestId = "req-3"; + const sessionId = "session-3"; + + // Create an operation first so we can verify status update + const toolRequest = { name: "echo", arguments: { message: "test" } }; + const operation = toolOperationStore.create( + sessionId, + toolRequest, + requestId, + ); + const testOpId = operation.id; + + manager.register(requestId, testOpId, 30000); + await manager.cancel(testOpId, "Test cancellation"); + + // Verify the operation store was updated + const updatedOperation = toolOperationStore.get(testOpId); + expect(updatedOperation).toBeDefined(); + expect(updatedOperation?.status).toBe("cancelled"); + expect(updatedOperation?.cancelReason).toBe("Test cancellation"); + }); + + test("cancel() clears timeout timer", async () => { + const originalClearTimeout = global.clearTimeout; + const clearTimeoutMock = mock(() => {}); + global.clearTimeout = clearTimeoutMock as any; + + try { + const requestId = "req-4"; + const operationId = randomUUID(); + + manager.register(requestId, operationId, 30000); + await manager.cancel(operationId); + + expect(clearTimeoutMock).toHaveBeenCalled(); + } finally { + global.clearTimeout = originalClearTimeout; + } + }); + + test("onResponse() clears pending request", async () => { + const requestId = "req-5"; + const operationId = randomUUID(); + + manager.register(requestId, operationId, 30000); + manager.onResponse(requestId); + + // Calling cancel after onResponse should not send notification + // because the request is no longer pending + await manager.cancel(operationId); // This should be a no-op + + // Verify no notification was sent (since onResponse already cleared it) + expect(mockClient.notification).not.toHaveBeenCalled(); + }); + + test("onResponse() ignores unknown requestId", () => { + // Should not throw for unknown requestId + expect(() => manager.onResponse("unknown-id")).not.toThrow(); + }); + + test("timeout auto-cancels operation", async () => { + const requestId = "req-6"; + const operationId = randomUUID(); + + // Register with very short timeout + manager.register(requestId, operationId, 50); + + // Wait for timeout to fire + await new Promise((resolve) => setTimeout(resolve, 150)); + + // The implementation should have auto-cancelled + // Verify the notification was sent with timeout reason + expect(mockClient.notification).toHaveBeenCalledWith( + expect.objectContaining({ + method: "notifications/cancelled", + params: expect.objectContaining({ + requestId: requestId, + reason: "Request timeout", + }), + }), + ); + }); }); diff --git a/packages/mcp/src/cancel/manager.ts b/packages/mcp/src/cancel/manager.ts index b5e677a..ae8bf3d 100644 --- a/packages/mcp/src/cancel/manager.ts +++ b/packages/mcp/src/cancel/manager.ts @@ -9,156 +9,156 @@ import type { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { toolOperationStore } from "../store/operation-store"; interface PendingRequest { - requestId: string; - operationId: string; - startedAt: Date; - timeoutMs: number; - timeoutHandle: ReturnType; - rejectFn?: (reason: Error) => void; + requestId: string; + operationId: string; + startedAt: Date; + timeoutMs: number; + timeoutHandle: ReturnType; + rejectFn?: (reason: Error) => void; } export class CancellationManager { - // Map: requestId → PendingRequest - private pendingRequests = new Map(); - // Reverse lookup: operationId → requestId - private operationToRequest = new Map(); - private defaultTimeoutMs = 30000; - private client: Client | null = null; - - /** - * Set the MCP client to use for sending notifications. - */ - setClient(client: Client): void { - this.client = client; - } - - /** - * Register a request for potential cancellation. - * Starts a timeout timer. - * @param requestId - The JSON-RPC request ID - * @param operationId - The operation ID - * @param timeoutMs - Timeout in milliseconds (default 30000) - * @param rejectFn - Optional reject function to abort pending promise - */ - register( - requestId: string, - operationId: string, - timeoutMs: number = this.defaultTimeoutMs, - rejectFn?: (reason: Error) => void, - ): void { - const timeoutHandle = setTimeout(() => { - this.onTimeout(requestId); - }, timeoutMs); - - this.pendingRequests.set(requestId, { - requestId, - operationId, - startedAt: new Date(), - timeoutMs, - timeoutHandle, - rejectFn, - }); - - // Reverse lookup for cancel by operationId - this.operationToRequest.set(operationId, requestId); - } - - /** - * Cancel an operation. - * Sends cancellation notification and updates store. - * @param operationId - The operation ID - * @param reason - Optional cancellation reason - */ - async cancel(operationId: string, reason?: string): Promise { - // Find pending request by operationId - const requestId = this.operationToRequest.get(operationId); - if (!requestId) { - // No pending request - already completed or unknown - return; - } - - const entry = this.pendingRequests.get(requestId); - if (!entry) { - return; - } - - // Clear timeout - clearTimeout(entry.timeoutHandle); - - // Update store first (before sending notification) - toolOperationStore.markCancelled(operationId, reason); - - // Reject pending promise to abort the callTool await - if (entry.rejectFn) { - entry.rejectFn(new Error(reason ?? "Operation cancelled")); - } - - // Send cancellation notification - await this.sendCancelNotification(requestId, reason ?? "User cancelled"); - - // Remove from pending - this.pendingRequests.delete(requestId); - this.operationToRequest.delete(operationId); - } - - /** - * Handle a response arriving for a request. - * Clears timeout and removes from pending list. - * @param requestId - The JSON-RPC request ID - */ - onResponse(requestId: string): void { - const entry = this.pendingRequests.get(requestId); - if (!entry) { - // Already cancelled or unknown — ignore response - return; - } - - // Clear timeout and remove - clearTimeout(entry.timeoutHandle); - this.pendingRequests.delete(requestId); - this.operationToRequest.delete(entry.operationId); - } - - /** - * Handle timeout for a request. - * @param requestId - The JSON-RPC request ID - */ - private onTimeout(requestId: string): void { - const entry = this.pendingRequests.get(requestId); - if (!entry) return; - - const reason = "Request timeout"; - - // Update store with cancelled status - toolOperationStore.markCancelled(entry.operationId, reason); - - // Reject pending promise to abort the callTool await - if (entry.rejectFn) { - entry.rejectFn(new Error(reason)); - } - - // Send cancel notification (fire and forget) - this.sendCancelNotification(requestId, reason); - - // Remove from pending - this.pendingRequests.delete(requestId); - this.operationToRequest.delete(entry.operationId); - } - - /** - * Send cancellation notification to the server. - */ - private async sendCancelNotification( - requestId: string, - reason: string, - ): Promise { - if (!this.client) return; - - await this.client.notification({ - method: "notifications/cancelled", - params: { requestId, reason }, - }); - } + // Map: requestId → PendingRequest + private pendingRequests = new Map(); + // Reverse lookup: operationId → requestId + private operationToRequest = new Map(); + private defaultTimeoutMs = 30000; + private client: Client | null = null; + + /** + * Set the MCP client to use for sending notifications. + */ + setClient(client: Client): void { + this.client = client; + } + + /** + * Register a request for potential cancellation. + * Starts a timeout timer. + * @param requestId - The JSON-RPC request ID + * @param operationId - The operation ID + * @param timeoutMs - Timeout in milliseconds (default 30000) + * @param rejectFn - Optional reject function to abort pending promise + */ + register( + requestId: string, + operationId: string, + timeoutMs: number = this.defaultTimeoutMs, + rejectFn?: (reason: Error) => void, + ): void { + const timeoutHandle = setTimeout(() => { + this.onTimeout(requestId); + }, timeoutMs); + + this.pendingRequests.set(requestId, { + requestId, + operationId, + startedAt: new Date(), + timeoutMs, + timeoutHandle, + rejectFn, + }); + + // Reverse lookup for cancel by operationId + this.operationToRequest.set(operationId, requestId); + } + + /** + * Cancel an operation. + * Sends cancellation notification and updates store. + * @param operationId - The operation ID + * @param reason - Optional cancellation reason + */ + async cancel(operationId: string, reason?: string): Promise { + // Find pending request by operationId + const requestId = this.operationToRequest.get(operationId); + if (!requestId) { + // No pending request - already completed or unknown + return; + } + + const entry = this.pendingRequests.get(requestId); + if (!entry) { + return; + } + + // Clear timeout + clearTimeout(entry.timeoutHandle); + + // Update store first (before sending notification) + toolOperationStore.markCancelled(operationId, reason); + + // Reject pending promise to abort the callTool await + if (entry.rejectFn) { + entry.rejectFn(new Error(reason ?? "Operation cancelled")); + } + + // Send cancellation notification + await this.sendCancelNotification(requestId, reason ?? "User cancelled"); + + // Remove from pending + this.pendingRequests.delete(requestId); + this.operationToRequest.delete(operationId); + } + + /** + * Handle a response arriving for a request. + * Clears timeout and removes from pending list. + * @param requestId - The JSON-RPC request ID + */ + onResponse(requestId: string): void { + const entry = this.pendingRequests.get(requestId); + if (!entry) { + // Already cancelled or unknown — ignore response + return; + } + + // Clear timeout and remove + clearTimeout(entry.timeoutHandle); + this.pendingRequests.delete(requestId); + this.operationToRequest.delete(entry.operationId); + } + + /** + * Handle timeout for a request. + * @param requestId - The JSON-RPC request ID + */ + private onTimeout(requestId: string): void { + const entry = this.pendingRequests.get(requestId); + if (!entry) return; + + const reason = "Request timeout"; + + // Update store with cancelled status + toolOperationStore.markCancelled(entry.operationId, reason); + + // Reject pending promise to abort the callTool await + if (entry.rejectFn) { + entry.rejectFn(new Error(reason)); + } + + // Send cancel notification (fire and forget) + this.sendCancelNotification(requestId, reason); + + // Remove from pending + this.pendingRequests.delete(requestId); + this.operationToRequest.delete(entry.operationId); + } + + /** + * Send cancellation notification to the server. + */ + private async sendCancelNotification( + requestId: string, + reason: string, + ): Promise { + if (!this.client) return; + + await this.client.notification({ + method: "notifications/cancelled", + params: { requestId, reason }, + }); + } } // Singleton instance diff --git a/packages/mcp/src/client/manager.ts b/packages/mcp/src/client/manager.ts index 2ee11a2..d434f15 100644 --- a/packages/mcp/src/client/manager.ts +++ b/packages/mcp/src/client/manager.ts @@ -19,18 +19,36 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; import type { MiddlewarePipeline, SessionManager } from "@say2/core"; +import { z } from "zod"; +import { cancellationManager } from "../cancel/manager"; +import { ContentParser } from "../content/parser"; +import { progressTracker } from "../progress/tracker"; +import { toolOperationStore } from "../store"; +import { taskManager } from "../task/manager"; import { LoggingTransport } from "../transport"; -import type { McpClientRegistry } from "./registry"; +import type { ToolContent } from "../types/content"; +import { McpProgressNotificationSchema } from "../types/progress"; +import { + type CreateTaskResult, + CreateTaskResultSchema, + EmptyResultSchema, + type Task, + TaskGetResultSchema, + TaskListResultSchema, + type TaskMetadata, + TaskStatusNotificationSchema, +} from "../types/task"; import type { + CallToolOptions, ToolCallRequest, ToolOperation, - CallToolOptions, } from "../types/tool"; -import { toolOperationStore } from "../store"; -import { progressTracker } from "../progress/tracker"; -import { McpProgressNotificationSchema } from "../types/progress"; -import { cancellationManager } from "../cancel/manager"; -import { ContentParser } from "../content/parser"; +import { + applyAnnotationDefaults, + type Tool, + type ToolAnnotations, +} from "../types/tool-annotations"; +import type { McpClientRegistry } from "./registry"; export class McpClientManager { constructor( @@ -41,7 +59,7 @@ export class McpClientManager { clientInfo: { name: string; version: string }, options?: { capabilities: any }, ) => Client = (info, opts) => new Client(info, opts), - ) { } + ) {} /** * Connect to an MCP server for the given session. @@ -119,6 +137,14 @@ export class McpClientManager { }, ); + // 9. Set up task status notification handler (optional per spec) + client.setNotificationHandler( + TaskStatusNotificationSchema, + (notification) => { + taskManager.handleStatusNotification(notification.params); + }, + ); + // 9. Register in registry this.registry.register(sessionId, client, loggingTransport); @@ -220,6 +246,14 @@ export class McpClientManager { cursor = result.nextCursor; } while (cursor); + // Cache in session + const session = this.sessionManager.get(sessionId); + if (session?.serverCapabilities) { + const caps = session.serverCapabilities as any; + if (!caps.discovered) caps.discovered = {}; + caps.discovered.tools = tools; + } + return { tools }; } @@ -241,6 +275,14 @@ export class McpClientManager { cursor = result.nextCursor; } while (cursor); + // Cache in session + const session = this.sessionManager.get(sessionId); + if (session?.serverCapabilities) { + const caps = session.serverCapabilities as any; + if (!caps.discovered) caps.discovered = {}; + caps.discovered.resources = resources; + } + return { resources }; } @@ -285,9 +327,72 @@ export class McpClientManager { cursor = result.nextCursor; } while (cursor); + // Cache in session + const session = this.sessionManager.get(sessionId); + if (session?.serverCapabilities) { + const caps = session.serverCapabilities as any; + if (!caps.discovered) caps.discovered = {}; + caps.discovered.prompts = prompts; + } + return { prompts }; } + // ========================================================================= + // Tool Annotations (Phase 2a Task 06) + // ========================================================================= + + /** + * List all tools with full typing and annotations applied. + * Returns cached tools from session with defaults applied to annotations. + * + * @param sessionId - The session ID + * @returns Array of fully-typed Tool objects with annotation defaults + */ + listToolsTyped(sessionId: string): Tool[] { + const session = this.sessionManager.get(sessionId); + const discovered = session?.serverCapabilities?.discovered as + | { tools?: Tool[] } + | undefined; + const tools = discovered?.tools ?? []; + + return tools.map((tool) => ({ + ...tool, + annotations: applyAnnotationDefaults(tool.annotations), + })); + } + + /** + * Retrieve annotations for a specific tool. + * Tools are stored during Phase 1 capability discovery. + * + * @param sessionId - The session ID + * @param toolName - The name of the tool + * @returns ToolAnnotations with defaults applied, or undefined if not found + */ + getToolAnnotations( + sessionId: string, + toolName: string, + ): ToolAnnotations | undefined { + const session = this.sessionManager.get(sessionId); + const discovered = session?.serverCapabilities?.discovered as + | { tools?: Tool[] } + | undefined; + + if (!discovered?.tools) { + return undefined; + } + + const tool = discovered.tools.find((t) => t.name === toolName); + + if (!tool) { + return undefined; + } + + // Apply defaults to ensure all fields are present + return applyAnnotationDefaults(tool.annotations); + } + // ========================================================================= // Tool Operations // ========================================================================= @@ -334,11 +439,20 @@ export class McpClientManager { }); cancellationManager.setClient(entry.client); - cancellationManager.register(requestId, operation.id, options?.timeout, cancelReject); + cancellationManager.register( + requestId, + operation.id, + options?.timeout, + cancelReject, + ); try { // Build request params with optional progress token - const callParams: { name: string; arguments: Record; _meta?: { progressToken: string } } = { + const callParams: { + name: string; + arguments: Record; + _meta?: { progressToken: string }; + } = { name: request.name, arguments: request.arguments ?? {}, }; @@ -362,7 +476,7 @@ export class McpClientManager { // Parse and validate content via ContentParser const contentParser = new ContentParser(); - let parsedContent; + let parsedContent: ToolContent[]; try { parsedContent = contentParser.parseContent(result.content as unknown[]); } catch (parseError) { @@ -371,7 +485,10 @@ export class McpClientManager { status: "error", error: { code: -32602, // Invalid params - message: parseError instanceof Error ? parseError.message : String(parseError), + message: + parseError instanceof Error + ? parseError.message + : String(parseError), }, }); return toolOperationStore.get(operation.id)!; @@ -488,4 +605,224 @@ export class McpClientManager { await cancellationManager.cancel(operationId, reason); } + + // ========================================================================= + // Task Operations (Phase 2a Task 07) + // ========================================================================= + + /** + * List all active tasks for a session. + * @param sessionId - The session ID + * @returns Array of active tasks + */ + async listTasks(sessionId: string): Promise { + const client = this.getClient(sessionId); + if (!client) { + throw new Error(`Session ${sessionId} not connected`); + } + + const tasks: Task[] = []; + let cursor: string | undefined; + + do { + const result = await client.request( + { method: "tasks/list", params: { cursor } }, + TaskListResultSchema, + ); + tasks.push(...result.tasks); + cursor = result.nextCursor; + } while (cursor); + + return tasks; + } + + /** + * Get a specific task's status. + * @param sessionId - The session ID + * @param taskId - The task identifier + * @returns Task metadata + */ + async getTask(sessionId: string, taskId: string): Promise { + const client = this.getClient(sessionId); + if (!client) { + throw new Error(`Session ${sessionId} not connected`); + } + + return await client.request( + { method: "tasks/get", params: { taskId } }, + TaskGetResultSchema, + ); + } + + /** + * Get the actual result of a completed task. + * @param sessionId - The session ID + * @param taskId - The task identifier + * @returns The tool call result + */ + async getTaskResult(sessionId: string, taskId: string): Promise { + const client = this.getClient(sessionId); + if (!client) { + throw new Error(`Session ${sessionId} not connected`); + } + + // Return unknown for now - will be typed when integrating with callTool + return await client.request( + { method: "tasks/result", params: { taskId } }, + z.unknown(), + ); + } + + /** + * Cancel a running task. + * @param sessionId - The session ID + * @param taskId - The task identifier + */ + async cancelTask(sessionId: string, taskId: string): Promise { + const client = this.getClient(sessionId); + if (!client) { + throw new Error(`Session ${sessionId} not connected`); + } + + await client.request( + { method: "tasks/cancel", params: { taskId } }, + EmptyResultSchema, + ); + + // Update local cache + taskManager.removeTask(taskId); + } + + /** + * Check if a tool supports task-augmented execution. + * @param sessionId - The session ID + * @param toolName - The tool name + * @returns Task support level + */ + getToolTaskSupport( + sessionId: string, + toolName: string, + ): "forbidden" | "optional" | "required" { + const session = this.sessionManager.get(sessionId); + const discovered = session?.serverCapabilities?.discovered as + | { tools?: Tool[] } + | undefined; + const tool = discovered?.tools?.find((t) => t.name === toolName); + + // Default to 'forbidden' if not specified + return tool?.execution?.taskSupport ?? "forbidden"; + } + + /** + * Call a tool with task-augmented execution. + * Returns a task that can be polled for completion. + * + * @param sessionId - The session ID + * @param request - The tool call request + * @param taskOptions - Task metadata (e.g., ttl) + * @returns CreateTaskResult with the created task + * @throws Error if tool doesn't support tasks + */ + async callToolAsTask( + sessionId: string, + request: ToolCallRequest, + taskOptions: TaskMetadata = {}, + ): Promise { + const client = this.getClient(sessionId); + if (!client) { + throw new Error(`Session ${sessionId} not connected`); + } + + // 1. Check server-level capability first (per spec) + const session = this.sessionManager.get(sessionId); + const serverCaps = session?.serverCapabilities as + | { tasks?: { requests?: { tools?: { call?: boolean } } } } + | undefined; + + if (!serverCaps?.tasks?.requests?.tools?.call) { + throw new Error("Server does not support task-augmented tool execution"); + } + + // 2. Check tool-level support + const taskSupport = this.getToolTaskSupport(sessionId, request.name); + if (taskSupport === "forbidden") { + throw new Error( + `Tool "${request.name}" does not support task-augmented execution`, + ); + } + + // Call tool with task metadata + const result = await client.request( + { + method: "tools/call", + params: { + name: request.name, + arguments: request.arguments, + task: taskOptions, + }, + }, + CreateTaskResultSchema, + ); + + // Register task in local manager + taskManager.registerTask(result.task.taskId, sessionId, result.task); + + return result; + } + + /** + * Call a tool as a task and wait for completion. + * Combines callToolAsTask with polling. + * + * @param sessionId - The session ID + * @param request - The tool call request + * @param taskOptions - Task metadata + * @param onProgress - Optional progress callback + * @returns The final tool result + */ + async callToolAsTaskAndWait( + sessionId: string, + request: ToolCallRequest, + taskOptions: TaskMetadata = {}, + onProgress?: (task: Task) => void, + ): Promise { + // Start the task + const createResult = await this.callToolAsTask( + sessionId, + request, + taskOptions, + ); + + const taskId = createResult.task.taskId; + + // Poll until complete or input_required (which needs user action) + const finalTask = await taskManager.pollUntilComplete( + taskId, + () => this.getTask(sessionId, taskId), + onProgress, + // Stop early on input_required since this method can't provide input + (task) => task.status === "input_required", + ); + + // Handle terminal states + if (finalTask.status === "failed") { + throw new Error(finalTask.statusMessage ?? `Task ${taskId} failed`); + } + + if (finalTask.status === "cancelled") { + throw new Error( + finalTask.statusMessage ?? `Task ${taskId} was cancelled`, + ); + } + + // input_required requires user interaction - this method can't handle it + if (finalTask.status === "input_required") { + throw new Error( + `Task ${taskId} requires input: ${finalTask.statusMessage ?? "waiting for elicitation or sampling"}`, + ); + } + + // Get the actual result + return await this.getTaskResult(sessionId, taskId); + } } diff --git a/packages/mcp/src/content/parser.test.ts b/packages/mcp/src/content/parser.test.ts index 016797e..9104af2 100644 --- a/packages/mcp/src/content/parser.test.ts +++ b/packages/mcp/src/content/parser.test.ts @@ -2,350 +2,350 @@ import { beforeEach, describe, expect, test } from "bun:test"; import { ContentParser } from "./parser"; describe("ContentParser", () => { - let parser: ContentParser; - - beforeEach(() => { - parser = new ContentParser(); - }); - - describe("parseContent()", () => { - test("parses text content", () => { - const raw = [{ type: "text", text: "Hello world" }]; - const result = parser.parseContent(raw); - - expect(result).toHaveLength(1); - expect(result[0]?.type).toBe("text"); - if (result[0]?.type === "text") { - expect(result[0].text).toBe("Hello world"); - } - }); - - test("parses image content with base64 data", () => { - const raw = [ - { - type: "image", - data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==", - mimeType: "image/png", - }, - ]; - const result = parser.parseContent(raw); - - expect(result).toHaveLength(1); - expect(result[0]?.type).toBe("image"); - if (result[0]?.type === "image") { - expect(result[0].data).toBeDefined(); - expect(result[0].mimeType).toBe("image/png"); - } - }); - - test("parses audio content", () => { - const raw = [ - { - type: "audio", - data: "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=", - mimeType: "audio/wav", - }, - ]; - const result = parser.parseContent(raw); - - expect(result).toHaveLength(1); - expect(result[0]?.type).toBe("audio"); - if (result[0]?.type === "audio") { - expect(result[0].mimeType).toBe("audio/wav"); - } - }); - - test("parses resource_link content", () => { - const raw = [ - { - type: "resource_link", - uri: "file:///path/to/file.txt", - name: "My File", - mimeType: "text/plain", - }, - ]; - const result = parser.parseContent(raw); - - expect(result).toHaveLength(1); - expect(result[0]?.type).toBe("resource_link"); - if (result[0]?.type === "resource_link") { - expect(result[0].uri).toBe("file:///path/to/file.txt"); - expect(result[0].name).toBe("My File"); - } - }); - - test("parses embedded resource content", () => { - const raw = [ - { - type: "resource", - resource: { - uri: "file:///data.json", - text: '{"key": "value"}', - mimeType: "application/json", - }, - }, - ]; - const result = parser.parseContent(raw); - - expect(result).toHaveLength(1); - expect(result[0]?.type).toBe("resource"); - if (result[0]?.type === "resource") { - expect(result[0].resource.uri).toBe("file:///data.json"); - expect(result[0].resource.text).toBe('{"key": "value"}'); - } - }); - - test("parses mixed content types", () => { - const raw = [ - { type: "text", text: "Hello" }, - { type: "image", data: "abc123", mimeType: "image/jpeg" }, - { type: "resource_link", uri: "file:///test", name: "Test" }, - ]; - const result = parser.parseContent(raw); - - expect(result).toHaveLength(3); - expect(result[0]?.type).toBe("text"); - expect(result[1]?.type).toBe("image"); - expect(result[2]?.type).toBe("resource_link"); - }); - - test("throws on invalid content type", () => { - const raw = [{ type: "invalid_type", data: "foo" }]; - - expect(() => parser.parseContent(raw)).toThrow(); - }); - - test("throws for missing required fields", () => { - const raw = [{ type: "text" }]; // missing text field - - expect(() => parser.parseContent(raw)).toThrow(); - }); - - // SKIPPED: Requires implementation changes to enforce strict mime type validation - // SKIPPED: Requires implementation changes to enforce strict mime type validation - test("throws on invalid image mime type", () => { - const raw = [ - { - type: "image", - data: "abc", - mimeType: "image/x-unknown", - }, - ]; - expect(() => parser.parseContent(raw)).toThrow("Invalid image MIME type"); - }); - - test("throws on invalid audio mime type", () => { - const raw = [ - { - type: "audio", - data: "abc", - mimeType: "audio/unknown", - }, - ]; - expect(() => parser.parseContent(raw)).toThrow("Invalid audio MIME type"); - }); - - test("throws for non-array input", () => { - expect(() => parser.parseContent({} as any)).toThrow( - "Content must be an array", - ); - }); - - test("preserves annotations", () => { - const raw = [ - { - type: "text", - text: "User-only message", - annotations: { - audience: ["user"], - priority: 0.8, - }, - }, - ]; - const result = parser.parseContent(raw); - - expect(result[0]?.annotations).toBeDefined(); - expect(result[0]?.annotations?.audience).toContain("user"); - expect(result[0]?.annotations?.priority).toBe(0.8); - }); - }); - - describe("validateStructuredOutput()", () => { - test("returns valid for content matching schema", () => { - const content = { name: "test", count: 42 }; - const schema = { - type: "object", - properties: { - name: { type: "string" }, - count: { type: "number" }, - }, - required: ["name"], - }; - - const result = parser.validateStructuredOutput(content, schema); - expect(result.valid).toBe(true); - expect(result.errors).toBeUndefined(); - }); - - test("returns invalid with errors for mismatched content", () => { - const content = { name: 123 }; // name should be string - const schema = { - type: "object", - properties: { - name: { type: "string" }, - }, - required: ["name"], - }; - - const result = parser.validateStructuredOutput(content, schema); - expect(result.valid).toBe(false); - expect(result.errors).toBeDefined(); - expect(result.errors!.length).toBeGreaterThan(0); - }); - - test("returns valid when no schema provided", () => { - const content = { anything: "goes" }; - - const result = parser.validateStructuredOutput(content); - expect(result.valid).toBe(true); - }); - - test("validates nested objects", () => { - const content = { - user: { name: "John", age: 30 }, - }; - const schema = { - type: "object", - properties: { - user: { - type: "object", - properties: { - name: { type: "string" }, - age: { type: "number" }, - }, - }, - }, - }; - - const result = parser.validateStructuredOutput(content, schema); - expect(result.valid).toBe(true); - }); - - test("validates array schema", () => { - const schema = { - type: "array", - items: { type: "string" }, - }; - const validContent = ["a", "b", "c"]; - - const result = parser.validateStructuredOutput(validContent, schema); - expect(result.valid).toBe(true); - }); - }); - - describe("decodeBase64()", () => { - test("decodes valid base64 to Uint8Array", () => { - // "Hello" in base64 - const base64 = "SGVsbG8="; - const result = parser.decodeBase64(base64); - - expect(result).toBeInstanceOf(Uint8Array); - expect(result.length).toBe(5); - // H=72, e=101, l=108, l=108, o=111 - expect(result[0]).toBe(72); - expect(result[4]).toBe(111); - }); - - test("decodes empty base64", () => { - const result = parser.decodeBase64(""); - expect(result).toBeInstanceOf(Uint8Array); - expect(result.length).toBe(0); - }); - - test("handles base64 with padding", () => { - // "Hi" = "SGk=" (with padding) - const result = parser.decodeBase64("SGk="); - expect(result.length).toBe(2); - }); - - test("handles base64 without padding", () => { - // Some base64 implementations strip padding - const result = parser.decodeBase64("SGk"); - expect(result.length).toBe(2); - }); - }); - - describe("getContentSize()", () => { - test("returns text length for text content", () => { - const content = { type: "text" as const, text: "Hello" }; - - expect(parser.getContentSize(content)).toBe(5); - }); - - test("returns estimated size for image content", () => { - const content = { - type: "image" as const, - data: "1234567890", // 10 chars - mimeType: "image/png", - }; - - // 10 * 0.75 = 7.5, floor = 7 - expect(parser.getContentSize(content)).toBe(7); - }); - - test("returns estimated size for audio content", () => { - const content = { - type: "audio" as const, - data: "12345678901234567890", // 20 chars - mimeType: "audio/wav", - }; - - // 20 * 0.75 = 15 - expect(parser.getContentSize(content)).toBe(15); - }); - - test("returns 0 for resource_link", () => { - const content = { - type: "resource_link" as const, - uri: "file:///test.txt", - }; - - expect(parser.getContentSize(content)).toBe(0); - }); - - test("returns text length for embedded resource with text", () => { - const content = { - type: "resource" as const, - resource: { - uri: "file:///data.json", - text: "Hello World", - }, - }; - - expect(parser.getContentSize(content)).toBe(11); - }); - }); - - describe("validateMimeType()", () => { - test("validates exact match", () => { - expect( - parser.validateMimeType("image/png", ["image/png", "image/jpeg"]), - ).toBe(true); - }); - - test("rejects non-matching type", () => { - expect(parser.validateMimeType("image/gif", ["image/png"])).toBe(false); - }); - - test("validates prefix match with wildcard", () => { - expect(parser.validateMimeType("image/png", ["image/*"])).toBe(true); - expect(parser.validateMimeType("image/jpeg", ["image/*"])).toBe(true); - expect(parser.validateMimeType("audio/wav", ["image/*"])).toBe(false); - }); - - test("validates audio mime types", () => { - expect(parser.validateMimeType("audio/wav", ["audio/*"])).toBe(true); - expect(parser.validateMimeType("audio/mp3", ["audio/*"])).toBe(true); - }); - }); + let parser: ContentParser; + + beforeEach(() => { + parser = new ContentParser(); + }); + + describe("parseContent()", () => { + test("parses text content", () => { + const raw = [{ type: "text", text: "Hello world" }]; + const result = parser.parseContent(raw); + + expect(result).toHaveLength(1); + expect(result[0]?.type).toBe("text"); + if (result[0]?.type === "text") { + expect(result[0].text).toBe("Hello world"); + } + }); + + test("parses image content with base64 data", () => { + const raw = [ + { + type: "image", + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==", + mimeType: "image/png", + }, + ]; + const result = parser.parseContent(raw); + + expect(result).toHaveLength(1); + expect(result[0]?.type).toBe("image"); + if (result[0]?.type === "image") { + expect(result[0].data).toBeDefined(); + expect(result[0].mimeType).toBe("image/png"); + } + }); + + test("parses audio content", () => { + const raw = [ + { + type: "audio", + data: "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=", + mimeType: "audio/wav", + }, + ]; + const result = parser.parseContent(raw); + + expect(result).toHaveLength(1); + expect(result[0]?.type).toBe("audio"); + if (result[0]?.type === "audio") { + expect(result[0].mimeType).toBe("audio/wav"); + } + }); + + test("parses resource_link content", () => { + const raw = [ + { + type: "resource_link", + uri: "file:///path/to/file.txt", + name: "My File", + mimeType: "text/plain", + }, + ]; + const result = parser.parseContent(raw); + + expect(result).toHaveLength(1); + expect(result[0]?.type).toBe("resource_link"); + if (result[0]?.type === "resource_link") { + expect(result[0].uri).toBe("file:///path/to/file.txt"); + expect(result[0].name).toBe("My File"); + } + }); + + test("parses embedded resource content", () => { + const raw = [ + { + type: "resource", + resource: { + uri: "file:///data.json", + text: '{"key": "value"}', + mimeType: "application/json", + }, + }, + ]; + const result = parser.parseContent(raw); + + expect(result).toHaveLength(1); + expect(result[0]?.type).toBe("resource"); + if (result[0]?.type === "resource") { + expect(result[0].resource.uri).toBe("file:///data.json"); + expect(result[0].resource.text).toBe('{"key": "value"}'); + } + }); + + test("parses mixed content types", () => { + const raw = [ + { type: "text", text: "Hello" }, + { type: "image", data: "abc123", mimeType: "image/jpeg" }, + { type: "resource_link", uri: "file:///test", name: "Test" }, + ]; + const result = parser.parseContent(raw); + + expect(result).toHaveLength(3); + expect(result[0]?.type).toBe("text"); + expect(result[1]?.type).toBe("image"); + expect(result[2]?.type).toBe("resource_link"); + }); + + test("throws on invalid content type", () => { + const raw = [{ type: "invalid_type", data: "foo" }]; + + expect(() => parser.parseContent(raw)).toThrow(); + }); + + test("throws for missing required fields", () => { + const raw = [{ type: "text" }]; // missing text field + + expect(() => parser.parseContent(raw)).toThrow(); + }); + + // SKIPPED: Requires implementation changes to enforce strict mime type validation + // SKIPPED: Requires implementation changes to enforce strict mime type validation + test("throws on invalid image mime type", () => { + const raw = [ + { + type: "image", + data: "abc", + mimeType: "image/x-unknown", + }, + ]; + expect(() => parser.parseContent(raw)).toThrow("Invalid image MIME type"); + }); + + test("throws on invalid audio mime type", () => { + const raw = [ + { + type: "audio", + data: "abc", + mimeType: "audio/unknown", + }, + ]; + expect(() => parser.parseContent(raw)).toThrow("Invalid audio MIME type"); + }); + + test("throws for non-array input", () => { + expect(() => parser.parseContent({} as any)).toThrow( + "Content must be an array", + ); + }); + + test("preserves annotations", () => { + const raw = [ + { + type: "text", + text: "User-only message", + annotations: { + audience: ["user"], + priority: 0.8, + }, + }, + ]; + const result = parser.parseContent(raw); + + expect(result[0]?.annotations).toBeDefined(); + expect(result[0]?.annotations?.audience).toContain("user"); + expect(result[0]?.annotations?.priority).toBe(0.8); + }); + }); + + describe("validateStructuredOutput()", () => { + test("returns valid for content matching schema", () => { + const content = { name: "test", count: 42 }; + const schema = { + type: "object", + properties: { + name: { type: "string" }, + count: { type: "number" }, + }, + required: ["name"], + }; + + const result = parser.validateStructuredOutput(content, schema); + expect(result.valid).toBe(true); + expect(result.errors).toBeUndefined(); + }); + + test("returns invalid with errors for mismatched content", () => { + const content = { name: 123 }; // name should be string + const schema = { + type: "object", + properties: { + name: { type: "string" }, + }, + required: ["name"], + }; + + const result = parser.validateStructuredOutput(content, schema); + expect(result.valid).toBe(false); + expect(result.errors).toBeDefined(); + expect(result.errors?.length).toBeGreaterThan(0); + }); + + test("returns valid when no schema provided", () => { + const content = { anything: "goes" }; + + const result = parser.validateStructuredOutput(content); + expect(result.valid).toBe(true); + }); + + test("validates nested objects", () => { + const content = { + user: { name: "John", age: 30 }, + }; + const schema = { + type: "object", + properties: { + user: { + type: "object", + properties: { + name: { type: "string" }, + age: { type: "number" }, + }, + }, + }, + }; + + const result = parser.validateStructuredOutput(content, schema); + expect(result.valid).toBe(true); + }); + + test("validates array schema", () => { + const schema = { + type: "array", + items: { type: "string" }, + }; + const validContent = ["a", "b", "c"]; + + const result = parser.validateStructuredOutput(validContent, schema); + expect(result.valid).toBe(true); + }); + }); + + describe("decodeBase64()", () => { + test("decodes valid base64 to Uint8Array", () => { + // "Hello" in base64 + const base64 = "SGVsbG8="; + const result = parser.decodeBase64(base64); + + expect(result).toBeInstanceOf(Uint8Array); + expect(result.length).toBe(5); + // H=72, e=101, l=108, l=108, o=111 + expect(result[0]).toBe(72); + expect(result[4]).toBe(111); + }); + + test("decodes empty base64", () => { + const result = parser.decodeBase64(""); + expect(result).toBeInstanceOf(Uint8Array); + expect(result.length).toBe(0); + }); + + test("handles base64 with padding", () => { + // "Hi" = "SGk=" (with padding) + const result = parser.decodeBase64("SGk="); + expect(result.length).toBe(2); + }); + + test("handles base64 without padding", () => { + // Some base64 implementations strip padding + const result = parser.decodeBase64("SGk"); + expect(result.length).toBe(2); + }); + }); + + describe("getContentSize()", () => { + test("returns text length for text content", () => { + const content = { type: "text" as const, text: "Hello" }; + + expect(parser.getContentSize(content)).toBe(5); + }); + + test("returns estimated size for image content", () => { + const content = { + type: "image" as const, + data: "1234567890", // 10 chars + mimeType: "image/png", + }; + + // 10 * 0.75 = 7.5, floor = 7 + expect(parser.getContentSize(content)).toBe(7); + }); + + test("returns estimated size for audio content", () => { + const content = { + type: "audio" as const, + data: "12345678901234567890", // 20 chars + mimeType: "audio/wav", + }; + + // 20 * 0.75 = 15 + expect(parser.getContentSize(content)).toBe(15); + }); + + test("returns 0 for resource_link", () => { + const content = { + type: "resource_link" as const, + uri: "file:///test.txt", + }; + + expect(parser.getContentSize(content)).toBe(0); + }); + + test("returns text length for embedded resource with text", () => { + const content = { + type: "resource" as const, + resource: { + uri: "file:///data.json", + text: "Hello World", + }, + }; + + expect(parser.getContentSize(content)).toBe(11); + }); + }); + + describe("validateMimeType()", () => { + test("validates exact match", () => { + expect( + parser.validateMimeType("image/png", ["image/png", "image/jpeg"]), + ).toBe(true); + }); + + test("rejects non-matching type", () => { + expect(parser.validateMimeType("image/gif", ["image/png"])).toBe(false); + }); + + test("validates prefix match with wildcard", () => { + expect(parser.validateMimeType("image/png", ["image/*"])).toBe(true); + expect(parser.validateMimeType("image/jpeg", ["image/*"])).toBe(true); + expect(parser.validateMimeType("audio/wav", ["image/*"])).toBe(false); + }); + + test("validates audio mime types", () => { + expect(parser.validateMimeType("audio/wav", ["audio/*"])).toBe(true); + expect(parser.validateMimeType("audio/mp3", ["audio/*"])).toBe(true); + }); + }); }); diff --git a/packages/mcp/src/content/parser.ts b/packages/mcp/src/content/parser.ts index bd2306c..e93f271 100644 --- a/packages/mcp/src/content/parser.ts +++ b/packages/mcp/src/content/parser.ts @@ -7,172 +7,172 @@ import Ajv from "ajv"; import { - ToolContentSchema, - AudioContentSchema, - AudioMimeTypes, - ImageMimeTypes, - type ToolContent, - type AudioContent, + type AudioContent, + AudioContentSchema, + AudioMimeTypes, + ImageMimeTypes, + type ToolContent, + ToolContentSchema, } from "../types/content"; export interface ValidationResult { - valid: boolean; - errors?: string[]; + valid: boolean; + errors?: string[]; } export class ContentParser { - private ajv = new Ajv(); - - /** - * Parse raw content array into typed ToolContent objects. - * Validates types, base64 data, and mime types. - * @param rawContent - The raw content array from JSON-RPC result - * @throws Error if content is invalid - */ - parseContent(rawContent: unknown[]): ToolContent[] { - if (!Array.isArray(rawContent)) { - throw new Error("Content must be an array"); - } - - return rawContent.map((item, index) => { - const result = ToolContentSchema.safeParse(item); - if (!result.success) { - const issues = result.error.issues - .map((i) => `${i.path.join(".")}: ${i.message}`) - .join(", "); - throw new Error(`Invalid content at index ${index}: ${issues}`); - } - const content = result.data; - - // Enforce strict MIME type validation for image and audio - if (content.type === "image") { - if (!this.validateMimeType(content.mimeType, ImageMimeTypes)) { - throw new Error(`Invalid image MIME type: ${content.mimeType}`); - } - } - - if (content.type === "audio") { - if (!this.validateMimeType(content.mimeType, AudioMimeTypes)) { - throw new Error(`Invalid audio MIME type: ${content.mimeType}`); - } - } - - return content; - }); - } - - /** - * Parse audio content item. - * Validates that the item matches the AudioContent schema. - * @param item - The content item to parse - */ - parseAudio(item: unknown): AudioContent { - const result = AudioContentSchema.safeParse(item); - if (!result.success) { - const issues = result.error.issues - .map((i) => `${i.path.join(".")}: ${i.message}`) - .join(", "); - throw new Error(`Invalid audio content: ${issues}`); - } - return result.data; - } - - /** - * Validate structured content against a JSON schema. - * @param content - The structured content object - * @param schema - The JSON schema (outputSchema) - */ - validateStructuredOutput( - content: unknown, - schema?: object, - ): ValidationResult { - if (!schema) { - // No schema = always valid - return { valid: true }; - } - - try { - const validate = this.ajv.compile(schema); - const valid = validate(content); - - if (!valid) { - return { - valid: false, - errors: validate.errors?.map( - (e) => `${e.instancePath || "/"} ${e.message}`, - ), - }; - } - - return { valid: true }; - } catch (err: any) { - return { - valid: false, - errors: [`Schema compilation error: ${err.message}`], - }; - } - } - - /** - * Decode base64 data to Uint8Array. - * @param data - Base64 string - */ - decodeBase64(data: string): Uint8Array { - try { - // Use Buffer in Node.js environment for proper base64 decoding - if (typeof Buffer !== "undefined") { - return new Uint8Array(Buffer.from(data, "base64")); - } - // Fallback for browser-like environments - const binary = atob(data); - const bytes = new Uint8Array(binary.length); - for (let i = 0; i < binary.length; i++) { - bytes[i] = binary.charCodeAt(i); - } - return bytes; - } catch { - throw new Error("Invalid base64 data"); - } - } - - /** - * Get the estimated byte size of a content item. - * @param content - The tool content - * @returns Estimated byte size - */ - getContentSize(content: ToolContent): number { - switch (content.type) { - case "text": - return content.text.length; - case "image": - case "audio": - // Base64 is ~33% larger than binary, so multiply by 0.75 to get actual size - return Math.floor(content.data.length * 0.75); - case "resource_link": - return 0; - case "resource": - if (content.resource.text) return content.resource.text.length; - if (content.resource.blob) - return Math.floor(content.resource.blob.length * 0.75); - return 0; - } - } - - /** - * Validate a MIME type against allowed types. - * @param mimeType - The MIME type to validate - * @param allowedTypes - Array of allowed MIME types or prefixes - */ - validateMimeType(mimeType: string, allowedTypes: readonly string[]): boolean { - return allowedTypes.some((allowed) => { - if (allowed.endsWith("/*")) { - // Prefix match (e.g., "image/*") - const prefix = allowed.slice(0, -1); - return mimeType.startsWith(prefix); - } - return mimeType === allowed; - }); - } + private ajv = new Ajv(); + + /** + * Parse raw content array into typed ToolContent objects. + * Validates types, base64 data, and mime types. + * @param rawContent - The raw content array from JSON-RPC result + * @throws Error if content is invalid + */ + parseContent(rawContent: unknown[]): ToolContent[] { + if (!Array.isArray(rawContent)) { + throw new Error("Content must be an array"); + } + + return rawContent.map((item, index) => { + const result = ToolContentSchema.safeParse(item); + if (!result.success) { + const issues = result.error.issues + .map((i) => `${i.path.join(".")}: ${i.message}`) + .join(", "); + throw new Error(`Invalid content at index ${index}: ${issues}`); + } + const content = result.data; + + // Enforce strict MIME type validation for image and audio + if (content.type === "image") { + if (!this.validateMimeType(content.mimeType, ImageMimeTypes)) { + throw new Error(`Invalid image MIME type: ${content.mimeType}`); + } + } + + if (content.type === "audio") { + if (!this.validateMimeType(content.mimeType, AudioMimeTypes)) { + throw new Error(`Invalid audio MIME type: ${content.mimeType}`); + } + } + + return content; + }); + } + + /** + * Parse audio content item. + * Validates that the item matches the AudioContent schema. + * @param item - The content item to parse + */ + parseAudio(item: unknown): AudioContent { + const result = AudioContentSchema.safeParse(item); + if (!result.success) { + const issues = result.error.issues + .map((i) => `${i.path.join(".")}: ${i.message}`) + .join(", "); + throw new Error(`Invalid audio content: ${issues}`); + } + return result.data; + } + + /** + * Validate structured content against a JSON schema. + * @param content - The structured content object + * @param schema - The JSON schema (outputSchema) + */ + validateStructuredOutput( + content: unknown, + schema?: object, + ): ValidationResult { + if (!schema) { + // No schema = always valid + return { valid: true }; + } + + try { + const validate = this.ajv.compile(schema); + const valid = validate(content); + + if (!valid) { + return { + valid: false, + errors: validate.errors?.map( + (e) => `${e.instancePath || "/"} ${e.message}`, + ), + }; + } + + return { valid: true }; + } catch (err: any) { + return { + valid: false, + errors: [`Schema compilation error: ${err.message}`], + }; + } + } + + /** + * Decode base64 data to Uint8Array. + * @param data - Base64 string + */ + decodeBase64(data: string): Uint8Array { + try { + // Use Buffer in Node.js environment for proper base64 decoding + if (typeof Buffer !== "undefined") { + return new Uint8Array(Buffer.from(data, "base64")); + } + // Fallback for browser-like environments + const binary = atob(data); + const bytes = new Uint8Array(binary.length); + for (let i = 0; i < binary.length; i++) { + bytes[i] = binary.charCodeAt(i); + } + return bytes; + } catch { + throw new Error("Invalid base64 data"); + } + } + + /** + * Get the estimated byte size of a content item. + * @param content - The tool content + * @returns Estimated byte size + */ + getContentSize(content: ToolContent): number { + switch (content.type) { + case "text": + return content.text.length; + case "image": + case "audio": + // Base64 is ~33% larger than binary, so multiply by 0.75 to get actual size + return Math.floor(content.data.length * 0.75); + case "resource_link": + return 0; + case "resource": + if (content.resource.text) return content.resource.text.length; + if (content.resource.blob) + return Math.floor(content.resource.blob.length * 0.75); + return 0; + } + } + + /** + * Validate a MIME type against allowed types. + * @param mimeType - The MIME type to validate + * @param allowedTypes - Array of allowed MIME types or prefixes + */ + validateMimeType(mimeType: string, allowedTypes: readonly string[]): boolean { + return allowedTypes.some((allowed) => { + if (allowed.endsWith("/*")) { + // Prefix match (e.g., "image/*") + const prefix = allowed.slice(0, -1); + return mimeType.startsWith(prefix); + } + return mimeType === allowed; + }); + } } // Singleton instance diff --git a/packages/mcp/src/progress/tracker.test.ts b/packages/mcp/src/progress/tracker.test.ts index 8c1f7cf..9342fad 100644 --- a/packages/mcp/src/progress/tracker.test.ts +++ b/packages/mcp/src/progress/tracker.test.ts @@ -1,7 +1,7 @@ import { beforeEach, describe, expect, test } from "bun:test"; import { randomUUID } from "node:crypto"; -import { ProgressTracker } from "./tracker"; import { ToolOperationStore } from "../store/operation-store"; +import { ProgressTracker } from "./tracker"; describe("ProgressTracker", () => { let tracker: ProgressTracker; diff --git a/packages/mcp/src/progress/tracker.ts b/packages/mcp/src/progress/tracker.ts index 73c242f..b096d3a 100644 --- a/packages/mcp/src/progress/tracker.ts +++ b/packages/mcp/src/progress/tracker.ts @@ -6,99 +6,99 @@ */ import { v4 as uuidv4 } from "uuid"; -import type { ProgressNotification, ProgressUpdate } from "../types/progress"; import { toolOperationStore } from "../store/operation-store"; +import type { ProgressNotification, ProgressUpdate } from "../types/progress"; export class ProgressTracker { - /** Map: progressToken → operationId */ - private activeTokens = new Map(); + /** Map: progressToken → operationId */ + private activeTokens = new Map(); - /** - * Generate a unique progress token. - * Format: prog-{timestamp}-{uuid-prefix} - */ - generateToken(): string { - return `prog-${Date.now()}-${uuidv4().slice(0, 8)}`; - } + /** + * Generate a unique progress token. + * Format: prog-{timestamp}-{uuid-prefix} + */ + generateToken(): string { + return `prog-${Date.now()}-${uuidv4().slice(0, 8)}`; + } - /** - * Register an operation for progress tracking. - * @param token - The progress token - * @param operationId - The operation ID - */ - register(token: string, operationId: string): void { - this.activeTokens.set(token, operationId); - } + /** + * Register an operation for progress tracking. + * @param token - The progress token + * @param operationId - The operation ID + */ + register(token: string, operationId: string): void { + this.activeTokens.set(token, operationId); + } - /** - * Handle an incoming progress notification. - * Updates the associated operation in the store. - * @param notification - The progress notification - */ - handleNotification(notification: ProgressNotification): void { - const token = String(notification.progressToken); - const operationId = this.activeTokens.get(token); + /** + * Handle an incoming progress notification. + * Updates the associated operation in the store. + * @param notification - The progress notification + */ + handleNotification(notification: ProgressNotification): void { + const token = String(notification.progressToken); + const operationId = this.activeTokens.get(token); - if (!operationId) { - // Ignore notifications for unknown tokens (could be from cancelled ops) - return; - } + if (!operationId) { + // Ignore notifications for unknown tokens (could be from cancelled ops) + return; + } - const update: ProgressUpdate = { - id: uuidv4(), - operationId, - progress: notification.progress, - total: notification.total, - message: notification.message, - timestamp: new Date(), - }; + const update: ProgressUpdate = { + id: uuidv4(), + operationId, + progress: notification.progress, + total: notification.total, + message: notification.message, + timestamp: new Date(), + }; - toolOperationStore.updateProgress(operationId, update); - } + toolOperationStore.updateProgress(operationId, update); + } - /** - * Unregister a token (cleanup). - * Called after tool call completes or is cancelled. - * @param token - The progress token - */ - unregister(token: string): void { - this.activeTokens.delete(token); - } + /** + * Unregister a token (cleanup). + * Called after tool call completes or is cancelled. + * @param token - The progress token + */ + unregister(token: string): void { + this.activeTokens.delete(token); + } - /** - * Get progress history for an operation. - * Delegates to the tool operation store. - * @param operationId - The operation ID - */ - getProgress(operationId: string): ProgressUpdate[] { - const operation = toolOperationStore.get(operationId); - if (!operation || !operation.progressUpdates) { - return []; - } - // Convert the stored progress to full ProgressUpdate objects - return operation.progressUpdates.map((p, index) => ({ - id: `${operationId}-progress-${index}`, - operationId, - progress: p.progress, - total: p.total, - message: p.message, - timestamp: p.timestamp, - })); - } + /** + * Get progress history for an operation. + * Delegates to the tool operation store. + * @param operationId - The operation ID + */ + getProgress(operationId: string): ProgressUpdate[] { + const operation = toolOperationStore.get(operationId); + if (!operation || !operation.progressUpdates) { + return []; + } + // Convert the stored progress to full ProgressUpdate objects + return operation.progressUpdates.map((p, index) => ({ + id: `${operationId}-progress-${index}`, + operationId, + progress: p.progress, + total: p.total, + message: p.message, + timestamp: p.timestamp, + })); + } - /** - * Check if a token is currently registered (for testing). - */ - isRegistered(token: string): boolean { - return this.activeTokens.has(token); - } + /** + * Check if a token is currently registered (for testing). + */ + isRegistered(token: string): boolean { + return this.activeTokens.has(token); + } - /** - * Get the number of active tokens (for testing). - */ - activeCount(): number { - return this.activeTokens.size; - } + /** + * Get the number of active tokens (for testing). + */ + activeCount(): number { + return this.activeTokens.size; + } } // Singleton instance diff --git a/packages/mcp/src/store/operation-store.test.ts b/packages/mcp/src/store/operation-store.test.ts index b08f184..ba9d745 100644 --- a/packages/mcp/src/store/operation-store.test.ts +++ b/packages/mcp/src/store/operation-store.test.ts @@ -108,8 +108,8 @@ describe("ToolOperationStore", () => { const updated = store.get(op.id); expect(updated?.progressUpdates).toHaveLength(1); - expect(updated?.progressUpdates[0]!.progress).toBe(50); - expect(updated?.progressUpdates[0]!.message).toBe("Processing..."); + expect(updated?.progressUpdates[0]?.progress).toBe(50); + expect(updated?.progressUpdates[0]?.message).toBe("Processing..."); }); it("updateProgress throws for non-existent operation", () => { @@ -140,8 +140,8 @@ describe("ToolOperationStore", () => { const updates = store.getProgress(op.id); expect(updates).toHaveLength(2); - expect(updates[0]!.progress).toBe(25); - expect(updates[1]!.progress).toBe(50); + expect(updates[0]?.progress).toBe(25); + expect(updates[1]?.progress).toBe(50); }); it("getProgress returns empty array for non-existent operation", () => { @@ -211,7 +211,7 @@ describe("ToolOperationStore", () => { const updated = store.get(op.id); expect(updated?.completedAt).toBeDefined(); - expect(updated?.completedAt!.getTime()).toBeGreaterThanOrEqual( + expect(updated?.completedAt?.getTime()).toBeGreaterThanOrEqual( op.startedAt.getTime(), ); }); diff --git a/packages/mcp/src/store/operation-store.ts b/packages/mcp/src/store/operation-store.ts index fb36d1a..43acee4 100644 --- a/packages/mcp/src/store/operation-store.ts +++ b/packages/mcp/src/store/operation-store.ts @@ -10,209 +10,209 @@ */ import { v4 as uuidv4 } from "uuid"; +import type { ProgressUpdate } from "../types/progress"; import type { - ToolCallRequest, - ToolCallResult, - ToolOperation, - JsonRpcError, + JsonRpcError, + ToolCallRequest, + ToolCallResult, + ToolOperation, } from "../types/tool"; -import type { ProgressUpdate } from "../types/progress"; export class ToolOperationStore { - private operations = new Map(); - - /** - * Create a new pending tool operation. - * @param sessionId - The session this operation belongs to - * @param request - The tool call request - * @param requestId - The JSON-RPC request ID for correlation - * @returns The created ToolOperation in pending status - */ - create( - sessionId: string, - request: ToolCallRequest, - requestId: string, - ): ToolOperation { - const id = uuidv4(); - const operation: ToolOperation = { - id, - sessionId, - requestId, - request, - status: "pending", - progressUpdates: [], - cancelRequested: false, - startedAt: new Date(), - }; - - this.operations.set(operation.id, operation); - return operation; - } - - /** - * Update an existing operation with result, error, or other fields. - * @param id - The operation ID - * @param updates - Partial updates to apply - * @throws Error if operation not found - */ - update( - id: string, - updates: { - status?: ToolOperation["status"]; - result?: ToolCallResult; - error?: JsonRpcError; - progressToken?: string | number; - cancelReason?: string; - completedAt?: Date; - }, - ): void { - const operation = this.operations.get(id); - if (!operation) { - throw new Error(`Tool operation not found: ${id}`); - } - - if (updates.status) { - operation.status = updates.status; - } - - if (updates.result) { - operation.result = updates.result; - } - - if (updates.error) { - operation.error = updates.error; - } - - // Set completedAt for terminal states - if ( - updates.status === "completed" || - updates.status === "error" || - updates.status === "cancelled" - ) { - operation.completedAt = new Date(); - } - } - - /** - * Add a progress update to an operation. - * @param id - The operation ID - * @param update - The progress update - */ - updateProgress(id: string, update: ProgressUpdate): void { - const operation = this.operations.get(id); - if (!operation) { - throw new Error(`Tool operation not found: ${id}`); - } - - operation.progressUpdates.push({ - progress: update.progress, - total: update.total, - message: update.message, - timestamp: update.timestamp, - }); - } - - /** - * Mark an operation as cancelled. - * @param id - The operation ID - * @param reason - Optional cancellation reason - */ - markCancelled(id: string, reason?: string): void { - const operation = this.operations.get(id); - if (!operation) { - // Operation may have been cleared or never existed - silently ignore - return; - } - - // Only mark as cancelled if still pending - if (operation.status !== "pending") { - return; - } - - operation.status = "cancelled"; - operation.cancelRequested = true; - if (reason) { - operation.cancelReason = reason; - } - operation.completedAt = new Date(); - } - - /** - * Get all progress updates for an operation. - * @param operationId - The operation ID - * @returns Array of progress updates (empty if not found) - */ - getProgress(operationId: string): ToolOperation["progressUpdates"] { - const operation = this.operations.get(operationId); - if (!operation) { - return []; - } - return operation.progressUpdates; - } - - /** - * Get the most recent progress update. - * @param operationId - The operation ID - * @returns Latest update or undefined - */ - getLatestProgress( - operationId: string, - ): ToolOperation["progressUpdates"][number] | undefined { - const updates = this.getProgress(operationId); - return updates.length > 0 ? updates[updates.length - 1] : undefined; - } - - /** - * Get an operation by ID. - * @param id - The operation ID - * @returns The operation or undefined if not found - */ - get(id: string): ToolOperation | undefined { - return this.operations.get(id); - } - - /** - * Get all operations for a session. - * @param sessionId - The session ID - * @returns Array of operations for the session - */ - getBySession(sessionId: string): ToolOperation[] { - return Array.from(this.operations.values()).filter( - (op) => op.sessionId === sessionId, - ); - } - - /** - * Get an operation by its JSON-RPC request ID. - * Useful for correlating responses with pending operations. - * @param requestId - The JSON-RPC request ID - * @returns The operation or undefined if not found - */ - getByRequestId(requestId: string): ToolOperation | undefined { - return Array.from(this.operations.values()).find( - (op) => op.requestId === requestId, - ); - } - - /** - * Clear all operations for a session. - * Called when session is closed. - * @param sessionId - The session ID - */ - clear(sessionId: string): void { - for (const [id, op] of this.operations.entries()) { - if (op.sessionId === sessionId) { - this.operations.delete(id); - } - } - } - - /** - * Get count of operations (for testing). - */ - count(): number { - return this.operations.size; - } + private operations = new Map(); + + /** + * Create a new pending tool operation. + * @param sessionId - The session this operation belongs to + * @param request - The tool call request + * @param requestId - The JSON-RPC request ID for correlation + * @returns The created ToolOperation in pending status + */ + create( + sessionId: string, + request: ToolCallRequest, + requestId: string, + ): ToolOperation { + const id = uuidv4(); + const operation: ToolOperation = { + id, + sessionId, + requestId, + request, + status: "pending", + progressUpdates: [], + cancelRequested: false, + startedAt: new Date(), + }; + + this.operations.set(operation.id, operation); + return operation; + } + + /** + * Update an existing operation with result, error, or other fields. + * @param id - The operation ID + * @param updates - Partial updates to apply + * @throws Error if operation not found + */ + update( + id: string, + updates: { + status?: ToolOperation["status"]; + result?: ToolCallResult; + error?: JsonRpcError; + progressToken?: string | number; + cancelReason?: string; + completedAt?: Date; + }, + ): void { + const operation = this.operations.get(id); + if (!operation) { + throw new Error(`Tool operation not found: ${id}`); + } + + if (updates.status) { + operation.status = updates.status; + } + + if (updates.result) { + operation.result = updates.result; + } + + if (updates.error) { + operation.error = updates.error; + } + + // Set completedAt for terminal states + if ( + updates.status === "completed" || + updates.status === "error" || + updates.status === "cancelled" + ) { + operation.completedAt = new Date(); + } + } + + /** + * Add a progress update to an operation. + * @param id - The operation ID + * @param update - The progress update + */ + updateProgress(id: string, update: ProgressUpdate): void { + const operation = this.operations.get(id); + if (!operation) { + throw new Error(`Tool operation not found: ${id}`); + } + + operation.progressUpdates.push({ + progress: update.progress, + total: update.total, + message: update.message, + timestamp: update.timestamp, + }); + } + + /** + * Mark an operation as cancelled. + * @param id - The operation ID + * @param reason - Optional cancellation reason + */ + markCancelled(id: string, reason?: string): void { + const operation = this.operations.get(id); + if (!operation) { + // Operation may have been cleared or never existed - silently ignore + return; + } + + // Only mark as cancelled if still pending + if (operation.status !== "pending") { + return; + } + + operation.status = "cancelled"; + operation.cancelRequested = true; + if (reason) { + operation.cancelReason = reason; + } + operation.completedAt = new Date(); + } + + /** + * Get all progress updates for an operation. + * @param operationId - The operation ID + * @returns Array of progress updates (empty if not found) + */ + getProgress(operationId: string): ToolOperation["progressUpdates"] { + const operation = this.operations.get(operationId); + if (!operation) { + return []; + } + return operation.progressUpdates; + } + + /** + * Get the most recent progress update. + * @param operationId - The operation ID + * @returns Latest update or undefined + */ + getLatestProgress( + operationId: string, + ): ToolOperation["progressUpdates"][number] | undefined { + const updates = this.getProgress(operationId); + return updates.length > 0 ? updates[updates.length - 1] : undefined; + } + + /** + * Get an operation by ID. + * @param id - The operation ID + * @returns The operation or undefined if not found + */ + get(id: string): ToolOperation | undefined { + return this.operations.get(id); + } + + /** + * Get all operations for a session. + * @param sessionId - The session ID + * @returns Array of operations for the session + */ + getBySession(sessionId: string): ToolOperation[] { + return Array.from(this.operations.values()).filter( + (op) => op.sessionId === sessionId, + ); + } + + /** + * Get an operation by its JSON-RPC request ID. + * Useful for correlating responses with pending operations. + * @param requestId - The JSON-RPC request ID + * @returns The operation or undefined if not found + */ + getByRequestId(requestId: string): ToolOperation | undefined { + return Array.from(this.operations.values()).find( + (op) => op.requestId === requestId, + ); + } + + /** + * Clear all operations for a session. + * Called when session is closed. + * @param sessionId - The session ID + */ + clear(sessionId: string): void { + for (const [id, op] of this.operations.entries()) { + if (op.sessionId === sessionId) { + this.operations.delete(id); + } + } + } + + /** + * Get count of operations (for testing). + */ + count(): number { + return this.operations.size; + } } // Singleton instance diff --git a/packages/mcp/src/task/manager.test.ts b/packages/mcp/src/task/manager.test.ts new file mode 100644 index 0000000..06bc7ce --- /dev/null +++ b/packages/mcp/src/task/manager.test.ts @@ -0,0 +1,405 @@ +/** + * TaskManager Unit Tests + * + * Tests for the TaskManager class that handles task lifecycle. + * Task 07: Task-Augmented Execution - Phase 2 Unit Tests + */ + +import { afterEach, beforeEach, describe, expect, it } from "bun:test"; +import type { Task } from "../types/task"; +import { TaskManager } from "./manager"; + +describe("TaskManager", () => { + let manager: TaskManager; + + beforeEach(() => { + manager = new TaskManager({ pollIntervalMs: 10, maxPollAttempts: 5 }); + }); + + afterEach(() => { + manager.clear(); + }); + + // ========================================================================= + // registerTask + // ========================================================================= + + describe("registerTask", () => { + it("stores task in cache with defaults", () => { + manager.registerTask("task-1", "session-1"); + + const task = manager.getTask("task-1"); + expect(task).toBeDefined(); + expect(task?.taskId).toBe("task-1"); + expect(task?.status).toBe("working"); + expect(task?.ttl).toBeNull(); + }); + + it("merges initial task state", () => { + manager.registerTask("task-2", "session-1", { + status: "completed", + statusMessage: "Already done", + ttl: 60000, + }); + + const task = manager.getTask("task-2"); + expect(task?.status).toBe("completed"); + expect(task?.statusMessage).toBe("Already done"); + expect(task?.ttl).toBe(60000); + }); + + it("generates timestamps for createdAt and lastUpdatedAt", () => { + manager.registerTask("task-3", "session-1"); + + const task = manager.getTask("task-3"); + expect(task?.createdAt).toBeDefined(); + expect(task?.lastUpdatedAt).toBeDefined(); + // Should be ISO 8601 format + expect(() => new Date(task?.createdAt)).not.toThrow(); + }); + + it("overwrites existing task with same ID", () => { + manager.registerTask("task-dup", "session-1", { statusMessage: "First" }); + manager.registerTask("task-dup", "session-1", { + statusMessage: "Second", + }); + + const task = manager.getTask("task-dup"); + expect(task?.statusMessage).toBe("Second"); + }); + }); + + // ========================================================================= + // getTask + // ========================================================================= + + describe("getTask", () => { + it("returns registered task", () => { + manager.registerTask("task-get", "session-1", { statusMessage: "test" }); + + const task = manager.getTask("task-get"); + expect(task).toBeDefined(); + expect(task?.statusMessage).toBe("test"); + }); + + it("returns undefined for unknown task", () => { + const task = manager.getTask("unknown-task"); + expect(task).toBeUndefined(); + }); + }); + + // ========================================================================= + // getTasksBySession + // ========================================================================= + + describe("getTasksBySession", () => { + it("returns all tasks (session filtering not yet implemented)", () => { + manager.registerTask("task-a", "session-1"); + manager.registerTask("task-b", "session-2"); + + const tasks = manager.getTasksBySession("session-1"); + // Currently returns all tasks regardless of session + expect(tasks.length).toBe(2); + }); + + it("returns empty array when no tasks", () => { + const tasks = manager.getTasksBySession("session-1"); + expect(tasks).toEqual([]); + }); + }); + + // ========================================================================= + // pollUntilComplete + // ========================================================================= + + describe("pollUntilComplete", () => { + it("returns immediately on completed status", async () => { + manager.registerTask("task-done", "session-1"); + let pollCount = 0; + + const result = await manager.pollUntilComplete("task-done", async () => { + pollCount++; + return { + taskId: "task-done", + status: "completed", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }; + }); + + expect(result.status).toBe("completed"); + expect(pollCount).toBe(1); + }); + + it("returns immediately on failed status", async () => { + manager.registerTask("task-fail", "session-1"); + + const result = await manager.pollUntilComplete("task-fail", async () => ({ + taskId: "task-fail", + status: "failed", + statusMessage: "Something went wrong", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + })); + + expect(result.status).toBe("failed"); + expect(result.statusMessage).toBe("Something went wrong"); + }); + + it("returns immediately on cancelled status", async () => { + manager.registerTask("task-cancel", "session-1"); + + const result = await manager.pollUntilComplete( + "task-cancel", + async () => ({ + taskId: "task-cancel", + status: "cancelled", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }), + ); + + expect(result.status).toBe("cancelled"); + }); + + it("continues polling on working status", async () => { + manager.registerTask("task-working", "session-1"); + let pollCount = 0; + + const result = await manager.pollUntilComplete( + "task-working", + async () => { + pollCount++; + // Complete after 3 polls + return { + taskId: "task-working", + status: pollCount >= 3 ? "completed" : "working", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }; + }, + ); + + expect(result.status).toBe("completed"); + expect(pollCount).toBe(3); + }); + + it("continues polling on input_required status", async () => { + // Use fresh manager with higher limits to avoid flakiness + const testManager = new TaskManager({ + pollIntervalMs: 10, + maxPollAttempts: 10, + }); + testManager.registerTask("task-input", "session-1"); + let pollCount = 0; + + const result = await testManager.pollUntilComplete( + "task-input", + async () => { + pollCount++; + // Status: input_required → input_required → completed + if (pollCount >= 3) { + return { + taskId: "task-input", + status: "completed" as const, + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }; + } + return { + taskId: "task-input", + status: "input_required" as const, + statusMessage: "Waiting for user input", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }; + }, + ); + + expect(result.status).toBe("completed"); + expect(pollCount).toBe(3); + }); + + it("respects task pollInterval when provided", async () => { + manager.registerTask("task-interval", "session-1"); + const startTime = Date.now(); + let pollCount = 0; + + await manager.pollUntilComplete("task-interval", async () => { + pollCount++; + return { + taskId: "task-interval", + status: pollCount >= 2 ? "completed" : "working", + pollInterval: 50, // 50ms suggested interval + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }; + }); + + const elapsed = Date.now() - startTime; + // Should have waited ~50ms between polls (not the default 10ms) + expect(elapsed).toBeGreaterThanOrEqual(40); + }); + + it("throws error after max poll attempts", async () => { + manager.registerTask("task-timeout", "session-1"); + + await expect( + manager.pollUntilComplete("task-timeout", async () => ({ + taskId: "task-timeout", + status: "working", // Never completes + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + })), + ).rejects.toThrow("Task task-timeout did not complete within timeout"); + }); + + it("calls onProgress callback on each poll", async () => { + manager.registerTask("task-progress", "session-1"); + const progressCalls: Task[] = []; + let pollCount = 0; + + await manager.pollUntilComplete( + "task-progress", + async () => { + pollCount++; + return { + taskId: "task-progress", + status: pollCount >= 2 ? "completed" : "working", + statusMessage: `Poll ${pollCount}`, + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }; + }, + (task) => { + progressCalls.push(task); + }, + ); + + expect(progressCalls.length).toBe(2); + expect(progressCalls[0]?.statusMessage).toBe("Poll 1"); + expect(progressCalls[1]?.statusMessage).toBe("Poll 2"); + }); + + it("updates cache on each poll", async () => { + manager.registerTask("task-cache", "session-1", { + statusMessage: "Initial", + }); + let pollCount = 0; + + await manager.pollUntilComplete("task-cache", async () => { + pollCount++; + return { + taskId: "task-cache", + status: pollCount >= 2 ? "completed" : "working", + statusMessage: `Updated ${pollCount}`, + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }; + }); + + const cached = manager.getTask("task-cache"); + expect(cached?.statusMessage).toBe("Updated 2"); + }); + }); + + // ========================================================================= + // handleStatusNotification + // ========================================================================= + + describe("handleStatusNotification", () => { + it("updates cached task", () => { + manager.registerTask("task-notify", "session-1", { + statusMessage: "Original", + }); + + manager.handleStatusNotification({ + taskId: "task-notify", + status: "completed", + statusMessage: "Updated via notification", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }); + + const task = manager.getTask("task-notify"); + expect(task?.status).toBe("completed"); + expect(task?.statusMessage).toBe("Updated via notification"); + }); + + it("creates task if not exists", () => { + manager.handleStatusNotification({ + taskId: "task-new", + status: "working", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }); + + const task = manager.getTask("task-new"); + expect(task).toBeDefined(); + expect(task?.status).toBe("working"); + }); + }); + + // ========================================================================= + // removeTask / clear + // ========================================================================= + + describe("removeTask", () => { + it("removes task from cache", () => { + manager.registerTask("task-remove", "session-1"); + expect(manager.getTask("task-remove")).toBeDefined(); + + manager.removeTask("task-remove"); + expect(manager.getTask("task-remove")).toBeUndefined(); + }); + + it("does not throw for unknown task", () => { + expect(() => manager.removeTask("unknown")).not.toThrow(); + }); + }); + + describe("clear", () => { + it("removes all tasks", () => { + manager.registerTask("task-a", "session-1"); + manager.registerTask("task-b", "session-1"); + manager.registerTask("task-c", "session-2"); + + manager.clear(); + + expect(manager.getTask("task-a")).toBeUndefined(); + expect(manager.getTask("task-b")).toBeUndefined(); + expect(manager.getTask("task-c")).toBeUndefined(); + }); + }); + + // ========================================================================= + // Constructor Options + // ========================================================================= + + describe("constructor options", () => { + it("uses default pollIntervalMs of 1000", async () => { + const defaultManager = new TaskManager(); + // Can't easily test the default interval without waiting, + // but we can verify the manager was created successfully + expect(defaultManager).toBeDefined(); + }); + + it("uses default maxPollAttempts of 300", async () => { + const defaultManager = new TaskManager({ pollIntervalMs: 1 }); + // Just verify it was created - testing 300 attempts would be slow + expect(defaultManager).toBeDefined(); + }); + }); +}); diff --git a/packages/mcp/src/task/manager.ts b/packages/mcp/src/task/manager.ts new file mode 100644 index 0000000..a4f16b6 --- /dev/null +++ b/packages/mcp/src/task/manager.ts @@ -0,0 +1,161 @@ +/** + * TaskManager + * + * Manages task lifecycle for task-augmented tool execution. + * Handles registration, polling, caching, and status notifications. + */ + +import type { Task, TaskStatus } from "../types/task"; + +// ============================================================================= +// Types +// ============================================================================= + +export interface TaskManagerOptions { + /** Default polling interval in milliseconds. Default: 1000 */ + pollIntervalMs?: number; + /** Maximum polling attempts before timeout. Default: 300 (5 minutes at 1s) */ + maxPollAttempts?: number; +} + +// ============================================================================= +// TaskManager +// ============================================================================= + +export class TaskManager { + private tasks = new Map(); + private pollIntervalMs: number; + private maxPollAttempts: number; + + constructor(options: TaskManagerOptions = {}) { + this.pollIntervalMs = options.pollIntervalMs ?? 1000; + this.maxPollAttempts = options.maxPollAttempts ?? 300; + } + + /** + * Register a new task in the manager. + * @param taskId - The task identifier + * @param sessionId - The session that owns this task + * @param initialTask - Initial task state from server + */ + registerTask( + taskId: string, + _sessionId: string, + initialTask: Partial = {}, + ): void { + this.tasks.set(taskId, { + taskId, + status: "working", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + ...initialTask, + }); + } + + /** + * Poll until task reaches a terminal status or shouldStop returns true. + * @param taskId - The task to poll + * @param fetchStatus - Callback to fetch current task status from server + * @param onProgress - Optional callback for status updates + * @param shouldStop - Optional callback to stop polling early (e.g., for input_required) + * @returns The final task state + */ + async pollUntilComplete( + taskId: string, + fetchStatus: () => Promise, + onProgress?: (task: Task) => void, + shouldStop?: (task: Task) => boolean, + ): Promise { + let attempts = 0; + + while (attempts < this.maxPollAttempts) { + const task = await fetchStatus(); + this.tasks.set(taskId, task); + + if (onProgress) { + onProgress(task); + } + + if (this.isTerminalStatus(task.status)) { + return task; + } + + // Check early exit condition (e.g., input_required) + if (shouldStop?.(task)) { + return task; + } + + // Use task's suggested pollInterval if available + const interval = task.pollInterval ?? this.pollIntervalMs; + await this.sleep(interval); + attempts++; + } + + throw new Error(`Task ${taskId} did not complete within timeout`); + } + + /** + * Check if a status is terminal (task processing complete per MCP spec). + * Note: input_required is NOT terminal - task waits for input. + */ + private isTerminalStatus(status: TaskStatus): boolean { + return ["completed", "failed", "cancelled"].includes(status); + } + + /** + * Sleep for specified milliseconds. + */ + private sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + /** + * Get a task from the cache. + * @param taskId - The task identifier + * @returns The cached task or undefined + */ + getTask(taskId: string): Task | undefined { + return this.tasks.get(taskId); + } + + /** + * Get all tasks for a session. + * @param sessionId - The session identifier + * @returns Array of tasks (note: sessionId filtering not yet implemented) + */ + getTasksBySession(_sessionId: string): Task[] { + // TODO: Add sessionId to Task and filter + return Array.from(this.tasks.values()); + } + + /** + * Handle incoming task status notification from server. + * Called by McpClientManager's notification handler. + * @param params - Task status notification payload + */ + handleStatusNotification(params: Task): void { + this.tasks.set(params.taskId, params); + } + + /** + * Remove a task from the cache. + * @param taskId - The task identifier + */ + removeTask(taskId: string): void { + this.tasks.delete(taskId); + } + + /** + * Clear all tasks from the cache. + */ + clear(): void { + this.tasks.clear(); + } +} + +// ============================================================================= +// Singleton Export +// ============================================================================= + +export const taskManager = new TaskManager(); diff --git a/packages/mcp/src/types/cancel.test.ts b/packages/mcp/src/types/cancel.test.ts index 4995475..2215ebc 100644 --- a/packages/mcp/src/types/cancel.test.ts +++ b/packages/mcp/src/types/cancel.test.ts @@ -2,72 +2,72 @@ import { describe, expect, test } from "bun:test"; import { CancelNotificationSchema, PendingRequestSchema } from "./cancel"; describe("Cancellation Schemas", () => { - describe("CancelNotificationSchema", () => { - test("validates notification with string requestId", () => { - const valid = { - requestId: "req-123", - reason: "User cancelled", - }; - const result = CancelNotificationSchema.safeParse(valid); - expect(result.success).toBe(true); - }); + describe("CancelNotificationSchema", () => { + test("validates notification with string requestId", () => { + const valid = { + requestId: "req-123", + reason: "User cancelled", + }; + const result = CancelNotificationSchema.safeParse(valid); + expect(result.success).toBe(true); + }); - test("validates notification with number requestId", () => { - const valid = { - requestId: 42, - }; - const result = CancelNotificationSchema.safeParse(valid); - expect(result.success).toBe(true); - }); + test("validates notification with number requestId", () => { + const valid = { + requestId: 42, + }; + const result = CancelNotificationSchema.safeParse(valid); + expect(result.success).toBe(true); + }); - test("validates notification without reason", () => { - const valid = { - requestId: "req-456", - }; - const result = CancelNotificationSchema.safeParse(valid); - expect(result.success).toBe(true); - }); + test("validates notification without reason", () => { + const valid = { + requestId: "req-456", + }; + const result = CancelNotificationSchema.safeParse(valid); + expect(result.success).toBe(true); + }); - test("rejects missing requestId", () => { - const invalid = { - reason: "Some reason", - }; - const result = CancelNotificationSchema.safeParse(invalid); - expect(result.success).toBe(false); - }); - }); + test("rejects missing requestId", () => { + const invalid = { + reason: "Some reason", + }; + const result = CancelNotificationSchema.safeParse(invalid); + expect(result.success).toBe(false); + }); + }); - describe("PendingRequestSchema", () => { - test("validates valid pending request", () => { - const valid = { - requestId: "req-789", - operationId: "123e4567-e89b-12d3-a456-426614174000", - startedAt: new Date(), - timeoutMs: 30000, - }; - const result = PendingRequestSchema.safeParse(valid); - expect(result.success).toBe(true); - }); + describe("PendingRequestSchema", () => { + test("validates valid pending request", () => { + const valid = { + requestId: "req-789", + operationId: "123e4567-e89b-12d3-a456-426614174000", + startedAt: new Date(), + timeoutMs: 30000, + }; + const result = PendingRequestSchema.safeParse(valid); + expect(result.success).toBe(true); + }); - test("rejects invalid operationId UUID", () => { - const invalid = { - requestId: "req-789", - operationId: "not-a-uuid", - startedAt: new Date(), - timeoutMs: 30000, - }; - const result = PendingRequestSchema.safeParse(invalid); - expect(result.success).toBe(false); - }); + test("rejects invalid operationId UUID", () => { + const invalid = { + requestId: "req-789", + operationId: "not-a-uuid", + startedAt: new Date(), + timeoutMs: 30000, + }; + const result = PendingRequestSchema.safeParse(invalid); + expect(result.success).toBe(false); + }); - test("rejects missing timeoutMs", () => { - const invalid = { - requestId: "req-789", - operationId: "123e4567-e89b-12d3-a456-426614174000", - startedAt: new Date(), - }; - const result = PendingRequestSchema.safeParse(invalid); - expect(result.success).toBe(false); - }); - }); + test("rejects missing timeoutMs", () => { + const invalid = { + requestId: "req-789", + operationId: "123e4567-e89b-12d3-a456-426614174000", + startedAt: new Date(), + }; + const result = PendingRequestSchema.safeParse(invalid); + expect(result.success).toBe(false); + }); + }); }); diff --git a/packages/mcp/src/types/cancel.ts b/packages/mcp/src/types/cancel.ts index 951b8be..ac5ad10 100644 --- a/packages/mcp/src/types/cancel.ts +++ b/packages/mcp/src/types/cancel.ts @@ -11,8 +11,8 @@ import { z } from "zod"; * Notification sent to cancel a request. */ export const CancelNotificationSchema = z.object({ - requestId: z.union([z.string(), z.number()]), - reason: z.string().optional(), + requestId: z.union([z.string(), z.number()]), + reason: z.string().optional(), }); export type CancelNotification = z.infer; @@ -21,10 +21,10 @@ export type CancelNotification = z.infer; * Tracks a pending request that can be cancelled. */ export const PendingRequestSchema = z.object({ - requestId: z.string(), - operationId: z.string().uuid(), - startedAt: z.date(), - timeoutMs: z.number(), + requestId: z.string(), + operationId: z.string().uuid(), + startedAt: z.date(), + timeoutMs: z.number(), }); export type PendingRequest = z.infer; diff --git a/packages/mcp/src/types/content.test.ts b/packages/mcp/src/types/content.test.ts index d13d6bd..3df62d2 100644 --- a/packages/mcp/src/types/content.test.ts +++ b/packages/mcp/src/types/content.test.ts @@ -1,87 +1,87 @@ import { describe, expect, test } from "bun:test"; import { - AnnotationsSchema, - AudioContentSchema, - AudioMimeTypes, - ImageContentSchema, - ImageMimeTypes, - ResourceLinkContentSchema, - TextContentSchema, + AnnotationsSchema, + AudioContentSchema, + AudioMimeTypes, + ImageContentSchema, + ImageMimeTypes, + ResourceLinkContentSchema, + TextContentSchema, } from "./content"; describe("Content Schemas", () => { - describe("AnnotationsSchema", () => { - test("validates valid annotations", () => { - const valid = { - audience: ["user"], - priority: 0.5, - }; - const result = AnnotationsSchema.safeParse(valid); - expect(result.success).toBe(true); - }); + describe("AnnotationsSchema", () => { + test("validates valid annotations", () => { + const valid = { + audience: ["user"], + priority: 0.5, + }; + const result = AnnotationsSchema.safeParse(valid); + expect(result.success).toBe(true); + }); - test("allows missing optional fields", () => { - const valid = {}; - const result = AnnotationsSchema.safeParse(valid); - expect(result.success).toBe(true); - }); + test("allows missing optional fields", () => { + const valid = {}; + const result = AnnotationsSchema.safeParse(valid); + expect(result.success).toBe(true); + }); - test("validates audience values", () => { - const invalid = { audience: ["admin"] }; - const result = AnnotationsSchema.safeParse(invalid); - expect(result.success).toBe(false); - }); + test("validates audience values", () => { + const invalid = { audience: ["admin"] }; + const result = AnnotationsSchema.safeParse(invalid); + expect(result.success).toBe(false); + }); - test("validates priority range", () => { - const invalid = { priority: 1.5 }; - const result = AnnotationsSchema.safeParse(invalid); - expect(result.success).toBe(false); - }); - }); + test("validates priority range", () => { + const invalid = { priority: 1.5 }; + const result = AnnotationsSchema.safeParse(invalid); + expect(result.success).toBe(false); + }); + }); - describe("Content Types", () => { - test("TextContentSchema validates", () => { - const valid = { type: "text", text: "Hello" }; - expect(TextContentSchema.safeParse(valid).success).toBe(true); - }); + describe("Content Types", () => { + test("TextContentSchema validates", () => { + const valid = { type: "text", text: "Hello" }; + expect(TextContentSchema.safeParse(valid).success).toBe(true); + }); - test("ImageContentSchema validates basic structure", () => { - const valid = { - type: "image", - data: "abc", - mimeType: "image/png", - }; - expect(ImageContentSchema.safeParse(valid).success).toBe(true); - }); + test("ImageContentSchema validates basic structure", () => { + const valid = { + type: "image", + data: "abc", + mimeType: "image/png", + }; + expect(ImageContentSchema.safeParse(valid).success).toBe(true); + }); - test("AudioContentSchema validates basic structure", () => { - const valid = { - type: "audio", - data: "abc", - mimeType: "audio/wav", - }; - expect(AudioContentSchema.safeParse(valid).success).toBe(true); - }); + test("AudioContentSchema validates basic structure", () => { + const valid = { + type: "audio", + data: "abc", + mimeType: "audio/wav", + }; + expect(AudioContentSchema.safeParse(valid).success).toBe(true); + }); - test("ResourceLinkContentSchema validates", () => { - const valid = { - type: "resource_link", - uri: "file:///test.txt", - name: "Test", - }; - expect(ResourceLinkContentSchema.safeParse(valid).success).toBe(true); - }); - }); + test("ResourceLinkContentSchema validates", () => { + const valid = { + type: "resource_link", + uri: "file:///test.txt", + name: "Test", + }; + expect(ResourceLinkContentSchema.safeParse(valid).success).toBe(true); + }); + }); - describe("Mime Types Lists", () => { - test("ImageMimeTypes contains standard types", () => { - expect(ImageMimeTypes).toContain("image/png"); - expect(ImageMimeTypes).toContain("image/jpeg"); - }); + describe("Mime Types Lists", () => { + test("ImageMimeTypes contains standard types", () => { + expect(ImageMimeTypes).toContain("image/png"); + expect(ImageMimeTypes).toContain("image/jpeg"); + }); - test("AudioMimeTypes contains standard types", () => { - expect(AudioMimeTypes).toContain("audio/wav"); - expect(AudioMimeTypes).toContain("audio/mp3"); - }); - }); + test("AudioMimeTypes contains standard types", () => { + expect(AudioMimeTypes).toContain("audio/wav"); + expect(AudioMimeTypes).toContain("audio/mp3"); + }); + }); }); diff --git a/packages/mcp/src/types/content.ts b/packages/mcp/src/types/content.ts index 48d8f77..6aaf4bc 100644 --- a/packages/mcp/src/types/content.ts +++ b/packages/mcp/src/types/content.ts @@ -11,43 +11,50 @@ import { z } from "zod"; * Supported audio MIME types per MCP spec. */ export const AudioMimeTypes = [ - "audio/wav", - "audio/mp3", - "audio/mpeg", - "audio/ogg", - "audio/webm", - "audio/flac", + "audio/wav", + "audio/mp3", + "audio/mpeg", + "audio/ogg", + "audio/webm", + "audio/flac", ] as const; /** * Supported image MIME types per MCP spec. */ export const ImageMimeTypes = [ - "image/png", - "image/jpeg", - "image/gif", - "image/webp", - "image/svg+xml", + "image/png", + "image/jpeg", + "image/gif", + "image/webp", + "image/svg+xml", ] as const; /** * Annotations for content items. * Used to indicate intended audience and priority. + * + * NOTE: This is ContentAnnotationsSchema (audience/priority for content). + * For tool behavioral hints, see ToolAnnotationsSchema in tool-annotations.ts. */ -export const AnnotationsSchema = z.object({ - audience: z.array(z.enum(["user", "assistant"])).optional(), - priority: z.number().min(0).max(1).optional(), +export const ContentAnnotationsSchema = z.object({ + audience: z.array(z.enum(["user", "assistant"])).optional(), + priority: z.number().min(0).max(1).optional(), }); -export type Annotations = z.infer; +export type ContentAnnotations = z.infer; + +// Backward compatibility alias +export const AnnotationsSchema = ContentAnnotationsSchema; +export type Annotations = ContentAnnotations; /** * Text content returned by a tool. */ export const TextContentSchema = z.object({ - type: z.literal("text"), - text: z.string(), - annotations: AnnotationsSchema.optional(), + type: z.literal("text"), + text: z.string(), + annotations: AnnotationsSchema.optional(), }); export type TextContent = z.infer; @@ -56,10 +63,10 @@ export type TextContent = z.infer; * Image content returned by a tool (base64 encoded). */ export const ImageContentSchema = z.object({ - type: z.literal("image"), - data: z.string(), // base64 - mimeType: z.string(), - annotations: AnnotationsSchema.optional(), + type: z.literal("image"), + data: z.string(), // base64 + mimeType: z.string(), + annotations: AnnotationsSchema.optional(), }); export type ImageContent = z.infer; @@ -69,10 +76,10 @@ export type ImageContent = z.infer; * Added in later MCP spec versions. */ export const AudioContentSchema = z.object({ - type: z.literal("audio"), - data: z.string(), // base64 - mimeType: z.string(), - annotations: AnnotationsSchema.optional(), + type: z.literal("audio"), + data: z.string(), // base64 + mimeType: z.string(), + annotations: AnnotationsSchema.optional(), }); export type AudioContent = z.infer; @@ -81,11 +88,11 @@ export type AudioContent = z.infer; * Resource link content - a reference to a resource. */ export const ResourceLinkContentSchema = z.object({ - type: z.literal("resource_link"), - uri: z.string(), - name: z.string().optional(), - mimeType: z.string().optional(), - annotations: AnnotationsSchema.optional(), + type: z.literal("resource_link"), + uri: z.string(), + name: z.string().optional(), + mimeType: z.string().optional(), + annotations: AnnotationsSchema.optional(), }); export type ResourceLinkContent = z.infer; @@ -94,27 +101,29 @@ export type ResourceLinkContent = z.infer; * Embedded resource content - inline resource data. */ export const EmbeddedResourceContentSchema = z.object({ - type: z.literal("resource"), - resource: z.object({ - uri: z.string(), - mimeType: z.string().optional(), - text: z.string().optional(), - blob: z.string().optional(), // base64 - }), - annotations: AnnotationsSchema.optional(), + type: z.literal("resource"), + resource: z.object({ + uri: z.string(), + mimeType: z.string().optional(), + text: z.string().optional(), + blob: z.string().optional(), // base64 + }), + annotations: AnnotationsSchema.optional(), }); -export type EmbeddedResourceContent = z.infer; +export type EmbeddedResourceContent = z.infer< + typeof EmbeddedResourceContentSchema +>; /** * Helper schema for any tool content item. */ export const ToolContentSchema = z.discriminatedUnion("type", [ - TextContentSchema, - ImageContentSchema, - AudioContentSchema, - ResourceLinkContentSchema, - EmbeddedResourceContentSchema, + TextContentSchema, + ImageContentSchema, + AudioContentSchema, + ResourceLinkContentSchema, + EmbeddedResourceContentSchema, ]); export type ToolContent = z.infer; diff --git a/packages/mcp/src/types/index.ts b/packages/mcp/src/types/index.ts index e302932..4d1445b 100644 --- a/packages/mcp/src/types/index.ts +++ b/packages/mcp/src/types/index.ts @@ -18,7 +18,9 @@ export interface McpClientEntry { // Forward reference - LoggingTransport is defined in transport module import type { LoggingTransport } from "../transport"; +export * from "./cancel"; +export * from "./progress"; +export * from "./task"; // Tool operation types (Phase 2a) export * from "./tool"; -export * from "./progress"; -export * from "./cancel"; +export * from "./tool-annotations"; diff --git a/packages/mcp/src/types/progress.ts b/packages/mcp/src/types/progress.ts index 72fa395..59c3f31 100644 --- a/packages/mcp/src/types/progress.ts +++ b/packages/mcp/src/types/progress.ts @@ -19,10 +19,10 @@ export type ProgressToken = z.infer; * Progress notification params received from server. */ export const ProgressNotificationSchema = z.object({ - progressToken: ProgressTokenSchema, - progress: z.number(), - total: z.number().optional(), - message: z.string().optional(), + progressToken: ProgressTokenSchema, + progress: z.number(), + total: z.number().optional(), + message: z.string().optional(), }); export type ProgressNotification = z.infer; @@ -32,8 +32,8 @@ export type ProgressNotification = z.infer; * Used for setNotificationHandler to register progress notification handler. */ export const McpProgressNotificationSchema = z.object({ - method: z.literal("notifications/progress"), - params: ProgressNotificationSchema, + method: z.literal("notifications/progress"), + params: ProgressNotificationSchema, }); /** @@ -41,12 +41,12 @@ export const McpProgressNotificationSchema = z.object({ * Adds timestamp and ID to the raw notification data. */ export const ProgressUpdateSchema = z.object({ - id: z.string().uuid(), - operationId: z.string().uuid(), - progress: z.number(), - total: z.number().optional(), - message: z.string().optional(), - timestamp: z.date(), + id: z.string().uuid(), + operationId: z.string().uuid(), + progress: z.number(), + total: z.number().optional(), + message: z.string().optional(), + timestamp: z.date(), }); export type ProgressUpdate = z.infer; diff --git a/packages/mcp/src/types/task.test.ts b/packages/mcp/src/types/task.test.ts new file mode 100644 index 0000000..90eb3cb --- /dev/null +++ b/packages/mcp/src/types/task.test.ts @@ -0,0 +1,349 @@ +/** + * Task Schema Tests + * + * Unit tests for Task-related Zod schemas. + * Task 07: Task-Augmented Execution - Phase 1 Schema Tests + */ + +import { describe, expect, it } from "bun:test"; +import { + CreateTaskResultSchema, + EmptyResultSchema, + RelatedTaskMetadataSchema, + type Task, + TaskGetResultSchema, + TaskListResultSchema, + TaskMetadataSchema, + TaskSchema, + type TaskStatus, + TaskStatusNotificationSchema, + TaskStatusSchema, +} from "./task"; + +describe("TaskStatusSchema", () => { + it("parses 'working' status", () => { + expect(TaskStatusSchema.parse("working")).toBe("working"); + }); + + it("parses 'input_required' status", () => { + expect(TaskStatusSchema.parse("input_required")).toBe("input_required"); + }); + + it("parses 'completed' status", () => { + expect(TaskStatusSchema.parse("completed")).toBe("completed"); + }); + + it("parses 'failed' status", () => { + expect(TaskStatusSchema.parse("failed")).toBe("failed"); + }); + + it("parses 'cancelled' status", () => { + expect(TaskStatusSchema.parse("cancelled")).toBe("cancelled"); + }); + + it("rejects invalid status string", () => { + expect(() => TaskStatusSchema.parse("pending")).toThrow(); + }); + + it("rejects non-string values", () => { + expect(() => TaskStatusSchema.parse(123)).toThrow(); + expect(() => TaskStatusSchema.parse(null)).toThrow(); + }); +}); + +describe("TaskMetadataSchema", () => { + it("parses empty object (all optional)", () => { + const result = TaskMetadataSchema.parse({}); + expect(result.ttl).toBeUndefined(); + }); + + it("parses ttl as number", () => { + const result = TaskMetadataSchema.parse({ ttl: 300000 }); + expect(result.ttl).toBe(300000); + }); + + it("rejects non-number ttl", () => { + expect(() => TaskMetadataSchema.parse({ ttl: "300000" })).toThrow(); + }); +}); + +describe("RelatedTaskMetadataSchema", () => { + it("parses taskId", () => { + const result = RelatedTaskMetadataSchema.parse({ taskId: "task-123" }); + expect(result.taskId).toBe("task-123"); + }); + + it("requires taskId", () => { + expect(() => RelatedTaskMetadataSchema.parse({})).toThrow(); + }); +}); + +describe("TaskSchema", () => { + const validTask: Task = { + taskId: "task-abc-123", + status: "working", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:00:05.000Z", + ttl: null, + }; + + it("parses minimal valid task", () => { + const result = TaskSchema.parse(validTask); + expect(result.taskId).toBe("task-abc-123"); + expect(result.status).toBe("working"); + expect(result.ttl).toBeNull(); + }); + + it("parses task with statusMessage", () => { + const task = { ...validTask, statusMessage: "Processing file..." }; + const result = TaskSchema.parse(task); + expect(result.statusMessage).toBe("Processing file..."); + }); + + it("parses task with pollInterval", () => { + const task = { ...validTask, pollInterval: 5000 }; + const result = TaskSchema.parse(task); + expect(result.pollInterval).toBe(5000); + }); + + it("parses task with numeric ttl", () => { + const task = { ...validTask, ttl: 600000 }; + const result = TaskSchema.parse(task); + expect(result.ttl).toBe(600000); + }); + + it("allows null ttl for unlimited retention", () => { + const result = TaskSchema.parse(validTask); + expect(result.ttl).toBeNull(); + }); + + it("validates ISO 8601 datetime for createdAt", () => { + const badTask = { ...validTask, createdAt: "not-a-date" }; + expect(() => TaskSchema.parse(badTask)).toThrow(); + }); + + it("validates ISO 8601 datetime for lastUpdatedAt", () => { + const badTask = { ...validTask, lastUpdatedAt: "2026/01/16" }; + expect(() => TaskSchema.parse(badTask)).toThrow(); + }); + + it("requires all mandatory fields", () => { + expect(() => TaskSchema.parse({ taskId: "test" })).toThrow(); + expect(() => TaskSchema.parse({ status: "working" })).toThrow(); + }); + + it("rejects invalid status", () => { + const badTask = { ...validTask, status: "running" }; + expect(() => TaskSchema.parse(badTask)).toThrow(); + }); + + it("parses complete task with all fields", () => { + const completeTask: Task = { + taskId: "task-full", + status: "completed", + statusMessage: "Done processing", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:05:00.000Z", + ttl: 3600000, + pollInterval: 2000, + }; + const result = TaskSchema.parse(completeTask); + expect(result.taskId).toBe("task-full"); + expect(result.status).toBe("completed"); + expect(result.statusMessage).toBe("Done processing"); + expect(result.ttl).toBe(3600000); + expect(result.pollInterval).toBe(2000); + }); +}); + +describe("CreateTaskResultSchema", () => { + const validTask: Task = { + taskId: "task-new", + status: "working", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:00:00.000Z", + ttl: null, + }; + + it("parses result with task only", () => { + const result = CreateTaskResultSchema.parse({ task: validTask }); + expect(result.task.taskId).toBe("task-new"); + expect(result._meta).toBeUndefined(); + }); + + it("parses result with _meta", () => { + const result = CreateTaskResultSchema.parse({ + task: validTask, + _meta: { custom: "data", version: 1 }, + }); + expect(result._meta?.custom).toBe("data"); + }); + + it("requires task field", () => { + expect(() => CreateTaskResultSchema.parse({})).toThrow(); + expect(() => CreateTaskResultSchema.parse({ _meta: {} })).toThrow(); + }); +}); + +describe("TaskListResultSchema", () => { + const task1: Task = { + taskId: "task-1", + status: "working", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:00:00.000Z", + ttl: null, + }; + + const task2: Task = { + taskId: "task-2", + status: "completed", + createdAt: "2026-01-16T09:00:00.000Z", + lastUpdatedAt: "2026-01-16T09:30:00.000Z", + ttl: 3600000, + }; + + it("parses empty task list", () => { + const result = TaskListResultSchema.parse({ tasks: [] }); + expect(result.tasks).toEqual([]); + expect(result.nextCursor).toBeUndefined(); + }); + + it("parses task list with items", () => { + const result = TaskListResultSchema.parse({ tasks: [task1, task2] }); + expect(result.tasks).toHaveLength(2); + expect(result.tasks[0]?.taskId).toBe("task-1"); + expect(result.tasks[1]?.taskId).toBe("task-2"); + }); + + it("parses task list with pagination cursor", () => { + const result = TaskListResultSchema.parse({ + tasks: [task1], + nextCursor: "cursor-abc", + }); + expect(result.nextCursor).toBe("cursor-abc"); + }); + + it("requires tasks array", () => { + expect(() => TaskListResultSchema.parse({})).toThrow(); + expect(() => TaskListResultSchema.parse({ nextCursor: "abc" })).toThrow(); + }); +}); + +describe("TaskGetResultSchema", () => { + it("is equivalent to TaskSchema", () => { + const task: Task = { + taskId: "task-get", + status: "failed", + statusMessage: "Server error", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:01:00.000Z", + ttl: null, + }; + const result = TaskGetResultSchema.parse(task); + expect(result.taskId).toBe("task-get"); + expect(result.status).toBe("failed"); + }); +}); + +describe("EmptyResultSchema", () => { + it("parses empty object", () => { + const result = EmptyResultSchema.parse({}); + expect(result).toEqual({}); + }); + + it("strips extra fields", () => { + const result = EmptyResultSchema.parse({ extra: "field" }); + // Zod strips unknown keys in strict mode or preserves in passthrough + // Default behavior: strips + expect((result as any).extra).toBeUndefined(); + }); +}); + +describe("Edge Cases", () => { + it("TaskStatus type inference is correct", () => { + const status: TaskStatus = "working"; + expect([ + "working", + "input_required", + "completed", + "failed", + "cancelled", + ]).toContain(status); + }); + + it("Task with all terminal statuses validates", () => { + const baseTask = { + taskId: "test", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:00:00.000Z", + ttl: null, + }; + + expect(TaskSchema.parse({ ...baseTask, status: "completed" }).status).toBe( + "completed", + ); + expect(TaskSchema.parse({ ...baseTask, status: "failed" }).status).toBe( + "failed", + ); + expect(TaskSchema.parse({ ...baseTask, status: "cancelled" }).status).toBe( + "cancelled", + ); + }); + + it("safeParse returns success false for invalid data", () => { + const result = TaskSchema.safeParse({ taskId: "test" }); + expect(result.success).toBe(false); + }); + + it("safeParse returns success true for valid data", () => { + const result = TaskSchema.safeParse({ + taskId: "test", + status: "working", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:00:00.000Z", + ttl: null, + }); + expect(result.success).toBe(true); + }); +}); + +describe("TaskStatusNotificationSchema", () => { + it("parses valid notification with method and params", () => { + const notification = { + method: "notifications/tasks/status", + params: { + taskId: "task-123", + status: "completed", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:05:00.000Z", + ttl: null, + }, + }; + const result = TaskStatusNotificationSchema.parse(notification); + expect(result.method).toBe("notifications/tasks/status"); + expect(result.params.taskId).toBe("task-123"); + expect(result.params.status).toBe("completed"); + }); + + it("rejects notification with wrong method", () => { + const notification = { + method: "notifications/progress", + params: { + taskId: "task-123", + status: "working", + createdAt: "2026-01-16T10:00:00.000Z", + lastUpdatedAt: "2026-01-16T10:00:00.000Z", + ttl: null, + }, + }; + expect(() => TaskStatusNotificationSchema.parse(notification)).toThrow(); + }); + + it("requires params to be valid Task", () => { + const notification = { + method: "notifications/tasks/status", + params: { taskId: "incomplete" }, // Missing required fields + }; + expect(() => TaskStatusNotificationSchema.parse(notification)).toThrow(); + }); +}); diff --git a/packages/mcp/src/types/task.ts b/packages/mcp/src/types/task.ts new file mode 100644 index 0000000..a52281f --- /dev/null +++ b/packages/mcp/src/types/task.ts @@ -0,0 +1,174 @@ +/** + * Task Types + * + * Zod schemas and TypeScript types for Task-Augmented Execution. + * Following MCP spec: https://spec.modelcontextprotocol.io/specification/2025-11-05/tasks/ + */ + +import { z } from "zod"; + +// ============================================================================= +// Task Status +// ============================================================================= + +/** + * Current state of a task. + */ +export const TaskStatusSchema = z.enum([ + "working", // Request currently being processed + "input_required", // Waiting for elicitation or sampling input + "completed", // Request completed successfully + "failed", // Request did not complete successfully + "cancelled", // Request was cancelled +]); + +export type TaskStatus = z.infer; + +// ============================================================================= +// Task Metadata (for request params) +// ============================================================================= + +/** + * Metadata to include in task-augmented requests. + */ +export const TaskMetadataSchema = z.object({ + /** + * Requested duration in milliseconds to retain task from creation. + */ + ttl: z.number().optional(), +}); + +export type TaskMetadata = z.infer; + +// ============================================================================= +// Related Task Metadata (for _meta field) +// ============================================================================= + +/** + * Metadata linking a message to a task. + */ +export const RelatedTaskMetadataSchema = z.object({ + taskId: z.string(), +}); + +export type RelatedTaskMetadata = z.infer; + +// ============================================================================= +// Task (from tasks/list or tasks/get response) +// ============================================================================= + +/** + * Task object representing a long-running operation. + */ +export const TaskSchema = z.object({ + /** + * The task identifier (receiver-generated). + */ + taskId: z.string(), + + /** + * Current task state. + */ + status: TaskStatusSchema, + + /** + * Optional human-readable message describing current state. + */ + statusMessage: z.string().optional(), + + /** + * ISO 8601 timestamp when task was created. + */ + createdAt: z.string().datetime(), + + /** + * ISO 8601 timestamp when task was last updated. + */ + lastUpdatedAt: z.string().datetime(), + + /** + * Actual retention duration in milliseconds, null for unlimited. + */ + ttl: z.number().nullable(), + + /** + * Suggested polling interval in milliseconds. + */ + pollInterval: z.number().optional(), +}); + +export type Task = z.infer; + +// ============================================================================= +// CreateTaskResult (initial response to task-augmented request) +// ============================================================================= + +/** + * Result returned when a task-augmented call creates a task. + */ +export const CreateTaskResultSchema = z.object({ + /** + * The created task (returned instead of CallToolResult). + */ + task: TaskSchema, + + /** + * Optional metadata. + */ + _meta: z.record(z.string(), z.unknown()).optional(), +}); + +export type CreateTaskResult = z.infer; + +// ============================================================================= +// TaskListResult (paginated list response) +// ============================================================================= + +/** + * Result from tasks/list request. + */ +export const TaskListResultSchema = z.object({ + tasks: z.array(TaskSchema), + nextCursor: z.string().optional(), +}); + +export type TaskListResult = z.infer; + +// ============================================================================= +// TaskGetResult (single task response) +// ============================================================================= + +/** + * Result from tasks/get request (same as Task). + */ +export const TaskGetResultSchema = TaskSchema; + +export type TaskGetResult = z.infer; + +// ============================================================================= +// EmptyResult (for cancel response) +// ============================================================================= + +/** + * Empty result from tasks/cancel. + */ +export const EmptyResultSchema = z.object({}); + +export type EmptyResult = z.infer; + +// ============================================================================= +// Task Status Notification (for setNotificationHandler) +// ============================================================================= + +/** + * MCP SDK-compatible notification schema with method field. + * Used for setNotificationHandler to register task status notification handler. + */ +export const TaskStatusNotificationSchema = z.object({ + method: z.literal("notifications/tasks/status"), + params: TaskSchema, +}); + +export type TaskStatusNotification = z.infer< + typeof TaskStatusNotificationSchema +>; diff --git a/packages/mcp/src/types/tool-annotations.test.ts b/packages/mcp/src/types/tool-annotations.test.ts new file mode 100644 index 0000000..3409c97 --- /dev/null +++ b/packages/mcp/src/types/tool-annotations.test.ts @@ -0,0 +1,448 @@ +/** + * Tool Annotations Schema Tests + * + * Unit tests for ToolAnnotationsSchema, ToolSchema, and helper functions. + * Task 06: Tool Annotations - Phase 1 Schema Tests + */ + +import { describe, expect, it } from "bun:test"; +import { + applyAnnotationDefaults, + getToolDisplayName, + IconSchema, + type Tool, + type ToolAnnotations, + ToolAnnotationsSchema, + ToolExecutionSchema, + ToolSchema, +} from "./tool-annotations"; + +describe("ToolAnnotationsSchema", () => { + describe("field parsing", () => { + it("parses title annotation", () => { + const annotations = { title: "My Tool Title" }; + const parsed = ToolAnnotationsSchema.parse(annotations); + expect(parsed.title).toBe("My Tool Title"); + }); + + it("parses readOnlyHint with default false", () => { + const parsed = ToolAnnotationsSchema.parse({}); + expect(parsed.readOnlyHint).toBe(false); + }); + + it("parses readOnlyHint when explicitly true", () => { + const parsed = ToolAnnotationsSchema.parse({ readOnlyHint: true }); + expect(parsed.readOnlyHint).toBe(true); + }); + + it("parses destructiveHint with default true", () => { + const parsed = ToolAnnotationsSchema.parse({}); + expect(parsed.destructiveHint).toBe(true); + }); + + it("parses destructiveHint when explicitly false", () => { + const parsed = ToolAnnotationsSchema.parse({ destructiveHint: false }); + expect(parsed.destructiveHint).toBe(false); + }); + + it("parses idempotentHint with default false", () => { + const parsed = ToolAnnotationsSchema.parse({}); + expect(parsed.idempotentHint).toBe(false); + }); + + it("parses idempotentHint when explicitly true", () => { + const parsed = ToolAnnotationsSchema.parse({ idempotentHint: true }); + expect(parsed.idempotentHint).toBe(true); + }); + + it("parses openWorldHint with default true", () => { + const parsed = ToolAnnotationsSchema.parse({}); + expect(parsed.openWorldHint).toBe(true); + }); + + it("parses openWorldHint when explicitly false", () => { + const parsed = ToolAnnotationsSchema.parse({ openWorldHint: false }); + expect(parsed.openWorldHint).toBe(false); + }); + }); + + describe("partial and empty annotations", () => { + it("handles empty annotations with all defaults", () => { + const parsed = ToolAnnotationsSchema.parse({}); + expect(parsed).toEqual({ + readOnlyHint: false, + destructiveHint: true, + idempotentHint: false, + openWorldHint: true, + }); + }); + + it("handles partial annotations - only title", () => { + const parsed = ToolAnnotationsSchema.parse({ title: "Just a title" }); + expect(parsed.title).toBe("Just a title"); + expect(parsed.readOnlyHint).toBe(false); + expect(parsed.destructiveHint).toBe(true); + }); + + it("handles partial annotations - only boolean hints", () => { + const parsed = ToolAnnotationsSchema.parse({ + readOnlyHint: true, + idempotentHint: true, + }); + expect(parsed.title).toBeUndefined(); + expect(parsed.readOnlyHint).toBe(true); + expect(parsed.destructiveHint).toBe(true); // default + expect(parsed.idempotentHint).toBe(true); + expect(parsed.openWorldHint).toBe(true); // default + }); + }); + + describe("invalid type rejection", () => { + it("rejects non-boolean readOnlyHint", () => { + expect(() => + ToolAnnotationsSchema.parse({ readOnlyHint: "yes" }), + ).toThrow(); + }); + + it("rejects non-boolean destructiveHint", () => { + expect(() => + ToolAnnotationsSchema.parse({ destructiveHint: 1 }), + ).toThrow(); + }); + + it("rejects non-string title", () => { + expect(() => ToolAnnotationsSchema.parse({ title: 123 })).toThrow(); + }); + + it("rejects non-boolean idempotentHint", () => { + expect(() => + ToolAnnotationsSchema.parse({ idempotentHint: null }), + ).toThrow(); + }); + + it("rejects non-boolean openWorldHint", () => { + expect(() => + ToolAnnotationsSchema.parse({ openWorldHint: {} }), + ).toThrow(); + }); + }); +}); + +describe("applyAnnotationDefaults", () => { + it("applies defaults to undefined", () => { + const result = applyAnnotationDefaults(undefined); + expect(result).toEqual({ + readOnlyHint: false, + destructiveHint: true, + idempotentHint: false, + openWorldHint: true, + }); + }); + + it("applies defaults to empty object", () => { + const result = applyAnnotationDefaults({}); + expect(result).toEqual({ + readOnlyHint: false, + destructiveHint: true, + idempotentHint: false, + openWorldHint: true, + }); + }); + + it("preserves provided values", () => { + const result = applyAnnotationDefaults({ + title: "Custom Title", + readOnlyHint: true, + }); + expect(result.title).toBe("Custom Title"); + expect(result.readOnlyHint).toBe(true); + expect(result.destructiveHint).toBe(true); // default + }); + + it("preserves all explicit values", () => { + const input: Partial = { + title: "Full Override", + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false, + }; + const result = applyAnnotationDefaults(input); + expect(result).toEqual(input as ToolAnnotations); + }); +}); + +describe("getToolDisplayName", () => { + it("returns annotations.title when present", () => { + const tool = { + name: "my_tool", + annotations: { title: "My Tool Display Name" } as ToolAnnotations, + }; + expect(getToolDisplayName(tool)).toBe("My Tool Display Name"); + }); + + it("falls back to name when title is undefined", () => { + const tool = { + name: "fallback_tool", + annotations: {} as ToolAnnotations, + }; + expect(getToolDisplayName(tool)).toBe("fallback_tool"); + }); + + it("falls back to name when annotations is undefined", () => { + const tool = { name: "no_annotations_tool" }; + expect(getToolDisplayName(tool)).toBe("no_annotations_tool"); + }); + + it("prefers title over name", () => { + const tool = { + name: "internal_name", + annotations: { + title: "User-Friendly Title", + readOnlyHint: false, + destructiveHint: true, + idempotentHint: false, + openWorldHint: true, + }, + }; + expect(getToolDisplayName(tool)).toBe("User-Friendly Title"); + }); +}); + +describe("ToolExecutionSchema", () => { + it("validates taskSupport: forbidden", () => { + const parsed = ToolExecutionSchema.parse({ taskSupport: "forbidden" }); + expect(parsed.taskSupport).toBe("forbidden"); + }); + + it("validates taskSupport: optional", () => { + const parsed = ToolExecutionSchema.parse({ taskSupport: "optional" }); + expect(parsed.taskSupport).toBe("optional"); + }); + + it("validates taskSupport: required", () => { + const parsed = ToolExecutionSchema.parse({ taskSupport: "required" }); + expect(parsed.taskSupport).toBe("required"); + }); + + it("allows empty object (all optional)", () => { + const parsed = ToolExecutionSchema.parse({}); + expect(parsed.taskSupport).toBeUndefined(); + }); + + it("rejects invalid taskSupport value", () => { + expect(() => + ToolExecutionSchema.parse({ taskSupport: "always" }), + ).toThrow(); + }); +}); + +describe("IconSchema", () => { + it("validates icon with src only", () => { + const parsed = IconSchema.parse({ src: "https://example.com/icon.png" }); + expect(parsed.src).toBe("https://example.com/icon.png"); + expect(parsed.mimeType).toBeUndefined(); + expect(parsed.sizes).toBeUndefined(); + }); + + it("validates icon with all fields", () => { + const icon = { + src: "data:image/png;base64,abc123", + mimeType: "image/png", + sizes: ["48x48", "96x96"], + }; + const parsed = IconSchema.parse(icon); + expect(parsed).toEqual(icon); + }); + + it("rejects missing src", () => { + expect(() => IconSchema.parse({ mimeType: "image/png" })).toThrow(); + }); + + it("rejects non-string src", () => { + expect(() => IconSchema.parse({ src: 123 })).toThrow(); + }); +}); + +describe("ToolSchema", () => { + const minimalTool = { + name: "test_tool", + inputSchema: { type: "object" as const }, + }; + + it("validates minimal tool definition", () => { + const parsed = ToolSchema.parse(minimalTool); + expect(parsed.name).toBe("test_tool"); + expect(parsed.inputSchema.type).toBe("object"); + }); + + it("validates tool with description", () => { + const tool = { ...minimalTool, description: "A test tool" }; + const parsed = ToolSchema.parse(tool); + expect(parsed.description).toBe("A test tool"); + }); + + it("validates tool with annotations", () => { + const tool = { + ...minimalTool, + annotations: { + title: "Test Tool", + readOnlyHint: true, + }, + }; + const parsed = ToolSchema.parse(tool); + expect(parsed.annotations?.title).toBe("Test Tool"); + expect(parsed.annotations?.readOnlyHint).toBe(true); + }); + + it("validates tool with outputSchema", () => { + const tool = { + ...minimalTool, + outputSchema: { + type: "object" as const, + properties: { result: { type: "string" } }, + }, + }; + const parsed = ToolSchema.parse(tool); + expect(parsed.outputSchema?.type).toBe("object"); + }); + + it("validates tool with execution config", () => { + const tool = { + ...minimalTool, + execution: { taskSupport: "optional" as const }, + }; + const parsed = ToolSchema.parse(tool); + expect(parsed.execution?.taskSupport).toBe("optional"); + }); + + it("validates tool with icons", () => { + const tool = { + ...minimalTool, + icons: [ + { src: "https://example.com/icon.svg", mimeType: "image/svg+xml" }, + ], + }; + const parsed = ToolSchema.parse(tool); + expect(parsed.icons).toHaveLength(1); + expect(parsed.icons?.[0]?.src).toBe("https://example.com/icon.svg"); + }); + + it("validates tool with _meta", () => { + const tool = { + ...minimalTool, + _meta: { version: "1.0", author: "test" }, + }; + const parsed = ToolSchema.parse(tool); + expect(parsed._meta?.version).toBe("1.0"); + }); + + it("validates complete tool with all fields", () => { + const completeTool: Tool = { + name: "complete_tool", + description: "A fully specified tool", + inputSchema: { + type: "object", + properties: { input: { type: "string" } }, + required: ["input"], + }, + outputSchema: { + type: "object", + properties: { output: { type: "number" } }, + }, + annotations: { + title: "Complete Tool", + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false, + }, + execution: { taskSupport: "required" }, + icons: [{ src: "/icon.png" }], + _meta: { custom: "data" }, + }; + const parsed = ToolSchema.parse(completeTool); + expect(parsed.name).toBe("complete_tool"); + expect(parsed.annotations?.title).toBe("Complete Tool"); + }); + + it("rejects tool without name", () => { + expect(() => + ToolSchema.parse({ inputSchema: { type: "object" } }), + ).toThrow(); + }); + + it("rejects tool without inputSchema", () => { + expect(() => ToolSchema.parse({ name: "no_schema" })).toThrow(); + }); + + it("rejects tool with invalid inputSchema type", () => { + expect(() => + ToolSchema.parse({ name: "bad_schema", inputSchema: { type: "array" } }), + ).toThrow(); + }); +}); + +describe("Phase 3: Edge Cases & Validation", () => { + it("strips unknown annotation fields", () => { + const result = ToolAnnotationsSchema.parse({ + title: "Test", + unknownField: "should be stripped", + }); + // biome-ignore lint/suspicious/noExplicitAny: testing stripper + expect((result as any).unknownField).toBeUndefined(); + expect(result.title).toBe("Test"); + }); + + it("safeParse handles invalid types gracefully", () => { + const result = ToolAnnotationsSchema.safeParse({ + readOnlyHint: "not a boolean", + }); + expect(result.success).toBe(false); + if (!result.success) { + expect(result.error?.issues[0]?.code).toBe("invalid_type"); + expect(result.error?.issues[0]?.path).toContain("readOnlyHint"); + } + }); + + it("handles null vs undefined gracefully", () => { + // undefined -> uses default + const res1 = ToolAnnotationsSchema.parse({ + readOnlyHint: undefined, + }); + expect(res1.readOnlyHint).toBe(false); + + // null -> invalid type (Zod default behavior for boolean is strict) + const res2 = ToolAnnotationsSchema.safeParse({ + readOnlyHint: null, + }); + expect(res2.success).toBe(false); + }); + + it("validates complex real-world annotations combination", () => { + const complex = { + title: "Production Tool", + readOnlyHint: true, + destructiveHint: false, // Explicit override + idempotentHint: true, + openWorldHint: false, + extraMetadata: { + source: "registry", + verified: true, + }, + }; + + const result = ToolAnnotationsSchema.parse(complex); + + expect(result).toEqual({ + title: "Production Tool", + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false, + }); + // Verify extra fields are stripped + // biome-ignore lint/suspicious/noExplicitAny: testing stripper + expect((result as any).extraMetadata).toBeUndefined(); + }); +}); diff --git a/packages/mcp/src/types/tool-annotations.ts b/packages/mcp/src/types/tool-annotations.ts new file mode 100644 index 0000000..016ba36 --- /dev/null +++ b/packages/mcp/src/types/tool-annotations.ts @@ -0,0 +1,178 @@ +/** + * Tool Annotations Types + * + * Zod schemas and TypeScript types for Tool Annotations. + * Following MCP spec: https://spec.modelcontextprotocol.io/specification/2025-11-05/server/tools/ + * + * NOTE: This is different from ContentAnnotationsSchema (audience/priority) in content.ts. + * ToolAnnotations are behavioral hints: readOnly, destructive, idempotent, openWorld. + */ + +import { z } from "zod"; + +// ============================================================================= +// Tool Annotations Schema (behavioral hints) +// ============================================================================= + +/** + * Tool annotations provide behavioral hints about tools. + * All properties are optional hints - not guaranteed to be accurate. + * + * @see https://spec.modelcontextprotocol.io/specification/2025-11-05/server/tools/#tool + */ +export const ToolAnnotationsSchema = z.object({ + /** + * A human-readable title for the tool. + */ + title: z.string().optional(), + + /** + * If true, the tool does not modify its environment. + * Default: false + */ + readOnlyHint: z.boolean().optional().default(false), + + /** + * If true, the tool may perform destructive updates to its environment. + * If false, the tool performs only additive updates. + * (Meaningful only when readOnlyHint == false) + * Default: true + */ + destructiveHint: z.boolean().optional().default(true), + + /** + * If true, calling the tool repeatedly with the same arguments + * will have no additional effect on its environment. + * (Meaningful only when readOnlyHint == false) + * Default: false + */ + idempotentHint: z.boolean().optional().default(false), + + /** + * If true, this tool may interact with an "open world" of external entities. + * If false, the tool's domain of interaction is closed. + * Default: true + */ + openWorldHint: z.boolean().optional().default(true), +}); + +export type ToolAnnotations = z.infer; + +// ============================================================================= +// Tool Execution Schema (for Task 07 - Augmented Tool Execution) +// ============================================================================= + +/** + * Tool execution configuration. + * Stub for Task 07 implementation. + */ +export const ToolExecutionSchema = z.object({ + /** + * Whether this tool supports task-based execution. + * - 'forbidden': Tool cannot be run as a task + * - 'optional': Tool can optionally run as a task + * - 'required': Tool must run as a task + */ + taskSupport: z.enum(["forbidden", "optional", "required"]).optional(), +}); + +export type ToolExecution = z.infer; + +// ============================================================================= +// Icon Schema (optional UI hints) +// ============================================================================= + +/** + * Icon for tool display in UIs. + */ +export const IconSchema = z.object({ + /** URL or data URI of the icon */ + src: z.string(), + /** MIME type of the icon (e.g., "image/png") */ + mimeType: z.string().optional(), + /** Available sizes (e.g., ["48x48", "96x96"]) */ + sizes: z.array(z.string()).optional(), +}); + +export type Icon = z.infer; + +// ============================================================================= +// Tool Schema (complete Tool interface from MCP SDK) +// ============================================================================= + +/** + * Complete Tool definition from MCP. + * Includes all properties from tools/list response. + */ +export const ToolSchema = z.object({ + /** Unique identifier for the tool */ + name: z.string(), + + /** Human-readable description of functionality */ + description: z.string().optional(), + + /** JSON Schema defining expected parameters */ + inputSchema: z.object({ + type: z.literal("object"), + properties: z.record(z.string(), z.unknown()).optional(), + required: z.array(z.string()).optional(), + }), + + /** Optional JSON Schema defining expected output structure */ + outputSchema: z + .object({ + type: z.literal("object"), + properties: z.record(z.string(), z.unknown()).optional(), + required: z.array(z.string()).optional(), + }) + .optional(), + + /** Behavioral hints for the tool */ + annotations: ToolAnnotationsSchema.optional(), + + /** Execution configuration (Task 07) */ + execution: ToolExecutionSchema.optional(), + + /** Icons for UI display */ + icons: z.array(IconSchema).optional(), + + /** Additional metadata */ + _meta: z.record(z.string(), z.unknown()).optional(), +}); + +export type Tool = z.infer; + +// ============================================================================= +// Helper: Apply annotation defaults +// ============================================================================= + +/** + * Apply spec-defined defaults to tool annotations. + * Ensures all hint fields are present with appropriate values. + * + * @param annotations - Partial annotations from server (may be undefined) + * @returns Complete ToolAnnotations with defaults applied + */ +export function applyAnnotationDefaults( + annotations?: Partial, +): ToolAnnotations { + return ToolAnnotationsSchema.parse(annotations ?? {}); +} + +// ============================================================================= +// Helper: Get display name +// ============================================================================= + +/** + * Get the display name for a tool following MCP precedence rules. + * Precedence: annotations.title > name + * + * @param tool - Tool object with name and optional annotations + * @returns The best display name for the tool + */ +export function getToolDisplayName(tool: { + name: string; + annotations?: ToolAnnotations; +}): string { + return tool.annotations?.title ?? tool.name; +} diff --git a/packages/mcp/src/types/tool.test.ts b/packages/mcp/src/types/tool.test.ts index c1994e1..dba110b 100644 --- a/packages/mcp/src/types/tool.test.ts +++ b/packages/mcp/src/types/tool.test.ts @@ -1,10 +1,4 @@ import { describe, expect, it } from "bun:test"; -import { - ToolCallRequestSchema, - ToolCallResultSchema, - ToolContentSchema, - ToolOperationSchema, -} from "./tool"; import { AnnotationsSchema, AudioContentSchema, @@ -13,6 +7,12 @@ import { ResourceLinkContentSchema, TextContentSchema, } from "./content"; +import { + ToolCallRequestSchema, + ToolCallResultSchema, + ToolContentSchema, + ToolOperationSchema, +} from "./tool"; describe("Tool Types Schemas", () => { describe("ToolCallRequestSchema", () => { diff --git a/packages/mcp/src/types/tool.ts b/packages/mcp/src/types/tool.ts index f0f813f..2f401ce 100644 --- a/packages/mcp/src/types/tool.ts +++ b/packages/mcp/src/types/tool.ts @@ -12,6 +12,7 @@ import { z } from "zod"; // ============================================================================= import { ToolContentSchema } from "./content"; + export * from "./content"; // ============================================================================= @@ -26,9 +27,9 @@ export * from "./content"; * Request to call a tool. */ export const ToolCallRequestSchema = z.object({ - name: z.string(), - arguments: z.record(z.string(), z.unknown()).optional(), - // _meta is used for progressToken, handled separately + name: z.string(), + arguments: z.record(z.string(), z.unknown()).optional(), + // _meta is used for progressToken, handled separately }); export type ToolCallRequest = z.infer; @@ -37,9 +38,9 @@ export type ToolCallRequest = z.infer; * Result returned from a tool call. */ export const ToolCallResultSchema = z.object({ - content: z.array(ToolContentSchema), - isError: z.boolean().optional(), - structuredContent: z.unknown().optional(), + content: z.array(ToolContentSchema), + isError: z.boolean().optional(), + structuredContent: z.unknown().optional(), }); export type ToolCallResult = z.infer; @@ -52,22 +53,22 @@ export type ToolCallResult = z.infer; * Status of a tool operation. */ export const ToolOperationStatus = { - PENDING: "pending", - COMPLETED: "completed", - ERROR: "error", - CANCELLED: "cancelled", + PENDING: "pending", + COMPLETED: "completed", + ERROR: "error", + CANCELLED: "cancelled", } as const; export type ToolOperationStatus = - (typeof ToolOperationStatus)[keyof typeof ToolOperationStatus]; + (typeof ToolOperationStatus)[keyof typeof ToolOperationStatus]; /** * JSON-RPC error structure. */ export const JsonRpcErrorSchema = z.object({ - code: z.number(), - message: z.string(), - data: z.unknown().optional(), + code: z.number(), + message: z.string(), + data: z.unknown().optional(), }); export type JsonRpcError = z.infer; @@ -76,30 +77,30 @@ export type JsonRpcError = z.infer; * A tool operation tracks the lifecycle of a single tools/call request. */ export const ToolOperationSchema = z.object({ - id: z.string().uuid(), - sessionId: z.string().uuid(), - requestId: z.string(), // JSON-RPC id for correlation - request: ToolCallRequestSchema, - status: z.enum(["pending", "completed", "error", "cancelled"]), - result: ToolCallResultSchema.optional(), - error: JsonRpcErrorSchema.optional(), - startedAt: z.date(), - completedAt: z.date().optional(), - // Progress tracking - progressToken: z.union([z.string(), z.number()]).optional(), - progressUpdates: z - .array( - z.object({ - progress: z.number(), - total: z.number().optional(), - message: z.string().optional(), - timestamp: z.date(), - }), - ) - .default([]), - // Cancellation - cancelRequested: z.boolean().default(false), - cancelReason: z.string().optional(), + id: z.string().uuid(), + sessionId: z.string().uuid(), + requestId: z.string(), // JSON-RPC id for correlation + request: ToolCallRequestSchema, + status: z.enum(["pending", "completed", "error", "cancelled"]), + result: ToolCallResultSchema.optional(), + error: JsonRpcErrorSchema.optional(), + startedAt: z.date(), + completedAt: z.date().optional(), + // Progress tracking + progressToken: z.union([z.string(), z.number()]).optional(), + progressUpdates: z + .array( + z.object({ + progress: z.number(), + total: z.number().optional(), + message: z.string().optional(), + timestamp: z.date(), + }), + ) + .default([]), + // Cancellation + cancelRequested: z.boolean().default(false), + cancelReason: z.string().optional(), }); export type ToolOperation = z.infer; @@ -112,10 +113,10 @@ export type ToolOperation = z.infer; * Options for calling a tool. */ export interface CallToolOptions { - /** Timeout in milliseconds. 0 = no timeout. */ - timeout?: number; - /** Whether to include a progress token for progress tracking. */ - includeProgress?: boolean; - /** JSON Schema for validating structuredContent. */ - outputSchema?: Record; + /** Timeout in milliseconds. 0 = no timeout. */ + timeout?: number; + /** Whether to include a progress token for progress tracking. */ + includeProgress?: boolean; + /** JSON Schema for validating structuredContent. */ + outputSchema?: Record; } diff --git a/packages/mcp/test/cancellation.test.ts b/packages/mcp/test/cancellation.test.ts index c1d2c3f..0b6c6c1 100644 --- a/packages/mcp/test/cancellation.test.ts +++ b/packages/mcp/test/cancellation.test.ts @@ -1,17 +1,17 @@ import { afterEach, beforeEach, describe, expect, test } from "bun:test"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { - createPipeline, - createStateMachineMiddleware, - LATEST_PROTOCOL_VERSION, - SessionManager, + createPipeline, + createStateMachineMiddleware, + LATEST_PROTOCOL_VERSION, + SessionManager, } from "@say2/core"; import { McpClientManager } from "../src/client/manager"; import { McpClientRegistry } from "../src/client/registry"; import { LoggingTransport } from "../src/transport"; import { - createMockServerTransport, - type MockServerTransport, + createMockServerTransport, + type MockServerTransport, } from "./fixtures/mock-server"; import { scenarioMockConfig } from "./fixtures/tool-scenarios"; @@ -25,292 +25,292 @@ import { scenarioMockConfig } from "./fixtures/tool-scenarios"; * 4. Responses after cancellation are ignored */ describe("Cancellation Integration", () => { - let sessionManager: SessionManager; - let pipeline: ReturnType; - let registry: McpClientRegistry; - let clientManager: McpClientManager; - let mockTransport: MockServerTransport; - let sessionId: string; - let client: Client; - - beforeEach(async () => { - sessionManager = new SessionManager(); - pipeline = createPipeline(); - - // Mock Protocol Detector - const mockDetector = { - isInitializeRequest: (msg: any) => - msg.method === "initialize" && "id" in msg, - isInitializeResponse: (msg: any) => - "result" in msg && "protocolVersion" in msg.result, - isInitializedNotification: (msg: any) => - msg.method === "notifications/initialized", - extractCapabilities: (msg: any) => msg.result?.capabilities, - extractServerInfo: (msg: any) => msg.result?.serverInfo, - }; - - pipeline.use( - (createStateMachineMiddleware as any)(sessionManager, mockDetector), - ); - - registry = new McpClientRegistry(); - clientManager = new McpClientManager(registry, sessionManager, pipeline); - - // Setup session - const session = sessionManager.create({ - name: "cancel-test-session", - transport: "stdio", - command: "node", - }); - sessionId = session.id; - - // Setup Transport - slowTool is configured with 5s delay for cancellation tests - mockTransport = createMockServerTransport(scenarioMockConfig); - client = new Client( - { name: "test-client", version: "1.0.0" }, - { capabilities: {} }, - ); - - const loggingTransport = new LoggingTransport( - mockTransport, - session, - pipeline, - ); - - // Initialize connection - await client.connect(loggingTransport); - registry.register(sessionId, client, loggingTransport); - - // Manually transition to ACTIVE - sessionManager.connect(sessionId); - sessionManager.initialize(sessionId); - sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); - }); - - afterEach(async () => { - if (mockTransport && !mockTransport.isClosed) { - await mockTransport.close(); - } - }); - - test("cancelOperation() sends notifications/cancelled to server", async () => { - // Start a slow tool call that we'll cancel - const toolCallPromise = clientManager.callTool(sessionId, { - name: "slowTool", - arguments: {}, - }); - - // Wait a moment for the request to be sent - await new Promise((resolve) => setTimeout(resolve, 100)); - - // Get the operation ID from the pending operation - const ops = clientManager.getToolOperations(sessionId); - const pendingOp = ops.find((op) => op.status === "pending"); - - if (pendingOp) { - await clientManager.cancelOperation(pendingOp.id, "Test cancellation"); - } - - // Verify mock server received cancellation - const cancelledRequests = mockTransport.getCancelledRequests(); - expect(cancelledRequests.length).toBeGreaterThan(0); - - // Clean up by waiting for the tool call to complete or fail - try { - await toolCallPromise; - } catch { - // Expected if cancelled - } - }); - - test("cancellation notification includes requestId", async () => { - // Capture sent messages - let cancelNotification: any = null; - const originalSend = mockTransport.send.bind(mockTransport); - mockTransport.send = async (msg: any) => { - if ( - "method" in msg && - msg.method === "notifications/cancelled" && - !("id" in msg) - ) { - cancelNotification = msg; - } - return originalSend(msg); - }; - - // Start slow tool and cancel - const toolCallPromise = clientManager.callTool(sessionId, { - name: "slowTool", - arguments: {}, - }); - - await new Promise((resolve) => setTimeout(resolve, 100)); - - const ops = clientManager.getToolOperations(sessionId); - const pendingOp = ops.find((op) => op.status === "pending"); - - if (pendingOp) { - await clientManager.cancelOperation(pendingOp.id); - } - - expect(cancelNotification).not.toBeNull(); - expect(cancelNotification?.params?.requestId).toBeDefined(); - - try { - await toolCallPromise; - } catch { - // Expected - } - }); - - test("cancellation notification includes reason when provided", async () => { - let cancelNotification: any = null; - const originalSend = mockTransport.send.bind(mockTransport); - mockTransport.send = async (msg: any) => { - if ( - "method" in msg && - msg.method === "notifications/cancelled" && - !("id" in msg) - ) { - cancelNotification = msg; - } - return originalSend(msg); - }; - - const toolCallPromise = clientManager.callTool(sessionId, { - name: "slowTool", - arguments: {}, - }); - - await new Promise((resolve) => setTimeout(resolve, 100)); - - const ops = clientManager.getToolOperations(sessionId); - const pendingOp = ops.find((op) => op.status === "pending"); - - if (pendingOp) { - await clientManager.cancelOperation(pendingOp.id, "User clicked cancel"); - } - - expect(cancelNotification?.params?.reason).toBe("User clicked cancel"); - - try { - await toolCallPromise; - } catch { - // Expected - } - }); - - test("cancelled operation has status 'cancelled'", async () => { - const toolCallPromise = clientManager.callTool(sessionId, { - name: "slowTool", - arguments: {}, - }); - - await new Promise((resolve) => setTimeout(resolve, 100)); - - const ops = clientManager.getToolOperations(sessionId); - const pendingOp = ops.find((op) => op.status === "pending"); - - if (pendingOp) { - await clientManager.cancelOperation(pendingOp.id); - } - - // Wait for the call to resolve/reject - try { - await toolCallPromise; - } catch { - // Expected - } - - // Verify status is cancelled - if (pendingOp) { - const finalOp = clientManager.getToolOperation(pendingOp.id); - expect(finalOp?.status).toBe("cancelled"); - } - }); - - test("response after cancellation is ignored", async () => { - const toolCallPromise = clientManager.callTool(sessionId, { - name: "slowTool", - arguments: {}, - }); - - await new Promise((resolve) => setTimeout(resolve, 100)); - - const ops = clientManager.getToolOperations(sessionId); - const pendingOp = ops.find((op) => op.status === "pending"); - - if (pendingOp) { - await clientManager.cancelOperation(pendingOp.id); - } - - // The mock server should not send response for cancelled requests - // (see mock-server.ts lines 476-480 and 514-516) - - try { - await toolCallPromise; - } catch { - // Expected - } - - if (pendingOp) { - const finalOp = clientManager.getToolOperation(pendingOp.id); - // Status should remain cancelled, not completed - expect(finalOp?.status).toBe("cancelled"); - // Result should not be set - expect(finalOp?.result).toBeUndefined(); - } - }); - - test("timeout auto-cancels long-running operation", async () => { - // This test requires the callTool to support timeout option - const toolCallPromise = clientManager.callTool( - sessionId, - { name: "verySlowTool", arguments: {} }, - { timeout: 500 }, // 500ms timeout for testing - ); - - // Wait for timeout to trigger - try { - await toolCallPromise; - } catch { - // Expected to throw or return error status - } - - const ops = clientManager.getToolOperations(sessionId); - const op = ops[ops.length - 1]; - - // Should be either cancelled or error due to timeout - expect(op).toBeDefined(); - expect(["cancelled", "error"]).toContain(op!.status); - }); - - test("completed operation cannot be cancelled", async () => { - // Call a fast tool that will complete quickly - const result = await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "quick" }, - }); - - expect(result.status).toBe("completed"); - - // Attempt to cancel completed operation - await clientManager.cancelOperation(result.id, "Too late"); - - // Status should still be completed - const finalOp = clientManager.getToolOperation(result.id); - expect(finalOp?.status).toBe("completed"); - }); - - test("normal tool completion clears pending tracking", async () => { - const result = await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "test" }, - }); - - expect(result.status).toBe("completed"); - - // Verify no stale pending requests - // (Implementation detail: onResponse should have been called) - }); + let sessionManager: SessionManager; + let pipeline: ReturnType; + let registry: McpClientRegistry; + let clientManager: McpClientManager; + let mockTransport: MockServerTransport; + let sessionId: string; + let client: Client; + + beforeEach(async () => { + sessionManager = new SessionManager(); + pipeline = createPipeline(); + + // Mock Protocol Detector + const mockDetector = { + isInitializeRequest: (msg: any) => + msg.method === "initialize" && "id" in msg, + isInitializeResponse: (msg: any) => + "result" in msg && "protocolVersion" in msg.result, + isInitializedNotification: (msg: any) => + msg.method === "notifications/initialized", + extractCapabilities: (msg: any) => msg.result?.capabilities, + extractServerInfo: (msg: any) => msg.result?.serverInfo, + }; + + pipeline.use( + (createStateMachineMiddleware as any)(sessionManager, mockDetector), + ); + + registry = new McpClientRegistry(); + clientManager = new McpClientManager(registry, sessionManager, pipeline); + + // Setup session + const session = sessionManager.create({ + name: "cancel-test-session", + transport: "stdio", + command: "node", + }); + sessionId = session.id; + + // Setup Transport - slowTool is configured with 5s delay for cancellation tests + mockTransport = createMockServerTransport(scenarioMockConfig); + client = new Client( + { name: "test-client", version: "1.0.0" }, + { capabilities: {} }, + ); + + const loggingTransport = new LoggingTransport( + mockTransport, + session, + pipeline, + ); + + // Initialize connection + await client.connect(loggingTransport); + registry.register(sessionId, client, loggingTransport); + + // Manually transition to ACTIVE + sessionManager.connect(sessionId); + sessionManager.initialize(sessionId); + sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); + }); + + afterEach(async () => { + if (mockTransport && !mockTransport.isClosed) { + await mockTransport.close(); + } + }); + + test("cancelOperation() sends notifications/cancelled to server", async () => { + // Start a slow tool call that we'll cancel + const toolCallPromise = clientManager.callTool(sessionId, { + name: "slowTool", + arguments: {}, + }); + + // Wait a moment for the request to be sent + await new Promise((resolve) => setTimeout(resolve, 100)); + + // Get the operation ID from the pending operation + const ops = clientManager.getToolOperations(sessionId); + const pendingOp = ops.find((op) => op.status === "pending"); + + if (pendingOp) { + await clientManager.cancelOperation(pendingOp.id, "Test cancellation"); + } + + // Verify mock server received cancellation + const cancelledRequests = mockTransport.getCancelledRequests(); + expect(cancelledRequests.length).toBeGreaterThan(0); + + // Clean up by waiting for the tool call to complete or fail + try { + await toolCallPromise; + } catch { + // Expected if cancelled + } + }); + + test("cancellation notification includes requestId", async () => { + // Capture sent messages + let cancelNotification: any = null; + const originalSend = mockTransport.send.bind(mockTransport); + mockTransport.send = async (msg: any) => { + if ( + "method" in msg && + msg.method === "notifications/cancelled" && + !("id" in msg) + ) { + cancelNotification = msg; + } + return originalSend(msg); + }; + + // Start slow tool and cancel + const toolCallPromise = clientManager.callTool(sessionId, { + name: "slowTool", + arguments: {}, + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const ops = clientManager.getToolOperations(sessionId); + const pendingOp = ops.find((op) => op.status === "pending"); + + if (pendingOp) { + await clientManager.cancelOperation(pendingOp.id); + } + + expect(cancelNotification).not.toBeNull(); + expect(cancelNotification?.params?.requestId).toBeDefined(); + + try { + await toolCallPromise; + } catch { + // Expected + } + }); + + test("cancellation notification includes reason when provided", async () => { + let cancelNotification: any = null; + const originalSend = mockTransport.send.bind(mockTransport); + mockTransport.send = async (msg: any) => { + if ( + "method" in msg && + msg.method === "notifications/cancelled" && + !("id" in msg) + ) { + cancelNotification = msg; + } + return originalSend(msg); + }; + + const toolCallPromise = clientManager.callTool(sessionId, { + name: "slowTool", + arguments: {}, + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const ops = clientManager.getToolOperations(sessionId); + const pendingOp = ops.find((op) => op.status === "pending"); + + if (pendingOp) { + await clientManager.cancelOperation(pendingOp.id, "User clicked cancel"); + } + + expect(cancelNotification?.params?.reason).toBe("User clicked cancel"); + + try { + await toolCallPromise; + } catch { + // Expected + } + }); + + test("cancelled operation has status 'cancelled'", async () => { + const toolCallPromise = clientManager.callTool(sessionId, { + name: "slowTool", + arguments: {}, + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const ops = clientManager.getToolOperations(sessionId); + const pendingOp = ops.find((op) => op.status === "pending"); + + if (pendingOp) { + await clientManager.cancelOperation(pendingOp.id); + } + + // Wait for the call to resolve/reject + try { + await toolCallPromise; + } catch { + // Expected + } + + // Verify status is cancelled + if (pendingOp) { + const finalOp = clientManager.getToolOperation(pendingOp.id); + expect(finalOp?.status).toBe("cancelled"); + } + }); + + test("response after cancellation is ignored", async () => { + const toolCallPromise = clientManager.callTool(sessionId, { + name: "slowTool", + arguments: {}, + }); + + await new Promise((resolve) => setTimeout(resolve, 100)); + + const ops = clientManager.getToolOperations(sessionId); + const pendingOp = ops.find((op) => op.status === "pending"); + + if (pendingOp) { + await clientManager.cancelOperation(pendingOp.id); + } + + // The mock server should not send response for cancelled requests + // (see mock-server.ts lines 476-480 and 514-516) + + try { + await toolCallPromise; + } catch { + // Expected + } + + if (pendingOp) { + const finalOp = clientManager.getToolOperation(pendingOp.id); + // Status should remain cancelled, not completed + expect(finalOp?.status).toBe("cancelled"); + // Result should not be set + expect(finalOp?.result).toBeUndefined(); + } + }); + + test("timeout auto-cancels long-running operation", async () => { + // This test requires the callTool to support timeout option + const toolCallPromise = clientManager.callTool( + sessionId, + { name: "verySlowTool", arguments: {} }, + { timeout: 500 }, // 500ms timeout for testing + ); + + // Wait for timeout to trigger + try { + await toolCallPromise; + } catch { + // Expected to throw or return error status + } + + const ops = clientManager.getToolOperations(sessionId); + const op = ops[ops.length - 1]; + + // Should be either cancelled or error due to timeout + expect(op).toBeDefined(); + expect(["cancelled", "error"]).toContain(op?.status); + }); + + test("completed operation cannot be cancelled", async () => { + // Call a fast tool that will complete quickly + const result = await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "quick" }, + }); + + expect(result.status).toBe("completed"); + + // Attempt to cancel completed operation + await clientManager.cancelOperation(result.id, "Too late"); + + // Status should still be completed + const finalOp = clientManager.getToolOperation(result.id); + expect(finalOp?.status).toBe("completed"); + }); + + test("normal tool completion clears pending tracking", async () => { + const result = await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "test" }, + }); + + expect(result.status).toBe("completed"); + + // Verify no stale pending requests + // (Implementation detail: onResponse should have been called) + }); }); diff --git a/packages/mcp/test/content-parsing.test.ts b/packages/mcp/test/content-parsing.test.ts index a95054d..3bab554 100644 --- a/packages/mcp/test/content-parsing.test.ts +++ b/packages/mcp/test/content-parsing.test.ts @@ -1,17 +1,17 @@ import { afterEach, beforeEach, describe, expect, test } from "bun:test"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { - createPipeline, - createStateMachineMiddleware, - LATEST_PROTOCOL_VERSION, - SessionManager, + createPipeline, + createStateMachineMiddleware, + LATEST_PROTOCOL_VERSION, + SessionManager, } from "@say2/core"; import { McpClientManager } from "../src/client/manager"; import { McpClientRegistry } from "../src/client/registry"; import { LoggingTransport } from "../src/transport"; import { - createMockServerTransport, - type MockServerTransport, + createMockServerTransport, + type MockServerTransport, } from "./fixtures/mock-server"; import { scenarioMockConfig } from "./fixtures/tool-scenarios"; @@ -25,198 +25,198 @@ import { scenarioMockConfig } from "./fixtures/tool-scenarios"; * 4. Structured output is handled */ describe("Content Parsing Integration", () => { - let sessionManager: SessionManager; - let pipeline: ReturnType; - let registry: McpClientRegistry; - let clientManager: McpClientManager; - let mockTransport: MockServerTransport; - let sessionId: string; - let client: Client; - - beforeEach(async () => { - sessionManager = new SessionManager(); - pipeline = createPipeline(); - - // Mock Protocol Detector - const mockDetector = { - isInitializeRequest: (msg: any) => - msg.method === "initialize" && "id" in msg, - isInitializeResponse: (msg: any) => - "result" in msg && "protocolVersion" in msg.result, - isInitializedNotification: (msg: any) => - msg.method === "notifications/initialized", - extractCapabilities: (msg: any) => msg.result?.capabilities, - extractServerInfo: (msg: any) => msg.result?.serverInfo, - }; - - pipeline.use( - (createStateMachineMiddleware as any)(sessionManager, mockDetector), - ); - - registry = new McpClientRegistry(); - clientManager = new McpClientManager(registry, sessionManager, pipeline); - - // Setup session - const session = sessionManager.create({ - name: "content-test-session", - transport: "stdio", - command: "node", - }); - sessionId = session.id; - - // Setup Transport with content-returning tools - mockTransport = createMockServerTransport(scenarioMockConfig); - client = new Client( - { name: "test-client", version: "1.0.0" }, - { capabilities: {} }, - ); - - const loggingTransport = new LoggingTransport( - mockTransport, - session, - pipeline, - ); - - // Initialize connection - await client.connect(loggingTransport); - registry.register(sessionId, client, loggingTransport); - - // Manually transition to ACTIVE - sessionManager.connect(sessionId); - sessionManager.initialize(sessionId); - sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); - }); - - afterEach(async () => { - if (mockTransport && !mockTransport.isClosed) { - await mockTransport.close(); - } - }); - - test("tool returns audio content and it is parsed correctly", async () => { - const result = await clientManager.callTool(sessionId, { - name: "getAudio", - }); - - expect(result.status).toBe("completed"); - expect(result.result?.content).toHaveLength(1); - - const content = result.result!.content[0]; - expect(content?.type).toBe("audio"); - if (content?.type === "audio") { - expect(content.data).toBeDefined(); - expect(content.data.length).toBeGreaterThan(0); - expect(content.mimeType).toBe("audio/wav"); - } - }); - - test("tool returns structuredContent and it is available in result", async () => { - const result = await clientManager.callTool(sessionId, { - name: "getStructured", - }); - - expect(result.status).toBe("completed"); - expect(result.result?.structuredContent).toBeDefined(); - - const structured = result.result!.structuredContent as any; - expect(structured.result).toBe("success"); - expect(structured.count).toBe(42); - expect(structured.items).toEqual(["a", "b", "c"]); - }); - - test("annotations are preserved in parsed content", async () => { - const result = await clientManager.callTool(sessionId, { - name: "getAnnotated", - }); - - expect(result.status).toBe("completed"); - expect(result.result?.content).toHaveLength(1); - - const content = result.result!.content[0]; - expect(content?.annotations).toBeDefined(); - expect(content?.annotations?.audience).toContain("user"); - expect(content?.annotations?.priority).toBe(0.8); - }); - - test("large base64 content is handled without memory issues", async () => { - // getImage returns a small image, but this test verifies the mechanism works - const result = await clientManager.callTool(sessionId, { - name: "getImage", - }); - - expect(result.status).toBe("completed"); - - const content = result.result!.content[0]; - expect(content?.type).toBe("image"); - if (content?.type === "image") { - // Verify base64 data is present and valid - expect(content.data.length).toBeGreaterThan(10); - // Base64 should only contain valid characters - expect(content.data).toMatch(/^[A-Za-z0-9+/=]+$/); - } - }); - - test("embedded resource content is parsed correctly", async () => { - const result = await clientManager.callTool(sessionId, { - name: "getEmbeddedResource", - }); - - expect(result.status).toBe("completed"); - - const content = result.result!.content[0]; - expect(content?.type).toBe("resource"); - if (content?.type === "resource") { - expect(content.resource.uri).toBe("file:///path/to/data.json"); - expect(content.resource.text).toBe('{"key": "value"}'); - expect(content.resource.mimeType).toBe("application/json"); - } - }); - - test("resource_link content is parsed correctly", async () => { - const result = await clientManager.callTool(sessionId, { - name: "getResourceLink", - }); - - expect(result.status).toBe("completed"); - - const content = result.result!.content[0]; - expect(content?.type).toBe("resource_link"); - if (content?.type === "resource_link") { - expect(content.uri).toBe("file:///path/to/resource.txt"); - expect(content.name).toBe("Resource File"); - expect(content.mimeType).toBe("text/plain"); - } - }); - - test("mixed content types are all parsed correctly", async () => { - const result = await clientManager.callTool(sessionId, { - name: "getMixed", - }); - - expect(result.status).toBe("completed"); - expect(result.result?.content).toHaveLength(3); - - const types = result.result!.content.map((c) => c.type); - expect(types).toContain("text"); - expect(types).toContain("image"); - expect(types).toContain("resource_link"); - }); - - test("text content is parsed correctly", async () => { - const result = await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "Hello World" }, - }); - - expect(result.status).toBe("completed"); - expect(result.result?.content).toHaveLength(1); - - const content = result.result!.content[0]; - expect(content?.type).toBe("text"); - if (content?.type === "text") { - expect(content.text).toContain("Hello World"); - } - }); + let sessionManager: SessionManager; + let pipeline: ReturnType; + let registry: McpClientRegistry; + let clientManager: McpClientManager; + let mockTransport: MockServerTransport; + let sessionId: string; + let client: Client; + + beforeEach(async () => { + sessionManager = new SessionManager(); + pipeline = createPipeline(); + + // Mock Protocol Detector + const mockDetector = { + isInitializeRequest: (msg: any) => + msg.method === "initialize" && "id" in msg, + isInitializeResponse: (msg: any) => + "result" in msg && "protocolVersion" in msg.result, + isInitializedNotification: (msg: any) => + msg.method === "notifications/initialized", + extractCapabilities: (msg: any) => msg.result?.capabilities, + extractServerInfo: (msg: any) => msg.result?.serverInfo, + }; + + pipeline.use( + (createStateMachineMiddleware as any)(sessionManager, mockDetector), + ); + + registry = new McpClientRegistry(); + clientManager = new McpClientManager(registry, sessionManager, pipeline); + + // Setup session + const session = sessionManager.create({ + name: "content-test-session", + transport: "stdio", + command: "node", + }); + sessionId = session.id; + + // Setup Transport with content-returning tools + mockTransport = createMockServerTransport(scenarioMockConfig); + client = new Client( + { name: "test-client", version: "1.0.0" }, + { capabilities: {} }, + ); + + const loggingTransport = new LoggingTransport( + mockTransport, + session, + pipeline, + ); + + // Initialize connection + await client.connect(loggingTransport); + registry.register(sessionId, client, loggingTransport); + + // Manually transition to ACTIVE + sessionManager.connect(sessionId); + sessionManager.initialize(sessionId); + sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); + }); + + afterEach(async () => { + if (mockTransport && !mockTransport.isClosed) { + await mockTransport.close(); + } + }); + + test("tool returns audio content and it is parsed correctly", async () => { + const result = await clientManager.callTool(sessionId, { + name: "getAudio", + }); + + expect(result.status).toBe("completed"); + expect(result.result?.content).toHaveLength(1); + + const content = result.result?.content[0]; + expect(content?.type).toBe("audio"); + if (content?.type === "audio") { + expect(content.data).toBeDefined(); + expect(content.data.length).toBeGreaterThan(0); + expect(content.mimeType).toBe("audio/wav"); + } + }); + + test("tool returns structuredContent and it is available in result", async () => { + const result = await clientManager.callTool(sessionId, { + name: "getStructured", + }); + + expect(result.status).toBe("completed"); + expect(result.result?.structuredContent).toBeDefined(); + + const structured = result.result?.structuredContent as any; + expect(structured.result).toBe("success"); + expect(structured.count).toBe(42); + expect(structured.items).toEqual(["a", "b", "c"]); + }); + + test("annotations are preserved in parsed content", async () => { + const result = await clientManager.callTool(sessionId, { + name: "getAnnotated", + }); + + expect(result.status).toBe("completed"); + expect(result.result?.content).toHaveLength(1); + + const content = result.result?.content[0]; + expect(content?.annotations).toBeDefined(); + expect(content?.annotations?.audience).toContain("user"); + expect(content?.annotations?.priority).toBe(0.8); + }); + + test("large base64 content is handled without memory issues", async () => { + // getImage returns a small image, but this test verifies the mechanism works + const result = await clientManager.callTool(sessionId, { + name: "getImage", + }); + + expect(result.status).toBe("completed"); + + const content = result.result?.content[0]; + expect(content?.type).toBe("image"); + if (content?.type === "image") { + // Verify base64 data is present and valid + expect(content.data.length).toBeGreaterThan(10); + // Base64 should only contain valid characters + expect(content.data).toMatch(/^[A-Za-z0-9+/=]+$/); + } + }); + + test("embedded resource content is parsed correctly", async () => { + const result = await clientManager.callTool(sessionId, { + name: "getEmbeddedResource", + }); + + expect(result.status).toBe("completed"); + + const content = result.result?.content[0]; + expect(content?.type).toBe("resource"); + if (content?.type === "resource") { + expect(content.resource.uri).toBe("file:///path/to/data.json"); + expect(content.resource.text).toBe('{"key": "value"}'); + expect(content.resource.mimeType).toBe("application/json"); + } + }); + + test("resource_link content is parsed correctly", async () => { + const result = await clientManager.callTool(sessionId, { + name: "getResourceLink", + }); + + expect(result.status).toBe("completed"); + + const content = result.result?.content[0]; + expect(content?.type).toBe("resource_link"); + if (content?.type === "resource_link") { + expect(content.uri).toBe("file:///path/to/resource.txt"); + expect(content.name).toBe("Resource File"); + expect(content.mimeType).toBe("text/plain"); + } + }); + + test("mixed content types are all parsed correctly", async () => { + const result = await clientManager.callTool(sessionId, { + name: "getMixed", + }); + + expect(result.status).toBe("completed"); + expect(result.result?.content).toHaveLength(3); + + const types = result.result?.content.map((c) => c.type); + expect(types).toContain("text"); + expect(types).toContain("image"); + expect(types).toContain("resource_link"); + }); + + test("text content is parsed correctly", async () => { + const result = await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "Hello World" }, + }); + + expect(result.status).toBe("completed"); + expect(result.result?.content).toHaveLength(1); + + const content = result.result?.content[0]; + expect(content?.type).toBe("text"); + if (content?.type === "text") { + expect(content.text).toContain("Hello World"); + } + }); }); /** @@ -231,115 +231,115 @@ describe("Content Parsing Integration", () => { * - Invalid structuredContent → validation error stored */ describe("ContentParser Integration Gap Detection", () => { - let sessionManager: SessionManager; - let pipeline: ReturnType; - let registry: McpClientRegistry; - let clientManager: McpClientManager; - let mockTransport: MockServerTransport; - let sessionId: string; - let client: Client; - - beforeEach(async () => { - sessionManager = new SessionManager(); - pipeline = createPipeline(); - - const mockDetector = { - isInitializeRequest: (msg: any) => - msg.method === "initialize" && "id" in msg, - isInitializeResponse: (msg: any) => - "result" in msg && "protocolVersion" in msg.result, - isInitializedNotification: (msg: any) => - msg.method === "notifications/initialized", - extractCapabilities: (msg: any) => msg.result?.capabilities, - extractServerInfo: (msg: any) => msg.result?.serverInfo, - }; - - pipeline.use( - (createStateMachineMiddleware as any)(sessionManager, mockDetector), - ); - - registry = new McpClientRegistry(); - clientManager = new McpClientManager(registry, sessionManager, pipeline); - - const session = sessionManager.create({ - name: "gap-test-session", - transport: "stdio", - command: "node", - }); - sessionId = session.id; - - mockTransport = createMockServerTransport(scenarioMockConfig); - client = new Client( - { name: "test-client", version: "1.0.0" }, - { capabilities: {} }, - ); - - const loggingTransport = new LoggingTransport( - mockTransport, - session, - pipeline, - ); - - await client.connect(loggingTransport); - registry.register(sessionId, client, loggingTransport); - - sessionManager.connect(sessionId); - sessionManager.initialize(sessionId); - sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); - }); - - afterEach(async () => { - if (mockTransport && !mockTransport.isClosed) { - await mockTransport.close(); - } - }); - - // contentParser is now integrated into callTool() - test("tool returning invalid audio MIME type should fail parsing", async () => { - // This tool returns audio with mimeType "audio/x-invalid-fake" - // If contentParser.parseContent() is integrated, it should throw - const result = await clientManager.callTool(sessionId, { - name: "getInvalidAudioMime", - }); - - // Expected: operation should have error status due to parsing failure - expect(result.status).toBe("error"); - expect(result.error?.message).toContain("Invalid audio MIME type"); - }); - - // contentParser is now integrated into callTool() - test("tool returning invalid image MIME type should fail parsing", async () => { - // This tool returns image with mimeType "image/x-invalid-fake" - const result = await clientManager.callTool(sessionId, { - name: "getInvalidImageMime", - }); - - // Expected: operation should have error status due to parsing failure - expect(result.status).toBe("error"); - expect(result.error?.message).toContain("Invalid image MIME type"); - }); - - // validateStructuredOutput is now integrated into callTool() - test("tool returning invalid structuredContent should fail validation", async () => { - // This tool returns structuredContent that doesn't match outputSchema - // The outputSchema requires 'result' field which is missing - const result = await clientManager.callTool( - sessionId, - { name: "getInvalidStructuredOutput" }, - { - // Provide the outputSchema in options - implementation validates against this - outputSchema: { - type: "object", - properties: { - result: { type: "string" }, - }, - required: ["result"], - }, - }, - ); - - // Expected: validation should fail and be recorded in operation - expect(result.status).toBe("error"); - expect(result.error?.message).toContain("Invalid structured output"); - }); + let sessionManager: SessionManager; + let pipeline: ReturnType; + let registry: McpClientRegistry; + let clientManager: McpClientManager; + let mockTransport: MockServerTransport; + let sessionId: string; + let client: Client; + + beforeEach(async () => { + sessionManager = new SessionManager(); + pipeline = createPipeline(); + + const mockDetector = { + isInitializeRequest: (msg: any) => + msg.method === "initialize" && "id" in msg, + isInitializeResponse: (msg: any) => + "result" in msg && "protocolVersion" in msg.result, + isInitializedNotification: (msg: any) => + msg.method === "notifications/initialized", + extractCapabilities: (msg: any) => msg.result?.capabilities, + extractServerInfo: (msg: any) => msg.result?.serverInfo, + }; + + pipeline.use( + (createStateMachineMiddleware as any)(sessionManager, mockDetector), + ); + + registry = new McpClientRegistry(); + clientManager = new McpClientManager(registry, sessionManager, pipeline); + + const session = sessionManager.create({ + name: "gap-test-session", + transport: "stdio", + command: "node", + }); + sessionId = session.id; + + mockTransport = createMockServerTransport(scenarioMockConfig); + client = new Client( + { name: "test-client", version: "1.0.0" }, + { capabilities: {} }, + ); + + const loggingTransport = new LoggingTransport( + mockTransport, + session, + pipeline, + ); + + await client.connect(loggingTransport); + registry.register(sessionId, client, loggingTransport); + + sessionManager.connect(sessionId); + sessionManager.initialize(sessionId); + sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); + }); + + afterEach(async () => { + if (mockTransport && !mockTransport.isClosed) { + await mockTransport.close(); + } + }); + + // contentParser is now integrated into callTool() + test("tool returning invalid audio MIME type should fail parsing", async () => { + // This tool returns audio with mimeType "audio/x-invalid-fake" + // If contentParser.parseContent() is integrated, it should throw + const result = await clientManager.callTool(sessionId, { + name: "getInvalidAudioMime", + }); + + // Expected: operation should have error status due to parsing failure + expect(result.status).toBe("error"); + expect(result.error?.message).toContain("Invalid audio MIME type"); + }); + + // contentParser is now integrated into callTool() + test("tool returning invalid image MIME type should fail parsing", async () => { + // This tool returns image with mimeType "image/x-invalid-fake" + const result = await clientManager.callTool(sessionId, { + name: "getInvalidImageMime", + }); + + // Expected: operation should have error status due to parsing failure + expect(result.status).toBe("error"); + expect(result.error?.message).toContain("Invalid image MIME type"); + }); + + // validateStructuredOutput is now integrated into callTool() + test("tool returning invalid structuredContent should fail validation", async () => { + // This tool returns structuredContent that doesn't match outputSchema + // The outputSchema requires 'result' field which is missing + const result = await clientManager.callTool( + sessionId, + { name: "getInvalidStructuredOutput" }, + { + // Provide the outputSchema in options - implementation validates against this + outputSchema: { + type: "object", + properties: { + result: { type: "string" }, + }, + required: ["result"], + }, + }, + ); + + // Expected: validation should fail and be recorded in operation + expect(result.status).toBe("error"); + expect(result.error?.message).toContain("Invalid structured output"); + }); }); diff --git a/packages/mcp/test/fixtures/mock-server.ts b/packages/mcp/test/fixtures/mock-server.ts index 439c062..1ee2734 100644 --- a/packages/mcp/test/fixtures/mock-server.ts +++ b/packages/mcp/test/fixtures/mock-server.ts @@ -58,6 +58,12 @@ interface MockServerConfig { tools?: boolean; resources?: boolean; prompts?: boolean; + /** Task-augmented execution capability (per MCP spec) */ + tasks?: { + requests?: { + tools?: { call?: boolean }; + }; + }; }; tools?: Array<{ name: string; @@ -107,7 +113,6 @@ const defaultConfig: MockServerConfig = { strictToolValidation: true, // Default to strict for executed tests }; - /** * Process a JSON-RPC message and return the response. */ @@ -189,6 +194,9 @@ function createInitializeResponse( ...(config.capabilities?.tools ? { tools: {} } : {}), ...(config.capabilities?.resources ? { resources: {} } : {}), ...(config.capabilities?.prompts ? { prompts: {} } : {}), + ...(config.capabilities?.tasks + ? { tasks: config.capabilities.tasks } + : {}), }, serverInfo: { name: config.name ?? "mock-mcp-server", @@ -413,7 +421,6 @@ function createToolCallResponse( }; } - /** * Create a mock transport that simulates MCP server behavior. * Use this in unit tests instead of spawning a real process. @@ -598,4 +605,3 @@ export function createMockServerTransport(config: MockServerConfig = {}) { } export type MockServerTransport = ReturnType; - diff --git a/packages/mcp/test/fixtures/task-mock-server.ts b/packages/mcp/test/fixtures/task-mock-server.ts new file mode 100644 index 0000000..2910464 --- /dev/null +++ b/packages/mcp/test/fixtures/task-mock-server.ts @@ -0,0 +1,341 @@ +/** + * Task Mock Server + * + * Mock MCP server with task-augmented execution support. + * Extends the base mock server with tasks/* methods. + */ + +import type { JSONRPCMessage } from "@modelcontextprotocol/sdk/types.js"; +import { createMockServerTransport } from "./mock-server"; + +// Track tasks created by this mock server +interface MockTask { + taskId: string; + status: "working" | "input_required" | "completed" | "failed" | "cancelled"; + statusMessage?: string; + createdAt: string; + lastUpdatedAt: string; + ttl: number | null; + pollInterval?: number; + // For simulation + completionDelay?: number; + willFail?: boolean; +} + +const mockTasks = new Map(); +let taskIdCounter = 0; + +/** + * Create a mock server transport with task support. + */ +export function createTaskMockServerTransport() { + // Base configuration with task-supporting tools + const baseTransport = createMockServerTransport({ + name: "task-mock-server", + version: "1.0.0", + capabilities: { + tools: true, + resources: false, + prompts: false, + // Task-augmented execution capability (per spec line 72) + tasks: { + requests: { + tools: { call: true }, + }, + }, + }, + tools: [ + { name: "echo", description: "Echo tool (no task support)" }, + { + name: "longProcess", + description: "Long-running process (optional task support)", + }, + { + name: "backgroundJob", + description: "Background job (required task support)", + }, + { + name: "quickTask", + description: "Quick task that completes immediately", + }, + { + name: "failingTask", + description: "Task that fails", + }, + { + name: "inputTask", + description: "Task that requires input (elicitation/sampling)", + }, + ], + toolBehaviors: { + echo: { + content: [{ type: "text", text: "Echo response" }], + }, + }, + strictToolValidation: false, // Allow unknown tools for testing + }); + + // Wrap send to intercept task-related methods + const originalSend = baseTransport.send.bind(baseTransport); + + baseTransport.send = async (message: JSONRPCMessage) => { + if ("method" in message && "id" in message) { + const method = message.method; + const id = message.id; + const params = message.params as any; + + switch (method) { + case "tools/list": { + // Override to include execution metadata + const response = { + jsonrpc: "2.0" as const, + id, + result: { + tools: [ + { + name: "echo", + description: "Echo tool", + inputSchema: { type: "object" }, + // No execution = forbidden + }, + { + name: "longProcess", + description: "Long-running process", + inputSchema: { type: "object" }, + execution: { taskSupport: "optional" }, + }, + { + name: "backgroundJob", + description: "Background job", + inputSchema: { type: "object" }, + execution: { taskSupport: "required" }, + }, + { + name: "quickTask", + description: "Quick task", + inputSchema: { type: "object" }, + execution: { taskSupport: "optional" }, + }, + { + name: "failingTask", + description: "Failing task", + inputSchema: { type: "object" }, + execution: { taskSupport: "optional" }, + }, + { + name: "inputTask", + description: "Task requiring input", + inputSchema: { type: "object" }, + execution: { taskSupport: "optional" }, + }, + ], + }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(response); + }); + return; + } + + case "tools/call": { + // Check if this is a task-augmented call + if (params?.task !== undefined) { + const taskId = `task-${++taskIdCounter}-${Date.now()}`; + const now = new Date().toISOString(); + + const willFail = params.name === "failingTask"; + const isQuick = params.name === "quickTask"; + const needsInput = params.name === "inputTask"; + + const task: MockTask = { + taskId, + status: isQuick + ? "completed" + : needsInput + ? "input_required" + : "working", + statusMessage: needsInput ? "Waiting for user input" : undefined, + createdAt: now, + lastUpdatedAt: now, + ttl: params.task?.ttl ?? null, + pollInterval: 100, // Fast polling for tests + willFail, + }; + + mockTasks.set(taskId, task); + + // Simulate completion after delay for non-quick tasks + if (!isQuick) { + setTimeout(() => { + const t = mockTasks.get(taskId); + if (t && t.status === "working") { + if (t.willFail) { + t.status = "failed"; + t.statusMessage = "Task failed intentionally"; + } else { + t.status = "completed"; + } + t.lastUpdatedAt = new Date().toISOString(); + } + }, 200); + } + + const response = { + jsonrpc: "2.0" as const, + id, + result: { + task: { + taskId: task.taskId, + status: task.status, + statusMessage: task.statusMessage, + createdAt: task.createdAt, + lastUpdatedAt: task.lastUpdatedAt, + ttl: task.ttl, + pollInterval: task.pollInterval, + }, + }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(response); + }); + return; + } + // Fall through to base handler for non-task calls + break; + } + + case "tasks/list": { + const tasks = Array.from(mockTasks.values()).map((t) => ({ + taskId: t.taskId, + status: t.status, + statusMessage: t.statusMessage, + createdAt: t.createdAt, + lastUpdatedAt: t.lastUpdatedAt, + ttl: t.ttl, + pollInterval: t.pollInterval, + })); + + const response = { + jsonrpc: "2.0" as const, + id, + result: { tasks }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(response); + }); + return; + } + + case "tasks/get": { + const task = mockTasks.get(params?.taskId); + if (!task) { + const errorResponse = { + jsonrpc: "2.0" as const, + id, + error: { + code: -32602, + message: `Unknown task: ${params?.taskId}`, + }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(errorResponse); + }); + return; + } + + const response = { + jsonrpc: "2.0" as const, + id, + result: { + taskId: task.taskId, + status: task.status, + statusMessage: task.statusMessage, + createdAt: task.createdAt, + lastUpdatedAt: task.lastUpdatedAt, + ttl: task.ttl, + pollInterval: task.pollInterval, + }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(response); + }); + return; + } + + case "tasks/result": { + const task = mockTasks.get(params?.taskId); + if (!task) { + const errorResponse = { + jsonrpc: "2.0" as const, + id, + error: { + code: -32602, + message: `Unknown task: ${params?.taskId}`, + }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(errorResponse); + }); + return; + } + + // Return the tool result + const response = { + jsonrpc: "2.0" as const, + id, + result: { + content: [{ type: "text", text: `Task ${task.taskId} result` }], + isError: task.status === "failed", + }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(response); + }); + return; + } + + case "tasks/cancel": { + const task = mockTasks.get(params?.taskId); + if (!task) { + const errorResponse = { + jsonrpc: "2.0" as const, + id, + error: { + code: -32602, + message: `Unknown task: ${params?.taskId}`, + }, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(errorResponse); + }); + return; + } + + task.status = "cancelled"; + task.lastUpdatedAt = new Date().toISOString(); + + const response = { + jsonrpc: "2.0" as const, + id, + result: {}, + }; + queueMicrotask(() => { + (baseTransport as any).onmessage?.(response); + }); + return; + } + } + } + + // Fall through to base transport for other methods + return originalSend(message); + }; + + // Add method to clear tasks between tests + (baseTransport as any).clearTasks = () => { + mockTasks.clear(); + taskIdCounter = 0; + }; + + return baseTransport; +} diff --git a/packages/mcp/test/fixtures/tool-scenarios.ts b/packages/mcp/test/fixtures/tool-scenarios.ts index c4f9a2d..a0bc3d4 100644 --- a/packages/mcp/test/fixtures/tool-scenarios.ts +++ b/packages/mcp/test/fixtures/tool-scenarios.ts @@ -12,50 +12,50 @@ import type { ToolBehavior, ToolContentConfig } from "./mock-server"; /** Sample text content */ export const sampleTextContent: ToolContentConfig = { - type: "text", - text: "Hello from the tool!", + type: "text", + text: "Hello from the tool!", }; /** Sample image content (1x1 red PNG) */ export const sampleImageContent: ToolContentConfig = { - type: "image", - data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==", - mimeType: "image/png", + type: "image", + data: "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg==", + mimeType: "image/png", }; /** Sample audio content (short WAV header) */ export const sampleAudioContent: ToolContentConfig = { - type: "audio", - data: "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=", - mimeType: "audio/wav", + type: "audio", + data: "UklGRiQAAABXQVZFZm10IBAAAAABAAEARKwAAIhYAQACABAAZGF0YQAAAAA=", + mimeType: "audio/wav", }; /** Sample resource link */ export const sampleResourceLinkContent: ToolContentConfig = { - type: "resource_link", - uri: "file:///path/to/resource.txt", - name: "Resource File", - mimeType: "text/plain", + type: "resource_link", + uri: "file:///path/to/resource.txt", + name: "Resource File", + mimeType: "text/plain", }; /** Sample embedded resource */ export const sampleEmbeddedResourceContent: ToolContentConfig = { - type: "resource", - resource: { - uri: "file:///path/to/data.json", - text: '{"key": "value"}', - mimeType: "application/json", - }, + type: "resource", + resource: { + uri: "file:///path/to/data.json", + text: '{"key": "value"}', + mimeType: "application/json", + }, }; /** Sample content with annotations */ export const sampleAnnotatedContent: ToolContentConfig = { - type: "text", - text: "This is for the user only", - annotations: { - audience: ["user"], - priority: 0.8, - }, + type: "text", + text: "This is for the user only", + annotations: { + audience: ["user"], + priority: 0.8, + }, }; // ============================================================================= @@ -64,209 +64,205 @@ export const sampleAnnotatedContent: ToolContentConfig = { /** Default tool behaviors for scenarios */ export const scenarioToolBehaviors: Record = { - // Basic echo - uses default behavior - echo: {}, + // Basic echo - uses default behavior + echo: {}, - // Returns image content - getImage: { - content: [sampleImageContent], - }, + // Returns image content + getImage: { + content: [sampleImageContent], + }, - // Returns audio content - getAudio: { - content: [sampleAudioContent], - }, + // Returns audio content + getAudio: { + content: [sampleAudioContent], + }, - // Returns resource link - getResourceLink: { - content: [sampleResourceLinkContent], - }, + // Returns resource link + getResourceLink: { + content: [sampleResourceLinkContent], + }, - // Returns embedded resource - getEmbeddedResource: { - content: [sampleEmbeddedResourceContent], - }, + // Returns embedded resource + getEmbeddedResource: { + content: [sampleEmbeddedResourceContent], + }, - // Returns multiple content types - getMixed: { - content: [ - sampleTextContent, - sampleImageContent, - sampleResourceLinkContent, - ], - }, + // Returns multiple content types + getMixed: { + content: [sampleTextContent, sampleImageContent, sampleResourceLinkContent], + }, - // Returns with annotations - getAnnotated: { - content: [sampleAnnotatedContent], - }, + // Returns with annotations + getAnnotated: { + content: [sampleAnnotatedContent], + }, - // Returns isError: true - failingTool: { - content: [{ type: "text", text: "Something went wrong" }], - isError: true, - }, + // Returns isError: true + failingTool: { + content: [{ type: "text", text: "Something went wrong" }], + isError: true, + }, - // Returns structured output - getStructured: { - content: [{ type: "text", text: "Structured data available" }], - structuredContent: { - result: "success", - count: 42, - items: ["a", "b", "c"], - }, - }, + // Returns structured output + getStructured: { + content: [{ type: "text", text: "Structured data available" }], + structuredContent: { + result: "success", + count: 42, + items: ["a", "b", "c"], + }, + }, - // Simulates slow operation (for timeout/cancel tests) - slowTool: { - content: [{ type: "text", text: "Completed after delay" }], - delayMs: 5000, - }, + // Simulates slow operation (for timeout/cancel tests) + slowTool: { + content: [{ type: "text", text: "Completed after delay" }], + delayMs: 5000, + }, - // Slow with progress notifications - slowWithProgress: { - content: [{ type: "text", text: "All steps complete" }], - delayMs: 3000, - progressSteps: 3, - }, + // Slow with progress notifications + slowWithProgress: { + content: [{ type: "text", text: "All steps complete" }], + delayMs: 3000, + progressSteps: 3, + }, - // Very slow (for timeout) - verySlowTool: { - content: [{ type: "text", text: "Should timeout" }], - delayMs: 60000, - }, + // Very slow (for timeout) + verySlowTool: { + content: [{ type: "text", text: "Should timeout" }], + delayMs: 60000, + }, - // GAP DETECTION: Returns audio with invalid MIME type - // Should fail if contentParser.parseContent() is integrated - getInvalidAudioMime: { - content: [ - { - type: "audio", - data: "UklGRiQA", - mimeType: "audio/x-invalid-fake", - }, - ], - }, + // GAP DETECTION: Returns audio with invalid MIME type + // Should fail if contentParser.parseContent() is integrated + getInvalidAudioMime: { + content: [ + { + type: "audio", + data: "UklGRiQA", + mimeType: "audio/x-invalid-fake", + }, + ], + }, - // GAP DETECTION: Returns image with invalid MIME type - // Should fail if contentParser.parseContent() is integrated - getInvalidImageMime: { - content: [ - { - type: "image", - data: "iVBORw0KGgo=", - mimeType: "image/x-invalid-fake", - }, - ], - }, + // GAP DETECTION: Returns image with invalid MIME type + // Should fail if contentParser.parseContent() is integrated + getInvalidImageMime: { + content: [ + { + type: "image", + data: "iVBORw0KGgo=", + mimeType: "image/x-invalid-fake", + }, + ], + }, - // GAP DETECTION: Returns structuredContent that doesn't match outputSchema - // Should fail if validateStructuredOutput() is called - getInvalidStructuredOutput: { - content: [{ type: "text", text: "Data with bad schema" }], - structuredContent: { - wrongField: "should fail validation", - // Missing required 'result' field per outputSchema - }, - }, + // GAP DETECTION: Returns structuredContent that doesn't match outputSchema + // Should fail if validateStructuredOutput() is called + getInvalidStructuredOutput: { + content: [{ type: "text", text: "Data with bad schema" }], + structuredContent: { + wrongField: "should fail validation", + // Missing required 'result' field per outputSchema + }, + }, }; /** Tool definitions with full schema */ export const scenarioToolDefinitions = [ - { - name: "echo", - description: "Echoes input back", - inputSchema: { - type: "object", - properties: { - message: { type: "string" }, - }, - required: ["message"], - }, - }, - { - name: "greet", - description: "Returns a greeting", - inputSchema: { - type: "object", - properties: { - name: { type: "string" }, - }, - }, - }, - { - name: "getImage", - description: "Returns image content", - }, - { - name: "getAudio", - description: "Returns audio content", - }, - { - name: "getResourceLink", - description: "Returns resource link", - }, - { - name: "getEmbeddedResource", - description: "Returns embedded resource", - }, - { - name: "getMixed", - description: "Returns mixed content types", - }, - { - name: "getAnnotated", - description: "Returns annotated content", - }, - { - name: "failingTool", - description: "Always returns isError: true", - }, - { - name: "getStructured", - description: "Returns structured output", - outputSchema: { - type: "object", - properties: { - result: { type: "string" }, - count: { type: "number" }, - items: { type: "array", items: { type: "string" } }, - }, - required: ["result"], - }, - }, - { - name: "slowTool", - description: "Simulates 5 second delay", - }, - { - name: "slowWithProgress", - description: "Slow with progress updates", - }, - { - name: "verySlowTool", - description: "60 second delay for timeout testing", - }, - // GAP DETECTION: These tools return invalid data to test contentParser integration - { - name: "getInvalidAudioMime", - description: "Returns audio with invalid MIME type - should fail if parsed", - }, - { - name: "getInvalidImageMime", - description: "Returns image with invalid MIME type - should fail if parsed", - }, - { - name: "getInvalidStructuredOutput", - description: "Returns structuredContent that doesn't match outputSchema", - outputSchema: { - type: "object", - properties: { - result: { type: "string" }, - }, - required: ["result"], - }, - }, + { + name: "echo", + description: "Echoes input back", + inputSchema: { + type: "object", + properties: { + message: { type: "string" }, + }, + required: ["message"], + }, + }, + { + name: "greet", + description: "Returns a greeting", + inputSchema: { + type: "object", + properties: { + name: { type: "string" }, + }, + }, + }, + { + name: "getImage", + description: "Returns image content", + }, + { + name: "getAudio", + description: "Returns audio content", + }, + { + name: "getResourceLink", + description: "Returns resource link", + }, + { + name: "getEmbeddedResource", + description: "Returns embedded resource", + }, + { + name: "getMixed", + description: "Returns mixed content types", + }, + { + name: "getAnnotated", + description: "Returns annotated content", + }, + { + name: "failingTool", + description: "Always returns isError: true", + }, + { + name: "getStructured", + description: "Returns structured output", + outputSchema: { + type: "object", + properties: { + result: { type: "string" }, + count: { type: "number" }, + items: { type: "array", items: { type: "string" } }, + }, + required: ["result"], + }, + }, + { + name: "slowTool", + description: "Simulates 5 second delay", + }, + { + name: "slowWithProgress", + description: "Slow with progress updates", + }, + { + name: "verySlowTool", + description: "60 second delay for timeout testing", + }, + // GAP DETECTION: These tools return invalid data to test contentParser integration + { + name: "getInvalidAudioMime", + description: "Returns audio with invalid MIME type - should fail if parsed", + }, + { + name: "getInvalidImageMime", + description: "Returns image with invalid MIME type - should fail if parsed", + }, + { + name: "getInvalidStructuredOutput", + description: "Returns structuredContent that doesn't match outputSchema", + outputSchema: { + type: "object", + properties: { + result: { type: "string" }, + }, + required: ["result"], + }, + }, ]; // ============================================================================= @@ -275,24 +271,24 @@ export const scenarioToolDefinitions = [ /** Full mock config with all tools and behaviors */ export const scenarioMockConfig = { - name: "scenario-mock-server", - version: "1.0.0", - protocolVersion: "2024-11-05", - capabilities: { - tools: true, - resources: true, - prompts: true, - }, - tools: scenarioToolDefinitions, - toolBehaviors: scenarioToolBehaviors, - strictToolValidation: true, + name: "scenario-mock-server", + version: "1.0.0", + protocolVersion: "2024-11-05", + capabilities: { + tools: true, + resources: true, + prompts: true, + }, + tools: scenarioToolDefinitions, + toolBehaviors: scenarioToolBehaviors, + strictToolValidation: true, }; /** Minimal config for basic tests */ export const minimalMockConfig = { - tools: [ - { name: "echo", description: "Echo tool" }, - { name: "greet", description: "Greeting tool" }, - ], - strictToolValidation: true, + tools: [ + { name: "echo", description: "Echo tool" }, + { name: "greet", description: "Greeting tool" }, + ], + strictToolValidation: true, }; diff --git a/packages/mcp/test/manager.test.ts b/packages/mcp/test/manager.test.ts index 7f6f011..a5c8259 100644 --- a/packages/mcp/test/manager.test.ts +++ b/packages/mcp/test/manager.test.ts @@ -12,8 +12,8 @@ import { McpClientRegistry } from "../src/client/registry"; // Mock the MCP SDK modules // Mock the MCP SDK modules -const mockClientConnect = mock(async () => { }); -const mockClientClose = mock(async () => { }); +const mockClientConnect = mock(async () => {}); +const mockClientClose = mock(async () => {}); const mockClientListTools = mock(async () => ({ tools: [], nextCursor: undefined, @@ -28,7 +28,7 @@ const mockClientListPrompts = mock(async () => ({ })); // Client factory for dependency injection -const mockSetNotificationHandler = mock(() => { }); +const mockSetNotificationHandler = mock(() => {}); const mockClientFactory = (_info: any, _opts: any) => ({ connect: mockClientConnect, @@ -270,8 +270,8 @@ describe("McpClientManager", () => { // Pre-register a mock client entry // (This simulates a connected state) - const mockClient = { close: async () => { } } as any; - const mockTransport = { close: async () => { } } as any; + const mockClient = { close: async () => {} } as any; + const mockTransport = { close: async () => {} } as any; try { registry.register(session.id, mockClient, mockTransport); @@ -291,9 +291,9 @@ describe("McpClientManager", () => { command: "echo", }); - const mockClose = mock(async () => { }); + const mockClose = mock(async () => {}); const mockClient = { close: mockClose } as any; - const mockTransport = { close: async () => { } } as any; + const mockTransport = { close: async () => {} } as any; registry.register(session.id, mockClient, mockTransport); await clientManager.disconnect(session.id); diff --git a/packages/mcp/test/progress-tracking.test.ts b/packages/mcp/test/progress-tracking.test.ts index d4c76cc..1c2df05 100644 --- a/packages/mcp/test/progress-tracking.test.ts +++ b/packages/mcp/test/progress-tracking.test.ts @@ -8,14 +8,14 @@ import { } from "@say2/core"; import { McpClientManager } from "../src/client/manager"; import { McpClientRegistry } from "../src/client/registry"; +import { progressTracker } from "../src/progress/tracker"; import { LoggingTransport } from "../src/transport"; +import { McpProgressNotificationSchema } from "../src/types/progress"; import { createMockServerTransport, type MockServerTransport, } from "./fixtures/mock-server"; import { scenarioMockConfig } from "./fixtures/tool-scenarios"; -import { progressTracker } from "../src/progress/tracker"; -import { McpProgressNotificationSchema } from "../src/types/progress"; /** * Progress Tracking Integration Tests diff --git a/packages/mcp/test/task-augmented.test.ts b/packages/mcp/test/task-augmented.test.ts new file mode 100644 index 0000000..8c071d6 --- /dev/null +++ b/packages/mcp/test/task-augmented.test.ts @@ -0,0 +1,477 @@ +/** + * Task-Augmented Execution Integration Tests + * + * Tests for task-augmented tool execution flow. + * Task 07: Task-Augmented Execution - Phase 3 Integration Tests + */ + +import { afterEach, beforeEach, describe, expect, test } from "bun:test"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { + createPipeline, + createStateMachineMiddleware, + LATEST_PROTOCOL_VERSION, + SessionManager, +} from "@say2/core"; +import { McpClientManager } from "../src/client/manager"; +import { McpClientRegistry } from "../src/client/registry"; +import { taskManager } from "../src/task/manager"; +import { LoggingTransport } from "../src/transport"; +import type { Task } from "../src/types/task"; +import type { MockServerTransport } from "./fixtures/mock-server.ts"; +import { createTaskMockServerTransport } from "./fixtures/task-mock-server.ts"; + +/** + * Task-Augmented Execution Integration Tests + * + * These tests verify the end-to-end flow of task-augmented tool calls: + * 1. callToolAsTask() creates a task and returns CreateTaskResult + * 2. listTasks(), getTask() retrieve task status + * 3. cancelTask() cancels running tasks + * 4. callToolAsTaskAndWait() polls until completion + * 5. getToolTaskSupport() correctly identifies task support levels + */ +describe("Task-Augmented Execution Integration", () => { + let sessionManager: SessionManager; + let pipeline: ReturnType; + let registry: McpClientRegistry; + let clientManager: McpClientManager; + let mockTransport: MockServerTransport; + let sessionId: string; + let client: Client; + + beforeEach(async () => { + sessionManager = new SessionManager(); + pipeline = createPipeline(); + + // Mock Protocol Detector + const mockDetector = { + isInitializeRequest: (msg: any) => + msg.method === "initialize" && "id" in msg, + isInitializeResponse: (msg: any) => + "result" in msg && "protocolVersion" in msg.result, + isInitializedNotification: (msg: any) => + msg.method === "notifications/initialized", + extractCapabilities: (msg: any) => msg.result?.capabilities, + extractServerInfo: (msg: any) => msg.result?.serverInfo, + }; + + pipeline.use( + (createStateMachineMiddleware as any)(sessionManager, mockDetector), + ); + + registry = new McpClientRegistry(); + clientManager = new McpClientManager(registry, sessionManager, pipeline); + + // Setup session + const session = sessionManager.create({ + name: "task-test-session", + transport: "stdio", + command: "node", + }); + sessionId = session.id; + + // Setup Transport with task support + mockTransport = createTaskMockServerTransport(); + client = new Client( + { name: "test-client", version: "1.0.0" }, + { capabilities: {} }, + ); + + const loggingTransport = new LoggingTransport( + mockTransport, + session, + pipeline, + ); + + // Initialize connection + await client.connect(loggingTransport); + registry.register(sessionId, client, loggingTransport); + + // Manually transition to ACTIVE with task capability + sessionManager.connect(sessionId); + sessionManager.initialize(sessionId); + sessionManager.activate( + sessionId, + {}, // clientCapabilities + { + // serverCapabilities - must include tasks for task-augmented execution + tools: {}, + tasks: { + requests: { + tools: { call: true }, + }, + }, + }, + LATEST_PROTOCOL_VERSION, + ); + + // Clear task manager from previous tests + taskManager.clear(); + }); + + afterEach(async () => { + taskManager.clear(); + if (mockTransport && !mockTransport.isClosed) { + await mockTransport.close(); + } + }); + + // ========================================================================= + // getToolTaskSupport + // ========================================================================= + + describe("getToolTaskSupport", () => { + test("returns 'forbidden' for tool without execution config", async () => { + // Discover capabilities to populate tools + await clientManager.listTools(sessionId); + + const support = clientManager.getToolTaskSupport(sessionId, "echo"); + expect(support).toBe("forbidden"); + }); + + test("returns 'optional' for tool with taskSupport: optional", async () => { + await clientManager.listTools(sessionId); + + const support = clientManager.getToolTaskSupport( + sessionId, + "longProcess", + ); + expect(support).toBe("optional"); + }); + + test("returns 'required' for tool with taskSupport: required", async () => { + await clientManager.listTools(sessionId); + + const support = clientManager.getToolTaskSupport( + sessionId, + "backgroundJob", + ); + expect(support).toBe("required"); + }); + + test("returns 'forbidden' for unknown tool", async () => { + await clientManager.listTools(sessionId); + + const support = clientManager.getToolTaskSupport( + sessionId, + "nonexistent", + ); + expect(support).toBe("forbidden"); + }); + }); + + // ========================================================================= + // listTasks + // ========================================================================= + + describe("listTasks", () => { + test("returns empty array when no tasks exist", async () => { + const tasks = await clientManager.listTasks(sessionId); + expect(tasks).toEqual([]); + }); + + test("returns tasks after creating them", async () => { + // Discover tools first to populate execution metadata + await clientManager.listTools(sessionId); + + // Create a task + await clientManager.callToolAsTask(sessionId, { + name: "longProcess", + arguments: { duration: 1000 }, + }); + + const tasks = await clientManager.listTasks(sessionId); + expect(tasks.length).toBeGreaterThanOrEqual(1); + }); + }); + + // ========================================================================= + // callToolAsTask + // ========================================================================= + + describe("callToolAsTask", () => { + test("returns CreateTaskResult with task object", async () => { + await clientManager.listTools(sessionId); // Discover tools first + + const result = await clientManager.callToolAsTask(sessionId, { + name: "longProcess", + arguments: { duration: 1000 }, + }); + + expect(result.task).toBeDefined(); + expect(result.task.taskId).toBeDefined(); + expect(result.task.status).toBe("working"); + expect(result.task.createdAt).toBeDefined(); + expect(result.task.lastUpdatedAt).toBeDefined(); + }); + + test("registers task in local TaskManager", async () => { + await clientManager.listTools(sessionId); + + const result = await clientManager.callToolAsTask(sessionId, { + name: "longProcess", + arguments: { duration: 500 }, + }); + + const cachedTask = taskManager.getTask(result.task.taskId); + expect(cachedTask).toBeDefined(); + expect(cachedTask?.taskId).toBe(result.task.taskId); + }); + + test("throws error for tool that doesn't support tasks", async () => { + await clientManager.listTools(sessionId); + + await expect( + clientManager.callToolAsTask(sessionId, { + name: "echo", + arguments: { message: "test" }, + }), + ).rejects.toThrow("does not support task-augmented execution"); + }); + + test("passes ttl option to server", async () => { + await clientManager.listTools(sessionId); + + const result = await clientManager.callToolAsTask( + sessionId, + { name: "longProcess", arguments: {} }, + { ttl: 600000 }, + ); + + expect(result.task).toBeDefined(); + // Server may echo back ttl or set its own + }); + }); + + // ========================================================================= + // getTask + // ========================================================================= + + describe("getTask", () => { + test("retrieves task status by ID", async () => { + await clientManager.listTools(sessionId); + + const createResult = await clientManager.callToolAsTask(sessionId, { + name: "longProcess", + arguments: { duration: 500 }, + }); + + const task = await clientManager.getTask( + sessionId, + createResult.task.taskId, + ); + expect(task.taskId).toBe(createResult.task.taskId); + expect(["working", "completed"]).toContain(task.status); + }); + + test("throws for unknown task ID", async () => { + await expect( + clientManager.getTask(sessionId, "nonexistent-task"), + ).rejects.toThrow(); + }); + }); + + // ========================================================================= + // cancelTask + // ========================================================================= + + describe("cancelTask", () => { + test("cancels a running task", async () => { + await clientManager.listTools(sessionId); + + const createResult = await clientManager.callToolAsTask(sessionId, { + name: "longProcess", + arguments: { duration: 10000 }, // Long duration so we can cancel + }); + + // Cancel the task + await clientManager.cancelTask(sessionId, createResult.task.taskId); + + // Verify task was removed from local cache + const cachedTask = taskManager.getTask(createResult.task.taskId); + expect(cachedTask).toBeUndefined(); + }); + + test("throws for unknown task ID", async () => { + await expect( + clientManager.cancelTask(sessionId, "nonexistent-task"), + ).rejects.toThrow(); + }); + }); + + // ========================================================================= + // callToolAsTaskAndWait + // ========================================================================= + + describe("callToolAsTaskAndWait", () => { + test("polls until task completes and returns result", async () => { + await clientManager.listTools(sessionId); + + const result = await clientManager.callToolAsTaskAndWait(sessionId, { + name: "quickTask", + arguments: {}, + }); + + expect(result).toBeDefined(); + }); + + test("calls onProgress callback during polling", async () => { + await clientManager.listTools(sessionId); + + const progressUpdates: Task[] = []; + + await clientManager.callToolAsTaskAndWait( + sessionId, + { name: "quickTask", arguments: {} }, + {}, + (task) => { + progressUpdates.push(task); + }, + ); + + expect(progressUpdates.length).toBeGreaterThanOrEqual(1); + }); + + test("throws error when task fails", async () => { + await clientManager.listTools(sessionId); + + await expect( + clientManager.callToolAsTaskAndWait(sessionId, { + name: "failingTask", + arguments: {}, + }), + ).rejects.toThrow(); + }); + }); + + // ========================================================================= + // Task Status Notifications + // ========================================================================= + + describe("Task Status Notifications", () => { + test("handleStatusNotification updates TaskManager cache", () => { + // Register a task first + taskManager.registerTask("test-task", sessionId, { + status: "working", + }); + + // Directly call handleStatusNotification (simulating notification handler) + taskManager.handleStatusNotification({ + taskId: "test-task", + status: "completed", + statusMessage: "Done", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }); + + const task = taskManager.getTask("test-task"); + expect(task?.status).toBe("completed"); + expect(task?.statusMessage).toBe("Done"); + }); + + test("handleStatusNotification creates task if not exists", () => { + // Directly call handleStatusNotification for unknown task + taskManager.handleStatusNotification({ + taskId: "new-task-from-notification", + status: "working", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + }); + + const task = taskManager.getTask("new-task-from-notification"); + expect(task).toBeDefined(); + expect(task?.status).toBe("working"); + }); + }); + + // ========================================================================= + // input_required Status (Spec line 111) + // ========================================================================= + + describe("input_required Status", () => { + test("callToolAsTask returns input_required status for inputTask tool", async () => { + await clientManager.listTools(sessionId); + + const result = await clientManager.callToolAsTask(sessionId, { + name: "inputTask", + arguments: {}, + }); + + expect(result.task.status).toBe("input_required"); + expect(result.task.statusMessage).toBe("Waiting for user input"); + }); + + test("getTask shows input_required task as waiting for input", async () => { + await clientManager.listTools(sessionId); + + const createResult = await clientManager.callToolAsTask(sessionId, { + name: "inputTask", + arguments: {}, + }); + + const task = await clientManager.getTask( + sessionId, + createResult.task.taskId, + ); + expect(task.status).toBe("input_required"); + }); + + test("polling continues on input_required until task completes", async () => { + // This tests that input_required is NOT a terminal status + // Use TaskManager directly with a mock that transitions: + // input_required -> input_required -> completed + const testManager = taskManager; + testManager.registerTask("poll-input-test", sessionId, { + status: "input_required", + }); + + let pollCount = 0; + const result = await testManager.pollUntilComplete( + "poll-input-test", + async () => { + pollCount++; + // Simulate: input_required x2, then completed + return { + taskId: "poll-input-test", + status: pollCount >= 3 ? "completed" : "input_required", + statusMessage: pollCount >= 3 ? "Done" : "Waiting for input", + createdAt: new Date().toISOString(), + lastUpdatedAt: new Date().toISOString(), + ttl: null, + } as Task; + }, + ); + + expect(result.status).toBe("completed"); + expect(pollCount).toBeGreaterThanOrEqual(3); + }); + }); + + // ========================================================================= + // Server Capability Check (Spec line 72, 78) + // ========================================================================= + + describe("Server Task Capability", () => { + test("mock server includes tasks capability in initialization", async () => { + // The task mock server should advertise tasks.requests.tools.call + // This is verified by the fact that our tests work, but let's be explicit + const session = sessionManager.get(sessionId); + // Session state should have captured the capabilities + expect(session).toBeDefined(); + // The server capabilities should include tasks + // Note: We can't directly access this from SessionManager, + // but the fact that task operations work proves the capability is there + }); + + test("getToolTaskSupport returns forbidden when tool lacks execution metadata", async () => { + await clientManager.listTools(sessionId); + + // echo tool has no execution.taskSupport + const support = clientManager.getToolTaskSupport(sessionId, "echo"); + expect(support).toBe("forbidden"); + }); + }); +}); diff --git a/packages/mcp/test/tool-annotations.test.ts b/packages/mcp/test/tool-annotations.test.ts new file mode 100644 index 0000000..641f5e7 --- /dev/null +++ b/packages/mcp/test/tool-annotations.test.ts @@ -0,0 +1,476 @@ +/** + * Tool Annotations Integration Tests + * + * Integration tests for tool annotations with mock MCP server. + * Task 06: Tool Annotations - Phase 2 Integration Tests + */ + +import { beforeEach, describe, expect, test } from "bun:test"; +import { Client } from "@modelcontextprotocol/sdk/client/index.js"; +import { + createPipeline, + createStateMachineMiddleware, + LATEST_PROTOCOL_VERSION, + SessionManager, +} from "@say2/core"; +import { McpClientManager } from "../src/client/manager"; +import { McpClientRegistry } from "../src/client/registry"; +import { LoggingTransport } from "../src/transport"; +import { createMockServerTransport } from "./fixtures/mock-server"; + +describe("Tool Annotations Integration Tests", () => { + let sessionManager: SessionManager; + let pipeline: ReturnType; + let registry: McpClientRegistry; + let clientManager: McpClientManager; + + // Mock Protocol Detector for API compatibility + const mockDetector = { + // biome-ignore lint/suspicious/noExplicitAny: mock + isInitializeRequest: (msg: any) => + msg.method === "initialize" && "id" in msg, + // biome-ignore lint/suspicious/noExplicitAny: mock + isInitializeResponse: (msg: any) => + "result" in msg && "protocolVersion" in msg.result, + // biome-ignore lint/suspicious/noExplicitAny: mock + isInitializedNotification: (msg: any) => + msg.method === "notifications/initialized", + // biome-ignore lint/suspicious/noExplicitAny: mock + extractCapabilities: (msg: any) => msg.result?.capabilities, + // biome-ignore lint/suspicious/noExplicitAny: mock + extractServerInfo: (msg: any) => msg.result?.serverInfo, + }; + + beforeEach(() => { + sessionManager = new SessionManager(); + pipeline = createPipeline(); + // biome-ignore lint/suspicious/noExplicitAny: API mismatch fix + pipeline.use( + (createStateMachineMiddleware as any)(sessionManager, mockDetector), + ); + registry = new McpClientRegistry(); + clientManager = new McpClientManager(registry, sessionManager, pipeline); + }); + + /** + * Helper: Set up a connected client with the given server configuration + */ + // biome-ignore lint/suspicious/noExplicitAny: flexible config + async function setupConnectedClient(serverConfig: any) { + const session = sessionManager.create({ + name: "test", + transport: "stdio", + command: "node", + }); + + const mockTransport = createMockServerTransport(serverConfig); + const loggingTransport = new LoggingTransport( + mockTransport, + session, + pipeline, + ); + + // Manually transition session state + sessionManager.connect(session.id); + sessionManager.initialize(session.id); + sessionManager.activate( + session.id, + serverConfig.capabilities ?? {}, + {}, + LATEST_PROTOCOL_VERSION, + ); + + // Create Client and Register + const client = new Client( + { name: "client", version: "1.0.0" }, + { capabilities: {} }, + ); + await client.connect(loggingTransport); + registry.register(session.id, client, loggingTransport); + + return { session, client, mockTransport }; + } + + /** + * Helper: Store tools in session's discovered capabilities + * Uses updateCapabilities() to properly mutate the session state + */ + // biome-ignore lint/suspicious/noExplicitAny: flexible tools array + function storeToolsInSession(sessionId: string, tools: any[]) { + // Use updateCapabilities to store discovered tools in serverCapabilities + sessionManager.updateCapabilities(sessionId, undefined, { + tools: true, // Keep original capability flag + discovered: { tools }, + }); + } + + // ========================================================================= + // getToolAnnotations() Tests + // ========================================================================= + + describe("getToolAnnotations()", () => { + test("returns annotations for existing tool with full annotations", async () => { + const toolsWithAnnotations = [ + { + name: "read_file", + description: "Reads a file", + inputSchema: { type: "object" }, + annotations: { + title: "Read File", + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false, + }, + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [{ name: "read_file", description: "Reads a file" }], + }); + + // Manually store tools with annotations + storeToolsInSession(session.id, toolsWithAnnotations); + + const annotations = clientManager.getToolAnnotations( + session.id, + "read_file", + ); + + expect(annotations).toBeDefined(); + expect(annotations?.title).toBe("Read File"); + expect(annotations?.readOnlyHint).toBe(true); + expect(annotations?.destructiveHint).toBe(false); + expect(annotations?.idempotentHint).toBe(true); + expect(annotations?.openWorldHint).toBe(false); + }); + + test("applies defaults for tool with partial annotations", async () => { + const toolsWithPartialAnnotations = [ + { + name: "search", + inputSchema: { type: "object" }, + annotations: { + title: "Search Tool", + // Other hints not specified - should get defaults + }, + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [{ name: "search", description: "Search tool" }], + }); + + storeToolsInSession(session.id, toolsWithPartialAnnotations); + + const annotations = clientManager.getToolAnnotations( + session.id, + "search", + ); + + expect(annotations).toBeDefined(); + expect(annotations?.title).toBe("Search Tool"); + // Defaults applied: + expect(annotations?.readOnlyHint).toBe(false); + expect(annotations?.destructiveHint).toBe(true); + expect(annotations?.idempotentHint).toBe(false); + expect(annotations?.openWorldHint).toBe(true); + }); + + test("applies all defaults for tool without annotations", async () => { + const toolsWithoutAnnotations = [ + { + name: "no_hints", + inputSchema: { type: "object" }, + // No annotations property + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [{ name: "no_hints", description: "Tool without annotations" }], + }); + + storeToolsInSession(session.id, toolsWithoutAnnotations); + + const annotations = clientManager.getToolAnnotations( + session.id, + "no_hints", + ); + + // Should return defaults (not undefined) because tool exists + expect(annotations).toBeDefined(); + expect(annotations?.readOnlyHint).toBe(false); + expect(annotations?.destructiveHint).toBe(true); + expect(annotations?.idempotentHint).toBe(false); + expect(annotations?.openWorldHint).toBe(true); + }); + + test("returns undefined for non-existent tool", async () => { + const tools = [ + { + name: "existing_tool", + inputSchema: { type: "object" }, + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [{ name: "existing_tool", description: "Existing" }], + }); + + storeToolsInSession(session.id, tools); + + const annotations = clientManager.getToolAnnotations( + session.id, + "non_existent_tool", + ); + + expect(annotations).toBeUndefined(); + }); + + test("returns undefined for non-existent session", () => { + const annotations = clientManager.getToolAnnotations( + "non-existent-session-id", + "any_tool", + ); + + expect(annotations).toBeUndefined(); + }); + + test("returns undefined when session has no discovered tools", async () => { + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [], + }); + + // Don't store any tools - discovered.tools will be undefined + + const annotations = clientManager.getToolAnnotations( + session.id, + "any_tool", + ); + + expect(annotations).toBeUndefined(); + }); + }); + + // ========================================================================= + // listToolsTyped() Tests + // ========================================================================= + + describe("listToolsTyped()", () => { + test("returns all tools with annotation defaults applied", async () => { + const tools = [ + { + name: "tool_with_annotations", + description: "Has annotations", + inputSchema: { type: "object" }, + annotations: { + title: "Annotated Tool", + readOnlyHint: true, + }, + }, + { + name: "tool_without_annotations", + description: "No annotations", + inputSchema: { type: "object" }, + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [ + { name: "tool_with_annotations", description: "Has annotations" }, + { name: "tool_without_annotations", description: "No annotations" }, + ], + }); + + storeToolsInSession(session.id, tools); + + const typedTools = clientManager.listToolsTyped(session.id); + + expect(typedTools).toHaveLength(2); + + // First tool - has explicit annotations + const annotatedTool = typedTools.find( + (t) => t.name === "tool_with_annotations", + ); + expect(annotatedTool?.annotations?.title).toBe("Annotated Tool"); + expect(annotatedTool?.annotations?.readOnlyHint).toBe(true); + expect(annotatedTool?.annotations?.destructiveHint).toBe(true); // default + expect(annotatedTool?.annotations?.idempotentHint).toBe(false); // default + expect(annotatedTool?.annotations?.openWorldHint).toBe(true); // default + + // Second tool - all defaults applied + const plainTool = typedTools.find( + (t) => t.name === "tool_without_annotations", + ); + expect(plainTool?.annotations?.readOnlyHint).toBe(false); + expect(plainTool?.annotations?.destructiveHint).toBe(true); + expect(plainTool?.annotations?.idempotentHint).toBe(false); + expect(plainTool?.annotations?.openWorldHint).toBe(true); + }); + + test("returns empty array for session with no tools", async () => { + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [], + }); + + // Empty discovered tools + storeToolsInSession(session.id, []); + + const typedTools = clientManager.listToolsTyped(session.id); + + expect(typedTools).toEqual([]); + }); + + test("returns empty array for non-existent session", () => { + const typedTools = clientManager.listToolsTyped("non-existent-session"); + + expect(typedTools).toEqual([]); + }); + + test("preserves all tool fields alongside annotations", async () => { + const tools = [ + { + name: "complete_tool", + description: "A tool with all fields", + inputSchema: { + type: "object", + properties: { input: { type: "string" } }, + required: ["input"], + }, + outputSchema: { + type: "object", + properties: { output: { type: "number" } }, + }, + annotations: { + title: "Complete Tool", + }, + execution: { + taskSupport: "optional", + }, + _meta: { + version: "1.0", + }, + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [{ name: "complete_tool", description: "Complete" }], + }); + + storeToolsInSession(session.id, tools); + + const typedTools = clientManager.listToolsTyped(session.id); + + expect(typedTools).toHaveLength(1); + const tool = typedTools[0]!; + + // Core fields preserved + expect(tool.name).toBe("complete_tool"); + expect(tool.description).toBe("A tool with all fields"); + expect(tool.inputSchema).toBeDefined(); + expect(tool.inputSchema.properties).toBeDefined(); + expect(tool.outputSchema).toBeDefined(); + + // Annotations with defaults + expect(tool.annotations?.title).toBe("Complete Tool"); + + // Other optional fields preserved + expect(tool.execution?.taskSupport).toBe("optional"); + expect(tool._meta?.version).toBe("1.0"); + }); + }); + + // ========================================================================= + // Edge Cases + // ========================================================================= + + describe("Edge Cases", () => { + test("handles tools with empty annotations object", async () => { + const tools = [ + { + name: "empty_annotations", + inputSchema: { type: "object" }, + annotations: {}, // Empty but present + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: [{ name: "empty_annotations", description: "Empty" }], + }); + + storeToolsInSession(session.id, tools); + + const annotations = clientManager.getToolAnnotations( + session.id, + "empty_annotations", + ); + + // All defaults should be applied + expect(annotations).toBeDefined(); + expect(annotations?.readOnlyHint).toBe(false); + expect(annotations?.destructiveHint).toBe(true); + expect(annotations?.idempotentHint).toBe(false); + expect(annotations?.openWorldHint).toBe(true); + }); + + test("handles multiple tools with varying annotation coverage", async () => { + const tools = [ + { + name: "full", + inputSchema: { type: "object" }, + annotations: { + title: "Full", + readOnlyHint: true, + destructiveHint: false, + idempotentHint: true, + openWorldHint: false, + }, + }, + { + name: "partial", + inputSchema: { type: "object" }, + annotations: { title: "Partial" }, + }, + { + name: "none", + inputSchema: { type: "object" }, + }, + ]; + + const { session } = await setupConnectedClient({ + capabilities: { tools: true }, + tools: tools.map((t) => ({ name: t.name, description: t.name })), + }); + + storeToolsInSession(session.id, tools); + + const typedTools = clientManager.listToolsTyped(session.id); + + expect(typedTools).toHaveLength(3); + + // Full - all explicit + const full = typedTools.find((t) => t.name === "full"); + expect(full?.annotations?.readOnlyHint).toBe(true); + expect(full?.annotations?.destructiveHint).toBe(false); + + // Partial - mixed + const partial = typedTools.find((t) => t.name === "partial"); + expect(partial?.annotations?.title).toBe("Partial"); + expect(partial?.annotations?.readOnlyHint).toBe(false); // default + + // None - all defaults + const none = typedTools.find((t) => t.name === "none"); + expect(none?.annotations?.readOnlyHint).toBe(false); + expect(none?.annotations?.destructiveHint).toBe(true); + }); + }); +}); diff --git a/packages/mcp/test/tool-call.test.ts b/packages/mcp/test/tool-call.test.ts index f418883..6d05230 100644 --- a/packages/mcp/test/tool-call.test.ts +++ b/packages/mcp/test/tool-call.test.ts @@ -1,10 +1,10 @@ import { beforeEach, describe, expect, test } from "bun:test"; import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { - createPipeline, - createStateMachineMiddleware, - LATEST_PROTOCOL_VERSION, - SessionManager, + createPipeline, + createStateMachineMiddleware, + LATEST_PROTOCOL_VERSION, + SessionManager, } from "@say2/core"; import { McpClientManager } from "../src/client/manager"; import { McpClientRegistry } from "../src/client/registry"; @@ -13,241 +13,241 @@ import { createMockServerTransport } from "./fixtures/mock-server"; import { scenarioMockConfig } from "./fixtures/tool-scenarios"; describe("Tool Execution Integration", () => { - let sessionManager: SessionManager; - let pipeline: ReturnType; - let registry: McpClientRegistry; - let clientManager: McpClientManager; - let mockTransport: ReturnType; - let sessionId: string; - let client: Client; - - beforeEach(async () => { - sessionManager = new SessionManager(); - pipeline = createPipeline(); - - // Mock Protocol Detector - const mockDetector = { - isInitializeRequest: (msg: any) => - msg.method === "initialize" && "id" in msg, - isInitializeResponse: (msg: any) => - "result" in msg && "protocolVersion" in msg.result, - isInitializedNotification: (msg: any) => - msg.method === "notifications/initialized", - extractCapabilities: (msg: any) => msg.result?.capabilities, - extractServerInfo: (msg: any) => msg.result?.serverInfo, - }; - - pipeline.use( - (createStateMachineMiddleware as any)(sessionManager, mockDetector), - ); - - registry = new McpClientRegistry(); - clientManager = new McpClientManager(registry, sessionManager, pipeline); - - // Setup session - const session = sessionManager.create({ - name: "test-session", - transport: "stdio", - command: "node", - }); - sessionId = session.id; - - // Setup Transport and Client - mockTransport = createMockServerTransport(scenarioMockConfig); - client = new Client( - { name: "test-client", version: "1.0.0" }, - { capabilities: {} }, - ); - - const loggingTransport = new LoggingTransport( - mockTransport, - session, - pipeline, - ); - - // Initialize connection - await client.connect(loggingTransport); - registry.register(sessionId, client, loggingTransport); - - // Manually transition to ACTIVE - sessionManager.connect(sessionId); - sessionManager.initialize(sessionId); - sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); - }); - - test("callTool() executes tool and returns result", async () => { - const result = await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "hello" }, - }); - - expect(result).toBeDefined(); - expect(result.status).toBe("completed"); - expect(result.result).toBeDefined(); - if (result.result && result.result.content.length > 0) { - expect(result.result.content[0]?.type).toBe("text"); - } else { - throw new Error("Expected content result"); - } - }); - - test("callTool() handles image content", async () => { - const result = await clientManager.callTool(sessionId, { - name: "getImage", - }); - - expect(result.status).toBe("completed"); - const content = result.result?.content[0]; - expect(content?.type).toBe("image"); - if (content?.type === "image") { - expect(content.data).toBeDefined(); - expect(content.mimeType).toBe("image/png"); - } - }); - - test("callTool() handles unknown tool error (-32602)", async () => { - // Expect failure - // The client.callTool throws error if server returns error? - // Or does it return ToolOperation with status='error'? - // MCP SDK client throws. Manager should catch and update status to 'error'? - // Or manager propagates? - // Spec says "Status is updated to 'error'". - // Manager callTool returns Promise. - // So it should return the operation object with status='error'. - - const result = await clientManager.callTool(sessionId, { - name: "nonExistentTool", - }); - - expect(result.status).toBe("error"); - expect(result.error).toBeDefined(); - expect(result.error?.code).toBe(-32602); - }); - - test("callTool() tracks operation in store", async () => { - const result = await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "test" }, - }); - - const validId = result.id; - const stored = clientManager.getToolOperation(validId); - expect(stored).toBeDefined(); - expect(stored?.id).toBe(validId); - expect(stored?.status).toBe("completed"); - }); - - test("getToolOperations() lists all session operations", async () => { - await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "1" }, - }); - await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "2" }, - }); - - const ops = clientManager.getToolOperations(sessionId); - expect(ops).toHaveLength(2); - }); - - test("callTool() validates request arguments", async () => { - // Valid request - const valid = await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "ok" }, - }); - expect(valid.status).toBe("completed"); - - // Invalid request (missing required arg) - // Mock server 'echo' tool requires 'message'. - // If strictToolValidation is on, validation error might come from server? - // SDK might strictly validate if local definition used? No, validation happens on server. - // Server returns -32602 (Invalid Params). - - const invalid = await clientManager.callTool(sessionId, { - name: "echo", - arguments: {}, // missing message - }); - - expect(invalid.status).toBe("error"); - // error code for invalid params is -32602 usually - }); - - test("callTool() with isError:true maps to status:error", async () => { - // The failingTool is configured to return { isError: true, content: [...] } - const result = await clientManager.callTool(sessionId, { - name: "failingTool", - }); - - // Even though the tool "succeeded" at the protocol level, - // isError: true should map to status: "error" - expect(result.status).toBe("error"); - expect(result.result).toBeDefined(); - expect(result.result?.isError).toBe(true); - expect(result.result?.content).toHaveLength(1); - expect(result.result?.content[0]?.type).toBe("text"); - }); - - test("callTool() handles mixed content types", async () => { - // getMixed returns: [text, image, resource_link] - const result = await clientManager.callTool(sessionId, { - name: "getMixed", - }); - - expect(result.status).toBe("completed"); - expect(result.result?.content).toHaveLength(3); - - // Verify each content type - const content = result.result!.content; - expect(content[0]?.type).toBe("text"); - expect(content[1]?.type).toBe("image"); - expect(content[2]?.type).toBe("resource_link"); - - // Verify image has required fields - if (content[1]?.type === "image") { - expect(content[1].data).toBeDefined(); - expect(content[1].mimeType).toBe("image/png"); - } - - // Verify resource_link has required fields - if (content[2]?.type === "resource_link") { - expect(content[2].uri).toBe("file:///path/to/resource.txt"); - expect(content[2].name).toBe("Resource File"); - } - }); - - test("ToolOperation has correct timestamps (startedAt, completedAt)", async () => { - const beforeCall = new Date(); - - const result = await clientManager.callTool(sessionId, { - name: "echo", - arguments: { message: "timestamp test" }, - }); - - const afterCall = new Date(); - - // Verify startedAt is set and within bounds - expect(result.startedAt).toBeDefined(); - expect(result.startedAt!.getTime()).toBeGreaterThanOrEqual( - beforeCall.getTime(), - ); - expect(result.startedAt!.getTime()).toBeLessThanOrEqual( - afterCall.getTime(), - ); - - // Verify completedAt is set and after startedAt - expect(result.completedAt).toBeDefined(); - expect(result.completedAt!.getTime()).toBeGreaterThanOrEqual( - result.startedAt!.getTime(), - ); - expect(result.completedAt!.getTime()).toBeLessThanOrEqual( - afterCall.getTime(), - ); - - // Also verify via getToolOperation - const stored = clientManager.getToolOperation(result.id); - expect(stored?.startedAt).toEqual(result.startedAt); - expect(stored?.completedAt).toEqual(result.completedAt); - }); + let sessionManager: SessionManager; + let pipeline: ReturnType; + let registry: McpClientRegistry; + let clientManager: McpClientManager; + let mockTransport: ReturnType; + let sessionId: string; + let client: Client; + + beforeEach(async () => { + sessionManager = new SessionManager(); + pipeline = createPipeline(); + + // Mock Protocol Detector + const mockDetector = { + isInitializeRequest: (msg: any) => + msg.method === "initialize" && "id" in msg, + isInitializeResponse: (msg: any) => + "result" in msg && "protocolVersion" in msg.result, + isInitializedNotification: (msg: any) => + msg.method === "notifications/initialized", + extractCapabilities: (msg: any) => msg.result?.capabilities, + extractServerInfo: (msg: any) => msg.result?.serverInfo, + }; + + pipeline.use( + (createStateMachineMiddleware as any)(sessionManager, mockDetector), + ); + + registry = new McpClientRegistry(); + clientManager = new McpClientManager(registry, sessionManager, pipeline); + + // Setup session + const session = sessionManager.create({ + name: "test-session", + transport: "stdio", + command: "node", + }); + sessionId = session.id; + + // Setup Transport and Client + mockTransport = createMockServerTransport(scenarioMockConfig); + client = new Client( + { name: "test-client", version: "1.0.0" }, + { capabilities: {} }, + ); + + const loggingTransport = new LoggingTransport( + mockTransport, + session, + pipeline, + ); + + // Initialize connection + await client.connect(loggingTransport); + registry.register(sessionId, client, loggingTransport); + + // Manually transition to ACTIVE + sessionManager.connect(sessionId); + sessionManager.initialize(sessionId); + sessionManager.activate(sessionId, {}, {}, LATEST_PROTOCOL_VERSION); + }); + + test("callTool() executes tool and returns result", async () => { + const result = await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "hello" }, + }); + + expect(result).toBeDefined(); + expect(result.status).toBe("completed"); + expect(result.result).toBeDefined(); + if (result.result && result.result.content.length > 0) { + expect(result.result.content[0]?.type).toBe("text"); + } else { + throw new Error("Expected content result"); + } + }); + + test("callTool() handles image content", async () => { + const result = await clientManager.callTool(sessionId, { + name: "getImage", + }); + + expect(result.status).toBe("completed"); + const content = result.result?.content[0]; + expect(content?.type).toBe("image"); + if (content?.type === "image") { + expect(content.data).toBeDefined(); + expect(content.mimeType).toBe("image/png"); + } + }); + + test("callTool() handles unknown tool error (-32602)", async () => { + // Expect failure + // The client.callTool throws error if server returns error? + // Or does it return ToolOperation with status='error'? + // MCP SDK client throws. Manager should catch and update status to 'error'? + // Or manager propagates? + // Spec says "Status is updated to 'error'". + // Manager callTool returns Promise. + // So it should return the operation object with status='error'. + + const result = await clientManager.callTool(sessionId, { + name: "nonExistentTool", + }); + + expect(result.status).toBe("error"); + expect(result.error).toBeDefined(); + expect(result.error?.code).toBe(-32602); + }); + + test("callTool() tracks operation in store", async () => { + const result = await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "test" }, + }); + + const validId = result.id; + const stored = clientManager.getToolOperation(validId); + expect(stored).toBeDefined(); + expect(stored?.id).toBe(validId); + expect(stored?.status).toBe("completed"); + }); + + test("getToolOperations() lists all session operations", async () => { + await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "1" }, + }); + await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "2" }, + }); + + const ops = clientManager.getToolOperations(sessionId); + expect(ops).toHaveLength(2); + }); + + test("callTool() validates request arguments", async () => { + // Valid request + const valid = await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "ok" }, + }); + expect(valid.status).toBe("completed"); + + // Invalid request (missing required arg) + // Mock server 'echo' tool requires 'message'. + // If strictToolValidation is on, validation error might come from server? + // SDK might strictly validate if local definition used? No, validation happens on server. + // Server returns -32602 (Invalid Params). + + const invalid = await clientManager.callTool(sessionId, { + name: "echo", + arguments: {}, // missing message + }); + + expect(invalid.status).toBe("error"); + // error code for invalid params is -32602 usually + }); + + test("callTool() with isError:true maps to status:error", async () => { + // The failingTool is configured to return { isError: true, content: [...] } + const result = await clientManager.callTool(sessionId, { + name: "failingTool", + }); + + // Even though the tool "succeeded" at the protocol level, + // isError: true should map to status: "error" + expect(result.status).toBe("error"); + expect(result.result).toBeDefined(); + expect(result.result?.isError).toBe(true); + expect(result.result?.content).toHaveLength(1); + expect(result.result?.content[0]?.type).toBe("text"); + }); + + test("callTool() handles mixed content types", async () => { + // getMixed returns: [text, image, resource_link] + const result = await clientManager.callTool(sessionId, { + name: "getMixed", + }); + + expect(result.status).toBe("completed"); + expect(result.result?.content).toHaveLength(3); + + // Verify each content type + const content = result.result?.content; + expect(content[0]?.type).toBe("text"); + expect(content[1]?.type).toBe("image"); + expect(content[2]?.type).toBe("resource_link"); + + // Verify image has required fields + if (content[1]?.type === "image") { + expect(content[1].data).toBeDefined(); + expect(content[1].mimeType).toBe("image/png"); + } + + // Verify resource_link has required fields + if (content[2]?.type === "resource_link") { + expect(content[2].uri).toBe("file:///path/to/resource.txt"); + expect(content[2].name).toBe("Resource File"); + } + }); + + test("ToolOperation has correct timestamps (startedAt, completedAt)", async () => { + const beforeCall = new Date(); + + const result = await clientManager.callTool(sessionId, { + name: "echo", + arguments: { message: "timestamp test" }, + }); + + const afterCall = new Date(); + + // Verify startedAt is set and within bounds + expect(result.startedAt).toBeDefined(); + expect(result.startedAt?.getTime()).toBeGreaterThanOrEqual( + beforeCall.getTime(), + ); + expect(result.startedAt?.getTime()).toBeLessThanOrEqual( + afterCall.getTime(), + ); + + // Verify completedAt is set and after startedAt + expect(result.completedAt).toBeDefined(); + expect(result.completedAt?.getTime()).toBeGreaterThanOrEqual( + result.startedAt?.getTime(), + ); + expect(result.completedAt?.getTime()).toBeLessThanOrEqual( + afterCall.getTime(), + ); + + // Also verify via getToolOperation + const stored = clientManager.getToolOperation(result.id); + expect(stored?.startedAt).toEqual(result.startedAt); + expect(stored?.completedAt).toEqual(result.completedAt); + }); }); diff --git a/stryker.config.mjs b/stryker.config.mjs index f0b1ab8..c252cd6 100644 --- a/stryker.config.mjs +++ b/stryker.config.mjs @@ -9,7 +9,8 @@ const config = { // Use command runner since Bun doesn't have a native Stryker plugin yet testRunner: "command", commandRunner: { - command: "bun test", + // Only run tests from packages/ to avoid v0-docs reference project + command: "bun test packages/", }, // TypeScript checker disabled for now (Bun compatibility) @@ -25,6 +26,15 @@ const config = { "!packages/*/src/**/index.ts", // Skip barrel exports ], + // Exclude v0-docs from being copied to sandbox (has uninstalled deps) + ignorePatterns: [ + "v0-docs/**", + ".git/**", + "node_modules/**", + "reports/**", + "coverage/**", + ], + // Reporter configuration reporters: ["html", "clear-text", "progress"], htmlReporter: {