BhuiyanMasum commited on
Commit
35f6479
·
1 Parent(s): ae22ae6
Files changed (3) hide show
  1. app.py +83 -33
  2. model_weights.pth +3 -0
  3. model.pth → vocab.pkl +2 -2
app.py CHANGED
@@ -4,26 +4,76 @@ import torch
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.responses import JSONResponse
6
  from pydantic import BaseModel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
- app = FastAPI()
9
 
10
- # Load the pre-trained model from within the container
11
- model_uri = "model.pth"
12
- model = torch.load(model_uri, weights_only=False)
 
 
 
 
13
 
14
- # Define input schema for JSON requests
15
- class ImageInput(BaseModel):
16
- image_path: str
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Preprocess the image
20
- def preprocess_image(image):
21
- image = image.convert('L') # Convert to grayscale
22
- image = image.resize((28, 28))
23
- image = np.array(image) / 255.0 # Normalize to [0, 1]
24
- image = (image - 0.1307) / 0.3081 # Standardize
25
- image = torch.tensor(image).unsqueeze(0).float() # Convert to tensor with batch dimension
26
- return image
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  # Root endpoint
@@ -31,21 +81,21 @@ def preprocess_image(image):
31
  def greet_json():
32
  return {"Hello": "World!"}
33
 
34
-
35
- # Predict endpoint for JSON input
36
- @app.post("/predict")
37
- async def predict_image(file: UploadFile = File(...)):
38
- try:
39
- # Read and preprocess the uploaded image
40
- image = Image.open(file.file)
41
- image = preprocess_image(image)
42
-
43
- # Make prediction
44
- model.eval()
45
- with torch.no_grad():
46
- output = model(image)
47
- prediction = output.argmax(dim=1).item()
48
-
49
- return JSONResponse(content={"prediction": f"The digit is {prediction}"})
50
- except Exception as e:
51
- return JSONResponse(content={"error": str(e)}, status_code=500)
 
4
  from fastapi import FastAPI, UploadFile, File
5
  from fastapi.responses import JSONResponse
6
  from pydantic import BaseModel
7
+ import torch.nn as nn
8
+ import pickle
9
+ import re
10
+
11
+ token_2_id = None
12
+ # Load the dictionary later
13
+ with open(".\\vocab.pkl", "rb") as f:
14
+ token_2_id = pickle.load(f)
15
+ print(token_2_id)
16
+
17
+
18
+ def normalize(text):
19
+ text = text.lower()
20
+ text = re.sub(r'[^a-z0-9\s]', '', text)
21
+ text = ' '.join(text.split())
22
+ return text
23
+
24
+
25
+ def tokenize(text):
26
+ tokens = text.split()
27
+ return tokens
28
+
29
+
30
+ def convert_tokens_2_ids(tokens):
31
+ input_ids = [
32
+ token_2_id.get(token, token_2_id['<UNK>']) for token in tokens
33
+ ]
34
+ return input_ids
35
 
 
36
 
37
+ def process_text(text, aspect):
38
+ text_aspect_pair = text + ' ' + aspect
39
+ normalized_text = normalize(text_aspect_pair)
40
+ tokens = tokenize(normalized_text)
41
+ input_ids = convert_tokens_2_ids(tokens)
42
+ input_ids = torch.tensor(input_ids).unsqueeze(0)
43
+ return input_ids
44
 
 
 
 
45
 
46
+ class ABSA(nn.Module):
47
+ def __init__(self, vocab_size, num_labels=3):
48
+ super(ABSA, self).__init__()
49
+ self.vocab_size = vocab_size
50
+ self.num_labels = num_labels
51
+ self.embedding_layer = nn.Embedding(
52
+ num_embeddings=vocab_size, embedding_dim=256
53
+ )
54
+ self.lstm_layer = nn.LSTM(
55
+ input_size=256,
56
+ hidden_size=512,
57
+ batch_first=True,
58
+ )
59
 
60
+ self.fc_layer = nn.Linear(
61
+ in_features=512,
62
+ out_features=self.num_labels
63
+ )
64
+
65
+ def forward(self, x):
66
+ embeddings = self.embedding_layer(x)
67
+ lstm_out, _ = self.lstm_layer(embeddings)
68
+ logits = self.fc_layer(lstm_out[:, -1, :])
69
+ return logits
70
+
71
+
72
+ model = ABSA(vocab_size=len(token_2_id.keys()), num_labels=3)
73
+ model = model.load_state_dict(torch.load('model_weights.pth'))
74
+
75
+ print("Model loaded successfully")
76
+ app = FastAPI()
77
 
78
 
79
  # Root endpoint
 
81
  def greet_json():
82
  return {"Hello": "World!"}
83
 
84
+ #
85
+ # # Predict endpoint for JSON input
86
+ # @app.post("/predict")
87
+ # async def predict_image(file: UploadFile = File(...)):
88
+ # try:
89
+ # # Read and preprocess the uploaded image
90
+ # image = Image.open(file.file)
91
+ # image = preprocess_image(image)
92
+ #
93
+ # # Make prediction
94
+ # model.eval()
95
+ # with torch.no_grad():
96
+ # output = model(image)
97
+ # prediction = output.argmax(dim=1).item()
98
+ #
99
+ # return JSONResponse(content={"prediction": f"The digit is {prediction}"})
100
+ # except Exception as e:
101
+ # return JSONResponse(content={"error": str(e)}, status_code=500)
model_weights.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8b1a6752d09f7a562dd93c6d857987c4df789908e132287bef226ac82c6c2a80
3
+ size 10208198
model.pth → vocab.pkl RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a25d6fbe70a15a02f24ff1e586b44b4fb0a626193293b44fb718a18851b9f12
3
- size 445812
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e195b47d3a84c730ff3a7f25df52b335174ba06914d9404bffa7f4422603d60
3
+ size 47213