diff --git a/.fernignore b/.fernignore index 16abecd79..8fcc59fed 100644 --- a/.fernignore +++ b/.fernignore @@ -15,6 +15,7 @@ src/cohere/manually_maintained/__init__.py src/cohere/bedrock_client.py src/cohere/aws_client.py src/cohere/sagemaker_client.py +src/cohere/oci_client.py src/cohere/client_v2.py mypy.ini src/cohere/aliases.py \ No newline at end of file diff --git a/README.md b/README.md index c474bb632..7df774b8f 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,111 @@ for event in response: print(event.delta.message.content.text, end="") ``` +## Oracle Cloud Infrastructure (OCI) + +The SDK supports Oracle Cloud Infrastructure (OCI) Generative AI service. First, install the OCI SDK: + +``` +pip install 'cohere[oci]' +``` + +Then use the `OciClient` or `OciClientV2`: + +```Python +import cohere + +# Using OCI config file authentication (default: ~/.oci/config) +co = cohere.OciClient( + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", +) + +response = co.embed( + model="embed-english-v3.0", + texts=["Hello world"], + input_type="search_document", +) + +print(response.embeddings) +``` + +### OCI Authentication Methods + +**1. Config File (Default)** +```Python +co = cohere.OciClient( + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", + # Uses ~/.oci/config with DEFAULT profile +) +``` + +**2. Custom Profile** +```Python +co = cohere.OciClient( + oci_profile="MY_PROFILE", + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", +) +``` + +**3. Session-based Authentication (Security Token)** +```Python +# Works with OCI CLI session tokens +co = cohere.OciClient( + oci_profile="MY_SESSION_PROFILE", # Profile with security_token_file + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", +) +``` + +**4. Direct Credentials** +```Python +co = cohere.OciClient( + oci_user_id="ocid1.user.oc1...", + oci_fingerprint="xx:xx:xx:...", + oci_tenancy_id="ocid1.tenancy.oc1...", + oci_private_key_path="~/.oci/key.pem", + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", +) +``` + +**5. Instance Principal (for OCI Compute instances)** +```Python +co = cohere.OciClient( + auth_type="instance_principal", + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", +) +``` + +### Supported OCI APIs + +The OCI client supports the following Cohere APIs: +- **Embed**: Full support for all embedding models +- **Chat**: Full support with both V1 (`OciClient`) and V2 (`OciClientV2`) APIs + - Streaming available via `chat_stream()` + - Supports Command-R and Command-A model families + +### OCI Model Availability and Limitations + +**Available on OCI On-Demand Inference:** +- ✅ **Embed models**: available on OCI Generative AI +- ✅ **Chat models**: available via `OciClient` (V1) and `OciClientV2` (V2) + +**Not Available on OCI On-Demand Inference:** +- ❌ **Generate API**: OCI TEXT_GENERATION models are base models that require fine-tuning before deployment +- ❌ **Rerank API**: OCI TEXT_RERANK models are base models that require fine-tuning before deployment +- ❌ **Multiple Embedding Types**: OCI on-demand models only support single embedding type per request (cannot request both `float` and `int8` simultaneously) + +**Note**: To use Generate or Rerank models on OCI, you need to: +1. Fine-tune the base model using OCI's fine-tuning service +2. Deploy the fine-tuned model to a dedicated endpoint +3. Update your code to use the deployed model endpoint + +For the latest model availability, see the [OCI Generative AI documentation](https://docs.oracle.com/en-us/iaas/Content/generative-ai/home.htm). + ## Contributing While we value open-source contributions to this SDK, the code is generated programmatically. Additions made directly would have to be moved over to our generation code, otherwise they would be overwritten upon the next generated release. Feel free to open a PR as a proof of concept, but know that we will not be able to merge it as-is. We suggest opening an issue first to discuss with us! diff --git a/pyproject.toml b/pyproject.toml index 8abcedd7f..a55466399 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,10 @@ requests = "^2.0.0" tokenizers = ">=0.15,<1" types-requests = "^2.0.0" typing_extensions = ">= 4.0.0" +oci = { version = "^2.165.0", optional = true } + +[tool.poetry.extras] +oci = ["oci"] [tool.poetry.group.dev.dependencies] mypy = "==1.13.0" diff --git a/src/cohere/__init__.py b/src/cohere/__init__.py index fbadbfd07..6048798d1 100644 --- a/src/cohere/__init__.py +++ b/src/cohere/__init__.py @@ -518,6 +518,8 @@ "NotFoundError": ".errors", "NotImplementedError": ".errors", "OAuthAuthorizeResponse": ".types", + "OciClient": ".oci_client", + "OciClientV2": ".oci_client", "ParseInfo": ".types", "RerankDocument": ".types", "RerankRequestDocumentsItem": ".types", @@ -852,6 +854,8 @@ def __dir__(): "NotFoundError", "NotImplementedError", "OAuthAuthorizeResponse", + "OciClient", + "OciClientV2", "ParseInfo", "RerankDocument", "RerankRequestDocumentsItem", diff --git a/src/cohere/manually_maintained/lazy_oci_deps.py b/src/cohere/manually_maintained/lazy_oci_deps.py new file mode 100644 index 000000000..072d028b8 --- /dev/null +++ b/src/cohere/manually_maintained/lazy_oci_deps.py @@ -0,0 +1,30 @@ +"""Lazy loading for optional OCI SDK dependency.""" + +from typing import Any + +OCI_INSTALLATION_MESSAGE = """ +The OCI SDK is required to use OciClient or OciClientV2. + +Install it with: + pip install oci + +Or with the optional dependency group: + pip install cohere[oci] +""" + + +def lazy_oci() -> Any: + """ + Lazily import the OCI SDK. + + Returns: + The oci module + + Raises: + ImportError: If the OCI SDK is not installed + """ + try: + import oci + return oci + except ImportError: + raise ImportError(OCI_INSTALLATION_MESSAGE) diff --git a/src/cohere/manually_maintained/streaming.py b/src/cohere/manually_maintained/streaming.py new file mode 100644 index 000000000..88e513a34 --- /dev/null +++ b/src/cohere/manually_maintained/streaming.py @@ -0,0 +1,15 @@ +import typing + +from httpx import SyncByteStream + + +class Streamer(SyncByteStream): + """Wrap an iterator of bytes for httpx streaming responses.""" + + lines: typing.Iterator[bytes] + + def __init__(self, lines: typing.Iterator[bytes]): + self.lines = lines + + def __iter__(self) -> typing.Iterator[bytes]: + return self.lines diff --git a/src/cohere/oci_client.py b/src/cohere/oci_client.py new file mode 100644 index 000000000..796223c52 --- /dev/null +++ b/src/cohere/oci_client.py @@ -0,0 +1,1181 @@ +"""Oracle Cloud Infrastructure (OCI) client for Cohere API.""" + +import configparser +import email.utils +import json +import os +import typing +import uuid + +import httpx +import requests +from .client import Client, ClientEnvironment +from .client_v2 import ClientV2 +from .manually_maintained.lazy_oci_deps import lazy_oci +from .manually_maintained.streaming import Streamer +from httpx import URL, ByteStream + + +class OciClient(Client): + """ + Cohere V1 API client for Oracle Cloud Infrastructure (OCI) Generative AI service. + + Use this client for V1 API models (Command R family) and embeddings. + For V2 API models (Command A family), use OciClientV2 instead. + + Supported APIs on OCI: + - embed(): Full support for all embedding models + - chat(): Full support with Command-R models + - chat_stream(): Streaming chat support + + Supports all authentication methods: + - Config file (default): Uses ~/.oci/config + - Session-based: Uses OCI CLI session tokens + - Direct credentials: Pass OCI credentials directly + - Instance principal: For OCI compute instances + - Resource principal: For OCI functions + + Example: + ```python + import cohere + + client = cohere.OciClient( + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", + ) + + response = client.chat( + model="command-r-08-2024", + message="Hello!", + ) + print(response.text) + ``` + """ + + def __init__( + self, + *, + oci_config_path: typing.Optional[str] = None, + oci_profile: typing.Optional[str] = None, + oci_user_id: typing.Optional[str] = None, + oci_fingerprint: typing.Optional[str] = None, + oci_tenancy_id: typing.Optional[str] = None, + oci_private_key_path: typing.Optional[str] = None, + oci_private_key_content: typing.Optional[str] = None, + auth_type: typing.Literal["api_key", "instance_principal", "resource_principal"] = "api_key", + oci_region: typing.Optional[str] = None, + oci_compartment_id: str, + timeout: typing.Optional[float] = None, + ): + oci_config = _load_oci_config( + auth_type=auth_type, + config_path=oci_config_path, + profile=oci_profile, + user_id=oci_user_id, + fingerprint=oci_fingerprint, + tenancy_id=oci_tenancy_id, + private_key_path=oci_private_key_path, + private_key_content=oci_private_key_content, + ) + + if oci_region is None: + oci_region = oci_config.get("region") + if oci_region is None: + raise ValueError("oci_region must be provided either directly or in OCI config file") + + Client.__init__( + self, + base_url="https://api.cohere.com", + environment=ClientEnvironment.PRODUCTION, + client_name="n/a", + timeout=timeout, + api_key="n/a", + httpx_client=httpx.Client( + event_hooks=get_event_hooks( + oci_config=oci_config, + oci_region=oci_region, + oci_compartment_id=oci_compartment_id, + is_v2_client=False, + ), + timeout=timeout, + ), + ) + + +class OciClientV2(ClientV2): + """ + Cohere V2 API client for Oracle Cloud Infrastructure (OCI) Generative AI service. + + Supported APIs on OCI: + - embed(): Full support for all embedding models (returns embeddings as dict) + - chat(): Full support with Command-A models (command-a-03-2025) + - chat_stream(): Streaming chat with proper V2 event format + + Note: rerank() requires fine-tuned models deployed to dedicated endpoints. + OCI on-demand inference does not support the rerank API. + + Supports all authentication methods: + - Config file (default): Uses ~/.oci/config + - Session-based: Uses OCI CLI session tokens + - Direct credentials: Pass OCI credentials directly + - Instance principal: For OCI compute instances + - Resource principal: For OCI functions + + Example using config file: + ```python + import cohere + + client = cohere.OciClientV2( + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", + ) + + response = client.embed( + model="embed-english-v3.0", + texts=["Hello world"], + input_type="search_document", + ) + print(response.embeddings.float_) + + response = client.chat( + model="command-a-03-2025", + messages=[{"role": "user", "content": "Hello!"}], + ) + print(response.message) + ``` + + Example using direct credentials: + ```python + client = cohere.OciClientV2( + oci_user_id="ocid1.user.oc1...", + oci_fingerprint="xx:xx:xx:...", + oci_tenancy_id="ocid1.tenancy.oc1...", + oci_private_key_path="~/.oci/key.pem", + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", + ) + ``` + + Example using instance principal: + ```python + client = cohere.OciClientV2( + auth_type="instance_principal", + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1...", + ) + ``` + """ + + def __init__( + self, + *, + # Authentication - Config file (default) + oci_config_path: typing.Optional[str] = None, + oci_profile: typing.Optional[str] = None, + # Authentication - Direct credentials + oci_user_id: typing.Optional[str] = None, + oci_fingerprint: typing.Optional[str] = None, + oci_tenancy_id: typing.Optional[str] = None, + oci_private_key_path: typing.Optional[str] = None, + oci_private_key_content: typing.Optional[str] = None, + # Authentication - Instance principal + auth_type: typing.Literal["api_key", "instance_principal", "resource_principal"] = "api_key", + # Required for OCI Generative AI + oci_region: typing.Optional[str] = None, + oci_compartment_id: str, + # Standard parameters + timeout: typing.Optional[float] = None, + ): + # Load OCI config based on auth_type + oci_config = _load_oci_config( + auth_type=auth_type, + config_path=oci_config_path, + profile=oci_profile, + user_id=oci_user_id, + fingerprint=oci_fingerprint, + tenancy_id=oci_tenancy_id, + private_key_path=oci_private_key_path, + private_key_content=oci_private_key_content, + ) + + # Get region from config if not provided + if oci_region is None: + oci_region = oci_config.get("region") + if oci_region is None: + raise ValueError("oci_region must be provided either directly or in OCI config file") + + # Create httpx client with OCI event hooks + ClientV2.__init__( + self, + base_url="https://api.cohere.com", # Unused, OCI URL set in hooks + environment=ClientEnvironment.PRODUCTION, + client_name="n/a", + timeout=timeout, + api_key="n/a", + httpx_client=httpx.Client( + event_hooks=get_event_hooks( + oci_config=oci_config, + oci_region=oci_region, + oci_compartment_id=oci_compartment_id, + is_v2_client=True, + ), + timeout=timeout, + ), + ) + + +EventHook = typing.Callable[..., typing.Any] + + +def _load_oci_config( + auth_type: str, + config_path: typing.Optional[str], + profile: typing.Optional[str], + **kwargs: typing.Any, +) -> typing.Dict[str, typing.Any]: + """ + Load OCI configuration based on authentication type. + + Args: + auth_type: Authentication method (api_key, instance_principal, resource_principal) + config_path: Path to OCI config file (for api_key auth) + profile: Profile name in config file (for api_key auth) + **kwargs: Direct credentials (user_id, fingerprint, etc.) + + Returns: + Dictionary containing OCI configuration + """ + oci = lazy_oci() + + if auth_type == "instance_principal": + signer = oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + return {"signer": signer, "auth_type": "instance_principal"} + + elif auth_type == "resource_principal": + signer = oci.auth.signers.get_resource_principals_signer() + return {"signer": signer, "auth_type": "resource_principal"} + + elif kwargs.get("user_id"): + # Direct credentials provided - validate required fields + required_fields = ["fingerprint", "tenancy_id"] + missing = [f for f in required_fields if not kwargs.get(f)] + if missing: + raise ValueError( + f"When providing oci_user_id, you must also provide: {', '.join('oci_' + f for f in missing)}" + ) + if not kwargs.get("private_key_path") and not kwargs.get("private_key_content"): + raise ValueError( + "When providing oci_user_id, you must also provide either " + "oci_private_key_path or oci_private_key_content" + ) + config = { + "user": kwargs["user_id"], + "fingerprint": kwargs["fingerprint"], + "tenancy": kwargs["tenancy_id"], + } + if kwargs.get("private_key_path"): + config["key_file"] = kwargs["private_key_path"] + if kwargs.get("private_key_content"): + config["key_content"] = kwargs["private_key_content"] + return config + + else: + # Load from config file + oci_config = oci.config.from_file( + file_location=config_path or "~/.oci/config", profile_name=profile or "DEFAULT" + ) + _remove_inherited_session_auth(oci_config, config_path=config_path, profile=profile) + return oci_config + + +def _remove_inherited_session_auth( + oci_config: typing.Dict[str, typing.Any], + *, + config_path: typing.Optional[str], + profile: typing.Optional[str], +) -> None: + """Drop session auth fields inherited from the OCI config DEFAULT section.""" + profile_name = profile or "DEFAULT" + if profile_name == "DEFAULT" or "security_token_file" not in oci_config: + return + + config_file = os.path.expanduser(config_path or "~/.oci/config") + parser = configparser.ConfigParser(interpolation=None) + if not parser.read(config_file): + return + + if not parser.has_section(profile_name): + oci_config.pop("security_token_file", None) + return + + explicit_security_token = False + current_section: typing.Optional[str] = None + with open(config_file, encoding="utf-8") as handle: + for raw_line in handle: + line = raw_line.strip() + if not line or line.startswith(("#", ";")): + continue + if line.startswith("[") and line.endswith("]"): + current_section = line[1:-1].strip() + continue + if current_section == profile_name and line.split("=", 1)[0].strip() == "security_token_file": + explicit_security_token = True + break + + if not explicit_security_token: + oci_config.pop("security_token_file", None) + + +def _usage_from_oci(usage_data: typing.Optional[typing.Dict[str, typing.Any]]) -> typing.Dict[str, typing.Any]: + usage_data = usage_data or {} + input_tokens = usage_data.get("inputTokens", 0) + output_tokens = usage_data.get("completionTokens", usage_data.get("outputTokens", 0)) + + return { + "tokens": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + "billed_units": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + } + } + + +def get_event_hooks( + oci_config: typing.Dict[str, typing.Any], + oci_region: str, + oci_compartment_id: str, + is_v2_client: bool = False, +) -> typing.Dict[str, typing.List[EventHook]]: + """ + Create httpx event hooks for OCI request/response transformation. + + Args: + oci_config: OCI configuration dictionary + oci_region: OCI region (e.g., "us-chicago-1") + oci_compartment_id: OCI compartment OCID + is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False) + + Returns: + Dictionary of event hooks for httpx + """ + return { + "request": [ + map_request_to_oci( + oci_config=oci_config, + oci_region=oci_region, + oci_compartment_id=oci_compartment_id, + is_v2_client=is_v2_client, + ), + ], + "response": [map_response_from_oci()], + } + + +def map_request_to_oci( + oci_config: typing.Dict[str, typing.Any], + oci_region: str, + oci_compartment_id: str, + is_v2_client: bool = False, +) -> EventHook: + """ + Create event hook that transforms Cohere requests to OCI format and signs them. + + Args: + oci_config: OCI configuration dictionary + oci_region: OCI region + oci_compartment_id: OCI compartment OCID + is_v2_client: Whether this is for OciClientV2 (True) or OciClient (False) + + Returns: + Event hook function for httpx + """ + oci = lazy_oci() + + # Create OCI signer based on config type + # Priority order: instance/resource principal > session-based auth > API key auth + if "signer" in oci_config: + signer = oci_config["signer"] # Instance/resource principal + elif "security_token_file" in oci_config: + # Session-based authentication with security token. + # The token file is re-read on every request so that OCI CLI token refreshes + # (e.g. `oci session refresh`) are picked up without restarting the client. + key_file = oci_config.get("key_file") + if not key_file: + raise ValueError( + "OCI config profile is missing 'key_file'. " + "Session-based auth requires a key_file entry in your OCI config profile." + ) + token_file_path = os.path.expanduser(oci_config["security_token_file"]) + private_key = oci.signer.load_private_key_from_file(os.path.expanduser(key_file)) + + class _RefreshingSecurityTokenSigner: + """Wraps SecurityTokenSigner and re-reads the token file before each signing call.""" + + def __init__(self) -> None: + self._token_file = token_file_path + self._private_key = private_key + self._refresh() + + def _refresh(self) -> None: + with open(self._token_file, "r") as _f: + _token = _f.read().strip() + self._signer = oci.auth.signers.SecurityTokenSigner( + token=_token, + private_key=self._private_key, + ) + + # Delegate all attribute access to the inner signer, refreshing first. + def __call__(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: + self._refresh() + return self._signer(*args, **kwargs) + + def __getattr__(self, name: str) -> typing.Any: + if name.startswith("_"): + raise AttributeError(name) + self._refresh() + return getattr(self._signer, name) + + signer = _RefreshingSecurityTokenSigner() + elif "user" in oci_config: + signer = oci.signer.Signer( + tenancy=oci_config["tenancy"], + user=oci_config["user"], + fingerprint=oci_config["fingerprint"], + private_key_file_location=oci_config.get("key_file"), + private_key_content=oci_config.get("key_content"), + ) + else: + # Config doesn't have user or security token - unsupported + raise ValueError( + "OCI config is missing 'user' field and no security_token_file found. " + "Please use a profile with standard API key authentication, " + "session-based authentication, or provide direct credentials via oci_user_id parameter." + ) + + def _event_hook(request: httpx.Request) -> None: + # Extract Cohere API details + path_parts = request.url.path.split("/") + endpoint = path_parts[-1] + body = json.loads(request.read()) + + # Build OCI URL + url = get_oci_url( + region=oci_region, + endpoint=endpoint, + ) + + # Transform request body to OCI format + oci_body = transform_request_to_oci( + endpoint=endpoint, + cohere_body=body, + compartment_id=oci_compartment_id, + is_v2=is_v2_client, + ) + + # Prepare request for signing + oci_body_bytes = json.dumps(oci_body).encode("utf-8") + + # Build headers for signing + headers = { + "content-type": "application/json", + "date": email.utils.formatdate(usegmt=True), + } + + # Create a requests.PreparedRequest for OCI signing + oci_request = requests.Request( + method=request.method, + url=url, + headers=headers, + data=oci_body_bytes, + ) + prepped_request = oci_request.prepare() + + # Sign the request using OCI signer (modifies headers in place) + signer.do_request_sign(prepped_request) + + # Update httpx request with signed headers + request.url = URL(url) + request.headers = httpx.Headers(prepped_request.headers) + request.stream = ByteStream(oci_body_bytes) + request._content = oci_body_bytes + request.extensions["endpoint"] = endpoint + request.extensions["is_stream"] = body.get("stream", False) + request.extensions["is_v2"] = is_v2_client + + return _event_hook + + +def map_response_from_oci() -> EventHook: + """ + Create event hook that transforms OCI responses to Cohere format. + + Returns: + Event hook function for httpx + """ + + def _hook(response: httpx.Response) -> None: + endpoint = response.request.extensions["endpoint"] + is_stream = response.request.extensions.get("is_stream", False) + is_v2 = response.request.extensions.get("is_v2", False) + + output: typing.Iterator[bytes] + + # Only transform successful responses (200-299) + # Let error responses pass through unchanged so SDK error handling works + if not (200 <= response.status_code < 300): + return + + # For streaming responses, wrap the stream with a transformer + if is_stream: + original_stream = response.stream + transformed_stream = transform_oci_stream_wrapper(original_stream, endpoint, is_v2) + response.stream = Streamer(transformed_stream) + # Reset consumption flags + if hasattr(response, "_content"): + del response._content + response.is_stream_consumed = False + response.is_closed = False + return + + # Handle non-streaming responses + oci_response = json.loads(response.read()) + cohere_response = transform_oci_response_to_cohere(endpoint, oci_response, is_v2) + output = iter([json.dumps(cohere_response).encode("utf-8")]) + + response.stream = Streamer(output) + + # Reset response for re-reading + if hasattr(response, "_content"): + del response._content + response.is_stream_consumed = False + response.is_closed = False + + return _hook + + +def get_oci_url( + region: str, + endpoint: str, +) -> str: + """ + Map Cohere endpoints to OCI Generative AI endpoints. + + Args: + region: OCI region (e.g., "us-chicago-1") + endpoint: Cohere endpoint name + Returns: + Full OCI Generative AI endpoint URL + """ + base = f"https://inference.generativeai.{region}.oci.oraclecloud.com" + api_version = "20231130" + + # Map Cohere endpoints to OCI actions + action_map = { + "embed": "embedText", + "chat": "chat", + } + + action = action_map.get(endpoint) + if action is None: + raise ValueError( + f"Endpoint '{endpoint}' is not supported by OCI Generative AI. " + f"Supported endpoints: {list(action_map.keys())}" + ) + return f"{base}/{api_version}/actions/{action}" + + +def normalize_model_for_oci(model: str) -> str: + """ + Normalize model name for OCI. + + OCI accepts model names in the format "cohere.model-name" or full OCIDs. + This function ensures proper formatting for all regions. + + Args: + model: Model name (e.g., "command-r-08-2024") or full OCID + + Returns: + Normalized model identifier (e.g., "cohere.command-r-08-2024" or OCID) + + Examples: + >>> normalize_model_for_oci("command-a-03-2025") + "cohere.command-a-03-2025" + >>> normalize_model_for_oci("cohere.embed-english-v3.0") + "cohere.embed-english-v3.0" + >>> normalize_model_for_oci("ocid1.generativeaimodel.oc1...") + "ocid1.generativeaimodel.oc1..." + """ + if not model: + raise ValueError("OCI requests require a non-empty model name") + + # If it's already an OCID, return as-is (works across all regions) + if model.startswith("ocid1."): + return model + + # Add "cohere." prefix if not present + if not model.startswith("cohere."): + return f"cohere.{model}" + + return model + + +def transform_request_to_oci( + endpoint: str, + cohere_body: typing.Dict[str, typing.Any], + compartment_id: str, + is_v2: bool = False, +) -> typing.Dict[str, typing.Any]: + """ + Transform Cohere request body to OCI format. + + Args: + endpoint: Cohere endpoint name + cohere_body: Original Cohere request body + compartment_id: OCI compartment OCID + is_v2: Whether this request comes from OciClientV2 (True) or OciClient (False) + + Returns: + Transformed request body in OCI format + """ + model = normalize_model_for_oci(cohere_body.get("model")) + + if endpoint == "embed": + if "texts" in cohere_body: + inputs = cohere_body["texts"] + elif "inputs" in cohere_body: + inputs = cohere_body["inputs"] + elif "images" in cohere_body: + raise ValueError("OCI embed does not support the top-level 'images' parameter; use 'inputs' instead") + else: + raise ValueError("OCI embed requires either 'texts' or 'inputs'") + + oci_body = { + "inputs": inputs, + "servingMode": { + "servingType": "ON_DEMAND", + "modelId": model, + }, + "compartmentId": compartment_id, + } + + # Add optional fields only if provided + if "input_type" in cohere_body: + oci_body["inputType"] = cohere_body["input_type"].upper() + + if "truncate" in cohere_body: + oci_body["truncate"] = cohere_body["truncate"].upper() + + if "embedding_types" in cohere_body: + oci_body["embeddingTypes"] = [et.upper() for et in cohere_body["embedding_types"]] + if "max_tokens" in cohere_body: + oci_body["maxTokens"] = cohere_body["max_tokens"] + if "output_dimension" in cohere_body: + oci_body["outputDimension"] = cohere_body["output_dimension"] + if "priority" in cohere_body: + oci_body["priority"] = cohere_body["priority"] + + return oci_body + + elif endpoint == "chat": + # Validate that the request body matches the client type + has_messages = "messages" in cohere_body + has_message = "message" in cohere_body + if is_v2 and not has_messages: + raise ValueError( + "OciClientV2 requires the V2 API format ('messages' array). " + "Got a V1-style request with 'message' string. " + "Use OciClient for V1 models like Command R, " + "or switch to the V2 messages format." + ) + if not is_v2 and has_messages and not has_message: + raise ValueError( + "OciClient uses the V1 API format (single 'message' string). " + "Got a V2-style request with 'messages' array. " + "Use OciClientV2 for V2 models like Command A." + ) + + chat_request: typing.Dict[str, typing.Any] = { + "apiFormat": "COHEREV2" if is_v2 else "COHERE", + } + + if is_v2: + # V2: Transform Cohere V2 messages to OCI V2 format + # Cohere sends: [{"role": "user", "content": "text"}] + # OCI expects: [{"role": "USER", "content": [{"type": "TEXT", "text": "..."}]}] + oci_messages = [] + for msg in cohere_body["messages"]: + oci_msg: typing.Dict[str, typing.Any] = { + "role": msg["role"].upper(), + } + + # Transform content + if isinstance(msg.get("content"), str): + oci_msg["content"] = [{"type": "TEXT", "text": msg["content"]}] + elif isinstance(msg.get("content"), list): + transformed_content = [] + for item in msg["content"]: + if isinstance(item, dict) and "type" in item: + transformed_item = item.copy() + transformed_item["type"] = item["type"].upper() + transformed_content.append(transformed_item) + else: + transformed_content.append(item) + oci_msg["content"] = transformed_content + else: + oci_msg["content"] = msg.get("content") or [] + + if "tool_calls" in msg: + oci_msg["toolCalls"] = msg["tool_calls"] + if "tool_call_id" in msg: + oci_msg["toolCallId"] = msg["tool_call_id"] + if "tool_plan" in msg: + oci_msg["toolPlan"] = msg["tool_plan"] + + oci_messages.append(oci_msg) + + chat_request["messages"] = oci_messages + + # V2 optional parameters + if "max_tokens" in cohere_body: + chat_request["maxTokens"] = cohere_body["max_tokens"] + if "temperature" in cohere_body: + chat_request["temperature"] = cohere_body["temperature"] + if "k" in cohere_body: + chat_request["topK"] = cohere_body["k"] + if "p" in cohere_body: + chat_request["topP"] = cohere_body["p"] + if "seed" in cohere_body: + chat_request["seed"] = cohere_body["seed"] + if "frequency_penalty" in cohere_body: + chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"] + if "presence_penalty" in cohere_body: + chat_request["presencePenalty"] = cohere_body["presence_penalty"] + if "stop_sequences" in cohere_body: + chat_request["stopSequences"] = cohere_body["stop_sequences"] + if "tools" in cohere_body: + chat_request["tools"] = cohere_body["tools"] + if "strict_tools" in cohere_body: + chat_request["strictTools"] = cohere_body["strict_tools"] + if "documents" in cohere_body: + chat_request["documents"] = cohere_body["documents"] + if "citation_options" in cohere_body: + chat_request["citationOptions"] = cohere_body["citation_options"] + if "response_format" in cohere_body: + chat_request["responseFormat"] = cohere_body["response_format"] + if "safety_mode" in cohere_body: + chat_request["safetyMode"] = cohere_body["safety_mode"] + if "logprobs" in cohere_body: + chat_request["logprobs"] = cohere_body["logprobs"] + if "tool_choice" in cohere_body: + chat_request["toolChoice"] = cohere_body["tool_choice"] + if "priority" in cohere_body: + chat_request["priority"] = cohere_body["priority"] + # Thinking parameter for Command A Reasoning models + if "thinking" in cohere_body and cohere_body["thinking"] is not None: + thinking = cohere_body["thinking"] + oci_thinking: typing.Dict[str, typing.Any] = {} + if "type" in thinking: + oci_thinking["type"] = thinking["type"].upper() + if "token_budget" in thinking and thinking["token_budget"] is not None: + oci_thinking["tokenBudget"] = thinking["token_budget"] + if oci_thinking: + chat_request["thinking"] = oci_thinking + else: + # V1: single message string + chat_request["message"] = cohere_body["message"] + + if "temperature" in cohere_body: + chat_request["temperature"] = cohere_body["temperature"] + if "max_tokens" in cohere_body: + chat_request["maxTokens"] = cohere_body["max_tokens"] + if "k" in cohere_body: + chat_request["topK"] = cohere_body["k"] + if "p" in cohere_body: + chat_request["topP"] = cohere_body["p"] + if "seed" in cohere_body: + chat_request["seed"] = cohere_body["seed"] + if "stop_sequences" in cohere_body: + chat_request["stopSequences"] = cohere_body["stop_sequences"] + if "frequency_penalty" in cohere_body: + chat_request["frequencyPenalty"] = cohere_body["frequency_penalty"] + if "presence_penalty" in cohere_body: + chat_request["presencePenalty"] = cohere_body["presence_penalty"] + if "preamble" in cohere_body: + chat_request["preambleOverride"] = cohere_body["preamble"] + if "chat_history" in cohere_body: + chat_request["chatHistory"] = cohere_body["chat_history"] + if "documents" in cohere_body: + chat_request["documents"] = cohere_body["documents"] + if "tools" in cohere_body: + chat_request["tools"] = cohere_body["tools"] + if "tool_results" in cohere_body: + chat_request["toolResults"] = cohere_body["tool_results"] + if "response_format" in cohere_body: + chat_request["responseFormat"] = cohere_body["response_format"] + if "safety_mode" in cohere_body: + chat_request["safetyMode"] = cohere_body["safety_mode"] + if "priority" in cohere_body: + chat_request["priority"] = cohere_body["priority"] + + # Handle streaming for both versions + if cohere_body.get("stream"): + chat_request["isStream"] = True + + # Top level OCI request structure + oci_body = { + "servingMode": { + "servingType": "ON_DEMAND", + "modelId": model, + }, + "compartmentId": compartment_id, + "chatRequest": chat_request, + } + + return oci_body + + raise ValueError( + f"Endpoint '{endpoint}' is not supported by OCI Generative AI on-demand inference. " + "Supported endpoints: ['embed', 'chat']" + ) + + +def transform_oci_response_to_cohere( + endpoint: str, oci_response: typing.Dict[str, typing.Any], is_v2: bool = False, +) -> typing.Dict[str, typing.Any]: + """ + Transform OCI response to Cohere format. + + Args: + endpoint: Cohere endpoint name + oci_response: OCI response body + is_v2: Whether this is a V2 API response + + Returns: + Transformed response in Cohere format + """ + if endpoint == "embed": + embeddings_data = oci_response.get("embeddings", {}) + + if isinstance(embeddings_data, dict): + normalized_embeddings = {str(key).lower(): value for key, value in embeddings_data.items()} + else: + normalized_embeddings = {"float": embeddings_data} + + if is_v2: + embeddings = normalized_embeddings + else: + embeddings = normalized_embeddings.get("float", []) + + meta = { + "api_version": {"version": "1"}, + } + usage = _usage_from_oci(oci_response.get("usage")) + if "tokens" in usage: + meta["tokens"] = usage["tokens"] + if "billed_units" in usage: + meta["billed_units"] = usage["billed_units"] + + return { + "id": oci_response.get("id", str(uuid.uuid4())), + "embeddings": embeddings, + "texts": [], + "meta": meta, + } + + elif endpoint == "chat": + chat_response = oci_response.get("chatResponse", {}) + + if is_v2: + usage = _usage_from_oci(chat_response.get("usage")) + message = chat_response.get("message", {}) + + if "role" in message: + message = {**message, "role": message["role"].lower()} + + if "content" in message and isinstance(message["content"], list): + transformed_content = [] + for item in message["content"]: + if isinstance(item, dict): + transformed_item = item.copy() + if "type" in transformed_item: + transformed_item["type"] = transformed_item["type"].lower() + transformed_content.append(transformed_item) + else: + transformed_content.append(item) + message = {**message, "content": transformed_content} + + if "toolCalls" in message: + tool_calls = message["toolCalls"] + message = {k: v for k, v in message.items() if k != "toolCalls"} + message["tool_calls"] = tool_calls + if "toolPlan" in message: + tool_plan = message["toolPlan"] + message = {k: v for k, v in message.items() if k != "toolPlan"} + message["tool_plan"] = tool_plan + + return { + "id": chat_response.get("id", str(uuid.uuid4())), + "message": message, + "finish_reason": chat_response.get("finishReason", "COMPLETE"), + "usage": usage, + } + + # V1 response + meta = { + "api_version": {"version": "1"}, + } + usage = _usage_from_oci(chat_response.get("usage")) + if "tokens" in usage: + meta["tokens"] = usage["tokens"] + if "billed_units" in usage: + meta["billed_units"] = usage["billed_units"] + + return { + "text": chat_response.get("text", ""), + "generation_id": str(uuid.uuid4()), + "chat_history": chat_response.get("chatHistory", []), + "finish_reason": chat_response.get("finishReason", "COMPLETE"), + "citations": chat_response.get("citations", []), + "documents": chat_response.get("documents", []), + "search_queries": chat_response.get("searchQueries", []), + "meta": meta, + } + + return oci_response + + +def transform_oci_stream_wrapper( + stream: typing.Iterator[bytes], endpoint: str, is_v2: bool = False, +) -> typing.Iterator[bytes]: + """ + Wrap OCI stream and transform events to Cohere format. + + Args: + stream: Original OCI stream iterator + endpoint: Cohere endpoint name + is_v2: Whether this is a V2 API stream + + Yields: + Bytes of transformed streaming events + """ + generation_id = str(uuid.uuid4()) + emitted_start = False + emitted_content_end = False + current_content_type: typing.Optional[str] = None + current_content_index = 0 + final_finish_reason = "COMPLETE" + final_usage: typing.Optional[typing.Dict[str, typing.Any]] = None + full_v1_text = "" + final_v1_finish_reason = "COMPLETE" + buffer = b"" + + def _emit_v2_event(event: typing.Dict[str, typing.Any]) -> bytes: + return b"data: " + json.dumps(event).encode("utf-8") + b"\n\n" + + def _emit_v1_event(event: typing.Dict[str, typing.Any]) -> bytes: + return json.dumps(event).encode("utf-8") + b"\n" + + def _current_content_type(oci_event: typing.Dict[str, typing.Any]) -> typing.Optional[str]: + message = oci_event.get("message") + if isinstance(message, dict): + content_list = message.get("content") + if content_list and isinstance(content_list, list) and len(content_list) > 0: + oci_type = content_list[0].get("type", "TEXT").upper() + return "thinking" if oci_type == "THINKING" else "text" + return None # finish-only or non-content event — don't trigger a type transition + + def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]: + nonlocal emitted_start, emitted_content_end, current_content_type, current_content_index + nonlocal final_finish_reason, final_usage + + event_content_type = _current_content_type(oci_event) + open_type = event_content_type or "text" + + if not emitted_start: + yield _emit_v2_event( + { + "type": "message-start", + "id": generation_id, + "delta": {"message": {"role": "assistant"}}, + } + ) + yield _emit_v2_event( + { + "type": "content-start", + "index": current_content_index, + "delta": {"message": {"content": {"type": open_type}}}, + } + ) + emitted_start = True + current_content_type = open_type + elif event_content_type is not None and current_content_type != event_content_type: + yield _emit_v2_event({"type": "content-end", "index": current_content_index}) + current_content_index += 1 + yield _emit_v2_event( + { + "type": "content-start", + "index": current_content_index, + "delta": {"message": {"content": {"type": event_content_type}}}, + } + ) + current_content_type = event_content_type + emitted_content_end = False + + for cohere_event in typing.cast( + typing.List[typing.Dict[str, typing.Any]], transform_stream_event(endpoint, oci_event, is_v2=True) + ): + if "index" in cohere_event: + cohere_event = {**cohere_event, "index": current_content_index} + if cohere_event["type"] == "content-end": + emitted_content_end = True + final_finish_reason = oci_event.get("finishReason", final_finish_reason) + final_usage = _usage_from_oci(oci_event.get("usage")) + yield _emit_v2_event(cohere_event) + + def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> bytes: + nonlocal full_v1_text, final_v1_finish_reason + event = transform_stream_event(endpoint, oci_event, is_v2=False) + if isinstance(event, dict): + if event.get("event_type") == "text-generation" and event.get("text"): + full_v1_text += typing.cast(str, event["text"]) + if "finishReason" in oci_event: + final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason) + return _emit_v1_event(event) + return b"" + + def _process_line(line: str) -> typing.Iterator[bytes]: + if not line.startswith("data: "): + return + + data_str = line[6:] + if data_str.strip() == "[DONE]": + if is_v2: + if emitted_start: + if not emitted_content_end: + yield _emit_v2_event({"type": "content-end", "index": current_content_index}) + message_end_event: typing.Dict[str, typing.Any] = { + "type": "message-end", + "id": generation_id, + "delta": {"finish_reason": final_finish_reason}, + } + if final_usage: + message_end_event["delta"]["usage"] = final_usage + yield _emit_v2_event(message_end_event) + else: + yield _emit_v1_event( + { + "event_type": "stream-end", + "finish_reason": final_v1_finish_reason, + "response": { + "text": full_v1_text, + "generation_id": generation_id, + "finish_reason": final_v1_finish_reason, + }, + } + ) + return + + try: + oci_event = json.loads(data_str) + except json.JSONDecodeError: + return + + try: + if is_v2: + for event_bytes in _transform_v2_event(oci_event): + yield event_bytes + else: + yield _transform_v1_event(oci_event) + except Exception as exc: + raise RuntimeError(f"OCI stream event transformation failed for endpoint '{endpoint}': {exc}") from exc + + for chunk in stream: + buffer += chunk + while b"\n" in buffer: + line_bytes, buffer = buffer.split(b"\n", 1) + line = line_bytes.decode("utf-8").strip() + for event_bytes in _process_line(line): + yield event_bytes + + if buffer.strip(): + line = buffer.decode("utf-8").strip() + for event_bytes in _process_line(line): + yield event_bytes + + +def transform_stream_event( + endpoint: str, oci_event: typing.Dict[str, typing.Any], is_v2: bool = False, +) -> typing.Union[typing.Dict[str, typing.Any], typing.List[typing.Dict[str, typing.Any]]]: + """ + Transform individual OCI stream event to Cohere format. + + Args: + endpoint: Cohere endpoint name + oci_event: OCI stream event + is_v2: Whether this is a V2 API stream + + Returns: + V2: List of transformed events. V1: Single transformed event dict. + """ + if endpoint == "chat": + if is_v2: + content_type = "text" + content_value = "" + message = oci_event.get("message") + + if "message" in oci_event and not isinstance(message, dict): + raise TypeError("OCI V2 stream event message must be an object") + + if isinstance(message, dict) and "content" in message: + content_list = message["content"] + if content_list and isinstance(content_list, list) and len(content_list) > 0: + first_content = content_list[0] + oci_type = first_content.get("type", "TEXT").upper() + if oci_type == "THINKING": + content_type = "thinking" + content_value = first_content.get("thinking", "") + else: + content_type = "text" + content_value = first_content.get("text", "") + + events: typing.List[typing.Dict[str, typing.Any]] = [] + if content_value: + delta_content: typing.Dict[str, typing.Any] = {} + if content_type == "thinking": + delta_content["thinking"] = content_value + else: + delta_content["text"] = content_value + + events.append( + { + "type": "content-delta", + "index": 0, + "delta": { + "message": { + "content": delta_content, + } + }, + } + ) + + if "finishReason" in oci_event: + events.append( + { + "type": "content-end", + "index": 0, + } + ) + + return events + + # V1 stream event + return { + "event_type": "text-generation", + "text": oci_event.get("text", ""), + "is_finished": oci_event.get("isFinished", False), + } + + return [] if is_v2 else {} diff --git a/tests/test_oci_client.py b/tests/test_oci_client.py new file mode 100644 index 000000000..45a682213 --- /dev/null +++ b/tests/test_oci_client.py @@ -0,0 +1,1194 @@ +"""Integration and unit tests for OCI Generative AI client. + +All integration tests are validated against the live OCI Generative AI inference +layer (us-chicago-1). The OciClientV2 uses the V2 Cohere API format (COHEREV2) +and communicates with the OCI inference endpoint at: + https://inference.generativeai.{region}.oci.oraclecloud.com + +Integration test coverage: + + V1 API (OciClient — Command R family): + Test Model What it proves + ------------------------------- -------------------------- ------------------------------------------ + test_embed embed-english-v3.0 V1 embed returns 2x 1024-dim float vectors + test_chat command-r-08-2024 V1 chat returns text with COHERE apiFormat + test_chat_stream command-r-08-2024 V1 streaming with text-generation events + + V2 API (OciClientV2 — Command A family): + Test Model What it proves + ------------------------------- -------------------------- ------------------------------------------ + test_embed_v2 embed-english-v3.0 V2 embed returns dict with float_ key + test_embed_with_model_prefix_v2 cohere.embed-english-v3.0 Model normalization works + test_chat_v2 command-a-03-2025 V2 chat returns message with COHEREV2 format + test_chat_stream_v2 command-a-03-2025 V2 SSE streaming with content-delta events + test_command_a_chat command-a-03-2025 Command A chat via V2 + + Cross-cutting: + Test Model What it proves + ------------------------------- -------------------------- ------------------------------------------ + test_config_file_auth embed-english-v3.0 API key auth from config file + test_custom_profile_auth embed-english-v3.0 Custom OCI profile auth + test_embed_english_v3 embed-english-v3.0 1024-dim embeddings + test_embed_multilingual_v3 embed-multilingual-v3.0 Multilingual model works + test_invalid_model invalid-model-name Error handling works + test_missing_compartment_id -- Raises TypeError + +Requirements: +1. OCI SDK installed: pip install oci +2. OCI credentials configured in ~/.oci/config +3. TEST_OCI environment variable set to run +4. OCI_COMPARTMENT_ID environment variable with valid OCI compartment OCID +5. OCI_REGION environment variable (optional, defaults to us-chicago-1) + +Run with: + TEST_OCI=1 OCI_COMPARTMENT_ID=ocid1.compartment.oc1... pytest tests/test_oci_client.py +""" + +import os +import sys +import tempfile +import types +import unittest +from unittest.mock import MagicMock, mock_open, patch + +import cohere + +if "tokenizers" not in sys.modules: + tokenizers_stub = types.ModuleType("tokenizers") + tokenizers_stub.Tokenizer = object + sys.modules["tokenizers"] = tokenizers_stub + +if "fastavro" not in sys.modules: + fastavro_stub = types.ModuleType("fastavro") + fastavro_stub.parse_schema = lambda schema: schema + fastavro_stub.reader = lambda *args, **kwargs: iter(()) + fastavro_stub.writer = lambda *args, **kwargs: None + sys.modules["fastavro"] = fastavro_stub + +if "httpx_sse" not in sys.modules: + httpx_sse_stub = types.ModuleType("httpx_sse") + httpx_sse_stub.connect_sse = lambda *args, **kwargs: None + sys.modules["httpx_sse"] = httpx_sse_stub + + +@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") +class TestOciClient(unittest.TestCase): + """Test OciClient (V1 API) with OCI Generative AI.""" + + def setUp(self): + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + if not compartment_id: + self.skipTest("OCI_COMPARTMENT_ID not set") + + region = os.getenv("OCI_REGION", "us-chicago-1") + profile = os.getenv("OCI_PROFILE", "DEFAULT") + + self.client = cohere.OciClient( + oci_region=region, + oci_compartment_id=compartment_id, + oci_profile=profile, + ) + + def test_embed(self): + """Test embedding with V1 client.""" + response = self.client.embed( + model="embed-english-v3.0", + texts=["Hello world", "Cohere on OCI"], + input_type="search_document", + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings) + self.assertEqual(len(response.embeddings), 2) + self.assertEqual(len(response.embeddings[0]), 1024) + + def test_chat(self): + """Test V1 chat with Command R.""" + response = self.client.chat( + model="command-r-08-2024", + message="What is 2+2? Answer with just the number.", + ) + self.assertIsNotNone(response) + self.assertIsNotNone(response.text) + self.assertIn("4", response.text) + + def test_chat_stream(self): + """Test V1 streaming chat.""" + events = [] + for event in self.client.chat_stream( + model="command-r-08-2024", + message="Count from 1 to 3.", + ): + events.append(event) + + self.assertTrue(len(events) > 0) + text_events = [e for e in events if hasattr(e, "text") and e.text] + self.assertTrue(len(text_events) > 0) + + +@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") +class TestOciClientV2(unittest.TestCase): + """Test OciClientV2 (v2 API) with OCI Generative AI.""" + + def setUp(self): + """Set up OCI v2 client for each test.""" + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + if not compartment_id: + self.skipTest("OCI_COMPARTMENT_ID not set") + + region = os.getenv("OCI_REGION", "us-chicago-1") + profile = os.getenv("OCI_PROFILE", "DEFAULT") + + self.client = cohere.OciClientV2( + oci_region=region, + oci_compartment_id=compartment_id, + oci_profile=profile, + ) + + def test_embed_v2(self): + """Test embedding with v2 client.""" + response = self.client.embed( + model="embed-english-v3.0", + texts=["Hello from v2", "Second text"], + input_type="search_document", + ) + + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings) + # V2 returns embeddings as a dict with "float" key + self.assertIsNotNone(response.embeddings.float_) + self.assertEqual(len(response.embeddings.float_), 2) + # Verify embedding dimensions (1024 for embed-english-v3.0) + self.assertEqual(len(response.embeddings.float_[0]), 1024) + + def test_embed_with_model_prefix_v2(self): + """Test embedding with 'cohere.' model prefix on v2 client.""" + response = self.client.embed( + model="cohere.embed-english-v3.0", + texts=["Test with prefix"], + input_type="search_document", + ) + + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings) + self.assertIsNotNone(response.embeddings.float_) + self.assertEqual(len(response.embeddings.float_), 1) + + def test_chat_v2(self): + """Test chat with v2 client.""" + response = self.client.chat( + model="command-a-03-2025", + messages=[{"role": "user", "content": "Say hello"}], + ) + + self.assertIsNotNone(response) + self.assertIsNotNone(response.message) + + def test_chat_stream_v2(self): + """Test streaming chat with v2 client.""" + events = [] + for event in self.client.chat_stream( + model="command-a-03-2025", + messages=[{"role": "user", "content": "Count from 1 to 3"}], + ): + events.append(event) + + self.assertTrue(len(events) > 0) + # Verify we received content-delta events with text + content_delta_events = [e for e in events if hasattr(e, "type") and e.type == "content-delta"] + self.assertTrue(len(content_delta_events) > 0) + + # Verify we can extract text from events + full_text = "" + for event in events: + if ( + hasattr(event, "delta") + and event.delta + and hasattr(event.delta, "message") + and event.delta.message + and hasattr(event.delta.message, "content") + and event.delta.message.content + and hasattr(event.delta.message.content, "text") + and event.delta.message.content.text is not None + ): + full_text += event.delta.message.content.text + + # Should have received some text + self.assertTrue(len(full_text) > 0) + +@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") +class TestOciClientAuthentication(unittest.TestCase): + """Test different OCI authentication methods.""" + + def test_config_file_auth(self): + """Test authentication using OCI config file.""" + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + if not compartment_id: + self.skipTest("OCI_COMPARTMENT_ID not set") + + profile = os.getenv("OCI_PROFILE", "DEFAULT") + client = cohere.OciClientV2( + oci_region="us-chicago-1", + oci_compartment_id=compartment_id, + oci_profile=profile, + ) + + # Test with a simple embed call + response = client.embed( + model="embed-english-v3.0", + texts=["Auth test"], + input_type="search_document", + ) + + self.assertIsNotNone(response) + self.assertIsNotNone(response.embeddings) + + def test_custom_profile_auth(self): + """Test authentication using custom OCI profile.""" + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + profile = os.getenv("OCI_PROFILE", "DEFAULT") + + if not compartment_id: + self.skipTest("OCI_COMPARTMENT_ID not set") + + client = cohere.OciClientV2( + oci_profile=profile, + oci_region="us-chicago-1", + oci_compartment_id=compartment_id, + ) + + response = client.embed( + model="embed-english-v3.0", + texts=["Profile auth test"], + input_type="search_document", + ) + + self.assertIsNotNone(response) + + +@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") +class TestOciClientErrors(unittest.TestCase): + """Test error handling in OCI client.""" + + def test_missing_compartment_id(self): + """Test error when compartment ID is missing.""" + with self.assertRaises(TypeError): + cohere.OciClientV2( + oci_region="us-chicago-1", + # Missing oci_compartment_id + ) + + def test_invalid_model(self): + """Test error handling with invalid model.""" + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + if not compartment_id: + self.skipTest("OCI_COMPARTMENT_ID not set") + + profile = os.getenv("OCI_PROFILE", "DEFAULT") + client = cohere.OciClientV2( + oci_region="us-chicago-1", + oci_compartment_id=compartment_id, + oci_profile=profile, + ) + + # OCI should return an error for invalid model + with self.assertRaises(Exception): + client.embed( + model="invalid-model-name", + texts=["Test"], + input_type="search_document", + ) + + +@unittest.skipIf(os.getenv("TEST_OCI") is None, "TEST_OCI not set") +class TestOciClientModels(unittest.TestCase): + """Test different Cohere models on OCI.""" + + def setUp(self): + """Set up OCI client for each test.""" + compartment_id = os.getenv("OCI_COMPARTMENT_ID") + if not compartment_id: + self.skipTest("OCI_COMPARTMENT_ID not set") + + region = os.getenv("OCI_REGION", "us-chicago-1") + profile = os.getenv("OCI_PROFILE", "DEFAULT") + + self.client = cohere.OciClientV2( + oci_region=region, + oci_compartment_id=compartment_id, + oci_profile=profile, + ) + + def test_embed_english_v3(self): + """Test embed-english-v3.0 model.""" + response = self.client.embed( + model="embed-english-v3.0", + texts=["Test"], + input_type="search_document", + ) + self.assertIsNotNone(response.embeddings) + self.assertIsNotNone(response.embeddings.float_) + self.assertEqual(len(response.embeddings.float_[0]), 1024) + + def test_embed_multilingual_v3(self): + """Test embed-multilingual-v3.0 model.""" + response = self.client.embed( + model="embed-multilingual-v3.0", + texts=["Test"], + input_type="search_document", + ) + self.assertIsNotNone(response.embeddings) + self.assertIsNotNone(response.embeddings.float_) + self.assertEqual(len(response.embeddings.float_[0]), 1024) + + def test_command_a_chat(self): + """Test command-a-03-2025 model for chat.""" + response = self.client.chat( + model="command-a-03-2025", + messages=[{"role": "user", "content": "Hello"}], + ) + self.assertIsNotNone(response.message) + + def test_embed_english_light_v3(self): + """Test embed-english-light-v3.0 returns 384-dim vectors.""" + response = self.client.embed( + model="embed-english-light-v3.0", + texts=["Hello world"], + input_type="search_document", + ) + self.assertIsNotNone(response.embeddings.float_) + self.assertEqual(len(response.embeddings.float_[0]), 384) + + def test_embed_multilingual_light_v3(self): + """Test embed-multilingual-light-v3.0 returns 384-dim vectors.""" + response = self.client.embed( + model="embed-multilingual-light-v3.0", + texts=["Bonjour le monde"], + input_type="search_document", + ) + self.assertIsNotNone(response.embeddings.float_) + self.assertEqual(len(response.embeddings.float_[0]), 384) + + def test_embed_search_query_input_type(self): + """Test embed with search_query input_type (distinct from search_document).""" + response = self.client.embed( + model="embed-english-v3.0", + texts=["What is the capital of France?"], + input_type="search_query", + ) + self.assertIsNotNone(response.embeddings.float_) + self.assertEqual(len(response.embeddings.float_[0]), 1024) + + def test_command_r_plus_chat(self): + """Test command-r-plus-08-2024 via V1 client.""" + v1_client = cohere.OciClient( + oci_region=os.getenv("OCI_REGION", "us-chicago-1"), + oci_compartment_id=os.getenv("OCI_COMPARTMENT_ID"), + oci_profile=os.getenv("OCI_PROFILE", "DEFAULT"), + ) + response = v1_client.chat( + model="command-r-plus-08-2024", + message="What is 2+2? Answer with just the number.", + ) + self.assertIsNotNone(response.text) + self.assertIn("4", response.text) + + def test_v2_multi_turn_chat(self): + """Test V2 chat with conversation history (multi-turn).""" + response = self.client.chat( + model="command-a-03-2025", + messages=[ + {"role": "user", "content": "My name is Alice."}, + {"role": "assistant", "content": "Nice to meet you, Alice!"}, + {"role": "user", "content": "What is my name?"}, + ], + ) + self.assertIsNotNone(response.message) + content = response.message.content[0].text + self.assertIn("Alice", content) + + def test_v2_system_message(self): + """Test V2 chat with a system message.""" + response = self.client.chat( + model="command-a-03-2025", + messages=[ + {"role": "system", "content": "You are a helpful assistant. Always respond in exactly 3 words."}, + {"role": "user", "content": "Say hello."}, + ], + ) + self.assertIsNotNone(response.message) + self.assertIsNotNone(response.message.content[0].text) + + +class TestOciClientTransformations(unittest.TestCase): + """Unit tests for OCI request/response transformations (no OCI credentials required).""" + + def test_thinking_parameter_transformation(self): + """Test that thinking parameter is correctly transformed to OCI format.""" + from cohere.oci_client import transform_request_to_oci + + cohere_body = { + "model": "command-a-reasoning-08-2025", + "messages": [{"role": "user", "content": "What is 2+2?"}], + "thinking": { + "type": "enabled", + "token_budget": 10000, + }, + } + + result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True) + + # Verify thinking parameter is transformed with camelCase for OCI API + chat_request = result["chatRequest"] + self.assertIn("thinking", chat_request) + self.assertEqual(chat_request["thinking"]["type"], "ENABLED") + self.assertEqual(chat_request["thinking"]["tokenBudget"], 10000) # camelCase for OCI + + def test_thinking_parameter_disabled(self): + """Test that disabled thinking is correctly transformed.""" + from cohere.oci_client import transform_request_to_oci + + cohere_body = { + "model": "command-a-reasoning-08-2025", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": { + "type": "disabled", + }, + } + + result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True) + + chat_request = result["chatRequest"] + self.assertIn("thinking", chat_request) + self.assertEqual(chat_request["thinking"]["type"], "DISABLED") + self.assertNotIn("token_budget", chat_request["thinking"]) + + def test_thinking_response_transformation(self): + """Test that thinking content in response is correctly transformed.""" + from cohere.oci_client import transform_oci_response_to_cohere + + oci_response = { + "chatResponse": { + "id": "test-id", + "message": { + "role": "ASSISTANT", + "content": [ + {"type": "THINKING", "thinking": "Let me think about this..."}, + {"type": "TEXT", "text": "The answer is 4."}, + ], + }, + "finishReason": "COMPLETE", + "usage": {"inputTokens": 10, "completionTokens": 20}, + } + } + + result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) + + # Verify content types are lowercased + self.assertEqual(result["message"]["content"][0]["type"], "thinking") + self.assertEqual(result["message"]["content"][1]["type"], "text") + + def test_stream_event_thinking_transformation(self): + """Test that thinking content in stream events is correctly transformed.""" + from cohere.oci_client import transform_stream_event + + # OCI thinking event + oci_event = { + "message": { + "content": [{"type": "THINKING", "thinking": "Reasoning step..."}] + } + } + + result = transform_stream_event("chat", oci_event, is_v2=True) + + self.assertEqual(result[0]["type"], "content-delta") + self.assertIn("thinking", result[0]["delta"]["message"]["content"]) + self.assertEqual(result[0]["delta"]["message"]["content"]["thinking"], "Reasoning step...") + + def test_stream_event_text_transformation(self): + """Test that text content in stream events is correctly transformed.""" + from cohere.oci_client import transform_stream_event + + # OCI text event + oci_event = { + "message": { + "content": [{"type": "TEXT", "text": "The answer is..."}] + } + } + + result = transform_stream_event("chat", oci_event, is_v2=True) + + self.assertEqual(result[0]["type"], "content-delta") + self.assertIn("text", result[0]["delta"]["message"]["content"]) + self.assertEqual(result[0]["delta"]["message"]["content"]["text"], "The answer is...") + + def test_thinking_parameter_none(self): + """Test that thinking=None does not crash (issue: null guard).""" + from cohere.oci_client import transform_request_to_oci + + cohere_body = { + "model": "command-a-03-2025", + "messages": [{"role": "user", "content": "Hello"}], + "thinking": None, # Explicitly set to None + } + + # Should not crash with TypeError + result = transform_request_to_oci("chat", cohere_body, "compartment-123", is_v2=True) + + chat_request = result["chatRequest"] + # thinking should not be in request when None + self.assertNotIn("thinking", chat_request) + + def test_v2_response_role_lowercased(self): + """Test that V2 response message role is lowercased.""" + from cohere.oci_client import transform_oci_response_to_cohere + + oci_response = { + "chatResponse": { + "id": "test-id", + "message": { + "role": "ASSISTANT", + "content": [{"type": "TEXT", "text": "Hello"}], + }, + "finishReason": "COMPLETE", + "usage": {"inputTokens": 10, "completionTokens": 20}, + } + } + + result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) + + # Role should be lowercased + self.assertEqual(result["message"]["role"], "assistant") + + def test_v2_response_finish_reason_uppercase(self): + """Test that V2 response finish_reason stays uppercase.""" + from cohere.oci_client import transform_oci_response_to_cohere + + oci_response = { + "chatResponse": { + "id": "test-id", + "message": { + "role": "ASSISTANT", + "content": [{"type": "TEXT", "text": "Hello"}], + }, + "finishReason": "MAX_TOKENS", + "usage": {"inputTokens": 10, "completionTokens": 20}, + } + } + + result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) + + # V2 finish_reason should stay uppercase + self.assertEqual(result["finish_reason"], "MAX_TOKENS") + + def test_v2_response_tool_calls_conversion(self): + """Test that V2 response converts toolCalls to tool_calls.""" + from cohere.oci_client import transform_oci_response_to_cohere + + oci_response = { + "chatResponse": { + "id": "test-id", + "message": { + "role": "ASSISTANT", + "content": [{"type": "TEXT", "text": "I'll help with that."}], + "toolCalls": [ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city": "London"}'}, + } + ], + }, + "finishReason": "TOOL_CALL", + "usage": {"inputTokens": 10, "completionTokens": 20}, + } + } + + result = transform_oci_response_to_cohere("chat", oci_response, is_v2=True) + + # toolCalls should be converted to tool_calls + self.assertIn("tool_calls", result["message"]) + self.assertNotIn("toolCalls", result["message"]) + self.assertEqual(len(result["message"]["tool_calls"]), 1) + self.assertEqual(result["message"]["tool_calls"][0]["id"], "call_123") + + def test_normalize_model_for_oci(self): + """Test model name normalization for OCI.""" + from cohere.oci_client import normalize_model_for_oci + + # Plain model name gets cohere. prefix + self.assertEqual(normalize_model_for_oci("command-a-03-2025"), "cohere.command-a-03-2025") + # Already prefixed passes through + self.assertEqual(normalize_model_for_oci("cohere.embed-english-v3.0"), "cohere.embed-english-v3.0") + # OCID passes through + self.assertEqual( + normalize_model_for_oci("ocid1.generativeaimodel.oc1.us-chicago-1.abc"), + "ocid1.generativeaimodel.oc1.us-chicago-1.abc", + ) + + def test_transform_embed_request(self): + """Test embed request transformation to OCI format.""" + from cohere.oci_client import transform_request_to_oci + + body = { + "model": "embed-english-v3.0", + "texts": ["hello", "world"], + "input_type": "search_document", + "truncate": "end", + "embedding_types": ["float", "int8"], + } + result = transform_request_to_oci("embed", body, "compartment-123") + + self.assertEqual(result["inputs"], ["hello", "world"]) + self.assertEqual(result["inputType"], "SEARCH_DOCUMENT") + self.assertEqual(result["truncate"], "END") + self.assertEqual(result["embeddingTypes"], ["FLOAT", "INT8"]) + self.assertEqual(result["compartmentId"], "compartment-123") + self.assertEqual(result["servingMode"]["modelId"], "cohere.embed-english-v3.0") + + def test_transform_embed_request_with_optional_params(self): + """Test embed request forwards optional params.""" + from cohere.oci_client import transform_request_to_oci + + body = { + "model": "embed-english-v3.0", + "inputs": [{"content": [{"type": "text", "text": "hello"}]}], + "input_type": "classification", + "max_tokens": 256, + "output_dimension": 512, + "priority": 42, + } + result = transform_request_to_oci("embed", body, "compartment-123") + + self.assertEqual(result["inputs"], body["inputs"]) + self.assertEqual(result["maxTokens"], 256) + self.assertEqual(result["outputDimension"], 512) + self.assertEqual(result["priority"], 42) + + def test_transform_embed_request_rejects_images(self): + """Test embed request fails clearly for unsupported top-level images.""" + from cohere.oci_client import transform_request_to_oci + + with self.assertRaises(ValueError) as ctx: + transform_request_to_oci( + "embed", + { + "model": "embed-english-v3.0", + "images": ["data:image/png;base64,abc"], + "input_type": "classification", + }, + "compartment-123", + ) + + self.assertIn("top-level 'images' parameter", str(ctx.exception)) + + def test_transform_chat_request_optional_params(self): + """Test chat request transformation includes optional params.""" + from cohere.oci_client import transform_request_to_oci + + body = { + "model": "command-a-03-2025", + "messages": [{"role": "user", "content": "Hi"}], + "max_tokens": 100, + "temperature": 0.7, + "stop_sequences": ["END"], + "frequency_penalty": 0.5, + "strict_tools": True, + "response_format": {"type": "json_object"}, + "logprobs": True, + "tool_choice": "REQUIRED", + "priority": 7, + } + result = transform_request_to_oci("chat", body, "compartment-123", is_v2=True) + + chat_req = result["chatRequest"] + self.assertEqual(chat_req["maxTokens"], 100) + self.assertEqual(chat_req["temperature"], 0.7) + self.assertEqual(chat_req["stopSequences"], ["END"]) + self.assertEqual(chat_req["frequencyPenalty"], 0.5) + self.assertTrue(chat_req["strictTools"]) + self.assertEqual(chat_req["responseFormat"], {"type": "json_object"}) + self.assertTrue(chat_req["logprobs"]) + self.assertEqual(chat_req["toolChoice"], "REQUIRED") + self.assertEqual(chat_req["priority"], 7) + + def test_v2_client_rejects_v1_request(self): + """Test OciClientV2 fails when given V1-style 'message' string.""" + from cohere.oci_client import transform_request_to_oci + + with self.assertRaises(ValueError) as ctx: + transform_request_to_oci( + "chat", + {"model": "command-a-03-2025", "message": "Hello"}, + "compartment-123", + is_v2=True, + ) + self.assertIn("OciClientV2", str(ctx.exception)) + + def test_v1_client_rejects_v2_request(self): + """Test OciClient fails when given V2-style 'messages' array.""" + from cohere.oci_client import transform_request_to_oci + + with self.assertRaises(ValueError) as ctx: + transform_request_to_oci( + "chat", + {"model": "command-r-08-2024", "messages": [{"role": "user", "content": "Hi"}]}, + "compartment-123", + is_v2=False, + ) + self.assertIn("OciClient ", str(ctx.exception)) + + def test_unsupported_endpoint_raises(self): + """Test that transform_request_to_oci raises for unsupported endpoints.""" + from cohere.oci_client import transform_request_to_oci + + with self.assertRaises(ValueError) as ctx: + transform_request_to_oci("rerank", {"model": "rerank-v3.5"}, "compartment-123") + self.assertIn("rerank", str(ctx.exception)) + self.assertIn("not supported", str(ctx.exception)) + + def test_v1_chat_request_optional_params(self): + """Test V1 chat request forwards supported optional params.""" + from cohere.oci_client import transform_request_to_oci + + body = { + "model": "command-r-08-2024", + "message": "Hi", + "max_tokens": 100, + "temperature": 0.7, + "k": 10, + "p": 0.8, + "seed": 123, + "stop_sequences": ["END"], + "frequency_penalty": 0.5, + "presence_penalty": 0.2, + "documents": [{"title": "Doc", "text": "Body"}], + "tools": [{"name": "lookup"}], + "tool_results": [{"call": {"name": "lookup"}}], + "response_format": {"type": "json_object"}, + "safety_mode": "NONE", + "priority": 4, + } + result = transform_request_to_oci("chat", body, "compartment-123", is_v2=False) + + chat_req = result["chatRequest"] + self.assertEqual(chat_req["apiFormat"], "COHERE") + self.assertEqual(chat_req["message"], "Hi") + self.assertEqual(chat_req["maxTokens"], 100) + self.assertEqual(chat_req["temperature"], 0.7) + self.assertEqual(chat_req["topK"], 10) + self.assertEqual(chat_req["topP"], 0.8) + self.assertEqual(chat_req["seed"], 123) + self.assertEqual(chat_req["frequencyPenalty"], 0.5) + self.assertEqual(chat_req["presencePenalty"], 0.2) + self.assertEqual(chat_req["priority"], 4) + + def test_v1_stream_wrapper_preserves_finish_reason(self): + """Test V1 stream-end uses the OCI finish reason from the final event.""" + import json + from cohere.oci_client import transform_oci_stream_wrapper + + chunks = [ + b'data: {"text": "Hello", "isFinished": false}\n', + b'data: {"text": " world", "isFinished": true, "finishReason": "MAX_TOKENS"}\n', + b"data: [DONE]\n", + ] + + events = [ + json.loads(raw.decode("utf-8")) + for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=False) + ] + + self.assertEqual(events[2]["event_type"], "stream-end") + self.assertEqual(events[2]["finish_reason"], "MAX_TOKENS") + self.assertEqual(events[2]["response"]["text"], "Hello world") + + def test_transform_chat_request_tool_message_fields(self): + """Test tool message fields are converted to OCI names.""" + from cohere.oci_client import transform_request_to_oci + + body = { + "model": "command-a-03-2025", + "messages": [ + { + "role": "assistant", + "content": [{"type": "text", "text": "Use tool"}], + "tool_calls": [{"id": "call_1"}], + "tool_plan": "Plan", + }, + { + "role": "tool", + "tool_call_id": "call_1", + "content": [{"type": "text", "text": "Result"}], + }, + ], + } + + result = transform_request_to_oci("chat", body, "compartment-123", is_v2=True) + assistant_message, tool_message = result["chatRequest"]["messages"] + self.assertEqual(assistant_message["toolCalls"], [{"id": "call_1"}]) + self.assertEqual(assistant_message["toolPlan"], "Plan") + self.assertEqual(tool_message["toolCallId"], "call_1") + + def test_get_oci_url_known_endpoints(self): + """Test URL generation for known endpoints.""" + from cohere.oci_client import get_oci_url + + url = get_oci_url("us-chicago-1", "embed") + self.assertIn("/actions/embedText", url) + + url = get_oci_url("us-chicago-1", "chat") + self.assertIn("/actions/chat", url) + + + def test_get_oci_url_unknown_endpoint_raises(self): + """Test that unknown endpoints raise ValueError instead of producing bad URLs.""" + from cohere.oci_client import get_oci_url + + with self.assertRaises(ValueError) as ctx: + get_oci_url("us-chicago-1", "unknown_endpoint") + self.assertIn("not supported", str(ctx.exception)) + + def test_load_oci_config_missing_private_key_raises(self): + """Test that direct credentials without private key raises clear error.""" + from cohere.oci_client import _load_oci_config + + with patch("cohere.oci_client.lazy_oci", return_value=MagicMock()): + with self.assertRaises(ValueError) as ctx: + _load_oci_config( + auth_type="api_key", + config_path=None, + profile=None, + user_id="ocid1.user.oc1...", + fingerprint="xx:xx:xx", + tenancy_id="ocid1.tenancy.oc1...", + # No private_key_path or private_key_content + ) + self.assertIn("oci_private_key_path", str(ctx.exception)) + + def test_load_oci_config_ignores_inherited_session_auth(self): + """Test that named API-key profiles do not inherit DEFAULT session auth fields.""" + from cohere.oci_client import _load_oci_config + + config_text = """ +[DEFAULT] +security_token_file=/tmp/default-token + +[API_KEY_AUTH] +user=ocid1.user.oc1..test +fingerprint=aa:bb +key_file=/tmp/test.pem +tenancy=ocid1.tenancy.oc1..test +region=us-chicago-1 +""".strip() + + with tempfile.NamedTemporaryFile("w", delete=False) as config_file: + config_file.write(config_text) + config_path = config_file.name + + try: + mock_oci = MagicMock() + mock_oci.config.from_file.return_value = { + "user": "ocid1.user.oc1..test", + "fingerprint": "aa:bb", + "key_file": "/tmp/test.pem", + "tenancy": "ocid1.tenancy.oc1..test", + "region": "us-chicago-1", + "security_token_file": "/tmp/default-token", + } + + with patch("cohere.oci_client.lazy_oci", return_value=mock_oci): + config = _load_oci_config( + auth_type="api_key", + config_path=config_path, + profile="API_KEY_AUTH", + ) + finally: + os.unlink(config_path) + + self.assertNotIn("security_token_file", config) + + def test_session_auth_prefers_security_token_signer(self): + """Test session-based auth uses SecurityTokenSigner before API key signer.""" + from cohere.oci_client import map_request_to_oci + + mock_oci = MagicMock() + mock_security_signer = MagicMock() + mock_oci.signer.load_private_key_from_file.return_value = "private-key" + mock_oci.auth.signers.SecurityTokenSigner.return_value = mock_security_signer + + with patch("cohere.oci_client.lazy_oci", return_value=mock_oci), patch( + "builtins.open", mock_open(read_data="session-token") + ): + hook = map_request_to_oci( + oci_config={ + "user": "ocid1.user.oc1..example", + "fingerprint": "xx:xx", + "tenancy": "ocid1.tenancy.oc1..example", + "security_token_file": "~/.oci/token", + "key_file": "~/.oci/key.pem", + }, + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1..example", + ) + + request = MagicMock() + request.url.path = "/v2/embed" + request.read.return_value = b'{"model":"embed-english-v3.0","texts":["hello"]}' + request.method = "POST" + request.extensions = {} + + hook(request) + + # SecurityTokenSigner is called at least once (init) and again per request + # (token file is re-read on each signing call to pick up refreshed tokens). + mock_oci.auth.signers.SecurityTokenSigner.assert_called_with( + token="session-token", + private_key="private-key", + ) + self.assertGreaterEqual(mock_oci.auth.signers.SecurityTokenSigner.call_count, 1) + mock_oci.signer.Signer.assert_not_called() + + def test_session_token_refreshed_on_subsequent_requests(self): + """Verify the refreshing signer picks up a new token written to the token file.""" + import tempfile + import os + from cohere.oci_client import map_request_to_oci + + mock_oci = MagicMock() + mock_oci.signer.load_private_key_from_file.return_value = "private-key" + + # Write initial token to a real temp file so we can overwrite it later. + with tempfile.NamedTemporaryFile("w", suffix=".token", delete=False) as tf: + tf.write("token-v1") + token_path = tf.name + + try: + with patch("cohere.oci_client.lazy_oci", return_value=mock_oci): + hook = map_request_to_oci( + oci_config={ + "security_token_file": token_path, + "key_file": "/irrelevant.pem", + }, + oci_region="us-chicago-1", + oci_compartment_id="ocid1.compartment.oc1..example", + ) + + def _make_request(): + req = MagicMock() + req.url.path = "/v2/embed" + req.read.return_value = b'{"model":"embed-english-v3.0","texts":["hi"]}' + req.method = "POST" + req.extensions = {} + return req + + # First request uses token-v1 + hook(_make_request()) + calls_after_first = mock_oci.auth.signers.SecurityTokenSigner.call_count + + # Simulate token refresh by overwriting the file + with open(token_path, "w") as _f: + _f.write("token-v2") + + # Second request — should re-read and use token-v2 + hook(_make_request()) + self.assertGreater( + mock_oci.auth.signers.SecurityTokenSigner.call_count, + calls_after_first, + "SecurityTokenSigner should be re-instantiated after token file update", + ) + # Verify the latest call used the refreshed token + all_calls = mock_oci.auth.signers.SecurityTokenSigner.call_args_list + last_call = all_calls[-1] + last_token = last_call.kwargs.get("token") or (last_call.args[0] if last_call.args else None) + self.assertEqual(last_token, "token-v2", "Last signing call must use the refreshed token") + finally: + os.unlink(token_path) + + def test_embed_response_lowercases_embedding_keys(self): + """Test embed response uses lowercase keys expected by the SDK model.""" + from cohere.oci_client import transform_oci_response_to_cohere + + result = transform_oci_response_to_cohere( + "embed", + { + "id": "embed-id", + "embeddings": {"FLOAT": [[0.1, 0.2]], "INT8": [[1, 2]]}, + "usage": {"inputTokens": 3, "completionTokens": 7}, + }, + is_v2=True, + ) + + self.assertIn("float", result["embeddings"]) + self.assertIn("int8", result["embeddings"]) + self.assertNotIn("FLOAT", result["embeddings"]) + self.assertEqual(result["meta"]["tokens"]["output_tokens"], 7) + + def test_normalize_model_for_oci_rejects_empty_model(self): + """Test model normalization fails clearly for empty model names.""" + from cohere.oci_client import normalize_model_for_oci + + with self.assertRaises(ValueError) as ctx: + normalize_model_for_oci("") + self.assertIn("non-empty model", str(ctx.exception)) + + def test_stream_wrapper_emits_full_event_lifecycle(self): + """Test that stream emits message-start, content-start, content-delta, content-end, message-end.""" + import json + from cohere.oci_client import transform_oci_stream_wrapper + + chunks = [ + b'data: {"message": {"content": [{"type": "TEXT", "text": "Hello"}]}}\n', + b'data: {"message": {"content": [{"type": "TEXT", "text": " world"}]}, "finishReason": "COMPLETE"}\n', + b'data: [DONE]\n', + ] + + events = [] + for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): + line = raw.decode("utf-8").strip() + if line.startswith("data: "): + events.append(json.loads(line[6:])) + + event_types = [e["type"] for e in events] + self.assertEqual(event_types[0], "message-start") + self.assertEqual(event_types[1], "content-start") + self.assertEqual(event_types[2], "content-delta") + self.assertEqual(event_types[3], "content-delta") + self.assertEqual(event_types[4], "content-end") + self.assertEqual(event_types[5], "message-end") + + # Verify message-start has id and role + self.assertIn("id", events[0]) + self.assertEqual(events[0]["delta"]["message"]["role"], "assistant") + + # Verify content-start has index and type + self.assertEqual(events[1]["index"], 0) + self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "text") + self.assertEqual(events[5]["delta"]["finish_reason"], "COMPLETE") + + def test_stream_wrapper_emits_new_content_block_on_thinking_transition(self): + """Test streams emit a new content block when transitioning from thinking to text.""" + import json + from cohere.oci_client import transform_oci_stream_wrapper + + chunks = [ + b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n', + b'data: {"message": {"content": [{"type": "TEXT", "text": "Answer"}]}, "finishReason": "COMPLETE"}\n', + b"data: [DONE]\n", + ] + + events = [] + for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): + line = raw.decode("utf-8").strip() + if line.startswith("data: "): + events.append(json.loads(line[6:])) + + self.assertEqual(events[1]["type"], "content-start") + self.assertEqual(events[1]["delta"]["message"]["content"]["type"], "thinking") + self.assertEqual(events[2]["type"], "content-delta") + self.assertEqual(events[2]["index"], 0) + self.assertEqual(events[3], {"type": "content-end", "index": 0}) + self.assertEqual(events[4]["type"], "content-start") + self.assertEqual(events[4]["index"], 1) + self.assertEqual(events[4]["delta"]["message"]["content"]["type"], "text") + self.assertEqual(events[5]["type"], "content-delta") + self.assertEqual(events[5]["index"], 1) + + def test_stream_wrapper_no_spurious_block_on_finish_only_event(self): + """Finish-only event after thinking block must not open a spurious empty text block.""" + import json + from cohere.oci_client import transform_oci_stream_wrapper + + chunks = [ + b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n', + b'data: {"finishReason": "COMPLETE"}\n', + b"data: [DONE]\n", + ] + + events = [] + for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): + line = raw.decode("utf-8").strip() + if line.startswith("data: "): + events.append(json.loads(line[6:])) + + types = [e["type"] for e in events] + # Must not contain two content-start events + self.assertEqual(types.count("content-start"), 1) + # The single content block must be thinking + cs = next(e for e in events if e["type"] == "content-start") + self.assertEqual(cs["delta"]["message"]["content"]["type"], "thinking") + # Must end cleanly + self.assertEqual(events[-1]["type"], "message-end") + + def test_stream_wrapper_skips_malformed_json_with_warning(self): + """Test that malformed JSON in SSE stream is skipped.""" + from cohere.oci_client import transform_oci_stream_wrapper + + chunks = [ + b'data: not-valid-json\n', + b'data: {"message": {"content": [{"type": "TEXT", "text": "hello"}]}}\n', + b'data: [DONE]\n', + ] + events = list(transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True)) + # Should get message-start + content-start + content-delta + content-end + message-end. + self.assertEqual(len(events), 5) + + def test_stream_wrapper_skips_message_end_for_empty_stream(self): + """Test empty streams do not emit message-end without a preceding message-start.""" + from cohere.oci_client import transform_oci_stream_wrapper + + events = list(transform_oci_stream_wrapper(iter([b"data: [DONE]\n"]), "chat", is_v2=True)) + + self.assertEqual(events, []) + + def test_stream_wrapper_done_uses_current_content_index_after_transition(self): + """Test fallback content-end uses the latest content index after type transitions.""" + import json + from cohere.oci_client import transform_oci_stream_wrapper + + chunks = [ + b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n', + b'data: {"message": {"content": [{"type": "TEXT", "text": "Answer"}]}}\n', + b"data: [DONE]\n", + ] + + events = [] + for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True): + line = raw.decode("utf-8").strip() + if line.startswith("data: "): + events.append(json.loads(line[6:])) + + self.assertEqual(events[-2], {"type": "content-end", "index": 1}) + self.assertEqual(events[-1]["type"], "message-end") + + def test_stream_wrapper_raises_on_transform_error(self): + """Test that transform errors in stream produce OCI-specific error.""" + from cohere.oci_client import transform_oci_stream_wrapper + + # Event with structure that will cause transform_stream_event to fail + # (message is None, causing TypeError on "content" in None) + chunks = [ + b'data: {"message": null}\n', + ] + with self.assertRaises(RuntimeError) as ctx: + list(transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True)) + self.assertIn("OCI stream event transformation failed", str(ctx.exception)) + + def test_stream_event_finish_reason_keeps_final_text(self): + """Test finish events keep final text before content-end.""" + from cohere.oci_client import transform_stream_event + + events = transform_stream_event( + "chat", + { + "message": {"content": [{"type": "TEXT", "text": " world"}]}, + "finishReason": "COMPLETE", + }, + is_v2=True, + ) + + self.assertEqual(events[0]["type"], "content-delta") + self.assertEqual(events[0]["delta"]["message"]["content"]["text"], " world") + self.assertEqual(events[1]["type"], "content-end") + +if __name__ == "__main__": + unittest.main()