blazingbunny commited on
Commit
4310abc
Β·
verified Β·
1 Parent(s): dcec766

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -36
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
 
2
  import streamlit as st
3
  import pandas as pd
 
4
  from pypdf import PdfReader
5
  from pyvis.network import Network
6
 
@@ -11,12 +13,12 @@ from knowledge_graph_maker import (
11
  # ── Page setup ──────────────────────────────────────────────────────────────────
12
  st.set_page_config(page_title="Knowledge Graph (OpenRouter)", layout="wide")
13
  st.title("Knowledge Graph from Text/PDF β€” OpenRouter")
14
- st.caption("Builds a knowledge graph with knowledge-graph-maker via OpenRouter. Paste text or upload a PDF; choose a model.")
15
 
16
  # ── Secrets / env ───────────────────────────────────────────────────────────────
17
  OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
18
 
19
- # Preset OpenRouter models (you can add more)
20
  OPENROUTER_MODELS = [
21
  "openai/gpt-oss-20b:free",
22
  "moonshotai/kimi-k2:free",
@@ -24,17 +26,15 @@ OPENROUTER_MODELS = [
24
  "google/gemma-3-27b-it:free",
25
  ]
26
 
 
 
 
 
 
 
27
  # ── Sidebar controls ───────────────────────────────────────────────────────────
28
  with st.sidebar:
29
  st.subheader("Model & Generation Settings")
30
-
31
- # Model choices
32
- OPENROUTER_MODELS = [
33
- "openai/gpt-oss-20b:free",
34
- "moonshotai/kimi-k2:free",
35
- "google/gemini-2.0-flash-exp:free",
36
- "google/gemma-3-27b-it:free",
37
- ]
38
  model_choice = st.selectbox("OpenRouter model", OPENROUTER_MODELS, index=0)
39
  custom_model = st.text_input("Custom model id (optional)", placeholder="e.g. meta-llama/llama-3.1-8b-instruct")
40
 
@@ -47,14 +47,11 @@ with st.sidebar:
47
  preset_names = list(PRESETS.keys())
48
  preset = st.selectbox("Choose a preset", preset_names, index=0,
49
  help=PRESETS[preset_names[0]]["desc"])
50
-
51
- # Apply preset button updates the sliders below
52
  if st.button("Apply preset"):
53
  st.session_state.temperature = PRESETS[preset]["temperature"]
54
  st.session_state.top_p = PRESETS[preset]["top_p"]
55
  st.toast(f"Applied: {preset}", icon="βœ…")
56
 
57
- # Sliders are bound to session state so the button can set them
58
  temperature = st.slider(
59
  "Temperature", 0.0, 1.0, key="temperature", step=0.05,
60
  help="Lower = more deterministic; higher = more variety"
@@ -64,7 +61,6 @@ with st.sidebar:
64
  help="Nucleus sampling threshold; 0.9 is a good default"
65
  )
66
 
67
- # Ontology controls
68
  st.markdown("### Ontology (labels)")
69
  labels_text = st.text_area(
70
  "Comma-separated labels",
@@ -75,6 +71,12 @@ with st.sidebar:
75
  "Relationships (comma-separated)",
76
  value="Relation between any pair of Entities",
77
  )
 
 
 
 
 
 
78
  # ── Helpers ────────────────────────────────────────────────────────────────────
79
  def parse_labels(text: str):
80
  return [lbl.strip() for lbl in text.split(",") if lbl.strip()] or [
@@ -99,8 +101,79 @@ def chunk_text(text: str, chars: int = 3500) -> list[Document]:
99
  docs.append(Document(text=chunk, metadata={"chunk_id": i // chars}))
100
  return docs
101
 
102
- def edges_to_pyvis(edges):
103
- # IMPORTANT: cdn_resources="in_line" prevents PyVis from creating a ./lib folder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  net = Network(
105
  height="700px",
106
  width="100%",
@@ -108,21 +181,137 @@ def edges_to_pyvis(edges):
108
  font_color="#222222",
109
  notebook=False,
110
  directed=False,
111
- cdn_resources="in_line",
112
  )
113
- node_ids = {}
114
- for e in edges:
115
- n1 = f"{e.node_1.label}:{e.node_1.name}"
116
- n2 = f"{e.node_2.label}:{e.node_2.name}"
117
- if n1 not in node_ids:
118
- net.add_node(n1, label=e.node_1.name, title=e.node_1.label)
119
- node_ids[n1] = True
120
- if n2 not in node_ids:
121
- net.add_node(n2, label=e.node_2.name, title=e.node_2.label)
122
- node_ids[n2] = True
123
- net.add_edge(n1, n2, title=e.relationship or "", value=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  net.toggle_physics(True)
125
- return net
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # ── Input tabs ─────────────────────────────────────────────────────────────────
128
  tab_text, tab_pdf = st.tabs(["πŸ“ Paste Text", "πŸ“„ Upload PDF"])
@@ -168,22 +357,40 @@ if st.button("Generate Knowledge Graph", type="primary"):
168
  llm = OpenAIClient(model=selected_model, temperature=temperature, top_p=top_p)
169
 
170
  gm = GraphMaker(ontology=ontology, llm_client=llm, verbose=False)
171
- edges = gm.from_documents(docs, delay_s_between=0) # tweak delay for rate limits if needed
172
 
173
  st.success(f"Graph built with {len(edges)} edges.")
174
 
175
- # Show edge table
176
  df = pd.DataFrame([{
177
  "node_1_label": e.node_1.label, "node_1": e.node_1.name,
178
  "node_2_label": e.node_2.label, "node_2": e.node_2.name,
179
- "relationship": e.relationship
180
  } for e in edges])
181
  st.dataframe(df, use_container_width=True)
182
 
183
- # Render graph in-memory (no writes)
184
- net = edges_to_pyvis(edges)
185
- html = net.generate_html() # <- avoids creating ./lib
186
- st.components.v1.html(html, height=750, scrolling=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  st.markdown("---")
189
  st.caption("Powered by knowledge-graph-maker via OpenRouter.")
 
1
  import os
2
+ import json
3
  import streamlit as st
4
  import pandas as pd
5
+ from collections import Counter
6
  from pypdf import PdfReader
7
  from pyvis.network import Network
8
 
 
13
  # ── Page setup ──────────────────────────────────────────────────────────────────
14
  st.set_page_config(page_title="Knowledge Graph (OpenRouter)", layout="wide")
15
  st.title("Knowledge Graph from Text/PDF β€” OpenRouter")
16
+ st.caption("Builds a knowledge graph with knowledge-graph-maker via OpenRouter. Pick a model, choose presets, and render via PyVis or Cytoscape.js.")
17
 
18
  # ── Secrets / env ───────────────────────────────────────────────────────────────
19
  OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY", "")
20
 
21
+ # Preset OpenRouter models (extend as needed)
22
  OPENROUTER_MODELS = [
23
  "openai/gpt-oss-20b:free",
24
  "moonshotai/kimi-k2:free",
 
26
  "google/gemma-3-27b-it:free",
27
  ]
28
 
29
+ # ---- Preset defaults in session state ----
30
+ if "temperature" not in st.session_state:
31
+ st.session_state.temperature = 0.1
32
+ if "top_p" not in st.session_state:
33
+ st.session_state.top_p = 0.9
34
+
35
  # ── Sidebar controls ───────────────────────────────────────────────────────────
36
  with st.sidebar:
37
  st.subheader("Model & Generation Settings")
 
 
 
 
 
 
 
 
38
  model_choice = st.selectbox("OpenRouter model", OPENROUTER_MODELS, index=0)
39
  custom_model = st.text_input("Custom model id (optional)", placeholder="e.g. meta-llama/llama-3.1-8b-instruct")
40
 
 
47
  preset_names = list(PRESETS.keys())
48
  preset = st.selectbox("Choose a preset", preset_names, index=0,
49
  help=PRESETS[preset_names[0]]["desc"])
 
 
50
  if st.button("Apply preset"):
51
  st.session_state.temperature = PRESETS[preset]["temperature"]
52
  st.session_state.top_p = PRESETS[preset]["top_p"]
53
  st.toast(f"Applied: {preset}", icon="βœ…")
54
 
 
55
  temperature = st.slider(
56
  "Temperature", 0.0, 1.0, key="temperature", step=0.05,
57
  help="Lower = more deterministic; higher = more variety"
 
61
  help="Nucleus sampling threshold; 0.9 is a good default"
62
  )
63
 
 
64
  st.markdown("### Ontology (labels)")
65
  labels_text = st.text_area(
66
  "Comma-separated labels",
 
71
  "Relationships (comma-separated)",
72
  value="Relation between any pair of Entities",
73
  )
74
+
75
+ st.markdown("### Visualization")
76
+ renderer = st.radio("Renderer", ["PyVis (interactive)", "Cytoscape.js (beta)"], index=0)
77
+ label_mode = st.radio("Edge labels", ["Always visible", "Tooltip only"], index=0)
78
+ show_legend = st.checkbox("Show color legend", value=True)
79
+
80
  # ── Helpers ────────────────────────────────────────────────────────────────────
81
  def parse_labels(text: str):
82
  return [lbl.strip() for lbl in text.split(",") if lbl.strip()] or [
 
101
  docs.append(Document(text=chunk, metadata={"chunk_id": i // chars}))
102
  return docs
103
 
104
+ def edges_to_rdf(edges):
105
+ """Convert knowledge-graph-maker edges to RDF-like triples."""
106
+ triples = []
107
+ for e in edges:
108
+ s = (e.node_1.name or "").strip()
109
+ p = (e.relationship or "").strip() or "related_to"
110
+ o = (e.node_2.name or "").strip()
111
+ if s and o:
112
+ triples.append({"subject": s, "predicate": p, "object": o})
113
+ return triples
114
+
115
+ def count_relation_frequency(triples):
116
+ """Return (freq_triplet, freq_predicate)."""
117
+ freq_triplet = Counter((t["subject"], t["predicate"], t["object"]) for t in triples)
118
+ freq_predicate = Counter(t["predicate"] for t in triples)
119
+ return freq_triplet, freq_predicate
120
+
121
+ # Color bins for predicate frequency
122
+ COLOR_BINS = [
123
+ (8, "#2F3B52", "freq β‰₯ 8"),
124
+ (5, "#4E6E9E", "5–7"),
125
+ (3, "#7FA6F8", "3–4"),
126
+ (1, "#BFD3FF", "1–2"),
127
+ ]
128
+ def color_for_predicate(p, freq_pred):
129
+ f = freq_pred[p]
130
+ if f >= 8: return "#2F3B52"
131
+ if f >= 5: return "#4E6E9E"
132
+ if f >= 3: return "#7FA6F8"
133
+ return "#BFD3FF"
134
+
135
+ def render_color_legend(freq_pred):
136
+ if not freq_pred:
137
+ return
138
+ # Display bins and a small summary of predicate counts in each bin
139
+ counts = {"β‰₯8":0, "5–7":0, "3–4":0, "1–2":0}
140
+ for p, f in freq_pred.items():
141
+ if f >= 8: counts["β‰₯8"] += 1
142
+ elif f >= 5: counts["5–7"] += 1
143
+ elif f >= 3: counts["3–4"] += 1
144
+ else: counts["1–2"] += 1
145
+
146
+ st.markdown("#### Legend (predicate frequency β†’ edge color)")
147
+ cols = st.columns(4)
148
+ bins_disp = [
149
+ ("#2F3B52", "β‰₯8", counts["β‰₯8"]),
150
+ ("#4E6E9E", "5–7", counts["5–7"]),
151
+ ("#7FA6F8", "3–4", counts["3–4"]),
152
+ ("#BFD3FF", "1–2", counts["1–2"]),
153
+ ]
154
+ for (c, label, cnt), col in zip(bins_disp, cols):
155
+ col.markdown(
156
+ f"""
157
+ <div style="display:flex;align-items:center;gap:8px;">
158
+ <div style="width:18px;height:12px;background:{c};border:1px solid #999;"></div>
159
+ <div><b>{label}</b> <span style="color:#666">({cnt})</span></div>
160
+ </div>
161
+ """,
162
+ unsafe_allow_html=True
163
+ )
164
+
165
+ # ── PyVis renderer (inline assets, optional labels) ─────────────────────────────
166
+ def edges_to_pyvis_with_freq(edges, label_mode: str):
167
+ """
168
+ Render PyVis graph with:
169
+ - visible edge labels (predicate) OR tooltip-only
170
+ - edge width scaled by exact triple frequency
171
+ - edge color based on predicate frequency
172
+ - inline assets (no filesystem writes)
173
+ """
174
+ triples = edges_to_rdf(edges)
175
+ freq_triplet, freq_pred = count_relation_frequency(triples)
176
+
177
  net = Network(
178
  height="700px",
179
  width="100%",
 
181
  font_color="#222222",
182
  notebook=False,
183
  directed=False,
184
+ cdn_resources="in_line", # avoid writing ./lib
185
  )
186
+
187
+ net.set_options("""
188
+ const options = {
189
+ edges: {
190
+ font: { size: 12, align: "middle" },
191
+ smooth: { type: "dynamic" },
192
+ scaling: { min: 1, max: 10 }
193
+ },
194
+ physics: { stabilization: true }
195
+ }
196
+ """)
197
+
198
+ seen = set()
199
+ for t in triples:
200
+ s, p, o = t["subject"], t["predicate"], t["object"]
201
+ n1, n2 = f"Entity:{s}", f"Entity:{o}"
202
+
203
+ if n1 not in seen:
204
+ net.add_node(n1, label=s, title="Entity")
205
+ seen.add(n1)
206
+ if n2 not in seen:
207
+ net.add_node(n2, label=o, title="Entity")
208
+ seen.add(n2)
209
+
210
+ width_val = int(max(1, freq_triplet[(s, p, o)]))
211
+ edge_kwargs = {
212
+ "title": p, # tooltip always available
213
+ "value": width_val, # width scales with frequency
214
+ "color": color_for_predicate(p, freq_pred),
215
+ }
216
+ if label_mode == "Always visible":
217
+ edge_kwargs["label"] = p # visible text on the edge
218
+ net.add_edge(n1, n2, **edge_kwargs)
219
+
220
  net.toggle_physics(True)
221
+ return net, triples, freq_triplet, freq_pred
222
+
223
+ # ── Cytoscape.js renderer (embedded HTML; no new Python deps) ───────────────────
224
+ def cytoscape_html(triples, freq_triplet, freq_pred, label_mode: str):
225
+ """
226
+ Build a self-contained HTML that renders Cytoscape.js via CDN.
227
+ - Edge width = exact triple frequency
228
+ - Edge color = predicate frequency bin
229
+ - Labels: nodes always have labels; edges show label depending on label_mode
230
+ """
231
+ # Build node and edge arrays
232
+ node_ids = {}
233
+ nodes = []
234
+ edges = []
235
+ def node_id(name):
236
+ if name not in node_ids:
237
+ node_ids[name] = f"n{len(node_ids)}"
238
+ nodes.append({"data": {"id": node_ids[name], "label": name}})
239
+ return node_ids[name]
240
+
241
+ for t in triples:
242
+ s, p, o = t["subject"], t["predicate"], t["object"]
243
+ sid, oid = node_id(s), node_id(o)
244
+ width_val = max(1, int(freq_triplet[(s, p, o)]))
245
+ color = color_for_predicate(p, freq_pred)
246
+ edge_label = p if label_mode == "Always visible" else "" # hide label if tooltip-only
247
+ edges.append({"data": {
248
+ "id": f"e{len(edges)}",
249
+ "source": sid, "target": oid,
250
+ "label": edge_label, "title": p,
251
+ "width": width_val, "color": color
252
+ }})
253
+
254
+ elements = nodes + edges
255
+ # Cytoscape style: show edge labels if present, else none; tooltip via title isn't native,
256
+ # but vis is clean and fast for large graphs.
257
+ html = f"""
258
+ <!DOCTYPE html>
259
+ <html>
260
+ <head>
261
+ <meta charset="utf-8" />
262
+ <meta name="viewport" content="width=device-width,initial-scale=1" />
263
+ <style>
264
+ html, body, #cy {{ width: 100%; height: 700px; margin: 0; padding: 0; background: #fff; }}
265
+ </style>
266
+ <script src="https://unpkg.com/[email protected]/dist/cytoscape.min.js"></script>
267
+ </head>
268
+ <body>
269
+ <div id="cy"></div>
270
+ <script>
271
+ const elements = {json.dumps(elements)};
272
+ const cy = cytoscape({{
273
+ container: document.getElementById('cy'),
274
+ elements: elements,
275
+ style: [
276
+ {{
277
+ selector: 'node',
278
+ style: {{
279
+ 'label': 'data(label)',
280
+ 'text-valign': 'center',
281
+ 'text-halign': 'center',
282
+ 'font-size': 12,
283
+ 'background-color': '#76A5FD',
284
+ 'color': '#222'
285
+ }}
286
+ }},
287
+ {{
288
+ selector: 'edge',
289
+ style: {{
290
+ 'line-color': 'data(color)',
291
+ 'width': 'mapData(width, 1, 10, 1, 10)',
292
+ 'curve-style': 'bezier',
293
+ 'target-arrow-shape': 'none',
294
+ 'label': 'data(label)',
295
+ 'font-size': 10,
296
+ 'text-rotation': 'autorotate',
297
+ 'text-margin-y': -4
298
+ }}
299
+ }}
300
+ ],
301
+ layout: {{
302
+ name: 'cose',
303
+ animate: true,
304
+ nodeRepulsion: 8000,
305
+ idealEdgeLength: 120,
306
+ gravity: 1.2,
307
+ numIter: 1000
308
+ }}
309
+ }});
310
+ </script>
311
+ </body>
312
+ </html>
313
+ """
314
+ return html
315
 
316
  # ── Input tabs ─────────────────────────────────────────────────────────────────
317
  tab_text, tab_pdf = st.tabs(["πŸ“ Paste Text", "πŸ“„ Upload PDF"])
 
357
  llm = OpenAIClient(model=selected_model, temperature=temperature, top_p=top_p)
358
 
359
  gm = GraphMaker(ontology=ontology, llm_client=llm, verbose=False)
360
+ edges = gm.from_documents(docs, delay_s_between=0)
361
 
362
  st.success(f"Graph built with {len(edges)} edges.")
363
 
364
+ # Show edges table
365
  df = pd.DataFrame([{
366
  "node_1_label": e.node_1.label, "node_1": e.node_1.name,
367
  "node_2_label": e.node_2.label, "node_2": e.node_2.name,
368
+ "relationship": e.relationship or "related_to"
369
  } for e in edges])
370
  st.dataframe(df, use_container_width=True)
371
 
372
+ # ---- Render: PyVis or Cytoscape.js
373
+ if renderer == "PyVis (interactive)":
374
+ net, triples, freq_triplet, freq_pred = edges_to_pyvis_with_freq(edges, label_mode)
375
+ html = net.generate_html() # no disk I/O
376
+ st.components.v1.html(html, height=750, scrolling=True)
377
+ else:
378
+ triples = edges_to_rdf(edges)
379
+ freq_triplet, freq_pred = count_relation_frequency(triples)
380
+ html = cytoscape_html(triples, freq_triplet, freq_pred, label_mode)
381
+ st.components.v1.html(html, height=750, scrolling=True)
382
+
383
+ # Legend (optional)
384
+ if show_legend:
385
+ render_color_legend(freq_pred)
386
+
387
+ # Download RDF tuples as JSON
388
+ st.download_button(
389
+ "Download RDF tuples (JSON)",
390
+ data=pd.Series(triples).to_json(orient="values"),
391
+ file_name="rdf_tuples.json",
392
+ mime="application/json"
393
+ )
394
 
395
  st.markdown("---")
396
  st.caption("Powered by knowledge-graph-maker via OpenRouter.")