johnbridges commited on
Commit
bf707d2
·
1 Parent(s): 6195aba
Files changed (2) hide show
  1. oa_server.py +3 -30
  2. openai_backend.py +104 -0
oa_server.py CHANGED
@@ -4,6 +4,7 @@ import json, time, uuid, logging
4
  from typing import Any, Dict, List, AsyncIterable, Optional
5
 
6
  from rabbit_repo import RabbitRepo
 
7
 
8
  logger = logging.getLogger(__name__)
9
 
@@ -22,34 +23,6 @@ def _last_user_text(messages: List[Dict[str, Any]]) -> str:
22
  return " ".join([t for t in texts if t])
23
  return ""
24
 
25
- # ------------------ backends (replace later) ------------------
26
- class ChatBackend:
27
- async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
28
- raise NotImplementedError
29
-
30
- class DummyChatBackend(ChatBackend):
31
- async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
32
- rid = f"chatcmpl-{uuid.uuid4().hex[:12]}"
33
- model = request.get("model", "gpt-4o-mini")
34
- text = _last_user_text(request.get("messages", [])) or "(empty)"
35
- out = f"Echo (Rabbit): {text}"
36
- now = _now()
37
-
38
- # role delta
39
- yield {"id": rid, "object":"chat.completion.chunk", "created": now, "model": model,
40
- "choices":[{"index":0,"delta":{"role":"assistant"},"finish_reason":None}]}
41
- # content deltas
42
- for piece in _chunk_text(out, 140):
43
- yield {"id": rid, "object":"chat.completion.chunk", "created": now, "model": model,
44
- "choices":[{"index":0,"delta":{"content":piece},"finish_reason":None}]}
45
- # final delta
46
- yield {"id": rid, "object":"chat.completion.chunk", "created": now, "model": model,
47
- "choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}
48
-
49
- class ImagesBackend:
50
- async def generate_b64(self, request: Dict[str, Any]) -> str:
51
- # 1x1 transparent PNG (stub)
52
- return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
53
 
54
  # ------------------ handler class ------------------
55
  class OpenAIServers:
@@ -63,8 +36,8 @@ class OpenAIServers:
63
  *, chat_backend: Optional[ChatBackend] = None,
64
  images_backend: Optional[ImagesBackend] = None):
65
  self._pub = publisher
66
- self._chat = chat_backend or DummyChatBackend()
67
- self._img = images_backend or ImagesBackend()
68
 
69
  # -------- Chat Completions --------
70
  async def handle_chat_create(self, data: Dict[str, Any]) -> None:
 
4
  from typing import Any, Dict, List, AsyncIterable, Optional
5
 
6
  from rabbit_repo import RabbitRepo
7
+ from openai_backend import OpenAICompatChatBackend, OpenAIImagesBackend
8
 
9
  logger = logging.getLogger(__name__)
10
 
 
23
  return " ".join([t for t in texts if t])
24
  return ""
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # ------------------ handler class ------------------
28
  class OpenAIServers:
 
36
  *, chat_backend: Optional[ChatBackend] = None,
37
  images_backend: Optional[ImagesBackend] = None):
38
  self._pub = publisher
39
+ self._chat = chat_backend or OpenAICompatChatBackend()
40
+ self._img = images_backend or OpenAIImagesBackend()
41
 
42
  # -------- Chat Completions --------
43
  async def handle_chat_create(self, data: Dict[str, Any]) -> None:
openai_backend.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # openai_backend.py
2
+ from __future__ import annotations
3
+ import os, json, base64, logging, asyncio
4
+ from typing import Any, AsyncIterable, Dict, Optional
5
+
6
+ from openai import AsyncOpenAI
7
+ from openai._types import NOT_GIVEN
8
+
9
+ from config import settings
10
+ from oa_server import ChatBackend, ImagesBackend # reuse your ABCs
11
+
12
+ log = logging.getLogger(__name__)
13
+
14
+ def _pick_api_key() -> str:
15
+ # Use HF/OpenAI-compatible key first (from appsettings → env), else OPENAI_API_KEY
16
+ return (
17
+ (getattr(settings, "LlmHFKey", None) or os.getenv("LlmHFKey")) or
18
+ (getattr(settings, "OpenAIApiKey", None) or os.getenv("OpenAIApiKey")) or
19
+ os.getenv("OPENAI_API_KEY", "")
20
+ )
21
+
22
+ def _pick_base_url() -> Optional[str]:
23
+ # If you’ve configured a custom OpenAI-compatible endpoint (e.g. Novita), use it.
24
+ url = getattr(settings, "LlmHFUrl", None) or os.getenv("LlmHFUrl")
25
+ return url or None
26
+
27
+ def _pick_default_model(incoming: Dict[str, Any]) -> str:
28
+ # Honor request.model, else prefer HF model id, else OpenAI model from config.
29
+ return (
30
+ incoming.get("model")
31
+ or getattr(settings, "LlmHFModelID", None)
32
+ or getattr(settings, "LlmGptModel", "gpt-4o-mini")
33
+ )
34
+
35
+ class OpenAICompatChatBackend(ChatBackend):
36
+ """
37
+ Streams Chat Completions from any OpenAI-compatible server.
38
+ - Passes 'tools'/'tool_choice' straight through (function-calling).
39
+ - Accepts multimodal 'messages[*].content' with text+image_url.
40
+ - Streams ChatCompletionChunk objects; we convert to plain dicts.
41
+ """
42
+ def __init__(self):
43
+ api_key = _pick_api_key()
44
+ base_url = _pick_base_url()
45
+ if not api_key:
46
+ log.warning("No API key found; requests will fail.")
47
+ self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
48
+
49
+ async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
50
+ # Strip our internal fields; forward only OpenAI payload
51
+ req = dict(request)
52
+ req.pop("reply_key", None) # handled by the caller
53
+ # Ensure streaming on the provider, even if caller omitted it
54
+ req.setdefault("stream", True)
55
+ req.setdefault("model", _pick_default_model(req))
56
+
57
+ # Some providers don’t like unknown keys; drop obviously non-OpenAI keys defensively
58
+ for k in ("ExchangeName", "FuncName", "MessageTimeout", "RoutingKeys"):
59
+ req.pop(k, None)
60
+
61
+ # OpenAI SDK returns an async iterator of ChatCompletionChunk
62
+ stream = await self.client.chat.completions.create(**req) # stream=True above
63
+ async for chunk in stream:
64
+ # Convert to plain dict for serialization over MQ
65
+ if hasattr(chunk, "model_dump_json"):
66
+ yield json.loads(chunk.model_dump_json())
67
+ elif hasattr(chunk, "to_dict"):
68
+ yield chunk.to_dict()
69
+ else:
70
+ yield chunk # already a dict
71
+
72
+ class OpenAIImagesBackend(ImagesBackend):
73
+ """
74
+ Generates base64 images via OpenAI-compatible Images API.
75
+ """
76
+ def __init__(self):
77
+ api_key = _pick_api_key()
78
+ base_url = _pick_base_url()
79
+ self.client = AsyncOpenAI(api_key=api_key, base_url=base_url)
80
+
81
+ async def generate_b64(self, request: Dict[str, Any]) -> str:
82
+ # Expect OpenAI 'images.generate' style fields
83
+ # - model (required by most providers)
84
+ # - prompt / or 'prompt' inside request['prompt']
85
+ # - size like '1024x1024'
86
+ model = request.get("model") or getattr(settings, "LlmHFModelID", None) or "gpt-image-1"
87
+ size = request.get("size", "1024x1024")
88
+ n = int(request.get("n", 1))
89
+ resp = await self.client.images.generate(
90
+ model=model,
91
+ prompt=request.get("prompt", ""),
92
+ size=size,
93
+ n=n,
94
+ # If upstream sends 'background' or 'transparent_background', pass-through if supported:
95
+ background=request.get("background") if "background" in request else NOT_GIVEN,
96
+ transparent_background=request.get("transparent_background") if "transparent_background" in request else NOT_GIVEN,
97
+ )
98
+ # Return first image b64
99
+ data0 = resp.data[0]
100
+ if hasattr(data0, "b64_json") and data0.b64_json:
101
+ return data0.b64_json
102
+ # Some providers return URLs; fetch is out-of-scope here—return placeholder
103
+ log.warning("Images API returned URL instead of b64; returning 1x1 transparent pixel.")
104
+ return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="