|
|
|
|
|
import time, logging |
|
|
from typing import Any, Dict, AsyncIterable |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
|
from backends_base import ChatBackend, ImagesBackend |
|
|
from config import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
|
import spaces |
|
|
from spaces.zero import client as zero_client |
|
|
except ImportError: |
|
|
spaces, zero_client = None, None |
|
|
|
|
|
|
|
|
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct" |
|
|
logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU...") |
|
|
|
|
|
tokenizer, load_error = None, None |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
MODEL_ID, |
|
|
trust_remote_code=True, |
|
|
use_fast=False, |
|
|
) |
|
|
except Exception as e: |
|
|
load_error = f"Failed to load tokenizer: {e}" |
|
|
logger.exception(load_error) |
|
|
|
|
|
|
|
|
|
|
|
def _pick_cpu_dtype() -> torch.dtype: |
|
|
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"): |
|
|
try: |
|
|
if torch.cpu.is_bf16_supported(): |
|
|
logger.info("CPU BF16 supported, will attempt torch.bfloat16") |
|
|
return torch.bfloat16 |
|
|
except Exception: |
|
|
pass |
|
|
logger.info("Falling back to torch.float32 on CPU") |
|
|
return torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
_MODEL_CACHE: dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {} |
|
|
|
|
|
|
|
|
def _get_model(device: str, dtype: torch.dtype): |
|
|
key = (device, dtype) |
|
|
if key in _MODEL_CACHE: |
|
|
return _MODEL_CACHE[key] |
|
|
|
|
|
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
if hasattr(cfg, "quantization_config"): |
|
|
logger.warning("Removing quantization_config from model config") |
|
|
delattr(cfg, "quantization_config") |
|
|
|
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
config=cfg, |
|
|
torch_dtype=dtype, |
|
|
trust_remote_code=True, |
|
|
device_map="auto" if device != "cpu" else {"": "cpu"}, |
|
|
) |
|
|
except Exception as e: |
|
|
if device == "cpu" and dtype == torch.bfloat16: |
|
|
logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
config=cfg, |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True, |
|
|
device_map={"": "cpu"}, |
|
|
) |
|
|
dtype = torch.float32 |
|
|
else: |
|
|
raise |
|
|
|
|
|
model.eval() |
|
|
_MODEL_CACHE[(device, dtype)] = model |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
class HFChatBackend(ChatBackend): |
|
|
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]: |
|
|
if load_error: |
|
|
raise RuntimeError(load_error) |
|
|
|
|
|
messages = request.get("messages", []) |
|
|
temperature = float(request.get("temperature", settings.LlmTemp or 0.7)) |
|
|
max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512)) |
|
|
|
|
|
rid = f"chatcmpl-hf-{int(time.time())}" |
|
|
now = int(time.time()) |
|
|
|
|
|
|
|
|
x_ip_token = request.get("x_ip_token") |
|
|
if x_ip_token and zero_client: |
|
|
zero_client.HEADERS["X-IP-Token"] = x_ip_token |
|
|
logger.debug("Injected X-IP-Token into ZeroGPU headers") |
|
|
|
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template: |
|
|
try: |
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
logger.debug("Applied chat template for prompt") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to apply chat template: {e}, using fallback") |
|
|
prompt = messages[-1]["content"] if messages else "(empty)" |
|
|
else: |
|
|
prompt = messages[-1]["content"] if messages else "(empty)" |
|
|
|
|
|
def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str: |
|
|
model = _get_model(device, dtype) |
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
if device != "cpu": |
|
|
autocast_ctx = torch.autocast(device_type=device, dtype=dtype) |
|
|
else: |
|
|
autocast_ctx = torch.cpu.amp.autocast(dtype=dtype) |
|
|
|
|
|
with autocast_ctx: |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if spaces: |
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def run_once(prompt: str) -> str: |
|
|
return _run_once(prompt, device="cuda", dtype=torch.float16) |
|
|
|
|
|
text = run_once(prompt) |
|
|
else: |
|
|
|
|
|
dtype = _pick_cpu_dtype() |
|
|
text = _run_once(prompt, device="cpu", dtype=dtype) |
|
|
|
|
|
yield { |
|
|
"id": rid, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": now, |
|
|
"model": MODEL_ID, |
|
|
"choices": [ |
|
|
{"index": 0, "delta": {"content": text}, "finish_reason": "stop"} |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class StubImagesBackend(ImagesBackend): |
|
|
""" |
|
|
Stub backend for images since HFChatBackend is text-only. |
|
|
Returns a transparent 1x1 PNG placeholder. |
|
|
""" |
|
|
async def generate_b64(self, request: Dict[str, Any]) -> str: |
|
|
logger.warning("Image generation not supported in HF backend.") |
|
|
return ( |
|
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII=" |
|
|
) |
|
|
|