import streamlit as st import torch import torch.nn.functional as F from transformers import BertTokenizer, BertModel import requests from bs4 import BeautifulSoup import pandas as pd # Define the model class (matching the saved architecture) class HeadlineClassifier(torch.nn.Module): def __init__(self, num_aspect_classes, num_polarity_classes): super(HeadlineClassifier, self).__init__() self.bert = BertModel.from_pretrained("sagorsarker/bangla-bert-base", return_dict=False) self.drop = torch.nn.Dropout(0.5) self.aspect_out = torch.nn.Linear(self.bert.config.hidden_size, num_aspect_classes) self.polarity_out = torch.nn.Linear(self.bert.config.hidden_size, num_polarity_classes) def forward(self, input_ids, attention_mask): _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False) output = self.drop(pooled_output) aspect_output = self.aspect_out(output) polarity_output = self.polarity_out(output) return aspect_output, polarity_output # Load tokenizer and model tokenizer = BertTokenizer.from_pretrained("sagorsarker/bangla-bert-base") model = HeadlineClassifier(num_aspect_classes=4, num_polarity_classes=3) model.load_state_dict(torch.load('best_model_state1.bin', map_location=torch.device('cpu'))) model.eval() # Class labels aspect_class_names = ["others", "politics", "religion", "sports"] polarity_class_names = ["negative", "neutral", "positive"] # Function for single text prediction def predict_text(text): encoded = tokenizer.encode_plus( text, max_length=40, add_special_tokens=True, return_token_type_ids=False, pad_to_max_length=True, return_attention_mask=True, return_tensors='pt' ) input_ids = encoded['input_ids'] attention_mask = encoded['attention_mask'] with torch.no_grad(): aspect_output, polarity_output = model(input_ids, attention_mask) aspect_prediction = torch.argmax(aspect_output, dim=1).item() polarity_prediction = torch.argmax(polarity_output, dim=1).item() return aspect_class_names[aspect_prediction], polarity_class_names[polarity_prediction] # Function to scrape headlines with multiple classes def scrape_headlines(url): response = requests.get(url) soup = BeautifulSoup(response.content, "html.parser") # Extract headlines with the specified classes headlines = [h.get_text(strip=True) for h in soup.find_all("a", class_=["title-link", "stretched-link", "Title"])[:50]] return headlines # Streamlit App Interface st.title("Bangla Headline Aspect and Polarity Predictor") # Radio button for functionality selection option = st.radio("Choose Analysis Type:", ("Particular", "Overall")) if option == "Particular": # Input for single text prediction text_input = st.text_area("Enter your Bangla text:") if st.button("Predict"): if text_input.strip(): aspect, polarity = predict_text(text_input) st.write("### Original Text:") st.write(f"{text_input}") st.write(f"**Predicted Aspect Class:** {aspect}") st.write(f"**Predicted Polarity Class:** {polarity}") else: st.warning("Please enter some text to predict.") elif option == "Overall": # Input for URL and headline analysis url_input = st.text_input("Enter the URL:") if st.button("Analyze Headlines"): if url_input.strip(): headlines = scrape_headlines(url_input) if not headlines: st.warning("No headlines found. Please check the URL or structure of the site.") else: # Initialize counters aspect_counts = {cls: 0 for cls in aspect_class_names} polarity_counts = {cls: 0 for cls in polarity_class_names} # Process each headline for headline in headlines: aspect, polarity = predict_text(headline) aspect_counts[aspect] += 1 polarity_counts[polarity] += 1 # Display counts st.write("### Aspect Class Counts") for cls in aspect_class_names: st.write(f"{cls}: {aspect_counts[cls]}") st.write("### Polarity Class Counts") for cls in polarity_class_names: st.write(f"{cls}: {polarity_counts[cls]}") # Display bar charts st.write("### Aspect Distribution") st.bar_chart(pd.DataFrame(list(aspect_counts.items()), columns=['Aspect', 'Count']).set_index('Aspect')) st.write("### Polarity Distribution") st.bar_chart(pd.DataFrame(list(polarity_counts.items()), columns=['Polarity', 'Count']).set_index('Polarity')) else: st.warning("Please enter a valid URL.")