-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
118 lines (89 loc) · 4.04 KB
/
main.py
File metadata and controls
118 lines (89 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib
import numpy as np
import os
import logging
try:
from dotenv import load_dotenv
load_dotenv(override=False)
except ImportError:
pass
from utils import normalize_landmarks
# ── Structured logging ──────────────────────────────────────────────────────
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(message)s")
logger = logging.getLogger("ml-api")
app = FastAPI(title="ASL Classifier API", version="1.0.0")
# ── C.2.1: Restrict CORS to known frontend origins only ─────────────────────
def get_allowed_origins() -> list[str]:
raw_origins = os.getenv("ALLOWED_ORIGINS", "")
origins = [origin.strip() for origin in raw_origins.split(",") if origin.strip()]
if not origins:
logger.warning("No ALLOWED_ORIGINS configured. Cross-origin requests will be blocked.")
return origins
ALLOWED_ORIGINS = get_allowed_origins()
app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_methods=["POST", "GET", "OPTIONS"],
allow_headers=["Content-Type"],
)
# ── C.1.1: Lazy model loading — do NOT load at import time (cold-start safe) ─
MODEL_PATH = "sign_classifier.pkl"
_model = None
def get_model():
"""Lazy-load the model on first inference request, then cache it globally."""
global _model
if _model is None:
if not os.path.exists(MODEL_PATH):
logger.error("Model file not found: %s", MODEL_PATH)
raise RuntimeError("Model file not found. Run train.py first.")
logger.info("Loading model from %s …", MODEL_PATH)
_model = joblib.load(MODEL_PATH)
logger.info("Model loaded successfully.")
return _model
# ── Pydantic schemas ─────────────────────────────────────────────────────────
class LandmarkItem(BaseModel):
x: float
y: float
z: float
class PredictionRequest(BaseModel):
landmarks: list[LandmarkItem]
# ── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/")
def health():
"""Health check — confirms the API is running and whether the model is loaded."""
model_ready = os.path.exists(MODEL_PATH)
return {"status": "online", "model_ready": model_ready, "version": "1.0.0"}
@app.post("/classify")
def classify(request: PredictionRequest):
# C.1.3: Validate landmark count before touching the model
if len(request.landmarks) != 21:
raise HTTPException(
status_code=422,
detail=f"Expected exactly 21 landmarks, received {len(request.landmarks)}."
)
# C.1.1: Lazy-load model (safe for serverless cold starts)
try:
model = get_model()
except RuntimeError as exc:
raise HTTPException(status_code=503, detail=str(exc))
try:
raw_landmarks = [{"x": l.x, "y": l.y, "z": l.z} for l in request.landmarks]
features = normalize_landmarks(raw_landmarks)
prediction = model.predict([features])[0]
probs = model.predict_proba([features])[0]
confidence = float(max(probs))
logger.info("Prediction: %s (confidence=%.4f)", prediction, confidence)
return {"sign": str(prediction), "confidence": round(confidence, 4)}
except Exception as exc:
# C.2.3: Do NOT expose raw exception strings to the client
logger.error("Prediction error: %s", exc, exc_info=True)
raise HTTPException(
status_code=400,
detail="Prediction failed due to invalid input data."
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)