Spaces:
Running
Running
import os | |
os.environ.setdefault("HF_HOME", "/app/hf_cache") | |
import logging | |
from fastapi import FastAPI, Request | |
from fastapi.templating import Jinja2Templates | |
from pydantic import BaseModel | |
import torch | |
import torch.nn as nn | |
from transformers import AutoModel, AutoTokenizer | |
import uvicorn | |
from huggingface_hub import hf_hub_download | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
templates = Jinja2Templates(directory="templates") | |
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
HF_REPO = "ThanhDT127/pho-bert-bilstm" | |
HF_FILE = "best_model_1.pth" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MODEL_DIR = "models" | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
MODEL_PATH = os.path.join(MODEL_DIR, HF_FILE) | |
try: | |
if not os.path.isfile(MODEL_PATH): | |
logger.info("Downloading model from Hugging Face Hub") | |
MODEL_PATH = hf_hub_download( | |
repo_id=HF_REPO, | |
filename=HF_FILE, | |
cache_dir=os.environ["HF_HOME"], | |
force_filename=HF_FILE, | |
token=HF_TOKEN | |
) | |
logger.info("Loading model from %s", MODEL_PATH) | |
model_state_dict = torch.load(MODEL_PATH, map_location=device) | |
logger.info("Model loaded successfully") | |
except Exception as e: | |
logger.error("Error loading model: %s", str(e)) | |
raise | |
class TextInput(BaseModel): | |
text: str | |
class BertBiLSTMClassifier(nn.Module): | |
def __init__(self, bert_model_name, num_emotion_classes, binary_cols, lstm_hidden_size=128, dropout=0.4): | |
super().__init__() | |
self.bert = AutoModel.from_pretrained(bert_model_name) | |
self.lstm = nn.LSTM( | |
input_size=self.bert.config.hidden_size, | |
hidden_size=lstm_hidden_size, | |
num_layers=1, | |
batch_first=True, | |
bidirectional=True | |
) | |
self.dropout = nn.Dropout(dropout) | |
self.emotion_fc = nn.Linear(lstm_hidden_size * 2, num_emotion_classes) | |
self.binary_fcs = nn.ModuleDict({ | |
col: nn.Linear(lstm_hidden_size * 2, 1) | |
for col in binary_cols | |
}) | |
def forward(self, input_ids, attention_mask): | |
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
seq_out = bert_out.last_hidden_state | |
lstm_out, _ = self.lstm(seq_out) | |
last_hidden = lstm_out[:, -1, :] | |
dropped = self.dropout(last_hidden) | |
emo_logits = self.emotion_fc(dropped) | |
bin_logits = { | |
col: self.binary_fcs[col](dropped).squeeze(-1) | |
for col in self.binary_fcs | |
} | |
return emo_logits, bin_logits | |
tokenizer = AutoTokenizer.from_pretrained( | |
"vinai/phobert-base", | |
use_fast=False, | |
cache_dir=os.environ["HF_HOME"] | |
) | |
binary_cols = [ | |
'sản phẩm', 'giá cả', 'vận chuyển', | |
'thái độ và dịch vụ khách hàng', 'khác' | |
] | |
model = BertBiLSTMClassifier( | |
bert_model_name="vinai/phobert-base", | |
num_emotion_classes=3, | |
binary_cols=binary_cols, | |
lstm_hidden_size=128 | |
).to(device) | |
# Load model state dict | |
model.load_state_dict(model_state_dict) | |
model.eval() | |
threshold_dict = { | |
'sản phẩm': 0.28, | |
'giá cả': 0.58, | |
'vận chuyển': 0.58, | |
'thái độ và dịch vụ khách hàng': 0.70, | |
'khác': 0.6 | |
} | |
def predict(text: str): | |
logger.info("Starting prediction for text: %s", text) | |
try: | |
enc = tokenizer( | |
text, add_special_tokens=True, max_length=128, | |
padding='max_length', truncation=True, return_tensors='pt' | |
) | |
input_ids = enc['input_ids'].to(device) | |
attention_mask = enc['attention_mask'].to(device) | |
with torch.no_grad(): | |
emo_logits, bin_logits = model(input_ids, attention_mask) | |
emo_pred = torch.argmax(emo_logits, dim=1).item() | |
bin_pred = { | |
col: (torch.sigmoid(bin_logits[col]) > threshold_dict[col]).float().item() | |
for col in binary_cols | |
} | |
emo_label = ['tiêu cực', 'trung tính', 'tích cực'][emo_pred] | |
bin_labels = {col: ('có' if bin_pred[col] == 1 else 'không') for col in binary_cols} | |
logger.info("Prediction completed: emotion=%s, binary=%s", emo_label, bin_labels) | |
return emo_label, bin_labels | |
except Exception as e: | |
logger.error("Error during prediction: %s", str(e)) | |
raise | |
async def read_root(request: Request): | |
logger.info("Received GET request for /") | |
try: | |
response = templates.TemplateResponse("index.html", {"request": request}) | |
logger.info("Successfully rendered index.html") | |
return response | |
except Exception as e: | |
logger.error("Error rendering index.html: %s", str(e)) | |
raise | |
async def predict_text(input: TextInput): | |
logger.info("Received POST request for /predict with input: %s", input.text) | |
try: | |
emotion, binary = predict(input.text) | |
logger.info("Sending prediction response: emotion=%s, binary=%s", emotion, binary) | |
return {"emotion": emotion, "binary": binary} | |
except Exception as e: | |
logger.error("Error in predict_text endpoint: %s", str(e)) | |
raise | |
if __name__ == "__main__": | |
port = int(os.getenv("PORT", 8000)) | |
logger.info("Starting Uvicorn server on port %d", port) | |
uvicorn.run("main:app", host="0.0.0.0", port=port) |