Spaces:
Sleeping
Sleeping
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import torch.nn as nn | |
import pickle | |
import re | |
token_2_id = None | |
# Load the dictionary later | |
with open(r"vocab.pkl", "rb") as f: | |
token_2_id = pickle.load(f) | |
print(token_2_id) | |
def normalize(text): | |
text = text.lower() | |
text = re.sub(r'[^a-z0-9\s]', '', text) | |
text = ' '.join(text.split()) | |
return text | |
def tokenize(text): | |
tokens = text.split() | |
return tokens | |
def convert_tokens_2_ids(tokens): | |
input_ids = [ | |
token_2_id.get(token, token_2_id['<UNK>']) for token in tokens | |
] | |
return input_ids | |
def process_text(text, aspect): | |
text_aspect_pair = text + ' ' + aspect | |
normalized_text = normalize(text_aspect_pair) | |
tokens = tokenize(normalized_text) | |
input_ids = convert_tokens_2_ids(tokens) | |
input_ids = torch.tensor(input_ids).unsqueeze(0) | |
return input_ids | |
class ABSA(nn.Module): | |
def __init__(self, vocab_size, num_labels=3): | |
super(ABSA, self).__init__() | |
self.vocab_size = vocab_size | |
self.num_labels = num_labels | |
self.embedding_layer = nn.Embedding( | |
num_embeddings=vocab_size, embedding_dim=256 | |
) | |
self.lstm_layer = nn.LSTM( | |
input_size=256, | |
hidden_size=512, | |
batch_first=True, | |
) | |
self.fc_layer = nn.Linear( | |
in_features=512, | |
out_features=self.num_labels | |
) | |
def forward(self, x): | |
embeddings = self.embedding_layer(x) | |
lstm_out, _ = self.lstm_layer(embeddings) | |
logits = self.fc_layer(lstm_out[:, -1, :]) | |
return logits | |
model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3) | |
model.load_state_dict(torch.load('model_weights.pth')) | |
model.eval() | |
print("Model loaded successfully") | |
app = FastAPI() | |
# Root endpoint | |
def greet_json(): | |
return {"Hello": "World!"} | |
# Input model for request validation | |
class TextAspectInput(BaseModel): | |
text: str | |
aspect: str | |
# Sentiment labels | |
SENTIMENT_LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"} | |
# Predict endpoint | |
async def predict_sentiment(input_data: TextAspectInput): | |
print(input_data) | |
try: | |
# Extract text and aspect | |
text = input_data.text | |
aspect = input_data.aspect | |
# Process input | |
input_ids = process_text(text, aspect) | |
print("Process text: ", input_ids) | |
# Make prediction | |
try: | |
with torch.no_grad(): | |
logits = model(input_ids) | |
probabilities = torch.softmax(logits, dim=-1) | |
prediction = probabilities.argmax(dim=-1).item() | |
sentiment = SENTIMENT_LABELS[prediction] | |
except Exception as e: | |
print(e) | |
return {"sentiment": sentiment, "probabilities": probabilities.squeeze().tolist()} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |