MyDockerAPI / app.py
BhuiyanMasum
changes
35f6479
raw
history blame
2.63 kB
import numpy as np
from PIL import Image
import torch
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch.nn as nn
import pickle
import re
token_2_id = None
# Load the dictionary later
with open(".\\vocab.pkl", "rb") as f:
token_2_id = pickle.load(f)
print(token_2_id)
def normalize(text):
text = text.lower()
text = re.sub(r'[^a-z0-9\s]', '', text)
text = ' '.join(text.split())
return text
def tokenize(text):
tokens = text.split()
return tokens
def convert_tokens_2_ids(tokens):
input_ids = [
token_2_id.get(token, token_2_id['<UNK>']) for token in tokens
]
return input_ids
def process_text(text, aspect):
text_aspect_pair = text + ' ' + aspect
normalized_text = normalize(text_aspect_pair)
tokens = tokenize(normalized_text)
input_ids = convert_tokens_2_ids(tokens)
input_ids = torch.tensor(input_ids).unsqueeze(0)
return input_ids
class ABSA(nn.Module):
def __init__(self, vocab_size, num_labels=3):
super(ABSA, self).__init__()
self.vocab_size = vocab_size
self.num_labels = num_labels
self.embedding_layer = nn.Embedding(
num_embeddings=vocab_size, embedding_dim=256
)
self.lstm_layer = nn.LSTM(
input_size=256,
hidden_size=512,
batch_first=True,
)
self.fc_layer = nn.Linear(
in_features=512,
out_features=self.num_labels
)
def forward(self, x):
embeddings = self.embedding_layer(x)
lstm_out, _ = self.lstm_layer(embeddings)
logits = self.fc_layer(lstm_out[:, -1, :])
return logits
model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3)
model = model.load_state_dict(torch.load('model_weights.pth'))
print("Model loaded successfully")
app = FastAPI()
# Root endpoint
@app.get("/")
def greet_json():
return {"Hello": "World!"}
#
# # Predict endpoint for JSON input
# @app.post("/predict")
# async def predict_image(file: UploadFile = File(...)):
# try:
# # Read and preprocess the uploaded image
# image = Image.open(file.file)
# image = preprocess_image(image)
#
# # Make prediction
# model.eval()
# with torch.no_grad():
# output = model(image)
# prediction = output.argmax(dim=1).item()
#
# return JSONResponse(content={"prediction": f"The digit is {prediction}"})
# except Exception as e:
# return JSONResponse(content={"error": str(e)}, status_code=500)