# hf_backend.py import time, logging, os, contextlib from typing import Any, Dict, AsyncIterable, List import torch from transformers import AutoTokenizer, AutoModelForCausalLM from backends_base import ChatBackend, ImagesBackend from config import settings try: import spaces except ImportError: spaces = None logger = logging.getLogger(__name__) # --- Load model/tokenizer on CPU at import time (ZeroGPU safe) --- MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct" logger.info(f"Loading {MODEL_ID} on CPU at startup (ZeroGPU safe)...") tokenizer = None model = None load_error = None try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float32, # CPU-safe default trust_remote_code=True, ) model.eval() except Exception as e: load_error = f"Failed to load model/tokenizer: {e}" logger.exception(load_error) # --- Device helpers --- def pick_device() -> str: forced = os.getenv("FORCE_DEVICE", "").lower().strip() if forced in {"cpu", "cuda", "mps"}: return forced if torch.cuda.is_available(): return "cuda" if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return "mps" return "cpu" def pick_dtype(device: str) -> torch.dtype: if device == "cuda": major, _ = torch.cuda.get_device_capability() return torch.bfloat16 if major >= 8 else torch.float16 if device == "mps": return torch.float16 return torch.float32 # --- Backend class --- class HFChatBackend(ChatBackend): async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]: if load_error: raise RuntimeError(load_error) messages = request.get("messages", []) prompt = messages[-1]["content"] if messages else "(empty)" temperature = float(request.get("temperature", settings.LlmTemp or 0.7)) max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512)) rid = f"chatcmpl-hf-{int(time.time())}" now = int(time.time()) if spaces: @spaces.GPU(duration=120) # allow longer run def run_once(prompt: str) -> str: device = pick_device() dtype = pick_dtype(device) # Move model to GPU if needed model.to(device=device, dtype=dtype).eval() inputs = tokenizer(prompt, return_tensors="pt").to(device) with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) else: def run_once(prompt: str) -> str: inputs = tokenizer(prompt, return_tensors="pt") with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) try: text = run_once(prompt) yield { "id": rid, "object": "chat.completion.chunk", "created": now, "model": MODEL_ID, "choices": [ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"} ], } except Exception: logger.exception("HF inference failed") raise class StubImagesBackend(ImagesBackend): async def generate_b64(self, request: Dict[str, Any]) -> str: logger.warning("Image generation not supported in HF backend.") return ( "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII=" )