LoloSemper commited on
Commit
34a30bc
verified
1 Parent(s): 6ae859b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -14
app.py CHANGED
@@ -21,6 +21,7 @@ if not os.path.exists(extract_dir):
21
  zip_ref.extractall(extract_dir)
22
 
23
  model_tf = tf.saved_model.load(extract_dir)
 
24
 
25
  # Funci贸n helper para inferencia TensorFlow
26
  def predict_tf(img: Image.Image):
@@ -35,17 +36,17 @@ def predict_tf(img: Image.Image):
35
  infer = model_tf.signatures["serving_default"]
36
  output = infer(img_tf)
37
  pred = list(output.values())[0].numpy()[0]
38
- probs = tf.nn.softmax(pred).numpy()
39
  return probs
40
  except Exception as e:
41
  print(f"Error en predict_tf: {e}")
42
- return np.zeros(2)
43
 
 
44
  MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
45
  feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
46
  model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
47
  model_vit.eval()
48
-
49
  model_malignancy = load_learner("ada_learn_malben.pkl")
50
  model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
51
 
@@ -64,6 +65,9 @@ RISK_LEVELS = {
64
  6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
65
  }
66
 
 
 
 
67
  def analizar_lesion_combined(img):
68
  try:
69
  img_fastai = PILImage.create(img)
@@ -74,7 +78,7 @@ def analizar_lesion_combined(img):
74
  pred_idx_vit = int(np.argmax(probs_vit))
75
  pred_class_vit = CLASSES[pred_idx_vit]
76
  confidence_vit = probs_vit[pred_idx_vit]
77
- except:
78
  pred_class_vit = "Error"
79
  confidence_vit = 0.0
80
  probs_vit = np.zeros(len(CLASSES))
@@ -92,13 +96,12 @@ def analizar_lesion_combined(img):
92
 
93
  try:
94
  probs_tf = predict_tf(img)
95
- if len(probs_tf) == 2:
96
- benign_prob, malignant_prob = probs_tf
97
- pred_class_tf = "Maligno" if malignant_prob > benign_prob else "Benigno"
98
- confidence_tf = max(probs_tf)
99
  else:
100
- pred_class_tf = "Modelo no binario"
101
- confidence_tf = 0.0
102
  except:
103
  pred_class_tf = "Error"
104
  confidence_tf = 0.0
@@ -115,13 +118,12 @@ def analizar_lesion_combined(img):
115
  buf = io.BytesIO()
116
  plt.savefig(buf, format="png")
117
  plt.close(fig)
118
- img_bytes = buf.getvalue()
119
- img_b64 = base64.b64encode(img_bytes).decode("utf-8")
120
  html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
121
 
122
  informe = f"""
123
  <div style="font-family:sans-serif; max-width:800px; margin:auto">
124
- <h2>馃Β Diagn贸stico por 4 modelos de IA</h2>
125
  <table style="border-collapse: collapse; width:100%; font-size:16px">
126
  <tr><th style="text-align:left">馃攳 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
127
  <tr><td>馃 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
@@ -130,7 +132,7 @@ def analizar_lesion_combined(img):
130
  <tr><td>馃敩 TensorFlow (saved_model)</td><td><b>{pred_class_tf}</b></td><td>{confidence_tf:.1%}</td></tr>
131
  </table>
132
  <br>
133
- <b>馃Ε Recomendaci贸n autom谩tica:</b><br>
134
  """
135
 
136
  cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
@@ -146,6 +148,7 @@ def analizar_lesion_combined(img):
146
  informe += "</div>"
147
  return informe, html_chart
148
 
 
149
  demo = gr.Interface(
150
  fn=analizar_lesion_combined,
151
  inputs=gr.Image(type="pil", label="Sube una imagen de la lesi贸n"),
@@ -158,3 +161,4 @@ demo = gr.Interface(
158
  if __name__ == "__main__":
159
  demo.launch()
160
 
 
 
21
  zip_ref.extractall(extract_dir)
22
 
23
  model_tf = tf.saved_model.load(extract_dir)
24
+ TF_NUM_CLASSES = 7 # asumimos que son las mismas que CLASSES
25
 
26
  # Funci贸n helper para inferencia TensorFlow
27
  def predict_tf(img: Image.Image):
 
36
  infer = model_tf.signatures["serving_default"]
37
  output = infer(img_tf)
38
  pred = list(output.values())[0].numpy()[0]
39
+ probs = tf.nn.softmax(pred[:TF_NUM_CLASSES]).numpy()
40
  return probs
41
  except Exception as e:
42
  print(f"Error en predict_tf: {e}")
43
+ return np.zeros(TF_NUM_CLASSES)
44
 
45
+ # --- Cargar modelos ---
46
  MODEL_NAME = "ahishamm/vit-base-HAM-10000-sharpened-patch-32"
47
  feature_extractor = ViTImageProcessor.from_pretrained(MODEL_NAME)
48
  model_vit = ViTForImageClassification.from_pretrained(MODEL_NAME)
49
  model_vit.eval()
 
50
  model_malignancy = load_learner("ada_learn_malben.pkl")
51
  model_norm2000 = load_learner("ada_learn_skin_norm2000.pkl")
52
 
 
65
  6: {'level': 'Bajo', 'color': '#44ff44', 'weight': 0.1}
66
  }
67
 
68
+ MALIGNANT_INDICES = [0, 1, 4] # clases de riesgo alto/cr铆tico
69
+
70
+
71
  def analizar_lesion_combined(img):
72
  try:
73
  img_fastai = PILImage.create(img)
 
78
  pred_idx_vit = int(np.argmax(probs_vit))
79
  pred_class_vit = CLASSES[pred_idx_vit]
80
  confidence_vit = probs_vit[pred_idx_vit]
81
+ except Exception as e:
82
  pred_class_vit = "Error"
83
  confidence_vit = 0.0
84
  probs_vit = np.zeros(len(CLASSES))
 
96
 
97
  try:
98
  probs_tf = predict_tf(img)
99
+ pred_idx_tf = int(np.argmax(probs_tf))
100
+ confidence_tf = probs_tf[pred_idx_tf]
101
+ if pred_idx_tf < len(CLASSES):
102
+ pred_class_tf = "Maligno" if pred_idx_tf in MALIGNANT_INDICES else "Benigno"
103
  else:
104
+ pred_class_tf = f"Desconocido"
 
105
  except:
106
  pred_class_tf = "Error"
107
  confidence_tf = 0.0
 
118
  buf = io.BytesIO()
119
  plt.savefig(buf, format="png")
120
  plt.close(fig)
121
+ img_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
 
122
  html_chart = f'<img src="data:image/png;base64,{img_b64}" style="max-width:100%"/>'
123
 
124
  informe = f"""
125
  <div style="font-family:sans-serif; max-width:800px; margin:auto">
126
+ <h2>馃И Diagn贸stico por 4 modelos de IA</h2>
127
  <table style="border-collapse: collapse; width:100%; font-size:16px">
128
  <tr><th style="text-align:left">馃攳 Modelo</th><th>Resultado</th><th>Confianza</th></tr>
129
  <tr><td>馃 ViT (transformer)</td><td><b>{pred_class_vit}</b></td><td>{confidence_vit:.1%}</td></tr>
 
132
  <tr><td>馃敩 TensorFlow (saved_model)</td><td><b>{pred_class_tf}</b></td><td>{confidence_tf:.1%}</td></tr>
133
  </table>
134
  <br>
135
+ <b>馃┖ Recomendaci贸n autom谩tica:</b><br>
136
  """
137
 
138
  cancer_risk_score = sum(probs_vit[i] * RISK_LEVELS[i]['weight'] for i in range(7))
 
148
  informe += "</div>"
149
  return informe, html_chart
150
 
151
+ # Interfaz Gradio
152
  demo = gr.Interface(
153
  fn=analizar_lesion_combined,
154
  inputs=gr.Image(type="pil", label="Sube una imagen de la lesi贸n"),
 
161
  if __name__ == "__main__":
162
  demo.launch()
163
 
164
+