File size: 5,125 Bytes
e6d6007
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import requests
from bs4 import BeautifulSoup
import pandas as pd

# Define the model class (matching the saved architecture)
class HeadlineClassifier(torch.nn.Module):
    def __init__(self, num_aspect_classes, num_polarity_classes):
        super(HeadlineClassifier, self).__init__()
        self.bert = BertModel.from_pretrained("sagorsarker/bangla-bert-base", return_dict=False)
        self.drop = torch.nn.Dropout(0.5)
        self.aspect_out = torch.nn.Linear(self.bert.config.hidden_size, num_aspect_classes)
        self.polarity_out = torch.nn.Linear(self.bert.config.hidden_size, num_polarity_classes)

    def forward(self, input_ids, attention_mask):
        _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
        output = self.drop(pooled_output)
        aspect_output = self.aspect_out(output)
        polarity_output = self.polarity_out(output)
        return aspect_output, polarity_output

# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
model = HeadlineClassifier(num_aspect_classes=4, num_polarity_classes=3)
model.load_state_dict(torch.load('best_model_state1.bin', map_location=torch.device('cpu')))
model.eval()

# Class labels
aspect_class_names = ["others", "politics", "religion", "sports"]
polarity_class_names = ["negative", "neutral", "positive"]

# Function for single text prediction
def predict_text(text):
    encoded = tokenizer.encode_plus(
        text,
        max_length=40,
        add_special_tokens=True,
        return_token_type_ids=False,
        pad_to_max_length=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    input_ids = encoded['input_ids']
    attention_mask = encoded['attention_mask']

    with torch.no_grad():
        aspect_output, polarity_output = model(input_ids, attention_mask)
        aspect_prediction = torch.argmax(aspect_output, dim=1).item()
        polarity_prediction = torch.argmax(polarity_output, dim=1).item()

    return aspect_class_names[aspect_prediction], polarity_class_names[polarity_prediction]

# Function to scrape headlines with multiple classes
def scrape_headlines(url):
    response = requests.get(url)
    soup = BeautifulSoup(response.content, "html.parser")
    
    # Extract headlines with the specified classes
    headlines = [h.get_text(strip=True) for h in soup.find_all("a", class_=["title-link", "stretched-link", "Title"])[:50]]
    return headlines

# Streamlit App Interface
st.title("Bangla Headline Aspect and Polarity Predictor")

# Radio button for functionality selection
option = st.radio("Choose Analysis Type:", ("Particular", "Overall"))

if option == "Particular":
    # Input for single text prediction
    text_input = st.text_area("Enter your Bangla text:")
    if st.button("Predict"):
        if text_input.strip():
            aspect, polarity = predict_text(text_input)
            st.write("### Original Text:")
            st.write(f"{text_input}")
            st.write(f"**Predicted Aspect Class:** {aspect}")
            st.write(f"**Predicted Polarity Class:** {polarity}")
        else:
            st.warning("Please enter some text to predict.")

elif option == "Overall":
    # Input for URL and headline analysis
    url_input = st.text_input("Enter the URL:")
    if st.button("Analyze Headlines"):
        if url_input.strip():
            headlines = scrape_headlines(url_input)
            if not headlines:
                st.warning("No headlines found. Please check the URL or structure of the site.")
            else:
                # Initialize counters
                aspect_counts = {cls: 0 for cls in aspect_class_names}
                polarity_counts = {cls: 0 for cls in polarity_class_names}
                
                # Process each headline
                for headline in headlines:
                    aspect, polarity = predict_text(headline)
                    aspect_counts[aspect] += 1
                    polarity_counts[polarity] += 1

                # Display counts
                st.write("### Aspect Class Counts")
                for cls in aspect_class_names:
                    st.write(f"{cls}: {aspect_counts[cls]}")

                st.write("### Polarity Class Counts")
                for cls in polarity_class_names:
                    st.write(f"{cls}: {polarity_counts[cls]}")
                
                # Display bar charts
                st.write("### Aspect Distribution")
                st.bar_chart(pd.DataFrame(list(aspect_counts.items()), columns=['Aspect', 'Count']).set_index('Aspect'))
                
                st.write("### Polarity Distribution")
                st.bar_chart(pd.DataFrame(list(polarity_counts.items()), columns=['Polarity', 'Count']).set_index('Polarity'))
        else:
            st.warning("Please enter a valid URL.")