rohithk-03 commited on
Commit
dddca60
·
1 Parent(s): 537e97c

update model code

Browse files
Files changed (1) hide show
  1. app.py +15 -1
app.py CHANGED
@@ -258,8 +258,22 @@ def predict():
258
  if (is_mri_image(temp_path)):
259
  return jsonify({"message": "Not an mri image", "confidence": 0.95, "saved_path": image_save_path})
260
  a, b = model.check_file(temp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  image = Image.open(temp_path).convert("RGB")
262
- output = model(transform(image).unsqueeze(0))
263
  stage = output.item()
264
  if stage <= 2.0:
265
  stage = "Mild"
 
258
  if (is_mri_image(temp_path)):
259
  return jsonify({"message": "Not an mri image", "confidence": 0.95, "saved_path": image_save_path})
260
  a, b = model.check_file(temp_path)
261
+
262
+ class ResNetRegression(nn.Module):
263
+ def __init__(self):
264
+ super(ResNetRegression, self).__init__()
265
+ self.model = models.resnet34(pretrained=True)
266
+ in_features = self.model.fc.in_features
267
+ # Change output layer for regression
268
+ self.model.fc = nn.Linear(in_features, 1)
269
+
270
+ def forward(self, x):
271
+ return self.model(x)
272
+
273
+ # Initialize Model, Loss, and Optimizer
274
+ model_new = ResNetRegression()
275
  image = Image.open(temp_path).convert("RGB")
276
+ output = model_new(transform(image).unsqueeze(0))
277
  stage = output.item()
278
  if stage <= 2.0:
279
  stage = "Mild"