import numpy as np from PIL import Image import torch from fastapi import FastAPI, UploadFile, File from fastapi.responses import JSONResponse from pydantic import BaseModel import torch.nn as nn import pickle import re token_2_id = None # Load the dictionary later with open(".\\vocab.pkl", "rb") as f: token_2_id = pickle.load(f) print(token_2_id) def normalize(text): text = text.lower() text = re.sub(r'[^a-z0-9\s]', '', text) text = ' '.join(text.split()) return text def tokenize(text): tokens = text.split() return tokens def convert_tokens_2_ids(tokens): input_ids = [ token_2_id.get(token, token_2_id['']) for token in tokens ] return input_ids def process_text(text, aspect): text_aspect_pair = text + ' ' + aspect normalized_text = normalize(text_aspect_pair) tokens = tokenize(normalized_text) input_ids = convert_tokens_2_ids(tokens) input_ids = torch.tensor(input_ids).unsqueeze(0) return input_ids class ABSA(nn.Module): def __init__(self, vocab_size, num_labels=3): super(ABSA, self).__init__() self.vocab_size = vocab_size self.num_labels = num_labels self.embedding_layer = nn.Embedding( num_embeddings=vocab_size, embedding_dim=256 ) self.lstm_layer = nn.LSTM( input_size=256, hidden_size=512, batch_first=True, ) self.fc_layer = nn.Linear( in_features=512, out_features=self.num_labels ) def forward(self, x): embeddings = self.embedding_layer(x) lstm_out, _ = self.lstm_layer(embeddings) logits = self.fc_layer(lstm_out[:, -1, :]) return logits model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3) model = model.load_state_dict(torch.load('model_weights.pth')) print("Model loaded successfully") app = FastAPI() # Root endpoint @app.get("/") def greet_json(): return {"Hello": "World!"} # # # Predict endpoint for JSON input # @app.post("/predict") # async def predict_image(file: UploadFile = File(...)): # try: # # Read and preprocess the uploaded image # image = Image.open(file.file) # image = preprocess_image(image) # # # Make prediction # model.eval() # with torch.no_grad(): # output = model(image) # prediction = output.argmax(dim=1).item() # # return JSONResponse(content={"prediction": f"The digit is {prediction}"}) # except Exception as e: # return JSONResponse(content={"error": str(e)}, status_code=500)