Spaces:
Sleeping
Sleeping
File size: 5,930 Bytes
fa2b4a8 ba932fd fa2b4a8 ba932fd fa2b4a8 ba932fd b207263 fa2b4a8 f3d9a1b ba932fd e44c49f f3d9a1b 16c2fe3 eb8c75c b207263 f3d9a1b ba932fd fa2b4a8 ba932fd f3d9a1b fa2b4a8 f3d9a1b 34a30bc cdc152a f3d9a1b b207263 f3d9a1b b207263 f3d9a1b ba932fd f3d9a1b ba932fd b224d9f ba932fd f3d9a1b ba932fd b207263 ba932fd f3d9a1b b207263 f3d9a1b ba932fd b207263 f3d9a1b b207263 ba932fd b207263 ba932fd f3d9a1b ba932fd b207263 ba932fd fa2b4a8 f3d9a1b b207263 fa2b4a8 b207263 fa2b4a8 f3d9a1b b207263 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
# --- Cargar modelo ViT preentrenado fine‑tuned HAM10000 ---
TF_MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification"
feature_extractor_tf = ViTImageProcessor.from_pretrained(TF_MODEL_NAME)
model_tf_vit = ViTForImageClassification.from_pretrained(TF_MODEL_NAME)
model_tf_vit.eval()
# 🔹 Cargar modelo ViT base
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()
# 🔹 Cargar modelos Fast.ai locales
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
# 🔹 EfficientNet B7 para binario (benigno vs maligno)
model_eff = EfficientNet.from_pretrained("efficientnet-b7", num_classes=2)
model_eff.eval()
eff_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# Clases estándar de HAM10000
CLASSES = [
"Queratosis actínica / Bowen", "Carcinoma células basales",
"Lesión queratósica benigna", "Dermatofibroma",
"Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}
MALIGNANT_INDICES = [0, 1, 4] # akiec, bcc, melanoma
def analizar_lesion_combined(img):
img_fastai = PILImage.create(img)
# ViT base
inputs = feature_extractor(img, return_tensors="pt")
with torch.no_grad():
outputs = model_vit(**inputs)
probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
idx_vit = int(np.argmax(probs_vit))
class_vit = CLASSES[idx_vit]
conf_vit = probs_vit[idx_vit]
# Fast.ai modelos
_, _, probs_mal = model_malignancy.predict(img_fastai)
prob_malign = float(probs_mal[1])
pred_fast_type, _, _ = model_norm2000.predict(img_fastai)
# ViT fine‑tuned (último modelo recomendado)
inputs_tf = feature_extractor_tf(img, return_tensors="pt")
with torch.no_grad():
outputs_tf = model_tf_vit(**inputs_tf)
probs_tf = outputs_tf.logits.softmax(dim=-1).cpu().numpy()[0]
idx_tf = int(np.argmax(probs_tf))
class_tf_model = CLASSES[idx_tf]
conf_tf = probs_tf[idx_tf]
mal_tf = "Maligno" if idx_tf in MALIGNANT_INDICES else "Benigno"
# EfficientNet B7
img_eff = eff_transform(img).unsqueeze(0)
with torch.no_grad():
out_eff = model_eff(img_eff)
prob_eff = torch.softmax(out_eff, dim=1)[0, 1].item()
eff_result = "Maligno" if prob_eff > 0.5 else "Benigno"
# Gráfico ViT base
colors = [RISK_LEVELS[i]['color'] for i in range(7)]
fig, ax = plt.subplots(figsize=(8, 3))
ax.bar(CLASSES, probs_vit*100, color=colors)
ax.set_title("Probabilidad ViT base por tipo de lesión")
ax.set_ylabel("Probabilidad (%)")
ax.set_xticks(np.arange(len(CLASSES)))
ax.set_xticklabels(CLASSES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
html_chart = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" style="max-width:100%"/>'
# Generar informe
informe = f"""
<div style="font-family:sans-serif; max-width:800px; margin:auto">
<h2>🧪 Diagnóstico por múltiples modelos de IA</h2>
<table style="width:100%; font-size:16px; border-collapse:collapse">
<tr><th>Modelo</th><th>Resultado</th><th>Confianza</th></tr>
<tr><td>🧠 ViT base</td><td><b>{class_vit}</b></td><td>{conf_vit:.1%}</td></tr>
<tr><td>🧬 Fast.ai (tipo)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
<tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{'Maligno' if prob_malign > 0.5 else 'Benigno'}</b></td><td>{prob_malign:.1%}</td></tr>
<tr><td>🌟 ViT fined‑tuned (HAM10000)</td><td><b>{mal_tf} ({class_tf_model})</b></td><td>{conf_tf:.1%}</td></tr>
<tr><td>🏥 EfficientNet B7 (binario)</td><td><b>{eff_result}</b></td><td>{prob_eff:.1%}</td></tr>
</table><br>
<b>🩺 Recomendación automática:</b><br>
"""
# Nivel de riesgo automático
risk = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
if prob_malign > 0.7 or risk > 0.6 or prob_eff > 0.7:
informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
elif prob_malign > 0.4 or risk > 0.4 or prob_eff > 0.5:
informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
elif risk > 0.2:
informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada en 2-4 semanas"
else:
informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
informe += "</div>"
return informe, html_chart
demo = gr.Interface(
fn=analizar_lesion_combined,
inputs=gr.Image(type="pil"),
outputs=[gr.HTML(label="Informe"), gr.HTML(label="Gráfico ViT base")],
title="Detector de Lesiones Cutáneas (ViT + Fast.ai + EfficientNet)",
)
if __name__ == "__main__":
demo.launch()
|