Spaces:
Sleeping
Sleeping
| """FastAPI application β AI SQL Analyst API and frontend server.""" | |
| import logging | |
| import threading | |
| from pathlib import Path | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(name)s %(message)s") | |
| logger = logging.getLogger("api") | |
| app = FastAPI(title="AI SQL Analyst", version="1.0.0") | |
| def _warm_caches(): | |
| """Pre-build schema, relationship, and data-profile caches at startup. | |
| Runs in a background thread so the server starts instantly. | |
| Any request that arrives before the profile is ready gets the static | |
| business rules immediately (non-blocking) and the full profile on the | |
| next request. | |
| """ | |
| try: | |
| logger.info("Cache warm-up β starting background pre-load...") | |
| from db.schema import format_schema | |
| from db.relationships import format_relationships | |
| import db.profiler as _profiler | |
| format_schema() | |
| logger.info("Cache warm-up β schema loaded") | |
| format_relationships() | |
| logger.info("Cache warm-up β relationships loaded") | |
| # Try to load from persistent DB cache first (milliseconds) | |
| loaded = _profiler.load_profile_from_db_cache() | |
| if loaded: | |
| logger.info("Cache warm-up β profile loaded from DB cache (instant)") | |
| else: | |
| # No DB cache yet (first ever deploy) β build from scratch | |
| logger.info("Cache warm-up β no DB cache found, building profile...") | |
| _profiler._do_build() | |
| logger.info("Cache warm-up β profile built and saved to DB") | |
| except Exception as exc: | |
| logger.warning("Cache warm-up failed (non-fatal): %s", exc) | |
| # Kick off cache pre-loading as soon as the module is imported | |
| threading.Thread(target=_warm_caches, daemon=True).start() | |
| # ββ CORS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Request / Response schemas ββββββββββββββββββββββββββββββββββββββββββββββ | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| provider: str = "groq" # "groq" | "openai" | |
| conversation_id: str | None = None | |
| class GenerateSQLResponse(BaseModel): | |
| sql: str | |
| class ExecuteSQLRequest(BaseModel): | |
| sql: str | |
| class ExecuteSQLResponse(BaseModel): | |
| sql: str | |
| data: list | |
| row_count: int | |
| error: str | None = None | |
| class ChatResponse(BaseModel): | |
| sql: str | |
| data: list | |
| row_count: int | |
| answer: str | |
| insights: str | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_sql_endpoint(req: QuestionRequest): | |
| """Generate SQL for a question without executing it.""" | |
| from ai.pipeline import SQLAnalystPipeline | |
| pipeline = SQLAnalystPipeline(provider=req.provider) | |
| sql = pipeline.generate_sql_only(req.question) | |
| return GenerateSQLResponse(sql=sql) | |
| def execute_sql_endpoint(req: ExecuteSQLRequest): | |
| """Execute a raw SQL SELECT query and return the results.""" | |
| from ai.validator import validate_sql | |
| from db.executor import execute_sql | |
| is_safe, reason = validate_sql(req.sql) | |
| if not is_safe: | |
| return ExecuteSQLResponse( | |
| sql=req.sql, | |
| data=[], | |
| row_count=0, | |
| error=f"Query rejected: {reason}", | |
| ) | |
| result = execute_sql(req.sql) | |
| if not result["success"]: | |
| return ExecuteSQLResponse( | |
| sql=req.sql, | |
| data=[], | |
| row_count=0, | |
| error=result["error"], | |
| ) | |
| data = result["data"] | |
| return ExecuteSQLResponse( | |
| sql=req.sql, | |
| data=data, | |
| row_count=len(data), | |
| ) | |
| def chat_endpoint(req: QuestionRequest): | |
| from ai.pipeline import SQLAnalystPipeline | |
| from db.memory import get_recent_history, add_turn | |
| logger.info( | |
| "CHAT request | provider=%s | conversation_id=%s | question=%s", | |
| req.provider, | |
| req.conversation_id or "default", | |
| req.question, | |
| ) | |
| conversation_id = req.conversation_id or "default" | |
| history = get_recent_history(conversation_id, limit=5) | |
| # Augment the question with recent conversation context | |
| if history: | |
| logger.info( | |
| "CHAT context | conversation_id=%s | history_turns=%d", | |
| conversation_id, | |
| len(history), | |
| ) | |
| history_lines: list[str] = ["You are in a multi-turn conversation. Here are the recent exchanges:"] | |
| for turn in history: | |
| history_lines.append(f"User: {turn['question']}") | |
| history_lines.append(f"Assistant: {turn['answer']}") | |
| history_lines.append(f"Now the user asks: {req.question}") | |
| question_with_context = "\n".join(history_lines) | |
| else: | |
| logger.info( | |
| "CHAT context | conversation_id=%s | history_turns=0 (no prior context used)", | |
| conversation_id, | |
| ) | |
| question_with_context = req.question | |
| pipeline = SQLAnalystPipeline(provider=req.provider) | |
| result = pipeline.run(question_with_context) | |
| logger.info( | |
| "CHAT result | conversation_id=%s | used_context=%s | sql_preview=%s", | |
| conversation_id, | |
| "yes" if history else "no", | |
| (result.get("sql") or "").replace("\n", " ")[:200], | |
| ) | |
| # Persist this turn for future context (store up to 200 rows so modal can show them) | |
| add_turn( | |
| conversation_id, | |
| req.question, | |
| result["answer"], | |
| result["sql"], | |
| query_result=(result["data"][:200] if result.get("data") else None), | |
| ) | |
| return ChatResponse( | |
| sql=result["sql"], | |
| data=result["data"], | |
| row_count=len(result.get("data") or []), | |
| answer=result["answer"], | |
| insights=result["insights"], | |
| ) | |
| # ββ Schema info endpoint (for debugging / transparency) βββββββββββββββββββββ | |
| def history_endpoint(conversation_id: str = "default"): | |
| from db.memory import get_full_history | |
| return get_full_history(conversation_id) | |
| def delete_turn_endpoint(turn_id: int): | |
| from db.memory import delete_turn | |
| delete_turn(turn_id) | |
| return {"ok": True} | |
| def history_sql_endpoint(turn_id: int): | |
| """Return just the SQL query for a specific history turn.""" | |
| from db.memory import get_turn_by_id | |
| turn = get_turn_by_id(turn_id) | |
| if not turn: | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404, detail="Turn not found") | |
| return {"turn_id": turn_id, "sql": turn.get("sql_query")} | |
| def history_result_endpoint(turn_id: int): | |
| """Return just the query result data for a specific history turn.""" | |
| from db.memory import get_turn_by_id | |
| turn = get_turn_by_id(turn_id) | |
| if not turn: | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404, detail="Turn not found") | |
| data = turn.get("query_result") or [] | |
| return {"turn_id": turn_id, "data": data, "row_count": len(data)} | |
| def history_answer_endpoint(turn_id: int): | |
| """Return just the AI answer/explanation for a specific history turn.""" | |
| from db.memory import get_turn_by_id | |
| turn = get_turn_by_id(turn_id) | |
| if not turn: | |
| from fastapi import HTTPException | |
| raise HTTPException(status_code=404, detail="Turn not found") | |
| return {"turn_id": turn_id, "question": turn.get("question"), "answer": turn.get("answer")} | |
| def schema_endpoint(): | |
| from db.schema import get_schema | |
| return get_schema() | |
| def relationships_endpoint(): | |
| from db.relationships import discover_relationships | |
| rels = discover_relationships() | |
| return [ | |
| { | |
| "table_a": r.table_a, "column_a": r.column_a, | |
| "table_b": r.table_b, "column_b": r.column_b, | |
| "confidence": r.confidence, "source": r.source, | |
| } | |
| for r in rels | |
| ] | |
| # ββ Frontend static files ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| FRONTEND_DIR = Path(__file__).parent / "frontend" | |
| app.mount("/static", StaticFiles(directory=str(FRONTEND_DIR)), name="static") | |
| def serve_frontend(): | |
| return FileResponse(str(FRONTEND_DIR / "index.html")) | |
| # ββ Run βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) | |