Commit
·
11cacc3
1
Parent(s):
d76b941
- hf_backend.py +9 -10
hf_backend.py
CHANGED
|
@@ -34,7 +34,6 @@ except Exception as e:
|
|
| 34 |
|
| 35 |
# ---------------- helpers ----------------
|
| 36 |
def _pick_cpu_dtype() -> torch.dtype:
|
| 37 |
-
# Prefer BF16 if CPU supports it
|
| 38 |
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
|
| 39 |
try:
|
| 40 |
if torch.cpu.is_bf16_supported():
|
|
@@ -51,16 +50,14 @@ _MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}
|
|
| 51 |
|
| 52 |
|
| 53 |
def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
if effective_key in _MODEL_CACHE:
|
| 58 |
-
return _MODEL_CACHE[effective_key], dtype
|
| 59 |
|
| 60 |
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 61 |
if hasattr(cfg, "quantization_config"):
|
| 62 |
logger.warning("Removing quantization_config from model config")
|
| 63 |
-
delattr(cfg, "quantization_config")
|
| 64 |
|
| 65 |
eff_dtype = dtype
|
| 66 |
try:
|
|
@@ -70,6 +67,7 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
|
|
| 70 |
torch_dtype=dtype,
|
| 71 |
trust_remote_code=True,
|
| 72 |
device_map="auto" if device != "cpu" else {"": "cpu"},
|
|
|
|
| 73 |
)
|
| 74 |
except Exception as e:
|
| 75 |
if device == "cpu" and dtype == torch.bfloat16:
|
|
@@ -81,10 +79,14 @@ def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, t
|
|
| 81 |
torch_dtype=eff_dtype,
|
| 82 |
trust_remote_code=True,
|
| 83 |
device_map={"": "cpu"},
|
|
|
|
| 84 |
)
|
| 85 |
else:
|
| 86 |
raise
|
| 87 |
|
|
|
|
|
|
|
|
|
|
| 88 |
model.eval()
|
| 89 |
_MODEL_CACHE[(device, eff_dtype)] = model
|
| 90 |
return model, eff_dtype
|
|
@@ -151,17 +153,14 @@ class HFChatBackend(ChatBackend):
|
|
| 151 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 152 |
|
| 153 |
if spaces:
|
| 154 |
-
# Always dispatch via ZeroGPU decorator if available.
|
| 155 |
@spaces.GPU(duration=120)
|
| 156 |
def run_once(prompt: str) -> str:
|
| 157 |
if torch.cuda.is_available():
|
| 158 |
return _run_once(prompt, device="cuda", req_dtype=torch.float16)
|
| 159 |
-
# Fallback to CPU inside the GPU context if CUDA is unavailable
|
| 160 |
return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
| 161 |
|
| 162 |
text = run_once(prompt)
|
| 163 |
else:
|
| 164 |
-
# CPU-only runtime
|
| 165 |
text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
| 166 |
|
| 167 |
yield {
|
|
|
|
| 34 |
|
| 35 |
# ---------------- helpers ----------------
|
| 36 |
def _pick_cpu_dtype() -> torch.dtype:
|
|
|
|
| 37 |
if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
|
| 38 |
try:
|
| 39 |
if torch.cpu.is_bf16_supported():
|
|
|
|
| 50 |
|
| 51 |
|
| 52 |
def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
|
| 53 |
+
key = (device, dtype)
|
| 54 |
+
if key in _MODEL_CACHE:
|
| 55 |
+
return _MODEL_CACHE[key], dtype
|
|
|
|
|
|
|
| 56 |
|
| 57 |
cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 58 |
if hasattr(cfg, "quantization_config"):
|
| 59 |
logger.warning("Removing quantization_config from model config")
|
| 60 |
+
delattr(cfg, "quantization_config")
|
| 61 |
|
| 62 |
eff_dtype = dtype
|
| 63 |
try:
|
|
|
|
| 67 |
torch_dtype=dtype,
|
| 68 |
trust_remote_code=True,
|
| 69 |
device_map="auto" if device != "cpu" else {"": "cpu"},
|
| 70 |
+
low_cpu_mem_usage=False, # ensure full load before casting
|
| 71 |
)
|
| 72 |
except Exception as e:
|
| 73 |
if device == "cpu" and dtype == torch.bfloat16:
|
|
|
|
| 79 |
torch_dtype=eff_dtype,
|
| 80 |
trust_remote_code=True,
|
| 81 |
device_map={"": "cpu"},
|
| 82 |
+
low_cpu_mem_usage=False,
|
| 83 |
)
|
| 84 |
else:
|
| 85 |
raise
|
| 86 |
|
| 87 |
+
# --- Force recast to target dtype/device (fixes FP8 leftovers) ---
|
| 88 |
+
model = model.to(device=device, dtype=eff_dtype)
|
| 89 |
+
|
| 90 |
model.eval()
|
| 91 |
_MODEL_CACHE[(device, eff_dtype)] = model
|
| 92 |
return model, eff_dtype
|
|
|
|
| 153 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 154 |
|
| 155 |
if spaces:
|
|
|
|
| 156 |
@spaces.GPU(duration=120)
|
| 157 |
def run_once(prompt: str) -> str:
|
| 158 |
if torch.cuda.is_available():
|
| 159 |
return _run_once(prompt, device="cuda", req_dtype=torch.float16)
|
|
|
|
| 160 |
return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
| 161 |
|
| 162 |
text = run_once(prompt)
|
| 163 |
else:
|
|
|
|
| 164 |
text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())
|
| 165 |
|
| 166 |
yield {
|