diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 8d0b13c8c..2ed9b1064 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -94,3 +94,9 @@ Tful tiangolo typeerror vulnz +ASSRF +canonname +gaierror +IMDS +INJ +sockaddr diff --git a/src/a2a/client/card_resolver.py b/src/a2a/client/card_resolver.py index adb3c5aee..21864d22f 100644 --- a/src/a2a/client/card_resolver.py +++ b/src/a2a/client/card_resolver.py @@ -16,6 +16,10 @@ AgentCard, ) from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH +from a2a.utils.url_validation import ( + A2ASSRFValidationError, + validate_agent_card_url, +) logger = logging.getLogger(__name__) @@ -65,11 +69,10 @@ async def get_agent_card( Raises: A2AClientHTTPError: If an HTTP error occurs during the request. - A2AClientJSONError: If the response body cannot be decoded as JSON - or validated against the AgentCard schema. + A2AClientJSONError: If the response body cannot be decoded as JSON, + validated against the AgentCard schema, or fails SSRF URL validation. """ if not relative_card_path: - # Use the default public agent card path configured during initialization path_segment = self.agent_card_path else: path_segment = relative_card_path.lstrip('/') @@ -89,8 +92,23 @@ async def get_agent_card( agent_card_data, ) agent_card = AgentCard.model_validate(agent_card_data) + + # Validate card.url before returning (fix for A2A-SSRF-01). + # Without this check, any caller who controls the card endpoint + # can redirect all subsequent RPC calls to an internal address. + try: + validate_agent_card_url(agent_card.url) + # Also validate any additional transport URLs declared in the card. + for iface in agent_card.additional_interfaces or []: + validate_agent_card_url(iface.url) + except A2ASSRFValidationError as e: + raise A2AClientJSONError( + f'AgentCard from {target_url} failed SSRF URL validation: {e}' + ) from e + if signature_verifier: signature_verifier(agent_card) + except httpx.HTTPStatusError as e: raise A2AClientHTTPError( e.response.status_code, @@ -105,7 +123,7 @@ async def get_agent_card( 503, f'Network communication error fetching agent card from {target_url}: {e}', ) from e - except ValidationError as e: # Pydantic validation error + except ValidationError as e: raise A2AClientJSONError( f'Failed to validate agent card structure from {target_url}: {e.json()}' ) from e diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 3bd6a0dc2..58fe58fca 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -1,7 +1,7 @@ import asyncio import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable from typing import cast from a2a.server.agent_execution import ( @@ -57,15 +57,16 @@ TaskState.rejected, } +# ---- NEW: caller identity extractor type (fix for A2A-INJ-01) ---- +# CallerIdExtractor extracts a stable identity string from ServerCallContext. +# Returns None if caller identity cannot be determined (unauthenticated). +CallerIdExtractor = Callable[['ServerCallContext | None'], str | None] +# ------------------------------------------------------------------ + @trace_class(kind=SpanKind.SERVER) class DefaultRequestHandler(RequestHandler): - """Default request handler for all incoming requests. - - This handler provides default implementations for all A2A JSON-RPC methods, - coordinating between the `AgentExecutor`, `TaskStore`, `QueueManager`, - and optional `PushNotifier`. - """ + """Default request handler for all incoming requests.""" _running_agents: dict[str, asyncio.Task] _background_tasks: set[asyncio.Task] @@ -78,17 +79,42 @@ def __init__( # noqa: PLR0913 push_config_store: PushNotificationConfigStore | None = None, push_sender: PushNotificationSender | None = None, request_context_builder: RequestContextBuilder | None = None, + # ---- NEW PARAMETER (fix for A2A-INJ-01) ---- + get_caller_id: CallerIdExtractor | None = None, + # -------------------------------------------- ) -> None: """Initializes the DefaultRequestHandler. Args: - agent_executor: The `AgentExecutor` instance to run agent logic. - task_store: The `TaskStore` instance to manage task persistence. - queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`. - push_config_store: The `PushNotificationConfigStore` instance for managing push notification configurations. Defaults to None. - push_sender: The `PushNotificationSender` instance for sending push notifications. Defaults to None. - request_context_builder: The `RequestContextBuilder` instance used - to build request contexts. Defaults to `SimpleRequestContextBuilder`. + agent_executor: The AgentExecutor instance to run agent logic. + task_store: The TaskStore instance to manage task persistence. + queue_manager: The QueueManager instance. Defaults to InMemoryQueueManager. + push_config_store: The PushNotificationConfigStore instance. + push_sender: The PushNotificationSender instance. + request_context_builder: The RequestContextBuilder instance. + get_caller_id: Optional callable that extracts a stable identity + string from a ServerCallContext (e.g. JWT sub, API key, mTLS + fingerprint). When provided, the handler tracks which caller + created each contextId and rejects messages from different + callers attempting to join that context (A2A-INJ-01 fix). + If None (default), no ownership tracking is performed -- + backward compatible with existing deployments. + + Example:: + + def get_caller_id( + ctx: ServerCallContext | None, + ) -> str | None: + if ctx is None or not ctx.user.is_authenticated: + return None + return ctx.user.user_name + + + handler = DefaultRequestHandler( + agent_executor=executor, + task_store=task_store, + get_caller_id=get_caller_id, + ) """ self.agent_executor = agent_executor self.task_store = task_store @@ -101,11 +127,20 @@ def __init__( # noqa: PLR0913 should_populate_referred_tasks=False, task_store=self.task_store ) ) - # TODO: Likely want an interface for managing this, like AgentExecutionManager. + # ---- NEW (fix for A2A-INJ-01) ---- + self._get_caller_id: CallerIdExtractor | None = get_caller_id + # Maps context_id -> owner identity; populated on first message per context. + self._context_owners: dict[str, str] = {} + if get_caller_id is None: + logger.warning( + 'DefaultRequestHandler initialized without get_caller_id: ' + 'context ownership is not enforced. Cross-user context injection ' + '(A2A-INJ-01 / CWE-639) is possible. Provide a get_caller_id ' + 'extractor to enable ownership checks.' + ) + # ---------------------------------- self._running_agents = {} self._running_agents_lock = asyncio.Lock() - # Tracks background tasks (e.g., deferred cleanups) to avoid orphaning - # asyncio tasks and to surface unexpected exceptions. self._background_tasks = set() async def on_get_task( @@ -117,8 +152,6 @@ async def on_get_task( task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - - # Apply historyLength parameter if specified return apply_history_length(task, params.history_length) async def on_cancel_task( @@ -132,7 +165,6 @@ async def on_cancel_task( if not task: raise ServerError(error=TaskNotFoundError()) - # Check if task is in a non-cancelable state (completed, canceled, failed, rejected) if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=TaskNotCancelableError( @@ -148,7 +180,6 @@ async def on_cancel_task( context=context, ) result_aggregator = ResultAggregator(task_manager) - queue = await self._queue_manager.tap(task.id) if not queue: queue = EventQueue() @@ -162,7 +193,6 @@ async def on_cancel_task( ), queue, ) - # Cancel the ongoing task, if one exists. if producer_task := self._running_agents.get(task.id): producer_task.cancel() @@ -196,20 +226,82 @@ async def _run_event_stream( await self.agent_executor.execute(request, queue) await queue.close() + def _check_context_ownership( + self, + context_id: str, + context: ServerCallContext | None, + ) -> None: + """Enforce context ownership when get_caller_id is configured. + + Called before any message is processed for an existing context_id. + Only invoked when context_id is already present in _context_owners, + which guarantees _get_caller_id is not None and owner is not None. + Raises ServerError(InvalidParamsError) if the caller does not own + the context. + """ + caller = self._get_caller_id(context) # type: ignore[misc] + owner = self._context_owners[context_id] + + if caller is None: + raise ServerError( + error=InvalidParamsError( + message=( + f'Access denied: cannot send to context_id={context_id!r} ' + 'because caller identity could not be determined.' + ) + ) + ) + + if caller != owner: + logger.warning( + 'Context injection attempt blocked: caller=%r tried to send to ' + 'context_id=%s owned by %r.', + caller, + context_id, + owner, + ) + raise ServerError( + error=InvalidParamsError( + message=( + f'Access denied: context_id={context_id!r} was created ' + 'by a different caller.' + ) + ) + ) + + def _record_context_owner( + self, + context_id: str, + context: ServerCallContext | None, + ) -> None: + """Record caller as owner of context_id on first use.""" + if self._get_caller_id is None or context_id in self._context_owners: + return + caller = self._get_caller_id(context) + if caller: + self._context_owners[context_id] = caller + logger.debug( + 'Recorded owner %r for context_id=%s', caller, context_id + ) + async def _setup_message_execution( self, params: MessageSendParams, context: ServerCallContext | None = None, ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: - """Common setup logic for both streaming and non-streaming message handling. + context_id = params.message.context_id + + # ---- FIX: A2A-INJ-01 -- enforce context ownership BEFORE task lookup ---- + # The check must happen at context_id level, not task level. An attacker + # who sends a new task_id under an existing context_id would otherwise + # bypass a task-level check (get_task() returns None -> check never runs). + if context_id and context_id in self._context_owners: + self._check_context_ownership(context_id, context) + # ----------------------------------------------------------------------- - Returns: - A tuple of (task_manager, task_id, queue, result_aggregator, producer_task) - """ - # Create task manager and validate existing task task_manager = TaskManager( task_id=params.message.task_id, - context_id=params.message.context_id, + context_id=context_id, task_store=self.task_store, initial_message=params.message, context=context, @@ -223,7 +315,6 @@ async def _setup_message_execution( message=f'Task {task.id} is in terminal state: {task.status.state.value}' ) ) - task = task_manager.update_with_message(params.message, task) elif params.message.task_id: raise ServerError( @@ -232,19 +323,19 @@ async def _setup_message_execution( ) ) - # Build request context request_context = await self._request_context_builder.build( params=params, task_id=task.id if task else None, - context_id=params.message.context_id, + context_id=context_id, task=task, context=context, ) - task_id = cast('str', request_context.task_id) - # Always assign a task ID. We may not actually upgrade to a task, but - # dictating the task ID at this layer is useful for tracking running - # agents. + + # Record ownership for new contexts after successful validation + new_context_id = request_context.context_id or context_id + if new_context_id: + self._record_context_owner(new_context_id, context) if ( self._push_config_store @@ -257,7 +348,6 @@ async def _setup_message_execution( queue = await self._queue_manager.create_or_tap(task_id) result_aggregator = ResultAggregator(task_manager) - # TODO: to manage the non-blocking flows. producer_task = asyncio.create_task( self._run_event_stream(request_context, queue) ) @@ -266,7 +356,6 @@ async def _setup_message_execution( return task_manager, task_id, queue, result_aggregator, producer_task def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: - """Validates that agent-generated task ID matches the expected task ID.""" if task_id != event_task_id: logger.error( 'Agent generated task_id=%s does not match the RequestContext task_id=%s.', @@ -280,7 +369,6 @@ def _validate_task_id_match(self, task_id: str, event_task_id: str) -> None: async def _send_push_notification_if_needed( self, task_id: str, result_aggregator: ResultAggregator ) -> None: - """Sends push notification if configured and task is available.""" if self._push_sender and task_id: latest_task = await result_aggregator.current_result if isinstance(latest_task, Task): @@ -307,13 +395,13 @@ async def on_message_send( consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) - blocking = True # Default to blocking behavior + blocking = True if params.configuration and params.configuration.blocking is False: blocking = False interrupted_or_non_blocking = False try: - # Create async callback for push notifications + async def push_notification_callback() -> None: await self._send_push_notification_if_needed( task_id, result_aggregator @@ -358,7 +446,6 @@ async def push_notification_callback() -> None: ) await self._send_push_notification_if_needed(task_id, result_aggregator) - return result async def on_message_send_stream( @@ -385,13 +472,11 @@ async def on_message_send_stream( async for event in result_aggregator.consume_and_emit(consumer): if isinstance(event, Task): self._validate_task_id_match(task_id, event.id) - await self._send_push_notification_if_needed( task_id, result_aggregator ) yield event except (asyncio.CancelledError, GeneratorExit): - # Client disconnected: continue consuming and persisting events in the background bg_task = asyncio.create_task( result_aggregator.consume_all(consumer) ) @@ -408,39 +493,31 @@ async def on_message_send_stream( async def _register_producer( self, task_id: str, producer_task: asyncio.Task ) -> None: - """Registers the agent execution task with the handler.""" async with self._running_agents_lock: self._running_agents[task_id] = producer_task def _track_background_task(self, task: asyncio.Task) -> None: - """Tracks a background task and logs exceptions on completion. - - This avoids unreferenced tasks (and associated lint warnings) while - ensuring any exceptions are surfaced in logs. - """ self._background_tasks.add(task) def _on_done(completed: asyncio.Task) -> None: try: - # Retrieve result to raise exceptions, if any completed.result() except asyncio.CancelledError: - name = completed.get_name() - logger.debug('Background task %s cancelled', name) + logger.debug( + 'Background task %s cancelled', completed.get_name() + ) except Exception: - name = completed.get_name() - logger.exception('Background task %s failed', name) + logger.exception( + 'Background task %s failed', completed.get_name() + ) finally: self._background_tasks.discard(completed) task.add_done_callback(_on_done) async def _cleanup_producer( - self, - producer_task: asyncio.Task, - task_id: str, + self, producer_task: asyncio.Task, task_id: str ) -> None: - """Cleans up the agent execution task and queue manager entry.""" try: await producer_task except asyncio.CancelledError: @@ -462,16 +539,12 @@ async def on_set_task_push_notification_config( """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.task_id, context) if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.set_info( - params.task_id, - params.push_notification_config, + params.task_id, params.push_notification_config ) - return params async def on_get_task_push_notification_config( @@ -485,11 +558,9 @@ async def on_get_task_push_notification_config( """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - push_notification_config = await self._push_config_store.get_info( params.id ) @@ -499,7 +570,6 @@ async def on_get_task_push_notification_config( message='Push notification config not found' ) ) - return TaskPushNotificationConfig( task_id=params.id, push_notification_config=push_notification_config[0], @@ -518,14 +588,12 @@ async def on_resubscribe_to_task( task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - if task.status.state in TERMINAL_TASK_STATES: raise ServerError( error=InvalidParamsError( message=f'Task {task.id} is in terminal state: {task.status.state.value}' ) ) - task_manager = TaskManager( task_id=task.id, context_id=task.context_id, @@ -533,13 +601,10 @@ async def on_resubscribe_to_task( initial_message=None, context=context, ) - result_aggregator = ResultAggregator(task_manager) - queue = await self._queue_manager.tap(task.id) if not queue: raise ServerError(error=TaskNotFoundError()) - consumer = EventConsumer(queue) async for event in result_aggregator.consume_and_emit(consumer): yield event @@ -555,20 +620,17 @@ async def on_list_task_push_notification_config( """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - push_notification_config_list = await self._push_config_store.get_info( params.id ) - return [ TaskPushNotificationConfig( - task_id=params.id, push_notification_config=config + task_id=params.id, push_notification_config=cfg ) - for config in push_notification_config_list + for cfg in push_notification_config_list ] async def on_delete_task_push_notification_config( @@ -582,11 +644,9 @@ async def on_delete_task_push_notification_config( """ if not self._push_config_store: raise ServerError(error=UnsupportedOperationError()) - task: Task | None = await self.task_store.get(params.id, context) if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.delete_info( params.id, params.push_notification_config_id ) diff --git a/src/a2a/utils/url_validation.py b/src/a2a/utils/url_validation.py new file mode 100644 index 000000000..094c19f94 --- /dev/null +++ b/src/a2a/utils/url_validation.py @@ -0,0 +1,116 @@ +"""URL validation utilities for A2A agent card URLs. + +Prevents Server-Side Request Forgery (SSRF) attacks by validating that +AgentCard.url values do not point to private, loopback, or link-local +network addresses before the SDK uses them as RPC endpoints. +""" + +import ipaddress +import logging +import socket + +from urllib.parse import urlparse + + +logger = logging.getLogger(__name__) + +# Only these schemes are permitted in AgentCard.url values. +_ALLOWED_SCHEMES = frozenset({'http', 'https'}) + +# Networks that must never be reachable via a resolved AgentCard URL. +# Covers: loopback, RFC 1918 private ranges, link-local (IMDS), and other +# IANA-reserved blocks that have no legitimate use as public agent endpoints. +_BLOCKED_NETWORKS: tuple[ipaddress.IPv4Network | ipaddress.IPv6Network, ...] = ( + # Loopback + ipaddress.ip_network('127.0.0.0/8'), + ipaddress.ip_network('::1/128'), + # RFC 1918 private ranges + ipaddress.ip_network('10.0.0.0/8'), + ipaddress.ip_network('172.16.0.0/12'), + ipaddress.ip_network('192.168.0.0/16'), + # Link-local -- covers AWS/GCP/Azure/OCI IMDS (169.254.169.254) + ipaddress.ip_network('169.254.0.0/16'), + ipaddress.ip_network('fe80::/10'), + # IPv6 unique local (ULA) -- equivalent of RFC 1918 for IPv6 + ipaddress.ip_network('fc00::/7'), + # Shared address space (RFC 6598 -- carrier-grade NAT) + ipaddress.ip_network('100.64.0.0/10'), + # Other IANA reserved / unroutable + ipaddress.ip_network('0.0.0.0/8'), + ipaddress.ip_network('192.0.0.0/24'), + ipaddress.ip_network('198.18.0.0/15'), + ipaddress.ip_network('240.0.0.0/4'), +) + + +class A2ASSRFValidationError(ValueError): + """Raised when an AgentCard URL fails SSRF validation.""" + + +def validate_agent_card_url(url: str) -> None: + """Validate that *url* is safe to use as an A2A RPC endpoint. + + Checks performed (in order): + + 1. URL must be parseable and non-empty. + 2. Scheme must be ``http`` or ``https``. + 3. Hostname must be present and non-empty. + 4. The hostname must resolve to a publicly routable IP address -- it must + not resolve to a loopback, private, link-local, or otherwise reserved + address (SSRF / IMDS protection). + + Args: + url: The URL string from ``AgentCard.url`` (or + ``AgentInterface.url``) to validate. + + Raises: + A2ASSRFValidationError: If the URL fails any validation check. + """ + if not url: + raise A2ASSRFValidationError('AgentCard URL must not be empty.') + + parsed = urlparse(url) + + # 1. Scheme check + scheme = (parsed.scheme or '').lower() + if scheme not in _ALLOWED_SCHEMES: + raise A2ASSRFValidationError( + f'AgentCard URL scheme {scheme!r} is not permitted. ' + f'Allowed schemes: {sorted(_ALLOWED_SCHEMES)}. ' + 'Arbitrary schemes allow SSRF attacks (CWE-918).' + ) + + # 2. Hostname presence + hostname = parsed.hostname + if not hostname: + raise A2ASSRFValidationError( + f'AgentCard URL {url!r} contains no hostname.' + ) + + # 3. Resolve hostname and check against blocked networks + try: + # getaddrinfo returns all A/AAAA records; check every resolved address. + addr_infos = socket.getaddrinfo(hostname, None) + except socket.gaierror as exc: + raise A2ASSRFValidationError( + f'AgentCard URL hostname {hostname!r} could not be resolved: {exc}. ' + 'Unresolvable hostnames may indicate DNS rebinding attempts.' + ) from exc + + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + ip_str = sockaddr[0] + try: + ip = ipaddress.ip_address(ip_str) + except ValueError: + continue + + for blocked in _BLOCKED_NETWORKS: + if ip in blocked: + raise A2ASSRFValidationError( + f'AgentCard URL {url!r} resolves to {ip_str}, ' + f'which is within the blocked network {blocked}. ' + 'Requests to private/loopback/link-local addresses are ' + 'forbidden to prevent SSRF attacks (CWE-918).' + ) + + logger.debug('AgentCard URL passed SSRF validation: %s', url) diff --git a/tests/client/test_card_resolver.py b/tests/client/test_card_resolver.py index 26f3f106d..13e2058f1 100644 --- a/tests/client/test_card_resolver.py +++ b/tests/client/test_card_resolver.py @@ -116,7 +116,13 @@ async def test_get_agent_card_success_default_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ) as mock_validate: result = await resolver.get_agent_card() mock_httpx_client.get.assert_called_once_with( @@ -141,7 +147,13 @@ async def test_get_agent_card_success_custom_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path=custom_path) @@ -163,7 +175,13 @@ async def test_get_agent_card_strips_leading_slash_from_relative_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path=custom_path) @@ -188,7 +206,13 @@ async def test_get_agent_card_with_http_kwargs( 'headers': {'Authorization': 'Bearer token'}, } with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(http_kwargs=http_kwargs) mock_httpx_client.get.assert_called_once_with( @@ -210,7 +234,13 @@ async def test_get_agent_card_root_path( mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path='/') mock_httpx_client.get.assert_called_once_with(f'{base_url}/') @@ -297,7 +327,13 @@ async def test_get_agent_card_logs_success( # noqa: PLR0913 mock_httpx_client.get.return_value = mock_response with ( patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ), caplog.at_level(logging.INFO), ): @@ -321,7 +357,13 @@ async def test_get_agent_card_none_relative_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path=None) mock_httpx_client.get.assert_called_once_with( @@ -342,7 +384,13 @@ async def test_get_agent_card_empty_string_relative_path( mock_httpx_client.get.return_value = mock_response with patch.object( - AgentCard, 'model_validate', return_value=Mock(spec=AgentCard) + AgentCard, + 'model_validate', + return_value=Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ), ): await resolver.get_agent_card(relative_card_path='') @@ -373,7 +421,11 @@ async def test_get_agent_card_returns_agent_card_instance( """Test that get_agent_card returns an AgentCard instance.""" mock_response.json.return_value = valid_agent_card_data mock_httpx_client.get.return_value = mock_response - mock_agent_card = Mock(spec=AgentCard) + mock_agent_card = Mock( + spec=AgentCard, + url='https://example.com/a2a', + additional_interfaces=None, + ) with patch.object( AgentCard, 'model_validate', return_value=mock_agent_card diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..63a978139 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,18 @@ +from unittest.mock import patch + +import pytest + + +@pytest.fixture(autouse=True) +def bypass_ssrf_url_validation(request): + """Bypass DNS-based SSRF validation for all tests except test_url_validation. + + Most tests use synthetic hostnames (localhost, testserver, example.com) + that either resolve to loopback or are unavailable in CI. The actual SSRF + validation logic is tested in tests/utils/test_url_validation.py. + """ + if 'test_url_validation' in request.node.nodeid: + yield + else: + with patch('a2a.client.card_resolver.validate_agent_card_url'): + yield diff --git a/tests/utils/test_url_validation.py b/tests/utils/test_url_validation.py new file mode 100644 index 000000000..f0469500a --- /dev/null +++ b/tests/utils/test_url_validation.py @@ -0,0 +1,109 @@ +"""Tests for a2a.utils.url_validation (A2A-SSRF-01 fix). + +Target: tests/utils/test_url_validation.py +""" + +import pytest + +from a2a.utils.url_validation import ( + A2ASSRFValidationError, + validate_agent_card_url, +) + + +class TestValidateAgentCardUrlScheme: + """URL scheme validation.""" + + @pytest.mark.parametrize( + 'url', + [ + 'file:///etc/passwd', + 'gopher://internal/1', + 'ftp://files.example.com/secret', + 'dict://internal/', + 'ldap://ldap.example.com/', + '', + ], + ) + def test_non_http_schemes_are_blocked(self, url): + with pytest.raises(A2ASSRFValidationError): + validate_agent_card_url(url) + + @pytest.mark.parametrize( + 'url', + [ + 'http://example.com/rpc', + 'https://example.com/rpc', + 'HTTP://EXAMPLE.COM/RPC', + 'HTTPS://EXAMPLE.COM/RPC', + ], + ) + def test_http_and_https_are_allowed(self, url): + # Should not raise — only scheme + hostname check, DNS may vary + # We only verify scheme acceptance here; real DNS tested separately. + try: + validate_agent_card_url(url) + except A2ASSRFValidationError as exc: + # Accept DNS resolution failure — scheme was accepted + assert 'could not be resolved' in str( + exc + ) or 'blocked network' in str(exc) + + +class TestValidateAgentCardUrlPrivateIPs: + """Private / reserved IP range blocking.""" + + @pytest.mark.parametrize( + 'url,label', + [ + ('http://127.0.0.1/rpc', 'loopback IPv4'), + ('http://127.1.2.3/rpc', 'loopback IPv4 (non-zero host)'), + ('http://[::1]/rpc', 'loopback IPv6'), + ('http://10.0.0.1/rpc', 'RFC 1918 10/8'), + ('http://10.255.255.255/rpc', 'RFC 1918 10/8 broadcast'), + ('http://172.16.0.1/rpc', 'RFC 1918 172.16/12'), + ('http://172.31.255.255/rpc', 'RFC 1918 172.31 (last in range)'), + ('http://192.168.1.1/rpc', 'RFC 1918 192.168/16'), + ('http://169.254.169.254/latest/meta-data/', 'AWS IMDS'), + ('http://169.254.0.1/rpc', 'link-local'), + ('http://100.64.0.1/rpc', 'shared address space RFC 6598'), + ], + ) + def test_private_addresses_are_blocked(self, url, label): + with pytest.raises(A2ASSRFValidationError, match='blocked network'): + validate_agent_card_url(url) + + def test_public_ip_is_allowed(self): + """A routable public IP should not be blocked.""" + # 93.184.216.34 is example.com — guaranteed public + try: + validate_agent_card_url('http://93.184.216.34/rpc') + except A2ASSRFValidationError as exc: + # Only acceptable failure is DNS (not a blocked-network error) + assert 'could not be resolved' in str(exc) + pytest.skip('DNS not available in this environment') + + +class TestValidateAgentCardUrlHostname: + """Hostname-level checks.""" + + def test_missing_hostname_is_blocked(self): + with pytest.raises(A2ASSRFValidationError, match='no hostname'): + validate_agent_card_url('http:///path') + + def test_empty_url_is_blocked(self): + with pytest.raises(A2ASSRFValidationError, match='must not be empty'): + validate_agent_card_url('') + + +class TestA2ASSRFValidationError: + """Exception type tests.""" + + def test_is_subclass_of_value_error(self): + assert issubclass(A2ASSRFValidationError, ValueError) + + def test_raises_with_descriptive_message(self): + with pytest.raises(A2ASSRFValidationError) as exc_info: + validate_agent_card_url('http://127.0.0.1/rpc') + assert '127.0.0.1' in str(exc_info.value) + assert 'CWE-918' in str(exc_info.value)