Forgekit / forgekit /ai_advisor.py
AIencoder's picture
Update forgekit/ai_advisor.py
58a2c61 verified
"""AI-powered merge advisor using Groq API (free, fast inference)."""
import os
import requests
from typing import Optional
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"
DEFAULT_MODEL = "llama-3.3-70b-versatile"
def _query_groq(
prompt: str,
system: str = "",
model: str = DEFAULT_MODEL,
api_key: Optional[str] = None,
max_tokens: int = 1024,
) -> str:
"""Query Groq's OpenAI-compatible API.
Args:
prompt: User message
system: System prompt
model: Groq model ID
api_key: Groq API key (free at console.groq.com)
max_tokens: Max response length
Returns:
Generated text response
"""
key = (api_key or "").strip() or os.environ.get("GROQ_API_KEY", "")
if not key:
return (
"**Groq API Key required** — the AI Advisor uses Groq for fast, free inference.\n\n"
"1. Go to [console.groq.com](https://console.groq.com) and sign up (free, no credit card)\n"
"2. Create an API key\n"
"3. Paste it in the field above\n\n"
"Groq gives you thousands of free requests per day with Llama 3.3 70B!"
)
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
payload = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
"temperature": 0.7,
}
try:
resp = requests.post(GROQ_API_URL, headers=headers, json=payload, timeout=30)
if resp.status_code == 429:
return "Rate limited — Groq free tier allows ~30 requests/min. Wait a moment and try again."
if resp.status_code == 401:
return "Invalid Groq API key. Get a free one at [console.groq.com](https://console.groq.com)."
if resp.status_code != 200:
return f"Groq API error (status {resp.status_code}). Try again."
data = resp.json()
text = data["choices"][0]["message"]["content"]
return text.strip()
except requests.exceptions.Timeout:
return "Request timed out — try again."
except Exception as e:
return f"Error: {str(e)}"
# ===== SYSTEM PROMPT =====
ADVISOR_SYSTEM = """You are ForgeKit AI, an expert assistant for merging large language models using mergekit. You have deep knowledge of:
- Model architectures (LLaMA, Qwen, Mistral, Gemma, Phi)
- Merge methods: DARE-TIES, TIES, SLERP, Linear, Task Arithmetic, Passthrough (Frankenmerge)
- Optimal weight/density configurations for different use cases
- Common pitfalls and best practices
Be concise, practical, and specific. Always give concrete numbers for weights and densities.
Format responses with markdown headers and bullet points for readability."""
# ===== AI FEATURES =====
def merge_advisor(
models_text: str,
goal: str = "",
api_key: Optional[str] = None,
) -> str:
"""AI recommends the best merge method, weights, and configuration."""
models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
if len(models) < 2:
return "Add at least 2 models (one per line) to get a recommendation."
models_str = "\n".join(f"- {m}" for m in models)
goal_str = f"\n\nThe user's goal: {goal}" if goal.strip() else ""
prompt = f"""I want to merge these models:
{models_str}
{goal_str}
Give me a specific recommendation:
1. **Best merge method** and why
2. **Exact weights** for each model
3. **Density values** (if applicable)
4. **Which model as base** and why
5. **Which tokenizer** to keep
6. **Warnings or tips** for these specific models
7. **The complete YAML config** ready for mergekit"""
return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)
def model_describer(
models_text: str,
method: str = "",
weights_text: str = "",
api_key: Optional[str] = None,
) -> str:
"""AI predicts what the merged model will be good at."""
models = [m.strip() for m in models_text.strip().split("\n") if m.strip()]
if not models:
return "Add models first."
models_str = "\n".join(f"- {m}" for m in models)
method_str = f" using **{method}**" if method else ""
weights_str = f"\nWeights: {weights_text}" if weights_text.strip() else ""
prompt = f"""I'm merging these models{method_str}:
{models_str}{weights_str}
Predict:
1. **What it will excel at** — specific tasks and benchmarks
2. **What it might lose** compared to individual source models
3. **Ideal use cases** for this merge
4. **Quality estimate** vs each source model
5. **A creative name suggestion** for this merged model"""
return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)
def config_explainer(
yaml_config: str,
api_key: Optional[str] = None,
) -> str:
"""AI explains a YAML merge config in plain English."""
if not yaml_config.strip() or yaml_config.startswith("# Add"):
return "Generate or paste a YAML config first."
prompt = f"""Explain this mergekit config in plain English for a beginner:
```yaml
{yaml_config}
```
Cover:
1. **What this does** in simple terms
2. **Why these settings** — explain each parameter
3. **What the output will be like**
4. **Potential issues** to watch for
5. **Resource requirements** (RAM, time, Colab tier)"""
return _query_groq(prompt, system=ADVISOR_SYSTEM, api_key=api_key)