huggingface_ai_final / smolagents_agent.py
alfulanny's picture
Update smolagents_agent.py
8564202 verified
raw
history blame
9.82 kB
import os
import re
from typing import Optional
from dotenv import load_dotenv
from smolagents import InferenceClientModel
# Load environment variables
load_dotenv()
# -------------------------
# Direct Tool-Based Agent (No Code Generation)
# -------------------------
class OptimizedSmolagentsGAIAgent:
"""
A direct agent that uses tools without code generation.
This avoids all the syntax and runtime errors from generated code.
"""
def __init__(self):
# Initialize model
self.model = self._initialize_model()
# Available tools
self.tools = {
'calculator': self._safe_calculate,
'web_search': self._safe_web_search,
'wikipedia': self._safe_wikipedia_search,
'visit_webpage': self._safe_visit_webpage,
'image_analysis': self._safe_image_analysis
}
def _initialize_model(self):
"""Initialize model with multiple fallbacks"""
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
print("HF_TOKEN not found. Using fallback mode.")
return None
# Try multiple models for reliability
model_options = [
"allenai/Olmo-3-7B-Instruct",
"allenai/Olmo-3-7B-Think"
]
for model in model_options:
try:
model = InferenceClientModel(
model_id=model,
token=hf_token,
timeout=30
)
print(f"Using model: {model}")
return model
except Exception as e:
print(f"Failed to initialize {model}: {e}")
continue
return None
def _classify_question(self, question: str) -> str:
"""Classify question type for appropriate tool selection"""
q_lower = question.lower()
# Mathematical questions
if any(word in q_lower for word in ['calculate', 'compute', 'solve', '+', '-', '*', '/', '=']):
return 'math'
# Web search questions
elif any(word in q_lower for word in ['search', 'find', 'recent', 'current', 'today']):
return 'search'
# Factual/Wikipedia questions
elif any(word in q_lower for word in ['who is', 'what is', 'when', 'where', 'history', 'biography']):
return 'wikipedia'
# Webpage questions
elif 'http' in question or 'www.' in question:
return 'webpage'
# Image questions
elif any(word in q_lower for word in ['image', 'picture', 'photo', 'visual', 'chess']):
return 'image'
# Default to search for general questions
else:
return 'search'
def _safe_calculate(self, question: str) -> str:
"""Safe calculation using basic math"""
try:
# Extract math expressions
math_pattern = r'[\d\+\-\*\/\.\(\)\s]+'
match = re.search(math_pattern, question)
if match:
expr = match.group().strip()
# Only allow basic arithmetic
if re.match(r'^[\d\+\-\*\/\.\(\)\s]+$', expr):
result = eval(expr)
return f"The calculation result is: {result}"
return "No clear mathematical expression found in the question."
except Exception as e:
return f"Unable to calculate: {str(e)}"
def _safe_web_search(self, question: str) -> str:
"""Safe web search with error handling"""
try:
from smolagents import DuckDuckGoSearchTool
search_tool = DuckDuckGoSearchTool()
# Clean the query for search
query = re.sub(r'[^a-zA-Z0-9\s]', '', question)
if len(query) > 100:
query = query[:100]
result = search_tool.forward(query)
if isinstance(result, str):
if len(result) > 300:
result = result[:300] + "..."
return f"Search results for '{query}': {result}"
return "Search completed successfully."
except Exception as e:
return f"Web search error: {str(e)}"
def _safe_wikipedia_search(self, question: str) -> str:
"""Safe Wikipedia search with error handling"""
try:
import wikipedia
# Extract search terms
if 'who is' in question.lower():
query = question.lower().replace('who is', '').strip()
elif 'what is' in question.lower():
query = question.lower().replace('what is', '').strip()
else:
query = question.strip()
if query:
summary = wikipedia.summary(query, sentences=2)
if len(summary) > 200:
summary = summary[:200] + "..."
return f"Information about '{query}': {summary}"
return "Unable to extract search terms from question."
except Exception as e:
return f"Wikipedia search error: {str(e)}"
def _safe_visit_webpage(self, question: str) -> str:
"""Safe webpage visiting with error handling"""
try:
from smolagents import VisitWebpageTool
visit_tool = VisitWebpageTool()
# Extract URL from question
url_pattern = r'https?://[^\s\)]+'
urls = re.findall(url_pattern, question)
if urls:
url = urls[0]
result = visit_tool.forward(url)
if isinstance(result, str):
if len(result) > 200:
result = result[:200] + "..."
return f"Content from {url}: {result}"
return f"Successfully visited {url}"
return "No URL found in the question."
except Exception as e:
return f"Webpage visit error: {str(e)}"
def _safe_image_analysis(self, question: str) -> str:
"""Safe image analysis with error handling"""
try:
# For chess questions
if 'chess' in question.lower():
return "Chess position analysis: This appears to be a chess-related question. Black's turn means black pieces need to make the next move. Without the actual board image, I cannot provide the specific move, but typical strategic considerations include developing pieces, controlling center, or_castling."
# For general image questions
elif any(word in question.lower() for word in ['image', 'picture', 'photo']):
return "Image analysis: The question references image content that I cannot directly access. For visual analysis tasks, please describe what you can see in the image or provide specific details about the visual elements."
else:
return "Image processing: Unable to analyze image content directly. Please provide more details about what visual information you need."
except Exception as e:
return f"Image analysis error: {str(e)}"
def _generate_direct_answer(self, question: str, question_type: str) -> str:
"""Generate direct answers without code generation"""
if question_type == 'math':
return self._safe_calculate(question)
elif question_type == 'search':
return self._safe_web_search(question)
elif question_type == 'wikipedia':
return self._safe_wikipedia_search(question)
elif question_type == 'webpage':
return self._safe_visit_webpage(question)
elif question_type == 'image':
return self._safe_image_analysis(question)
else:
# Default fallback
return self._safe_web_search(question)
def process_question(self, question: str) -> str:
"""Process question using direct tool approach (no code generation)"""
# Handle no model case
if not self.model:
return "No language model available. Please set HF_TOKEN in environment variables."
try:
# Classify question type
question_type = self._classify_question(question)
# Generate direct answer using appropriate tool
answer = self._generate_direct_answer(question, question_type)
return answer
except Exception as e:
error_msg = str(e)
# Specific error handling
if "timeout" in error_msg.lower():
return "Request timed out. The question may be too complex. Please try a simpler question."
elif "500" in error_msg:
return "Server error occurred. This may be a temporary issue. Please try again later."
else:
return f"Unable to process question: {error_msg[:200]}"
# -------------------------
# Test the direct tool agent
# -------------------------
if __name__ == "__main__":
agent = OptimizedSmolagentsGAIAgent()
test_questions = [
"What is the capital of France?",
"Calculate 15 + 27 * 3",
"Who is Mercedes Sosa?",
"Review the chess position in the image",
"What does this webpage say: https://example.com"
]
print("=== DIRECT TOOL AGENT TEST ===\n")
for question in test_questions:
print(f"Q: {question}")
answer = agent.process_question(question)
print(f"A: {answer[:200]}...")
print("-" * 50)