dhruv2842 commited on
Commit
b050424
Β·
verified Β·
1 Parent(s): c54a3f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -21,6 +21,7 @@ DB_PATH = os.path.join(OUTPUT_DIR, 'results.db')
21
 
22
 
23
  def init_db():
 
24
  conn = sqlite3.connect(DB_PATH)
25
  cursor = conn.cursor()
26
  cursor.execute("""
@@ -38,24 +39,36 @@ def init_db():
38
 
39
 
40
  init_db()
 
 
41
  from densenet_withglam import get_model_with_attention
42
- # 1️⃣ Instantiate the model
 
43
  model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
44
  state_dict = torch.load('densenet169_seed40_best2.pt', map_location='cpu')
45
- model.load_state_dict(state_dict) # Should now match
46
  model.eval()
47
 
48
  # βœ… Class Names
49
  CLASS_NAMES = ["Normal", "Early Glaucoma", "Advanced Glaucoma"]
50
 
 
 
 
 
 
 
 
 
51
  @app.route('/')
52
  def home():
 
53
  return "Glaucoma Detection Flask API (3-Class Model) is running!"
54
 
55
  @app.route("/test_file")
56
  def test_file():
57
  """Check if the .pt model file is present and readable."""
58
- filepath = "densenet169_seed40_best.pt"
59
  if os.path.exists(filepath):
60
  return f"βœ… Model file found at: {filepath}"
61
  else:
 
21
 
22
 
23
  def init_db():
24
+ """Initialize SQLite database for storing results."""
25
  conn = sqlite3.connect(DB_PATH)
26
  cursor = conn.cursor()
27
  cursor.execute("""
 
39
 
40
 
41
  init_db()
42
+
43
+ # βœ… Import your custom GLAM model
44
  from densenet_withglam import get_model_with_attention
45
+
46
+ # βœ… Instantiate the model
47
  model = get_model_with_attention('densenet169', num_classes=3) # Will have GLAM
48
  state_dict = torch.load('densenet169_seed40_best2.pt', map_location='cpu')
49
+ model.load_state_dict(state_dict) # Load your trained weights
50
  model.eval()
51
 
52
  # βœ… Class Names
53
  CLASS_NAMES = ["Normal", "Early Glaucoma", "Advanced Glaucoma"]
54
 
55
+ # βœ… Transformation for input images
56
+ transform = transforms.Compose([
57
+ transforms.Resize((224, 224)), # Adjust size based on training
58
+ transforms.ToTensor(),
59
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], # Standard ImageNet stats
60
+ std=[0.229, 0.224, 0.225])
61
+ ])
62
+
63
  @app.route('/')
64
  def home():
65
+ """Check that the API is working."""
66
  return "Glaucoma Detection Flask API (3-Class Model) is running!"
67
 
68
  @app.route("/test_file")
69
  def test_file():
70
  """Check if the .pt model file is present and readable."""
71
+ filepath = "densenet169_seed40_best2.pt"
72
  if os.path.exists(filepath):
73
  return f"βœ… Model file found at: {filepath}"
74
  else: