File size: 12,703 Bytes
e86199a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
import streamlit as st
import tempfile
import os
import json
from typing import List, Dict, Any, Optional, Tuple
import traceback

# Import our modules
from src.document_processor import DocumentProcessor
from src.llm_extractor import LLMExtractor
from src.graph_builder import GraphBuilder
from src.visualizer import GraphVisualizer
from config.settings import Config

# Page config
st.set_page_config(
    page_title="Knowledge Graph Extraction",
    page_icon="πŸ•ΈοΈ",
    layout="wide"
)

# Initialize components
@st.cache_resource
def initialize_components():
    config = Config()
    doc_processor = DocumentProcessor()
    llm_extractor = LLMExtractor()
    graph_builder = GraphBuilder()
    visualizer = GraphVisualizer()
    return config, doc_processor, llm_extractor, graph_builder, visualizer

config, doc_processor, llm_extractor, graph_builder, visualizer = initialize_components()

def process_uploaded_files(uploaded_files, api_key, batch_mode, layout_type, 
                          show_labels, show_edge_labels, min_importance, entity_types_filter):
    """Process uploaded files and extract knowledge graph."""
    
    try:
        # Update API key
        if api_key.strip():
            config.OPENROUTER_API_KEY = api_key.strip()
            llm_extractor.config.OPENROUTER_API_KEY = api_key.strip()
            llm_extractor.headers["Authorization"] = f"Bearer {api_key.strip()}"
        
        if not config.OPENROUTER_API_KEY:
            st.error("❌ OpenRouter API key is required")
            return None
        
        if not uploaded_files:
            st.error("❌ Please upload at least one file")
            return None
        
        progress_bar = st.progress(0)
        status_text = st.empty()
        
        status_text.text("Loading documents...")
        progress_bar.progress(0.1)
        
        # Save uploaded files to temporary location
        file_paths = []
        for uploaded_file in uploaded_files:
            # Create temporary file
            with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{uploaded_file.name}") as tmp_file:
                tmp_file.write(uploaded_file.getvalue())
                file_paths.append(tmp_file.name)
        
        # Process documents
        doc_results = doc_processor.process_documents(file_paths, batch_mode)
        
        # Clean up temporary files
        for file_path in file_paths:
            try:
                os.unlink(file_path)
            except:
                pass
        
        # Check for errors
        failed_files = [r for r in doc_results if r['status'] == 'error']
        if failed_files:
            error_msg = "Failed to process files:\n" + "\n".join([f"- {r['file_path']}: {r['error']}" for r in failed_files])
            if len(failed_files) == len(doc_results):
                st.error(f"❌ {error_msg}")
                return None
        
        status_text.text("Extracting entities and relationships...")
        progress_bar.progress(0.3)
        
        # Extract entities and relationships
        all_entities = []
        all_relationships = []
        extraction_errors = []
        
        for doc_result in doc_results:
            if doc_result['status'] == 'success':
                extraction_result = llm_extractor.process_chunks(doc_result['chunks'])
                
                if extraction_result.get('errors'):
                    extraction_errors.extend(extraction_result['errors'])
                
                all_entities.extend(extraction_result.get('entities', []))
                all_relationships.extend(extraction_result.get('relationships', []))
        
        if not all_entities:
            error_msg = "No entities extracted from documents"
            if extraction_errors:
                error_msg += f"\nExtraction errors: {'; '.join(extraction_errors[:3])}"
            st.error(f"❌ {error_msg}")
            return None
        
        status_text.text("Building knowledge graph...")
        progress_bar.progress(0.6)
        
        # Build graph
        graph = graph_builder.build_graph(all_entities, all_relationships)
        
        if not graph.nodes():
            st.error("❌ No valid knowledge graph could be built")
            return None
        
        status_text.text("Applying filters...")
        progress_bar.progress(0.7)
        
        # Apply filters
        filtered_graph = graph
        if entity_types_filter:
            filtered_graph = graph_builder.filter_graph(
                entity_types=entity_types_filter,
                min_importance=min_importance
            )
        elif min_importance > 0:
            filtered_graph = graph_builder.filter_graph(min_importance=min_importance)
        
        if not filtered_graph.nodes():
            st.error("❌ No entities remain after applying filters")
            return None
        
        status_text.text("Generating visualizations...")
        progress_bar.progress(0.8)
        
        # Generate graph visualization
        graph_image_path = visualizer.visualize_graph(
            filtered_graph,
            layout_type=layout_type,
            show_labels=show_labels,
            show_edge_labels=show_edge_labels
        )
        
        # Get statistics
        stats = graph_builder.get_graph_statistics()
        stats_summary = visualizer.create_statistics_summary(filtered_graph, stats)
        
        # Get entity list
        entity_list = visualizer.create_entity_list(filtered_graph)
        
        # Get central nodes
        central_nodes = graph_builder.get_central_nodes()
        central_nodes_text = "## Most Central Entities\n\n"
        for i, (node, score) in enumerate(central_nodes, 1):
            central_nodes_text += f"{i}. **{node}** (centrality: {score:.3f})\n"
        
        status_text.text("Complete!")
        progress_bar.progress(1.0)
        
        # Success message
        success_msg = f"βœ… Successfully processed {len([r for r in doc_results if r['status'] == 'success'])} document(s)"
        if failed_files:
            success_msg += f"\n⚠️ {len(failed_files)} file(s) failed to process"
        if extraction_errors:
            success_msg += f"\n⚠️ {len(extraction_errors)} extraction error(s) occurred"
        
        return {
            'success_msg': success_msg,
            'graph_image_path': graph_image_path,
            'stats_summary': stats_summary,
            'entity_list': entity_list,
            'central_nodes_text': central_nodes_text,
            'graph': filtered_graph
        }
        
    except Exception as e:
        st.error(f"❌ Error: {str(e)}")
        st.error(f"Full traceback:\n{traceback.format_exc()}")
        return None

# Main app
def main():
    st.title("πŸ•ΈοΈ Knowledge Graph Extraction")
    st.markdown("""
    Upload documents and extract knowledge graphs using LLMs via OpenRouter.
    Supports PDF, TXT, DOCX, and JSON files.
    """)
    
    # Sidebar for configuration
    with st.sidebar:
        st.header("πŸ“ Document Upload")
        uploaded_files = st.file_uploader(
            "Choose files",
            type=['pdf', 'txt', 'docx', 'json'],
            accept_multiple_files=True
        )
        
        batch_mode = st.checkbox(
            "Batch Processing Mode",
            value=False,
            help="Process multiple files together"
        )
        
        st.header("πŸ”‘ API Configuration")
        api_key = st.text_input(
            "OpenRouter API Key",
            type="password",
            placeholder="Enter your OpenRouter API key",
            help="Get your key at openrouter.ai"
        )
        
        st.header("πŸŽ›οΈ Visualization Settings")
        layout_type = st.selectbox(
            "Layout Algorithm",
            options=visualizer.get_layout_options(),
            index=0
        )
        
        show_labels = st.checkbox("Show Node Labels", value=True)
        show_edge_labels = st.checkbox("Show Edge Labels", value=False)
        
        st.header("πŸ” Filtering Options")
        min_importance = st.slider(
            "Minimum Entity Importance",
            min_value=0.0,
            max_value=1.0,
            value=0.3,
            step=0.1
        )
        
        entity_types_filter = st.multiselect(
            "Entity Types Filter",
            options=[],
            help="Filter will be populated after processing"
        )
        
        process_button = st.button("πŸš€ Extract Knowledge Graph", type="primary")
    
    # Main content area
    if process_button and uploaded_files:
        with st.spinner("Processing..."):
            result = process_uploaded_files(
                uploaded_files, api_key, batch_mode, layout_type,
                show_labels, show_edge_labels, min_importance, entity_types_filter
            )
        
        if result:
            # Store results in session state
            st.session_state['result'] = result
            
            # Display success message
            st.success(result['success_msg'])
            
            # Create tabs for results
            tab1, tab2, tab3, tab4 = st.tabs(["πŸ“ˆ Graph Visualization", "πŸ“‹ Statistics", "πŸ“ Entities", "🎯 Central Nodes"])
            
            with tab1:
                if result['graph_image_path'] and os.path.exists(result['graph_image_path']):
                    st.image(result['graph_image_path'], caption="Knowledge Graph", use_column_width=True)
                else:
                    st.error("Failed to generate graph visualization")
            
            with tab2:
                st.markdown(result['stats_summary'])
            
            with tab3:
                st.markdown(result['entity_list'])
            
            with tab4:
                st.markdown(result['central_nodes_text'])
            
            # Export options
            st.header("πŸ’Ύ Export Options")
            col1, col2 = st.columns(2)
            
            with col1:
                export_format = st.selectbox(
                    "Export Format",
                    options=["json", "graphml", "gexf"],
                    index=0
                )
            
            with col2:
                if st.button("πŸ“₯ Export Graph"):
                    try:
                        export_data = graph_builder.export_graph(export_format)
                        st.text_area("Export Data", value=export_data, height=300)
                        
                        # Download button
                        st.download_button(
                            label=f"Download {export_format.upper()} file",
                            data=export_data,
                            file_name=f"knowledge_graph.{export_format}",
                            mime="application/octet-stream"
                        )
                    except Exception as e:
                        st.error(f"Export failed: {str(e)}")
    
    elif process_button and not uploaded_files:
        st.warning("Please upload at least one file before processing.")
    
    # Instructions
    st.header("πŸ“š Instructions")
    
    with st.expander("How to use this app"):
        st.markdown("""
        1. **Upload Documents**: Select one or more files (PDF, TXT, DOCX, JSON) using the file uploader in the sidebar
        2. **Enter API Key**: Get a free API key from [OpenRouter](https://openrouter.ai) and enter it in the sidebar
        3. **Configure Settings**: Adjust visualization and filtering options in the sidebar
        4. **Extract Graph**: Click the "Extract Knowledge Graph" button and wait for processing
        5. **Explore Results**: View the graph, statistics, and entity details in the tabs
        6. **Export**: Download the graph data in various formats
        """)
    
    with st.expander("Features"):
        st.markdown("""
        - **Multi-format Support**: PDF, TXT, DOCX, JSON files
        - **Batch Processing**: Process multiple documents together
        - **Smart Extraction**: Uses LLM to identify important entities and relationships
        - **Interactive Filtering**: Filter by entity type and importance
        - **Multiple Layouts**: Various graph layout algorithms
        - **Export Options**: JSON, GraphML, GEXF formats
        - **Free Models**: Uses cost-effective OpenRouter models
        """)
    
    with st.expander("Notes"):
        st.markdown("""
        - File size limit: 10MB per file
        - Free OpenRouter models are used to minimize costs
        - Processing time depends on document size and complexity
        """)

if __name__ == "__main__":
    main()