File size: 6,441 Bytes
cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f 8871df9 cc2ed2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """Постобработка SQL: чистка вывода модели, валидация и нормализация.
Соответствует разделу 2.5 пояснительной записки. Pipeline:
raw_output ──► strip_model_artifacts ──► is_valid_sql ──► sql | ""
Дополнительно модуль предоставляет:
is_select_only(sql) — AST-уровневый гвардейл против DDL/DML
перед выполнением сгенерированного запроса;
normalize_sql(sql) — каноническая форма для расчёта Exact Match
(совместима с evaluate_pauq.py).
"""
from __future__ import annotations
import logging
import re
import sqlglot
from sqlglot import exp
from sqlglot.errors import ParseError
logger = logging.getLogger(__name__)
# Ключевые слова, с которых может начинаться корректный SQL-запрос.
_SQL_START_KEYWORDS = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE")
_SQL_START_REGEX = re.compile(
r"\b(" + "|".join(_SQL_START_KEYWORDS) + r")\b",
flags=re.IGNORECASE,
)
_FENCE_REGEX = re.compile(r"```(?:sql)?\s*(.*?)```", flags=re.DOTALL | re.IGNORECASE)
_PREFIX_REGEX = re.compile(r"^\s*(?:SQL|Ответ|Answer)\s*:\s*", flags=re.IGNORECASE)
# Типы AST-узлов, которые мы считаем «осмысленными» SQL-запросами.
# sqlglot — лояльный парсер: 'garbage text' он распарсит как Column/Table.
# Без проверки isinstance такие случаи будут проходить is_valid_sql.
_VALID_ROOT_TYPES: tuple[type[exp.Expression], ...] = (
exp.Select,
exp.With,
exp.Insert,
exp.Update,
exp.Delete,
exp.Union,
exp.Intersect,
exp.Except,
)
def strip_model_artifacts(text: str) -> str:
"""Очищает вывод модели от markdown и пояснений до начала SQL-запроса.
Шаги:
1. Если ответ обёрнут в ```sql ... ``` — извлекается содержимое.
2. Удаляются префиксы вида «SQL:», «Ответ:», «Answer:».
3. Ищется первое вхождение SQL-ключевого слова, всё до него отбрасывается.
4. Берётся первый statement до первой точки с запятой включительно.
"""
fence = _FENCE_REGEX.search(text)
if fence:
text = fence.group(1)
text = _PREFIX_REGEX.sub("", text)
keyword_match = _SQL_START_REGEX.search(text)
if keyword_match:
text = text[keyword_match.start():]
text = text.strip()
if ";" in text:
head, _, _ = text.partition(";")
text = head.strip() + ";"
return text.strip()
def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
"""Проверяет, что строка — это валидный SQL-запрос.
Парсится через sqlglot и дополнительно проверяется, что корень AST —
это один из «осмысленных» типов запроса (SELECT/WITH/INSERT/UPDATE/
DELETE/UNION). Без проверки типа sqlglot принимает за SQL даже
случайные идентификаторы, потому что он лояльный парсер.
"""
if not sql or not sql.strip():
return False
try:
parsed = sqlglot.parse_one(sql, dialect=dialect)
except (ParseError, ValueError, TypeError) as e:
logger.debug("sqlglot не смог разобрать SQL: %s", e)
return False
if parsed is None:
return False
return isinstance(parsed, _VALID_ROOT_TYPES)
def is_select_only(sql: str, dialect: str = "sqlite") -> bool:
"""Возвращает True, если SQL — это SELECT (в т. ч. внутри WITH-CTE).
Используется как guardrail перед выполнением сгенерированного запроса
на реальной базе данных: модель не должна получить возможность вызвать
DROP/UPDATE/DELETE/INSERT, даже если такие конструкции синтаксически
корректны.
"""
if not sql or not sql.strip():
return False
try:
parsed = sqlglot.parse_one(sql, dialect=dialect)
except (ParseError, ValueError, TypeError):
return False
if parsed is None:
return False
if isinstance(parsed, exp.Select):
return True
if isinstance(parsed, exp.With):
return isinstance(parsed.this, exp.Select)
if isinstance(parsed, exp.Subquery):
return isinstance(parsed.this, exp.Select)
return False
def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
"""Каноническая форма для расчёта Exact Match.
Использует sqlglot с флагом ``normalize=True`` — это нормализует регистр
ключевых слов и идентификаторов. Результат приводится к верхнему регистру,
чтобы EM считался идентично эталонной реализации в ``evaluate_pauq.py``.
"""
try:
parsed = sqlglot.parse_one(sql, dialect=dialect)
return parsed.sql(dialect=dialect, normalize=True).upper()
except (ParseError, ValueError, TypeError):
return re.sub(r"\s+", " ", sql.upper()).strip().rstrip(";")
def postprocess(raw_output: str) -> str:
"""Полный pipeline постобработки вывода модели.
1. Чистка артефактов через :func:`strip_model_artifacts`.
2. Валидация через :func:`is_valid_sql`.
3. Возврат пустой строки при провале валидации.
Соответствует разделу 2.5 пояснительной записки.
"""
sql = strip_model_artifacts(raw_output)
if not is_valid_sql(sql):
logger.warning("postprocess отбросил невалидный SQL: %r", sql[:120])
return ""
return sql
|