Spaces:
Running
Running
File size: 5,406 Bytes
eb16c9f c017e96 eb16c9f c017e96 eb16c9f c017e96 eb16c9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
os.environ.setdefault("HF_HOME", "/app/hf_cache")
import logging
from fastapi import FastAPI, Request
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer
import uvicorn
from huggingface_hub import hf_hub_download
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
templates = Jinja2Templates(directory="templates")
HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
HF_REPO = "ThanhDT127/pho-bert-bilstm"
HF_FILE = "best_model_1.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_DIR = "models"
os.makedirs(MODEL_DIR, exist_ok=True)
MODEL_PATH = os.path.join(MODEL_DIR, HF_FILE)
try:
if not os.path.isfile(MODEL_PATH):
logger.info("Downloading model from Hugging Face Hub")
MODEL_PATH = hf_hub_download(
repo_id=HF_REPO,
filename=HF_FILE,
cache_dir=os.environ["HF_HOME"],
force_filename=HF_FILE,
token=HF_TOKEN
)
logger.info("Loading model from %s", MODEL_PATH)
model_state_dict = torch.load(MODEL_PATH, map_location=device)
logger.info("Model loaded successfully")
except Exception as e:
logger.error("Error loading model: %s", str(e))
raise
class TextInput(BaseModel):
text: str
class BertBiLSTMClassifier(nn.Module):
def __init__(self, bert_model_name, num_emotion_classes, binary_cols, lstm_hidden_size=128, dropout=0.4):
super().__init__()
self.bert = AutoModel.from_pretrained(bert_model_name)
self.lstm = nn.LSTM(
input_size=self.bert.config.hidden_size,
hidden_size=lstm_hidden_size,
num_layers=1,
batch_first=True,
bidirectional=True
)
self.dropout = nn.Dropout(dropout)
self.emotion_fc = nn.Linear(lstm_hidden_size * 2, num_emotion_classes)
self.binary_fcs = nn.ModuleDict({
col: nn.Linear(lstm_hidden_size * 2, 1)
for col in binary_cols
})
def forward(self, input_ids, attention_mask):
bert_out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
seq_out = bert_out.last_hidden_state
lstm_out, _ = self.lstm(seq_out)
last_hidden = lstm_out[:, -1, :]
dropped = self.dropout(last_hidden)
emo_logits = self.emotion_fc(dropped)
bin_logits = {
col: self.binary_fcs[col](dropped).squeeze(-1)
for col in self.binary_fcs
}
return emo_logits, bin_logits
tokenizer = AutoTokenizer.from_pretrained(
"vinai/phobert-base",
use_fast=False,
cache_dir=os.environ["HF_HOME"]
)
binary_cols = [
'sản phẩm', 'giá cả', 'vận chuyển',
'thái độ và dịch vụ khách hàng', 'khác'
]
model = BertBiLSTMClassifier(
bert_model_name="vinai/phobert-base",
num_emotion_classes=3,
binary_cols=binary_cols,
lstm_hidden_size=128
).to(device)
# Load model state dict
model.load_state_dict(model_state_dict)
model.eval()
threshold_dict = {
'sản phẩm': 0.28,
'giá cả': 0.58,
'vận chuyển': 0.58,
'thái độ và dịch vụ khách hàng': 0.70,
'khác': 0.6
}
def predict(text: str):
logger.info("Starting prediction for text: %s", text)
try:
enc = tokenizer(
text, add_special_tokens=True, max_length=128,
padding='max_length', truncation=True, return_tensors='pt'
)
input_ids = enc['input_ids'].to(device)
attention_mask = enc['attention_mask'].to(device)
with torch.no_grad():
emo_logits, bin_logits = model(input_ids, attention_mask)
emo_pred = torch.argmax(emo_logits, dim=1).item()
bin_pred = {
col: (torch.sigmoid(bin_logits[col]) > threshold_dict[col]).float().item()
for col in binary_cols
}
emo_label = ['tiêu cực', 'trung tính', 'tích cực'][emo_pred]
bin_labels = {col: ('có' if bin_pred[col] == 1 else 'không') for col in binary_cols}
logger.info("Prediction completed: emotion=%s, binary=%s", emo_label, bin_labels)
return emo_label, bin_labels
except Exception as e:
logger.error("Error during prediction: %s", str(e))
raise
@app.get("/")
async def read_root(request: Request):
logger.info("Received GET request for /")
try:
response = templates.TemplateResponse("index.html", {"request": request})
logger.info("Successfully rendered index.html")
return response
except Exception as e:
logger.error("Error rendering index.html: %s", str(e))
raise
@app.post("/predict")
async def predict_text(input: TextInput):
logger.info("Received POST request for /predict with input: %s", input.text)
try:
emotion, binary = predict(input.text)
logger.info("Sending prediction response: emotion=%s, binary=%s", emotion, binary)
return {"emotion": emotion, "binary": binary}
except Exception as e:
logger.error("Error in predict_text endpoint: %s", str(e))
raise
if __name__ == "__main__":
port = int(os.getenv("PORT", 8000))
logger.info("Starting Uvicorn server on port %d", port)
uvicorn.run("main:app", host="0.0.0.0", port=port) |