diff --git a/pyproject.toml b/pyproject.toml index 80a9bc6..426ab33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dependencies = [ "psutil>=7.2.1", "python-dotenv>=1.2.1", "eth-account>=0.13.0", + "pyhpke>=0.6.0", ] [dependency-groups] diff --git a/tee_gateway/__main__.py b/tee_gateway/__main__.py index 165d5d1..ad087af 100644 --- a/tee_gateway/__main__.py +++ b/tee_gateway/__main__.py @@ -23,6 +23,10 @@ ) from tee_gateway.llm_backend import get_provider_config, set_provider_config from tee_gateway.heartbeat import create_heartbeat_service +from tee_gateway.controllers.ohttp_controller import ( + create_anonymous_chat_completion, + get_hpke_config, +) from x402.http import FacilitatorConfig, HTTPFacilitatorClientSync, PaymentOption from x402.http.middleware.flask import payment_middleware @@ -437,6 +441,20 @@ def create_app(): "/heartbeat/status", "heartbeat-status", heartbeat_status, methods=["GET"] ) + # Anonymous inference (OHTTP-wrapped chat completions). Deliberately + # mounted via add_url_rule rather than the OpenAPI spec because the body + # is raw binary and connexion's request-validation pipeline would reject + # it as malformed JSON. + app.app.add_url_rule( + "/v1/ohttp", + "anonymous-chat", + create_anonymous_chat_completion, + methods=["POST"], + ) + app.app.add_url_rule( + "/v1/ohttp/config", "ohttp-config", get_hpke_config, methods=["GET"] + ) + # Initialize TEE here so it runs under both Gunicorn and direct execution. # This is the single TEEKeyManager instance — the same key both registers # with nitriding and signs all LLM responses. diff --git a/tee_gateway/controllers/ohttp_controller.py b/tee_gateway/controllers/ohttp_controller.py new file mode 100644 index 0000000..688e862 --- /dev/null +++ b/tee_gateway/controllers/ohttp_controller.py @@ -0,0 +1,141 @@ +""" +Oblivious HTTP endpoint for anonymous inference. + +This handler is intentionally minimal: it does HPKE decapsulation, dispatches +the inner request to the existing chat-completions handler in-process (no +network hop), and HPKE-encapsulates the response. The inner JSON request is +identical to the standard /v1/chat/completions body. + +Threat model nuances: + * The relay in front of this endpoint sees the encapsulated ciphertext and + the client IP, but no request content. + * The enclave sees plaintext and the relay's IP, never the client's. + * If the client's payload contains identifiers (cookies, ``user`` field, + custom request IDs), unlinkability is broken at the application layer — + we strip the obvious ones below. + * Streaming is intentionally not supported on this endpoint. SSE would + create per-chunk side channels (timing, length) that defeat the point of + bundling everything into a single sealed response. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +import connexion +from flask import Response + +from tee_gateway import ohttp +from tee_gateway.tee_manager import get_tee_keys + +logger = logging.getLogger(__name__) + +OHTTP_MEDIA_TYPE = "message/ohttp-req" +OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res" + +# Fields that can re-identify a client and have no role in inference. We drop +# them before forwarding to the inner handler — keeping them inside the +# encrypted envelope would only protect them from the relay, not from us or +# the upstream LLM provider. +_IDENTIFYING_FIELDS = ("user", "metadata", "x-request-id", "request_id") + + +def _scrub(payload: dict[str, Any]) -> dict[str, Any]: + return {k: v for k, v in payload.items() if k not in _IDENTIFYING_FIELDS} + + +def create_anonymous_chat_completion(): + """POST /v1/ohttp — decrypt, dispatch, re-encrypt. + + Body: raw bytes (OHTTP-encapsulated request). + Returns: raw bytes (OHTTP-encapsulated response) with Content-Type + ``message/ohttp-res``. + """ + req = connexion.request + # Tolerate both Connexion's Flask request and a bare Flask request. + raw_body: bytes = req.get_data(cache=False) + if not raw_body: + return _error(400, "empty body") + + tee = get_tee_keys() + if tee.hpke_private_key is None: + return _error(503, "anonymous inference not initialized") + + try: + decap = ohttp.decapsulate_request(tee.hpke_private_key, raw_body) + except Exception as exc: + # Don't leak which step failed — clients can retry with a fresh + # encapsulation, all observable failures look identical. + logger.warning("OHTTP decapsulation failed: %s", type(exc).__name__) + return _error(400, "malformed encapsulated request") + + try: + inner_body = json.loads(decap.plaintext.decode("utf-8")) + except (UnicodeDecodeError, json.JSONDecodeError): + return _error(400, "inner payload is not valid JSON") + + if not isinstance(inner_body, dict): + return _error(400, "inner payload must be a JSON object") + + if inner_body.get("stream"): + # Streaming is rejected on principle (see module docstring). Clients + # who want low TTFT under anonymity should use a shorter max_tokens. + return _error(400, "stream=true is not supported over OHTTP") + + inner_body = _scrub(inner_body) + + # Late import to avoid a circular dependency at module load (the chat + # controller pulls in models that import this package). + from tee_gateway.controllers.chat_controller import ( + _create_non_streaming_response, + _parse_chat_request, + ) + + try: + chat_request = _parse_chat_request(inner_body) + inner_result = _create_non_streaming_response(chat_request) + except Exception as exc: + logger.error("inner inference failed under OHTTP: %s", exc, exc_info=True) + inner_result = ({"error": "inference failed"}, 500) + + # _create_non_streaming_response returns either a dict or (body, status) + if isinstance(inner_result, tuple): + body_obj, status = inner_result + else: + body_obj, status = inner_result, 200 + + inner_json = json.dumps( + {"status": status, "body": body_obj}, + separators=(",", ":"), + ).encode("utf-8") + + sealed = ohttp.encapsulate_response( + decap.response_key, decap.enc, inner_json + ) + return Response( + sealed, + status=200, + mimetype=OHTTP_RESPONSE_MEDIA_TYPE, + ) + + +def get_hpke_config(): + """GET /v1/ohttp/config — return the HPKE key configuration. + + Returns both an OHTTP-compliant binary key_config (base64) and the + individual fields for clients that prefer to parse JSON. The same data is + embedded inside the attestation document at /signing-key for clients that + want to verify the binding to the enclave's PCRs in one step. + """ + try: + tee = get_tee_keys() + return tee.get_hpke_config(), 200 + except Exception as exc: + logger.error("HPKE config error: %s", exc, exc_info=True) + return {"error": str(exc)}, 500 + + +def _error(status: int, message: str) -> tuple[dict, int]: + return {"error": message}, status diff --git a/tee_gateway/ohttp.py b/tee_gateway/ohttp.py new file mode 100644 index 0000000..d5369aa --- /dev/null +++ b/tee_gateway/ohttp.py @@ -0,0 +1,166 @@ +""" +Oblivious HTTP encapsulation for anonymous inference. + +Implements request/response encapsulation per RFC 9458 (Oblivious HTTP) +with a fixed HPKE ciphersuite: + - KEM: DHKEM(X25519, HKDF-SHA256) (0x0020) + - KDF: HKDF-SHA256 (0x0001) + - AEAD: ChaCha20-Poly1305 (0x0003) + +The inner payload is application/json — we do not BHTTP-wrap the inference +request, since the enclave is the terminal endpoint and not a generic HTTP +proxy. This is a documented divergence from strict RFC 9458; the cryptographic +construction (HPKE base + exported response keying) is identical. + +Trust model: the relay sees ciphertext + client IP; the enclave sees plaintext ++ relay IP. Unlinkability holds unless relay and enclave collude. +""" + +from __future__ import annotations + +import os +import struct +from dataclasses import dataclass + +from cryptography.hazmat.primitives import hashes, hmac +from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 +from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand +from pyhpke import AEADId, CipherSuite, KDFId, KEMId +from pyhpke.kem_key_interface import KEMKeyInterface + + +# RFC 9180 / 9458 algorithm identifiers +KEM_ID_X25519 = 0x0020 +KDF_ID_HKDF_SHA256 = 0x0001 +AEAD_ID_CHACHA20_POLY1305 = 0x0003 + +# Single, stable key configuration ID. Bump when the keypair or suite changes +# so clients can refuse stale configs. +KEY_CONFIG_ID = 0x01 + +# AEAD parameters for ChaCha20-Poly1305 +_NK = 32 # key length +_NN = 12 # nonce length + +# Per RFC 9458 §4.1/4.2 — "info" labels for the HPKE context. +_LABEL_REQUEST = b"message/bhttp request" +_LABEL_RESPONSE = b"message/bhttp response" + +_SUITE = CipherSuite.new( + KEMId.DHKEM_X25519_HKDF_SHA256, + KDFId.HKDF_SHA256, + AEADId.CHACHA20_POLY1305, +) + + +def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes: + return bytes([key_id]) + struct.pack( + ">HHH", + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ) + + +def key_config(public_key_raw: bytes, key_id: int = KEY_CONFIG_ID) -> bytes: + """Build an OHTTP key configuration blob (RFC 9458 §3). + + Format: + key_id(1) || kem_id(2) || public_key(Npk=32) || + symmetric_algorithms_length(2) || (kdf_id(2) || aead_id(2))+ + """ + if len(public_key_raw) != 32: + raise ValueError("X25519 public key must be 32 bytes") + symmetric = struct.pack(">HH", KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305) + return ( + bytes([key_id]) + + struct.pack(">H", KEM_ID_X25519) + + public_key_raw + + struct.pack(">H", len(symmetric)) + + symmetric + ) + + +@dataclass +class DecapsulatedRequest: + """Result of decapsulating an OHTTP-wrapped request.""" + + plaintext: bytes + response_key: bytes # 32 bytes exported from the HPKE context + enc: bytes # client's ephemeral public key, used as salt for the response + + +def decapsulate_request( + private_key: KEMKeyInterface, encapsulated_request: bytes +) -> DecapsulatedRequest: + """Decrypt an HPKE-wrapped request inside the enclave. + + Raises ValueError on malformed input or unsupported ciphersuite. We + never echo the underlying exception text to clients — it can leak + timing/oracle info. + """ + if len(encapsulated_request) < 7 + 32: + raise ValueError("encapsulated request too short") + + key_id = encapsulated_request[0] + kem_id, kdf_id, aead_id = struct.unpack(">HHH", encapsulated_request[1:7]) + if (key_id, kem_id, kdf_id, aead_id) != ( + KEY_CONFIG_ID, + KEM_ID_X25519, + KDF_ID_HKDF_SHA256, + AEAD_ID_CHACHA20_POLY1305, + ): + raise ValueError("unsupported HPKE configuration") + + enc = encapsulated_request[7 : 7 + 32] + aead_ct = encapsulated_request[7 + 32 :] + + info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id) + recipient = _SUITE.create_recipient_context(enc, private_key, info=info) + plaintext = recipient.open(aead_ct, aad=b"") + + # Export a fresh secret bound to this HPKE context, used to derive the + # response AEAD key. This is the OHTTP-defined separation between the + # request and response halves of the same exchange. + response_secret = recipient.export(_LABEL_RESPONSE, _NK) + return DecapsulatedRequest( + plaintext=plaintext, response_key=response_secret, enc=enc + ) + + +def encapsulate_response( + response_secret: bytes, enc: bytes, plaintext: bytes +) -> bytes: + """Seal a response under the per-request derived key (RFC 9458 §4.2). + + Wire format: response_nonce(max(Nn, Nk)=Nk=32) || AEAD ciphertext + """ + response_nonce = os.urandom(max(_NN, _NK)) + salt = enc + response_nonce + + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + + aead_key = HKDFExpand( + algorithm=hashes.SHA256(), length=_NK, info=b"key" + ).derive(prk) + aead_nonce = HKDFExpand( + algorithm=hashes.SHA256(), length=_NN, info=b"nonce" + ).derive(prk) + + ct = ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"") + return response_nonce + ct + + +def generate_keypair() -> tuple[KEMKeyInterface, bytes]: + """Generate an X25519 keypair for HPKE. Returns (private_key, public_key_raw). + + pyhpke 0.6 derives keys from random IKM via ``kem.derive_key_pair(ikm)``, + which returns a ``KEMKeyPair`` wrapper. We hold onto the private side for + decapsulation and serialize the public side to raw 32-byte form for the + key configuration blob. + """ + pair = _SUITE.kem.derive_key_pair(os.urandom(32)) + pk_raw = pair.public_key.to_public_bytes() + return pair.private_key, pk_raw diff --git a/tee_gateway/tee_manager.py b/tee_gateway/tee_manager.py index 3d98b12..402438e 100644 --- a/tee_gateway/tee_manager.py +++ b/tee_gateway/tee_manager.py @@ -17,6 +17,8 @@ from eth_account import Account from eth_hash.auto import keccak +from tee_gateway import ohttp + logger = logging.getLogger(__name__) NITRIDING_BASE_URL = "http://127.0.0.1:8080" @@ -36,6 +38,11 @@ def __init__(self, register=True): self.public_key_pem = None self.tee_id = None self.wallet_address = None + # HPKE keypair for OHTTP-style anonymous inference. Generated in the + # same enclave boot so the X25519 public key is covered by the same + # attestation that covers the RSA signing key. + self.hpke_private_key = None + self.hpke_public_key_raw: bytes | None = None self._generate_keys() if register: self.register_with_nitriding() @@ -70,19 +77,39 @@ def _generate_keys(self): wallet_account = Account.from_key(wallet_key_bytes) self.wallet_address = wallet_account.address + # HPKE X25519 keypair — never leaves the enclave; clients address it + # via the public-key fingerprint published with the attestation. + self.hpke_private_key, self.hpke_public_key_raw = ohttp.generate_keypair() + logger.info("TEE key pair generated successfully") logger.info(f"tee_id: 0x{self.tee_id}") logger.info(f"wallet_address: {self.wallet_address}") + logger.info( + f"hpke_public_key (X25519, raw, hex): {self.hpke_public_key_raw.hex()}" + ) def register_with_nitriding(self): - """Register public key hash with nitriding.""" + """Register public key hash with nitriding. + + The hash covers both the RSA signing key (DER-encoded SPKI) and the + raw X25519 HPKE public key. Including both in a single attested digest + means a verifier who validates the attestation document automatically + gets binding for the HPKE config used for anonymous inference — no + separate trust anchor required. + """ try: public_key_der = self.public_key.public_bytes( encoding=serialization.Encoding.DER, format=serialization.PublicFormat.SubjectPublicKeyInfo, ) - key_hash = hashlib.sha256(public_key_der).digest() + # Domain-separated transcript so a future addition of more keys + # can't be confused with the existing layout. + transcript = ( + b"og-tee-keys|v2|rsa-spki=" + public_key_der + + b"|hpke-x25519=" + (self.hpke_public_key_raw or b"") + ) + key_hash = hashlib.sha256(transcript).digest() key_hash_b64 = base64.b64encode(key_hash).decode("utf-8") logger.info(f"Public key DER length: {len(public_key_der)} bytes") @@ -149,12 +176,34 @@ def get_wallet_address(self) -> str: """Return the TEE-generated Ethereum wallet address (checksum).""" return self.wallet_address + def get_hpke_config(self) -> dict: + """Return the HPKE key configuration for anonymous inference. + + ``key_config`` is the RFC 9458 §3 binary key-config blob, base64-encoded. + Clients should treat this as authoritative only when fetched alongside + the Nitro attestation document (which commits to the same key hash via + nitriding registration). + """ + if self.hpke_public_key_raw is None: + raise RuntimeError("HPKE keypair not initialized") + return { + "key_id": ohttp.KEY_CONFIG_ID, + "kem_id": ohttp.KEM_ID_X25519, + "kdf_id": ohttp.KDF_ID_HKDF_SHA256, + "aead_id": ohttp.AEAD_ID_CHACHA20_POLY1305, + "public_key": self.hpke_public_key_raw.hex(), + "key_config": base64.b64encode( + ohttp.key_config(self.hpke_public_key_raw) + ).decode("ascii"), + } + def get_attestation_document(self) -> dict: """Return TEE attestation document.""" return { "public_key": self.public_key_pem, "tee_id": f"0x{self.tee_id}", "wallet_address": self.wallet_address, + "hpke": self.get_hpke_config() if self.hpke_public_key_raw else None, "timestamp": datetime.now(UTC).isoformat(), "enclave_info": { "platform": "aws-nitro", diff --git a/tee_gateway/test/test_ohttp.py b/tee_gateway/test/test_ohttp.py new file mode 100644 index 0000000..44342ce --- /dev/null +++ b/tee_gateway/test/test_ohttp.py @@ -0,0 +1,99 @@ +"""Tests for the OHTTP encapsulation module.""" + +from __future__ import annotations + +import json + +import pytest + +from tee_gateway import ohttp + + +def test_round_trip_request_and_response(): + sk, pk_raw = ohttp.generate_keypair() + + plaintext = json.dumps({"model": "gpt-4.1", "n": 1}).encode() + # Encapsulate using the same code paths a client would, since pyhpke is + # symmetric — we wire the request manually. + config = ohttp.key_config(pk_raw) + assert config[0] == ohttp.KEY_CONFIG_ID + + # Build a wire payload exactly as the SDK does. + import struct + hdr = ( + bytes([ohttp.KEY_CONFIG_ID]) + + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(plaintext, aad=b"") + wire = hdr + enc + ct + + decap = ohttp.decapsulate_request(sk, wire) + assert decap.plaintext == plaintext + + response_secret = sender.export(b"message/bhttp response", 32) + assert decap.response_key == response_secret + assert decap.enc == enc + + response = b'{"ok":true}' + sealed = ohttp.encapsulate_response(decap.response_key, decap.enc, response) + + # Round-trip the response on the "client" side using the same primitives. + import os + from cryptography.hazmat.primitives import hashes, hmac + from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 + from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand + + response_nonce = sealed[:32] + aead_ct = sealed[32:] + salt = enc + response_nonce + h = hmac.HMAC(salt, hashes.SHA256()) + h.update(response_secret) + prk = h.finalize() + key = HKDFExpand(algorithm=hashes.SHA256(), length=32, info=b"key").derive(prk) + nonce = HKDFExpand(algorithm=hashes.SHA256(), length=12, info=b"nonce").derive(prk) + assert ChaCha20Poly1305(key).decrypt(nonce, aead_ct, b"") == response + + +def test_rejects_wrong_suite(): + sk, pk_raw = ohttp.generate_keypair() + # Build a wire with the wrong AEAD ID + import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", ohttp.KEM_ID_X25519, ohttp.KDF_ID_HKDF_SHA256, 0x0001 # AES-128-GCM + ) + fake_wire = hdr + b"\x00" * 32 + b"\x00" * 16 + with pytest.raises(ValueError, match="unsupported"): + ohttp.decapsulate_request(sk, fake_wire) + + +def test_rejects_short_input(): + sk, _ = ohttp.generate_keypair() + with pytest.raises(ValueError, match="too short"): + ohttp.decapsulate_request(sk, b"\x01") + + +def test_rejects_tampered_ciphertext(): + sk, pk_raw = ohttp.generate_keypair() + import struct + hdr = bytes([ohttp.KEY_CONFIG_ID]) + struct.pack( + ">HHH", + ohttp.KEM_ID_X25519, + ohttp.KDF_ID_HKDF_SHA256, + ohttp.AEAD_ID_CHACHA20_POLY1305, + ) + info = b"message/bhttp request" + b"\x00" + hdr + pkr = ohttp._SUITE.kem.deserialize_public_key(pk_raw) + enc, sender = ohttp._SUITE.create_sender_context(pkr, info=info) + ct = sender.seal(b"hello", aad=b"") + wire = bytearray(hdr + enc + ct) + wire[-1] ^= 0xFF + with pytest.raises(Exception): + ohttp.decapsulate_request(sk, bytes(wire))