rohithk-03 commited on
Commit
76ba6f0
·
1 Parent(s): dddca60

update model code

Browse files
Files changed (2) hide show
  1. app.py +26 -0
  2. requirements.txt +1 -1
app.py CHANGED
@@ -17,6 +17,28 @@ import cloudinary.uploader
17
  from a import main
18
  import numpy as np
19
  import torchvision.transforms as transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # Initialize Flask app
21
  app = Flask(__name__)
22
 
@@ -272,6 +294,10 @@ def predict():
272
 
273
  # Initialize Model, Loss, and Optimizer
274
  model_new = ResNetRegression()
 
 
 
 
275
  image = Image.open(temp_path).convert("RGB")
276
  output = model_new(transform(image).unsqueeze(0))
277
  stage = output.item()
 
17
  from a import main
18
  import numpy as np
19
  import torchvision.transforms as transforms
20
+ import pandas as pd
21
+ import nibabel as nib
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ import torchvision
26
+ import torchvision.transforms as transforms
27
+ import cv2
28
+ from PIL import Image
29
+ from sklearn.model_selection import train_test_split
30
+ import os
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.optim as optim
34
+ import torchvision.transforms as transforms
35
+ import torchvision.models as models
36
+ from torch.utils.data import Dataset, DataLoader
37
+ import pandas as pd
38
+ from PIL import Image
39
+ import os
40
+ from sklearn.model_selection import train_test_split
41
+ from sklearn.preprocessing import MinMaxScaler
42
  # Initialize Flask app
43
  app = Flask(__name__)
44
 
 
294
 
295
  # Initialize Model, Loss, and Optimizer
296
  model_new = ResNetRegression()
297
+ checkpoint = torch.load(
298
+ "/home/user/app/a.pth", weights_only=False, map_location=torch.device('cpu'))
299
+ checkpoint = remove_module_from_checkpoint(checkpoint)
300
+ model_new.load_state_dict(checkpoint['model_state_dict'])
301
  image = Image.open(temp_path).convert("RGB")
302
  output = model_new(transform(image).unsqueeze(0))
303
  stage = output.item()
requirements.txt CHANGED
@@ -24,4 +24,4 @@ cloudinary
24
  PyPDF2
25
  scikit-learn
26
  pandas
27
- gdown
 
24
  PyPDF2
25
  scikit-learn
26
  pandas
27
+ gdown