File size: 8,072 Bytes
82df4a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# demo_phobert_api.py
# -*- coding: utf-8 -*-

from fastapi import FastAPI
from pydantic import BaseModel
import torch
import re
import json
import emoji
from underthesea import word_tokenize
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification

# Khởi tạo FastAPI app
app = FastAPI(
    title="PhoBERT Emotion Classification API",
    description="API dự đoán cảm xúc của câu tiếng Việt sử dụng PhoBERT.",
    version="1.0"
)

###############################################################################
# TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
###############################################################################
emoji_mapping = {
    "😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
    "🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
    "🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
    "😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
    "🤑": "[satisfaction]",
    "🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
    "😏": "[sarcasm]",
    "😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
    "😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
    "😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
    "🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
    "🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
    "😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
    "😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
    "😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
    "😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
    "😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
}

###############################################################################
# HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
###############################################################################
def replace_emojis(sentence, emoji_mapping):
    processed_sentence = []
    for char in sentence:
        if char in emoji_mapping:
            processed_sentence.append(emoji_mapping[char])
        elif not emoji.is_emoji(char):
            processed_sentence.append(char)
    return ''.join(processed_sentence)

def remove_profanity(sentence):
    profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
    words = sentence.split()
    filtered = [w for w in words if w.lower() not in profane_words]
    return ' '.join(filtered)

def remove_special_characters(sentence):
    return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)

def normalize_whitespace(sentence):
    return ' '.join(sentence.split())

def remove_repeated_characters(sentence):
    return re.sub(r"(.)\1{2,}", r"\1", sentence)

def replace_numbers(sentence):
    return re.sub(r"\d+", "[number]", sentence)

def tokenize_underthesea(sentence):
    tokens = word_tokenize(sentence)
    return " ".join(tokens)

# Nếu có abbreviations.json, load nó. Nếu không thì để rỗng.
try:
    with open("abbreviations.json", "r", encoding="utf-8") as f:
        abbreviations = json.load(f)
except Exception as e:
    abbreviations = {}

def preprocess_sentence(sentence):
    sentence = sentence.lower()
    sentence = replace_emojis(sentence, emoji_mapping)
    sentence = remove_profanity(sentence)
    sentence = remove_special_characters(sentence)
    sentence = normalize_whitespace(sentence)
    # Thay thế từ viết tắt nếu có trong abbreviations
    words = sentence.split()
    replaced = []
    for w in words:
        if w in abbreviations:
            replaced.append(" ".join(abbreviations[w]))
        else:
            replaced.append(w)
    sentence = " ".join(replaced)
    sentence = remove_repeated_characters(sentence)
    sentence = replace_numbers(sentence)
    sentence = tokenize_underthesea(sentence)
    return sentence

###############################################################################
# LOAD CHECKPOINT
###############################################################################
checkpoint_dir = "./checkpoint"  # Đường dẫn đến folder checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"

print("Loading config...")
config = AutoConfig.from_pretrained(checkpoint_dir)

# Mapping id to label theo thứ tự bạn cung cấp
custom_id2label = {
    0: 'Anger',
    1: 'Disgust',
    2: 'Enjoyment',
    3: 'Fear',
    4: 'Other',
    5: 'Sadness',
    6: 'Surprise'
}

if hasattr(config, "id2label") and config.id2label:
    if all(label.startswith("LABEL_") for label in config.id2label.values()):
        id2label = custom_id2label
    else:
        id2label = {int(k): v for k, v in config.id2label.items()}
else:
    id2label = custom_id2label

print("id2label loaded:", id2label)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)

print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
model.to(device)
model.eval()

###############################################################################
# HÀM PREDICT
###############################################################################
label2message = {
    'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
    'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
    'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
    'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
    'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
    'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
    'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
}

def predict_text(text: str) -> str:
    text_proc = preprocess_sentence(text)
    inputs = tokenizer(
        [text_proc],
        padding=True,
        truncation=True,
        max_length=256,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        pred_id = outputs.logits.argmax(dim=-1).item()

    if pred_id in id2label:
        label = id2label[pred_id]
        message = label2message.get(label, "")
        if message:
            return f"Dự đoán cảm xúc: {label}. {message}"
        else:
            return f"Dự đoán cảm xúc: {label}."
    else:
        return f"Nhãn không xác định (id={pred_id})"

###############################################################################
# ĐỊNH NGHĨA MODEL INPUT
###############################################################################
class InputText(BaseModel):
    text: str

###############################################################################
# API ENDPOINT
###############################################################################
@app.post("/predict")
def predict(input_text: InputText):
    """
    Nhận một câu tiếng Việt và trả về dự đoán cảm xúc.
    """
    result = predict_text(input_text.text)
    return {"result": result}

###############################################################################
# CHẠY API SERVER
###############################################################################
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)