fecia commited on
Commit
6045f26
·
verified ·
1 Parent(s): 1b15490

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +263 -249
app.py CHANGED
@@ -15,10 +15,9 @@ import pandas as pd # Para formatear la salida en tabla
15
 
16
  # --- Configuración ---
17
  MODEL_REPO_ID = "google/cxr-foundation"
18
- MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space' # Directorio dentro del contenedor del Space
19
- # Umbrales
20
- SIMILARITY_DIFFERENCE_THRESHOLD = 0.0
21
- POSITIVE_SIMILARITY_THRESHOLD = 0.0
22
  print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}")
23
 
24
  # --- Prompts ---
@@ -31,255 +30,158 @@ criteria_list_negative = [
31
  "cropped image", "scapulae overlying lungs", "blurred image", "obscuring artifact"
32
  ]
33
 
34
- # --- Funciones Auxiliares (Integradas o adaptadas) ---
35
- # @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)]) # Puede ayudar rendimiento
36
- def preprocess_text(text):
37
- """Función interna del preprocesador BERT."""
38
- return bert_preprocessor_global(text)
39
 
40
  def bert_tokenize(text, preprocessor):
41
- """Tokeniza texto usando el preprocesador BERT cargado globalmente."""
42
- if preprocessor is None:
43
- raise ValueError("BERT preprocessor no está cargado.")
44
  if not isinstance(text, str): text = str(text)
45
-
46
- # Ejecutar el preprocesador
47
  out = preprocessor(tf.constant([text.lower()]))
48
-
49
- # Extraer y procesar IDs y máscaras
50
  ids = out['input_word_ids'].numpy().astype(np.int32)
51
  masks = out['input_mask'].numpy().astype(np.float32)
52
  paddings = 1.0 - masks
53
-
54
- # Reemplazar token [SEP] (102) por 0 y marcar como padding
55
  end_token_idx = (ids == 102)
56
  ids[end_token_idx] = 0
57
  paddings[end_token_idx] = 1.0
58
-
59
- # Asegurar las dimensiones (B, T, S) -> (1, 1, 128)
60
- # El preprocesador puede devolver (1, 128), necesitamos (1, 1, 128)
61
  if ids.ndim == 2: ids = np.expand_dims(ids, axis=1)
62
  if paddings.ndim == 2: paddings = np.expand_dims(paddings, axis=1)
63
-
64
- # Verificar formas finales
65
  expected_shape = (1, 1, 128)
66
  if ids.shape != expected_shape:
67
- # Intentar reajustar si es necesario (puede pasar con algunas versiones)
68
  if ids.shape == (1,128): ids = np.expand_dims(ids, axis=1)
69
  else: raise ValueError(f"Shape incorrecta para ids: {ids.shape}, esperado {expected_shape}")
70
  if paddings.shape != expected_shape:
71
  if paddings.shape == (1,128): paddings = np.expand_dims(paddings, axis=1)
72
  else: raise ValueError(f"Shape incorrecta para paddings: {paddings.shape}, esperado {expected_shape}")
73
-
74
  return ids, paddings
75
 
76
  def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
77
- """Crea tf.train.Example desde NumPy array (escala de grises)."""
78
  if image_array.ndim == 3 and image_array.shape[2] == 1:
79
- image_array = np.squeeze(image_array, axis=2) # Asegurar 2D
80
  elif image_array.ndim != 2:
81
- raise ValueError(f'Array debe ser 2-D (escala de grises). Dimensiones actuales: {image_array.ndim}')
82
-
83
  image = image_array.astype(np.float32)
84
- min_val = image.min()
85
- max_val = image.max()
86
-
87
- # Evitar división por cero si la imagen es constante
88
  if max_val <= min_val:
89
- # Si es constante, tratar como uint8 si el rango original lo permitía,
90
- # o simplemente ponerla a 0 si es float.
91
  if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255):
92
- pixel_array = image.astype(np.uint8)
93
- bitdepth = 8
94
- else: # Caso flotante constante o fuera de rango uint8
95
- pixel_array = np.zeros_like(image, dtype=np.uint16)
96
- bitdepth = 16
97
  else:
98
- image -= min_val # Mover mínimo a cero
99
  current_max = max_val - min_val
100
- # Escalar a 16-bit para mayor precisión si no era uint8 originalmente
101
  if image_array.dtype != np.uint8:
102
  image *= 65535 / current_max
103
- pixel_array = image.astype(np.uint16)
104
- bitdepth = 16
105
  else:
106
- # Si era uint8, mantener el rango y tipo
107
- # La resta del min ya la dejó en [0, current_max]
108
- # Escalar a 255 si es necesario
109
  image *= 255 / current_max
110
- pixel_array = image.astype(np.uint8)
111
- bitdepth = 8
112
-
113
- # Codificar como PNG
114
  output = io.BytesIO()
115
- png.Writer(
116
- width=pixel_array.shape[1],
117
- height=pixel_array.shape[0],
118
- greyscale=True,
119
- bitdepth=bitdepth
120
- ).write(output, pixel_array.tolist())
121
- png_bytes = output.getvalue()
122
-
123
- # Crear tf.train.Example
124
  example = tf.train.Example()
125
  features = example.features.feature
126
- features['image/encoded'].bytes_list.value.append(png_bytes)
127
  features['image/format'].bytes_list.value.append(b'png')
128
  return example
129
 
130
  def generate_image_embedding(img_np, elixrc_infer, qformer_infer):
131
- """Genera embedding final de imagen."""
132
- if elixrc_infer is None or qformer_infer is None:
133
- raise ValueError("Modelos ELIXR-C o QFormer no cargados.")
134
-
135
  try:
136
- # 1. ELIXR-C
137
  serialized_img_tf_example = png_to_tfexample(img_np).SerializeToString()
138
  elixrc_output = elixrc_infer(input_example=tf.constant([serialized_img_tf_example]))
139
  elixrc_embedding = elixrc_output['feature_maps_0'].numpy()
140
- print(f" Embedding ELIXR-C shape: {elixrc_embedding.shape}")
141
-
142
- # 2. QFormer (Imagen)
143
  qformer_input_img = {
144
  'image_feature': elixrc_embedding.tolist(),
145
- 'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(), # Texto vacío
146
- 'paddings': np.ones((1, 1, 128), dtype=np.float32).tolist(), # Todo padding
147
  }
148
  qformer_output_img = qformer_infer(**qformer_input_img)
149
  image_embedding = qformer_output_img['all_contrastive_img_emb'].numpy()
150
-
151
- # Ajustar dimensiones si es necesario
152
  if image_embedding.ndim > 2:
153
- print(f" Ajustando dimensiones embedding imagen (original: {image_embedding.shape})")
154
- image_embedding = np.mean(
155
- image_embedding,
156
- axis=tuple(range(1, image_embedding.ndim - 1))
157
- )
158
- if image_embedding.ndim == 1:
159
- image_embedding = np.expand_dims(image_embedding, axis=0)
160
- elif image_embedding.ndim == 1:
161
- image_embedding = np.expand_dims(image_embedding, axis=0) # Asegurar 2D
162
-
163
- print(f" Embedding final imagen shape: {image_embedding.shape}")
164
- if image_embedding.ndim != 2:
165
- raise ValueError(f"Embedding final de imagen no tiene 2 dimensiones: {image_embedding.shape}")
166
  return image_embedding
167
-
168
  except Exception as e:
169
- print(f"Error generando embedding de imagen: {e}")
170
- traceback.print_exc()
171
- raise # Re-lanzar la excepción para que Gradio la maneje
172
 
173
  def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer):
174
- """Calcula similitudes y clasifica."""
175
- if image_embedding is None: raise ValueError("Embedding de imagen es None.")
176
  if bert_preprocessor is None: raise ValueError("Preprocesador BERT es None.")
177
  if qformer_infer is None: raise ValueError("QFormer es None.")
178
-
179
  detailed_results = {}
180
- print("\n--- Calculando similitudes y clasificando ---")
181
-
182
  for i in range(len(criteria_list_positive)):
183
- positive_text = criteria_list_positive[i]
184
- negative_text = criteria_list_negative[i]
185
- criterion_name = positive_text # Usar prompt positivo como clave
186
-
187
- print(f"Procesando criterio: \"{criterion_name}\"")
188
  similarity_positive, similarity_negative, difference = None, None, None
189
  classification_comp, classification_simp = "ERROR", "ERROR"
190
-
191
  try:
192
- # 1. Embedding Texto Positivo
193
  tokens_pos, paddings_pos = bert_tokenize(positive_text, bert_preprocessor)
194
- qformer_input_text_pos = {
195
- 'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), # Dummy
196
- 'ids': tokens_pos.tolist(), 'paddings': paddings_pos.tolist(),
197
- }
198
- text_embedding_pos = qformer_infer(**qformer_input_text_pos)['contrastive_txt_emb'].numpy()
199
  if text_embedding_pos.ndim == 1: text_embedding_pos = np.expand_dims(text_embedding_pos, axis=0)
200
 
201
- # 2. Embedding Texto Negativo
202
  tokens_neg, paddings_neg = bert_tokenize(negative_text, bert_preprocessor)
203
- qformer_input_text_neg = {
204
- 'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), # Dummy
205
- 'ids': tokens_neg.tolist(), 'paddings': paddings_neg.tolist(),
206
- }
207
- text_embedding_neg = qformer_infer(**qformer_input_text_neg)['contrastive_txt_emb'].numpy()
208
  if text_embedding_neg.ndim == 1: text_embedding_neg = np.expand_dims(text_embedding_neg, axis=0)
209
 
210
- # Verificar compatibilidad de dimensiones para similitud
211
- if image_embedding.shape[1] != text_embedding_pos.shape[1]:
212
- raise ValueError(f"Dimensión incompatible: Imagen ({image_embedding.shape[1]}) vs Texto Pos ({text_embedding_pos.shape[1]})")
213
- if image_embedding.shape[1] != text_embedding_neg.shape[1]:
214
- raise ValueError(f"Dimensión incompatible: Imagen ({image_embedding.shape[1]}) vs Texto Neg ({text_embedding_neg.shape[1]})")
215
 
216
- # 3. Calcular Similitudes
217
  similarity_positive = cosine_similarity(image_embedding, text_embedding_pos)[0][0]
218
  similarity_negative = cosine_similarity(image_embedding, text_embedding_neg)[0][0]
219
- print(f" Sim (+)={similarity_positive:.4f}, Sim (-)={similarity_negative:.4f}")
220
 
221
- # 4. Clasificar
222
  difference = similarity_positive - similarity_negative
223
  classification_comp = "PASS" if difference > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL"
224
  classification_simp = "PASS" if similarity_positive > POSITIVE_SIMILARITY_THRESHOLD else "FAIL"
225
- print(f" Diff={difference:.4f} -> Comp: {classification_comp}, Simp: {classification_simp}")
226
-
227
  except Exception as e:
228
- print(f" ERROR procesando criterio '{criterion_name}': {e}")
229
- traceback.print_exc()
230
- # Mantener clasificaciones como "ERROR"
231
-
232
- # Guardar resultados
233
  detailed_results[criterion_name] = {
234
- 'positive_prompt': positive_text,
235
- 'negative_prompt': negative_text,
236
  'similarity_positive': float(similarity_positive) if similarity_positive is not None else None,
237
  'similarity_negative': float(similarity_negative) if similarity_negative is not None else None,
238
  'difference': float(difference) if difference is not None else None,
239
- 'classification_comparative': classification_comp,
240
- 'classification_simplified': classification_simp
241
  }
242
  return detailed_results
243
 
244
  # --- Carga Global de Modelos ---
245
- # Se ejecuta UNA VEZ al iniciar la aplicación Gradio/Space
246
  print("--- Iniciando carga global de modelos ---")
247
  start_time = time.time()
248
  models_loaded = False
249
  bert_preprocessor_global = None
250
  elixrc_infer_global = None
251
  qformer_infer_global = None
252
-
253
  try:
254
- # Verificar autenticación HF (útil si se usan modelos privados, aunque no es el caso aquí)
255
- # if HfFolder.get_token() is None:
256
- # print("Advertencia: No se encontró token de Hugging Face.")
257
- # else:
258
- # print("Token de Hugging Face encontrado.")
259
 
260
- # Crear directorio si no existe
261
  os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True)
262
  print(f"Descargando/verificando modelos en: {MODEL_DOWNLOAD_DIR}")
263
  snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR,
264
  allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'],
265
- local_dir_use_symlinks=False) # Evitar symlinks
266
  print("Modelos descargados/verificados.")
267
 
268
- # Cargar Preprocesador BERT desde TF Hub
269
  print("Cargando Preprocesador BERT...")
270
- # Usar handle explícito puede ser más robusto en algunos entornos
271
  bert_preprocess_handle = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
272
  bert_preprocessor_global = tf_hub.KerasLayer(bert_preprocess_handle)
273
  print("Preprocesador BERT cargado.")
274
 
275
- # Cargar ELIXR-C
276
  print("Cargando ELIXR-C...")
277
  elixrc_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'elixr-c-v2-pooled')
278
  elixrc_model = tf.saved_model.load(elixrc_model_path)
279
  elixrc_infer_global = elixrc_model.signatures['serving_default']
280
  print("Modelo ELIXR-C cargado.")
281
 
282
- # Cargar QFormer (ELIXR-B Text)
283
  print("Cargando QFormer (ELIXR-B Text)...")
284
  qformer_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'pax-elixr-b-text')
285
  qformer_model = tf.saved_model.load(qformer_model_path)
@@ -289,143 +191,255 @@ try:
289
  models_loaded = True
290
  end_time = time.time()
291
  print(f"--- Modelos cargados globalmente con éxito en {end_time - start_time:.2f} segundos ---")
292
-
293
  except Exception as e:
294
  models_loaded = False
295
- print(f"--- ERROR CRÍTICO DURANTE LA CARGA GLOBAL DE MODELOS ---")
296
- print(e)
297
- traceback.print_exc()
298
- # Gradio se iniciará, pero la función de análisis fallará.
299
 
300
  # --- Función Principal de Procesamiento para Gradio ---
301
- def assess_quality(image_pil):
302
- """Función que Gradio llamará con la imagen de entrada."""
303
  if not models_loaded:
304
  raise gr.Error("Error: Los modelos no se pudieron cargar. La aplicación no puede procesar imágenes.")
305
  if image_pil is None:
306
- # Devolver resultados vacíos o un mensaje de error si no hay imagen
307
- return pd.DataFrame(), "N/A", None # Dataframe vacío, Label vacío, JSON vacío
 
 
 
 
 
 
 
308
 
309
  print("\n--- Iniciando evaluación para nueva imagen ---")
310
  start_process_time = time.time()
311
-
312
  try:
313
- # 1. Convertir PIL Image a NumPy array (escala de grises)
314
- # Gradio con type="pil" ya la entrega como objeto PIL
315
  img_np = np.array(image_pil.convert('L'))
316
- print(f"Imagen convertida a NumPy. Shape: {img_np.shape}, Tipo: {img_np.dtype}")
317
-
318
- # 2. Generar Embedding de Imagen
319
- print("Generando embedding de imagen...")
320
  image_embedding = generate_image_embedding(img_np, elixrc_infer_global, qformer_infer_global)
321
- print("Embedding de imagen generado.")
322
-
323
- # 3. Calcular Similitudes y Clasificar
324
- print("Calculando similitudes y clasificando criterios...")
325
  detailed_results = calculate_similarities_and_classify(image_embedding, bert_preprocessor_global, qformer_infer_global)
326
- print("Clasificación completada.")
327
-
328
- # 4. Formatear Resultados para Gradio
329
- output_data = []
330
- passed_count = 0
331
- total_count = 0
332
  for criterion, details in detailed_results.items():
333
  total_count += 1
334
- sim_pos_str = f"{details['similarity_positive']:.4f}" if details['similarity_positive'] is not None else "N/A"
335
- sim_neg_str = f"{details['similarity_negative']:.4f}" if details['similarity_negative'] is not None else "N/A"
336
- diff_str = f"{details['difference']:.4f}" if details['difference'] is not None else "N/A"
337
- assessment_comp = details['classification_comparative']
338
- assessment_simp = details['classification_simplified']
339
- output_data.append([
340
- criterion,
341
- sim_pos_str,
342
- sim_neg_str,
343
- diff_str,
344
- assessment_comp,
345
- assessment_simp
346
- ])
347
- if assessment_comp == "PASS":
348
- passed_count += 1
349
-
350
- # Crear DataFrame
351
- df_results = pd.DataFrame(output_data, columns=[
352
- "Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)"
353
- ])
354
-
355
- # Calcular etiqueta de calidad general
356
- overall_quality = "Error"
357
  if total_count > 0:
358
  pass_rate = passed_count / total_count
359
  if pass_rate >= 0.85: overall_quality = "Excellent"
360
  elif pass_rate >= 0.70: overall_quality = "Good"
361
  elif pass_rate >= 0.50: overall_quality = "Fair"
362
  else: overall_quality = "Poor"
363
- quality_label = f"{overall_quality} ({passed_count}/{total_count} criteria passed)"
364
-
365
  end_process_time = time.time()
366
- print(f"--- Evaluación completada en {end_process_time - start_process_time:.2f} segundos ---")
367
-
368
- # Devolver DataFrame, Etiqueta y JSON
369
- return df_results, quality_label, detailed_results
370
-
 
 
 
 
 
371
  except Exception as e:
372
- print(f"Error durante el procesamiento de la imagen en Gradio: {e}")
373
- traceback.print_exc()
374
- # Lanzar un gr.Error para mostrarlo en la UI de Gradio
375
- raise gr.Error(f"Error procesando la imagen: {str(e)}")
376
-
377
-
378
- # --- Definir la Interfaz Gradio ---
379
- css = """
380
- #quality-label label {
381
- font-size: 1.1em;
382
- font-weight: bold;
383
- }
384
- """
385
- with gr.Blocks(css=css) as demo:
386
- gr.Markdown(
387
- """
388
- # Chest X-ray Technical Quality Assessment
389
- Upload a chest X-ray image (PNG, JPG, etc.) to evaluate its technical quality based on 7 standard criteria
390
- using the ELIXR model family (comparative strategy: Positive vs Negative prompts).
391
- **Note:** Model loading on startup might take a minute. Processing an image can take 10-30 seconds depending on server load.
392
- """
393
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
394
  with gr.Row():
395
- with gr.Column(scale=1):
396
- input_image = gr.Image(type="pil", label="Upload Chest X-ray")
397
- submit_button = gr.Button("Assess Quality", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
  # Añadir ejemplos si tienes imágenes de ejemplo
399
- # Asegúrate de que la carpeta 'examples' exista y contenga las imágenes
400
  # gr.Examples(
401
- # examples=[os.path.join("examples", "sample_cxr.png")], # Lista de rutas a ejemplos
402
- # inputs=input_image
403
  # )
 
 
 
 
 
 
404
  with gr.Column(scale=2):
405
- output_label = gr.Label(label="Overall Quality Estimate", elem_id="quality-label")
406
- # ***** LÍNEA MODIFICADA: Se quitó height=350 *****
407
- output_dataframe = gr.DataFrame(
408
- headers=["Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)"],
409
- label="Detailed Quality Assessment",
410
- wrap=True
411
- # height=350 # <- Eliminado
 
 
 
 
412
  )
413
- output_json = gr.JSON(label="Raw Results (for debugging)")
414
-
415
- # Conectar el botón a la función de procesamiento
416
- submit_button.click(
417
- fn=assess_quality,
418
- inputs=input_image,
419
- outputs=[output_dataframe, output_label, output_json]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  )
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  # --- Iniciar la Aplicación Gradio ---
423
- # Al desplegar en Spaces, Gradio se encarga de esto automáticamente.
424
- # Para ejecutar localmente: demo.launch()
425
- # Para Spaces, es mejor dejar que HF maneje el launch.
426
- # demo.launch(share=True) # Para obtener un link público temporal si corres localmente
427
  if __name__ == "__main__":
428
- # share=True solo si quieres un enlace público temporal desde local
429
- # server_name="0.0.0.0" para permitir conexiones de red local
430
  # server_port=7860 es el puerto estándar de HF Spaces
431
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
15
 
16
  # --- Configuración ---
17
  MODEL_REPO_ID = "google/cxr-foundation"
18
+ MODEL_DOWNLOAD_DIR = './hf_cxr_foundation_space'
19
+ SIMILARITY_DIFFERENCE_THRESHOLD = 0.1
20
+ POSITIVE_SIMILARITY_THRESHOLD = 0.1
 
21
  print(f"Usando umbrales: Comp Δ={SIMILARITY_DIFFERENCE_THRESHOLD}, Simp τ={POSITIVE_SIMILARITY_THRESHOLD}")
22
 
23
  # --- Prompts ---
 
30
  "cropped image", "scapulae overlying lungs", "blurred image", "obscuring artifact"
31
  ]
32
 
33
+ # --- Funciones Auxiliares (MISMAS que en la versión anterior de Gradio) ---
34
+ # @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
35
+ # def preprocess_text(text):
36
+ # return bert_preprocessor_global(text) # Asume que bert_preprocessor_global está cargado
 
37
 
38
  def bert_tokenize(text, preprocessor):
39
+ if preprocessor is None: raise ValueError("BERT preprocessor no está cargado.")
 
 
40
  if not isinstance(text, str): text = str(text)
 
 
41
  out = preprocessor(tf.constant([text.lower()]))
 
 
42
  ids = out['input_word_ids'].numpy().astype(np.int32)
43
  masks = out['input_mask'].numpy().astype(np.float32)
44
  paddings = 1.0 - masks
 
 
45
  end_token_idx = (ids == 102)
46
  ids[end_token_idx] = 0
47
  paddings[end_token_idx] = 1.0
 
 
 
48
  if ids.ndim == 2: ids = np.expand_dims(ids, axis=1)
49
  if paddings.ndim == 2: paddings = np.expand_dims(paddings, axis=1)
 
 
50
  expected_shape = (1, 1, 128)
51
  if ids.shape != expected_shape:
 
52
  if ids.shape == (1,128): ids = np.expand_dims(ids, axis=1)
53
  else: raise ValueError(f"Shape incorrecta para ids: {ids.shape}, esperado {expected_shape}")
54
  if paddings.shape != expected_shape:
55
  if paddings.shape == (1,128): paddings = np.expand_dims(paddings, axis=1)
56
  else: raise ValueError(f"Shape incorrecta para paddings: {paddings.shape}, esperado {expected_shape}")
 
57
  return ids, paddings
58
 
59
  def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
 
60
  if image_array.ndim == 3 and image_array.shape[2] == 1:
61
+ image_array = np.squeeze(image_array, axis=2)
62
  elif image_array.ndim != 2:
63
+ raise ValueError(f'Array debe ser 2-D. Dimensiones: {image_array.ndim}')
 
64
  image = image_array.astype(np.float32)
65
+ min_val, max_val = image.min(), image.max()
 
 
 
66
  if max_val <= min_val:
 
 
67
  if image_array.dtype == np.uint8 or (min_val >= 0 and max_val <= 255):
68
+ pixel_array = image.astype(np.uint8); bitdepth = 8
69
+ else:
70
+ pixel_array = np.zeros_like(image, dtype=np.uint16); bitdepth = 16
 
 
71
  else:
72
+ image -= min_val
73
  current_max = max_val - min_val
 
74
  if image_array.dtype != np.uint8:
75
  image *= 65535 / current_max
76
+ pixel_array = image.astype(np.uint16); bitdepth = 16
 
77
  else:
 
 
 
78
  image *= 255 / current_max
79
+ pixel_array = image.astype(np.uint8); bitdepth = 8
 
 
 
80
  output = io.BytesIO()
81
+ png.Writer(width=pixel_array.shape[1], height=pixel_array.shape[0], greyscale=True, bitdepth=bitdepth).write(output, pixel_array.tolist())
 
 
 
 
 
 
 
 
82
  example = tf.train.Example()
83
  features = example.features.feature
84
+ features['image/encoded'].bytes_list.value.append(output.getvalue())
85
  features['image/format'].bytes_list.value.append(b'png')
86
  return example
87
 
88
  def generate_image_embedding(img_np, elixrc_infer, qformer_infer):
89
+ if elixrc_infer is None or qformer_infer is None: raise ValueError("Modelos ELIXR-C o QFormer no cargados.")
 
 
 
90
  try:
 
91
  serialized_img_tf_example = png_to_tfexample(img_np).SerializeToString()
92
  elixrc_output = elixrc_infer(input_example=tf.constant([serialized_img_tf_example]))
93
  elixrc_embedding = elixrc_output['feature_maps_0'].numpy()
 
 
 
94
  qformer_input_img = {
95
  'image_feature': elixrc_embedding.tolist(),
96
+ 'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
97
+ 'paddings': np.ones((1, 1, 128), dtype=np.float32).tolist(),
98
  }
99
  qformer_output_img = qformer_infer(**qformer_input_img)
100
  image_embedding = qformer_output_img['all_contrastive_img_emb'].numpy()
 
 
101
  if image_embedding.ndim > 2:
102
+ image_embedding = np.mean(image_embedding, axis=tuple(range(1, image_embedding.ndim - 1)))
103
+ if image_embedding.ndim == 1: image_embedding = np.expand_dims(image_embedding, axis=0)
104
+ if image_embedding.ndim != 2: raise ValueError(f"Embedding final no tiene 2 dims: {image_embedding.shape}")
 
 
 
 
 
 
 
 
 
 
105
  return image_embedding
 
106
  except Exception as e:
107
+ print(f"Error generando embedding imagen: {e}"); traceback.print_exc(); raise
 
 
108
 
109
  def calculate_similarities_and_classify(image_embedding, bert_preprocessor, qformer_infer):
110
+ if image_embedding is None: raise ValueError("Embedding imagen es None.")
 
111
  if bert_preprocessor is None: raise ValueError("Preprocesador BERT es None.")
112
  if qformer_infer is None: raise ValueError("QFormer es None.")
 
113
  detailed_results = {}
114
+ print("\n--- Calculando similitudes ---")
 
115
  for i in range(len(criteria_list_positive)):
116
+ positive_text, negative_text = criteria_list_positive[i], criteria_list_negative[i]
117
+ criterion_name = positive_text
118
+ print(f"Procesando: \"{criterion_name}\"")
 
 
119
  similarity_positive, similarity_negative, difference = None, None, None
120
  classification_comp, classification_simp = "ERROR", "ERROR"
 
121
  try:
 
122
  tokens_pos, paddings_pos = bert_tokenize(positive_text, bert_preprocessor)
123
+ qformer_input_pos = {'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), 'ids': tokens_pos.tolist(), 'paddings': paddings_pos.tolist()}
124
+ text_embedding_pos = qformer_infer(**qformer_input_pos)['contrastive_txt_emb'].numpy()
 
 
 
125
  if text_embedding_pos.ndim == 1: text_embedding_pos = np.expand_dims(text_embedding_pos, axis=0)
126
 
 
127
  tokens_neg, paddings_neg = bert_tokenize(negative_text, bert_preprocessor)
128
+ qformer_input_neg = {'image_feature': np.zeros([1, 8, 8, 1376], dtype=np.float32).tolist(), 'ids': tokens_neg.tolist(), 'paddings': paddings_neg.tolist()}
129
+ text_embedding_neg = qformer_infer(**qformer_input_neg)['contrastive_txt_emb'].numpy()
 
 
 
130
  if text_embedding_neg.ndim == 1: text_embedding_neg = np.expand_dims(text_embedding_neg, axis=0)
131
 
132
+ if image_embedding.shape[1] != text_embedding_pos.shape[1]: raise ValueError(f"Dim mismatch: Img ({image_embedding.shape[1]}) vs Pos ({text_embedding_pos.shape[1]})")
133
+ if image_embedding.shape[1] != text_embedding_neg.shape[1]: raise ValueError(f"Dim mismatch: Img ({image_embedding.shape[1]}) vs Neg ({text_embedding_neg.shape[1]})")
 
 
 
134
 
 
135
  similarity_positive = cosine_similarity(image_embedding, text_embedding_pos)[0][0]
136
  similarity_negative = cosine_similarity(image_embedding, text_embedding_neg)[0][0]
 
137
 
 
138
  difference = similarity_positive - similarity_negative
139
  classification_comp = "PASS" if difference > SIMILARITY_DIFFERENCE_THRESHOLD else "FAIL"
140
  classification_simp = "PASS" if similarity_positive > POSITIVE_SIMILARITY_THRESHOLD else "FAIL"
141
+ print(f" Sim(+)={similarity_positive:.4f}, Sim(-)={similarity_negative:.4f}, Diff={difference:.4f} -> Comp:{classification_comp}, Simp:{classification_simp}")
 
142
  except Exception as e:
143
+ print(f" ERROR criterio '{criterion_name}': {e}"); traceback.print_exc()
 
 
 
 
144
  detailed_results[criterion_name] = {
145
+ 'positive_prompt': positive_text, 'negative_prompt': negative_text,
 
146
  'similarity_positive': float(similarity_positive) if similarity_positive is not None else None,
147
  'similarity_negative': float(similarity_negative) if similarity_negative is not None else None,
148
  'difference': float(difference) if difference is not None else None,
149
+ 'classification_comparative': classification_comp, 'classification_simplified': classification_simp
 
150
  }
151
  return detailed_results
152
 
153
  # --- Carga Global de Modelos ---
 
154
  print("--- Iniciando carga global de modelos ---")
155
  start_time = time.time()
156
  models_loaded = False
157
  bert_preprocessor_global = None
158
  elixrc_infer_global = None
159
  qformer_infer_global = None
 
160
  try:
161
+ # Añadir token si es necesario (para repos privados o gated)
162
+ hf_token = os.environ.get("HF_TOKEN") # Leer token desde secretos del Space
163
+ # if hf_token:
164
+ # print("Usando HF_TOKEN para autenticación.")
165
+ # HfFolder.save_token(hf_token)
166
 
 
167
  os.makedirs(MODEL_DOWNLOAD_DIR, exist_ok=True)
168
  print(f"Descargando/verificando modelos en: {MODEL_DOWNLOAD_DIR}")
169
  snapshot_download(repo_id=MODEL_REPO_ID, local_dir=MODEL_DOWNLOAD_DIR,
170
  allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'],
171
+ local_dir_use_symlinks=False, token=hf_token) # Pasar token aquí
172
  print("Modelos descargados/verificados.")
173
 
 
174
  print("Cargando Preprocesador BERT...")
 
175
  bert_preprocess_handle = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
176
  bert_preprocessor_global = tf_hub.KerasLayer(bert_preprocess_handle)
177
  print("Preprocesador BERT cargado.")
178
 
 
179
  print("Cargando ELIXR-C...")
180
  elixrc_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'elixr-c-v2-pooled')
181
  elixrc_model = tf.saved_model.load(elixrc_model_path)
182
  elixrc_infer_global = elixrc_model.signatures['serving_default']
183
  print("Modelo ELIXR-C cargado.")
184
 
 
185
  print("Cargando QFormer (ELIXR-B Text)...")
186
  qformer_model_path = os.path.join(MODEL_DOWNLOAD_DIR, 'pax-elixr-b-text')
187
  qformer_model = tf.saved_model.load(qformer_model_path)
 
191
  models_loaded = True
192
  end_time = time.time()
193
  print(f"--- Modelos cargados globalmente con éxito en {end_time - start_time:.2f} segundos ---")
 
194
  except Exception as e:
195
  models_loaded = False
196
+ print(f"--- ERROR CRÍTICO DURANTE LA CARGA GLOBAL DE MODELOS ---"); print(e); traceback.print_exc()
 
 
 
197
 
198
  # --- Función Principal de Procesamiento para Gradio ---
199
+ def assess_quality_and_update_ui(image_pil):
200
+ """Procesa la imagen y devuelve actualizaciones para la UI."""
201
  if not models_loaded:
202
  raise gr.Error("Error: Los modelos no se pudieron cargar. La aplicación no puede procesar imágenes.")
203
  if image_pil is None:
204
+ # Devuelve valores por defecto/vacíos y controla la visibilidad
205
+ return (
206
+ gr.update(visible=True), # Muestra bienvenida
207
+ gr.update(visible=False), # Oculta resultados
208
+ None, # Borra imagen de salida
209
+ gr.update(value="N/A"), # Borra etiqueta
210
+ pd.DataFrame(), # Borra dataframe
211
+ None # Borra JSON
212
+ )
213
 
214
  print("\n--- Iniciando evaluación para nueva imagen ---")
215
  start_process_time = time.time()
 
216
  try:
217
+ # 1. Convertir a NumPy
 
218
  img_np = np.array(image_pil.convert('L'))
219
+ # 2. Generar Embedding
 
 
 
220
  image_embedding = generate_image_embedding(img_np, elixrc_infer_global, qformer_infer_global)
221
+ # 3. Clasificar
 
 
 
222
  detailed_results = calculate_similarities_and_classify(image_embedding, bert_preprocessor_global, qformer_infer_global)
223
+ # 4. Formatear Resultados
224
+ output_data, passed_count, total_count = [], 0, 0
 
 
 
 
225
  for criterion, details in detailed_results.items():
226
  total_count += 1
227
+ sim_pos = details['similarity_positive']
228
+ sim_neg = details['similarity_negative']
229
+ diff = details['difference']
230
+ comp = details['classification_comparative']
231
+ simp = details['classification_simplified']
232
+ output_data.append([ criterion, f"{sim_pos:.4f}" if sim_pos else "N/A",
233
+ f"{sim_neg:.4f}" if sim_neg else "N/A", f"{diff:.4f}" if diff else "N/A", comp, simp ])
234
+ if comp == "PASS": passed_count += 1
235
+ df_results = pd.DataFrame(output_data, columns=[ "Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)" ])
236
+ overall_quality = "Error"; pass_rate = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  if total_count > 0:
238
  pass_rate = passed_count / total_count
239
  if pass_rate >= 0.85: overall_quality = "Excellent"
240
  elif pass_rate >= 0.70: overall_quality = "Good"
241
  elif pass_rate >= 0.50: overall_quality = "Fair"
242
  else: overall_quality = "Poor"
243
+ quality_label = f"{overall_quality} ({passed_count}/{total_count} passed)"
 
244
  end_process_time = time.time()
245
+ print(f"--- Evaluación completada en {end_process_time - start_process_time:.2f} seg ---")
246
+ # Devolver resultados y actualizar visibilidad
247
+ return (
248
+ gr.update(visible=False), # Oculta bienvenida
249
+ gr.update(visible=True), # Muestra resultados
250
+ image_pil, # Muestra imagen procesada
251
+ gr.update(value=quality_label), # Actualiza etiqueta
252
+ df_results, # Actualiza dataframe
253
+ detailed_results # Actualiza JSON
254
+ )
255
  except Exception as e:
256
+ print(f"Error durante procesamiento Gradio: {e}"); traceback.print_exc()
257
+ raise gr.Error(f"Error procesando imagen: {str(e)}")
258
+
259
+ # --- Función para Resetear la UI ---
260
+ def reset_ui():
261
+ print("Reseteando UI...")
262
+ return (
263
+ gr.update(visible=True), # Muestra bienvenida
264
+ gr.update(visible=False), # Oculta resultados
265
+ None, # Borra imagen de entrada
266
+ None, # Borra imagen de salida
267
+ gr.update(value="N/A"), # Borra etiqueta
268
+ pd.DataFrame(), # Borra dataframe
269
+ None # Borra JSON
 
 
 
 
 
 
 
270
  )
271
+
272
+ # --- Definir Tema Oscuro Personalizado ---
273
+ # Inspirado en los colores del HTML original y Tailwind dark grays/blues
274
+ dark_theme = gr.themes.Default(
275
+ primary_hue=gr.themes.colors.blue, # Azul como color primario
276
+ secondary_hue=gr.themes.colors.blue, # Azul secundario
277
+ neutral_hue=gr.themes.colors.gray, # Gris neutro
278
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
279
+ font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "Consolas", "monospace"],
280
+ ).set(
281
+ # Fondos
282
+ body_background_fill="#111827", # Fondo principal muy oscuro (gray-900)
283
+ background_fill_primary="#1f2937", # Fondo de componentes (gray-800)
284
+ background_fill_secondary="#374151", # Fondo secundario (gray-700)
285
+ block_background_fill="#1f2937", # Fondo de bloques (gray-800)
286
+
287
+ # Texto
288
+ body_text_color="#d1d5db", # Texto principal claro (gray-300)
289
+ text_color_subdued="#9ca3af", # Texto secundario (gray-400)
290
+ block_label_text_color="#d1d5db", # Etiquetas de bloque (gray-300)
291
+ block_title_text_color="#ffffff", # Títulos de bloque (blanco)
292
+
293
+ # Bordes
294
+ border_color_accent="#374151", # Borde (gray-700)
295
+ border_color_primary="#4b5563", # Borde primario (gray-600)
296
+
297
+ # Botones y Elementos Interactivos
298
+ button_primary_background_fill="*primary_600", # Usa color primario (azul)
299
+ button_primary_text_color="#ffffff",
300
+ button_secondary_background_fill="*neutral_700",
301
+ button_secondary_text_color="#ffffff",
302
+ input_background_fill="#374151", # Fondo de inputs (gray-700)
303
+ input_border_color="#4b5563", # Borde de inputs (gray-600)
304
+ input_text_color="#ffffff", # Texto en inputs
305
+
306
+ # Sombras y Radios
307
+ shadow_drop="rgba(0,0,0,0.2) 0px 2px 4px",
308
+ block_shadow="rgba(0,0,0,0.2) 0px 2px 5px",
309
+ radius_size="*radius_lg", # Bordes redondeados
310
+ )
311
+
312
+
313
+ # --- Definir la Interfaz Gradio con Bloques y Tema ---
314
+ with gr.Blocks(theme=dark_theme, title="CXR Quality Assessment") as demo:
315
+ # --- Cabecera ---
316
  with gr.Row():
317
+ gr.Markdown(
318
+ """
319
+ # <span style="color: #e5e7eb;">CXR Quality Assessment</span>
320
+ <p style="color: #9ca3af;">Evaluate chest X-ray technical quality using AI (ELIXR family)</p>
321
+ """, # Usar blanco/gris claro para texto cabecera
322
+ elem_id="app-header"
323
+ )
324
+
325
+ # --- Contenido Principal (Dos Columnas) ---
326
+ with gr.Row(equal_height=False): # Permitir alturas diferentes
327
+
328
+ # --- Columna Izquierda (Carga) ---
329
+ with gr.Column(scale=1, min_width=350):
330
+ gr.Markdown("### 1. Upload Image", elem_id="upload-title")
331
+ input_image = gr.Image(type="pil", label="Upload Chest X-ray", height=300) # Altura fija para imagen entrada
332
+ with gr.Row():
333
+ analyze_btn = gr.Button("Analyze Image", variant="primary", scale=2)
334
+ reset_btn = gr.Button("Reset", variant="secondary", scale=1)
335
  # Añadir ejemplos si tienes imágenes de ejemplo
 
336
  # gr.Examples(
337
+ # examples=[os.path.join("examples", "sample_cxr.png")],
338
+ # inputs=input_image, label="Example CXR"
339
  # )
340
+ gr.Markdown(
341
+ "<p style='color:#9ca3af; font-size:0.9em;'>Model loading on startup takes ~1 min. Analysis takes ~15-40 sec.</p>"
342
+ )
343
+
344
+
345
+ # --- Columna Derecha (Bienvenida / Resultados) ---
346
  with gr.Column(scale=2):
347
+
348
+ # --- Bloque de Bienvenida (Visible Inicialmente) ---
349
+ with gr.Column(visible=True, elem_id="welcome-section") as welcome_block:
350
+ gr.Markdown(
351
+ """
352
+ ### Welcome!
353
+ Upload a chest X-ray image (PNG, JPG, etc.) on the left panel and click "Analyze Image".
354
+
355
+ The system will evaluate its technical quality based on 7 standard criteria using the ELIXR model family.
356
+ The results will appear here once the analysis is complete.
357
+ """, elem_id="welcome-text"
358
  )
359
+ # Podrías añadir un icono o imagen aquí si quieres
360
+ # gr.Image("path/to/welcome_icon.png", interactive=False, show_label=False, show_download_button=False)
361
+
362
+
363
+ # --- Bloque de Resultados (Oculto Inicialmente) ---
364
+ with gr.Column(visible=False, elem_id="results-section") as results_block:
365
+ gr.Markdown("### 2. Quality Assessment Results", elem_id="results-title")
366
+ with gr.Row(): # Fila para imagen de salida y resumen
367
+ with gr.Column(scale=1):
368
+ output_image = gr.Image(type="pil", label="Analyzed Image", interactive=False)
369
+ with gr.Column(scale=1):
370
+ gr.Markdown("#### Summary", elem_id="summary-title")
371
+ output_label = gr.Label(value="N/A", label="Overall Quality Estimate", elem_id="quality-label")
372
+ # Podríamos añadir más texto de resumen aquí si quisiéramos
373
+
374
+ gr.Markdown("#### Detailed Criteria Evaluation", elem_id="detailed-title")
375
+ output_dataframe = gr.DataFrame(
376
+ headers=["Criterion", "Sim (+)", "Sim (-)", "Difference", "Assessment (Comp)", "Assessment (Simp)"],
377
+ label=None, # Quitar etiqueta redundante
378
+ wrap=True,
379
+ # La altura ahora se maneja mejor automáticamente o con CSS
380
+ # row_count=(7, "dynamic") # Mostrar 7 filas, permitir scroll si hay más
381
+ max_rows=10, # Limitar filas visibles con scroll
382
+ overflow_row_behaviour="show_ends", # Muestra inicio/fin al hacer scroll
383
+ interactive=False, # No editable
384
+ elem_id="results-dataframe"
385
+ )
386
+ with gr.Accordion("Raw JSON Output (for debugging)", open=False):
387
+ output_json = gr.JSON(label=None)
388
+
389
+ gr.Markdown(
390
+ f"""
391
+ #### Technical Notes
392
+ * **Criterion:** Quality aspect evaluated.
393
+ * **Sim (+/-):** Cosine similarity with positive/negative prompt.
394
+ * **Difference:** Sim (+) - Sim (-).
395
+ * **Assessment (Comp):** PASS if Difference > {SIMILARITY_DIFFERENCE_THRESHOLD}. (Main Result)
396
+ * **Assessment (Simp):** PASS if Sim (+) > {POSITIVE_SIMILARITY_THRESHOLD}.
397
+ """, elem_id="notes-text"
398
+ )
399
+
400
+ # --- Pie de página ---
401
+ gr.Markdown(
402
+ """
403
+ ----
404
+ <p style='text-align:center; color:#9ca3af; font-size:0.8em;'>
405
+ CXR Quality Assessment Tool | Model: google/cxr-foundation | Interface: Gradio
406
+ </p>
407
+ """, elem_id="app-footer"
408
+ )
409
+
410
+
411
+ # --- Conexiones de Eventos ---
412
+ analyze_btn.click(
413
+ fn=assess_quality_and_update_ui,
414
+ inputs=[input_image],
415
+ outputs=[
416
+ welcome_block, # -> actualiza visibilidad bienvenida
417
+ results_block, # -> actualiza visibilidad resultados
418
+ output_image, # -> muestra imagen analizada
419
+ output_label, # -> actualiza etiqueta resumen
420
+ output_dataframe, # -> actualiza tabla
421
+ output_json # -> actualiza JSON
422
+ ]
423
  )
424
 
425
+ reset_btn.click(
426
+ fn=reset_ui,
427
+ inputs=None, # No necesita inputs
428
+ outputs=[
429
+ welcome_block,
430
+ results_block,
431
+ input_image, # -> limpia imagen entrada
432
+ output_image,
433
+ output_label,
434
+ output_dataframe,
435
+ output_json
436
+ ]
437
+ )
438
+
439
+
440
  # --- Iniciar la Aplicación Gradio ---
 
 
 
 
441
  if __name__ == "__main__":
442
+ # server_name="0.0.0.0" para accesibilidad en red local
 
443
  # server_port=7860 es el puerto estándar de HF Spaces
444
+ # auth=("user", "password") # Si quieres añadir autenticación básica localmente
445
+ demo.launch(server_name="0.0.0.0", server_port=7860) #, share=True) # Quita share=True para despliegue normal