Tyycha commited on
Commit
8871df9
·
0 Parent(s):

initial commit

Browse files
.env.example ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Скопируй в .env и заполни. .env в git не уходит.
2
+
3
+ # API ключ для baseline-сравнения (выбери одного провайдера)
4
+ GIGACHAT_API_KEY=
5
+ OPENAI_API_KEY=
6
+ YANDEXGPT_API_KEY=
7
+ YANDEXGPT_FOLDER_ID=
8
+
9
+ # HuggingFace (нужен для скачивания приватных адаптеров)
10
+ HF_TOKEN=
11
+
12
+ # Локальная модель
13
+ BASE_MODEL_NAME=Qwen/Qwen2.5-Coder-3B-Instruct
14
+ LORA_ADAPTER_PATH=./checkpoints/qwen-coder-pauq-lora
15
+ DEVICE=cpu
16
+
17
+ # Пути
18
+ PAUQ_DATA_DIR=./data/pauq
19
+ DATABASES_DIR=./data/databases
20
+
21
+ # API
22
+ API_HOST=127.0.0.1
23
+ API_PORT=8000
.gitignore ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ .eggs/
9
+ dist/
10
+ build/
11
+
12
+ # Virtual environments
13
+ .venv/
14
+ venv/
15
+ env/
16
+ .python-version
17
+
18
+ # uv
19
+ .uv/
20
+ uv.lock
21
+
22
+ # IDE
23
+ .vscode/
24
+ .idea/
25
+ *.swp
26
+ *.swo
27
+
28
+ # Jupyter
29
+ .ipynb_checkpoints/
30
+ *.ipynb_checkpoints
31
+
32
+ # Environment variables
33
+ .env
34
+ .env.local
35
+ .env.*.local
36
+
37
+ # ML artifacts
38
+ checkpoints/
39
+ wandb/
40
+ *.bin
41
+ *.safetensors
42
+ *.gguf
43
+
44
+ # Data
45
+ data/pauq/
46
+ data/databases/
47
+ data/processed/
48
+ data/*.json
49
+ data/*.sqlite
50
+ data/*.db
51
+ # Демо-база нужна в репозитории
52
+ !data/demo/sales.sqlite
53
+
54
+ # Logs
55
+ *.log
56
+ logs/
57
+
58
+ # OS
59
+ .DS_Store
60
+ Thumbs.db
61
+ desktop.ini
62
+
63
+ # Test artifacts
64
+ .pytest_cache/
65
+ .coverage
66
+ htmlcov/
67
+
68
+ # Outputs
69
+ outputs/
70
+ results/
README.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ru2SQL
3
+ emoji: 🗄️
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.35.0
8
+ app_file: streamlit_app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # ru2sql
13
+
14
+ Генеративная модель для преобразования вопросов на русском языке в SQL-запросы.
15
+ Практическая часть ВКР, направление «Программная инженерия», 4 курс.
16
+
17
+ **Стек:** Python 3.10+, PyTorch, transformers, PEFT (LoRA), FastAPI, sqlglot.
18
+ **Основная модель:** Qwen2.5-Coder-3B-Instruct, дообученная методом QLoRA на датасете PAUQ.
19
+ **Сравнение:** ruT5-base baseline + GigaChat API.
20
+
21
+ См. `plan_VKR_text2sql_ru.md` для полного плана работ на месяц.
22
+
23
+ ---
24
+
25
+ ## Быстрый старт (на десктопе)
26
+
27
+ ### 1. Установка
28
+
29
+ ```bash
30
+ # Установи uv (https://docs.astral.sh/uv/) если ещё нет
31
+ pip install uv
32
+
33
+ # Клонируй репозиторий и установи зависимости
34
+ git clone <твой-репо> ru2sql
35
+ cd ru2sql
36
+ uv venv
37
+ .venv\Scripts\activate # Windows
38
+ # source .venv/bin/activate # Linux/Mac
39
+ uv pip install -e ".[dev]"
40
+ ```
41
+
42
+ ### 2. Конфигурация
43
+
44
+ ```bash
45
+ copy .env.example .env # Windows
46
+ # cp .env.example .env # Linux/Mac
47
+ ```
48
+
49
+ Открой `.env` и заполни ключи (минимум `GIGACHAT_API_KEY` для baseline-сравнения, остальное опционально).
50
+
51
+ ### 3. Скачай PAUQ
52
+
53
+ ```bash
54
+ git clone https://github.com/ai-forever/pauq.git data/pauq_repo
55
+ # Затем разложи train.json/dev.json/test.json в data/pauq/
56
+ # и SQLite-файлы в data/databases/{db_id}/{db_id}.sqlite
57
+ ```
58
+
59
+ ### 4. Тесты
60
+
61
+ ```bash
62
+ pytest -v
63
+ ```
64
+
65
+ Тесты для модулей `prompt`, `postprocess`, `metrics`, `schema` должны проходить
66
+ без скачивания модели и датасета.
67
+
68
+ ### 5. Запуск API
69
+
70
+ ```bash
71
+ uvicorn src.api.main:app --reload
72
+ # Swagger UI: http://127.0.0.1:8000/docs
73
+ ```
74
+
75
+ При первом запуске модель Qwen2.5-Coder-3B (~6 GB) скачается из HuggingFace Hub.
76
+ На CPU инференс занимает 15–30 секунд на запрос — это ожидаемо.
77
+
78
+ ### 6. Запрос к API
79
+
80
+ ```bash
81
+ curl -X POST http://127.0.0.1:8000/generate-sql \
82
+ -H "Content-Type: application/json" \
83
+ -d '{"question": "Сколько студентов на факультете ПИ?", "db_id": "university"}'
84
+ ```
85
+
86
+ ---
87
+
88
+ ## Обучение модели
89
+
90
+ Тренировка идёт **в Kaggle Notebook** (бесплатный T4 GPU). Локально на CPU/AMD GPU
91
+ обучить 3B-модель не получится.
92
+
93
+ Шаги:
94
+ 1. Открой `notebooks/kaggle_train_qwen_qlora.ipynb` на kaggle.com.
95
+ 2. В Settings выбери Accelerator: GPU T4 x1 (или x2 для скорости).
96
+ 3. Add-ons → Secrets → добавь `HF_TOKEN` и `WANDB_API_KEY`.
97
+ 4. Запусти все ячейки. Тренировка ~4–6 часов.
98
+ 5. По завершении адаптер пушится на твой приватный HF-репо.
99
+ 6. Скачай его на десктоп:
100
+ ```bash
101
+ huggingface-cli download your-username/qwen-coder-pauq-lora \
102
+ --local-dir checkpoints/qwen-coder-pauq-lora
103
+ ```
104
+
105
+ После этого `LORA_ADAPTER_PATH` в `.env` укажет на скачанный адаптер,
106
+ и API будет использовать дообученную модель.
107
+
108
+ ---
109
+
110
+ ## Структура проекта
111
+
112
+ ```
113
+ ru2sql/
114
+ ├── pyproject.toml # зависимости (uv)
115
+ ├── .env.example # шаблон конфигурации
116
+ ├── plan_VKR_text2sql_ru.md # план работ на месяц
117
+ ├── notebooks/
118
+ │ └── kaggle_train_qwen_qlora.ipynb
119
+ ├── src/
120
+ │ ├── config.py # настройки через pydantic-settings
121
+ │ ├── data/
122
+ │ │ ├── loader.py # чтение PAUQ JSON
123
+ │ │ ├── schema.py # SchemaRetriever (DDL из SQLite)
124
+ │ │ └── prompt.py # PromptBuilder + chat-template
125
+ │ ├── models/
126
+ │ │ ├── inference.py # InferenceEngine (модель + LoRA)
127
+ │ │ └── postprocess.py # очистка SQL + sqlglot валидация
128
+ │ ├── evaluation/
129
+ │ │ ├── metrics.py # Exact Match + Execution Accuracy
130
+ │ │ └── evaluate.py # CLI для прогона на split'е
131
+ │ └── api/
132
+ │ ├── main.py # FastAPI app
133
+ │ ├── schemas.py # Pydantic-модели
134
+ │ └── dependencies.py # lifespan + DI
135
+ └── tests/
136
+ ├── test_prompt.py
137
+ ├── test_postprocess.py
138
+ ├── test_metrics.py
139
+ └── test_schema.py
140
+ ```
141
+
142
+ ---
143
+
144
+ ## Прогон оценки
145
+
146
+ ```bash
147
+ # Полный прогон на dev split
148
+ python -m src.evaluation.evaluate --split dev
149
+
150
+ # Быстрая проверка на 50 примерах
151
+ python -m src.evaluation.evaluate --split dev --limit 50
152
+ ```
153
+
154
+ Результат сохраняется в `results/predictions.jsonl`, метрики печатаются в stdout.
155
+
156
+ ---
157
+
158
+ ## Метрики (планируемые)
159
+
160
+ | Модель | EM | Execution Accuracy |
161
+ |---|---|---|
162
+ | ruT5-base (baseline) | 25–35% | 30–40% |
163
+ | **Qwen2.5-Coder-3B + QLoRA** | **50–60%** | **55–70%** |
164
+ | GigaChat API (zero-shot) | 55–70% | 65–80% |
165
+
166
+ ---
167
+
168
+ ## Что НЕ входит в MVP
169
+
170
+ Сознательно оставлено в раздел «направления дальнейшей работы»:
171
+ - Few-shot retrieval похожих примеров.
172
+ - Schema linking (автоматический отбор релевантных таблиц).
173
+ - Self-correction по ошибкам исполнения SQL.
174
+ - Constrained decoding (грамматика SQL).
175
+ - Дообучение на синтетических данных.
176
+
177
+ ---
178
+
179
+ ## Лицензия и атрибуция
180
+
181
+ Учебный проект. Использует:
182
+ - PAUQ — Apache 2.0, https://github.com/ai-forever/pauq
183
+ - Qwen2.5-Coder — Apache 2.0, https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct
184
+ - ruT5 — MIT, https://huggingface.co/ai-forever/ruT5-base
adapters/qwen-coder-pauq-lora/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
adapters/qwen-coder-pauq-lora/README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
adapters/qwen-coder-pauq-lora/adapter_config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alora_invocation_tokens": null,
3
+ "alpha_pattern": {},
4
+ "arrow_config": null,
5
+ "auto_mapping": null,
6
+ "base_model_name_or_path": "Qwen/Qwen2.5-Coder-3B-Instruct",
7
+ "bias": "none",
8
+ "corda_config": null,
9
+ "ensure_weight_tying": false,
10
+ "eva_config": null,
11
+ "exclude_modules": null,
12
+ "fan_in_fan_out": false,
13
+ "inference_mode": true,
14
+ "init_lora_weights": true,
15
+ "layer_replication": null,
16
+ "layers_pattern": null,
17
+ "layers_to_transform": null,
18
+ "loftq_config": {},
19
+ "lora_alpha": 32,
20
+ "lora_bias": false,
21
+ "lora_dropout": 0.05,
22
+ "lora_ga_config": null,
23
+ "megatron_config": null,
24
+ "megatron_core": "megatron.core",
25
+ "modules_to_save": null,
26
+ "peft_type": "LORA",
27
+ "peft_version": "0.19.1",
28
+ "qalora_group_size": 16,
29
+ "r": 16,
30
+ "rank_pattern": {},
31
+ "revision": null,
32
+ "target_modules": [
33
+ "up_proj",
34
+ "v_proj",
35
+ "gate_proj",
36
+ "k_proj",
37
+ "o_proj",
38
+ "q_proj",
39
+ "down_proj"
40
+ ],
41
+ "target_parameters": null,
42
+ "task_type": "CAUSAL_LM",
43
+ "trainable_token_indices": null,
44
+ "use_bdlora": null,
45
+ "use_dora": false,
46
+ "use_qalora": false,
47
+ "use_rslora": false
48
+ }
adapters/qwen-coder-pauq-lora/chat_template.jinja ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0]['role'] == 'system' %}
4
+ {{- messages[0]['content'] }}
5
+ {%- else %}
6
+ {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
7
+ {%- endif %}
8
+ {{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
9
+ {%- for tool in tools %}
10
+ {{- "\n" }}
11
+ {{- tool | tojson }}
12
+ {%- endfor %}
13
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
14
+ {%- else %}
15
+ {%- if messages[0]['role'] == 'system' %}
16
+ {{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
17
+ {%- else %}
18
+ {{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
19
+ {%- endif %}
20
+ {%- endif %}
21
+ {%- for message in messages %}
22
+ {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
23
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
24
+ {%- elif message.role == "assistant" %}
25
+ {{- '<|im_start|>' + message.role }}
26
+ {%- if message.content %}
27
+ {{- '\n' + message.content }}
28
+ {%- endif %}
29
+ {%- for tool_call in message.tool_calls %}
30
+ {%- if tool_call.function is defined %}
31
+ {%- set tool_call = tool_call.function %}
32
+ {%- endif %}
33
+ {{- '\n<tool_call>\n{"name": "' }}
34
+ {{- tool_call.name }}
35
+ {{- '", "arguments": ' }}
36
+ {{- tool_call.arguments | tojson }}
37
+ {{- '}\n</tool_call>' }}
38
+ {%- endfor %}
39
+ {{- '<|im_end|>\n' }}
40
+ {%- elif message.role == "tool" %}
41
+ {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
42
+ {{- '<|im_start|>user' }}
43
+ {%- endif %}
44
+ {{- '\n<tool_response>\n' }}
45
+ {{- message.content }}
46
+ {{- '\n</tool_response>' }}
47
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
48
+ {{- '<|im_end|>\n' }}
49
+ {%- endif %}
50
+ {%- endif %}
51
+ {%- endfor %}
52
+ {%- if add_generation_prompt %}
53
+ {{- '<|im_start|>assistant\n' }}
54
+ {%- endif %}
adapters/qwen-coder-pauq-lora/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fd169731d2cbde95e10bf356d66d5997fd885dd8dbb6fb4684da3f23b2585d8
3
+ size 11421892
adapters/qwen-coder-pauq-lora/tokenizer_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": null,
5
+ "clean_up_tokenization_spaces": false,
6
+ "eos_token": "<|im_end|>",
7
+ "errors": "replace",
8
+ "extra_special_tokens": [
9
+ "<|im_start|>",
10
+ "<|im_end|>",
11
+ "<|object_ref_start|>",
12
+ "<|object_ref_end|>",
13
+ "<|box_start|>",
14
+ "<|box_end|>",
15
+ "<|quad_start|>",
16
+ "<|quad_end|>",
17
+ "<|vision_start|>",
18
+ "<|vision_end|>",
19
+ "<|vision_pad|>",
20
+ "<|image_pad|>",
21
+ "<|video_pad|>"
22
+ ],
23
+ "is_local": false,
24
+ "local_files_only": false,
25
+ "model_max_length": 32768,
26
+ "pad_token": "<|endoftext|>",
27
+ "split_special_tokens": false,
28
+ "tokenizer_class": "Qwen2Tokenizer",
29
+ "unk_token": null
30
+ }
configs/example_vocabulary.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Бизнес-словарь компании — пример заполнения
2
+ # Скопируй этот файл, переименуй под свою компанию и заполни своими терминами.
3
+ # Путь к файлу указывается при запуске утилиты.
4
+
5
+ company: "ООО Ромашка"
6
+
7
+ # Бизнес-термины и метрики
8
+ # Ключ — слово/фраза как говорит аналитик
9
+ # Значение — что это означает в терминах SQL / данных
10
+ terms:
11
+ выручка: "SUM(orders.amount) при условии orders.status = 'paid'"
12
+ оборот: "SUM(orders.amount) по всем заказам включая отменённые"
13
+ активный клиент: "клиент, совершивший хотя бы одну покупку за последние 90 дней"
14
+ новый клиент: "клиент, зарегистрированный менее 30 дней назад"
15
+ этот год: "YEAR(order_date) совпадает с текущим годом"
16
+ прошлый месяц: "месяц предшествующий текущему"
17
+ этот квартал: "текущий квартал календарного года (Q1=янв-март, Q2=апр-июн и т.д.)"
18
+ средний чек: "AVG(orders.amount) по оплаченным заказам"
19
+ конверсия: "доля оплаченных заказов от общего числа"
20
+
21
+ # Стандартные условия фильтрации (применяются по умолчанию если аналитик явно не указал иное)
22
+ filters:
23
+ только_оплаченные: "orders.status = 'paid'"
24
+ без_возвратов: "orders.is_return = 0 или orders.is_return IS NULL"
25
+ только_активные_товары: "products.is_active = 1"
26
+
27
+ # Дополнительные правила и особенности схемы
28
+ notes:
29
+ - "Таблица orders содержит все заказы. Колонка amount — сумма в рублях."
30
+ - "Клиенты хранятся в таблице customers, товары — в products."
31
+ - "Связь заказ-товар через таблицу order_items (order_id, product_id, quantity, price)."
32
+ - "Даты хранятся в формате YYYY-MM-DD в колонке order_date."
33
+ - "Менеджеры хранятся в таблице managers, связь с заказами через orders.manager_id."
data/demo/sales.sqlite ADDED
Binary file (57.3 kB). View file
 
data/demo/sales.sqlite-journal ADDED
Binary file (512 Bytes). View file
 
data/demo/test.db ADDED
Binary file (8.19 kB). View file
 
data/demo/test.db-journal ADDED
Binary file (512 Bytes). View file
 
data/pauq_repo ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 1c4a286e30c883f9b9bd5ca59b27cee76d4544ab
notebooks/kaggle_train_qwen_qlora.ipynb ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Обучение Qwen2.5-Coder-3B на PAUQ через QLoRA\n",
8
+ "\n",
9
+ "**Где запускать:** Kaggle Notebook с GPU T4 (Settings → Accelerator → GPU T4 x2 или T4 x1).\n",
10
+ "\n",
11
+ "**Что делаем:**\n",
12
+ "1. Ставим зависимости.\n",
13
+ "2. Качаем PAUQ.\n",
14
+ "3. Загружаем Qwen2.5-Coder-3B в 4-bit.\n",
15
+ "4. Готовим датасет в chat-формате.\n",
16
+ "5. Дообучаем LoRA-адаптер через `SFTTrainer`.\n",
17
+ "6. Сохраняем адаптер локально и (опционально) пушим на HuggingFace Hub.\n",
18
+ "\n",
19
+ "**Время на T4:** ~2–3 часа на эпоху (если ~10к примеров, max_seq_length=1024)."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {},
25
+ "source": [
26
+ "## 1. Установка зависимостей"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "!pip install -q -U \\\n",
36
+ " transformers==4.44.2 \\\n",
37
+ " peft==0.12.0 \\\n",
38
+ " accelerate==0.33.0 \\\n",
39
+ " bitsandbytes==0.43.3 \\\n",
40
+ " trl==0.10.1 \\\n",
41
+ " datasets==2.20.0 \\\n",
42
+ " sqlglot==25.5.1 \\\n",
43
+ " wandb"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "markdown",
48
+ "metadata": {},
49
+ "source": [
50
+ "## 2. Авторизация HuggingFace и W&B\n",
51
+ "\n",
52
+ "В Kaggle добавь секреты: `HF_TOKEN` и `WANDB_API_KEY` через Add-ons → Secrets. Тогда они подхватятся автоматически."
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "metadata": {},
59
+ "outputs": [],
60
+ "source": [
61
+ "import os\n",
62
+ "from kaggle_secrets import UserSecretsClient\n",
63
+ "\n",
64
+ "secrets = UserSecretsClient()\n",
65
+ "os.environ[\"HF_TOKEN\"] = secrets.get_secret(\"HF_TOKEN\")\n",
66
+ "os.environ[\"WANDB_API_KEY\"] = secrets.get_secret(\"WANDB_API_KEY\")\n",
67
+ "\n",
68
+ "from huggingface_hub import login\n",
69
+ "login(token=os.environ[\"HF_TOKEN\"])\n",
70
+ "\n",
71
+ "import wandb\n",
72
+ "wandb.login()"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "metadata": {},
78
+ "source": [
79
+ "## 3. Скачиваем PAUQ\n",
80
+ "\n",
81
+ "Альтернатива: загрузить PAUQ как Kaggle Dataset и подключить через `/kaggle/input/`."
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "!git clone https://github.com/ai-forever/pauq.git /kaggle/working/pauq_repo\n",
91
+ "!ls /kaggle/working/pauq_repo"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": null,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "import json\n",
101
+ "from pathlib import Path\n",
102
+ "\n",
103
+ "# Точные пути зависят от структуры репозитория. Найди train/dev/test файлы:\n",
104
+ "for p in Path(\"/kaggle/working/pauq_repo\").rglob(\"*.json\"):\n",
105
+ " print(p)"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "# ОБНОВИ пути после `ls` выше\n",
115
+ "TRAIN_JSON = Path(\"/kaggle/working/pauq_repo/path/to/train.json\")\n",
116
+ "DEV_JSON = Path(\"/kaggle/working/pauq_repo/path/to/dev.json\")\n",
117
+ "DATABASES_DIR = Path(\"/kaggle/working/pauq_repo/path/to/databases\")\n",
118
+ "\n",
119
+ "with TRAIN_JSON.open() as f:\n",
120
+ " train_raw = json.load(f)\n",
121
+ "with DEV_JSON.open() as f:\n",
122
+ " dev_raw = json.load(f)\n",
123
+ "\n",
124
+ "print(f\"train: {len(train_raw)}, dev: {len(dev_raw)}\")\n",
125
+ "print(\"Пример:\", train_raw[0])"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "metadata": {},
131
+ "source": [
132
+ "## 4. SchemaRetriever и PromptBuilder (инлайн)\n",
133
+ "\n",
134
+ "В Kaggle нет нашего пакета `src/`, поэтому копируем минимум нужного кода прямо сюда."
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "import sqlite3\n",
144
+ "from functools import lru_cache\n",
145
+ "\n",
146
+ "SYSTEM_PROMPT = (\n",
147
+ " \"Ты — ассистент, который преобразует вопросы на русском языке в корректные SQL-запросы. \"\n",
148
+ " \"Тебе даётся схема базы данных в виде CREATE TABLE statements и пример нескольких строк. \"\n",
149
+ " \"Сгенерируй один SQL-запрос, который отвечает на вопрос пользователя. \"\n",
150
+ " \"Возвращай ТОЛЬКО SQL без объяснений, без markdown, без префиксов.\"\n",
151
+ ")\n",
152
+ "\n",
153
+ "@lru_cache(maxsize=512)\n",
154
+ "def render_schema(db_id: str, n_samples: int = 2) -> str:\n",
155
+ " db_path = DATABASES_DIR / db_id / f\"{db_id}.sqlite\"\n",
156
+ " if not db_path.exists():\n",
157
+ " return \"\"\n",
158
+ " conn = sqlite3.connect(f\"file:{db_path}?mode=ro\", uri=True)\n",
159
+ " conn.text_factory = lambda b: b.decode(\"utf-8\", errors=\"replace\")\n",
160
+ " cur = conn.cursor()\n",
161
+ " cur.execute(\"SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'\")\n",
162
+ " parts = []\n",
163
+ " for name, ddl in cur.fetchall():\n",
164
+ " if not ddl:\n",
165
+ " continue\n",
166
+ " parts.append(ddl.strip() + \";\")\n",
167
+ " try:\n",
168
+ " cur.execute(f'SELECT * FROM \"{name}\" LIMIT {n_samples}')\n",
169
+ " rows = cur.fetchall()\n",
170
+ " for r in rows:\n",
171
+ " parts.append(f\"-- {r}\")\n",
172
+ " except sqlite3.Error:\n",
173
+ " pass\n",
174
+ " parts.append(\"\")\n",
175
+ " conn.close()\n",
176
+ " return \"\\n\".join(parts).strip()\n",
177
+ "\n",
178
+ "def build_messages(schema: str, question: str, sql: str | None = None):\n",
179
+ " user = f\"### Schema:\\n{schema}\\n\\n### Question:\\n{question}\\n\\n### SQL:\\n\"\n",
180
+ " msgs = [\n",
181
+ " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
182
+ " {\"role\": \"user\", \"content\": user},\n",
183
+ " ]\n",
184
+ " if sql is not None:\n",
185
+ " msgs.append({\"role\": \"assistant\", \"content\": sql.strip()})\n",
186
+ " return msgs"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "markdown",
191
+ "metadata": {},
192
+ "source": [
193
+ "## 5. Готовим датасет для SFT"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "from datasets import Dataset\n",
203
+ "\n",
204
+ "def to_record(item):\n",
205
+ " q = item.get(\"question\") or item.get(\"question_ru\") or \"\"\n",
206
+ " sql = item.get(\"query\") or item.get(\"sql_query\") or item.get(\"sql\") or \"\"\n",
207
+ " db_id = item.get(\"db_id\") or item.get(\"database\") or \"\"\n",
208
+ " if not (q and sql and db_id):\n",
209
+ " return None\n",
210
+ " schema = render_schema(db_id)\n",
211
+ " if not schema:\n",
212
+ " return None\n",
213
+ " return {\"messages\": build_messages(schema, q.strip(), sql.strip())}\n",
214
+ "\n",
215
+ "train_records = [r for r in (to_record(x) for x in train_raw) if r]\n",
216
+ "dev_records = [r for r in (to_record(x) for x in dev_raw) if r]\n",
217
+ "print(f\"train usable: {len(train_records)}, dev usable: {len(dev_records)}\")\n",
218
+ "\n",
219
+ "train_ds = Dataset.from_list(train_records)\n",
220
+ "dev_ds = Dataset.from_list(dev_records)"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "markdown",
225
+ "metadata": {},
226
+ "source": [
227
+ "## 6. Загружаем модель в 4-bit"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "import torch\n",
237
+ "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
238
+ "\n",
239
+ "MODEL_NAME = \"Qwen/Qwen2.5-Coder-3B-Instruct\"\n",
240
+ "\n",
241
+ "bnb_config = BitsAndBytesConfig(\n",
242
+ " load_in_4bit=True,\n",
243
+ " bnb_4bit_quant_type=\"nf4\",\n",
244
+ " bnb_4bit_compute_dtype=torch.bfloat16,\n",
245
+ " bnb_4bit_use_double_quant=True,\n",
246
+ ")\n",
247
+ "\n",
248
+ "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
249
+ "if tokenizer.pad_token is None:\n",
250
+ " tokenizer.pad_token = tokenizer.eos_token\n",
251
+ "\n",
252
+ "model = AutoModelForCausalLM.from_pretrained(\n",
253
+ " MODEL_NAME,\n",
254
+ " quantization_config=bnb_config,\n",
255
+ " device_map=\"auto\",\n",
256
+ " torch_dtype=torch.bfloat16,\n",
257
+ ")\n",
258
+ "model.config.use_cache = False"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "markdown",
263
+ "metadata": {},
264
+ "source": [
265
+ "## 7. Конфиг LoRA"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "from peft import LoraConfig, prepare_model_for_kbit_training\n",
275
+ "\n",
276
+ "model = prepare_model_for_kbit_training(model)\n",
277
+ "\n",
278
+ "lora_config = LoraConfig(\n",
279
+ " r=16,\n",
280
+ " lora_alpha=32,\n",
281
+ " lora_dropout=0.05,\n",
282
+ " bias=\"none\",\n",
283
+ " task_type=\"CAUSAL_LM\",\n",
284
+ " target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
285
+ " \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
286
+ ")"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "metadata": {},
292
+ "source": [
293
+ "## 8. Тренировка через SFTTrainer"
294
+ ]
295
+ },
296
+ {
297
+ "cell_type": "code",
298
+ "execution_count": null,
299
+ "metadata": {},
300
+ "outputs": [],
301
+ "source": [
302
+ "from trl import SFTConfig, SFTTrainer\n",
303
+ "\n",
304
+ "OUTPUT_DIR = \"/kaggle/working/qwen-coder-pauq-lora\"\n",
305
+ "\n",
306
+ "sft_config = SFTConfig(\n",
307
+ " output_dir=OUTPUT_DIR,\n",
308
+ " num_train_epochs=2,\n",
309
+ " per_device_train_batch_size=1,\n",
310
+ " gradient_accumulation_steps=8,\n",
311
+ " gradient_checkpointing=True,\n",
312
+ " learning_rate=2e-4,\n",
313
+ " lr_scheduler_type=\"cosine\",\n",
314
+ " warmup_ratio=0.03,\n",
315
+ " optim=\"paged_adamw_8bit\",\n",
316
+ " bf16=True,\n",
317
+ " logging_steps=20,\n",
318
+ " save_strategy=\"epoch\",\n",
319
+ " save_total_limit=2,\n",
320
+ " eval_strategy=\"no\", # eval делаем отдельно после тренировки\n",
321
+ " max_seq_length=1024,\n",
322
+ " packing=False,\n",
323
+ " report_to=\"wandb\",\n",
324
+ " run_name=\"qwen3b-pauq-qlora\",\n",
325
+ ")\n",
326
+ "\n",
327
+ "trainer = SFTTrainer(\n",
328
+ " model=model,\n",
329
+ " tokenizer=tokenizer,\n",
330
+ " train_dataset=train_ds,\n",
331
+ " peft_config=lora_config,\n",
332
+ " args=sft_config,\n",
333
+ ")\n",
334
+ "\n",
335
+ "trainer.train()"
336
+ ]
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "execution_count": null,
341
+ "metadata": {},
342
+ "outputs": [],
343
+ "source": [
344
+ "trainer.save_model(OUTPUT_DIR)\n",
345
+ "tokenizer.save_pretrained(OUTPUT_DIR)\n",
346
+ "print(\"Saved to\", OUTPUT_DIR)"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "markdown",
351
+ "metadata": {},
352
+ "source": [
353
+ "## 9. Быстрая проверка inference"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "model.config.use_cache = True\n",
363
+ "model.eval()\n",
364
+ "\n",
365
+ "ex = dev_records[0]\n",
366
+ "prompt_msgs = ex[\"messages\"][:2] # без assistant-ответа\n",
367
+ "prompt = tokenizer.apply_chat_template(prompt_msgs, tokenize=False, add_generation_prompt=True)\n",
368
+ "inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
369
+ "\n",
370
+ "with torch.no_grad():\n",
371
+ " out = model.generate(**inputs, max_new_tokens=256, do_sample=False,\n",
372
+ " pad_token_id=tokenizer.eos_token_id)\n",
373
+ "new_tokens = out[0][inputs[\"input_ids\"].shape[1]:]\n",
374
+ "print(\"Pred:\", tokenizer.decode(new_tokens, skip_special_tokens=True))\n",
375
+ "print(\"Gold:\", ex[\"messages\"][2][\"content\"])"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "markdown",
380
+ "metadata": {},
381
+ "source": [
382
+ "## 10. Загрузка адаптера на HuggingFace Hub (приватный репо)"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "HF_REPO = \"your-username/qwen-coder-pauq-lora\" # замени на свой\n",
392
+ "\n",
393
+ "trainer.model.push_to_hub(HF_REPO, private=True)\n",
394
+ "tokenizer.push_to_hub(HF_REPO, private=True)\n",
395
+ "print(\"Pushed to\", HF_REPO)"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "markdown",
400
+ "metadata": {},
401
+ "source": [
402
+ "## Дальше\n",
403
+ "\n",
404
+ "1. Скачай адаптер на десктоп: `huggingface-cli download your-username/qwen-coder-pauq-lora --local-dir checkpoints/qwen-coder-pauq-lora`.\n",
405
+ "2. Запусти `python -m src.evaluation.evaluate --split dev --limit 100` локально, либо запусти полный eval здесь же на Kaggle.\n",
406
+ "3. Если метрики низкие: проверь prompt format, увеличь эпохи, понизь learning rate."
407
+ ]
408
+ }
409
+ ],
410
+ "metadata": {
411
+ "kernelspec": {
412
+ "display_name": "Python 3",
413
+ "language": "python",
414
+ "name": "python3"
415
+ },
416
+ "language_info": {
417
+ "codemirror_mode": {"name": "ipython", "version": 3},
418
+ "file_extension": ".py",
419
+ "mimetype": "text/x-python",
420
+ "name": "python",
421
+ "nbconvert_exporter": "python",
422
+ "pygments_lexer": "ipython3",
423
+ "version": "3.10"
424
+ }
425
+ },
426
+ "nbformat": 4,
427
+ "nbformat_minor": 4
428
+ }
plan_VKR_text2sql_ru.md ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # План практической части ВКР: «Утилита Natural Language → SQL для бизнес-аналитики»
2
+
3
+ **Студент:** Danis, ПИ, 4 курс
4
+ **Срок:** 4 недели
5
+ **Дата:** 29 апреля 2026
6
+
7
+ ---
8
+
9
+ ## 0. Контур решения
10
+
11
+ **Финальный продукт:** утилита, которая позволяет аналитику малого и среднего бизнеса задавать вопросы на русском языке и получать готовые данные из корпоративной базы данных — без знания SQL.
12
+
13
+ Система: вопрос на русском → бизнес-словарь компании → схема БД → SQL → выполнение → результат.
14
+
15
+ Подход: fine-tuning **Qwen2.5-Coder-3B-Instruct** методом QLoRA на датасете **PAUQ**, обёрнутый в **FastAPI** с дополнительными модулями подключения к произвольной БД, настраиваемым бизнес-словарём и веб-интерфейсом на Streamlit.
16
+
17
+ Для научного сравнения параллельно прогоняется **GigaChat API** (или OpenAI) и **ruT5-base** baseline.
18
+
19
+ Инфраструктура:
20
+ - Тренировка: **Kaggle Notebooks** (T4 16 GB бесплатно).
21
+ - Разработка кода и API: **десктоп** Ryzen 5 3600X + 16 GB RAM.
22
+ - Демо на защите: **ноутбук** Ryzen 5 5500U + 16 GB RAM, инференс на CPU.
23
+
24
+ Артефакты ВКР:
25
+ - Рабочая утилита с веб-интерфейсом (Streamlit)
26
+ - Модуль подключения к произвольной БД (SQLite / PostgreSQL / MySQL)
27
+ - Модуль бизнес-словаря (YAML-конфиг с определениями метрик компании)
28
+ - Сравнительная таблица метрик (EM, Execution Accuracy)
29
+ - Анализ ошибок на 30+ примерах
30
+
31
+ ---
32
+
33
+ ## 1. Технологический стек
34
+
35
+ ### 1.1 Среда разработки
36
+
37
+ | Компонент | Выбор |
38
+ |---|---|
39
+ | Язык | Python 3.10+ |
40
+ | Менеджер пакетов | uv (быстрый, современный) |
41
+ | Контроль версий | Git + GitHub |
42
+ | IDE | VS Code |
43
+
44
+ ### 1.2 ML и обучение
45
+
46
+ | Компонент | Выбор | Где используется |
47
+ |---|---|---|
48
+ | PyTorch 2.x | основа | Kaggle |
49
+ | transformers | модели и токенизация | Kaggle + десктоп |
50
+ | peft | LoRA/QLoRA | Kaggle |
51
+ | bitsandbytes | 4-bit квантизация | Kaggle (на CPU не нужен) |
52
+ | trl | SFTTrainer | Kaggle |
53
+ | datasets | работа с PAUQ | Kaggle + десктоп |
54
+ | W&B | логирование экспериментов | Kaggle |
55
+
56
+ ### 1.3 Инференс на десктопе и ноутбуке
57
+
58
+ Для локального инференса без GPU есть два пути:
59
+
60
+ | Путь | Скорость | Сложность | Применение |
61
+ |---|---|---|---|
62
+ | transformers на CPU (int8) | 15–30 с/запрос | проще | разработка, отладка |
63
+ | llama.cpp (gguf int4) | 5–15 с/запрос | сложнее | финальное демо |
64
+
65
+ **Рекомендация:** для разработки — transformers, для защиты — llama.cpp.
66
+
67
+ ### 1.4 API и SQL
68
+
69
+ | Компонент | Выбор |
70
+ |---|---|
71
+ | FastAPI + Uvicorn | REST API |
72
+ | Pydantic v2 | валидация |
73
+ | sqlite3 (stdlib) | работа с БД из PAUQ |
74
+ | sqlglot | парсинг и валидация SQL |
75
+ | pytest | тесты |
76
+
77
+ ---
78
+
79
+ ## 2. Архитектура
80
+
81
+ ```
82
+ ┌──────────────────────────────────────────────────────────────┐
83
+ │ Streamlit Web Interface │
84
+ │ Поле вопроса | Выбор БД | Редактор бизнес-словаря │
85
+ │ Таблица результатов | История запросов │
86
+ └──────────────────────────┬───────────────────────────────────┘
87
+ │ HTTP
88
+ ┌──────────────────────────▼───────────────────────────────────┐
89
+ │ FastAPI REST API │
90
+ │ POST /query {question_ru, db_id} → {sql, result, ...} │
91
+ └──────┬──────────────┬───────────────┬─────────────────────���──┘
92
+ │ │ │
93
+ ▼ ▼ ▼
94
+ ┌────────────┐ ┌────────────┐ ┌─────────────────┐
95
+ │ DbConnector│ │ Business │ │ SchemaRetriever │
96
+ │ SQLite / │ │ Vocabulary │ │ (DDL из БД) │
97
+ │ Postgres / │ │ (YAML- │ └────────┬────────┘
98
+ │ MySQL │ │ конфиг) │ │
99
+ └─────┬──────┘ └─────┬──────┘ │
100
+ │ │ │
101
+ │ ┌────▼─────────────────▼──┐
102
+ │ │ PromptBuilder │
103
+ │ │ вопрос + схема + │
104
+ │ │ определения метрик │
105
+ │ └────────────┬────────────┘
106
+ │ ▼
107
+ │ ┌────────────────────────┐
108
+ │ │ InferenceEngine │
109
+ │ │ Qwen2.5-Coder-3B │
110
+ │ │ + LoRA adapter │
111
+ │ └────────────┬───────────┘
112
+ │ ▼
113
+ │ ┌────────────────────────┐
114
+ │ │ SqlPostProcessor │
115
+ │ │ (sqlglot validation) │
116
+ │ └────────────┬───────────┘
117
+ │ │
118
+ └──────────────────────┘
119
+ │ выполнить SQL
120
+
121
+ ┌─────────────────┐
122
+ │ SqlExecutor │
123
+ │ результат → │
124
+ │ аналитику │
125
+ └─────────────────┘
126
+ ```
127
+
128
+ Структура проекта (см. файлы в репозитории):
129
+ ```
130
+ ru2sql/
131
+ ├── README.md
132
+ ├── pyproject.toml
133
+ ├── .gitignore
134
+ ├── notebooks/
135
+ │ └── kaggle_train_qwen_qlora.ipynb
136
+ ├── src/
137
+ │ ├── config.py
138
+ │ ├── data/ — loader, schema, prompt
139
+ │ ├── models/ — inference, postprocess
140
+ │ ├── evaluation/ — metrics, evaluate
141
+ │ └── api/ — main, schemas, dependencies
142
+ ├── tests/
143
+ └── scripts/
144
+ ```
145
+
146
+ ---
147
+
148
+ ## 3. Помесячный план
149
+
150
+ ### Неделя 1. Окружение, данные, baseline
151
+
152
+ **Цель:** работающий pipeline от вопроса до SQL на маленькой модели.
153
+
154
+ | День | Задача |
155
+ |---|---|
156
+ | 1 | Установка Python 3.10+, uv, Git. Клонирование репозитория. `uv sync`. Проверка что FastAPI стартует. |
157
+ | 2 | Регистрация на Kaggle, HuggingFace, W&B. Скачивание PAUQ (https://github.com/ai-forever/pauq). |
158
+ | 3 | Анализ датасета в notebook: распределения, сложности, примеры. Реализация `SchemaRetriever`. |
159
+ | 4 | Реализация `PromptBuilder`. Тесты: `pytest tests/test_prompt.py`. |
160
+ | 5–6 | Kaggle-notebook: обучение **ruT5-base** на 2 эпохи. Сохранение чекпойнта. |
161
+ | 7 | Реализация `metrics.py` (EM + Execution Accuracy). Прогон ruT5 на dev. Запись в W&B. |
162
+
163
+ Контрольная точка недели: ruT5-base даёт 25–35% EM на PAUQ dev.
164
+
165
+ ### Неделя 2. Главная модель (Qwen2.5-Coder-3B + QLoRA)
166
+
167
+ **Цель:** обученный LoRA-адаптер для Qwen с метриками выше baseline.
168
+
169
+ | День | Задача |
170
+ |---|---|
171
+ | 1 | Kaggle-notebook: загрузка Qwen2.5-Coder-3B в 4-bit, тестовый inference. |
172
+ | 2 | Подготовка PAUQ в chat-формате под модель. |
173
+ | 3–4 | SFTTrainer + LoRA (r=16, alpha=32). Прогон 2–3 эпохи (~4–6 часов суммарно). |
174
+ | 5 | Сохранение LoRA-адаптера на HuggingFace Hub (приватный репозиторий). |
175
+ | 6 | Скачивание адаптера на десктоп. Локальный инференс на CPU через transformers. |
176
+ | 7 | Прогон на dev split, метрики, error analysis на 30 примерах. |
177
+
178
+ Контрольная точка недели: Qwen+LoRA даёт 50–60% EM на PAUQ dev и работает на десктопе.
179
+
180
+ ### Неделя 3. Бизнес-утилита: коннектор + словарь + исполнение SQL
181
+
182
+ **Цель:** превратить API в полноценную бизнес-утилиту — подключение к реальной БД, настройка под компанию, возврат данных.
183
+
184
+ | День | Задача |
185
+ |---|---|
186
+ | 1 | FastAPI: `/generate-sql`, `/query`, `/databases`, `/health`. Lifespan для загрузки модели. |
187
+ | 2 | Модуль `DbConnector` — подключение к SQLite/PostgreSQL/MySQL по строке подключения. Автоматическое чтение схемы (`INFORMATION_SCHEMA`). |
188
+ | 3 | Модуль `BusinessVocabulary` — загрузка YAML-конфига с определениями метрик. Подстановка определений в промпт перед генерацией SQL. Пример конфига: `выручка: "SUM(orders.amount) WHERE status='paid'"`. |
189
+ | 4 | Эндпоинт `/query` — принимает вопрос, генерирует SQL, выполняет на подключённой БД, возвращает результат в JSON (таблица строк). |
190
+ | 5 | Получение API-ключа GigaChat (или YandexGPT), скрипт прогона на тех же примерах. Сравнительная таблица: ruT5 vs Qwen+LoRA vs GigaChat по EM и EX. |
191
+ | 6 | `SqlPostProcessor` через sqlglot. Тесты pytest на все новые модули. |
192
+ | 7 | Создание демо-базы данных (SQLite) с реалистичными бизнес-данными: продажи, клиенты, товары. Написание бизнес-словаря под эту базу. |
193
+
194
+ Контрольная точка недели: аналитик вводит "Какая выручка за январь?" → утилита возвращает число из реальной БД.
195
+
196
+ ### Неделя 4. Streamlit-интерфейс, демо, материалы для ВКР
197
+
198
+ **Цель:** красивый рабочий продукт для защиты + готовые материалы для текста ВКР.
199
+
200
+ | День | Задача |
201
+ |---|---|
202
+ | 1 | Streamlit-интерфейс: поле ввода вопроса, выбор БД, отображение сгенерированного SQL и таблицы результатов. |
203
+ | 2 | В интерфейсе: вкладка настройки бизнес-словаря (редактирование YAML прямо в браузере). История запросов. |
204
+ | 3 | Error analysis: разбор 30 ошибок Qwen+LoRA, классификация по категориям (неверный JOIN, неверное условие WHERE и т.д.). |
205
+ | 4 | Конвертация LoRA + базовой модели в gguf через llama.cpp для быстрого инференса на CPU. |
206
+ | 5 | Диаграммы архитектуры (draw.io), скриншоты интерфейса, графики метрик (matplotlib). |
207
+ | 6 | Глава «Реализация» и глава «Практическое применение» в тексте ВКР. |
208
+ | 7 | Прогон полного сценария на ноутбуке с демо-базой. Резервная копия чекпойнта на HuggingFace. |
209
+
210
+ ---
211
+
212
+ ## 4. Метрики качества
213
+
214
+ Стандарт для Text-to-SQL:
215
+
216
+ - **Exact Match (EM)** — нормализуем оба SQL и сравниваем посимвольно.
217
+ - **Execution Accuracy (EX)** — выполняем оба SQL на реальной SQLite, сравниваем результаты как множества кортежей.
218
+
219
+ EX важнее EM, потому что разные SQL могут дать одинаковый результат.
220
+
221
+ Целевые числа на PAUQ dev (ориентировочно):
222
+ - ruT5-base: 25–35% EM, 30–40% EX.
223
+ - Qwen2.5-Coder-3B + LoRA: 50–60% EM, 55–70% EX.
224
+ - GigaChat / GPT-4 (zero-shot, через API): 55–70% EM, 65–80% EX.
225
+
226
+ Ваш Qwen после QLoRA должен быть близок к API-моделям. Это и будет защищаемый результат.
227
+
228
+ ---
229
+
230
+ ## 5. Риски и план B
231
+
232
+ | Риск | План B |
233
+ |---|---|
234
+ | Kaggle квота закончилась | Переключиться на Google Colab Free или арендовать GPU на vast.ai (~$2 за обучение) |
235
+ | Qwen-3B плохо сходится | Понизить learning rate до 1e-4, увеличить эпохи до 5, проверить prompt format |
236
+ | llama.cpp не успеваю настроить к защите | Демо через transformers на CPU напрямую — медленнее, но работает |
237
+ | GigaChat недоступен | YandexGPT либо OpenAI через VPN — Pydantic-обёртка одна, провайдер меняется одной строчкой |
238
+ | Не хватает времени на error analysis | Минимум — 20 ошибок руками, простая классификация в Excel |
239
+
240
+ ---
241
+
242
+ ## 6. Что вынести в «направления дальнейшей работы»
243
+
244
+ Эти улучшения **не делаем** в рамках месяца, но упоминаем в ВКР:
245
+ - Few-shot retrieval (поиск похожих примеров через эмбеддинги).
246
+ - Schema linking (автоматический отбор таблиц).
247
+ - Self-correction (выполнение SQL, исправление по ошибке).
248
+ - Constrained decoding (ограничение токенов до валидной SQL-грамматики).
249
+ - Дообучение на синтетических данных от GPT-4.
250
+
251
+ ---
252
+
253
+ ## 7. Итоговый чек-лист на старте
254
+
255
+ - [ ] Установлены Python 3.10+, uv, Git, VS Code на десктопе
256
+ - [ ] Создан репозиторий ru2sql на GitHub
257
+ - [ ] Зарегистрированы аккаунты Kaggle, HuggingFace, W&B
258
+ - [ ] Получен ключ GigaChat (или OpenAI)
259
+ - [ ] Скачан PAUQ
260
+ - [ ] `uv sync` проходит без ошибок
261
+ - [ ] `uvicorn src.api.main:app --reload` стартует
262
+ - [ ] Прочитаны статьи: Spider (2018), QLoRA (2023), краткое описание Qwen2.5-Coder
263
+
264
+ После чек-листа можно стартовать День 3 первой недели.
pyproject.toml ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "ru2sql"
3
+ version = "0.1.0"
4
+ description = "Russian-to-SQL generative model for graduation thesis"
5
+ authors = [{ name = "Danis", email = "[email protected]" }]
6
+ requires-python = ">=3.10,<3.13"
7
+ readme = "README.md"
8
+
9
+ dependencies = [
10
+ # API
11
+ "fastapi>=0.115.0",
12
+ "uvicorn[standard]>=0.30.0",
13
+ "pydantic>=2.7.0",
14
+ "pydantic-settings>=2.4.0",
15
+
16
+ # SQL parsing / validation
17
+ "sqlglot>=25.0.0",
18
+
19
+ # Data
20
+ "datasets>=2.20.0",
21
+ "pandas>=2.2.0",
22
+
23
+ # ML inference (CPU-friendly versions for desktop/laptop)
24
+ # Heavy training deps (bitsandbytes, peft, trl) live in [training] and run on Kaggle
25
+ "torch>=2.3.0",
26
+ "transformers>=4.44.0",
27
+ "accelerate>=0.33.0",
28
+ "peft>=0.12.0", # for loading LoRA adapter at inference time
29
+
30
+ # Misc
31
+ "python-dotenv>=1.0.0",
32
+ "httpx>=0.27.0", # for GigaChat/OpenAI API client
33
+ "tqdm>=4.66.0",
34
+
35
+ # Интерфейс
36
+ "streamlit>=1.35.0",
37
+ "pyyaml>=6.0",
38
+ ]
39
+
40
+ [project.optional-dependencies]
41
+ dev = [
42
+ "pytest>=8.3.0",
43
+ "pytest-asyncio>=0.23.0",
44
+ "ruff>=0.6.0",
45
+ "ipykernel>=6.29.0",
46
+ "matplotlib>=3.9.0",
47
+ "seaborn>=0.13.0",
48
+ ]
49
+
50
+ # Heavy GPU-only deps. Install on Kaggle: `pip install -e .[training]`
51
+ training = [
52
+ "bitsandbytes>=0.43.0",
53
+ "trl>=0.10.0",
54
+ "wandb>=0.17.0",
55
+ ]
56
+
57
+ [build-system]
58
+ requires = ["hatchling"]
59
+ build-backend = "hatchling.build"
60
+
61
+ [tool.hatch.build.targets.wheel]
62
+ packages = ["src"]
63
+
64
+ [tool.ruff]
65
+ line-length = 100
66
+ target-version = "py310"
67
+
68
+ [tool.ruff.lint]
69
+ select = ["E", "F", "W", "I", "B", "UP"]
70
+ ignore = ["E501"]
71
+
72
+ [tool.pytest.ini_options]
73
+ testpaths = ["tests"]
74
+ pythonpath = ["."]
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.35.0
2
+ torch>=2.3.0
3
+ transformers>=4.44.0
4
+ accelerate>=0.33.0
5
+ peft>=0.12.0
6
+ pydantic>=2.7.0
7
+ pydantic-settings>=2.4.0
8
+ sqlglot>=25.0.0
9
+ pandas>=2.2.0
10
+ python-dotenv>=1.0.0
11
+ huggingface_hub>=1.0.0
12
+ pyyaml>=6.0
13
+ tqdm>=4.66.0
14
+ httpx>=0.27.0
src/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """ru2sql — Russian-to-SQL generative model."""
2
+
3
+ __version__ = "0.1.0"
src/api/__init__.py ADDED
File without changes
src/api/dependencies.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI lifespan и DI: загрузка модели один раз при старте."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from contextlib import asynccontextmanager
6
+
7
+ from fastapi import FastAPI
8
+
9
+ from src.config import settings
10
+ from src.data.schema import SchemaRetriever
11
+ from src.models.inference import InferenceEngine
12
+
13
+
14
+ class AppState:
15
+ engine: InferenceEngine | None = None
16
+ schema_retriever: SchemaRetriever | None = None
17
+
18
+
19
+ state = AppState()
20
+
21
+
22
+ @asynccontextmanager
23
+ async def lifespan(app: FastAPI):
24
+ """Грузим модель при старте, освобождаем при остановке."""
25
+ state.engine = InferenceEngine()
26
+ state.engine.load()
27
+ state.schema_retriever = SchemaRetriever(settings.databases_dir)
28
+ yield
29
+ state.engine = None
30
+ state.schema_retriever = None
31
+
32
+
33
+ def get_engine() -> InferenceEngine:
34
+ if state.engine is None:
35
+ raise RuntimeError("Inference engine not initialized")
36
+ return state.engine
37
+
38
+
39
+ def get_schema_retriever() -> SchemaRetriever:
40
+ if state.schema_retriever is None:
41
+ raise RuntimeError("SchemaRetriever not initialized")
42
+ return state.schema_retriever
src/api/main.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI приложение.
2
+
3
+ Запуск:
4
+ uvicorn src.api.main:app --reload
5
+ # Swagger UI: http://127.0.0.1:8000/docs
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import sqlite3
11
+
12
+ from fastapi import Depends, FastAPI, HTTPException
13
+ from fastapi.concurrency import run_in_threadpool
14
+
15
+ from src.api.dependencies import get_engine, get_schema_retriever, lifespan
16
+ from src.api.schemas import (
17
+ DatabaseInfo,
18
+ ExecutionResult,
19
+ GenerateRequest,
20
+ GenerateResponse,
21
+ HealthResponse,
22
+ )
23
+ from src.config import settings
24
+ from src.data.schema import SchemaRetriever
25
+ from src.models.inference import InferenceEngine
26
+ from src.models.postprocess import is_valid_sql
27
+
28
+ app = FastAPI(
29
+ title="ru2sql",
30
+ description="Преобразование вопросов на русском в SQL-запросы",
31
+ version="0.1.0",
32
+ lifespan=lifespan,
33
+ )
34
+
35
+
36
+ @app.get("/health", response_model=HealthResponse)
37
+ def health(engine: InferenceEngine = Depends(get_engine)):
38
+ return HealthResponse(
39
+ status="ok",
40
+ model_loaded=engine._loaded,
41
+ base_model=engine.base_model_name,
42
+ )
43
+
44
+
45
+ @app.get("/databases", response_model=list[DatabaseInfo])
46
+ def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
47
+ out: list[DatabaseInfo] = []
48
+ for db_id in retriever.list_databases():
49
+ try:
50
+ tables = [t.name for t in retriever.get_tables(db_id, n_sample_rows=0)]
51
+ out.append(DatabaseInfo(db_id=db_id, tables=tables))
52
+ except FileNotFoundError:
53
+ continue
54
+ return out
55
+
56
+
57
+ @app.post("/generate-sql", response_model=GenerateResponse)
58
+ async def generate_sql(
59
+ req: GenerateRequest,
60
+ engine: InferenceEngine = Depends(get_engine),
61
+ retriever: SchemaRetriever = Depends(get_schema_retriever),
62
+ ):
63
+ try:
64
+ schema_text = retriever.render_schema(req.db_id)
65
+ except FileNotFoundError as e:
66
+ raise HTTPException(status_code=404, detail=str(e)) from e
67
+
68
+ # Inference синхронный и тяжёлый — выносим в threadpool
69
+ result = await run_in_threadpool(engine.generate, schema_text, req.question)
70
+
71
+ valid = is_valid_sql(result.sql)
72
+ response = GenerateResponse(
73
+ sql=result.sql,
74
+ raw_output=result.raw_output,
75
+ is_valid_sql=valid,
76
+ )
77
+
78
+ if req.execute and valid:
79
+ try:
80
+ response.execution = await run_in_threadpool(
81
+ _execute_sql, req.db_id, result.sql, retriever
82
+ )
83
+ except sqlite3.Error as e:
84
+ response.error = f"SQL execution error: {e}"
85
+
86
+ return response
87
+
88
+
89
+ def _execute_sql(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionResult:
90
+ db_path = retriever.db_path(db_id)
91
+ conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
92
+ try:
93
+ conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
94
+ cur = conn.cursor()
95
+ cur.execute(sql)
96
+ rows = cur.fetchmany(100)
97
+ cols = [d[0] for d in cur.description] if cur.description else []
98
+ return ExecutionResult(
99
+ columns=cols,
100
+ rows=[list(r) for r in rows],
101
+ row_count=len(rows),
102
+ )
103
+ finally:
104
+ conn.close()
105
+
106
+
107
+ if __name__ == "__main__":
108
+ import uvicorn
109
+
110
+ uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True)
src/api/schemas.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Pydantic-модели для FastAPI endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class GenerateRequest(BaseModel):
9
+ question: str = Field(..., min_length=1, max_length=2000, description="Вопрос на русском")
10
+ db_id: str = Field(..., min_length=1, description="Идентификатор БД из PAUQ")
11
+ execute: bool = Field(default=False, description="Прогнать сгенерированный SQL на БД")
12
+
13
+
14
+ class ExecutionResult(BaseModel):
15
+ columns: list[str]
16
+ rows: list[list]
17
+ row_count: int
18
+
19
+
20
+ class GenerateResponse(BaseModel):
21
+ sql: str
22
+ raw_output: str
23
+ is_valid_sql: bool
24
+ execution: ExecutionResult | None = None
25
+ error: str | None = None
26
+
27
+
28
+ class DatabaseInfo(BaseModel):
29
+ db_id: str
30
+ tables: list[str]
31
+
32
+
33
+ class HealthResponse(BaseModel):
34
+ status: str
35
+ model_loaded: bool
36
+ base_model: str
src/business/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .vocabulary import BusinessVocabulary
2
+
3
+ __all__ = ["BusinessVocabulary"]
src/business/vocabulary.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """BusinessVocabulary — настраиваемый бизнес-словарь компании.
2
+
3
+ Позволяет аналитику один раз описать бизнес-термины и метрики компании в YAML-файле,
4
+ после чего модель правильно интерпретирует их в SQL-запросах.
5
+
6
+ Пример YAML-конфига (configs/example_vocabulary.yaml):
7
+ company: "ООО Ромашка"
8
+ terms:
9
+ выручка: "SUM(orders.amount) WHERE orders.status = 'paid'"
10
+ активный клиент: "клиент, совершивший покупку за последние 90 дней"
11
+ этот год: "YEAR(order_date) = strftime('%Y', 'now')"
12
+ прошлый месяц: "strftime('%Y-%m', order_date) = strftime('%Y-%m', 'now', '-1 month')"
13
+
14
+ filters:
15
+ только_оплаченные: "orders.status = 'paid'"
16
+ без_возвратов: "orders.is_return = 0"
17
+
18
+ Пример использования:
19
+ vocab = BusinessVocabulary.from_yaml("configs/my_company.yaml")
20
+ enriched_prompt = vocab.enrich_prompt("Какая выручка за январь?")
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ from dataclasses import dataclass, field
26
+ from pathlib import Path
27
+
28
+ try:
29
+ import yaml # type: ignore
30
+ _YAML_AVAILABLE = True
31
+ except ImportError:
32
+ _YAML_AVAILABLE = False
33
+
34
+
35
+ @dataclass
36
+ class BusinessVocabulary:
37
+ """Хранит бизнес-термины и метрики компании, подставляет их в промпт модели."""
38
+
39
+ company: str = ""
40
+ terms: dict[str, str] = field(default_factory=dict)
41
+ filters: dict[str, str] = field(default_factory=dict)
42
+ notes: list[str] = field(default_factory=list)
43
+
44
+ # ------------------------------------------------------------------
45
+ # Загрузка
46
+ # ------------------------------------------------------------------
47
+
48
+ @classmethod
49
+ def from_yaml(cls, path: str | Path) -> "BusinessVocabulary":
50
+ """Загружает словарь из YAML-файла."""
51
+ if not _YAML_AVAILABLE:
52
+ raise ImportError("Установи PyYAML: pip install pyyaml")
53
+ path = Path(path)
54
+ if not path.exists():
55
+ raise FileNotFoundError(f"Файл бизнес-словаря не найден: {path}")
56
+ with open(path, encoding="utf-8") as f:
57
+ data = yaml.safe_load(f) or {}
58
+ return cls(
59
+ company=data.get("company", ""),
60
+ terms=data.get("terms", {}),
61
+ filters=data.get("filters", {}),
62
+ notes=data.get("notes", []),
63
+ )
64
+
65
+ @classmethod
66
+ def from_dict(cls, data: dict) -> "BusinessVocabulary":
67
+ """Создаёт словарь из словаря Python (удобно для API и Streamlit)."""
68
+ return cls(
69
+ company=data.get("company", ""),
70
+ terms=data.get("terms", {}),
71
+ filters=data.get("filters", {}),
72
+ notes=data.get("notes", []),
73
+ )
74
+
75
+ @classmethod
76
+ def empty(cls) -> "BusinessVocabulary":
77
+ """Пустой словарь — для случая когда компания ещё не настроила термины."""
78
+ return cls()
79
+
80
+ # ------------------------------------------------------------------
81
+ # Использование
82
+ # ------------------------------------------------------------------
83
+
84
+ def enrich_prompt(self, question: str) -> str:
85
+ """Добавляет к вопросу пользователя контекст из бизнес-словаря.
86
+
87
+ Если вопрос содержит известные термины — подставляет их определения.
88
+ Возвращает обогащённый вопрос для подстановки в промпт модели.
89
+ """
90
+ if not self.terms and not self.filters and not self.notes:
91
+ return question
92
+
93
+ context_lines: list[str] = []
94
+
95
+ # Находим термины которые упоминаются в вопросе
96
+ question_lower = question.lower()
97
+ relevant_terms = {
98
+ term: definition
99
+ for term, definition in self.terms.items()
100
+ if term.lower() in question_lower
101
+ }
102
+
103
+ if relevant_terms:
104
+ context_lines.append("Определения терминов компании:")
105
+ for term, definition in relevant_terms.items():
106
+ context_lines.append(f" - {term}: {definition}")
107
+
108
+ if self.filters:
109
+ context_lines.append("Стандартные фильтры компании:")
110
+ for name, condition in self.filters.items():
111
+ context_lines.append(f" - {name}: {condition}")
112
+
113
+ if self.notes:
114
+ context_lines.append("Дополнительные правила:")
115
+ for note in self.notes:
116
+ context_lines.append(f" - {note}")
117
+
118
+ if not context_lines:
119
+ return question
120
+
121
+ context = "\n".join(context_lines)
122
+ return f"{question}\n\n[Контекст компании]\n{context}"
123
+
124
+ def render_system_context(self) -> str:
125
+ """Текст для системного промпта — описывает все термины компании."""
126
+ if not self.terms and not self.filters and not self.notes:
127
+ return ""
128
+
129
+ lines: list[str] = []
130
+ if self.company:
131
+ lines.append(f"Компания: {self.company}")
132
+ lines.append("")
133
+
134
+ if self.terms:
135
+ lines.append("Бизнес-термины и метрики:")
136
+ for term, definition in self.terms.items():
137
+ lines.append(f" - «{term}» означает: {definition}")
138
+
139
+ if self.filters:
140
+ lines.append("")
141
+ lines.append("Стандартные условия фильтрации:")
142
+ for name, condition in self.filters.items():
143
+ lines.append(f" - {name}: {condition}")
144
+
145
+ if self.notes:
146
+ lines.append("")
147
+ lines.append("Важные правила:")
148
+ for note in self.notes:
149
+ lines.append(f" - {note}")
150
+
151
+ return "\n".join(lines)
152
+
153
+ def to_yaml_string(self) -> str:
154
+ """Сериализует словарь обратно в YAML-строку (для редактора в Streamlit)."""
155
+ if not _YAML_AVAILABLE:
156
+ raise ImportError("Установи PyYAML: pip install pyyaml")
157
+ data = {
158
+ "company": self.company,
159
+ "terms": self.terms,
160
+ "filters": self.filters,
161
+ "notes": self.notes,
162
+ }
163
+ return yaml.dump(data, allow_unicode=True, sort_keys=False, default_flow_style=False)
164
+
165
+ def save_yaml(self, path: str | Path) -> None:
166
+ """Сохраняет словарь в YAML-файл."""
167
+ path = Path(path)
168
+ path.parent.mkdir(parents=True, exist_ok=True)
169
+ with open(path, "w", encoding="utf-8") as f:
170
+ f.write(self.to_yaml_string())
171
+
172
+ def __bool__(self) -> bool:
173
+ return bool(self.terms or self.filters or self.notes)
src/config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Конфигурация проекта. Читаем из .env через pydantic-settings."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from pydantic_settings import BaseSettings, SettingsConfigDict
8
+
9
+ ROOT_DIR = Path(__file__).resolve().parent.parent
10
+
11
+
12
+ class Settings(BaseSettings):
13
+ """Все настройки приложения. Значения берутся из .env, переменных окружения, либо дефолтов."""
14
+
15
+ model_config = SettingsConfigDict(
16
+ env_file=str(ROOT_DIR / ".env"),
17
+ env_file_encoding="utf-8",
18
+ extra="ignore",
19
+ )
20
+
21
+ # Модель
22
+ base_model_name: str = "Qwen/Qwen2.5-Coder-3B-Instruct"
23
+ lora_adapter_path: str = str(ROOT_DIR / "checkpoints" / "qwen-coder-pauq-lora")
24
+ device: str = "cpu" # "cpu" | "cuda" | "mps"
25
+
26
+ # Данные
27
+ pauq_data_dir: Path = ROOT_DIR / "data" / "pauq"
28
+ databases_dir: Path = ROOT_DIR / "data" / "databases"
29
+
30
+ # API ключи (используется только тот, который заполнен)
31
+ gigachat_api_key: str = ""
32
+ openai_api_key: str = ""
33
+ yandexgpt_api_key: str = ""
34
+ yandexgpt_folder_id: str = ""
35
+ hf_token: str = ""
36
+
37
+ # FastAPI
38
+ api_host: str = "127.0.0.1"
39
+ api_port: int = 8000
40
+
41
+ # Inference defaults
42
+ max_new_tokens: int = 256
43
+ temperature: float = 0.0 # для SQL детерминизм лучше
44
+ do_sample: bool = False
45
+
46
+
47
+ # Singleton-инстанс. Импортируется по всему проекту: `from src.config import settings`
48
+ settings = Settings()
src/data/__init__.py ADDED
File without changes
src/data/loader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Загрузчик датасета PAUQ.
2
+
3
+ PAUQ распространяется в JSON-формате с полями question, query, db_id и т.д.
4
+ См. https://github.com/ai-forever/pauq
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ from dataclasses import dataclass
11
+ from pathlib import Path
12
+ from typing import Iterator
13
+
14
+
15
+ @dataclass
16
+ class PauqExample:
17
+ question: str
18
+ query: str # gold SQL
19
+ db_id: str
20
+ query_type: str | None = None # easy/medium/hard/extra если есть
21
+ raw: dict | None = None
22
+
23
+
24
+ def load_pauq_split(path: Path | str) -> list[PauqExample]:
25
+ """Читает train.json / dev.json / test.json из PAUQ."""
26
+ path = Path(path)
27
+ with path.open("r", encoding="utf-8") as f:
28
+ raw = json.load(f)
29
+
30
+ examples: list[PauqExample] = []
31
+ for item in raw:
32
+ # PAUQ имеет несколько ревизий формата; пробуем самые частые поля
33
+ question = item.get("question") or item.get("question_ru") or ""
34
+ query = item.get("query") or item.get("sql_query") or item.get("sql") or ""
35
+ db_id = item.get("db_id") or item.get("database") or ""
36
+ if not (question and query and db_id):
37
+ continue
38
+ examples.append(
39
+ PauqExample(
40
+ question=question.strip(),
41
+ query=query.strip(),
42
+ db_id=db_id.strip(),
43
+ query_type=item.get("query_type") or item.get("hardness"),
44
+ raw=item,
45
+ )
46
+ )
47
+ return examples
48
+
49
+
50
+ def iter_pauq_split(path: Path | str) -> Iterator[PauqExample]:
51
+ """Удобно при больших датасетах — генератор."""
52
+ yield from load_pauq_split(path)
src/data/prompt.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PromptBuilder — формирует input для модели в формате chat-template."""
2
+
3
+ from __future__ import annotations
4
+
5
+ SYSTEM_PROMPT = (
6
+ "Ты — ассистент, который преобразует вопросы на русском языке в корректные SQL-запросы. "
7
+ "Тебе даётся схема базы данных в виде CREATE TABLE statements и пример нескольких строк. "
8
+ "Сгенерируй один SQL-запрос, который отвечает на вопрос пользователя. "
9
+ "Возвращай ТОЛЬКО SQL без объяснений, без markdown, без префиксов."
10
+ )
11
+
12
+
13
+ def build_user_message(schema: str, question: str) -> str:
14
+ return f"### Schema:\n{schema}\n\n### Question:\n{question}\n\n### SQL:\n"
15
+
16
+
17
+ def build_chat_messages(schema: str, question: str) -> list[dict]:
18
+ """Формат для tokenizer.apply_chat_template."""
19
+ return [
20
+ {"role": "system", "content": SYSTEM_PROMPT},
21
+ {"role": "user", "content": build_user_message(schema, question)},
22
+ ]
23
+
24
+
25
+ def build_training_example(schema: str, question: str, sql: str) -> list[dict]:
26
+ """Полный диалог для SFT с ответом ассистента."""
27
+ msgs = build_chat_messages(schema, question)
28
+ msgs.append({"role": "assistant", "content": sql.strip()})
29
+ return msgs
src/data/schema.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SchemaRetriever — извлекает DDL и примеры строк из SQLite-файлов PAUQ/Spider."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sqlite3
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+
9
+
10
+ @dataclass
11
+ class TableInfo:
12
+ name: str
13
+ create_sql: str
14
+ sample_rows: list[tuple]
15
+
16
+
17
+ class SchemaRetriever:
18
+ """Читает структуру SQLite-БД для подачи в prompt модели."""
19
+
20
+ def __init__(self, databases_dir: Path | str):
21
+ self.databases_dir = Path(databases_dir)
22
+
23
+ def db_path(self, db_id: str) -> Path:
24
+ """В Spider/PAUQ каждая БД лежит в databases_dir/{db_id}/{db_id}.sqlite."""
25
+ path = self.databases_dir / db_id / f"{db_id}.sqlite"
26
+ if not path.exists():
27
+ raise FileNotFoundError(f"Database file not found: {path}")
28
+ return path
29
+
30
+ def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableInfo]:
31
+ """Возвращает список таблиц с CREATE-SQL и примером строк."""
32
+ path = self.db_path(db_id)
33
+ conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
34
+ try:
35
+ conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
36
+ cur = conn.cursor()
37
+ cur.execute(
38
+ "SELECT name, sql FROM sqlite_master "
39
+ "WHERE type='table' AND name NOT LIKE 'sqlite_%'"
40
+ )
41
+ rows = cur.fetchall()
42
+
43
+ tables: list[TableInfo] = []
44
+ for table_name, create_sql in rows:
45
+ if not create_sql:
46
+ continue
47
+ try:
48
+ cur.execute(f'SELECT * FROM "{table_name}" LIMIT {n_sample_rows}')
49
+ samples = cur.fetchall()
50
+ except sqlite3.Error:
51
+ samples = []
52
+ tables.append(
53
+ TableInfo(name=table_name, create_sql=create_sql.strip(), sample_rows=samples)
54
+ )
55
+ return tables
56
+ finally:
57
+ conn.close()
58
+
59
+ def render_schema(self, db_id: str, include_samples: bool = True) -> str:
60
+ """Текстовое представление схемы для prompt'а."""
61
+ tables = self.get_tables(db_id)
62
+ parts: list[str] = []
63
+ for t in tables:
64
+ parts.append(t.create_sql + ";")
65
+ if include_samples and t.sample_rows:
66
+ parts.append(f"-- Примеры строк из {t.name}:")
67
+ for row in t.sample_rows:
68
+ parts.append(f"-- {row}")
69
+ parts.append("")
70
+ return "\n".join(parts).strip()
71
+
72
+ def list_databases(self) -> list[str]:
73
+ """Список доступных db_id."""
74
+ if not self.databases_dir.exists():
75
+ return []
76
+ return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())
src/db/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .connector import DbConnector
2
+ from .executor import SqlExecutor, QueryResult
3
+
4
+ __all__ = ["DbConnector", "SqlExecutor", "QueryResult"]
src/db/connector.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DbConnector -- podklyuchenie k proizvolnoy baze dannykh i chtenie skhemy.
2
+
3
+ Podderzhivaemye tipy BD:
4
+ SQLite -- put k faylu: "sqlite:///path/to/db.sqlite" ili prosto put
5
+ PostgreSQL -- "postgresql://user:pass@host:port/dbname" (trebuet psycopg2)
6
+ MySQL -- "mysql://user:pass@host:port/dbname" (trebuet pymysql)
7
+
8
+ Primer:
9
+ conn = DbConnector("sqlite:///data/demo/sales.sqlite")
10
+ print(conn.render_schema())
11
+ tables = conn.list_tables()
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import sqlite3
17
+ from dataclasses import dataclass, field
18
+ from pathlib import Path
19
+ from urllib.parse import urlparse
20
+
21
+
22
+ @dataclass
23
+ class ColumnInfo:
24
+ name: str
25
+ type: str
26
+ nullable: bool = True
27
+ primary_key: bool = False
28
+
29
+
30
+ @dataclass
31
+ class TableInfo:
32
+ name: str
33
+ columns: list[ColumnInfo] = field(default_factory=list)
34
+ sample_rows: list[tuple] = field(default_factory=list)
35
+
36
+ def to_ddl(self) -> str:
37
+ """Generiruet CREATE TABLE statement iz metadannykh."""
38
+ col_parts = []
39
+ for col in self.columns:
40
+ line = f" {col.name} {col.type}"
41
+ if col.primary_key:
42
+ line += " PRIMARY KEY"
43
+ if not col.nullable:
44
+ line += " NOT NULL"
45
+ col_parts.append(line)
46
+ return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);"
47
+
48
+
49
+ class DbConnector:
50
+ """Universalnyy konektor k BD. Umeet chitat skhemu dlya podstanovki v prompt."""
51
+
52
+ def __init__(self, connection_string: str, n_sample_rows: int = 2):
53
+ self.connection_string = self._normalize(connection_string)
54
+ self.n_sample_rows = n_sample_rows
55
+ self._db_type = self._detect_type(self.connection_string)
56
+
57
+ def list_tables(self) -> list[str]:
58
+ return [t.name for t in self._get_tables(n_sample_rows=0)]
59
+
60
+ def get_schema(self, include_samples: bool = True) -> list[TableInfo]:
61
+ return self._get_tables(n_sample_rows=self.n_sample_rows if include_samples else 0)
62
+
63
+ def render_schema(self, include_samples: bool = True) -> str:
64
+ tables = self.get_schema(include_samples=include_samples)
65
+ parts: list[str] = []
66
+ for t in tables:
67
+ parts.append(t.to_ddl())
68
+ if include_samples and t.sample_rows:
69
+ parts.append(f"-- Primery strok iz {t.name}:")
70
+ for row in t.sample_rows:
71
+ parts.append(f"-- {row}")
72
+ parts.append("")
73
+ return "\n".join(parts).strip()
74
+
75
+ def test_connection(self) -> bool:
76
+ try:
77
+ self._get_tables(n_sample_rows=0)
78
+ return True
79
+ except Exception:
80
+ return False
81
+
82
+ def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
83
+ if self._db_type == "sqlite":
84
+ return self._get_tables_sqlite(n_sample_rows)
85
+ elif self._db_type == "postgresql":
86
+ return self._get_tables_postgres(n_sample_rows)
87
+ elif self._db_type == "mysql":
88
+ return self._get_tables_mysql(n_sample_rows)
89
+ else:
90
+ raise ValueError(f"Neizvestnyy tip BD: {self._db_type}")
91
+
92
+ def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
93
+ path = self._safe_sqlite_path(self._sqlite_path())
94
+ conn = sqlite3.connect(str(path))
95
+ conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
96
+ try:
97
+ cur = conn.cursor()
98
+ cur.execute(
99
+ "SELECT name FROM sqlite_master "
100
+ "WHERE type='table' AND name NOT LIKE 'sqlite_%' "
101
+ "ORDER BY name"
102
+ )
103
+ table_names = [r[0] for r in cur.fetchall()]
104
+ tables: list[TableInfo] = []
105
+ for name in table_names:
106
+ cur.execute(f'PRAGMA table_info("{name}")')
107
+ cols = [
108
+ ColumnInfo(
109
+ name=row[1],
110
+ type=row[2] or "TEXT",
111
+ nullable=not row[3],
112
+ primary_key=bool(row[5]),
113
+ )
114
+ for row in cur.fetchall()
115
+ ]
116
+ samples: list[tuple] = []
117
+ if n_sample_rows > 0:
118
+ try:
119
+ cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
120
+ samples = cur.fetchall()
121
+ except sqlite3.Error:
122
+ pass
123
+ tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
124
+ return tables
125
+ finally:
126
+ conn.close()
127
+
128
+ def _get_tables_postgres(self, n_sample_rows: int) -> list[TableInfo]:
129
+ try:
130
+ import psycopg2 # type: ignore
131
+ except ImportError as e:
132
+ raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e
133
+
134
+ conn = psycopg2.connect(self.connection_string)
135
+ try:
136
+ cur = conn.cursor()
137
+ cur.execute(
138
+ "SELECT table_name FROM information_schema.tables "
139
+ "WHERE table_schema = 'public' AND table_type = 'BASE TABLE' "
140
+ "ORDER BY table_name"
141
+ )
142
+ table_names = [r[0] for r in cur.fetchall()]
143
+ tables: list[TableInfo] = []
144
+ for name in table_names:
145
+ cur.execute(
146
+ "SELECT column_name, data_type, is_nullable "
147
+ "FROM information_schema.columns "
148
+ "WHERE table_name = %s AND table_schema = 'public' "
149
+ "ORDER BY ordinal_position",
150
+ (name,),
151
+ )
152
+ cols = [
153
+ ColumnInfo(name=r[0], type=r[1], nullable=(r[2] == "YES"))
154
+ for r in cur.fetchall()
155
+ ]
156
+ samples: list[tuple] = []
157
+ if n_sample_rows > 0:
158
+ cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
159
+ samples = cur.fetchall()
160
+ tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
161
+ return tables
162
+ finally:
163
+ conn.close()
164
+
165
+ def _get_tables_mysql(self, n_sample_rows: int) -> list[TableInfo]:
166
+ try:
167
+ import pymysql # type: ignore
168
+ except ImportError as e:
169
+ raise ImportError("Ustanovi pymysql: pip install pymysql") from e
170
+
171
+ parsed = urlparse(self.connection_string)
172
+ conn = pymysql.connect(
173
+ host=parsed.hostname,
174
+ port=parsed.port or 3306,
175
+ user=parsed.username,
176
+ password=parsed.password,
177
+ database=parsed.path.lstrip("/"),
178
+ )
179
+ try:
180
+ cur = conn.cursor()
181
+ cur.execute("SHOW TABLES")
182
+ table_names = [r[0] for r in cur.fetchall()]
183
+ tables: list[TableInfo] = []
184
+ for name in table_names:
185
+ cur.execute(f"DESCRIBE `{name}`")
186
+ cols = [
187
+ ColumnInfo(
188
+ name=r[0], type=r[1],
189
+ nullable=(r[2] == "YES"),
190
+ primary_key=(r[3] == "PRI"),
191
+ )
192
+ for r in cur.fetchall()
193
+ ]
194
+ samples: list[tuple] = []
195
+ if n_sample_rows > 0:
196
+ cur.execute(f"SELECT * FROM `{name}` LIMIT {n_sample_rows}")
197
+ samples = cur.fetchall()
198
+ tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
199
+ return tables
200
+ finally:
201
+ conn.close()
202
+
203
+ def _sqlite_path(self) -> Path:
204
+ cs = self.connection_string
205
+ if cs.startswith("sqlite:///"):
206
+ return Path(cs[10:])
207
+ return Path(cs)
208
+
209
+ @staticmethod
210
+ def _safe_sqlite_path(path: Path) -> Path:
211
+ """Esli ryadom s BD est journal-fayl, kopируем fayl vo vremennuyu direktoriu."""
212
+ import shutil
213
+ import tempfile
214
+ journal = Path(str(path) + "-journal")
215
+ wal = Path(str(path) + "-wal")
216
+ if journal.exists() or wal.exists():
217
+ tmp = Path(tempfile.mktemp(suffix=".sqlite"))
218
+ shutil.copy2(path, tmp)
219
+ return tmp
220
+ return path
221
+
222
+ @staticmethod
223
+ def _normalize(cs: str) -> str:
224
+ """Esli peredan prosto put k faylu -- prevraschaem v sqlite:// URI."""
225
+ cs = cs.strip()
226
+ if cs.endswith(".sqlite") or cs.endswith(".db"):
227
+ return f"sqlite:///{cs}"
228
+ return cs
229
+
230
+ @staticmethod
231
+ def _detect_type(cs: str) -> str:
232
+ if cs.startswith("sqlite"):
233
+ return "sqlite"
234
+ if cs.startswith("postgresql") or cs.startswith("postgres"):
235
+ return "postgresql"
236
+ if cs.startswith("mysql"):
237
+ return "mysql"
238
+ raise ValueError(f"Ne udalos opredelit tip BD: {cs}")
src/db/executor.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """SqlExecutor -- vypolnyaet SQL-zapros na podklyuchennoy BD i vozvraschaet rezultat.
2
+
3
+ Primer:
4
+ executor = SqlExecutor("sqlite:///data/demo/sales.sqlite")
5
+ result = executor.run("SELECT SUM(amount) FROM orders WHERE status='paid'")
6
+ print(result.columns)
7
+ print(result.rows)
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import sqlite3
13
+ from dataclasses import dataclass, field
14
+ from pathlib import Path
15
+ from urllib.parse import urlparse
16
+
17
+
18
+ @dataclass
19
+ class QueryResult:
20
+ """Rezultat vypolneniya SQL-zaprosa."""
21
+ columns: list[str]
22
+ rows: list[list]
23
+ row_count: int
24
+ sql: str
25
+ error: str | None = None
26
+
27
+ @property
28
+ def success(self) -> bool:
29
+ return self.error is None
30
+
31
+ def to_dict(self) -> dict:
32
+ return {
33
+ "columns": self.columns,
34
+ "rows": self.rows,
35
+ "row_count": self.row_count,
36
+ "sql": self.sql,
37
+ "error": self.error,
38
+ }
39
+
40
+ def to_markdown_table(self) -> str:
41
+ if self.error:
42
+ return f"Oshibka: {self.error}"
43
+ if not self.rows:
44
+ return "(pustoy rezultat)"
45
+ header = " | ".join(self.columns)
46
+ sep = " | ".join(["---"] * len(self.columns))
47
+ rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
48
+ return f"{header}\n{sep}\n{rows}"
49
+
50
+
51
+ class SqlExecutor:
52
+ """Vypolnyaet SQL na podklyuchennoy BD."""
53
+
54
+ MAX_ROWS = 500
55
+
56
+ def __init__(self, connection_string: str):
57
+ self.connection_string = connection_string.strip()
58
+ self._db_type = self._detect_type(self.connection_string)
59
+
60
+ def run(self, sql: str) -> QueryResult:
61
+ try:
62
+ if self._db_type == "sqlite":
63
+ return self._run_sqlite(sql)
64
+ elif self._db_type == "postgresql":
65
+ return self._run_postgres(sql)
66
+ elif self._db_type == "mysql":
67
+ return self._run_mysql(sql)
68
+ else:
69
+ return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
70
+ error=f"Neizvestnyy tip BD: {self._db_type}")
71
+ except Exception as e:
72
+ return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))
73
+
74
+ def _run_sqlite(self, sql: str) -> QueryResult:
75
+ path = self._safe_sqlite_path(self._sqlite_path())
76
+ conn = sqlite3.connect(str(path))
77
+ conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
78
+ try:
79
+ cur = conn.cursor()
80
+ cur.execute(sql)
81
+ cols = [d[0] for d in (cur.description or [])]
82
+ rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
83
+ return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
84
+ finally:
85
+ conn.close()
86
+
87
+ def _run_postgres(self, sql: str) -> QueryResult:
88
+ try:
89
+ import psycopg2 # type: ignore
90
+ except ImportError as e:
91
+ raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e
92
+
93
+ conn = psycopg2.connect(self.connection_string)
94
+ try:
95
+ cur = conn.cursor()
96
+ cur.execute(sql)
97
+ cols = [d[0] for d in (cur.description or [])]
98
+ rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
99
+ return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
100
+ finally:
101
+ conn.close()
102
+
103
+ def _run_mysql(self, sql: str) -> QueryResult:
104
+ try:
105
+ import pymysql # type: ignore
106
+ except ImportError as e:
107
+ raise ImportError("Ustanovi pymysql: pip install pymysql") from e
108
+
109
+ parsed = urlparse(self.connection_string)
110
+ conn = pymysql.connect(
111
+ host=parsed.hostname,
112
+ port=parsed.port or 3306,
113
+ user=parsed.username,
114
+ password=parsed.password,
115
+ database=parsed.path.lstrip("/"),
116
+ )
117
+ try:
118
+ cur = conn.cursor()
119
+ cur.execute(sql)
120
+ cols = [d[0] for d in (cur.description or [])]
121
+ rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
122
+ return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
123
+ finally:
124
+ conn.close()
125
+
126
+ def _sqlite_path(self) -> Path:
127
+ cs = self.connection_string
128
+ if cs.startswith("sqlite:///"):
129
+ return Path(cs[10:])
130
+ return Path(cs)
131
+
132
+ @staticmethod
133
+ def _safe_sqlite_path(path: Path) -> Path:
134
+ import shutil
135
+ import tempfile
136
+ journal = Path(str(path) + "-journal")
137
+ wal = Path(str(path) + "-wal")
138
+ if journal.exists() or wal.exists():
139
+ tmp = Path(tempfile.mktemp(suffix=".sqlite"))
140
+ shutil.copy2(path, tmp)
141
+ return tmp
142
+ return path
143
+
144
+ @staticmethod
145
+ def _detect_type(cs: str) -> str:
146
+ if cs.startswith("sqlite") or cs.endswith(".sqlite") or cs.endswith(".db"):
147
+ return "sqlite"
148
+ if cs.startswith("postgresql") or cs.startswith("postgres"):
149
+ return "postgresql"
150
+ if cs.startswith("mysql"):
151
+ return "mysql"
152
+ raise ValueError(f"Ne udalos opredelit tip BD: {cs}")
src/evaluation/__init__.py ADDED
File without changes
src/evaluation/evaluate.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Скрипт прогона модели на test-сплите PAUQ.
2
+
3
+ Использование:
4
+ python -m src.evaluation.evaluate --split dev --limit 50
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ from pathlib import Path
12
+
13
+ from tqdm import tqdm
14
+
15
+ from src.config import settings
16
+ from src.data.loader import load_pauq_split
17
+ from src.data.schema import SchemaRetriever
18
+ from src.evaluation.metrics import compute_metrics
19
+ from src.models.inference import InferenceEngine
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--split", default="dev", choices=["train", "dev", "test"])
25
+ parser.add_argument("--limit", type=int, default=None, help="Ограничить число примеров")
26
+ parser.add_argument("--output", type=Path, default=Path("results/predictions.jsonl"))
27
+ args = parser.parse_args()
28
+
29
+ split_path = settings.pauq_data_dir / f"{args.split}.json"
30
+ examples = load_pauq_split(split_path)
31
+ if args.limit:
32
+ examples = examples[: args.limit]
33
+
34
+ schema_ret = SchemaRetriever(settings.databases_dir)
35
+ engine = InferenceEngine()
36
+ engine.load()
37
+
38
+ predictions: list[str] = []
39
+ golds: list[str] = []
40
+ db_ids: list[str] = []
41
+ rows = []
42
+
43
+ for ex in tqdm(examples, desc="Inference"):
44
+ try:
45
+ schema = schema_ret.render_schema(ex.db_id)
46
+ except FileNotFoundError:
47
+ continue
48
+ result = engine.generate(schema, ex.question)
49
+ predictions.append(result.sql)
50
+ golds.append(ex.query)
51
+ db_ids.append(ex.db_id)
52
+ rows.append(
53
+ {
54
+ "db_id": ex.db_id,
55
+ "question": ex.question,
56
+ "gold": ex.query,
57
+ "pred": result.sql,
58
+ "raw": result.raw_output,
59
+ }
60
+ )
61
+
62
+ args.output.parent.mkdir(parents=True, exist_ok=True)
63
+ with args.output.open("w", encoding="utf-8") as f:
64
+ for r in rows:
65
+ f.write(json.dumps(r, ensure_ascii=False) + "\n")
66
+
67
+ metrics = compute_metrics(predictions, golds, db_ids, settings.databases_dir)
68
+ print(json.dumps(metrics, indent=2, ensure_ascii=False))
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
src/evaluation/metrics.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Метрики Text-to-SQL: Exact Match и Execution Accuracy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sqlite3
6
+ from pathlib import Path
7
+
8
+ from src.models.postprocess import normalize_sql
9
+
10
+
11
+ def exact_match(predicted: str, gold: str, dialect: str = "sqlite") -> bool:
12
+ """Сравнение нормализованных SQL посимвольно. Грубая, но честная метрика."""
13
+ return normalize_sql(predicted, dialect) == normalize_sql(gold, dialect)
14
+
15
+
16
+ def execution_accuracy(
17
+ predicted_sql: str,
18
+ gold_sql: str,
19
+ db_path: Path | str,
20
+ timeout_seconds: float = 5.0,
21
+ ) -> bool:
22
+ """Прогон обоих SQL на SQLite. True если результаты совпадают как множества."""
23
+ db_path = Path(db_path)
24
+ conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=timeout_seconds)
25
+ try:
26
+ conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
27
+ try:
28
+ pred_rows = _run(conn, predicted_sql)
29
+ except sqlite3.Error:
30
+ return False
31
+ try:
32
+ gold_rows = _run(conn, gold_sql)
33
+ except sqlite3.Error:
34
+ return False
35
+ return _rows_equal(pred_rows, gold_rows)
36
+ finally:
37
+ conn.close()
38
+
39
+
40
+ def _run(conn: sqlite3.Connection, sql: str) -> list[tuple]:
41
+ cur = conn.cursor()
42
+ cur.execute(sql)
43
+ return cur.fetchall()
44
+
45
+
46
+ def _rows_equal(a: list[tuple], b: list[tuple]) -> bool:
47
+ """Сравнение как мультимножеств — порядок не важен (если в SQL нет ORDER BY)."""
48
+ if len(a) != len(b):
49
+ return False
50
+ return sorted(map(_row_key, a)) == sorted(map(_row_key, b))
51
+
52
+
53
+ def _row_key(row: tuple) -> tuple:
54
+ return tuple(str(x) for x in row)
55
+
56
+
57
+ def compute_metrics(
58
+ predictions: list[str],
59
+ golds: list[str],
60
+ db_ids: list[str],
61
+ databases_dir: Path | str,
62
+ ) -> dict:
63
+ """Прогон по всему датасету. Возвращает dict с EM, EX, и счётчиками."""
64
+ databases_dir = Path(databases_dir)
65
+ n = len(predictions)
66
+ assert n == len(golds) == len(db_ids), "Mismatched lengths"
67
+
68
+ em_count = 0
69
+ ex_count = 0
70
+ parse_fail = 0
71
+
72
+ for pred, gold, db_id in zip(predictions, golds, db_ids):
73
+ if exact_match(pred, gold):
74
+ em_count += 1
75
+
76
+ db_path = databases_dir / db_id / f"{db_id}.sqlite"
77
+ if not db_path.exists():
78
+ parse_fail += 1
79
+ continue
80
+
81
+ if execution_accuracy(pred, gold, db_path):
82
+ ex_count += 1
83
+
84
+ return {
85
+ "n": n,
86
+ "exact_match": em_count / n if n else 0.0,
87
+ "execution_accuracy": ex_count / n if n else 0.0,
88
+ "parse_fail": parse_fail,
89
+ }
src/models/__init__.py ADDED
File without changes
src/models/inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Загрузка модели + LoRA-адаптера и инференс.
2
+
3
+ На десктопе/ноутбуке без GPU работает на CPU. Медленно, но достаточно для разработки и демо.
4
+ На Kaggle/Colab — на GPU, быстрее.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+
15
+ from src.config import settings
16
+ from src.data.prompt import build_chat_messages
17
+ from src.models.postprocess import postprocess
18
+
19
+
20
+ @dataclass
21
+ class GenerationResult:
22
+ sql: str
23
+ raw_output: str
24
+
25
+
26
+ class InferenceEngine:
27
+ """Singleton-обёртка над моделью. Загружается один раз при старте API."""
28
+
29
+ def __init__(
30
+ self,
31
+ base_model_name: str | None = None,
32
+ lora_adapter_path: str | None = None,
33
+ device: str | None = None,
34
+ ):
35
+ self.base_model_name = base_model_name or settings.base_model_name
36
+ self.lora_adapter_path = lora_adapter_path or settings.lora_adapter_path
37
+ self.device = device or settings.device
38
+ self.tokenizer = None
39
+ self.model = None
40
+ self._loaded = False
41
+
42
+ def load(self) -> None:
43
+ """Лениво грузим модель. На CPU без квантизации."""
44
+ if self._loaded:
45
+ return
46
+
47
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
48
+ # bfloat16 вдвое меньше float32 (~6 ГБ vs ~12 ГБ) и поддерживается на CPU
49
+ self.model = AutoModelForCausalLM.from_pretrained(
50
+ self.base_model_name,
51
+ dtype=torch.bfloat16,
52
+ device_map=self.device if self.device != "cpu" else None,
53
+ )
54
+
55
+ # Подцепляем LoRA-адаптер: сначала ищем локально, потом на HF Hub
56
+ adapter_path = Path(self.lora_adapter_path)
57
+ adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
58
+ try:
59
+ from peft import PeftModel
60
+ self.model = PeftModel.from_pretrained(self.model, adapter_id)
61
+ except ImportError:
62
+ pass # peft не установлен — работаем на базовой модели
63
+
64
+ self.model.eval()
65
+ self._loaded = True
66
+
67
+ def generate(
68
+ self,
69
+ schema: str,
70
+ question: str,
71
+ max_new_tokens: int | None = None,
72
+ ) -> GenerationResult:
73
+ """Принимает schema (текст DDL) и вопрос, возвращает SQL."""
74
+ if not self._loaded:
75
+ self.load()
76
+
77
+ messages = build_chat_messages(schema, question)
78
+ prompt = self.tokenizer.apply_chat_template(
79
+ messages, tokenize=False, add_generation_prompt=True
80
+ )
81
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
82
+
83
+ with torch.no_grad():
84
+ output_ids = self.model.generate(
85
+ **inputs,
86
+ max_new_tokens=max_new_tokens or settings.max_new_tokens,
87
+ do_sample=settings.do_sample,
88
+ temperature=settings.temperature if settings.do_sample else 1.0,
89
+ pad_token_id=self.tokenizer.eos_token_id,
90
+ )
91
+
92
+ new_tokens = output_ids[0][inputs["input_ids"].shape[1] :]
93
+ raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
94
+ return GenerationResult(sql=postprocess(raw), raw_output=raw)
src/models/postprocess.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Постобработка SQL: чистка вывода модели и базовая валидация через sqlglot."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+
7
+ import sqlglot
8
+ from sqlglot.errors import ParseError
9
+
10
+
11
+ def strip_model_artifacts(text: str) -> str:
12
+ """Убирает markdown-блоки, префиксы, лишний текст после SQL."""
13
+ # ```sql ... ```
14
+ m = re.search(r"```(?:sql)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
15
+ if m:
16
+ text = m.group(1)
17
+
18
+ # Убираем "SQL:", "Ответ:" и т.п. в начале
19
+ text = re.sub(r"^\s*(?:SQL|Ответ|Answer)\s*:\s*", "", text, flags=re.IGNORECASE)
20
+
21
+ # Если есть несколько SQL — берём первый до точки с запятой
22
+ text = text.strip()
23
+ if ";" in text:
24
+ head, _, _ = text.partition(";")
25
+ text = head.strip() + ";"
26
+
27
+ return text.strip()
28
+
29
+
30
+ def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
31
+ """Парсится ли SQL через sqlglot."""
32
+ try:
33
+ sqlglot.parse_one(sql, dialect=dialect)
34
+ return True
35
+ except ParseError:
36
+ return False
37
+
38
+
39
+ def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
40
+ """Нормализация для Exact Match: единый регистр ключевых слов, пробелы."""
41
+ try:
42
+ return sqlglot.parse_one(sql, dialect=dialect).sql(dialect=dialect, pretty=False).lower()
43
+ except ParseError:
44
+ # Если не парсится — просто нижний регистр и схлопывание пробелов
45
+ return re.sub(r"\s+", " ", sql.lower()).strip().rstrip(";")
46
+
47
+
48
+ def postprocess(raw_output: str) -> str:
49
+ """Полный pipeline постобработки."""
50
+ return strip_model_artifacts(raw_output)
streamlit_app.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Streamlit-интерфейс утилиты Ru2SQL.
2
+
3
+ Запуск:
4
+ streamlit run streamlit_app.py
5
+
6
+ Что умеет:
7
+ - Подключиться к любой SQLite/PostgreSQL/MySQL базе данных
8
+ - Загрузить бизнес-словарь компании из YAML-файла или редактировать прямо в браузере
9
+ - Принять вопрос на русском → сгенерировать SQL → выполнить → показать результат
10
+ - Хранить историю запросов в текущей сессии
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import sys
16
+ import time
17
+ from pathlib import Path
18
+
19
+ import streamlit as st
20
+
21
+ # Путь к src/
22
+ ROOT = Path(__file__).resolve().parent
23
+ sys.path.insert(0, str(ROOT))
24
+
25
+ # ──────────────────────────────────────────────
26
+ # Конфигурация страницы
27
+ # ──────────────────────────────────────────────
28
+ st.set_page_config(
29
+ page_title="Ru2SQL — Natural Language → SQL",
30
+ page_icon="🗄️",
31
+ layout="wide",
32
+ initial_sidebar_state="expanded",
33
+ )
34
+
35
+ # ──────────────────────────────────────────────
36
+ # CSS
37
+ # ──────────────────────────────────────────────
38
+ st.markdown("""
39
+ <style>
40
+ .sql-box {
41
+ background: #1e1e2e;
42
+ color: #cdd6f4;
43
+ font-family: 'Courier New', monospace;
44
+ font-size: 14px;
45
+ padding: 16px;
46
+ border-radius: 8px;
47
+ border-left: 4px solid #89b4fa;
48
+ white-space: pre-wrap;
49
+ margin: 8px 0;
50
+ }
51
+ .metric-card {
52
+ background: #313244;
53
+ padding: 12px 16px;
54
+ border-radius: 8px;
55
+ text-align: center;
56
+ }
57
+ .status-ok { color: #a6e3a1; font-weight: bold; }
58
+ .status-err { color: #f38ba8; font-weight: bold; }
59
+ .history-item {
60
+ border-left: 3px solid #89b4fa;
61
+ padding: 8px 12px;
62
+ margin: 6px 0;
63
+ background: #1e1e2e;
64
+ border-radius: 0 6px 6px 0;
65
+ }
66
+ </style>
67
+ """, unsafe_allow_html=True)
68
+
69
+
70
+ # ──────────────────────────────────────────────
71
+ # Session state
72
+ # ──────────────────────────────────────────────
73
+ def _default_vocab_yaml() -> str:
74
+ example = ROOT / "configs" / "example_vocabulary.yaml"
75
+ if example.exists():
76
+ return example.read_text(encoding="utf-8")
77
+ return (
78
+ "company: Моя компания\n\n"
79
+ "terms:\n"
80
+ " выручка: SUM(orders.amount) WHERE status = 'paid'\n\n"
81
+ "filters:\n"
82
+ " только_оплаченные: orders.status = 'paid'\n\n"
83
+ "notes: []\n"
84
+ )
85
+
86
+
87
+ def _init_state():
88
+ defaults = {
89
+ "history": [],
90
+ "model_loaded": False,
91
+ "engine": None,
92
+ "db_connector": None,
93
+ "db_executor": None,
94
+ "vocabulary": None,
95
+ "db_connection_string": "",
96
+ "vocab_yaml": _default_vocab_yaml(),
97
+ }
98
+ for k, v in defaults.items():
99
+ if k not in st.session_state:
100
+ st.session_state[k] = v
101
+
102
+ _init_state()
103
+
104
+
105
+ # ──────────────────────────────────────────────
106
+ # Вспомогательные функции
107
+ # ──────────────────────────────────────────────
108
+ @st.cache_resource(show_spinner="Загружаю модель… (~30 с на первый раз)")
109
+ def _load_engine():
110
+ from src.models.inference import InferenceEngine
111
+ engine = InferenceEngine()
112
+ engine.load()
113
+ return engine
114
+
115
+
116
+ def _connect_db(cs: str):
117
+ from src.db.connector import DbConnector
118
+ from src.db.executor import SqlExecutor
119
+ connector = DbConnector(cs)
120
+ executor = SqlExecutor(cs)
121
+ return connector, executor
122
+
123
+
124
+ def _load_vocab_from_yaml(yaml_text: str):
125
+ import tempfile
126
+ from src.business.vocabulary import BusinessVocabulary
127
+ tmp = Path(tempfile.mktemp(suffix=".yaml"))
128
+ tmp.write_text(yaml_text, encoding="utf-8")
129
+ vocab = BusinessVocabulary.from_yaml(tmp)
130
+ tmp.unlink(missing_ok=True)
131
+ return vocab
132
+
133
+
134
+ # ──────────────────────────────────────────────
135
+ # Боковая панель
136
+ # ──────────────────────────────────────────────
137
+ with st.sidebar:
138
+ st.title("⚙️ Настройки")
139
+
140
+ # ── Модель — загружается автоматически при старте ──
141
+ st.subheader("🤖 Модель")
142
+ if not st.session_state.model_loaded:
143
+ with st.spinner("Загружаю модель…"):
144
+ try:
145
+ st.session_state.engine = _load_engine()
146
+ st.session_state.model_loaded = True
147
+ except Exception as e:
148
+ st.error(f"Ошибка загрузки модели: {e}")
149
+
150
+ if st.session_state.model_loaded:
151
+ st.markdown('<span class="status-ok">✅ Модель готова</span>', unsafe_allow_html=True)
152
+ else:
153
+ st.markdown('<span class="status-err">⚠️ Модель не загружена</span>', unsafe_allow_html=True)
154
+
155
+ st.divider()
156
+
157
+ # ── База данных ──
158
+ st.subheader("🗄️ База данных")
159
+
160
+ db_type = st.radio("Тип подключения", ["SQLite файл", "Строка подключения"],
161
+ horizontal=True)
162
+
163
+ if db_type == "SQLite файл":
164
+ uploaded = st.file_uploader("Загрузить .sqlite файл", type=["sqlite", "db"])
165
+ use_demo = st.checkbox("Использовать демо-базу", value=True)
166
+
167
+ if use_demo:
168
+ demo_path = ROOT / "data" / "demo" / "sales.sqlite"
169
+ cs = str(demo_path)
170
+ elif uploaded:
171
+ import tempfile
172
+ tmp_db = Path(tempfile.mktemp(suffix=".sqlite"))
173
+ tmp_db.write_bytes(uploaded.read())
174
+ cs = str(tmp_db)
175
+ else:
176
+ cs = ""
177
+ else:
178
+ cs = st.text_input(
179
+ "Строка подключения",
180
+ placeholder="postgresql://user:pass@localhost/mydb",
181
+ value=st.session_state.db_connection_string,
182
+ )
183
+
184
+ if cs and st.button("Подключиться к БД", use_container_width=True):
185
+ try:
186
+ connector, executor = _connect_db(cs)
187
+ tables = connector.list_tables()
188
+ st.session_state.db_connector = connector
189
+ st.session_state.db_executor = executor
190
+ st.session_state.db_connection_string = cs
191
+ st.success(f"Подключено! Таблиц: {len(tables)}")
192
+ except Exception as e:
193
+ st.error(f"Ошибка подключения: {e}")
194
+
195
+ if st.session_state.db_connector:
196
+ tables = st.session_state.db_connector.list_tables()
197
+ st.markdown('<span class="status-ok">✅ БД подключена</span>', unsafe_allow_html=True)
198
+ with st.expander("Таблицы"):
199
+ for t in tables:
200
+ st.code(t)
201
+
202
+ st.divider()
203
+
204
+ # ── Бизнес-словарь ──
205
+ st.subheader("📖 Бизнес-словарь")
206
+
207
+ vocab_yaml = st.text_area(
208
+ "YAML-конфигурация",
209
+ value=st.session_state.vocab_yaml,
210
+ height=260,
211
+ help="Определите термины вашей компании — модель будет их учитывать при генерации SQL",
212
+ )
213
+ st.session_state.vocab_yaml = vocab_yaml
214
+
215
+ if st.button("Применить словарь", use_container_width=True):
216
+ try:
217
+ st.session_state.vocabulary = _load_vocab_from_yaml(vocab_yaml)
218
+ st.success("Словарь применён!")
219
+ except Exception as e:
220
+ st.error(f"Ошибка в YAML: {e}")
221
+
222
+ if st.session_state.vocabulary:
223
+ v = st.session_state.vocabulary
224
+ st.markdown(f'<span class="status-ok">✅ Словарь: {v.company or "загружен"}</span>',
225
+ unsafe_allow_html=True)
226
+ terms_count = len(v.terms)
227
+ if terms_count:
228
+ st.caption(f"{terms_count} терминов определено")
229
+
230
+
231
+ # ──────────────────────────────────────────────
232
+ # Основная область
233
+ # ──────────────────────────────────────────────
234
+ st.title("🗄️ Ru2SQL — Бизнес-аналитика на русском языке")
235
+ st.caption("Задайте вопрос на русском → получите SQL и данные из вашей базы")
236
+
237
+ tab_query, tab_schema, tab_history = st.tabs(["💬 Запрос", "📐 Схема БД", "🕓 История"])
238
+
239
+ # ──────────── Вкладка: Запрос ────────────
240
+ with tab_query:
241
+ ready = st.session_state.model_loaded and st.session_state.db_connector is not None
242
+
243
+ if not ready:
244
+ cols = st.columns(2)
245
+ with cols[0]:
246
+ if not st.session_state.model_loaded:
247
+ st.warning("⚠️ Загрузите модель в левой панели")
248
+ with cols[1]:
249
+ if st.session_state.db_connector is None:
250
+ st.warning("⚠️ Подключитесь к базе данных в левой панели")
251
+
252
+ question = st.text_area(
253
+ "Ваш вопрос",
254
+ placeholder="Например: Какая выручка за январь этого года?",
255
+ height=100,
256
+ disabled=not ready,
257
+ )
258
+
259
+ col_btn, col_hint = st.columns([1, 4])
260
+ with col_btn:
261
+ run_btn = st.button("▶ Выполнить", type="primary",
262
+ disabled=not ready or not question.strip(),
263
+ use_container_width=True)
264
+ with col_hint:
265
+ if ready:
266
+ st.caption("Модель сгенерирует SQL и выполнит его на вашей БД")
267
+
268
+ # Быстрые примеры
269
+ if st.session_state.db_connection_string and "sales" in st.session_state.db_connection_string:
270
+ st.caption("💡 Попробуйте:")
271
+ example_cols = st.columns(3)
272
+ examples = [
273
+ "Какая выручка за 2026 год?",
274
+ "Топ-5 клиентов по сумме заказов",
275
+ "Сколько заказов по каждому менеджеру?",
276
+ ]
277
+ for i, ex in enumerate(examples):
278
+ with example_cols[i]:
279
+ if st.button(ex, key=f"ex_{i}", use_container_width=True):
280
+ question = ex
281
+ run_btn = True
282
+
283
+ if run_btn and question.strip():
284
+ engine = st.session_state.engine
285
+ connector = st.session_state.db_connector
286
+ executor = st.session_state.db_executor
287
+ vocab = st.session_state.vocabulary
288
+
289
+ # Обогащаем вопрос бизнес-словарём
290
+ enriched_question = vocab.enrich_prompt(question) if vocab else question
291
+
292
+ # Получаем схему
293
+ schema = connector.render_schema(include_samples=True)
294
+
295
+ with st.spinner("Генерирую SQL…"):
296
+ t0 = time.time()
297
+ result = engine.generate(schema, enriched_question)
298
+ gen_time = time.time() - t0
299
+
300
+ st.subheader("Сгенерированный SQL")
301
+ st.markdown(f'<div class="sql-box">{result.sql}</div>', unsafe_allow_html=True)
302
+
303
+ col1, col2 = st.columns(2)
304
+ col1.metric("Время генерации", f"{gen_time:.1f} с")
305
+
306
+ # Выполняем SQL
307
+ if result.sql.strip():
308
+ with st.spinner("Выполняю запрос…"):
309
+ qr = executor.run(result.sql)
310
+
311
+ if qr.success:
312
+ col2.metric("Строк в результате", qr.row_count)
313
+ st.subheader("Результат")
314
+ if qr.rows:
315
+ import pandas as pd
316
+ df = pd.DataFrame(qr.rows, columns=qr.columns)
317
+ st.dataframe(df, use_container_width=True)
318
+ else:
319
+ st.info("Запрос выполнен успешно, результат пустой")
320
+ else:
321
+ col2.error("Ошибка выполнения")
322
+ st.error(f"SQL ошибка: {qr.error}")
323
+
324
+ # Добавляем в историю
325
+ st.session_state.history.append({
326
+ "question": question,
327
+ "sql": result.sql,
328
+ "success": qr.success if result.sql.strip() else False,
329
+ "rows": qr.row_count if result.sql.strip() and qr.success else 0,
330
+ "time": gen_time,
331
+ })
332
+
333
+ # ──────────── Вкладка: Схема БД ────────────
334
+ with tab_schema:
335
+ if st.session_state.db_connector is None:
336
+ st.info("Подключитесь к базе данных в левой панели")
337
+ else:
338
+ connector = st.session_state.db_connector
339
+ st.subheader("Структура базы данных")
340
+
341
+ show_samples = st.toggle("Показывать примеры строк", value=True)
342
+ schema_text = connector.render_schema(include_samples=show_samples)
343
+
344
+ for table in connector.get_schema(include_samples=show_samples):
345
+ with st.expander(f"📋 {table.name} ({len(table.columns)} колонок)"):
346
+ st.code(table.to_ddl(), language="sql")
347
+ if show_samples and table.sample_rows:
348
+ import pandas as pd
349
+ cols = [c.name for c in table.columns]
350
+ st.caption("Примеры строк:")
351
+ st.dataframe(
352
+ pd.DataFrame(table.sample_rows, columns=cols),
353
+ use_container_width=True,
354
+ )
355
+
356
+ # ──────────── Вкладка: История ────────────
357
+ with tab_history:
358
+ history = st.session_state.history
359
+ if not history:
360
+ st.info("История запросов пуста. Задайте первый вопрос на вкладке «Запрос».")
361
+ else:
362
+ st.subheader(f"История запросов ({len(history)})")
363
+
364
+ if st.button("Очистить историю"):
365
+ st.session_state.history = []
366
+ st.rerun()
367
+
368
+ for i, item in enumerate(reversed(history)):
369
+ status = "✅" if item["success"] else "❌"
370
+ with st.expander(f"{status} {item['question']}", expanded=(i == 0)):
371
+ st.markdown(f'<div class="sql-box">{item["sql"]}</div>', unsafe_allow_html=True)
372
+ cols = st.columns(3)
373
+ cols[0].metric("Время генерации", f"{item['time']:.1f} с")
374
+ cols[1].metric("Строк", item["rows"])
375
+ cols[2].metric("Статус", "OK" if item["success"] else "Ошибка")
tests/__init__.py ADDED
File without changes
tests/test_metrics.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Тесты на метрики EM и EX."""
2
+
3
+ import sqlite3
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+
8
+ from src.evaluation.metrics import exact_match, execution_accuracy
9
+
10
+
11
+ def test_exact_match_simple():
12
+ assert exact_match("SELECT * FROM t", "select * from t")
13
+
14
+
15
+ def test_exact_match_whitespace():
16
+ assert exact_match("SELECT * FROM t", "SELECT * FROM t")
17
+
18
+
19
+ def test_exact_match_negative():
20
+ assert not exact_match("SELECT a FROM t", "SELECT b FROM t")
21
+
22
+
23
+ @pytest.fixture
24
+ def tmp_sqlite(tmp_path: Path) -> Path:
25
+ db = tmp_path / "tiny.sqlite"
26
+ conn = sqlite3.connect(db)
27
+ conn.execute("CREATE TABLE users (id INT, name TEXT)")
28
+ conn.executemany("INSERT INTO users VALUES (?, ?)", [(1, "a"), (2, "b")])
29
+ conn.commit()
30
+ conn.close()
31
+ return db
32
+
33
+
34
+ def test_execution_accuracy_match(tmp_sqlite: Path):
35
+ pred = "SELECT id FROM users ORDER BY id"
36
+ gold = "SELECT id FROM users ORDER BY id"
37
+ assert execution_accuracy(pred, gold, tmp_sqlite)
38
+
39
+
40
+ def test_execution_accuracy_set_equal(tmp_sqlite: Path):
41
+ pred = "SELECT id FROM users ORDER BY id DESC"
42
+ gold = "SELECT id FROM users ORDER BY id ASC"
43
+ # Без ORDER BY проверки — как множества они равны
44
+ assert execution_accuracy(pred, gold, tmp_sqlite)
45
+
46
+
47
+ def test_execution_accuracy_mismatch(tmp_sqlite: Path):
48
+ pred = "SELECT id FROM users WHERE id = 1"
49
+ gold = "SELECT id FROM users WHERE id = 2"
50
+ assert not execution_accuracy(pred, gold, tmp_sqlite)
51
+
52
+
53
+ def test_execution_accuracy_invalid_pred(tmp_sqlite: Path):
54
+ pred = "SELEC bad sql"
55
+ gold = "SELECT id FROM users"
56
+ assert not execution_accuracy(pred, gold, tmp_sqlite)
tests/test_postprocess.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Тесты на постобработку SQL."""
2
+
3
+ from src.models.postprocess import (
4
+ is_valid_sql,
5
+ normalize_sql,
6
+ postprocess,
7
+ strip_model_artifacts,
8
+ )
9
+
10
+
11
+ def test_strip_markdown_block():
12
+ raw = "```sql\nSELECT * FROM users;\n```"
13
+ assert strip_model_artifacts(raw).startswith("SELECT")
14
+
15
+
16
+ def test_strip_sql_prefix():
17
+ raw = "SQL: SELECT 1;"
18
+ assert strip_model_artifacts(raw).startswith("SELECT")
19
+
20
+
21
+ def test_keeps_first_statement():
22
+ raw = "SELECT 1; SELECT 2;"
23
+ out = strip_model_artifacts(raw)
24
+ assert "SELECT 1" in out
25
+ assert "SELECT 2" not in out
26
+
27
+
28
+ def test_valid_sql():
29
+ assert is_valid_sql("SELECT * FROM students WHERE id = 1")
30
+
31
+
32
+ def test_invalid_sql():
33
+ assert not is_valid_sql("SELEC * FRM where")
34
+
35
+
36
+ def test_normalize_em():
37
+ a = "SELECT * FROM Users"
38
+ b = "select * from users"
39
+ assert normalize_sql(a) == normalize_sql(b)
40
+
41
+
42
+ def test_postprocess_full():
43
+ raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
44
+ out = postprocess(raw)
45
+ assert out.startswith("SELECT name")
46
+ assert "SELECT 2" not in out
tests/test_prompt.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Тесты на PromptBuilder."""
2
+
3
+ from src.data.prompt import (
4
+ SYSTEM_PROMPT,
5
+ build_chat_messages,
6
+ build_training_example,
7
+ build_user_message,
8
+ )
9
+
10
+
11
+ def test_user_message_contains_parts():
12
+ msg = build_user_message("CREATE TABLE t (id INT);", "Покажи всё")
13
+ assert "Schema:" in msg
14
+ assert "Question:" in msg
15
+ assert "SQL:" in msg
16
+ assert "CREATE TABLE" in msg
17
+ assert "Покажи всё" in msg
18
+
19
+
20
+ def test_chat_messages_have_system_and_user():
21
+ msgs = build_chat_messages("schema", "question")
22
+ assert len(msgs) == 2
23
+ assert msgs[0]["role"] == "system"
24
+ assert msgs[0]["content"] == SYSTEM_PROMPT
25
+ assert msgs[1]["role"] == "user"
26
+
27
+
28
+ def test_training_example_has_assistant():
29
+ msgs = build_training_example("schema", "question", "SELECT 1")
30
+ assert len(msgs) == 3
31
+ assert msgs[2]["role"] == "assistant"
32
+ assert msgs[2]["content"] == "SELECT 1"
tests/test_schema.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Тесты на SchemaRetriever."""
2
+
3
+ import sqlite3
4
+ from pathlib import Path
5
+
6
+ import pytest
7
+
8
+ from src.data.schema import SchemaRetriever
9
+
10
+
11
+ @pytest.fixture
12
+ def fake_databases_dir(tmp_path: Path) -> Path:
13
+ """Создаёт структуру databases/uni/uni.sqlite с двумя таблицами."""
14
+ db_id = "uni"
15
+ (tmp_path / db_id).mkdir()
16
+ db_path = tmp_path / db_id / f"{db_id}.sqlite"
17
+ conn = sqlite3.connect(db_path)
18
+ conn.execute("CREATE TABLE students (id INTEGER PRIMARY KEY, name TEXT)")
19
+ conn.execute("CREATE TABLE groups (id INTEGER PRIMARY KEY, faculty TEXT)")
20
+ conn.execute("INSERT INTO students VALUES (1, 'Иван')")
21
+ conn.execute("INSERT INTO groups VALUES (10, 'ПИ')")
22
+ conn.commit()
23
+ conn.close()
24
+ return tmp_path
25
+
26
+
27
+ def test_list_databases(fake_databases_dir: Path):
28
+ r = SchemaRetriever(fake_databases_dir)
29
+ assert r.list_databases() == ["uni"]
30
+
31
+
32
+ def test_get_tables(fake_databases_dir: Path):
33
+ r = SchemaRetriever(fake_databases_dir)
34
+ tables = r.get_tables("uni")
35
+ names = sorted(t.name for t in tables)
36
+ assert names == ["groups", "students"]
37
+
38
+
39
+ def test_render_schema_contains_create(fake_databases_dir: Path):
40
+ r = SchemaRetriever(fake_databases_dir)
41
+ text = r.render_schema("uni")
42
+ assert "CREATE TABLE" in text
43
+ assert "students" in text
44
+ assert "groups" in text