| """Постобработка SQL: чистка вывода модели, валидация и нормализация. |
| |
| Соответствует разделу 2.5 пояснительной записки. Pipeline: |
| raw_output ──► strip_model_artifacts ──► is_valid_sql ──► sql | "" |
| |
| Дополнительно модуль предоставляет: |
| is_select_only(sql) — AST-уровневый гвардейл против DDL/DML |
| перед выполнением сгенерированного запроса; |
| normalize_sql(sql) — каноническая форма для расчёта Exact Match |
| (совместима с evaluate_pauq.py). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import re |
|
|
| import sqlglot |
| from sqlglot import exp |
| from sqlglot.errors import ParseError |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| _SQL_START_KEYWORDS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE") |
| _SQL_START_REGEX = re.compile( |
| r"\b(" + "|".join(_SQL_START_KEYWORDS) + r")\b", |
| flags=re.IGNORECASE, |
| ) |
| _FENCE_REGEX = re.compile(r"```(?:sql)?\s*(.*?)```", flags=re.DOTALL | re.IGNORECASE) |
| _PREFIX_REGEX = re.compile(r"^\s*(?:SQL|Ответ|Answer)\s*:\s*", flags=re.IGNORECASE) |
|
|
| |
| |
| |
| _VALID_ROOT_TYPES: tuple[type[exp.Expression], ...] = ( |
| exp.Select, |
| exp.With, |
| exp.Insert, |
| exp.Update, |
| exp.Delete, |
| exp.Union, |
| exp.Intersect, |
| exp.Except, |
| ) |
|
|
|
|
| def strip_model_artifacts(text: str) -> str: |
| """Очищает вывод модели от markdown и пояснений до начала SQL-запроса. |
| |
| Шаги: |
| 1. Если ответ обёрнут в ```sql ... ``` — извлекается содержимое. |
| 2. Удаляются префиксы вида «SQL:», «Ответ:», «Answer:». |
| 3. Ищется первое вхождение SQL-ключевого слова, всё до него отбрасывается. |
| 4. Берётся первый statement до первой точки с запятой включительно. |
| """ |
| fence = _FENCE_REGEX.search(text) |
| if fence: |
| text = fence.group(1) |
|
|
| text = _PREFIX_REGEX.sub("", text) |
|
|
| keyword_match = _SQL_START_REGEX.search(text) |
| if keyword_match: |
| text = text[keyword_match.start():] |
|
|
| text = text.strip() |
| if ";" in text: |
| head, _, _ = text.partition(";") |
| text = head.strip() + ";" |
|
|
| return text.strip() |
|
|
|
|
| def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool: |
| """Проверяет, что строка — это валидный SQL-запрос. |
| |
| Парсится через sqlglot и дополнительно проверяется, что корень AST — |
| это один из «осмысленных» типов запроса (SELECT/WITH/INSERT/UPDATE/ |
| DELETE/UNION). Без проверки типа sqlglot принимает за SQL даже |
| случайные идентификаторы, потому что он лояльный парсер. |
| """ |
| if not sql or not sql.strip(): |
| return False |
| try: |
| parsed = sqlglot.parse_one(sql, dialect=dialect) |
| except (ParseError, ValueError, TypeError) as e: |
| logger.debug("sqlglot не смог разобрать SQL: %s", e) |
| return False |
| if parsed is None: |
| return False |
| return isinstance(parsed, _VALID_ROOT_TYPES) |
|
|
|
|
| def is_select_only(sql: str, dialect: str = "sqlite") -> bool: |
| """Возвращает True, если SQL — это SELECT (в т. ч. внутри WITH-CTE). |
| |
| Используется как guardrail перед выполнением сгенерированного запроса |
| на реальной базе данных: модель не должна получить возможность вызвать |
| DROP/UPDATE/DELETE/INSERT, даже если такие конструкции синтаксически |
| корректны. |
| """ |
| if not sql or not sql.strip(): |
| return False |
| try: |
| parsed = sqlglot.parse_one(sql, dialect=dialect) |
| except (ParseError, ValueError, TypeError): |
| return False |
| if parsed is None: |
| return False |
| if isinstance(parsed, exp.Select): |
| return True |
| if isinstance(parsed, exp.With): |
| return isinstance(parsed.this, exp.Select) |
| if isinstance(parsed, exp.Subquery): |
| return isinstance(parsed.this, exp.Select) |
| return False |
|
|
|
|
| def normalize_sql(sql: str, dialect: str = "sqlite") -> str: |
| """Каноническая форма для расчёта Exact Match. |
| |
| Использует sqlglot с флагом ``normalize=True`` — это нормализует регистр |
| ключевых слов и идентификаторов. Результат приводится к верхнему регистру, |
| чтобы EM считался идентично эталонной реализации в ``evaluate_pauq.py``. |
| """ |
| try: |
| parsed = sqlglot.parse_one(sql, dialect=dialect) |
| return parsed.sql(dialect=dialect, normalize=True).upper() |
| except (ParseError, ValueError, TypeError): |
| return re.sub(r"\s+", " ", sql.upper()).strip().rstrip(";") |
|
|
|
|
| def postprocess(raw_output: str) -> str: |
| """Полный pipeline постобработки вывода модели. |
| |
| 1. Чистка артефактов через :func:`strip_model_artifacts`. |
| 2. Валидация через :func:`is_valid_sql`. |
| 3. Возврат пустой строки при провале валидации. |
| |
| Соответствует разделу 2.5 пояснительной записки. |
| """ |
| sql = strip_model_artifacts(raw_output) |
| if not is_valid_sql(sql): |
| logger.warning("postprocess отбросил невалидный SQL: %r", sql[:120]) |
| return "" |
| return sql |
|
|