From 4d88e6f53d2b1889e6f1132006df35af3a7a9a4e Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Thu, 26 Mar 2026 02:50:18 +0800 Subject: [PATCH] fix: complete schema v0.11.2 follow-ups --- src/acp/agent/router.py | 45 ++++++++++++--- src/acp/client/connection.py | 34 +++++++++--- src/acp/interfaces.py | 7 ++- src/acp/router.py | 52 +++++++++++------- src/acp/utils.py | 104 +++++++++++++++++++++++++++++++++-- tests/conftest.py | 14 ++++- tests/test_compatibility.py | 43 +++++++++++++++ tests/test_rpc.py | 24 ++++++++ tests/test_utils.py | 41 +++++++++++++- 9 files changed, 317 insertions(+), 47 deletions(-) diff --git a/src/acp/agent/router.py b/src/acp/agent/router.py index 26db13f..2a27bcd 100644 --- a/src/acp/agent/router.py +++ b/src/acp/agent/router.py @@ -2,10 +2,12 @@ from typing import Any +from pydantic import BaseModel + from ..exceptions import RequestError from ..interfaces import Agent from ..meta import AGENT_METHODS -from ..router import MessageRouter +from ..router import MessageRouter, Route, _resolve_handler, _warn_legacy_handler from ..schema import ( AuthenticateRequest, CancelNotification, @@ -17,15 +19,41 @@ NewSessionRequest, PromptRequest, ResumeSessionRequest, + SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest, SetSessionModelRequest, SetSessionModeRequest, ) -from ..utils import normalize_result +from ..utils import model_to_kwargs, normalize_result __all__ = ["build_agent_router"] +_SET_CONFIG_OPTION_MODELS = (SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest) + + +def _validate_set_config_option_request(params: Any) -> BaseModel: + if isinstance(params, dict) and params.get("type") == "boolean": + return SetSessionConfigOptionBooleanRequest.model_validate(params) + return SetSessionConfigOptionSelectRequest.model_validate(params) + + +def _make_set_config_option_handler(agent: Agent) -> Any: + func, attr, legacy_api = _resolve_handler(agent, "set_config_option") + if func is None: + return None + + async def wrapper(params: Any) -> Any: + if legacy_api: + _warn_legacy_handler(agent, attr) + request = _validate_set_config_option_request(params) + if legacy_api: + return await func(request) + return await func(**model_to_kwargs(request, _SET_CONFIG_OPTION_MODELS)) + + return wrapper + + def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> MessageRouter: router = MessageRouter(use_unstable_protocol=use_unstable_protocol) @@ -63,12 +91,13 @@ def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> Mes adapt_result=normalize_result, unstable=True, ) - router.route_request( - AGENT_METHODS["session_set_config_option"], - SetSessionConfigOptionSelectRequest, - agent, - "set_config_option", - adapt_result=normalize_result, + router.add_route( + Route( + method=AGENT_METHODS["session_set_config_option"], + func=_make_set_config_option_handler(agent), + kind="request", + adapt_result=normalize_result, + ) ) router.route_request( AGENT_METHODS["authenticate"], diff --git a/src/acp/client/connection.py b/src/acp/client/connection.py index d3471b6..bb8fdc2 100644 --- a/src/acp/client/connection.py +++ b/src/acp/client/connection.py @@ -35,6 +35,7 @@ ResourceContentBlock, ResumeSessionRequest, ResumeSessionResponse, + SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionResponse, SetSessionConfigOptionSelectRequest, SetSessionModelRequest, @@ -44,7 +45,7 @@ SseMcpServer, TextContentBlock, ) -from ..utils import compatible_class, notify_model, param_model, request_model, request_model_from_dict +from ..utils import compatible_class, notify_model, param_model, param_models, request_model, request_model_from_dict from .router import build_client_router __all__ = ["ClientSideConnection"] @@ -154,16 +155,30 @@ async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any) SetSessionModelResponse, ) - @param_model(SetSessionConfigOptionSelectRequest) + @param_models(SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest) async def set_config_option( - self, config_id: str, session_id: str, value: str, **kwargs: Any + self, config_id: str, session_id: str, value: str | bool, **kwargs: Any ) -> SetSessionConfigOptionResponse: + request = ( + SetSessionConfigOptionBooleanRequest( + config_id=config_id, + session_id=session_id, + type="boolean", + value=value, + field_meta=kwargs or None, + ) + if isinstance(value, bool) + else SetSessionConfigOptionSelectRequest( + config_id=config_id, + session_id=session_id, + value=value, + field_meta=kwargs or None, + ) + ) return await request_model_from_dict( self._conn, AGENT_METHODS["session_set_config_option"], - SetSessionConfigOptionSelectRequest( - config_id=config_id, session_id=session_id, value=value, field_meta=kwargs or None - ), + request, SetSessionConfigOptionResponse, ) @@ -193,7 +208,12 @@ async def prompt( return await request_model( self._conn, AGENT_METHODS["session_prompt"], - PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None), + PromptRequest( + prompt=prompt, + session_id=session_id, + message_id=message_id, + field_meta=kwargs or None, + ), PromptResponse, ) diff --git a/src/acp/interfaces.py b/src/acp/interfaces.py index f33f403..2decdb2 100644 --- a/src/acp/interfaces.py +++ b/src/acp/interfaces.py @@ -50,6 +50,7 @@ ResumeSessionResponse, SessionInfoUpdate, SessionNotification, + SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionResponse, SetSessionConfigOptionSelectRequest, SetSessionModelRequest, @@ -70,7 +71,7 @@ WriteTextFileRequest, WriteTextFileResponse, ) -from .utils import param_model +from .utils import param_model, param_models __all__ = ["Agent", "Client"] @@ -181,9 +182,9 @@ async def set_session_model( self, model_id: str, session_id: str, **kwargs: Any ) -> SetSessionModelResponse | None: ... - @param_model(SetSessionConfigOptionSelectRequest) + @param_models(SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest) async def set_config_option( - self, config_id: str, session_id: str, value: str, **kwargs: Any + self, config_id: str, session_id: str, value: str | bool, **kwargs: Any ) -> SetSessionConfigOptionResponse | None: ... @param_model(AuthenticateRequest) diff --git a/src/acp/router.py b/src/acp/router.py index 2aa3c24..3069deb 100644 --- a/src/acp/router.py +++ b/src/acp/router.py @@ -20,6 +20,34 @@ HandlerT = TypeVar("HandlerT", bound=RequestHandler) +def _warn_legacy_handler(obj: Any, attr: str) -> None: + warnings.warn( + f"The old style method {type(obj).__name__}.{attr} is deprecated, please update to the snake-cased form.", + DeprecationWarning, + stacklevel=3, + ) + + +def _resolve_handler(obj: Any, attr: str) -> tuple[AsyncHandler | None, str, bool]: + legacy_api = False + func = getattr(obj, attr, None) + if func is None and "_" in attr: + attr = to_camel_case(attr) + func = getattr(obj, attr, None) + legacy_api = True + elif callable(func) and "_" not in attr: + original_func = func + if hasattr(func, "__func__"): + original_func = func.__func__ + parameters = inspect.signature(original_func).parameters + if len(parameters) == 2 and "params" in parameters: + legacy_api = True + + if func is None or not callable(func): + return None, attr, legacy_api + return func, attr, legacy_api + + @dataclass(slots=True) class Route: method: str @@ -63,31 +91,13 @@ def add_route(self, route: Route) -> None: self._notifications[route.method] = route def _make_func(self, model: type[BaseModel], obj: Any, attr: str) -> AsyncHandler | None: - legacy_api = False - func = getattr(obj, attr, None) - if func is None and "_" in attr: - attr = to_camel_case(attr) - func = getattr(obj, attr, None) - legacy_api = True - elif callable(func) and "_" not in attr: - original_func = func - if hasattr(func, "__func__"): - original_func = func.__func__ - parameters = inspect.signature(original_func).parameters - if len(parameters) == 2 and "params" in parameters: - legacy_api = True - - if func is None or not callable(func): + func, attr, legacy_api = _resolve_handler(obj, attr) + if func is None: return None async def wrapper(params: Any) -> Any: if legacy_api: - warnings.warn( - f"The old style method {type(obj).__name__}.{attr} is deprecated, " - "please update to the snake-cased form.", - DeprecationWarning, - stacklevel=3, - ) + _warn_legacy_handler(obj, attr) model_obj = model.model_validate(params) if legacy_api: return await func(model_obj) # type: ignore[arg-type] diff --git a/src/acp/utils.py b/src/acp/utils.py index 1be9c19..3d62496 100644 --- a/src/acp/utils.py +++ b/src/acp/utils.py @@ -26,6 +26,29 @@ MethodT = TypeVar("MethodT", bound=Callable) ClassT = TypeVar("ClassT", bound=type) T = TypeVar("T") +MultiParamModelSpec = tuple[type[BaseModel], ...] + + +def _param_models_name(models: MultiParamModelSpec) -> str: + return " | ".join(model_type.__name__ for model_type in models) + + +def _param_models_field_names(models: MultiParamModelSpec) -> tuple[str, ...]: + shared_fields = set(models[0].model_fields) + for model_type in models[1:]: + shared_fields &= set(model_type.model_fields) + return tuple(field_name for field_name in models[0].model_fields if field_name in shared_fields) + + +def model_to_kwargs(model_obj: BaseModel, models: MultiParamModelSpec) -> dict[str, Any]: + kwargs = { + field_name: getattr(model_obj, field_name) + for field_name in _param_models_field_names(models) + if field_name != "field_meta" + } + if meta := getattr(model_obj, "field_meta", None): + kwargs.update(meta) + return kwargs def serialize_params(params: BaseModel) -> dict[str, Any]: @@ -114,6 +137,18 @@ def decorator(func: MethodT) -> MethodT: return decorator +def param_models(*param_cls: type[BaseModel]) -> Callable[[MethodT], MethodT]: + """Decorator to mark a method as accepting multiple legacy parameter models.""" + if not param_cls: + raise ValueError("param_models() requires at least one model class") + + def decorator(func: MethodT) -> MethodT: + func.__param_models__ = param_cls # type: ignore[attr-defined] + return func + + return decorator + + def to_camel_case(snake_str: str) -> str: """Convert snake_case strings to camelCase.""" components = snake_str.split("_") @@ -129,7 +164,9 @@ def wrapped(self, params: BaseModel) -> T: DeprecationWarning, stacklevel=3, ) - kwargs = {k: getattr(params, k) for k in model.model_fields if k != "field_meta"} + kwargs = { + field_name: getattr(params, field_name) for field_name in model.model_fields if field_name != "field_meta" + } if meta := getattr(params, "field_meta", None): kwargs.update(meta) return func(self, **kwargs) # type: ignore[arg-type] @@ -152,7 +189,11 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T: DeprecationWarning, stacklevel=3, ) - kwargs = {k: getattr(param, k) for k in model.model_fields if k != "field_meta"} + kwargs = { + field_name: getattr(param, field_name) + for field_name in model.model_fields + if field_name != "field_meta" + } if meta := getattr(param, "field_meta", None): kwargs.update(meta) return func(self, **kwargs) # type: ignore[arg-type] @@ -161,14 +202,67 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T: return wrapped +def _make_multi_legacy_func(func: Callable[..., T], models: MultiParamModelSpec) -> Callable[[Any, BaseModel], T]: + model_name = _param_models_name(models) + + @functools.wraps(func) + def wrapped(self, params: BaseModel) -> T: + warnings.warn( + f"Calling {func.__name__} with {model_name} parameter is " # type: ignore[attr-defined] + "deprecated, please update to the new API style.", + DeprecationWarning, + stacklevel=3, + ) + return func(self, **model_to_kwargs(params, models)) # type: ignore[arg-type] + + return wrapped + + +def _make_multi_compatible_func(func: Callable[..., T], models: MultiParamModelSpec) -> Callable[..., T]: + model_name = _param_models_name(models) + + @functools.wraps(func) + def wrapped(self, *args: Any, **kwargs: Any) -> T: + param = None + if not kwargs and len(args) == 1: + param = args[0] + elif not args and len(kwargs) == 1: + param = kwargs.get("params") + if isinstance(param, models): + warnings.warn( + f"Calling {func.__name__} with {model_name} parameter " # type: ignore[attr-defined] + "is deprecated, please update to the new API style.", + DeprecationWarning, + stacklevel=3, + ) + return func(self, **model_to_kwargs(param, models)) # type: ignore[arg-type] + return func(self, *args, **kwargs) + + return wrapped + + def compatible_class(cls: ClassT) -> ClassT: """Mark a class as backward compatible with old API style.""" for attr in dir(cls): func = getattr(cls, attr) - if not callable(func) or (model := getattr(func, "__param_model__", None)) is None: + if not callable(func): + continue + model = getattr(func, "__param_model__", None) + models = getattr(func, "__param_models__", None) + if model is None and models is None: continue if "_" in attr: - setattr(cls, to_camel_case(attr), _make_legacy_func(func, model)) + if models is not None: + setattr(cls, to_camel_case(attr), _make_multi_legacy_func(func, models)) + else: + if model is None: + continue + setattr(cls, to_camel_case(attr), _make_legacy_func(func, model)) else: - setattr(cls, attr, _make_compatible_func(func, model)) + if models is not None: + setattr(cls, attr, _make_multi_compatible_func(func, models)) + else: + if model is None: + continue + setattr(cls, attr, _make_compatible_func(func, model)) return cls diff --git a/tests/conftest.py b/tests/conftest.py index 825610a..f154167 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -231,6 +231,7 @@ class TestAgent: def __init__(self) -> None: self.prompts: list[PromptRequest] = [] self.cancellations: list[str] = [] + self.config_option_calls: list[tuple[str, str, str | bool]] = [] self.ext_calls: list[tuple[str, dict]] = [] self.ext_notes: list[tuple[str, dict]] = [] @@ -267,9 +268,17 @@ async def prompt( | EmbeddedResourceContentBlock ], session_id: str, + message_id: str | None = None, **kwargs: Any, ) -> PromptResponse: - self.prompts.append(PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None)) + self.prompts.append( + PromptRequest( + prompt=prompt, + session_id=session_id, + message_id=message_id, + field_meta=kwargs or None, + ) + ) return PromptResponse(stop_reason="end_turn") async def cancel(self, session_id: str, **kwargs: Any) -> None: @@ -284,8 +293,9 @@ async def set_session_mode(self, mode_id: str, session_id: str, **kwargs: Any) - return SetSessionModeResponse() async def set_config_option( - self, config_id: str, session_id: str, value: str, **kwargs: Any + self, config_id: str, session_id: str, value: str | bool, **kwargs: Any ) -> SetSessionConfigOptionResponse | None: + self.config_option_calls.append((config_id, session_id, value)) return SetSessionConfigOptionResponse(config_options=[]) async def ext_method(self, method: str, params: dict) -> dict: diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py index 013427e..cdca9ad 100644 --- a/tests/test_compatibility.py +++ b/tests/test_compatibility.py @@ -11,6 +11,7 @@ RequestError, RequestPermissionResponse, SessionNotification, + SetSessionConfigOptionResponse, SetSessionModelResponse, SetSessionModeResponse, WriteTextFileResponse, @@ -25,6 +26,8 @@ NewSessionRequest, ReadTextFileRequest, RequestPermissionRequest, + SetSessionConfigOptionBooleanRequest, + SetSessionConfigOptionSelectRequest, SetSessionModelRequest, SetSessionModeRequest, WriteTextFileRequest, @@ -34,6 +37,9 @@ class LegacyAgent: def __init__(self) -> None: self.prompts: list[PromptRequest] = [] + self.config_option_requests: list[ + SetSessionConfigOptionBooleanRequest | SetSessionConfigOptionSelectRequest + ] = [] self.cancellations: list[str] = [] self.ext_calls: list[tuple[str, dict]] = [] self.ext_notes: list[tuple[str, dict]] = [] @@ -64,6 +70,12 @@ async def setSessionMode(self, params: SetSessionModeRequest) -> SetSessionModeR async def setSessionModel(self, params: SetSessionModelRequest) -> SetSessionModelResponse | None: return SetSessionModelResponse() + async def setConfigOption( + self, params: SetSessionConfigOptionBooleanRequest | SetSessionConfigOptionSelectRequest + ) -> SetSessionConfigOptionResponse | None: + self.config_option_requests.append(params) + return SetSessionConfigOptionResponse(config_options=[]) + async def extMethod(self, method: str, params: dict) -> dict: self.ext_calls.append((method, params)) if method == "example.com/echo": @@ -167,3 +179,34 @@ async def test_initialize_and_new_session_compat(connect, client): assert len(record) == 1 assert resp.content == "Hello, World!" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("agent,client", [(LegacyAgent(), LegacyClient())]) +async def test_set_config_option_boolean_compat(connect, agent): + _, agent_conn = connect() + + with pytest.warns(DeprecationWarning) as record: + resp = await agent_conn.setConfigOption( + SetSessionConfigOptionBooleanRequest( + config_id="brave_mode", + session_id="test-session-123", + type="boolean", + value=True, + ) + ) + + assert len(record) == 2 + assert "SetSessionConfigOptionBooleanRequest | SetSessionConfigOptionSelectRequest parameter is deprecated" in str( + record[0].message + ) + assert "The old style method LegacyAgent.setConfigOption is deprecated" in str(record[1].message) + assert isinstance(resp, SetSessionConfigOptionResponse) + assert agent.config_option_requests == [ + SetSessionConfigOptionBooleanRequest( + config_id="brave_mode", + session_id="test-session-123", + type="boolean", + value=True, + ) + ] diff --git a/tests/test_rpc.py b/tests/test_rpc.py index bcf068e..0d3bb75 100644 --- a/tests/test_rpc.py +++ b/tests/test_rpc.py @@ -262,6 +262,30 @@ async def test_set_config_option(connect, agent, client): resp = await agent_conn.set_config_option(session_id="sess", config_id="theme", value="dark") assert isinstance(resp, SetSessionConfigOptionResponse) assert resp.config_options == [] + assert agent.config_option_calls == [("theme", "sess", "dark")] + + +@pytest.mark.asyncio +async def test_set_config_option_boolean(connect, agent, client): + _, agent_conn = connect() + + resp = await agent_conn.set_config_option(session_id="sess", config_id="brave_mode", value=True) + assert isinstance(resp, SetSessionConfigOptionResponse) + assert resp.config_options == [] + assert agent.config_option_calls == [("brave_mode", "sess", True)] + + +@pytest.mark.asyncio +async def test_prompt_message_id_roundtrip(connect, agent, client): + _, agent_conn = connect() + + resp = await agent_conn.prompt( + session_id="sess", + prompt=[TextContentBlock(type="text", text="hello")], + message_id="123e4567-e89b-12d3-a456-426614174000", + ) + assert isinstance(resp, PromptResponse) + assert agent.prompts[-1].message_id == "123e4567-e89b-12d3-a456-426614174000" @pytest.mark.asyncio diff --git a/tests/test_utils.py b/tests/test_utils.py index 47706d9..bf00257 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,11 @@ import pytest -from acp.schema import AgentMessageChunk, TextContentBlock +from acp.schema import ( + AgentMessageChunk, + SetSessionConfigOptionBooleanRequest, + SetSessionConfigOptionSelectRequest, + TextContentBlock, +) from acp.utils import serialize_params @@ -40,6 +45,40 @@ def test_field_meta_can_be_set_by_name_on_models() -> None: assert chunk.content.field_meta == {"inner": "value"} +def test_serialize_params_uses_boolean_config_variant() -> None: + request = SetSessionConfigOptionBooleanRequest( + config_id="brave_mode", + session_id="sess", + type="boolean", + value=True, + ) + + payload = serialize_params(request) + + assert payload == { + "configId": "brave_mode", + "sessionId": "sess", + "type": "boolean", + "value": True, + } + + +def test_serialize_params_uses_select_config_variant() -> None: + request = SetSessionConfigOptionSelectRequest( + config_id="theme", + session_id="sess", + value="dark", + ) + + payload = serialize_params(request) + + assert payload == { + "configId": "theme", + "sessionId": "sess", + "value": "dark", + } + + @pytest.mark.parametrize( "original, expected", [