dhruv2842 commited on
Commit
751c399
Β·
verified Β·
1 Parent(s): 5b79cd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -64
app.py CHANGED
@@ -1,10 +1,11 @@
1
  from flask import Flask, request, jsonify, send_file
2
- from tensorflow.keras.models import load_model, Model
3
  from PIL import Image
4
- import numpy as np
 
 
5
  import os
6
  import cv2
7
- import tensorflow as tf
8
  from datetime import datetime
9
  import sqlite3
10
 
@@ -17,6 +18,7 @@ if not os.path.exists(OUTPUT_DIR):
17
 
18
  DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
19
 
 
20
  def init_db():
21
  conn = sqlite3.connect(DB_PATH)
22
  cursor = conn.cursor()
@@ -33,65 +35,27 @@ def init_db():
33
  conn.commit()
34
  conn.close()
35
 
 
36
  init_db()
37
 
38
- # βœ… Load Model
39
- model = load_model('mobilenet_glaucoma_model.h5', compile=False)
 
40
 
41
  # βœ… Preprocess Image
42
- def preprocess_image(img):
43
- img = img.resize((224, 224))
44
- img = np.array(img) / 255.0
45
- img = np.expand_dims(img, axis=0)
46
- return img
47
-
48
- # βœ… Grad-CAM Generation
49
- def make_gradcam(img_array, model, last_conv_layer_name='Conv_1_bn'):
50
- """Generate Grad-CAM for the given image and model."""
51
- last_conv_layer = model.get_layer(last_conv_layer_name)
52
- grad_model = Model(inputs=model.inputs, outputs=[last_conv_layer.output, model.output])
53
-
54
- with tf.GradientTape() as tape:
55
- conv_outputs, predictions = grad_model(img_array)
56
- loss = predictions[:, 0]
57
- grads = tape.gradient(loss, conv_outputs)
58
-
59
- pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
60
- conv_outputs = conv_outputs[0].numpy()
61
- pooled_grads = pooled_grads.numpy()
62
-
63
- for i in range(conv_outputs.shape[-1]):
64
- conv_outputs[..., i] *= pooled_grads[i]
65
-
66
- heatmap = tf.reduce_mean(conv_outputs, axis=-1).numpy()
67
- heatmap = np.maximum(heatmap, 0)
68
- heatmap /= np.max(heatmap)
69
-
70
- return heatmap
71
-
72
- # βœ… Save Grad-CAM Overlay
73
- def save_gradcam_image(original_img, heatmap, filename='gradcam.png', output_dir=OUTPUT_DIR):
74
- """Save the Grad-CAM overlay image and return its path."""
75
- img = np.array(original_img.resize((224, 224)))
76
- heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
77
- heatmap = np.uint8(255 * heatmap)
78
-
79
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
80
- overlay = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
81
-
82
- filepath = os.path.join(output_dir, filename)
83
- cv2.imwrite(filepath, overlay)
84
-
85
- return filepath
86
 
87
  @app.route('/')
88
  def home():
89
- return "Glaucoma Detection Flask API is running!"
90
 
91
  @app.route("/test_file")
92
  def test_file():
93
- """Check if the model file is present and readable."""
94
- filepath = "mobilenet_glaucoma_model.h5"
95
  if os.path.exists(filepath):
96
  return f"βœ… Model file found at: {filepath}"
97
  else:
@@ -99,7 +63,7 @@ def test_file():
99
 
100
  @app.route('/predict', methods=['POST'])
101
  def predict():
102
- """Perform prediction, save results (including uploaded image), and save to SQLite database."""
103
  if 'file' not in request.files:
104
  return jsonify({'error': 'No file uploaded'}), 400
105
 
@@ -116,26 +80,24 @@ def predict():
116
 
117
  # βœ… Perform prediction
118
  img = Image.open(uploaded_file_path).convert('RGB')
119
- img_array = preprocess_image(img)
120
 
121
- prediction = model.predict(img_array)[0]
 
 
 
122
  glaucoma_prob = 1 - prediction[0]
123
  normal_prob = prediction[0]
124
  result = 'Glaucoma' if glaucoma_prob > normal_prob else 'Normal'
125
  confidence = float(glaucoma_prob) if result == 'Glaucoma' else float(normal_prob)
126
 
127
- # βœ… Grad-CAM
128
- heatmap = make_gradcam(img_array, model, last_conv_layer_name='Conv_1_bn')
129
- gradcam_filename = f"gradcam_{timestamp}.png"
130
- save_gradcam_image(img, heatmap, filename=gradcam_filename)
131
-
132
- # βœ… Save results to SQLite
133
  conn = sqlite3.connect(DB_PATH)
134
  cursor = conn.cursor()
135
  cursor.execute("""
136
  INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp)
137
  VALUES (?, ?, ?, ?, ?)
138
- """, (uploaded_filename, result, confidence, gradcam_filename, datetime.now().isoformat()))
139
  conn.commit()
140
  conn.close()
141
 
@@ -144,13 +106,14 @@ def predict():
144
  'confidence': confidence,
145
  'normal_probability': float(normal_prob),
146
  'glaucoma_probability': float(glaucoma_prob),
147
- 'gradcam_image': gradcam_filename,
148
  'image_filename': uploaded_filename
149
  })
150
 
151
  except Exception as e:
152
  return jsonify({'error': str(e)}), 500
153
 
 
154
  @app.route('/results', methods=['GET'])
155
  def results():
156
  """List all results from the SQLite database."""
@@ -175,7 +138,7 @@ def results():
175
 
176
  @app.route('/gradcam/<filename>')
177
  def get_gradcam(filename):
178
- """Serve the Grad-CAM overlay image."""
179
  filepath = os.path.join(OUTPUT_DIR, filename)
180
  if os.path.exists(filepath):
181
  return send_file(filepath, mimetype='image/png')
@@ -191,5 +154,6 @@ def get_image(filename):
191
  else:
192
  return jsonify({'error': 'File not found'}), 404
193
 
 
194
  if __name__ == '__main__':
195
  app.run(host='0.0.0.0', port=7860)
 
1
  from flask import Flask, request, jsonify, send_file
 
2
  from PIL import Image
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
  import os
7
  import cv2
8
+ import numpy as np
9
  from datetime import datetime
10
  import sqlite3
11
 
 
18
 
19
  DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
20
 
21
+
22
  def init_db():
23
  conn = sqlite3.connect(DB_PATH)
24
  cursor = conn.cursor()
 
35
  conn.commit()
36
  conn.close()
37
 
38
+
39
  init_db()
40
 
41
+ # βœ… Load PyTorch Model
42
+ model = torch.load('your_model.pt', map_location=torch.device('cpu'))
43
+ model.eval()
44
 
45
  # βœ… Preprocess Image
46
+ transform = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.ToTensor(),
49
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @app.route('/')
52
  def home():
53
+ return "Glaucoma Detection Flask API (PyTorch) is running!"
54
 
55
  @app.route("/test_file")
56
  def test_file():
57
+ """Check if the .pt model file is present and readable."""
58
+ filepath = "densenet169_seed40_best.pt"
59
  if os.path.exists(filepath):
60
  return f"βœ… Model file found at: {filepath}"
61
  else:
 
63
 
64
  @app.route('/predict', methods=['POST'])
65
  def predict():
66
+ """Perform prediction using PyTorch, save results (including uploaded image), and save to SQLite database."""
67
  if 'file' not in request.files:
68
  return jsonify({'error': 'No file uploaded'}), 400
69
 
 
80
 
81
  # βœ… Perform prediction
82
  img = Image.open(uploaded_file_path).convert('RGB')
83
+ input_tensor = transform(img).unsqueeze(0)
84
 
85
+ with torch.no_grad():
86
+ prediction = model(input_tensor).numpy()[0]
87
+
88
+ # βœ… Interpret the output
89
  glaucoma_prob = 1 - prediction[0]
90
  normal_prob = prediction[0]
91
  result = 'Glaucoma' if glaucoma_prob > normal_prob else 'Normal'
92
  confidence = float(glaucoma_prob) if result == 'Glaucoma' else float(normal_prob)
93
 
94
+ # βœ… Save results to SQLite (no Grad-CAM generation for now)
 
 
 
 
 
95
  conn = sqlite3.connect(DB_PATH)
96
  cursor = conn.cursor()
97
  cursor.execute("""
98
  INSERT INTO results (image_filename, prediction, confidence, gradcam_filename, timestamp)
99
  VALUES (?, ?, ?, ?, ?)
100
+ """, (uploaded_filename, result, confidence, '', datetime.now().isoformat()))
101
  conn.commit()
102
  conn.close()
103
 
 
106
  'confidence': confidence,
107
  'normal_probability': float(normal_prob),
108
  'glaucoma_probability': float(glaucoma_prob),
109
+ 'gradcam_image': '', # No Grad-CAM for PyTorch for now
110
  'image_filename': uploaded_filename
111
  })
112
 
113
  except Exception as e:
114
  return jsonify({'error': str(e)}), 500
115
 
116
+
117
  @app.route('/results', methods=['GET'])
118
  def results():
119
  """List all results from the SQLite database."""
 
138
 
139
  @app.route('/gradcam/<filename>')
140
  def get_gradcam(filename):
141
+ """Serve the Grad-CAM overlay image (no-op if not used)."""
142
  filepath = os.path.join(OUTPUT_DIR, filename)
143
  if os.path.exists(filepath):
144
  return send_file(filepath, mimetype='image/png')
 
154
  else:
155
  return jsonify({'error': 'File not found'}), 404
156
 
157
+
158
  if __name__ == '__main__':
159
  app.run(host='0.0.0.0', port=7860)