Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies = [
"psutil>=7.2.1",
"python-dotenv>=1.2.1",
"eth-account>=0.13.0",
"pyhpke>=0.6.0",
]

[dependency-groups]
Expand Down
18 changes: 18 additions & 0 deletions tee_gateway/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
)
from tee_gateway.llm_backend import get_provider_config, set_provider_config
from tee_gateway.heartbeat import create_heartbeat_service
from tee_gateway.controllers.ohttp_controller import (
create_anonymous_chat_completion,
get_hpke_config,
)

from x402.http import FacilitatorConfig, HTTPFacilitatorClientSync, PaymentOption
from x402.http.middleware.flask import payment_middleware
Expand Down Expand Up @@ -437,6 +441,20 @@ def create_app():
"/heartbeat/status", "heartbeat-status", heartbeat_status, methods=["GET"]
)

# Anonymous inference (OHTTP-wrapped chat completions). Deliberately
# mounted via add_url_rule rather than the OpenAPI spec because the body
# is raw binary and connexion's request-validation pipeline would reject
# it as malformed JSON.
app.app.add_url_rule(
"/v1/ohttp",
"anonymous-chat",
create_anonymous_chat_completion,
methods=["POST"],
)
app.app.add_url_rule(
"/v1/ohttp/config", "ohttp-config", get_hpke_config, methods=["GET"]
)

# Initialize TEE here so it runs under both Gunicorn and direct execution.
# This is the single TEEKeyManager instance — the same key both registers
# with nitriding and signs all LLM responses.
Expand Down
141 changes: 141 additions & 0 deletions tee_gateway/controllers/ohttp_controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""
Oblivious HTTP endpoint for anonymous inference.

This handler is intentionally minimal: it does HPKE decapsulation, dispatches
the inner request to the existing chat-completions handler in-process (no
network hop), and HPKE-encapsulates the response. The inner JSON request is
identical to the standard /v1/chat/completions body.

Threat model nuances:
* The relay in front of this endpoint sees the encapsulated ciphertext and
the client IP, but no request content.
* The enclave sees plaintext and the relay's IP, never the client's.
* If the client's payload contains identifiers (cookies, ``user`` field,
custom request IDs), unlinkability is broken at the application layer —
we strip the obvious ones below.
* Streaming is intentionally not supported on this endpoint. SSE would
create per-chunk side channels (timing, length) that defeat the point of
bundling everything into a single sealed response.
"""

from __future__ import annotations

import json
import logging
from typing import Any

import connexion
from flask import Response

from tee_gateway import ohttp
from tee_gateway.tee_manager import get_tee_keys

logger = logging.getLogger(__name__)

OHTTP_MEDIA_TYPE = "message/ohttp-req"
OHTTP_RESPONSE_MEDIA_TYPE = "message/ohttp-res"

# Fields that can re-identify a client and have no role in inference. We drop
# them before forwarding to the inner handler — keeping them inside the
# encrypted envelope would only protect them from the relay, not from us or
# the upstream LLM provider.
_IDENTIFYING_FIELDS = ("user", "metadata", "x-request-id", "request_id")


def _scrub(payload: dict[str, Any]) -> dict[str, Any]:
return {k: v for k, v in payload.items() if k not in _IDENTIFYING_FIELDS}


def create_anonymous_chat_completion():
"""POST /v1/ohttp — decrypt, dispatch, re-encrypt.

Body: raw bytes (OHTTP-encapsulated request).
Returns: raw bytes (OHTTP-encapsulated response) with Content-Type
``message/ohttp-res``.
"""
req = connexion.request
# Tolerate both Connexion's Flask request and a bare Flask request.
raw_body: bytes = req.get_data(cache=False)
if not raw_body:
return _error(400, "empty body")

tee = get_tee_keys()
if tee.hpke_private_key is None:
return _error(503, "anonymous inference not initialized")

try:
decap = ohttp.decapsulate_request(tee.hpke_private_key, raw_body)
except Exception as exc:
# Don't leak which step failed — clients can retry with a fresh
# encapsulation, all observable failures look identical.
logger.warning("OHTTP decapsulation failed: %s", type(exc).__name__)
return _error(400, "malformed encapsulated request")

try:
inner_body = json.loads(decap.plaintext.decode("utf-8"))
except (UnicodeDecodeError, json.JSONDecodeError):
return _error(400, "inner payload is not valid JSON")

if not isinstance(inner_body, dict):
return _error(400, "inner payload must be a JSON object")

if inner_body.get("stream"):
# Streaming is rejected on principle (see module docstring). Clients
# who want low TTFT under anonymity should use a shorter max_tokens.
return _error(400, "stream=true is not supported over OHTTP")

inner_body = _scrub(inner_body)

# Late import to avoid a circular dependency at module load (the chat
# controller pulls in models that import this package).
from tee_gateway.controllers.chat_controller import (
_create_non_streaming_response,
_parse_chat_request,
)

try:
chat_request = _parse_chat_request(inner_body)
inner_result = _create_non_streaming_response(chat_request)
except Exception as exc:
logger.error("inner inference failed under OHTTP: %s", exc, exc_info=True)
inner_result = ({"error": "inference failed"}, 500)

# _create_non_streaming_response returns either a dict or (body, status)
if isinstance(inner_result, tuple):
body_obj, status = inner_result
else:
body_obj, status = inner_result, 200

inner_json = json.dumps(
{"status": status, "body": body_obj},
separators=(",", ":"),
).encode("utf-8")

sealed = ohttp.encapsulate_response(
decap.response_key, decap.enc, inner_json
)
return Response(
sealed,
status=200,
mimetype=OHTTP_RESPONSE_MEDIA_TYPE,
)


def get_hpke_config():
"""GET /v1/ohttp/config — return the HPKE key configuration.

Returns both an OHTTP-compliant binary key_config (base64) and the
individual fields for clients that prefer to parse JSON. The same data is
embedded inside the attestation document at /signing-key for clients that
want to verify the binding to the enclave's PCRs in one step.
"""
try:
tee = get_tee_keys()
return tee.get_hpke_config(), 200
except Exception as exc:
logger.error("HPKE config error: %s", exc, exc_info=True)
return {"error": str(exc)}, 500


def _error(status: int, message: str) -> tuple[dict, int]:
return {"error": message}, status
166 changes: 166 additions & 0 deletions tee_gateway/ohttp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""
Oblivious HTTP encapsulation for anonymous inference.

Implements request/response encapsulation per RFC 9458 (Oblivious HTTP)
with a fixed HPKE ciphersuite:
- KEM: DHKEM(X25519, HKDF-SHA256) (0x0020)
- KDF: HKDF-SHA256 (0x0001)
- AEAD: ChaCha20-Poly1305 (0x0003)

The inner payload is application/json — we do not BHTTP-wrap the inference
request, since the enclave is the terminal endpoint and not a generic HTTP
proxy. This is a documented divergence from strict RFC 9458; the cryptographic
construction (HPKE base + exported response keying) is identical.

Trust model: the relay sees ciphertext + client IP; the enclave sees plaintext
+ relay IP. Unlinkability holds unless relay and enclave collude.
"""

from __future__ import annotations

import os
import struct
from dataclasses import dataclass

from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
from pyhpke import AEADId, CipherSuite, KDFId, KEMId
from pyhpke.kem_key_interface import KEMKeyInterface


# RFC 9180 / 9458 algorithm identifiers
KEM_ID_X25519 = 0x0020
KDF_ID_HKDF_SHA256 = 0x0001
AEAD_ID_CHACHA20_POLY1305 = 0x0003

# Single, stable key configuration ID. Bump when the keypair or suite changes
# so clients can refuse stale configs.
KEY_CONFIG_ID = 0x01

# AEAD parameters for ChaCha20-Poly1305
_NK = 32 # key length
_NN = 12 # nonce length

# Per RFC 9458 §4.1/4.2 — "info" labels for the HPKE context.
_LABEL_REQUEST = b"message/bhttp request"
_LABEL_RESPONSE = b"message/bhttp response"

_SUITE = CipherSuite.new(
KEMId.DHKEM_X25519_HKDF_SHA256,
KDFId.HKDF_SHA256,
AEADId.CHACHA20_POLY1305,
)


def _header_bytes(key_id: int = KEY_CONFIG_ID) -> bytes:
return bytes([key_id]) + struct.pack(
">HHH",
KEM_ID_X25519,
KDF_ID_HKDF_SHA256,
AEAD_ID_CHACHA20_POLY1305,
)


def key_config(public_key_raw: bytes, key_id: int = KEY_CONFIG_ID) -> bytes:
"""Build an OHTTP key configuration blob (RFC 9458 §3).

Format:
key_id(1) || kem_id(2) || public_key(Npk=32) ||
symmetric_algorithms_length(2) || (kdf_id(2) || aead_id(2))+
"""
if len(public_key_raw) != 32:
raise ValueError("X25519 public key must be 32 bytes")
symmetric = struct.pack(">HH", KDF_ID_HKDF_SHA256, AEAD_ID_CHACHA20_POLY1305)
return (
bytes([key_id])
+ struct.pack(">H", KEM_ID_X25519)
+ public_key_raw
+ struct.pack(">H", len(symmetric))
+ symmetric
)


@dataclass
class DecapsulatedRequest:
"""Result of decapsulating an OHTTP-wrapped request."""

plaintext: bytes
response_key: bytes # 32 bytes exported from the HPKE context
enc: bytes # client's ephemeral public key, used as salt for the response


def decapsulate_request(
private_key: KEMKeyInterface, encapsulated_request: bytes
) -> DecapsulatedRequest:
"""Decrypt an HPKE-wrapped request inside the enclave.

Raises ValueError on malformed input or unsupported ciphersuite. We
never echo the underlying exception text to clients — it can leak
timing/oracle info.
"""
if len(encapsulated_request) < 7 + 32:
raise ValueError("encapsulated request too short")

key_id = encapsulated_request[0]
kem_id, kdf_id, aead_id = struct.unpack(">HHH", encapsulated_request[1:7])
if (key_id, kem_id, kdf_id, aead_id) != (
KEY_CONFIG_ID,
KEM_ID_X25519,
KDF_ID_HKDF_SHA256,
AEAD_ID_CHACHA20_POLY1305,
):
raise ValueError("unsupported HPKE configuration")

enc = encapsulated_request[7 : 7 + 32]
aead_ct = encapsulated_request[7 + 32 :]

info = _LABEL_REQUEST + b"\x00" + _header_bytes(key_id)
recipient = _SUITE.create_recipient_context(enc, private_key, info=info)
plaintext = recipient.open(aead_ct, aad=b"")

# Export a fresh secret bound to this HPKE context, used to derive the
# response AEAD key. This is the OHTTP-defined separation between the
# request and response halves of the same exchange.
response_secret = recipient.export(_LABEL_RESPONSE, _NK)
return DecapsulatedRequest(
plaintext=plaintext, response_key=response_secret, enc=enc
)


def encapsulate_response(
response_secret: bytes, enc: bytes, plaintext: bytes
) -> bytes:
"""Seal a response under the per-request derived key (RFC 9458 §4.2).

Wire format: response_nonce(max(Nn, Nk)=Nk=32) || AEAD ciphertext
"""
response_nonce = os.urandom(max(_NN, _NK))
salt = enc + response_nonce

h = hmac.HMAC(salt, hashes.SHA256())
h.update(response_secret)
prk = h.finalize()

aead_key = HKDFExpand(
algorithm=hashes.SHA256(), length=_NK, info=b"key"
).derive(prk)
aead_nonce = HKDFExpand(
algorithm=hashes.SHA256(), length=_NN, info=b"nonce"
).derive(prk)

ct = ChaCha20Poly1305(aead_key).encrypt(aead_nonce, plaintext, b"")
return response_nonce + ct


def generate_keypair() -> tuple[KEMKeyInterface, bytes]:
"""Generate an X25519 keypair for HPKE. Returns (private_key, public_key_raw).

pyhpke 0.6 derives keys from random IKM via ``kem.derive_key_pair(ikm)``,
which returns a ``KEMKeyPair`` wrapper. We hold onto the private side for
decapsulation and serialize the public side to raw 32-byte form for the
key configuration blob.
"""
pair = _SUITE.kem.derive_key_pair(os.urandom(32))
pk_raw = pair.public_key.to_public_bytes()
return pair.private_key, pk_raw
Loading
Loading