johnbridges commited on
Commit
1175344
·
1 Parent(s): b416f51
Files changed (1) hide show
  1. hf_backend.py +24 -6
hf_backend.py CHANGED
@@ -1,4 +1,3 @@
1
- # hf_backend.py
2
  import time, logging, json, asyncio
3
  from contextlib import nullcontext
4
  from typing import Any, Dict, AsyncIterable, Tuple
@@ -33,16 +32,38 @@ except Exception as e:
33
  load_error = f"Failed to load tokenizer: {e}"
34
  logger.exception(load_error)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def _pick_cpu_dtype() -> torch.dtype:
37
  try:
38
- if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported") and torch.cpu.is_bf16_supported():
39
- logger.info("[dtype] CPU BF16 supported -> torch.bfloat16")
40
  return torch.bfloat16
41
  except Exception as e:
42
  logger.warning(f"[dtype] BF16 probe failed: {e}")
43
  logger.info("[dtype] fallback -> torch.float32")
44
  return torch.float32
45
 
 
 
 
 
 
 
46
  _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
47
 
48
  def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
@@ -153,12 +174,10 @@ class HFChatBackend(ChatBackend):
153
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
154
  logger.info("[req] injected X-IP-Token into ZeroGPU headers")
155
 
156
- # Build prompt (pass tools to template)
157
  if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
158
  try:
159
  prompt = tokenizer.apply_chat_template(
160
  messages,
161
- #tools=tools,
162
  tokenize=False,
163
  add_generation_prompt=True,
164
  )
@@ -212,7 +231,6 @@ class HFChatBackend(ChatBackend):
212
  logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
213
  return text
214
 
215
- # Offload heavy work to a worker thread so asyncio heartbeats continue
216
  if spaces:
217
  @spaces.GPU(duration=120)
218
  def run_once_sync(prompt: str) -> str:
 
 
1
  import time, logging, json, asyncio
2
  from contextlib import nullcontext
3
  from typing import Any, Dict, AsyncIterable, Tuple
 
32
  load_error = f"Failed to load tokenizer: {e}"
33
  logger.exception(load_error)
34
 
35
+
36
+ def probe_bf16_runtime() -> bool:
37
+ """Check if BF16 is both reported and actually used in ops on CPU."""
38
+ if not (hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported")):
39
+ return False
40
+ if not torch.cpu.is_bf16_supported():
41
+ return False
42
+ try:
43
+ a = torch.randn(16, 16, dtype=torch.bfloat16)
44
+ b = torch.randn(16, 16, dtype=torch.bfloat16)
45
+ c = a @ b
46
+ return c.dtype == torch.bfloat16
47
+ except Exception:
48
+ return False
49
+
50
+
51
  def _pick_cpu_dtype() -> torch.dtype:
52
  try:
53
+ if probe_bf16_runtime():
54
+ logger.info("[dtype] Verified BF16 execution on CPU -> torch.bfloat16")
55
  return torch.bfloat16
56
  except Exception as e:
57
  logger.warning(f"[dtype] BF16 probe failed: {e}")
58
  logger.info("[dtype] fallback -> torch.float32")
59
  return torch.float32
60
 
61
+
62
+ # Log CPU dtype capability at startup
63
+ CPU_DTYPE = _pick_cpu_dtype()
64
+ logger.info(f"[init] Default CPU dtype = {CPU_DTYPE}")
65
+
66
+
67
  _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
68
 
69
  def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
 
174
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
175
  logger.info("[req] injected X-IP-Token into ZeroGPU headers")
176
 
 
177
  if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
178
  try:
179
  prompt = tokenizer.apply_chat_template(
180
  messages,
 
181
  tokenize=False,
182
  add_generation_prompt=True,
183
  )
 
231
  logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
232
  return text
233
 
 
234
  if spaces:
235
  @spaces.GPU(duration=120)
236
  def run_once_sync(prompt: str) -> str: