import torch from transformers import ViTImageProcessor, ViTForImageClassification from transformers import AutoFeatureExtractor, AutoModelForImageClassification from fastai.learner import load_learner from fastai.vision.core import PILImage from PIL import Image import matplotlib.pyplot as plt import numpy as np import gradio as gr import io import base64 import os import zipfile # --- Cargar modelo ViT --- MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32" feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME) model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME) model_vit.eval() # --- Cargar modelos Fast.ai --- model_malignancy = load_learner("ada_learn_malben.pkl") model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl") # --- Cargar modelo EfficientNetB3 desde Hugging Face --- model_effnet = AutoModelForImageClassification.from_pretrained("syaha/skin_cancer_detection_model") extractor_effnet = AutoFeatureExtractor.from_pretrained("syaha/skin_cancer_detection_model") model_effnet.eval() CLASSES = [ "Queratosis actínica / Bowen", "Carcinoma células basales", "Lesión queratósica benigna", "Dermatofibroma", "Melanoma maligno", "Nevus melanocítico", "Lesión vascular" ] RISK_LEVELS = { 0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6}, 1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8}, 2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}, 3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}, 4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0}, 5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}, 6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1} } MALIGNANT_INDICES = [0, 1, 4] # clases de riesgo alto/crítico def analizar_lesion_combined(img): try: img_fastai = PILImage.create(img) inputs = feature_extractor(img, return_tensors="pt") with torch.no_grad(): outputs = model_vit(**inputs) probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0] pred_idx_vit = int(np.argmax(probs_vit)) pred_class_vit = CLASSES[pred_idx_vit] confidence_vit = probs_vit[pred_idx_vit] except Exception as e: pred_class_vit = "Error" confidence_vit = 0.0 probs_vit = np.zeros(len(CLASSES)) try: pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai) prob_malignant = float(probs_fast_mal[1]) except: prob_malignant = 0.0 try: pred_fast_type, _, _ = model_norm2000.predict(img_fastai) except: pred_fast_type = "Error" try: inputs_eff = extractor_effnet(images=img, return_tensors="pt") with torch.no_grad(): outputs_eff = model_effnet(**inputs_eff) probs_eff = outputs_eff.logits.softmax(dim=-1).cpu().numpy()[0] pred_idx_eff = int(np.argmax(probs_eff)) confidence_eff = probs_eff[pred_idx_eff] pred_class_eff = model_effnet.config.id2label[str(pred_idx_eff)] except Exception as e: pred_class_eff = "Error" confidence_eff = 0.0 colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)] fig, ax = plt.subplots(figsize=(8, 3)) ax.bar(CLASSES, probs_vit*100, color=colors_bars) ax.set_title("Probabilidad ViT por tipo de lesión") ax.set_ylabel("Probabilidad (%)") ax.set_xticks(np.arange(len(CLASSES))) ax.set_xticklabels(CLASSES, rotation=45, ha='right') ax.grid(axis='y', alpha=0.2) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png") plt.close(fig) img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") html_chart = f'' informe = f"""

🧪 Diagnóstico por 4 modelos de IA

🔍 ModeloResultadoConfianza
🧠 ViT (transformer){pred_class_vit}{confidence_vit:.1%}
🧬 Fast.ai (clasificación){pred_fast_type}N/A
⚠️ Fast.ai (malignidad){"Maligno" if prob_malignant > 0.5 else "Benigno"}{prob_malignant:.1%}
🔬 EfficientNetB3 (HAM10000){pred_class_eff}{confidence_eff:.1%}

🧪 Recomendación automática:
""" cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7)) if prob_malignant > 0.7 or cancer_risk_score > 0.6: informe += "🚨 CRÍTICO – Derivación urgente a oncología dermatológica" elif prob_malignant > 0.4 or cancer_risk_score > 0.4: informe += "⚠️ ALTO RIESGO – Consulta con dermatólogo en 7 días" elif cancer_risk_score > 0.2: informe += "📋 RIESGO MODERADO – Evaluación programada (2-4 semanas)" else: informe += "✅ BAJO RIESGO – Seguimiento de rutina (3-6 meses)" informe += "
" return informe, html_chart # Interfaz Gradio demo = gr.Interface( fn=analizar_lesion_combined, inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"), outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")], title="Detector de Lesiones Cutáneas (ViT + Fast.ai + EfficientNetB3)", description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y un modelo EfficientNetB3.", flagging_mode="never" ) if __name__ == "__main__": demo.launch()