Asish Karthikeya Gogineni
Refactor: Code Structure Update & UI Redesign
a3bdcf1
"""
Enhanced Code Analysis with AST + Call Graph + Control Flow
This module provides comprehensive code analysis using:
1. AST (Abstract Syntax Tree) - Code structure
2. Call Graph - Function-to-function relationships
3. Import Graph - Module dependencies
4. Class Hierarchy - Inheritance relationships
Uses tree-sitter for multi-language support.
"""
import logging
import networkx as nx
import os
from typing import List, Dict, Optional, Set, Tuple
from dataclasses import dataclass, field
from tree_sitter import Language, Parser
import tree_sitter_python
import tree_sitter_javascript
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@dataclass
class FunctionInfo:
"""Information about a function/method"""
name: str
file_path: str
start_line: int
end_line: int
is_method: bool = False
class_name: Optional[str] = None
calls: List[str] = field(default_factory=list)
parameters: List[str] = field(default_factory=list)
@property
def full_name(self) -> str:
if self.class_name:
return f"{self.class_name}.{self.name}"
return self.name
@property
def node_id(self) -> str:
return f"{self.file_path}::{self.full_name}"
@dataclass
class ClassInfo:
"""Information about a class"""
name: str
file_path: str
start_line: int
end_line: int
bases: List[str] = field(default_factory=list) # Parent classes
methods: List[str] = field(default_factory=list)
@dataclass
class ImportInfo:
"""Information about an import"""
module: str
names: List[str] = field(default_factory=list) # Specific names imported
is_from_import: bool = False
class EnhancedCodeAnalyzer:
"""
Enhanced code analyzer that builds:
- AST-based structure graph
- Function call graph
- Import dependency graph
- Class hierarchy graph
"""
def __init__(self):
# Main knowledge graph
self.graph = nx.DiGraph()
# Specialized indices for faster lookups
self.functions: Dict[str, FunctionInfo] = {} # node_id -> FunctionInfo
self.classes: Dict[str, ClassInfo] = {} # node_id -> ClassInfo
self.imports: Dict[str, List[ImportInfo]] = {} # file_path -> imports
self.definitions: Dict[str, List[str]] = {} # name -> [node_ids]
# Track unresolved calls for later resolution
self.unresolved_calls: List[Tuple[str, str, int]] = [] # (caller_id, callee_name, line)
# Parsers
self.parsers = {}
self._init_parsers()
def _init_parsers(self):
"""Initialize tree-sitter parsers for supported languages."""
try:
# Python
py_language = Language(tree_sitter_python.language())
py_parser = Parser(py_language)
self.parsers['python'] = py_parser
self.parsers['py'] = py_parser
# JavaScript
js_language = Language(tree_sitter_javascript.language())
js_parser = Parser(js_language)
self.parsers['javascript'] = js_parser
self.parsers['js'] = js_parser
self.parsers['jsx'] = js_parser
except Exception as e:
logger.error(f"Error initializing parsers: {e}")
def add_file(self, file_path: str, content: str):
"""Parse a file and add it to the knowledge graph."""
ext = file_path.split('.')[-1].lower()
parser = self.parsers.get(ext)
if not parser:
return
try:
tree = parser.parse(bytes(content, "utf8"))
root_node = tree.root_node
# Add file node
self.graph.add_node(
file_path,
type="file",
name=os.path.basename(file_path),
language=ext
)
# Extract all symbols
self._extract_symbols(root_node, file_path, content)
except Exception as e:
logger.error(f"Failed to parse {file_path}: {e}")
def _extract_symbols(self, node, file_path: str, content: str,
current_class: Optional[str] = None,
current_function: Optional[str] = None):
"""Recursively extract symbols from AST node."""
# ========== IMPORTS ==========
if node.type == "import_statement":
self._process_import(node, file_path, content)
elif node.type == "import_from_statement":
self._process_from_import(node, file_path, content)
# ========== CLASSES ==========
elif node.type == "class_definition":
class_info = self._process_class(node, file_path, content)
if class_info:
# Recurse into class body with class context
for child in node.children:
if child.type == "block":
self._extract_symbols(child, file_path, content,
current_class=class_info.name)
return # Don't recurse again below
# ========== FUNCTIONS/METHODS ==========
elif node.type == "function_definition":
func_info = self._process_function(node, file_path, content, current_class)
if func_info:
# Recurse into function body to find calls
for child in node.children:
if child.type == "block":
self._extract_symbols(child, file_path, content,
current_class=current_class,
current_function=func_info.node_id)
return # Don't recurse again below
# ========== FUNCTION CALLS ==========
elif node.type == "call":
self._process_call(node, file_path, content, current_function or file_path)
# Recurse into children
for child in node.children:
self._extract_symbols(child, file_path, content,
current_class, current_function)
def _process_import(self, node, file_path: str, content: str):
"""Process import statement."""
# import module1, module2
for child in node.children:
if child.type == "dotted_name":
module_name = self._get_text(child, content)
import_info = ImportInfo(module=module_name)
if file_path not in self.imports:
self.imports[file_path] = []
self.imports[file_path].append(import_info)
# Add import edge
self.graph.add_edge(file_path, module_name, relation="imports")
def _process_from_import(self, node, file_path: str, content: str):
"""Process from X import Y statement."""
module_name = None
names = []
for child in node.children:
if child.type == "dotted_name" and module_name is None:
module_name = self._get_text(child, content)
elif child.type == "import_from_list":
for name_node in child.children:
if name_node.type == "aliased_import":
name = self._get_text(name_node.children[0], content)
names.append(name)
elif name_node.type == "identifier":
names.append(self._get_text(name_node, content))
if module_name:
import_info = ImportInfo(module=module_name, names=names, is_from_import=True)
if file_path not in self.imports:
self.imports[file_path] = []
self.imports[file_path].append(import_info)
# Add import edge
self.graph.add_edge(file_path, module_name, relation="imports")
# Register imported names as potential definitions
for name in names:
if name not in self.definitions:
self.definitions[name] = []
self.definitions[name].append(f"{module_name}.{name}")
def _process_class(self, node, file_path: str, content: str) -> Optional[ClassInfo]:
"""Process class definition."""
name_node = node.child_by_field_name("name")
if not name_node:
return None
class_name = self._get_text(name_node, content)
node_id = f"{file_path}::{class_name}"
# Get base classes
bases = []
for child in node.children:
if child.type == "argument_list":
for arg in child.children:
if arg.type == "identifier":
bases.append(self._get_text(arg, content))
class_info = ClassInfo(
name=class_name,
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
bases=bases
)
self.classes[node_id] = class_info
# Add to graph
self.graph.add_node(
node_id,
type="class",
name=class_name,
start_line=class_info.start_line,
end_line=class_info.end_line
)
self.graph.add_edge(file_path, node_id, relation="defines")
# Add inheritance edges
for base in bases:
self.graph.add_edge(node_id, base, relation="inherits_from")
# Register definition
if class_name not in self.definitions:
self.definitions[class_name] = []
self.definitions[class_name].append(node_id)
return class_info
def _process_function(self, node, file_path: str, content: str,
current_class: Optional[str] = None) -> Optional[FunctionInfo]:
"""Process function/method definition."""
name_node = node.child_by_field_name("name")
if not name_node:
return None
func_name = self._get_text(name_node, content)
# Get parameters
params = []
params_node = node.child_by_field_name("parameters")
if params_node:
for child in params_node.children:
if child.type == "identifier":
params.append(self._get_text(child, content))
elif child.type == "typed_parameter":
name = child.child_by_field_name("name")
if name:
params.append(self._get_text(name, content))
func_info = FunctionInfo(
name=func_name,
file_path=file_path,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
is_method=current_class is not None,
class_name=current_class,
parameters=params
)
node_id = func_info.node_id
self.functions[node_id] = func_info
# Add to graph
self.graph.add_node(
node_id,
type="function" if not current_class else "method",
name=func_name,
full_name=func_info.full_name,
start_line=func_info.start_line,
end_line=func_info.end_line,
parameters=",".join(params)
)
# Link to parent (file or class)
if current_class:
class_id = f"{file_path}::{current_class}"
self.graph.add_edge(class_id, node_id, relation="has_method")
else:
self.graph.add_edge(file_path, node_id, relation="defines")
# Register definition
if func_name not in self.definitions:
self.definitions[func_name] = []
self.definitions[func_name].append(node_id)
return func_info
def _process_call(self, node, file_path: str, content: str, caller_id: str):
"""Process function call."""
func_node = node.child_by_field_name("function")
if not func_node:
return
callee_name = self._get_text(func_node, content)
call_line = node.start_point[0] + 1
# Track call in function info
if caller_id in self.functions:
self.functions[caller_id].calls.append(callee_name)
# Store for later resolution
self.unresolved_calls.append((caller_id, callee_name, call_line))
def _get_text(self, node, content: str) -> str:
"""Get text content of a node."""
return content[node.start_byte:node.end_byte]
def resolve_call_graph(self):
"""Resolve all function calls to their definitions."""
resolved_count = 0
for caller_id, callee_name, line in self.unresolved_calls:
# Handle method calls like "self.method" or "obj.method"
simple_name = callee_name.split(".")[-1]
# Try to find definition
target_ids = []
# Check direct match
if callee_name in self.definitions:
target_ids.extend(self.definitions[callee_name])
# Check simple name (for methods)
if simple_name in self.definitions and simple_name != callee_name:
target_ids.extend(self.definitions[simple_name])
# Add call edges
for target_id in target_ids:
self.graph.add_edge(
caller_id,
target_id,
relation="calls",
line=line
)
resolved_count += 1
logger.info(f"Resolved {resolved_count} function calls in call graph")
def get_callers(self, function_name: str) -> List[str]:
"""Find all functions that call the specified function."""
callers = []
# Find the function's node_id
target_ids = self.definitions.get(function_name, [])
for target_id in target_ids:
# Find incoming "calls" edges
for pred in self.graph.predecessors(target_id):
edge_data = self.graph.get_edge_data(pred, target_id)
if edge_data and edge_data.get("relation") == "calls":
callers.append(pred)
return callers
def get_callees(self, function_name: str) -> List[str]:
"""Find all functions called by the given function."""
callees = []
# Find the function's node_id
caller_ids = self.definitions.get(function_name, [])
for caller_id in caller_ids:
# Find outgoing "calls" edges
for succ in self.graph.successors(caller_id):
edge_data = self.graph.get_edge_data(caller_id, succ)
if edge_data and edge_data.get("relation") == "calls":
callees.append(succ)
return callees
def get_call_chain(self, start_func: str, end_func: str, max_depth: int = 5) -> List[List[str]]:
"""Find call paths from start_func to end_func."""
paths = []
start_ids = self.definitions.get(start_func, [])
end_ids = self.definitions.get(end_func, [])
for start_id in start_ids:
for end_id in end_ids:
try:
for path in nx.all_simple_paths(self.graph, start_id, end_id, cutoff=max_depth):
# Filter to only show call edges
call_path = [start_id]
for i in range(len(path) - 1):
edge = self.graph.get_edge_data(path[i], path[i+1])
if edge and edge.get("relation") == "calls":
call_path.append(path[i+1])
if len(call_path) > 1:
paths.append(call_path)
except nx.NetworkXNoPath:
continue
return paths
def get_file_dependencies(self, file_path: str) -> Dict[str, List[str]]:
"""Get all dependencies of a file (imports, calls to other files)."""
deps = {
"imports": [],
"calls_to": [],
"called_by": []
}
# Direct imports
deps["imports"] = [imp.module for imp in self.imports.get(file_path, [])]
# Functions in this file that call functions in other files
for func_id, func_info in self.functions.items():
if func_info.file_path == file_path:
for callee in self.get_callees(func_info.name):
callee_file = callee.split("::")[0]
if callee_file != file_path and callee_file not in deps["calls_to"]:
deps["calls_to"].append(callee_file)
# Functions in other files that call functions in this file
for func_id, func_info in self.functions.items():
if func_info.file_path == file_path:
for caller in self.get_callers(func_info.name):
caller_file = caller.split("::")[0]
if caller_file != file_path and caller_file not in deps["called_by"]:
deps["called_by"].append(caller_file)
return deps
def get_related_nodes(self, node_id: str, depth: int = 2) -> List[str]:
"""Get nodes related to the given node via graph traversal."""
if node_id not in self.graph:
# Try to find by name
if node_id in self.definitions:
node_ids = self.definitions[node_id]
all_related = []
for nid in node_ids:
all_related.extend(list(nx.bfs_tree(self.graph, nid, depth_limit=depth)))
return list(set(all_related))
return []
return list(nx.bfs_tree(self.graph, node_id, depth_limit=depth))
def get_statistics(self) -> Dict:
"""Get analysis statistics."""
return {
"total_nodes": self.graph.number_of_nodes(),
"total_edges": self.graph.number_of_edges(),
"files": len([n for n, d in self.graph.nodes(data=True) if d.get("type") == "file"]),
"classes": len(self.classes),
"functions": len([f for f in self.functions.values() if not f.is_method]),
"methods": len([f for f in self.functions.values() if f.is_method]),
"imports": sum(len(imps) for imps in self.imports.values()),
"call_edges": len([1 for _, _, d in self.graph.edges(data=True) if d.get("relation") == "calls"])
}
def save_graph(self, path: str):
"""Save the graph to a GraphML file."""
# Resolve call graph first
self.resolve_call_graph()
# Log statistics
stats = self.get_statistics()
logger.info(f"Graph Statistics: {stats}")
nx.write_graphml(self.graph, path)
logger.info(f"Graph saved to {path}")
# Backward compatibility alias
class ASTGraphBuilder(EnhancedCodeAnalyzer):
"""Alias for backward compatibility with existing code."""
pass