File size: 3,913 Bytes
41f73cb
510db06
41f73cb
cb06d03
 
87c68a6
510db06
8436e7c
41f73cb
 
510db06
 
 
 
 
 
 
 
 
 
 
cb06d03
510db06
 
41f73cb
cb06d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510db06
 
 
 
 
cb06d03
 
510db06
 
 
 
 
 
 
 
 
 
cb06d03
510db06
41f73cb
510db06
c178699
510db06
 
 
cb06d03
87c68a6
 
8436e7c
cb06d03
87c68a6
510db06
41f73cb
510db06
41f73cb
cb06d03
87c68a6
41f73cb
 
510db06
87c68a6
 
 
 
 
 
 
cb06d03
 
87c68a6
 
cb06d03
87c68a6
 
cb06d03
 
 
 
 
 
 
 
 
 
41f73cb
87c68a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter
import graphrag  # Import the graphrag library

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    bert_model = AutoModel.from_pretrained("bert-base-uncased")
    
    # Initialize GraphRAG model
    # Note: You may need to adjust these parameters based on GraphRAG's actual interface
    graph_rag_model = graphrag.GraphRAG(
        bert_model,
        num_labels=2,  # For binary sentiment classification
        num_hidden_layers=2,
        hidden_size=768,
        intermediate_size=3072,
    )
    
    return tokenizer, graph_rag_model

def text_to_graph(text):
    words = text.split()
    G = nx.Graph()
    for i, word in enumerate(words):
        G.add_node(i, word=word)
        if i > 0:
            G.add_edge(i-1, i)
    
    edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()],
                  [e[1] for e in G.edges()] + [e[0] for e in G.edges()]]
    
    return {
        "edge_index": edge_index,
        "num_nodes": len(G.nodes()),
        "node_feat": [[ord(word[0])] for word in words],  # Use ASCII value of first letter as feature
        "edge_attr": [[1] for _ in range(len(G.edges()) * 2)],  # All edges have the same attribute
    }

def analyze_text(text, tokenizer, model):
    # Tokenize the text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    
    # Create graph representation
    graph = text_to_graph(text)
    
    # Combine tokenized input with graph representation
    # Note: You may need to adjust this based on GraphRAG's actual input requirements
    combined_input = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "edge_index": torch.tensor(graph["edge_index"], dtype=torch.long),
        "node_feat": torch.tensor(graph["node_feat"], dtype=torch.float),
        "edge_attr": torch.tensor(graph["edge_attr"], dtype=torch.float),
        "num_nodes": graph["num_nodes"]
    }
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**combined_input)
    
    # Process outputs
    # Note: Adjust this based on GraphRAG's actual output format
    logits = outputs.logits if hasattr(outputs, 'logits') else outputs
    probabilities = torch.softmax(logits, dim=1)
    sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
    confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
    
    return sentiment, confidence, graph

st.title("GraphRAG-based Text Analysis")

tokenizer, model = load_model()

text_input = st.text_area("Enter text for analysis:", height=200)

if st.button("Analyze Text"):
    if text_input:
        sentiment, confidence, graph = analyze_text(text_input, tokenizer, model)
        st.write(f"Sentiment: {sentiment}")
        st.write(f"Confidence: {confidence:.2f}")
        
        # Additional analysis
        word_count = len(text_input.split())
        st.write(f"Word count: {word_count}")
        
        # Most common words
        words = [word.lower() for word in text_input.split() if word.isalnum()]
        word_freq = Counter(words).most_common(5)
        
        st.write("Top 5 most common words:")
        for word, freq in word_freq:
            st.write(f"- {word}: {freq}")
        
        # Visualize graph
        G = nx.Graph()
        G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1]))
        
        plt.figure(figsize=(10, 6))
        nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray')
        plt.title("Text as Graph")
        st.pyplot(plt)
        
    else:
        st.write("Please enter some text to analyze.")