Xray_ViT / app.py
Agrannya's picture
Upload folder using huggingface_hub
effeed6 verified
raw
history blame
5.99 kB
# Import required libraries
import gradio as gr
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
import numpy as np
import io
import matplotlib.pyplot as plt
import warnings
import os
# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")
# --- 1. Load Model and Processor at Startup ---
print("Loading model and processor...")
try:
# Use the model name from your working example
MODEL_NAME = "codewithdark/vit-chest-xray"
# Load the processor and model from Hugging Face
processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
# Check for GPU and move the model if available for faster inference
if torch.cuda.is_available():
print(f"GPU available: {torch.cuda.get_device_name(0)}. Moving model to GPU.")
model.to("cuda")
else:
print("Warning: GPU not available, using CPU. Inference may be slow.")
# Define the labels for classification
LABEL_COLUMNS = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding']
print("Model and processor loaded successfully.")
except Exception as e:
# Handle errors during model loading
print(f"Fatal Error: Could not load model or processor. {e}")
processor, model = None, None
# --- 2. Define Prediction and Visualization Functions ---
def predict_xray(image: Image.Image) -> dict:
"""
Takes a PIL Image, preprocesses it, and returns the predicted label and confidence.
"""
if model is None or processor is None:
return {"error": "Model not loaded. Please restart the application."}
# Ensure the image is in RGB format, as required by the model
if image.mode != 'RGB':
image = image.convert('RGB')
# Prepare the image for the model
inputs = processor(images=image, return_tensors="pt")
# Move inputs to GPU if available
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
# Perform inference
with torch.no_grad():
outputs = model(**inputs)
# Get probabilities and find the top prediction
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
confidence = torch.max(probabilities).item()
predicted_class_idx = torch.argmax(probabilities, dim=-1).item()
# Map index to label
predicted_label = LABEL_COLUMNS[predicted_class_idx]
return {"label": predicted_label, "confidence": confidence}
def visualize_results(image: Image.Image, result: dict) -> Image.Image:
"""
Overlays the prediction results on the input image for visualization.
"""
image_viz = image.copy().convert("RGB")
img_array = np.array(image_viz)
plt.figure(figsize=(6, 6))
plt.imshow(img_array)
plt.axis("off")
# Create the text to display on the image
text = f"{result['label']}: {result['confidence']:.2%}"
plt.text(10, 30, text, color="white", fontsize=12,
bbox=dict(facecolor="black", alpha=0.6))
# Save the plot to a memory buffer and return as a PIL Image
buffer = io.BytesIO()
plt.savefig(buffer, format="png", bbox_inches="tight", pad_inches=0)
plt.close()
buffer.seek(0)
return Image.open(buffer)
# --- 3. Main Function for Gradio ---
def analyze_chest_xray(image: Image.Image) -> tuple:
"""
The main function that connects the UI to the backend prediction logic.
"""
if image is None:
return None, "Please upload an image to analyze."
# Get the prediction
prediction_result = predict_xray(image)
if "error" in prediction_result:
return None, prediction_result["error"]
# Create the text output
text_result = f"Prediction: {prediction_result['label']}\nConfidence: {prediction_result['confidence']:.2%}"
# Create the annotated image for display
annotated_image = visualize_results(image, prediction_result)
return annotated_image, text_result
# --- 4. Define and Launch the Gradio Interface ---
# Spotify-inspired CSS for a modern, dark theme
custom_css = """
.gradio-container { background: #121212; font-family: 'Circular', 'Helvetica Neue', sans-serif; color: #ffffff; }
.title { font-size: 2.5em; color: #1db954; text-align: center; margin-bottom: 10px; font-weight: bold; }
.gr-button { background: #1db954 !important; color: #000000 !important; border-radius: 500px !important; font-weight: 700 !important; text-transform: uppercase; }
.gr-button:hover { background: #1ed760 !important; transform: scale(1.05); }
.output-image, .input-image { border: 2px solid #1db954 !important; border-radius: 12px !important; box-shadow: 0 4px 12px rgba(29, 185, 84, 0.3); }
.gr-textbox { background: #282828 !important; color: #ffffff !important; border: 1px solid #1db954 !important; border-radius: 8px !important; }
"""
with gr.Blocks(css=custom_css, title="Chest X-Ray Detection") as demo:
gr.Markdown("<h1 class='title'>Chest X-Ray Analysis System</h1>", elem_classes="title")
gr.Markdown("Upload a chest X-ray image to classify potential abnormalities using a Vision Transformer (ViT).")
with gr.Row():
image_input = gr.Image(type="pil", label="Upload Chest X-Ray", sources=["upload"], elem_classes="input-image")
output_image = gr.Image(label="Annotated Result", elem_classes="output-image")
predict_button = gr.Button("Analyze X-Ray")
output_text = gr.Textbox(label="Prediction Details", lines=2)
predict_button.click(
fn=analyze_chest_xray,
inputs=[image_input],
outputs=[output_image, output_text]
)
if __name__ == "__main__":
if model and processor:
print("Launching Gradio interface...")
demo.launch(share=True, debug=True)
else:
print("Gradio interface could not be launched because the model failed to load.")