oop / app.py
Mohammed Foud
first commit
79d2a14
raw
history blame
3.37 kB
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import io
import base64
# Load the model and tokenizer
model_path = "./final_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
def predict_sentiment(text):
# Preprocess text
text = text.lower()
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=-1)
predicted_class = torch.argmax(probabilities, dim=-1).item()
# Map class to sentiment
sentiment_map = {0: "Negative", 1: "Neutral", 2: "Positive"}
sentiment = sentiment_map[predicted_class]
# Get probabilities
probs = probabilities[0].tolist()
prob_dict = {sentiment_map[i]: f"{prob*100:.2f}%" for i, prob in enumerate(probs)}
return sentiment, prob_dict
def analyze_reviews(reviews_text):
# Split reviews by newline
reviews = [r.strip() for r in reviews_text.split('\n') if r.strip()]
if not reviews:
return "Please enter at least one review.", None
# Process each review
results = []
for review in reviews:
sentiment, probs = predict_sentiment(review)
results.append({
'Review': review,
'Sentiment': sentiment,
'Confidence': probs
})
# Create DataFrame for display
df = pd.DataFrame(results)
# Create visualization
plt.figure(figsize=(10, 6))
sentiment_counts = df['Sentiment'].value_counts()
plt.bar(sentiment_counts.index, sentiment_counts.values)
plt.title('Sentiment Distribution')
plt.xlabel('Sentiment')
plt.ylabel('Count')
# Save plot to bytes
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plot_base64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close()
return df, f'<img src="data:image/png;base64,{plot_base64}" style="max-width:100%;">'
# Create Gradio interface
with gr.Blocks(title="Amazon Review Sentiment Analysis") as demo:
gr.Markdown("# Amazon Review Sentiment Analysis")
gr.Markdown("Enter one or more reviews (one per line) to analyze their sentiment.")
with gr.Row():
with gr.Column():
reviews_input = gr.Textbox(
label="Enter Reviews",
placeholder="Enter your reviews here (one per line)...",
lines=10
)
analyze_btn = gr.Button("Analyze Reviews")
with gr.Column():
results_table = gr.Dataframe(
headers=["Review", "Sentiment", "Confidence"],
datatype=["str", "str", "str"],
col_count=(3, "fixed")
)
plot_output = gr.HTML()
analyze_btn.click(
fn=analyze_reviews,
inputs=reviews_input,
outputs=[results_table, plot_output]
)
if __name__ == "__main__":
demo.launch(share=True)