Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from PIL import Image | |
import numpy as np | |
from torchvision import models, transforms | |
from torchvision.datasets import ImageFolder | |
from torch.utils.data import DataLoader | |
from transformers import ViTForImageClassification | |
from torch import nn | |
from torch.cuda.amp import autocast | |
import os | |
from contextlib import nullcontext | |
# Global configuration | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Label mapping (HAM10K) | |
label_mapping = { | |
0: "Меланоцитарный невус", | |
1: "Меланома", | |
2: "Базальноклеточная карцинома", | |
3: "Актинический кератоз", | |
4: "Доброкачественная кератома", | |
5: "Дерматофиброма", | |
6: "Сосудистые поражения" | |
} | |
# Paths and hyperparams | |
CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./") | |
SUBMISSIONS_PATH = os.getenv("SUBMISSIONS_PATH", "./submissions") | |
FT_BATCH = 32 | |
FT_EPOCHS = 1 # adjust as needed | |
LR = 1e-4 | |
os.makedirs(CHECKPOINTS_PATH, exist_ok=True) | |
os.makedirs(SUBMISSIONS_PATH, exist_ok=True) | |
# Model definitions | |
def get_efficientnet(): | |
model = models.efficientnet_v2_s(weights="IMAGENET1K_V1") | |
model.classifier[1] = nn.Linear(1280, len(label_mapping)) | |
return model.to(device) | |
def get_deit(): | |
model = ViTForImageClassification.from_pretrained( | |
'facebook/deit-base-patch16-224', | |
num_labels=len(label_mapping), | |
ignore_mismatched_sizes=True | |
) | |
return model.to(device) | |
# Transforms | |
train_transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def transform_image(image): | |
return train_transform(image).unsqueeze(0).to(device) | |
# Model Handler | |
class ModelHandler: | |
def __init__(self): | |
self.efficientnet = None | |
self.deit = None | |
self.models_loaded = False | |
self.load_models() | |
def load_models(self): | |
try: | |
self.efficientnet = get_efficientnet() | |
eff_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth") | |
self.efficientnet.load_state_dict(torch.load(eff_path, map_location=device)) | |
self.efficientnet.eval() | |
self.deit = get_deit() | |
deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth") | |
self.deit.load_state_dict(torch.load(deit_path, map_location=device)) | |
self.deit.eval() | |
self.models_loaded = True | |
print("✅ Models loaded successfully") | |
except Exception as e: | |
print(f"❌ Error loading models: {e}") | |
self.models_loaded = False | |
def predict(self, image, use='efficientnet'): | |
if not self.models_loaded: | |
return {"error": "Модели не загружены"} | |
inputs = transform_image(image) | |
ctx = autocast() if device.type == 'cuda' else nullcontext() | |
with ctx: | |
if use == 'efficientnet': | |
logits = self.efficientnet(inputs) | |
elif use == 'deit': | |
logits = self.deit(pixel_values=inputs).logits | |
else: | |
logits = (self.efficientnet(inputs) + self.deit(pixel_values=inputs).logits) / 2 | |
probs = torch.nn.functional.softmax(logits, dim=1) | |
return self._format_predictions(probs) | |
def _format_predictions(self, probs): | |
top5_probs, top5_inds = torch.topk(probs, 5) | |
return {label_mapping[i.item()]: float(top5_probs[0][k].item()) | |
for k, i in enumerate(top5_inds[0])} | |
# Initialize handler | |
model_handler = ModelHandler() | |
def predict_efficientnet(image): | |
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'efficientnet') | |
def predict_deit(image): | |
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'deit') | |
def predict_ensemble(image): | |
return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'ensemble') | |
# Finetuning logic | |
def finetune_models(): | |
# Prepare dataset | |
dataset = ImageFolder(SUBMISSIONS_PATH, transform=train_transform) | |
loader = DataLoader(dataset, batch_size=8, shuffle=True) | |
# Finetune EfficientNet | |
eff = get_efficientnet() | |
eff.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"), map_location=device)) | |
eff.train() | |
optimizer = torch.optim.Adam(eff.parameters(), lr=LR) | |
criterion = nn.CrossEntropyLoss() | |
for epoch in range(FT_EPOCHS): | |
for imgs, lbls in loader: | |
imgs, lbls = imgs.to(device), lbls.to(device) | |
optimizer.zero_grad() | |
outputs = eff(imgs) | |
loss = criterion(outputs, lbls) | |
loss.backward() | |
optimizer.step() | |
torch.save(eff.state_dict(), os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")) | |
# Finetune DeiT | |
dt = get_deit() | |
dt.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "deit_best.pth"), map_location=device)) | |
dt.train() | |
optimizer = torch.optim.Adam(dt.parameters(), lr=LR) | |
for epoch in range(FT_EPOCHS): | |
for imgs, lbls in loader: | |
imgs, lbls = imgs.to(device), lbls.to(device) | |
optimizer.zero_grad() | |
outputs = dt(pixel_values=imgs).logits | |
loss = criterion(outputs, lbls) | |
loss.backward() | |
optimizer.step() | |
torch.save(dt.state_dict(), os.path.join(CHECKPOINTS_PATH, "deit_best.pth")) | |
# Reload into handler | |
model_handler.load_models() | |
print("🔄 Models fine-tuned and reloaded") | |
def handle_submission(image, label): | |
if image is None or label is None: | |
return "⚠️ Загрузите изображение и выберите метку" | |
# Save image under label folder | |
lbl_dir = os.path.join(SUBMISSIONS_PATH, str(label)) | |
os.makedirs(lbl_dir, exist_ok=True) | |
idx = len([f for f in os.listdir(lbl_dir) if f.endswith(('.png','.jpg'))]) + 1 | |
path = os.path.join(lbl_dir, f"{label}_{idx}.png") | |
image.save(path) | |
# Count total submissions | |
total = sum(len(files) for _, _, files in os.walk(SUBMISSIONS_PATH)) | |
rem = FT_BATCH - (total % FT_BATCH) | |
if rem == FT_BATCH: | |
rem = 0 # just reached batch multiple | |
# Trigger finetune if batch complete | |
if total % FT_BATCH == 0: | |
finetune_models() | |
# Clear submissions | |
for root, _, files in os.walk(SUBMISSIONS_PATH): | |
for f in files: | |
os.remove(os.path.join(root, f)) | |
return f"Осталось {rem} изображений до следующей тонкой настройки" | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Диагностика кожных поражений (HAM10K)") | |
status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены" | |
gr.Markdown(f"**Состояние моделей:** {status}") | |
with gr.Tabs(): | |
with gr.TabItem("EfficientNet"): | |
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") | |
gr.Button("Предсказать").click(predict_efficientnet, inputs=img, outputs=out) | |
with gr.TabItem("DeiT"): | |
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") | |
gr.Button("Предсказать").click(predict_deit, inputs=img, outputs=out) | |
with gr.TabItem("Ансамблевая модель"): | |
img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") | |
gr.Button("Предсказать").click(predict_ensemble, inputs=img, outputs=out) | |
with gr.TabItem("Submit for Finetuning"): | |
sub_img = gr.Image(type="pil", label="Изображение для тонкой настройки") | |
sub_lbl = gr.Dropdown(choices=list(label_mapping.values()), label="Выберите метку") | |
sub_btn = gr.Button("Отправить") | |
sub_out = gr.Textbox(label="Статус") | |
sub_btn.click(handle_submission, inputs=[sub_img, sub_lbl], outputs=sub_out) | |
return demo | |
if __name__ == "__main__": | |
interface = create_interface() | |
print("🚀 Запуск интерфейса...") | |
interface.launch(server_port=7860) | |