Vizionary / app.py
Grinding's picture
Update app.py
b6228ee verified
# app.py
# End-to-end FastVLM service optimized for CPU latency.
# Uses ONNX runtime if model.onnx is present, otherwise optimized PyTorch fallback.
# IMPORTANT: do not import transformers/huggingface_hub BEFORE the env/cache setup below.
import os
from pathlib import Path
import io
import base64
import asyncio
import logging
import time
from typing import Optional
# -------------------------
# 1) Cache + threading envs BEFORE HF imports
# -------------------------
CACHE_ROOT = "/app/.cache"
os.environ.setdefault("XDG_CACHE_HOME", CACHE_ROOT)
os.environ.setdefault("HF_HOME", f"{CACHE_ROOT}/hf")
os.environ.setdefault("TRANSFORMERS_CACHE", f"{CACHE_ROOT}/transformers")
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", f"{CACHE_ROOT}/hf")
os.environ.setdefault("HF_DATASETS_CACHE", f"{CACHE_ROOT}/datasets")
# limit threads for 2 vCPU environment
os.environ.setdefault("OMP_NUM_THREADS", "2")
os.environ.setdefault("MKL_NUM_THREADS", "2")
Path(CACHE_ROOT).mkdir(parents=True, exist_ok=True)
try:
os.chmod(CACHE_ROOT, 0o777)
except Exception:
pass
# -------------------------
# Now safe to import heavy libs
# -------------------------
from fastapi import FastAPI, WebSocket, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import numpy as np
# Try imports for ONNX runtime and torch
USE_ONNX = False
onnx = None
onnxruntime = None
try:
import onnxruntime as ort
onnxruntime = ort
onnx = True
except Exception:
onnx = None
try:
import torch
TORCH_AVAILABLE = True
# limit torch threads (double-safety)
try:
torch.set_num_threads(2)
except Exception:
pass
except Exception:
TORCH_AVAILABLE = False
logger = logging.getLogger("uvicorn.error")
app = FastAPI()
# serve static UI from /static
app.mount("/static", StaticFiles(directory="static"), name="static")
# index route
@app.get("/", response_class=FileResponse)
async def index():
p = Path("static/index.html")
if p.exists():
return FileResponse(p)
return JSONResponse({"error": "static/index.html not found"}, status_code=404)
# Globals
MODEL_ONNX_PATH = Path("model.onnx")
USE_ONNX = MODEL_ONNX_PATH.exists() and (onnxruntime is not None)
# Model placeholders
tok = None
model = None
device = None
PRE_IDS = None
POST_IDS = None
DEFAULT_PROMPT = "<image>\nDescribe this image in complete one sentence nothing further then that"
# ONNX session placeholder
ort_session = None
# active ws map (optional)
active_ws_by_ip = {}
# -------------------------
# Startup: load model & tokenizer (or ONNX session)
# -------------------------
@app.on_event("startup")
def startup_load():
global tok, model, device, PRE_IDS, POST_IDS, ort_session, USE_ONNX
device = "cpu"
logger.info("Startup: USE_ONNX=%s", USE_ONNX)
if USE_ONNX:
logger.info("Loading ONNX runtime session from %s", MODEL_ONNX_PATH)
# create session with intra_op_num_threads to bound CPU usage
sess_opts = ort.SessionOptions()
sess_opts.intra_op_num_threads = 2
sess_opts.inter_op_num_threads = 1
# enable all available providers, but CPUExecutionProvider will be used in HF Spaces
ort_session = ort.InferenceSession(str(MODEL_ONNX_PATH), sess_opts)
logger.info("ONNX session ready.")
# Note: ONNX path expects the model to accept the same inputs names used at export.
# Keep tok None if you exported the tokenizer/prompt separately; but we still want tokenizer for decoding.
else:
# PyTorch fallback: load tokenizer and model (transformers)
if not TORCH_AVAILABLE:
raise RuntimeError("Torch not installed and ONNX not available. Install torch or provide model.onnx.")
from transformers import AutoTokenizer, AutoModelForCausalLM
MODEL_ID = "apple/FastVLM-0.5B"
logger.info("Loading tokenizer and PyTorch model (%s) — this may take a while...", MODEL_ID)
# load tokenizer/model to cache dir set by env
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE"))
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE"))
model.to(device)
model.eval()
logger.info("PyTorch model loaded.")
# Precompute tokenized prompt tensors (to avoid tokenizing per-frame)
try:
messages = [{"role": "user", "content": DEFAULT_PROMPT}]
try:
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
except Exception:
rendered = f"{DEFAULT_PROMPT}\n<image>\n"
if "<image>" in rendered:
pre, post = rendered.split("<image>", 1)
else:
pre, post = rendered, ""
PRE_IDS = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
POST_IDS = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
logger.info("Cached token tensors PRE_IDS/POST_IDS")
except Exception as e:
logger.warning("Could not cache PRE_IDS/POST_IDS: %s", e)
PRE_IDS = None
POST_IDS = None
# -------------------------
# Helpers
# -------------------------
def prepare_image_tensor_for_torch(pil_image: Image.Image):
"""
Basic fallback: convert to 224x224 -> CHW tensor normalized 0-1
If model has a processor, we used it earlier when available (but for generality keep this).
"""
pil_small = pil_image.convert("RGB").resize((224, 224))
arr = np.array(pil_small).astype(np.float32) / 255.0
px = torch.tensor(arr).permute(2, 0, 1).unsqueeze(0)
px = px.to(device=device, dtype=next(model.parameters()).dtype)
return px
async def run_torch_inference(pil_image: Image.Image, prompt: Optional[str] = None, timeout_s: float = 10.0) -> str:
"""
Runs PyTorch inference in thread with timeout. Uses cached PRE_IDS/POST_IDS when available.
"""
def _sync():
start = time.time()
nonlocal prompt
if prompt is None:
prompt = DEFAULT_PROMPT
try:
# try using tokenizer's chat template
try:
rendered = tok.apply_chat_template([{"role":"user","content":prompt}], add_generation_prompt=True, tokenize=False)
except Exception:
rendered = f"{prompt}\n<image>\n"
if "<image>" in rendered:
pre, post = rendered.split("<image>", 1)
else:
pre, post = rendered, ""
# reuse cached token tensors if present
try:
pre_ids = PRE_IDS.clone() if PRE_IDS is not None else tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = POST_IDS.clone() if POST_IDS is not None else tok(post, return_tensors="pt", add_special_tokens=False).input_ids
except Exception:
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids
img_tok = torch.tensor([[ -200 ]], dtype=pre_ids.dtype).to(pre_ids.device)
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(device)
attention_mask = torch.ones_like(input_ids, device=device)
# prepare pixels
try:
# if model provides helper, use it (remote code may expose it); else fallback
px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="pt")["pixel_values"]
px = px.to(device=device, dtype=next(model.parameters()).dtype)
except Exception:
px = prepare_image_tensor_for_torch(pil_image)
# inference: use inference_mode and bounds
try:
with torch.inference_mode():
out = model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=32,
do_sample=False,
early_stopping=True,
max_time=6.0, # bound generation
num_beams=1,
)
except TypeError:
with torch.inference_mode():
out = model.generate(
inputs=input_ids,
attention_mask=attention_mask,
images=px,
max_new_tokens=32,
do_sample=False,
early_stopping=True,
num_beams=1,
)
try:
decoded = tok.decode(out[0], skip_special_tokens=True)
except Exception:
decoded = str(out[0].tolist())
elapsed = time.time() - start
logger.info("torch inference done (%.2fs) -> %s", elapsed, decoded[:120])
return decoded
except Exception as e:
logger.exception("torch inference exception: %s", e)
return "Error: inference failed."
# run in separate thread but with timeout to avoid hanging long requests
try:
return await asyncio.wait_for(asyncio.to_thread(_sync), timeout=timeout_s + 2.0)
except asyncio.TimeoutError:
logger.warning("torch inference timed out after %ss", timeout_s)
return "Error: inference timed out."
async def run_onnx_inference(pil_image: Image.Image, prompt: Optional[str] = None, timeout_s: float = 6.0) -> str:
"""
Runs ONNX inference. This assumes the ONNX export accepts:
- input_ids (int64)
- attention_mask (int64)
- pixel_values (float32)
and returns logits / output token ids in a known output name.
NOTE: You must export ONNX accordingly. This function is a best-effort adapter.
"""
if ort_session is None:
return await run_torch_inference(pil_image, prompt=prompt, timeout_s=timeout_s)
# For ONNX we still need tokenizer to produce input token ids and decode outputs
if prompt is None:
prompt = DEFAULT_PROMPT
try:
try:
rendered = tok.apply_chat_template([{"role":"user","content":prompt}], add_generation_prompt=True, tokenize=False)
except Exception:
rendered = f"{prompt}\n<image>\n"
if "<image>" in rendered:
pre, post = rendered.split("<image>", 1)
else:
pre, post = rendered, ""
# tokenize (small cost; PRE_IDS might be reusable but ONNX export must match tokens mapping)
pre_ids = tok(pre, return_tensors="np", add_special_tokens=False).input_ids.astype("int64")
post_ids = tok(post, return_tensors="np", add_special_tokens=False).input_ids.astype("int64")
# construct input sequence: concat pre + image token (-200) + post
img_token = np.array([[-200]], dtype="int64")
input_ids = np.concatenate([pre_ids, img_token, post_ids], axis=1)
attention_mask = np.ones_like(input_ids, dtype="int64")
# pixels - use tokenizer/processor if available else fallback
try:
px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="np")["pixel_values"]
except Exception:
pil_small = pil_image.convert("RGB").resize((224, 224))
arr = np.array(pil_small).astype(np.float32) / 255.0
px = arr.transpose(2, 0, 1)[None, :, :, :].astype(np.float32)
# prepare feeds: names depend on export; here's a common naming convention
feeds = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": px,
}
# run inference synchronously in thread with timeout
def _sync_run():
start = time.time()
outs = ort_session.run(None, feeds)
elapsed = time.time() - start
logger.info("onnx runtime inference done (%.2fs)", elapsed)
# decode - this depends how you exported; if outputs are token ids:
try:
# many exports output logits; handle simple case where output is ids array
out_ids = outs[0]
# ensure 1D list of ints
if isinstance(out_ids, (list, tuple)):
out_ids = out_ids[0]
# convert numpy -> python list, then decode
decoded = tok.decode(out_ids[0], skip_special_tokens=True)
return decoded
except Exception:
# fallback string
return "OK (onnx) - result"
try:
return await asyncio.wait_for(asyncio.to_thread(_sync_run), timeout=timeout_s + 1.0)
except asyncio.TimeoutError:
logger.warning("onnx inference timed out")
return "Error: onnx inference timed out."
except Exception as e:
logger.exception("onnx path error: %s", e)
return "Error: onnx inference failed."
# -------------------------
# Endpoints: ping & upload
# -------------------------
@app.get("/ping")
async def ping():
return {"status": "ok", "onnx": bool(USE_ONNX)}
@app.post("/upload")
async def upload(file: UploadFile = File(...)):
content = await file.read()
try:
img = Image.open(io.BytesIO(content)).convert("RGB")
except Exception:
raise HTTPException(status_code=400, detail="Uploaded file not a supported image")
# route to ONNX or torch accordingly
if USE_ONNX:
caption = await run_onnx_inference(img, timeout_s=20.0)
else:
caption = await run_torch_inference(img, timeout_s=20.0)
return JSONResponse({"filename": file.filename, "caption": caption})
# -------------------------
# WebSocket streaming (single inference at a time; drop frames while busy)
# -------------------------
inference_lock = asyncio.Lock()
@app.websocket("/ws")
async def websocket_endpoint(ws: WebSocket):
client_ip = None
try:
await ws.accept()
try:
client_ip = ws.client[0] if ws.client else None
except Exception:
client_ip = None
# optional: dedupe per client ip (closes previous conn)
if client_ip:
prev = active_ws_by_ip.get(client_ip)
if prev is not None and prev is not ws:
try:
await prev.close()
except Exception:
pass
active_ws_by_ip[client_ip] = ws
await ws.send_json({"type": "connected", "msg": "Send frames as {type:'frame', data: dataURL} (send only when previous resolved)"})
while True:
try:
data = await ws.receive_json()
except Exception:
break
# allow keepalive
if isinstance(data, dict) and data.get("type") == "ping":
try:
await ws.send_json({"type": "pong"})
except Exception:
pass
continue
if not isinstance(data, dict) or data.get("type") != "frame":
await ws.send_json({"type": "error", "msg": "invalid message type"})
continue
# drop if busy
if inference_lock.locked():
try:
await ws.send_json({"type": "skipped", "reason": "busy"})
except Exception:
pass
continue
b64 = data.get("data", "")
if not isinstance(b64, str):
await ws.send_json({"type": "error", "msg": "missing frame data"})
continue
if b64.startswith("data:image"):
b64 = b64.split(",", 1)[1]
try:
img_bytes = base64.b64decode(b64)
pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
except Exception:
await ws.send_json({"type": "error", "msg": "invalid image data"})
continue
# run inference under lock so only one runs at a time
async with inference_lock:
start = time.time()
if USE_ONNX:
caption = await run_onnx_inference(pil_img, timeout_s=20.0)
else:
caption = await run_torch_inference(pil_img, timeout_s=20.0)
elapsed = time.time() - start
# send caption (or error text)
try:
await ws.send_json({"type": "caption", "text": caption, "t": round(elapsed, 2)})
except Exception:
break
except Exception as e:
logger.exception("ws handler error: %s", e)
finally:
try:
if client_ip and active_ws_by_ip.get(client_ip) is ws:
del active_ws_by_ip[client_ip]
except Exception:
pass
try:
await ws.close()
except Exception:
pass