| """FastAPI приложение. |
| |
| Запуск: |
| uvicorn src.api.main:app --reload |
| # Swagger UI: http://127.0.0.1:8000/docs |
| |
| Эндпоинты: |
| GET /health — статус сервиса и загруженной модели |
| GET /databases — список БД из data/databases (PAUQ-структура) |
| POST /generate-sql — генерация SQL по db_id из PAUQ |
| POST /schema — схема произвольной БД по connection string |
| POST /query — полный pipeline для произвольной БД |
| (генерация + опциональное исполнение + guardrail) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import sqlite3 |
| import time |
|
|
| from fastapi import Depends, FastAPI, HTTPException |
| from fastapi.concurrency import run_in_threadpool |
|
|
| from src.api.dependencies import get_engine, get_schema_retriever, lifespan |
| from src.api.schemas import ( |
| DatabaseInfo, |
| ExecutionResult, |
| GenerateRequest, |
| GenerateResponse, |
| HealthResponse, |
| QueryRequest, |
| QueryResponse, |
| SchemaRequest, |
| SchemaResponse, |
| TablePayload, |
| ColumnPayload, |
| ) |
| from src.business.vocabulary import BusinessVocabulary |
| from src.config import settings |
| from src.data.schema import SchemaRetriever |
| from src.data.schema_provider import ConnectionSchemaProvider |
| from src.db.executor import SqlExecutor |
| from src.models.inference import InferenceEngine |
| from src.models.postprocess import is_select_only, is_valid_sql |
|
|
| logger = logging.getLogger(__name__) |
|
|
| app = FastAPI( |
| title="ru2sql", |
| description="Преобразование вопросов на русском в SQL-запросы", |
| version="0.2.0", |
| lifespan=lifespan, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @app.get("/health", response_model=HealthResponse) |
| def health(engine: InferenceEngine = Depends(get_engine)): |
| return HealthResponse( |
| status="ok", |
| model_loaded=engine.loaded, |
| base_model=engine.base_model_name, |
| ) |
|
|
|
|
| @app.post("/warmup") |
| async def warmup(engine: InferenceEngine = Depends(get_engine)): |
| """Прогревает модель одной короткой генерацией. |
| |
| Первый инференс холодной модели на CPU сильно дольше последующих: |
| подгружаются LoRA-слои, формируется граф вычислений, заполняется |
| KV-кеш. Вызов /warmup делает один маленький проход с минимальным |
| max_new_tokens, чтобы боевой /query шёл уже по прогретой модели. |
| """ |
| t0 = time.time() |
| schema = "CREATE TABLE t (id INT);" |
| await run_in_threadpool(engine.generate, schema, "SELECT id", None, 16) |
| return {"warmup_seconds": round(time.time() - t0, 2)} |
|
|
|
|
| @app.get("/databases", response_model=list[DatabaseInfo]) |
| def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)): |
| out: list[DatabaseInfo] = [] |
| for db_id in retriever.list_databases(): |
| try: |
| tables = [t.name for t in retriever.get_tables(db_id, n_sample_rows=0)] |
| out.append(DatabaseInfo(db_id=db_id, tables=tables)) |
| except FileNotFoundError: |
| continue |
| return out |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/generate-sql", response_model=GenerateResponse) |
| async def generate_sql( |
| req: GenerateRequest, |
| engine: InferenceEngine = Depends(get_engine), |
| retriever: SchemaRetriever = Depends(get_schema_retriever), |
| ): |
| """Генерация SQL для базы из PAUQ-структуры (data/databases/{db_id}).""" |
| try: |
| schema_text = retriever.render_schema(req.db_id) |
| except FileNotFoundError as e: |
| raise HTTPException(status_code=404, detail=str(e)) from e |
|
|
| vocab = ( |
| BusinessVocabulary.from_dict(req.vocabulary.model_dump()) |
| if req.vocabulary |
| else None |
| ) |
|
|
| result = await run_in_threadpool( |
| engine.generate, schema_text, req.question, vocab |
| ) |
| valid = is_valid_sql(result.sql) |
|
|
| response = GenerateResponse( |
| sql=result.sql, |
| raw_output=result.raw_output, |
| is_valid_sql=valid, |
| ) |
|
|
| if req.execute and valid: |
| if not is_select_only(result.sql): |
| response.error = ( |
| "SQL отклонён гвардейлом: разрешены только запросы SELECT и WITH." |
| ) |
| logger.warning("Guardrail отклонил SQL: %r", result.sql[:120]) |
| return response |
| try: |
| response.execution = await run_in_threadpool( |
| _execute_sql_pauq, req.db_id, result.sql, retriever |
| ) |
| except sqlite3.Error as e: |
| response.error = f"SQL execution error: {e}" |
|
|
| return response |
|
|
|
|
| def _execute_sql_pauq(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionResult: |
| db_path = retriever.db_path(db_id) |
| conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True) |
| try: |
| conn.text_factory = lambda b: b.decode("utf-8", errors="replace") |
| cur = conn.cursor() |
| cur.execute(sql) |
| rows = cur.fetchmany(100) |
| cols = [d[0] for d in cur.description] if cur.description else [] |
| return ExecutionResult(columns=cols, rows=[list(r) for r in rows], row_count=len(rows)) |
| finally: |
| conn.close() |
|
|
|
|
| |
| |
| |
|
|
| @app.post("/schema", response_model=SchemaResponse) |
| async def get_schema(req: SchemaRequest): |
| """Возвращает схему произвольной БД для отображения в клиенте.""" |
| try: |
| provider = ConnectionSchemaProvider(req.connection_string) |
| tables = await run_in_threadpool(provider.get_tables, 2 if req.include_samples else 0) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Ошибка чтения схемы: {e}") from e |
|
|
| payload = [ |
| TablePayload( |
| name=t.name, |
| columns=[ |
| ColumnPayload( |
| name=c.name, type=c.type, |
| nullable=c.nullable, primary_key=c.primary_key, |
| ) |
| for c in t.columns |
| ], |
| sample_rows=[list(r) for r in t.sample_rows], |
| ddl=t.to_ddl(), |
| ) |
| for t in tables |
| ] |
| return SchemaResponse(tables=payload) |
|
|
|
|
| @app.post("/query", response_model=QueryResponse) |
| async def query( |
| req: QueryRequest, |
| engine: InferenceEngine = Depends(get_engine), |
| ): |
| """Полный pipeline: вопрос → SQL → опциональное исполнение на БД. |
| |
| В отличие от /generate-sql, работает с произвольной БД по connection |
| string. Используется Streamlit-клиентом и сторонними интеграциями. |
| Перед выполнением SQL проходит проверку is_select_only (раздел 4.4). |
| """ |
| |
| try: |
| provider = ConnectionSchemaProvider(req.connection_string) |
| schema_text = await run_in_threadpool(provider.render_schema, True) |
| except Exception as e: |
| raise HTTPException(status_code=400, detail=f"Ошибка подключения к БД: {e}") from e |
|
|
| vocab = ( |
| BusinessVocabulary.from_dict(req.vocabulary.model_dump()) |
| if req.vocabulary |
| else None |
| ) |
|
|
| |
| t0 = time.time() |
| result = await run_in_threadpool(engine.generate, schema_text, req.question, vocab) |
| gen_time = time.time() - t0 |
|
|
| valid = is_valid_sql(result.sql) |
| response = QueryResponse( |
| sql=result.sql, |
| raw_output=result.raw_output, |
| is_valid_sql=valid, |
| gen_time_seconds=round(gen_time, 2), |
| ) |
|
|
| |
| if req.execute and valid: |
| if not is_select_only(result.sql): |
| response.error = ( |
| "SQL отклонён гвардейлом: разрешены только запросы SELECT и WITH." |
| ) |
| logger.warning("Guardrail отклонил SQL: %r", result.sql[:120]) |
| return response |
| try: |
| executor = SqlExecutor(req.connection_string) |
| qr = await run_in_threadpool(executor.run, result.sql) |
| if qr.success: |
| response.execution = ExecutionResult( |
| columns=qr.columns, rows=qr.rows, row_count=qr.row_count, |
| ) |
| else: |
| response.error = f"SQL execution error: {qr.error}" |
| except Exception as e: |
| response.error = f"SQL execution error: {e}" |
|
|
| return response |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True) |
|
|