johnbridges commited on
Commit
6e98acb
·
1 Parent(s): 7e2c46b
Files changed (2) hide show
  1. app.py +3 -2
  2. transformers_backend.py +88 -0
app.py CHANGED
@@ -7,7 +7,8 @@ from rabbit_base import RabbitBase
7
  from listener import RabbitListenerBase
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
- from vllm_backend import VLLMChatBackend, StubImagesBackend
 
11
 
12
  logging.basicConfig(
13
  level=logging.INFO,
@@ -34,7 +35,7 @@ base = RabbitBase(exchange_type_resolver=resolver)
34
 
35
  servers = OpenAIServers(
36
  publisher,
37
- chat_backend=VLLMChatBackend(),
38
  images_backend=StubImagesBackend()
39
  )
40
 
 
7
  from listener import RabbitListenerBase
8
  from rabbit_repo import RabbitRepo
9
  from oa_server import OpenAIServers
10
+ #from vllm_backend import VLLMChatBackend, StubImagesBackend
11
+ from transformers_backend import TransformersChatBackend, StubImagesBackend
12
 
13
  logging.basicConfig(
14
  level=logging.INFO,
 
35
 
36
  servers = OpenAIServers(
37
  publisher,
38
+ chat_backend=TransformersChatBackend(),
39
  images_backend=StubImagesBackend()
40
  )
41
 
transformers_backend.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # transformers_backend.py
2
+ import time, logging
3
+ from typing import Any, Dict, AsyncIterable
4
+
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from backends_base import ChatBackend, ImagesBackend
7
+ from config import settings
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ try:
12
+ import spaces
13
+ except ImportError:
14
+ spaces = None
15
+
16
+
17
+ class TransformersChatBackend(ChatBackend):
18
+ """
19
+ Lightweight backend for Hugging Face Spaces (ZeroGPU).
20
+ Reloads model on every request using Transformers, not vLLM.
21
+ """
22
+
23
+ async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
24
+ messages = request.get("messages", [])
25
+ prompt = messages[-1]["content"] if messages else "(empty)"
26
+
27
+ # Config-driven defaults
28
+ model_id = request.get("model") or settings.LlmHFModelID
29
+ temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
30
+ max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))
31
+
32
+ rid = f"chatcmpl-transformers-{int(time.time())}"
33
+ now = int(time.time())
34
+
35
+ # Run inside ZeroGPU lease
36
+ if spaces:
37
+ @spaces.GPU(duration=60)
38
+ def run_once(prompt: str) -> str:
39
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
40
+ model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
41
+
42
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
43
+ outputs = model.generate(
44
+ **inputs,
45
+ max_new_tokens=max_tokens,
46
+ temperature=temperature,
47
+ do_sample=True,
48
+ )
49
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ else:
51
+ def run_once(prompt: str) -> str:
52
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
53
+ model = AutoModelForCausalLM.from_pretrained(model_id)
54
+
55
+ inputs = tokenizer(prompt, return_tensors="pt")
56
+ outputs = model.generate(
57
+ **inputs,
58
+ max_new_tokens=max_tokens,
59
+ temperature=temperature,
60
+ do_sample=True,
61
+ )
62
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
63
+
64
+ try:
65
+ text = run_once(prompt)
66
+ yield {
67
+ "id": rid,
68
+ "object": "chat.completion.chunk",
69
+ "created": now,
70
+ "model": model_id,
71
+ "choices": [
72
+ {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
73
+ ],
74
+ }
75
+ except Exception:
76
+ logger.exception("Transformers inference failed")
77
+ raise
78
+
79
+
80
+ class StubImagesBackend(ImagesBackend):
81
+ """
82
+ Image generation stub — returns a transparent PNG placeholder.
83
+ """
84
+ async def generate_b64(self, request: Dict[str, Any]) -> str:
85
+ logger.warning("Image generation not supported in Transformers backend.")
86
+ return (
87
+ "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
88
+ )