File size: 8,877 Bytes
53b4cd0 49730b6 53b4cd0 981adbd 53b4cd0 c868b41 53b4cd0 981adbd 53b4cd0 49730b6 5681f8c 49730b6 53b4cd0 49730b6 53b4cd0 49730b6 53b4cd0 49730b6 53b4cd0 49730b6 53b4cd0 647b808 53b4cd0 49730b6 53b4cd0 647b808 53b4cd0 647b808 53b4cd0 49730b6 53b4cd0 647b808 53b4cd0 49730b6 53b4cd0 981adbd 49730b6 53b4cd0 647b808 49730b6 647b808 49730b6 53b4cd0 49730b6 53b4cd0 49730b6 53b4cd0 49730b6 647b808 49730b6 647b808 49730b6 647b808 49730b6 647b808 53b4cd0 647b808 53b4cd0 49730b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import gradio as gr
import torch
from PIL import Image
import numpy as np
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
from torch import nn
from torch.cuda.amp import autocast
import os
from contextlib import nullcontext
# Global configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Label mapping (HAM10K)
label_mapping = {
0: "Меланоцитарный невус",
1: "Меланома",
2: "Базальноклеточная карцинома",
3: "Актинический кератоз",
4: "Доброкачественная кератома",
5: "Дерматофиброма",
6: "Сосудистые поражения"
}
# Paths and hyperparams
CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./")
SUBMISSIONS_PATH = os.getenv("SUBMISSIONS_PATH", "./submissions")
FT_BATCH = 32
FT_EPOCHS = 1 # adjust as needed
LR = 1e-4
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
os.makedirs(SUBMISSIONS_PATH, exist_ok=True)
# Model definitions
def get_efficientnet():
model = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
model.classifier[1] = nn.Linear(1280, len(label_mapping))
return model.to(device)
def get_deit():
model = ViTForImageClassification.from_pretrained(
'facebook/deit-base-patch16-224',
num_labels=len(label_mapping),
ignore_mismatched_sizes=True
)
return model.to(device)
# Transforms
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def transform_image(image):
return train_transform(image).unsqueeze(0).to(device)
# Model Handler
class ModelHandler:
def __init__(self):
self.efficientnet = None
self.deit = None
self.models_loaded = False
self.load_models()
def load_models(self):
try:
self.efficientnet = get_efficientnet()
eff_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")
self.efficientnet.load_state_dict(torch.load(eff_path, map_location=device))
self.efficientnet.eval()
self.deit = get_deit()
deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth")
self.deit.load_state_dict(torch.load(deit_path, map_location=device))
self.deit.eval()
self.models_loaded = True
print("✅ Models loaded successfully")
except Exception as e:
print(f"❌ Error loading models: {e}")
self.models_loaded = False
@torch.no_grad()
def predict(self, image, use='efficientnet'):
if not self.models_loaded:
return {"error": "Модели не загружены"}
inputs = transform_image(image)
ctx = autocast() if device.type == 'cuda' else nullcontext()
with ctx:
if use == 'efficientnet':
logits = self.efficientnet(inputs)
elif use == 'deit':
logits = self.deit(pixel_values=inputs).logits
else:
logits = (self.efficientnet(inputs) + self.deit(pixel_values=inputs).logits) / 2
probs = torch.nn.functional.softmax(logits, dim=1)
return self._format_predictions(probs)
def _format_predictions(self, probs):
top5_probs, top5_inds = torch.topk(probs, 5)
return {label_mapping[i.item()]: float(top5_probs[0][k].item())
for k, i in enumerate(top5_inds[0])}
# Initialize handler
model_handler = ModelHandler()
def predict_efficientnet(image):
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'efficientnet')
def predict_deit(image):
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'deit')
def predict_ensemble(image):
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'ensemble')
# Finetuning logic
def finetune_models():
# Prepare dataset
dataset = ImageFolder(SUBMISSIONS_PATH, transform=train_transform)
loader = DataLoader(dataset, batch_size=8, shuffle=True)
# Finetune EfficientNet
eff = get_efficientnet()
eff.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"), map_location=device))
eff.train()
optimizer = torch.optim.Adam(eff.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
for epoch in range(FT_EPOCHS):
for imgs, lbls in loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
outputs = eff(imgs)
loss = criterion(outputs, lbls)
loss.backward()
optimizer.step()
torch.save(eff.state_dict(), os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"))
# Finetune DeiT
dt = get_deit()
dt.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "deit_best.pth"), map_location=device))
dt.train()
optimizer = torch.optim.Adam(dt.parameters(), lr=LR)
for epoch in range(FT_EPOCHS):
for imgs, lbls in loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
outputs = dt(pixel_values=imgs).logits
loss = criterion(outputs, lbls)
loss.backward()
optimizer.step()
torch.save(dt.state_dict(), os.path.join(CHECKPOINTS_PATH, "deit_best.pth"))
# Reload into handler
model_handler.load_models()
print("🔄 Models fine-tuned and reloaded")
def handle_submission(image, label):
if image is None or label is None:
return "⚠️ Загрузите изображение и выберите метку"
# Save image under label folder
lbl_dir = os.path.join(SUBMISSIONS_PATH, str(label))
os.makedirs(lbl_dir, exist_ok=True)
idx = len([f for f in os.listdir(lbl_dir) if f.endswith(('.png','.jpg'))]) + 1
path = os.path.join(lbl_dir, f"{label}_{idx}.png")
image.save(path)
# Count total submissions
total = sum(len(files) for _, _, files in os.walk(SUBMISSIONS_PATH))
rem = FT_BATCH - (total % FT_BATCH)
if rem == FT_BATCH:
rem = 0 # just reached batch multiple
# Trigger finetune if batch complete
if total % FT_BATCH == 0:
finetune_models()
# Clear submissions
for root, _, files in os.walk(SUBMISSIONS_PATH):
for f in files:
os.remove(os.path.join(root, f))
return f"Осталось {rem} изображений до следующей тонкой настройки"
# Create Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Диагностика кожных поражений (HAM10K)")
status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены"
gr.Markdown(f"**Состояние моделей:** {status}")
with gr.Tabs():
with gr.TabItem("EfficientNet"):
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
gr.Button("Предсказать").click(predict_efficientnet, inputs=img, outputs=out)
with gr.TabItem("DeiT"):
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
gr.Button("Предсказать").click(predict_deit, inputs=img, outputs=out)
with gr.TabItem("Ансамблевая модель"):
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
gr.Button("Предсказать").click(predict_ensemble, inputs=img, outputs=out)
with gr.TabItem("Submit for Finetuning"):
sub_img = gr.Image(type="pil", label="Изображение для тонкой настройки")
sub_lbl = gr.Dropdown(choices=list(label_mapping.values()), label="Выберите метку")
sub_btn = gr.Button("Отправить")
sub_out = gr.Textbox(label="Статус")
sub_btn.click(handle_submission, inputs=[sub_img, sub_lbl], outputs=sub_out)
return demo
if __name__ == "__main__":
interface = create_interface()
print("🚀 Запуск интерфейса...")
interface.launch(server_port=7860)
|