ZDPLI commited on
Commit
49730b6
·
verified ·
1 Parent(s): c868b41

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -92
app.py CHANGED
@@ -2,8 +2,9 @@ import gradio as gr
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
- from torchvision import models
6
- from torchvision import transforms
 
7
  from transformers import ViTForImageClassification
8
  from torch import nn
9
  from torch.cuda.amp import autocast
@@ -25,35 +26,39 @@ label_mapping = {
25
  6: "Сосудистые поражения"
26
  }
27
 
28
- # Model paths
29
- CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./")
 
 
 
 
 
 
 
30
 
31
  # Model definitions
32
  def get_efficientnet():
33
  model = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
34
- model.classifier[1] = nn.Linear(1280, 7)
35
  return model.to(device)
36
 
37
  def get_deit():
38
  model = ViTForImageClassification.from_pretrained(
39
  'facebook/deit-base-patch16-224',
40
- num_labels=7,
41
  ignore_mismatched_sizes=True
42
  )
43
  return model.to(device)
44
 
45
  # Transforms
 
 
 
 
 
 
46
  def transform_image(image):
47
- """Transform PIL image to model input format"""
48
- transform = transforms.Compose([
49
- transforms.Resize((224, 224)),
50
- transforms.ToTensor(),
51
- transforms.Normalize(
52
- mean=[0.485, 0.456, 0.406],
53
- std=[0.229, 0.224, 0.225]
54
- )
55
- ])
56
- return transform(image).unsqueeze(0).to(device)
57
 
58
  # Model Handler
59
  class ModelHandler:
@@ -65,13 +70,11 @@ class ModelHandler:
65
 
66
  def load_models(self):
67
  try:
68
- # Load EfficientNet
69
  self.efficientnet = get_efficientnet()
70
- efficientnet_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")
71
- self.efficientnet.load_state_dict(torch.load(efficientnet_path, map_location=device))
72
  self.efficientnet.eval()
73
 
74
- # Load DeiT
75
  self.deit = get_deit()
76
  deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth")
77
  self.deit.load_state_dict(torch.load(deit_path, map_location=device))
@@ -80,79 +83,111 @@ class ModelHandler:
80
  self.models_loaded = True
81
  print("✅ Models loaded successfully")
82
  except Exception as e:
83
- print(f"❌ Error loading models: {str(e)}")
84
  self.models_loaded = False
85
 
86
  @torch.no_grad()
87
- def predict_efficientnet(self, image):
88
- if not self.models_loaded:
89
- return {"error": "Модели не загружены"}
90
-
91
- inputs = transform_image(image)
92
- # Handle autocast based on device
93
- ctx = autocast() if device.type == 'cuda' else nullcontext()
94
- with ctx:
95
- outputs = self.efficientnet(inputs)
96
- probs = torch.nn.functional.softmax(outputs, dim=1)
97
-
98
- return self._format_predictions(probs)
99
-
100
- @torch.no_grad()
101
- def predict_deit(self, image):
102
  if not self.models_loaded:
103
  return {"error": "Модели не загружены"}
104
-
105
  inputs = transform_image(image)
106
  ctx = autocast() if device.type == 'cuda' else nullcontext()
107
  with ctx:
108
- outputs = self.deit(pixel_values=inputs).logits # Corrected parameter
109
- probs = torch.nn.functional.softmax(outputs, dim=1)
110
-
 
 
 
 
111
  return self._format_predictions(probs)
112
 
113
- @torch.no_grad()
114
- def predict_ensemble(self, image):
115
- if not self.models_loaded:
116
- return {"error": "Модели не загружены"}
117
 
118
- inputs = transform_image(image)
119
- ctx = autocast() if device.type == 'cuda' else nullcontext()
120
- with ctx:
121
- eff_probs = torch.nn.functional.softmax(self.efficientnet(inputs), dim=1)
122
- deit_probs = torch.nn.functional.softmax(self.deit(pixel_values=inputs).logits, dim=1)
123
- ensemble_probs = (eff_probs + deit_probs) / 2
124
-
125
- return self._format_predictions(ensemble_probs)
126
-
127
- def _format_predictions(self, probs): # Corrected indentation
128
- top5_probs, top5_indices = torch.topk(probs, 5)
129
- result = {}
130
- for i in range(5):
131
- idx = top5_indices[0][i].item()
132
- label = label_mapping.get(idx, f"Класс {idx}")
133
- result[label] = float(top5_probs[0][i].item())
134
- return result
135
-
136
- # Initialize model handler
137
  model_handler = ModelHandler()
138
 
139
- # Prediction wrappers
140
  def predict_efficientnet(image):
141
- if image is None:
142
- return "⚠️ Загрузите изображение"
143
- return model_handler.predict_efficientnet(image)
144
 
145
  def predict_deit(image):
146
- if image is None:
147
- return "⚠️ Загрузите изображение"
148
- return model_handler.predict_deit(image)
149
 
150
  def predict_ensemble(image):
151
- if image is None:
152
- return "⚠️ Загрузите изображение"
153
- return model_handler.predict_ensemble(image)
154
-
155
- # Create Gradio Blocks with Tabs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def create_interface():
157
  with gr.Blocks() as demo:
158
  gr.Markdown("# Диагностика кожных поражений (HAM10K)")
@@ -161,30 +196,27 @@ def create_interface():
161
 
162
  with gr.Tabs():
163
  with gr.TabItem("EfficientNet"):
164
- img = gr.Image(label="Загрузите изображение", type="pil")
165
- btn = gr.Button("Предсказать", variant="primary")
166
- out = gr.Label(label="Результаты")
167
- btn.click(predict_efficientnet, inputs=img, outputs=out)
168
- gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img)
169
 
170
  with gr.TabItem("DeiT"):
171
- img = gr.Image(label="Загрузите изображение", type="pil")
172
- btn = gr.Button("Предсказать", variant="primary")
173
- out = gr.Label(label="Результаты")
174
- btn.click(predict_deit, inputs=img, outputs=out)
175
- gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img)
176
 
177
  with gr.TabItem("Ансамблевая модель"):
178
- img = gr.Image(label="Загрузите изображение", type="pil")
179
- btn = gr.Button("Предсказать", variant="primary")
180
- out = gr.Label(label="Результаты")
181
- btn.click(predict_ensemble, inputs=img, outputs=out)
182
- gr.Examples(examples=["examples/akiec.jpg", "examples/bcc.jpg", "examples/df.jpg"], inputs=img)
 
 
 
 
183
 
184
  return demo
185
 
186
- # Launch interface
187
  if __name__ == "__main__":
188
  interface = create_interface()
189
  print("🚀 Запуск интерфейса...")
190
- interface.launch(server_port=7860) # Explicitly set port if needed
 
2
  import torch
3
  from PIL import Image
4
  import numpy as np
5
+ from torchvision import models, transforms
6
+ from torchvision.datasets import ImageFolder
7
+ from torch.utils.data import DataLoader
8
  from transformers import ViTForImageClassification
9
  from torch import nn
10
  from torch.cuda.amp import autocast
 
26
  6: "Сосудистые поражения"
27
  }
28
 
29
+ # Paths and hyperparams
30
+ CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./checkpoints")
31
+ SUBMISSIONS_PATH = os.getenv("SUBMISSIONS_PATH", "./submissions")
32
+ FT_BATCH = 32
33
+ FT_EPOCHS = 1 # adjust as needed
34
+ LR = 1e-4
35
+
36
+ os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
37
+ os.makedirs(SUBMISSIONS_PATH, exist_ok=True)
38
 
39
  # Model definitions
40
  def get_efficientnet():
41
  model = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
42
+ model.classifier[1] = nn.Linear(1280, len(label_mapping))
43
  return model.to(device)
44
 
45
  def get_deit():
46
  model = ViTForImageClassification.from_pretrained(
47
  'facebook/deit-base-patch16-224',
48
+ num_labels=len(label_mapping),
49
  ignore_mismatched_sizes=True
50
  )
51
  return model.to(device)
52
 
53
  # Transforms
54
+ train_transform = transforms.Compose([
55
+ transforms.Resize((224, 224)),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
58
+ ])
59
+
60
  def transform_image(image):
61
+ return train_transform(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
 
 
62
 
63
  # Model Handler
64
  class ModelHandler:
 
70
 
71
  def load_models(self):
72
  try:
 
73
  self.efficientnet = get_efficientnet()
74
+ eff_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")
75
+ self.efficientnet.load_state_dict(torch.load(eff_path, map_location=device))
76
  self.efficientnet.eval()
77
 
 
78
  self.deit = get_deit()
79
  deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth")
80
  self.deit.load_state_dict(torch.load(deit_path, map_location=device))
 
83
  self.models_loaded = True
84
  print("✅ Models loaded successfully")
85
  except Exception as e:
86
+ print(f"❌ Error loading models: {e}")
87
  self.models_loaded = False
88
 
89
  @torch.no_grad()
90
+ def predict(self, image, use='efficientnet'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  if not self.models_loaded:
92
  return {"error": "Модели не загружены"}
 
93
  inputs = transform_image(image)
94
  ctx = autocast() if device.type == 'cuda' else nullcontext()
95
  with ctx:
96
+ if use == 'efficientnet':
97
+ logits = self.efficientnet(inputs)
98
+ elif use == 'deit':
99
+ logits = self.deit(pixel_values=inputs).logits
100
+ else:
101
+ logits = (self.efficientnet(inputs) + self.deit(pixel_values=inputs).logits) / 2
102
+ probs = torch.nn.functional.softmax(logits, dim=1)
103
  return self._format_predictions(probs)
104
 
105
+ def _format_predictions(self, probs):
106
+ top5_probs, top5_inds = torch.topk(probs, 5)
107
+ return {label_mapping[i.item()]: float(top5_probs[0][k].item())
108
+ for k, i in enumerate(top5_inds[0])}
109
 
110
+ # Initialize handler
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  model_handler = ModelHandler()
112
 
 
113
  def predict_efficientnet(image):
114
+ return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'efficientnet')
 
 
115
 
116
  def predict_deit(image):
117
+ return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'deit')
 
 
118
 
119
  def predict_ensemble(image):
120
+ return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'ensemble')
121
+
122
+ # Finetuning logic
123
+
124
+ def finetune_models():
125
+ # Prepare dataset
126
+ dataset = ImageFolder(SUBMISSIONS_PATH, transform=train_transform)
127
+ loader = DataLoader(dataset, batch_size=8, shuffle=True)
128
+
129
+ # Finetune EfficientNet
130
+ eff = get_efficientnet()
131
+ eff.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"), map_location=device))
132
+ eff.train()
133
+ optimizer = torch.optim.Adam(eff.parameters(), lr=LR)
134
+ criterion = nn.CrossEntropyLoss()
135
+ for epoch in range(FT_EPOCHS):
136
+ for imgs, lbls in loader:
137
+ imgs, lbls = imgs.to(device), lbls.to(device)
138
+ optimizer.zero_grad()
139
+ outputs = eff(imgs)
140
+ loss = criterion(outputs, lbls)
141
+ loss.backward()
142
+ optimizer.step()
143
+ torch.save(eff.state_dict(), os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"))
144
+
145
+ # Finetune DeiT
146
+ dt = get_deit()
147
+ dt.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "deit_best.pth"), map_location=device))
148
+ dt.train()
149
+ optimizer = torch.optim.Adam(dt.parameters(), lr=LR)
150
+ for epoch in range(FT_EPOCHS):
151
+ for imgs, lbls in loader:
152
+ imgs, lbls = imgs.to(device), lbls.to(device)
153
+ optimizer.zero_grad()
154
+ outputs = dt(pixel_values=imgs).logits
155
+ loss = criterion(outputs, lbls)
156
+ loss.backward()
157
+ optimizer.step()
158
+ torch.save(dt.state_dict(), os.path.join(CHECKPOINTS_PATH, "deit_best.pth"))
159
+
160
+ # Reload into handler
161
+ model_handler.load_models()
162
+ print("🔄 Models fine-tuned and reloaded")
163
+
164
+
165
+ def handle_submission(image, label):
166
+ if image is None or label is None:
167
+ return "⚠️ Загрузите изображение и выберите метку"
168
+ # Save image under label folder
169
+ lbl_dir = os.path.join(SUBMISSIONS_PATH, str(label))
170
+ os.makedirs(lbl_dir, exist_ok=True)
171
+ idx = len([f for f in os.listdir(lbl_dir) if f.endswith(('.png','.jpg'))]) + 1
172
+ path = os.path.join(lbl_dir, f"{label}_{idx}.png")
173
+ image.save(path)
174
+
175
+ # Count total submissions
176
+ total = sum(len(files) for _, _, files in os.walk(SUBMISSIONS_PATH))
177
+ rem = FT_BATCH - (total % FT_BATCH)
178
+ if rem == FT_BATCH:
179
+ rem = 0 # just reached batch multiple
180
+ # Trigger finetune if batch complete
181
+ if total % FT_BATCH == 0:
182
+ finetune_models()
183
+ # Clear submissions
184
+ for root, _, files in os.walk(SUBMISSIONS_PATH):
185
+ for f in files:
186
+ os.remove(os.path.join(root, f))
187
+
188
+ return f"Осталось {rem} изображений до следующей тонкой настройки"
189
+
190
+ # Create Gradio interface
191
  def create_interface():
192
  with gr.Blocks() as demo:
193
  gr.Markdown("# Диагностика кожных поражений (HAM10K)")
 
196
 
197
  with gr.Tabs():
198
  with gr.TabItem("EfficientNet"):
199
+ img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
200
+ gr.Button("Предсказать").click(predict_efficientnet, inputs=img, outputs=out)
 
 
 
201
 
202
  with gr.TabItem("DeiT"):
203
+ img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
204
+ gr.Button("Предсказать").click(predict_deit, inputs=img, outputs=out)
 
 
 
205
 
206
  with gr.TabItem("Ансамблевая модель"):
207
+ img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
208
+ gr.Button("Предсказать").click(predict_ensemble, inputs=img, outputs=out)
209
+
210
+ with gr.TabItem("Submit for Finetuning"):
211
+ sub_img = gr.Image(type="pil", label="Изображение для тонкой настройки")
212
+ sub_lbl = gr.Dropdown(choices=list(label_mapping.values()), label="Выберите метку")
213
+ sub_btn = gr.Button("Отправить")
214
+ sub_out = gr.Textbox(label="Статус")
215
+ sub_btn.click(handle_submission, inputs=[sub_img, sub_lbl], outputs=sub_out)
216
 
217
  return demo
218
 
 
219
  if __name__ == "__main__":
220
  interface = create_interface()
221
  print("🚀 Запуск интерфейса...")
222
+ interface.launch(server_port=7860)