|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
from pathlib import Path |
|
|
import io |
|
|
import base64 |
|
|
import asyncio |
|
|
import logging |
|
|
import time |
|
|
from typing import Optional |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CACHE_ROOT = "/app/.cache" |
|
|
os.environ.setdefault("XDG_CACHE_HOME", CACHE_ROOT) |
|
|
os.environ.setdefault("HF_HOME", f"{CACHE_ROOT}/hf") |
|
|
os.environ.setdefault("TRANSFORMERS_CACHE", f"{CACHE_ROOT}/transformers") |
|
|
os.environ.setdefault("HUGGINGFACE_HUB_CACHE", f"{CACHE_ROOT}/hf") |
|
|
os.environ.setdefault("HF_DATASETS_CACHE", f"{CACHE_ROOT}/datasets") |
|
|
|
|
|
os.environ.setdefault("OMP_NUM_THREADS", "2") |
|
|
os.environ.setdefault("MKL_NUM_THREADS", "2") |
|
|
Path(CACHE_ROOT).mkdir(parents=True, exist_ok=True) |
|
|
try: |
|
|
os.chmod(CACHE_ROOT, 0o777) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, WebSocket, UploadFile, File, HTTPException |
|
|
from fastapi.responses import JSONResponse, FileResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
USE_ONNX = False |
|
|
onnx = None |
|
|
onnxruntime = None |
|
|
try: |
|
|
import onnxruntime as ort |
|
|
onnxruntime = ort |
|
|
onnx = True |
|
|
except Exception: |
|
|
onnx = None |
|
|
|
|
|
try: |
|
|
import torch |
|
|
TORCH_AVAILABLE = True |
|
|
|
|
|
try: |
|
|
torch.set_num_threads(2) |
|
|
except Exception: |
|
|
pass |
|
|
except Exception: |
|
|
TORCH_AVAILABLE = False |
|
|
|
|
|
logger = logging.getLogger("uvicorn.error") |
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
|
|
|
@app.get("/", response_class=FileResponse) |
|
|
async def index(): |
|
|
p = Path("static/index.html") |
|
|
if p.exists(): |
|
|
return FileResponse(p) |
|
|
return JSONResponse({"error": "static/index.html not found"}, status_code=404) |
|
|
|
|
|
|
|
|
MODEL_ONNX_PATH = Path("model.onnx") |
|
|
USE_ONNX = MODEL_ONNX_PATH.exists() and (onnxruntime is not None) |
|
|
|
|
|
|
|
|
tok = None |
|
|
model = None |
|
|
device = None |
|
|
PRE_IDS = None |
|
|
POST_IDS = None |
|
|
DEFAULT_PROMPT = "<image>\nDescribe this image in complete one sentence nothing further then that" |
|
|
|
|
|
|
|
|
ort_session = None |
|
|
|
|
|
|
|
|
active_ws_by_ip = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
def startup_load(): |
|
|
global tok, model, device, PRE_IDS, POST_IDS, ort_session, USE_ONNX |
|
|
device = "cpu" |
|
|
logger.info("Startup: USE_ONNX=%s", USE_ONNX) |
|
|
|
|
|
if USE_ONNX: |
|
|
logger.info("Loading ONNX runtime session from %s", MODEL_ONNX_PATH) |
|
|
|
|
|
sess_opts = ort.SessionOptions() |
|
|
sess_opts.intra_op_num_threads = 2 |
|
|
sess_opts.inter_op_num_threads = 1 |
|
|
|
|
|
ort_session = ort.InferenceSession(str(MODEL_ONNX_PATH), sess_opts) |
|
|
logger.info("ONNX session ready.") |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
if not TORCH_AVAILABLE: |
|
|
raise RuntimeError("Torch not installed and ONNX not available. Install torch or provide model.onnx.") |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
MODEL_ID = "apple/FastVLM-0.5B" |
|
|
logger.info("Loading tokenizer and PyTorch model (%s) — this may take a while...", MODEL_ID) |
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE")) |
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=os.environ.get("TRANSFORMERS_CACHE")) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
logger.info("PyTorch model loaded.") |
|
|
|
|
|
|
|
|
try: |
|
|
messages = [{"role": "user", "content": DEFAULT_PROMPT}] |
|
|
try: |
|
|
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
except Exception: |
|
|
rendered = f"{DEFAULT_PROMPT}\n<image>\n" |
|
|
if "<image>" in rendered: |
|
|
pre, post = rendered.split("<image>", 1) |
|
|
else: |
|
|
pre, post = rendered, "" |
|
|
PRE_IDS = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
POST_IDS = tok(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
logger.info("Cached token tensors PRE_IDS/POST_IDS") |
|
|
except Exception as e: |
|
|
logger.warning("Could not cache PRE_IDS/POST_IDS: %s", e) |
|
|
PRE_IDS = None |
|
|
POST_IDS = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_image_tensor_for_torch(pil_image: Image.Image): |
|
|
""" |
|
|
Basic fallback: convert to 224x224 -> CHW tensor normalized 0-1 |
|
|
If model has a processor, we used it earlier when available (but for generality keep this). |
|
|
""" |
|
|
pil_small = pil_image.convert("RGB").resize((224, 224)) |
|
|
arr = np.array(pil_small).astype(np.float32) / 255.0 |
|
|
px = torch.tensor(arr).permute(2, 0, 1).unsqueeze(0) |
|
|
px = px.to(device=device, dtype=next(model.parameters()).dtype) |
|
|
return px |
|
|
|
|
|
async def run_torch_inference(pil_image: Image.Image, prompt: Optional[str] = None, timeout_s: float = 10.0) -> str: |
|
|
""" |
|
|
Runs PyTorch inference in thread with timeout. Uses cached PRE_IDS/POST_IDS when available. |
|
|
""" |
|
|
def _sync(): |
|
|
start = time.time() |
|
|
nonlocal prompt |
|
|
if prompt is None: |
|
|
prompt = DEFAULT_PROMPT |
|
|
try: |
|
|
|
|
|
try: |
|
|
rendered = tok.apply_chat_template([{"role":"user","content":prompt}], add_generation_prompt=True, tokenize=False) |
|
|
except Exception: |
|
|
rendered = f"{prompt}\n<image>\n" |
|
|
if "<image>" in rendered: |
|
|
pre, post = rendered.split("<image>", 1) |
|
|
else: |
|
|
pre, post = rendered, "" |
|
|
|
|
|
try: |
|
|
pre_ids = PRE_IDS.clone() if PRE_IDS is not None else tok(pre, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
post_ids = POST_IDS.clone() if POST_IDS is not None else tok(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
except Exception: |
|
|
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
|
|
|
img_tok = torch.tensor([[ -200 ]], dtype=pre_ids.dtype).to(pre_ids.device) |
|
|
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(device) |
|
|
attention_mask = torch.ones_like(input_ids, device=device) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="pt")["pixel_values"] |
|
|
px = px.to(device=device, dtype=next(model.parameters()).dtype) |
|
|
except Exception: |
|
|
px = prepare_image_tensor_for_torch(pil_image) |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
out = model.generate( |
|
|
inputs=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
images=px, |
|
|
max_new_tokens=32, |
|
|
do_sample=False, |
|
|
early_stopping=True, |
|
|
max_time=6.0, |
|
|
num_beams=1, |
|
|
) |
|
|
except TypeError: |
|
|
with torch.inference_mode(): |
|
|
out = model.generate( |
|
|
inputs=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
images=px, |
|
|
max_new_tokens=32, |
|
|
do_sample=False, |
|
|
early_stopping=True, |
|
|
num_beams=1, |
|
|
) |
|
|
try: |
|
|
decoded = tok.decode(out[0], skip_special_tokens=True) |
|
|
except Exception: |
|
|
decoded = str(out[0].tolist()) |
|
|
elapsed = time.time() - start |
|
|
logger.info("torch inference done (%.2fs) -> %s", elapsed, decoded[:120]) |
|
|
return decoded |
|
|
except Exception as e: |
|
|
logger.exception("torch inference exception: %s", e) |
|
|
return "Error: inference failed." |
|
|
|
|
|
|
|
|
try: |
|
|
return await asyncio.wait_for(asyncio.to_thread(_sync), timeout=timeout_s + 2.0) |
|
|
except asyncio.TimeoutError: |
|
|
logger.warning("torch inference timed out after %ss", timeout_s) |
|
|
return "Error: inference timed out." |
|
|
|
|
|
async def run_onnx_inference(pil_image: Image.Image, prompt: Optional[str] = None, timeout_s: float = 6.0) -> str: |
|
|
""" |
|
|
Runs ONNX inference. This assumes the ONNX export accepts: |
|
|
- input_ids (int64) |
|
|
- attention_mask (int64) |
|
|
- pixel_values (float32) |
|
|
and returns logits / output token ids in a known output name. |
|
|
NOTE: You must export ONNX accordingly. This function is a best-effort adapter. |
|
|
""" |
|
|
if ort_session is None: |
|
|
return await run_torch_inference(pil_image, prompt=prompt, timeout_s=timeout_s) |
|
|
|
|
|
|
|
|
if prompt is None: |
|
|
prompt = DEFAULT_PROMPT |
|
|
try: |
|
|
try: |
|
|
rendered = tok.apply_chat_template([{"role":"user","content":prompt}], add_generation_prompt=True, tokenize=False) |
|
|
except Exception: |
|
|
rendered = f"{prompt}\n<image>\n" |
|
|
if "<image>" in rendered: |
|
|
pre, post = rendered.split("<image>", 1) |
|
|
else: |
|
|
pre, post = rendered, "" |
|
|
|
|
|
|
|
|
pre_ids = tok(pre, return_tensors="np", add_special_tokens=False).input_ids.astype("int64") |
|
|
post_ids = tok(post, return_tensors="np", add_special_tokens=False).input_ids.astype("int64") |
|
|
|
|
|
|
|
|
img_token = np.array([[-200]], dtype="int64") |
|
|
input_ids = np.concatenate([pre_ids, img_token, post_ids], axis=1) |
|
|
attention_mask = np.ones_like(input_ids, dtype="int64") |
|
|
|
|
|
|
|
|
try: |
|
|
px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="np")["pixel_values"] |
|
|
except Exception: |
|
|
pil_small = pil_image.convert("RGB").resize((224, 224)) |
|
|
arr = np.array(pil_small).astype(np.float32) / 255.0 |
|
|
px = arr.transpose(2, 0, 1)[None, :, :, :].astype(np.float32) |
|
|
|
|
|
|
|
|
feeds = { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"pixel_values": px, |
|
|
} |
|
|
|
|
|
|
|
|
def _sync_run(): |
|
|
start = time.time() |
|
|
outs = ort_session.run(None, feeds) |
|
|
elapsed = time.time() - start |
|
|
logger.info("onnx runtime inference done (%.2fs)", elapsed) |
|
|
|
|
|
try: |
|
|
|
|
|
out_ids = outs[0] |
|
|
|
|
|
if isinstance(out_ids, (list, tuple)): |
|
|
out_ids = out_ids[0] |
|
|
|
|
|
decoded = tok.decode(out_ids[0], skip_special_tokens=True) |
|
|
return decoded |
|
|
except Exception: |
|
|
|
|
|
return "OK (onnx) - result" |
|
|
try: |
|
|
return await asyncio.wait_for(asyncio.to_thread(_sync_run), timeout=timeout_s + 1.0) |
|
|
except asyncio.TimeoutError: |
|
|
logger.warning("onnx inference timed out") |
|
|
return "Error: onnx inference timed out." |
|
|
except Exception as e: |
|
|
logger.exception("onnx path error: %s", e) |
|
|
return "Error: onnx inference failed." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/ping") |
|
|
async def ping(): |
|
|
return {"status": "ok", "onnx": bool(USE_ONNX)} |
|
|
|
|
|
@app.post("/upload") |
|
|
async def upload(file: UploadFile = File(...)): |
|
|
content = await file.read() |
|
|
try: |
|
|
img = Image.open(io.BytesIO(content)).convert("RGB") |
|
|
except Exception: |
|
|
raise HTTPException(status_code=400, detail="Uploaded file not a supported image") |
|
|
|
|
|
if USE_ONNX: |
|
|
caption = await run_onnx_inference(img, timeout_s=20.0) |
|
|
else: |
|
|
caption = await run_torch_inference(img, timeout_s=20.0) |
|
|
return JSONResponse({"filename": file.filename, "caption": caption}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_lock = asyncio.Lock() |
|
|
|
|
|
@app.websocket("/ws") |
|
|
async def websocket_endpoint(ws: WebSocket): |
|
|
client_ip = None |
|
|
try: |
|
|
await ws.accept() |
|
|
try: |
|
|
client_ip = ws.client[0] if ws.client else None |
|
|
except Exception: |
|
|
client_ip = None |
|
|
|
|
|
|
|
|
if client_ip: |
|
|
prev = active_ws_by_ip.get(client_ip) |
|
|
if prev is not None and prev is not ws: |
|
|
try: |
|
|
await prev.close() |
|
|
except Exception: |
|
|
pass |
|
|
active_ws_by_ip[client_ip] = ws |
|
|
|
|
|
await ws.send_json({"type": "connected", "msg": "Send frames as {type:'frame', data: dataURL} (send only when previous resolved)"}) |
|
|
|
|
|
while True: |
|
|
try: |
|
|
data = await ws.receive_json() |
|
|
except Exception: |
|
|
break |
|
|
|
|
|
|
|
|
if isinstance(data, dict) and data.get("type") == "ping": |
|
|
try: |
|
|
await ws.send_json({"type": "pong"}) |
|
|
except Exception: |
|
|
pass |
|
|
continue |
|
|
|
|
|
if not isinstance(data, dict) or data.get("type") != "frame": |
|
|
await ws.send_json({"type": "error", "msg": "invalid message type"}) |
|
|
continue |
|
|
|
|
|
|
|
|
if inference_lock.locked(): |
|
|
try: |
|
|
await ws.send_json({"type": "skipped", "reason": "busy"}) |
|
|
except Exception: |
|
|
pass |
|
|
continue |
|
|
|
|
|
b64 = data.get("data", "") |
|
|
if not isinstance(b64, str): |
|
|
await ws.send_json({"type": "error", "msg": "missing frame data"}) |
|
|
continue |
|
|
if b64.startswith("data:image"): |
|
|
b64 = b64.split(",", 1)[1] |
|
|
try: |
|
|
img_bytes = base64.b64decode(b64) |
|
|
pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
except Exception: |
|
|
await ws.send_json({"type": "error", "msg": "invalid image data"}) |
|
|
continue |
|
|
|
|
|
|
|
|
async with inference_lock: |
|
|
start = time.time() |
|
|
if USE_ONNX: |
|
|
caption = await run_onnx_inference(pil_img, timeout_s=20.0) |
|
|
else: |
|
|
caption = await run_torch_inference(pil_img, timeout_s=20.0) |
|
|
elapsed = time.time() - start |
|
|
|
|
|
try: |
|
|
await ws.send_json({"type": "caption", "text": caption, "t": round(elapsed, 2)}) |
|
|
except Exception: |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logger.exception("ws handler error: %s", e) |
|
|
finally: |
|
|
try: |
|
|
if client_ip and active_ws_by_ip.get(client_ip) is ws: |
|
|
del active_ws_by_ip[client_ip] |
|
|
except Exception: |
|
|
pass |
|
|
try: |
|
|
await ws.close() |
|
|
except Exception: |
|
|
pass |
|
|
|