Ru2SQL / src /db /connector.py
Tyycha's picture
fix bugs
cc2ed2f
"""DbConnector — подключение к произвольной БД и чтение схемы.
Поддерживаемые типы БД:
SQLite — путь к файлу: "sqlite:///path/to/db.sqlite" или просто путь
PostgreSQL — "postgresql://user:pass@host:port/dbname" (требует psycopg2)
MySQL — "mysql://user:pass@host:port/dbname" (требует pymysql)
Пример:
conn = DbConnector("sqlite:///data/demo/sales.sqlite")
print(conn.render_schema())
tables = conn.list_tables()
"""
from __future__ import annotations
import logging
import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
@dataclass
class ColumnInfo:
name: str
type: str
nullable: bool = True
primary_key: bool = False
@dataclass
class TableInfo:
name: str
columns: list[ColumnInfo] = field(default_factory=list)
sample_rows: list[tuple] = field(default_factory=list)
def to_ddl(self) -> str:
"""Генерирует CREATE TABLE statement из метаданных."""
col_parts = []
for col in self.columns:
line = f" {col.name} {col.type}"
if col.primary_key:
line += " PRIMARY KEY"
if not col.nullable:
line += " NOT NULL"
col_parts.append(line)
return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);"
class DbConnector:
"""Универсальный коннектор к БД. Читает схему для подстановки в промпт."""
def __init__(self, connection_string: str, n_sample_rows: int = 2):
self.connection_string = self._normalize(connection_string)
self.n_sample_rows = n_sample_rows
self._db_type = self._detect_type(self.connection_string)
def list_tables(self) -> list[str]:
return [t.name for t in self._get_tables(n_sample_rows=0)]
def get_schema(self, include_samples: bool = True) -> list[TableInfo]:
return self._get_tables(n_sample_rows=self.n_sample_rows if include_samples else 0)
def render_schema(self, include_samples: bool = True) -> str:
tables = self.get_schema(include_samples=include_samples)
parts: list[str] = []
for t in tables:
parts.append(t.to_ddl())
if include_samples and t.sample_rows:
parts.append(f"-- Примеры строк из {t.name}:")
for row in t.sample_rows:
parts.append(f"-- {row}")
parts.append("")
return "\n".join(parts).strip()
def test_connection(self) -> bool:
try:
self._get_tables(n_sample_rows=0)
return True
except Exception as e: # noqa: BLE001
logger.warning("Подключение к БД не удалось: %s", e)
return False
def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
if self._db_type == "sqlite":
return self._get_tables_sqlite(n_sample_rows)
elif self._db_type == "postgresql":
return self._get_tables_postgres(n_sample_rows)
elif self._db_type == "mysql":
return self._get_tables_mysql(n_sample_rows)
else:
raise ValueError(f"Неизвестный тип БД: {self._db_type}")
def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
"""SQLite-подключение в режиме read-only через URI.
immutable=1 говорит SQLite, что файл не изменяется во время сессии,
поэтому journal/WAL-файлы можно игнорировать. Это убирает прежнюю
логику с копированием БД во временную директорию и заодно даёт
guardrail-уровень безопасности: любая модифицирующая операция
на таком соединении завершится ошибкой.
"""
path = self._sqlite_path()
conn = sqlite3.connect(self._sqlite_uri(path), uri=True)
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
try:
cur = conn.cursor()
cur.execute(
"SELECT name FROM sqlite_master "
"WHERE type='table' AND name NOT LIKE 'sqlite_%' "
"ORDER BY name"
)
table_names = [r[0] for r in cur.fetchall()]
tables: list[TableInfo] = []
for name in table_names:
cur.execute(f'PRAGMA table_info("{name}")')
cols = [
ColumnInfo(
name=row[1],
type=row[2] or "TEXT",
nullable=not row[3],
primary_key=bool(row[5]),
)
for row in cur.fetchall()
]
samples: list[tuple] = []
if n_sample_rows > 0:
try:
cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
samples = cur.fetchall()
except sqlite3.Error as e:
logger.debug("Не удалось получить sample-строки для %s: %s",
name, e)
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
return tables
finally:
conn.close()
def _get_tables_postgres(self, n_sample_rows: int) -> list[TableInfo]:
try:
import psycopg2 # type: ignore
except ImportError as e:
raise ImportError("Установи psycopg2: pip install psycopg2-binary") from e
conn = psycopg2.connect(self.connection_string)
try:
cur = conn.cursor()
cur.execute(
"SELECT table_name FROM information_schema.tables "
"WHERE table_schema = 'public' AND table_type = 'BASE TABLE' "
"ORDER BY table_name"
)
table_names = [r[0] for r in cur.fetchall()]
tables: list[TableInfo] = []
for name in table_names:
cur.execute(
"SELECT column_name, data_type, is_nullable "
"FROM information_schema.columns "
"WHERE table_name = %s AND table_schema = 'public' "
"ORDER BY ordinal_position",
(name,),
)
cols = [
ColumnInfo(name=r[0], type=r[1], nullable=(r[2] == "YES"))
for r in cur.fetchall()
]
samples: list[tuple] = []
if n_sample_rows > 0:
cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
samples = cur.fetchall()
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
return tables
finally:
conn.close()
def _get_tables_mysql(self, n_sample_rows: int) -> list[TableInfo]:
try:
import pymysql # type: ignore
except ImportError as e:
raise ImportError("Установи pymysql: pip install pymysql") from e
parsed = urlparse(self.connection_string)
conn = pymysql.connect(
host=parsed.hostname,
port=parsed.port or 3306,
user=parsed.username,
password=parsed.password,
database=parsed.path.lstrip("/"),
)
try:
cur = conn.cursor()
cur.execute("SHOW TABLES")
table_names = [r[0] for r in cur.fetchall()]
tables: list[TableInfo] = []
for name in table_names:
cur.execute(f"DESCRIBE `{name}`")
cols = [
ColumnInfo(
name=r[0], type=r[1],
nullable=(r[2] == "YES"),
primary_key=(r[3] == "PRI"),
)
for r in cur.fetchall()
]
samples: list[tuple] = []
if n_sample_rows > 0:
cur.execute(f"SELECT * FROM `{name}` LIMIT {n_sample_rows}")
samples = cur.fetchall()
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
return tables
finally:
conn.close()
def _sqlite_path(self) -> Path:
cs = self.connection_string
if cs.startswith("sqlite:///"):
return Path(cs[10:])
return Path(cs)
@staticmethod
def _sqlite_uri(path: Path) -> str:
"""Read-only URI для SQLite с игнорированием journal/WAL."""
return f"file:{path}?mode=ro&immutable=1"
@staticmethod
def _normalize(cs: str) -> str:
"""Если передан просто путь к файлу — превращаем в sqlite:// URI.
Если строка уже выглядит как URI (sqlite/postgres/mysql) —
возвращаем как есть. Без этой проверки сценарий «передали
корректный sqlite:///path» приводил к двойной нормализации
и подключению к несуществующему пути.
"""
cs = cs.strip()
if cs.startswith(("sqlite:", "postgres", "mysql")):
return cs
if cs.endswith(".sqlite") or cs.endswith(".db"):
return f"sqlite:///{cs}"
return cs
@staticmethod
def _detect_type(cs: str) -> str:
if cs.startswith("sqlite"):
return "sqlite"
if cs.startswith("postgresql") or cs.startswith("postgres"):
return "postgresql"
if cs.startswith("mysql"):
return "mysql"
raise ValueError(f"Не удалось определить тип БД: {cs}")