File size: 5,904 Bytes
ba3b7a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import io
import tempfile
import streamlit as st
import pandas as pd
from pypdf import PdfReader
from pyvis.network import Network

from knowledge_graph_maker import (
    GraphMaker, Ontology, Document,
    OpenAIClient, GroqClient
)

st.set_page_config(page_title="Knowledge Graph Maker", layout="wide")

st.title("Knowledge Graph from Text/PDF (Docker Space)")
st.caption("Uses knowledge-graph-maker with OpenAI or Groq. Paste text or upload a PDF; view the interactive graph below.")

# Choose LLM client based on available env vars
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")

with st.sidebar:
    st.subheader("Model Settings")
    temperature = st.slider("Temperature", 0.0, 1.0, 0.1, 0.05)
    top_p = st.slider("Top-p", 0.0, 1.0, 0.5, 0.05)
    provider = st.radio("Provider", ["OpenAI", "Groq"], index=0 if OPENAI_API_KEY else 1 if GROQ_API_KEY else 0)
    if provider == "OpenAI":
        oai_model = st.text_input("OpenAI model", value="gpt-3.5-turbo")
    else:
        groq_model = st.text_input("Groq model", value="mixtral-8x7b-32768")

    st.markdown("### Ontology (labels)")
    default_labels = [
        {"Person": "Person name without adjectives (may appear as name or pronoun)"},
        {"Object": "Avoid the definite article 'the' in name"},
        {"Event": "Events involving multiple people; no verbs like gives/leaves"},
        "Place", "Document", "Organisation", "Action",
        {"Miscellanous": "Important concept that fits none of the above"}
    ]
    labels_text = st.text_area("Labels (JSON or comma-separated)", value=", ".join(
        [lbl if isinstance(lbl, str) else list(lbl.keys())[0] for lbl in default_labels]
    ), height=80)
    st.markdown("### Relationships focus")
    relationships_text = st.text_input("Relationships (comma-separated)", value="Relation between any pair of Entities")

def parse_labels(text):
    # Allow simple "A, B, C" input; fall back to defaults above if empty
    if not text.strip():
        return [ "Person","Object","Event","Place","Document","Organisation","Action","Miscellanous" ]
    return [lbl.strip() for lbl in text.split(",") if lbl.strip()]

def split_pdf(file) -> str:
    reader = PdfReader(file)
    parts = []
    for page in reader.pages:
        try:
            parts.append(page.extract_text() or "")
        except Exception:
            continue
    return "\n".join(parts)

def build_graph_documents(text: str) -> list[Document]:
    # Simple chunking: ~900-1000 tokens ≈ ~3000-4000 chars heuristic
    # Adjust if needed.
    CHARS = 3500
    docs = []
    for i in range(0, len(text), CHARS):
        chunk = text[i:i+CHARS].strip()
        if chunk:
            docs.append(Document(text=chunk, metadata={"chunk_id": i//CHARS}))
    return docs

def edges_to_pyvis(edges):
    net = Network(height="700px", width="100%", bgcolor="#ffffff", font_color="#222222", notebook=False, directed=False)
    # Simple map to keep unique node IDs
    node_ids = {}

    def node_key(label, name): return f"{label}:{name}"

    for e in edges:
        n1 = node_key(e.node_1.label, e.node_1.name)
        n2 = node_key(e.node_2.label, e.node_2.name)

        if n1 not in node_ids:
            net.add_node(n1, label=e.node_1.name, title=e.node_1.label)
            node_ids[n1] = True
        if n2 not in node_ids:
            net.add_node(n2, label=e.node_2.name, title=e.node_2.label)
            node_ids[n2] = True

        rel = e.relationship or ""
        net.add_edge(n1, n2, title=rel, value=1)

    net.toggle_physics(True)
    return net

# Input UI
tab_text, tab_pdf = st.tabs(["📝 Paste Text", "📄 Upload PDF"])
input_text = ""
with tab_text:
    input_text = st.text_area("Paste your text here", height=220, placeholder="Paste text…")
with tab_pdf:
    pdf_file = st.file_uploader("Upload a PDF", type=["pdf"])
    if pdf_file:
        input_text = split_pdf(pdf_file)

if st.button("Generate Knowledge Graph", type="primary"):
    if not input_text.strip():
        st.warning("Please provide text or a PDF.")
        st.stop()

    # Prepare LLM client
    if provider == "OpenAI":
        if not OPENAI_API_KEY:
            st.error("OPENAI_API_KEY is not set in the Space Secrets.")
            st.stop()
        llm = OpenAIClient(model=oai_model, temperature=temperature, top_p=top_p)
    else:
        if not GROQ_API_KEY:
            st.error("GROQ_API_KEY is not set in the Space Secrets.")
            st.stop()
        llm = GroqClient(model=groq_model, temperature=temperature, top_p=top_p)

    # Ontology
    ontology = Ontology(
        labels=parse_labels(labels_text),
        relationships=[r.strip() for r in relationships_text.split(",") if r.strip()] or ["Relation between any pair of Entities"]
    )

    st.info("Chunking input and building graph… this may take a bit for longer texts.")

    gm = GraphMaker(ontology=ontology, llm_client=llm, verbose=False)
    docs = build_graph_documents(input_text)

    edges = gm.from_documents(docs, delay_s_between=0)  # tune delay for rate limits
    st.success(f"Graph built with {len(edges)} edges.")

    # Show edges table
    df = pd.DataFrame([{
        "node_1_label": e.node_1.label, "node_1": e.node_1.name,
        "node_2_label": e.node_2.label, "node_2": e.node_2.name,
        "relationship": e.relationship
    } for e in edges])
    st.dataframe(df, use_container_width=True)

    # Render with PyVis inside Streamlit
    net = edges_to_pyvis(edges)
    with tempfile.TemporaryDirectory() as td:
        html_path = os.path.join(td, "graph.html")
        net.save_graph(html_path)
        html_content = open(html_path, "r", encoding="utf-8").read()
        st.components.v1.html(html_content, height=750, scrolling=True)

st.markdown("---")
st.caption("Built with [knowledge-graph-maker](https://github.com/rahulnyk/knowledge_graph_maker).")