johnbridges commited on
Commit
2143c4b
Β·
1 Parent(s): ac7c5a8
Files changed (2) hide show
  1. app.py +28 -13
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # app.py
2
- import asyncio, logging, os
3
  import gradio as gr
4
 
5
  from config import settings
@@ -7,7 +7,7 @@ from rabbit_base import RabbitBase
7
  from listener import RabbitListenerBase
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
- from vllm_backend import VLLMChatBackend, StubImagesBackend # βœ… our backend
11
  import state # holds vllm_engine reference
12
 
13
  # ---- vLLM imports ----
@@ -20,31 +20,46 @@ logging.basicConfig(
20
  )
21
  log = logging.getLogger("app")
22
 
 
23
  try:
24
  import spaces
 
25
  @spaces.GPU(duration=60)
26
- def gpu_entrypoint() -> str:
27
  return "gpu: ready"
 
 
 
 
 
 
 
 
 
 
28
  except Exception:
29
- def gpu_entrypoint() -> str:
30
  return "gpu: not available (CPU only)"
31
 
 
 
 
 
 
 
 
 
32
  # ----------------- vLLM init -----------------
33
  async def init_vllm():
34
  if state.vllm_engine is not None:
35
  return state.vllm_engine
36
 
37
  model_id = getattr(settings, "LlmHFModelID", "Qwen/Qwen2.5-7B-Instruct")
 
38
  log.info(f"Loading vLLM model: {model_id}")
39
 
40
- # Always use GPU (cuda) β€” Spaces provides GPU when @spaces.GPU is active
41
- args = AsyncEngineArgs(
42
- model=model_id,
43
- trust_remote_code=True,
44
- max_model_len=getattr(settings, "LlmOpenAICtxSize", 32768),
45
- device="cuda", # βœ… force GPU
46
- )
47
- state.vllm_engine = AsyncLLMEngine.from_engine_args(args)
48
  return state.vllm_engine
49
 
50
  # ----------------- RabbitMQ wiring -----------------
@@ -83,7 +98,7 @@ async def _startup_init():
83
  log.exception("Startup init failed")
84
  return f"ERROR: {e}"
85
 
86
- async def ping():
87
  return "ok"
88
 
89
  # ----------------- Gradio UI -----------------
 
1
  # app.py
2
+ import asyncio, logging
3
  import gradio as gr
4
 
5
  from config import settings
 
7
  from listener import RabbitListenerBase
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
+ from vllm_backend import VLLMChatBackend, StubImagesBackend
11
  import state # holds vllm_engine reference
12
 
13
  # ---- vLLM imports ----
 
20
  )
21
  log = logging.getLogger("app")
22
 
23
+ # ----------------- Hugging Face Spaces helpers -----------------
24
  try:
25
  import spaces
26
+
27
  @spaces.GPU(duration=60)
28
+ def gpu_entrypoint() -> str:
29
  return "gpu: ready"
30
+
31
+ @spaces.GPU(duration=600)
32
+ def _build_vllm_engine_on_gpu(model_id: str, max_len: int):
33
+ args = AsyncEngineArgs(
34
+ model=model_id,
35
+ trust_remote_code=True,
36
+ max_model_len=max_len,
37
+ )
38
+ return AsyncLLMEngine.from_engine_args(args)
39
+
40
  except Exception:
41
+ def gpu_entrypoint() -> str:
42
  return "gpu: not available (CPU only)"
43
 
44
+ def _build_vllm_engine_on_gpu(model_id: str, max_len: int):
45
+ args = AsyncEngineArgs(
46
+ model=model_id,
47
+ trust_remote_code=True,
48
+ max_model_len=max_len,
49
+ )
50
+ return AsyncLLMEngine.from_engine_args(args)
51
+
52
  # ----------------- vLLM init -----------------
53
  async def init_vllm():
54
  if state.vllm_engine is not None:
55
  return state.vllm_engine
56
 
57
  model_id = getattr(settings, "LlmHFModelID", "Qwen/Qwen2.5-7B-Instruct")
58
+ max_len = int(getattr(settings, "LlmOpenAICtxSize", 32768))
59
  log.info(f"Loading vLLM model: {model_id}")
60
 
61
+ # Build inside a GPU context so Spaces ZeroGPU exposes CUDA
62
+ state.vllm_engine = _build_vllm_engine_on_gpu(model_id, max_len)
 
 
 
 
 
 
63
  return state.vllm_engine
64
 
65
  # ----------------- RabbitMQ wiring -----------------
 
98
  log.exception("Startup init failed")
99
  return f"ERROR: {e}"
100
 
101
+ async def ping():
102
  return "ok"
103
 
104
  # ----------------- Gradio UI -----------------
requirements.txt CHANGED
@@ -3,8 +3,9 @@ fastapi>=0.116.1
3
  uvicorn>=0.35.0
4
  aio-pika>=9.5.7
5
 
6
- pydantic==2.11.1
7
- pydantic-settings==2.10.1
 
8
  spaces
9
 
10
  vllm>=0.10.0
 
3
  uvicorn>=0.35.0
4
  aio-pika>=9.5.7
5
 
6
+ pydantic>=2.17.0
7
+ pydantic-settings>=2.6.0
8
+
9
  spaces
10
 
11
  vllm>=0.10.0