File size: 3,904 Bytes
60a9595 7471f75 2dcb7ad 60a9595 2dcb7ad 60a9595 2dcb7ad 60a9595 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 60a9595 2dcb7ad 2ad6a17 7471f75 60a9595 7471f75 2ad6a17 7471f75 60a9595 2ad6a17 7471f75 2ad6a17 60a9595 2ad6a17 60a9595 1b21789 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# hf_backend.py
import time, logging
from typing import Any, Dict, AsyncIterable
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from backends_base import ChatBackend, ImagesBackend
from config import settings
logger = logging.getLogger(__name__)
try:
import spaces
from spaces.zero import client as zero_client
except ImportError:
spaces, zero_client = None, None
# --- Model setup ---
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU...")
tokenizer, load_error = None, None
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True,
use_fast=False
)
except Exception as e:
load_error = f"Failed to load tokenizer: {e}"
logger.exception(load_error)
# ---------------- Chat Backend ----------------
class HFChatBackend(ChatBackend):
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
if load_error:
raise RuntimeError(load_error)
messages = request.get("messages", [])
prompt = messages[-1]["content"] if messages else "(empty)"
temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
rid = f"chatcmpl-hf-{int(time.time())}"
now = int(time.time())
# --- Inject X-IP-Token into global headers if ZeroGPU is used ---
x_ip_token = request.get("x_ip_token")
if x_ip_token and zero_client:
zero_client.HEADERS["X-IP-Token"] = x_ip_token
logger.debug("Injected X-IP-Token into ZeroGPU headers")
def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
trust_remote_code=True,
device_map="auto" if device != "cpu" else {"": "cpu"},
)
model.eval()
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
if device != "cpu":
autocast_ctx = torch.autocast(device_type=device, dtype=dtype)
else:
autocast_ctx = torch.cpu.amp.autocast(dtype=dtype)
with autocast_ctx:
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
if spaces:
# --- GPU path with ZeroGPU ---
@spaces.GPU(duration=120)
def run_once(prompt: str) -> str:
return _run_once(prompt, device="cuda", dtype=torch.float16)
text = run_once(prompt)
else:
# --- CPU-only fallback ---
text = _run_once(prompt, device="cpu", dtype=torch.float32)
yield {
"id": rid,
"object": "chat.completion.chunk",
"created": now,
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
],
}
# ---------------- Stub Images Backend ----------------
class StubImagesBackend(ImagesBackend):
"""
Stub backend for images since HFChatBackend is text-only.
Returns a transparent 1x1 PNG placeholder.
"""
async def generate_b64(self, request: Dict[str, Any]) -> str:
logger.warning("Image generation not supported in HF backend.")
return (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
)
|