File size: 5,915 Bytes
fa2b4a8
2b23c99
fa2b4a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95a8fa7
 
 
 
 
 
 
fa2b4a8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cdc152a
2b23c99
 
 
 
 
 
 
 
 
cdc152a
2b23c99
fa2b4a8
2b23c99
cdc152a
2b23c99
 
 
 
 
 
 
 
fe3b436
2b23c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe3b436
fa2b4a8
2b23c99
 
 
 
 
 
 
fa2b4a8
dcd58f1
2b23c99
 
 
 
 
 
 
 
 
dcd58f1
2b23c99
 
dcd58f1
2b23c99
 
 
 
 
dcd58f1
fa2b4a8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# app.py

import gradio as gr
import torch
import numpy as np
import matplotlib.pyplot as plt
import base64
import io
from fastai.vision.all import *
import tensorflow as tf
from tensorflow import keras
import zipfile
import os
import traceback

# Cargar modelo TensorFlow ISIC (descomprimir solo una vez)
if not os.path.exists("saved_model"):
    with zipfile.ZipFile("saved_model.zip", "r") as zip_ref:
        zip_ref.extractall(".")

# Cargar modelo ISIC con TensorFlow
from keras.layers import TFSMLayer

try:
    model_isic = TFSMLayer("saved_model", call_endpoint="serving_default")
except Exception as e:
    print("🔴 Error al cargar el modelo ISIC con TFSMLayer:", e)
    raise

# Cargar modelos fastai
model_malignancy = load_learner("modelo_malignancy.pkl")
model_norm2000 = load_learner("modelo_norm2000.pkl")

# Cargar modelo ViT
from transformers import AutoImageProcessor, AutoModelForImageClassification
feature_extractor = AutoImageProcessor.from_pretrained("nateraw/vit-skin-cancer")
model_vit = AutoModelForImageClassification.from_pretrained("nateraw/vit-skin-cancer")

# Clases y colores
CLASSES = ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
RISK_LEVELS = {
    0: {"label": "akiec", "color": "#FF6F61", "weight": 0.9},
    1: {"label": "bcc", "color": "#FF8C42", "weight": 0.7},
    2: {"label": "bkl", "color": "#FFD166", "weight": 0.3},
    3: {"label": "df", "color": "#06D6A0", "weight": 0.1},
    4: {"label": "mel", "color": "#EF476F", "weight": 1.0},
    5: {"label": "nv", "color": "#118AB2", "weight": 0.2},
    6: {"label": "vasc", "color": "#073B4C", "weight": 0.4},
}

# Preprocesado para TensorFlow ISIC
def preprocess_image_isic(pil_image):
    image = pil_image.resize((224, 224))
    array = np.array(image) / 255.0
    return np.expand_dims(array, axis=0)

# Función de análisis (como ya la tienes)
def analizar_lesion_combined(img):
    try:
        img_fastai = PILImage.create(img)
        inputs = feature_extractor(img, return_tensors="pt")
        with torch.no_grad():
            outputs = model_vit(**inputs)
            probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
        pred_idx_vit = int(np.argmax(probs_vit))
        pred_class_vit = CLASSES[pred_idx_vit]
        confidence_vit = probs_vit[pred_idx_vit]

        pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
        prob_malignant = float(probs_fast_mal[1])
        pred_fast_type, _, probs_fast_type = model_norm2000.predict(img_fastai)

        x_isic = preprocess_image_isic(img)
        preds_isic_dict = model_isic(x_isic)
        print("🔍 Claves de salida de model_isic:", preds_isic_dict.keys())
        key = list(preds_isic_dict.keys())[0]
        preds_isic = preds_isic_dict[key].numpy()[0]
        pred_idx_isic = int(np.argmax(preds_isic))
        pred_class_isic = CLASSES[pred_idx_isic]
        confidence_isic = preds_isic[pred_idx_isic]

        colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
        fig, ax = plt.subplots(figsize=(8, 3))
        ax.bar(CLASSES, probs_vit*100, color=colors_bars)
        ax.set_title("Probabilidad ViT por tipo de lesión")
        ax.set_ylabel("Probabilidad (%)")
        ax.set_xticks(np.arange(len(CLASSES)))
        ax.set_xticklabels(CLASSES, rotation=45, ha='right')
        ax.grid(axis='y', alpha=0.2)
        plt.tight_layout()
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        plt.close(fig)
        img_bytes = buf.getvalue()
        img_b64 = base64.b64encode(img_bytes).decode("utf-8")
        html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'

        informe = f"""<div style="font-family:sans-serif; max-width:800px; margin:auto">
        <h2>🧪 Diagnóstico por 4 modelos de IA</h2>
        <table style="border-collapse: collapse; width:100%; font-size:16px">
            <tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
            <tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
            <tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
            <tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
            <tr><td>🔬 ISIC TensorFlow</td><td><b>{pred_class_isic}</b></td><td>{confidence_isic:.1%}</td></tr>
        </table><br><b>🩺 Recomendación automática:</b><br>"""

        cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
        if prob_malignant > 0.7 or cancer_risk_score > 0.6:
            informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
        elif prob_malignant > 0.4 or cancer_risk_score > 0.4:
            informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
        elif cancer_risk_score > 0.2:
            informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
        else:
            informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"

        informe += "</div>"
        return informe, html_chart

    except Exception as e:
        print("🔴 ERROR en analizar_lesion_combined:")
        print(str(e))
        traceback.print_exc()
        return f"<b>Error interno:</b> {str(e)}", ""

# INTERFAZ
demo = gr.Interface(
    fn=analizar_lesion_combined,
    inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
    outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
    title="Detector de Lesiones Cutáneas (ViT + Fast.ai + ISIC TensorFlow)",
    description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y el modelo ISIC TensorFlow.",
    flagging_mode="never"
)

# LANZAMIENTO
if __name__ == "__main__":
    demo.launch()