Spaces:
Sleeping
Sleeping
import numpy as np | |
from PIL import Image | |
import torch | |
from fastapi import FastAPI, UploadFile, File | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
import torch.nn as nn | |
import pickle | |
import re | |
token_2_id = None | |
# Load the dictionary later | |
with open(".\\vocab.pkl", "rb") as f: | |
token_2_id = pickle.load(f) | |
print(token_2_id) | |
def normalize(text): | |
text = text.lower() | |
text = re.sub(r'[^a-z0-9\s]', '', text) | |
text = ' '.join(text.split()) | |
return text | |
def tokenize(text): | |
tokens = text.split() | |
return tokens | |
def convert_tokens_2_ids(tokens): | |
input_ids = [ | |
token_2_id.get(token, token_2_id['<UNK>']) for token in tokens | |
] | |
return input_ids | |
def process_text(text, aspect): | |
text_aspect_pair = text + ' ' + aspect | |
normalized_text = normalize(text_aspect_pair) | |
tokens = tokenize(normalized_text) | |
input_ids = convert_tokens_2_ids(tokens) | |
input_ids = torch.tensor(input_ids).unsqueeze(0) | |
return input_ids | |
class ABSA(nn.Module): | |
def __init__(self, vocab_size, num_labels=3): | |
super(ABSA, self).__init__() | |
self.vocab_size = vocab_size | |
self.num_labels = num_labels | |
self.embedding_layer = nn.Embedding( | |
num_embeddings=vocab_size, embedding_dim=256 | |
) | |
self.lstm_layer = nn.LSTM( | |
input_size=256, | |
hidden_size=512, | |
batch_first=True, | |
) | |
self.fc_layer = nn.Linear( | |
in_features=512, | |
out_features=self.num_labels | |
) | |
def forward(self, x): | |
embeddings = self.embedding_layer(x) | |
lstm_out, _ = self.lstm_layer(embeddings) | |
logits = self.fc_layer(lstm_out[:, -1, :]) | |
return logits | |
model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3) | |
model = model.load_state_dict(torch.load('model_weights.pth')) | |
print("Model loaded successfully") | |
app = FastAPI() | |
# Root endpoint | |
def greet_json(): | |
return {"Hello": "World!"} | |
# | |
# # Predict endpoint for JSON input | |
# @app.post("/predict") | |
# async def predict_image(file: UploadFile = File(...)): | |
# try: | |
# # Read and preprocess the uploaded image | |
# image = Image.open(file.file) | |
# image = preprocess_image(image) | |
# | |
# # Make prediction | |
# model.eval() | |
# with torch.no_grad(): | |
# output = model(image) | |
# prediction = output.argmax(dim=1).item() | |
# | |
# return JSONResponse(content={"prediction": f"The digit is {prediction}"}) | |
# except Exception as e: | |
# return JSONResponse(content={"error": str(e)}, status_code=500) | |