File size: 7,337 Bytes
5b89d45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import glob
from typing import List, Optional
from langchain_core.tools import tool
from pydantic import BaseModel, Field

# Define Input Schemas
class ListFilesInput(BaseModel):
    path: str = Field(description="Directory path to list files from. Use '.' for root.")

class ReadFileInput(BaseModel):
    file_path: str = Field(description="Path to the file to read.")

# Define Tools Factory
def get_filesystem_tools(root_dir: str = "."):
    """Returns a list of tools bound to the specified root directory."""
    
    # Ensure root_dir is absolute
    root_dir = os.path.abspath(root_dir)

    @tool("list_files", args_schema=ListFilesInput)
    def list_files(path: str = ".") -> str:
        """Lists files in the specified directory."""
        try:
            # Resolve target path relative to root_dir
            if path == ".":
                target_path = root_dir
            else:
                target_path = os.path.abspath(os.path.join(root_dir, path))
            
            # Security check: ensure we are inside the codebase
            if not target_path.startswith(root_dir):
                return f"Error: Access denied. Path must be within the codebase: {root_dir}"
            
            if not os.path.exists(target_path):
                return f"Error: Path does not exist: {path}"

            files = []
            for item in os.listdir(target_path):
                if item.startswith(".") and item != ".gitignore": continue
                
                full_item_path = os.path.join(target_path, item)
                
                if os.path.isdir(full_item_path):
                    files.append(f"{item}/")
                else:
                    files.append(item)
            
            # Sort for stability
            files.sort()
            return "\n".join(files)
        except Exception as e:
            return f"Error listing files: {e}"

    @tool("read_file", args_schema=ReadFileInput)
    def read_file(file_path: str) -> str:
        """Reads the content of a file."""
        try:
            # Resolve full path
            full_path = os.path.abspath(os.path.join(root_dir, file_path))
            
            # Security check
            if not full_path.startswith(root_dir):
                return "Error: Access denied. File must be within the codebase."
            
            if not os.path.exists(full_path):
                 return f"Error: File not found: {file_path}"

            # Check file size to avoid overloading context
            # Groq TPM limit is ~12k tokens. 12000 chars is roughly 3k tokens.
            # We strictly prevent reading massive files to keep the agent alive.
            if os.path.getsize(full_path) > 12000:
                 return f"Error: File '{file_path}' is too large ({os.path.getsize(full_path)} bytes). Read specific lines or functions instead."
                 
            with open(full_path, "r", errors='ignore') as f:
                content = f.read()
                return content
        except Exception as e:
            return f"Error reading file: {e}"

    return [list_files, read_file]


# ============================================================================
# Call Graph Tools
# ============================================================================

class FindCallersInput(BaseModel):
    function_name: str = Field(description="Name of the function to find callers for")

class FindCalleesInput(BaseModel):
    function_name: str = Field(description="Name of the function to find callees for")

class FindCallChainInput(BaseModel):
    start_function: str = Field(description="Name of the starting function")
    end_function: str = Field(description="Name of the target function to trace to")


def get_call_graph_tools(analyzer):
    """Returns tools for querying the call graph."""
    
    @tool("find_callers", args_schema=FindCallersInput)
    def find_callers(function_name: str) -> str:
        """Find all functions that call the specified function.
        Useful for understanding: "Who uses this function?" or "What depends on this?"
        """
        if analyzer is None:
            return "Error: No code analysis available. Index a codebase first."
        
        try:
            callers = analyzer.get_callers(function_name)
            
            if not callers:
                return f"No callers found for '{function_name}'. It may be unused or called dynamically."
            
            result = f"Functions that call '{function_name}':\n"
            for caller in callers:
                parts = caller.split("::")
                if len(parts) == 2:
                    result += f"  - {parts[1]} (in {parts[0]})\n"
                else:
                    result += f"  - {caller}\n"
            
            return result
        except Exception as e:
            return f"Error finding callers: {e}"
    
    @tool("find_callees", args_schema=FindCalleesInput)
    def find_callees(function_name: str) -> str:
        """Find all functions that are called by the specified function.
        Useful for understanding: "What does this function do?" or "What are its dependencies?"
        """
        if analyzer is None:
            return "Error: No code analysis available. Index a codebase first."
        
        try:
            callees = analyzer.get_callees(function_name)
            
            if not callees:
                return f"No callees found for '{function_name}'. It may not call any other tracked functions."
            
            result = f"Functions called by '{function_name}':\n"
            for callee in callees:
                parts = callee.split("::")
                if len(parts) == 2:
                    result += f"  - {parts[1]} (in {parts[0]})\n"
                else:
                    result += f"  - {callee}\n"
            
            return result
        except Exception as e:
            return f"Error finding callees: {e}"
    
    @tool("find_call_chain", args_schema=FindCallChainInput)
    def find_call_chain(start_function: str, end_function: str) -> str:
        """Find the call path from one function to another.
        Useful for: "How does execution flow from main() to save_to_db()?"
        """
        if analyzer is None:
            return "Error: No code analysis available. Index a codebase first."
        
        try:
            chains = analyzer.get_call_chain(start_function, end_function)
            
            if not chains:
                return f"No call path found from '{start_function}' to '{end_function}'."
            
            result = f"Call paths from '{start_function}' to '{end_function}':\n\n"
            for i, chain in enumerate(chains[:5], 1):
                result += f"Path {i}:\n"
                for j, node in enumerate(chain):
                    parts = node.split("::")
                    func_name = parts[1] if len(parts) == 2 else node
                    indent = "  " * j
                    arrow = "-> " if j > 0 else ""
                    result += f"{indent}{arrow}{func_name}\n"
                result += "\n"
            
            return result
        except Exception as e:
            return f"Error finding call chain: {e}"
    
    return [find_callers, find_callees, find_call_chain]