File size: 3,586 Bytes
41f73cb
510db06
41f73cb
cb06d03
 
87c68a6
c753736
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8436e7c
41f73cb
 
510db06
 
 
c753736
 
 
510db06
 
c753736
cb06d03
510db06
 
41f73cb
cb06d03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510db06
 
 
 
 
cb06d03
 
510db06
c753736
510db06
 
 
 
 
 
 
 
cb06d03
510db06
41f73cb
510db06
c178699
510db06
c753736
510db06
cb06d03
87c68a6
 
8436e7c
cb06d03
87c68a6
c753736
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import streamlit as st
from transformers import AutoTokenizer, AutoModel
import torch
import networkx as nx
import matplotlib.pyplot as plt
from collections import Counter
import graphrag
import inspect

st.title("GraphRAG Module Exploration and Text Analysis")

# Diagnostic section
st.header("GraphRAG Module Contents")
graphrag_contents = dir(graphrag)
st.write("Available attributes and methods in graphrag module:")
for item in graphrag_contents:
    st.write(f"- {item}")
    attr = getattr(graphrag, item)
    if inspect.isclass(attr) or inspect.isfunction(attr):
        st.write(f"  Signature: {inspect.signature(attr)}")
        st.write(f"  Docstring: {attr.__doc__}")

# Attempt to find a suitable model class
model_class = None
for item in graphrag_contents:
    if 'model' in item.lower():
        model_class = getattr(graphrag, item)
        st.write(f"Found potential model class: {item}")
        break

if model_class is None:
    st.error("Could not find a suitable model class in graphrag module.")
    st.stop()

@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    bert_model = AutoModel.from_pretrained("bert-base-uncased")
    
    # Initialize graphrag model
    # Note: This is a placeholder. Adjust based on the actual model class found
    graph_rag_model = model_class(
        bert_model,
        num_labels=2,  # For binary sentiment classification
        # Add or remove parameters based on the actual model's requirements
    )
    
    return tokenizer, graph_rag_model

def text_to_graph(text):
    words = text.split()
    G = nx.Graph()
    for i, word in enumerate(words):
        G.add_node(i, word=word)
        if i > 0:
            G.add_edge(i-1, i)
    
    edge_index = [[e[0] for e in G.edges()] + [e[1] for e in G.edges()],
                  [e[1] for e in G.edges()] + [e[0] for e in G.edges()]]
    
    return {
        "edge_index": edge_index,
        "num_nodes": len(G.nodes()),
        "node_feat": [[ord(word[0])] for word in words],  # Use ASCII value of first letter as feature
        "edge_attr": [[1] for _ in range(len(G.edges()) * 2)],  # All edges have the same attribute
    }

def analyze_text(text, tokenizer, model):
    # Tokenize the text
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    
    # Create graph representation
    graph = text_to_graph(text)
    
    # Combine tokenized input with graph representation
    # Note: This is a placeholder. Adjust based on the actual model's input requirements
    combined_input = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "edge_index": torch.tensor(graph["edge_index"], dtype=torch.long),
        "node_feat": torch.tensor(graph["node_feat"], dtype=torch.float),
        "edge_attr": torch.tensor(graph["edge_attr"], dtype=torch.float),
        "num_nodes": graph["num_nodes"]
    }
    
    # Perform inference
    with torch.no_grad():
        outputs = model(**combined_input)
    
    # Process outputs
    # Note: Adjust this based on the actual model's output format
    logits = outputs.logits if hasattr(outputs, 'logits') else outputs
    probabilities = torch.softmax(logits, dim=1)
    sentiment = "Positive" if probabilities[0][1] > probabilities[0][0] else "Negative"
    confidence = probabilities[0][1].item() if sentiment == "Positive" else probabilities[0][0].item()
    
    return sentiment, confidence, graph

# Rest of the Streamlit app (text input, analysis button, etc.) remains the same...