Spaces:
Sleeping
Sleeping
File size: 4,270 Bytes
5f3ddcb 602715a 6dc2ac5 5f3ddcb 602715a 5f3ddcb 392d1a7 5f3ddcb 602715a 5f3ddcb 602715a 5f3ddcb 602715a 5f3ddcb 602715a 5f3ddcb 602715a 5f3ddcb 602715a bb5fa1c 602715a 5f3ddcb 602715a 5f3ddcb 602715a 5f3ddcb 602715a 5f3ddcb 602715a b917252 602715a 5f3ddcb 602715a fea2d22 602715a 5f3ddcb 602715a 5f3ddcb 602715a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertTokenizer, BertForSequenceClassification
import gradio as gr
import os
import zipfile
# --------- Sentiment Model (Binary, expanded to 3 classes) ---------
sentiment_model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
sentiment_model.eval()
if not os.path.exists("result"):
with zipfile.ZipFile("model_output.zip", "r") as zip_ref:
zip_ref.extractall(".")
# 加载你的多标签分类模型
label_tokenizer = AutoTokenizer.from_pretrained("result")
label_model = AutoModelForSequenceClassification.from_pretrained("result", use_safetensors=True)
label_model.eval()
# 多标签类别
label_map = {
0: "Landscape & Culture",
1: "Service & Facilities",
2: "Experience & Atmosphere",
3: "Transportation Accessibility",
4: "Interactive Activities",
5: "Price & Consumption"
}
threshold = 0.5
# --------- Multi-label Classification Model (Your model) ---------
label_dir = "./result"
label_tokenizer = BertTokenizer.from_pretrained(label_dir)
label_model = BertForSequenceClassification.from_pretrained(label_dir)
label_model.eval()
# Label categories
label_map = {
0: "Landscape & Culture",
1: "Service & Facilities",
2: "Experience & Atmosphere",
3: "Transportation Accessibility",
4: "Interactive Activities",
5: "Price & Consumption"
}
threshold = 0.5
# --------- Inference Function ---------
def analyze(text):
if not text.strip():
return "Please enter a valid comment.", "Please enter a valid comment."
# --- Sentiment Analysis ---
sent_inputs = sentiment_tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=128)
with torch.no_grad():
sent_outputs = sentiment_model(**sent_inputs)
probs = torch.softmax(sent_outputs.logits, dim=1).squeeze().tolist()
pos_prob, neg_prob = probs[1], probs[0]
if abs(pos_prob - neg_prob) < 0.2:
sentiment_label = "Neutral"
elif pos_prob > neg_prob:
sentiment_label = "Positive"
else:
sentiment_label = "Negative"
sentiment_result = (
f"Prediction: {sentiment_label}\n\n"
f"Sentiment Scores:\n"
f"Positive: {pos_prob:.2f}\n"
f"Neutral: {1 - abs(pos_prob - neg_prob):.2f} \n"
f"Negative: {neg_prob:.2f}"
)
# --- Label Prediction ---
label_inputs = label_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
with torch.no_grad():
label_outputs = label_model(**label_inputs)
logits = label_outputs.logits
probs = torch.sigmoid(logits).squeeze().tolist()
if isinstance(probs, float):
probs = [probs]
selected_labels = [label_map[i] for i, p in enumerate(probs) if p >= threshold]
if selected_labels:
label_result = "Predicted Tags:\n" + "\n".join([f"{label_map[i]} ({probs[i]:.2f})" for i in range(len(probs)) if probs[i] >= threshold])
else:
label_result = "No confident labels identified by the model."
return sentiment_result, label_result
# --------- Gradio Web UI ---------
with gr.Blocks(title="Sentiment + Tag Analysis System") as demo:
gr.Markdown("## 🌟 Comment Analyzer")
gr.Markdown(
"This tool analyzes **Tourist comment data** using deep learning models. "
"It predicts both **sentiment polarity** (Positive / Neutral / Negative) and **semantic category tags** (6 themes)."
)
with gr.Row():
with gr.Column():
input_box = gr.Textbox(label="Enter a review", placeholder="e.g., The park is peaceful and the staff are friendly...", lines=4)
submit_btn = gr.Button("🔍 Analyze")
with gr.Column():
sentiment_output = gr.Textbox(label="Sentiment Result", lines=6)
label_output = gr.Textbox(label="Tag Classification Result", lines=6)
submit_btn.click(fn=analyze, inputs=input_box, outputs=[sentiment_output, label_output])
# --------- Run App ---------
if __name__ == "__main__":
demo.launch()
|