CultriX's picture
First commit
e86199a
raw
history blame
12.7 kB
import streamlit as st
import tempfile
import os
import json
from typing import List, Dict, Any, Optional, Tuple
import traceback
# Import our modules
from src.document_processor import DocumentProcessor
from src.llm_extractor import LLMExtractor
from src.graph_builder import GraphBuilder
from src.visualizer import GraphVisualizer
from config.settings import Config
# Page config
st.set_page_config(
page_title="Knowledge Graph Extraction",
page_icon="πŸ•ΈοΈ",
layout="wide"
)
# Initialize components
@st.cache_resource
def initialize_components():
config = Config()
doc_processor = DocumentProcessor()
llm_extractor = LLMExtractor()
graph_builder = GraphBuilder()
visualizer = GraphVisualizer()
return config, doc_processor, llm_extractor, graph_builder, visualizer
config, doc_processor, llm_extractor, graph_builder, visualizer = initialize_components()
def process_uploaded_files(uploaded_files, api_key, batch_mode, layout_type,
show_labels, show_edge_labels, min_importance, entity_types_filter):
"""Process uploaded files and extract knowledge graph."""
try:
# Update API key
if api_key.strip():
config.OPENROUTER_API_KEY = api_key.strip()
llm_extractor.config.OPENROUTER_API_KEY = api_key.strip()
llm_extractor.headers["Authorization"] = f"Bearer {api_key.strip()}"
if not config.OPENROUTER_API_KEY:
st.error("❌ OpenRouter API key is required")
return None
if not uploaded_files:
st.error("❌ Please upload at least one file")
return None
progress_bar = st.progress(0)
status_text = st.empty()
status_text.text("Loading documents...")
progress_bar.progress(0.1)
# Save uploaded files to temporary location
file_paths = []
for uploaded_file in uploaded_files:
# Create temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{uploaded_file.name}") as tmp_file:
tmp_file.write(uploaded_file.getvalue())
file_paths.append(tmp_file.name)
# Process documents
doc_results = doc_processor.process_documents(file_paths, batch_mode)
# Clean up temporary files
for file_path in file_paths:
try:
os.unlink(file_path)
except:
pass
# Check for errors
failed_files = [r for r in doc_results if r['status'] == 'error']
if failed_files:
error_msg = "Failed to process files:\n" + "\n".join([f"- {r['file_path']}: {r['error']}" for r in failed_files])
if len(failed_files) == len(doc_results):
st.error(f"❌ {error_msg}")
return None
status_text.text("Extracting entities and relationships...")
progress_bar.progress(0.3)
# Extract entities and relationships
all_entities = []
all_relationships = []
extraction_errors = []
for doc_result in doc_results:
if doc_result['status'] == 'success':
extraction_result = llm_extractor.process_chunks(doc_result['chunks'])
if extraction_result.get('errors'):
extraction_errors.extend(extraction_result['errors'])
all_entities.extend(extraction_result.get('entities', []))
all_relationships.extend(extraction_result.get('relationships', []))
if not all_entities:
error_msg = "No entities extracted from documents"
if extraction_errors:
error_msg += f"\nExtraction errors: {'; '.join(extraction_errors[:3])}"
st.error(f"❌ {error_msg}")
return None
status_text.text("Building knowledge graph...")
progress_bar.progress(0.6)
# Build graph
graph = graph_builder.build_graph(all_entities, all_relationships)
if not graph.nodes():
st.error("❌ No valid knowledge graph could be built")
return None
status_text.text("Applying filters...")
progress_bar.progress(0.7)
# Apply filters
filtered_graph = graph
if entity_types_filter:
filtered_graph = graph_builder.filter_graph(
entity_types=entity_types_filter,
min_importance=min_importance
)
elif min_importance > 0:
filtered_graph = graph_builder.filter_graph(min_importance=min_importance)
if not filtered_graph.nodes():
st.error("❌ No entities remain after applying filters")
return None
status_text.text("Generating visualizations...")
progress_bar.progress(0.8)
# Generate graph visualization
graph_image_path = visualizer.visualize_graph(
filtered_graph,
layout_type=layout_type,
show_labels=show_labels,
show_edge_labels=show_edge_labels
)
# Get statistics
stats = graph_builder.get_graph_statistics()
stats_summary = visualizer.create_statistics_summary(filtered_graph, stats)
# Get entity list
entity_list = visualizer.create_entity_list(filtered_graph)
# Get central nodes
central_nodes = graph_builder.get_central_nodes()
central_nodes_text = "## Most Central Entities\n\n"
for i, (node, score) in enumerate(central_nodes, 1):
central_nodes_text += f"{i}. **{node}** (centrality: {score:.3f})\n"
status_text.text("Complete!")
progress_bar.progress(1.0)
# Success message
success_msg = f"βœ… Successfully processed {len([r for r in doc_results if r['status'] == 'success'])} document(s)"
if failed_files:
success_msg += f"\n⚠️ {len(failed_files)} file(s) failed to process"
if extraction_errors:
success_msg += f"\n⚠️ {len(extraction_errors)} extraction error(s) occurred"
return {
'success_msg': success_msg,
'graph_image_path': graph_image_path,
'stats_summary': stats_summary,
'entity_list': entity_list,
'central_nodes_text': central_nodes_text,
'graph': filtered_graph
}
except Exception as e:
st.error(f"❌ Error: {str(e)}")
st.error(f"Full traceback:\n{traceback.format_exc()}")
return None
# Main app
def main():
st.title("πŸ•ΈοΈ Knowledge Graph Extraction")
st.markdown("""
Upload documents and extract knowledge graphs using LLMs via OpenRouter.
Supports PDF, TXT, DOCX, and JSON files.
""")
# Sidebar for configuration
with st.sidebar:
st.header("πŸ“ Document Upload")
uploaded_files = st.file_uploader(
"Choose files",
type=['pdf', 'txt', 'docx', 'json'],
accept_multiple_files=True
)
batch_mode = st.checkbox(
"Batch Processing Mode",
value=False,
help="Process multiple files together"
)
st.header("πŸ”‘ API Configuration")
api_key = st.text_input(
"OpenRouter API Key",
type="password",
placeholder="Enter your OpenRouter API key",
help="Get your key at openrouter.ai"
)
st.header("πŸŽ›οΈ Visualization Settings")
layout_type = st.selectbox(
"Layout Algorithm",
options=visualizer.get_layout_options(),
index=0
)
show_labels = st.checkbox("Show Node Labels", value=True)
show_edge_labels = st.checkbox("Show Edge Labels", value=False)
st.header("πŸ” Filtering Options")
min_importance = st.slider(
"Minimum Entity Importance",
min_value=0.0,
max_value=1.0,
value=0.3,
step=0.1
)
entity_types_filter = st.multiselect(
"Entity Types Filter",
options=[],
help="Filter will be populated after processing"
)
process_button = st.button("πŸš€ Extract Knowledge Graph", type="primary")
# Main content area
if process_button and uploaded_files:
with st.spinner("Processing..."):
result = process_uploaded_files(
uploaded_files, api_key, batch_mode, layout_type,
show_labels, show_edge_labels, min_importance, entity_types_filter
)
if result:
# Store results in session state
st.session_state['result'] = result
# Display success message
st.success(result['success_msg'])
# Create tabs for results
tab1, tab2, tab3, tab4 = st.tabs(["πŸ“ˆ Graph Visualization", "πŸ“‹ Statistics", "πŸ“ Entities", "🎯 Central Nodes"])
with tab1:
if result['graph_image_path'] and os.path.exists(result['graph_image_path']):
st.image(result['graph_image_path'], caption="Knowledge Graph", use_column_width=True)
else:
st.error("Failed to generate graph visualization")
with tab2:
st.markdown(result['stats_summary'])
with tab3:
st.markdown(result['entity_list'])
with tab4:
st.markdown(result['central_nodes_text'])
# Export options
st.header("πŸ’Ύ Export Options")
col1, col2 = st.columns(2)
with col1:
export_format = st.selectbox(
"Export Format",
options=["json", "graphml", "gexf"],
index=0
)
with col2:
if st.button("πŸ“₯ Export Graph"):
try:
export_data = graph_builder.export_graph(export_format)
st.text_area("Export Data", value=export_data, height=300)
# Download button
st.download_button(
label=f"Download {export_format.upper()} file",
data=export_data,
file_name=f"knowledge_graph.{export_format}",
mime="application/octet-stream"
)
except Exception as e:
st.error(f"Export failed: {str(e)}")
elif process_button and not uploaded_files:
st.warning("Please upload at least one file before processing.")
# Instructions
st.header("πŸ“š Instructions")
with st.expander("How to use this app"):
st.markdown("""
1. **Upload Documents**: Select one or more files (PDF, TXT, DOCX, JSON) using the file uploader in the sidebar
2. **Enter API Key**: Get a free API key from [OpenRouter](https://openrouter.ai) and enter it in the sidebar
3. **Configure Settings**: Adjust visualization and filtering options in the sidebar
4. **Extract Graph**: Click the "Extract Knowledge Graph" button and wait for processing
5. **Explore Results**: View the graph, statistics, and entity details in the tabs
6. **Export**: Download the graph data in various formats
""")
with st.expander("Features"):
st.markdown("""
- **Multi-format Support**: PDF, TXT, DOCX, JSON files
- **Batch Processing**: Process multiple documents together
- **Smart Extraction**: Uses LLM to identify important entities and relationships
- **Interactive Filtering**: Filter by entity type and importance
- **Multiple Layouts**: Various graph layout algorithms
- **Export Options**: JSON, GraphML, GEXF formats
- **Free Models**: Uses cost-effective OpenRouter models
""")
with st.expander("Notes"):
st.markdown("""
- File size limit: 10MB per file
- Free OpenRouter models are used to minimize costs
- Processing time depends on document size and complexity
""")
if __name__ == "__main__":
main()