From b7b699af8c3a67f4eefb33f429ca3270d1f990bb Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 31 Mar 2026 14:05:21 +0000 Subject: [PATCH 1/3] refactor: remove ClientTaskManager and related consumers from client components --- src/a2a/client/__init__.py | 4 - src/a2a/client/base_client.py | 49 +---- src/a2a/client/client.py | 31 +-- src/a2a/client/client_factory.py | 19 +- src/a2a/client/client_task_manager.py | 167 --------------- tests/client/test_base_client.py | 27 +-- tests/client/test_base_client_interceptors.py | 1 - tests/client/test_client_factory.py | 9 +- tests/client/test_client_task_manager.py | 191 ------------------ .../test_default_push_notification_support.py | 9 +- .../cross_version/client_server/client_1_0.py | 8 +- .../test_client_server_integration.py | 13 +- .../integration/test_copying_observability.py | 5 +- tests/integration/test_end_to_end.py | 20 +- .../test_stream_generator_cleanup.py | 2 +- 15 files changed, 54 insertions(+), 501 deletions(-) delete mode 100644 src/a2a/client/client_task_manager.py delete mode 100644 tests/client/test_client_task_manager.py diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 26e35a4cb..188ab4c80 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -11,8 +11,6 @@ Client, ClientCallContext, ClientConfig, - ClientEvent, - Consumer, ) from a2a.client.client_factory import ClientFactory, minimal_agent_card from a2a.client.errors import ( @@ -35,9 +33,7 @@ 'ClientCallContext', 'ClientCallInterceptor', 'ClientConfig', - 'ClientEvent', 'ClientFactory', - 'Consumer', 'CredentialService', 'InMemoryContextCredentialStore', 'create_text_message_object', diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index a825ef50c..7d5a941e8 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -5,10 +5,7 @@ Client, ClientCallContext, ClientConfig, - ClientEvent, - Consumer, ) -from a2a.client.client_task_manager import ClientTaskManager from a2a.client.interceptors import ( AfterArgs, BeforeArgs, @@ -42,10 +39,9 @@ def __init__( card: AgentCard, config: ClientConfig, transport: ClientTransport, - consumers: list[Consumer], interceptors: list[ClientCallInterceptor], ): - super().__init__(consumers, interceptors) + super().__init__(interceptors) self._card = card self._config = config self._transport = transport @@ -56,7 +52,7 @@ async def send_message( request: SendMessageRequest, *, context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: + ) -> AsyncIterator[StreamResponse]: """Sends a message to the agent. This method handles both streaming and non-streaming (polling) interactions @@ -84,19 +80,14 @@ async def send_message( # In non-streaming case we convert to a StreamResponse so that the # client always sees the same iterator. stream_response = StreamResponse() - client_event: ClientEvent if response.HasField('task'): stream_response.task.CopyFrom(response.task) - client_event = (stream_response, response.task) elif response.HasField('message'): stream_response.message.CopyFrom(response.message) - client_event = (stream_response, None) else: - # Response must have either task or message raise ValueError('Response has neither task nor message') - await self.consume(client_event, self._card) - yield client_event + yield stream_response return async for event in self._execute_stream_with_interceptors( @@ -130,8 +121,7 @@ async def _process_stream( self, stream: AsyncIterator[StreamResponse], before_args: BeforeArgs, - ) -> AsyncGenerator[ClientEvent]: - tracker = ClientTaskManager() + ) -> AsyncGenerator[StreamResponse, None]: async for stream_response in stream: after_args = AfterArgs( result=stream_response, @@ -140,12 +130,8 @@ async def _process_stream( context=before_args.context, ) await self._intercept_after(after_args) - intercepted_response = after_args.result - client_event = await self._format_stream_event( - intercepted_response, tracker - ) - yield client_event - if intercepted_response.HasField('message'): + yield after_args.result + if after_args.result.HasField('message'): return async def get_task( @@ -318,7 +304,7 @@ async def subscribe( request: SubscribeToTaskRequest, *, context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: + ) -> AsyncIterator[StreamResponse]: """Resubscribes to a task's event stream. This is only available if both the client and server support streaming. @@ -436,7 +422,7 @@ async def _execute_stream_with_interceptors( transport_call: Callable[ [Any, ClientCallContext | None], AsyncIterator[StreamResponse] ], - ) -> AsyncIterator[ClientEvent]: + ) -> AsyncIterator[StreamResponse]: before_args = BeforeArgs( input=input_data, @@ -455,8 +441,7 @@ async def _execute_stream_with_interceptors( ) await self._intercept_after(after_args, before_result['executed']) - tracker = ClientTaskManager() - yield await self._format_stream_event(after_args.result, tracker) + yield after_args.result return stream = transport_call(before_args.input, before_args.context) @@ -495,19 +480,3 @@ async def _intercept_after( await interceptor.after(args) if args.early_return: return - - async def _format_stream_event( - self, stream_response: StreamResponse, tracker: ClientTaskManager - ) -> ClientEvent: - client_event: ClientEvent - if stream_response.HasField('message'): - client_event = (stream_response, None) - await self.consume(client_event, self._card) - return client_event - - await tracker.process(stream_response) - updated_task = tracker.get_task_or_raise() - client_event = (stream_response, updated_task) - - await self.consume(client_event, self._card) - return client_event diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 291b3864c..556be9030 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -2,7 +2,7 @@ import logging from abc import ABC, abstractmethod -from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping +from collections.abc import AsyncIterator, Callable, MutableMapping from types import TracebackType from typing import Any @@ -77,13 +77,6 @@ class ClientConfig: """Push notification configurations to use for every request.""" -ClientEvent = tuple[StreamResponse, Task | None] - -# Alias for an event consuming callback. It takes either a (task, update) pair -# or a message as well as the agent card for the agent this came from. -Consumer = Callable[[ClientEvent, AgentCard], Coroutine[None, Any, Any]] - - class ClientCallContext(BaseModel): """A context passed with each client call, allowing for call-specific. @@ -106,16 +99,13 @@ class Client(ABC): def __init__( self, - consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, ): - """Initializes the client with consumers and interceptors. + """Initializes the client with interceptors. Args: - consumers: A list of callables to process events from the agent. interceptors: A list of interceptors to process requests and responses. """ - self._consumers = consumers or [] self._interceptors = interceptors or [] async def __aenter__(self) -> Self: @@ -137,7 +127,7 @@ async def send_message( request: SendMessageRequest, *, context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: + ) -> AsyncIterator[StreamResponse]: """Sends a message to the server. This will automatically use the streaming or non-streaming approach @@ -218,7 +208,7 @@ async def subscribe( request: SubscribeToTaskRequest, *, context: ClientCallContext | None = None, - ) -> AsyncIterator[ClientEvent]: + ) -> AsyncIterator[StreamResponse]: """Resubscribes to a task's event stream.""" return yield @@ -233,23 +223,10 @@ async def get_extended_agent_card( ) -> AgentCard: """Retrieves the agent's card.""" - async def add_event_consumer(self, consumer: Consumer) -> None: - """Attaches additional consumers to the `Client`.""" - self._consumers.append(consumer) - async def add_interceptor(self, interceptor: ClientCallInterceptor) -> None: """Attaches additional interceptors to the `Client`.""" self._interceptors.append(interceptor) - async def consume( - self, - event: ClientEvent, - card: AgentCard, - ) -> None: - """Processes the event via all the registered `Consumer`s.""" - for c in self._consumers: - await c(event, card) - @abstractmethod async def close(self) -> None: """Closes the client and releases any underlying resources.""" diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 4aa1f88c7..5f07e7713 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -11,7 +11,7 @@ from a2a.client.base_client import BaseClient from a2a.client.card_resolver import A2ACardResolver -from a2a.client.client import Client, ClientConfig, Consumer +from a2a.client.client import Client, ClientConfig from a2a.client.transports.base import ClientTransport from a2a.client.transports.jsonrpc import JsonRpcTransport from a2a.client.transports.rest import RestTransport @@ -77,17 +77,12 @@ class ClientFactory: def __init__( self, config: ClientConfig, - consumers: list[Consumer] | None = None, ): - if consumers is None: - consumers = [] - client = config.httpx_client or httpx.AsyncClient() client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT) config.httpx_client = client self._config = config - self._consumers = consumers self._registry: dict[str, TransportProducer] = {} self._register_defaults(config.supported_protocol_bindings) @@ -263,7 +258,6 @@ async def connect( # noqa: PLR0913 cls, agent: str | AgentCard, client_config: ClientConfig | None = None, - consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, relative_card_path: str | None = None, resolver_http_kwargs: dict[str, Any] | None = None, @@ -286,7 +280,7 @@ async def connect( # noqa: PLR0913 Args: agent: The base URL of the agent, or the AgentCard to connect to. client_config: The ClientConfig to use when connecting to the agent. - consumers: A list of `Consumer` methods to pass responses to. + interceptors: A list of interceptors to use for each request. These are used for things like attaching credentials or http headers to all outbound requests. @@ -325,7 +319,7 @@ async def connect( # noqa: PLR0913 factory = cls(client_config) for label, generator in (extra_transports or {}).items(): factory.register(label, generator) - return factory.create(card, consumers, interceptors) + return factory.create(card, interceptors) def register(self, label: str, generator: TransportProducer) -> None: """Register a new transport producer for a given transport label.""" @@ -334,14 +328,12 @@ def register(self, label: str, generator: TransportProducer) -> None: def create( self, card: AgentCard, - consumers: list[Consumer] | None = None, interceptors: list[ClientCallInterceptor] | None = None, ) -> Client: """Create a new `Client` for the provided `AgentCard`. Args: card: An `AgentCard` defining the characteristics of the agent. - consumers: A list of `Consumer` methods to pass responses to. interceptors: A list of interceptors to use for each request. These are used for things like attaching credentials or http headers to all outbound requests. @@ -381,10 +373,6 @@ def create( if transport_protocol not in self._registry: raise ValueError(f'no client available for {transport_protocol}') - all_consumers = self._consumers.copy() - if consumers: - all_consumers.extend(consumers) - transport = self._registry[transport_protocol]( card, selected_interface.url, self._config ) @@ -398,7 +386,6 @@ def create( card, self._config, transport, - all_consumers, interceptors or [], ) diff --git a/src/a2a/client/client_task_manager.py b/src/a2a/client/client_task_manager.py deleted file mode 100644 index e5a3267f1..000000000 --- a/src/a2a/client/client_task_manager.py +++ /dev/null @@ -1,167 +0,0 @@ -import logging - -from a2a.client.errors import A2AClientError -from a2a.types.a2a_pb2 import ( - Message, - StreamResponse, - Task, - TaskState, - TaskStatus, -) -from a2a.utils import append_artifact_to_task - - -logger = logging.getLogger(__name__) - - -class ClientTaskManager: - """Helps manage a task's lifecycle during execution of a request. - - Responsible for retrieving, saving, and updating the `Task` object based on - events received from the agent. - """ - - def __init__( - self, - ) -> None: - """Initializes the `ClientTaskManager`.""" - self._current_task: Task | None = None - self._task_id: str | None = None - self._context_id: str | None = None - - def get_task(self) -> Task | None: - """Retrieves the current task object, either from memory. - - If `task_id` is set, it returns `_current_task` otherwise None. - - Returns: - The `Task` object if found, otherwise `None`. - """ - if not self._task_id: - logger.debug('task_id is not set, cannot get task.') - return None - - return self._current_task - - def get_task_or_raise(self) -> Task: - """Retrieves the current task object. - - Returns: - The `Task` object. - - Raises: - A2AClientError: If there is no current known Task. - """ - if not (task := self.get_task()): - # Note: The source of this error is either from bad client usage - # or from the server sending invalid updates. It indicates that this - # task manager has not consumed any information about a task, yet - # the caller is attempting to retrieve the current state of the task - # it expects to be present. - raise A2AClientError('no current Task') - return task - - async def process( - self, - event: StreamResponse, - ) -> Task | None: - """Processes a task-related event (Task, Status, Artifact) and saves the updated task state. - - Ensures task and context IDs match or are set from the event. - - Args: - event: The task-related event (`Task`, `TaskStatusUpdateEvent`, or `TaskArtifactUpdateEvent`). - - Returns: - The updated `Task` object after processing the event. - - Raises: - A2AClientError: If the task ID in the event conflicts with the TaskManager's ID - when the TaskManager's ID is already set. - """ - if event.HasField('message'): - # Messages are not processed here. - return None - - if event.HasField('task'): - if self._current_task: - raise A2AClientError( - 'Task is already set, create new manager for new tasks.' - ) - await self._save_task(event.task) - return event.task - - task = self._current_task - - if event.HasField('status_update'): - status_update = event.status_update - if not task: - task = Task( - status=TaskStatus(state=TaskState.TASK_STATE_UNSPECIFIED), - id=status_update.task_id, - context_id=status_update.context_id, - ) - - logger.debug( - 'Updating task %s status to: %s', - status_update.task_id, - status_update.status.state, - ) - if status_update.status.HasField('message'): - # "Repeated" fields are merged by appending. - task.history.append(status_update.status.message) - - if status_update.metadata: - task.metadata.MergeFrom(status_update.metadata) - - task.status.CopyFrom(status_update.status) - await self._save_task(task) - - if event.HasField('artifact_update'): - artifact_update = event.artifact_update - if not task: - task = Task( - status=TaskStatus(state=TaskState.TASK_STATE_UNSPECIFIED), - id=artifact_update.task_id, - context_id=artifact_update.context_id, - ) - - logger.debug('Appending artifact to task %s', task.id) - append_artifact_to_task(task, artifact_update) - await self._save_task(task) - - return self._current_task - - async def _save_task(self, task: Task) -> None: - """Saves the given task to the `_current_task` and updated `_task_id` and `_context_id`. - - Args: - task: The `Task` object to save. - """ - logger.debug('Saving task with id: %s', task.id) - self._current_task = task - if not self._task_id: - logger.info('New task created with id: %s', task.id) - self._task_id = task.id - self._context_id = task.context_id - - def update_with_message(self, message: Message, task: Task) -> Task: - """Updates a task object adding a new message to its history. - - If the task has a message in its current status, that message is moved - to the history first. - - Args: - message: The new `Message` to add to the history. - task: The `Task` object to update. - - Returns: - The updated `Task` object (updated in-place). - """ - if task.status.HasField('message'): - task.history.append(task.status.message) - task.status.ClearField('message') - - task.history.append(message) - self._current_task = task - return task diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index 4aa243377..f782b33d8 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -72,7 +72,6 @@ def base_client( card=sample_agent_card, config=config, transport=mock_transport, - consumers=[], interceptors=[], ) @@ -152,10 +151,8 @@ async def create_stream(*args, **kwargs): assert not mock_transport.send_message.called assert len(events) == 1 # events[0] is (StreamResponse, Task) tuple - stream_response, tracked_task = events[0] - assert stream_response.task.id == 'task-123' - assert tracked_task is not None - assert tracked_task.id == 'task-123' + response = events[0] + assert response.task.id == 'task-123' @pytest.mark.asyncio async def test_send_message_non_streaming( @@ -183,10 +180,8 @@ async def test_send_message_non_streaming( assert mock_transport.send_message.call_args[0][0].metadata == meta assert not mock_transport.send_message_streaming.called assert len(events) == 1 - stream_response, tracked_task = events[0] - assert stream_response.task.id == 'task-456' - assert tracked_task is not None - assert tracked_task.id == 'task-456' + response = events[0] + assert response.task.id == 'task-456' @pytest.mark.asyncio async def test_send_message_non_streaming_agent_capability_false( @@ -211,10 +206,8 @@ async def test_send_message_non_streaming_agent_capability_false( mock_transport.send_message.assert_called_once() assert not mock_transport.send_message_streaming.called assert len(events) == 1 - stream_response, tracked_task = events[0] - assert stream_response is not None - assert tracked_task is not None - assert tracked_task.id == 'task-789' + response = events[0] + assert response.task.id == 'task-789' @pytest.mark.asyncio async def test_send_message_callsite_config_overrides_non_streaming( @@ -244,8 +237,8 @@ async def test_send_message_callsite_config_overrides_non_streaming( mock_transport.send_message.assert_called_once() assert not mock_transport.send_message_streaming.called assert len(events) == 1 - stream_response, _ = events[0] - assert stream_response.task.id == 'task-cfg-ns-1' + response = events[0] + assert response.task.id == 'task-cfg-ns-1' params = mock_transport.send_message.call_args[0][0] assert params.configuration.history_length == 2 @@ -286,8 +279,8 @@ async def create_stream(*args, **kwargs): mock_transport.send_message_streaming.assert_called_once() assert not mock_transport.send_message.called assert len(events) == 1 - stream_response, _ = events[0] - assert stream_response.task.id == 'task-cfg-s-1' + response = events[0] + assert response.task.id == 'task-cfg-s-1' params = mock_transport.send_message_streaming.call_args[0][0] assert params.configuration.history_length == 0 diff --git a/tests/client/test_base_client_interceptors.py b/tests/client/test_base_client_interceptors.py index 0e7328440..d7930062f 100644 --- a/tests/client/test_base_client_interceptors.py +++ b/tests/client/test_base_client_interceptors.py @@ -57,7 +57,6 @@ def base_client( card=sample_agent_card, config=config, transport=mock_transport, - consumers=[], interceptors=[mock_interceptor], ) diff --git a/tests/client/test_client_factory.py b/tests/client/test_client_factory.py index 1ad3c4c93..a5366e0d3 100644 --- a/tests/client/test_client_factory.py +++ b/tests/client/test_client_factory.py @@ -268,24 +268,21 @@ def custom_transport_producer(*args, **kwargs): @pytest.mark.asyncio -async def test_client_factory_connect_with_consumers_and_interceptors( +async def test_client_factory_connect_with_interceptors( base_agent_card: AgentCard, ): - """Verify consumers and interceptors are passed through correctly.""" - consumer1 = MagicMock() + """Verify interceptors are passed through correctly.""" interceptor1 = MagicMock() with patch('a2a.client.client_factory.BaseClient') as mock_base_client: await ClientFactory.connect( base_agent_card, - consumers=[consumer1], interceptors=[interceptor1], ) mock_base_client.assert_called_once() call_args = mock_base_client.call_args[0] - assert call_args[3] == [consumer1] - assert call_args[4] == [interceptor1] + assert call_args[3] == [interceptor1] def test_client_factory_applies_tenant_decorator(base_agent_card: AgentCard): diff --git a/tests/client/test_client_task_manager.py b/tests/client/test_client_task_manager.py deleted file mode 100644 index 24f2da69b..000000000 --- a/tests/client/test_client_task_manager.py +++ /dev/null @@ -1,191 +0,0 @@ -from unittest.mock import patch - -import pytest - -from a2a.client.client_task_manager import ClientTaskManager -from a2a.client.errors import A2AClientError -from a2a.types.a2a_pb2 import ( - Artifact, - Message, - Part, - Role, - StreamResponse, - Task, - TaskArtifactUpdateEvent, - TaskState, - TaskStatus, - TaskStatusUpdateEvent, -) - - -@pytest.fixture -def task_manager() -> ClientTaskManager: - return ClientTaskManager() - - -@pytest.fixture -def sample_task() -> Task: - return Task( - id='task123', - context_id='context456', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - - -@pytest.fixture -def sample_message() -> Message: - return Message( - message_id='msg1', - role=Role.ROLE_USER, - parts=[Part(text='Hello')], - ) - - -def test_get_task_no_task_id_returns_none( - task_manager: ClientTaskManager, -) -> None: - assert task_manager.get_task() is None - - -def test_get_task_or_raise_no_task_raises_error( - task_manager: ClientTaskManager, -) -> None: - with pytest.raises(A2AClientError, match='no current Task'): - task_manager.get_task_or_raise() - - -@pytest.mark.asyncio -async def test_process_with_task( - task_manager: ClientTaskManager, sample_task: Task -) -> None: - """Test processing a StreamResponse containing a task.""" - event = StreamResponse(task=sample_task) - result = await task_manager.process(event) - assert result == sample_task - assert task_manager.get_task() == sample_task - assert task_manager._task_id == sample_task.id - assert task_manager._context_id == sample_task.context_id - - -@pytest.mark.asyncio -async def test_process_with_task_already_set_raises_error( - task_manager: ClientTaskManager, sample_task: Task -) -> None: - """Test that processing a second task raises an error.""" - event = StreamResponse(task=sample_task) - await task_manager.process(event) - with pytest.raises( - A2AClientError, - match='Task is already set, create new manager for new tasks.', - ): - await task_manager.process(event) - - -@pytest.mark.asyncio -async def test_process_with_status_update( - task_manager: ClientTaskManager, sample_task: Task, sample_message: Message -) -> None: - """Test processing a status update after a task has been set.""" - # First set the task - task_event = StreamResponse(task=sample_task) - await task_manager.process(task_event) - - # Now process a status update - status_update = TaskStatusUpdateEvent( - task_id=sample_task.id, - context_id=sample_task.context_id, - status=TaskStatus( - state=TaskState.TASK_STATE_COMPLETED, message=sample_message - ), - ) - status_event = StreamResponse(status_update=status_update) - updated_task = await task_manager.process(status_event) - - assert updated_task is not None - assert updated_task.status.state == TaskState.TASK_STATE_COMPLETED - assert len(updated_task.history) == 1 - assert updated_task.history[0].message_id == sample_message.message_id - - -@pytest.mark.asyncio -async def test_process_with_artifact_update( - task_manager: ClientTaskManager, sample_task: Task -) -> None: - """Test processing an artifact update after a task has been set.""" - # First set the task - task_event = StreamResponse(task=sample_task) - await task_manager.process(task_event) - - artifact = Artifact( - artifact_id='art1', parts=[Part(text='artifact content')] - ) - artifact_update = TaskArtifactUpdateEvent( - task_id=sample_task.id, - context_id=sample_task.context_id, - artifact=artifact, - ) - artifact_event = StreamResponse(artifact_update=artifact_update) - - with patch( - 'a2a.client.client_task_manager.append_artifact_to_task' - ) as mock_append: - updated_task = await task_manager.process(artifact_event) - mock_append.assert_called_once_with(updated_task, artifact_update) - - -@pytest.mark.asyncio -async def test_process_creates_task_if_not_exists_on_status_update( - task_manager: ClientTaskManager, -) -> None: - """Test that processing a status update creates a task if none exists.""" - status_update = TaskStatusUpdateEvent( - task_id='new_task', - context_id='new_context', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ) - status_event = StreamResponse(status_update=status_update) - updated_task = await task_manager.process(status_event) - - assert updated_task is not None - assert updated_task.id == 'new_task' - assert updated_task.status.state == TaskState.TASK_STATE_WORKING - - -@pytest.mark.asyncio -async def test_process_with_message_returns_none( - task_manager: ClientTaskManager, sample_message: Message -) -> None: - """Test that processing a message event returns None.""" - event = StreamResponse(message=sample_message) - result = await task_manager.process(event) - assert result is None - - -def test_update_with_message( - task_manager: ClientTaskManager, sample_task: Task, sample_message: Message -) -> None: - """Test updating a task with a new message.""" - updated_task = task_manager.update_with_message(sample_message, sample_task) - assert len(updated_task.history) == 1 - assert updated_task.history[0].message_id == sample_message.message_id - - -def test_update_with_message_moves_status_message( - task_manager: ClientTaskManager, sample_task: Task, sample_message: Message -) -> None: - """Test that status message is moved to history when updating.""" - status_message = Message( - message_id='status_msg', - role=Role.ROLE_AGENT, - parts=[Part(text='Status')], - ) - sample_task.status.message.CopyFrom(status_message) - - updated_task = task_manager.update_with_message(sample_message, sample_task) - - # History should contain both status_message and sample_message - assert len(updated_task.history) == 2 - assert updated_task.history[0].message_id == status_message.message_id - assert updated_task.history[1].message_id == sample_message.message_id - # Status message should be cleared - assert not updated_task.status.HasField('message') diff --git a/tests/e2e/push_notifications/test_default_push_notification_support.py b/tests/e2e/push_notifications/test_default_push_notification_support.py index f7a3da457..053707d62 100644 --- a/tests/e2e/push_notifications/test_default_push_notification_support.py +++ b/tests/e2e/push_notifications/test_default_push_notification_support.py @@ -131,10 +131,7 @@ async def test_notification_triggering_with_in_message_config_e2e( ) ] assert len(responses) == 1 - assert isinstance(responses[0], tuple) - # ClientEvent is tuple[StreamResponse, Task | None] - # responses[0][0] is StreamResponse with task field - stream_response = responses[0][0] + stream_response = responses[0] assert stream_response.HasField('task') task = stream_response.task @@ -189,9 +186,7 @@ async def test_notification_triggering_after_config_change_e2e( ) ] assert len(responses) == 1 - assert isinstance(responses[0], tuple) - # ClientEvent is tuple[StreamResponse, Task | None] - stream_response = responses[0][0] + stream_response = responses[0] assert stream_response.HasField('task') task = stream_response.task assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED diff --git a/tests/integration/cross_version/client_server/client_1_0.py b/tests/integration/cross_version/client_server/client_1_0.py index 537a73602..5a5e192cf 100644 --- a/tests/integration/cross_version/client_server/client_1_0.py +++ b/tests/integration/cross_version/client_server/client_1_0.py @@ -54,8 +54,8 @@ async def test_send_message_stream(client): assert len(events) > 0, 'Expected at least one event' first_event = events[0] - # In v1.0 SDK, send_message returns tuple[StreamResponse, Task | None] - stream_response = first_event[0] + # In v1.0 SDK, send_message returns StreamResponse + stream_response = first_event # Try to find task_id in the oneof fields of StreamResponse task_id = 'unknown' @@ -92,7 +92,7 @@ async def test_send_message_sync(url, protocol_enum): request=SendMessageRequest(message=msg) ): assert event is not None - stream_response = event[0] + stream_response = event status = None if stream_response.HasField('task'): @@ -161,7 +161,7 @@ async def test_subscribe(client, task_id): request=SubscribeToTaskRequest(id=task_id) ): assert event is not None - stream_response = event[0] + stream_response = event if stream_response.HasField('artifact_update'): has_artifact = True artifact = stream_response.artifact_update.artifact diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 8884a5dd8..e00b53c02 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -384,7 +384,6 @@ def grpc_03_setup( card=agent_card, config=ClientConfig(), transport=transport, - consumers=[], interceptors=[], ) return TransportSetup(client=client, handler=handler) @@ -410,7 +409,8 @@ async def test_client_sends_message_streaming(transport_setups) -> None: events = [event async for event in stream] assert len(events) == 1 - _, task = events[0] + event = events[0] + task = event.task assert task is not None assert task.id == TASK_FROM_STREAM.id @@ -439,7 +439,8 @@ async def test_client_sends_message_blocking(transport_setups) -> None: events = [event async for event in client.send_message(request=params)] assert len(events) == 1 - _, task = events[0] + event = events[0] + task = event.task assert task is not None assert task.id == TASK_FROM_BLOCKING.id handler.on_message_send.assert_awaited_once_with(params, ANY) @@ -588,8 +589,7 @@ async def test_client_subscribe(transport_setups) -> None: stream = client.subscribe(request=params) first_event = await stream.__anext__() - _, task = first_event - assert task.id == RESUBSCRIBE_EVENT.task_id + assert first_event.status_update.task_id == RESUBSCRIBE_EVENT.task_id handler.on_subscribe_to_task.assert_called_once() await client.close() @@ -624,7 +624,6 @@ async def test_json_transport_base_client_send_message_with_extensions( card=agent_card, config=ClientConfig(streaming=False), transport=transport, - consumers=[], interceptors=[], ) @@ -797,7 +796,6 @@ async def test_client_get_signed_extended_card( card=agent_card, config=ClientConfig(streaming=False), transport=transport, - consumers=[], interceptors=[], ) @@ -888,7 +886,6 @@ async def test_client_get_signed_base_and_extended_cards( card=base_card, config=ClientConfig(streaming=False), transport=transport, - consumers=[], interceptors=[], ) diff --git a/tests/integration/test_copying_observability.py b/tests/integration/test_copying_observability.py index 9ef1c0483..a207c9b24 100644 --- a/tests/integration/test_copying_observability.py +++ b/tests/integration/test_copying_observability.py @@ -152,9 +152,8 @@ async def test_mutation_observability(agent_card: AgentCard, use_copying: bool): ) ] - task = events[-1][1] - assert task is not None - task_id = task.id + event = events[-1] + task_id = event.status_update.task_id # 2. Second message to mutate it message_to_send_2 = Message( diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index c2d22889b..4987acdb5 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -65,7 +65,7 @@ def assert_artifacts_match(artifacts, expected_artifacts): def assert_events_match(events, expected_events): assert len(events) == len(expected_events) - for (event, _), (expected_type, expected_val) in zip( + for event, (expected_type, expected_val) in zip( events, expected_events, strict=True ): assert event.HasField(expected_type) @@ -320,7 +320,7 @@ async def test_end_to_end_send_message_blocking(transport_setups): ) ] assert len(events) == 1 - response, _ = events[0] + response = events[0] assert response.task.id assert response.task.status.state == TaskState.TASK_STATE_COMPLETED assert_artifacts_match( @@ -358,7 +358,7 @@ async def test_end_to_end_send_message_non_blocking(transport_setups): ) ] assert len(events) == 1 - response, _ = events[0] + response = events[0] assert response.task.id assert response.task.status.state == TaskState.TASK_STATE_SUBMITTED assert_history_matches( @@ -396,7 +396,8 @@ async def test_end_to_end_send_message_streaming(transport_setups): ], ) - task = await client.get_task(request=GetTaskRequest(id=events[0][1].id)) + task_id = events[0].status_update.task_id + task = await client.get_task(request=GetTaskRequest(id=task_id)) assert_history_matches( task.history, [ @@ -424,8 +425,8 @@ async def test_end_to_end_get_task(transport_setups): request=SendMessageRequest(message=message_to_send) ) ] - _, task = events[-1] - task_id = task.id + response = events[0] + task_id = response.status_update.task_id get_request = GetTaskRequest(id=task_id) retrieved_task = await client.get_task(request=get_request) @@ -456,7 +457,7 @@ async def test_end_to_end_list_tasks(transport_setups): expected_task_ids = [] for i in range(total_items): # One event is enough to get the task ID - _, task = await anext( + response = await anext( client.send_message( request=SendMessageRequest( message=Message( @@ -467,7 +468,7 @@ async def test_end_to_end_list_tasks(transport_setups): ) ) ) - expected_task_ids.append(task.id) + expected_task_ids.append(response.status_update.task_id) list_request = ListTasksRequest(page_size=page_size) @@ -522,7 +523,8 @@ async def test_end_to_end_input_required(transport_setups): ], ) - task = await client.get_task(request=GetTaskRequest(id=events[0][1].id)) + task_id = events[0].status_update.task_id + task = await client.get_task(request=GetTaskRequest(id=task_id)) assert task.status.state == TaskState.TASK_STATE_INPUT_REQUIRED assert_history_matches( diff --git a/tests/integration/test_stream_generator_cleanup.py b/tests/integration/test_stream_generator_cleanup.py index 184bf6654..47ab5212f 100644 --- a/tests/integration/test_stream_generator_cleanup.py +++ b/tests/integration/test_stream_generator_cleanup.py @@ -119,7 +119,7 @@ async def test_stream_message_no_athrow(client: BaseClient) -> None: ) ] assert events - assert events[0][0].HasField('message') + assert events[0].HasField('message') gc.collect() await loop.shutdown_asyncgens() From e7a1c50fc3c72dcbef1830d903b45a742ee9af20 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 1 Apr 2026 12:58:25 +0000 Subject: [PATCH 2/3] refactor: simplify client factory interface and standardize streaming return types to StreamResponse --- src/a2a/client/base_client.py | 6 +++--- src/a2a/client/client.py | 4 +--- src/a2a/client/client_factory.py | 7 +++---- tests/client/test_base_client.py | 1 - 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/a2a/client/base_client.py b/src/a2a/client/base_client.py index 7d5a941e8..53fd38cdb 100644 --- a/src/a2a/client/base_client.py +++ b/src/a2a/client/base_client.py @@ -64,7 +64,7 @@ async def send_message( context: Optional client call context. Yields: - An async iterator of `ClientEvent` + An async iterator of `StreamResponse` """ self._apply_client_config(request) if not self._config.streaming or not self._card.capabilities.streaming: @@ -314,7 +314,7 @@ async def subscribe( context: Optional client call context. Yields: - An async iterator of `ClientEvent` objects. + An async iterator of `StreamResponse` objects. Raises: NotImplementedError: If streaming is not supported by the client or server. @@ -432,7 +432,7 @@ async def _execute_stream_with_interceptors( ) before_result = await self._intercept_before(before_args) - if before_result: + if before_result is not None: after_args = AfterArgs( result=before_result['early_return'], method=method, diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 556be9030..1f94a4426 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -132,9 +132,7 @@ async def send_message( This will automatically use the streaming or non-streaming approach as supported by the server and the client config. Client will - aggregate update events and return an iterator of (`Task`,`Update`) - pairs, or a `Message`. Client will also send these values to any - configured `Consumer`s in the client. + aggregate update events and return an iterator of `StreamResponse`. """ return yield diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index 5f07e7713..c5d5e8aa4 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -63,12 +63,11 @@ class ClientFactory: .. code-block:: python - factory = ClientFactory(config, consumers) + factory = ClientFactory(config) # Optionally register custom client implementations factory.register('my_customer_transport', NewCustomTransportClient) - # Then with an agent card make a client with additional consumers and - # interceptors - client = factory.create(card, additional_consumers, interceptors) + # Then with an agent card make a client with additional interceptors + client = factory.create(card, interceptors) Now the client can be used consistently regardless of the transport. This aligns the client configuration with the server's capabilities. diff --git a/tests/client/test_base_client.py b/tests/client/test_base_client.py index f782b33d8..ed49469a7 100644 --- a/tests/client/test_base_client.py +++ b/tests/client/test_base_client.py @@ -150,7 +150,6 @@ async def create_stream(*args, **kwargs): ) assert not mock_transport.send_message.called assert len(events) == 1 - # events[0] is (StreamResponse, Task) tuple response = events[0] assert response.task.id == 'task-123' From ad6cdde3940a9abb84ee927378ac9db99d9d2bae Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Thu, 2 Apr 2026 07:31:46 +0000 Subject: [PATCH 3/3] fix itk tests --- itk/main.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/itk/main.py b/itk/main.py index 45a5ea159..fc5b7d876 100644 --- a/itk/main.py +++ b/itk/main.py @@ -138,17 +138,19 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]: nested_msg = wrap_instruction_to_request(call.instruction) request = SendMessageRequest(message=nested_msg) - results = [] + results: list[str] = [] async for event in client.send_message(request): - # Event is streaming response and task + # Event is StreamResponse logger.info('Event: %s', event) - stream_resp, task = event + stream_resp = event message = None if stream_resp.HasField('message'): message = stream_resp.message - elif task and task.status.HasField('message'): - message = task.status.message + elif stream_resp.HasField( + 'task' + ) and stream_resp.task.status.HasField('message'): + message = stream_resp.task.status.message elif stream_resp.HasField( 'status_update' ) and stream_resp.status_update.status.HasField('message'):