GradLLM / hf_backend.py
johnbridges's picture
.
1175344
raw
history blame
10.9 kB
import time, logging, json, asyncio
from contextlib import nullcontext
from typing import Any, Dict, AsyncIterable, Tuple
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from backends_base import ChatBackend, ImagesBackend
from config import settings
logger = logging.getLogger(__name__)
def _snippet(txt: str, n: int = 800) -> str:
if not isinstance(txt, str):
return f"<non-str:{type(txt)}>"
return txt if len(txt) <= n else txt[:n] + f"... <+{len(txt)-n} chars>"
try:
import spaces
from spaces.zero import client as zero_client
except ImportError:
spaces, zero_client = None, None
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
logger.info(f"[init] MODEL_ID={MODEL_ID}")
tokenizer, load_error = None, None
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
has_template = hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None)
logger.info(f"[init] tokenizer loaded. chat_template={'yes' if has_template else 'no'}")
except Exception as e:
load_error = f"Failed to load tokenizer: {e}"
logger.exception(load_error)
def probe_bf16_runtime() -> bool:
"""Check if BF16 is both reported and actually used in ops on CPU."""
if not (hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported")):
return False
if not torch.cpu.is_bf16_supported():
return False
try:
a = torch.randn(16, 16, dtype=torch.bfloat16)
b = torch.randn(16, 16, dtype=torch.bfloat16)
c = a @ b
return c.dtype == torch.bfloat16
except Exception:
return False
def _pick_cpu_dtype() -> torch.dtype:
try:
if probe_bf16_runtime():
logger.info("[dtype] Verified BF16 execution on CPU -> torch.bfloat16")
return torch.bfloat16
except Exception as e:
logger.warning(f"[dtype] BF16 probe failed: {e}")
logger.info("[dtype] fallback -> torch.float32")
return torch.float32
# Log CPU dtype capability at startup
CPU_DTYPE = _pick_cpu_dtype()
logger.info(f"[init] Default CPU dtype = {CPU_DTYPE}")
_MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
key = (device, dtype)
if key in _MODEL_CACHE:
logger.info(f"[cache] hit model for device={device} dtype={dtype}")
return _MODEL_CACHE[key], dtype
logger.info(f"[load] begin from_pretrained device={device} dtype={dtype}")
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
if hasattr(cfg, "quantization_config"):
logger.warning("[load] removing quantization_config from config to avoid FP8 path")
delattr(cfg, "quantization_config")
eff_dtype = dtype
try:
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=cfg,
torch_dtype=dtype,
trust_remote_code=True,
device_map="auto" if device != "cpu" else {"": "cpu"},
low_cpu_mem_usage=False,
)
except Exception as e:
if device == "cpu" and dtype == torch.bfloat16:
logger.warning(f"[load] BF16 load failed on CPU ({e}). retry FP32.")
eff_dtype = torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=cfg,
torch_dtype=eff_dtype,
trust_remote_code=True,
device_map={"": "cpu"},
low_cpu_mem_usage=False,
)
else:
logger.exception("[load] from_pretrained failed")
raise
if device == "cpu":
logger.info(f"[load] casting all weights to CPU dtype={eff_dtype}")
model = model.to(device=device, dtype=eff_dtype)
else:
logger.info(f"[load] moving model to device={device} (no recast)")
model = model.to(device=device)
model.eval()
try:
first_dtype = next(model.parameters()).dtype
logger.info(f"[load] ready. effective_dtype={eff_dtype} first_param_dtype={first_dtype}")
except Exception:
logger.info(f"[load] ready. effective_dtype={eff_dtype} (param dtype probe failed)")
_MODEL_CACHE[(device, eff_dtype)] = model
return model, eff_dtype
def _max_context(model, tokenizer) -> int:
mc = getattr(getattr(model, "config", None), "max_position_embeddings", None)
if isinstance(mc, int) and mc > 0:
return mc
tk = getattr(tokenizer, "model_max_length", None)
if isinstance(tk, int) and tk > 0 and tk < 10**12:
return tk
return 32768 # safe default for Qwen3
def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int, model, tokenizer):
toks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
input_ids = toks["input_ids"]
attn = toks.get("attention_mask", None)
ctx = _max_context(model, tokenizer)
limit = max(8, ctx - max_new_tokens)
in_len = input_ids.shape[-1]
if in_len > limit:
cut = in_len - limit
input_ids = input_ids[:, -limit:]
if attn is not None:
attn = attn[:, -limit:]
logger.warning(f"[truncate] prompt_tokens={in_len} > limit={limit}. truncated_left_by={cut} to fit ctx={ctx}, new_input={input_ids.shape[-1]}, max_new={max_new_tokens}")
inputs = {"input_ids": input_ids}
if attn is not None:
inputs["attention_mask"] = attn
inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
return inputs, in_len, ctx, limit
class HFChatBackend(ChatBackend):
async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
if load_error:
raise RuntimeError(load_error)
messages = request.get("messages", [])
tools = request.get("tools")
temperature = float(request.get("temperature", settings.LlmTemp or 0.3))
req_max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 32000))
rid = f"chatcmpl-hf-{int(time.time())}"
now = int(time.time())
logger.info(f"[req] rid={rid} temp={temperature} req_max_tokens={req_max_tokens} "
f"msgs={len(messages)} tools={'yes' if tools else 'no'} "
f"spaces={'yes' if spaces else 'no'} cuda={'yes' if torch.cuda.is_available() else 'no'}")
x_ip_token = request.get("x_ip_token")
if x_ip_token and zero_client:
zero_client.HEADERS["X-IP-Token"] = x_ip_token
logger.info("[req] injected X-IP-Token into ZeroGPU headers")
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
try:
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
logger.info(f"[prompt] built via chat_template. len={len(prompt)}\n{_snippet(prompt, 800)}")
except Exception as e:
logger.warning(f"[prompt] chat_template failed -> fallback. err={e}")
prompt = messages[-1]["content"] if messages else "(empty)"
logger.info(f"[prompt] fallback content len={len(prompt)}\n{_snippet(prompt, 800)}")
else:
prompt = messages[-1]["content"] if messages else "(empty)"
logger.info(f"[prompt] no template. using last user text len={len(prompt)}\n{_snippet(prompt, 800)}")
def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
model, eff_dtype = _get_model(device, req_dtype)
max_new_tokens = req_max_tokens
inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer)
logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} "
f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}")
do_sample = temperature > 1e-6
temp = max(1e-5, temperature) if do_sample else 0.0
eos_id = tokenizer.eos_token_id
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id
with torch.inference_mode():
if device != "cpu":
autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype)
else:
autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) if eff_dtype == torch.bfloat16 else nullcontext()
gen_kwargs = dict(
max_new_tokens=max_new_tokens,
temperature=temp,
do_sample=do_sample,
use_cache=True,
eos_token_id=eos_id,
pad_token_id=pad_id,
)
logger.info(f"[gen] kwargs={gen_kwargs}")
with autocast_ctx:
outputs = model.generate(**inputs, **gen_kwargs)
input_len = inputs["input_ids"].shape[-1]
generated_ids = outputs[0][input_len:]
logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
return text
if spaces:
@spaces.GPU(duration=120)
def run_once_sync(prompt: str) -> str:
if torch.cuda.is_available():
logger.info("[path] ZeroGPU + CUDA")
return _run_once(prompt, device="cuda", req_dtype=torch.float16)
logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
text = await asyncio.to_thread(run_once_sync, prompt)
else:
logger.info("[path] CPU-only runtime")
text = await asyncio.to_thread(_run_once, prompt, "cpu", _pick_cpu_dtype())
chunk = {
"id": rid,
"object": "chat.completion.chunk",
"created": now,
"model": MODEL_ID,
"choices": [
{"index": 0, "delta": {"role": "assistant", "content": text}, "finish_reason": "stop"}
],
}
logger.info(f"[out] chunk summary -> id={rid} content_len={len(text)}")
yield chunk
class StubImagesBackend(ImagesBackend):
async def generate_b64(self, request: Dict[str, Any]) -> str:
logger.warning("Image generation not supported in HF backend.")
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="