Spaces:
Sleeping
Sleeping
File size: 6,389 Bytes
fa2b4a8 ba932fd fa2b4a8 ba932fd fa2b4a8 ba932fd fa2b4a8 ba932fd fa2b4a8 ba932fd 34a30bc ba932fd 8cfacf4 b224d9f 8cfacf4 b224d9f 8cfacf4 b224d9f ba932fd 8cfacf4 34a30bc 8cfacf4 34a30bc ba932fd 34a30bc ba932fd 16c2fe3 eb8c75c ba932fd fa2b4a8 ba932fd fa2b4a8 34a30bc cdc152a 8cfacf4 34a30bc 8cfacf4 ba932fd 8cfacf4 b224d9f 6ae859b 8cfacf4 ba932fd 8cfacf4 6ae859b 8cfacf4 3f789e4 8cfacf4 34a30bc 8cfacf4 34a30bc 6ae859b 8cfacf4 ba932fd b224d9f ba932fd 34a30bc ba932fd 34a30bc ba932fd c57cd86 ba932fd 34a30bc fa2b4a8 ba932fd fa2b4a8 8658ccb 34a30bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
import torch
from transformers import ViTImageProcessor, ViTForImageClassification
from fastai.learner import load_learner
from fastai.vision.core import PILImage
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
import io
import base64
import os
import zipfile
import tensorflow as tf
# --- Extraer y cargar modelo TensorFlow desde zip ---
zip_path = "saved_model.zip"
extract_dir = "saved_model"
if not os.path.exists(extract_dir):
os.makedirs(extract_dir)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
model_tf = tf.saved_model.load(extract_dir)
TF_NUM_CLASSES = 7 # asumimos que son las mismas que CLASSES
# Función helper para inferencia TensorFlow
def predict_tf(img: Image.Image):
try:
img_resized = img.resize((224,224))
img_np = np.array(img_resized) / 255.0
if img_np.shape[-1] == 4:
img_np = img_np[..., :3]
img_tf = tf.convert_to_tensor(img_np, dtype=tf.float32)
img_tf = tf.expand_dims(img_tf, axis=0)
infer = model_tf.signatures["serving_default"]
output = infer(img_tf)
pred = list(output.values())[0].numpy()[0]
probs = tf.nn.softmax(pred[:TF_NUM_CLASSES]).numpy()
return probs
except Exception as e:
print(f"Error en predict_tf: {e}")
return np.zeros(TF_NUM_CLASSES)
# --- Cargar modelos ---
MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
model_vit.eval()
model_malignancy = load_learner("ada_learn_malben.pkl")
model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
CLASSES = [
"Queratosis actínica / Bowen", "Carcinoma células basales",
"Lesión queratósica benigna", "Dermatofibroma",
"Melanoma maligno", "Nevus melanocítico", "Lesión vascular"
]
RISK_LEVELS = {
0: {'level': 'Moderado', 'color': '#ffaa00', 'weight': 0.6},
1: {'level': 'Alto', 'color': '#ff4444', 'weight': 0.8},
2: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
3: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
4: {'level': 'Crítico', 'color': '#cc0000', 'weight': 1.0},
5: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1},
6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
}
MALIGNANT_INDICES = [0, 1, 4] # clases de riesgo alto/crítico
def analizar_lesion_combined(img):
try:
img_fastai = PILImage.create(img)
inputs = feature_extractor(img, return_tensors="pt")
with torch.no_grad():
outputs = model_vit(**inputs)
probs_vit = outputs.logits.softmax(dim=-1).cpu().numpy()[0]
pred_idx_vit = int(np.argmax(probs_vit))
pred_class_vit = CLASSES[pred_idx_vit]
confidence_vit = probs_vit[pred_idx_vit]
except Exception as e:
pred_class_vit = "Error"
confidence_vit = 0.0
probs_vit = np.zeros(len(CLASSES))
try:
pred_fast_malignant, _, probs_fast_mal = model_malignancy.predict(img_fastai)
prob_malignant = float(probs_fast_mal[1])
except:
prob_malignant = 0.0
try:
pred_fast_type, _, _ = model_norm2000.predict(img_fastai)
except:
pred_fast_type = "Error"
try:
probs_tf = predict_tf(img)
pred_idx_tf = int(np.argmax(probs_tf))
confidence_tf = probs_tf[pred_idx_tf]
if pred_idx_tf < len(CLASSES):
pred_class_tf = "Maligno" if pred_idx_tf in MALIGNANT_INDICES else "Benigno"
else:
pred_class_tf = f"Desconocido"
except:
pred_class_tf = "Error"
confidence_tf = 0.0
colors_bars = [RISK_LEVELS[i]['color'] for i in range(7)]
fig, ax = plt.subplots(figsize=(8, 3))
ax.bar(CLASSES, probs_vit*100, color=colors_bars)
ax.set_title("Probabilidad ViT por tipo de lesión")
ax.set_ylabel("Probabilidad (%)")
ax.set_xticks(np.arange(len(CLASSES)))
ax.set_xticklabels(CLASSES, rotation=45, ha='right')
ax.grid(axis='y', alpha=0.2)
plt.tight_layout()
buf = io.BytesIO()
plt.savefig(buf, format="png")
plt.close(fig)
img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
informe = f"""
<div style="font-family:sans-serif; max-width:800px; margin:auto">
<h2>🧪 Diagnóstico por 4 modelos de IA</h2>
<table style="border-collapse: collapse; width:100%; font-size:16px">
<tr><th style="text-align:left">🔍 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
<tr><td>🧠 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
<tr><td>🧬 Fast.ai (clasificación)</td><td><b>{pred_fast_type}</b></td><td>N/A</td></tr>
<tr><td>⚠️ Fast.ai (malignidad)</td><td><b>{"Maligno" if prob_malignant > 0.5 else "Benigno"}</b></td><td>{prob_malignant:.1%}</td></tr>
<tr><td>🔬 TensorFlow (saved_model)</td><td><b>{pred_class_tf}</b></td><td>{confidence_tf:.1%}</td></tr>
</table>
<br>
<b>🧪 Recomendación automática:</b><br>
"""
cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
if prob_malignant > 0.7 or cancer_risk_score > 0.6:
informe += "🚨 <b>CRÍTICO</b> – Derivación urgente a oncología dermatológica"
elif prob_malignant > 0.4 or cancer_risk_score > 0.4:
informe += "⚠️ <b>ALTO RIESGO</b> – Consulta con dermatólogo en 7 días"
elif cancer_risk_score > 0.2:
informe += "📋 <b>RIESGO MODERADO</b> – Evaluación programada (2-4 semanas)"
else:
informe += "✅ <b>BAJO RIESGO</b> – Seguimiento de rutina (3-6 meses)"
informe += "</div>"
return informe, html_chart
# Interfaz Gradio
demo = gr.Interface(
fn=analizar_lesion_combined,
inputs=gr.Image(type="pil", label="Sube una imagen de la lesión"),
outputs=[gr.HTML(label="Informe combinado"), gr.HTML(label="Gráfico ViT")],
title="Detector de Lesiones Cutáneas (ViT + Fast.ai + TensorFlow)",
description="Comparación entre ViT transformer (HAM10000), dos modelos Fast.ai y un modelo TensorFlow.",
flagging_mode="never"
)
if __name__ == "__main__":
demo.launch()
|