|
|
import time, logging, json, asyncio |
|
|
from contextlib import nullcontext |
|
|
from typing import Any, Dict, AsyncIterable, Tuple |
|
|
|
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig |
|
|
from backends_base import ChatBackend, ImagesBackend |
|
|
from config import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _snippet(txt: str, n: int = 800) -> str: |
|
|
if not isinstance(txt, str): |
|
|
return f"<non-str:{type(txt)}>" |
|
|
return txt if len(txt) <= n else txt[:n] + f"... <+{len(txt)-n} chars>" |
|
|
|
|
|
try: |
|
|
import spaces |
|
|
from spaces.zero import client as zero_client |
|
|
except ImportError: |
|
|
spaces, zero_client = None, None |
|
|
|
|
|
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct" |
|
|
logger.info(f"[init] MODEL_ID={MODEL_ID}") |
|
|
|
|
|
tokenizer, load_error = None, None |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False) |
|
|
has_template = hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None) |
|
|
logger.info(f"[init] tokenizer loaded. chat_template={'yes' if has_template else 'no'}") |
|
|
except Exception as e: |
|
|
load_error = f"Failed to load tokenizer: {e}" |
|
|
logger.exception(load_error) |
|
|
|
|
|
|
|
|
def probe_bf16_runtime() -> bool: |
|
|
"""Check if BF16 is both reported and actually used in ops on CPU.""" |
|
|
if not (hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported")): |
|
|
return False |
|
|
if not torch.cpu.is_bf16_supported(): |
|
|
return False |
|
|
try: |
|
|
a = torch.randn(16, 16, dtype=torch.bfloat16) |
|
|
b = torch.randn(16, 16, dtype=torch.bfloat16) |
|
|
c = a @ b |
|
|
return c.dtype == torch.bfloat16 |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def _pick_cpu_dtype() -> torch.dtype: |
|
|
try: |
|
|
if probe_bf16_runtime(): |
|
|
logger.info("[dtype] Verified BF16 execution on CPU -> torch.bfloat16") |
|
|
return torch.bfloat16 |
|
|
except Exception as e: |
|
|
logger.warning(f"[dtype] BF16 probe failed: {e}") |
|
|
logger.info("[dtype] fallback -> torch.float32") |
|
|
return torch.float32 |
|
|
|
|
|
|
|
|
|
|
|
CPU_DTYPE = _pick_cpu_dtype() |
|
|
logger.info(f"[init] Default CPU dtype = {CPU_DTYPE}") |
|
|
|
|
|
|
|
|
_MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {} |
|
|
|
|
|
def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]: |
|
|
key = (device, dtype) |
|
|
if key in _MODEL_CACHE: |
|
|
logger.info(f"[cache] hit model for device={device} dtype={dtype}") |
|
|
return _MODEL_CACHE[key], dtype |
|
|
|
|
|
logger.info(f"[load] begin from_pretrained device={device} dtype={dtype}") |
|
|
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
if hasattr(cfg, "quantization_config"): |
|
|
logger.warning("[load] removing quantization_config from config to avoid FP8 path") |
|
|
delattr(cfg, "quantization_config") |
|
|
|
|
|
eff_dtype = dtype |
|
|
try: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
config=cfg, |
|
|
torch_dtype=dtype, |
|
|
trust_remote_code=True, |
|
|
device_map="auto" if device != "cpu" else {"": "cpu"}, |
|
|
low_cpu_mem_usage=False, |
|
|
) |
|
|
except Exception as e: |
|
|
if device == "cpu" and dtype == torch.bfloat16: |
|
|
logger.warning(f"[load] BF16 load failed on CPU ({e}). retry FP32.") |
|
|
eff_dtype = torch.float32 |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
config=cfg, |
|
|
torch_dtype=eff_dtype, |
|
|
trust_remote_code=True, |
|
|
device_map={"": "cpu"}, |
|
|
low_cpu_mem_usage=False, |
|
|
) |
|
|
else: |
|
|
logger.exception("[load] from_pretrained failed") |
|
|
raise |
|
|
|
|
|
if device == "cpu": |
|
|
logger.info(f"[load] casting all weights to CPU dtype={eff_dtype}") |
|
|
model = model.to(device=device, dtype=eff_dtype) |
|
|
else: |
|
|
logger.info(f"[load] moving model to device={device} (no recast)") |
|
|
model = model.to(device=device) |
|
|
|
|
|
model.eval() |
|
|
try: |
|
|
first_dtype = next(model.parameters()).dtype |
|
|
logger.info(f"[load] ready. effective_dtype={eff_dtype} first_param_dtype={first_dtype}") |
|
|
except Exception: |
|
|
logger.info(f"[load] ready. effective_dtype={eff_dtype} (param dtype probe failed)") |
|
|
|
|
|
_MODEL_CACHE[(device, eff_dtype)] = model |
|
|
return model, eff_dtype |
|
|
|
|
|
def _max_context(model, tokenizer) -> int: |
|
|
mc = getattr(getattr(model, "config", None), "max_position_embeddings", None) |
|
|
if isinstance(mc, int) and mc > 0: |
|
|
return mc |
|
|
tk = getattr(tokenizer, "model_max_length", None) |
|
|
if isinstance(tk, int) and tk > 0 and tk < 10**12: |
|
|
return tk |
|
|
return 32768 |
|
|
|
|
|
def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int, model, tokenizer): |
|
|
toks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) |
|
|
input_ids = toks["input_ids"] |
|
|
attn = toks.get("attention_mask", None) |
|
|
|
|
|
ctx = _max_context(model, tokenizer) |
|
|
limit = max(8, ctx - max_new_tokens) |
|
|
in_len = input_ids.shape[-1] |
|
|
if in_len > limit: |
|
|
cut = in_len - limit |
|
|
input_ids = input_ids[:, -limit:] |
|
|
if attn is not None: |
|
|
attn = attn[:, -limit:] |
|
|
logger.warning(f"[truncate] prompt_tokens={in_len} > limit={limit}. truncated_left_by={cut} to fit ctx={ctx}, new_input={input_ids.shape[-1]}, max_new={max_new_tokens}") |
|
|
|
|
|
inputs = {"input_ids": input_ids} |
|
|
if attn is not None: |
|
|
inputs["attention_mask"] = attn |
|
|
|
|
|
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()} |
|
|
return inputs, in_len, ctx, limit |
|
|
|
|
|
class HFChatBackend(ChatBackend): |
|
|
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]: |
|
|
if load_error: |
|
|
raise RuntimeError(load_error) |
|
|
|
|
|
messages = request.get("messages", []) |
|
|
tools = request.get("tools") |
|
|
temperature = float(request.get("temperature", settings.LlmTemp or 0.3)) |
|
|
req_max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 32000)) |
|
|
|
|
|
rid = f"chatcmpl-hf-{int(time.time())}" |
|
|
now = int(time.time()) |
|
|
|
|
|
logger.info(f"[req] rid={rid} temp={temperature} req_max_tokens={req_max_tokens} " |
|
|
f"msgs={len(messages)} tools={'yes' if tools else 'no'} " |
|
|
f"spaces={'yes' if spaces else 'no'} cuda={'yes' if torch.cuda.is_available() else 'no'}") |
|
|
|
|
|
x_ip_token = request.get("x_ip_token") |
|
|
if x_ip_token and zero_client: |
|
|
zero_client.HEADERS["X-IP-Token"] = x_ip_token |
|
|
logger.info("[req] injected X-IP-Token into ZeroGPU headers") |
|
|
|
|
|
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): |
|
|
try: |
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
) |
|
|
logger.info(f"[prompt] built via chat_template. len={len(prompt)}\n{_snippet(prompt, 800)}") |
|
|
except Exception as e: |
|
|
logger.warning(f"[prompt] chat_template failed -> fallback. err={e}") |
|
|
prompt = messages[-1]["content"] if messages else "(empty)" |
|
|
logger.info(f"[prompt] fallback content len={len(prompt)}\n{_snippet(prompt, 800)}") |
|
|
else: |
|
|
prompt = messages[-1]["content"] if messages else "(empty)" |
|
|
logger.info(f"[prompt] no template. using last user text len={len(prompt)}\n{_snippet(prompt, 800)}") |
|
|
|
|
|
def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str: |
|
|
model, eff_dtype = _get_model(device, req_dtype) |
|
|
max_new_tokens = req_max_tokens |
|
|
|
|
|
inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer) |
|
|
|
|
|
logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} " |
|
|
f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}") |
|
|
|
|
|
do_sample = temperature > 1e-6 |
|
|
temp = max(1e-5, temperature) if do_sample else 0.0 |
|
|
|
|
|
eos_id = tokenizer.eos_token_id |
|
|
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id |
|
|
|
|
|
with torch.inference_mode(): |
|
|
if device != "cpu": |
|
|
autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype) |
|
|
else: |
|
|
autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) if eff_dtype == torch.bfloat16 else nullcontext() |
|
|
|
|
|
gen_kwargs = dict( |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temp, |
|
|
do_sample=do_sample, |
|
|
use_cache=True, |
|
|
eos_token_id=eos_id, |
|
|
pad_token_id=pad_id, |
|
|
) |
|
|
logger.info(f"[gen] kwargs={gen_kwargs}") |
|
|
|
|
|
with autocast_ctx: |
|
|
outputs = model.generate(**inputs, **gen_kwargs) |
|
|
|
|
|
input_len = inputs["input_ids"].shape[-1] |
|
|
generated_ids = outputs[0][input_len:] |
|
|
logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}") |
|
|
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}") |
|
|
return text |
|
|
|
|
|
if spaces: |
|
|
@spaces.GPU(duration=120) |
|
|
def run_once_sync(prompt: str) -> str: |
|
|
if torch.cuda.is_available(): |
|
|
logger.info("[path] ZeroGPU + CUDA") |
|
|
return _run_once(prompt, device="cuda", req_dtype=torch.float16) |
|
|
logger.info("[path] ZeroGPU but no CUDA -> CPU fallback") |
|
|
return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype()) |
|
|
text = await asyncio.to_thread(run_once_sync, prompt) |
|
|
else: |
|
|
logger.info("[path] CPU-only runtime") |
|
|
text = await asyncio.to_thread(_run_once, prompt, "cpu", _pick_cpu_dtype()) |
|
|
|
|
|
chunk = { |
|
|
"id": rid, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": now, |
|
|
"model": MODEL_ID, |
|
|
"choices": [ |
|
|
{"index": 0, "delta": {"role": "assistant", "content": text}, "finish_reason": "stop"} |
|
|
], |
|
|
} |
|
|
logger.info(f"[out] chunk summary -> id={rid} content_len={len(text)}") |
|
|
yield chunk |
|
|
|
|
|
|
|
|
class StubImagesBackend(ImagesBackend): |
|
|
async def generate_b64(self, request: Dict[str, Any]) -> str: |
|
|
logger.warning("Image generation not supported in HF backend.") |
|
|
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII=" |
|
|
|