johnbridges commited on
Commit
213e916
·
1 Parent(s): 11cacc3
Files changed (1) hide show
  1. hf_backend.py +11 -11
hf_backend.py CHANGED
@@ -67,7 +67,7 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
67
  torch_dtype=dtype,
68
  trust_remote_code=True,
69
  device_map="auto" if device != "cpu" else {"": "cpu"},
70
- low_cpu_mem_usage=False, # ensure full load before casting
71
  )
72
  except Exception as e:
73
  if device == "cpu" and dtype == torch.bfloat16:
@@ -84,8 +84,10 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
84
  else:
85
  raise
86
 
87
- # --- Force recast to target dtype/device (fixes FP8 leftovers) ---
88
- model = model.to(device=device, dtype=eff_dtype)
 
 
89
 
90
  model.eval()
91
  _MODEL_CACHE[(device, eff_dtype)] = model
@@ -105,13 +107,11 @@ class HFChatBackend(ChatBackend):
105
  rid = f"chatcmpl-hf-{int(time.time())}"
106
  now = int(time.time())
107
 
108
- # --- Inject X-IP-Token into global headers if ZeroGPU is used ---
109
  x_ip_token = request.get("x_ip_token")
110
  if x_ip_token and zero_client:
111
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
112
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
113
 
114
- # Build prompt using chat template if available
115
  if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
116
  try:
117
  prompt = tokenizer.apply_chat_template(
@@ -150,7 +150,11 @@ class HFChatBackend(ChatBackend):
150
  use_cache=True,
151
  )
152
 
153
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
154
 
155
  if spaces:
156
  @spaces.GPU(duration=120)
@@ -169,17 +173,13 @@ class HFChatBackend(ChatBackend):
169
  "created": now,
170
  "model": MODEL_ID,
171
  "choices": [
172
- {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
173
  ],
174
  }
175
 
176
 
177
  # ---------------- Stub Images Backend ----------------
178
  class StubImagesBackend(ImagesBackend):
179
- """
180
- Stub backend for images since HFChatBackend is text-only.
181
- Returns a transparent 1x1 PNG placeholder.
182
- """
183
  async def generate_b64(self, request: Dict[str, Any]) -> str:
184
  logger.warning("Image generation not supported in HF backend.")
185
  return (
 
67
  torch_dtype=dtype,
68
  trust_remote_code=True,
69
  device_map="auto" if device != "cpu" else {"": "cpu"},
70
+ low_cpu_mem_usage=False,
71
  )
72
  except Exception as e:
73
  if device == "cpu" and dtype == torch.bfloat16:
 
84
  else:
85
  raise
86
 
87
+ if device == "cpu":
88
+ model = model.to(device=device, dtype=eff_dtype)
89
+ else:
90
+ model = model.to(device=device)
91
 
92
  model.eval()
93
  _MODEL_CACHE[(device, eff_dtype)] = model
 
107
  rid = f"chatcmpl-hf-{int(time.time())}"
108
  now = int(time.time())
109
 
 
110
  x_ip_token = request.get("x_ip_token")
111
  if x_ip_token and zero_client:
112
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
113
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
114
 
 
115
  if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
116
  try:
117
  prompt = tokenizer.apply_chat_template(
 
150
  use_cache=True,
151
  )
152
 
153
+ # Slice: keep only newly generated tokens
154
+ input_len = inputs["input_ids"].shape[-1]
155
+ generated_ids = outputs[0][input_len:]
156
+ text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
157
+ return text
158
 
159
  if spaces:
160
  @spaces.GPU(duration=120)
 
173
  "created": now,
174
  "model": MODEL_ID,
175
  "choices": [
176
+ {"index": 0, "delta": {"role": "assistant", "content": text}, "finish_reason": "stop"}
177
  ],
178
  }
179
 
180
 
181
  # ---------------- Stub Images Backend ----------------
182
  class StubImagesBackend(ImagesBackend):
 
 
 
 
183
  async def generate_b64(self, request: Dict[str, Any]) -> str:
184
  logger.warning("Image generation not supported in HF backend.")
185
  return (