Spaces:
Running
Running
File size: 7,337 Bytes
5b89d45 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | import os
import glob
from typing import List, Optional
from langchain_core.tools import tool
from pydantic import BaseModel, Field
# Define Input Schemas
class ListFilesInput(BaseModel):
path: str = Field(description="Directory path to list files from. Use '.' for root.")
class ReadFileInput(BaseModel):
file_path: str = Field(description="Path to the file to read.")
# Define Tools Factory
def get_filesystem_tools(root_dir: str = "."):
"""Returns a list of tools bound to the specified root directory."""
# Ensure root_dir is absolute
root_dir = os.path.abspath(root_dir)
@tool("list_files", args_schema=ListFilesInput)
def list_files(path: str = ".") -> str:
"""Lists files in the specified directory."""
try:
# Resolve target path relative to root_dir
if path == ".":
target_path = root_dir
else:
target_path = os.path.abspath(os.path.join(root_dir, path))
# Security check: ensure we are inside the codebase
if not target_path.startswith(root_dir):
return f"Error: Access denied. Path must be within the codebase: {root_dir}"
if not os.path.exists(target_path):
return f"Error: Path does not exist: {path}"
files = []
for item in os.listdir(target_path):
if item.startswith(".") and item != ".gitignore": continue
full_item_path = os.path.join(target_path, item)
if os.path.isdir(full_item_path):
files.append(f"{item}/")
else:
files.append(item)
# Sort for stability
files.sort()
return "\n".join(files)
except Exception as e:
return f"Error listing files: {e}"
@tool("read_file", args_schema=ReadFileInput)
def read_file(file_path: str) -> str:
"""Reads the content of a file."""
try:
# Resolve full path
full_path = os.path.abspath(os.path.join(root_dir, file_path))
# Security check
if not full_path.startswith(root_dir):
return "Error: Access denied. File must be within the codebase."
if not os.path.exists(full_path):
return f"Error: File not found: {file_path}"
# Check file size to avoid overloading context
# Groq TPM limit is ~12k tokens. 12000 chars is roughly 3k tokens.
# We strictly prevent reading massive files to keep the agent alive.
if os.path.getsize(full_path) > 12000:
return f"Error: File '{file_path}' is too large ({os.path.getsize(full_path)} bytes). Read specific lines or functions instead."
with open(full_path, "r", errors='ignore') as f:
content = f.read()
return content
except Exception as e:
return f"Error reading file: {e}"
return [list_files, read_file]
# ============================================================================
# Call Graph Tools
# ============================================================================
class FindCallersInput(BaseModel):
function_name: str = Field(description="Name of the function to find callers for")
class FindCalleesInput(BaseModel):
function_name: str = Field(description="Name of the function to find callees for")
class FindCallChainInput(BaseModel):
start_function: str = Field(description="Name of the starting function")
end_function: str = Field(description="Name of the target function to trace to")
def get_call_graph_tools(analyzer):
"""Returns tools for querying the call graph."""
@tool("find_callers", args_schema=FindCallersInput)
def find_callers(function_name: str) -> str:
"""Find all functions that call the specified function.
Useful for understanding: "Who uses this function?" or "What depends on this?"
"""
if analyzer is None:
return "Error: No code analysis available. Index a codebase first."
try:
callers = analyzer.get_callers(function_name)
if not callers:
return f"No callers found for '{function_name}'. It may be unused or called dynamically."
result = f"Functions that call '{function_name}':\n"
for caller in callers:
parts = caller.split("::")
if len(parts) == 2:
result += f" - {parts[1]} (in {parts[0]})\n"
else:
result += f" - {caller}\n"
return result
except Exception as e:
return f"Error finding callers: {e}"
@tool("find_callees", args_schema=FindCalleesInput)
def find_callees(function_name: str) -> str:
"""Find all functions that are called by the specified function.
Useful for understanding: "What does this function do?" or "What are its dependencies?"
"""
if analyzer is None:
return "Error: No code analysis available. Index a codebase first."
try:
callees = analyzer.get_callees(function_name)
if not callees:
return f"No callees found for '{function_name}'. It may not call any other tracked functions."
result = f"Functions called by '{function_name}':\n"
for callee in callees:
parts = callee.split("::")
if len(parts) == 2:
result += f" - {parts[1]} (in {parts[0]})\n"
else:
result += f" - {callee}\n"
return result
except Exception as e:
return f"Error finding callees: {e}"
@tool("find_call_chain", args_schema=FindCallChainInput)
def find_call_chain(start_function: str, end_function: str) -> str:
"""Find the call path from one function to another.
Useful for: "How does execution flow from main() to save_to_db()?"
"""
if analyzer is None:
return "Error: No code analysis available. Index a codebase first."
try:
chains = analyzer.get_call_chain(start_function, end_function)
if not chains:
return f"No call path found from '{start_function}' to '{end_function}'."
result = f"Call paths from '{start_function}' to '{end_function}':\n\n"
for i, chain in enumerate(chains[:5], 1):
result += f"Path {i}:\n"
for j, node in enumerate(chain):
parts = node.split("::")
func_name = parts[1] if len(parts) == 2 else node
indent = " " * j
arrow = "-> " if j > 0 else ""
result += f"{indent}{arrow}{func_name}\n"
result += "\n"
return result
except Exception as e:
return f"Error finding call chain: {e}"
return [find_callers, find_callees, find_call_chain]
|