johnbridges commited on
Commit
76f1775
·
1 Parent(s): 36d163c
Files changed (2) hide show
  1. app.py +0 -36
  2. vllm_backend.py +49 -30
app.py CHANGED
@@ -8,11 +8,6 @@ from listener import RabbitListenerBase
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
  from vllm_backend import VLLMChatBackend, StubImagesBackend
11
- import state # holds vllm_engine reference
12
-
13
- # ---- vLLM imports ----
14
- from vllm.engine.async_llm_engine import AsyncLLMEngine
15
- from vllm.engine.arg_utils import AsyncEngineArgs
16
 
17
  logging.basicConfig(
18
  level=logging.INFO,
@@ -28,40 +23,10 @@ try:
28
  def gpu_entrypoint() -> str:
29
  return "gpu: ready"
30
 
31
- @spaces.GPU(duration=60)
32
- def _build_vllm_engine_on_gpu(model_id: str, max_len: int):
33
- args = AsyncEngineArgs(
34
- model=model_id,
35
- trust_remote_code=True,
36
- max_model_len=max_len,
37
- )
38
- return AsyncLLMEngine.from_engine_args(args)
39
-
40
  except Exception:
41
  def gpu_entrypoint() -> str:
42
  return "gpu: not available (CPU only)"
43
 
44
- def _build_vllm_engine_on_gpu(model_id: str, max_len: int):
45
- args = AsyncEngineArgs(
46
- model=model_id,
47
- trust_remote_code=True,
48
- max_model_len=max_len,
49
- )
50
- return AsyncLLMEngine.from_engine_args(args)
51
-
52
- # ----------------- vLLM init -----------------
53
- async def init_vllm():
54
- if state.vllm_engine is not None:
55
- return state.vllm_engine
56
-
57
- model_id = getattr(settings, "LlmHFModelID", "Qwen/Qwen2.5-7B-Instruct")
58
- max_len = int(getattr(settings, "LlmOpenAICtxSize", 32768))
59
- log.info(f"Loading vLLM model: {model_id}")
60
-
61
- # Build inside a GPU context so Spaces ZeroGPU exposes CUDA
62
- state.vllm_engine = _build_vllm_engine_on_gpu(model_id, max_len)
63
- return state.vllm_engine
64
-
65
  # ----------------- RabbitMQ wiring -----------------
66
  publisher = RabbitRepo(external_source="openai.mq.server")
67
  resolver = (lambda name: "direct" if name.startswith("oa.") else settings.RABBIT_EXCHANGE_TYPE)
@@ -90,7 +55,6 @@ listener = RabbitListenerBase(base, instance_name=settings.RABBIT_INSTANCE_NAME,
90
  # ----------------- Startup init -----------------
91
  async def _startup_init():
92
  try:
93
- await init_vllm() # load vLLM model
94
  await base.connect() # connect to RabbitMQ
95
  await listener.start(DECLS) # start queue listeners
96
  return "OpenAI MQ + vLLM: ready"
 
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
  from vllm_backend import VLLMChatBackend, StubImagesBackend
 
 
 
 
 
11
 
12
  logging.basicConfig(
13
  level=logging.INFO,
 
23
  def gpu_entrypoint() -> str:
24
  return "gpu: ready"
25
 
 
 
 
 
 
 
 
 
 
26
  except Exception:
27
  def gpu_entrypoint() -> str:
28
  return "gpu: not available (CPU only)"
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # ----------------- RabbitMQ wiring -----------------
31
  publisher = RabbitRepo(external_source="openai.mq.server")
32
  resolver = (lambda name: "direct" if name.startswith("oa.") else settings.RABBIT_EXCHANGE_TYPE)
 
55
  # ----------------- Startup init -----------------
56
  async def _startup_init():
57
  try:
 
58
  await base.connect() # connect to RabbitMQ
59
  await listener.start(DECLS) # start queue listeners
60
  return "OpenAI MQ + vLLM: ready"
vllm_backend.py CHANGED
@@ -4,57 +4,75 @@ from typing import Any, Dict, AsyncIterable
4
 
5
  from vllm.sampling_params import SamplingParams
6
  from backends_base import ChatBackend, ImagesBackend
7
- from state import vllm_engine # ✅ the single source of truth
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
 
 
 
 
 
11
  class VLLMChatBackend(ChatBackend):
12
  """
13
- Streams completions from a local vLLM engine.
14
- Produces OpenAI-compatible ChatCompletionChunk dicts.
15
  """
16
- async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
17
- if vllm_engine is None:
18
- raise RuntimeError("vLLM engine not initialized")
19
 
20
- # For now: just grab the last user message
21
  messages = request.get("messages", [])
22
  prompt = messages[-1]["content"] if messages else "(empty)"
23
 
24
  params = SamplingParams(
25
  temperature=float(request.get("temperature", 0.7)),
26
  max_tokens=int(request.get("max_tokens", 512)),
27
- stream=True,
28
  )
29
 
30
  rid = f"chatcmpl-local-{int(time.time())}"
31
  now = int(time.time())
32
  model_name = request.get("model", "local-vllm")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
- async for output in vllm_engine.generate(prompt, params, request_id=rid):
36
- text_piece = output.outputs[0].text
37
- yield {
38
- "id": rid,
39
- "object": "chat.completion.chunk",
40
- "created": now,
41
- "model": model_name,
42
- "choices": [
43
- {"index": 0, "delta": {"content": text_piece}, "finish_reason": None}
44
- ],
45
- }
46
  except Exception:
47
- logger.exception("vLLM generation failed")
48
  raise
49
 
50
- # Final stop signal
51
- yield {
52
- "id": rid,
53
- "object": "chat.completion.chunk",
54
- "created": now,
55
- "model": model_name,
56
- "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
57
- }
58
 
59
  class StubImagesBackend(ImagesBackend):
60
  """
@@ -63,5 +81,6 @@ class StubImagesBackend(ImagesBackend):
63
  """
64
  async def generate_b64(self, request: Dict[str, Any]) -> str:
65
  logger.warning("Image generation not supported in local vLLM backend.")
66
- # 1x1 transparent PNG
67
- return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
 
 
4
 
5
  from vllm.sampling_params import SamplingParams
6
  from backends_base import ChatBackend, ImagesBackend
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
10
+ try:
11
+ import spaces
12
+ except ImportError:
13
+ spaces = None
14
+
15
+
16
  class VLLMChatBackend(ChatBackend):
17
  """
18
+ On ZeroGPU: build vLLM engine per request (no persistent state).
19
+ Returns a single ChatCompletionChunk with the full text.
20
  """
 
 
 
21
 
22
+ async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
23
  messages = request.get("messages", [])
24
  prompt = messages[-1]["content"] if messages else "(empty)"
25
 
26
  params = SamplingParams(
27
  temperature=float(request.get("temperature", 0.7)),
28
  max_tokens=int(request.get("max_tokens", 512)),
29
+ stream=False, # we want full text only
30
  )
31
 
32
  rid = f"chatcmpl-local-{int(time.time())}"
33
  now = int(time.time())
34
  model_name = request.get("model", "local-vllm")
35
 
36
+ # GPU wrapper for ZeroGPU
37
+ if spaces:
38
+ @spaces.GPU(duration=60)
39
+ def run_once(prompt: str) -> str:
40
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
41
+ from vllm.engine.arg_utils import AsyncEngineArgs
42
+
43
+ args = AsyncEngineArgs(model=model_name, trust_remote_code=True)
44
+ engine = AsyncLLMEngine.from_engine_args(args)
45
+
46
+ # synchronous generate
47
+ outputs = list(engine.generate(prompt, params, request_id=rid))
48
+ return outputs[-1].outputs[0].text if outputs else ""
49
+
50
+ else:
51
+ def run_once(prompt: str) -> str:
52
+ from vllm.engine.async_llm_engine import AsyncLLMEngine
53
+ from vllm.engine.arg_utils import AsyncEngineArgs
54
+
55
+ args = AsyncEngineArgs(model=model_name, trust_remote_code=True)
56
+ engine = AsyncLLMEngine.from_engine_args(args)
57
+
58
+ outputs = list(engine.generate(prompt, params, request_id=rid))
59
+ return outputs[-1].outputs[0].text if outputs else ""
60
+
61
  try:
62
+ text = run_once(prompt)
63
+ yield {
64
+ "id": rid,
65
+ "object": "chat.completion.chunk",
66
+ "created": now,
67
+ "model": model_name,
68
+ "choices": [
69
+ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
70
+ ],
71
+ }
 
72
  except Exception:
73
+ logger.exception("vLLM inference failed")
74
  raise
75
 
 
 
 
 
 
 
 
 
76
 
77
  class StubImagesBackend(ImagesBackend):
78
  """
 
81
  """
82
  async def generate_b64(self, request: Dict[str, Any]) -> str:
83
  logger.warning("Image generation not supported in local vLLM backend.")
84
+ return (
85
+ "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
86
+ )