Stellajin916 commited on
Commit
5f3ddcb
·
verified ·
1 Parent(s): 885e0c9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+
7
+ # 情感模型(京东)
8
+ sentiment_model_name = "uer/roberta-base-finetuned-jd-binary-chinese"
9
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
10
+ sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name)
11
+ sentiment_model.eval()
12
+
13
+ # 解压你自己的多标签模型
14
+ if not os.path.exists("result"):
15
+ with zipfile.ZipFile("model_output.zip", "r") as zip_ref:
16
+ zip_ref.extractall(".")
17
+
18
+ # 加载你的多标签分类模型
19
+ custom_tokenizer = AutoTokenizer.from_pretrained("result")
20
+ custom_model = AutoModelForSequenceClassification.from_pretrained("result", use_safetensors=True)
21
+ custom_model.eval()
22
+
23
+ # 多标签类别
24
+ label_map = {
25
+ 0: "Landscape & Culture",
26
+ 1: "Service & Facilities",
27
+ 2: "Experience & Atmosphere",
28
+ 3: "Transportation Accessibility",
29
+ 4: "Interactive Activities",
30
+ 5: "Price & Consumption"
31
+ }
32
+
33
+ # 推理函数
34
+ def analyze(text, threshold=0.5):
35
+ # 情感分析
36
+ inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
37
+ with torch.no_grad():
38
+ outputs = sentiment_model(**inputs)
39
+ probs = torch.softmax(outputs.logits, dim=1).squeeze().tolist()
40
+ sentiment = "积极 (Positive)" if torch.argmax(outputs.logits) == 1 else "消极 (Negative)"
41
+ sentiment_result = f"{sentiment}\nPositive: {probs[1]:.2f}, Negative: {probs[0]:.2f}"
42
+
43
+ # 多标签分类
44
+ inputs = custom_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
45
+ with torch.no_grad():
46
+ logits = custom_model(**inputs).logits
47
+ probs = torch.sigmoid(logits).squeeze().tolist()
48
+ if isinstance(probs, float): # 单个标签时
49
+ probs = [probs]
50
+ results = [
51
+ f"{label_map[i]} ({probs[i]:.2f})"
52
+ for i in range(len(probs)) if probs[i] >= threshold
53
+ ]
54
+ if results:
55
+ label_result = "\n".join(results)
56
+ else:
57
+ label_result = "The model was unable to identify the correct labels."
58
+
59
+ return f"【Sentiment analysis】\n{sentiment_result}\n\n【Category of topic】\n{label_result}"
60
+
61
+ # 创建 Gradio 页面
62
+ demo = gr.Interface(
63
+ fn=analyze,
64
+ inputs=[
65
+ gr.Textbox(lines=3, label="请输入评论内容"),
66
+ gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.5, label="分类标签阈值")
67
+ ],
68
+ outputs="text",
69
+ title="中文评论分析器",
70
+ description="使用京东情感模型 + 自定义多标签模型,对评论内容进行双重分析"
71
+ )
72
+
73
+ demo.launch()