Spaces:
Running
Running
# Import required libraries | |
import gradio as gr | |
from PIL import Image | |
import torch | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
import numpy as np | |
import io | |
import matplotlib.pyplot as plt | |
import warnings | |
import os | |
# Suppress warnings for cleaner output | |
warnings.filterwarnings("ignore") | |
# --- 1. Load Model and Processor at Startup --- | |
print("Loading model and processor...") | |
try: | |
# Use the model name from your working example | |
MODEL_NAME = "codewithdark/vit-chest-xray" | |
# Load the processor and model from Hugging Face | |
processor = AutoImageProcessor.from_pretrained(MODEL_NAME) | |
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) | |
# Check for GPU and move the model if available for faster inference | |
if torch.cuda.is_available(): | |
print(f"GPU available: {torch.cuda.get_device_name(0)}. Moving model to GPU.") | |
model.to("cuda") | |
else: | |
print("Warning: GPU not available, using CPU. Inference may be slow.") | |
# Define the labels for classification | |
LABEL_COLUMNS = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding'] | |
print("Model and processor loaded successfully.") | |
except Exception as e: | |
# Handle errors during model loading | |
print(f"Fatal Error: Could not load model or processor. {e}") | |
processor, model = None, None | |
# --- 2. Define Prediction and Visualization Functions --- | |
def predict_xray(image: Image.Image) -> dict: | |
""" | |
Takes a PIL Image, preprocesses it, and returns the predicted label and confidence. | |
""" | |
if model is None or processor is None: | |
return {"error": "Model not loaded. Please restart the application."} | |
# Ensure the image is in RGB format, as required by the model | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Prepare the image for the model | |
inputs = processor(images=image, return_tensors="pt") | |
# Move inputs to GPU if available | |
if torch.cuda.is_available(): | |
inputs = {k: v.to("cuda") for k, v in inputs.items()} | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# Get probabilities and find the top prediction | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=-1) | |
confidence = torch.max(probabilities).item() | |
predicted_class_idx = torch.argmax(probabilities, dim=-1).item() | |
# Map index to label | |
predicted_label = LABEL_COLUMNS[predicted_class_idx] | |
return {"label": predicted_label, "confidence": confidence} | |
def visualize_results(image: Image.Image, result: dict) -> Image.Image: | |
""" | |
Overlays the prediction results on the input image for visualization. | |
""" | |
image_viz = image.copy().convert("RGB") | |
img_array = np.array(image_viz) | |
plt.figure(figsize=(6, 6)) | |
plt.imshow(img_array) | |
plt.axis("off") | |
# Create the text to display on the image | |
text = f"{result['label']}: {result['confidence']:.2%}" | |
plt.text(10, 30, text, color="white", fontsize=12, | |
bbox=dict(facecolor="black", alpha=0.6)) | |
# Save the plot to a memory buffer and return as a PIL Image | |
buffer = io.BytesIO() | |
plt.savefig(buffer, format="png", bbox_inches="tight", pad_inches=0) | |
plt.close() | |
buffer.seek(0) | |
return Image.open(buffer) | |
# --- 3. Main Function for Gradio --- | |
def analyze_chest_xray(image: Image.Image) -> tuple: | |
""" | |
The main function that connects the UI to the backend prediction logic. | |
""" | |
if image is None: | |
return None, "Please upload an image to analyze." | |
# Get the prediction | |
prediction_result = predict_xray(image) | |
if "error" in prediction_result: | |
return None, prediction_result["error"] | |
# Create the text output | |
text_result = f"Prediction: {prediction_result['label']}\nConfidence: {prediction_result['confidence']:.2%}" | |
# Create the annotated image for display | |
annotated_image = visualize_results(image, prediction_result) | |
return annotated_image, text_result | |
# --- 4. Define and Launch the Gradio Interface --- | |
# Spotify-inspired CSS for a modern, dark theme | |
custom_css = """ | |
.gradio-container { background: #121212; font-family: 'Circular', 'Helvetica Neue', sans-serif; color: #ffffff; } | |
.title { font-size: 2.5em; color: #1db954; text-align: center; margin-bottom: 10px; font-weight: bold; } | |
.gr-button { background: #1db954 !important; color: #000000 !important; border-radius: 500px !important; font-weight: 700 !important; text-transform: uppercase; } | |
.gr-button:hover { background: #1ed760 !important; transform: scale(1.05); } | |
.output-image, .input-image { border: 2px solid #1db954 !important; border-radius: 12px !important; box-shadow: 0 4px 12px rgba(29, 185, 84, 0.3); } | |
.gr-textbox { background: #282828 !important; color: #ffffff !important; border: 1px solid #1db954 !important; border-radius: 8px !important; } | |
""" | |
with gr.Blocks(css=custom_css, title="Chest X-Ray Detection") as demo: | |
gr.Markdown("<h1 class='title'>Chest X-Ray Analysis System</h1>", elem_classes="title") | |
gr.Markdown("Upload a chest X-ray image to classify potential abnormalities using a Vision Transformer (ViT).") | |
with gr.Row(): | |
image_input = gr.Image(type="pil", label="Upload Chest X-Ray", sources=["upload"], elem_classes="input-image") | |
output_image = gr.Image(label="Annotated Result", elem_classes="output-image") | |
predict_button = gr.Button("Analyze X-Ray") | |
output_text = gr.Textbox(label="Prediction Details", lines=2) | |
predict_button.click( | |
fn=analyze_chest_xray, | |
inputs=[image_input], | |
outputs=[output_image, output_text] | |
) | |
if __name__ == "__main__": | |
if model and processor: | |
print("Launching Gradio interface...") | |
demo.launch(share=True, debug=True) | |
else: | |
print("Gradio interface could not be launched because the model failed to load.") | |