File size: 5,930 Bytes
fa2b4a8
ba932fd
 
 
 
fa2b4a8
ba932fd
 
fa2b4a8
ba932fd
b207263
 
fa2b4a8
f3d9a1b
 
 
 
 
 
 
ba932fd
 
 
 
e44c49f
f3d9a1b
16c2fe3
 
eb8c75c
b207263
 
 
 
 
 
 
 
 
 
 
f3d9a1b
ba932fd
 
 
 
 
fa2b4a8
ba932fd
f3d9a1b
 
 
 
 
 
fa2b4a8
f3d9a1b
34a30bc
cdc152a
f3d9a1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b207263
f3d9a1b
 
 
 
 
 
 
 
 
b207263
 
 
 
 
 
 
f3d9a1b
 
ba932fd
f3d9a1b
 
ba932fd
b224d9f
ba932fd
 
 
 
 
 
f3d9a1b
ba932fd
b207263
ba932fd
 
f3d9a1b
 
 
 
 
 
 
b207263
f3d9a1b
 
ba932fd
b207263
 
f3d9a1b
b207263
ba932fd
b207263
ba932fd
f3d9a1b
 
ba932fd
 
b207263
ba932fd
 
 
fa2b4a8
 
f3d9a1b
 
b207263
fa2b4a8
b207263
fa2b4a8
 
f3d9a1b
b207263
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
from torchvision import transforms
from efficientnet_pytorch import EfficientNet

# --- Cargar modelo ViT preentrenado fine‑tuned HAM10000 ---
TF_MODEL_NAME = "Anwarkh1/Skin_Cancer-Image_Classification"
feature_extractor_tf = ViTImageProcessor.from_pretrained(TF_MODEL_NAME)
model_tf_vit = ViTForImageClassification.from_pretrained(TF_MODEL_NAME)
model_tf_vit.eval()

# 🔹 Cargar modelo ViT base
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()

# 🔹 Cargar modelos Fast.ai locales
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")

# 🔹 EfficientNet B7 para binario (benigno vs maligno)
model_eff = EfficientNet.from_pretrained("efficientnet-b7", num_classes=2)
model_eff.eval()

eff_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Clases estándar de HAM10000
CLASSES = [
    "Queratosis actínica / Bowen", "Carcinoma células basales",
    "Lesión queratósica benigna", "Dermatofibroma",
    "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
    0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
    1: {'level': 'Alto',     'color': '#ff4444', 'weight': 0.8},
    2: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1},
    3: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1},
    4: {'level': 'Crítico',  'color': '#cc0000', 'weight': 1.0},
    5: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1},
    6: {'level': 'Bajo',     'color': '#44ff44', 'weight': 0.1}
}
MALIGNANT_INDICES = [0, 1, 4]  # akiec, bcc, melanoma

def analizar_lesion_combined(img):
    img_fastai = PILImage.create(img)

    # ViT base
    inputs = feature_extractor(img, return_tensors="pt")
    with torch.no_grad():
        outputs = model_vit(**inputs)
        probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
    idx_vit = int(np.argmax(probs_vit))
    class_vit = CLASSES[idx_vit]
    conf_vit = probs_vit[idx_vit]

    # Fast.ai modelos
    _, _, probs_mal = model_malignancy.predict(img_fastai)
    prob_malign = float(probs_mal[1])
    pred_fast_type, _, _ = model_norm2000.predict(img_fastai)

    # ViT fine‑tuned (último modelo recomendado)
    inputs_tf = feature_extractor_tf(img, return_tensors="pt")
    with torch.no_grad():
        outputs_tf = model_tf_vit(**inputs_tf)
        probs_tf = outputs_tf.logits.softmax(dim=-1).cpu().numpy()[0]
    idx_tf = int(np.argmax(probs_tf))
    class_tf_model = CLASSES[idx_tf]
    conf_tf = probs_tf[idx_tf]
    mal_tf = "Maligno" if idx_tf in MALIGNANT_INDICES else "Benigno"

    # EfficientNet B7
    img_eff = eff_transform(img).unsqueeze(0)
    with torch.no_grad():
        out_eff = model_eff(img_eff)
        prob_eff = torch.softmax(out_eff, dim=1)[0, 1].item()
    eff_result = "Maligno" if prob_eff > 0.5 else "Benigno"

    # Gráfico ViT base
    colors = [RISK_LEVELS[i]['color'] for i in range(7)]
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.bar(CLASSES, probs_vit*100, color=colors)
    ax.set_title("Probabilidad ViT base por tipo de lesión")
    ax.set_ylabel("Probabilidad (%)")
    ax.set_xticks(np.arange(len(CLASSES)))
    ax.set_xticklabels(CLASSES, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.2)
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    html_chart = f'<img src="data:image/png;base64,{base64.b64encode(buf.getvalue()).decode()}" style="max-width:100%"/>'

    # Generar informe
    informe = f"""
    <div style="font-family:sans-serif; max-width:800px; margin:auto">
    <h2>🧪 Diagnóstico por múltiples modelos de IA</h2>
    <table style="width:100%; font-size:16px; border-collapse:collapse">
      <tr><th>Modelo</th><th>Resultado</th><th>Confianza</th></tr>
      <tr><td>🧠 ViT base</td><td><b>{class_vit}</b></td><td>{conf_vit:.1%}</td></tr>
      <tr><td>🧬 Fast.ai (tipo)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
      <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{'Maligno' if prob_malign > 0.5 else 'Benigno'}</b></td><td>{prob_malign:.1%}</td></tr>
      <tr><td>🌟 ViT fined‑tuned (HAM10000)</td><td><b>{mal_tf} ({class_tf_model})</b></td><td>{conf_tf:.1%}</td></tr>
      <tr><td>🏥 EfficientNet B7 (binario)</td><td><b>{eff_result}</b></td><td>{prob_eff:.1%}</td></tr>
    </table><br>
    <b>🩺 Recomendación automática:</b><br>
    """

    # Nivel de riesgo automático
    risk = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
    if prob_malign > 0.7 or risk > 0.6 or prob_eff > 0.7:
        informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
    elif prob_malign > 0.4 or risk > 0.4 or prob_eff > 0.5:
        informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
    elif risk > 0.2:
        informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada en 2-4 semanas"
    else:
        informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
    informe += "</div>"

    return informe, html_chart

demo = gr.Interface(
    fn=analizar_lesion_combined,
    inputs=gr.Image(type="pil"),
    outputs=[gr.HTML(label="Informe"), gr.HTML(label="Gráfico ViT base")],
    title="Detector de Lesiones Cutáneas (ViT + Fast.ai + EfficientNet)",
)

if __name__ == "__main__":
    demo.launch()