Spaces:
Sleeping
Sleeping
# demo_phobert_gradio.py | |
# -*- coding: utf-8 -*- | |
import gradio as gr | |
import torch | |
import re | |
import json | |
import emoji | |
import numpy as np | |
from underthesea import word_tokenize | |
from transformers import ( | |
AutoConfig, | |
AutoTokenizer, | |
AutoModelForSequenceClassification | |
) | |
############################################################################### | |
# TẢI MAPPING EMOJI - COPY Y NGUYÊN TỪ FILE TRAIN | |
############################################################################### | |
emoji_mapping = { | |
"😀": "[joy]", "😃": "[joy]", "😄": "[joy]", "😁": "[joy]", "😆": "[joy]", "😅": "[joy]", "😂": "[joy]", "🤣": "[joy]", | |
"🙂": "[love]", "🙃": "[love]", "😉": "[love]", "😊": "[love]", "😇": "[love]", "🥰": "[love]", "😍": "[love]", | |
"🤩": "[love]", "😘": "[love]", "😗": "[love]", "☺": "[love]", "😚": "[love]", "😙": "[love]", | |
"😋": "[satisfaction]", "😛": "[satisfaction]", "😜": "[satisfaction]", "🤪": "[satisfaction]", "😝": "[satisfaction]", | |
"🤑": "[satisfaction]", | |
"🤐": "[neutral]", "🤨": "[neutral]", "😐": "[neutral]", "😑": "[neutral]", "😶": "[neutral]", | |
"😏": "[sarcasm]", | |
"😒": "[disappointment]", "🙄": "[disappointment]", "😬": "[disappointment]", | |
"😔": "[sadness]", "😪": "[sadness]", "😢": "[sadness]", "😭": "[sadness]", "😥": "[sadness]", "😓": "[sadness]", | |
"😩": "[tiredness]", "😫": "[tiredness]", "🥱": "[tiredness]", | |
"🤤": "[discomfort]", "🤢": "[discomfort]", "🤮": "[discomfort]", "🤧": "[discomfort]", "🥵": "[discomfort]", | |
"🥶": "[discomfort]", "🥴": "[discomfort]", "😵": "[discomfort]", "🤯": "[discomfort]", | |
"😕": "[confused]", "😟": "[confused]", "🙁": "[confused]", "☹": "[confused]", | |
"😮": "[surprise]", "😯": "[surprise]", "😲": "[surprise]", "😳": "[surprise]", "🥺": "[pleading]", | |
"😦": "[fear]", "😧": "[fear]", "😨": "[fear]", "😰": "[fear]", "😱": "[fear]", | |
"😖": "[confusion]", "😣": "[confusion]", "😞": "[confusion]", | |
"😤": "[anger]", "😡": "[anger]", "😠": "[anger]", "🤬": "[anger]", "😈": "[mischievous]", "👿": "[mischievous]" | |
} | |
############################################################################### | |
# HÀM XỬ LÝ (COPY TỪ FILE TRAIN) | |
############################################################################### | |
def replace_emojis(sentence, emoji_mapping): | |
processed_sentence = [] | |
for char in sentence: | |
if char in emoji_mapping: | |
processed_sentence.append(emoji_mapping[char]) | |
elif not emoji.is_emoji(char): | |
processed_sentence.append(char) | |
return ''.join(processed_sentence) | |
def remove_profanity(sentence): | |
profane_words = ["loz", "vloz", "vl", "dm", "đm", "clgt", "dmm", "cc", "vc", "đù mé", "vãi"] | |
words = sentence.split() | |
filtered = [w for w in words if w.lower() not in profane_words] | |
return ' '.join(filtered) | |
def remove_special_characters(sentence): | |
return re.sub(r"[\^\*@#&$%<>~{}|\\]", "", sentence) | |
def normalize_whitespace(sentence): | |
return ' '.join(sentence.split()) | |
def remove_repeated_characters(sentence): | |
return re.sub(r"(.)\1{2,}", r"\1", sentence) | |
def replace_numbers(sentence): | |
return re.sub(r"\d+", "[number]", sentence) | |
def tokenize_underthesea(sentence): | |
tokens = word_tokenize(sentence) | |
return " ".join(tokens) | |
# Nếu có abbreviations.json, bạn load. Nếu không thì để rỗng. | |
try: | |
with open("abbreviations.json", "r", encoding="utf-8") as f: | |
abbreviations = json.load(f) | |
except: | |
abbreviations = {} | |
def preprocess_sentence(sentence): | |
# hạ thấp | |
sentence = sentence.lower() | |
# thay thế emoji | |
sentence = replace_emojis(sentence, emoji_mapping) | |
# loại bỏ từ nhạy cảm | |
sentence = remove_profanity(sentence) | |
# bỏ ký tự đặc biệt | |
sentence = remove_special_characters(sentence) | |
# chuẩn hoá khoảng trắng | |
sentence = normalize_whitespace(sentence) | |
# thay thế viết tắt | |
words = sentence.split() | |
replaced = [] | |
for w in words: | |
if w in abbreviations: | |
replaced.append(" ".join(abbreviations[w])) | |
else: | |
replaced.append(w) | |
sentence = " ".join(replaced) | |
# bỏ bớt kí tự lặp | |
sentence = remove_repeated_characters(sentence) | |
# thay số thành [number] | |
sentence = replace_numbers(sentence) | |
# tokenize tiếng Việt | |
sentence = tokenize_underthesea(sentence) | |
return sentence | |
############################################################################### | |
# LOAD CHECKPOINT | |
############################################################################### | |
checkpoint_dir = "./checkpoint" # Folder checkpoint nằm trong cùng thư mục với file script | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Loading config...") | |
config = AutoConfig.from_pretrained(checkpoint_dir) | |
# Mapping id to label theo thứ tự bạn cung cấp | |
custom_id2label = { | |
0: 'Anger', | |
1: 'Disgust', | |
2: 'Enjoyment', | |
3: 'Fear', | |
4: 'Other', | |
5: 'Sadness', | |
6: 'Surprise' | |
} | |
# Kiểm tra và sử dụng custom_id2label nếu config.id2label không đúng | |
if hasattr(config, "id2label") and config.id2label: | |
# Nếu config.id2label chứa 'LABEL_x', sử dụng custom mapping | |
if all(label.startswith("LABEL_") for label in config.id2label.values()): | |
id2label = custom_id2label | |
else: | |
id2label = {int(k): v for k, v in config.id2label.items()} | |
else: | |
id2label = custom_id2label # Sử dụng mapping mặc định nếu config không có id2label | |
print("id2label loaded:", id2label) | |
print("Loading tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir) | |
print("Loading model...") | |
model = AutoModelForSequenceClassification.from_pretrained(checkpoint_dir, config=config) | |
model.to(device) | |
model.eval() | |
############################################################################### | |
# HÀM PREDICT | |
############################################################################### | |
# Mapping từ label đến thông điệp tương ứng | |
label2message = { | |
'Anger': 'Hãy bình tĩnh và giải quyết vấn đề một cách bình thản.', | |
'Disgust': 'Hãy tránh xa những thứ khiến bạn không thích.', | |
'Enjoyment': 'Chúc mừng bạn có một ngày tuyệt vời!', | |
'Fear': 'Hãy đối mặt với nỗi sợ để vượt qua chúng.', | |
'Other': 'Cảm xúc của bạn hiện tại không được phân loại rõ ràng.', | |
'Sadness': 'Hãy tìm kiếm sự hỗ trợ khi cần thiết.', | |
'Surprise': 'Thật bất ngờ! Hãy tận hưởng khoảnh khắc này.' | |
} | |
def predict_text(text: str) -> str: | |
"""Tiền xử lý, token hoá và chạy model => trả về label và thông điệp.""" | |
text_proc = preprocess_sentence(text) | |
inputs = tokenizer( | |
[text_proc], | |
padding=True, | |
truncation=True, | |
max_length=256, | |
return_tensors="pt" | |
).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
pred_id = outputs.logits.argmax(dim=-1).item() | |
if pred_id in id2label: | |
label = id2label[pred_id] | |
message = label2message.get(label, "") | |
if message: | |
return f"Dự đoán cảm xúc: {label}. {message}" | |
else: | |
return f"Dự đoán cảm xúc: {label}." | |
else: | |
return f"Nhãn không xác định (id={pred_id})" | |
############################################################################### | |
# GRADIO APP | |
############################################################################### | |
def run_demo(input_text): | |
predicted_emotion = predict_text(input_text) | |
return predicted_emotion | |
demo = gr.Interface( | |
fn=run_demo, | |
inputs=gr.Textbox(lines=3, label="Nhập câu tiếng Việt"), | |
outputs=gr.Textbox(label="Kết quả"), | |
title="PhoBERT Emotion Classification", | |
description="Nhập vào 1 câu tiếng Việt để dự đoán cảm xúc." | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |