import gradio as gr import torch from PIL import Image import numpy as np from torchvision import models, transforms from torchvision.datasets import ImageFolder from torch.utils.data import DataLoader from transformers import ViTForImageClassification from torch import nn from torch.cuda.amp import autocast import os from contextlib import nullcontext # Global configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Label mapping (HAM10K) label_mapping = { 0: "Меланоцитарный невус", 1: "Меланома", 2: "Базальноклеточная карцинома", 3: "Актинический кератоз", 4: "Доброкачественная кератома", 5: "Дерматофиброма", 6: "Сосудистые поражения" } # Paths and hyperparams CHECKPOINTS_PATH = os.getenv("CHECKPOINTS_PATH", "./") SUBMISSIONS_PATH = os.getenv("SUBMISSIONS_PATH", "./submissions") FT_BATCH = 32 FT_EPOCHS = 1 # adjust as needed LR = 1e-4 os.makedirs(CHECKPOINTS_PATH, exist_ok=True) os.makedirs(SUBMISSIONS_PATH, exist_ok=True) # Model definitions def get_efficientnet(): model = models.efficientnet_v2_s(weights="IMAGENET1K_V1") model.classifier[1] = nn.Linear(1280, len(label_mapping)) return model.to(device) def get_deit(): model = ViTForImageClassification.from_pretrained( 'facebook/deit-base-patch16-224', num_labels=len(label_mapping), ignore_mismatched_sizes=True ) return model.to(device) # Transforms train_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def transform_image(image): return train_transform(image).unsqueeze(0).to(device) # Model Handler class ModelHandler: def __init__(self): self.efficientnet = None self.deit = None self.models_loaded = False self.load_models() def load_models(self): try: self.efficientnet = get_efficientnet() eff_path = os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth") self.efficientnet.load_state_dict(torch.load(eff_path, map_location=device)) self.efficientnet.eval() self.deit = get_deit() deit_path = os.path.join(CHECKPOINTS_PATH, "deit_best.pth") self.deit.load_state_dict(torch.load(deit_path, map_location=device)) self.deit.eval() self.models_loaded = True print("✅ Models loaded successfully") except Exception as e: print(f"❌ Error loading models: {e}") self.models_loaded = False @torch.no_grad() def predict(self, image, use='efficientnet'): if not self.models_loaded: return {"error": "Модели не загружены"} inputs = transform_image(image) ctx = autocast() if device.type == 'cuda' else nullcontext() with ctx: if use == 'efficientnet': logits = self.efficientnet(inputs) elif use == 'deit': logits = self.deit(pixel_values=inputs).logits else: logits = (self.efficientnet(inputs) + self.deit(pixel_values=inputs).logits) / 2 probs = torch.nn.functional.softmax(logits, dim=1) return self._format_predictions(probs) def _format_predictions(self, probs): top5_probs, top5_inds = torch.topk(probs, 5) return {label_mapping[i.item()]: float(top5_probs[0][k].item()) for k, i in enumerate(top5_inds[0])} # Initialize handler model_handler = ModelHandler() def predict_efficientnet(image): return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'efficientnet') def predict_deit(image): return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'deit') def predict_ensemble(image): return "⚠️ Загрузите изображение" if image is None else model_handler.predict(image, 'ensemble') # Finetuning logic def finetune_models(): # Prepare dataset dataset = ImageFolder(SUBMISSIONS_PATH, transform=train_transform) loader = DataLoader(dataset, batch_size=8, shuffle=True) # Finetune EfficientNet eff = get_efficientnet() eff.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth"), map_location=device)) eff.train() optimizer = torch.optim.Adam(eff.parameters(), lr=LR) criterion = nn.CrossEntropyLoss() for epoch in range(FT_EPOCHS): for imgs, lbls in loader: imgs, lbls = imgs.to(device), lbls.to(device) optimizer.zero_grad() outputs = eff(imgs) loss = criterion(outputs, lbls) loss.backward() optimizer.step() torch.save(eff.state_dict(), os.path.join(CHECKPOINTS_PATH, "efficientnet_best.pth")) # Finetune DeiT dt = get_deit() dt.load_state_dict(torch.load(os.path.join(CHECKPOINTS_PATH, "deit_best.pth"), map_location=device)) dt.train() optimizer = torch.optim.Adam(dt.parameters(), lr=LR) for epoch in range(FT_EPOCHS): for imgs, lbls in loader: imgs, lbls = imgs.to(device), lbls.to(device) optimizer.zero_grad() outputs = dt(pixel_values=imgs).logits loss = criterion(outputs, lbls) loss.backward() optimizer.step() torch.save(dt.state_dict(), os.path.join(CHECKPOINTS_PATH, "deit_best.pth")) # Reload into handler model_handler.load_models() print("🔄 Models fine-tuned and reloaded") def handle_submission(image, label): if image is None or label is None: return "⚠️ Загрузите изображение и выберите метку" # Save image under label folder lbl_dir = os.path.join(SUBMISSIONS_PATH, str(label)) os.makedirs(lbl_dir, exist_ok=True) idx = len([f for f in os.listdir(lbl_dir) if f.endswith(('.png','.jpg'))]) + 1 path = os.path.join(lbl_dir, f"{label}_{idx}.png") image.save(path) # Count total submissions total = sum(len(files) for _, _, files in os.walk(SUBMISSIONS_PATH)) rem = FT_BATCH - (total % FT_BATCH) if rem == FT_BATCH: rem = 0 # just reached batch multiple # Trigger finetune if batch complete if total % FT_BATCH == 0: finetune_models() # Clear submissions for root, _, files in os.walk(SUBMISSIONS_PATH): for f in files: os.remove(os.path.join(root, f)) return f"Осталось {rem} изображений до следующей тонкой настройки" # Create Gradio interface def create_interface(): with gr.Blocks() as demo: gr.Markdown("# Диагностика кожных поражений (HAM10K)") status = "✅ Модели готовы к предсказанию" if model_handler.models_loaded else "⚠️ Предупреждение: Модели не загружены" gr.Markdown(f"**Состояние моделей:** {status}") with gr.Tabs(): with gr.TabItem("EfficientNet"): img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") gr.Button("Предсказать").click(predict_efficientnet, inputs=img, outputs=out) with gr.TabItem("DeiT"): img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") gr.Button("Предсказать").click(predict_deit, inputs=img, outputs=out) with gr.TabItem("Ансамблевая модель"): img, out = gr.Image(type="pil", label="Загрузите изображение"), gr.Label(label="Результаты") gr.Button("Предсказать").click(predict_ensemble, inputs=img, outputs=out) with gr.TabItem("Submit for Finetuning"): sub_img = gr.Image(type="pil", label="Изображение для тонкой настройки") sub_lbl = gr.Dropdown(choices=list(label_mapping.values()), label="Выберите метку") sub_btn = gr.Button("Отправить") sub_out = gr.Textbox(label="Статус") sub_btn.click(handle_submission, inputs=[sub_img, sub_lbl], outputs=sub_out) return demo if __name__ == "__main__": interface = create_interface() print("🚀 Запуск интерфейса...") interface.launch(server_port=7860)