johnbridges commited on
Commit
11cacc3
·
1 Parent(s): d76b941
Files changed (1) hide show
  1. hf_backend.py +9 -10
hf_backend.py CHANGED
@@ -34,7 +34,6 @@ except Exception as e:
34
 
35
  # ---------------- helpers ----------------
36
  def _pick_cpu_dtype() -> torch.dtype:
37
- # Prefer BF16 if CPU supports it
38
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
39
  try:
40
  if torch.cpu.is_bf16_supported():
@@ -51,16 +50,14 @@ _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
51
 
52
 
53
  def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
54
- # Return model and the effective dtype actually loaded with
55
- # (handles CPU BF16 -> FP32 fallback)
56
- effective_key = (device, dtype)
57
- if effective_key in _MODEL_CACHE:
58
- return _MODEL_CACHE[effective_key], dtype
59
 
60
  cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
61
  if hasattr(cfg, "quantization_config"):
62
  logger.warning("Removing quantization_config from model config")
63
- delattr(cfg, "quantization_config") # delete instead of setting None
64
 
65
  eff_dtype = dtype
66
  try:
@@ -70,6 +67,7 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
70
  torch_dtype=dtype,
71
  trust_remote_code=True,
72
  device_map="auto" if device != "cpu" else {"": "cpu"},
 
73
  )
74
  except Exception as e:
75
  if device == "cpu" and dtype == torch.bfloat16:
@@ -81,10 +79,14 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
81
  torch_dtype=eff_dtype,
82
  trust_remote_code=True,
83
  device_map={"": "cpu"},
 
84
  )
85
  else:
86
  raise
87
 
 
 
 
88
  model.eval()
89
  _MODEL_CACHE[(device, eff_dtype)] = model
90
  return model, eff_dtype
@@ -151,17 +153,14 @@ class HFChatBackend(ChatBackend):
151
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
152
 
153
  if spaces:
154
- # Always dispatch via ZeroGPU decorator if available.
155
  @spaces.GPU(duration=120)
156
  def run_once(prompt: str) -> str:
157
  if torch.cuda.is_available():
158
  return _run_once(prompt, device="cuda", req_dtype=torch.float16)
159
- # Fallback to CPU inside the GPU context if CUDA is unavailable
160
  return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
161
 
162
  text = run_once(prompt)
163
  else:
164
- # CPU-only runtime
165
  text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
166
 
167
  yield {
 
34
 
35
  # ---------------- helpers ----------------
36
  def _pick_cpu_dtype() -> torch.dtype:
 
37
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
38
  try:
39
  if torch.cpu.is_bf16_supported():
 
50
 
51
 
52
  def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
53
+ key = (device, dtype)
54
+ if key in _MODEL_CACHE:
55
+ return _MODEL_CACHE[key], dtype
 
 
56
 
57
  cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
58
  if hasattr(cfg, "quantization_config"):
59
  logger.warning("Removing quantization_config from model config")
60
+ delattr(cfg, "quantization_config")
61
 
62
  eff_dtype = dtype
63
  try:
 
67
  torch_dtype=dtype,
68
  trust_remote_code=True,
69
  device_map="auto" if device != "cpu" else {"": "cpu"},
70
+ low_cpu_mem_usage=False, # ensure full load before casting
71
  )
72
  except Exception as e:
73
  if device == "cpu" and dtype == torch.bfloat16:
 
79
  torch_dtype=eff_dtype,
80
  trust_remote_code=True,
81
  device_map={"": "cpu"},
82
+ low_cpu_mem_usage=False,
83
  )
84
  else:
85
  raise
86
 
87
+ # --- Force recast to target dtype/device (fixes FP8 leftovers) ---
88
+ model = model.to(device=device, dtype=eff_dtype)
89
+
90
  model.eval()
91
  _MODEL_CACHE[(device, eff_dtype)] = model
92
  return model, eff_dtype
 
153
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
154
 
155
  if spaces:
 
156
  @spaces.GPU(duration=120)
157
  def run_once(prompt: str) -> str:
158
  if torch.cuda.is_available():
159
  return _run_once(prompt, device="cuda", req_dtype=torch.float16)
 
160
  return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
161
 
162
  text = run_once(prompt)
163
  else:
 
164
  text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
165
 
166
  yield {