ZDPLI commited on
Commit
647b808
·
verified ·
1 Parent(s): fddb52e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -63
app.py CHANGED
@@ -3,6 +3,7 @@ import torch
3
  from PIL import Image
4
  import numpy as np
5
  from torchvision import models
 
6
  from transformers import ViTForImageClassification
7
  from torch import nn
8
  from torch.cuda.amp import autocast
@@ -24,7 +25,7 @@ label_mapping = {
24
  }
25
 
26
  # Model paths
27
- CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./")
28
 
29
  # Model definitions
30
  def get_efficientnet():
@@ -60,7 +61,7 @@ class ModelHandler:
60
  self.deit = None
61
  self.models_loaded = False
62
  self.load_models()
63
-
64
  def load_models(self):
65
  try:
66
  # Load EfficientNet
@@ -68,59 +69,56 @@ class ModelHandler:
68
  efficientnet_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")
69
  self.efficientnet.load_state_dict(torch.load(efficientnet_path, map_location=device))
70
  self.efficientnet.eval()
71
-
72
  # Load DeiT
73
  self.deit = get_deit()
74
  deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth")
75
  self.deit.load_state_dict(torch.load(deit_path, map_location=device))
76
  self.deit.eval()
77
-
78
  self.models_loaded = True
79
  print("✅ Models loaded successfully")
80
  except Exception as e:
81
  print(f"❌ Error loading models: {str(e)}")
82
  self.models_loaded = False
83
-
84
  @torch.no_grad()
85
  def predict_efficientnet(self, image):
86
  if not self.models_loaded:
87
  return {"error": "Модели не загружены"}
88
-
89
  inputs = transform_image(image)
90
  with autocast():
91
  outputs = self.efficientnet(inputs)
92
  probs = torch.nn.functional.softmax(outputs, dim=1)
93
-
94
  return self._format_predictions(probs)
95
 
96
  @torch.no_grad()
97
  def predict_deit(self, image):
98
  if not self.models_loaded:
99
  return {"error": "Модели не загружены"}
100
-
101
  inputs = transform_image(image)
102
  with autocast():
103
  outputs = self.deit(inputs).logits
104
  probs = torch.nn.functional.softmax(outputs, dim=1)
105
-
106
  return self._format_predictions(probs)
107
-
108
  @torch.no_grad()
109
  def predict_ensemble(self, image):
110
  if not self.models_loaded:
111
  return {"error": "Модели не загружены"}
112
-
113
  inputs = transform_image(image)
114
  with autocast():
115
- # Get predictions from both models
116
  eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
117
  deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1)
118
-
119
- # Ensemble prediction (average probabilities)
120
  ensemble_probs = (eff_probs + deit_probs) / 2
121
-
122
  return self._format_predictions(ensemble_probs)
123
-
124
  def _format_predictions(self, probs):
125
  top5_probs, top5_indices = torch.topk(probs, 5)
126
  result = {}
@@ -133,7 +131,7 @@ class ModelHandler:
133
  # Initialize model handler
134
  model_handler = ModelHandler()
135
 
136
- # Prediction functions
137
  def predict_efficientnet(image):
138
  if image is None:
139
  return "⚠️ Загрузите изображение"
@@ -149,51 +147,39 @@ def predict_ensemble(image):
149
  return "⚠️ Загрузите изображение"
150
  return model_handler.predict_ensemble(image)
151
 
152
- # Gradio Interface
153
- def create_individual_tab(model_name, predict_fn):
154
- with gr.Blocks():
155
- with gr.Row():
156
- with gr.Column(scale=1):
157
- image_input = gr.Image(label="Загрузите изображение", type="pil")
158
- predict_btn = gr.Button("Предсказать", variant="primary")
159
- with gr.Column(scale=1):
160
- result_output = gr.Label(label="Результаты")
161
-
162
- predict_btn.click(
163
- predict_fn,
164
- inputs=image_input,
165
- outputs=result_output
166
- )
167
-
168
- gr.Examples(
169
- examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"],
170
- inputs=image_input,
171
- label="Примеры из ISIC"
172
- )
173
-
174
- # Create interface
175
- interface = gr.TabbedInterface(
176
- interface_list=[
177
- lambda: create_individual_tab("EfficientNet", predict_efficientnet),
178
- lambda: create_individual_tab("DeiT", predict_deit),
179
- lambda: create_individual_tab("Ансамблевая модель", predict_ensemble)
180
- ],
181
- tab_names=[
182
- "EfficientNet",
183
- "DeiT",
184
- "Ансамблевая модель"
185
- ]
186
- )
187
-
188
- # Add startup check
189
- def check_models():
190
- if not model_handler.models_loaded:
191
- return "⚠️ Предупреждение: Модели не загружены"
192
- return "✅ Модели готовы к предсказанию"
193
-
194
- startup_status = check_models()
195
- print(startup_status)
196
-
197
  if __name__ == "__main__":
 
198
  print("🚀 Запуск интерфейса...")
199
- interface.launch()
 
3
  from PIL import Image
4
  import numpy as np
5
  from torchvision import models
6
+ from torchvision import transforms
7
  from transformers import ViTForImageClassification
8
  from torch import nn
9
  from torch.cuda.amp import autocast
 
25
  }
26
 
27
  # Model paths
28
+ CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./checkpoints")
29
 
30
  # Model definitions
31
  def get_efficientnet():
 
61
  self.deit = None
62
  self.models_loaded = False
63
  self.load_models()
64
+
65
  def load_models(self):
66
  try:
67
  # Load EfficientNet
 
69
  efficientnet_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")
70
  self.efficientnet.load_state_dict(torch.load(efficientnet_path, map_location=device))
71
  self.efficientnet.eval()
72
+
73
  # Load DeiT
74
  self.deit = get_deit()
75
  deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth")
76
  self.deit.load_state_dict(torch.load(deit_path, map_location=device))
77
  self.deit.eval()
78
+
79
  self.models_loaded = True
80
  print("✅ Models loaded successfully")
81
  except Exception as e:
82
  print(f"❌ Error loading models: {str(e)}")
83
  self.models_loaded = False
84
+
85
  @torch.no_grad()
86
  def predict_efficientnet(self, image):
87
  if not self.models_loaded:
88
  return {"error": "Модели не загружены"}
89
+
90
  inputs = transform_image(image)
91
  with autocast():
92
  outputs = self.efficientnet(inputs)
93
  probs = torch.nn.functional.softmax(outputs, dim=1)
94
+
95
  return self._format_predictions(probs)
96
 
97
  @torch.no_grad()
98
  def predict_deit(self, image):
99
  if not self.models_loaded:
100
  return {"error": "Модели не загружены"}
101
+
102
  inputs = transform_image(image)
103
  with autocast():
104
  outputs = self.deit(inputs).logits
105
  probs = torch.nn.functional.softmax(outputs, dim=1)
106
+
107
  return self._format_predictions(probs)
108
+
109
  @torch.no_grad()
110
  def predict_ensemble(self, image):
111
  if not self.models_loaded:
112
  return {"error": "Модели не загружены"}
113
+
114
  inputs = transform_image(image)
115
  with autocast():
 
116
  eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
117
  deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1)
 
 
118
  ensemble_probs = (eff_probs + deit_probs) / 2
119
+
120
  return self._format_predictions(ensemble_probs)
121
+
122
  def _format_predictions(self, probs):
123
  top5_probs, top5_indices = torch.topk(probs, 5)
124
  result = {}
 
131
  # Initialize model handler
132
  model_handler = ModelHandler()
133
 
134
+ # Prediction wrappers
135
  def predict_efficientnet(image):
136
  if image is None:
137
  return "⚠️ Загрузите изображение"
 
147
  return "⚠️ Загрузите изображение"
148
  return model_handler.predict_ensemble(image)
149
 
150
+ # Create Gradio Blocks with Tabs
151
+ def create_interface():
152
+ with gr.Blocks() as demo:
153
+ gr.Markdown("# Диагностика кожных поражений (HAM10K)")
154
+ status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены"
155
+ gr.Markdown(f"**Состояние моделей:** {status}")
156
+
157
+ with gr.Tabs():
158
+ with gr.TabItem("EfficientNet"):
159
+ img = gr.Image(label="Загрузите изображение", type="pil")
160
+ btn = gr.Button("Предсказать", variant="primary")
161
+ out = gr.Label(label="Результаты")
162
+ btn.click(predict_efficientnet, inputs=img, outputs=out)
163
+ gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img)
164
+
165
+ with gr.TabItem("DeiT"):
166
+ img = gr.Image(label="Загрузите изображение", type="pil")
167
+ btn = gr.Button("Предсказать", variant="primary")
168
+ out = gr.Label(label="Результаты")
169
+ btn.click(predict_deit, inputs=img, outputs=out)
170
+ gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img)
171
+
172
+ with gr.TabItem("Ансамблевая модель"):
173
+ img = gr.Image(label="Загрузите изображение", type="pil")
174
+ btn = gr.Button("Предсказать", variant="primary")
175
+ out = gr.Label(label="Результаты")
176
+ btn.click(predict_ensemble, inputs=img, outputs=out)
177
+ gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img)
178
+
179
+ return demo
180
+
181
+ # Launch interface
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  if __name__ == "__main__":
183
+ interface = create_interface()
184
  print("🚀 Запуск интерфейса...")
185
+ interface.launch(share=True)