File size: 2,993 Bytes
1a9e80b
55489c3
1a9e80b
35f6479
 
 
 
 
 
6f2a3aa
35f6479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e7060d
 
35f6479
 
 
 
 
 
 
1a9e80b
 
35f6479
 
 
 
 
 
 
 
 
 
 
 
 
1a9e80b
35f6479
 
 
 
 
 
 
 
 
 
 
 
 
7074e72
 
35f6479
 
 
1a9e80b
 
 
1e7060d
 
 
1a9e80b
01e8001
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299b551
01e8001
 
299b551
 
 
 
 
 
 
 
 
01e8001
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch.nn as nn
import pickle
import re

token_2_id = None
# Load the dictionary later
with open(r"vocab.pkl", "rb") as f:
    token_2_id = pickle.load(f)
print(token_2_id)


def normalize(text):
    text = text.lower()
    text = re.sub(r'[^a-z0-9\s]', '', text)
    text = ' '.join(text.split())
    return text


def tokenize(text):
    tokens = text.split()
    return tokens


def convert_tokens_2_ids(tokens):
    input_ids = [
        token_2_id.get(token, token_2_id['<UNK>']) for token in tokens
    ]
    return input_ids


def process_text(text, aspect):
    text_aspect_pair = text + ' ' + aspect
    normalized_text = normalize(text_aspect_pair)
    tokens = tokenize(normalized_text)
    input_ids = convert_tokens_2_ids(tokens)
    input_ids = torch.tensor(input_ids).unsqueeze(0)
    return input_ids


class ABSA(nn.Module):
    def __init__(self, vocab_size, num_labels=3):
        super(ABSA, self).__init__()
        self.vocab_size = vocab_size
        self.num_labels = num_labels
        self.embedding_layer = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=256
        )
        self.lstm_layer = nn.LSTM(
            input_size=256,
            hidden_size=512,
            batch_first=True,
        )

        self.fc_layer = nn.Linear(
            in_features=512,
            out_features=self.num_labels
        )

    def forward(self, x):
        embeddings = self.embedding_layer(x)
        lstm_out, _ = self.lstm_layer(embeddings)
        logits = self.fc_layer(lstm_out[:, -1, :])
        return logits


model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3)
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

print("Model loaded successfully")
app = FastAPI()


# Root endpoint
@app.get("/")
def greet_json():
    return {"Hello": "World!"}

# Input model for request validation
class TextAspectInput(BaseModel):
    text: str
    aspect: str


# Sentiment labels
SENTIMENT_LABELS = {0: "Negative", 1: "Neutral", 2: "Positive"}


# Predict endpoint
@app.post("/predict")
async def predict_sentiment(input_data: TextAspectInput):
    print(input_data)
    try:
        # Extract text and aspect
        text = input_data.text
        aspect = input_data.aspect

        # Process input
        input_ids = process_text(text, aspect)
        print("Process text: ", input_ids)

        # Make prediction
        try:
            with torch.no_grad():
                logits = model(input_ids)
                probabilities = torch.softmax(logits, dim=-1)
                prediction = probabilities.argmax(dim=-1).item()
                sentiment = SENTIMENT_LABELS[prediction]
        except Exception as e:
            print(e)
            
        return {"sentiment": sentiment, "probabilities": probabilities.squeeze().tolist()}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))