johnbridges commited on
Commit
2ad6a17
·
1 Parent(s): ba9c53a

added cpu support to hf_backend

Browse files
Files changed (1) hide show
  1. hf_backend.py +48 -39
hf_backend.py CHANGED
@@ -15,21 +15,19 @@ try:
15
  except ImportError:
16
  spaces, zero_client = None, None
17
 
18
- # --- Model setup (CPU-safe load, real inference on GPU only) ---
19
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
20
- logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU (ZeroGPU safe)...")
21
 
22
- tokenizer, model, load_error = None, None, None
23
  try:
24
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
25
- model = AutoModelForCausalLM.from_pretrained(
26
  MODEL_ID,
27
- torch_dtype=torch.float32, # dummy dtype for CPU preload
28
  trust_remote_code=True,
 
29
  )
30
- model.eval()
31
  except Exception as e:
32
- load_error = f"Failed to load model/tokenizer: {e}"
33
  logger.exception(load_error)
34
 
35
 
@@ -47,47 +45,58 @@ class HFChatBackend(ChatBackend):
47
  rid = f"chatcmpl-hf-{int(time.time())}"
48
  now = int(time.time())
49
 
50
- if not spaces:
51
- raise RuntimeError("ZeroGPU (spaces) is required but not available!")
52
-
53
- # --- Inject X-IP-Token into global headers ---
54
  x_ip_token = request.get("x_ip_token")
55
  if x_ip_token and zero_client:
56
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
57
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
58
 
59
- # --- Define the GPU-only inference function ---
60
- @spaces.GPU(duration=120)
61
- def run_once(prompt: str) -> str:
62
- device = "cuda" # force CUDA
63
- dtype = torch.float16
64
-
65
- model.to(device=device, dtype=dtype).eval()
 
66
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
67
 
68
- with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
69
- outputs = model.generate(
70
- **inputs,
71
- max_new_tokens=max_tokens,
72
- temperature=temperature,
73
- do_sample=True,
74
- )
 
 
 
 
 
 
 
75
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
76
 
77
- try:
 
 
 
 
 
78
  text = run_once(prompt)
79
- yield {
80
- "id": rid,
81
- "object": "chat.completion.chunk",
82
- "created": now,
83
- "model": MODEL_ID,
84
- "choices": [
85
- {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
86
- ],
87
- }
88
- except Exception:
89
- logger.exception("HF inference failed")
90
- raise
 
91
 
92
 
93
  # ---------------- Stub Images Backend ----------------
 
15
  except ImportError:
16
  spaces, zero_client = None, None
17
 
18
+ # --- Model setup ---
19
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
20
+ logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU...")
21
 
22
+ tokenizer, load_error = None, None
23
  try:
24
+ tokenizer = AutoTokenizer.from_pretrained(
 
25
  MODEL_ID,
 
26
  trust_remote_code=True,
27
+ use_fast=False
28
  )
 
29
  except Exception as e:
30
+ load_error = f"Failed to load tokenizer: {e}"
31
  logger.exception(load_error)
32
 
33
 
 
45
  rid = f"chatcmpl-hf-{int(time.time())}"
46
  now = int(time.time())
47
 
48
+ # --- Inject X-IP-Token into global headers if ZeroGPU is used ---
 
 
 
49
  x_ip_token = request.get("x_ip_token")
50
  if x_ip_token and zero_client:
51
  zero_client.HEADERS["X-IP-Token"] = x_ip_token
52
  logger.debug("Injected X-IP-Token into ZeroGPU headers")
53
 
54
+ def _run_once(prompt: str, device: str, dtype: torch.dtype) -> str:
55
+ model = AutoModelForCausalLM.from_pretrained(
56
+ MODEL_ID,
57
+ torch_dtype=dtype,
58
+ trust_remote_code=True,
59
+ device_map="auto" if device != "cpu" else {"": "cpu"},
60
+ )
61
+ model.eval()
62
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
63
 
64
+ with torch.inference_mode():
65
+ if device != "cpu":
66
+ autocast_ctx = torch.autocast(device_type=device, dtype=dtype)
67
+ else:
68
+ autocast_ctx = torch.cpu.amp.autocast(dtype=dtype)
69
+
70
+ with autocast_ctx:
71
+ outputs = model.generate(
72
+ **inputs,
73
+ max_new_tokens=max_tokens,
74
+ temperature=temperature,
75
+ do_sample=True,
76
+ )
77
+
78
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
79
 
80
+ if spaces:
81
+ # --- GPU path with ZeroGPU ---
82
+ @spaces.GPU(duration=120)
83
+ def run_once(prompt: str) -> str:
84
+ return _run_once(prompt, device="cuda", dtype=torch.float16)
85
+
86
  text = run_once(prompt)
87
+ else:
88
+ # --- CPU-only fallback ---
89
+ text = _run_once(prompt, device="cpu", dtype=torch.float32)
90
+
91
+ yield {
92
+ "id": rid,
93
+ "object": "chat.completion.chunk",
94
+ "created": now,
95
+ "model": MODEL_ID,
96
+ "choices": [
97
+ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
98
+ ],
99
+ }
100
 
101
 
102
  # ---------------- Stub Images Backend ----------------