|
import streamlit as st |
|
import tempfile |
|
import os |
|
import json |
|
from typing import List, Dict, Any, Optional, Tuple |
|
import traceback |
|
|
|
|
|
from src.document_processor import DocumentProcessor |
|
from src.llm_extractor import LLMExtractor |
|
from src.graph_builder import GraphBuilder |
|
from src.visualizer import GraphVisualizer |
|
from config.settings import Config |
|
|
|
|
|
st.set_page_config( |
|
page_title="Knowledge Graph Extraction", |
|
page_icon="πΈοΈ", |
|
layout="wide" |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def initialize_components(): |
|
config = Config() |
|
doc_processor = DocumentProcessor() |
|
llm_extractor = LLMExtractor() |
|
graph_builder = GraphBuilder() |
|
visualizer = GraphVisualizer() |
|
return config, doc_processor, llm_extractor, graph_builder, visualizer |
|
|
|
config, doc_processor, llm_extractor, graph_builder, visualizer = initialize_components() |
|
|
|
def process_uploaded_files(uploaded_files, api_key, batch_mode, visualization_type, layout_type, |
|
show_labels, show_edge_labels, min_importance, entity_types_filter): |
|
"""Process uploaded files and extract knowledge graph.""" |
|
|
|
try: |
|
|
|
if api_key.strip(): |
|
config.OPENROUTER_API_KEY = api_key.strip() |
|
llm_extractor.config.OPENROUTER_API_KEY = api_key.strip() |
|
llm_extractor.headers["Authorization"] = f"Bearer {api_key.strip()}" |
|
|
|
if not config.OPENROUTER_API_KEY: |
|
st.error("β OpenRouter API key is required") |
|
return None |
|
|
|
if not uploaded_files: |
|
st.error("β Please upload at least one file") |
|
return None |
|
|
|
progress_bar = st.progress(0) |
|
status_text = st.empty() |
|
|
|
status_text.text("Loading documents...") |
|
progress_bar.progress(0.1) |
|
|
|
|
|
file_paths = [] |
|
for uploaded_file in uploaded_files: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_{uploaded_file.name}") as tmp_file: |
|
tmp_file.write(uploaded_file.getvalue()) |
|
file_paths.append(tmp_file.name) |
|
|
|
|
|
doc_results = doc_processor.process_documents(file_paths, batch_mode) |
|
|
|
|
|
for file_path in file_paths: |
|
try: |
|
os.unlink(file_path) |
|
except: |
|
pass |
|
|
|
|
|
failed_files = [r for r in doc_results if r['status'] == 'error'] |
|
if failed_files: |
|
error_msg = "Failed to process files:\n" + "\n".join([f"- {r['file_path']}: {r['error']}" for r in failed_files]) |
|
if len(failed_files) == len(doc_results): |
|
st.error(f"β {error_msg}") |
|
return None |
|
|
|
status_text.text("Extracting entities and relationships...") |
|
progress_bar.progress(0.3) |
|
|
|
|
|
all_entities = [] |
|
all_relationships = [] |
|
extraction_errors = [] |
|
|
|
for doc_result in doc_results: |
|
if doc_result['status'] == 'success': |
|
extraction_result = llm_extractor.process_chunks(doc_result['chunks']) |
|
|
|
if extraction_result.get('errors'): |
|
extraction_errors.extend(extraction_result['errors']) |
|
|
|
all_entities.extend(extraction_result.get('entities', [])) |
|
all_relationships.extend(extraction_result.get('relationships', [])) |
|
|
|
if not all_entities: |
|
error_msg = "No entities extracted from documents" |
|
if extraction_errors: |
|
error_msg += f"\nExtraction errors: {'; '.join(extraction_errors[:3])}" |
|
st.error(f"β {error_msg}") |
|
return None |
|
|
|
status_text.text("Building knowledge graph...") |
|
progress_bar.progress(0.6) |
|
|
|
|
|
graph = graph_builder.build_graph(all_entities, all_relationships) |
|
|
|
if not graph.nodes(): |
|
st.error("β No valid knowledge graph could be built") |
|
return None |
|
|
|
status_text.text("Applying filters...") |
|
progress_bar.progress(0.7) |
|
|
|
|
|
filtered_graph = graph |
|
if entity_types_filter: |
|
filtered_graph = graph_builder.filter_graph( |
|
entity_types=entity_types_filter, |
|
min_importance=min_importance |
|
) |
|
elif min_importance > 0: |
|
filtered_graph = graph_builder.filter_graph(min_importance=min_importance) |
|
|
|
if not filtered_graph.nodes(): |
|
st.error("β No entities remain after applying filters") |
|
return None |
|
|
|
status_text.text("Generating visualizations...") |
|
progress_bar.progress(0.8) |
|
|
|
|
|
if visualization_type == "plotly": |
|
graph_viz = visualizer.create_plotly_interactive(filtered_graph, layout_type) |
|
graph_image_path = None |
|
elif visualization_type == "pyvis": |
|
graph_image_path = visualizer.create_pyvis_interactive(filtered_graph, layout_type) |
|
graph_viz = None |
|
elif visualization_type == "vis.js": |
|
graph_viz = visualizer.create_interactive_html(filtered_graph) |
|
graph_image_path = None |
|
else: |
|
graph_image_path = visualizer.visualize_graph( |
|
filtered_graph, |
|
layout_type=layout_type, |
|
show_labels=show_labels, |
|
show_edge_labels=show_edge_labels |
|
) |
|
graph_viz = None |
|
|
|
|
|
stats = graph_builder.get_graph_statistics() |
|
stats_summary = visualizer.create_statistics_summary(filtered_graph, stats) |
|
|
|
|
|
entity_list = visualizer.create_entity_list(filtered_graph) |
|
|
|
|
|
central_nodes = graph_builder.get_central_nodes() |
|
central_nodes_text = "## Most Central Entities\n\n" |
|
for i, (node, score) in enumerate(central_nodes, 1): |
|
central_nodes_text += f"{i}. **{node}** (centrality: {score:.3f})\n" |
|
|
|
status_text.text("Complete!") |
|
progress_bar.progress(1.0) |
|
|
|
|
|
success_msg = f"β
Successfully processed {len([r for r in doc_results if r['status'] == 'success'])} document(s)" |
|
if failed_files: |
|
success_msg += f"\nβ οΈ {len(failed_files)} file(s) failed to process" |
|
if extraction_errors: |
|
success_msg += f"\nβ οΈ {len(extraction_errors)} extraction error(s) occurred" |
|
|
|
return { |
|
'success_msg': success_msg, |
|
'graph_image_path': graph_image_path, |
|
'graph_viz': graph_viz, |
|
'visualization_type': visualization_type, |
|
'stats_summary': stats_summary, |
|
'entity_list': entity_list, |
|
'central_nodes_text': central_nodes_text, |
|
'graph': filtered_graph |
|
} |
|
|
|
except Exception as e: |
|
st.error(f"β Error: {str(e)}") |
|
st.error(f"Full traceback:\n{traceback.format_exc()}") |
|
return None |
|
|
|
|
|
def main(): |
|
st.title("πΈοΈ Knowledge Graph Extraction") |
|
st.markdown(""" |
|
Upload documents and extract knowledge graphs using LLMs via OpenRouter. |
|
Supports PDF, TXT, DOCX, and JSON files. |
|
""") |
|
|
|
|
|
with st.sidebar: |
|
st.header("π Document Upload") |
|
uploaded_files = st.file_uploader( |
|
"Choose files", |
|
type=['pdf', 'txt', 'docx', 'json'], |
|
accept_multiple_files=True |
|
) |
|
|
|
batch_mode = st.checkbox( |
|
"Batch Processing Mode", |
|
value=False, |
|
help="Process multiple files together" |
|
) |
|
|
|
st.header("π API Configuration") |
|
api_key = st.text_input( |
|
"OpenRouter API Key", |
|
type="password", |
|
placeholder="Enter your OpenRouter API key", |
|
help="Get your key at openrouter.ai" |
|
) |
|
|
|
st.header("ποΈ Visualization Settings") |
|
visualization_type = st.selectbox( |
|
"Visualization Type", |
|
options=visualizer.get_visualization_options(), |
|
index=1, |
|
help="Choose visualization method" |
|
) |
|
|
|
layout_type = st.selectbox( |
|
"Layout Algorithm", |
|
options=visualizer.get_layout_options(), |
|
index=0 |
|
) |
|
|
|
show_labels = st.checkbox("Show Node Labels", value=True) |
|
show_edge_labels = st.checkbox("Show Edge Labels", value=False) |
|
|
|
st.header("π Filtering Options") |
|
min_importance = st.slider( |
|
"Minimum Entity Importance", |
|
min_value=0.0, |
|
max_value=1.0, |
|
value=0.3, |
|
step=0.1 |
|
) |
|
|
|
entity_types_filter = st.multiselect( |
|
"Entity Types Filter", |
|
options=[], |
|
help="Filter will be populated after processing" |
|
) |
|
|
|
process_button = st.button("π Extract Knowledge Graph", type="primary") |
|
|
|
|
|
if process_button and uploaded_files: |
|
with st.spinner("Processing..."): |
|
result = process_uploaded_files( |
|
uploaded_files, api_key, batch_mode, visualization_type, layout_type, |
|
show_labels, show_edge_labels, min_importance, entity_types_filter |
|
) |
|
|
|
if result: |
|
|
|
st.session_state['result'] = result |
|
|
|
|
|
st.success(result['success_msg']) |
|
|
|
|
|
tab1, tab2, tab3, tab4 = st.tabs(["π Graph Visualization", "π Statistics", "π Entities", "π― Central Nodes"]) |
|
|
|
with tab1: |
|
viz_type = result['visualization_type'] |
|
|
|
if viz_type == "plotly" and result['graph_viz']: |
|
st.plotly_chart(result['graph_viz'], use_container_width=True) |
|
st.info("π― Interactive Plotly graph: Hover over nodes for details, drag to pan, scroll to zoom") |
|
|
|
elif viz_type == "pyvis" and result['graph_image_path'] and os.path.exists(result['graph_image_path']): |
|
|
|
with open(result['graph_image_path'], 'r', encoding='utf-8') as f: |
|
html_content = f.read() |
|
st.components.v1.html(html_content, height=600, scrolling=True) |
|
st.info("π― Interactive Pyvis graph: Drag nodes to rearrange, hover for details") |
|
|
|
elif viz_type == "vis.js" and result['graph_viz']: |
|
st.components.v1.html(result['graph_viz'], height=600, scrolling=True) |
|
st.info("π― Interactive vis.js graph: Drag nodes, hover for details, use physics simulation") |
|
|
|
elif viz_type == "matplotlib" and result['graph_image_path'] and os.path.exists(result['graph_image_path']): |
|
st.image(result['graph_image_path'], caption="Knowledge Graph", use_column_width=True) |
|
st.info("π Static matplotlib visualization") |
|
|
|
else: |
|
st.error("Failed to generate graph visualization") |
|
|
|
with tab2: |
|
st.markdown(result['stats_summary']) |
|
|
|
with tab3: |
|
st.markdown(result['entity_list']) |
|
|
|
with tab4: |
|
st.markdown(result['central_nodes_text']) |
|
|
|
|
|
st.header("πΎ Export Options") |
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
export_format = st.selectbox( |
|
"Export Format", |
|
options=["json", "graphml", "gexf"], |
|
index=0 |
|
) |
|
|
|
with col2: |
|
if st.button("π₯ Export Graph"): |
|
try: |
|
export_data = graph_builder.export_graph(export_format) |
|
st.text_area("Export Data", value=export_data, height=300) |
|
|
|
|
|
st.download_button( |
|
label=f"Download {export_format.upper()} file", |
|
data=export_data, |
|
file_name=f"knowledge_graph.{export_format}", |
|
mime="application/octet-stream" |
|
) |
|
except Exception as e: |
|
st.error(f"Export failed: {str(e)}") |
|
|
|
elif process_button and not uploaded_files: |
|
st.warning("Please upload at least one file before processing.") |
|
|
|
|
|
st.header("π Instructions") |
|
|
|
with st.expander("How to use this app"): |
|
st.markdown(""" |
|
1. **Upload Documents**: Select one or more files (PDF, TXT, DOCX, JSON) using the file uploader in the sidebar |
|
2. **Enter API Key**: Get a free API key from [OpenRouter](https://openrouter.ai) and enter it in the sidebar |
|
3. **Configure Settings**: Adjust visualization and filtering options in the sidebar |
|
4. **Extract Graph**: Click the "Extract Knowledge Graph" button and wait for processing |
|
5. **Explore Results**: View the graph, statistics, and entity details in the tabs |
|
6. **Export**: Download the graph data in various formats |
|
""") |
|
|
|
with st.expander("Features"): |
|
st.markdown(""" |
|
- **Multi-format Support**: PDF, TXT, DOCX, JSON files |
|
- **Batch Processing**: Process multiple documents together |
|
- **Smart Extraction**: Uses LLM to identify important entities and relationships |
|
- **Interactive Filtering**: Filter by entity type and importance |
|
- **Multiple Layouts**: Various graph layout algorithms |
|
- **Export Options**: JSON, GraphML, GEXF formats |
|
- **Free Models**: Uses cost-effective OpenRouter models |
|
""") |
|
|
|
with st.expander("Notes"): |
|
st.markdown(""" |
|
- File size limit: 10MB per file |
|
- Free OpenRouter models are used to minimize costs |
|
- Processing time depends on document size and complexity |
|
""") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|