Ru2SQL / src /evaluation /metrics.py
Tyycha's picture
initial commit
8871df9
"""Метрики Text-to-SQL: Exact Match и Execution Accuracy."""
from __future__ import annotations
import sqlite3
from pathlib import Path
from src.models.postprocess import normalize_sql
def exact_match(predicted: str, gold: str, dialect: str = "sqlite") -> bool:
"""Сравнение нормализованных SQL посимвольно. Грубая, но честная метрика."""
return normalize_sql(predicted, dialect) == normalize_sql(gold, dialect)
def execution_accuracy(
predicted_sql: str,
gold_sql: str,
db_path: Path | str,
timeout_seconds: float = 5.0,
) -> bool:
"""Прогон обоих SQL на SQLite. True если результаты совпадают как множества."""
db_path = Path(db_path)
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=timeout_seconds)
try:
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
try:
pred_rows = _run(conn, predicted_sql)
except sqlite3.Error:
return False
try:
gold_rows = _run(conn, gold_sql)
except sqlite3.Error:
return False
return _rows_equal(pred_rows, gold_rows)
finally:
conn.close()
def _run(conn: sqlite3.Connection, sql: str) -> list[tuple]:
cur = conn.cursor()
cur.execute(sql)
return cur.fetchall()
def _rows_equal(a: list[tuple], b: list[tuple]) -> bool:
"""Сравнение как мультимножеств — порядок не важен (если в SQL нет ORDER BY)."""
if len(a) != len(b):
return False
return sorted(map(_row_key, a)) == sorted(map(_row_key, b))
def _row_key(row: tuple) -> tuple:
return tuple(str(x) for x in row)
def compute_metrics(
predictions: list[str],
golds: list[str],
db_ids: list[str],
databases_dir: Path | str,
) -> dict:
"""Прогон по всему датасету. Возвращает dict с EM, EX, и счётчиками."""
databases_dir = Path(databases_dir)
n = len(predictions)
assert n == len(golds) == len(db_ids), "Mismatched lengths"
em_count = 0
ex_count = 0
parse_fail = 0
for pred, gold, db_id in zip(predictions, golds, db_ids):
if exact_match(pred, gold):
em_count += 1
db_path = databases_dir / db_id / f"{db_id}.sqlite"
if not db_path.exists():
parse_fail += 1
continue
if execution_accuracy(pred, gold, db_path):
ex_count += 1
return {
"n": n,
"exact_match": em_count / n if n else 0.0,
"execution_accuracy": ex_count / n if n else 0.0,
"parse_fail": parse_fail,
}