ZDPLI commited on
Commit
53b4cd0
·
verified ·
1 Parent(s): c9ddca5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ import numpy as np
5
+ from torchvision import models
6
+ from transformers import ViTForImageClassification
7
+ from torch import nn
8
+ from torch.cuda.amp import autocast
9
+ import os
10
+
11
+ # Global configuration
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print(f"Using device: {device}")
14
+
15
+ # Label mapping (HAM10K)
16
+ label_mapping = {
17
+ 0: "Меланома",
18
+ 1: "Меланоцитарный невус",
19
+ 2: "Базальноклеточная карцинома",
20
+ 3: "Актинический кератоз",
21
+ 4: "Доброкачественная кератоза",
22
+ 5: "Дерматофиброма",
23
+ 6: "Сосудистые поражения"
24
+ }
25
+
26
+ # Model paths
27
+ CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./checkpoints")
28
+
29
+ # Model definitions
30
+ def get_efficientnet():
31
+ model = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
32
+ model.classifier[1] = nn.Linear(1280, 7)
33
+ return model.to(device)
34
+
35
+ def get_deit():
36
+ model = ViTForImageClassification.from_pretrained(
37
+ 'facebook/deit-base-patch16-224',
38
+ num_labels=7,
39
+ ignore_mismatched_sizes=True
40
+ )
41
+ return model.to(device)
42
+
43
+ # Transforms
44
+ def transform_image(image):
45
+ """Transform PIL image to model input format"""
46
+ transform = transforms.Compose([
47
+ transforms.Resize((224, 224)),
48
+ transforms.ToTensor(),
49
+ transforms.Normalize(
50
+ mean=[0.485, 0.456, 0.406],
51
+ std=[0.229, 0.224, 0.225]
52
+ )
53
+ ])
54
+ return transform(image).unsqueeze(0).to(device)
55
+
56
+ # Model Handler
57
+ class ModelHandler:
58
+ def __init__(self):
59
+ self.efficientnet = None
60
+ self.deit = None
61
+ self.models_loaded = False
62
+ self.load_models()
63
+
64
+ def load_models(self):
65
+ try:
66
+ # Load EfficientNet
67
+ self.efficientnet = get_efficientnet()
68
+ efficientnet_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")
69
+ self.efficientnet.load_state_dict(torch.load(efficientnet_path, map_location=device))
70
+ self.efficientnet.eval()
71
+
72
+ # Load DeiT
73
+ self.deit = get_deit()
74
+ deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth")
75
+ self.deit.load_state_dict(torch.load(deit_path, map_location=device))
76
+ self.deit.eval()
77
+
78
+ self.models_loaded = True
79
+ print("✅ Models loaded successfully")
80
+ except Exception as e:
81
+ print(f"❌ Error loading models: {str(e)}")
82
+ self.models_loaded = False
83
+
84
+ @torch.no_grad()
85
+ def predict_efficientnet(self, image):
86
+ if not self.models_loaded:
87
+ return {"error": "Модели не загружены"}
88
+
89
+ inputs = transform_image(image)
90
+ with autocast():
91
+ outputs = self.efficientnet(inputs)
92
+ probs = torch.nn.functional.softmax(outputs, dim=1)
93
+
94
+ return self._format_predictions(probs)
95
+
96
+ @torch.no_grad()
97
+ def predict_deit(self, image):
98
+ if not self.models_loaded:
99
+ return {"error": "Модели не загружены"}
100
+
101
+ inputs = transform_image(image)
102
+ with autocast():
103
+ outputs = self.deit(inputs).logits
104
+ probs = torch.nn.functional.softmax(outputs, dim=1)
105
+
106
+ return self._format_predictions(probs)
107
+
108
+ @torch.no_grad()
109
+ def predict_ensemble(self, image):
110
+ if not self.models_loaded:
111
+ return {"error": "Модели не загружены"}
112
+
113
+ inputs = transform_image(image)
114
+ with autocast():
115
+ # Get predictions from both models
116
+ eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
117
+ deit_probs = torch.nn.functional.softmax(self.deit(inputs).logits, dim=1)
118
+
119
+ # Ensemble prediction (average probabilities)
120
+ ensemble_probs = (eff_probs + deit_probs) / 2
121
+
122
+ return self._format_predictions(ensemble_probs)
123
+
124
+ def _format_predictions(self, probs):
125
+ top5_probs, top5_indices = torch.topk(probs, 5)
126
+ result = {}
127
+ for i in range(5):
128
+ idx = top5_indices[0][i].item()
129
+ label = label_mapping.get(idx, f"Класс {idx}")
130
+ result[label] = float(top5_probs[0][i].item() * 100)
131
+ return result
132
+
133
+ # Initialize model handler
134
+ model_handler = ModelHandler()
135
+
136
+ # Prediction functions
137
+ def predict_efficientnet(image):
138
+ if image is None:
139
+ return "⚠️ Загрузите изображение"
140
+ return model_handler.predict_efficientnet(image)
141
+
142
+ def predict_deit(image):
143
+ if image is None:
144
+ return "⚠️ Загрузите изображение"
145
+ return model_handler.predict_deit(image)
146
+
147
+ def predict_ensemble(image):
148
+ if image is None:
149
+ return "⚠️ Загрузите изображение"
150
+ return model_handler.predict_ensemble(image)
151
+
152
+ # Gradio Interface
153
+ def create_individual_tab(model_name, predict_fn):
154
+ with gr.Blocks():
155
+ with gr.Row():
156
+ with gr.Column(scale=1):
157
+ image_input = gr.Image(label="Загрузите изображение", type="pil")
158
+ predict_btn = gr.Button("Предсказать", variant="primary")
159
+ with gr.Column(scale=1):
160
+ result_output = gr.Label(label="Результаты")
161
+
162
+ predict_btn.click(
163
+ predict_fn,
164
+ inputs=image_input,
165
+ outputs=result_output
166
+ )
167
+
168
+ gr.Examples(
169
+ examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"],
170
+ inputs=image_input,
171
+ label="Примеры из ISIC"
172
+ )
173
+
174
+ # Create interface
175
+ interface = gr.TabbedInterface(
176
+ interface_list=[
177
+ lambda: create_individual_tab("EfficientNet", predict_efficientnet),
178
+ lambda: create_individual_tab("DeiT", predict_deit),
179
+ lambda: create_individual_tab("Ансамблевая модель", predict_ensemble)
180
+ ],
181
+ tab_names=[
182
+ "EfficientNet",
183
+ "DeiT",
184
+ "Ансамблевая модель"
185
+ ],
186
+ title="DermVision Pro",
187
+ description="""
188
+ # Дерматологический классификатор
189
+ Выберите вкладку для использования соответствующей модели:
190
+ - EfficientNet: традиционная CNN модель
191
+ - DeiT: Vision Transformer
192
+ - Ансамблевая модель: комбинация CNN и Vision Transformer
193
+ """,
194
+ theme=gr.themes.Soft(),
195
+ css="""
196
+ .container {max-width: 1200px; margin: auto;}
197
+ .gr-button {font-size: 1.1em; padding: 8px 16px;}
198
+ .gr-textbox {font-size: 1.1em;}
199
+ .gr-column {min-width: 400px;}
200
+ """
201
+ )
202
+
203
+ # Add startup check
204
+ def check_models():
205
+ if not model_handler.models_loaded:
206
+ return "⚠️ Предупреждение: Модели не загружены"
207
+ return "✅ Модели готовы к предсказанию"
208
+
209
+ startup_status = check_models()
210
+ print(startup_status)
211
+
212
+ if __name__ == "__main__":
213
+ print("🚀 Запуск интерфейса...")
214
+ interface.launch()