import streamlit as st import tempfile import os import json from typing import List, Dict, Any, Optional, Tuple import traceback # Import our modules from src.document_processor import DocumentProcessor from src.llm_extractor import LLMExtractor from src.graph_builder import GraphBuilder from src.visualizer import GraphVisualizer from config.settings import Config # Page config st.set_page_config( page_title="Knowledge Graph Extraction", page_icon="πŸ•ΈοΈ", layout="wide" ) # Initialize components @st.cache_resource def initialize_components(): config = Config() doc_processor = DocumentProcessor() llm_extractor = LLMExtractor() graph_builder = GraphBuilder() visualizer = GraphVisualizer() return config, doc_processor, llm_extractor, graph_builder, visualizer config, doc_processor, llm_extractor, graph_builder, visualizer = initialize_components() def process_uploaded_files(uploaded_files, api_key, batch_mode, visualization_type, layout_type, show_labels, show_edge_labels, min_importance, entity_types_filter): """Process uploaded files and extract knowledge graph.""" try: # Update API key if api_key.strip(): config.OPENROUTER_API_KEY = api_key.strip() llm_extractor.config.OPENROUTER_API_KEY = api_key.strip() llm_extractor.headers["Authorization"] = f"Bearer {api_key.strip()}" if not config.OPENROUTER_API_KEY: st.error("❌ OpenRouter API key is required") return None if not uploaded_files: st.error("❌ Please upload at least one file") return None progress_bar = st.progress(0) status_text = st.empty() status_text.text("Loading documents...") progress_bar.progress(0.1) # Save uploaded files to temporary location file_paths = [] for uploaded_file in uploaded_files: # Create temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{uploaded_file.name}") as tmp_file: tmp_file.write(uploaded_file.getvalue()) file_paths.append(tmp_file.name) # Process documents doc_results = doc_processor.process_documents(file_paths, batch_mode) # Clean up temporary files for file_path in file_paths: try: os.unlink(file_path) except: pass # Check for errors failed_files = [r for r in doc_results if r['status'] == 'error'] if failed_files: error_msg = "Failed to process files:\n" + "\n".join([f"- {r['file_path']}: {r['error']}" for r in failed_files]) if len(failed_files) == len(doc_results): st.error(f"❌ {error_msg}") return None status_text.text("Extracting entities and relationships...") progress_bar.progress(0.3) # Extract entities and relationships all_entities = [] all_relationships = [] extraction_errors = [] for doc_result in doc_results: if doc_result['status'] == 'success': extraction_result = llm_extractor.process_chunks(doc_result['chunks']) if extraction_result.get('errors'): extraction_errors.extend(extraction_result['errors']) all_entities.extend(extraction_result.get('entities', [])) all_relationships.extend(extraction_result.get('relationships', [])) if not all_entities: error_msg = "No entities extracted from documents" if extraction_errors: error_msg += f"\nExtraction errors: {'; '.join(extraction_errors[:3])}" st.error(f"❌ {error_msg}") return None status_text.text("Building knowledge graph...") progress_bar.progress(0.6) # Build graph graph = graph_builder.build_graph(all_entities, all_relationships) if not graph.nodes(): st.error("❌ No valid knowledge graph could be built") return None status_text.text("Applying filters...") progress_bar.progress(0.7) # Apply filters filtered_graph = graph if entity_types_filter: filtered_graph = graph_builder.filter_graph( entity_types=entity_types_filter, min_importance=min_importance ) elif min_importance > 0: filtered_graph = graph_builder.filter_graph(min_importance=min_importance) if not filtered_graph.nodes(): st.error("❌ No entities remain after applying filters") return None status_text.text("Generating visualizations...") progress_bar.progress(0.8) # Generate graph visualization based on type if visualization_type == "plotly": graph_viz = visualizer.create_plotly_interactive(filtered_graph, layout_type) graph_image_path = None elif visualization_type == "pyvis": graph_image_path = visualizer.create_pyvis_interactive(filtered_graph, layout_type) graph_viz = None elif visualization_type == "vis.js": graph_viz = visualizer.create_interactive_html(filtered_graph) graph_image_path = None else: # matplotlib graph_image_path = visualizer.visualize_graph( filtered_graph, layout_type=layout_type, show_labels=show_labels, show_edge_labels=show_edge_labels ) graph_viz = None # Get statistics stats = graph_builder.get_graph_statistics() stats_summary = visualizer.create_statistics_summary(filtered_graph, stats) # Get entity list entity_list = visualizer.create_entity_list(filtered_graph) # Get central nodes central_nodes = graph_builder.get_central_nodes() central_nodes_text = "## Most Central Entities\n\n" for i, (node, score) in enumerate(central_nodes, 1): central_nodes_text += f"{i}. **{node}** (centrality: {score:.3f})\n" status_text.text("Complete!") progress_bar.progress(1.0) # Success message success_msg = f"βœ… Successfully processed {len([r for r in doc_results if r['status'] == 'success'])} document(s)" if failed_files: success_msg += f"\n⚠️ {len(failed_files)} file(s) failed to process" if extraction_errors: success_msg += f"\n⚠️ {len(extraction_errors)} extraction error(s) occurred" return { 'success_msg': success_msg, 'graph_image_path': graph_image_path, 'graph_viz': graph_viz, 'visualization_type': visualization_type, 'stats_summary': stats_summary, 'entity_list': entity_list, 'central_nodes_text': central_nodes_text, 'graph': filtered_graph } except Exception as e: st.error(f"❌ Error: {str(e)}") st.error(f"Full traceback:\n{traceback.format_exc()}") return None # Main app def main(): st.title("πŸ•ΈοΈ Knowledge Graph Extraction") st.markdown(""" Upload documents and extract knowledge graphs using LLMs via OpenRouter. Supports PDF, TXT, DOCX, and JSON files. """) # Sidebar for configuration with st.sidebar: st.header("πŸ“ Document Upload") uploaded_files = st.file_uploader( "Choose files", type=['pdf', 'txt', 'docx', 'json'], accept_multiple_files=True ) batch_mode = st.checkbox( "Batch Processing Mode", value=False, help="Process multiple files together" ) st.header("πŸ”‘ API Configuration") api_key = st.text_input( "OpenRouter API Key", type="password", placeholder="Enter your OpenRouter API key", help="Get your key at openrouter.ai" ) st.header("πŸŽ›οΈ Visualization Settings") visualization_type = st.selectbox( "Visualization Type", options=visualizer.get_visualization_options(), index=1, # Default to plotly for interactivity help="Choose visualization method" ) layout_type = st.selectbox( "Layout Algorithm", options=visualizer.get_layout_options(), index=0 ) show_labels = st.checkbox("Show Node Labels", value=True) show_edge_labels = st.checkbox("Show Edge Labels", value=False) st.header("πŸ” Filtering Options") min_importance = st.slider( "Minimum Entity Importance", min_value=0.0, max_value=1.0, value=0.3, step=0.1 ) entity_types_filter = st.multiselect( "Entity Types Filter", options=[], help="Filter will be populated after processing" ) process_button = st.button("πŸš€ Extract Knowledge Graph", type="primary") # Main content area if process_button and uploaded_files: with st.spinner("Processing..."): result = process_uploaded_files( uploaded_files, api_key, batch_mode, visualization_type, layout_type, show_labels, show_edge_labels, min_importance, entity_types_filter ) if result: # Store results in session state st.session_state['result'] = result # Display success message st.success(result['success_msg']) # Create tabs for results tab1, tab2, tab3, tab4 = st.tabs(["πŸ“ˆ Graph Visualization", "πŸ“‹ Statistics", "πŸ“ Entities", "🎯 Central Nodes"]) with tab1: viz_type = result['visualization_type'] if viz_type == "plotly" and result['graph_viz']: st.plotly_chart(result['graph_viz'], use_container_width=True) st.info("🎯 Interactive Plotly graph: Hover over nodes for details, drag to pan, scroll to zoom") elif viz_type == "pyvis" and result['graph_image_path'] and os.path.exists(result['graph_image_path']): # Read HTML file and display with open(result['graph_image_path'], 'r', encoding='utf-8') as f: html_content = f.read() st.components.v1.html(html_content, height=600, scrolling=True) st.info("🎯 Interactive Pyvis graph: Drag nodes to rearrange, hover for details") elif viz_type == "vis.js" and result['graph_viz']: st.components.v1.html(result['graph_viz'], height=600, scrolling=True) st.info("🎯 Interactive vis.js graph: Drag nodes, hover for details, use physics simulation") elif viz_type == "matplotlib" and result['graph_image_path'] and os.path.exists(result['graph_image_path']): st.image(result['graph_image_path'], caption="Knowledge Graph", use_column_width=True) st.info("πŸ“Š Static matplotlib visualization") else: st.error("Failed to generate graph visualization") with tab2: st.markdown(result['stats_summary']) with tab3: st.markdown(result['entity_list']) with tab4: st.markdown(result['central_nodes_text']) # Export options st.header("πŸ’Ύ Export Options") col1, col2 = st.columns(2) with col1: export_format = st.selectbox( "Export Format", options=["json", "graphml", "gexf"], index=0 ) with col2: if st.button("πŸ“₯ Export Graph"): try: export_data = graph_builder.export_graph(export_format) st.text_area("Export Data", value=export_data, height=300) # Download button st.download_button( label=f"Download {export_format.upper()} file", data=export_data, file_name=f"knowledge_graph.{export_format}", mime="application/octet-stream" ) except Exception as e: st.error(f"Export failed: {str(e)}") elif process_button and not uploaded_files: st.warning("Please upload at least one file before processing.") # Instructions st.header("πŸ“š Instructions") with st.expander("How to use this app"): st.markdown(""" 1. **Upload Documents**: Select one or more files (PDF, TXT, DOCX, JSON) using the file uploader in the sidebar 2. **Enter API Key**: Get a free API key from [OpenRouter](https://openrouter.ai) and enter it in the sidebar 3. **Configure Settings**: Adjust visualization and filtering options in the sidebar 4. **Extract Graph**: Click the "Extract Knowledge Graph" button and wait for processing 5. **Explore Results**: View the graph, statistics, and entity details in the tabs 6. **Export**: Download the graph data in various formats """) with st.expander("Features"): st.markdown(""" - **Multi-format Support**: PDF, TXT, DOCX, JSON files - **Batch Processing**: Process multiple documents together - **Smart Extraction**: Uses LLM to identify important entities and relationships - **Interactive Filtering**: Filter by entity type and importance - **Multiple Layouts**: Various graph layout algorithms - **Export Options**: JSON, GraphML, GEXF formats - **Free Models**: Uses cost-effective OpenRouter models """) with st.expander("Notes"): st.markdown(""" - File size limit: 10MB per file - Free OpenRouter models are used to minimize costs - Processing time depends on document size and complexity """) if __name__ == "__main__": main()