# demo_phobert_api.py # -*- coding: utf-8 -*- from fastapi import FastAPI from pydantic import BaseModel import torch import re import json import emoji from underthesea import word_tokenize from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification # Khởi tạo FastAPI app app = FastAPI( title="PhoBERT Emotion Classification API", description="API dự đoán cảm xúc của câu tiếng Việt sử dụng PhoBERT.", version="1.0" ) ############################################################################### # TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN ############################################################################### emoji_mapping = { "😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]", "🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]", "🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]", "😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]", "🤑": "[satisfaction]", "🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]", "😏": "[sarcasm]", "😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]", "😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]", "😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]", "🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]", "🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]", "😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]", "😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]", "😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]", "😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]", "😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]" } ############################################################################### # HÀM XỬ LÝ (COPY TỪ FILE TRAIN) ############################################################################### def replace_emojis(sentence, emoji_mapping): processed_sentence = [] for char in sentence: if char in emoji_mapping: processed_sentence.append(emoji_mapping[char]) elif not emoji.is_emoji(char): processed_sentence.append(char) return ''.join(processed_sentence) def remove_profanity(sentence): profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"] words = sentence.split() filtered = [w for w in words if w.lower() not in profane_words] return ' '.join(filtered) def remove_special_characters(sentence): return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence) def normalize_whitespace(sentence): return ' '.join(sentence.split()) def remove_repeated_characters(sentence): return re.sub(r"(.)\1{2,}", r"\1", sentence) def replace_numbers(sentence): return re.sub(r"\d+", "[number]", sentence) def tokenize_underthesea(sentence): tokens = word_tokenize(sentence) return " ".join(tokens) # Nếu có abbreviations.json, load nó. Nếu không thì để rỗng. try: with open("abbreviations.json", "r", encoding="utf-8") as f: abbreviations = json.load(f) except Exception as e: abbreviations = {} def preprocess_sentence(sentence): sentence = sentence.lower() sentence = replace_emojis(sentence, emoji_mapping) sentence = remove_profanity(sentence) sentence = remove_special_characters(sentence) sentence = normalize_whitespace(sentence) # Thay thế từ viết tắt nếu có trong abbreviations words = sentence.split() replaced = [] for w in words: if w in abbreviations: replaced.append(" ".join(abbreviations[w])) else: replaced.append(w) sentence = " ".join(replaced) sentence = remove_repeated_characters(sentence) sentence = replace_numbers(sentence) sentence = tokenize_underthesea(sentence) return sentence ############################################################################### # LOAD CHECKPOINT ############################################################################### checkpoint_dir = "./checkpoint" # Đường dẫn đến folder checkpoint device = "cuda" if torch.cuda.is_available() else "cpu" print("Loading config...") config = AutoConfig.from_pretrained(checkpoint_dir) # Mapping id to label theo thứ tự bạn cung cấp custom_id2label = { 0: 'Anger', 1: 'Disgust', 2: 'Enjoyment', 3: 'Fear', 4: 'Other', 5: 'Sadness', 6: 'Surprise' } if hasattr(config, "id2label") and config.id2label: if all(label.startswith("LABEL_") for label in config.id2label.values()): id2label = custom_id2label else: id2label = {int(k): v for k, v in config.id2label.items()} else: id2label = custom_id2label print("id2label loaded:", id2label) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) print("Loading model...") model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config) model.to(device) model.eval() ############################################################################### # HÀM PREDICT ############################################################################### label2message = { 'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.', 'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.', 'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!', 'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.', 'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.', 'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.', 'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.' } def predict_text(text: str) -> str: text_proc = preprocess_sentence(text) inputs = tokenizer( [text_proc], padding=True, truncation=True, max_length=256, return_tensors="pt" ).to(device) with torch.no_grad(): outputs = model(**inputs) pred_id = outputs.logits.argmax(dim=-1).item() if pred_id in id2label: label = id2label[pred_id] message = label2message.get(label, "") if message: return f"Dự đoán cảm xúc: {label}. {message}" else: return f"Dự đoán cảm xúc: {label}." else: return f"Nhãn không xác định (id={pred_id})" ############################################################################### # ĐỊNH NGHĨA MODEL INPUT ############################################################################### class InputText(BaseModel): text: str ############################################################################### # API ENDPOINT ############################################################################### @app.post("/predict") def predict(input_text: InputText): """ Nhận một câu tiếng Việt và trả về dự đoán cảm xúc. """ result = predict_text(input_text.text) return {"result": result} ############################################################################### # CHẠY API SERVER ############################################################################### if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)