Commit
·
213e916
1
Parent(s):
11cacc3
- hf_backend.py +11 -11
hf_backend.py
CHANGED
|
@@ -67,7 +67,7 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
|
|
| 67 |
torch_dtype=dtype,
|
| 68 |
trust_remote_code=True,
|
| 69 |
device_map="auto" if device != "cpu" else {"": "cpu"},
|
| 70 |
-
low_cpu_mem_usage=False,
|
| 71 |
)
|
| 72 |
except Exception as e:
|
| 73 |
if device == "cpu" and dtype == torch.bfloat16:
|
|
@@ -84,8 +84,10 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
|
|
| 84 |
else:
|
| 85 |
raise
|
| 86 |
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
|
| 90 |
model.eval()
|
| 91 |
_MODEL_CACHE[(device, eff_dtype)] = model
|
|
@@ -105,13 +107,11 @@ class HFChatBackend(ChatBackend):
|
|
| 105 |
rid = f"chatcmpl-hf-{int(time.time())}"
|
| 106 |
now = int(time.time())
|
| 107 |
|
| 108 |
-
# --- Inject X-IP-Token into global headers if ZeroGPU is used ---
|
| 109 |
x_ip_token = request.get("x_ip_token")
|
| 110 |
if x_ip_token and zero_client:
|
| 111 |
zero_client.HEADERS["X-IP-Token"] = x_ip_token
|
| 112 |
logger.debug("Injected X-IP-Token into ZeroGPU headers")
|
| 113 |
|
| 114 |
-
# Build prompt using chat template if available
|
| 115 |
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
|
| 116 |
try:
|
| 117 |
prompt = tokenizer.apply_chat_template(
|
|
@@ -150,7 +150,11 @@ class HFChatBackend(ChatBackend):
|
|
| 150 |
use_cache=True,
|
| 151 |
)
|
| 152 |
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
if spaces:
|
| 156 |
@spaces.GPU(duration=120)
|
|
@@ -169,17 +173,13 @@ class HFChatBackend(ChatBackend):
|
|
| 169 |
"created": now,
|
| 170 |
"model": MODEL_ID,
|
| 171 |
"choices": [
|
| 172 |
-
{"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
|
| 173 |
],
|
| 174 |
}
|
| 175 |
|
| 176 |
|
| 177 |
# ---------------- Stub Images Backend ----------------
|
| 178 |
class StubImagesBackend(ImagesBackend):
|
| 179 |
-
"""
|
| 180 |
-
Stub backend for images since HFChatBackend is text-only.
|
| 181 |
-
Returns a transparent 1x1 PNG placeholder.
|
| 182 |
-
"""
|
| 183 |
async def generate_b64(self, request: Dict[str, Any]) -> str:
|
| 184 |
logger.warning("Image generation not supported in HF backend.")
|
| 185 |
return (
|
|
|
|
| 67 |
torch_dtype=dtype,
|
| 68 |
trust_remote_code=True,
|
| 69 |
device_map="auto" if device != "cpu" else {"": "cpu"},
|
| 70 |
+
low_cpu_mem_usage=False,
|
| 71 |
)
|
| 72 |
except Exception as e:
|
| 73 |
if device == "cpu" and dtype == torch.bfloat16:
|
|
|
|
| 84 |
else:
|
| 85 |
raise
|
| 86 |
|
| 87 |
+
if device == "cpu":
|
| 88 |
+
model = model.to(device=device, dtype=eff_dtype)
|
| 89 |
+
else:
|
| 90 |
+
model = model.to(device=device)
|
| 91 |
|
| 92 |
model.eval()
|
| 93 |
_MODEL_CACHE[(device, eff_dtype)] = model
|
|
|
|
| 107 |
rid = f"chatcmpl-hf-{int(time.time())}"
|
| 108 |
now = int(time.time())
|
| 109 |
|
|
|
|
| 110 |
x_ip_token = request.get("x_ip_token")
|
| 111 |
if x_ip_token and zero_client:
|
| 112 |
zero_client.HEADERS["X-IP-Token"] = x_ip_token
|
| 113 |
logger.debug("Injected X-IP-Token into ZeroGPU headers")
|
| 114 |
|
|
|
|
| 115 |
if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
|
| 116 |
try:
|
| 117 |
prompt = tokenizer.apply_chat_template(
|
|
|
|
| 150 |
use_cache=True,
|
| 151 |
)
|
| 152 |
|
| 153 |
+
# Slice: keep only newly generated tokens
|
| 154 |
+
input_len = inputs["input_ids"].shape[-1]
|
| 155 |
+
generated_ids = outputs[0][input_len:]
|
| 156 |
+
text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
|
| 157 |
+
return text
|
| 158 |
|
| 159 |
if spaces:
|
| 160 |
@spaces.GPU(duration=120)
|
|
|
|
| 173 |
"created": now,
|
| 174 |
"model": MODEL_ID,
|
| 175 |
"choices": [
|
| 176 |
+
{"index": 0, "delta": {"role": "assistant", "content": text}, "finish_reason": "stop"}
|
| 177 |
],
|
| 178 |
}
|
| 179 |
|
| 180 |
|
| 181 |
# ---------------- Stub Images Backend ----------------
|
| 182 |
class StubImagesBackend(ImagesBackend):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
async def generate_b64(self, request: Dict[str, Any]) -> str:
|
| 184 |
logger.warning("Image generation not supported in HF backend.")
|
| 185 |
return (
|