johnbridges commited on
Commit
d74feed
·
1 Parent(s): 3cce116
Files changed (1) hide show
  1. app.py +87 -256
app.py CHANGED
@@ -1,259 +1,90 @@
1
- # timesfm_backend.py
2
- import time
3
- import json
4
- import logging
5
- from typing import Any, Dict, List, Optional, Tuple
6
 
7
- import numpy as np
8
-
9
- from backends_base import ChatBackend, ImagesBackend # ChatBackend for OA server
10
  from config import settings
11
-
12
- logger = logging.getLogger(__name__)
13
-
14
- # Try to import TimesFM. If not present, we fall back to a naive forecaster.
15
- _TIMESFM_AVAILABLE = False
16
- _TFM = None
 
 
 
 
 
 
 
 
 
 
 
 
17
  try:
18
- # google timesfm 2.5 requires `pip install timesfm`
19
- # model class name can be TimesFm (library-dependent)
20
- from timesfm import TimesFm # type: ignore
21
- _TIMESFM_AVAILABLE = True
22
- except Exception as e:
23
- logger.warning("timesfm not available (%s) — will use naive fallback.", e)
24
-
25
-
26
- def _parse_series(series: Any) -> np.ndarray:
27
- """
28
- Accepts list[float], list[int], list[dict{value:..}], or dict with 'values'.
29
- Returns a 1D float numpy array. Raises ValueError on empty/invalid.
30
- """
31
- if series is None:
32
- raise ValueError("series is required")
33
-
34
- if isinstance(series, dict):
35
- if "values" in series:
36
- series = series["values"]
37
- elif "y" in series:
38
- series = series["y"]
39
-
40
- vals: List[float] = []
41
- if isinstance(series, (list, tuple)):
42
- if series and isinstance(series[0], dict):
43
- # e.g. [{"t": "...", "y": 1.2}, ...] or {"value": ...}
44
- for item in series:
45
- if "y" in item:
46
- vals.append(float(item["y"]))
47
- elif "value" in item:
48
- vals.append(float(item["value"]))
49
- else:
50
- # numeric list
51
- vals = [float(x) for x in series]
52
- else:
53
- raise ValueError("series must be a list/tuple or dict with 'values'/'y'")
54
-
55
- if not vals:
56
- raise ValueError("series is empty")
57
- return np.asarray(vals, dtype=np.float32)
58
-
59
-
60
- def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray:
61
- """
62
- Very small, dependency-free fallback:
63
- - if length >= 4: mean of last 4 points
64
- - else: mean of all points
65
- """
66
- if horizon <= 0:
67
- return np.zeros((0,), dtype=np.float32)
68
- k = 4 if y.shape[0] >= 4 else y.shape[0]
69
- base = float(np.mean(y[-k:]))
70
- return np.full((horizon,), base, dtype=np.float32)
71
-
72
-
73
- class TimesFMBackend(ChatBackend):
74
- """
75
- Chat-compatible backend (for oa_server) wrapping TimesFM (if installed).
76
- If TimesFM is missing, uses a naive statistical fallback.
77
- """
78
-
79
- def __init__(self,
80
- model_id: Optional[str] = None,
81
- device: Optional[str] = None):
82
- """
83
- model_id: optional identifier for logs/metadata
84
- device: 'cpu' or 'cuda' (passed to TimesFm if supported by installed lib)
85
- """
86
- self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
87
- self.device = device or "cpu"
88
- self._model = None # lazy init
89
-
90
- # ---------- internal ----------
91
- def _ensure_model(self):
92
- if self._model is not None or not _TIMESFM_AVAILABLE:
93
- return
94
- try:
95
- # minimal init; adjust kwargs if your installed version needs different args
96
- self._model = TimesFm() # type: ignore
97
- logger.info("TimesFM model initialized.")
98
- except Exception as e:
99
- logger.exception("Failed to initialize TimesFM; will use fallback. %s", e)
100
- self._model = None
101
-
102
- # ---------- public helpers ----------
103
- async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]:
104
- """
105
- Unified forecast entrypoint.
106
- Expected keys (directly in payload OR nested under 'data' OR 'timeseries'):
107
- - series: list of numbers (or list of dicts holding 'y'/'value')
108
- - horizon: int (>0)
109
- - freq: optional string for metadata only
110
- Returns:
111
- {
112
- "model": "...",
113
- "horizon": int,
114
- "freq": str|None,
115
- "forecast": [floats],
116
- "note": str|None
117
- }
118
- """
119
- # unwrap if nested
120
- if "data" in payload and isinstance(payload["data"], dict):
121
- payload = {**payload, **payload["data"]}
122
- if "timeseries" in payload and isinstance(payload["timeseries"], dict):
123
- payload = {**payload, **payload["timeseries"]}
124
-
125
- series = payload.get("series")
126
- horizon = int(payload.get("horizon", 0))
127
- freq = payload.get("freq")
128
-
129
- y = _parse_series(series)
130
- if horizon <= 0:
131
- raise ValueError("horizon must be a positive integer")
132
-
133
- self._ensure_model()
134
-
135
- if _TIMESFM_AVAILABLE and self._model is not None:
136
- # Use real TimesFM
137
- try:
138
- # Most TimesFM APIs are batch-oriented; we add a batch dim and remove it later
139
- # If your installed version differs (e.g., .predict with signature),
140
- # change these two lines accordingly:
141
- y_batch = y[None, :]
142
- preds = self._model.predict(y_batch, horizon=horizon) # type: ignore
143
- # preds shape => (1, horizon)
144
- fc = np.asarray(preds).reshape(-1).tolist()
145
- note = None
146
- except Exception as e:
147
- logger.exception("TimesFM predict failed; falling back. %s", e)
148
- fc = _fallback_forecast(y, horizon).tolist()
149
- note = "fallback_used_due_to_predict_error"
150
- else:
151
- # Fallback path
152
- fc = _fallback_forecast(y, horizon).tolist()
153
- note = "fallback_used_timesfm_missing"
154
-
155
- return {
156
- "model": self.model_id,
157
- "horizon": horizon,
158
- "freq": freq,
159
- "forecast": fc,
160
- "note": note,
161
- }
162
-
163
- # ---------- ChatBackend interface (for oa_server) ----------
164
- async def stream(self, request: Dict[str, Any]):
165
- """
166
- OA-compatible streaming shim:
167
- - Extracts forecast inputs from request (or from last user message JSON).
168
- - Runs forecast() and yields ONE OpenAI-style chat chunk whose content
169
- is a compact JSON string with the forecast result.
170
- """
171
- rid = f"chatcmpl-timesfm-{int(time.time())}"
172
- now = int(time.time())
173
-
174
- # try to gather payload
175
- payload: Dict[str, Any] = {}
176
-
177
- # 1) allow direct shape: {series, horizon, ...} / or under 'data'/'timeseries'
178
- if isinstance(request, dict):
179
- payload = dict(request) # shallow copy
180
-
181
- # 2) optionally parse last user message if it's JSON
182
- try:
183
- msgs = request.get("messages") if isinstance(request, dict) else None
184
- if isinstance(msgs, list) and msgs:
185
- for m in reversed(msgs):
186
- if isinstance(m, dict) and m.get("role") == "user":
187
- c = m.get("content")
188
- if isinstance(c, str):
189
- c_str = c.strip()
190
- if (c_str.startswith("{") and c_str.endswith("}")) or (
191
- c_str.startswith("[") and c_str.endswith("]")
192
- ):
193
- # try parse JSON content
194
- parsed = json.loads(c_str)
195
- if isinstance(parsed, dict):
196
- payload.update(parsed)
197
- break
198
- except Exception:
199
- # non-fatal: keep whatever we had
200
- pass
201
-
202
- # run forecast
203
- try:
204
- result = await self.forecast(payload)
205
- except Exception as e:
206
- # return an error chunk in OpenAI shape
207
- err = {"error": str(e)}
208
- content = json.dumps(err, separators=(",", ":"), ensure_ascii=False)
209
- yield {
210
- "id": rid,
211
- "object": "chat.completion.chunk",
212
- "created": now,
213
- "model": self.model_id,
214
- "choices": [
215
- {
216
- "index": 0,
217
- "delta": {"role": "assistant", "content": content},
218
- "finish_reason": "stop",
219
- }
220
- ],
221
- }
222
- return
223
-
224
- # success: compact JSON content so your .NET can parse
225
- content = json.dumps(
226
- {
227
- "model": result.get("model"),
228
- "horizon": result.get("horizon"),
229
- "freq": result.get("freq"),
230
- "forecast": result.get("forecast"),
231
- "note": result.get("note"),
232
- "backend": "timesfm",
233
- },
234
- separators=(",", ":"),
235
- ensure_ascii=False,
236
- )
237
-
238
- yield {
239
- "id": rid,
240
- "object": "chat.completion.chunk",
241
- "created": now,
242
- "model": self.model_id,
243
- "choices": [
244
- {
245
- "index": 0,
246
- "delta": {"role": "assistant", "content": content},
247
- "finish_reason": "stop",
248
- }
249
- ],
250
- }
251
-
252
-
253
- # Optional: keep an images stub to satisfy oa_server wiring if needed elsewhere
254
- class StubImagesBackend(ImagesBackend):
255
- async def generate_b64(self, request: Dict[str, Any]) -> str:
256
- logger.warning("Image generation not supported in TimesFM backend.")
257
- return (
258
- "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
259
- )
 
1
+ # app.py
2
+ import asyncio, logging
3
+ import gradio as gr
 
 
4
 
 
 
 
5
  from config import settings
6
+ from rabbit_base import RabbitBase
7
+ from listener import RabbitListenerBase
8
+ from rabbit_repo import RabbitRepo
9
+ from oa_server import OpenAIServers
10
+ #from vllm_backend import VLLMChatBackend, StubImagesBackend
11
+ #from transformers_backend import TransformersChatBackend, StubImagesBackend
12
+ #from hf_backend import HFChatBackend, StubImagesBackend
13
+ from hf_backend import StubImagesBackend
14
+ from timesfm_backend import TimesFMBackend
15
+
16
+
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
20
+ )
21
+ log = logging.getLogger("app")
22
+
23
+ # ----------------- Hugging Face Spaces helpers -----------------
24
  try:
25
+ import spaces
26
+
27
+ @spaces.GPU(duration=60)
28
+ def gpu_entrypoint() -> str:
29
+ return "gpu: ready"
30
+
31
+ except Exception:
32
+ def gpu_entrypoint() -> str:
33
+ return "gpu: not available (CPU only)"
34
+
35
+ # ----------------- RabbitMQ wiring -----------------
36
+ publisher = RabbitRepo(external_source="openai.mq.server")
37
+ resolver = (lambda name: "direct" if name.startswith("oa.") else settings.RABBIT_EXCHANGE_TYPE)
38
+ base = RabbitBase(exchange_type_resolver=resolver)
39
+
40
+ servers = OpenAIServers(
41
+ publisher,
42
+ chat_backend=TimesFMBackend(),
43
+ images_backend=StubImagesBackend()
44
+ )
45
+
46
+ handlers = {
47
+ "oaChatCreate": servers.handle_chat_create,
48
+ "oaImagesGenerate": servers.handle_images_generate,
49
+ }
50
+
51
+ DECLS = [
52
+ {"ExchangeName": "oa.chat.create", "FuncName": "oaChatCreate",
53
+ "MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
54
+ {"ExchangeName": "oa.images.generate", "FuncName": "oaImagesGenerate",
55
+ "MessageTimeout": 600_000, "RoutingKeys": [settings.RABBIT_ROUTING_KEY]},
56
+ ]
57
+
58
+ listener = RabbitListenerBase(base, instance_name=settings.RABBIT_INSTANCE_NAME, handlers=handlers)
59
+
60
+ # ----------------- Startup init -----------------
61
+ async def _startup_init():
62
+ try:
63
+ await base.connect() # connect to RabbitMQ
64
+ await listener.start(DECLS) # start queue listeners
65
+ return "OpenAI MQ + vLLM: ready"
66
+ except Exception as e:
67
+ log.exception("Startup init failed")
68
+ return f"ERROR: {e}"
69
+
70
+ async def ping():
71
+ return "ok"
72
+
73
+ # ----------------- Gradio UI -----------------
74
+ with gr.Blocks(title="OpenAI over RabbitMQ (local vLLM)", theme=gr.themes.Soft()) as demo:
75
+ gr.Markdown("## OpenAI-compatible over RabbitMQ using vLLM locally inside Space")
76
+ with gr.Tabs():
77
+ with gr.Tab("Service"):
78
+ btn = gr.Button("Ping")
79
+ out = gr.Textbox(label="Ping result")
80
+ btn.click(ping, inputs=None, outputs=out)
81
+ init_status = gr.Textbox(label="Startup status", interactive=False)
82
+ demo.load(fn=_startup_init, inputs=None, outputs=init_status)
83
+
84
+ with gr.Tab("@spaces.GPU Probe"):
85
+ gpu_btn = gr.Button("GPU Ready Probe", variant="primary")
86
+ gpu_out = gr.Textbox(label="GPU Probe Result", interactive=False)
87
+ gpu_btn.click(gpu_entrypoint, inputs=None, outputs=gpu_out)
88
+
89
+ if __name__ == "__main__":
90
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True, debug=True, mcp_server=True)