Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 41 additions & 18 deletions tests/shared/test_streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import json
import multiprocessing
import socket
import threading
import time
import traceback
import warnings
from collections.abc import AsyncIterator, Generator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
Expand Down Expand Up @@ -1462,7 +1464,7 @@ async def sampling_callback(


# Context-aware server implementation for testing request context propagation
async def _handle_context_list_tools( # pragma: no cover
async def _handle_context_list_tools(
ctx: ServerRequestContext, params: PaginatedRequestParams | None
) -> ListToolsResult:
return ListToolsResult(
Expand All @@ -1487,15 +1489,13 @@ async def _handle_context_list_tools( # pragma: no cover
)


async def _handle_context_call_tool( # pragma: no cover
ctx: ServerRequestContext, params: CallToolRequestParams
) -> CallToolResult:
async def _handle_context_call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> CallToolResult:
name = params.name
args = params.arguments or {}

if name == "echo_headers":
headers_info: dict[str, Any] = {}
if ctx.request and isinstance(ctx.request, Request):
if ctx.request and isinstance(ctx.request, Request): # pragma: no branch
headers_info = dict(ctx.request.headers)
return CallToolResult(content=[TextContent(type="text", text=json.dumps(headers_info))])

Expand All @@ -1506,19 +1506,19 @@ async def _handle_context_call_tool( # pragma: no cover
"method": None,
"path": None,
}
if ctx.request and isinstance(ctx.request, Request):
if ctx.request and isinstance(ctx.request, Request): # pragma: no branch
request = ctx.request
context_data["headers"] = dict(request.headers)
context_data["method"] = request.method
context_data["path"] = request.url.path
return CallToolResult(content=[TextContent(type="text", text=json.dumps(context_data))])

return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")])
return CallToolResult(content=[TextContent(type="text", text=f"Unknown tool: {name}")]) # pragma: no cover


# Server runner for context-aware testing
def run_context_aware_server(port: int): # pragma: no cover
"""Run the context-aware test server."""
def _create_context_aware_server(port: int) -> uvicorn.Server:
"""Create the context-aware test server app and uvicorn.Server."""
server = Server(
"ContextAwareServer",
on_list_tools=_handle_context_list_tools,
Expand Down Expand Up @@ -1547,26 +1547,45 @@ def run_context_aware_server(port: int): # pragma: no cover
log_level="error",
)
)
server_instance.run()
return server_instance


@pytest.fixture
def context_aware_server(basic_server_port: int) -> Generator[None, None, None]:
"""Start the context-aware server in a separate process."""
proc = multiprocessing.Process(target=run_context_aware_server, args=(basic_server_port,), daemon=True)
proc.start()
"""Start the context-aware server on a background thread (in-process for coverage).

Unlike multiprocessing, threads share the host process's warning filters.
Uvicorn and the Windows ProactorEventLoop emit DeprecationWarning /
ResourceWarning during startup and teardown that pytest's
``filterwarnings = ["error"]`` would otherwise promote to hard failures.
We therefore run the server with all warnings suppressed (mirroring
the implicit isolation that multiprocessing provided).
"""
server_instance = _create_context_aware_server(basic_server_port)

def _run() -> None:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
server_instance.run()

thread = threading.Thread(target=_run, daemon=True)
thread.start()

# Wait for server to be running
wait_for_server(basic_server_port)

yield

proc.kill()
proc.join(timeout=2)
if proc.is_alive(): # pragma: no cover
print("Context-aware server process failed to terminate")
server_instance.should_exit = True
thread.join(timeout=5)


# Marker to suppress Windows ProactorEventLoop teardown warnings on threaded servers.
# When uvicorn runs in a thread (instead of a subprocess), transport finalizers fire
# during GC in the main process and trigger PytestUnraisableExceptionWarning.
_suppress_transport_teardown = pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning")


@_suppress_transport_teardown
@pytest.mark.anyio
async def test_streamablehttp_request_context_propagation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request context is properly propagated through StreamableHTTP."""
Expand Down Expand Up @@ -1600,6 +1619,7 @@ async def test_streamablehttp_request_context_propagation(context_aware_server:
assert headers_data.get("x-trace-id") == "trace-123"


@_suppress_transport_teardown
@pytest.mark.anyio
async def test_streamablehttp_request_context_isolation(context_aware_server: None, basic_server_url: str) -> None:
"""Test that request contexts are isolated between StreamableHTTP clients."""
Expand Down Expand Up @@ -1638,6 +1658,7 @@ async def test_streamablehttp_request_context_isolation(context_aware_server: No
assert ctx["headers"].get("authorization") == f"Bearer token-{i}"


@_suppress_transport_teardown
@pytest.mark.anyio
async def test_client_includes_protocol_version_header_after_init(context_aware_server: None, basic_server_url: str):
"""Test that client includes mcp-protocol-version header after initialization."""
Expand Down Expand Up @@ -2251,6 +2272,7 @@ async def test_streamable_http_client_does_not_mutate_provided_client(
assert custom_client.headers.get("Authorization") == "Bearer test-token"


@_suppress_transport_teardown
@pytest.mark.anyio
async def test_streamable_http_client_mcp_headers_override_defaults(
context_aware_server: None, basic_server_url: str
Expand Down Expand Up @@ -2282,6 +2304,7 @@ async def test_streamable_http_client_mcp_headers_override_defaults(
assert headers_data["content-type"] == "application/json"


@_suppress_transport_teardown
@pytest.mark.anyio
async def test_streamable_http_client_preserves_custom_with_mcp_headers(
context_aware_server: None, basic_server_url: str
Expand Down
Loading