"""Метрики Text-to-SQL: Exact Match и Execution Accuracy.""" from __future__ import annotations import sqlite3 from pathlib import Path from src.models.postprocess import normalize_sql def exact_match(predicted: str, gold: str, dialect: str = "sqlite") -> bool: """Сравнение нормализованных SQL посимвольно. Грубая, но честная метрика.""" return normalize_sql(predicted, dialect) == normalize_sql(gold, dialect) def execution_accuracy( predicted_sql: str, gold_sql: str, db_path: Path | str, timeout_seconds: float = 5.0, ) -> bool: """Прогон обоих SQL на SQLite. True если результаты совпадают как множества.""" db_path = Path(db_path) conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=timeout_seconds) try: conn.text_factory = lambda b: b.decode("utf-8", errors="replace") try: pred_rows = _run(conn, predicted_sql) except sqlite3.Error: return False try: gold_rows = _run(conn, gold_sql) except sqlite3.Error: return False return _rows_equal(pred_rows, gold_rows) finally: conn.close() def _run(conn: sqlite3.Connection, sql: str) -> list[tuple]: cur = conn.cursor() cur.execute(sql) return cur.fetchall() def _rows_equal(a: list[tuple], b: list[tuple]) -> bool: """Сравнение как мультимножеств — порядок не важен (если в SQL нет ORDER BY).""" if len(a) != len(b): return False return sorted(map(_row_key, a)) == sorted(map(_row_key, b)) def _row_key(row: tuple) -> tuple: return tuple(str(x) for x in row) def compute_metrics( predictions: list[str], golds: list[str], db_ids: list[str], databases_dir: Path | str, ) -> dict: """Прогон по всему датасету. Возвращает dict с EM, EX, и счётчиками.""" databases_dir = Path(databases_dir) n = len(predictions) assert n == len(golds) == len(db_ids), "Mismatched lengths" em_count = 0 ex_count = 0 parse_fail = 0 for pred, gold, db_id in zip(predictions, golds, db_ids): if exact_match(pred, gold): em_count += 1 db_path = databases_dir / db_id / f"{db_id}.sqlite" if not db_path.exists(): parse_fail += 1 continue if execution_accuracy(pred, gold, db_path): ex_count += 1 return { "n": n, "exact_match": em_count / n if n else 0.0, "execution_accuracy": ex_count / n if n else 0.0, "parse_fail": parse_fail, }