| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
| import networkx as nx |
| import numpy as np |
| from typing import Dict, List, Any, Tuple, Optional |
| import json |
| import io |
| import base64 |
| import tempfile |
| import os |
| import plotly.graph_objects as go |
| import plotly.express as px |
| from pyvis.network import Network |
|
|
| class GraphVisualizer: |
| def __init__(self): |
| self.color_map = { |
| 'PERSON': '#FF6B6B', |
| 'ORGANIZATION': '#4ECDC4', |
| 'LOCATION': '#45B7D1', |
| 'CONCEPT': '#96CEB4', |
| 'EVENT': '#FFEAA7', |
| 'OBJECT': '#DDA0DD', |
| 'UNKNOWN': '#95A5A6' |
| } |
| |
| def visualize_graph(self, |
| graph: nx.DiGraph, |
| layout_type: str = "spring", |
| show_labels: bool = True, |
| show_edge_labels: bool = False, |
| node_size_factor: float = 1.0, |
| figsize: Tuple[int, int] = (12, 8)) -> str: |
| """Create a matplotlib visualization of the graph and return file path.""" |
| |
| if not graph.nodes(): |
| return self._create_empty_graph_image() |
| |
| |
| plt.figure(figsize=figsize) |
| plt.clf() |
| |
| |
| pos = self._calculate_layout(graph, layout_type) |
| |
| |
| node_colors = [self.color_map.get(graph.nodes[node].get('type', 'UNKNOWN'), '#95A5A6') |
| for node in graph.nodes()] |
| node_sizes = [graph.nodes[node].get('size', 20) * node_size_factor * 10 |
| for node in graph.nodes()] |
| |
| |
| nx.draw_networkx_nodes(graph, pos, |
| node_color=node_colors, |
| node_size=node_sizes, |
| alpha=0.8) |
| |
| |
| nx.draw_networkx_edges(graph, pos, |
| edge_color='gray', |
| arrows=True, |
| arrowsize=20, |
| alpha=0.6, |
| width=1.5) |
| |
| |
| if show_labels: |
| |
| labels = {} |
| for node in graph.nodes(): |
| importance = graph.nodes[node].get('importance', 0.0) |
| labels[node] = f"{node}\n({importance:.2f})" |
| |
| nx.draw_networkx_labels(graph, pos, labels, font_size=8) |
| |
| |
| if show_edge_labels: |
| edge_labels = {(u, v): data.get('relationship', '') |
| for u, v, data in graph.edges(data=True)} |
| nx.draw_networkx_edge_labels(graph, pos, edge_labels, font_size=6) |
| |
| plt.title("Knowledge Graph", fontsize=16, fontweight='bold') |
| plt.axis('off') |
| plt.tight_layout() |
| |
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
| plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') |
| plt.close() |
| |
| return temp_file.name |
| |
| def _calculate_layout(self, graph: nx.DiGraph, layout_type: str) -> Dict[str, Tuple[float, float]]: |
| """Calculate node positions using specified layout algorithm.""" |
| try: |
| if layout_type == "spring": |
| return nx.spring_layout(graph, k=1, iterations=50) |
| elif layout_type == "circular": |
| return nx.circular_layout(graph) |
| elif layout_type == "shell": |
| return nx.shell_layout(graph) |
| elif layout_type == "kamada_kawai": |
| return nx.kamada_kawai_layout(graph) |
| elif layout_type == "random": |
| return nx.random_layout(graph) |
| else: |
| return nx.spring_layout(graph, k=1, iterations=50) |
| except: |
| |
| return nx.spring_layout(graph, k=1, iterations=50) |
| |
| def _create_empty_graph_image(self) -> str: |
| """Create an image for empty graph.""" |
| plt.figure(figsize=(8, 6)) |
| plt.text(0.5, 0.5, 'No graph data to display', |
| horizontalalignment='center', verticalalignment='center', |
| fontsize=16, transform=plt.gca().transAxes) |
| plt.axis('off') |
| |
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
| plt.savefig(temp_file.name, format='png', dpi=150, bbox_inches='tight') |
| plt.close() |
| |
| return temp_file.name |
| |
| def create_interactive_html(self, graph: nx.DiGraph) -> str: |
| """Create an interactive HTML visualization using vis.js.""" |
| if not graph.nodes(): |
| return "<div>No graph data to display</div>" |
| |
| |
| nodes = [] |
| edges = [] |
| |
| for node, data in graph.nodes(data=True): |
| nodes.append({ |
| "id": node, |
| "label": node, |
| "color": self.color_map.get(data.get('type', 'UNKNOWN'), '#95A5A6'), |
| "size": data.get('size', 20), |
| "title": f"Type: {data.get('type', 'UNKNOWN')}<br>" |
| f"Importance: {data.get('importance', 0.0):.2f}<br>" |
| f"Description: {data.get('description', 'N/A')}" |
| }) |
| |
| for u, v, data in graph.edges(data=True): |
| edges.append({ |
| "from": u, |
| "to": v, |
| "label": data.get('relationship', ''), |
| "title": data.get('description', ''), |
| "arrows": {"to": {"enabled": True}} |
| }) |
| |
| html_template = f""" |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <script src="https://unpkg.com/vis-network/standalone/umd/vis-network.min.js"></script> |
| <style> |
| #mynetworkid {{ |
| width: 100%; |
| height: 600px; |
| border: 1px solid lightgray; |
| }} |
| </style> |
| </head> |
| <body> |
| <div id="mynetworkid"></div> |
| |
| <script> |
| var nodes = new vis.DataSet({json.dumps(nodes)}); |
| var edges = new vis.DataSet({json.dumps(edges)}); |
| var container = document.getElementById('mynetworkid'); |
| |
| var data = {{ |
| nodes: nodes, |
| edges: edges |
| }}; |
| |
| var options = {{ |
| nodes: {{ |
| shape: 'dot', |
| scaling: {{ |
| min: 10, |
| max: 30 |
| }}, |
| font: {{ |
| size: 12, |
| face: 'Tahoma' |
| }} |
| }}, |
| edges: {{ |
| font: {{align: 'middle'}}, |
| color: {{color:'gray'}}, |
| arrows: {{to: {{enabled: true, scaleFactor: 1}}}} |
| }}, |
| physics: {{ |
| enabled: true, |
| stabilization: {{enabled: true, iterations: 200}} |
| }}, |
| interaction: {{ |
| hover: true, |
| tooltipDelay: 200 |
| }} |
| }}; |
| |
| var network = new vis.Network(container, data, options); |
| </script> |
| </body> |
| </html> |
| """ |
| |
| return html_template |
| |
| def create_statistics_summary(self, graph: nx.DiGraph, stats: Dict[str, Any]) -> str: |
| """Create a formatted statistics summary.""" |
| if not graph.nodes(): |
| return "No graph statistics available." |
| |
| |
| type_counts = {} |
| for node, data in graph.nodes(data=True): |
| node_type = data.get('type', 'UNKNOWN') |
| type_counts[node_type] = type_counts.get(node_type, 0) + 1 |
| |
| |
| rel_counts = {} |
| for u, v, data in graph.edges(data=True): |
| rel_type = data.get('relationship', 'unknown') |
| rel_counts[rel_type] = rel_counts.get(rel_type, 0) + 1 |
| |
| summary = f""" |
| ## Graph Statistics |
| |
| **Basic Metrics:** |
| - Nodes: {stats['num_nodes']} |
| - Edges: {stats['num_edges']} |
| - Density: {stats['density']:.3f} |
| - Connected: {'Yes' if stats['is_connected'] else 'No'} |
| - Components: {stats['num_components']} |
| - Average Degree: {stats['avg_degree']:.2f} |
| |
| **Entity Types:** |
| """ |
| |
| for entity_type, count in sorted(type_counts.items()): |
| summary += f"\n- {entity_type}: {count}" |
| |
| summary += "\n\n**Relationship Types:**" |
| for rel_type, count in sorted(rel_counts.items()): |
| summary += f"\n- {rel_type}: {count}" |
| |
| return summary |
| |
| def create_entity_list(self, graph: nx.DiGraph, sort_by: str = "importance") -> str: |
| """Create a formatted list of entities.""" |
| if not graph.nodes(): |
| return "No entities found." |
| |
| entities = [] |
| for node, data in graph.nodes(data=True): |
| entities.append({ |
| 'name': node, |
| 'type': data.get('type', 'UNKNOWN'), |
| 'importance': data.get('importance', 0.0), |
| 'description': data.get('description', 'N/A'), |
| 'connections': graph.degree(node) |
| }) |
| |
| |
| if sort_by == "importance": |
| entities.sort(key=lambda x: x['importance'], reverse=True) |
| elif sort_by == "connections": |
| entities.sort(key=lambda x: x['connections'], reverse=True) |
| elif sort_by == "name": |
| entities.sort(key=lambda x: x['name']) |
| |
| entity_list = "## Entities\n\n" |
| for entity in entities: |
| entity_list += f""" |
| **{entity['name']}** ({entity['type']}) |
| - Importance: {entity['importance']:.2f} |
| - Connections: {entity['connections']} |
| - Description: {entity['description']} |
| |
| """ |
| |
| return entity_list |
| |
| def get_layout_options(self) -> List[str]: |
| """Get available layout options.""" |
| return ["spring", "circular", "shell", "kamada_kawai", "random"] |
| |
| def get_entity_types(self, graph: nx.DiGraph) -> List[str]: |
| """Get unique entity types from the graph.""" |
| types = set() |
| for node, data in graph.nodes(data=True): |
| types.add(data.get('type', 'UNKNOWN')) |
| return sorted(list(types)) |
| |
| def create_plotly_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> go.Figure: |
| """Create an interactive Plotly visualization of the graph.""" |
| if not graph.nodes(): |
| |
| fig = go.Figure() |
| fig.add_annotation( |
| text="No graph data to display", |
| xref="paper", yref="paper", |
| x=0.5, y=0.5, xanchor='center', yanchor='middle', |
| showarrow=False, font=dict(size=16) |
| ) |
| return fig |
| |
| |
| pos = self._calculate_layout(graph, layout_type) |
| |
| |
| node_x = [] |
| node_y = [] |
| node_text = [] |
| node_info = [] |
| node_colors = [] |
| node_sizes = [] |
| |
| for node in graph.nodes(): |
| x, y = pos[node] |
| node_x.append(x) |
| node_y.append(y) |
| |
| data = graph.nodes[node] |
| node_type = data.get('type', 'UNKNOWN') |
| importance = data.get('importance', 0.0) |
| description = data.get('description', 'N/A') |
| connections = graph.degree(node) |
| |
| node_text.append(node) |
| node_info.append( |
| f"<b>{node}</b><br>" |
| f"Type: {node_type}<br>" |
| f"Importance: {importance:.2f}<br>" |
| f"Connections: {connections}<br>" |
| f"Description: {description}" |
| ) |
| node_colors.append(self.color_map.get(node_type, '#95A5A6')) |
| node_sizes.append(max(10, data.get('size', 20))) |
| |
| |
| edge_x = [] |
| edge_y = [] |
| edge_info = [] |
| |
| for edge in graph.edges(): |
| x0, y0 = pos[edge[0]] |
| x1, y1 = pos[edge[1]] |
| edge_x.extend([x0, x1, None]) |
| edge_y.extend([y0, y1, None]) |
| |
| edge_data = graph.edges[edge] |
| relationship = edge_data.get('relationship', 'connected') |
| edge_info.append(f"{edge[0]} → {edge[1]}<br>Relationship: {relationship}") |
| |
| |
| edge_trace = go.Scatter( |
| x=edge_x, y=edge_y, |
| line=dict(width=2, color='gray'), |
| hoverinfo='none', |
| mode='lines' |
| ) |
| |
| |
| node_trace = go.Scatter( |
| x=node_x, y=node_y, |
| mode='markers+text', |
| hoverinfo='text', |
| text=node_text, |
| hovertext=node_info, |
| textposition="middle center", |
| marker=dict( |
| size=node_sizes, |
| color=node_colors, |
| line=dict(width=2, color='white') |
| ) |
| ) |
| |
| |
| fig = go.Figure(data=[edge_trace, node_trace], |
| layout=go.Layout( |
| title='Interactive Knowledge Graph', |
| titlefont_size=16, |
| showlegend=False, |
| hovermode='closest', |
| margin=dict(b=20,l=5,r=5,t=40), |
| annotations=[ dict( |
| text="Hover over nodes for details. Drag to pan, scroll to zoom.", |
| showarrow=False, |
| xref="paper", yref="paper", |
| x=0.005, y=-0.002, |
| xanchor='left', yanchor='bottom', |
| font=dict(color="gray", size=12) |
| )], |
| xaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
| yaxis=dict(showgrid=False, zeroline=False, showticklabels=False), |
| plot_bgcolor='white' |
| )) |
| |
| return fig |
| |
| def create_pyvis_interactive(self, graph: nx.DiGraph, layout_type: str = "spring") -> str: |
| """Create an interactive pyvis visualization and return HTML file path.""" |
| if not graph.nodes(): |
| return self._create_empty_pyvis_graph() |
| |
| |
| net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") |
| |
| |
| net.set_options(""" |
| { |
| "physics": { |
| "enabled": true, |
| "stabilization": {"enabled": true, "iterations": 200}, |
| "barnesHut": { |
| "gravitationalConstant": -2000, |
| "centralGravity": 0.3, |
| "springLength": 95, |
| "springConstant": 0.04, |
| "damping": 0.09 |
| } |
| }, |
| "interaction": { |
| "hover": true, |
| "tooltipDelay": 200, |
| "hideEdgesOnDrag": false |
| } |
| } |
| """) |
| |
| |
| for node, data in graph.nodes(data=True): |
| node_type = data.get('type', 'UNKNOWN') |
| importance = data.get('importance', 0.0) |
| description = data.get('description', 'N/A') |
| connections = graph.degree(node) |
| |
| |
| color = self.color_map.get(node_type, '#95A5A6') |
| size = max(10, data.get('size', 20)) |
| |
| |
| title = f""" |
| <b>{node}</b><br> |
| Type: {node_type}<br> |
| Importance: {importance:.2f}<br> |
| Connections: {connections}<br> |
| Description: {description} |
| """ |
| |
| net.add_node(node, label=node, title=title, color=color, size=size) |
| |
| |
| for u, v, data in graph.edges(data=True): |
| relationship = data.get('relationship', 'connected') |
| title = f"{u} → {v}<br>Relationship: {relationship}" |
| |
| net.add_edge(u, v, title=title, arrows="to", color="gray") |
| |
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') |
| net.save_graph(temp_file.name) |
| temp_file.close() |
| |
| return temp_file.name |
| |
| def _create_empty_pyvis_graph(self) -> str: |
| """Create an empty pyvis graph.""" |
| net = Network(height="600px", width="100%", bgcolor="#ffffff", font_color="black") |
| net.add_node(1, label="No graph data", color="#cccccc") |
| |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') |
| net.save_graph(temp_file.name) |
| temp_file.close() |
| |
| return temp_file.name |
| |
| def get_visualization_options(self) -> List[str]: |
| """Get available visualization types.""" |
| return ["matplotlib", "plotly", "pyvis", "vis.js"] |
| |
| def get_relationship_types(self, graph: nx.DiGraph) -> List[str]: |
| """Get unique relationship types from the graph.""" |
| types = set() |
| for u, v, data in graph.edges(data=True): |
| types.add(data.get('relationship', 'unknown')) |
| return sorted(list(types)) |
|
|