#!/usr/bin/env python3 """ Code Flow Analyzer with Gradio Interface - Hugging Face Spaces & Colab Compatible A single-file application that uses LangChain agents with the Gemini model to analyze code structure and generate Mermaid.js flowchart diagrams through a web interface. """ import ast import re import os import traceback import sys from typing import Dict, Any, List, Tuple import getpass # Check if running in Colab try: import google.colab IN_COLAB = True print("š¢ Running in Google Colab") except ImportError: IN_COLAB = False print("š” Running locally or in Hugging Face Spaces") # Install dependencies if in Colab if IN_COLAB: print("š¦ Installing dependencies...") os.system("pip install -q gradio langchain langgraph langchain-google-genai") print("ā Dependencies installed") import gradio as gr from langchain.chat_models import init_chat_model from langchain_google_genai import ChatGoogleGenerativeAI from langchain.tools import tool from langgraph.prebuilt import create_react_agent from langgraph.checkpoint.memory import MemorySaver # Sample code examples (unchanged) SAMPLE_PYTHON = '''def main(): user_input = get_user_input() if user_input: result = process_data(user_input) if result > 0: display_result(result) else: show_error() else: show_help() def get_user_input(): return input("Enter data: ") def process_data(data): for i in range(len(data)): if data[i].isdigit(): return int(data[i]) return -1 def display_result(result): print(f"Result: {result}") def show_error(): print("Error processing data") def show_help(): print("Please provide valid input")''' SAMPLE_JAVASCRIPT = '''function calculateTotal(items) { let total = 0; for (let item of items) { if (item.price > 0) { total += item.price; } } return total; } function processOrder(order) { if (validateOrder(order)) { const total = calculateTotal(order.items); if (total > 100) { applyDiscount(order); } return generateReceipt(order); } else { throw new Error("Invalid order"); } } function validateOrder(order) { return order && order.items && order.items.length > 0; } function applyDiscount(order) { order.discount = 0.1; // 10% discount } function generateReceipt(order) { return { items: order.items, total: calculateTotal(order.items), timestamp: new Date() }; }''' SAMPLE_JAVA = '''public class Calculator { public static void main(String[] args) { Calculator calc = new Calculator(); int result = calc.performCalculation(); calc.displayResult(result); } public int performCalculation() { int a = getUserInput(); int b = getUserInput(); if (a > b) { return multiply(a, b); } else { return add(a, b); } } private int add(int x, int y) { return x + y; } private int multiply(int x, int y) { return x * y; } private int getUserInput() { return 5; // Simplified for demo } private void displayResult(int result) { System.out.println("Result: " + result); } }''' # --- Gemini API Key Setup --- def setup_api_key(): """Setup API key for Colab, Hugging Face Spaces, and local environments""" api_key = os.getenv("GOOGLE_API_KEY") if not api_key: if IN_COLAB: print("š Please enter your Google API key:") print(" Get a key from: https://aistudio.google.com/app/apikey") api_key = getpass.getpass("GOOGLE_API_KEY: ") if api_key: os.environ["GOOGLE_API_KEY"] = api_key print("ā API key set successfully") else: print("ā ļø No API key provided - agent features will be disabled") else: print("ā ļø GOOGLE_API_KEY not found in environment variables") print(" Set it with: export GOOGLE_API_KEY='your-key-here'") print(" In Hugging Face Spaces, use the 'Secrets' tab to set the key.") else: print("ā Google API key found") return api_key or os.getenv("GOOGLE_API_KEY") # Setup API key api_key = setup_api_key() # Initialize LangChain components model = None memory = None agent_executor = None if api_key: try: # Use a stable, updated model name to avoid the 404 error model = ChatGoogleGenerativeAI(model="gemini-2.0-flash-001", temperature=0) print("ā Gemini model initialized successfully: gemini-2.0-flash-001") memory = MemorySaver() except Exception as e: print(f"ā Could not initialize Gemini model: {e}") print(" Please check your API key and internet connection.") model = None memory = None # --- Tool Definitions (unchanged) --- @tool def analyze_code_structure(source_code: str) -> Dict[str, Any]: """ Analyzes source code structure to identify functions, control flow, and dependencies. Returns structured data about the code that can be used to generate flow diagrams. """ try: # Try to parse as Python first try: tree = ast.parse(source_code) return _analyze_python_ast(tree) except SyntaxError: # If Python parsing fails, do basic text analysis return _analyze_code_text(source_code) except Exception as e: return {"error": f"Analysis error: {str(e)}"} def _analyze_python_ast(tree) -> Dict[str, Any]: """Analyze Python AST""" analysis = { "functions": [], "classes": [], "control_flows": [], "imports": [], "call_graph": {} } class CodeAnalyzer(ast.NodeVisitor): def __init__(self): self.current_function = None def visit_FunctionDef(self, node): func_info = { "name": node.name, "line": node.lineno, "args": [arg.arg for arg in node.args.args], "calls": [], "conditions": [], "loops": [] } self.current_function = node.name analysis["call_graph"][node.name] = [] # Analyze function body for child in ast.walk(node): if isinstance(child, ast.Call): if hasattr(child.func, 'id'): func_info["calls"].append(child.func.id) analysis["call_graph"][node.name].append(child.func.id) elif hasattr(child.func, 'attr'): func_info["calls"].append(child.func.attr) elif isinstance(child, ast.If): func_info["conditions"].append(f"if condition at line {child.lineno}") elif isinstance(child, (ast.For, ast.While)): loop_type = "for" if isinstance(child, ast.For) else "while" func_info["loops"].append(f"{loop_type} loop at line {child.lineno}") analysis["functions"].append(func_info) self.generic_visit(node) def visit_ClassDef(self, node): class_info = { "name": node.name, "line": node.lineno, "methods": [] } for item in node.body: if isinstance(item, ast.FunctionDef): class_info["methods"].append(item.name) analysis["classes"].append(class_info) self.generic_visit(node) def visit_Import(self, node): for alias in node.names: analysis["imports"].append(alias.name) self.generic_visit(node) def visit_ImportFrom(self, node): module = node.module or "" for alias in node.names: analysis["imports"].append(f"{module}.{alias.name}") self.generic_visit(node) analyzer = CodeAnalyzer() analyzer.visit(tree) return analysis def _analyze_code_text(source_code: str) -> Dict[str, Any]: """Basic text-based code analysis for non-Python code""" lines = source_code.split('\n') analysis = { "functions": [], "classes": [], "control_flows": [], "imports": [], "call_graph": {} } for i, line in enumerate(lines, 1): line = line.strip() # JavaScript function detection js_func_match = re.match(r'function\s+(\w+)\s*\(', line) if js_func_match: func_name = js_func_match.group(1) analysis["functions"].append({ "name": func_name, "line": i, "args": [], "calls": [], "conditions": [], "loops": [] }) analysis["call_graph"][func_name] = [] # Java/C++ method detection java_method_match = re.match(r'(?:public|private|protected)?\s*(?:static)?\s*\w+\s+(\w+)\s*\(', line) if java_method_match and not js_func_match: func_name = java_method_match.group(1) if func_name not in ['class', 'if', 'for', 'while']: # Avoid keywords analysis["functions"].append({ "name": func_name, "line": i, "args": [], "calls": [], "conditions": [], "loops": [] }) analysis["call_graph"][func_name] = [] # Control structures if re.match(r'\s*(if|else|elif|switch)\s*[\(\{]', line): analysis["control_flows"].append(f"condition at line {i}") if re.match(r'\s*(for|while|do)\s*[\(\{]', line): analysis["control_flows"].append(f"loop at line {i}") return analysis @tool def generate_mermaid_diagram(analysis_data: Dict[str, Any]) -> str: """ Generates a Mermaid.js flowchart diagram from code analysis data. Creates a visual representation of the code flow including function calls and control structures. """ if "error" in analysis_data: return f"flowchart TD\n Error[ā {analysis_data['error']}]" functions = analysis_data.get("functions", []) call_graph = analysis_data.get("call_graph", {}) if not functions: return """flowchart TD Start([š Program Start]) --> NoFunc[No Functions Found] NoFunc --> End([š Program End]) classDef startEnd fill:#e1f5fe,stroke:#01579b,stroke-width:2px classDef warning fill:#fff3e0,stroke:#e65100,stroke-width:2px class Start,End startEnd class NoFunc warning""" mermaid_lines = ["flowchart TD"] mermaid_lines.append(" Start([š Program Start]) --> Main") # Create nodes for each function func_nodes = [] for i, func in enumerate(functions): func_name = func["name"] safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', func_name) node_id = f"F{i}_{safe_name}" func_nodes.append(node_id) # Function node with emoji mermaid_lines.append(f" {node_id}[āļø {func_name}()]") # Add control structures within function conditions = func.get("conditions", []) loops = func.get("loops", []) if conditions: for j, condition in enumerate(conditions[:2]): # Limit to 2 conditions per function cond_id = f"{node_id}_C{j}" mermaid_lines.append(f" {node_id} --> {cond_id}{{š¤ Decision}}") mermaid_lines.append(f" {cond_id} -->|Yes| {node_id}_Y{j}[ā True Path]") mermaid_lines.append(f" {cond_id} -->|No| {node_id}_N{j}[ā False Path]") if loops: for j, loop in enumerate(loops[:1]): # Limit to 1 loop per function loop_id = f"{node_id}_L{j}" loop_type = "š Loop" if "for" in loop else "ā° While Loop" mermaid_lines.append(f" {node_id} --> {loop_id}[{loop_type}]") mermaid_lines.append(f" {loop_id} --> {loop_id}") # Self-loop # Connect main flow if func_nodes: mermaid_lines.append(f" Main --> {func_nodes[0]}") # Connect functions in sequence (simplified) for i in range(len(func_nodes) - 1): mermaid_lines.append(f" {func_nodes[i]} --> {func_nodes[i + 1]}") # Connect to end mermaid_lines.append(f" {func_nodes[-1]} --> End([š Program End])") # Add function call relationships (simplified to avoid clutter) call_count = 0 for caller, callees in call_graph.items(): if call_count >= 3: # Limit number of call relationships break caller_node = None for node in func_nodes: if caller.lower() in node.lower(): caller_node = node break if caller_node: for callee in callees[:2]: # Limit callees per function callee_node = None for node in func_nodes: if callee.lower() in node.lower(): callee_node = node break if callee_node and callee_node != caller_node: mermaid_lines.append(f" {caller_node} -.->|calls| {callee_node}") call_count += 1 # Add styling mermaid_lines.extend([ "", " classDef startEnd fill:#e1f5fe,stroke:#01579b,stroke-width:3px,color:#000", " classDef process fill:#f3e5f5,stroke:#4a148c,stroke-width:2px,color:#000", " classDef decision fill:#fff3e0,stroke:#e65100,stroke-width:2px,color:#000", " classDef success fill:#e8f5e8,stroke:#2e7d32,stroke-width:2px,color:#000", " classDef error fill:#ffebee,stroke:#c62828,stroke-width:2px,color:#000", "", " class Start,End startEnd", f" class {','.join(func_nodes)} process" if func_nodes else "" ]) return "\n".join(mermaid_lines) @tool def calculate_complexity_score(analysis_data: Dict[str, Any]) -> int: """ Calculates a complexity score for the code based on various metrics. Higher scores indicate more complex code structure. """ if "error" in analysis_data: return 0 score = 0 functions = analysis_data.get("functions", []) # Base score for number of functions score += len(functions) * 3 # Add score for control structures for func in functions: score += len(func.get("conditions", [])) * 4 # Conditions add complexity score += len(func.get("loops", [])) * 3 # Loops add complexity score += len(func.get("calls", [])) * 1 # Function calls add some complexity score += len(func.get("args", [])) * 1 # Parameters add complexity # Add score for classes score += len(analysis_data.get("classes", [])) * 5 return min(score, 100) # Cap at 100 # Create the agent if model is available if model and memory: tools = [analyze_code_structure, generate_mermaid_diagram, calculate_complexity_score] agent_executor = create_react_agent(model, tools, checkpointer=memory) print("ā LangChain agent created successfully") else: agent_executor = None print("ā LangChain agent not available") def analyze_code_with_agent(source_code: str, language: str = "auto") -> Tuple[str, str, List[str], int, str]: """ Main function that uses the LangChain agent to analyze code and generate diagrams. Returns: (mermaid_diagram, analysis_summary, functions_found, complexity_score, error_message) """ if not source_code.strip(): return "", "No code provided", [], 0, "Please enter some source code to analyze" if not agent_executor: return "", "Agent not available", [], 0, "ā LangChain agent not initialized. Please check your GOOGLE_API_KEY" try: # Detect language if auto if language == "auto": if "def " in source_code or "import " in source_code: language = "Python" elif "function " in source_code or "const " in source_code or "let " in source_code: language = "JavaScript" elif ("public " in source_code and "class " in source_code) or "System.out" in source_code: language = "Java" elif "#include" in source_code or "std::" in source_code: language = "C++" else: language = "Unknown" config = { "configurable": {"thread_id": f"session_{hash(source_code) % 10000}"}, "recursion_limit": 100 } # Refined prompt for better tool use prompt = f""" You are a code analysis expert. Analyze the following {language} source code. Your task is to: 1. Use the 'analyze_code_structure' tool with the full source code provided below. 2. Use the 'generate_mermaid_diagram' tool with the output of the first tool. 3. Use the 'calculate_complexity_score' tool with the output of the first tool. 4. Provide a brief, human-readable summary of the analysis, including the generated Mermaid diagram, complexity score, and a list of functions found. 5. Present the final result in a clear, easy-to-read format. Source Code to Analyze: ```{language.lower()} {source_code} ``` """ result = agent_executor.invoke( {"messages": [{"role": "user", "content": prompt}]}, config ) if result and "messages" in result: response_content = result["messages"][-1].content # Extract Mermaid diagram mermaid_match = re.search(r'```mermaid\n(.*?)\n```', response_content, re.DOTALL) mermaid_diagram = mermaid_match.group(1) if mermaid_match else "" # Extract complexity score complexity_match = re.search(r'complexity.*?(\d+)', response_content, re.IGNORECASE) complexity_score = int(complexity_match.group(1)) if complexity_match else 0 # Extract functions functions_found = [] func_matches = re.findall(r'Functions found:.*?([^\n]+)', response_content, re.IGNORECASE) if func_matches: functions_found = [f.strip() for f in func_matches[0].split(',')] else: # Fallback: extract from analysis analysis_result = analyze_code_structure.invoke({"source_code": source_code}) functions_found = [f["name"] for f in analysis_result.get("functions", [])] # Clean up the response for summary summary = re.sub(r'```mermaid.*?```', '', response_content, flags=re.DOTALL) summary = re.sub(r'flowchart TD.*?(?=\n\n|\Z)', '', summary, flags=re.DOTALL) summary = summary.strip() if not mermaid_diagram and not summary: # Last resort fallback if agent fails entirely analysis_result = analyze_code_structure.invoke({"source_code": source_code}) mermaid_diagram = generate_mermaid_diagram.invoke({"analysis_data": analysis_result}) complexity_score = calculate_complexity_score.invoke({"analysis_data": analysis_result}) functions_found = [f["name"] for f in analysis_result.get("functions", [])] summary = "Agent failed to provide a detailed summary, but a fallback analysis was successful." return mermaid_diagram, summary, functions_found, complexity_score, "" except Exception as e: error_msg = f"ā Analysis failed: {str(e)}" print(f"Error details: {traceback.format_exc()}") return "", "", [], 0, error_msg def create_mermaid_html(mermaid_code: str) -> str: """Create HTML to render Mermaid diagram""" if not mermaid_code.strip(): return "