File size: 10,934 Bytes
b416f51
d76b941
 
2dcb7ad
 
1d79762
2dcb7ad
 
 
60a9595
 
bf6d44e
 
 
 
 
2dcb7ad
 
60a9595
2dcb7ad
60a9595
2dcb7ad
 
bf6d44e
2dcb7ad
2ad6a17
2dcb7ad
be6d3d6
bf6d44e
 
2dcb7ad
2ad6a17
2dcb7ad
 
1175344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d79762
bf6d44e
1175344
 
bf6d44e
 
 
 
1d79762
 
1175344
 
 
 
 
 
d76b941
849364d
d76b941
11cacc3
 
bf6d44e
11cacc3
849364d
bf6d44e
849364d
 
bf6d44e
11cacc3
d279e64
d76b941
d279e64
 
 
 
 
 
 
213e916
d279e64
 
 
bf6d44e
d76b941
d279e64
 
 
d76b941
d279e64
 
11cacc3
d279e64
 
bf6d44e
d279e64
849364d
213e916
bf6d44e
213e916
 
bf6d44e
213e916
11cacc3
849364d
bf6d44e
 
 
 
 
 
d76b941
 
849364d
be6d3d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
849364d
2dcb7ad
 
 
 
 
 
bf6d44e
552430d
 
2dcb7ad
 
 
 
be6d3d6
bf6d44e
 
 
7471f75
60a9595
 
bf6d44e
7471f75
d76b941
d279e64
 
 
 
 
 
be6d3d6
d279e64
bf6d44e
d279e64
bf6d44e
d279e64
 
bf6d44e
d279e64
d76b941
 
552430d
be6d3d6
 
 
 
 
bf6d44e
be6d3d6
 
 
 
 
60a9595
2ad6a17
 
d76b941
2ad6a17
bf6d44e
 
 
be6d3d6
 
 
bf6d44e
be6d3d6
 
bf6d44e
 
2ad6a17
 
bf6d44e
 
be6d3d6
213e916
bf6d44e
213e916
bf6d44e
213e916
7471f75
2ad6a17
 
b416f51
d76b941
bf6d44e
d76b941
bf6d44e
d76b941
b416f51
2ad6a17
bf6d44e
b416f51
2ad6a17
bf6d44e
2ad6a17
 
 
 
 
213e916
2ad6a17
 
bf6d44e
 
60a9595
 
1b21789
 
 
bf6d44e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import time, logging, json, asyncio
from contextlib import nullcontext
from typing import Any, Dict, AsyncIterable, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from backends_base import ChatBackend, ImagesBackend
from config import settings

logger = logging.getLogger(__name__)

def _snippet(txt: str, n: int = 800) -> str:
    if not isinstance(txt, str):
        return f"<non-str:{type(txt)}>"
    return txt if len(txt) <= n else txt[:n] + f"... <+{len(txt)-n} chars>"

try:
    import spaces
    from spaces.zero import client as zero_client
except ImportError:
    spaces, zero_client = None, None

MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
logger.info(f"[init] MODEL_ID={MODEL_ID}")

tokenizer, load_error = None, None
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
    has_template = hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None)
    logger.info(f"[init] tokenizer loaded. chat_template={'yes' if has_template else 'no'}")
except Exception as e:
    load_error = f"Failed to load tokenizer: {e}"
    logger.exception(load_error)


def probe_bf16_runtime() -> bool:
    """Check if BF16 is both reported and actually used in ops on CPU."""
    if not (hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported")):
        return False
    if not torch.cpu.is_bf16_supported():
        return False
    try:
        a = torch.randn(16, 16, dtype=torch.bfloat16)
        b = torch.randn(16, 16, dtype=torch.bfloat16)
        c = a @ b
        return c.dtype == torch.bfloat16
    except Exception:
        return False


def _pick_cpu_dtype() -> torch.dtype:
    try:
        if probe_bf16_runtime():
            logger.info("[dtype] Verified BF16 execution on CPU -> torch.bfloat16")
            return torch.bfloat16
    except Exception as e:
        logger.warning(f"[dtype] BF16 probe failed: {e}")
    logger.info("[dtype] fallback -> torch.float32")
    return torch.float32


# Log CPU dtype capability at startup
CPU_DTYPE = _pick_cpu_dtype()
logger.info(f"[init] Default CPU dtype = {CPU_DTYPE}")


_MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}

def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
    key = (device, dtype)
    if key in _MODEL_CACHE:
        logger.info(f"[cache] hit model for device={device} dtype={dtype}")
        return _MODEL_CACHE[key], dtype

    logger.info(f"[load] begin from_pretrained device={device} dtype={dtype}")
    cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
    if hasattr(cfg, "quantization_config"):
        logger.warning("[load] removing quantization_config from config to avoid FP8 path")
        delattr(cfg, "quantization_config")

    eff_dtype = dtype
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            config=cfg,
            torch_dtype=dtype,
            trust_remote_code=True,
            device_map="auto" if device != "cpu" else {"": "cpu"},
            low_cpu_mem_usage=False,
        )
    except Exception as e:
        if device == "cpu" and dtype == torch.bfloat16:
            logger.warning(f"[load] BF16 load failed on CPU ({e}). retry FP32.")
            eff_dtype = torch.float32
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID,
                config=cfg,
                torch_dtype=eff_dtype,
                trust_remote_code=True,
                device_map={"": "cpu"},
                low_cpu_mem_usage=False,
            )
        else:
            logger.exception("[load] from_pretrained failed")
            raise

    if device == "cpu":
        logger.info(f"[load] casting all weights to CPU dtype={eff_dtype}")
        model = model.to(device=device, dtype=eff_dtype)
    else:
        logger.info(f"[load] moving model to device={device} (no recast)")
        model = model.to(device=device)

    model.eval()
    try:
        first_dtype = next(model.parameters()).dtype
        logger.info(f"[load] ready. effective_dtype={eff_dtype} first_param_dtype={first_dtype}")
    except Exception:
        logger.info(f"[load] ready. effective_dtype={eff_dtype} (param dtype probe failed)")

    _MODEL_CACHE[(device, eff_dtype)] = model
    return model, eff_dtype

def _max_context(model, tokenizer) -> int:
    mc = getattr(getattr(model, "config", None), "max_position_embeddings", None)
    if isinstance(mc, int) and mc > 0:
        return mc
    tk = getattr(tokenizer, "model_max_length", None)
    if isinstance(tk, int) and tk > 0 and tk < 10**12:
        return tk
    return 32768  # safe default for Qwen3

def _build_inputs_with_truncation(prompt: str, device: str, max_new_tokens: int, model, tokenizer):
    toks = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    input_ids = toks["input_ids"]
    attn = toks.get("attention_mask", None)

    ctx = _max_context(model, tokenizer)
    limit = max(8, ctx - max_new_tokens)
    in_len = input_ids.shape[-1]
    if in_len > limit:
        cut = in_len - limit
        input_ids = input_ids[:, -limit:]
        if attn is not None:
            attn = attn[:, -limit:]
        logger.warning(f"[truncate] prompt_tokens={in_len} > limit={limit}. truncated_left_by={cut} to fit ctx={ctx}, new_input={input_ids.shape[-1]}, max_new={max_new_tokens}")

    inputs = {"input_ids": input_ids}
    if attn is not None:
        inputs["attention_mask"] = attn

    inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}
    return inputs, in_len, ctx, limit

class HFChatBackend(ChatBackend):
    async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
        if load_error:
            raise RuntimeError(load_error)

        messages = request.get("messages", [])
        tools = request.get("tools")
        temperature = float(request.get("temperature", settings.LlmTemp or 0.3))
        req_max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 32000))

        rid = f"chatcmpl-hf-{int(time.time())}"
        now = int(time.time())

        logger.info(f"[req] rid={rid} temp={temperature} req_max_tokens={req_max_tokens} "
                    f"msgs={len(messages)} tools={'yes' if tools else 'no'} "
                    f"spaces={'yes' if spaces else 'no'} cuda={'yes' if torch.cuda.is_available() else 'no'}")

        x_ip_token = request.get("x_ip_token")
        if x_ip_token and zero_client:
            zero_client.HEADERS["X-IP-Token"] = x_ip_token
            logger.info("[req] injected X-IP-Token into ZeroGPU headers")

        if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
            try:
                prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                )
                logger.info(f"[prompt] built via chat_template. len={len(prompt)}\n{_snippet(prompt, 800)}")
            except Exception as e:
                logger.warning(f"[prompt] chat_template failed -> fallback. err={e}")
                prompt = messages[-1]["content"] if messages else "(empty)"
                logger.info(f"[prompt] fallback content len={len(prompt)}\n{_snippet(prompt, 800)}")
        else:
            prompt = messages[-1]["content"] if messages else "(empty)"
            logger.info(f"[prompt] no template. using last user text len={len(prompt)}\n{_snippet(prompt, 800)}")

        def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
            model, eff_dtype = _get_model(device, req_dtype)
            max_new_tokens = req_max_tokens

            inputs, orig_in_len, ctx, limit = _build_inputs_with_truncation(prompt, device, max_new_tokens, model, tokenizer)

            logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={inputs['input_ids'].shape[-1]} "
                        f"(orig={orig_in_len}) max_ctx={ctx} limit_for_input={limit} max_new_tokens={max_new_tokens}")

            do_sample = temperature > 1e-6
            temp = max(1e-5, temperature) if do_sample else 0.0

            eos_id = tokenizer.eos_token_id
            pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else eos_id

            with torch.inference_mode():
                if device != "cpu":
                    autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype)
                else:
                    autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) if eff_dtype == torch.bfloat16 else nullcontext()

                gen_kwargs = dict(
                    max_new_tokens=max_new_tokens,
                    temperature=temp,
                    do_sample=do_sample,
                    use_cache=True,
                    eos_token_id=eos_id,
                    pad_token_id=pad_id,
                )
                logger.info(f"[gen] kwargs={gen_kwargs}")

                with autocast_ctx:
                    outputs = model.generate(**inputs, **gen_kwargs)

            input_len = inputs["input_ids"].shape[-1]
            generated_ids = outputs[0][input_len:]
            logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
            text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
            return text

        if spaces:
            @spaces.GPU(duration=120)
            def run_once_sync(prompt: str) -> str:
                if torch.cuda.is_available():
                    logger.info("[path] ZeroGPU + CUDA")
                    return _run_once(prompt, device="cuda", req_dtype=torch.float16)
                logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
                return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
            text = await asyncio.to_thread(run_once_sync, prompt)
        else:
            logger.info("[path] CPU-only runtime")
            text = await asyncio.to_thread(_run_once, prompt, "cpu", _pick_cpu_dtype())

        chunk = {
            "id": rid,
            "object": "chat.completion.chunk",
            "created": now,
            "model": MODEL_ID,
            "choices": [
                {"index": 0, "delta": {"role": "assistant", "content": text}, "finish_reason": "stop"}
            ],
        }
        logger.info(f"[out] chunk summary -> id={rid} content_len={len(text)}")
        yield chunk


class StubImagesBackend(ImagesBackend):
    async def generate_b64(self, request: Dict[str, Any]) -> str:
        logger.warning("Image generation not supported in HF backend.")
        return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="