| """DbConnector — подключение к произвольной БД и чтение схемы. |
| |
| Поддерживаемые типы БД: |
| SQLite — путь к файлу: "sqlite:///path/to/db.sqlite" или просто путь |
| PostgreSQL — "postgresql://user:pass@host:port/dbname" (требует psycopg2) |
| MySQL — "mysql://user:pass@host:port/dbname" (требует pymysql) |
| |
| Пример: |
| conn = DbConnector("sqlite:///data/demo/sales.sqlite") |
| print(conn.render_schema()) |
| tables = conn.list_tables() |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import sqlite3 |
| from dataclasses import dataclass, field |
| from pathlib import Path |
| from urllib.parse import urlparse |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class ColumnInfo: |
| name: str |
| type: str |
| nullable: bool = True |
| primary_key: bool = False |
|
|
|
|
| @dataclass |
| class TableInfo: |
| name: str |
| columns: list[ColumnInfo] = field(default_factory=list) |
| sample_rows: list[tuple] = field(default_factory=list) |
|
|
| def to_ddl(self) -> str: |
| """Генерирует CREATE TABLE statement из метаданных.""" |
| col_parts = [] |
| for col in self.columns: |
| line = f" {col.name} {col.type}" |
| if col.primary_key: |
| line += " PRIMARY KEY" |
| if not col.nullable: |
| line += " NOT NULL" |
| col_parts.append(line) |
| return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);" |
|
|
|
|
| class DbConnector: |
| """Универсальный коннектор к БД. Читает схему для подстановки в промпт.""" |
|
|
| def __init__(self, connection_string: str, n_sample_rows: int = 2): |
| self.connection_string = self._normalize(connection_string) |
| self.n_sample_rows = n_sample_rows |
| self._db_type = self._detect_type(self.connection_string) |
|
|
| def list_tables(self) -> list[str]: |
| return [t.name for t in self._get_tables(n_sample_rows=0)] |
|
|
| def get_schema(self, include_samples: bool = True) -> list[TableInfo]: |
| return self._get_tables(n_sample_rows=self.n_sample_rows if include_samples else 0) |
|
|
| def render_schema(self, include_samples: bool = True) -> str: |
| tables = self.get_schema(include_samples=include_samples) |
| parts: list[str] = [] |
| for t in tables: |
| parts.append(t.to_ddl()) |
| if include_samples and t.sample_rows: |
| parts.append(f"-- Примеры строк из {t.name}:") |
| for row in t.sample_rows: |
| parts.append(f"-- {row}") |
| parts.append("") |
| return "\n".join(parts).strip() |
|
|
| def test_connection(self) -> bool: |
| try: |
| self._get_tables(n_sample_rows=0) |
| return True |
| except Exception as e: |
| logger.warning("Подключение к БД не удалось: %s", e) |
| return False |
|
|
| def _get_tables(self, n_sample_rows: int) -> list[TableInfo]: |
| if self._db_type == "sqlite": |
| return self._get_tables_sqlite(n_sample_rows) |
| elif self._db_type == "postgresql": |
| return self._get_tables_postgres(n_sample_rows) |
| elif self._db_type == "mysql": |
| return self._get_tables_mysql(n_sample_rows) |
| else: |
| raise ValueError(f"Неизвестный тип БД: {self._db_type}") |
|
|
| def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]: |
| """SQLite-подключение в режиме read-only через URI. |
| |
| immutable=1 говорит SQLite, что файл не изменяется во время сессии, |
| поэтому journal/WAL-файлы можно игнорировать. Это убирает прежнюю |
| логику с копированием БД во временную директорию и заодно даёт |
| guardrail-уровень безопасности: любая модифицирующая операция |
| на таком соединении завершится ошибкой. |
| """ |
| path = self._sqlite_path() |
| conn = sqlite3.connect(self._sqlite_uri(path), uri=True) |
| conn.text_factory = lambda b: b.decode("utf-8", errors="replace") |
| try: |
| cur = conn.cursor() |
| cur.execute( |
| "SELECT name FROM sqlite_master " |
| "WHERE type='table' AND name NOT LIKE 'sqlite_%' " |
| "ORDER BY name" |
| ) |
| table_names = [r[0] for r in cur.fetchall()] |
| tables: list[TableInfo] = [] |
| for name in table_names: |
| cur.execute(f'PRAGMA table_info("{name}")') |
| cols = [ |
| ColumnInfo( |
| name=row[1], |
| type=row[2] or "TEXT", |
| nullable=not row[3], |
| primary_key=bool(row[5]), |
| ) |
| for row in cur.fetchall() |
| ] |
| samples: list[tuple] = [] |
| if n_sample_rows > 0: |
| try: |
| cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}') |
| samples = cur.fetchall() |
| except sqlite3.Error as e: |
| logger.debug("Не удалось получить sample-строки для %s: %s", |
| name, e) |
| tables.append(TableInfo(name=name, columns=cols, sample_rows=samples)) |
| return tables |
| finally: |
| conn.close() |
|
|
| def _get_tables_postgres(self, n_sample_rows: int) -> list[TableInfo]: |
| try: |
| import psycopg2 |
| except ImportError as e: |
| raise ImportError("Установи psycopg2: pip install psycopg2-binary") from e |
|
|
| conn = psycopg2.connect(self.connection_string) |
| try: |
| cur = conn.cursor() |
| cur.execute( |
| "SELECT table_name FROM information_schema.tables " |
| "WHERE table_schema = 'public' AND table_type = 'BASE TABLE' " |
| "ORDER BY table_name" |
| ) |
| table_names = [r[0] for r in cur.fetchall()] |
| tables: list[TableInfo] = [] |
| for name in table_names: |
| cur.execute( |
| "SELECT column_name, data_type, is_nullable " |
| "FROM information_schema.columns " |
| "WHERE table_name = %s AND table_schema = 'public' " |
| "ORDER BY ordinal_position", |
| (name,), |
| ) |
| cols = [ |
| ColumnInfo(name=r[0], type=r[1], nullable=(r[2] == "YES")) |
| for r in cur.fetchall() |
| ] |
| samples: list[tuple] = [] |
| if n_sample_rows > 0: |
| cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}') |
| samples = cur.fetchall() |
| tables.append(TableInfo(name=name, columns=cols, sample_rows=samples)) |
| return tables |
| finally: |
| conn.close() |
|
|
| def _get_tables_mysql(self, n_sample_rows: int) -> list[TableInfo]: |
| try: |
| import pymysql |
| except ImportError as e: |
| raise ImportError("Установи pymysql: pip install pymysql") from e |
|
|
| parsed = urlparse(self.connection_string) |
| conn = pymysql.connect( |
| host=parsed.hostname, |
| port=parsed.port or 3306, |
| user=parsed.username, |
| password=parsed.password, |
| database=parsed.path.lstrip("/"), |
| ) |
| try: |
| cur = conn.cursor() |
| cur.execute("SHOW TABLES") |
| table_names = [r[0] for r in cur.fetchall()] |
| tables: list[TableInfo] = [] |
| for name in table_names: |
| cur.execute(f"DESCRIBE `{name}`") |
| cols = [ |
| ColumnInfo( |
| name=r[0], type=r[1], |
| nullable=(r[2] == "YES"), |
| primary_key=(r[3] == "PRI"), |
| ) |
| for r in cur.fetchall() |
| ] |
| samples: list[tuple] = [] |
| if n_sample_rows > 0: |
| cur.execute(f"SELECT * FROM `{name}` LIMIT {n_sample_rows}") |
| samples = cur.fetchall() |
| tables.append(TableInfo(name=name, columns=cols, sample_rows=samples)) |
| return tables |
| finally: |
| conn.close() |
|
|
| def _sqlite_path(self) -> Path: |
| cs = self.connection_string |
| if cs.startswith("sqlite:///"): |
| return Path(cs[10:]) |
| return Path(cs) |
|
|
| @staticmethod |
| def _sqlite_uri(path: Path) -> str: |
| """Read-only URI для SQLite с игнорированием journal/WAL.""" |
| return f"file:{path}?mode=ro&immutable=1" |
|
|
| @staticmethod |
| def _normalize(cs: str) -> str: |
| """Если передан просто путь к файлу — превращаем в sqlite:// URI. |
| |
| Если строка уже выглядит как URI (sqlite/postgres/mysql) — |
| возвращаем как есть. Без этой проверки сценарий «передали |
| корректный sqlite:///path» приводил к двойной нормализации |
| и подключению к несуществующему пути. |
| """ |
| cs = cs.strip() |
| if cs.startswith(("sqlite:", "postgres", "mysql")): |
| return cs |
| if cs.endswith(".sqlite") or cs.endswith(".db"): |
| return f"sqlite:///{cs}" |
| return cs |
|
|
| @staticmethod |
| def _detect_type(cs: str) -> str: |
| if cs.startswith("sqlite"): |
| return "sqlite" |
| if cs.startswith("postgresql") or cs.startswith("postgres"): |
| return "postgresql" |
| if cs.startswith("mysql"): |
| return "mysql" |
| raise ValueError(f"Не удалось определить тип БД: {cs}") |
|
|