IsrakML's picture
Upload 2 files
e6d6007 verified
import streamlit as st
import torch
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
import requests
from bs4 import BeautifulSoup
import pandas as pd
# Define the model class (matching the saved architecture)
class HeadlineClassifier(torch.nn.Module):
def __init__(self, num_aspect_classes, num_polarity_classes):
super(HeadlineClassifier, self).__init__()
self.bert = BertModel.from_pretrained("sagorsarker/bangla-bert-base", return_dict=False)
self.drop = torch.nn.Dropout(0.5)
self.aspect_out = torch.nn.Linear(self.bert.config.hidden_size, num_aspect_classes)
self.polarity_out = torch.nn.Linear(self.bert.config.hidden_size, num_polarity_classes)
def forward(self, input_ids, attention_mask):
_, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
output = self.drop(pooled_output)
aspect_output = self.aspect_out(output)
polarity_output = self.polarity_out(output)
return aspect_output, polarity_output
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
model = HeadlineClassifier(num_aspect_classes=4, num_polarity_classes=3)
model.load_state_dict(torch.load('best_model_state1.bin', map_location=torch.device('cpu')))
model.eval()
# Class labels
aspect_class_names = ["others", "politics", "religion", "sports"]
polarity_class_names = ["negative", "neutral", "positive"]
# Function for single text prediction
def predict_text(text):
encoded = tokenizer.encode_plus(
text,
max_length=40,
add_special_tokens=True,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt'
)
input_ids = encoded['input_ids']
attention_mask = encoded['attention_mask']
with torch.no_grad():
aspect_output, polarity_output = model(input_ids, attention_mask)
aspect_prediction = torch.argmax(aspect_output, dim=1).item()
polarity_prediction = torch.argmax(polarity_output, dim=1).item()
return aspect_class_names[aspect_prediction], polarity_class_names[polarity_prediction]
# Function to scrape headlines with multiple classes
def scrape_headlines(url):
response = requests.get(url)
soup = BeautifulSoup(response.content, "html.parser")
# Extract headlines with the specified classes
headlines = [h.get_text(strip=True) for h in soup.find_all("a", class_=["title-link", "stretched-link", "Title"])[:50]]
return headlines
# Streamlit App Interface
st.title("Bangla Headline Aspect and Polarity Predictor")
# Radio button for functionality selection
option = st.radio("Choose Analysis Type:", ("Particular", "Overall"))
if option == "Particular":
# Input for single text prediction
text_input = st.text_area("Enter your Bangla text:")
if st.button("Predict"):
if text_input.strip():
aspect, polarity = predict_text(text_input)
st.write("### Original Text:")
st.write(f"{text_input}")
st.write(f"**Predicted Aspect Class:** {aspect}")
st.write(f"**Predicted Polarity Class:** {polarity}")
else:
st.warning("Please enter some text to predict.")
elif option == "Overall":
# Input for URL and headline analysis
url_input = st.text_input("Enter the URL:")
if st.button("Analyze Headlines"):
if url_input.strip():
headlines = scrape_headlines(url_input)
if not headlines:
st.warning("No headlines found. Please check the URL or structure of the site.")
else:
# Initialize counters
aspect_counts = {cls: 0 for cls in aspect_class_names}
polarity_counts = {cls: 0 for cls in polarity_class_names}
# Process each headline
for headline in headlines:
aspect, polarity = predict_text(headline)
aspect_counts[aspect] += 1
polarity_counts[polarity] += 1
# Display counts
st.write("### Aspect Class Counts")
for cls in aspect_class_names:
st.write(f"{cls}: {aspect_counts[cls]}")
st.write("### Polarity Class Counts")
for cls in polarity_class_names:
st.write(f"{cls}: {polarity_counts[cls]}")
# Display bar charts
st.write("### Aspect Distribution")
st.bar_chart(pd.DataFrame(list(aspect_counts.items()), columns=['Aspect', 'Count']).set_index('Aspect'))
st.write("### Polarity Distribution")
st.bar_chart(pd.DataFrame(list(polarity_counts.items()), columns=['Polarity', 'Count']).set_index('Polarity'))
else:
st.warning("Please enter a valid URL.")