From 91e887e42bd81d9d5c6854a6ebe9e79173a28088 Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 13 May 2026 01:32:34 +0000 Subject: [PATCH] Add OHTTP-style anonymous inference endpoint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implements RFC 9458 Oblivious HTTP encapsulation so clients can submit chat completions through an independent relay without exposing their IP to the enclave or their prompt to the relay. The HPKE X25519 keypair is generated alongside the existing RSA signing key and bound to the same nitriding registration digest, so the Nitro attestation document commits to both. - tee_gateway/ohttp.py: HPKE wrap/unwrap helpers (DHKEM(X25519)/HKDF-SHA256/ ChaCha20-Poly1305). Response keying derived per-context per RFC 9458 §4.2. - tee_gateway/tee_manager.py: HPKE keypair, key-config blob, attestation document now includes the HPKE public key. - tee_gateway/controllers/ohttp_controller.py: /v1/ohttp dispatches the decrypted request to the existing chat handler, scrubs identifying fields before forwarding upstream, refuses stream=true. - /v1/ohttp/config exposes the HPKE key config for client discovery. - Test coverage: round-trip, wrong-suite, truncated input, tampered ciphertext. Known limitation: payment gating is not yet wired for this endpoint; a blind-token layer will follow in a separate change. https://claude.ai/code/session_01WyddtSz2rtiP61LtVJbsJy --- pyproject.toml | 1 + tee_gateway/__main__.py | 18 +++ tee_gateway/controllers/ohttp_controller.py | 141 +++++++++++++++++ tee_gateway/ohttp.py | 166 ++++++++++++++++++++ tee_gateway/tee_manager.py | 53 ++++++- tee_gateway/test/test_ohttp.py | 99 ++++++++++++ 6 files changed, 476 insertions(+), 2 deletions(-) create mode 100644 tee_gateway/controllers/ohttp_controller.py create mode 100644 tee_gateway/ohttp.py create mode 100644 tee_gateway/test/test_ohttp.py diff --git a/pyproject.toml b/pyproject.toml index 80a9bc6..426ab33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "psutil>=7.2.1", "python-dotenv>=1.2.1", "eth-account>=0.13.0", + "pyhpke>=0.6.0", ] [dependency-groups] diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 165d5d1..ad087af 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -23,6 +23,10 @@ ) from tee_gateway.llm_backend import get_provider_config, set_provider_config from tee_gateway.heartbeat import create_heartbeat_service +from tee_gateway.controllers.ohttp_controller import ( + create_anonymous_chat_completion, + get_hpke_config, +) from x402.http import FacilitatorConfig, HTTPFacilitatorClientSync, PaymentOption from x402.http.middleware.flask import payment_middleware @@ -437,6 +441,20 @@ def create_app(): "/heartbeat/status", "heartbeat-status", heartbeat_status, methods=["GET"] ) + # Anonymous inference (OHTTP-wrapped chat completions). Deliberately + # mounted via add_url_rule rather than the OpenAPI spec because the body + # is raw binary and connexion's request-validation pipeline would reject + # it as malformed JSON. + app.app.add_url_rule( + "/v1/ohttp", + "anonymous-chat", + create_anonymous_chat_completion, + methods=["POST"], + ) + app.app.add_url_rule( + "/v1/ohttp/config", "ohttp-config", get_hpke_config, methods=["GET"] + ) + # Initialize TEE here so it runs under both Gunicorn and direct execution. # This is the single TEEKeyManager instance — the same key both registers # with nitriding and signs all LLM responses. diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py new file mode 100644 index 0000000..688e862 --- /dev/null +++ b/tee_gateway/controllers/ohttp_controller.py @@ -0,0 +1,141 @@ +""" +Oblivious HTTP endpoint for anonymous inference. + +This handler is intentionally minimal: it does HPKE decapsulation, dispatches +the inner request to the existing chat-completions handler in-process (no +network hop), and HPKE-encapsulates the response. The inner JSON request is +identical to the standard /v1/chat/completions body. + +Threat model nuances: + * The relay in front of this endpoint sees the encapsulated ciphertext and + the client IP, but no request content. + * The enclave sees plaintext and the relay's IP, never the client's. + * If the client's payload contains identifiers (cookies, ``user`` field, + custom request IDs), unlinkability is broken at the application layer — + we strip the obvious ones below. + * Streaming is intentionally not supported on this endpoint. SSE would + create per-chunk side channels (timing, length) that defeat the point of + bundling everything into a single sealed response. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import connexion +from flask import Response + +from tee_gateway import ohttp +from tee_gateway.tee_manager import get_tee_keys + +logger = logging.getLogger(__name__) + +OHTTP_MEDIA_TYPE = "message/ohttp-req" +OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res" + +# Fields that can re-identify a client and have no role in inference. We drop +# them before forwarding to the inner handler — keeping them inside the +# encrypted envelope would only protect them from the relay, not from us or +# the upstream LLM provider. +_IDENTIFYING_FIELDS = ("user", "metadata", "x-request-id", "request_id") + + +def _scrub(payload: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in payload.items() if k not in _IDENTIFYING_FIELDS} + + +def create_anonymous_chat_completion(): + """POST /v1/ohttp — decrypt, dispatch, re-encrypt. + + Body: raw bytes (OHTTP-encapsulated request). + Returns: raw bytes (OHTTP-encapsulated response) with Content-Type + ``message/ohttp-res``. + """ + req = connexion.request + # Tolerate both Connexion's Flask request and a bare Flask request. + raw_body: bytes = req.get_data(cache=False) + if not raw_body: + return _error(400, "empty body") + + tee = get_tee_keys() + if tee.hpke_private_key is None: + return _error(503, "anonymous inference not initialized") + + try: + decap = ohttp.decapsulate_request(tee.hpke_private_key, raw_body) + except Exception as exc: + # Don't leak which step failed — clients can retry with a fresh + # encapsulation, all observable failures look identical. + logger.warning("OHTTP decapsulation failed: %s", type(exc).__name__) + return _error(400, "malformed encapsulated request") + + try: + inner_body = json.loads(decap.plaintext.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return _error(400, "inner payload is not valid JSON") + + if not isinstance(inner_body, dict): + return _error(400, "inner payload must be a JSON object") + + if inner_body.get("stream"): + # Streaming is rejected on principle (see module docstring). Clients + # who want low TTFT under anonymity should use a shorter max_tokens. + return _error(400, "stream=true is not supported over OHTTP") + + inner_body = _scrub(inner_body) + + # Late import to avoid a circular dependency at module load (the chat + # controller pulls in models that import this package). + from tee_gateway.controllers.chat_controller import ( + _create_non_streaming_response, + _parse_chat_request, + ) + + try: + chat_request = _parse_chat_request(inner_body) + inner_result = _create_non_streaming_response(chat_request) + except Exception as exc: + logger.error("inner inference failed under OHTTP: %s", exc, exc_info=True) + inner_result = ({"error": "inference failed"}, 500) + + # _create_non_streaming_response returns either a dict or (body, status) + if isinstance(inner_result, tuple): + body_obj, status = inner_result + else: + body_obj, status = inner_result, 200 + + inner_json = json.dumps( + {"status": status, "body": body_obj}, + separators=(",", ":"), + ).encode("utf-8") + + sealed = ohttp.encapsulate_response( + decap.response_key, decap.enc, inner_json + ) + return Response( + sealed, + status=200, + mimetype=OHTTP_RESPONSE_MEDIA_TYPE, + ) + + +def get_hpke_config(): + """GET /v1/ohttp/config — return the HPKE key configuration. + + Returns both an OHTTP-compliant binary key_config (base64) and the + individual fields for clients that prefer to parse JSON. The same data is + embedded inside the attestation document at /signing-key for clients that + want to verify the binding to the enclave's PCRs in one step. + """ + try: + tee = get_tee_keys() + return tee.get_hpke_config(), 200 + except Exception as exc: + logger.error("HPKE config error: %s", exc, exc_info=True) + return {"error": str(exc)}, 500 + + +def _error(status: int, message: str) -> tuple[dict, int]: + return {"error": message}, status diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py new file mode 100644 index 0000000..d5369aa --- /dev/null +++ b/tee_gateway/ohttp.py @@ -0,0 +1,166 @@ +""" +Oblivious HTTP encapsulation for anonymous inference. + +Implements request/response encapsulation per RFC 9458 (Oblivious HTTP) +with a fixed HPKE ciphersuite: + - KEM: DHKEM(X25519, HKDF-SHA256) (0x0020) + - KDF: HKDF-SHA256 (0x0001) + - AEAD: ChaCha20-Poly1305 (0x0003) + +The inner payload is application/json — we do not BHTTP-wrap the inference +request, since the enclave is the terminal endpoint and not a generic HTTP +proxy. This is a documented divergence from strict RFC 9458; the cryptographic +construction (HPKE base + exported response keying) is identical. + +Trust model: the relay sees ciphertext + client IP; the enclave sees plaintext ++ relay IP. Unlinkability holds unless relay and enclave collude. +""" + +from __future__ import annotations + +import os +import struct +from dataclasses import dataclass + +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand +from pyhpke import AEADId, CipherSuite, KDFId, KEMId +from pyhpke.kem_key_interface import KEMKeyInterface + + +# RFC 9180 / 9458 algorithm identifiers +KEM_ID_X25519 = 0x0020 +KDF_ID_HKDF_SHA256 = 0x0001 +AEAD_ID_CHACHA20_POLY1305 = 0x0003 + +# Single, stable key configuration ID. Bump when the keypair or suite changes +# so clients can refuse stale configs. +KEY_CONFIG_ID = 0x01 + +# AEAD parameters for ChaCha20-Poly1305 +_NK = 32 # key length +_NN = 12 # nonce length + +# Per RFC 9458 §4.1/4.2 — "info" labels for the HPKE context. +_LABEL_REQUEST = b"message/bhttp request" +_LABEL_RESPONSE = b"message/bhttp response" + +_SUITE = CipherSuite.new( + KEMId.DHKEM_X25519_HKDF_SHA256, + KDFId.HKDF_SHA256, + AEADId.CHACHA20_POLY1305, +) + + +def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes: + return bytes([key_id]) + struct.pack( + ">HHH", + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ) + + +def key_config(public_key_raw: bytes, key_id: int = KEY_CONFIG_ID) -> bytes: + """Build an OHTTP key configuration blob (RFC 9458 §3). + + Format: + key_id(1) || kem_id(2) || public_key(Npk=32) || + symmetric_algorithms_length(2) || (kdf_id(2) || aead_id(2))+ + """ + if len(public_key_raw) != 32: + raise ValueError("X25519 public key must be 32 bytes") + symmetric = struct.pack(">HH", KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) + return ( + bytes([key_id]) + + struct.pack(">H", KEM_ID_X25519) + + public_key_raw + + struct.pack(">H", len(symmetric)) + + symmetric + ) + + +@dataclass +class DecapsulatedRequest: + """Result of decapsulating an OHTTP-wrapped request.""" + + plaintext: bytes + response_key: bytes # 32 bytes exported from the HPKE context + enc: bytes # client's ephemeral public key, used as salt for the response + + +def decapsulate_request( + private_key: KEMKeyInterface, encapsulated_request: bytes +) -> DecapsulatedRequest: + """Decrypt an HPKE-wrapped request inside the enclave. + + Raises ValueError on malformed input or unsupported ciphersuite. We + never echo the underlying exception text to clients — it can leak + timing/oracle info. + """ + if len(encapsulated_request) < 7 + 32: + raise ValueError("encapsulated request too short") + + key_id = encapsulated_request[0] + kem_id, kdf_id, aead_id = struct.unpack(">HHH", encapsulated_request[1:7]) + if (key_id, kem_id, kdf_id, aead_id) != ( + KEY_CONFIG_ID, + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ): + raise ValueError("unsupported HPKE configuration") + + enc = encapsulated_request[7 : 7 + 32] + aead_ct = encapsulated_request[7 + 32 :] + + info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id) + recipient = _SUITE.create_recipient_context(enc, private_key, info=info) + plaintext = recipient.open(aead_ct, aad=b"") + + # Export a fresh secret bound to this HPKE context, used to derive the + # response AEAD key. This is the OHTTP-defined separation between the + # request and response halves of the same exchange. + response_secret = recipient.export(_LABEL_RESPONSE, _NK) + return DecapsulatedRequest( + plaintext=plaintext, response_key=response_secret, enc=enc + ) + + +def encapsulate_response( + response_secret: bytes, enc: bytes, plaintext: bytes +) -> bytes: + """Seal a response under the per-request derived key (RFC 9458 §4.2). + + Wire format: response_nonce(max(Nn, Nk)=Nk=32) || AEAD ciphertext + """ + response_nonce = os.urandom(max(_NN, _NK)) + salt = enc + response_nonce + + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + + aead_key = HKDFExpand( + algorithm=hashes.SHA256(), length=_NK, info=b"key" + ).derive(prk) + aead_nonce = HKDFExpand( + algorithm=hashes.SHA256(), length=_NN, info=b"nonce" + ).derive(prk) + + ct = ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"") + return response_nonce + ct + + +def generate_keypair() -> tuple[KEMKeyInterface, bytes]: + """Generate an X25519 keypair for HPKE. Returns (private_key, public_key_raw). + + pyhpke 0.6 derives keys from random IKM via ``kem.derive_key_pair(ikm)``, + which returns a ``KEMKeyPair`` wrapper. We hold onto the private side for + decapsulation and serialize the public side to raw 32-byte form for the + key configuration blob. + """ + pair = _SUITE.kem.derive_key_pair(os.urandom(32)) + pk_raw = pair.public_key.to_public_bytes() + return pair.private_key, pk_raw diff --git a/tee_gateway/tee_manager.py b/tee_gateway/tee_manager.py index 3d98b12..402438e 100644 --- a/tee_gateway/tee_manager.py +++ b/tee_gateway/tee_manager.py @@ -17,6 +17,8 @@ from eth_account import Account from eth_hash.auto import keccak +from tee_gateway import ohttp + logger = logging.getLogger(__name__) NITRIDING_BASE_URL = "http://127.0.0.1:8080" @@ -36,6 +38,11 @@ def __init__(self, register=True): self.public_key_pem = None self.tee_id = None self.wallet_address = None + # HPKE keypair for OHTTP-style anonymous inference. Generated in the + # same enclave boot so the X25519 public key is covered by the same + # attestation that covers the RSA signing key. + self.hpke_private_key = None + self.hpke_public_key_raw: bytes | None = None self._generate_keys() if register: self.register_with_nitriding() @@ -70,19 +77,39 @@ def _generate_keys(self): wallet_account = Account.from_key(wallet_key_bytes) self.wallet_address = wallet_account.address + # HPKE X25519 keypair — never leaves the enclave; clients address it + # via the public-key fingerprint published with the attestation. + self.hpke_private_key, self.hpke_public_key_raw = ohttp.generate_keypair() + logger.info("TEE key pair generated successfully") logger.info(f"tee_id: 0x{self.tee_id}") logger.info(f"wallet_address: {self.wallet_address}") + logger.info( + f"hpke_public_key (X25519, raw, hex): {self.hpke_public_key_raw.hex()}" + ) def register_with_nitriding(self): - """Register public key hash with nitriding.""" + """Register public key hash with nitriding. + + The hash covers both the RSA signing key (DER-encoded SPKI) and the + raw X25519 HPKE public key. Including both in a single attested digest + means a verifier who validates the attestation document automatically + gets binding for the HPKE config used for anonymous inference — no + separate trust anchor required. + """ try: public_key_der = self.public_key.public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) - key_hash = hashlib.sha256(public_key_der).digest() + # Domain-separated transcript so a future addition of more keys + # can't be confused with the existing layout. + transcript = ( + b"og-tee-keys|v2|rsa-spki=" + public_key_der + + b"|hpke-x25519=" + (self.hpke_public_key_raw or b"") + ) + key_hash = hashlib.sha256(transcript).digest() key_hash_b64 = base64.b64encode(key_hash).decode("utf-8") logger.info(f"Public key DER length: {len(public_key_der)} bytes") @@ -149,12 +176,34 @@ def get_wallet_address(self) -> str: """Return the TEE-generated Ethereum wallet address (checksum).""" return self.wallet_address + def get_hpke_config(self) -> dict: + """Return the HPKE key configuration for anonymous inference. + + ``key_config`` is the RFC 9458 §3 binary key-config blob, base64-encoded. + Clients should treat this as authoritative only when fetched alongside + the Nitro attestation document (which commits to the same key hash via + nitriding registration). + """ + if self.hpke_public_key_raw is None: + raise RuntimeError("HPKE keypair not initialized") + return { + "key_id": ohttp.KEY_CONFIG_ID, + "kem_id": ohttp.KEM_ID_X25519, + "kdf_id": ohttp.KDF_ID_HKDF_SHA256, + "aead_id": ohttp.AEAD_ID_CHACHA20_POLY1305, + "public_key": self.hpke_public_key_raw.hex(), + "key_config": base64.b64encode( + ohttp.key_config(self.hpke_public_key_raw) + ).decode("ascii"), + } + def get_attestation_document(self) -> dict: """Return TEE attestation document.""" return { "public_key": self.public_key_pem, "tee_id": f"0x{self.tee_id}", "wallet_address": self.wallet_address, + "hpke": self.get_hpke_config() if self.hpke_public_key_raw else None, "timestamp": datetime.now(UTC).isoformat(), "enclave_info": { "platform": "aws-nitro", diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py new file mode 100644 index 0000000..44342ce --- /dev/null +++ b/tee_gateway/test/test_ohttp.py @@ -0,0 +1,99 @@ +"""Tests for the OHTTP encapsulation module.""" + +from __future__ import annotations + +import json + +import pytest + +from tee_gateway import ohttp + + +def test_round_trip_request_and_response(): + sk, pk_raw = ohttp.generate_keypair() + + plaintext = json.dumps({"model": "gpt-4.1", "n": 1}).encode() + # Encapsulate using the same code paths a client would, since pyhpke is + # symmetric — we wire the request manually. + config = ohttp.key_config(pk_raw) + assert config[0] == ohttp.KEY_CONFIG_ID + + # Build a wire payload exactly as the SDK does. + import struct + hdr = ( + bytes([ohttp.KEY_CONFIG_ID]) + + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(plaintext, aad=b"") + wire = hdr + enc + ct + + decap = ohttp.decapsulate_request(sk, wire) + assert decap.plaintext == plaintext + + response_secret = sender.export(b"message/bhttp response", 32) + assert decap.response_key == response_secret + assert decap.enc == enc + + response = b'{"ok":true}' + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, response) + + # Round-trip the response on the "client" side using the same primitives. + import os + from cryptography.hazmat.primitives import hashes, hmac + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand + + response_nonce = sealed[:32] + aead_ct = sealed[32:] + salt = enc + response_nonce + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + key = HKDFExpand(algorithm=hashes.SHA256(), length=32, info=b"key").derive(prk) + nonce = HKDFExpand(algorithm=hashes.SHA256(), length=12, info=b"nonce").derive(prk) + assert ChaCha20Poly1305(key).decrypt(nonce, aead_ct, b"") == response + + +def test_rejects_wrong_suite(): + sk, pk_raw = ohttp.generate_keypair() + # Build a wire with the wrong AEAD ID + import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", ohttp.KEM_ID_X25519, ohttp.KDF_ID_HKDF_SHA256, 0x0001 # AES-128-GCM + ) + fake_wire = hdr + b"\x00" * 32 + b"\x00" * 16 + with pytest.raises(ValueError, match="unsupported"): + ohttp.decapsulate_request(sk, fake_wire) + + +def test_rejects_short_input(): + sk, _ = ohttp.generate_keypair() + with pytest.raises(ValueError, match="too short"): + ohttp.decapsulate_request(sk, b"\x01") + + +def test_rejects_tampered_ciphertext(): + sk, pk_raw = ohttp.generate_keypair() + import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(b"hello", aad=b"") + wire = bytearray(hdr + enc + ct) + wire[-1] ^= 0xFF + with pytest.raises(Exception): + ohttp.decapsulate_request(sk, bytes(wire))