Ru2SQL / tests /test_postprocess.py
Tyycha's picture
fix bugs
cc2ed2f
"""ВСсты Π½Π° постобработку SQL ΠΈ связанныС Ρ„ΡƒΠ½ΠΊΡ†ΠΈΠΈ.
ΠŸΠΎΠΊΡ€Ρ‹Π²Π°Π΅Ρ‚ Ρ€Π°Π·Π΄Π΅Π» 2.5 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки: чистку Π°Ρ€Ρ‚Π΅Ρ„Π°ΠΊΡ‚ΠΎΠ²,
Π²Π°Π»ΠΈΠ΄Π°Ρ†ΠΈΡŽ Ρ‡Π΅Ρ€Π΅Π· sqlglot, Π½ΠΎΡ€ΠΌΠ°Π»ΠΈΠ·Π°Ρ†ΠΈΡŽ для Exact Match ΠΈ AST-ΡƒΡ€ΠΎΠ²Π½Π΅Π²Ρ‹ΠΉ
Π³Π²Π°Ρ€Π΄Π΅ΠΉΠ» is_select_only.
"""
from src.models.postprocess import (
is_select_only,
is_valid_sql,
normalize_sql,
postprocess,
strip_model_artifacts,
)
# ──────────────────────────────────────────────────────────────────────
# strip_model_artifacts
# ──────────────────────────────────────────────────────────────────────
def test_strip_markdown_block_with_lang():
raw = "```sql\nSELECT * FROM users;\n```"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_markdown_block_without_lang():
raw = "```\nSELECT id FROM t;\n```"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_sql_prefix():
raw = "SQL: SELECT 1;"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_russian_prefix():
raw = "ΠžΡ‚Π²Π΅Ρ‚: SELECT name FROM students;"
assert strip_model_artifacts(raw).upper().startswith("SELECT")
def test_strip_natural_language_before_select():
raw = "Π’ΠΎΡ‚ SQL, ΠΊΠΎΡ‚ΠΎΡ€Ρ‹ΠΉ ΠΎΡ‚Π²Π΅Ρ‡Π°Π΅Ρ‚ Π½Π° вопрос: SELECT * FROM t WHERE id = 1;"
out = strip_model_artifacts(raw)
assert out.upper().startswith("SELECT")
assert "Π’ΠΎΡ‚" not in out
def test_keeps_first_statement_of_two():
raw = "SELECT 1; SELECT 2;"
out = strip_model_artifacts(raw)
assert "SELECT 1" in out
assert "SELECT 2" not in out
def test_with_cte_is_preserved():
raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
out = strip_model_artifacts(raw)
assert out.upper().startswith("WITH")
def test_strip_returns_empty_on_garbage():
# НСт Π½ΠΈ ΠΎΠ΄Π½ΠΎΠ³ΠΎ SQL-ΠΊΠ»ΡŽΡ‡Π΅Π²ΠΎΠ³ΠΎ слова β€” ΠΎΠ±Ρ€Π΅Π·Π°Ρ‚ΡŒ Π½Π΅Ρ‡Π΅Π³ΠΎ, Π½ΠΎ ΠΈ пустого
# ΠΎΡ‚Π²Π΅Ρ‚Π° модСль Π΅Ρ‰Ρ‘ Π½Π΅ Π½Π°Π³Π΅Π½Π΅Ρ€ΠΈΠ»Π°: Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Π΅ΠΌ ΠΊΠ°ΠΊ Π΅ΡΡ‚ΡŒ, валидация
# отсССт дальшС ΠΏΠΎ ΠΏΠ°ΠΉΠΏΠ»Π°ΠΉΠ½Ρƒ.
raw = "просто тСкст Π±Π΅Π· запроса"
assert strip_model_artifacts(raw) == "просто тСкст Π±Π΅Π· запроса"
# ──────────────────────────────────────────────────────────────────────
# is_valid_sql
# ──────────────────────────────────────────────────────────────────────
def test_valid_select():
assert is_valid_sql("SELECT * FROM students WHERE id = 1")
def test_valid_with_cte():
assert is_valid_sql("WITH x AS (SELECT id FROM t) SELECT * FROM x")
def test_invalid_garbage():
assert not is_valid_sql("SELEC * FRM where")
def test_invalid_empty():
assert not is_valid_sql("")
assert not is_valid_sql(" ")
# ──────────────────────────────────────────────────────────────────────
# is_select_only β€” guardrail
# ──────────────────────────────────────────────────────────────────────
def test_select_passes_guardrail():
assert is_select_only("SELECT id FROM t")
def test_with_cte_passes_guardrail():
assert is_select_only("WITH x AS (SELECT id FROM t) SELECT * FROM x")
def test_drop_table_blocked():
assert not is_select_only("DROP TABLE users")
def test_delete_blocked():
assert not is_select_only("DELETE FROM users WHERE id = 1")
def test_update_blocked():
assert not is_select_only("UPDATE users SET name = 'a' WHERE id = 1")
def test_insert_blocked():
assert not is_select_only("INSERT INTO users (id, name) VALUES (1, 'a')")
def test_empty_blocked():
assert not is_select_only("")
assert not is_select_only(" ")
def test_invalid_sql_blocked_by_guardrail():
# На Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΉ строкС is_select_only Π΄ΠΎΠ»ΠΆΠ΅Π½ чСстно Π²ΠΎΠ·Π²Ρ€Π°Ρ‰Π°Ρ‚ΡŒ False,
# Π° Π½Π΅ ΠΏΠ°Π΄Π°Ρ‚ΡŒ с ΠΈΡΠΊΠ»ΡŽΡ‡Π΅Π½ΠΈΠ΅ΠΌ.
assert not is_select_only("not a sql at all")
# ──────────────────────────────────────────────────────────────────────
# normalize_sql
# ──────────────────────────────────────────────────────────────────────
def test_normalize_collapses_whitespace():
a = "SELECT * FROM Users"
b = "select * from users"
assert normalize_sql(a) == normalize_sql(b)
def test_normalize_idempotent():
sql = "SELECT id FROM t WHERE x = 1"
assert normalize_sql(normalize_sql(sql)) == normalize_sql(sql)
def test_normalize_fallback_on_invalid():
# На Π½Π΅Π²Π°Π»ΠΈΠ΄Π½ΠΎΠΌ SQL функция Π½Π΅ Π΄ΠΎΠ»ΠΆΠ½Π° ΠΏΠ°Π΄Π°Ρ‚ΡŒ β€” Π΄ΠΎΠ»ΠΆΠ΅Π½ ΡΡ€Π°Π±ΠΎΡ‚Π°Ρ‚ΡŒ fallback.
out = normalize_sql("not really sql")
assert isinstance(out, str)
assert out.upper() == out # Π²Π΅Ρ€Ρ…Π½ΠΈΠΉ рСгистр сохранён
# ──────────────────────────────────────────────────────────────────────
# postprocess β€” ΠΏΠΎΠ»Π½Ρ‹ΠΉ pipeline
# ──────────────────────────────────────────────────────────────────────
def test_postprocess_extracts_from_markdown():
raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
out = postprocess(raw)
assert out.upper().startswith("SELECT NAME") or out.startswith("SELECT name")
assert "SELECT 2" not in out
def test_postprocess_returns_empty_on_invalid():
# ВСкст Π½Π΅ содСрТит Π²Π°Π»ΠΈΠ΄Π½ΠΎΠ³ΠΎ SQL β€” pipeline Π΄ΠΎΠ»ΠΆΠ΅Π½ Π²Π΅Ρ€Π½ΡƒΡ‚ΡŒ ΠΏΡƒΡΡ‚ΡƒΡŽ строку,
# ΠΊΠ°ΠΊ описано Π² Ρ€Π°Π·Π΄Π΅Π»Π΅ 2.5 ΠΏΠΎΡΡΠ½ΠΈΡ‚Π΅Π»ΡŒΠ½ΠΎΠΉ записки.
raw = "Π― Π½Π΅ ΠΌΠΎΠ³Ρƒ ΡΠ³Π΅Π½Π΅Ρ€ΠΈΡ€ΠΎΠ²Π°Ρ‚ΡŒ SQL для этого вопроса."
assert postprocess(raw) == ""
def test_postprocess_returns_empty_on_truncated():
# МодСль ΠΎΠ±ΠΎΡ€Π²Π°Π»Π° Π³Π΅Π½Π΅Ρ€Π°Ρ†ΠΈΡŽ Π½Π° сСрСдинС запроса β€” Π½Π΅Π²Π°Π»ΠΈΠ΄Π½Ρ‹ΠΉ синтаксис.
raw = "SELECT * FROM users WHERE"
assert postprocess(raw) == ""
def test_postprocess_keeps_valid_with_cte():
raw = "WITH agg AS (SELECT id FROM t) SELECT * FROM agg"
out = postprocess(raw)
assert out.upper().startswith("WITH")