johnbridges commited on
Commit
d279e64
·
1 Parent(s): 849364d
Files changed (1) hide show
  1. hf_backend.py +40 -12
hf_backend.py CHANGED
@@ -33,11 +33,10 @@ except Exception as e:
33
 
34
  # ---------------- helpers ----------------
35
  def _pick_cpu_dtype() -> torch.dtype:
36
- # Prefer BF16 if CPU supports it
37
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
38
  try:
39
  if torch.cpu.is_bf16_supported():
40
- logger.info("CPU BF16 supported, using torch.bfloat16")
41
  return torch.bfloat16
42
  except Exception:
43
  pass
@@ -57,17 +56,32 @@ def _get_model(device: str, dtype: torch.dtype):
57
  cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
58
  if hasattr(cfg, "quantization_config"):
59
  logger.warning("Removing quantization_config from model config")
60
- delattr(cfg, "quantization_config") # delete instead of setting None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- model = AutoModelForCausalLM.from_pretrained(
63
- MODEL_ID,
64
- config=cfg,
65
- torch_dtype=dtype,
66
- trust_remote_code=True,
67
- device_map="auto" if device != "cpu" else {"": "cpu"},
68
- )
69
  model.eval()
70
- _MODEL_CACHE[key] = model
71
  return model
72
 
73
 
@@ -78,7 +92,6 @@ class HFChatBackend(ChatBackend):
78
  raise RuntimeError(load_error)
79
 
80
  messages = request.get("messages", [])
81
- prompt = messages[-1]["content"] if messages else "(empty)"
82
  temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
83
  max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
84
 
@@ -91,6 +104,21 @@ class HFChatBackend(ChatBackend):
91
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
92
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
95
  model = _get_model(device, dtype)
96
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
33
 
34
  # ---------------- helpers ----------------
35
  def _pick_cpu_dtype() -> torch.dtype:
 
36
  if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported"):
37
  try:
38
  if torch.cpu.is_bf16_supported():
39
+ logger.info("CPU BF16 supported, will attempt torch.bfloat16")
40
  return torch.bfloat16
41
  except Exception:
42
  pass
 
56
  cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
57
  if hasattr(cfg, "quantization_config"):
58
  logger.warning("Removing quantization_config from model config")
59
+ delattr(cfg, "quantization_config")
60
+
61
+ try:
62
+ model = AutoModelForCausalLM.from_pretrained(
63
+ MODEL_ID,
64
+ config=cfg,
65
+ torch_dtype=dtype,
66
+ trust_remote_code=True,
67
+ device_map="auto" if device != "cpu" else {"": "cpu"},
68
+ )
69
+ except Exception as e:
70
+ if device == "cpu" and dtype == torch.bfloat16:
71
+ logger.warning(f"BF16 load failed on CPU: {e}. Retrying with FP32.")
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ MODEL_ID,
74
+ config=cfg,
75
+ torch_dtype=torch.float32,
76
+ trust_remote_code=True,
77
+ device_map={"": "cpu"},
78
+ )
79
+ dtype = torch.float32
80
+ else:
81
+ raise
82
 
 
 
 
 
 
 
 
83
  model.eval()
84
+ _MODEL_CACHE[(device, dtype)] = model
85
  return model
86
 
87
 
 
92
  raise RuntimeError(load_error)
93
 
94
  messages = request.get("messages", [])
 
95
  temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
96
  max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
97
 
 
104
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
105
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
106
 
107
+ # Build prompt using chat template if available
108
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
109
+ try:
110
+ prompt = tokenizer.apply_chat_template(
111
+ messages,
112
+ tokenize=False,
113
+ add_generation_prompt=True,
114
+ )
115
+ logger.debug("Applied chat template for prompt")
116
+ except Exception as e:
117
+ logger.warning(f"Failed to apply chat template: {e}, using fallback")
118
+ prompt = messages[-1]["content"] if messages else "(empty)"
119
+ else:
120
+ prompt = messages[-1]["content"] if messages else "(empty)"
121
+
122
  def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
123
  model = _get_model(device, dtype)
124
  inputs = tokenizer(prompt, return_tensors="pt").to(device)