Commit ·
8871df9
0
Parent(s):
initial commit
Browse files- .env.example +23 -0
- .gitignore +70 -0
- README.md +184 -0
- adapters/qwen-coder-pauq-lora/.gitattributes +36 -0
- adapters/qwen-coder-pauq-lora/README.md +199 -0
- adapters/qwen-coder-pauq-lora/adapter_config.json +48 -0
- adapters/qwen-coder-pauq-lora/chat_template.jinja +54 -0
- adapters/qwen-coder-pauq-lora/tokenizer.json +3 -0
- adapters/qwen-coder-pauq-lora/tokenizer_config.json +30 -0
- configs/example_vocabulary.yaml +33 -0
- data/demo/sales.sqlite +0 -0
- data/demo/sales.sqlite-journal +0 -0
- data/demo/test.db +0 -0
- data/demo/test.db-journal +0 -0
- data/pauq_repo +1 -0
- notebooks/kaggle_train_qwen_qlora.ipynb +428 -0
- plan_VKR_text2sql_ru.md +264 -0
- pyproject.toml +74 -0
- requirements.txt +14 -0
- src/__init__.py +3 -0
- src/api/__init__.py +0 -0
- src/api/dependencies.py +42 -0
- src/api/main.py +110 -0
- src/api/schemas.py +36 -0
- src/business/__init__.py +3 -0
- src/business/vocabulary.py +173 -0
- src/config.py +48 -0
- src/data/__init__.py +0 -0
- src/data/loader.py +52 -0
- src/data/prompt.py +29 -0
- src/data/schema.py +76 -0
- src/db/__init__.py +4 -0
- src/db/connector.py +238 -0
- src/db/executor.py +152 -0
- src/evaluation/__init__.py +0 -0
- src/evaluation/evaluate.py +72 -0
- src/evaluation/metrics.py +89 -0
- src/models/__init__.py +0 -0
- src/models/inference.py +94 -0
- src/models/postprocess.py +50 -0
- streamlit_app.py +375 -0
- tests/__init__.py +0 -0
- tests/test_metrics.py +56 -0
- tests/test_postprocess.py +46 -0
- tests/test_prompt.py +32 -0
- tests/test_schema.py +44 -0
.env.example
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Скопируй в .env и заполни. .env в git не уходит.
|
| 2 |
+
|
| 3 |
+
# API ключ для baseline-сравнения (выбери одного провайдера)
|
| 4 |
+
GIGACHAT_API_KEY=
|
| 5 |
+
OPENAI_API_KEY=
|
| 6 |
+
YANDEXGPT_API_KEY=
|
| 7 |
+
YANDEXGPT_FOLDER_ID=
|
| 8 |
+
|
| 9 |
+
# HuggingFace (нужен для скачивания приватных адаптеров)
|
| 10 |
+
HF_TOKEN=
|
| 11 |
+
|
| 12 |
+
# Локальная модель
|
| 13 |
+
BASE_MODEL_NAME=Qwen/Qwen2.5-Coder-3B-Instruct
|
| 14 |
+
LORA_ADAPTER_PATH=./checkpoints/qwen-coder-pauq-lora
|
| 15 |
+
DEVICE=cpu
|
| 16 |
+
|
| 17 |
+
# Пути
|
| 18 |
+
PAUQ_DATA_DIR=./data/pauq
|
| 19 |
+
DATABASES_DIR=./data/databases
|
| 20 |
+
|
| 21 |
+
# API
|
| 22 |
+
API_HOST=127.0.0.1
|
| 23 |
+
API_PORT=8000
|
.gitignore
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
.eggs/
|
| 9 |
+
dist/
|
| 10 |
+
build/
|
| 11 |
+
|
| 12 |
+
# Virtual environments
|
| 13 |
+
.venv/
|
| 14 |
+
venv/
|
| 15 |
+
env/
|
| 16 |
+
.python-version
|
| 17 |
+
|
| 18 |
+
# uv
|
| 19 |
+
.uv/
|
| 20 |
+
uv.lock
|
| 21 |
+
|
| 22 |
+
# IDE
|
| 23 |
+
.vscode/
|
| 24 |
+
.idea/
|
| 25 |
+
*.swp
|
| 26 |
+
*.swo
|
| 27 |
+
|
| 28 |
+
# Jupyter
|
| 29 |
+
.ipynb_checkpoints/
|
| 30 |
+
*.ipynb_checkpoints
|
| 31 |
+
|
| 32 |
+
# Environment variables
|
| 33 |
+
.env
|
| 34 |
+
.env.local
|
| 35 |
+
.env.*.local
|
| 36 |
+
|
| 37 |
+
# ML artifacts
|
| 38 |
+
checkpoints/
|
| 39 |
+
wandb/
|
| 40 |
+
*.bin
|
| 41 |
+
*.safetensors
|
| 42 |
+
*.gguf
|
| 43 |
+
|
| 44 |
+
# Data
|
| 45 |
+
data/pauq/
|
| 46 |
+
data/databases/
|
| 47 |
+
data/processed/
|
| 48 |
+
data/*.json
|
| 49 |
+
data/*.sqlite
|
| 50 |
+
data/*.db
|
| 51 |
+
# Демо-база нужна в репозитории
|
| 52 |
+
!data/demo/sales.sqlite
|
| 53 |
+
|
| 54 |
+
# Logs
|
| 55 |
+
*.log
|
| 56 |
+
logs/
|
| 57 |
+
|
| 58 |
+
# OS
|
| 59 |
+
.DS_Store
|
| 60 |
+
Thumbs.db
|
| 61 |
+
desktop.ini
|
| 62 |
+
|
| 63 |
+
# Test artifacts
|
| 64 |
+
.pytest_cache/
|
| 65 |
+
.coverage
|
| 66 |
+
htmlcov/
|
| 67 |
+
|
| 68 |
+
# Outputs
|
| 69 |
+
outputs/
|
| 70 |
+
results/
|
README.md
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Ru2SQL
|
| 3 |
+
emoji: 🗄️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.35.0
|
| 8 |
+
app_file: streamlit_app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# ru2sql
|
| 13 |
+
|
| 14 |
+
Генеративная модель для преобразования вопросов на русском языке в SQL-запросы.
|
| 15 |
+
Практическая часть ВКР, направление «Программная инженерия», 4 курс.
|
| 16 |
+
|
| 17 |
+
**Стек:** Python 3.10+, PyTorch, transformers, PEFT (LoRA), FastAPI, sqlglot.
|
| 18 |
+
**Основная модель:** Qwen2.5-Coder-3B-Instruct, дообученная методом QLoRA на датасете PAUQ.
|
| 19 |
+
**Сравнение:** ruT5-base baseline + GigaChat API.
|
| 20 |
+
|
| 21 |
+
См. `plan_VKR_text2sql_ru.md` для полного плана работ на месяц.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Быстрый старт (на десктопе)
|
| 26 |
+
|
| 27 |
+
### 1. Установка
|
| 28 |
+
|
| 29 |
+
```bash
|
| 30 |
+
# Установи uv (https://docs.astral.sh/uv/) если ещё нет
|
| 31 |
+
pip install uv
|
| 32 |
+
|
| 33 |
+
# Клонируй репозиторий и установи зависимости
|
| 34 |
+
git clone <твой-репо> ru2sql
|
| 35 |
+
cd ru2sql
|
| 36 |
+
uv venv
|
| 37 |
+
.venv\Scripts\activate # Windows
|
| 38 |
+
# source .venv/bin/activate # Linux/Mac
|
| 39 |
+
uv pip install -e ".[dev]"
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
### 2. Конфигурация
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
copy .env.example .env # Windows
|
| 46 |
+
# cp .env.example .env # Linux/Mac
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
Открой `.env` и заполни ключи (минимум `GIGACHAT_API_KEY` для baseline-сравнения, остальное опционально).
|
| 50 |
+
|
| 51 |
+
### 3. Скачай PAUQ
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
git clone https://github.com/ai-forever/pauq.git data/pauq_repo
|
| 55 |
+
# Затем разложи train.json/dev.json/test.json в data/pauq/
|
| 56 |
+
# и SQLite-файлы в data/databases/{db_id}/{db_id}.sqlite
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
### 4. Тесты
|
| 60 |
+
|
| 61 |
+
```bash
|
| 62 |
+
pytest -v
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
Тесты для модулей `prompt`, `postprocess`, `metrics`, `schema` должны проходить
|
| 66 |
+
без скачивания модели и датасета.
|
| 67 |
+
|
| 68 |
+
### 5. Запуск API
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
uvicorn src.api.main:app --reload
|
| 72 |
+
# Swagger UI: http://127.0.0.1:8000/docs
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
При первом запуске модель Qwen2.5-Coder-3B (~6 GB) скачается из HuggingFace Hub.
|
| 76 |
+
На CPU инференс занимает 15–30 секунд на запрос — это ожидаемо.
|
| 77 |
+
|
| 78 |
+
### 6. Запрос к API
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
curl -X POST http://127.0.0.1:8000/generate-sql \
|
| 82 |
+
-H "Content-Type: application/json" \
|
| 83 |
+
-d '{"question": "Сколько студентов на факультете ПИ?", "db_id": "university"}'
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## Обучение модели
|
| 89 |
+
|
| 90 |
+
Тренировка идёт **в Kaggle Notebook** (бесплатный T4 GPU). Локально на CPU/AMD GPU
|
| 91 |
+
обучить 3B-модель не получится.
|
| 92 |
+
|
| 93 |
+
Шаги:
|
| 94 |
+
1. Открой `notebooks/kaggle_train_qwen_qlora.ipynb` на kaggle.com.
|
| 95 |
+
2. В Settings выбери Accelerator: GPU T4 x1 (или x2 для скорости).
|
| 96 |
+
3. Add-ons → Secrets → добавь `HF_TOKEN` и `WANDB_API_KEY`.
|
| 97 |
+
4. Запусти все ячейки. Тренировка ~4–6 часов.
|
| 98 |
+
5. По завершении адаптер пушится на твой приватный HF-репо.
|
| 99 |
+
6. Скачай его на десктоп:
|
| 100 |
+
```bash
|
| 101 |
+
huggingface-cli download your-username/qwen-coder-pauq-lora \
|
| 102 |
+
--local-dir checkpoints/qwen-coder-pauq-lora
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
После этого `LORA_ADAPTER_PATH` в `.env` укажет на скачанный адаптер,
|
| 106 |
+
и API будет использовать дообученную модель.
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## Структура проекта
|
| 111 |
+
|
| 112 |
+
```
|
| 113 |
+
ru2sql/
|
| 114 |
+
├── pyproject.toml # зависимости (uv)
|
| 115 |
+
├── .env.example # шаблон конфигурации
|
| 116 |
+
├── plan_VKR_text2sql_ru.md # план работ на месяц
|
| 117 |
+
├── notebooks/
|
| 118 |
+
│ └── kaggle_train_qwen_qlora.ipynb
|
| 119 |
+
├── src/
|
| 120 |
+
│ ├── config.py # настройки через pydantic-settings
|
| 121 |
+
│ ├── data/
|
| 122 |
+
│ │ ├── loader.py # чтение PAUQ JSON
|
| 123 |
+
│ │ ├── schema.py # SchemaRetriever (DDL из SQLite)
|
| 124 |
+
│ │ └── prompt.py # PromptBuilder + chat-template
|
| 125 |
+
│ ├── models/
|
| 126 |
+
│ │ ├── inference.py # InferenceEngine (модель + LoRA)
|
| 127 |
+
│ │ └── postprocess.py # очистка SQL + sqlglot валидация
|
| 128 |
+
│ ├── evaluation/
|
| 129 |
+
│ │ ├── metrics.py # Exact Match + Execution Accuracy
|
| 130 |
+
│ │ └── evaluate.py # CLI для прогона на split'е
|
| 131 |
+
│ └── api/
|
| 132 |
+
│ ├── main.py # FastAPI app
|
| 133 |
+
│ ├── schemas.py # Pydantic-модели
|
| 134 |
+
│ └── dependencies.py # lifespan + DI
|
| 135 |
+
└── tests/
|
| 136 |
+
├── test_prompt.py
|
| 137 |
+
├── test_postprocess.py
|
| 138 |
+
├── test_metrics.py
|
| 139 |
+
└── test_schema.py
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
---
|
| 143 |
+
|
| 144 |
+
## Прогон оценки
|
| 145 |
+
|
| 146 |
+
```bash
|
| 147 |
+
# Полный прогон на dev split
|
| 148 |
+
python -m src.evaluation.evaluate --split dev
|
| 149 |
+
|
| 150 |
+
# Быстрая проверка на 50 примерах
|
| 151 |
+
python -m src.evaluation.evaluate --split dev --limit 50
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Результат сохраняется в `results/predictions.jsonl`, метрики печатаются в stdout.
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
## Метрики (планируемые)
|
| 159 |
+
|
| 160 |
+
| Модель | EM | Execution Accuracy |
|
| 161 |
+
|---|---|---|
|
| 162 |
+
| ruT5-base (baseline) | 25–35% | 30–40% |
|
| 163 |
+
| **Qwen2.5-Coder-3B + QLoRA** | **50–60%** | **55–70%** |
|
| 164 |
+
| GigaChat API (zero-shot) | 55–70% | 65–80% |
|
| 165 |
+
|
| 166 |
+
---
|
| 167 |
+
|
| 168 |
+
## Что НЕ входит в MVP
|
| 169 |
+
|
| 170 |
+
Сознательно оставлено в раздел «направления дальнейшей работы»:
|
| 171 |
+
- Few-shot retrieval похожих примеров.
|
| 172 |
+
- Schema linking (автоматический отбор релевантных таблиц).
|
| 173 |
+
- Self-correction по ошибкам исполнения SQL.
|
| 174 |
+
- Constrained decoding (грамматика SQL).
|
| 175 |
+
- Дообучение на синтетических данных.
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
## Лицензия и атрибуция
|
| 180 |
+
|
| 181 |
+
Учебный проект. Использует:
|
| 182 |
+
- PAUQ — Apache 2.0, https://github.com/ai-forever/pauq
|
| 183 |
+
- Qwen2.5-Coder — Apache 2.0, https://huggingface.co/Qwen/Qwen2.5-Coder-3B-Instruct
|
| 184 |
+
- ruT5 — MIT, https://huggingface.co/ai-forever/ruT5-base
|
adapters/qwen-coder-pauq-lora/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
adapters/qwen-coder-pauq-lora/README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
adapters/qwen-coder-pauq-lora/adapter_config.json
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alora_invocation_tokens": null,
|
| 3 |
+
"alpha_pattern": {},
|
| 4 |
+
"arrow_config": null,
|
| 5 |
+
"auto_mapping": null,
|
| 6 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-Coder-3B-Instruct",
|
| 7 |
+
"bias": "none",
|
| 8 |
+
"corda_config": null,
|
| 9 |
+
"ensure_weight_tying": false,
|
| 10 |
+
"eva_config": null,
|
| 11 |
+
"exclude_modules": null,
|
| 12 |
+
"fan_in_fan_out": false,
|
| 13 |
+
"inference_mode": true,
|
| 14 |
+
"init_lora_weights": true,
|
| 15 |
+
"layer_replication": null,
|
| 16 |
+
"layers_pattern": null,
|
| 17 |
+
"layers_to_transform": null,
|
| 18 |
+
"loftq_config": {},
|
| 19 |
+
"lora_alpha": 32,
|
| 20 |
+
"lora_bias": false,
|
| 21 |
+
"lora_dropout": 0.05,
|
| 22 |
+
"lora_ga_config": null,
|
| 23 |
+
"megatron_config": null,
|
| 24 |
+
"megatron_core": "megatron.core",
|
| 25 |
+
"modules_to_save": null,
|
| 26 |
+
"peft_type": "LORA",
|
| 27 |
+
"peft_version": "0.19.1",
|
| 28 |
+
"qalora_group_size": 16,
|
| 29 |
+
"r": 16,
|
| 30 |
+
"rank_pattern": {},
|
| 31 |
+
"revision": null,
|
| 32 |
+
"target_modules": [
|
| 33 |
+
"up_proj",
|
| 34 |
+
"v_proj",
|
| 35 |
+
"gate_proj",
|
| 36 |
+
"k_proj",
|
| 37 |
+
"o_proj",
|
| 38 |
+
"q_proj",
|
| 39 |
+
"down_proj"
|
| 40 |
+
],
|
| 41 |
+
"target_parameters": null,
|
| 42 |
+
"task_type": "CAUSAL_LM",
|
| 43 |
+
"trainable_token_indices": null,
|
| 44 |
+
"use_bdlora": null,
|
| 45 |
+
"use_dora": false,
|
| 46 |
+
"use_qalora": false,
|
| 47 |
+
"use_rslora": false
|
| 48 |
+
}
|
adapters/qwen-coder-pauq-lora/chat_template.jinja
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 4 |
+
{{- messages[0]['content'] }}
|
| 5 |
+
{%- else %}
|
| 6 |
+
{{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
|
| 7 |
+
{%- endif %}
|
| 8 |
+
{{- "\n\n# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 9 |
+
{%- for tool in tools %}
|
| 10 |
+
{{- "\n" }}
|
| 11 |
+
{{- tool | tojson }}
|
| 12 |
+
{%- endfor %}
|
| 13 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 14 |
+
{%- else %}
|
| 15 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 16 |
+
{{- '<|im_start|>system\n' + messages[0]['content'] + '<|im_end|>\n' }}
|
| 17 |
+
{%- else %}
|
| 18 |
+
{{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
|
| 19 |
+
{%- endif %}
|
| 20 |
+
{%- endif %}
|
| 21 |
+
{%- for message in messages %}
|
| 22 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}
|
| 23 |
+
{{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }}
|
| 24 |
+
{%- elif message.role == "assistant" %}
|
| 25 |
+
{{- '<|im_start|>' + message.role }}
|
| 26 |
+
{%- if message.content %}
|
| 27 |
+
{{- '\n' + message.content }}
|
| 28 |
+
{%- endif %}
|
| 29 |
+
{%- for tool_call in message.tool_calls %}
|
| 30 |
+
{%- if tool_call.function is defined %}
|
| 31 |
+
{%- set tool_call = tool_call.function %}
|
| 32 |
+
{%- endif %}
|
| 33 |
+
{{- '\n<tool_call>\n{"name": "' }}
|
| 34 |
+
{{- tool_call.name }}
|
| 35 |
+
{{- '", "arguments": ' }}
|
| 36 |
+
{{- tool_call.arguments | tojson }}
|
| 37 |
+
{{- '}\n</tool_call>' }}
|
| 38 |
+
{%- endfor %}
|
| 39 |
+
{{- '<|im_end|>\n' }}
|
| 40 |
+
{%- elif message.role == "tool" %}
|
| 41 |
+
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
|
| 42 |
+
{{- '<|im_start|>user' }}
|
| 43 |
+
{%- endif %}
|
| 44 |
+
{{- '\n<tool_response>\n' }}
|
| 45 |
+
{{- message.content }}
|
| 46 |
+
{{- '\n</tool_response>' }}
|
| 47 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 48 |
+
{{- '<|im_end|>\n' }}
|
| 49 |
+
{%- endif %}
|
| 50 |
+
{%- endif %}
|
| 51 |
+
{%- endfor %}
|
| 52 |
+
{%- if add_generation_prompt %}
|
| 53 |
+
{{- '<|im_start|>assistant\n' }}
|
| 54 |
+
{%- endif %}
|
adapters/qwen-coder-pauq-lora/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3fd169731d2cbde95e10bf356d66d5997fd885dd8dbb6fb4684da3f23b2585d8
|
| 3 |
+
size 11421892
|
adapters/qwen-coder-pauq-lora/tokenizer_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_prefix_space": false,
|
| 3 |
+
"backend": "tokenizers",
|
| 4 |
+
"bos_token": null,
|
| 5 |
+
"clean_up_tokenization_spaces": false,
|
| 6 |
+
"eos_token": "<|im_end|>",
|
| 7 |
+
"errors": "replace",
|
| 8 |
+
"extra_special_tokens": [
|
| 9 |
+
"<|im_start|>",
|
| 10 |
+
"<|im_end|>",
|
| 11 |
+
"<|object_ref_start|>",
|
| 12 |
+
"<|object_ref_end|>",
|
| 13 |
+
"<|box_start|>",
|
| 14 |
+
"<|box_end|>",
|
| 15 |
+
"<|quad_start|>",
|
| 16 |
+
"<|quad_end|>",
|
| 17 |
+
"<|vision_start|>",
|
| 18 |
+
"<|vision_end|>",
|
| 19 |
+
"<|vision_pad|>",
|
| 20 |
+
"<|image_pad|>",
|
| 21 |
+
"<|video_pad|>"
|
| 22 |
+
],
|
| 23 |
+
"is_local": false,
|
| 24 |
+
"local_files_only": false,
|
| 25 |
+
"model_max_length": 32768,
|
| 26 |
+
"pad_token": "<|endoftext|>",
|
| 27 |
+
"split_special_tokens": false,
|
| 28 |
+
"tokenizer_class": "Qwen2Tokenizer",
|
| 29 |
+
"unk_token": null
|
| 30 |
+
}
|
configs/example_vocabulary.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Бизнес-словарь компании — пример заполнения
|
| 2 |
+
# Скопируй этот файл, переименуй под свою компанию и заполни своими терминами.
|
| 3 |
+
# Путь к файлу указывается при запуске утилиты.
|
| 4 |
+
|
| 5 |
+
company: "ООО Ромашка"
|
| 6 |
+
|
| 7 |
+
# Бизнес-термины и метрики
|
| 8 |
+
# Ключ — слово/фраза как говорит аналитик
|
| 9 |
+
# Значение — что это означает в терминах SQL / данных
|
| 10 |
+
terms:
|
| 11 |
+
выручка: "SUM(orders.amount) при условии orders.status = 'paid'"
|
| 12 |
+
оборот: "SUM(orders.amount) по всем заказам включая отменённые"
|
| 13 |
+
активный клиент: "клиент, совершивший хотя бы одну покупку за последние 90 дней"
|
| 14 |
+
новый клиент: "клиент, зарегистрированный менее 30 дней назад"
|
| 15 |
+
этот год: "YEAR(order_date) совпадает с текущим годом"
|
| 16 |
+
прошлый месяц: "месяц предшествующий текущему"
|
| 17 |
+
этот квартал: "текущий квартал календарного года (Q1=янв-март, Q2=апр-июн и т.д.)"
|
| 18 |
+
средний чек: "AVG(orders.amount) по оплаченным заказам"
|
| 19 |
+
конверсия: "доля оплаченных заказов от общего числа"
|
| 20 |
+
|
| 21 |
+
# Стандартные условия фильтрации (применяются по умолчанию если аналитик явно не указал иное)
|
| 22 |
+
filters:
|
| 23 |
+
только_оплаченные: "orders.status = 'paid'"
|
| 24 |
+
без_возвратов: "orders.is_return = 0 или orders.is_return IS NULL"
|
| 25 |
+
только_активные_товары: "products.is_active = 1"
|
| 26 |
+
|
| 27 |
+
# Дополнительные правила и особенности схемы
|
| 28 |
+
notes:
|
| 29 |
+
- "Таблица orders содержит все заказы. Колонка amount — сумма в рублях."
|
| 30 |
+
- "Клиенты хранятся в таблице customers, товары — в products."
|
| 31 |
+
- "Связь заказ-товар через таблицу order_items (order_id, product_id, quantity, price)."
|
| 32 |
+
- "Даты хранятся в формате YYYY-MM-DD в колонке order_date."
|
| 33 |
+
- "Менеджеры хранятся в таблице managers, связь с заказами через orders.manager_id."
|
data/demo/sales.sqlite
ADDED
|
Binary file (57.3 kB). View file
|
|
|
data/demo/sales.sqlite-journal
ADDED
|
Binary file (512 Bytes). View file
|
|
|
data/demo/test.db
ADDED
|
Binary file (8.19 kB). View file
|
|
|
data/demo/test.db-journal
ADDED
|
Binary file (512 Bytes). View file
|
|
|
data/pauq_repo
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 1c4a286e30c883f9b9bd5ca59b27cee76d4544ab
|
notebooks/kaggle_train_qwen_qlora.ipynb
ADDED
|
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# Обучение Qwen2.5-Coder-3B на PAUQ через QLoRA\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"**Где запускать:** Kaggle Notebook с GPU T4 (Settings → Accelerator → GPU T4 x2 или T4 x1).\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Что делаем:**\n",
|
| 12 |
+
"1. Ставим зависимости.\n",
|
| 13 |
+
"2. Качаем PAUQ.\n",
|
| 14 |
+
"3. Загружаем Qwen2.5-Coder-3B в 4-bit.\n",
|
| 15 |
+
"4. Готовим датасет в chat-формате.\n",
|
| 16 |
+
"5. Дообучаем LoRA-адаптер через `SFTTrainer`.\n",
|
| 17 |
+
"6. Сохраняем адаптер локально и (опционально) пушим на HuggingFace Hub.\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"**Время на T4:** ~2–3 часа на эпоху (если ~10к примеров, max_seq_length=1024)."
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "markdown",
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"source": [
|
| 26 |
+
"## 1. Установка зависимостей"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": null,
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"!pip install -q -U \\\n",
|
| 36 |
+
" transformers==4.44.2 \\\n",
|
| 37 |
+
" peft==0.12.0 \\\n",
|
| 38 |
+
" accelerate==0.33.0 \\\n",
|
| 39 |
+
" bitsandbytes==0.43.3 \\\n",
|
| 40 |
+
" trl==0.10.1 \\\n",
|
| 41 |
+
" datasets==2.20.0 \\\n",
|
| 42 |
+
" sqlglot==25.5.1 \\\n",
|
| 43 |
+
" wandb"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "markdown",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"source": [
|
| 50 |
+
"## 2. Авторизация HuggingFace и W&B\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"В Kaggle добавь секреты: `HF_TOKEN` и `WANDB_API_KEY` через Add-ons → Secrets. Тогда они подхватятся автоматически."
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"cell_type": "code",
|
| 57 |
+
"execution_count": null,
|
| 58 |
+
"metadata": {},
|
| 59 |
+
"outputs": [],
|
| 60 |
+
"source": [
|
| 61 |
+
"import os\n",
|
| 62 |
+
"from kaggle_secrets import UserSecretsClient\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"secrets = UserSecretsClient()\n",
|
| 65 |
+
"os.environ[\"HF_TOKEN\"] = secrets.get_secret(\"HF_TOKEN\")\n",
|
| 66 |
+
"os.environ[\"WANDB_API_KEY\"] = secrets.get_secret(\"WANDB_API_KEY\")\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"from huggingface_hub import login\n",
|
| 69 |
+
"login(token=os.environ[\"HF_TOKEN\"])\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"import wandb\n",
|
| 72 |
+
"wandb.login()"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"source": [
|
| 79 |
+
"## 3. Скачиваем PAUQ\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"Альтернатива: загрузить PAUQ как Kaggle Dataset и подключить через `/kaggle/input/`."
|
| 82 |
+
]
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"cell_type": "code",
|
| 86 |
+
"execution_count": null,
|
| 87 |
+
"metadata": {},
|
| 88 |
+
"outputs": [],
|
| 89 |
+
"source": [
|
| 90 |
+
"!git clone https://github.com/ai-forever/pauq.git /kaggle/working/pauq_repo\n",
|
| 91 |
+
"!ls /kaggle/working/pauq_repo"
|
| 92 |
+
]
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"cell_type": "code",
|
| 96 |
+
"execution_count": null,
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"import json\n",
|
| 101 |
+
"from pathlib import Path\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# Точные пути зависят от структуры репозитория. Найди train/dev/test файлы:\n",
|
| 104 |
+
"for p in Path(\"/kaggle/working/pauq_repo\").rglob(\"*.json\"):\n",
|
| 105 |
+
" print(p)"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "code",
|
| 110 |
+
"execution_count": null,
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [],
|
| 113 |
+
"source": [
|
| 114 |
+
"# ОБНОВИ пути после `ls` выше\n",
|
| 115 |
+
"TRAIN_JSON = Path(\"/kaggle/working/pauq_repo/path/to/train.json\")\n",
|
| 116 |
+
"DEV_JSON = Path(\"/kaggle/working/pauq_repo/path/to/dev.json\")\n",
|
| 117 |
+
"DATABASES_DIR = Path(\"/kaggle/working/pauq_repo/path/to/databases\")\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"with TRAIN_JSON.open() as f:\n",
|
| 120 |
+
" train_raw = json.load(f)\n",
|
| 121 |
+
"with DEV_JSON.open() as f:\n",
|
| 122 |
+
" dev_raw = json.load(f)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"print(f\"train: {len(train_raw)}, dev: {len(dev_raw)}\")\n",
|
| 125 |
+
"print(\"Пример:\", train_raw[0])"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"source": [
|
| 132 |
+
"## 4. SchemaRetriever и PromptBuilder (инлайн)\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"В Kaggle нет нашего пакета `src/`, поэтому копируем минимум нужного кода прямо сюда."
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": null,
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"outputs": [],
|
| 142 |
+
"source": [
|
| 143 |
+
"import sqlite3\n",
|
| 144 |
+
"from functools import lru_cache\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"SYSTEM_PROMPT = (\n",
|
| 147 |
+
" \"Ты — ассистент, который преобразует вопросы на русском языке в корректные SQL-запросы. \"\n",
|
| 148 |
+
" \"Тебе даётся схема базы данных в виде CREATE TABLE statements и пример нескольких строк. \"\n",
|
| 149 |
+
" \"Сгенерируй один SQL-запрос, который отвечает на вопрос пользователя. \"\n",
|
| 150 |
+
" \"Возвращай ТОЛЬКО SQL без объяснений, без markdown, без префиксов.\"\n",
|
| 151 |
+
")\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"@lru_cache(maxsize=512)\n",
|
| 154 |
+
"def render_schema(db_id: str, n_samples: int = 2) -> str:\n",
|
| 155 |
+
" db_path = DATABASES_DIR / db_id / f\"{db_id}.sqlite\"\n",
|
| 156 |
+
" if not db_path.exists():\n",
|
| 157 |
+
" return \"\"\n",
|
| 158 |
+
" conn = sqlite3.connect(f\"file:{db_path}?mode=ro\", uri=True)\n",
|
| 159 |
+
" conn.text_factory = lambda b: b.decode(\"utf-8\", errors=\"replace\")\n",
|
| 160 |
+
" cur = conn.cursor()\n",
|
| 161 |
+
" cur.execute(\"SELECT name, sql FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'\")\n",
|
| 162 |
+
" parts = []\n",
|
| 163 |
+
" for name, ddl in cur.fetchall():\n",
|
| 164 |
+
" if not ddl:\n",
|
| 165 |
+
" continue\n",
|
| 166 |
+
" parts.append(ddl.strip() + \";\")\n",
|
| 167 |
+
" try:\n",
|
| 168 |
+
" cur.execute(f'SELECT * FROM \"{name}\" LIMIT {n_samples}')\n",
|
| 169 |
+
" rows = cur.fetchall()\n",
|
| 170 |
+
" for r in rows:\n",
|
| 171 |
+
" parts.append(f\"-- {r}\")\n",
|
| 172 |
+
" except sqlite3.Error:\n",
|
| 173 |
+
" pass\n",
|
| 174 |
+
" parts.append(\"\")\n",
|
| 175 |
+
" conn.close()\n",
|
| 176 |
+
" return \"\\n\".join(parts).strip()\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"def build_messages(schema: str, question: str, sql: str | None = None):\n",
|
| 179 |
+
" user = f\"### Schema:\\n{schema}\\n\\n### Question:\\n{question}\\n\\n### SQL:\\n\"\n",
|
| 180 |
+
" msgs = [\n",
|
| 181 |
+
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
| 182 |
+
" {\"role\": \"user\", \"content\": user},\n",
|
| 183 |
+
" ]\n",
|
| 184 |
+
" if sql is not None:\n",
|
| 185 |
+
" msgs.append({\"role\": \"assistant\", \"content\": sql.strip()})\n",
|
| 186 |
+
" return msgs"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "markdown",
|
| 191 |
+
"metadata": {},
|
| 192 |
+
"source": [
|
| 193 |
+
"## 5. Готовим датасет для SFT"
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"cell_type": "code",
|
| 198 |
+
"execution_count": null,
|
| 199 |
+
"metadata": {},
|
| 200 |
+
"outputs": [],
|
| 201 |
+
"source": [
|
| 202 |
+
"from datasets import Dataset\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"def to_record(item):\n",
|
| 205 |
+
" q = item.get(\"question\") or item.get(\"question_ru\") or \"\"\n",
|
| 206 |
+
" sql = item.get(\"query\") or item.get(\"sql_query\") or item.get(\"sql\") or \"\"\n",
|
| 207 |
+
" db_id = item.get(\"db_id\") or item.get(\"database\") or \"\"\n",
|
| 208 |
+
" if not (q and sql and db_id):\n",
|
| 209 |
+
" return None\n",
|
| 210 |
+
" schema = render_schema(db_id)\n",
|
| 211 |
+
" if not schema:\n",
|
| 212 |
+
" return None\n",
|
| 213 |
+
" return {\"messages\": build_messages(schema, q.strip(), sql.strip())}\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"train_records = [r for r in (to_record(x) for x in train_raw) if r]\n",
|
| 216 |
+
"dev_records = [r for r in (to_record(x) for x in dev_raw) if r]\n",
|
| 217 |
+
"print(f\"train usable: {len(train_records)}, dev usable: {len(dev_records)}\")\n",
|
| 218 |
+
"\n",
|
| 219 |
+
"train_ds = Dataset.from_list(train_records)\n",
|
| 220 |
+
"dev_ds = Dataset.from_list(dev_records)"
|
| 221 |
+
]
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"cell_type": "markdown",
|
| 225 |
+
"metadata": {},
|
| 226 |
+
"source": [
|
| 227 |
+
"## 6. Загружаем модель в 4-bit"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"cell_type": "code",
|
| 232 |
+
"execution_count": null,
|
| 233 |
+
"metadata": {},
|
| 234 |
+
"outputs": [],
|
| 235 |
+
"source": [
|
| 236 |
+
"import torch\n",
|
| 237 |
+
"from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"MODEL_NAME = \"Qwen/Qwen2.5-Coder-3B-Instruct\"\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"bnb_config = BitsAndBytesConfig(\n",
|
| 242 |
+
" load_in_4bit=True,\n",
|
| 243 |
+
" bnb_4bit_quant_type=\"nf4\",\n",
|
| 244 |
+
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
|
| 245 |
+
" bnb_4bit_use_double_quant=True,\n",
|
| 246 |
+
")\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)\n",
|
| 249 |
+
"if tokenizer.pad_token is None:\n",
|
| 250 |
+
" tokenizer.pad_token = tokenizer.eos_token\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"model = AutoModelForCausalLM.from_pretrained(\n",
|
| 253 |
+
" MODEL_NAME,\n",
|
| 254 |
+
" quantization_config=bnb_config,\n",
|
| 255 |
+
" device_map=\"auto\",\n",
|
| 256 |
+
" torch_dtype=torch.bfloat16,\n",
|
| 257 |
+
")\n",
|
| 258 |
+
"model.config.use_cache = False"
|
| 259 |
+
]
|
| 260 |
+
},
|
| 261 |
+
{
|
| 262 |
+
"cell_type": "markdown",
|
| 263 |
+
"metadata": {},
|
| 264 |
+
"source": [
|
| 265 |
+
"## 7. Конфиг LoRA"
|
| 266 |
+
]
|
| 267 |
+
},
|
| 268 |
+
{
|
| 269 |
+
"cell_type": "code",
|
| 270 |
+
"execution_count": null,
|
| 271 |
+
"metadata": {},
|
| 272 |
+
"outputs": [],
|
| 273 |
+
"source": [
|
| 274 |
+
"from peft import LoraConfig, prepare_model_for_kbit_training\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"model = prepare_model_for_kbit_training(model)\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"lora_config = LoraConfig(\n",
|
| 279 |
+
" r=16,\n",
|
| 280 |
+
" lora_alpha=32,\n",
|
| 281 |
+
" lora_dropout=0.05,\n",
|
| 282 |
+
" bias=\"none\",\n",
|
| 283 |
+
" task_type=\"CAUSAL_LM\",\n",
|
| 284 |
+
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
|
| 285 |
+
" \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
|
| 286 |
+
")"
|
| 287 |
+
]
|
| 288 |
+
},
|
| 289 |
+
{
|
| 290 |
+
"cell_type": "markdown",
|
| 291 |
+
"metadata": {},
|
| 292 |
+
"source": [
|
| 293 |
+
"## 8. Тренировка через SFTTrainer"
|
| 294 |
+
]
|
| 295 |
+
},
|
| 296 |
+
{
|
| 297 |
+
"cell_type": "code",
|
| 298 |
+
"execution_count": null,
|
| 299 |
+
"metadata": {},
|
| 300 |
+
"outputs": [],
|
| 301 |
+
"source": [
|
| 302 |
+
"from trl import SFTConfig, SFTTrainer\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"OUTPUT_DIR = \"/kaggle/working/qwen-coder-pauq-lora\"\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"sft_config = SFTConfig(\n",
|
| 307 |
+
" output_dir=OUTPUT_DIR,\n",
|
| 308 |
+
" num_train_epochs=2,\n",
|
| 309 |
+
" per_device_train_batch_size=1,\n",
|
| 310 |
+
" gradient_accumulation_steps=8,\n",
|
| 311 |
+
" gradient_checkpointing=True,\n",
|
| 312 |
+
" learning_rate=2e-4,\n",
|
| 313 |
+
" lr_scheduler_type=\"cosine\",\n",
|
| 314 |
+
" warmup_ratio=0.03,\n",
|
| 315 |
+
" optim=\"paged_adamw_8bit\",\n",
|
| 316 |
+
" bf16=True,\n",
|
| 317 |
+
" logging_steps=20,\n",
|
| 318 |
+
" save_strategy=\"epoch\",\n",
|
| 319 |
+
" save_total_limit=2,\n",
|
| 320 |
+
" eval_strategy=\"no\", # eval делаем отдельно после тренировки\n",
|
| 321 |
+
" max_seq_length=1024,\n",
|
| 322 |
+
" packing=False,\n",
|
| 323 |
+
" report_to=\"wandb\",\n",
|
| 324 |
+
" run_name=\"qwen3b-pauq-qlora\",\n",
|
| 325 |
+
")\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"trainer = SFTTrainer(\n",
|
| 328 |
+
" model=model,\n",
|
| 329 |
+
" tokenizer=tokenizer,\n",
|
| 330 |
+
" train_dataset=train_ds,\n",
|
| 331 |
+
" peft_config=lora_config,\n",
|
| 332 |
+
" args=sft_config,\n",
|
| 333 |
+
")\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"trainer.train()"
|
| 336 |
+
]
|
| 337 |
+
},
|
| 338 |
+
{
|
| 339 |
+
"cell_type": "code",
|
| 340 |
+
"execution_count": null,
|
| 341 |
+
"metadata": {},
|
| 342 |
+
"outputs": [],
|
| 343 |
+
"source": [
|
| 344 |
+
"trainer.save_model(OUTPUT_DIR)\n",
|
| 345 |
+
"tokenizer.save_pretrained(OUTPUT_DIR)\n",
|
| 346 |
+
"print(\"Saved to\", OUTPUT_DIR)"
|
| 347 |
+
]
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"cell_type": "markdown",
|
| 351 |
+
"metadata": {},
|
| 352 |
+
"source": [
|
| 353 |
+
"## 9. Быстрая проверка inference"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"cell_type": "code",
|
| 358 |
+
"execution_count": null,
|
| 359 |
+
"metadata": {},
|
| 360 |
+
"outputs": [],
|
| 361 |
+
"source": [
|
| 362 |
+
"model.config.use_cache = True\n",
|
| 363 |
+
"model.eval()\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"ex = dev_records[0]\n",
|
| 366 |
+
"prompt_msgs = ex[\"messages\"][:2] # без assistant-ответа\n",
|
| 367 |
+
"prompt = tokenizer.apply_chat_template(prompt_msgs, tokenize=False, add_generation_prompt=True)\n",
|
| 368 |
+
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
|
| 369 |
+
"\n",
|
| 370 |
+
"with torch.no_grad():\n",
|
| 371 |
+
" out = model.generate(**inputs, max_new_tokens=256, do_sample=False,\n",
|
| 372 |
+
" pad_token_id=tokenizer.eos_token_id)\n",
|
| 373 |
+
"new_tokens = out[0][inputs[\"input_ids\"].shape[1]:]\n",
|
| 374 |
+
"print(\"Pred:\", tokenizer.decode(new_tokens, skip_special_tokens=True))\n",
|
| 375 |
+
"print(\"Gold:\", ex[\"messages\"][2][\"content\"])"
|
| 376 |
+
]
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"cell_type": "markdown",
|
| 380 |
+
"metadata": {},
|
| 381 |
+
"source": [
|
| 382 |
+
"## 10. Загрузка адаптера на HuggingFace Hub (приватный репо)"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "code",
|
| 387 |
+
"execution_count": null,
|
| 388 |
+
"metadata": {},
|
| 389 |
+
"outputs": [],
|
| 390 |
+
"source": [
|
| 391 |
+
"HF_REPO = \"your-username/qwen-coder-pauq-lora\" # замени на свой\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"trainer.model.push_to_hub(HF_REPO, private=True)\n",
|
| 394 |
+
"tokenizer.push_to_hub(HF_REPO, private=True)\n",
|
| 395 |
+
"print(\"Pushed to\", HF_REPO)"
|
| 396 |
+
]
|
| 397 |
+
},
|
| 398 |
+
{
|
| 399 |
+
"cell_type": "markdown",
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"source": [
|
| 402 |
+
"## Дальше\n",
|
| 403 |
+
"\n",
|
| 404 |
+
"1. Скачай адаптер на десктоп: `huggingface-cli download your-username/qwen-coder-pauq-lora --local-dir checkpoints/qwen-coder-pauq-lora`.\n",
|
| 405 |
+
"2. Запусти `python -m src.evaluation.evaluate --split dev --limit 100` локально, либо запусти полный eval здесь же на Kaggle.\n",
|
| 406 |
+
"3. Если метрики низкие: проверь prompt format, увеличь эпохи, понизь learning rate."
|
| 407 |
+
]
|
| 408 |
+
}
|
| 409 |
+
],
|
| 410 |
+
"metadata": {
|
| 411 |
+
"kernelspec": {
|
| 412 |
+
"display_name": "Python 3",
|
| 413 |
+
"language": "python",
|
| 414 |
+
"name": "python3"
|
| 415 |
+
},
|
| 416 |
+
"language_info": {
|
| 417 |
+
"codemirror_mode": {"name": "ipython", "version": 3},
|
| 418 |
+
"file_extension": ".py",
|
| 419 |
+
"mimetype": "text/x-python",
|
| 420 |
+
"name": "python",
|
| 421 |
+
"nbconvert_exporter": "python",
|
| 422 |
+
"pygments_lexer": "ipython3",
|
| 423 |
+
"version": "3.10"
|
| 424 |
+
}
|
| 425 |
+
},
|
| 426 |
+
"nbformat": 4,
|
| 427 |
+
"nbformat_minor": 4
|
| 428 |
+
}
|
plan_VKR_text2sql_ru.md
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# План практической части ВКР: «Утилита Natural Language → SQL для бизнес-аналитики»
|
| 2 |
+
|
| 3 |
+
**Студент:** Danis, ПИ, 4 курс
|
| 4 |
+
**Срок:** 4 недели
|
| 5 |
+
**Дата:** 29 апреля 2026
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## 0. Контур решения
|
| 10 |
+
|
| 11 |
+
**Финальный продукт:** утилита, которая позволяет аналитику малого и среднего бизнеса задавать вопросы на русском языке и получать готовые данные из корпоративной базы данных — без знания SQL.
|
| 12 |
+
|
| 13 |
+
Система: вопрос на русском → бизнес-словарь компании → схема БД → SQL → выполнение → результат.
|
| 14 |
+
|
| 15 |
+
Подход: fine-tuning **Qwen2.5-Coder-3B-Instruct** методом QLoRA на датасете **PAUQ**, обёрнутый в **FastAPI** с дополнительными модулями подключения к произвольной БД, настраиваемым бизнес-словарём и веб-интерфейсом на Streamlit.
|
| 16 |
+
|
| 17 |
+
Для научного сравнения параллельно прогоняется **GigaChat API** (или OpenAI) и **ruT5-base** baseline.
|
| 18 |
+
|
| 19 |
+
Инфраструктура:
|
| 20 |
+
- Тренировка: **Kaggle Notebooks** (T4 16 GB бесплатно).
|
| 21 |
+
- Разработка кода и API: **десктоп** Ryzen 5 3600X + 16 GB RAM.
|
| 22 |
+
- Демо на защите: **ноутбук** Ryzen 5 5500U + 16 GB RAM, инференс на CPU.
|
| 23 |
+
|
| 24 |
+
Артефакты ВКР:
|
| 25 |
+
- Рабочая утилита с веб-интерфейсом (Streamlit)
|
| 26 |
+
- Модуль подключения к произвольной БД (SQLite / PostgreSQL / MySQL)
|
| 27 |
+
- Модуль бизнес-словаря (YAML-конфиг с определениями метрик компании)
|
| 28 |
+
- Сравнительная таблица метрик (EM, Execution Accuracy)
|
| 29 |
+
- Анализ ошибок на 30+ примерах
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## 1. Технологический стек
|
| 34 |
+
|
| 35 |
+
### 1.1 Среда разработки
|
| 36 |
+
|
| 37 |
+
| Компонент | Выбор |
|
| 38 |
+
|---|---|
|
| 39 |
+
| Язык | Python 3.10+ |
|
| 40 |
+
| Менеджер пакетов | uv (быстрый, современный) |
|
| 41 |
+
| Контроль версий | Git + GitHub |
|
| 42 |
+
| IDE | VS Code |
|
| 43 |
+
|
| 44 |
+
### 1.2 ML и обучение
|
| 45 |
+
|
| 46 |
+
| Компонент | Выбор | Где используется |
|
| 47 |
+
|---|---|---|
|
| 48 |
+
| PyTorch 2.x | основа | Kaggle |
|
| 49 |
+
| transformers | модели и токенизация | Kaggle + десктоп |
|
| 50 |
+
| peft | LoRA/QLoRA | Kaggle |
|
| 51 |
+
| bitsandbytes | 4-bit квантизация | Kaggle (на CPU не нужен) |
|
| 52 |
+
| trl | SFTTrainer | Kaggle |
|
| 53 |
+
| datasets | работа с PAUQ | Kaggle + десктоп |
|
| 54 |
+
| W&B | логирование экспериментов | Kaggle |
|
| 55 |
+
|
| 56 |
+
### 1.3 Инференс на десктопе и ноутбуке
|
| 57 |
+
|
| 58 |
+
Для локального инференса без GPU есть два пути:
|
| 59 |
+
|
| 60 |
+
| Путь | Скорость | Сложность | Применение |
|
| 61 |
+
|---|---|---|---|
|
| 62 |
+
| transformers на CPU (int8) | 15–30 с/запрос | проще | разработка, отладка |
|
| 63 |
+
| llama.cpp (gguf int4) | 5–15 с/запрос | сложнее | финальное демо |
|
| 64 |
+
|
| 65 |
+
**Рекомендация:** для разработки — transformers, для защиты — llama.cpp.
|
| 66 |
+
|
| 67 |
+
### 1.4 API и SQL
|
| 68 |
+
|
| 69 |
+
| Компонент | Выбор |
|
| 70 |
+
|---|---|
|
| 71 |
+
| FastAPI + Uvicorn | REST API |
|
| 72 |
+
| Pydantic v2 | валидация |
|
| 73 |
+
| sqlite3 (stdlib) | работа с БД из PAUQ |
|
| 74 |
+
| sqlglot | парсинг и валидация SQL |
|
| 75 |
+
| pytest | тесты |
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## 2. Архитектура
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
┌──────────────────────────────────────────────────────────────┐
|
| 83 |
+
│ Streamlit Web Interface │
|
| 84 |
+
│ Поле вопроса | Выбор БД | Редактор бизнес-словаря │
|
| 85 |
+
│ Таблица результатов | История запросов │
|
| 86 |
+
└──────────────────────────┬───────────────────────────────────┘
|
| 87 |
+
│ HTTP
|
| 88 |
+
┌──────────────────────────▼───────────────────────────────────┐
|
| 89 |
+
│ FastAPI REST API │
|
| 90 |
+
│ POST /query {question_ru, db_id} → {sql, result, ...} │
|
| 91 |
+
└──────┬──────────────┬───────────────┬─────────────────────���──┘
|
| 92 |
+
│ │ │
|
| 93 |
+
▼ ▼ ▼
|
| 94 |
+
┌────────────┐ ┌────────────┐ ┌─────────────────┐
|
| 95 |
+
│ DbConnector│ │ Business │ │ SchemaRetriever │
|
| 96 |
+
│ SQLite / │ │ Vocabulary │ │ (DDL из БД) │
|
| 97 |
+
│ Postgres / │ │ (YAML- │ └────────┬────────┘
|
| 98 |
+
│ MySQL │ │ конфиг) │ │
|
| 99 |
+
└─────┬──────┘ └─────┬──────┘ │
|
| 100 |
+
│ │ │
|
| 101 |
+
│ ┌────▼─────────────────▼──┐
|
| 102 |
+
│ │ PromptBuilder │
|
| 103 |
+
│ │ вопрос + схема + │
|
| 104 |
+
│ │ определения метрик │
|
| 105 |
+
│ └────────────┬────────────┘
|
| 106 |
+
│ ▼
|
| 107 |
+
│ ┌────────────────────────┐
|
| 108 |
+
│ │ InferenceEngine │
|
| 109 |
+
│ │ Qwen2.5-Coder-3B │
|
| 110 |
+
│ │ + LoRA adapter │
|
| 111 |
+
│ └────────────┬───────────┘
|
| 112 |
+
│ ▼
|
| 113 |
+
│ ┌────────────────────────┐
|
| 114 |
+
│ │ SqlPostProcessor │
|
| 115 |
+
│ │ (sqlglot validation) │
|
| 116 |
+
│ └────────────┬───────────┘
|
| 117 |
+
│ │
|
| 118 |
+
└──────────────────────┘
|
| 119 |
+
│ выполнить SQL
|
| 120 |
+
▼
|
| 121 |
+
┌─────────────────┐
|
| 122 |
+
│ SqlExecutor │
|
| 123 |
+
│ результат → │
|
| 124 |
+
│ аналитику │
|
| 125 |
+
└─────────────────┘
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
Структура проекта (см. файлы в репозитории):
|
| 129 |
+
```
|
| 130 |
+
ru2sql/
|
| 131 |
+
├── README.md
|
| 132 |
+
├── pyproject.toml
|
| 133 |
+
├── .gitignore
|
| 134 |
+
├── notebooks/
|
| 135 |
+
│ └── kaggle_train_qwen_qlora.ipynb
|
| 136 |
+
├── src/
|
| 137 |
+
│ ├── config.py
|
| 138 |
+
│ ├── data/ — loader, schema, prompt
|
| 139 |
+
│ ├── models/ — inference, postprocess
|
| 140 |
+
│ ├── evaluation/ — metrics, evaluate
|
| 141 |
+
│ └── api/ — main, schemas, dependencies
|
| 142 |
+
├── tests/
|
| 143 |
+
└── scripts/
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
---
|
| 147 |
+
|
| 148 |
+
## 3. Помесячный план
|
| 149 |
+
|
| 150 |
+
### Неделя 1. Окружение, данные, baseline
|
| 151 |
+
|
| 152 |
+
**Цель:** работающий pipeline от вопроса до SQL на маленькой модели.
|
| 153 |
+
|
| 154 |
+
| День | Задача |
|
| 155 |
+
|---|---|
|
| 156 |
+
| 1 | Установка Python 3.10+, uv, Git. Клонирование репозитория. `uv sync`. Проверка что FastAPI стартует. |
|
| 157 |
+
| 2 | Регистрация на Kaggle, HuggingFace, W&B. Скачивание PAUQ (https://github.com/ai-forever/pauq). |
|
| 158 |
+
| 3 | Анализ датасета в notebook: распределения, сложности, примеры. Реализация `SchemaRetriever`. |
|
| 159 |
+
| 4 | Реализация `PromptBuilder`. Тесты: `pytest tests/test_prompt.py`. |
|
| 160 |
+
| 5–6 | Kaggle-notebook: обучение **ruT5-base** на 2 эпохи. Сохранение чекпойнта. |
|
| 161 |
+
| 7 | Реализация `metrics.py` (EM + Execution Accuracy). Прогон ruT5 на dev. Запись в W&B. |
|
| 162 |
+
|
| 163 |
+
Контрольная точка недели: ruT5-base даёт 25–35% EM на PAUQ dev.
|
| 164 |
+
|
| 165 |
+
### Неделя 2. Главная модель (Qwen2.5-Coder-3B + QLoRA)
|
| 166 |
+
|
| 167 |
+
**Цель:** обученный LoRA-адаптер для Qwen с метриками выше baseline.
|
| 168 |
+
|
| 169 |
+
| День | Задача |
|
| 170 |
+
|---|---|
|
| 171 |
+
| 1 | Kaggle-notebook: загрузка Qwen2.5-Coder-3B в 4-bit, тестовый inference. |
|
| 172 |
+
| 2 | Подготовка PAUQ в chat-формате под модель. |
|
| 173 |
+
| 3–4 | SFTTrainer + LoRA (r=16, alpha=32). Прогон 2–3 эпохи (~4–6 часов суммарно). |
|
| 174 |
+
| 5 | Сохранение LoRA-адаптера на HuggingFace Hub (приватный репозиторий). |
|
| 175 |
+
| 6 | Скачивание адаптера на десктоп. Локальный инференс на CPU через transformers. |
|
| 176 |
+
| 7 | Прогон на dev split, метрики, error analysis на 30 примерах. |
|
| 177 |
+
|
| 178 |
+
Контрольная точка недели: Qwen+LoRA даёт 50–60% EM на PAUQ dev и работает на десктопе.
|
| 179 |
+
|
| 180 |
+
### Неделя 3. Бизнес-утилита: коннектор + словарь + исполнение SQL
|
| 181 |
+
|
| 182 |
+
**Цель:** превратить API в полноценную бизнес-утилиту — подключение к реальной БД, настройка под компанию, возврат данных.
|
| 183 |
+
|
| 184 |
+
| День | Задача |
|
| 185 |
+
|---|---|
|
| 186 |
+
| 1 | FastAPI: `/generate-sql`, `/query`, `/databases`, `/health`. Lifespan для загрузки модели. |
|
| 187 |
+
| 2 | Модуль `DbConnector` — подключение к SQLite/PostgreSQL/MySQL по строке подключения. Автоматическое чтение схемы (`INFORMATION_SCHEMA`). |
|
| 188 |
+
| 3 | Модуль `BusinessVocabulary` — загрузка YAML-конфига с определениями метрик. Подстановка определений в промпт перед генерацией SQL. Пример конфига: `выручка: "SUM(orders.amount) WHERE status='paid'"`. |
|
| 189 |
+
| 4 | Эндпоинт `/query` — принимает вопрос, генерирует SQL, выполняет на подключённой БД, возвращает результат в JSON (таблица строк). |
|
| 190 |
+
| 5 | Получение API-ключа GigaChat (или YandexGPT), скрипт прогона на тех же примерах. Сравнительная таблица: ruT5 vs Qwen+LoRA vs GigaChat по EM и EX. |
|
| 191 |
+
| 6 | `SqlPostProcessor` через sqlglot. Тесты pytest на все новые модули. |
|
| 192 |
+
| 7 | Создание демо-базы данных (SQLite) с реалистичными бизнес-данными: продажи, клиенты, товары. Написание бизнес-словаря под эту базу. |
|
| 193 |
+
|
| 194 |
+
Контрольная точка недели: аналитик вводит "Какая выручка за январь?" → утилита возвращает число из реальной БД.
|
| 195 |
+
|
| 196 |
+
### Неделя 4. Streamlit-интерфейс, демо, материалы для ВКР
|
| 197 |
+
|
| 198 |
+
**Цель:** красивый рабочий продукт для защиты + готовые материалы для текста ВКР.
|
| 199 |
+
|
| 200 |
+
| День | Задача |
|
| 201 |
+
|---|---|
|
| 202 |
+
| 1 | Streamlit-интерфейс: поле ввода вопроса, выбор БД, отображение сгенерированного SQL и таблицы результатов. |
|
| 203 |
+
| 2 | В интерфейсе: вкладка настройки бизнес-словаря (редактирование YAML прямо в браузере). История запросов. |
|
| 204 |
+
| 3 | Error analysis: разбор 30 ошибок Qwen+LoRA, классификация по категориям (неверный JOIN, неверное условие WHERE и т.д.). |
|
| 205 |
+
| 4 | Конвертация LoRA + базовой модели в gguf через llama.cpp для быстрого инференса на CPU. |
|
| 206 |
+
| 5 | Диаграммы архитектуры (draw.io), скриншоты интерфейса, графики метрик (matplotlib). |
|
| 207 |
+
| 6 | Глава «Реализация» и глава «Практическое применение» в тексте ВКР. |
|
| 208 |
+
| 7 | Прогон полного сценария на ноутбуке с демо-базой. Резервная копия чекпойнта на HuggingFace. |
|
| 209 |
+
|
| 210 |
+
---
|
| 211 |
+
|
| 212 |
+
## 4. Метрики качества
|
| 213 |
+
|
| 214 |
+
Стандарт для Text-to-SQL:
|
| 215 |
+
|
| 216 |
+
- **Exact Match (EM)** — нормализуем оба SQL и сравниваем посимвольно.
|
| 217 |
+
- **Execution Accuracy (EX)** — выполняем оба SQL на реальной SQLite, сравниваем результаты как множества кортежей.
|
| 218 |
+
|
| 219 |
+
EX важнее EM, потому что разные SQL могут дать одинаковый результат.
|
| 220 |
+
|
| 221 |
+
Целевые числа на PAUQ dev (ориентировочно):
|
| 222 |
+
- ruT5-base: 25–35% EM, 30–40% EX.
|
| 223 |
+
- Qwen2.5-Coder-3B + LoRA: 50–60% EM, 55–70% EX.
|
| 224 |
+
- GigaChat / GPT-4 (zero-shot, через API): 55–70% EM, 65–80% EX.
|
| 225 |
+
|
| 226 |
+
Ваш Qwen после QLoRA должен быть близок к API-моделям. Это и будет защищаемый результат.
|
| 227 |
+
|
| 228 |
+
---
|
| 229 |
+
|
| 230 |
+
## 5. Риски и план B
|
| 231 |
+
|
| 232 |
+
| Риск | План B |
|
| 233 |
+
|---|---|
|
| 234 |
+
| Kaggle квота закончилась | Переключиться на Google Colab Free или арендовать GPU на vast.ai (~$2 за обучение) |
|
| 235 |
+
| Qwen-3B плохо сходится | Понизить learning rate до 1e-4, увеличить эпохи до 5, проверить prompt format |
|
| 236 |
+
| llama.cpp не успеваю настроить к защите | Демо через transformers на CPU напрямую — медленнее, но работает |
|
| 237 |
+
| GigaChat недоступен | YandexGPT либо OpenAI через VPN — Pydantic-обёртка одна, провайдер меняется одной строчкой |
|
| 238 |
+
| Не хватает времени на error analysis | Минимум — 20 ошибок руками, простая классификация в Excel |
|
| 239 |
+
|
| 240 |
+
---
|
| 241 |
+
|
| 242 |
+
## 6. Что вынести в «направления дальнейшей работы»
|
| 243 |
+
|
| 244 |
+
Эти улучшения **не делаем** в рамках месяца, но упоминаем в ВКР:
|
| 245 |
+
- Few-shot retrieval (поиск похожих примеров через эмбеддинги).
|
| 246 |
+
- Schema linking (автоматический отбор таблиц).
|
| 247 |
+
- Self-correction (выполнение SQL, исправление по ошибке).
|
| 248 |
+
- Constrained decoding (ограничение токенов до валидной SQL-грамматики).
|
| 249 |
+
- Дообучение на синтетических данных от GPT-4.
|
| 250 |
+
|
| 251 |
+
---
|
| 252 |
+
|
| 253 |
+
## 7. Итоговый чек-лист на старте
|
| 254 |
+
|
| 255 |
+
- [ ] Установлены Python 3.10+, uv, Git, VS Code на десктопе
|
| 256 |
+
- [ ] Создан репозиторий ru2sql на GitHub
|
| 257 |
+
- [ ] Зарегистрированы аккаунты Kaggle, HuggingFace, W&B
|
| 258 |
+
- [ ] Получен ключ GigaChat (или OpenAI)
|
| 259 |
+
- [ ] Скачан PAUQ
|
| 260 |
+
- [ ] `uv sync` проходит без ошибок
|
| 261 |
+
- [ ] `uvicorn src.api.main:app --reload` стартует
|
| 262 |
+
- [ ] Прочитаны статьи: Spider (2018), QLoRA (2023), краткое описание Qwen2.5-Coder
|
| 263 |
+
|
| 264 |
+
После чек-листа можно стартовать День 3 первой недели.
|
pyproject.toml
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "ru2sql"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Russian-to-SQL generative model for graduation thesis"
|
| 5 |
+
authors = [{ name = "Danis", email = "[email protected]" }]
|
| 6 |
+
requires-python = ">=3.10,<3.13"
|
| 7 |
+
readme = "README.md"
|
| 8 |
+
|
| 9 |
+
dependencies = [
|
| 10 |
+
# API
|
| 11 |
+
"fastapi>=0.115.0",
|
| 12 |
+
"uvicorn[standard]>=0.30.0",
|
| 13 |
+
"pydantic>=2.7.0",
|
| 14 |
+
"pydantic-settings>=2.4.0",
|
| 15 |
+
|
| 16 |
+
# SQL parsing / validation
|
| 17 |
+
"sqlglot>=25.0.0",
|
| 18 |
+
|
| 19 |
+
# Data
|
| 20 |
+
"datasets>=2.20.0",
|
| 21 |
+
"pandas>=2.2.0",
|
| 22 |
+
|
| 23 |
+
# ML inference (CPU-friendly versions for desktop/laptop)
|
| 24 |
+
# Heavy training deps (bitsandbytes, peft, trl) live in [training] and run on Kaggle
|
| 25 |
+
"torch>=2.3.0",
|
| 26 |
+
"transformers>=4.44.0",
|
| 27 |
+
"accelerate>=0.33.0",
|
| 28 |
+
"peft>=0.12.0", # for loading LoRA adapter at inference time
|
| 29 |
+
|
| 30 |
+
# Misc
|
| 31 |
+
"python-dotenv>=1.0.0",
|
| 32 |
+
"httpx>=0.27.0", # for GigaChat/OpenAI API client
|
| 33 |
+
"tqdm>=4.66.0",
|
| 34 |
+
|
| 35 |
+
# Интерфейс
|
| 36 |
+
"streamlit>=1.35.0",
|
| 37 |
+
"pyyaml>=6.0",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
[project.optional-dependencies]
|
| 41 |
+
dev = [
|
| 42 |
+
"pytest>=8.3.0",
|
| 43 |
+
"pytest-asyncio>=0.23.0",
|
| 44 |
+
"ruff>=0.6.0",
|
| 45 |
+
"ipykernel>=6.29.0",
|
| 46 |
+
"matplotlib>=3.9.0",
|
| 47 |
+
"seaborn>=0.13.0",
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# Heavy GPU-only deps. Install on Kaggle: `pip install -e .[training]`
|
| 51 |
+
training = [
|
| 52 |
+
"bitsandbytes>=0.43.0",
|
| 53 |
+
"trl>=0.10.0",
|
| 54 |
+
"wandb>=0.17.0",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
[build-system]
|
| 58 |
+
requires = ["hatchling"]
|
| 59 |
+
build-backend = "hatchling.build"
|
| 60 |
+
|
| 61 |
+
[tool.hatch.build.targets.wheel]
|
| 62 |
+
packages = ["src"]
|
| 63 |
+
|
| 64 |
+
[tool.ruff]
|
| 65 |
+
line-length = 100
|
| 66 |
+
target-version = "py310"
|
| 67 |
+
|
| 68 |
+
[tool.ruff.lint]
|
| 69 |
+
select = ["E", "F", "W", "I", "B", "UP"]
|
| 70 |
+
ignore = ["E501"]
|
| 71 |
+
|
| 72 |
+
[tool.pytest.ini_options]
|
| 73 |
+
testpaths = ["tests"]
|
| 74 |
+
pythonpath = ["."]
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.35.0
|
| 2 |
+
torch>=2.3.0
|
| 3 |
+
transformers>=4.44.0
|
| 4 |
+
accelerate>=0.33.0
|
| 5 |
+
peft>=0.12.0
|
| 6 |
+
pydantic>=2.7.0
|
| 7 |
+
pydantic-settings>=2.4.0
|
| 8 |
+
sqlglot>=25.0.0
|
| 9 |
+
pandas>=2.2.0
|
| 10 |
+
python-dotenv>=1.0.0
|
| 11 |
+
huggingface_hub>=1.0.0
|
| 12 |
+
pyyaml>=6.0
|
| 13 |
+
tqdm>=4.66.0
|
| 14 |
+
httpx>=0.27.0
|
src/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ru2sql — Russian-to-SQL generative model."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.1.0"
|
src/api/__init__.py
ADDED
|
File without changes
|
src/api/dependencies.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI lifespan и DI: загрузка модели один раз при старте."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from contextlib import asynccontextmanager
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI
|
| 8 |
+
|
| 9 |
+
from src.config import settings
|
| 10 |
+
from src.data.schema import SchemaRetriever
|
| 11 |
+
from src.models.inference import InferenceEngine
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class AppState:
|
| 15 |
+
engine: InferenceEngine | None = None
|
| 16 |
+
schema_retriever: SchemaRetriever | None = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
state = AppState()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@asynccontextmanager
|
| 23 |
+
async def lifespan(app: FastAPI):
|
| 24 |
+
"""Грузим модель при старте, освобождаем при остановке."""
|
| 25 |
+
state.engine = InferenceEngine()
|
| 26 |
+
state.engine.load()
|
| 27 |
+
state.schema_retriever = SchemaRetriever(settings.databases_dir)
|
| 28 |
+
yield
|
| 29 |
+
state.engine = None
|
| 30 |
+
state.schema_retriever = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_engine() -> InferenceEngine:
|
| 34 |
+
if state.engine is None:
|
| 35 |
+
raise RuntimeError("Inference engine not initialized")
|
| 36 |
+
return state.engine
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_schema_retriever() -> SchemaRetriever:
|
| 40 |
+
if state.schema_retriever is None:
|
| 41 |
+
raise RuntimeError("SchemaRetriever not initialized")
|
| 42 |
+
return state.schema_retriever
|
src/api/main.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI приложение.
|
| 2 |
+
|
| 3 |
+
Запуск:
|
| 4 |
+
uvicorn src.api.main:app --reload
|
| 5 |
+
# Swagger UI: http://127.0.0.1:8000/docs
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import sqlite3
|
| 11 |
+
|
| 12 |
+
from fastapi import Depends, FastAPI, HTTPException
|
| 13 |
+
from fastapi.concurrency import run_in_threadpool
|
| 14 |
+
|
| 15 |
+
from src.api.dependencies import get_engine, get_schema_retriever, lifespan
|
| 16 |
+
from src.api.schemas import (
|
| 17 |
+
DatabaseInfo,
|
| 18 |
+
ExecutionResult,
|
| 19 |
+
GenerateRequest,
|
| 20 |
+
GenerateResponse,
|
| 21 |
+
HealthResponse,
|
| 22 |
+
)
|
| 23 |
+
from src.config import settings
|
| 24 |
+
from src.data.schema import SchemaRetriever
|
| 25 |
+
from src.models.inference import InferenceEngine
|
| 26 |
+
from src.models.postprocess import is_valid_sql
|
| 27 |
+
|
| 28 |
+
app = FastAPI(
|
| 29 |
+
title="ru2sql",
|
| 30 |
+
description="Преобразование вопросов на русском в SQL-запросы",
|
| 31 |
+
version="0.1.0",
|
| 32 |
+
lifespan=lifespan,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@app.get("/health", response_model=HealthResponse)
|
| 37 |
+
def health(engine: InferenceEngine = Depends(get_engine)):
|
| 38 |
+
return HealthResponse(
|
| 39 |
+
status="ok",
|
| 40 |
+
model_loaded=engine._loaded,
|
| 41 |
+
base_model=engine.base_model_name,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@app.get("/databases", response_model=list[DatabaseInfo])
|
| 46 |
+
def list_databases(retriever: SchemaRetriever = Depends(get_schema_retriever)):
|
| 47 |
+
out: list[DatabaseInfo] = []
|
| 48 |
+
for db_id in retriever.list_databases():
|
| 49 |
+
try:
|
| 50 |
+
tables = [t.name for t in retriever.get_tables(db_id, n_sample_rows=0)]
|
| 51 |
+
out.append(DatabaseInfo(db_id=db_id, tables=tables))
|
| 52 |
+
except FileNotFoundError:
|
| 53 |
+
continue
|
| 54 |
+
return out
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@app.post("/generate-sql", response_model=GenerateResponse)
|
| 58 |
+
async def generate_sql(
|
| 59 |
+
req: GenerateRequest,
|
| 60 |
+
engine: InferenceEngine = Depends(get_engine),
|
| 61 |
+
retriever: SchemaRetriever = Depends(get_schema_retriever),
|
| 62 |
+
):
|
| 63 |
+
try:
|
| 64 |
+
schema_text = retriever.render_schema(req.db_id)
|
| 65 |
+
except FileNotFoundError as e:
|
| 66 |
+
raise HTTPException(status_code=404, detail=str(e)) from e
|
| 67 |
+
|
| 68 |
+
# Inference синхронный и тяжёлый — выносим в threadpool
|
| 69 |
+
result = await run_in_threadpool(engine.generate, schema_text, req.question)
|
| 70 |
+
|
| 71 |
+
valid = is_valid_sql(result.sql)
|
| 72 |
+
response = GenerateResponse(
|
| 73 |
+
sql=result.sql,
|
| 74 |
+
raw_output=result.raw_output,
|
| 75 |
+
is_valid_sql=valid,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if req.execute and valid:
|
| 79 |
+
try:
|
| 80 |
+
response.execution = await run_in_threadpool(
|
| 81 |
+
_execute_sql, req.db_id, result.sql, retriever
|
| 82 |
+
)
|
| 83 |
+
except sqlite3.Error as e:
|
| 84 |
+
response.error = f"SQL execution error: {e}"
|
| 85 |
+
|
| 86 |
+
return response
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _execute_sql(db_id: str, sql: str, retriever: SchemaRetriever) -> ExecutionResult:
|
| 90 |
+
db_path = retriever.db_path(db_id)
|
| 91 |
+
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True)
|
| 92 |
+
try:
|
| 93 |
+
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 94 |
+
cur = conn.cursor()
|
| 95 |
+
cur.execute(sql)
|
| 96 |
+
rows = cur.fetchmany(100)
|
| 97 |
+
cols = [d[0] for d in cur.description] if cur.description else []
|
| 98 |
+
return ExecutionResult(
|
| 99 |
+
columns=cols,
|
| 100 |
+
rows=[list(r) for r in rows],
|
| 101 |
+
row_count=len(rows),
|
| 102 |
+
)
|
| 103 |
+
finally:
|
| 104 |
+
conn.close()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
if __name__ == "__main__":
|
| 108 |
+
import uvicorn
|
| 109 |
+
|
| 110 |
+
uvicorn.run("src.api.main:app", host=settings.api_host, port=settings.api_port, reload=True)
|
src/api/schemas.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic-модели для FastAPI endpoints."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GenerateRequest(BaseModel):
|
| 9 |
+
question: str = Field(..., min_length=1, max_length=2000, description="Вопрос на русском")
|
| 10 |
+
db_id: str = Field(..., min_length=1, description="Идентификатор БД из PAUQ")
|
| 11 |
+
execute: bool = Field(default=False, description="Прогнать сгенерированный SQL на БД")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ExecutionResult(BaseModel):
|
| 15 |
+
columns: list[str]
|
| 16 |
+
rows: list[list]
|
| 17 |
+
row_count: int
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GenerateResponse(BaseModel):
|
| 21 |
+
sql: str
|
| 22 |
+
raw_output: str
|
| 23 |
+
is_valid_sql: bool
|
| 24 |
+
execution: ExecutionResult | None = None
|
| 25 |
+
error: str | None = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DatabaseInfo(BaseModel):
|
| 29 |
+
db_id: str
|
| 30 |
+
tables: list[str]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class HealthResponse(BaseModel):
|
| 34 |
+
status: str
|
| 35 |
+
model_loaded: bool
|
| 36 |
+
base_model: str
|
src/business/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vocabulary import BusinessVocabulary
|
| 2 |
+
|
| 3 |
+
__all__ = ["BusinessVocabulary"]
|
src/business/vocabulary.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""BusinessVocabulary — настраиваемый бизнес-словарь компании.
|
| 2 |
+
|
| 3 |
+
Позволяет аналитику один раз описать бизнес-термины и метрики компании в YAML-файле,
|
| 4 |
+
после чего модель правильно интерпретирует их в SQL-запросах.
|
| 5 |
+
|
| 6 |
+
Пример YAML-конфига (configs/example_vocabulary.yaml):
|
| 7 |
+
company: "ООО Ромашка"
|
| 8 |
+
terms:
|
| 9 |
+
выручка: "SUM(orders.amount) WHERE orders.status = 'paid'"
|
| 10 |
+
активный клиент: "клиент, совершивший покупку за последние 90 дней"
|
| 11 |
+
этот год: "YEAR(order_date) = strftime('%Y', 'now')"
|
| 12 |
+
прошлый месяц: "strftime('%Y-%m', order_date) = strftime('%Y-%m', 'now', '-1 month')"
|
| 13 |
+
|
| 14 |
+
filters:
|
| 15 |
+
только_оплаченные: "orders.status = 'paid'"
|
| 16 |
+
без_возвратов: "orders.is_return = 0"
|
| 17 |
+
|
| 18 |
+
Пример использования:
|
| 19 |
+
vocab = BusinessVocabulary.from_yaml("configs/my_company.yaml")
|
| 20 |
+
enriched_prompt = vocab.enrich_prompt("Какая выручка за январь?")
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
from dataclasses import dataclass, field
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import yaml # type: ignore
|
| 30 |
+
_YAML_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
_YAML_AVAILABLE = False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class BusinessVocabulary:
|
| 37 |
+
"""Хранит бизнес-термины и метрики компании, подставляет их в промпт модели."""
|
| 38 |
+
|
| 39 |
+
company: str = ""
|
| 40 |
+
terms: dict[str, str] = field(default_factory=dict)
|
| 41 |
+
filters: dict[str, str] = field(default_factory=dict)
|
| 42 |
+
notes: list[str] = field(default_factory=list)
|
| 43 |
+
|
| 44 |
+
# ------------------------------------------------------------------
|
| 45 |
+
# Загрузка
|
| 46 |
+
# ------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
@classmethod
|
| 49 |
+
def from_yaml(cls, path: str | Path) -> "BusinessVocabulary":
|
| 50 |
+
"""Загружает словарь из YAML-файла."""
|
| 51 |
+
if not _YAML_AVAILABLE:
|
| 52 |
+
raise ImportError("Установи PyYAML: pip install pyyaml")
|
| 53 |
+
path = Path(path)
|
| 54 |
+
if not path.exists():
|
| 55 |
+
raise FileNotFoundError(f"Файл бизнес-словаря не найден: {path}")
|
| 56 |
+
with open(path, encoding="utf-8") as f:
|
| 57 |
+
data = yaml.safe_load(f) or {}
|
| 58 |
+
return cls(
|
| 59 |
+
company=data.get("company", ""),
|
| 60 |
+
terms=data.get("terms", {}),
|
| 61 |
+
filters=data.get("filters", {}),
|
| 62 |
+
notes=data.get("notes", []),
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def from_dict(cls, data: dict) -> "BusinessVocabulary":
|
| 67 |
+
"""Создаёт словарь из словаря Python (удобно для API и Streamlit)."""
|
| 68 |
+
return cls(
|
| 69 |
+
company=data.get("company", ""),
|
| 70 |
+
terms=data.get("terms", {}),
|
| 71 |
+
filters=data.get("filters", {}),
|
| 72 |
+
notes=data.get("notes", []),
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def empty(cls) -> "BusinessVocabulary":
|
| 77 |
+
"""Пустой словарь — для случая когда компания ещё не настроила термины."""
|
| 78 |
+
return cls()
|
| 79 |
+
|
| 80 |
+
# ------------------------------------------------------------------
|
| 81 |
+
# Использование
|
| 82 |
+
# ------------------------------------------------------------------
|
| 83 |
+
|
| 84 |
+
def enrich_prompt(self, question: str) -> str:
|
| 85 |
+
"""Добавляет к вопросу пользователя контекст из бизнес-словаря.
|
| 86 |
+
|
| 87 |
+
Если вопрос содержит известные термины — подставляет их определения.
|
| 88 |
+
Возвращает обогащённый вопрос для подстановки в промпт модели.
|
| 89 |
+
"""
|
| 90 |
+
if not self.terms and not self.filters and not self.notes:
|
| 91 |
+
return question
|
| 92 |
+
|
| 93 |
+
context_lines: list[str] = []
|
| 94 |
+
|
| 95 |
+
# Находим термины которые упоминаются в вопросе
|
| 96 |
+
question_lower = question.lower()
|
| 97 |
+
relevant_terms = {
|
| 98 |
+
term: definition
|
| 99 |
+
for term, definition in self.terms.items()
|
| 100 |
+
if term.lower() in question_lower
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
if relevant_terms:
|
| 104 |
+
context_lines.append("Определения терминов компании:")
|
| 105 |
+
for term, definition in relevant_terms.items():
|
| 106 |
+
context_lines.append(f" - {term}: {definition}")
|
| 107 |
+
|
| 108 |
+
if self.filters:
|
| 109 |
+
context_lines.append("Стандартные фильтры компании:")
|
| 110 |
+
for name, condition in self.filters.items():
|
| 111 |
+
context_lines.append(f" - {name}: {condition}")
|
| 112 |
+
|
| 113 |
+
if self.notes:
|
| 114 |
+
context_lines.append("Дополнительные правила:")
|
| 115 |
+
for note in self.notes:
|
| 116 |
+
context_lines.append(f" - {note}")
|
| 117 |
+
|
| 118 |
+
if not context_lines:
|
| 119 |
+
return question
|
| 120 |
+
|
| 121 |
+
context = "\n".join(context_lines)
|
| 122 |
+
return f"{question}\n\n[Контекст компании]\n{context}"
|
| 123 |
+
|
| 124 |
+
def render_system_context(self) -> str:
|
| 125 |
+
"""Текст для системного промпта — описывает все термины компании."""
|
| 126 |
+
if not self.terms and not self.filters and not self.notes:
|
| 127 |
+
return ""
|
| 128 |
+
|
| 129 |
+
lines: list[str] = []
|
| 130 |
+
if self.company:
|
| 131 |
+
lines.append(f"Компания: {self.company}")
|
| 132 |
+
lines.append("")
|
| 133 |
+
|
| 134 |
+
if self.terms:
|
| 135 |
+
lines.append("Бизнес-термины и метрики:")
|
| 136 |
+
for term, definition in self.terms.items():
|
| 137 |
+
lines.append(f" - «{term}» означает: {definition}")
|
| 138 |
+
|
| 139 |
+
if self.filters:
|
| 140 |
+
lines.append("")
|
| 141 |
+
lines.append("Стандартные условия фильтрации:")
|
| 142 |
+
for name, condition in self.filters.items():
|
| 143 |
+
lines.append(f" - {name}: {condition}")
|
| 144 |
+
|
| 145 |
+
if self.notes:
|
| 146 |
+
lines.append("")
|
| 147 |
+
lines.append("Важные правила:")
|
| 148 |
+
for note in self.notes:
|
| 149 |
+
lines.append(f" - {note}")
|
| 150 |
+
|
| 151 |
+
return "\n".join(lines)
|
| 152 |
+
|
| 153 |
+
def to_yaml_string(self) -> str:
|
| 154 |
+
"""Сериализует словарь обратно в YAML-строку (для редактора в Streamlit)."""
|
| 155 |
+
if not _YAML_AVAILABLE:
|
| 156 |
+
raise ImportError("Установи PyYAML: pip install pyyaml")
|
| 157 |
+
data = {
|
| 158 |
+
"company": self.company,
|
| 159 |
+
"terms": self.terms,
|
| 160 |
+
"filters": self.filters,
|
| 161 |
+
"notes": self.notes,
|
| 162 |
+
}
|
| 163 |
+
return yaml.dump(data, allow_unicode=True, sort_keys=False, default_flow_style=False)
|
| 164 |
+
|
| 165 |
+
def save_yaml(self, path: str | Path) -> None:
|
| 166 |
+
"""Сохраняет словарь в YAML-файл."""
|
| 167 |
+
path = Path(path)
|
| 168 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 169 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 170 |
+
f.write(self.to_yaml_string())
|
| 171 |
+
|
| 172 |
+
def __bool__(self) -> bool:
|
| 173 |
+
return bool(self.terms or self.filters or self.notes)
|
src/config.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Конфигурация проекта. Читаем из .env через pydantic-settings."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 8 |
+
|
| 9 |
+
ROOT_DIR = Path(__file__).resolve().parent.parent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Settings(BaseSettings):
|
| 13 |
+
"""Все настройки приложения. Значения берутся из .env, переменных окружения, либо дефолтов."""
|
| 14 |
+
|
| 15 |
+
model_config = SettingsConfigDict(
|
| 16 |
+
env_file=str(ROOT_DIR / ".env"),
|
| 17 |
+
env_file_encoding="utf-8",
|
| 18 |
+
extra="ignore",
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Модель
|
| 22 |
+
base_model_name: str = "Qwen/Qwen2.5-Coder-3B-Instruct"
|
| 23 |
+
lora_adapter_path: str = str(ROOT_DIR / "checkpoints" / "qwen-coder-pauq-lora")
|
| 24 |
+
device: str = "cpu" # "cpu" | "cuda" | "mps"
|
| 25 |
+
|
| 26 |
+
# Данные
|
| 27 |
+
pauq_data_dir: Path = ROOT_DIR / "data" / "pauq"
|
| 28 |
+
databases_dir: Path = ROOT_DIR / "data" / "databases"
|
| 29 |
+
|
| 30 |
+
# API ключи (используется только тот, который заполнен)
|
| 31 |
+
gigachat_api_key: str = ""
|
| 32 |
+
openai_api_key: str = ""
|
| 33 |
+
yandexgpt_api_key: str = ""
|
| 34 |
+
yandexgpt_folder_id: str = ""
|
| 35 |
+
hf_token: str = ""
|
| 36 |
+
|
| 37 |
+
# FastAPI
|
| 38 |
+
api_host: str = "127.0.0.1"
|
| 39 |
+
api_port: int = 8000
|
| 40 |
+
|
| 41 |
+
# Inference defaults
|
| 42 |
+
max_new_tokens: int = 256
|
| 43 |
+
temperature: float = 0.0 # для SQL детерминизм лучше
|
| 44 |
+
do_sample: bool = False
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Singleton-инстанс. Импортируется по всему проекту: `from src.config import settings`
|
| 48 |
+
settings = Settings()
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/loader.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Загрузчик датасета PAUQ.
|
| 2 |
+
|
| 3 |
+
PAUQ распространяется в JSON-формате с полями question, query, db_id и т.д.
|
| 4 |
+
См. https://github.com/ai-forever/pauq
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Iterator
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class PauqExample:
|
| 17 |
+
question: str
|
| 18 |
+
query: str # gold SQL
|
| 19 |
+
db_id: str
|
| 20 |
+
query_type: str | None = None # easy/medium/hard/extra если есть
|
| 21 |
+
raw: dict | None = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_pauq_split(path: Path | str) -> list[PauqExample]:
|
| 25 |
+
"""Читает train.json / dev.json / test.json из PAUQ."""
|
| 26 |
+
path = Path(path)
|
| 27 |
+
with path.open("r", encoding="utf-8") as f:
|
| 28 |
+
raw = json.load(f)
|
| 29 |
+
|
| 30 |
+
examples: list[PauqExample] = []
|
| 31 |
+
for item in raw:
|
| 32 |
+
# PAUQ имеет несколько ревизий формата; пробуем самые частые поля
|
| 33 |
+
question = item.get("question") or item.get("question_ru") or ""
|
| 34 |
+
query = item.get("query") or item.get("sql_query") or item.get("sql") or ""
|
| 35 |
+
db_id = item.get("db_id") or item.get("database") or ""
|
| 36 |
+
if not (question and query and db_id):
|
| 37 |
+
continue
|
| 38 |
+
examples.append(
|
| 39 |
+
PauqExample(
|
| 40 |
+
question=question.strip(),
|
| 41 |
+
query=query.strip(),
|
| 42 |
+
db_id=db_id.strip(),
|
| 43 |
+
query_type=item.get("query_type") or item.get("hardness"),
|
| 44 |
+
raw=item,
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
return examples
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def iter_pauq_split(path: Path | str) -> Iterator[PauqExample]:
|
| 51 |
+
"""Удобно при больших датасетах — генератор."""
|
| 52 |
+
yield from load_pauq_split(path)
|
src/data/prompt.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PromptBuilder — формирует input для модели в формате chat-template."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
SYSTEM_PROMPT = (
|
| 6 |
+
"Ты — ассистент, который преобразует вопросы на русском языке в корректные SQL-запросы. "
|
| 7 |
+
"Тебе даётся схема базы данных в виде CREATE TABLE statements и пример нескольких строк. "
|
| 8 |
+
"Сгенерируй один SQL-запрос, который отвечает на вопрос пользователя. "
|
| 9 |
+
"Возвращай ТОЛЬКО SQL без объяснений, без markdown, без префиксов."
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_user_message(schema: str, question: str) -> str:
|
| 14 |
+
return f"### Schema:\n{schema}\n\n### Question:\n{question}\n\n### SQL:\n"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def build_chat_messages(schema: str, question: str) -> list[dict]:
|
| 18 |
+
"""Формат для tokenizer.apply_chat_template."""
|
| 19 |
+
return [
|
| 20 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 21 |
+
{"role": "user", "content": build_user_message(schema, question)},
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_training_example(schema: str, question: str, sql: str) -> list[dict]:
|
| 26 |
+
"""Полный диалог для SFT с ответом ассистента."""
|
| 27 |
+
msgs = build_chat_messages(schema, question)
|
| 28 |
+
msgs.append({"role": "assistant", "content": sql.strip()})
|
| 29 |
+
return msgs
|
src/data/schema.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SchemaRetriever — извлекает DDL и примеры строк из SQLite-файлов PAUQ/Spider."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sqlite3
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class TableInfo:
|
| 12 |
+
name: str
|
| 13 |
+
create_sql: str
|
| 14 |
+
sample_rows: list[tuple]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SchemaRetriever:
|
| 18 |
+
"""Читает структуру SQLite-БД для подачи в prompt модели."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, databases_dir: Path | str):
|
| 21 |
+
self.databases_dir = Path(databases_dir)
|
| 22 |
+
|
| 23 |
+
def db_path(self, db_id: str) -> Path:
|
| 24 |
+
"""В Spider/PAUQ каждая БД лежит в databases_dir/{db_id}/{db_id}.sqlite."""
|
| 25 |
+
path = self.databases_dir / db_id / f"{db_id}.sqlite"
|
| 26 |
+
if not path.exists():
|
| 27 |
+
raise FileNotFoundError(f"Database file not found: {path}")
|
| 28 |
+
return path
|
| 29 |
+
|
| 30 |
+
def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableInfo]:
|
| 31 |
+
"""Возвращает список таблиц с CREATE-SQL и примером строк."""
|
| 32 |
+
path = self.db_path(db_id)
|
| 33 |
+
conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
|
| 34 |
+
try:
|
| 35 |
+
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 36 |
+
cur = conn.cursor()
|
| 37 |
+
cur.execute(
|
| 38 |
+
"SELECT name, sql FROM sqlite_master "
|
| 39 |
+
"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
| 40 |
+
)
|
| 41 |
+
rows = cur.fetchall()
|
| 42 |
+
|
| 43 |
+
tables: list[TableInfo] = []
|
| 44 |
+
for table_name, create_sql in rows:
|
| 45 |
+
if not create_sql:
|
| 46 |
+
continue
|
| 47 |
+
try:
|
| 48 |
+
cur.execute(f'SELECT * FROM "{table_name}" LIMIT {n_sample_rows}')
|
| 49 |
+
samples = cur.fetchall()
|
| 50 |
+
except sqlite3.Error:
|
| 51 |
+
samples = []
|
| 52 |
+
tables.append(
|
| 53 |
+
TableInfo(name=table_name, create_sql=create_sql.strip(), sample_rows=samples)
|
| 54 |
+
)
|
| 55 |
+
return tables
|
| 56 |
+
finally:
|
| 57 |
+
conn.close()
|
| 58 |
+
|
| 59 |
+
def render_schema(self, db_id: str, include_samples: bool = True) -> str:
|
| 60 |
+
"""Текстовое представление схемы для prompt'а."""
|
| 61 |
+
tables = self.get_tables(db_id)
|
| 62 |
+
parts: list[str] = []
|
| 63 |
+
for t in tables:
|
| 64 |
+
parts.append(t.create_sql + ";")
|
| 65 |
+
if include_samples and t.sample_rows:
|
| 66 |
+
parts.append(f"-- Примеры строк из {t.name}:")
|
| 67 |
+
for row in t.sample_rows:
|
| 68 |
+
parts.append(f"-- {row}")
|
| 69 |
+
parts.append("")
|
| 70 |
+
return "\n".join(parts).strip()
|
| 71 |
+
|
| 72 |
+
def list_databases(self) -> list[str]:
|
| 73 |
+
"""Список доступных db_id."""
|
| 74 |
+
if not self.databases_dir.exists():
|
| 75 |
+
return []
|
| 76 |
+
return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())
|
src/db/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .connector import DbConnector
|
| 2 |
+
from .executor import SqlExecutor, QueryResult
|
| 3 |
+
|
| 4 |
+
__all__ = ["DbConnector", "SqlExecutor", "QueryResult"]
|
src/db/connector.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""DbConnector -- podklyuchenie k proizvolnoy baze dannykh i chtenie skhemy.
|
| 2 |
+
|
| 3 |
+
Podderzhivaemye tipy BD:
|
| 4 |
+
SQLite -- put k faylu: "sqlite:///path/to/db.sqlite" ili prosto put
|
| 5 |
+
PostgreSQL -- "postgresql://user:pass@host:port/dbname" (trebuet psycopg2)
|
| 6 |
+
MySQL -- "mysql://user:pass@host:port/dbname" (trebuet pymysql)
|
| 7 |
+
|
| 8 |
+
Primer:
|
| 9 |
+
conn = DbConnector("sqlite:///data/demo/sales.sqlite")
|
| 10 |
+
print(conn.render_schema())
|
| 11 |
+
tables = conn.list_tables()
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import sqlite3
|
| 17 |
+
from dataclasses import dataclass, field
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from urllib.parse import urlparse
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class ColumnInfo:
|
| 24 |
+
name: str
|
| 25 |
+
type: str
|
| 26 |
+
nullable: bool = True
|
| 27 |
+
primary_key: bool = False
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class TableInfo:
|
| 32 |
+
name: str
|
| 33 |
+
columns: list[ColumnInfo] = field(default_factory=list)
|
| 34 |
+
sample_rows: list[tuple] = field(default_factory=list)
|
| 35 |
+
|
| 36 |
+
def to_ddl(self) -> str:
|
| 37 |
+
"""Generiruet CREATE TABLE statement iz metadannykh."""
|
| 38 |
+
col_parts = []
|
| 39 |
+
for col in self.columns:
|
| 40 |
+
line = f" {col.name} {col.type}"
|
| 41 |
+
if col.primary_key:
|
| 42 |
+
line += " PRIMARY KEY"
|
| 43 |
+
if not col.nullable:
|
| 44 |
+
line += " NOT NULL"
|
| 45 |
+
col_parts.append(line)
|
| 46 |
+
return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);"
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class DbConnector:
|
| 50 |
+
"""Universalnyy konektor k BD. Umeet chitat skhemu dlya podstanovki v prompt."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, connection_string: str, n_sample_rows: int = 2):
|
| 53 |
+
self.connection_string = self._normalize(connection_string)
|
| 54 |
+
self.n_sample_rows = n_sample_rows
|
| 55 |
+
self._db_type = self._detect_type(self.connection_string)
|
| 56 |
+
|
| 57 |
+
def list_tables(self) -> list[str]:
|
| 58 |
+
return [t.name for t in self._get_tables(n_sample_rows=0)]
|
| 59 |
+
|
| 60 |
+
def get_schema(self, include_samples: bool = True) -> list[TableInfo]:
|
| 61 |
+
return self._get_tables(n_sample_rows=self.n_sample_rows if include_samples else 0)
|
| 62 |
+
|
| 63 |
+
def render_schema(self, include_samples: bool = True) -> str:
|
| 64 |
+
tables = self.get_schema(include_samples=include_samples)
|
| 65 |
+
parts: list[str] = []
|
| 66 |
+
for t in tables:
|
| 67 |
+
parts.append(t.to_ddl())
|
| 68 |
+
if include_samples and t.sample_rows:
|
| 69 |
+
parts.append(f"-- Primery strok iz {t.name}:")
|
| 70 |
+
for row in t.sample_rows:
|
| 71 |
+
parts.append(f"-- {row}")
|
| 72 |
+
parts.append("")
|
| 73 |
+
return "\n".join(parts).strip()
|
| 74 |
+
|
| 75 |
+
def test_connection(self) -> bool:
|
| 76 |
+
try:
|
| 77 |
+
self._get_tables(n_sample_rows=0)
|
| 78 |
+
return True
|
| 79 |
+
except Exception:
|
| 80 |
+
return False
|
| 81 |
+
|
| 82 |
+
def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
|
| 83 |
+
if self._db_type == "sqlite":
|
| 84 |
+
return self._get_tables_sqlite(n_sample_rows)
|
| 85 |
+
elif self._db_type == "postgresql":
|
| 86 |
+
return self._get_tables_postgres(n_sample_rows)
|
| 87 |
+
elif self._db_type == "mysql":
|
| 88 |
+
return self._get_tables_mysql(n_sample_rows)
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"Neizvestnyy tip BD: {self._db_type}")
|
| 91 |
+
|
| 92 |
+
def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
|
| 93 |
+
path = self._safe_sqlite_path(self._sqlite_path())
|
| 94 |
+
conn = sqlite3.connect(str(path))
|
| 95 |
+
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 96 |
+
try:
|
| 97 |
+
cur = conn.cursor()
|
| 98 |
+
cur.execute(
|
| 99 |
+
"SELECT name FROM sqlite_master "
|
| 100 |
+
"WHERE type='table' AND name NOT LIKE 'sqlite_%' "
|
| 101 |
+
"ORDER BY name"
|
| 102 |
+
)
|
| 103 |
+
table_names = [r[0] for r in cur.fetchall()]
|
| 104 |
+
tables: list[TableInfo] = []
|
| 105 |
+
for name in table_names:
|
| 106 |
+
cur.execute(f'PRAGMA table_info("{name}")')
|
| 107 |
+
cols = [
|
| 108 |
+
ColumnInfo(
|
| 109 |
+
name=row[1],
|
| 110 |
+
type=row[2] or "TEXT",
|
| 111 |
+
nullable=not row[3],
|
| 112 |
+
primary_key=bool(row[5]),
|
| 113 |
+
)
|
| 114 |
+
for row in cur.fetchall()
|
| 115 |
+
]
|
| 116 |
+
samples: list[tuple] = []
|
| 117 |
+
if n_sample_rows > 0:
|
| 118 |
+
try:
|
| 119 |
+
cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
|
| 120 |
+
samples = cur.fetchall()
|
| 121 |
+
except sqlite3.Error:
|
| 122 |
+
pass
|
| 123 |
+
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
|
| 124 |
+
return tables
|
| 125 |
+
finally:
|
| 126 |
+
conn.close()
|
| 127 |
+
|
| 128 |
+
def _get_tables_postgres(self, n_sample_rows: int) -> list[TableInfo]:
|
| 129 |
+
try:
|
| 130 |
+
import psycopg2 # type: ignore
|
| 131 |
+
except ImportError as e:
|
| 132 |
+
raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e
|
| 133 |
+
|
| 134 |
+
conn = psycopg2.connect(self.connection_string)
|
| 135 |
+
try:
|
| 136 |
+
cur = conn.cursor()
|
| 137 |
+
cur.execute(
|
| 138 |
+
"SELECT table_name FROM information_schema.tables "
|
| 139 |
+
"WHERE table_schema = 'public' AND table_type = 'BASE TABLE' "
|
| 140 |
+
"ORDER BY table_name"
|
| 141 |
+
)
|
| 142 |
+
table_names = [r[0] for r in cur.fetchall()]
|
| 143 |
+
tables: list[TableInfo] = []
|
| 144 |
+
for name in table_names:
|
| 145 |
+
cur.execute(
|
| 146 |
+
"SELECT column_name, data_type, is_nullable "
|
| 147 |
+
"FROM information_schema.columns "
|
| 148 |
+
"WHERE table_name = %s AND table_schema = 'public' "
|
| 149 |
+
"ORDER BY ordinal_position",
|
| 150 |
+
(name,),
|
| 151 |
+
)
|
| 152 |
+
cols = [
|
| 153 |
+
ColumnInfo(name=r[0], type=r[1], nullable=(r[2] == "YES"))
|
| 154 |
+
for r in cur.fetchall()
|
| 155 |
+
]
|
| 156 |
+
samples: list[tuple] = []
|
| 157 |
+
if n_sample_rows > 0:
|
| 158 |
+
cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
|
| 159 |
+
samples = cur.fetchall()
|
| 160 |
+
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
|
| 161 |
+
return tables
|
| 162 |
+
finally:
|
| 163 |
+
conn.close()
|
| 164 |
+
|
| 165 |
+
def _get_tables_mysql(self, n_sample_rows: int) -> list[TableInfo]:
|
| 166 |
+
try:
|
| 167 |
+
import pymysql # type: ignore
|
| 168 |
+
except ImportError as e:
|
| 169 |
+
raise ImportError("Ustanovi pymysql: pip install pymysql") from e
|
| 170 |
+
|
| 171 |
+
parsed = urlparse(self.connection_string)
|
| 172 |
+
conn = pymysql.connect(
|
| 173 |
+
host=parsed.hostname,
|
| 174 |
+
port=parsed.port or 3306,
|
| 175 |
+
user=parsed.username,
|
| 176 |
+
password=parsed.password,
|
| 177 |
+
database=parsed.path.lstrip("/"),
|
| 178 |
+
)
|
| 179 |
+
try:
|
| 180 |
+
cur = conn.cursor()
|
| 181 |
+
cur.execute("SHOW TABLES")
|
| 182 |
+
table_names = [r[0] for r in cur.fetchall()]
|
| 183 |
+
tables: list[TableInfo] = []
|
| 184 |
+
for name in table_names:
|
| 185 |
+
cur.execute(f"DESCRIBE `{name}`")
|
| 186 |
+
cols = [
|
| 187 |
+
ColumnInfo(
|
| 188 |
+
name=r[0], type=r[1],
|
| 189 |
+
nullable=(r[2] == "YES"),
|
| 190 |
+
primary_key=(r[3] == "PRI"),
|
| 191 |
+
)
|
| 192 |
+
for r in cur.fetchall()
|
| 193 |
+
]
|
| 194 |
+
samples: list[tuple] = []
|
| 195 |
+
if n_sample_rows > 0:
|
| 196 |
+
cur.execute(f"SELECT * FROM `{name}` LIMIT {n_sample_rows}")
|
| 197 |
+
samples = cur.fetchall()
|
| 198 |
+
tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
|
| 199 |
+
return tables
|
| 200 |
+
finally:
|
| 201 |
+
conn.close()
|
| 202 |
+
|
| 203 |
+
def _sqlite_path(self) -> Path:
|
| 204 |
+
cs = self.connection_string
|
| 205 |
+
if cs.startswith("sqlite:///"):
|
| 206 |
+
return Path(cs[10:])
|
| 207 |
+
return Path(cs)
|
| 208 |
+
|
| 209 |
+
@staticmethod
|
| 210 |
+
def _safe_sqlite_path(path: Path) -> Path:
|
| 211 |
+
"""Esli ryadom s BD est journal-fayl, kopируем fayl vo vremennuyu direktoriu."""
|
| 212 |
+
import shutil
|
| 213 |
+
import tempfile
|
| 214 |
+
journal = Path(str(path) + "-journal")
|
| 215 |
+
wal = Path(str(path) + "-wal")
|
| 216 |
+
if journal.exists() or wal.exists():
|
| 217 |
+
tmp = Path(tempfile.mktemp(suffix=".sqlite"))
|
| 218 |
+
shutil.copy2(path, tmp)
|
| 219 |
+
return tmp
|
| 220 |
+
return path
|
| 221 |
+
|
| 222 |
+
@staticmethod
|
| 223 |
+
def _normalize(cs: str) -> str:
|
| 224 |
+
"""Esli peredan prosto put k faylu -- prevraschaem v sqlite:// URI."""
|
| 225 |
+
cs = cs.strip()
|
| 226 |
+
if cs.endswith(".sqlite") or cs.endswith(".db"):
|
| 227 |
+
return f"sqlite:///{cs}"
|
| 228 |
+
return cs
|
| 229 |
+
|
| 230 |
+
@staticmethod
|
| 231 |
+
def _detect_type(cs: str) -> str:
|
| 232 |
+
if cs.startswith("sqlite"):
|
| 233 |
+
return "sqlite"
|
| 234 |
+
if cs.startswith("postgresql") or cs.startswith("postgres"):
|
| 235 |
+
return "postgresql"
|
| 236 |
+
if cs.startswith("mysql"):
|
| 237 |
+
return "mysql"
|
| 238 |
+
raise ValueError(f"Ne udalos opredelit tip BD: {cs}")
|
src/db/executor.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SqlExecutor -- vypolnyaet SQL-zapros na podklyuchennoy BD i vozvraschaet rezultat.
|
| 2 |
+
|
| 3 |
+
Primer:
|
| 4 |
+
executor = SqlExecutor("sqlite:///data/demo/sales.sqlite")
|
| 5 |
+
result = executor.run("SELECT SUM(amount) FROM orders WHERE status='paid'")
|
| 6 |
+
print(result.columns)
|
| 7 |
+
print(result.rows)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import sqlite3
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from urllib.parse import urlparse
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class QueryResult:
|
| 20 |
+
"""Rezultat vypolneniya SQL-zaprosa."""
|
| 21 |
+
columns: list[str]
|
| 22 |
+
rows: list[list]
|
| 23 |
+
row_count: int
|
| 24 |
+
sql: str
|
| 25 |
+
error: str | None = None
|
| 26 |
+
|
| 27 |
+
@property
|
| 28 |
+
def success(self) -> bool:
|
| 29 |
+
return self.error is None
|
| 30 |
+
|
| 31 |
+
def to_dict(self) -> dict:
|
| 32 |
+
return {
|
| 33 |
+
"columns": self.columns,
|
| 34 |
+
"rows": self.rows,
|
| 35 |
+
"row_count": self.row_count,
|
| 36 |
+
"sql": self.sql,
|
| 37 |
+
"error": self.error,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
def to_markdown_table(self) -> str:
|
| 41 |
+
if self.error:
|
| 42 |
+
return f"Oshibka: {self.error}"
|
| 43 |
+
if not self.rows:
|
| 44 |
+
return "(pustoy rezultat)"
|
| 45 |
+
header = " | ".join(self.columns)
|
| 46 |
+
sep = " | ".join(["---"] * len(self.columns))
|
| 47 |
+
rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
|
| 48 |
+
return f"{header}\n{sep}\n{rows}"
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class SqlExecutor:
|
| 52 |
+
"""Vypolnyaet SQL na podklyuchennoy BD."""
|
| 53 |
+
|
| 54 |
+
MAX_ROWS = 500
|
| 55 |
+
|
| 56 |
+
def __init__(self, connection_string: str):
|
| 57 |
+
self.connection_string = connection_string.strip()
|
| 58 |
+
self._db_type = self._detect_type(self.connection_string)
|
| 59 |
+
|
| 60 |
+
def run(self, sql: str) -> QueryResult:
|
| 61 |
+
try:
|
| 62 |
+
if self._db_type == "sqlite":
|
| 63 |
+
return self._run_sqlite(sql)
|
| 64 |
+
elif self._db_type == "postgresql":
|
| 65 |
+
return self._run_postgres(sql)
|
| 66 |
+
elif self._db_type == "mysql":
|
| 67 |
+
return self._run_mysql(sql)
|
| 68 |
+
else:
|
| 69 |
+
return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
|
| 70 |
+
error=f"Neizvestnyy tip BD: {self._db_type}")
|
| 71 |
+
except Exception as e:
|
| 72 |
+
return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))
|
| 73 |
+
|
| 74 |
+
def _run_sqlite(self, sql: str) -> QueryResult:
|
| 75 |
+
path = self._safe_sqlite_path(self._sqlite_path())
|
| 76 |
+
conn = sqlite3.connect(str(path))
|
| 77 |
+
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 78 |
+
try:
|
| 79 |
+
cur = conn.cursor()
|
| 80 |
+
cur.execute(sql)
|
| 81 |
+
cols = [d[0] for d in (cur.description or [])]
|
| 82 |
+
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
|
| 83 |
+
return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
|
| 84 |
+
finally:
|
| 85 |
+
conn.close()
|
| 86 |
+
|
| 87 |
+
def _run_postgres(self, sql: str) -> QueryResult:
|
| 88 |
+
try:
|
| 89 |
+
import psycopg2 # type: ignore
|
| 90 |
+
except ImportError as e:
|
| 91 |
+
raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e
|
| 92 |
+
|
| 93 |
+
conn = psycopg2.connect(self.connection_string)
|
| 94 |
+
try:
|
| 95 |
+
cur = conn.cursor()
|
| 96 |
+
cur.execute(sql)
|
| 97 |
+
cols = [d[0] for d in (cur.description or [])]
|
| 98 |
+
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
|
| 99 |
+
return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
|
| 100 |
+
finally:
|
| 101 |
+
conn.close()
|
| 102 |
+
|
| 103 |
+
def _run_mysql(self, sql: str) -> QueryResult:
|
| 104 |
+
try:
|
| 105 |
+
import pymysql # type: ignore
|
| 106 |
+
except ImportError as e:
|
| 107 |
+
raise ImportError("Ustanovi pymysql: pip install pymysql") from e
|
| 108 |
+
|
| 109 |
+
parsed = urlparse(self.connection_string)
|
| 110 |
+
conn = pymysql.connect(
|
| 111 |
+
host=parsed.hostname,
|
| 112 |
+
port=parsed.port or 3306,
|
| 113 |
+
user=parsed.username,
|
| 114 |
+
password=parsed.password,
|
| 115 |
+
database=parsed.path.lstrip("/"),
|
| 116 |
+
)
|
| 117 |
+
try:
|
| 118 |
+
cur = conn.cursor()
|
| 119 |
+
cur.execute(sql)
|
| 120 |
+
cols = [d[0] for d in (cur.description or [])]
|
| 121 |
+
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
|
| 122 |
+
return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
|
| 123 |
+
finally:
|
| 124 |
+
conn.close()
|
| 125 |
+
|
| 126 |
+
def _sqlite_path(self) -> Path:
|
| 127 |
+
cs = self.connection_string
|
| 128 |
+
if cs.startswith("sqlite:///"):
|
| 129 |
+
return Path(cs[10:])
|
| 130 |
+
return Path(cs)
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def _safe_sqlite_path(path: Path) -> Path:
|
| 134 |
+
import shutil
|
| 135 |
+
import tempfile
|
| 136 |
+
journal = Path(str(path) + "-journal")
|
| 137 |
+
wal = Path(str(path) + "-wal")
|
| 138 |
+
if journal.exists() or wal.exists():
|
| 139 |
+
tmp = Path(tempfile.mktemp(suffix=".sqlite"))
|
| 140 |
+
shutil.copy2(path, tmp)
|
| 141 |
+
return tmp
|
| 142 |
+
return path
|
| 143 |
+
|
| 144 |
+
@staticmethod
|
| 145 |
+
def _detect_type(cs: str) -> str:
|
| 146 |
+
if cs.startswith("sqlite") or cs.endswith(".sqlite") or cs.endswith(".db"):
|
| 147 |
+
return "sqlite"
|
| 148 |
+
if cs.startswith("postgresql") or cs.startswith("postgres"):
|
| 149 |
+
return "postgresql"
|
| 150 |
+
if cs.startswith("mysql"):
|
| 151 |
+
return "mysql"
|
| 152 |
+
raise ValueError(f"Ne udalos opredelit tip BD: {cs}")
|
src/evaluation/__init__.py
ADDED
|
File without changes
|
src/evaluation/evaluate.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Скрипт прогона модели на test-сплите PAUQ.
|
| 2 |
+
|
| 3 |
+
Использование:
|
| 4 |
+
python -m src.evaluation.evaluate --split dev --limit 50
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import json
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
|
| 15 |
+
from src.config import settings
|
| 16 |
+
from src.data.loader import load_pauq_split
|
| 17 |
+
from src.data.schema import SchemaRetriever
|
| 18 |
+
from src.evaluation.metrics import compute_metrics
|
| 19 |
+
from src.models.inference import InferenceEngine
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument("--split", default="dev", choices=["train", "dev", "test"])
|
| 25 |
+
parser.add_argument("--limit", type=int, default=None, help="Ограничить число примеров")
|
| 26 |
+
parser.add_argument("--output", type=Path, default=Path("results/predictions.jsonl"))
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
split_path = settings.pauq_data_dir / f"{args.split}.json"
|
| 30 |
+
examples = load_pauq_split(split_path)
|
| 31 |
+
if args.limit:
|
| 32 |
+
examples = examples[: args.limit]
|
| 33 |
+
|
| 34 |
+
schema_ret = SchemaRetriever(settings.databases_dir)
|
| 35 |
+
engine = InferenceEngine()
|
| 36 |
+
engine.load()
|
| 37 |
+
|
| 38 |
+
predictions: list[str] = []
|
| 39 |
+
golds: list[str] = []
|
| 40 |
+
db_ids: list[str] = []
|
| 41 |
+
rows = []
|
| 42 |
+
|
| 43 |
+
for ex in tqdm(examples, desc="Inference"):
|
| 44 |
+
try:
|
| 45 |
+
schema = schema_ret.render_schema(ex.db_id)
|
| 46 |
+
except FileNotFoundError:
|
| 47 |
+
continue
|
| 48 |
+
result = engine.generate(schema, ex.question)
|
| 49 |
+
predictions.append(result.sql)
|
| 50 |
+
golds.append(ex.query)
|
| 51 |
+
db_ids.append(ex.db_id)
|
| 52 |
+
rows.append(
|
| 53 |
+
{
|
| 54 |
+
"db_id": ex.db_id,
|
| 55 |
+
"question": ex.question,
|
| 56 |
+
"gold": ex.query,
|
| 57 |
+
"pred": result.sql,
|
| 58 |
+
"raw": result.raw_output,
|
| 59 |
+
}
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
args.output.parent.mkdir(parents=True, exist_ok=True)
|
| 63 |
+
with args.output.open("w", encoding="utf-8") as f:
|
| 64 |
+
for r in rows:
|
| 65 |
+
f.write(json.dumps(r, ensure_ascii=False) + "\n")
|
| 66 |
+
|
| 67 |
+
metrics = compute_metrics(predictions, golds, db_ids, settings.databases_dir)
|
| 68 |
+
print(json.dumps(metrics, indent=2, ensure_ascii=False))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|
src/evaluation/metrics.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Метрики Text-to-SQL: Exact Match и Execution Accuracy."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import sqlite3
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from src.models.postprocess import normalize_sql
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def exact_match(predicted: str, gold: str, dialect: str = "sqlite") -> bool:
|
| 12 |
+
"""Сравнение нормализованных SQL посимвольно. Грубая, но честная метрика."""
|
| 13 |
+
return normalize_sql(predicted, dialect) == normalize_sql(gold, dialect)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def execution_accuracy(
|
| 17 |
+
predicted_sql: str,
|
| 18 |
+
gold_sql: str,
|
| 19 |
+
db_path: Path | str,
|
| 20 |
+
timeout_seconds: float = 5.0,
|
| 21 |
+
) -> bool:
|
| 22 |
+
"""Прогон обоих SQL на SQLite. True если результаты совпадают как множества."""
|
| 23 |
+
db_path = Path(db_path)
|
| 24 |
+
conn = sqlite3.connect(f"file:{db_path}?mode=ro", uri=True, timeout=timeout_seconds)
|
| 25 |
+
try:
|
| 26 |
+
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
|
| 27 |
+
try:
|
| 28 |
+
pred_rows = _run(conn, predicted_sql)
|
| 29 |
+
except sqlite3.Error:
|
| 30 |
+
return False
|
| 31 |
+
try:
|
| 32 |
+
gold_rows = _run(conn, gold_sql)
|
| 33 |
+
except sqlite3.Error:
|
| 34 |
+
return False
|
| 35 |
+
return _rows_equal(pred_rows, gold_rows)
|
| 36 |
+
finally:
|
| 37 |
+
conn.close()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _run(conn: sqlite3.Connection, sql: str) -> list[tuple]:
|
| 41 |
+
cur = conn.cursor()
|
| 42 |
+
cur.execute(sql)
|
| 43 |
+
return cur.fetchall()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _rows_equal(a: list[tuple], b: list[tuple]) -> bool:
|
| 47 |
+
"""Сравнение как мультимножеств — порядок не важен (если в SQL нет ORDER BY)."""
|
| 48 |
+
if len(a) != len(b):
|
| 49 |
+
return False
|
| 50 |
+
return sorted(map(_row_key, a)) == sorted(map(_row_key, b))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _row_key(row: tuple) -> tuple:
|
| 54 |
+
return tuple(str(x) for x in row)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def compute_metrics(
|
| 58 |
+
predictions: list[str],
|
| 59 |
+
golds: list[str],
|
| 60 |
+
db_ids: list[str],
|
| 61 |
+
databases_dir: Path | str,
|
| 62 |
+
) -> dict:
|
| 63 |
+
"""Прогон по всему датасету. Возвращает dict с EM, EX, и счётчиками."""
|
| 64 |
+
databases_dir = Path(databases_dir)
|
| 65 |
+
n = len(predictions)
|
| 66 |
+
assert n == len(golds) == len(db_ids), "Mismatched lengths"
|
| 67 |
+
|
| 68 |
+
em_count = 0
|
| 69 |
+
ex_count = 0
|
| 70 |
+
parse_fail = 0
|
| 71 |
+
|
| 72 |
+
for pred, gold, db_id in zip(predictions, golds, db_ids):
|
| 73 |
+
if exact_match(pred, gold):
|
| 74 |
+
em_count += 1
|
| 75 |
+
|
| 76 |
+
db_path = databases_dir / db_id / f"{db_id}.sqlite"
|
| 77 |
+
if not db_path.exists():
|
| 78 |
+
parse_fail += 1
|
| 79 |
+
continue
|
| 80 |
+
|
| 81 |
+
if execution_accuracy(pred, gold, db_path):
|
| 82 |
+
ex_count += 1
|
| 83 |
+
|
| 84 |
+
return {
|
| 85 |
+
"n": n,
|
| 86 |
+
"exact_match": em_count / n if n else 0.0,
|
| 87 |
+
"execution_accuracy": ex_count / n if n else 0.0,
|
| 88 |
+
"parse_fail": parse_fail,
|
| 89 |
+
}
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/inference.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Загрузка модели + LoRA-адаптера и инференс.
|
| 2 |
+
|
| 3 |
+
На десктопе/ноутбуке без GPU работает на CPU. Медленно, но достаточно для разработки и демо.
|
| 4 |
+
На Kaggle/Colab — на GPU, быстрее.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
from src.config import settings
|
| 16 |
+
from src.data.prompt import build_chat_messages
|
| 17 |
+
from src.models.postprocess import postprocess
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class GenerationResult:
|
| 22 |
+
sql: str
|
| 23 |
+
raw_output: str
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class InferenceEngine:
|
| 27 |
+
"""Singleton-обёртка над моделью. Загружается один раз при старте API."""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
base_model_name: str | None = None,
|
| 32 |
+
lora_adapter_path: str | None = None,
|
| 33 |
+
device: str | None = None,
|
| 34 |
+
):
|
| 35 |
+
self.base_model_name = base_model_name or settings.base_model_name
|
| 36 |
+
self.lora_adapter_path = lora_adapter_path or settings.lora_adapter_path
|
| 37 |
+
self.device = device or settings.device
|
| 38 |
+
self.tokenizer = None
|
| 39 |
+
self.model = None
|
| 40 |
+
self._loaded = False
|
| 41 |
+
|
| 42 |
+
def load(self) -> None:
|
| 43 |
+
"""Лениво грузим модель. На CPU без квантизации."""
|
| 44 |
+
if self._loaded:
|
| 45 |
+
return
|
| 46 |
+
|
| 47 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name)
|
| 48 |
+
# bfloat16 вдвое меньше float32 (~6 ГБ vs ~12 ГБ) и поддерживается на CPU
|
| 49 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
| 50 |
+
self.base_model_name,
|
| 51 |
+
dtype=torch.bfloat16,
|
| 52 |
+
device_map=self.device if self.device != "cpu" else None,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Подцепляем LoRA-адаптер: сначала ищем локально, потом на HF Hub
|
| 56 |
+
adapter_path = Path(self.lora_adapter_path)
|
| 57 |
+
adapter_id = str(adapter_path) if adapter_path.exists() else self.lora_adapter_path
|
| 58 |
+
try:
|
| 59 |
+
from peft import PeftModel
|
| 60 |
+
self.model = PeftModel.from_pretrained(self.model, adapter_id)
|
| 61 |
+
except ImportError:
|
| 62 |
+
pass # peft не установлен — работаем на базовой модели
|
| 63 |
+
|
| 64 |
+
self.model.eval()
|
| 65 |
+
self._loaded = True
|
| 66 |
+
|
| 67 |
+
def generate(
|
| 68 |
+
self,
|
| 69 |
+
schema: str,
|
| 70 |
+
question: str,
|
| 71 |
+
max_new_tokens: int | None = None,
|
| 72 |
+
) -> GenerationResult:
|
| 73 |
+
"""Принимает schema (текст DDL) и вопрос, возвращает SQL."""
|
| 74 |
+
if not self._loaded:
|
| 75 |
+
self.load()
|
| 76 |
+
|
| 77 |
+
messages = build_chat_messages(schema, question)
|
| 78 |
+
prompt = self.tokenizer.apply_chat_template(
|
| 79 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 80 |
+
)
|
| 81 |
+
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
|
| 82 |
+
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
output_ids = self.model.generate(
|
| 85 |
+
**inputs,
|
| 86 |
+
max_new_tokens=max_new_tokens or settings.max_new_tokens,
|
| 87 |
+
do_sample=settings.do_sample,
|
| 88 |
+
temperature=settings.temperature if settings.do_sample else 1.0,
|
| 89 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
new_tokens = output_ids[0][inputs["input_ids"].shape[1] :]
|
| 93 |
+
raw = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 94 |
+
return GenerationResult(sql=postprocess(raw), raw_output=raw)
|
src/models/postprocess.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Постобработка SQL: чистка вывода модели и базовая валидация через sqlglot."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
import sqlglot
|
| 8 |
+
from sqlglot.errors import ParseError
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def strip_model_artifacts(text: str) -> str:
|
| 12 |
+
"""Убирает markdown-блоки, префиксы, лишний текст после SQL."""
|
| 13 |
+
# ```sql ... ```
|
| 14 |
+
m = re.search(r"```(?:sql)?\s*(.*?)```", text, re.DOTALL | re.IGNORECASE)
|
| 15 |
+
if m:
|
| 16 |
+
text = m.group(1)
|
| 17 |
+
|
| 18 |
+
# Убираем "SQL:", "Ответ:" и т.п. в начале
|
| 19 |
+
text = re.sub(r"^\s*(?:SQL|Ответ|Answer)\s*:\s*", "", text, flags=re.IGNORECASE)
|
| 20 |
+
|
| 21 |
+
# Если есть несколько SQL — берём первый до точки с запятой
|
| 22 |
+
text = text.strip()
|
| 23 |
+
if ";" in text:
|
| 24 |
+
head, _, _ = text.partition(";")
|
| 25 |
+
text = head.strip() + ";"
|
| 26 |
+
|
| 27 |
+
return text.strip()
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def is_valid_sql(sql: str, dialect: str = "sqlite") -> bool:
|
| 31 |
+
"""Парсится ли SQL через sqlglot."""
|
| 32 |
+
try:
|
| 33 |
+
sqlglot.parse_one(sql, dialect=dialect)
|
| 34 |
+
return True
|
| 35 |
+
except ParseError:
|
| 36 |
+
return False
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def normalize_sql(sql: str, dialect: str = "sqlite") -> str:
|
| 40 |
+
"""Нормализация для Exact Match: единый регистр ключевых слов, пробелы."""
|
| 41 |
+
try:
|
| 42 |
+
return sqlglot.parse_one(sql, dialect=dialect).sql(dialect=dialect, pretty=False).lower()
|
| 43 |
+
except ParseError:
|
| 44 |
+
# Если не парсится — просто нижний регистр и схлопывание пробелов
|
| 45 |
+
return re.sub(r"\s+", " ", sql.lower()).strip().rstrip(";")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def postprocess(raw_output: str) -> str:
|
| 49 |
+
"""Полный pipeline постобработки."""
|
| 50 |
+
return strip_model_artifacts(raw_output)
|
streamlit_app.py
ADDED
|
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Streamlit-интерфейс утилиты Ru2SQL.
|
| 2 |
+
|
| 3 |
+
Запуск:
|
| 4 |
+
streamlit run streamlit_app.py
|
| 5 |
+
|
| 6 |
+
Что умеет:
|
| 7 |
+
- Подключиться к любой SQLite/PostgreSQL/MySQL базе данных
|
| 8 |
+
- Загрузить бизнес-словарь компании из YAML-файла или редактировать прямо в браузере
|
| 9 |
+
- Принять вопрос на русском → сгенерировать SQL → выполнить → показать результат
|
| 10 |
+
- Хранить историю запросов в текущей сессии
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
import time
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import streamlit as st
|
| 20 |
+
|
| 21 |
+
# Путь к src/
|
| 22 |
+
ROOT = Path(__file__).resolve().parent
|
| 23 |
+
sys.path.insert(0, str(ROOT))
|
| 24 |
+
|
| 25 |
+
# ──────────────────────────────────────────────
|
| 26 |
+
# Конфигурация страницы
|
| 27 |
+
# ──────────────────────────────────────────────
|
| 28 |
+
st.set_page_config(
|
| 29 |
+
page_title="Ru2SQL — Natural Language → SQL",
|
| 30 |
+
page_icon="🗄️",
|
| 31 |
+
layout="wide",
|
| 32 |
+
initial_sidebar_state="expanded",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
# ──────────────────────────────────────────────
|
| 36 |
+
# CSS
|
| 37 |
+
# ──────────────────────────────────────────────
|
| 38 |
+
st.markdown("""
|
| 39 |
+
<style>
|
| 40 |
+
.sql-box {
|
| 41 |
+
background: #1e1e2e;
|
| 42 |
+
color: #cdd6f4;
|
| 43 |
+
font-family: 'Courier New', monospace;
|
| 44 |
+
font-size: 14px;
|
| 45 |
+
padding: 16px;
|
| 46 |
+
border-radius: 8px;
|
| 47 |
+
border-left: 4px solid #89b4fa;
|
| 48 |
+
white-space: pre-wrap;
|
| 49 |
+
margin: 8px 0;
|
| 50 |
+
}
|
| 51 |
+
.metric-card {
|
| 52 |
+
background: #313244;
|
| 53 |
+
padding: 12px 16px;
|
| 54 |
+
border-radius: 8px;
|
| 55 |
+
text-align: center;
|
| 56 |
+
}
|
| 57 |
+
.status-ok { color: #a6e3a1; font-weight: bold; }
|
| 58 |
+
.status-err { color: #f38ba8; font-weight: bold; }
|
| 59 |
+
.history-item {
|
| 60 |
+
border-left: 3px solid #89b4fa;
|
| 61 |
+
padding: 8px 12px;
|
| 62 |
+
margin: 6px 0;
|
| 63 |
+
background: #1e1e2e;
|
| 64 |
+
border-radius: 0 6px 6px 0;
|
| 65 |
+
}
|
| 66 |
+
</style>
|
| 67 |
+
""", unsafe_allow_html=True)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ──────────────────────────────────────────────
|
| 71 |
+
# Session state
|
| 72 |
+
# ──────────────────────────────────────────────
|
| 73 |
+
def _default_vocab_yaml() -> str:
|
| 74 |
+
example = ROOT / "configs" / "example_vocabulary.yaml"
|
| 75 |
+
if example.exists():
|
| 76 |
+
return example.read_text(encoding="utf-8")
|
| 77 |
+
return (
|
| 78 |
+
"company: Моя компания\n\n"
|
| 79 |
+
"terms:\n"
|
| 80 |
+
" выручка: SUM(orders.amount) WHERE status = 'paid'\n\n"
|
| 81 |
+
"filters:\n"
|
| 82 |
+
" только_оплаченные: orders.status = 'paid'\n\n"
|
| 83 |
+
"notes: []\n"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _init_state():
|
| 88 |
+
defaults = {
|
| 89 |
+
"history": [],
|
| 90 |
+
"model_loaded": False,
|
| 91 |
+
"engine": None,
|
| 92 |
+
"db_connector": None,
|
| 93 |
+
"db_executor": None,
|
| 94 |
+
"vocabulary": None,
|
| 95 |
+
"db_connection_string": "",
|
| 96 |
+
"vocab_yaml": _default_vocab_yaml(),
|
| 97 |
+
}
|
| 98 |
+
for k, v in defaults.items():
|
| 99 |
+
if k not in st.session_state:
|
| 100 |
+
st.session_state[k] = v
|
| 101 |
+
|
| 102 |
+
_init_state()
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
# ──────────────────────────────────────────────
|
| 106 |
+
# Вспомогательные функции
|
| 107 |
+
# ──────────────────────────────────────────────
|
| 108 |
+
@st.cache_resource(show_spinner="Загружаю модель… (~30 с на первый раз)")
|
| 109 |
+
def _load_engine():
|
| 110 |
+
from src.models.inference import InferenceEngine
|
| 111 |
+
engine = InferenceEngine()
|
| 112 |
+
engine.load()
|
| 113 |
+
return engine
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _connect_db(cs: str):
|
| 117 |
+
from src.db.connector import DbConnector
|
| 118 |
+
from src.db.executor import SqlExecutor
|
| 119 |
+
connector = DbConnector(cs)
|
| 120 |
+
executor = SqlExecutor(cs)
|
| 121 |
+
return connector, executor
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _load_vocab_from_yaml(yaml_text: str):
|
| 125 |
+
import tempfile
|
| 126 |
+
from src.business.vocabulary import BusinessVocabulary
|
| 127 |
+
tmp = Path(tempfile.mktemp(suffix=".yaml"))
|
| 128 |
+
tmp.write_text(yaml_text, encoding="utf-8")
|
| 129 |
+
vocab = BusinessVocabulary.from_yaml(tmp)
|
| 130 |
+
tmp.unlink(missing_ok=True)
|
| 131 |
+
return vocab
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ──────────────────────────────────────────────
|
| 135 |
+
# Боковая панель
|
| 136 |
+
# ──────────────────────────────────────────────
|
| 137 |
+
with st.sidebar:
|
| 138 |
+
st.title("⚙️ Настройки")
|
| 139 |
+
|
| 140 |
+
# ── Модель — загружается автоматически при старте ──
|
| 141 |
+
st.subheader("🤖 Модель")
|
| 142 |
+
if not st.session_state.model_loaded:
|
| 143 |
+
with st.spinner("Загружаю модель…"):
|
| 144 |
+
try:
|
| 145 |
+
st.session_state.engine = _load_engine()
|
| 146 |
+
st.session_state.model_loaded = True
|
| 147 |
+
except Exception as e:
|
| 148 |
+
st.error(f"Ошибка загрузки модели: {e}")
|
| 149 |
+
|
| 150 |
+
if st.session_state.model_loaded:
|
| 151 |
+
st.markdown('<span class="status-ok">✅ Модель готова</span>', unsafe_allow_html=True)
|
| 152 |
+
else:
|
| 153 |
+
st.markdown('<span class="status-err">⚠️ Модель не загружена</span>', unsafe_allow_html=True)
|
| 154 |
+
|
| 155 |
+
st.divider()
|
| 156 |
+
|
| 157 |
+
# ── База данных ──
|
| 158 |
+
st.subheader("🗄️ База данных")
|
| 159 |
+
|
| 160 |
+
db_type = st.radio("Тип подключения", ["SQLite файл", "Строка подключения"],
|
| 161 |
+
horizontal=True)
|
| 162 |
+
|
| 163 |
+
if db_type == "SQLite файл":
|
| 164 |
+
uploaded = st.file_uploader("Загрузить .sqlite файл", type=["sqlite", "db"])
|
| 165 |
+
use_demo = st.checkbox("Использовать демо-базу", value=True)
|
| 166 |
+
|
| 167 |
+
if use_demo:
|
| 168 |
+
demo_path = ROOT / "data" / "demo" / "sales.sqlite"
|
| 169 |
+
cs = str(demo_path)
|
| 170 |
+
elif uploaded:
|
| 171 |
+
import tempfile
|
| 172 |
+
tmp_db = Path(tempfile.mktemp(suffix=".sqlite"))
|
| 173 |
+
tmp_db.write_bytes(uploaded.read())
|
| 174 |
+
cs = str(tmp_db)
|
| 175 |
+
else:
|
| 176 |
+
cs = ""
|
| 177 |
+
else:
|
| 178 |
+
cs = st.text_input(
|
| 179 |
+
"Строка подключения",
|
| 180 |
+
placeholder="postgresql://user:pass@localhost/mydb",
|
| 181 |
+
value=st.session_state.db_connection_string,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
if cs and st.button("Подключиться к БД", use_container_width=True):
|
| 185 |
+
try:
|
| 186 |
+
connector, executor = _connect_db(cs)
|
| 187 |
+
tables = connector.list_tables()
|
| 188 |
+
st.session_state.db_connector = connector
|
| 189 |
+
st.session_state.db_executor = executor
|
| 190 |
+
st.session_state.db_connection_string = cs
|
| 191 |
+
st.success(f"Подключено! Таблиц: {len(tables)}")
|
| 192 |
+
except Exception as e:
|
| 193 |
+
st.error(f"Ошибка подключения: {e}")
|
| 194 |
+
|
| 195 |
+
if st.session_state.db_connector:
|
| 196 |
+
tables = st.session_state.db_connector.list_tables()
|
| 197 |
+
st.markdown('<span class="status-ok">✅ БД подключена</span>', unsafe_allow_html=True)
|
| 198 |
+
with st.expander("Таблицы"):
|
| 199 |
+
for t in tables:
|
| 200 |
+
st.code(t)
|
| 201 |
+
|
| 202 |
+
st.divider()
|
| 203 |
+
|
| 204 |
+
# ── Бизнес-словарь ──
|
| 205 |
+
st.subheader("📖 Бизнес-словарь")
|
| 206 |
+
|
| 207 |
+
vocab_yaml = st.text_area(
|
| 208 |
+
"YAML-конфигурация",
|
| 209 |
+
value=st.session_state.vocab_yaml,
|
| 210 |
+
height=260,
|
| 211 |
+
help="Определите термины вашей компании — модель будет их учитывать при генерации SQL",
|
| 212 |
+
)
|
| 213 |
+
st.session_state.vocab_yaml = vocab_yaml
|
| 214 |
+
|
| 215 |
+
if st.button("Применить словарь", use_container_width=True):
|
| 216 |
+
try:
|
| 217 |
+
st.session_state.vocabulary = _load_vocab_from_yaml(vocab_yaml)
|
| 218 |
+
st.success("Словарь применён!")
|
| 219 |
+
except Exception as e:
|
| 220 |
+
st.error(f"Ошибка в YAML: {e}")
|
| 221 |
+
|
| 222 |
+
if st.session_state.vocabulary:
|
| 223 |
+
v = st.session_state.vocabulary
|
| 224 |
+
st.markdown(f'<span class="status-ok">✅ Словарь: {v.company or "загружен"}</span>',
|
| 225 |
+
unsafe_allow_html=True)
|
| 226 |
+
terms_count = len(v.terms)
|
| 227 |
+
if terms_count:
|
| 228 |
+
st.caption(f"{terms_count} терминов определено")
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ──────────────────────────────────────────────
|
| 232 |
+
# Основная область
|
| 233 |
+
# ──────────────────────────────────────────────
|
| 234 |
+
st.title("🗄️ Ru2SQL — Бизнес-аналитика на русском языке")
|
| 235 |
+
st.caption("Задайте вопрос на русском → получите SQL и данные из вашей базы")
|
| 236 |
+
|
| 237 |
+
tab_query, tab_schema, tab_history = st.tabs(["💬 Запрос", "📐 Схема БД", "🕓 История"])
|
| 238 |
+
|
| 239 |
+
# ──────────── Вкладка: Запрос ────────────
|
| 240 |
+
with tab_query:
|
| 241 |
+
ready = st.session_state.model_loaded and st.session_state.db_connector is not None
|
| 242 |
+
|
| 243 |
+
if not ready:
|
| 244 |
+
cols = st.columns(2)
|
| 245 |
+
with cols[0]:
|
| 246 |
+
if not st.session_state.model_loaded:
|
| 247 |
+
st.warning("⚠️ Загрузите модель в левой панели")
|
| 248 |
+
with cols[1]:
|
| 249 |
+
if st.session_state.db_connector is None:
|
| 250 |
+
st.warning("⚠️ Подключитесь к базе данных в левой панели")
|
| 251 |
+
|
| 252 |
+
question = st.text_area(
|
| 253 |
+
"Ваш вопрос",
|
| 254 |
+
placeholder="Например: Какая выручка за январь этого года?",
|
| 255 |
+
height=100,
|
| 256 |
+
disabled=not ready,
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
col_btn, col_hint = st.columns([1, 4])
|
| 260 |
+
with col_btn:
|
| 261 |
+
run_btn = st.button("▶ Выполнить", type="primary",
|
| 262 |
+
disabled=not ready or not question.strip(),
|
| 263 |
+
use_container_width=True)
|
| 264 |
+
with col_hint:
|
| 265 |
+
if ready:
|
| 266 |
+
st.caption("Модель сгенерирует SQL и выполнит его на вашей БД")
|
| 267 |
+
|
| 268 |
+
# Быстрые примеры
|
| 269 |
+
if st.session_state.db_connection_string and "sales" in st.session_state.db_connection_string:
|
| 270 |
+
st.caption("💡 Попробуйте:")
|
| 271 |
+
example_cols = st.columns(3)
|
| 272 |
+
examples = [
|
| 273 |
+
"Какая выручка за 2026 год?",
|
| 274 |
+
"Топ-5 клиентов по сумме заказов",
|
| 275 |
+
"Сколько заказов по каждому менеджеру?",
|
| 276 |
+
]
|
| 277 |
+
for i, ex in enumerate(examples):
|
| 278 |
+
with example_cols[i]:
|
| 279 |
+
if st.button(ex, key=f"ex_{i}", use_container_width=True):
|
| 280 |
+
question = ex
|
| 281 |
+
run_btn = True
|
| 282 |
+
|
| 283 |
+
if run_btn and question.strip():
|
| 284 |
+
engine = st.session_state.engine
|
| 285 |
+
connector = st.session_state.db_connector
|
| 286 |
+
executor = st.session_state.db_executor
|
| 287 |
+
vocab = st.session_state.vocabulary
|
| 288 |
+
|
| 289 |
+
# Обогащаем вопрос бизнес-словарём
|
| 290 |
+
enriched_question = vocab.enrich_prompt(question) if vocab else question
|
| 291 |
+
|
| 292 |
+
# Получаем схему
|
| 293 |
+
schema = connector.render_schema(include_samples=True)
|
| 294 |
+
|
| 295 |
+
with st.spinner("Генерирую SQL…"):
|
| 296 |
+
t0 = time.time()
|
| 297 |
+
result = engine.generate(schema, enriched_question)
|
| 298 |
+
gen_time = time.time() - t0
|
| 299 |
+
|
| 300 |
+
st.subheader("Сгенерированный SQL")
|
| 301 |
+
st.markdown(f'<div class="sql-box">{result.sql}</div>', unsafe_allow_html=True)
|
| 302 |
+
|
| 303 |
+
col1, col2 = st.columns(2)
|
| 304 |
+
col1.metric("Время генерации", f"{gen_time:.1f} с")
|
| 305 |
+
|
| 306 |
+
# Выполняем SQL
|
| 307 |
+
if result.sql.strip():
|
| 308 |
+
with st.spinner("Выполняю запрос…"):
|
| 309 |
+
qr = executor.run(result.sql)
|
| 310 |
+
|
| 311 |
+
if qr.success:
|
| 312 |
+
col2.metric("Строк в результате", qr.row_count)
|
| 313 |
+
st.subheader("Результат")
|
| 314 |
+
if qr.rows:
|
| 315 |
+
import pandas as pd
|
| 316 |
+
df = pd.DataFrame(qr.rows, columns=qr.columns)
|
| 317 |
+
st.dataframe(df, use_container_width=True)
|
| 318 |
+
else:
|
| 319 |
+
st.info("Запрос выполнен успешно, результат пустой")
|
| 320 |
+
else:
|
| 321 |
+
col2.error("Ошибка выполнения")
|
| 322 |
+
st.error(f"SQL ошибка: {qr.error}")
|
| 323 |
+
|
| 324 |
+
# Добавляем в историю
|
| 325 |
+
st.session_state.history.append({
|
| 326 |
+
"question": question,
|
| 327 |
+
"sql": result.sql,
|
| 328 |
+
"success": qr.success if result.sql.strip() else False,
|
| 329 |
+
"rows": qr.row_count if result.sql.strip() and qr.success else 0,
|
| 330 |
+
"time": gen_time,
|
| 331 |
+
})
|
| 332 |
+
|
| 333 |
+
# ──────────── Вкладка: Схема БД ────────────
|
| 334 |
+
with tab_schema:
|
| 335 |
+
if st.session_state.db_connector is None:
|
| 336 |
+
st.info("Подключитесь к базе данных в левой панели")
|
| 337 |
+
else:
|
| 338 |
+
connector = st.session_state.db_connector
|
| 339 |
+
st.subheader("Структура базы данных")
|
| 340 |
+
|
| 341 |
+
show_samples = st.toggle("Показывать примеры строк", value=True)
|
| 342 |
+
schema_text = connector.render_schema(include_samples=show_samples)
|
| 343 |
+
|
| 344 |
+
for table in connector.get_schema(include_samples=show_samples):
|
| 345 |
+
with st.expander(f"📋 {table.name} ({len(table.columns)} колонок)"):
|
| 346 |
+
st.code(table.to_ddl(), language="sql")
|
| 347 |
+
if show_samples and table.sample_rows:
|
| 348 |
+
import pandas as pd
|
| 349 |
+
cols = [c.name for c in table.columns]
|
| 350 |
+
st.caption("Примеры строк:")
|
| 351 |
+
st.dataframe(
|
| 352 |
+
pd.DataFrame(table.sample_rows, columns=cols),
|
| 353 |
+
use_container_width=True,
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# ──────────── Вкладка: История ────────────
|
| 357 |
+
with tab_history:
|
| 358 |
+
history = st.session_state.history
|
| 359 |
+
if not history:
|
| 360 |
+
st.info("История запросов пуста. Задайте первый вопрос на вкладке «Запрос».")
|
| 361 |
+
else:
|
| 362 |
+
st.subheader(f"История запросов ({len(history)})")
|
| 363 |
+
|
| 364 |
+
if st.button("Очистить историю"):
|
| 365 |
+
st.session_state.history = []
|
| 366 |
+
st.rerun()
|
| 367 |
+
|
| 368 |
+
for i, item in enumerate(reversed(history)):
|
| 369 |
+
status = "✅" if item["success"] else "❌"
|
| 370 |
+
with st.expander(f"{status} {item['question']}", expanded=(i == 0)):
|
| 371 |
+
st.markdown(f'<div class="sql-box">{item["sql"]}</div>', unsafe_allow_html=True)
|
| 372 |
+
cols = st.columns(3)
|
| 373 |
+
cols[0].metric("Время генерации", f"{item['time']:.1f} с")
|
| 374 |
+
cols[1].metric("Строк", item["rows"])
|
| 375 |
+
cols[2].metric("Статус", "OK" if item["success"] else "Ошибка")
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_metrics.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Тесты на метрики EM и EX."""
|
| 2 |
+
|
| 3 |
+
import sqlite3
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.evaluation.metrics import exact_match, execution_accuracy
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_exact_match_simple():
|
| 12 |
+
assert exact_match("SELECT * FROM t", "select * from t")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_exact_match_whitespace():
|
| 16 |
+
assert exact_match("SELECT * FROM t", "SELECT * FROM t")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def test_exact_match_negative():
|
| 20 |
+
assert not exact_match("SELECT a FROM t", "SELECT b FROM t")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@pytest.fixture
|
| 24 |
+
def tmp_sqlite(tmp_path: Path) -> Path:
|
| 25 |
+
db = tmp_path / "tiny.sqlite"
|
| 26 |
+
conn = sqlite3.connect(db)
|
| 27 |
+
conn.execute("CREATE TABLE users (id INT, name TEXT)")
|
| 28 |
+
conn.executemany("INSERT INTO users VALUES (?, ?)", [(1, "a"), (2, "b")])
|
| 29 |
+
conn.commit()
|
| 30 |
+
conn.close()
|
| 31 |
+
return db
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def test_execution_accuracy_match(tmp_sqlite: Path):
|
| 35 |
+
pred = "SELECT id FROM users ORDER BY id"
|
| 36 |
+
gold = "SELECT id FROM users ORDER BY id"
|
| 37 |
+
assert execution_accuracy(pred, gold, tmp_sqlite)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_execution_accuracy_set_equal(tmp_sqlite: Path):
|
| 41 |
+
pred = "SELECT id FROM users ORDER BY id DESC"
|
| 42 |
+
gold = "SELECT id FROM users ORDER BY id ASC"
|
| 43 |
+
# Без ORDER BY проверки — как множества они равны
|
| 44 |
+
assert execution_accuracy(pred, gold, tmp_sqlite)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_execution_accuracy_mismatch(tmp_sqlite: Path):
|
| 48 |
+
pred = "SELECT id FROM users WHERE id = 1"
|
| 49 |
+
gold = "SELECT id FROM users WHERE id = 2"
|
| 50 |
+
assert not execution_accuracy(pred, gold, tmp_sqlite)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def test_execution_accuracy_invalid_pred(tmp_sqlite: Path):
|
| 54 |
+
pred = "SELEC bad sql"
|
| 55 |
+
gold = "SELECT id FROM users"
|
| 56 |
+
assert not execution_accuracy(pred, gold, tmp_sqlite)
|
tests/test_postprocess.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Тесты на постобработку SQL."""
|
| 2 |
+
|
| 3 |
+
from src.models.postprocess import (
|
| 4 |
+
is_valid_sql,
|
| 5 |
+
normalize_sql,
|
| 6 |
+
postprocess,
|
| 7 |
+
strip_model_artifacts,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_strip_markdown_block():
|
| 12 |
+
raw = "```sql\nSELECT * FROM users;\n```"
|
| 13 |
+
assert strip_model_artifacts(raw).startswith("SELECT")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_strip_sql_prefix():
|
| 17 |
+
raw = "SQL: SELECT 1;"
|
| 18 |
+
assert strip_model_artifacts(raw).startswith("SELECT")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_keeps_first_statement():
|
| 22 |
+
raw = "SELECT 1; SELECT 2;"
|
| 23 |
+
out = strip_model_artifacts(raw)
|
| 24 |
+
assert "SELECT 1" in out
|
| 25 |
+
assert "SELECT 2" not in out
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_valid_sql():
|
| 29 |
+
assert is_valid_sql("SELECT * FROM students WHERE id = 1")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_invalid_sql():
|
| 33 |
+
assert not is_valid_sql("SELEC * FRM where")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_normalize_em():
|
| 37 |
+
a = "SELECT * FROM Users"
|
| 38 |
+
b = "select * from users"
|
| 39 |
+
assert normalize_sql(a) == normalize_sql(b)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_postprocess_full():
|
| 43 |
+
raw = "```sql\nSELECT name FROM students WHERE group_id = 1;\nSELECT 2;\n```"
|
| 44 |
+
out = postprocess(raw)
|
| 45 |
+
assert out.startswith("SELECT name")
|
| 46 |
+
assert "SELECT 2" not in out
|
tests/test_prompt.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Тесты на PromptBuilder."""
|
| 2 |
+
|
| 3 |
+
from src.data.prompt import (
|
| 4 |
+
SYSTEM_PROMPT,
|
| 5 |
+
build_chat_messages,
|
| 6 |
+
build_training_example,
|
| 7 |
+
build_user_message,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_user_message_contains_parts():
|
| 12 |
+
msg = build_user_message("CREATE TABLE t (id INT);", "Покажи всё")
|
| 13 |
+
assert "Schema:" in msg
|
| 14 |
+
assert "Question:" in msg
|
| 15 |
+
assert "SQL:" in msg
|
| 16 |
+
assert "CREATE TABLE" in msg
|
| 17 |
+
assert "Покажи всё" in msg
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_chat_messages_have_system_and_user():
|
| 21 |
+
msgs = build_chat_messages("schema", "question")
|
| 22 |
+
assert len(msgs) == 2
|
| 23 |
+
assert msgs[0]["role"] == "system"
|
| 24 |
+
assert msgs[0]["content"] == SYSTEM_PROMPT
|
| 25 |
+
assert msgs[1]["role"] == "user"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_training_example_has_assistant():
|
| 29 |
+
msgs = build_training_example("schema", "question", "SELECT 1")
|
| 30 |
+
assert len(msgs) == 3
|
| 31 |
+
assert msgs[2]["role"] == "assistant"
|
| 32 |
+
assert msgs[2]["content"] == "SELECT 1"
|
tests/test_schema.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Тесты на SchemaRetriever."""
|
| 2 |
+
|
| 3 |
+
import sqlite3
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import pytest
|
| 7 |
+
|
| 8 |
+
from src.data.schema import SchemaRetriever
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def fake_databases_dir(tmp_path: Path) -> Path:
|
| 13 |
+
"""Создаёт структуру databases/uni/uni.sqlite с двумя таблицами."""
|
| 14 |
+
db_id = "uni"
|
| 15 |
+
(tmp_path / db_id).mkdir()
|
| 16 |
+
db_path = tmp_path / db_id / f"{db_id}.sqlite"
|
| 17 |
+
conn = sqlite3.connect(db_path)
|
| 18 |
+
conn.execute("CREATE TABLE students (id INTEGER PRIMARY KEY, name TEXT)")
|
| 19 |
+
conn.execute("CREATE TABLE groups (id INTEGER PRIMARY KEY, faculty TEXT)")
|
| 20 |
+
conn.execute("INSERT INTO students VALUES (1, 'Иван')")
|
| 21 |
+
conn.execute("INSERT INTO groups VALUES (10, 'ПИ')")
|
| 22 |
+
conn.commit()
|
| 23 |
+
conn.close()
|
| 24 |
+
return tmp_path
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def test_list_databases(fake_databases_dir: Path):
|
| 28 |
+
r = SchemaRetriever(fake_databases_dir)
|
| 29 |
+
assert r.list_databases() == ["uni"]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def test_get_tables(fake_databases_dir: Path):
|
| 33 |
+
r = SchemaRetriever(fake_databases_dir)
|
| 34 |
+
tables = r.get_tables("uni")
|
| 35 |
+
names = sorted(t.name for t in tables)
|
| 36 |
+
assert names == ["groups", "students"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_render_schema_contains_create(fake_databases_dir: Path):
|
| 40 |
+
r = SchemaRetriever(fake_databases_dir)
|
| 41 |
+
text = r.render_schema("uni")
|
| 42 |
+
assert "CREATE TABLE" in text
|
| 43 |
+
assert "students" in text
|
| 44 |
+
assert "groups" in text
|