import re import os from vanna import Agent, AgentConfig from vanna.core.registry import ToolRegistry from vanna.core.user import UserResolver, User, RequestContext from vanna.tools import RunSqlTool from vanna.tools.agent_memory import SaveQuestionToolArgsTool, SearchSavedCorrectToolUsesTool from vanna.integrations.postgres import PostgresRunner from vanna.integrations.local.agent_memory import DemoAgentMemory from .vanna_huggingface_llm_service import VannaHuggingFaceLlmService from typing import List, Dict, Any, Optional from vanna.core.system_prompt import SystemPromptBuilder from vanna.core.registry import ToolSchema from datetime import datetime class CustomSQLSystemPromptBuilder(SystemPromptBuilder): """Complete system prompt builder for Vanna SQL assistant v2.""" VERSION = "2.2.0" def __init__(self, company_name: str = "CoJournalist", sql_runner: Optional[PostgresRunner] = None): self.company_name = company_name self.sql_runner = sql_runner async def build_system_prompt( self, user: User, tool_schemas: List[ToolSchema], context: Optional[Dict[str, Any]] = None ) -> str: today = datetime.now().strftime("%Y-%m-%d") username = getattr(user, "username", user.id) # ====================== # BASE PROMPT # ====================== prompt = f"[System Prompt v{self.VERSION}]\n\n" prompt += f"You are an expert SQL assistant for the company {self.company_name}.\n" prompt += f"Date: {today}\nUser: {username}\nGroups: {', '.join(user.group_memberships)}\n\n" prompt += ( "Your role: generate correct and efficient SQL queries from natural language.\n" "You always respond in **raw CSV format**, with no explanation or extra text.\n" "You have full access to all tables and relationships described in the schema.\n" ) # ====================== # SQL DIRECTIVES # ====================== prompt += ( "\n## SQL Directives\n" "- Always use table aliases in JOINs\n" "- Never use SELECT *\n" "- Prefer window functions over subqueries when possible\n" "- Always include a LIMIT for exploratory queries\n" "- Format dates and numbers for readability\n" ) # ====================== # DATABASE SCHEMA # ====================== if context and "database_schema" in context: prompt += "\n## Database Schema\n" prompt += context["database_schema"] else: prompt += ( "\n## Database Schema\n" "Tables:\n" "- posts (id, title, source_url, author, published_date, image_url, type, provider_id, created_at, updated_at, dead)\n" "- providers (id, name)\n" "- provider_attributes (id, provider_id, type, name)\n" "- post_provider_attributes (post_id, attribute_id)\n" "- tags (id, name)\n" "- post_tags (post_id, tag_id, weight)\n" "\nRelationships:\n" " - posts.provider_id → providers.id\n" " - post_provider_attributes.post_id → posts.id\n" " - post_provider_attributes.attribute_id → provider_attributes.id\n" " - provider_attributes.provider_id → providers.id\n" " - post_tags.post_id → posts.id\n" " - post_tags.tag_id → tags.id\n" ) # ====================== # SEMANTIC INFORMATION # ====================== prompt += ( "\n## Semantic Information\n" "- `posts.title`: title of the content (often descriptive, may contain keywords).\n" "- `posts.source_url`: external link to the article or resource.\n" "- `posts.author`: author, journalist, or organization name (e.g., 'The New York Times').\n" "- `posts.published_date`: publication date.\n" "- `posts.type`: content type ENUM ('spotlight', 'resource', 'insight').\n" "- `providers.name`: name of the publishing organization (e.g., 'Nuanced', 'SND').\n" "- `tags.name`: thematic keyword or topic (e.g., '3D', 'AI', 'Design').\n" "- `post_tags.weight`: relevance score between a post and a tag.\n" "- `posts.dead`: boolean flag indicating if the post is dead/removed (true = dead, false = active).\n" ) # ====================== # BUSINESS LOGIC # ====================== prompt += ( "\n## Business Logic\n" "- **ALWAYS filter out dead posts**: Include `WHERE p.dead = false` (or `AND p.dead = false`) in every query. Never return posts where dead = true.\n" "- A query mentioning an organization (e.g., 'New York Times') should search both `posts.author` and `providers.name`.\n" "- Return all post types (spotlight, resource, insight) unless the user specifies otherwise.\n" "- Tags link posts to specific themes or disciplines.\n" "- A single post may have multiple tags, awards, or categories.\n" "- If the user mentions a year (e.g., 'in 2021'), filter with `EXTRACT(YEAR FROM published_date) = 2021`.\n" "- If the user says 'recently', filter posts from the last 90 days.\n" "- Always limit exploratory results to 9 rows.\n" "\n" "## CRITICAL: Search Strategy\n" "**IMPORTANT**: Only 3 posts currently have tags. Most posts (7,245+) are NOT tagged yet.\n" "\n" "**Hybrid Search Approach (RECOMMENDED)**:\n" "- ALWAYS use a hybrid approach combining tag search AND keyword search with OR logic.\n" "- Use LEFT JOINs for tags (not INNER JOIN) so untagged posts are included.\n" "\n" "**Keyword Matching - Use PostgreSQL Regex for Exact Word Boundaries**:\n" "- Use ~* operator for case-insensitive regex matching\n" "- Use \\m and \\M for word boundaries (start and end of word)\n" "- Pattern: column ~* '\\\\mkeyword\\\\M'\n" "- Example: p.title ~* '\\\\mf1\\\\M' matches 'F1' but NOT 'profile' or 'if'\n" "- This ensures exact word matching, not substring matching\n" "\n" "**When to use tag-only search**: Only if user explicitly mentions 'tagged with' or 'tag:'.\n" "**When to use keyword-only search**: For author/organization names, or when tags are not relevant.\n" "\n" "This ensures maximum result coverage while the database is being enriched with tags.\n" ) # ====================== # AVAILABLE TOOLS # ====================== if tool_schemas: prompt += "\n## Available Tools\n" for tool in tool_schemas: prompt += f"- {tool.name}: {getattr(tool, 'description', 'No description')}\n" prompt += f" Parameters: {getattr(tool, 'parameters', 'N/A')}\n" # ====================== # MEMORY SYSTEM # ====================== tool_names = [t.name for t in tool_schemas] has_search = "search_saved_correct_tool_uses" in tool_names has_save = "save_question_tool_args" in tool_names if has_search or has_save: prompt += "\n## Memory System\n" if has_search: prompt += "- Use `search_saved_correct_tool_uses` to detect past patterns.\n" if has_save: prompt += "- Use `save_question_tool_args` to store successful pairs.\n" # ====================== # EXAMPLES # ====================== prompt += ( "\n## Example Interactions\n" "User: 'F1' or 'Show me F1 content'\n" "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type " "FROM posts p " "LEFT JOIN post_tags pt ON p.id = pt.post_id " "LEFT JOIN tags t ON pt.tag_id = t.id " "LEFT JOIN providers pr ON p.provider_id = pr.id " "WHERE p.dead = false AND (t.name ~* '\\\\mf1\\\\M' OR t.name ~* '\\\\mformula\\\\M' " "OR p.title ~* '\\\\mf1\\\\M' OR p.title ~* '\\\\mformula\\\\M' " "OR p.author ~* '\\\\mf1\\\\M') " "ORDER BY p.published_date DESC NULLS LAST LIMIT 9;\"]\n" "\nUser: 'Show me posts from The New York Times'\n" "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type " "FROM posts p " "LEFT JOIN providers pr ON p.provider_id = pr.id " "WHERE p.dead = false AND (p.author ~* '\\\\mnew\\\\M.*\\\\myork\\\\M.*\\\\mtimes\\\\M' OR pr.name ~* '\\\\mnew\\\\M.*\\\\myork\\\\M.*\\\\mtimes\\\\M') " "ORDER BY p.published_date DESC NULLS LAST LIMIT 9;\"]\n" "\nUser: 'interactive visualizations'\n" "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type " "FROM posts p " "LEFT JOIN post_tags pt ON p.id = pt.post_id " "LEFT JOIN tags t ON pt.tag_id = t.id " "WHERE p.dead = false AND (t.name ~* '\\\\minteractive\\\\M' OR p.title ~* '\\\\minteractive\\\\M' " "OR p.title ~* '\\\\mvisualization\\\\M' OR t.name ~* '\\\\mdataviz\\\\M') " "ORDER BY p.published_date DESC NULLS LAST LIMIT 9;\"]\n" ) # ====================== # FINAL INSTRUCTIONS # ====================== prompt += ( "\nIMPORTANT:\n" "- Always return **only the raw CSV result** — no explanations, no JSON, no commentary.\n" "- Stop tool execution once the query result is obtained.\n" ) return prompt class SimpleUserResolver(UserResolver): async def resolve_user(self, request_context: RequestContext) -> User: user_email = request_context.get_cookie('vanna_email') or 'guest@example.com' group = 'admin' if user_email == 'admin@example.com' else 'user' return User(id=user_email, email=user_email, group_memberships=[group]) class VannaComponent: def __init__( self, hf_model: str, hf_token: str, hf_provider: str, connection_string: str, ): llm = VannaHuggingFaceLlmService(model=hf_model, token=hf_token, provider=hf_provider) self.sql_runner = PostgresRunner(connection_string=connection_string) db_tool = RunSqlTool(sql_runner=self.sql_runner) agent_memory = DemoAgentMemory(max_items=1000) save_memory_tool = SaveQuestionToolArgsTool() search_memory_tool = SearchSavedCorrectToolUsesTool() self.user_resolver = SimpleUserResolver() tools = ToolRegistry() tools.register_local_tool(db_tool, access_groups=['admin', 'user']) tools.register_local_tool(save_memory_tool, access_groups=['admin']) tools.register_local_tool(search_memory_tool, access_groups=['admin', 'user']) self.agent = Agent( llm_service=llm, tool_registry=tools, user_resolver=self.user_resolver, agent_memory=agent_memory, system_prompt_builder=CustomSQLSystemPromptBuilder("CoJournalist", self.sql_runner), config=AgentConfig(stream_responses=False, max_tool_iterations=3) ) async def ask(self, prompt_for_llm: str): ctx = RequestContext() print(f"\n{'='*80}") print(f"🙋 User Query: {prompt_for_llm}") print(f"{'='*80}\n") final_text = "" seen_texts = set() query_executed = False result_row_count = 0 async for component in self.agent.send_message(request_context=ctx, message=prompt_for_llm): simple = getattr(component, "simple_component", None) text = getattr(simple, "text", "") if simple else "" if text and text not in seen_texts: print(f"💬 LLM Response: {text[:300]}...") final_text += text + "\n" seen_texts.add(text) sql_query = getattr(component, "sql", None) if sql_query: query_executed = True print(f"\n🧾 SQL Query Generated:") print(f"{'-'*80}") print(f"{sql_query}") print(f"{'-'*80}\n") metadata = getattr(component, "metadata", None) if metadata: print(f"📋 Query Metadata: {metadata}") result_row_count = metadata.get("row_count", 0) if result_row_count == 0: print(f"⚠️ Query returned 0 rows - no data matched the criteria") else: print(f"✅ Query returned {result_row_count} rows") component_type = getattr(component, "type", None) if component_type: print(f"🔖 Component Type: {component_type}") match = re.search(r"query_results_[\w-]+\.csv", final_text) if match: filename = match.group(0) # Calculate the user-specific folder based on the default user ID import hashlib user_hash = hashlib.sha256("guest@example.com".encode()).hexdigest()[:16] folder = user_hash full_path = os.path.join(folder, filename) print(f"\n📁 Looking for CSV file: {full_path}") # Create folder if it doesn't exist if not os.path.exists(folder): print(f"📂 Creating user directory: {folder}") os.makedirs(folder, exist_ok=True) if os.path.exists(full_path): print(f"✅ Found CSV file, reading contents...") with open(full_path, "r", encoding="utf-8") as f: csv_data = f.read().strip() print(f"📊 CSV Data Preview: {csv_data[:200]}...") print(f"{'='*80}\n") return csv_data else: print(f"❌ CSV file not found at: {full_path}") # List files in the directory to help debug if os.path.exists(folder): files = os.listdir(folder) print(f"📂 Files in {folder}: {files}") print(f"\n{'='*80}") if not query_executed: print(f"⚠️ No SQL query was executed by the LLM") print(f"📤 Returning final response to user") print(f"{'='*80}\n") return final_text