Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 37 additions & 8 deletions src/acp/agent/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from typing import Any

from pydantic import BaseModel

from ..exceptions import RequestError
from ..interfaces import Agent
from ..meta import AGENT_METHODS
from ..router import MessageRouter
from ..router import MessageRouter, Route, _resolve_handler, _warn_legacy_handler
from ..schema import (
AuthenticateRequest,
CancelNotification,
Expand All @@ -17,15 +19,41 @@
NewSessionRequest,
PromptRequest,
ResumeSessionRequest,
SetSessionConfigOptionBooleanRequest,
SetSessionConfigOptionSelectRequest,
SetSessionModelRequest,
SetSessionModeRequest,
)
from ..utils import normalize_result
from ..utils import model_to_kwargs, normalize_result

__all__ = ["build_agent_router"]


_SET_CONFIG_OPTION_MODELS = (SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest)


def _validate_set_config_option_request(params: Any) -> BaseModel:
if isinstance(params, dict) and params.get("type") == "boolean":
return SetSessionConfigOptionBooleanRequest.model_validate(params)
return SetSessionConfigOptionSelectRequest.model_validate(params)


def _make_set_config_option_handler(agent: Agent) -> Any:
func, attr, legacy_api = _resolve_handler(agent, "set_config_option")
if func is None:
return None

async def wrapper(params: Any) -> Any:
if legacy_api:
_warn_legacy_handler(agent, attr)
request = _validate_set_config_option_request(params)
if legacy_api:
return await func(request)
return await func(**model_to_kwargs(request, _SET_CONFIG_OPTION_MODELS))

return wrapper


def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> MessageRouter:
router = MessageRouter(use_unstable_protocol=use_unstable_protocol)

Expand Down Expand Up @@ -63,12 +91,13 @@ def build_agent_router(agent: Agent, use_unstable_protocol: bool = False) -> Mes
adapt_result=normalize_result,
unstable=True,
)
router.route_request(
AGENT_METHODS["session_set_config_option"],
SetSessionConfigOptionSelectRequest,
agent,
"set_config_option",
adapt_result=normalize_result,
router.add_route(
Route(
method=AGENT_METHODS["session_set_config_option"],
func=_make_set_config_option_handler(agent),
kind="request",
adapt_result=normalize_result,
)
)
router.route_request(
AGENT_METHODS["authenticate"],
Expand Down
34 changes: 27 additions & 7 deletions src/acp/client/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ResourceContentBlock,
ResumeSessionRequest,
ResumeSessionResponse,
SetSessionConfigOptionBooleanRequest,
SetSessionConfigOptionResponse,
SetSessionConfigOptionSelectRequest,
SetSessionModelRequest,
Expand All @@ -44,7 +45,7 @@
SseMcpServer,
TextContentBlock,
)
from ..utils import compatible_class, notify_model, param_model, request_model, request_model_from_dict
from ..utils import compatible_class, notify_model, param_model, param_models, request_model, request_model_from_dict
from .router import build_client_router

__all__ = ["ClientSideConnection"]
Expand Down Expand Up @@ -154,16 +155,30 @@ async def set_session_model(self, model_id: str, session_id: str, **kwargs: Any)
SetSessionModelResponse,
)

@param_model(SetSessionConfigOptionSelectRequest)
@param_models(SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest)
async def set_config_option(
self, config_id: str, session_id: str, value: str, **kwargs: Any
self, config_id: str, session_id: str, value: str | bool, **kwargs: Any
) -> SetSessionConfigOptionResponse:
request = (
SetSessionConfigOptionBooleanRequest(
config_id=config_id,
session_id=session_id,
type="boolean",
value=value,
field_meta=kwargs or None,
)
if isinstance(value, bool)
else SetSessionConfigOptionSelectRequest(
config_id=config_id,
session_id=session_id,
value=value,
field_meta=kwargs or None,
)
)
return await request_model_from_dict(
self._conn,
AGENT_METHODS["session_set_config_option"],
SetSessionConfigOptionSelectRequest(
config_id=config_id, session_id=session_id, value=value, field_meta=kwargs or None
),
request,
SetSessionConfigOptionResponse,
)

Expand Down Expand Up @@ -193,7 +208,12 @@ async def prompt(
return await request_model(
self._conn,
AGENT_METHODS["session_prompt"],
PromptRequest(prompt=prompt, session_id=session_id, field_meta=kwargs or None),
PromptRequest(
prompt=prompt,
session_id=session_id,
message_id=message_id,
field_meta=kwargs or None,
),
PromptResponse,
)

Expand Down
7 changes: 4 additions & 3 deletions src/acp/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
ResumeSessionResponse,
SessionInfoUpdate,
SessionNotification,
SetSessionConfigOptionBooleanRequest,
SetSessionConfigOptionResponse,
SetSessionConfigOptionSelectRequest,
SetSessionModelRequest,
Expand All @@ -70,7 +71,7 @@
WriteTextFileRequest,
WriteTextFileResponse,
)
from .utils import param_model
from .utils import param_model, param_models

__all__ = ["Agent", "Client"]

Expand Down Expand Up @@ -181,9 +182,9 @@ async def set_session_model(
self, model_id: str, session_id: str, **kwargs: Any
) -> SetSessionModelResponse | None: ...

@param_model(SetSessionConfigOptionSelectRequest)
@param_models(SetSessionConfigOptionBooleanRequest, SetSessionConfigOptionSelectRequest)
async def set_config_option(
self, config_id: str, session_id: str, value: str, **kwargs: Any
self, config_id: str, session_id: str, value: str | bool, **kwargs: Any
) -> SetSessionConfigOptionResponse | None: ...

@param_model(AuthenticateRequest)
Expand Down
52 changes: 31 additions & 21 deletions src/acp/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,34 @@
HandlerT = TypeVar("HandlerT", bound=RequestHandler)


def _warn_legacy_handler(obj: Any, attr: str) -> None:
warnings.warn(
f"The old style method {type(obj).__name__}.{attr} is deprecated, please update to the snake-cased form.",
DeprecationWarning,
stacklevel=3,
)


def _resolve_handler(obj: Any, attr: str) -> tuple[AsyncHandler | None, str, bool]:
legacy_api = False
func = getattr(obj, attr, None)
if func is None and "_" in attr:
attr = to_camel_case(attr)
func = getattr(obj, attr, None)
legacy_api = True
elif callable(func) and "_" not in attr:
original_func = func
if hasattr(func, "__func__"):
original_func = func.__func__
parameters = inspect.signature(original_func).parameters
if len(parameters) == 2 and "params" in parameters:
legacy_api = True

if func is None or not callable(func):
return None, attr, legacy_api
return func, attr, legacy_api


@dataclass(slots=True)
class Route:
method: str
Expand Down Expand Up @@ -63,31 +91,13 @@ def add_route(self, route: Route) -> None:
self._notifications[route.method] = route

def _make_func(self, model: type[BaseModel], obj: Any, attr: str) -> AsyncHandler | None:
legacy_api = False
func = getattr(obj, attr, None)
if func is None and "_" in attr:
attr = to_camel_case(attr)
func = getattr(obj, attr, None)
legacy_api = True
elif callable(func) and "_" not in attr:
original_func = func
if hasattr(func, "__func__"):
original_func = func.__func__
parameters = inspect.signature(original_func).parameters
if len(parameters) == 2 and "params" in parameters:
legacy_api = True

if func is None or not callable(func):
func, attr, legacy_api = _resolve_handler(obj, attr)
if func is None:
return None

async def wrapper(params: Any) -> Any:
if legacy_api:
warnings.warn(
f"The old style method {type(obj).__name__}.{attr} is deprecated, "
"please update to the snake-cased form.",
DeprecationWarning,
stacklevel=3,
)
_warn_legacy_handler(obj, attr)
model_obj = model.model_validate(params)
if legacy_api:
return await func(model_obj) # type: ignore[arg-type]
Expand Down
104 changes: 99 additions & 5 deletions src/acp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,29 @@
MethodT = TypeVar("MethodT", bound=Callable)
ClassT = TypeVar("ClassT", bound=type)
T = TypeVar("T")
MultiParamModelSpec = tuple[type[BaseModel], ...]


def _param_models_name(models: MultiParamModelSpec) -> str:
return " | ".join(model_type.__name__ for model_type in models)


def _param_models_field_names(models: MultiParamModelSpec) -> tuple[str, ...]:
shared_fields = set(models[0].model_fields)
for model_type in models[1:]:
shared_fields &= set(model_type.model_fields)
return tuple(field_name for field_name in models[0].model_fields if field_name in shared_fields)


def model_to_kwargs(model_obj: BaseModel, models: MultiParamModelSpec) -> dict[str, Any]:
kwargs = {
field_name: getattr(model_obj, field_name)
for field_name in _param_models_field_names(models)
if field_name != "field_meta"
}
if meta := getattr(model_obj, "field_meta", None):
kwargs.update(meta)
return kwargs


def serialize_params(params: BaseModel) -> dict[str, Any]:
Expand Down Expand Up @@ -114,6 +137,18 @@ def decorator(func: MethodT) -> MethodT:
return decorator


def param_models(*param_cls: type[BaseModel]) -> Callable[[MethodT], MethodT]:
"""Decorator to mark a method as accepting multiple legacy parameter models."""
if not param_cls:
raise ValueError("param_models() requires at least one model class")

def decorator(func: MethodT) -> MethodT:
func.__param_models__ = param_cls # type: ignore[attr-defined]
return func

return decorator


def to_camel_case(snake_str: str) -> str:
"""Convert snake_case strings to camelCase."""
components = snake_str.split("_")
Expand All @@ -129,7 +164,9 @@ def wrapped(self, params: BaseModel) -> T:
DeprecationWarning,
stacklevel=3,
)
kwargs = {k: getattr(params, k) for k in model.model_fields if k != "field_meta"}
kwargs = {
field_name: getattr(params, field_name) for field_name in model.model_fields if field_name != "field_meta"
}
if meta := getattr(params, "field_meta", None):
kwargs.update(meta)
return func(self, **kwargs) # type: ignore[arg-type]
Expand All @@ -152,7 +189,11 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T:
DeprecationWarning,
stacklevel=3,
)
kwargs = {k: getattr(param, k) for k in model.model_fields if k != "field_meta"}
kwargs = {
field_name: getattr(param, field_name)
for field_name in model.model_fields
if field_name != "field_meta"
}
if meta := getattr(param, "field_meta", None):
kwargs.update(meta)
return func(self, **kwargs) # type: ignore[arg-type]
Expand All @@ -161,14 +202,67 @@ def wrapped(self, *args: Any, **kwargs: Any) -> T:
return wrapped


def _make_multi_legacy_func(func: Callable[..., T], models: MultiParamModelSpec) -> Callable[[Any, BaseModel], T]:
model_name = _param_models_name(models)

@functools.wraps(func)
def wrapped(self, params: BaseModel) -> T:
warnings.warn(
f"Calling {func.__name__} with {model_name} parameter is " # type: ignore[attr-defined]
"deprecated, please update to the new API style.",
DeprecationWarning,
stacklevel=3,
)
return func(self, **model_to_kwargs(params, models)) # type: ignore[arg-type]

return wrapped


def _make_multi_compatible_func(func: Callable[..., T], models: MultiParamModelSpec) -> Callable[..., T]:
model_name = _param_models_name(models)

@functools.wraps(func)
def wrapped(self, *args: Any, **kwargs: Any) -> T:
param = None
if not kwargs and len(args) == 1:
param = args[0]
elif not args and len(kwargs) == 1:
param = kwargs.get("params")
if isinstance(param, models):
warnings.warn(
f"Calling {func.__name__} with {model_name} parameter " # type: ignore[attr-defined]
"is deprecated, please update to the new API style.",
DeprecationWarning,
stacklevel=3,
)
return func(self, **model_to_kwargs(param, models)) # type: ignore[arg-type]
return func(self, *args, **kwargs)

return wrapped


def compatible_class(cls: ClassT) -> ClassT:
"""Mark a class as backward compatible with old API style."""
for attr in dir(cls):
func = getattr(cls, attr)
if not callable(func) or (model := getattr(func, "__param_model__", None)) is None:
if not callable(func):
continue
model = getattr(func, "__param_model__", None)
models = getattr(func, "__param_models__", None)
if model is None and models is None:
continue
if "_" in attr:
setattr(cls, to_camel_case(attr), _make_legacy_func(func, model))
if models is not None:
setattr(cls, to_camel_case(attr), _make_multi_legacy_func(func, models))
else:
if model is None:
continue
setattr(cls, to_camel_case(attr), _make_legacy_func(func, model))
else:
setattr(cls, attr, _make_compatible_func(func, model))
if models is not None:
setattr(cls, attr, _make_multi_compatible_func(func, models))
else:
if model is None:
continue
setattr(cls, attr, _make_compatible_func(func, model))
return cls
Loading
Loading