Spaces:
Sleeping
Sleeping
File size: 3,403 Bytes
41f73cb 4bf193c cb06d03 41f73cb cb06d03 87c68a6 8436e7c 41f73cb cb06d03 4bf193c 41f73cb cb06d03 4bf193c cb06d03 41f73cb c178699 cb06d03 87c68a6 8436e7c cb06d03 87c68a6 cb06d03 41f73cb 4bf193c 41f73cb cb06d03 87c68a6 41f73cb 4bf193c 87c68a6 cb06d03 87c68a6 cb06d03 87c68a6 cb06d03 41f73cb 87c68a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import streamlit as st
from transformers import GraphormerForGraphClassification, GraphormerFeatureExtractor
from datasets import Dataset
from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator
import torch
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter
@st.cache_resource
def load_model():
model = GraphormerForGraphClassification.from_pretrained(
"clefourrier/pcqm4mv2_graphormer_base",
num_classes=2, # Binary classification (positive/negative sentiment)
ignore_mismatched_sizes=True,
)
feature_extractor = GraphormerFeatureExtractor.from_pretrained("clefourrier/pcqm4mv2_graphormer_base")
return model, feature_extractor
def text_to_graph(text):
words = text.split()
G = nx.Graph()
for i, word in enumerate(words):
G.add_node(i, word=word)
if i > 0:
G.add_edge(i-1, i)
edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()],
[e[1] for e in G.edges()] + [e[0] for e in G.edges()]]
return {
"edge_index": edge_index,
"num_nodes": len(G.nodes()),
"node_feat": [[ord(word[0])] for word in words], # Use ASCII value of first letter as feature
"edge_attr": [[1] for _ in range(len(G.edges()) * 2)], # All edges have the same attribute
"y": [1] # Placeholder label, will be ignored during inference
}
def analyze_text(text, model, feature_extractor):
graph = text_to_graph(text)
dataset = Dataset.from_dict({"train": [graph]})
dataset_processed = dataset.map(preprocess_item, batched=False)
inputs = GraphormerDataCollator()(dataset_processed["train"])
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=1)
sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
return sentiment, confidence, graph
st.title("Graph-based Text Analysis")
model, feature_extractor = load_model()
text_input = st.text_area("Enter text for analysis:", height=200)
if st.button("Analyze Text"):
if text_input:
sentiment, confidence, graph = analyze_text(text_input, model, feature_extractor)
st.write(f"Sentiment: {sentiment}")
st.write(f"Confidence: {confidence:.2f}")
# Additional analysis
word_count = len(text_input.split())
st.write(f"Word count: {word_count}")
# Most common words
words = [word.lower() for word in text_input.split() if word.isalnum()]
word_freq = Counter(words).most_common(5)
st.write("Top 5 most common words:")
for word, freq in word_freq:
st.write(f"- {word}: {freq}")
# Visualize graph
G = nx.Graph()
G.add_edges_from(zip(graph["edge_index"][0], graph["edge_index"][1]))
plt.figure(figsize=(10, 6))
nx.draw(G, with_labels=False, node_size=30, node_color='lightblue', edge_color='gray')
plt.title("Text as Graph")
st.pyplot(plt)
else:
st.write("Please enter some text to analyze.") |