Asish Karthikeya Gogineni
Refactor: Code Structure Update & UI Redesign
a3bdcf1
import os
import glob
from typing import List, Optional
from langchain_core.tools import tool
from pydantic import BaseModel, Field
# Define Input Schemas
class ListFilesInput(BaseModel):
path: str = Field(description="Directory path to list files from. Use '.' for root.")
class ReadFileInput(BaseModel):
file_path: str = Field(description="Path to the file to read.")
# Define Tools Factory
def get_filesystem_tools(root_dir: str = "."):
"""Returns a list of tools bound to the specified root directory."""
# Ensure root_dir is absolute
root_dir = os.path.abspath(root_dir)
@tool("list_files", args_schema=ListFilesInput)
def list_files(path: str = ".") -> str:
"""Lists files in the specified directory."""
try:
# Resolve target path relative to root_dir
if path == ".":
target_path = root_dir
else:
target_path = os.path.abspath(os.path.join(root_dir, path))
# Security check: ensure we are inside the codebase
if not target_path.startswith(root_dir):
return f"Error: Access denied. Path must be within the codebase: {root_dir}"
if not os.path.exists(target_path):
return f"Error: Path does not exist: {path}"
files = []
for item in os.listdir(target_path):
if item.startswith(".") and item != ".gitignore": continue
full_item_path = os.path.join(target_path, item)
if os.path.isdir(full_item_path):
files.append(f"{item}/")
else:
files.append(item)
# Sort for stability
files.sort()
return "\n".join(files)
except Exception as e:
return f"Error listing files: {e}"
@tool("read_file", args_schema=ReadFileInput)
def read_file(file_path: str) -> str:
"""Reads the content of a file."""
try:
# Resolve full path
full_path = os.path.abspath(os.path.join(root_dir, file_path))
# Security check
if not full_path.startswith(root_dir):
return "Error: Access denied. File must be within the codebase."
if not os.path.exists(full_path):
return f"Error: File not found: {file_path}"
# Check file size to avoid overloading context
# Groq TPM limit is ~12k tokens. 12000 chars is roughly 3k tokens.
# We strictly prevent reading massive files to keep the agent alive.
if os.path.getsize(full_path) > 12000:
return f"Error: File '{file_path}' is too large ({os.path.getsize(full_path)} bytes). Read specific lines or functions instead."
with open(full_path, "r", errors='ignore') as f:
content = f.read()
return content
except Exception as e:
return f"Error reading file: {e}"
return [list_files, read_file]
# ============================================================================
# Call Graph Tools
# ============================================================================
class FindCallersInput(BaseModel):
function_name: str = Field(description="Name of the function to find callers for")
class FindCalleesInput(BaseModel):
function_name: str = Field(description="Name of the function to find callees for")
class FindCallChainInput(BaseModel):
start_function: str = Field(description="Name of the starting function")
end_function: str = Field(description="Name of the target function to trace to")
def get_call_graph_tools(analyzer):
"""Returns tools for querying the call graph."""
@tool("find_callers", args_schema=FindCallersInput)
def find_callers(function_name: str) -> str:
"""Find all functions that call the specified function.
Useful for understanding: "Who uses this function?" or "What depends on this?"
"""
if analyzer is None:
return "Error: No code analysis available. Index a codebase first."
try:
callers = analyzer.get_callers(function_name)
if not callers:
return f"No callers found for '{function_name}'. It may be unused or called dynamically."
result = f"Functions that call '{function_name}':\n"
for caller in callers:
parts = caller.split("::")
if len(parts) == 2:
result += f" - {parts[1]} (in {parts[0]})\n"
else:
result += f" - {caller}\n"
return result
except Exception as e:
return f"Error finding callers: {e}"
@tool("find_callees", args_schema=FindCalleesInput)
def find_callees(function_name: str) -> str:
"""Find all functions that are called by the specified function.
Useful for understanding: "What does this function do?" or "What are its dependencies?"
"""
if analyzer is None:
return "Error: No code analysis available. Index a codebase first."
try:
callees = analyzer.get_callees(function_name)
if not callees:
return f"No callees found for '{function_name}'. It may not call any other tracked functions."
result = f"Functions called by '{function_name}':\n"
for callee in callees:
parts = callee.split("::")
if len(parts) == 2:
result += f" - {parts[1]} (in {parts[0]})\n"
else:
result += f" - {callee}\n"
return result
except Exception as e:
return f"Error finding callees: {e}"
@tool("find_call_chain", args_schema=FindCallChainInput)
def find_call_chain(start_function: str, end_function: str) -> str:
"""Find the call path from one function to another.
Useful for: "How does execution flow from main() to save_to_db()?"
"""
if analyzer is None:
return "Error: No code analysis available. Index a codebase first."
try:
chains = analyzer.get_call_chain(start_function, end_function)
if not chains:
return f"No call path found from '{start_function}' to '{end_function}'."
result = f"Call paths from '{start_function}' to '{end_function}':\n\n"
for i, chain in enumerate(chains[:5], 1):
result += f"Path {i}:\n"
for j, node in enumerate(chain):
parts = node.split("::")
func_name = parts[1] if len(parts) == 2 else node
indent = " " * j
arrow = "-> " if j > 0 else ""
result += f"{indent}{arrow}{func_name}\n"
result += "\n"
return result
except Exception as e:
return f"Error finding call chain: {e}"
return [find_callers, find_callees, find_call_chain]