File size: 6,818 Bytes
60a9595
 
d76b941
 
2dcb7ad
 
1d79762
2dcb7ad
 
 
60a9595
 
2dcb7ad
 
60a9595
2dcb7ad
60a9595
2dcb7ad
2ad6a17
2dcb7ad
2ad6a17
2dcb7ad
2ad6a17
2dcb7ad
2ad6a17
2dcb7ad
 
1d79762
2dcb7ad
 
2ad6a17
2dcb7ad
 
 
1d79762
 
 
 
 
d279e64
1d79762
 
 
 
 
 
 
849364d
d76b941
849364d
 
d76b941
11cacc3
 
 
849364d
 
 
 
11cacc3
d279e64
d76b941
d279e64
 
 
 
 
 
 
11cacc3
d279e64
 
 
 
d76b941
d279e64
 
 
d76b941
d279e64
 
11cacc3
d279e64
 
 
849364d
11cacc3
 
 
849364d
d76b941
 
849364d
 
60a9595
2dcb7ad
 
 
 
 
 
 
 
 
 
 
 
2ad6a17
7471f75
60a9595
 
 
7471f75
d279e64
d76b941
d279e64
 
 
 
 
 
 
 
 
 
 
 
 
d76b941
 
 
 
 
60a9595
2ad6a17
 
d76b941
2ad6a17
d76b941
 
 
 
2ad6a17
 
 
 
 
 
 
d76b941
2ad6a17
 
7471f75
 
2ad6a17
 
 
d76b941
 
 
2ad6a17
60a9595
2ad6a17
d76b941
2ad6a17
 
 
 
 
 
 
 
 
 
60a9595
 
 
1b21789
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
# hf_backend.py
import time, logging
from contextlib import nullcontext
from typing import Any, Dict, AsyncIterable, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from backends_base import ChatBackend, ImagesBackend
from config import settings

logger = logging.getLogger(__name__)

try:
    import spaces
    from spaces.zero import client as zero_client
except ImportError:
    spaces, zero_client = None, None

# --- Model setup ---
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU...")

tokenizer, load_error = None, None
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        use_fast=False,
    )
except Exception as e:
    load_error = f"Failed to load tokenizer: {e}"
    logger.exception(load_error)


# ---------------- helpers ----------------
def _pick_cpu_dtype() -> torch.dtype:
    if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
        try:
            if torch.cpu.is_bf16_supported():
                logger.info("CPU BF16 supported, will attempt torch.bfloat16")
                return torch.bfloat16
        except Exception:
            pass
    logger.info("Falling back to torch.float32 on CPU")
    return torch.float32


# ---------------- global cache ----------------
_MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}


def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
    key = (device, dtype)
    if key in _MODEL_CACHE:
        return _MODEL_CACHE[key], dtype

    cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
    if hasattr(cfg, "quantization_config"):
        logger.warning("Removing quantization_config from model config")
        delattr(cfg, "quantization_config")

    eff_dtype = dtype
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            config=cfg,
            torch_dtype=dtype,
            trust_remote_code=True,
            device_map="auto" if device != "cpu" else {"": "cpu"},
            low_cpu_mem_usage=False,  # ensure full load before casting
        )
    except Exception as e:
        if device == "cpu" and dtype == torch.bfloat16:
            logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
            eff_dtype = torch.float32
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID,
                config=cfg,
                torch_dtype=eff_dtype,
                trust_remote_code=True,
                device_map={"": "cpu"},
                low_cpu_mem_usage=False,
            )
        else:
            raise

    # --- Force recast to target dtype/device (fixes FP8 leftovers) ---
    model = model.to(device=device, dtype=eff_dtype)

    model.eval()
    _MODEL_CACHE[(device, eff_dtype)] = model
    return model, eff_dtype


# ---------------- Chat Backend ----------------
class HFChatBackend(ChatBackend):
    async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
        if load_error:
            raise RuntimeError(load_error)

        messages = request.get("messages", [])
        temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
        max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))

        rid = f"chatcmpl-hf-{int(time.time())}"
        now = int(time.time())

        # --- Inject X-IP-Token into global headers if ZeroGPU is used ---
        x_ip_token = request.get("x_ip_token")
        if x_ip_token and zero_client:
            zero_client.HEADERS["X-IP-Token"] = x_ip_token
            logger.debug("Injected X-IP-Token into ZeroGPU headers")

        # Build prompt using chat template if available
        if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
            try:
                prompt = tokenizer.apply_chat_template(
                    messages,
                    tokenize=False,
                    add_generation_prompt=True,
                )
                logger.debug("Applied chat template for prompt")
            except Exception as e:
                logger.warning(f"Failed to apply chat template: {e}, using fallback")
                prompt = messages[-1]["content"] if messages else "(empty)"
        else:
            prompt = messages[-1]["content"] if messages else "(empty)"

        def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
            model, eff_dtype = _get_model(device, req_dtype)

            inputs = tokenizer(prompt, return_tensors="pt")
            inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}

            with torch.inference_mode():
                if device != "cpu":
                    autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype)
                else:
                    if eff_dtype == torch.bfloat16:
                        autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16)
                    else:
                        autocast_ctx = nullcontext()

                with autocast_ctx:
                    outputs = model.generate(
                        **inputs,
                        max_new_tokens=max_tokens,
                        temperature=temperature,
                        do_sample=True,
                        use_cache=True,
                    )

            return tokenizer.decode(outputs[0], skip_special_tokens=True)

        if spaces:
            @spaces.GPU(duration=120)
            def run_once(prompt: str) -> str:
                if torch.cuda.is_available():
                    return _run_once(prompt, device="cuda", req_dtype=torch.float16)
                return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())

            text = run_once(prompt)
        else:
            text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())

        yield {
            "id": rid,
            "object": "chat.completion.chunk",
            "created": now,
            "model": MODEL_ID,
            "choices": [
                {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
            ],
        }


# ---------------- Stub Images Backend ----------------
class StubImagesBackend(ImagesBackend):
    """
    Stub backend for images since HFChatBackend is text-only.
    Returns a transparent 1x1 PNG placeholder.
    """
    async def generate_b64(self, request: Dict[str, Any]) -> str:
        logger.warning("Image generation not supported in HF backend.")
        return (
            "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
        )