Ru2SQL / src /data /schema_provider.py
Tyycha's picture
fix bugs
cc2ed2f
"""Единый интерфейс работы со схемами баз данных.
До рефакторинга в проекте существовали две независимые иерархии:
* ``SchemaRetriever`` (``src/data/schema.py``) — читал DDL из SQLite-файлов
в Spider/PAUQ-структуре ``{databases_dir}/{db_id}/{db_id}.sqlite``.
* ``DbConnector`` (``src/db/connector.py``) — подключался к произвольной БД
по строке подключения, умел SQLite/PostgreSQL/MySQL.
Они решали одну задачу, но по-разному оформляли результат
(``TableInfo`` в каждом был свой) и не имели общего интерфейса. Этот
модуль вводит единый протокол ``SchemaProvider`` и общий dataclass
``TableSchema``. Старые классы становятся тонкими фасадами поверх
новых реализаций.
Соответствует разделам 3.4 и 4.1 пояснительной записки.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterable, Protocol
@dataclass
class ColumnSchema:
"""Описание колонки таблицы."""
name: str
type: str
nullable: bool = True
primary_key: bool = False
@dataclass
class TableSchema:
"""Унифицированное описание таблицы независимо от источника схемы.
Поле ``create_sql`` хранит исходный CREATE TABLE statement, если он
доступен (актуально для SQLite — он его сам отдаёт из ``sqlite_master``).
Когда источник схемы — PostgreSQL/MySQL, DDL генерируется из
метаданных через :meth:`to_ddl`.
"""
name: str
columns: list[ColumnSchema] = field(default_factory=list)
sample_rows: list[tuple] = field(default_factory=list)
create_sql: str | None = None
def to_ddl(self) -> str:
"""CREATE TABLE для подстановки в промпт.
Если есть оригинальный ``create_sql`` — возвращаем его, чтобы
сохранить все нюансы (ограничения, FK, AUTOINCREMENT). Иначе
собираем из метаданных колонок.
"""
if self.create_sql:
return self.create_sql.rstrip(";") + ";"
col_parts: list[str] = []
for col in self.columns:
line = f" {col.name} {col.type}"
if col.primary_key:
line += " PRIMARY KEY"
if not col.nullable:
line += " NOT NULL"
col_parts.append(line)
return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);"
class SchemaProvider(Protocol):
"""Протокол любого источника схемы базы данных.
Контракт минимальный: уметь перечислить таблицы и отрендерить схему
в текст для подстановки в промпт. Этого достаточно и для PAUQ-сценария
(``SpiderSchemaProvider``), и для подключения к боевой БД пользователя
(``ConnectionSchemaProvider``).
"""
def list_tables(self) -> list[str]: ...
def get_tables(self, n_sample_rows: int = 3) -> list[TableSchema]: ...
def render_schema(self, include_samples: bool = True) -> str: ...
# ──────────────────────────────────────────────────────────────────────
# Утилита рендеринга — общая для всех реализаций
# ──────────────────────────────────────────────────────────────────────
def render_tables(tables: Iterable[TableSchema], include_samples: bool = True) -> str:
"""Собирает текстовое представление списка таблиц для промпта."""
parts: list[str] = []
for t in tables:
parts.append(t.to_ddl())
if include_samples and t.sample_rows:
parts.append(f"-- Примеры строк из {t.name}:")
for row in t.sample_rows:
parts.append(f"-- {row}")
parts.append("")
return "\n".join(parts).strip()
# ──────────────────────────────────────────────────────────────────────
# Реализация 1 — Spider/PAUQ-структура
# ──────────────────────────────────────────────────────────────────────
class SpiderSchemaProvider:
"""SchemaProvider для каталога ``data/databases/{db_id}/{db_id}.sqlite``.
Используется при работе с PAUQ/Spider: каждая БД лежит в одноимённой
папке. Один экземпляр SpiderSchemaProvider обслуживает всю коллекцию
баз — конкретная БД выбирается по ``db_id`` в методах.
"""
def __init__(self, databases_dir: Path | str):
self.databases_dir = Path(databases_dir)
def list_databases(self) -> list[str]:
if not self.databases_dir.exists():
return []
return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())
def db_path(self, db_id: str) -> Path:
path = self.databases_dir / db_id / f"{db_id}.sqlite"
if not path.exists():
raise FileNotFoundError(f"Database file not found: {path}")
return path
def for_database(self, db_id: str) -> "ConnectionSchemaProvider":
"""Возвращает SchemaProvider, привязанный к конкретной БД."""
return ConnectionSchemaProvider(f"sqlite:///{self.db_path(db_id)}")
# ── Совместимость с предыдущим API SchemaRetriever ────────────────
def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableSchema]:
return self.for_database(db_id).get_tables(n_sample_rows=n_sample_rows)
def render_schema(self, db_id: str, include_samples: bool = True) -> str:
return self.for_database(db_id).render_schema(include_samples=include_samples)
# ──────────────────────────────────────────────────────────────────────
# Реализация 2 — произвольная БД по connection string
# ──────────────────────────────────────────────────────────────────────
class ConnectionSchemaProvider:
"""SchemaProvider для произвольной БД (SQLite/PostgreSQL/MySQL).
Делегирует чтение DbConnector'у, но возвращает объекты единого типа
``TableSchema``. Это нужно, чтобы один и тот же код в API и Streamlit
мог работать как с PAUQ-структурой, так и с боевой БД пользователя.
"""
def __init__(self, connection_string: str, n_sample_rows: int = 2):
# Импорт здесь, чтобы избежать кольцевой зависимости
# (db.connector → data.schema_provider в случае фасада).
from src.db.connector import DbConnector
self._connector = DbConnector(connection_string, n_sample_rows=n_sample_rows)
self.connection_string = self._connector.connection_string
# ── Базовые операции SchemaProvider ───────────────────────────────
def list_tables(self) -> list[str]:
return self._connector.list_tables()
def get_tables(self, n_sample_rows: int = 3) -> list[TableSchema]:
# DbConnector в текущей реализации использует свой n_sample_rows из ctor;
# для совместимости с протоколом — игнорируем параметр здесь, доверяя
# настройке коннектора. При желании можно завести setter.
raw = self._connector.get_schema(include_samples=n_sample_rows > 0)
return [
TableSchema(
name=t.name,
columns=[
ColumnSchema(
name=c.name, type=c.type,
nullable=c.nullable, primary_key=c.primary_key,
)
for c in t.columns
],
sample_rows=list(t.sample_rows),
)
for t in raw
]
def render_schema(self, include_samples: bool = True) -> str:
return render_tables(self.get_tables(n_sample_rows=2 if include_samples else 0),
include_samples=include_samples)
def test_connection(self) -> bool:
return self._connector.test_connection()