|
|
|
|
|
import time |
|
|
import json |
|
|
import logging |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from backends_base import ChatBackend, ImagesBackend |
|
|
from config import settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
try: |
|
|
from timesfm import TimesFm |
|
|
_TIMESFM_AVAILABLE = True |
|
|
except Exception as e: |
|
|
logger.warning("timesfm not available (%s) — using naive fallback.", e) |
|
|
TimesFm = None |
|
|
_TIMESFM_AVAILABLE = False |
|
|
|
|
|
|
|
|
|
|
|
def _parse_series(series: Any) -> np.ndarray: |
|
|
""" |
|
|
Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'. |
|
|
Returns: 1D float32 numpy array. |
|
|
""" |
|
|
if series is None: |
|
|
raise ValueError("series is required") |
|
|
|
|
|
if isinstance(series, dict): |
|
|
|
|
|
series = series.get("values") or series.get("y") |
|
|
|
|
|
vals: List[float] = [] |
|
|
if isinstance(series, (list, tuple)): |
|
|
if series and isinstance(series[0], dict): |
|
|
for item in series: |
|
|
if "y" in item: |
|
|
vals.append(float(item["y"])) |
|
|
elif "value" in item: |
|
|
vals.append(float(item["value"])) |
|
|
else: |
|
|
vals = [float(x) for x in series] |
|
|
else: |
|
|
raise ValueError("series must be a list/tuple or dict with 'values'/'y'") |
|
|
|
|
|
if not vals: |
|
|
raise ValueError("series is empty") |
|
|
return np.asarray(vals, dtype=np.float32) |
|
|
|
|
|
|
|
|
def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray: |
|
|
""" |
|
|
Naive fallback: mean of last 4 (or all if <4), repeated H times. |
|
|
""" |
|
|
if horizon <= 0: |
|
|
return np.zeros((0,), dtype=np.float32) |
|
|
k = 4 if y.shape[0] >= 4 else y.shape[0] |
|
|
base = float(np.mean(y[-k:])) |
|
|
return np.full((horizon,), base, dtype=np.float32) |
|
|
|
|
|
|
|
|
def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]: |
|
|
""" |
|
|
Try to parse JSON from a plain string or a fenced ```json block. |
|
|
Returns dict or None. |
|
|
""" |
|
|
s = s.strip() |
|
|
|
|
|
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")): |
|
|
try: |
|
|
obj = json.loads(s) |
|
|
return obj if isinstance(obj, dict) else None |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
if "```" in s: |
|
|
parts = s.split("```") |
|
|
for i in range(1, len(parts), 2): |
|
|
block = parts[i] |
|
|
if block.lstrip().lower().startswith("json"): |
|
|
block = block.split("\n", 1)[-1] |
|
|
try: |
|
|
obj = json.loads(block.strip()) |
|
|
return obj if isinstance(obj, dict) else None |
|
|
except Exception: |
|
|
continue |
|
|
return None |
|
|
|
|
|
|
|
|
def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
OpenAI chat format compatibility: |
|
|
payload["messages"] may hold user JSON in the last user message. |
|
|
content can be a plain string or a list of parts [{"type":"text","text":...}]. |
|
|
If a JSON object is found, merge its keys into payload. |
|
|
""" |
|
|
msgs = payload.get("messages") |
|
|
if not isinstance(msgs, list): |
|
|
return payload |
|
|
|
|
|
for m in reversed(msgs): |
|
|
if not isinstance(m, dict) or m.get("role") != "user": |
|
|
continue |
|
|
content = m.get("content") |
|
|
texts: List[str] = [] |
|
|
if isinstance(content, list): |
|
|
texts = [ |
|
|
p.get("text") |
|
|
for p in content |
|
|
if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str) |
|
|
] |
|
|
elif isinstance(content, str): |
|
|
texts = [content] |
|
|
|
|
|
for t in reversed(texts): |
|
|
obj = _extract_json_from_text(t) |
|
|
if isinstance(obj, dict): |
|
|
return {**payload, **obj} |
|
|
break |
|
|
return payload |
|
|
|
|
|
|
|
|
|
|
|
class TimesFMBackend(ChatBackend): |
|
|
""" |
|
|
Accepts OpenAI chat-completions requests. |
|
|
Pulls timeseries config from: |
|
|
- top-level keys, OR |
|
|
- payload['data'] (CloudEvents wrapper), OR |
|
|
- last user message JSON (OpenAI format). |
|
|
Keys: |
|
|
series: list[float|int|{y|value}] |
|
|
horizon: int (>0) |
|
|
freq: optional str |
|
|
""" |
|
|
|
|
|
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None): |
|
|
self.model_id = model_id or "google/timesfm-2.5-200m-pytorch" |
|
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
self._model: Optional[TimesFm] = None |
|
|
|
|
|
def _ensure_model(self) -> None: |
|
|
if self._model is not None or not _TIMESFM_AVAILABLE: |
|
|
return |
|
|
try: |
|
|
|
|
|
self._model = TimesFm( |
|
|
context_len=512, |
|
|
horizon_len=128, |
|
|
input_patch_len=32, |
|
|
) |
|
|
self._model.load_from_checkpoint(self.model_id) |
|
|
try: |
|
|
self._model.to(self.device) |
|
|
except Exception: |
|
|
pass |
|
|
logger.info("TimesFM loaded from %s on %s", self.model_id, self.device) |
|
|
except Exception as e: |
|
|
logger.exception("TimesFM init failed; fallback only. %s", e) |
|
|
self._model = None |
|
|
|
|
|
async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
|
|
|
|
|
if isinstance(payload.get("data"), dict): |
|
|
payload = {**payload, **payload["data"]} |
|
|
if isinstance(payload.get("timeseries"), dict): |
|
|
payload = {**payload, **payload["timeseries"]} |
|
|
|
|
|
|
|
|
payload = _merge_openai_message_json(payload) |
|
|
|
|
|
y = _parse_series(payload.get("series")) |
|
|
horizon = int(payload.get("horizon", 0)) |
|
|
freq = payload.get("freq") |
|
|
if horizon <= 0: |
|
|
raise ValueError("horizon must be a positive integer") |
|
|
|
|
|
self._ensure_model() |
|
|
|
|
|
note = None |
|
|
if _TIMESFM_AVAILABLE and self._model is not None: |
|
|
try: |
|
|
x = torch.tensor(y, dtype=torch.float32, device=self.device).unsqueeze(0) |
|
|
preds = self._model.forecast_on_batch(x, horizon) |
|
|
fc = preds[0].detach().cpu().numpy().astype(float).tolist() |
|
|
except Exception as e: |
|
|
logger.exception("TimesFM forecast failed; fallback used. %s", e) |
|
|
fc = _fallback_forecast(y, horizon).tolist() |
|
|
note = "fallback_used_due_to_predict_error" |
|
|
else: |
|
|
fc = _fallback_forecast(y, horizon).tolist() |
|
|
note = "fallback_used_timesfm_missing" |
|
|
|
|
|
return { |
|
|
"model": self.model_id, |
|
|
"horizon": horizon, |
|
|
"freq": freq, |
|
|
"forecast": fc, |
|
|
"note": note, |
|
|
} |
|
|
|
|
|
async def stream(self, request: Dict[str, Any]): |
|
|
""" |
|
|
OA-compatible streaming shim: |
|
|
Emits exactly one chat.completion.chunk with compact JSON content. |
|
|
""" |
|
|
rid = f"chatcmpl-timesfm-{int(time.time())}" |
|
|
now = int(time.time()) |
|
|
payload = dict(request) if isinstance(request, dict) else {} |
|
|
|
|
|
try: |
|
|
result = await self.forecast(payload) |
|
|
except Exception as e: |
|
|
content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False) |
|
|
yield { |
|
|
"id": rid, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": now, |
|
|
"model": self.model_id, |
|
|
"choices": [ |
|
|
{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"} |
|
|
], |
|
|
} |
|
|
return |
|
|
|
|
|
content = json.dumps( |
|
|
{ |
|
|
"model": result["model"], |
|
|
"horizon": result["horizon"], |
|
|
"freq": result["freq"], |
|
|
"forecast": result["forecast"], |
|
|
"note": result.get("note"), |
|
|
"backend": "timesfm", |
|
|
}, |
|
|
separators=(",", ":"), |
|
|
ensure_ascii=False, |
|
|
) |
|
|
yield { |
|
|
"id": rid, |
|
|
"object": "chat.completion.chunk", |
|
|
"created": now, |
|
|
"model": self.model_id, |
|
|
"choices": [ |
|
|
{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"} |
|
|
], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class StubImagesBackend(ImagesBackend): |
|
|
async def generate_b64(self, request: Dict[str, Any]) -> str: |
|
|
logger.warning("Image generation not supported in TimesFM backend.") |
|
|
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII=" |
|
|
|