File size: 5,985 Bytes
effeed6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Import required libraries
import gradio as gr
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
import numpy as np
import io
import matplotlib.pyplot as plt
import warnings
import os

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# --- 1. Load Model and Processor at Startup ---
print("Loading model and processor...")
try:
    # Use the model name from your working example
    MODEL_NAME = "codewithdark/vit-chest-xray"
    
    # Load the processor and model from Hugging Face
    processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
    model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
    
    # Check for GPU and move the model if available for faster inference
    if torch.cuda.is_available():
        print(f"GPU available: {torch.cuda.get_device_name(0)}. Moving model to GPU.")
        model.to("cuda")
    else:
        print("Warning: GPU not available, using CPU. Inference may be slow.")
        
    # Define the labels for classification
    LABEL_COLUMNS = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding']
    print("Model and processor loaded successfully.")

except Exception as e:
    # Handle errors during model loading
    print(f"Fatal Error: Could not load model or processor. {e}")
    processor, model = None, None

# --- 2. Define Prediction and Visualization Functions ---

def predict_xray(image: Image.Image) -> dict:
    """
    Takes a PIL Image, preprocesses it, and returns the predicted label and confidence.
    """
    if model is None or processor is None:
        return {"error": "Model not loaded. Please restart the application."}

    # Ensure the image is in RGB format, as required by the model
    if image.mode != 'RGB':
        image = image.convert('RGB')
    
    # Prepare the image for the model
    inputs = processor(images=image, return_tensors="pt")

    # Move inputs to GPU if available
    if torch.cuda.is_available():
        inputs = {k: v.to("cuda") for k, v in inputs.items()}

    # Perform inference
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Get probabilities and find the top prediction
    logits = outputs.logits
    probabilities = torch.softmax(logits, dim=-1)
    confidence = torch.max(probabilities).item()
    predicted_class_idx = torch.argmax(probabilities, dim=-1).item()
    
    # Map index to label
    predicted_label = LABEL_COLUMNS[predicted_class_idx]
    
    return {"label": predicted_label, "confidence": confidence}

def visualize_results(image: Image.Image, result: dict) -> Image.Image:
    """
    Overlays the prediction results on the input image for visualization.
    """
    image_viz = image.copy().convert("RGB")
    img_array = np.array(image_viz)
    
    plt.figure(figsize=(6, 6))
    plt.imshow(img_array)
    plt.axis("off")
    
    # Create the text to display on the image
    text = f"{result['label']}: {result['confidence']:.2%}"
    plt.text(10, 30, text, color="white", fontsize=12,
             bbox=dict(facecolor="black", alpha=0.6))
    
    # Save the plot to a memory buffer and return as a PIL Image
    buffer = io.BytesIO()
    plt.savefig(buffer, format="png", bbox_inches="tight", pad_inches=0)
    plt.close()
    buffer.seek(0)
    
    return Image.open(buffer)

# --- 3. Main Function for Gradio ---

def analyze_chest_xray(image: Image.Image) -> tuple:
    """
    The main function that connects the UI to the backend prediction logic.
    """
    if image is None:
        return None, "Please upload an image to analyze."
    
    # Get the prediction
    prediction_result = predict_xray(image)

    if "error" in prediction_result:
        return None, prediction_result["error"]

    # Create the text output
    text_result = f"Prediction: {prediction_result['label']}\nConfidence: {prediction_result['confidence']:.2%}"
    
    # Create the annotated image for display
    annotated_image = visualize_results(image, prediction_result)
    
    return annotated_image, text_result

# --- 4. Define and Launch the Gradio Interface ---

# Spotify-inspired CSS for a modern, dark theme
custom_css = """
.gradio-container { background: #121212; font-family: 'Circular', 'Helvetica Neue', sans-serif; color: #ffffff; }
.title { font-size: 2.5em; color: #1db954; text-align: center; margin-bottom: 10px; font-weight: bold; }
.gr-button { background: #1db954 !important; color: #000000 !important; border-radius: 500px !important; font-weight: 700 !important; text-transform: uppercase; }
.gr-button:hover { background: #1ed760 !important; transform: scale(1.05); }
.output-image, .input-image { border: 2px solid #1db954 !important; border-radius: 12px !important; box-shadow: 0 4px 12px rgba(29, 185, 84, 0.3); }
.gr-textbox { background: #282828 !important; color: #ffffff !important; border: 1px solid #1db954 !important; border-radius: 8px !important; }
"""

with gr.Blocks(css=custom_css, title="Chest X-Ray Detection") as demo:
    gr.Markdown("<h1 class='title'>Chest X-Ray Analysis System</h1>", elem_classes="title")
    gr.Markdown("Upload a chest X-ray image to classify potential abnormalities using a Vision Transformer (ViT).")
    
    with gr.Row():
        image_input = gr.Image(type="pil", label="Upload Chest X-Ray", sources=["upload"], elem_classes="input-image")
        output_image = gr.Image(label="Annotated Result", elem_classes="output-image")
        
    predict_button = gr.Button("Analyze X-Ray")
    output_text = gr.Textbox(label="Prediction Details", lines=2)
    
    predict_button.click(
        fn=analyze_chest_xray,
        inputs=[image_input],
        outputs=[output_image, output_text]
    )

if __name__ == "__main__":
    if model and processor:
        print("Launching Gradio interface...")
        demo.launch(share=True, debug=True)
    else:
        print("Gradio interface could not be launched because the model failed to load.")