import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch.nn as nn import pickle import re token_2_id = None # Load the dictionary later with open(r"vocab.pkl", "rb") as f: token_2_id = pickle.load(f) print(token_2_id) def normalize(text): text = text.lower() text = re.sub(r'[^a-z0-9\s]', '', text) text = ' '.join(text.split()) return text def tokenize(text): tokens = text.split() return tokens def convert_tokens_2_ids(tokens): input_ids = [ token_2_id.get(token, token_2_id['']) for token in tokens ] return input_ids def process_text(text, aspect): text_aspect_pair = text + ' ' + aspect normalized_text = normalize(text_aspect_pair) tokens = tokenize(normalized_text) input_ids = convert_tokens_2_ids(tokens) input_ids = torch.tensor(input_ids).unsqueeze(0) return input_ids class ABSA(nn.Module): def __init__(self, vocab_size, num_labels=3): super(ABSA, self).__init__() self.vocab_size = vocab_size self.num_labels = num_labels self.embedding_layer = nn.Embedding( num_embeddings=vocab_size, embedding_dim=256 ) self.lstm_layer = nn.LSTM( input_size=256, hidden_size=512, batch_first=True, ) self.fc_layer = nn.Linear( in_features=512, out_features=self.num_labels ) def forward(self, x): embeddings = self.embedding_layer(x) lstm_out, _ = self.lstm_layer(embeddings) logits = self.fc_layer(lstm_out[:, -1, :]) return logits model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3) model.load_state_dict(torch.load('model_weights.pth')) model.eval() print("Model loaded successfully") app = FastAPI() # Root endpoint @app.get("/") def greet_json(): return {"Hello": "World!"} # Input model for request validation class TextAspectInput(BaseModel): text: str aspect: str # Sentiment labels SENTIMENT_LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"} # Predict endpoint @app.post("/predict") async def predict_sentiment(input_data: TextAspectInput): print(input_data) try: # Extract text and aspect text = input_data.text aspect = input_data.aspect # Process input input_ids = process_text(text, aspect) print("Process text: ", input_ids) # Make prediction try: with torch.no_grad(): logits = model(input_ids) probabilities = torch.softmax(logits, dim=-1) prediction = probabilities.argmax(dim=-1).item() sentiment = SENTIMENT_LABELS[prediction] except Exception as e: print(e) return {"sentiment": sentiment, "probabilities": probabilities.squeeze().tolist()} except Exception as e: raise HTTPException(status_code=500, detail=str(e))