ThoraxAI / app.py
CristinaLA's picture
Update app.py
368855b verified
import gradio as gr
from huggingface_hub import snapshot_download
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import png
# --- Part 1: Launch Embeddings Model ---
# Define local route to download the files of the model
MODEL_DIR = "./cxr_foundation_models"
# Download the files of the model
snapshot_download(repo_id="google/cxr-foundation", local_dir=MODEL_DIR,
allow_patterns=['elixr-c-v2-pooled/*', 'pax-elixr-b-text/*'])
# Load saved TensorFlow models
elixrc_model = tf.saved_model.load(f"{MODEL_DIR}/elixr-c-v2-pooled")
qformer_model = tf.saved_model.load(f"{MODEL_DIR}/pax-elixr-b-text")
# Helper function to proccess images
def png_to_tfexample(image_array: np.ndarray) -> tf.train.Example:
image = image_array.astype(np.float32)
image -= image.min()
if image_array.dtype == np.uint8:
pixel_array = image.astype(np.uint8)
bitdepth = 8
else:
max_val = image.max()
if max_val > 0:
image *= 65535 / max_val
pixel_array = image.astype(np.uint16)
bitdepth = 16
if pixel_array.ndim != 2:
raise ValueError(f'Array must be 2-D. Actual dimensions: {pixel_array.ndim}')
output = io.BytesIO()
png.Writer(
width=pixel_array.shape[1],
height=pixel_array.shape[0],
greyscale=True,
bitdepth=bitdepth
).write(output, pixel_array.tolist())
png_bytes = output.getvalue()
example = tf.train.Example()
features = example.features.feature
features['image/encoded'].bytes_list.value.append(png_bytes)
features['image/format'].bytes_list.value.append(b'png')
return example
# --- Part 2: Application Logic with a Demo Classifier ---
def procesar_radiografia(imagen: Image.Image):
# Step 1: Generate the embedding
img_array = np.array(imagen.convert('L'))
elixrc_infer = elixrc_model.signatures['serving_default']
elixrc_output = elixrc_infer(input_example=tf.constant([png_to_tfexample(img_array).SerializeToString()]))
elixrc_embedding = elixrc_output['feature_maps_0'].numpy()
qformer_input = {
'image_feature': elixrc_embedding.tolist(),
'ids': np.zeros((1, 1, 128), dtype=np.int32).tolist(),
'paddings': np.zeros((1, 1, 128), dtype=np.float32).tolist(),
}
qformer_infer = qformer_model.signatures['serving_default']
qformer_output = qformer_infer(**qformer_input)
elixrb_embeddings = qformer_output['all_contrastive_img_emb']
# Step 2: Simulate classification based on embedding
# En un proyecto real, aquí iría el código de tu clasificador.
# Por ahora, simularemos un resultado
etiquetas = {
"Normal": 0.8,
"Neumonía": 0.15,
"Cardiomegalia": 0.05
}
# Devuelve el resultado en un formato de etiqueta
return etiquetas
# Crea la interfaz de Gradio
interfaz = gr.Interface(
fn=procesar_radiografia,
inputs=gr.Image(type="pil"),
outputs="label",
title="Asistente de Análisis de Radiografías de Tórax (Demo)",
description="Sube una radiografía y el modelo de IA proporcionará una clasificación preliminar. **Nota: Esto es una herramienta demostrativa y no un diagnóstico médico.**"
)
# Lanza la interfaz
interfaz.launch()