# Import required libraries import gradio as gr from PIL import Image import torch from transformers import AutoImageProcessor, AutoModelForImageClassification import numpy as np import io import matplotlib.pyplot as plt import warnings import os # Suppress warnings for cleaner output warnings.filterwarnings("ignore") # --- 1. Load Model and Processor at Startup --- print("Loading model and processor...") try: # Use the model name from your working example MODEL_NAME = "codewithdark/vit-chest-xray" # Load the processor and model from Hugging Face processor = AutoImageProcessor.from_pretrained(MODEL_NAME) model = AutoModelForImageClassification.from_pretrained(MODEL_NAME) # Check for GPU and move the model if available for faster inference if torch.cuda.is_available(): print(f"GPU available: {torch.cuda.get_device_name(0)}. Moving model to GPU.") model.to("cuda") else: print("Warning: GPU not available, using CPU. Inference may be slow.") # Define the labels for classification LABEL_COLUMNS = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding'] print("Model and processor loaded successfully.") except Exception as e: # Handle errors during model loading print(f"Fatal Error: Could not load model or processor. {e}") processor, model = None, None # --- 2. Define Prediction and Visualization Functions --- def predict_xray(image: Image.Image) -> dict: """ Takes a PIL Image, preprocesses it, and returns the predicted label and confidence. """ if model is None or processor is None: return {"error": "Model not loaded. Please restart the application."} # Ensure the image is in RGB format, as required by the model if image.mode != 'RGB': image = image.convert('RGB') # Prepare the image for the model inputs = processor(images=image, return_tensors="pt") # Move inputs to GPU if available if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # Perform inference with torch.no_grad(): outputs = model(**inputs) # Get probabilities and find the top prediction logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) confidence = torch.max(probabilities).item() predicted_class_idx = torch.argmax(probabilities, dim=-1).item() # Map index to label predicted_label = LABEL_COLUMNS[predicted_class_idx] return {"label": predicted_label, "confidence": confidence} def visualize_results(image: Image.Image, result: dict) -> Image.Image: """ Overlays the prediction results on the input image for visualization. """ image_viz = image.copy().convert("RGB") img_array = np.array(image_viz) plt.figure(figsize=(6, 6)) plt.imshow(img_array) plt.axis("off") # Create the text to display on the image text = f"{result['label']}: {result['confidence']:.2%}" plt.text(10, 30, text, color="white", fontsize=12, bbox=dict(facecolor="black", alpha=0.6)) # Save the plot to a memory buffer and return as a PIL Image buffer = io.BytesIO() plt.savefig(buffer, format="png", bbox_inches="tight", pad_inches=0) plt.close() buffer.seek(0) return Image.open(buffer) # --- 3. Main Function for Gradio --- def analyze_chest_xray(image: Image.Image) -> tuple: """ The main function that connects the UI to the backend prediction logic. """ if image is None: return None, "Please upload an image to analyze." # Get the prediction prediction_result = predict_xray(image) if "error" in prediction_result: return None, prediction_result["error"] # Create the text output text_result = f"Prediction: {prediction_result['label']}\nConfidence: {prediction_result['confidence']:.2%}" # Create the annotated image for display annotated_image = visualize_results(image, prediction_result) return annotated_image, text_result # --- 4. Define and Launch the Gradio Interface --- # Spotify-inspired CSS for a modern, dark theme custom_css = """ .gradio-container { background: #121212; font-family: 'Circular', 'Helvetica Neue', sans-serif; color: #ffffff; } .title { font-size: 2.5em; color: #1db954; text-align: center; margin-bottom: 10px; font-weight: bold; } .gr-button { background: #1db954 !important; color: #000000 !important; border-radius: 500px !important; font-weight: 700 !important; text-transform: uppercase; } .gr-button:hover { background: #1ed760 !important; transform: scale(1.05); } .output-image, .input-image { border: 2px solid #1db954 !important; border-radius: 12px !important; box-shadow: 0 4px 12px rgba(29, 185, 84, 0.3); } .gr-textbox { background: #282828 !important; color: #ffffff !important; border: 1px solid #1db954 !important; border-radius: 8px !important; } """ with gr.Blocks(css=custom_css, title="Chest X-Ray Detection") as demo: gr.Markdown("