johnbridges commited on
Commit
60a9595
·
1 Parent(s): 5b32c71
Files changed (1) hide show
  1. hf_backend.py +42 -63
hf_backend.py CHANGED
@@ -1,5 +1,5 @@
1
- # hf_backend.py (patched)
2
- import time, logging, os
3
  from typing import Any, Dict, AsyncIterable
4
 
5
  import torch
@@ -7,23 +7,24 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
7
  from backends_base import ChatBackend, ImagesBackend
8
  from config import settings
9
 
 
 
10
  try:
11
  import spaces
12
- from spaces.zero.client import SpaceZeroClient
13
  except ImportError:
14
- spaces, SpaceZeroClient = None, None
15
-
16
- logger = logging.getLogger(__name__)
17
 
 
18
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
19
- logger.info(f"Loading {MODEL_ID} on CPU at startup (ZeroGPU safe)...")
20
 
21
  tokenizer, model, load_error = None, None, None
22
  try:
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_ID,
26
- torch_dtype=torch.float32,
27
  trust_remote_code=True,
28
  )
29
  model.eval()
@@ -32,22 +33,7 @@ except Exception as e:
32
  logger.exception(load_error)
33
 
34
 
35
- def pick_device() -> str:
36
- if torch.cuda.is_available():
37
- return "cuda"
38
- if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
39
- return "mps"
40
- return "cpu"
41
-
42
- def pick_dtype(device: str) -> torch.dtype:
43
- if device == "cuda":
44
- major, _ = torch.cuda.get_device_capability()
45
- return torch.bfloat16 if major >= 8 else torch.float16
46
- if device == "mps":
47
- return torch.float16
48
- return torch.float32
49
-
50
-
51
  class HFChatBackend(ChatBackend):
52
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
53
  if load_error:
@@ -61,19 +47,24 @@ class HFChatBackend(ChatBackend):
61
  rid = f"chatcmpl-hf-{int(time.time())}"
62
  now = int(time.time())
63
 
64
- # --- ✅ Extract X-IP-Token from RabbitMQ message
 
 
 
65
  x_ip_token = request.get("x_ip_token")
66
- headers = {}
67
- if x_ip_token:
68
- headers["X-IP-Token"] = x_ip_token
69
- logger.info("Using X-IP-Token from request for ZeroGPU attribution")
70
-
71
- def _gpu_inference_fn(prompt: str) -> str:
72
- device = pick_device()
73
- dtype = pick_dtype(device)
74
- model.to(device=device, dtype=dtype).eval()
75
 
 
 
 
 
 
 
 
76
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
77
  with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
78
  outputs = model.generate(
79
  **inputs,
@@ -83,35 +74,23 @@ class HFChatBackend(ChatBackend):
83
  )
84
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
85
 
86
- if spaces and SpaceZeroClient:
87
- # Use a custom SpaceZeroClient with headers
88
- client = SpaceZeroClient(headers=headers or None)
89
- try:
90
- text = await client.run(_gpu_inference_fn, args=[prompt], duration=120)
91
- except Exception:
92
- logger.exception("HF inference (ZeroGPU) failed")
93
- raise
94
- else:
95
- # CPU fallback
96
- inputs = tokenizer(prompt, return_tensors="pt")
97
- with torch.inference_mode():
98
- outputs = model.generate(
99
- **inputs,
100
- max_new_tokens=max_tokens,
101
- temperature=temperature,
102
- do_sample=True,
103
- )
104
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
105
-
106
- yield {
107
- "id": rid,
108
- "object": "chat.completion.chunk",
109
- "created": now,
110
- "model": MODEL_ID,
111
- "choices": [
112
- {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
113
- ],
114
- }
115
  class StubImagesBackend(ImagesBackend):
116
  """
117
  Stub backend for images since HFChatBackend is text-only.
 
1
+ # hf_backend.py
2
+ import time, logging
3
  from typing import Any, Dict, AsyncIterable
4
 
5
  import torch
 
7
  from backends_base import ChatBackend, ImagesBackend
8
  from config import settings
9
 
10
+ logger = logging.getLogger(__name__)
11
+
12
  try:
13
  import spaces
14
+ from spaces.zero import client as zero_client
15
  except ImportError:
16
+ spaces, zero_client = None, None
 
 
17
 
18
+ # --- Model setup (CPU-safe load, real inference on GPU only) ---
19
  MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
20
+ logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU (ZeroGPU safe)...")
21
 
22
  tokenizer, model, load_error = None, None, None
23
  try:
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  MODEL_ID,
27
+ torch_dtype=torch.float32, # dummy dtype for CPU preload
28
  trust_remote_code=True,
29
  )
30
  model.eval()
 
33
  logger.exception(load_error)
34
 
35
 
36
+ # ---------------- Chat Backend ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  class HFChatBackend(ChatBackend):
38
  async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
39
  if load_error:
 
47
  rid = f"chatcmpl-hf-{int(time.time())}"
48
  now = int(time.time())
49
 
50
+ if not spaces:
51
+ raise RuntimeError("ZeroGPU (spaces) is required but not available!")
52
+
53
+ # --- Inject X-IP-Token into global headers ---
54
  x_ip_token = request.get("x_ip_token")
55
+ if x_ip_token and zero_client:
56
+ zero_client.HEADERS["X-IP-Token"] = x_ip_token
57
+ logger.debug("Injected X-IP-Token into ZeroGPU headers")
 
 
 
 
 
 
58
 
59
+ # --- Define the GPU-only inference function ---
60
+ @spaces.GPU(duration=120)
61
+ def run_once(prompt: str) -> str:
62
+ device = "cuda" # force CUDA
63
+ dtype = torch.float16
64
+
65
+ model.to(device=device, dtype=dtype).eval()
66
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
67
+
68
  with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
69
  outputs = model.generate(
70
  **inputs,
 
74
  )
75
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
76
 
77
+ try:
78
+ text = run_once(prompt)
79
+ yield {
80
+ "id": rid,
81
+ "object": "chat.completion.chunk",
82
+ "created": now,
83
+ "model": MODEL_ID,
84
+ "choices": [
85
+ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
86
+ ],
87
+ }
88
+ except Exception:
89
+ logger.exception("HF inference failed")
90
+ raise
91
+
92
+
93
+ # ---------------- Stub Images Backend ----------------
 
 
 
 
 
 
 
 
 
 
 
 
94
  class StubImagesBackend(ImagesBackend):
95
  """
96
  Stub backend for images since HFChatBackend is text-only.