Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ex_app/lib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from ex_app.lib.agent import react
from ex_app.lib.logger import log
from ex_app.lib.mcp_server import UserAuthMiddleware, ToolListMiddleware
from ex_app.lib.mcp_server import UserAuthMiddleware, ToolListMiddleware, MCPAuthHeaderMiddleware
from ex_app.lib.provider import provider
from ex_app.lib.tools import get_categories

Expand Down Expand Up @@ -72,6 +72,7 @@ async def exapp_lifespan(app: FastAPI):

APP = FastAPI(lifespan=lifespan)
APP.add_middleware(AppAPIAuthMiddleware) # set global AppAPI authentication middleware
APP.add_middleware(MCPAuthHeaderMiddleware) # captures Authorization into ContextVar for FastMCP tasks
categories=get_categories()

SETTINGS = SettingsForm(
Expand Down
63 changes: 59 additions & 4 deletions ex_app/lib/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import time
import asyncio
import inspect
import logging
from contextvars import ContextVar
from functools import wraps

from fastmcp.server.dependencies import get_context
Expand All @@ -13,6 +15,42 @@
from ex_app.lib.tools import get_tools
import requests

logger = logging.getLogger(__name__)

# ContextVar to propagate the Authorization header into FastMCP background tasks.
# asyncio.create_task() copies the current context snapshot at task-creation time,
# so a ContextVar set before the task is spawned is visible inside it — even after
# Starlette's request context has been torn down.
_mcp_auth_header: ContextVar[str | None] = ContextVar("_mcp_auth_header", default=None)


class MCPAuthHeaderMiddleware:
"""Pure-ASGI middleware that captures the Authorization header from every
HTTP request and stores it in a ContextVar.

Must be added to the outer FastAPI app (APP) so it runs before FastMCP takes
ownership of the request. Because asyncio copies the current context when
spawning tasks, the ContextVar value is available inside FastMCP's background-
task processing even after the original Starlette request context is gone.
"""

def __init__(self, app):
self.app = app

async def __call__(self, scope, receive, send):
if scope["type"] == "http":
headers = {k: v for k, v in scope.get("headers", [])}
raw_auth = headers.get(b"authorization", b"")
auth_value = raw_auth.decode("latin-1") if raw_auth else None
token = _mcp_auth_header.set(auth_value)
try:
await self.app(scope, receive, send)
finally:
_mcp_auth_header.reset(token)
else:
await self.app(scope, receive, send)


def get_user(authorization_header: str, nc: AsyncNextcloudApp) -> str:
response = requests.get(
f"{nc.app_cfg.endpoint}/ocs/v2.php/cloud/user",
Expand All @@ -29,10 +67,23 @@ def get_user(authorization_header: str, nc: AsyncNextcloudApp) -> str:

class UserAuthMiddleware(Middleware):
async def on_message(self, context: MiddlewareContext, call_next):
# Middleware stores user info in context state
authorization_header = context.fastmcp_context.request_context.request.headers.get("Authorization")
# 1. Primary path: read from ContextVar — always works, including inside
# FastMCP background tasks where the Starlette request context is gone.
authorization_header = _mcp_auth_header.get()

# 2. Fallback: try the live request context (works for non-background paths
# and keeps compatibility if the transport behaviour changes).
if authorization_header is None:
try:
authorization_header = (
context.fastmcp_context.request_context.request.headers.get("Authorization")
)
except Exception:
pass
Comment on lines +81 to +82

if not authorization_header:
raise Exception("Authorization header is missing/invalid")

nc = AsyncNextcloudApp()
user = get_user(authorization_header, nc)
Comment on lines +84 to 88
await nc.set_user(user)
Expand All @@ -53,8 +104,12 @@ async def on_message(
call_next: CallNext[mt.ListToolsRequest, list[Tool]],
) -> list[Tool]:
global LAST_MCP_TOOL_UPDATE
if LAST_MCP_TOOL_UPDATE + 60 < time.time():
safe, dangerous = await get_tools(context.fastmcp_context.get_state("nextcloud"))
nc = context.fastmcp_context.get_state("nextcloud")
# Guard: only refresh the tool list when auth succeeded (nc is set) and
# the cache has expired. Previously this would crash on every message
# type (including `initialize`) because nc was None after auth failures.
if nc is not None and LAST_MCP_TOOL_UPDATE + 60 < time.time():
safe, dangerous = await get_tools(nc)
tools = await self.mcp.get_tools()
if LAST_MCP_TOOL_UPDATE + 60 < time.time():
for tool in tools.keys():
Expand Down
Loading