ducdatit2002 commited on
Commit
8092054
·
verified ·
1 Parent(s): 2d5c298

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # demo_phobert_gradio.py
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import gradio as gr
5
+ import torch
6
+ import re
7
+ import json
8
+ import emoji
9
+ import numpy as np
10
+ from underthesea import word_tokenize
11
+
12
+ from transformers import (
13
+ AutoConfig,
14
+ AutoTokenizer,
15
+ AutoModelForSequenceClassification
16
+ )
17
+
18
+ ###############################################################################
19
+ # TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
20
+ ###############################################################################
21
+ emoji_mapping = {
22
+ "😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
23
+ "🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
24
+ "🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
25
+ "😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
26
+ "🤑": "[satisfaction]",
27
+ "🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
28
+ "😏": "[sarcasm]",
29
+ "😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
30
+ "😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
31
+ "😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
32
+ "🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
33
+ "🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
34
+ "😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
35
+ "😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
36
+ "😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
37
+ "😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
38
+ "😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
39
+ }
40
+
41
+ ###############################################################################
42
+ # HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
43
+ ###############################################################################
44
+ def replace_emojis(sentence, emoji_mapping):
45
+ processed_sentence = []
46
+ for char in sentence:
47
+ if char in emoji_mapping:
48
+ processed_sentence.append(emoji_mapping[char])
49
+ elif not emoji.is_emoji(char):
50
+ processed_sentence.append(char)
51
+ return ''.join(processed_sentence)
52
+
53
+ def remove_profanity(sentence):
54
+ profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
55
+ words = sentence.split()
56
+ filtered = [w for w in words if w.lower() not in profane_words]
57
+ return ' '.join(filtered)
58
+
59
+ def remove_special_characters(sentence):
60
+ return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)
61
+
62
+ def normalize_whitespace(sentence):
63
+ return ' '.join(sentence.split())
64
+
65
+ def remove_repeated_characters(sentence):
66
+ return re.sub(r"(.)\1{2,}", r"\1", sentence)
67
+
68
+ def replace_numbers(sentence):
69
+ return re.sub(r"\d+", "[number]", sentence)
70
+
71
+ def tokenize_underthesea(sentence):
72
+ tokens = word_tokenize(sentence)
73
+ return " ".join(tokens)
74
+
75
+ # Nếu có abbreviations.json, bạn load. Nếu không thì để rỗng.
76
+ try:
77
+ with open("abbreviations.json", "r", encoding="utf-8") as f:
78
+ abbreviations = json.load(f)
79
+ except:
80
+ abbreviations = {}
81
+
82
+ def preprocess_sentence(sentence):
83
+ # hạ thấp
84
+ sentence = sentence.lower()
85
+ # thay thế emoji
86
+ sentence = replace_emojis(sentence, emoji_mapping)
87
+ # loại bỏ từ nhạy cảm
88
+ sentence = remove_profanity(sentence)
89
+ # bỏ ký tự đặc biệt
90
+ sentence = remove_special_characters(sentence)
91
+ # chuẩn hoá khoảng trắng
92
+ sentence = normalize_whitespace(sentence)
93
+ # thay thế viết tắt
94
+ words = sentence.split()
95
+ replaced = []
96
+ for w in words:
97
+ if w in abbreviations:
98
+ replaced.append(" ".join(abbreviations[w]))
99
+ else:
100
+ replaced.append(w)
101
+ sentence = " ".join(replaced)
102
+ # bỏ bớt kí tự lặp
103
+ sentence = remove_repeated_characters(sentence)
104
+ # thay số thành [number]
105
+ sentence = replace_numbers(sentence)
106
+ # tokenize tiếng Việt
107
+ sentence = tokenize_underthesea(sentence)
108
+ return sentence
109
+
110
+ ###############################################################################
111
+ # LOAD CHECKPOINT
112
+ ###############################################################################
113
+ checkpoint_dir = "./checkpoint" # Folder checkpoint nằm trong cùng thư mục với file script
114
+ device = "cuda" if torch.cuda.is_available() else "cpu"
115
+
116
+ print("Loading config...")
117
+ config = AutoConfig.from_pretrained(checkpoint_dir)
118
+
119
+ # Mapping id to label theo thứ tự bạn cung cấp
120
+ custom_id2label = {
121
+ 0: 'Anger',
122
+ 1: 'Disgust',
123
+ 2: 'Enjoyment',
124
+ 3: 'Fear',
125
+ 4: 'Other',
126
+ 5: 'Sadness',
127
+ 6: 'Surprise'
128
+ }
129
+
130
+ # Kiểm tra và sử dụng custom_id2label nếu config.id2label không đúng
131
+ if hasattr(config, "id2label") and config.id2label:
132
+ # Nếu config.id2label chứa 'LABEL_x', sử dụng custom mapping
133
+ if all(label.startswith("LABEL_") for label in config.id2label.values()):
134
+ id2label = custom_id2label
135
+ else:
136
+ id2label = {int(k): v for k, v in config.id2label.items()}
137
+ else:
138
+ id2label = custom_id2label # Sử dụng mapping mặc định nếu config không có id2label
139
+
140
+ print("id2label loaded:", id2label)
141
+
142
+ print("Loading tokenizer...")
143
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
144
+
145
+ print("Loading model...")
146
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
147
+ model.to(device)
148
+ model.eval()
149
+
150
+ ###############################################################################
151
+ # HÀM PREDICT
152
+ ###############################################################################
153
+ # Mapping từ label đến thông điệp tương ứng
154
+ label2message = {
155
+ 'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
156
+ 'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
157
+ 'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
158
+ 'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
159
+ 'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
160
+ 'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
161
+ 'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
162
+ }
163
+
164
+ def predict_text(text: str) -> str:
165
+ """Tiền xử lý, token hoá và chạy model => trả về label và thông điệp."""
166
+ text_proc = preprocess_sentence(text)
167
+ inputs = tokenizer(
168
+ [text_proc],
169
+ padding=True,
170
+ truncation=True,
171
+ max_length=256,
172
+ return_tensors="pt"
173
+ ).to(device)
174
+
175
+ with torch.no_grad():
176
+ outputs = model(**inputs)
177
+ pred_id = outputs.logits.argmax(dim=-1).item()
178
+
179
+ if pred_id in id2label:
180
+ label = id2label[pred_id]
181
+ message = label2message.get(label, "")
182
+ if message:
183
+ return f"Dự đoán cảm xúc: {label}. {message}"
184
+ else:
185
+ return f"Dự đoán cảm xúc: {label}."
186
+ else:
187
+ return f"Nhãn không xác định (id={pred_id})"
188
+
189
+ ###############################################################################
190
+ # GRADIO APP
191
+ ###############################################################################
192
+ def run_demo(input_text):
193
+ predicted_emotion = predict_text(input_text)
194
+ return predicted_emotion
195
+
196
+ demo = gr.Interface(
197
+ fn=run_demo,
198
+ inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"),
199
+ outputs=gr.Textbox(label="Kết quả"),
200
+ title="PhoBERT Emotion Classification",
201
+ description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc."
202
+ )
203
+
204
+ if __name__ == "__main__":
205
+ demo.launch(share=True)