ZDPLI's picture
Update app.py
5681f8c verified
import gradio as gr
import torch
from PIL import Image
import numpy as np
from torchvision import models, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification
from torch import nn
from torch.cuda.amp import autocast
import os
from contextlib import nullcontext
# Global configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Label mapping (HAM10K)
label_mapping = {
0: "Меланоцитарный невус",
1: "Меланома",
2: "Базальноклеточная карцинома",
3: "Актинический кератоз",
4: "Доброкачественная кератома",
5: "Дерматофиброма",
6: "Сосудистые поражения"
}
# Paths and hyperparams
CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./")
SUBMISSIONS_PATH = os.getenv("SUBMISSIONS_PATH", "./submissions")
FT_BATCH = 32
FT_EPOCHS = 1 # adjust as needed
LR = 1e-4
os.makedirs(CHECKPOINTS_PATH, exist_ok=True)
os.makedirs(SUBMISSIONS_PATH, exist_ok=True)
# Model definitions
def get_efficientnet():
model = models.efficientnet_v2_s(weights="IMAGENET1K_V1")
model.classifier[1] = nn.Linear(1280, len(label_mapping))
return model.to(device)
def get_deit():
model = ViTForImageClassification.from_pretrained(
'facebook/deit-base-patch16-224',
num_labels=len(label_mapping),
ignore_mismatched_sizes=True
)
return model.to(device)
# Transforms
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def transform_image(image):
return train_transform(image).unsqueeze(0).to(device)
# Model Handler
class ModelHandler:
def __init__(self):
self.efficientnet = None
self.deit = None
self.models_loaded = False
self.load_models()
def load_models(self):
try:
self.efficientnet = get_efficientnet()
eff_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")
self.efficientnet.load_state_dict(torch.load(eff_path, map_location=device))
self.efficientnet.eval()
self.deit = get_deit()
deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth")
self.deit.load_state_dict(torch.load(deit_path, map_location=device))
self.deit.eval()
self.models_loaded = True
print("✅ Models loaded successfully")
except Exception as e:
print(f"❌ Error loading models: {e}")
self.models_loaded = False
@torch.no_grad()
def predict(self, image, use='efficientnet'):
if not self.models_loaded:
return {"error": "Модели не загружены"}
inputs = transform_image(image)
ctx = autocast() if device.type == 'cuda' else nullcontext()
with ctx:
if use == 'efficientnet':
logits = self.efficientnet(inputs)
elif use == 'deit':
logits = self.deit(pixel_values=inputs).logits
else:
logits = (self.efficientnet(inputs) + self.deit(pixel_values=inputs).logits) / 2
probs = torch.nn.functional.softmax(logits, dim=1)
return self._format_predictions(probs)
def _format_predictions(self, probs):
top5_probs, top5_inds = torch.topk(probs, 5)
return {label_mapping[i.item()]: float(top5_probs[0][k].item())
for k, i in enumerate(top5_inds[0])}
# Initialize handler
model_handler = ModelHandler()
def predict_efficientnet(image):
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'efficientnet')
def predict_deit(image):
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'deit')
def predict_ensemble(image):
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'ensemble')
# Finetuning logic
def finetune_models():
# Prepare dataset
dataset = ImageFolder(SUBMISSIONS_PATH, transform=train_transform)
loader = DataLoader(dataset, batch_size=8, shuffle=True)
# Finetune EfficientNet
eff = get_efficientnet()
eff.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"), map_location=device))
eff.train()
optimizer = torch.optim.Adam(eff.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()
for epoch in range(FT_EPOCHS):
for imgs, lbls in loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
outputs = eff(imgs)
loss = criterion(outputs, lbls)
loss.backward()
optimizer.step()
torch.save(eff.state_dict(), os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"))
# Finetune DeiT
dt = get_deit()
dt.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "deit_best.pth"), map_location=device))
dt.train()
optimizer = torch.optim.Adam(dt.parameters(), lr=LR)
for epoch in range(FT_EPOCHS):
for imgs, lbls in loader:
imgs, lbls = imgs.to(device), lbls.to(device)
optimizer.zero_grad()
outputs = dt(pixel_values=imgs).logits
loss = criterion(outputs, lbls)
loss.backward()
optimizer.step()
torch.save(dt.state_dict(), os.path.join(CHECKPOINTS_PATH, "deit_best.pth"))
# Reload into handler
model_handler.load_models()
print("🔄 Models fine-tuned and reloaded")
def handle_submission(image, label):
if image is None or label is None:
return "⚠️ Загрузите изображение и выберите метку"
# Save image under label folder
lbl_dir = os.path.join(SUBMISSIONS_PATH, str(label))
os.makedirs(lbl_dir, exist_ok=True)
idx = len([f for f in os.listdir(lbl_dir) if f.endswith(('.png','.jpg'))]) + 1
path = os.path.join(lbl_dir, f"{label}_{idx}.png")
image.save(path)
# Count total submissions
total = sum(len(files) for _, _, files in os.walk(SUBMISSIONS_PATH))
rem = FT_BATCH - (total % FT_BATCH)
if rem == FT_BATCH:
rem = 0 # just reached batch multiple
# Trigger finetune if batch complete
if total % FT_BATCH == 0:
finetune_models()
# Clear submissions
for root, _, files in os.walk(SUBMISSIONS_PATH):
for f in files:
os.remove(os.path.join(root, f))
return f"Осталось {rem} изображений до следующей тонкой настройки"
# Create Gradio interface
def create_interface():
with gr.Blocks() as demo:
gr.Markdown("# Диагностика кожных поражений (HAM10K)")
status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены"
gr.Markdown(f"**Состояние моделей:** {status}")
with gr.Tabs():
with gr.TabItem("EfficientNet"):
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
gr.Button("Предсказать").click(predict_efficientnet, inputs=img, outputs=out)
with gr.TabItem("DeiT"):
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
gr.Button("Предсказать").click(predict_deit, inputs=img, outputs=out)
with gr.TabItem("Ансамблевая модель"):
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты")
gr.Button("Предсказать").click(predict_ensemble, inputs=img, outputs=out)
with gr.TabItem("Submit for Finetuning"):
sub_img = gr.Image(type="pil", label="Изображение для тонкой настройки")
sub_lbl = gr.Dropdown(choices=list(label_mapping.values()), label="Выберите метку")
sub_btn = gr.Button("Отправить")
sub_out = gr.Textbox(label="Статус")
sub_btn.click(handle_submission, inputs=[sub_img, sub_lbl], outputs=sub_out)
return demo
if __name__ == "__main__":
interface = create_interface()
print("🚀 Запуск интерфейса...")
interface.launch(server_port=7860)