Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions itk/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,19 @@ async def handle_call_agent(call: instruction_pb2.CallAgent) -> list[str]:
nested_msg = wrap_instruction_to_request(call.instruction)
request = SendMessageRequest(message=nested_msg)

results = []
results: list[str] = []
async for event in client.send_message(request):
# Event is streaming response and task
# Event is StreamResponse
logger.info('Event: %s', event)
stream_resp, task = event
stream_resp = event

message = None
if stream_resp.HasField('message'):
message = stream_resp.message
elif task and task.status.HasField('message'):
message = task.status.message
elif stream_resp.HasField(
'task'
) and stream_resp.task.status.HasField('message'):
message = stream_resp.task.status.message
elif stream_resp.HasField(
'status_update'
) and stream_resp.status_update.status.HasField('message'):
Expand Down
4 changes: 0 additions & 4 deletions src/a2a/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
)
from a2a.client.client_factory import ClientFactory, minimal_agent_card
from a2a.client.errors import (
Expand All @@ -35,9 +33,7 @@
'ClientCallContext',
'ClientCallInterceptor',
'ClientConfig',
'ClientEvent',
'ClientFactory',
'Consumer',
'CredentialService',
'InMemoryContextCredentialStore',
'create_text_message_object',
Expand Down
55 changes: 12 additions & 43 deletions src/a2a/client/base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
Client,
ClientCallContext,
ClientConfig,
ClientEvent,
Consumer,
)
from a2a.client.client_task_manager import ClientTaskManager
from a2a.client.interceptors import (
AfterArgs,
BeforeArgs,
Expand Down Expand Up @@ -42,10 +39,9 @@ def __init__(
card: AgentCard,
config: ClientConfig,
transport: ClientTransport,
consumers: list[Consumer],
interceptors: list[ClientCallInterceptor],
):
super().__init__(consumers, interceptors)
super().__init__(interceptors)
self._card = card
self._config = config
self._transport = transport
Expand All @@ -56,7 +52,7 @@ async def send_message(
request: SendMessageRequest,
*,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent]:
) -> AsyncIterator[StreamResponse]:
"""Sends a message to the agent.

This method handles both streaming and non-streaming (polling) interactions
Expand All @@ -68,7 +64,7 @@ async def send_message(
context: Optional client call context.

Yields:
An async iterator of `ClientEvent`
An async iterator of `StreamResponse`
"""
self._apply_client_config(request)
if not self._config.streaming or not self._card.capabilities.streaming:
Expand All @@ -84,19 +80,14 @@ async def send_message(
# In non-streaming case we convert to a StreamResponse so that the
# client always sees the same iterator.
stream_response = StreamResponse()
client_event: ClientEvent
if response.HasField('task'):
stream_response.task.CopyFrom(response.task)
client_event = (stream_response, response.task)
elif response.HasField('message'):
stream_response.message.CopyFrom(response.message)
client_event = (stream_response, None)
else:
# Response must have either task or message
raise ValueError('Response has neither task nor message')

await self.consume(client_event, self._card)
yield client_event
yield stream_response
return

async for event in self._execute_stream_with_interceptors(
Expand Down Expand Up @@ -130,8 +121,7 @@ async def _process_stream(
self,
stream: AsyncIterator[StreamResponse],
before_args: BeforeArgs,
) -> AsyncGenerator[ClientEvent]:
tracker = ClientTaskManager()
) -> AsyncGenerator[StreamResponse, None]:
async for stream_response in stream:
after_args = AfterArgs(
result=stream_response,
Expand All @@ -140,12 +130,8 @@ async def _process_stream(
context=before_args.context,
)
await self._intercept_after(after_args)
intercepted_response = after_args.result
client_event = await self._format_stream_event(
intercepted_response, tracker
)
yield client_event
if intercepted_response.HasField('message'):
yield after_args.result
if after_args.result.HasField('message'):
return

async def get_task(
Expand Down Expand Up @@ -318,7 +304,7 @@ async def subscribe(
request: SubscribeToTaskRequest,
*,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent]:
) -> AsyncIterator[StreamResponse]:
"""Resubscribes to a task's event stream.

This is only available if both the client and server support streaming.
Expand All @@ -328,7 +314,7 @@ async def subscribe(
context: Optional client call context.

Yields:
An async iterator of `ClientEvent` objects.
An async iterator of `StreamResponse` objects.

Raises:
NotImplementedError: If streaming is not supported by the client or server.
Expand Down Expand Up @@ -436,7 +422,7 @@ async def _execute_stream_with_interceptors(
transport_call: Callable[
[Any, ClientCallContext | None], AsyncIterator[StreamResponse]
],
) -> AsyncIterator[ClientEvent]:
) -> AsyncIterator[StreamResponse]:

before_args = BeforeArgs(
input=input_data,
Expand All @@ -446,7 +432,7 @@ async def _execute_stream_with_interceptors(
)
before_result = await self._intercept_before(before_args)

if before_result:
if before_result is not None:
after_args = AfterArgs(
result=before_result['early_return'],
method=method,
Expand All @@ -455,8 +441,7 @@ async def _execute_stream_with_interceptors(
)
await self._intercept_after(after_args, before_result['executed'])

tracker = ClientTaskManager()
yield await self._format_stream_event(after_args.result, tracker)
yield after_args.result
return

stream = transport_call(before_args.input, before_args.context)
Expand Down Expand Up @@ -495,19 +480,3 @@ async def _intercept_after(
await interceptor.after(args)
if args.early_return:
return

async def _format_stream_event(
self, stream_response: StreamResponse, tracker: ClientTaskManager
) -> ClientEvent:
client_event: ClientEvent
if stream_response.HasField('message'):
client_event = (stream_response, None)
await self.consume(client_event, self._card)
return client_event

await tracker.process(stream_response)
updated_task = tracker.get_task_or_raise()
client_event = (stream_response, updated_task)

await self.consume(client_event, self._card)
return client_event
35 changes: 5 additions & 30 deletions src/a2a/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging

from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Callable, Coroutine, MutableMapping
from collections.abc import AsyncIterator, Callable, MutableMapping
from types import TracebackType
from typing import Any

Expand Down Expand Up @@ -77,13 +77,6 @@
"""Push notification configurations to use for every request."""


ClientEvent = tuple[StreamResponse, Task | None]

# Alias for an event consuming callback. It takes either a (task, update) pair
# or a message as well as the agent card for the agent this came from.
Consumer = Callable[[ClientEvent, AgentCard], Coroutine[None, Any, Any]]


class ClientCallContext(BaseModel):
"""A context passed with each client call, allowing for call-specific.

Expand All @@ -106,16 +99,13 @@

def __init__(
self,
consumers: list[Consumer] | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
):
"""Initializes the client with consumers and interceptors.
"""Initializes the client with interceptors.

Args:
consumers: A list of callables to process events from the agent.
interceptors: A list of interceptors to process requests and responses.
"""
self._consumers = consumers or []
self._interceptors = interceptors or []

async def __aenter__(self) -> Self:
Expand All @@ -137,88 +127,86 @@
request: SendMessageRequest,
*,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent]:
) -> AsyncIterator[StreamResponse]:
"""Sends a message to the server.

This will automatically use the streaming or non-streaming approach
as supported by the server and the client config. Client will
aggregate update events and return an iterator of (`Task`,`Update`)
pairs, or a `Message`. Client will also send these values to any
configured `Consumer`s in the client.
aggregate update events and return an iterator of `StreamResponse`.
"""
return
yield

@abstractmethod
async def get_task(
self,
request: GetTaskRequest,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Retrieves the current state and history of a specific task."""

@abstractmethod
async def list_tasks(
self,
request: ListTasksRequest,
*,
context: ClientCallContext | None = None,
) -> ListTasksResponse:
"""Retrieves tasks for an agent."""

@abstractmethod
async def cancel_task(
self,
request: CancelTaskRequest,
*,
context: ClientCallContext | None = None,
) -> Task:
"""Requests the agent to cancel a specific task."""

@abstractmethod
async def create_task_push_notification_config(
self,
request: TaskPushNotificationConfig,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Sets or updates the push notification configuration for a specific task."""

@abstractmethod
async def get_task_push_notification_config(
self,
request: GetTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Retrieves the push notification configuration for a specific task."""

@abstractmethod
async def list_task_push_notification_configs(
self,
request: ListTaskPushNotificationConfigsRequest,
*,
context: ClientCallContext | None = None,
) -> ListTaskPushNotificationConfigsResponse:
"""Lists push notification configurations for a specific task."""

@abstractmethod
async def delete_task_push_notification_config(
self,
request: DeleteTaskPushNotificationConfigRequest,
*,
context: ClientCallContext | None = None,
) -> None:
"""Deletes the push notification configuration for a specific task."""

@abstractmethod
async def subscribe(
self,
request: SubscribeToTaskRequest,
*,
context: ClientCallContext | None = None,
) -> AsyncIterator[ClientEvent]:
) -> AsyncIterator[StreamResponse]:

Check notice on line 209 in src/a2a/client/client.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Copy/pasted code

see src/a2a/client/transports/base.py (61-133)
"""Resubscribes to a task's event stream."""
return
yield
Expand All @@ -233,23 +221,10 @@
) -> AgentCard:
"""Retrieves the agent's card."""

async def add_event_consumer(self, consumer: Consumer) -> None:
"""Attaches additional consumers to the `Client`."""
self._consumers.append(consumer)

async def add_interceptor(self, interceptor: ClientCallInterceptor) -> None:
"""Attaches additional interceptors to the `Client`."""
self._interceptors.append(interceptor)

async def consume(
self,
event: ClientEvent,
card: AgentCard,
) -> None:
"""Processes the event via all the registered `Consumer`s."""
for c in self._consumers:
await c(event, card)

@abstractmethod
async def close(self) -> None:
"""Closes the client and releases any underlying resources."""
26 changes: 6 additions & 20 deletions src/a2a/client/client_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from a2a.client.base_client import BaseClient
from a2a.client.card_resolver import A2ACardResolver
from a2a.client.client import Client, ClientConfig, Consumer
from a2a.client.client import Client, ClientConfig
from a2a.client.transports.base import ClientTransport
from a2a.client.transports.jsonrpc import JsonRpcTransport
from a2a.client.transports.rest import RestTransport
Expand Down Expand Up @@ -63,12 +63,11 @@ class ClientFactory:

.. code-block:: python

factory = ClientFactory(config, consumers)
factory = ClientFactory(config)
# Optionally register custom client implementations
factory.register('my_customer_transport', NewCustomTransportClient)
# Then with an agent card make a client with additional consumers and
# interceptors
client = factory.create(card, additional_consumers, interceptors)
# Then with an agent card make a client with additional interceptors
client = factory.create(card, interceptors)

Now the client can be used consistently regardless of the transport. This
aligns the client configuration with the server's capabilities.
Expand All @@ -77,17 +76,12 @@ class ClientFactory:
def __init__(
self,
config: ClientConfig,
consumers: list[Consumer] | None = None,
):
if consumers is None:
consumers = []

client = config.httpx_client or httpx.AsyncClient()
client.headers.setdefault(VERSION_HEADER, PROTOCOL_VERSION_CURRENT)
config.httpx_client = client

self._config = config
self._consumers = consumers
self._registry: dict[str, TransportProducer] = {}
self._register_defaults(config.supported_protocol_bindings)

Expand Down Expand Up @@ -263,7 +257,6 @@ async def connect( # noqa: PLR0913
cls,
agent: str | AgentCard,
client_config: ClientConfig | None = None,
consumers: list[Consumer] | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
relative_card_path: str | None = None,
resolver_http_kwargs: dict[str, Any] | None = None,
Expand All @@ -286,7 +279,7 @@ async def connect( # noqa: PLR0913
Args:
agent: The base URL of the agent, or the AgentCard to connect to.
client_config: The ClientConfig to use when connecting to the agent.
consumers: A list of `Consumer` methods to pass responses to.

interceptors: A list of interceptors to use for each request. These
are used for things like attaching credentials or http headers
to all outbound requests.
Expand Down Expand Up @@ -325,7 +318,7 @@ async def connect( # noqa: PLR0913
factory = cls(client_config)
for label, generator in (extra_transports or {}).items():
factory.register(label, generator)
return factory.create(card, consumers, interceptors)
return factory.create(card, interceptors)

def register(self, label: str, generator: TransportProducer) -> None:
"""Register a new transport producer for a given transport label."""
Expand All @@ -334,14 +327,12 @@ def register(self, label: str, generator: TransportProducer) -> None:
def create(
self,
card: AgentCard,
consumers: list[Consumer] | None = None,
interceptors: list[ClientCallInterceptor] | None = None,
) -> Client:
"""Create a new `Client` for the provided `AgentCard`.

Args:
card: An `AgentCard` defining the characteristics of the agent.
consumers: A list of `Consumer` methods to pass responses to.
interceptors: A list of interceptors to use for each request. These
are used for things like attaching credentials or http headers
to all outbound requests.
Expand Down Expand Up @@ -381,10 +372,6 @@ def create(
if transport_protocol not in self._registry:
raise ValueError(f'no client available for {transport_protocol}')

all_consumers = self._consumers.copy()
if consumers:
all_consumers.extend(consumers)

transport = self._registry[transport_protocol](
card, selected_interface.url, self._config
)
Expand All @@ -398,7 +385,6 @@ def create(
card,
self._config,
transport,
all_consumers,
interceptors or [],
)

Expand Down
Loading
Loading