blazingbunny's picture
Update app.py
649323e verified
# ── Streamlit must write to /tmp on Spaces ──────────────────────────────────────
import os as _os
_os.environ["STREAMLIT_CONFIG_DIR"] = "/tmp"
_os.environ["STREAMLIT_CACHE_DIR"] = "/tmp"
_os.environ["STREAMLIT_CACHE_STORAGE"] = "filesystem"
# ── Imports ────────────────────────────────────────────────────────────────────
import os
import io
import json
import streamlit as st
import pandas as pd
from collections import Counter
from pypdf import PdfReader
from pyvis.network import Network
from knowledge_graph_maker import (
GraphMaker, Ontology, Document, OpenAIClient
)
# ── Page setup ──────────────────────────────────────────────────────────────────
st.set_page_config(page_title="Knowledge Graph (OpenRouter)", layout="wide")
st.title("Knowledge Graph from Text/PDF β€” OpenRouter")
st.caption("Builds a knowledge graph with knowledge-graph-maker via OpenRouter. Pick a model, choose presets, and render via PyVis or Cytoscape.js.")
# ── Secrets / env ───────────────────────────────────────────────────────────────
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
# Preset OpenRouter models (extend as needed)
OPENROUTER_MODELS = [
"openai/gpt-oss-20b:free",
"moonshotai/kimi-k2:free",
"google/gemini-2.0-flash-exp:free",
"google/gemma-3-27b-it:free",
]
# ---- Preset defaults in session state ----
if "temperature" not in st.session_state:
st.session_state.temperature = 0.1
if "top_p" not in st.session_state:
st.session_state.top_p = 0.9
# ── Sidebar controls ───────────────────────────────────────────────────────────
with st.sidebar:
st.subheader("Model & Generation Settings")
model_choice = st.selectbox("OpenRouter model", OPENROUTER_MODELS, index=0)
custom_model = st.text_input("Custom model id (optional)", placeholder="e.g. meta-llama/llama-3.1-8b-instruct")
st.markdown("### Preset")
PRESETS = {
"Extractive (stable)": {"temperature": 0.1, "top_p": 0.9, "desc": "Most deterministic; best for IE"},
"Balanced": {"temperature": 0.2, "top_p": 0.9, "desc": "Slightly more recall"},
"Exploratory": {"temperature": 0.4, "top_p": 0.95, "desc": "More ideas; may add noise"},
}
preset_names = list(PRESETS.keys())
preset = st.selectbox("Choose a preset", preset_names, index=0,
help=PRESETS[preset_names[0]]["desc"])
if st.button("Apply preset"):
st.session_state.temperature = PRESETS[preset]["temperature"]
st.session_state.top_p = PRESETS[preset]["top_p"]
st.toast(f"Applied: {preset}", icon="βœ…")
temperature = st.slider(
"Temperature", 0.0, 1.0, key="temperature", step=0.05,
help="Lower = more deterministic; higher = more variety"
)
top_p = st.slider(
"Top-p", 0.0, 1.0, key="top_p", step=0.05,
help="Nucleus sampling threshold; 0.9 is a good default"
)
st.markdown("### Ontology (labels)")
labels_text = st.text_area(
"Comma-separated labels",
value="Person, Object, Event, Place, Document, Organisation, Action, Miscellanous",
height=70,
)
relationships_text = st.text_input(
"Relationships (comma-separated)",
value="Relation between any pair of Entities",
)
st.markdown("### Visualization")
renderer = st.radio("Renderer", ["PyVis (interactive)", "Cytoscape.js (beta)"], index=0)
label_mode = st.radio("Edge labels", ["Always visible", "Tooltip only"], index=0)
show_legend = st.checkbox("Show color legend", value=True)
# ── Helpers ────────────────────────────────────────────────────────────────────
def parse_labels(text: str):
return [lbl.strip() for lbl in text.split(",") if lbl.strip()] or [
"Person", "Object", "Event", "Place", "Document", "Organisation", "Action", "Miscellanous"
]
def pdf_to_text(file) -> str:
reader = PdfReader(file)
parts = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts)
def chunk_text(text: str, chars: int = 3500) -> list[Document]:
docs = []
for i in range(0, len(text), chars):
chunk = text[i:i+chars].strip()
if chunk:
docs.append(Document(text=chunk, metadata={"chunk_id": i // chars}))
return docs
def edges_to_rdf(edges):
"""Convert knowledge-graph-maker edges to RDF-like triples."""
triples = []
for e in edges:
s = (e.node_1.name or "").strip()
p = (e.relationship or "").strip() or "related_to"
o = (e.node_2.name or "").strip()
if s and o:
triples.append({"subject": s, "predicate": p, "object": o})
return triples
from collections import Counter
def count_relation_frequency(triples):
"""Return (freq_triplet, freq_predicate)."""
freq_triplet = Counter((t["subject"], t["predicate"], t["object"]) for t in triples)
freq_predicate = Counter(t["predicate"] for t in triples)
return freq_triplet, freq_predicate
# Color bins for predicate frequency
COLOR_BINS = [
(8, "#2F3B52", "freq β‰₯ 8"),
(5, "#4E6E9E", "5–7"),
(3, "#7FA6F8", "3–4"),
(1, "#BFD3FF", "1–2"),
]
def color_for_predicate(p, freq_pred):
f = freq_pred[p]
if f >= 8: return "#2F3B52"
if f >= 5: return "#4E6E9E"
if f >= 3: return "#7FA6F8"
return "#BFD3FF"
def render_color_legend(freq_pred):
if not freq_pred:
return
counts = {"β‰₯8":0, "5–7":0, "3–4":0, "1–2":0}
for _, f in freq_pred.items():
if f >= 8: counts["β‰₯8"] += 1
elif f >= 5: counts["5–7"] += 1
elif f >= 3: counts["3–4"] += 1
else: counts["1–2"] += 1
st.markdown("#### Legend (predicate frequency β†’ edge color)")
cols = st.columns(4)
bins_disp = [
("#2F3B52", "β‰₯8", counts["β‰₯8"]),
("#4E6E9E", "5–7", counts["5–7"]),
("#7FA6F8", "3–4", counts["3–4"]),
("#BFD3FF", "1–2", counts["1–2"]),
]
for (c, label, cnt), col in zip(bins_disp, cols):
col.markdown(
f"""
<div style="display:flex;align-items:center;gap:8px;">
<div style="width:18px;height:12px;background:{c};border:1px solid #999;"></div>
<div><b>{label}</b> <span style="color:#666">({cnt})</span></div>
</div>
""",
unsafe_allow_html=True
)
# ── PyVis renderer (inline assets, optional labels) ─────────────────────────────
def edges_to_pyvis_with_freq(edges, label_mode: str):
triples = edges_to_rdf(edges)
freq_triplet, freq_pred = count_relation_frequency(triples)
net = Network(
height="700px",
width="100%",
bgcolor="#ffffff",
font_color="#222222",
notebook=False,
directed=False,
cdn_resources="in_line",
)
# βœ… valid JSON (not JS)
net.set_options(json.dumps(PYVIS_OPTIONS))
seen = set()
for t in triples:
s, p, o = t["subject"], t["predicate"], t["object"]
n1, n2 = f"Entity:{s}", f"Entity:{o}"
if n1 not in seen:
net.add_node(n1, label=s, title="Entity")
seen.add(n1)
if n2 not in seen:
net.add_node(n2, label=o, title="Entity")
seen.add(n2)
width_val = int(max(1, freq_triplet[(s, p, o)]))
edge_kwargs = {
"title": p,
"value": width_val,
"color": color_for_predicate(p, freq_pred),
}
if label_mode == "Always visible":
edge_kwargs["label"] = p
net.add_edge(n1, n2, **edge_kwargs)
net.toggle_physics(True)
return net, triples, freq_triplet, freq_pred
# ── Cytoscape.js renderer (embedded HTML; no new Python deps) ───────────────────
def cytoscape_html(triples, freq_triplet, freq_pred, label_mode: str):
"""
Self-contained HTML with Cytoscape.js via CDN.
- Edge width = exact triple frequency
- Edge color = predicate frequency bin
- Labels: nodes always labeled; edges labeled per label_mode
"""
node_ids = {}
nodes, edges = [], []
def node_id(name):
if name not in node_ids:
node_ids[name] = f"n{len(node_ids)}"
nodes.append({"data": {"id": node_ids[name], "label": name}})
return node_ids[name]
for t in triples:
s, p, o = t["subject"], t["predicate"], t["object"]
sid, oid = node_id(s), node_id(o)
width_val = max(1, int(freq_triplet[(s, p, o)]))
color = color_for_predicate(p, freq_pred)
edge_label = p if label_mode == "Always visible" else ""
edges.append({"data": {
"id": f"e{len(edges)}",
"source": sid, "target": oid,
"label": edge_label, "title": p,
"width": width_val, "color": color
}})
elements = nodes + edges
html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width,initial-scale=1" />
<style>
html, body, #cy {{ width: 100%; height: 700px; margin: 0; padding: 0; background: #fff; }}
</style>
<script src="https://unpkg.com/[email protected]/dist/cytoscape.min.js"></script>
</head>
<body>
<div id="cy"></div>
<script>
const elements = {json.dumps(elements)};
const cy = cytoscape({{
container: document.getElementById('cy'),
elements: elements,
style: [
{{
selector: 'node',
style: {{
'label': 'data(label)',
'text-valign': 'center',
'text-halign': 'center',
'font-size': 12,
'background-color': '#76A5FD',
'color': '#222'
}}
}},
{{
selector: 'edge',
style: {{
'line-color': 'data(color)',
'width': 'mapData(width, 1, 10, 1, 10)',
'curve-style': 'bezier',
'target-arrow-shape': 'none',
'label': 'data(label)',
'font-size': 10,
'text-rotation': 'autorotate',
'text-margin-y': -4
}}
}}
],
layout: {{
name: 'cose',
animate: true,
nodeRepulsion: 8000,
idealEdgeLength: 120,
gravity: 1.2,
numIter: 1000
}}
}});
</script>
</body>
</html>
"""
return html
# ── Input tabs ─────────────────────────────────────────────────────────────────
tab_text, tab_pdf = st.tabs(["πŸ“ Paste Text", "πŸ“„ Upload PDF"])
input_text = ""
with tab_text:
input_text = st.text_area("Paste your text here", height=220, placeholder="Paste text…")
with tab_pdf:
pdf_file = st.file_uploader("Upload a PDF", type=["pdf"])
if pdf_file:
input_text = pdf_to_text(pdf_file)
# ── Action ─────────────────────────────────────────────────────────────────────
if st.button("Generate Knowledge Graph", type="primary"):
if not input_text.strip():
st.warning("Please provide text or a PDF.")
st.stop()
if not OPENROUTER_API_KEY:
st.error("OPENROUTER_API_KEY is not set in Space Secrets.")
st.stop()
# Route OpenAI SDK traffic through OpenRouter (OpenAI-compatible)
os.environ["OPENAI_API_KEY"] = OPENROUTER_API_KEY
os.environ["OPENAI_BASE_URL"] = "https://openrouter.ai/api/v1"
os.environ["OPENAI_DEFAULT_HEADERS"] = (
'{"HTTP-Referer":"https://huggingface.co/spaces/blazingbunny/rahulnyk_knowledge_graph",'
'"X-Title":"Knowledge Graph (OpenRouter)"}'
)
selected_model = custom_model.strip() if custom_model.strip() else model_choice
# Ontology
ontology = Ontology(
labels=parse_labels(labels_text),
relationships=[r.strip() for r in relationships_text.split(",") if r.strip()] or
["Relation between any pair of Entities"],
)
st.info("Chunking input and building graph…")
docs = chunk_text(input_text)
# LLM client (OpenRouter via OpenAI client)
llm = OpenAIClient(model=selected_model, temperature=temperature, top_p=top_p)
gm = GraphMaker(ontology=ontology, llm_client=llm, verbose=False)
edges = gm.from_documents(docs, delay_s_between=0)
st.success(f"Graph built with {len(edges)} edges.")
# Show edges table
df = pd.DataFrame([{
"node_1_label": e.node_1.label, "node_1": e.node_1.name,
"node_2_label": e.node_2.label, "node_2": e.node_2.name,
"relationship": e.relationship or "related_to"
} for e in edges])
st.dataframe(df, use_container_width=True)
# ---- Render: PyVis or Cytoscape.js
if renderer == "PyVis (interactive)":
net, triples, freq_triplet, freq_pred = edges_to_pyvis_with_freq(edges, label_mode)
html = net.generate_html() # no disk I/O
st.components.v1.html(html, height=750, scrolling=True)
else:
triples = edges_to_rdf(edges)
freq_triplet, freq_pred = count_relation_frequency(triples)
html = cytoscape_html(triples, freq_triplet, freq_pred, label_mode)
st.components.v1.html(html, height=750, scrolling=True)
# Legend (optional)
if show_legend:
render_color_legend(freq_pred)
# Download RDF tuples as JSON (in-memory bytes, no filesystem)
json_bytes = io.BytesIO(json.dumps(triples, ensure_ascii=False, indent=2).encode("utf-8"))
st.download_button(
"Download RDF tuples (JSON)",
data=json_bytes.getvalue(),
file_name="rdf_tuples.json",
mime="application/json"
)
st.markdown("---")
st.caption("Powered by knowledge-graph-maker via OpenRouter.")