johnbridges commited on
Commit
1d79762
·
1 Parent(s): aa096cd
Files changed (1) hide show
  1. hf_backend.py +27 -4
hf_backend.py CHANGED
@@ -3,7 +3,7 @@ import time, logging
3
  from typing import Any, Dict, AsyncIterable
4
 
5
  import torch
6
- from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from backends_base import ChatBackend, ImagesBackend
8
  from config import settings
9
 
@@ -24,13 +24,27 @@ try:
24
  tokenizer = AutoTokenizer.from_pretrained(
25
  MODEL_ID,
26
  trust_remote_code=True,
27
- use_fast=False
28
  )
29
  except Exception as e:
30
  load_error = f"Failed to load tokenizer: {e}"
31
  logger.exception(load_error)
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ---------------- Chat Backend ----------------
35
  class HFChatBackend(ChatBackend):
36
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
@@ -52,13 +66,21 @@ class HFChatBackend(ChatBackend):
52
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
53
 
54
  def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
 
 
 
 
 
 
55
  model = AutoModelForCausalLM.from_pretrained(
56
  MODEL_ID,
 
57
  torch_dtype=dtype,
58
  trust_remote_code=True,
59
  device_map="auto" if device != "cpu" else {"": "cpu"},
60
  )
61
  model.eval()
 
62
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
63
 
64
  with torch.inference_mode():
@@ -85,8 +107,9 @@ class HFChatBackend(ChatBackend):
85
 
86
  text = run_once(prompt)
87
  else:
88
- # --- CPU-only fallback ---
89
- text = _run_once(prompt, device="cpu", dtype=torch.float32)
 
90
 
91
  yield {
92
  "id": rid,
 
3
  from typing import Any, Dict, AsyncIterable
4
 
5
  import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
7
  from backends_base import ChatBackend, ImagesBackend
8
  from config import settings
9
 
 
24
  tokenizer = AutoTokenizer.from_pretrained(
25
  MODEL_ID,
26
  trust_remote_code=True,
27
+ use_fast=False,
28
  )
29
  except Exception as e:
30
  load_error = f"Failed to load tokenizer: {e}"
31
  logger.exception(load_error)
32
 
33
 
34
+ # ---------------- helpers ----------------
35
+ def _pick_cpu_dtype() -> torch.dtype:
36
+ # Prefer BF16 if CPU supports it
37
+ if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
38
+ try:
39
+ if torch.cpu.is_bf16_supported():
40
+ logger.info("CPU BF16 supported, using torch.bfloat16")
41
+ return torch.bfloat16
42
+ except Exception:
43
+ pass
44
+ logger.info("Falling back to torch.float32 on CPU")
45
+ return torch.float32
46
+
47
+
48
  # ---------------- Chat Backend ----------------
49
  class HFChatBackend(ChatBackend):
50
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
 
66
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
67
 
68
  def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
69
+ # Load config and strip any quantization settings (fix FP8 issue)
70
+ cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
71
+ if hasattr(cfg, "quantization_config"):
72
+ logger.warning("Removing quantization_config from model config")
73
+ cfg.quantization_config = None
74
+
75
  model = AutoModelForCausalLM.from_pretrained(
76
  MODEL_ID,
77
+ config=cfg,
78
  torch_dtype=dtype,
79
  trust_remote_code=True,
80
  device_map="auto" if device != "cpu" else {"": "cpu"},
81
  )
82
  model.eval()
83
+
84
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
85
 
86
  with torch.inference_mode():
 
107
 
108
  text = run_once(prompt)
109
  else:
110
+ # --- CPU-only fallback with auto dtype detection ---
111
+ dtype = _pick_cpu_dtype()
112
+ text = _run_once(prompt, device="cpu", dtype=dtype)
113
 
114
  yield {
115
  "id": rid,