fecia commited on
Commit
ad75092
·
verified ·
1 Parent(s): 8a763c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -352
app.py CHANGED
@@ -7,426 +7,329 @@ import tensorflow_text as tf_text
7
  import tensorflow_hub as tf_hub
8
  import numpy as np
9
  from PIL import Image
10
- from huggingface_hub import snapshot_download, HfFolder
11
  from sklearn.metrics.pairwise import cosine_similarity
12
  import traceback
13
  import time
14
- import pandas as pd # Para formatear la salida en tabla
15
 
16
  # --- Configuración ---
17
  MODEL_REPO_ID = "google/cxr-foundation"
18
- MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space' # Directorio dentro del contenedor del Space
19
- # Umbrales
20
- SIMILARITY_DIFFERENCE_THRESHOLD = 0.0
21
- POSITIVE_SIMILARITY_THRESHOLD = 0.0
22
  print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}")
23
 
24
- # --- Prompts ---
25
  criteria_list_positive = [
26
- "optimal centering", "optimal inspiration", "optimal penetration",
27
- "complete field of view", "scapulae retracted", "sharp image", "artifact free"
 
 
 
 
 
28
  ]
29
  criteria_list_negative = [
30
- "poorly centered", "poor inspiration", "non-diagnostic exposure",
31
- "cropped image", "scapulae overlying lungs", "blurred image", "obscuring artifact"
 
 
 
 
 
32
  ]
33
 
34
- # --- Funciones Auxiliares (Integradas o adaptadas) ---
35
- # @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) # Puede ayudar rendimiento
36
- def preprocess_text(text):
37
- """Función interna del preprocesador BERT."""
38
- return bert_preprocessor_global(text)
39
-
40
  def bert_tokenize(text, preprocessor):
41
- """Tokeniza texto usando el preprocesador BERT cargado globalmente."""
42
  if preprocessor is None:
43
- raise ValueError("BERT preprocessor no está cargado.")
44
- if not isinstance(text, str): text = str(text)
45
-
46
- # Ejecutar el preprocesador
47
- out = preprocessor(tf.constant([text.lower()]))
48
-
49
- # Extraer y procesar IDs y máscaras
50
  ids = out['input_word_ids'].numpy().astype(np.int32)
51
  masks = out['input_mask'].numpy().astype(np.float32)
52
  paddings = 1.0 - masks
53
-
54
- # Reemplazar token [SEP] (102) por 0 y marcar como padding
55
  end_token_idx = (ids == 102)
56
  ids[end_token_idx] = 0
57
  paddings[end_token_idx] = 1.0
58
-
59
- # Asegurar las dimensiones (B, T, S) -> (1, 1, 128)
60
- # El preprocesador puede devolver (1, 128), necesitamos (1, 1, 128)
61
- if ids.ndim == 2: ids = np.expand_dims(ids, axis=1)
62
- if paddings.ndim == 2: paddings = np.expand_dims(paddings, axis=1)
63
-
64
- # Verificar formas finales
65
- expected_shape = (1, 1, 128)
66
- if ids.shape != expected_shape:
67
- # Intentar reajustar si es necesario (puede pasar con algunas versiones)
68
- if ids.shape == (1,128): ids = np.expand_dims(ids, axis=1)
69
- else: raise ValueError(f"Shape incorrecta para ids: {ids.shape}, esperado {expected_shape}")
70
- if paddings.shape != expected_shape:
71
- if paddings.shape == (1,128): paddings = np.expand_dims(paddings, axis=1)
72
- else: raise ValueError(f"Shape incorrecta para paddings: {paddings.shape}, esperado {expected_shape}")
73
-
74
  return ids, paddings
75
 
76
  def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
77
- """Crea tf.train.Example desde NumPy array (escala de grises)."""
78
  if image_array.ndim == 3 and image_array.shape[2] == 1:
79
- image_array = np.squeeze(image_array, axis=2) # Asegurar 2D
80
  elif image_array.ndim != 2:
81
- raise ValueError(f'Array debe ser 2-D (escala de grises). Dimensiones actuales: {image_array.ndim}')
82
-
83
  image = image_array.astype(np.float32)
84
- min_val = image.min()
85
- max_val = image.max()
86
-
87
- # Evitar división por cero si la imagen es constante
88
  if max_val <= min_val:
89
- # Si es constante, tratar como uint8 si el rango original lo permitía,
90
- # o simplemente ponerla a 0 si es float.
91
  if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255):
92
- pixel_array = image.astype(np.uint8)
93
- bitdepth = 8
94
- else: # Caso flotante constante o fuera de rango uint8
95
- pixel_array = np.zeros_like(image, dtype=np.uint16)
96
- bitdepth = 16
97
  else:
98
- image -= min_val # Mover mínimo a cero
99
  current_max = max_val - min_val
100
- # Escalar a 16-bit para mayor precisión si no era uint8 originalmente
101
  if image_array.dtype != np.uint8:
102
  image *= 65535 / current_max
103
- pixel_array = image.astype(np.uint16)
104
- bitdepth = 16
105
  else:
106
- # Si era uint8, mantener el rango y tipo
107
- # La resta del min ya la dejó en [0, current_max]
108
- # Escalar a 255 si es necesario
109
  image *= 255 / current_max
110
- pixel_array = image.astype(np.uint8)
111
- bitdepth = 8
112
-
113
- # Codificar como PNG
114
  output = io.BytesIO()
115
- png.Writer(
116
- width=pixel_array.shape[1],
117
- height=pixel_array.shape[0],
118
- greyscale=True,
119
- bitdepth=bitdepth
120
- ).write(output, pixel_array.tolist())
121
- png_bytes = output.getvalue()
122
-
123
- # Crear tf.train.Example
124
  example = tf.train.Example()
125
  features = example.features.feature
126
- features['image/encoded'].bytes_list.value.append(png_bytes)
127
  features['image/format'].bytes_list.value.append(b'png')
128
  return example
129
 
130
  def generate_image_embedding(img_np, elixrc_infer, qformer_infer):
131
- """Genera embedding final de imagen."""
132
  if elixrc_infer is None or qformer_infer is None:
133
  raise ValueError("Modelos ELIXR-C o QFormer no cargados.")
134
-
135
  try:
136
- # 1. ELIXR-C
137
- serialized_img_tf_example = png_to_tfexample(img_np).SerializeToString()
138
- elixrc_output = elixrc_infer(input_example=tf.constant([serialized_img_tf_example]))
139
- elixrc_embedding = elixrc_output['feature_maps_0'].numpy()
140
- print(f" Embedding ELIXR-C shape: {elixrc_embedding.shape}")
141
-
142
- # 2. QFormer (Imagen)
143
- qformer_input_img = {
144
- 'image_feature': elixrc_embedding.tolist(),
145
- 'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(), # Texto vacío
146
- 'paddings': np.ones((1, 1, 128), dtype=np.float32).tolist(), # Todo padding
147
  }
148
- qformer_output_img = qformer_infer(**qformer_input_img)
149
- image_embedding = qformer_output_img['all_contrastive_img_emb'].numpy()
150
-
151
- # Ajustar dimensiones si es necesario
152
- if image_embedding.ndim > 2:
153
- print(f" Ajustando dimensiones embedding imagen (original: {image_embedding.shape})")
154
- image_embedding = np.mean(
155
- image_embedding,
156
- axis=tuple(range(1, image_embedding.ndim - 1))
157
- )
158
- if image_embedding.ndim == 1:
159
- image_embedding = np.expand_dims(image_embedding, axis=0)
160
- elif image_embedding.ndim == 1:
161
- image_embedding = np.expand_dims(image_embedding, axis=0) # Asegurar 2D
162
-
163
- print(f" Embedding final imagen shape: {image_embedding.shape}")
164
- if image_embedding.ndim != 2:
165
- raise ValueError(f"Embedding final de imagen no tiene 2 dimensiones: {image_embedding.shape}")
166
- return image_embedding
167
-
168
  except Exception as e:
169
- print(f"Error generando embedding de imagen: {e}")
170
  traceback.print_exc()
171
- raise # Re-lanzar la excepción para que Gradio la maneje
172
-
173
- def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer):
174
- """Calcula similitudes y clasifica."""
175
- if image_embedding is None: raise ValueError("Embedding de imagen es None.")
176
- if bert_preprocessor is None: raise ValueError("Preprocesador BERT es None.")
177
- if qformer_infer is None: raise ValueError("QFormer es None.")
178
-
179
- detailed_results = {}
180
- print("\n--- Calculando similitudes y clasificando ---")
181
-
182
- for i in range(len(criteria_list_positive)):
183
- positive_text = criteria_list_positive[i]
184
- negative_text = criteria_list_negative[i]
185
- criterion_name = positive_text # Usar prompt positivo como clave
186
-
187
- print(f"Procesando criterio: \"{criterion_name}\"")
188
- similarity_positive, similarity_negative, difference = None, None, None
189
- classification_comp, classification_simp = "ERROR", "ERROR"
190
-
191
  try:
192
- # 1. Embedding Texto Positivo
193
- tokens_pos, paddings_pos = bert_tokenize(positive_text, bert_preprocessor)
194
- qformer_input_text_pos = {
195
- 'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), # Dummy
196
- 'ids': tokens_pos.tolist(), 'paddings': paddings_pos.tolist(),
197
- }
198
- text_embedding_pos = qformer_infer(**qformer_input_text_pos)['contrastive_txt_emb'].numpy()
199
- if text_embedding_pos.ndim == 1: text_embedding_pos = np.expand_dims(text_embedding_pos, axis=0)
200
-
201
- # 2. Embedding Texto Negativo
202
- tokens_neg, paddings_neg = bert_tokenize(negative_text, bert_preprocessor)
203
- qformer_input_text_neg = {
204
- 'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), # Dummy
205
- 'ids': tokens_neg.tolist(), 'paddings': paddings_neg.tolist(),
206
- }
207
- text_embedding_neg = qformer_infer(**qformer_input_text_neg)['contrastive_txt_emb'].numpy()
208
- if text_embedding_neg.ndim == 1: text_embedding_neg = np.expand_dims(text_embedding_neg, axis=0)
209
-
210
- # Verificar compatibilidad de dimensiones para similitud
211
- if image_embedding.shape[1] != text_embedding_pos.shape[1]:
212
- raise ValueError(f"Dimensión incompatible: Imagen ({image_embedding.shape[1]}) vs Texto Pos ({text_embedding_pos.shape[1]})")
213
- if image_embedding.shape[1] != text_embedding_neg.shape[1]:
214
- raise ValueError(f"Dimensión incompatible: Imagen ({image_embedding.shape[1]}) vs Texto Neg ({text_embedding_neg.shape[1]})")
215
-
216
- # 3. Calcular Similitudes
217
- similarity_positive = cosine_similarity(image_embedding, text_embedding_pos)[0][0]
218
- similarity_negative = cosine_similarity(image_embedding, text_embedding_neg)[0][0]
219
- print(f" Sim (+)={similarity_positive:.4f}, Sim (-)={similarity_negative:.4f}")
220
-
221
- # 4. Clasificar
222
- difference = similarity_positive - similarity_negative
223
- classification_comp = "PASS" if difference > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL"
224
- classification_simp = "PASS" if similarity_positive > POSITIVE_SIMILARITY_THRESHOLD else "FAIL"
225
- print(f" Diff={difference:.4f} -> Comp: {classification_comp}, Simp: {classification_simp}")
226
-
227
  except Exception as e:
228
- print(f" ERROR procesando criterio '{criterion_name}': {e}")
229
- traceback.print_exc()
230
- # Mantener clasificaciones como "ERROR"
231
-
232
- # Guardar resultados
233
- detailed_results[criterion_name] = {
234
- 'positive_prompt': positive_text,
235
- 'negative_prompt': negative_text,
236
- 'similarity_positive': float(similarity_positive) if similarity_positive is not None else None,
237
- 'similarity_negative': float(similarity_negative) if similarity_negative is not None else None,
238
- 'difference': float(difference) if difference is not None else None,
239
- 'classification_comparative': classification_comp,
240
- 'classification_simplified': classification_simp
241
  }
242
- return detailed_results
243
 
244
  # --- Carga Global de Modelos ---
245
- # Se ejecuta UNA VEZ al iniciar la aplicación Gradio/Space
246
- print("--- Iniciando carga global de modelos ---")
247
  start_time = time.time()
248
  models_loaded = False
249
- bert_preprocessor_global = None
250
- elixrc_infer_global = None
251
- qformer_infer_global = None
252
-
253
  try:
254
- # Verificar autenticación HF (útil si se usan modelos privados, aunque no es el caso aquí)
255
- # if HfFolder.get_token() is None:
256
- # print("Advertencia: No se encontró token de Hugging Face.")
257
- # else:
258
- # print("Token de Hugging Face encontrado.")
259
-
260
- # Crear directorio si no existe
261
  os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True)
262
- print(f"Descargando/verificando modelos en: {MODEL_DOWNLOAD_DIR}")
263
  snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR,
264
- allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'],
265
- local_dir_use_symlinks=False) # Evitar symlinks
266
- print("Modelos descargados/verificados.")
267
-
268
- # Cargar Preprocesador BERT desde TF Hub
269
- print("Cargando Preprocesador BERT...")
270
- # Usar handle explícito puede ser más robusto en algunos entornos
271
- bert_preprocess_handle = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
272
- bert_preprocessor_global = tf_hub.KerasLayer(bert_preprocess_handle)
273
- print("Preprocesador BERT cargado.")
274
-
275
- # Cargar ELIXR-C
276
- print("Cargando ELIXR-C...")
277
- elixrc_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'elixr-c-v2-pooled')
278
- elixrc_model = tf.saved_model.load(elixrc_model_path)
279
- elixrc_infer_global = elixrc_model.signatures['serving_default']
280
- print("Modelo ELIXR-C cargado.")
281
-
282
- # Cargar QFormer (ELIXR-B Text)
283
- print("Cargando QFormer (ELIXR-B Text)...")
284
- qformer_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'pax-elixr-b-text')
285
- qformer_model = tf.saved_model.load(qformer_model_path)
286
- qformer_infer_global = qformer_model.signatures['serving_default']
287
- print("Modelo QFormer cargado.")
288
-
289
  models_loaded = True
290
- end_time = time.time()
291
- print(f"--- Modelos cargados globalmente con éxito en {end_time - start_time:.2f} segundos ---")
292
-
293
  except Exception as e:
294
- models_loaded = False
295
- print(f"--- ERROR CRÍTICO DURANTE LA CARGA GLOBAL DE MODELOS ---")
296
- print(e)
297
  traceback.print_exc()
298
- # Gradio se iniciará, pero la función de análisis fallará.
299
 
300
- # --- Función Principal de Procesamiento para Gradio ---
301
- def assess_quality(image_pil):
302
- """Función que Gradio llamará con la imagen de entrada."""
303
  if not models_loaded:
304
- raise gr.Error("Error: Los modelos no se pudieron cargar. La aplicación no puede procesar imágenes.")
305
  if image_pil is None:
306
- # Devolver resultados vacíos o un mensaje de error si no hay imagen
307
- return pd.DataFrame(), "N/A", None # Dataframe vacío, Label vacío, JSON vacío
308
-
309
- print("\n--- Iniciando evaluación para nueva imagen ---")
310
- start_process_time = time.time()
311
-
312
- try:
313
- # 1. Convertir PIL Image a NumPy array (escala de grises)
314
- # Gradio con type="pil" ya la entrega como objeto PIL
315
- img_np = np.array(image_pil.convert('L'))
316
- print(f"Imagen convertida a NumPy. Shape: {img_np.shape}, Tipo: {img_np.dtype}")
317
-
318
- # 2. Generar Embedding de Imagen
319
- print("Generando embedding de imagen...")
320
- image_embedding = generate_image_embedding(img_np, elixrc_infer_global, qformer_infer_global)
321
- print("Embedding de imagen generado.")
322
-
323
- # 3. Calcular Similitudes y Clasificar
324
- print("Calculando similitudes y clasificando criterios...")
325
- detailed_results = calculate_similarities_and_classify(image_embedding, bert_preprocessor_global, qformer_infer_global)
326
- print("Clasificación completada.")
327
-
328
- # 4. Formatear Resultados para Gradio
329
- output_data = []
330
- passed_count = 0
331
- total_count = 0
332
- for criterion, details in detailed_results.items():
333
- total_count += 1
334
- sim_pos_str = f"{details['similarity_positive']:.4f}" if details['similarity_positive'] is not None else "N/A"
335
- sim_neg_str = f"{details['similarity_negative']:.4f}" if details['similarity_negative'] is not None else "N/A"
336
- diff_str = f"{details['difference']:.4f}" if details['difference'] is not None else "N/A"
337
- assessment_comp = details['classification_comparative']
338
- assessment_simp = details['classification_simplified']
339
- output_data.append([
340
- criterion,
341
- sim_pos_str,
342
- sim_neg_str,
343
- diff_str,
344
- assessment_comp,
345
- assessment_simp
346
- ])
347
- if assessment_comp == "PASS":
348
- passed_count += 1
349
-
350
- # Crear DataFrame
351
- df_results = pd.DataFrame(output_data, columns=[
352
- "Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)"
353
- ])
354
-
355
- # Calcular etiqueta de calidad general
356
- overall_quality = "Error"
357
- if total_count > 0:
358
- pass_rate = passed_count / total_count
359
- if pass_rate >= 0.85: overall_quality = "Excellent"
360
- elif pass_rate >= 0.70: overall_quality = "Good"
361
- elif pass_rate >= 0.50: overall_quality = "Fair"
362
- else: overall_quality = "Poor"
363
- quality_label = f"{overall_quality} ({passed_count}/{total_count} criteria passed)"
364
-
365
- end_process_time = time.time()
366
- print(f"--- Evaluación completada en {end_process_time - start_process_time:.2f} segundos ---")
367
-
368
- # Devolver DataFrame, Etiqueta y JSON
369
- return df_results, quality_label, detailed_results
370
-
371
- except Exception as e:
372
- print(f"Error durante el procesamiento de la imagen en Gradio: {e}")
373
- traceback.print_exc()
374
- # Lanzar un gr.Error para mostrarlo en la UI de Gradio
375
- raise gr.Error(f"Error procesando la imagen: {str(e)}")
376
-
377
 
378
- # --- Definir la Interfaz Gradio ---
379
- css = """
380
- #quality-label label {
381
- font-size: 1.1em;
382
- font-weight: bold;
383
- }
384
- """
385
- with gr.Blocks(css=css) as demo:
386
- gr.Markdown(
387
- """
388
- # Chest X-ray Technical Quality Assessment
389
- Upload a chest X-ray image (PNG, JPG, etc.) to evaluate its technical quality based on 7 standard criteria
390
- using the ELIXR model family (comparative strategy: Positive vs Negative prompts).
391
- **Note:** Model loading on startup might take a minute. Processing an image can take 10-30 seconds depending on server load.
392
- """
393
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  with gr.Row():
395
- with gr.Column(scale=1):
396
- input_image = gr.Image(type="pil", label="Upload Chest X-ray")
397
- submit_button = gr.Button("Assess Quality", variant="primary")
398
- # Añadir ejemplos si tienes imágenes de ejemplo
399
- # Asegúrate de que la carpeta 'examples' exista y contenga las imágenes
400
- # gr.Examples(
401
- # examples=[os.path.join("examples", "sample_cxr.png")], # Lista de rutas a ejemplos
402
- # inputs=input_image
403
- # )
 
 
 
 
 
 
 
 
 
 
404
  with gr.Column(scale=2):
405
- output_label = gr.Label(label="Overall Quality Estimate", elem_id="quality-label")
406
- # --- LÍNEA MODIFICADA ---
407
- output_dataframe = gr.DataFrame(
408
- headers=["Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)"],
409
- label="Detailed Quality Assessment",
410
- wrap=True
411
- # height=350 <-- ELIMINADO
412
- )
413
- # --- FIN LÍNEA MODIFICADA ---
414
- output_json = gr.JSON(label="Raw Results (for debugging)")
415
-
416
- # Conectar el botón a la función de procesamiento
417
- submit_button.click(
418
- fn=assess_quality,
419
- inputs=input_image,
420
- outputs=[output_dataframe, output_label, output_json]
 
 
 
 
 
 
 
421
  )
422
 
423
- # --- Iniciar la Aplicación Gradio ---
424
- # Al desplegar en Spaces, Gradio se encarga de esto automáticamente.
425
- # Para ejecutar localmente: demo.launch()
426
- # Para Spaces, es mejor dejar que HF maneje el launch.
427
- # demo.launch(share=True) # Para obtener un link público temporal si corres localmente
428
  if __name__ == "__main__":
429
- # share=True solo si quieres un enlace público temporal desde local
430
- # server_name="0.0.0.0" para permitir conexiones de red local
431
- # server_port=7860 es el puerto estándar de HF Spaces
432
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
7
  import tensorflow_hub as tf_hub
8
  import numpy as np
9
  from PIL import Image
10
+ from huggingface_hub import snapshot_download
11
  from sklearn.metrics.pairwise import cosine_similarity
12
  import traceback
13
  import time
 
14
 
15
  # --- Configuración ---
16
  MODEL_REPO_ID = "google/cxr-foundation"
17
+ MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space'
18
+ SIMILARITY_DIFFERENCE_THRESHOLD = 0.1
19
+ POSITIVE_SIMILARITY_THRESHOLD = 0.1
20
+
21
  print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}")
22
 
23
+ # Prompts por defecto mejorados
24
  criteria_list_positive = [
25
+ "optimal centering mediastinum",
26
+ "deep inspiration",
27
+ "adequate penetration",
28
+ "complete lung fields",
29
+ "scapulae retracted outside lungs",
30
+ "sharp contrast",
31
+ "artifact-free image"
32
  ]
33
  criteria_list_negative = [
34
+ "poor centering",
35
+ "shallow inspiration",
36
+ "overexposed image",
37
+ "underexposed image",
38
+ "cropped lung fields",
39
+ "scapular overlay on lungs",
40
+ "blurred image with artifacts"
41
  ]
42
 
43
+ # --- Funciones Auxiliares ---
 
 
 
 
 
44
  def bert_tokenize(text, preprocessor):
 
45
  if preprocessor is None:
46
+ raise ValueError("BERT preprocessor no está cargado.")
47
+ text = str(text).lower()
48
+ out = preprocessor(tf.constant([text]))
 
 
 
 
49
  ids = out['input_word_ids'].numpy().astype(np.int32)
50
  masks = out['input_mask'].numpy().astype(np.float32)
51
  paddings = 1.0 - masks
52
+ # Ajustes para el token de fin
 
53
  end_token_idx = (ids == 102)
54
  ids[end_token_idx] = 0
55
  paddings[end_token_idx] = 1.0
56
+ # Asegurar forma (1,1,128)
57
+ if ids.ndim == 2: ids = np.expand_dims(ids, 1)
58
+ if paddings.ndim == 2: paddings = np.expand_dims(paddings, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  return ids, paddings
60
 
61
  def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
62
+ # (sin cambios, convierte array NumPy a tf.Example PNG)
63
  if image_array.ndim == 3 and image_array.shape[2] == 1:
64
+ image_array = np.squeeze(image_array, axis=2)
65
  elif image_array.ndim != 2:
66
+ raise ValueError(f'Array debe ser 2-D. Dimensiones: {image_array.ndim}')
 
67
  image = image_array.astype(np.float32)
68
+ min_val, max_val = image.min(), image.max()
 
 
 
69
  if max_val <= min_val:
 
 
70
  if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255):
71
+ pixel_array = image.astype(np.uint8); bitdepth = 8
72
+ else:
73
+ pixel_array = np.zeros_like(image, dtype=np.uint16); bitdepth = 16
 
 
74
  else:
75
+ image -= min_val
76
  current_max = max_val - min_val
 
77
  if image_array.dtype != np.uint8:
78
  image *= 65535 / current_max
79
+ pixel_array = image.astype(np.uint16); bitdepth = 16
 
80
  else:
 
 
 
81
  image *= 255 / current_max
82
+ pixel_array = image.astype(np.uint8); bitdepth = 8
 
 
 
83
  output = io.BytesIO()
84
+ png.Writer(width=pixel_array.shape[1], height=pixel_array.shape[0],
85
+ greyscale=True, bitdepth=bitdepth).write(output, pixel_array.tolist())
 
 
 
 
 
 
 
86
  example = tf.train.Example()
87
  features = example.features.feature
88
+ features['image/encoded'].bytes_list.value.append(output.getvalue())
89
  features['image/format'].bytes_list.value.append(b'png')
90
  return example
91
 
92
  def generate_image_embedding(img_np, elixrc_infer, qformer_infer):
 
93
  if elixrc_infer is None or qformer_infer is None:
94
  raise ValueError("Modelos ELIXR-C o QFormer no cargados.")
 
95
  try:
96
+ serialized = png_to_tfexample(img_np).SerializeToString()
97
+ elixrc_out = elixrc_infer(input_example=tf.constant([serialized]))
98
+ elixr_emb = elixrc_out['feature_maps_0'].numpy()
99
+ q_in = {
100
+ 'image_feature': elixr_emb.tolist(),
101
+ 'ids': np.zeros((1,1,128),dtype=np.int32).tolist(),
102
+ 'paddings': np.ones((1,1,128),dtype=np.float32).tolist(),
 
 
 
 
103
  }
104
+ q_out = qformer_infer(**q_in)
105
+ img_emb = q_out['all_contrastive_img_emb'].numpy()
106
+ if img_emb.ndim > 2:
107
+ img_emb = img_emb.mean(axis=tuple(range(1, img_emb.ndim-1)))
108
+ if img_emb.ndim == 1:
109
+ img_emb = img_emb[np.newaxis, :]
110
+ return img_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  except Exception as e:
112
+ print(f"Error embedding imagen: {e}")
113
  traceback.print_exc()
114
+ raise
115
+
116
+ def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer,
117
+ criteria_positive, criteria_negative):
118
+ results = {}
119
+ for pos, neg in zip(criteria_positive, criteria_negative):
120
+ sim_pos = sim_neg = diff = None
121
+ comp = simp = "ERROR"
 
 
 
 
 
 
 
 
 
 
 
 
122
  try:
123
+ # Embedding texto positivo
124
+ ids_p, pad_p = bert_tokenize(pos, bert_preprocessor)
125
+ inp_p = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(),
126
+ 'ids': ids_p.tolist(), 'paddings': pad_p.tolist()}
127
+ txt_p = qformer_infer(**inp_p)['contrastive_txt_emb'].numpy()
128
+ # Embedding texto negativo
129
+ ids_n, pad_n = bert_tokenize(neg, bert_preprocessor)
130
+ inp_n = {'image_feature': np.zeros([1,8,8,1376],dtype=np.float32).tolist(),
131
+ 'ids': ids_n.tolist(), 'paddings': pad_n.tolist()}
132
+ txt_n = qformer_infer(**inp_n)['contrastive_txt_emb'].numpy()
133
+
134
+ sim_pos = float(cosine_similarity(image_embedding, txt_p.reshape(1,-1))[0][0])
135
+ sim_neg = float(cosine_similarity(image_embedding, txt_n.reshape(1,-1))[0][0])
136
+ diff = sim_pos - sim_neg
137
+ comp = "PASS" if diff > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL"
138
+ simp = "PASS" if sim_pos > POSITIVE_SIMILARITY_THRESHOLD else "FAIL"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  except Exception as e:
140
+ print(f"Error en criterio '{pos}': {e}")
141
+ results[pos] = {
142
+ 'positive_prompt': pos,
143
+ 'negative_prompt': neg,
144
+ 'sim_pos': sim_pos,
145
+ 'sim_neg': sim_neg,
146
+ 'difference': diff,
147
+ 'comp': comp,
148
+ 'simp': simp
 
 
 
 
149
  }
150
+ return results
151
 
152
  # --- Carga Global de Modelos ---
153
+ print("--- Iniciando carga de modelos ---")
 
154
  start_time = time.time()
155
  models_loaded = False
156
+ bert_preproc = elixrc = qformer = None
 
 
 
157
  try:
158
+ hf_token = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
159
  os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True)
 
160
  snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR,
161
+ allow_patterns=['elixr-c-v2-pooled/*','pax-elixr-b-text/*'],
162
+ local_dir_use_symlinks=False, token=hf_token)
163
+ bert_preproc = tf_hub.KerasLayer("https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3")
164
+ elixr = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'elixr-c-v2-pooled')).signatures['serving_default']
165
+ qformer = tf.saved_model.load(os.path.join(MODEL_DOWNLOAD_DIR,'pax-elixr-b-text')).signatures['serving_default']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  models_loaded = True
167
+ print(f"Modelos cargados en {time.time()-start_time:.2f}s")
 
 
168
  except Exception as e:
169
+ print("ERROR cargando modelos:", e)
 
 
170
  traceback.print_exc()
 
171
 
172
+ # --- Función Principal para Gradio ---
173
+ def assess_quality_and_update_ui(image_pil, pos_input, neg_input):
 
174
  if not models_loaded:
175
+ raise gr.Error("No se pudieron cargar los modelos.")
176
  if image_pil is None:
177
+ # devuelve: welcome visible, results oculto, imagen None, etiqueta N/A, html vacío, json vacío
178
+ return (
179
+ gr.update(visible=True),
180
+ gr.update(visible=False),
181
+ None,
182
+ "N/A",
183
+ "",
184
+ {}
185
+ )
186
+ # Parsear listas de prompts
187
+ pos_list = [l.strip() for l in pos_input.splitlines() if l.strip()]
188
+ neg_list = [l.strip() for l in neg_input.splitlines() if l.strip()]
189
+ if len(pos_list) != len(neg_list):
190
+ raise gr.Error("El número de prompts positivos y negativos debe coincidir.")
191
+ # Embedding imagen
192
+ img_np = np.array(image_pil.convert('L'))
193
+ emb = generate_image_embedding(img_np, elixr, qformer)
194
+ # Calcular similitudes
195
+ details = calculate_similarities_and_classify(emb, bert_preproc, qformer, pos_list, neg_list)
196
+ # Generar HTML
197
+ passed = total = 0
198
+ rows = ""
199
+ for crit, d in details.items():
200
+ total += 1
201
+ if d['comp']=="PASS': passed+=1
202
+ c_style = "color:#22c55e;font-weight:bold;" if d['comp']=="PASS" else "color:#ef4444;font-weight:bold;"
203
+ s_style = "color:#22c55e;font-weight:bold;" if d['simp']=="PASS" else "color:#ef4444;font-weight:bold;"
204
+ rows += (
205
+ f"<tr>"
206
+ f"<td>{crit}</td>"
207
+ f"<td>{d['sim_pos']:.4f}</td>"
208
+ f"<td>{d['sim_neg']:.4f}</td>"
209
+ f"<td>{d['difference']:.4f}</td>"
210
+ f"<td style='{c_style}'>{d['comp']}</td>"
211
+ f"<td style='{s_style}'>{d['simp']}</td>"
212
+ f"</tr>"
213
+ )
214
+ html = f"""
215
+ <table style="width:100%;border-collapse:collapse;">
216
+ <thead style="background:#f2f2f2;">
217
+ <tr>
218
+ <th>Criterion</th><th>Sim (+)</th><th>Sim (-)</th><th>Diff</th>
219
+ <th>Assessment (Comp)</th><th>Assessment (Simp)</th>
220
+ </tr>
221
+ </thead>
222
+ <tbody>{rows}</tbody>
223
+ </table>
224
+ """
225
+ # Etiqueta general
226
+ pass_rate = passed/total if total>0 else 0
227
+ if pass_rate>=0.85: overall="Excellent"
228
+ elif pass_rate>=0.70: overall="Good"
229
+ elif pass_rate>=0.50: overall="Fair"
230
+ else: overall="Poor"
231
+ quality_label = f"{overall} ({passed}/{total} passed)"
232
+ # Devolver actualizaciones UI
233
+ return (
234
+ gr.update(visible=False),
235
+ gr.update(visible=True),
236
+ image_pil,
237
+ quality_label,
238
+ html,
239
+ details
240
+ )
 
 
 
 
 
 
 
241
 
242
+ def reset_ui():
243
+ return (
244
+ gr.update(visible=True),
245
+ gr.update(visible=False),
246
+ None, # limpia input_image
247
+ None, # limpia output_image
248
+ "N/A", # etiqueta calidad
249
+ "", # HTML
250
+ {} # JSON
 
 
 
 
 
 
251
  )
252
+
253
+ # --- Definir Tema ---
254
+ dark_theme = gr.themes.Default(
255
+ primary_hue=gr.themes.colors.blue,
256
+ secondary_hue=gr.themes.colors.blue,
257
+ neutral_hue=gr.themes.colors.gray,
258
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
259
+ font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "Consolas", "monospace"],
260
+ ).set(
261
+ body_background_fill="#111827",
262
+ background_fill_primary="#1f2937",
263
+ background_fill_secondary="#374151",
264
+ block_background_fill="#1f2937",
265
+ body_text_color="#d1d5db",
266
+ block_label_text_color="#d1d5db",
267
+ block_title_text_color="#ffffff",
268
+ border_color_accent="#374151",
269
+ border_color_primary="#4b5563",
270
+ button_primary_background_fill="*primary_600",
271
+ button_primary_text_color="#ffffff",
272
+ button_secondary_background_fill="*neutral_700",
273
+ button_secondary_text_color="#ffffff",
274
+ input_background_fill="#374151",
275
+ input_border_color="#4b5563",
276
+ shadow_drop="rgba(0,0,0,0.2) 0px 2px 4px",
277
+ block_shadow="rgba(0,0,0,0.2) 0px 2px 5px",
278
+ )
279
+
280
+ # --- Interfaz Gradio ---
281
+ with gr.Blocks(theme=dark_theme, title="CXR Quality Assessment") as demo:
282
+ # Cabecera
283
+ gr.Markdown("""
284
+ # <span style="color: #e5e7eb;">CXR Quality Assessment</span>
285
+ <p style="color: #9ca3af;">Evalúa la calidad técnica de radiografías de tórax con AI</p>
286
+ """)
287
+ # Prompts editables
288
  with gr.Row():
289
+ positive_prompts_input = gr.Textarea(
290
+ label="Prompts Positivos (uno por línea)",
291
+ value="\n".join(criteria_list_positive),
292
+ lines=7
293
+ )
294
+ negative_prompts_input = gr.Textarea(
295
+ label="Prompts Negativos (uno por línea)",
296
+ value="\n".join(criteria_list_negative),
297
+ lines=7
298
+ )
299
+ # Contenido principal
300
+ with gr.Row(equal_height=False):
301
+ with gr.Column(scale=1, min_width=300):
302
+ gr.Markdown("### 1. Carga de Imagen")
303
+ input_image = gr.Image(type="pil", label="Sube tu CXR", height=300)
304
+ with gr.Row():
305
+ analyze_btn = gr.Button("Analizar", variant="primary")
306
+ reset_btn = gr.Button("Reset", variant="secondary")
307
+ gr.Markdown("<p style='color:#9ca3af; font-size:0.9em;'>La carga de modelos tarda ~1 min; el análisis ~15–40 s.</p>")
308
  with gr.Column(scale=2):
309
+ with gr.Column(visible=True) as welcome_block:
310
+ gr.Markdown("### ¡Bienvenido! Sube una radiografía y haz clic en «Analizar».")
311
+ with gr.Column(visible=False) as results_block:
312
+ gr.Markdown("### 2. Resultados")
313
+ with gr.Row():
314
+ output_image = gr.Image(type="pil", label="Imagen Analizada", interactive=False)
315
+ with gr.Column():
316
+ gr.Markdown("#### Calidad Global")
317
+ output_label = gr.Label(value="N/A")
318
+ gr.Markdown("#### Evaluación Detallada")
319
+ output_html = gr.HTML()
320
+ with gr.Accordion("Ver JSON (debug)", open=False):
321
+ output_json = gr.JSON()
322
+ # Conexiones
323
+ analyze_btn.click(
324
+ fn=assess_quality_and_update_ui,
325
+ inputs=[input_image, positive_prompts_input, negative_prompts_input],
326
+ outputs=[welcome_block, results_block, output_image, output_label, output_html, output_json]
327
+ )
328
+ reset_btn.click(
329
+ fn=reset_ui,
330
+ inputs=None,
331
+ outputs=[welcome_block, results_block, input_image, output_image, output_label, output_html, output_json]
332
  )
333
 
 
 
 
 
 
334
  if __name__ == "__main__":
335
+ demo.launch(server_name="0.0.0.0", server_port=7860)