File size: 5,798 Bytes
fa2b4a8
ba932fd
e44c49f
ba932fd
 
 
fa2b4a8
ba932fd
 
fa2b4a8
ba932fd
fa2b4a8
ba932fd
fa2b4a8
e44c49f
ba932fd
 
 
 
e44c49f
 
16c2fe3
 
eb8c75c
e44c49f
 
 
 
 
ba932fd
 
 
 
 
e44c49f
fa2b4a8
ba932fd
 
 
 
 
 
 
fa2b4a8
 
34a30bc
 
cdc152a
8cfacf4
 
 
 
 
 
 
 
 
34a30bc
8cfacf4
 
 
ba932fd
8cfacf4
 
b224d9f
6ae859b
8cfacf4
ba932fd
8cfacf4
6ae859b
 
8cfacf4
3f789e4
8cfacf4
e44c49f
 
 
 
 
 
 
 
 
 
ba932fd
 
 
 
 
 
b224d9f
ba932fd
 
 
 
 
 
34a30bc
ba932fd
 
 
 
34a30bc
ba932fd
 
 
 
 
e44c49f
ba932fd
 
c57cd86
ba932fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34a30bc
fa2b4a8
 
 
 
e44c49f
 
fa2b4a8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
import os
import zipfile

# --- Cargar modelo ViT ---
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()

# --- Cargar modelos Fast.ai ---
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")

# --- Cargar modelo EfficientNetB3 desde Hugging Face ---
model_effnet = AutoModelForImageClassification.from_pretrained("syaha/skin_cancer_detection_model")
extractor_effnet = AutoFeatureExtractor.from_pretrained("syaha/skin_cancer_detection_model")
model_effnet.eval()

CLASSES = [
    "Queratosis actínica / Bowen", "Carcinoma células basales",
    "Lesión queratósica benigna", "Dermatofibroma",
    "Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]

RISK_LEVELS = {
    0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
    1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
    2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
    5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
    6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}

MALIGNANT_INDICES = [0, 1, 4]  # clases de riesgo alto/crítico

def analizar_lesion_combined(img):
    try:
        img_fastai = PILImage.create(img)
        inputs = feature_extractor(img, return_tensors="pt")
        with torch.no_grad():
            outputs = model_vit(**inputs)
            probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
        pred_idx_vit = int(np.argmax(probs_vit))
        pred_class_vit = CLASSES[pred_idx_vit]
        confidence_vit = probs_vit[pred_idx_vit]
    except Exception as e:
        pred_class_vit = "Error"
        confidence_vit = 0.0
        probs_vit = np.zeros(len(CLASSES))

    try:
        pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
        prob_malignant = float(probs_fast_mal[1])
    except:
        prob_malignant = 0.0

    try:
        pred_fast_type, _, _ = model_norm2000.predict(img_fastai)
    except:
        pred_fast_type = "Error"

    try:
        inputs_eff = extractor_effnet(images=img, return_tensors="pt")
        with torch.no_grad():
            outputs_eff = model_effnet(**inputs_eff)
            probs_eff = outputs_eff.logits.softmax(dim=-1).cpu().numpy()[0]
        pred_idx_eff = int(np.argmax(probs_eff))
        confidence_eff = probs_eff[pred_idx_eff]
        pred_class_eff = model_effnet.config.id2label[str(pred_idx_eff)]
    except Exception as e:
        pred_class_eff = "Error"
        confidence_eff = 0.0

    colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
    fig, ax = plt.subplots(figsize=(8, 3))
    ax.bar(CLASSES, probs_vit*100, color=colors_bars)
    ax.set_title("Probabilidad ViT por tipo de lesión")
    ax.set_ylabel("Probabilidad (%)")
    ax.set_xticks(np.arange(len(CLASSES)))
    ax.set_xticklabels(CLASSES, rotation=45, ha='right')
    ax.grid(axis='y', alpha=0.2)
    plt.tight_layout()
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    plt.close(fig)
    img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'

    informe = f"""
    <div style="font-family:sans-serif; max-width:800px; margin:auto">
    <h2>🧪 Diagnóstico por 4 modelos de IA</h2>
    <table style="border-collapse: collapse; width:100%; font-size:16px">
        <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
        <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
        <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
        <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
        <tr><td>🔬 EfficientNetB3 (HAM10000)</td><td><b>{pred_class_eff}</b></td><td>{confidence_eff:.1%}</td></tr>
    </table>
    <br>
    <b>🧪 Recomendación automática:</b><br>
    """

    cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
    if prob_malignant > 0.7 or cancer_risk_score > 0.6:
        informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
    elif prob_malignant > 0.4 or cancer_risk_score > 0.4:
        informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
    elif cancer_risk_score > 0.2:
        informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
    else:
        informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"

    informe += "</div>"
    return informe, html_chart

# Interfaz Gradio
demo = gr.Interface(
    fn=analizar_lesion_combined,
    inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
    outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
    title="Detector de Lesiones Cutáneas (ViT + Fast.ai + EfficientNetB3)",
    description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y un modelo EfficientNetB3.",
    flagging_mode="never"
)

if __name__ == "__main__":
    demo.launch()