GradLLM / timesfm_backend.py
johnbridges's picture
.
16f3ba1
raw
history blame
9.05 kB
# timesfm_backend.py
import time
import json
import logging
from typing import Any, Dict, List, Optional
import numpy as np
import torch
from backends_base import ChatBackend, ImagesBackend
from config import settings
logger = logging.getLogger(__name__)
# ---------------- TimesFM import (fallback-safe) ----------------
try:
from timesfm import TimesFm # Google TimesFM 2.5+
_TIMESFM_AVAILABLE = True
except Exception as e:
logger.warning("timesfm not available (%s) — using naive fallback.", e)
TimesFm = None # type: ignore
_TIMESFM_AVAILABLE = False
# ---------------- helpers ----------------
def _parse_series(series: Any) -> np.ndarray:
"""
Accepts: list[float|int], list[dict{'y'|'value'}], or dict with 'values'/'y'.
Returns: 1D float32 numpy array.
"""
if series is None:
raise ValueError("series is required")
if isinstance(series, dict):
# allow {"values":[...]} or {"y":[...]}
series = series.get("values") or series.get("y")
vals: List[float] = []
if isinstance(series, (list, tuple)):
if series and isinstance(series[0], dict):
for item in series:
if "y" in item:
vals.append(float(item["y"]))
elif "value" in item:
vals.append(float(item["value"]))
else:
vals = [float(x) for x in series]
else:
raise ValueError("series must be a list/tuple or dict with 'values'/'y'")
if not vals:
raise ValueError("series is empty")
return np.asarray(vals, dtype=np.float32)
def _fallback_forecast(y: np.ndarray, horizon: int) -> np.ndarray:
"""
Naive fallback: mean of last 4 (or all if <4), repeated H times.
"""
if horizon <= 0:
return np.zeros((0,), dtype=np.float32)
k = 4 if y.shape[0] >= 4 else y.shape[0]
base = float(np.mean(y[-k:]))
return np.full((horizon,), base, dtype=np.float32)
def _extract_json_from_text(s: str) -> Optional[Dict[str, Any]]:
"""
Try to parse JSON from a plain string or a fenced ```json block.
Returns dict or None.
"""
s = s.strip()
# whole-string JSON object/array
if (s.startswith("{") and s.endswith("}")) or (s.startswith("[") and s.endswith("]")):
try:
obj = json.loads(s)
return obj if isinstance(obj, dict) else None
except Exception:
pass
# fenced code blocks
if "```" in s:
parts = s.split("```")
for i in range(1, len(parts), 2):
block = parts[i]
if block.lstrip().lower().startswith("json"):
block = block.split("\n", 1)[-1]
try:
obj = json.loads(block.strip())
return obj if isinstance(obj, dict) else None
except Exception:
continue
return None
def _merge_openai_message_json(payload: Dict[str, Any]) -> Dict[str, Any]:
"""
OpenAI chat format compatibility:
payload["messages"] may hold user JSON in the last user message.
content can be a plain string or a list of parts [{"type":"text","text":...}].
If a JSON object is found, merge its keys into payload.
"""
msgs = payload.get("messages")
if not isinstance(msgs, list):
return payload
for m in reversed(msgs):
if not isinstance(m, dict) or m.get("role") != "user":
continue
content = m.get("content")
texts: List[str] = []
if isinstance(content, list):
texts = [
p.get("text")
for p in content
if isinstance(p, dict) and p.get("type") == "text" and isinstance(p.get("text"), str)
]
elif isinstance(content, str):
texts = [content]
for t in reversed(texts):
obj = _extract_json_from_text(t)
if isinstance(obj, dict):
return {**payload, **obj}
break # only inspect last user
return payload
# ---------------- backend ----------------
class TimesFMBackend(ChatBackend):
"""
Accepts OpenAI chat-completions requests.
Pulls timeseries config from:
- top-level keys, OR
- payload['data'] (CloudEvents wrapper), OR
- last user message JSON (OpenAI format).
Keys:
series: list[float|int|{y|value}]
horizon: int (>0)
freq: optional str
"""
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
self.model_id = model_id or "google/timesfm-2.5-200m-pytorch"
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._model: Optional[TimesFm] = None # type: ignore
def _ensure_model(self) -> None:
if self._model is not None or not _TIMESFM_AVAILABLE:
return
try:
# Set lengths compatible with the 2.5 checkpoints.
self._model = TimesFm(
context_len=512,
horizon_len=128,
input_patch_len=32,
)
self._model.load_from_checkpoint(self.model_id)
try:
self._model.to(self.device) # type: ignore[attr-defined]
except Exception:
pass
logger.info("TimesFM loaded from %s on %s", self.model_id, self.device)
except Exception as e:
logger.exception("TimesFM init failed; fallback only. %s", e)
self._model = None
async def forecast(self, payload: Dict[str, Any]) -> Dict[str, Any]:
# unwrap CloudEvents .data and nested .timeseries
if isinstance(payload.get("data"), dict):
payload = {**payload, **payload["data"]}
if isinstance(payload.get("timeseries"), dict):
payload = {**payload, **payload["timeseries"]}
# merge JSON embedded in last user message (OpenAI format)
payload = _merge_openai_message_json(payload)
y = _parse_series(payload.get("series"))
horizon = int(payload.get("horizon", 0))
freq = payload.get("freq")
if horizon <= 0:
raise ValueError("horizon must be a positive integer")
self._ensure_model()
note = None
if _TIMESFM_AVAILABLE and self._model is not None:
try:
x = torch.tensor(y, dtype=torch.float32, device=self.device).unsqueeze(0) # [1, T]
preds = self._model.forecast_on_batch(x, horizon) # -> [1, H]
fc = preds[0].detach().cpu().numpy().astype(float).tolist()
except Exception as e:
logger.exception("TimesFM forecast failed; fallback used. %s", e)
fc = _fallback_forecast(y, horizon).tolist()
note = "fallback_used_due_to_predict_error"
else:
fc = _fallback_forecast(y, horizon).tolist()
note = "fallback_used_timesfm_missing"
return {
"model": self.model_id,
"horizon": horizon,
"freq": freq,
"forecast": fc,
"note": note,
}
async def stream(self, request: Dict[str, Any]):
"""
OA-compatible streaming shim:
Emits exactly one chat.completion.chunk with compact JSON content.
"""
rid = f"chatcmpl-timesfm-{int(time.time())}"
now = int(time.time())
payload = dict(request) if isinstance(request, dict) else {}
try:
result = await self.forecast(payload)
except Exception as e:
content = json.dumps({"error": str(e)}, separators=(",", ":"), ensure_ascii=False)
yield {
"id": rid,
"object": "chat.completion.chunk",
"created": now,
"model": self.model_id,
"choices": [
{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}
],
}
return
content = json.dumps(
{
"model": result["model"],
"horizon": result["horizon"],
"freq": result["freq"],
"forecast": result["forecast"],
"note": result.get("note"),
"backend": "timesfm",
},
separators=(",", ":"),
ensure_ascii=False,
)
yield {
"id": rid,
"object": "chat.completion.chunk",
"created": now,
"model": self.model_id,
"choices": [
{"index": 0, "delta": {"role": "assistant", "content": content}, "finish_reason": "stop"}
],
}
# ---------------- images stub ----------------
class StubImagesBackend(ImagesBackend):
async def generate_b64(self, request: Dict[str, Any]) -> str:
logger.warning("Image generation not supported in TimesFM backend.")
return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="