diff --git a/ex_app/lib/main.py b/ex_app/lib/main.py index 63d9661..3ac0462 100644 --- a/ex_app/lib/main.py +++ b/ex_app/lib/main.py @@ -23,7 +23,7 @@ from ex_app.lib.agent import react from ex_app.lib.logger import log -from ex_app.lib.mcp_server import UserAuthMiddleware, ToolListMiddleware +from ex_app.lib.mcp_server import UserAuthMiddleware, ToolListMiddleware, MCPAuthHeaderMiddleware from ex_app.lib.provider import provider from ex_app.lib.tools import get_categories @@ -72,6 +72,7 @@ async def exapp_lifespan(app: FastAPI): APP = FastAPI(lifespan=lifespan) APP.add_middleware(AppAPIAuthMiddleware) # set global AppAPI authentication middleware +APP.add_middleware(MCPAuthHeaderMiddleware) # captures Authorization into ContextVar for FastMCP tasks categories=get_categories() SETTINGS = SettingsForm( diff --git a/ex_app/lib/mcp_server.py b/ex_app/lib/mcp_server.py index 4ccc1cc..153e071 100644 --- a/ex_app/lib/mcp_server.py +++ b/ex_app/lib/mcp_server.py @@ -3,6 +3,8 @@ import time import asyncio import inspect +import logging +from contextvars import ContextVar from functools import wraps from fastmcp.server.dependencies import get_context @@ -13,6 +15,42 @@ from ex_app.lib.tools import get_tools import requests +logger = logging.getLogger(__name__) + +# ContextVar to propagate the Authorization header into FastMCP background tasks. +# asyncio.create_task() copies the current context snapshot at task-creation time, +# so a ContextVar set before the task is spawned is visible inside it — even after +# Starlette's request context has been torn down. +_mcp_auth_header: ContextVar[str | None] = ContextVar("_mcp_auth_header", default=None) + + +class MCPAuthHeaderMiddleware: + """Pure-ASGI middleware that captures the Authorization header from every + HTTP request and stores it in a ContextVar. + + Must be added to the outer FastAPI app (APP) so it runs before FastMCP takes + ownership of the request. Because asyncio copies the current context when + spawning tasks, the ContextVar value is available inside FastMCP's background- + task processing even after the original Starlette request context is gone. + """ + + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = {k: v for k, v in scope.get("headers", [])} + raw_auth = headers.get(b"authorization", b"") + auth_value = raw_auth.decode("latin-1") if raw_auth else None + token = _mcp_auth_header.set(auth_value) + try: + await self.app(scope, receive, send) + finally: + _mcp_auth_header.reset(token) + else: + await self.app(scope, receive, send) + + def get_user(authorization_header: str, nc: AsyncNextcloudApp) -> str: response = requests.get( f"{nc.app_cfg.endpoint}/ocs/v2.php/cloud/user", @@ -29,10 +67,23 @@ def get_user(authorization_header: str, nc: AsyncNextcloudApp) -> str: class UserAuthMiddleware(Middleware): async def on_message(self, context: MiddlewareContext, call_next): - # Middleware stores user info in context state - authorization_header = context.fastmcp_context.request_context.request.headers.get("Authorization") + # 1. Primary path: read from ContextVar — always works, including inside + # FastMCP background tasks where the Starlette request context is gone. + authorization_header = _mcp_auth_header.get() + + # 2. Fallback: try the live request context (works for non-background paths + # and keeps compatibility if the transport behaviour changes). if authorization_header is None: + try: + authorization_header = ( + context.fastmcp_context.request_context.request.headers.get("Authorization") + ) + except Exception: + pass + + if not authorization_header: raise Exception("Authorization header is missing/invalid") + nc = AsyncNextcloudApp() user = get_user(authorization_header, nc) await nc.set_user(user) @@ -53,8 +104,12 @@ async def on_message( call_next: CallNext[mt.ListToolsRequest, list[Tool]], ) -> list[Tool]: global LAST_MCP_TOOL_UPDATE - if LAST_MCP_TOOL_UPDATE + 60 < time.time(): - safe, dangerous = await get_tools(context.fastmcp_context.get_state("nextcloud")) + nc = context.fastmcp_context.get_state("nextcloud") + # Guard: only refresh the tool list when auth succeeded (nc is set) and + # the cache has expired. Previously this would crash on every message + # type (including `initialize`) because nc was None after auth failures. + if nc is not None and LAST_MCP_TOOL_UPDATE + 60 < time.time(): + safe, dangerous = await get_tools(nc) tools = await self.mcp.get_tools() if LAST_MCP_TOOL_UPDATE + 60 < time.time(): for tool in tools.keys():