import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from transformers import BertTokenizer, BertForSequenceClassification import gradio as gr import os import zipfile # --------- Sentiment Model (Binary, expanded to 3 classes) --------- sentiment_model_name = "uer/roberta-base-finetuned-jd-binary-chinese" sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name) sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name) sentiment_model.eval() if not os.path.exists("result"): with zipfile.ZipFile("model_output.zip", "r") as zip_ref: zip_ref.extractall(".") # 加载你的多标签分类模型 label_tokenizer = AutoTokenizer.from_pretrained("result") label_model = AutoModelForSequenceClassification.from_pretrained("result", use_safetensors=True) label_model.eval() # 多标签类别 label_map = { 0: "Landscape & Culture", 1: "Service & Facilities", 2: "Experience & Atmosphere", 3: "Transportation Accessibility", 4: "Interactive Activities", 5: "Price & Consumption" } threshold = 0.5 # --------- Multi-label Classification Model (Your model) --------- label_dir = "./result" label_tokenizer = BertTokenizer.from_pretrained(label_dir) label_model = BertForSequenceClassification.from_pretrained(label_dir) label_model.eval() # Label categories label_map = { 0: "Landscape & Culture", 1: "Service & Facilities", 2: "Experience & Atmosphere", 3: "Transportation Accessibility", 4: "Interactive Activities", 5: "Price & Consumption" } threshold = 0.5 # --------- Inference Function --------- def analyze(text): if not text.strip(): return "Please enter a valid comment.", "Please enter a valid comment." # --- Sentiment Analysis --- sent_inputs = sentiment_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=128) with torch.no_grad(): sent_outputs = sentiment_model(**sent_inputs) probs = torch.softmax(sent_outputs.logits, dim=1).squeeze().tolist() pos_prob, neg_prob = probs[1], probs[0] if abs(pos_prob - neg_prob) < 0.2: sentiment_label = "Neutral" elif pos_prob > neg_prob: sentiment_label = "Positive" else: sentiment_label = "Negative" sentiment_result = ( f"Prediction: {sentiment_label}\n\n" f"Sentiment Scores:\n" f"Positive: {pos_prob:.2f}\n" f"Neutral: {1 - abs(pos_prob - neg_prob):.2f} \n" f"Negative: {neg_prob:.2f}" ) # --- Label Prediction --- label_inputs = label_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) with torch.no_grad(): label_outputs = label_model(**label_inputs) logits = label_outputs.logits probs = torch.sigmoid(logits).squeeze().tolist() if isinstance(probs, float): probs = [probs] selected_labels = [label_map[i] for i, p in enumerate(probs) if p >= threshold] if selected_labels: label_result = "Predicted Tags:\n" + "\n".join([f"{label_map[i]} ({probs[i]:.2f})" for i in range(len(probs)) if probs[i] >= threshold]) else: label_result = "No confident labels identified by the model." return sentiment_result, label_result # --------- Gradio Web UI --------- with gr.Blocks(title="Sentiment + Tag Analysis System") as demo: gr.Markdown("## 🌟 Comment Analyzer") gr.Markdown( "This tool analyzes **Tourist comment data** using deep learning models. " "It predicts both **sentiment polarity** (Positive / Neutral / Negative) and **semantic category tags** (6 themes)." ) with gr.Row(): with gr.Column(): input_box = gr.Textbox(label="Enter a review", placeholder="e.g., The park is peaceful and the staff are friendly...", lines=4) submit_btn = gr.Button("🔍 Analyze") with gr.Column(): sentiment_output = gr.Textbox(label="Sentiment Result", lines=6) label_output = gr.Textbox(label="Tag Classification Result", lines=6) submit_btn.click(fn=analyze, inputs=input_box, outputs=[sentiment_output, label_output]) # --------- Run App --------- if __name__ == "__main__": demo.launch()