Spaces:
Sleeping
Sleeping
File size: 8,072 Bytes
82df4a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
# demo_phobert_api.py
# -*- coding: utf-8 -*-
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import re
import json
import emoji
from underthesea import word_tokenize
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
# Khởi tạo FastAPI app
app = FastAPI(
title="PhoBERT Emotion Classification API",
description="API dự đoán cảm xúc của câu tiếng Việt sử dụng PhoBERT.",
version="1.0"
)
###############################################################################
# TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN
###############################################################################
emoji_mapping = {
"😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]",
"🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]",
"🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]",
"😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]",
"🤑": "[satisfaction]",
"🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]",
"😏": "[sarcasm]",
"😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]",
"😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]",
"😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]",
"🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]",
"🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]",
"😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]",
"😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]",
"😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]",
"😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]",
"😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]"
}
###############################################################################
# HÀM XỬ LÝ (COPY TỪ FILE TRAIN)
###############################################################################
def replace_emojis(sentence, emoji_mapping):
processed_sentence = []
for char in sentence:
if char in emoji_mapping:
processed_sentence.append(emoji_mapping[char])
elif not emoji.is_emoji(char):
processed_sentence.append(char)
return ''.join(processed_sentence)
def remove_profanity(sentence):
profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"]
words = sentence.split()
filtered = [w for w in words if w.lower() not in profane_words]
return ' '.join(filtered)
def remove_special_characters(sentence):
return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence)
def normalize_whitespace(sentence):
return ' '.join(sentence.split())
def remove_repeated_characters(sentence):
return re.sub(r"(.)\1{2,}", r"\1", sentence)
def replace_numbers(sentence):
return re.sub(r"\d+", "[number]", sentence)
def tokenize_underthesea(sentence):
tokens = word_tokenize(sentence)
return " ".join(tokens)
# Nếu có abbreviations.json, load nó. Nếu không thì để rỗng.
try:
with open("abbreviations.json", "r", encoding="utf-8") as f:
abbreviations = json.load(f)
except Exception as e:
abbreviations = {}
def preprocess_sentence(sentence):
sentence = sentence.lower()
sentence = replace_emojis(sentence, emoji_mapping)
sentence = remove_profanity(sentence)
sentence = remove_special_characters(sentence)
sentence = normalize_whitespace(sentence)
# Thay thế từ viết tắt nếu có trong abbreviations
words = sentence.split()
replaced = []
for w in words:
if w in abbreviations:
replaced.append(" ".join(abbreviations[w]))
else:
replaced.append(w)
sentence = " ".join(replaced)
sentence = remove_repeated_characters(sentence)
sentence = replace_numbers(sentence)
sentence = tokenize_underthesea(sentence)
return sentence
###############################################################################
# LOAD CHECKPOINT
###############################################################################
checkpoint_dir = "./checkpoint" # Đường dẫn đến folder checkpoint
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Loading config...")
config = AutoConfig.from_pretrained(checkpoint_dir)
# Mapping id to label theo thứ tự bạn cung cấp
custom_id2label = {
0: 'Anger',
1: 'Disgust',
2: 'Enjoyment',
3: 'Fear',
4: 'Other',
5: 'Sadness',
6: 'Surprise'
}
if hasattr(config, "id2label") and config.id2label:
if all(label.startswith("LABEL_") for label in config.id2label.values()):
id2label = custom_id2label
else:
id2label = {int(k): v for k, v in config.id2label.items()}
else:
id2label = custom_id2label
print("id2label loaded:", id2label)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
print("Loading model...")
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config)
model.to(device)
model.eval()
###############################################################################
# HÀM PREDICT
###############################################################################
label2message = {
'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.',
'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.',
'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!',
'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.',
'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.',
'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.',
'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.'
}
def predict_text(text: str) -> str:
text_proc = preprocess_sentence(text)
inputs = tokenizer(
[text_proc],
padding=True,
truncation=True,
max_length=256,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
pred_id = outputs.logits.argmax(dim=-1).item()
if pred_id in id2label:
label = id2label[pred_id]
message = label2message.get(label, "")
if message:
return f"Dự đoán cảm xúc: {label}. {message}"
else:
return f"Dự đoán cảm xúc: {label}."
else:
return f"Nhãn không xác định (id={pred_id})"
###############################################################################
# ĐỊNH NGHĨA MODEL INPUT
###############################################################################
class InputText(BaseModel):
text: str
###############################################################################
# API ENDPOINT
###############################################################################
@app.post("/predict")
def predict(input_text: InputText):
"""
Nhận một câu tiếng Việt và trả về dự đoán cảm xúc.
"""
result = predict_text(input_text.text)
return {"result": result}
###############################################################################
# CHẠY API SERVER
###############################################################################
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
|