blazingbunny commited on
Commit
ba3b7a5
·
verified ·
1 Parent(s): a1f3d10

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -0
app.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import tempfile
4
+ import streamlit as st
5
+ import pandas as pd
6
+ from pypdf import PdfReader
7
+ from pyvis.network import Network
8
+
9
+ from knowledge_graph_maker import (
10
+ GraphMaker, Ontology, Document,
11
+ OpenAIClient, GroqClient
12
+ )
13
+
14
+ st.set_page_config(page_title="Knowledge Graph Maker", layout="wide")
15
+
16
+ st.title("Knowledge Graph from Text/PDF (Docker Space)")
17
+ st.caption("Uses knowledge-graph-maker with OpenAI or Groq. Paste text or upload a PDF; view the interactive graph below.")
18
+
19
+ # Choose LLM client based on available env vars
20
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
21
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
22
+
23
+ with st.sidebar:
24
+ st.subheader("Model Settings")
25
+ temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.05)
26
+ top_p = st.slider("Top-p", 0.0, 1.0, 0.5, 0.05)
27
+ provider = st.radio("Provider", ["OpenAI", "Groq"], index=0 if OPENAI_API_KEY else 1 if GROQ_API_KEY else 0)
28
+ if provider == "OpenAI":
29
+ oai_model = st.text_input("OpenAI model", value="gpt-3.5-turbo")
30
+ else:
31
+ groq_model = st.text_input("Groq model", value="mixtral-8x7b-32768")
32
+
33
+ st.markdown("### Ontology (labels)")
34
+ default_labels = [
35
+ {"Person": "Person name without adjectives (may appear as name or pronoun)"},
36
+ {"Object": "Avoid the definite article 'the' in name"},
37
+ {"Event": "Events involving multiple people; no verbs like gives/leaves"},
38
+ "Place", "Document", "Organisation", "Action",
39
+ {"Miscellanous": "Important concept that fits none of the above"}
40
+ ]
41
+ labels_text = st.text_area("Labels (JSON or comma-separated)", value=", ".join(
42
+ [lbl if isinstance(lbl, str) else list(lbl.keys())[0] for lbl in default_labels]
43
+ ), height=80)
44
+ st.markdown("### Relationships focus")
45
+ relationships_text = st.text_input("Relationships (comma-separated)", value="Relation between any pair of Entities")
46
+
47
+ def parse_labels(text):
48
+ # Allow simple "A, B, C" input; fall back to defaults above if empty
49
+ if not text.strip():
50
+ return [ "Person","Object","Event","Place","Document","Organisation","Action","Miscellanous" ]
51
+ return [lbl.strip() for lbl in text.split(",") if lbl.strip()]
52
+
53
+ def split_pdf(file) -> str:
54
+ reader = PdfReader(file)
55
+ parts = []
56
+ for page in reader.pages:
57
+ try:
58
+ parts.append(page.extract_text() or "")
59
+ except Exception:
60
+ continue
61
+ return "\n".join(parts)
62
+
63
+ def build_graph_documents(text: str) -> list[Document]:
64
+ # Simple chunking: ~900-1000 tokens ≈ ~3000-4000 chars heuristic
65
+ # Adjust if needed.
66
+ CHARS = 3500
67
+ docs = []
68
+ for i in range(0, len(text), CHARS):
69
+ chunk = text[i:i+CHARS].strip()
70
+ if chunk:
71
+ docs.append(Document(text=chunk, metadata={"chunk_id": i//CHARS}))
72
+ return docs
73
+
74
+ def edges_to_pyvis(edges):
75
+ net = Network(height="700px", width="100%", bgcolor="#ffffff", font_color="#222222", notebook=False, directed=False)
76
+ # Simple map to keep unique node IDs
77
+ node_ids = {}
78
+
79
+ def node_key(label, name): return f"{label}:{name}"
80
+
81
+ for e in edges:
82
+ n1 = node_key(e.node_1.label, e.node_1.name)
83
+ n2 = node_key(e.node_2.label, e.node_2.name)
84
+
85
+ if n1 not in node_ids:
86
+ net.add_node(n1, label=e.node_1.name, title=e.node_1.label)
87
+ node_ids[n1] = True
88
+ if n2 not in node_ids:
89
+ net.add_node(n2, label=e.node_2.name, title=e.node_2.label)
90
+ node_ids[n2] = True
91
+
92
+ rel = e.relationship or ""
93
+ net.add_edge(n1, n2, title=rel, value=1)
94
+
95
+ net.toggle_physics(True)
96
+ return net
97
+
98
+ # Input UI
99
+ tab_text, tab_pdf = st.tabs(["📝 Paste Text", "📄 Upload PDF"])
100
+ input_text = ""
101
+ with tab_text:
102
+ input_text = st.text_area("Paste your text here", height=220, placeholder="Paste text…")
103
+ with tab_pdf:
104
+ pdf_file = st.file_uploader("Upload a PDF", type=["pdf"])
105
+ if pdf_file:
106
+ input_text = split_pdf(pdf_file)
107
+
108
+ if st.button("Generate Knowledge Graph", type="primary"):
109
+ if not input_text.strip():
110
+ st.warning("Please provide text or a PDF.")
111
+ st.stop()
112
+
113
+ # Prepare LLM client
114
+ if provider == "OpenAI":
115
+ if not OPENAI_API_KEY:
116
+ st.error("OPENAI_API_KEY is not set in the Space Secrets.")
117
+ st.stop()
118
+ llm = OpenAIClient(model=oai_model, temperature=temperature, top_p=top_p)
119
+ else:
120
+ if not GROQ_API_KEY:
121
+ st.error("GROQ_API_KEY is not set in the Space Secrets.")
122
+ st.stop()
123
+ llm = GroqClient(model=groq_model, temperature=temperature, top_p=top_p)
124
+
125
+ # Ontology
126
+ ontology = Ontology(
127
+ labels=parse_labels(labels_text),
128
+ relationships=[r.strip() for r in relationships_text.split(",") if r.strip()] or ["Relation between any pair of Entities"]
129
+ )
130
+
131
+ st.info("Chunking input and building graph… this may take a bit for longer texts.")
132
+
133
+ gm = GraphMaker(ontology=ontology, llm_client=llm, verbose=False)
134
+ docs = build_graph_documents(input_text)
135
+
136
+ edges = gm.from_documents(docs, delay_s_between=0) # tune delay for rate limits
137
+ st.success(f"Graph built with {len(edges)} edges.")
138
+
139
+ # Show edges table
140
+ df = pd.DataFrame([{
141
+ "node_1_label": e.node_1.label, "node_1": e.node_1.name,
142
+ "node_2_label": e.node_2.label, "node_2": e.node_2.name,
143
+ "relationship": e.relationship
144
+ } for e in edges])
145
+ st.dataframe(df, use_container_width=True)
146
+
147
+ # Render with PyVis inside Streamlit
148
+ net = edges_to_pyvis(edges)
149
+ with tempfile.TemporaryDirectory() as td:
150
+ html_path = os.path.join(td, "graph.html")
151
+ net.save_graph(html_path)
152
+ html_content = open(html_path, "r", encoding="utf-8").read()
153
+ st.components.v1.html(html_content, height=750, scrolling=True)
154
+
155
+ st.markdown("---")
156
+ st.caption("Built with [knowledge-graph-maker](https://github.com/rahulnyk/knowledge_graph_maker).")