File size: 3,403 Bytes
41f73cb
4bf193c
cb06d03
 
41f73cb
cb06d03
 
87c68a6
8436e7c
41f73cb
 
cb06d03
 
 
 
 
4bf193c
 
41f73cb
cb06d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bf193c
cb06d03
 
 
 
 
 
 
41f73cb
 
c178699
cb06d03
 
87c68a6
 
8436e7c
cb06d03
87c68a6
cb06d03
41f73cb
4bf193c
41f73cb
cb06d03
87c68a6
41f73cb
 
4bf193c
87c68a6
 
 
 
 
 
 
cb06d03
 
87c68a6
 
cb06d03
87c68a6
 
cb06d03
 
 
 
 
 
 
 
 
 
41f73cb
87c68a6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import streamlit as st
from transformers import GraphormerForGraphClassification, GraphormerFeatureExtractor
from datasets import Dataset
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
import torch
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter

@st.cache_resource
def load_model():
    model = GraphormerForGraphClassification.from_pretrained(
        "clefourrier/pcqm4mv2_graphormer_base",
        num_classes=2,  # Binary classification (positive/negative sentiment)
        ignore_mismatched_sizes=True,
    )
    feature_extractor = GraphormerFeatureExtractor.from_pretrained("clefourrier/pcqm4mv2_graphormer_base")
    return model, feature_extractor

def text_to_graph(text):
    words = text.split()
    G = nx.Graph()
    for i, word in enumerate(words):
        G.add_node(i, word=word)
        if i > 0:
            G.add_edge(i-1, i)
    
    edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()],
                  [e[1] for e in G.edges()] + [e[0] for e in G.edges()]]
    
    return {
        "edge_index": edge_index,
        "num_nodes": len(G.nodes()),
        "node_feat": [[ord(word[0])] for word in words],  # Use ASCII value of first letter as feature
        "edge_attr": [[1] for _ in range(len(G.edges()) * 2)],  # All edges have the same attribute
        "y": [1]  # Placeholder label, will be ignored during inference
    }

def analyze_text(text, model, feature_extractor):
    graph = text_to_graph(text)
    dataset = Dataset.from_dict({"train": [graph]})
    dataset_processed = dataset.map(preprocess_item, batched=False)
    
    inputs = GraphormerDataCollator()(dataset_processed["train"])
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim=1)
    sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
    confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
    
    return sentiment, confidence, graph

st.title("Graph-based Text Analysis")

model, feature_extractor = load_model()

text_input = st.text_area("Enter text for analysis:", height=200)

if st.button("Analyze Text"):
    if text_input:
        sentiment, confidence, graph = analyze_text(text_input, model, feature_extractor)
        st.write(f"Sentiment: {sentiment}")
        st.write(f"Confidence: {confidence:.2f}")
        
        # Additional analysis
        word_count = len(text_input.split())
        st.write(f"Word count: {word_count}")
        
        # Most common words
        words = [word.lower() for word in text_input.split() if word.isalnum()]
        word_freq = Counter(words).most_common(5)
        
        st.write("Top 5 most common words:")
        for word, freq in word_freq:
            st.write(f"- {word}: {freq}")
        
        # Visualize graph
        G = nx.Graph()
        G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1]))
        
        plt.figure(figsize=(10, 6))
        nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray')
        plt.title("Text as Graph")
        st.pyplot(plt)
        
    else:
        st.write("Please enter some text to analyze.")