File size: 6,135 Bytes
60a9595 7471f75 2dcb7ad 1d79762 2dcb7ad 60a9595 2dcb7ad 60a9595 2dcb7ad 60a9595 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 2ad6a17 2dcb7ad 1d79762 2dcb7ad 2ad6a17 2dcb7ad 1d79762 d279e64 1d79762 849364d d279e64 849364d d279e64 849364d 60a9595 2dcb7ad 2ad6a17 7471f75 60a9595 7471f75 d279e64 2ad6a17 849364d 7471f75 60a9595 2ad6a17 7471f75 2ad6a17 60a9595 2ad6a17 1d79762 2ad6a17 60a9595 1b21789 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
# hf_backend.py
import time, logging
from typing import Any, Dict, AsyncIterable
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from backends_base import ChatBackend, ImagesBackend
from config import settings
logger = logging.getLogger(__name__)
try:
import spaces
from spaces.zero import client as zero_client
except ImportError:
spaces, zero_client = None, None
# --- Model setup ---
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU...")
tokenizer, load_error = None, None
try:
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
trust_remote_code=True,
use_fast=False,
)
except Exception as e:
load_error = f"Failed to load tokenizer: {e}"
logger.exception(load_error)
# ---------------- helpers ----------------
def _pick_cpu_dtype() -> torch.dtype:
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
try:
if torch.cpu.is_bf16_supported():
logger.info("CPU BF16 supported, will attempt torch.bfloat16")
return torch.bfloat16
except Exception:
pass
logger.info("Falling back to torch.float32 on CPU")
return torch.float32
# ---------------- global cache ----------------
_MODEL_CACHE: dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
def _get_model(device: str, dtype: torch.dtype):
key = (device, dtype)
if key in _MODEL_CACHE:
return _MODEL_CACHE[key]
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
if hasattr(cfg, "quantization_config"):
logger.warning("Removing quantization_config from model config")
delattr(cfg, "quantization_config")
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=cfg,
torch_dtype=dtype,
trust_remote_code=True,
device_map="auto" if device != "cpu" else {"": "cpu"},
)
except Exception as e:
if device == "cpu" and dtype == torch.bfloat16:
logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=cfg,
torch_dtype=torch.float32,
trust_remote_code=True,
device_map={"": "cpu"},
)
dtype = torch.float32
else:
raise
model.eval()
_MODEL_CACHE[(device, dtype)] = model
return model
# ---------------- Chat Backend ----------------
class HFChatBackend(ChatBackend):
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
if load_error:
raise RuntimeError(load_error)
messages = request.get("messages", [])
temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
rid = f"chatcmpl-hf-{int(time.time())}"
now = int(time.time())
# --- Inject X-IP-Token into global headers if ZeroGPU is used ---
x_ip_token = request.get("x_ip_token")
if x_ip_token and zero_client:
zero_client.HEADERS["X-IP-Token"] = x_ip_token
logger.debug("Injected X-IP-Token into ZeroGPU headers")
# Build prompt using chat template if available
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
logger.debug("Applied chat template for prompt")
except Exception as e:
logger.warning(f"Failed to apply chat template: {e}, using fallback")
prompt = messages[-1]["content"] if messages else "(empty)"
else:
prompt = messages[-1]["content"] if messages else "(empty)"
def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
model = _get_model(device, dtype)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.inference_mode():
if device != "cpu":
autocast_ctx = torch.autocast(device_type=device, dtype=dtype)
else:
autocast_ctx = torch.cpu.amp.autocast(dtype=dtype)
with autocast_ctx:
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
if spaces:
# --- GPU path with ZeroGPU ---
@spaces.GPU(duration=120)
def run_once(prompt: str) -> str:
return _run_once(prompt, device="cuda", dtype=torch.float16)
text = run_once(prompt)
else:
# --- CPU-only fallback with auto dtype detection ---
dtype = _pick_cpu_dtype()
text = _run_once(prompt, device="cpu", dtype=dtype)
yield {
"id": rid,
"object": "chat.completion.chunk",
"created": now,
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
],
}
# ---------------- Stub Images Backend ----------------
class StubImagesBackend(ImagesBackend):
"""
Stub backend for images since HFChatBackend is text-only.
Returns a transparent 1x1 PNG placeholder.
"""
async def generate_b64(self, request: Dict[str, Any]) -> str:
logger.warning("Image generation not supported in HF backend.")
return (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
)
|