File size: 9,407 Bytes
ca46f55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
from PIL import Image
import io
import cv2

# --- Load model and labels ---
model = load_model("checkpoints/keras_model.h5")
with open("labels.txt", "r") as f:
    class_labels = [line.strip() for line in f]

# --- Preprocess input ---
def preprocess_input(img):
    img = img.resize((224, 224))
    arr = keras_image.img_to_array(img)
    arr = arr / 255.0
    return np.expand_dims(arr, axis=0)

# --- Enhanced Grad-CAM implementation for Keras ---
def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"):
    try:
        # Try to find the specified layer
        target_layer = model.get_layer(last_conv_layer_name)
    except:
        # Fallback: find any convolutional layer
        for layer in model.layers:
            if 'conv' in layer.name.lower():
                target_layer = layer
                break
        else:
            return None

    grad_model = tf.keras.models.Model(
        [model.inputs], [target_layer.output, model.output]
    )

    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        loss = predictions[:, class_index]

    grads = tape.gradient(loss, conv_outputs)[0]
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    conv_outputs = conv_outputs[0]

    heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
    heatmap = np.maximum(heatmap, 0)
    heatmap = heatmap / np.max(heatmap + K.epsilon())
    return heatmap.numpy()

# --- Enhanced Overlay heatmap on image ---
def overlay_gradcam(original_img, heatmap):
    if heatmap is None:
        return original_img

    # Resize heatmap
    heatmap = cv2.resize(heatmap, original_img.size)

    # Normalize safely
    heatmap = np.maximum(heatmap, 0)
    if np.max(heatmap) != 0:
        heatmap /= np.max(heatmap)
    heatmap = np.uint8(255 * heatmap)

    # Apply JET colormap for better medical visualization
    heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    # Convert PIL to array
    original_array = np.array(original_img.convert("RGB"))

    # Enhanced blend with better contrast
    superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0)

    return Image.fromarray(superimposed_img)

# --- Enhanced Prediction Function ---
def classify_and_explain(img):
    if img is None:
        return None, {}, "No image provided"

    img_array = preprocess_input(img)
    predictions = model.predict(img_array, verbose=0)[0]
    pred_idx = int(np.argmax(predictions))
    pred_class = class_labels[pred_idx]
    confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))}

    # Enhanced Grad-CAM
    try:
        heatmap = get_gradcam_heatmap(img_array, model, pred_idx)
        gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap)
    except Exception as e:
        print(f"Grad-CAM error: {e}")
        gradcam_img = img.resize((224, 224))  # fallback image

    return gradcam_img, confidence_dict

# --- Custom CSS for Dark Mode Medical Interface ---
css = """
.gradio-container {
    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
    background: #1a1a1a;
    min-height: 100vh;
    padding: 20px;
    color: #ffffff;
}

.main-header {
    text-align: center;
    color: white;
    margin-bottom: 2rem;
    padding: 2rem 0;
}

.main-header h1 {
    font-size: 2.5rem;
    margin-bottom: 0.5rem;
    text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
    color: #ffffff;
}

.confidence-bar {
    background: linear-gradient(90deg, #3498db 0%, #2ecc71 100%);
    height: 25px;
    border-radius: 12px;
    margin: 8px 0;
    transition: all 0.3s ease;
    box-shadow: 0 2px 4px rgba(0,0,0,0.3);
}

.confidence-container {
    margin: 15px 0;
    padding: 20px;
    border-radius: 12px;
    background: rgba(255,255,255,0.1);
    backdrop-filter: blur(10px);
    box-shadow: 0 8px 32px rgba(0,0,0,0.3);
    border: 1px solid rgba(255,255,255,0.1);
}

.input-section, .output-section {
    background: rgba(255,255,255,0.05);
    padding: 25px;
    border-radius: 15px;
    margin: 15px;
    backdrop-filter: blur(10px);
    box-shadow: 0 8px 32px rgba(0,0,0,0.3);
    border: 1px solid rgba(255,255,255,0.1);
}

.section-title {
    color: #ffffff;
    font-size: 1.3rem;
    font-weight: 600;
    margin-bottom: 15px;
    border-bottom: 2px solid #3498db;
    padding-bottom: 8px;
}

.gradio-button {
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
    border: none;
    color: white;
    padding: 12px 24px;
    border-radius: 25px;
    font-weight: 600;
    transition: all 0.3s ease;
    box-shadow: 0 4px 15px rgba(0,0,0,0.3);
}

.gradio-button:hover {
    transform: translateY(-2px);
    box-shadow: 0 6px 20px rgba(0,0,0,0.4);
}

.gradio-image {
    border-radius: 12px;
    box-shadow: 0 4px 15px rgba(0,0,0,0.3);
    border: 1px solid rgba(255,255,255,0.1);
}

.gradio-textbox, .gradio-number {
    border-radius: 8px;
    border: 2px solid #333333;
    padding: 12px;
    font-size: 1rem;
    background: rgba(255,255,255,0.05);
    color: #ffffff;
}

.gradio-textbox:focus, .gradio-number:focus {
    border-color: #3498db;
    box-shadow: 0 0 0 0.2rem rgba(52,152,219,0.25);
}

.gradio-label {
    color: #ffffff !important;
}

.heatmap-container {
    background: rgba(255,255,255,0.05);
    padding: 15px;
    border-radius: 12px;
    border: 1px solid rgba(255,255,255,0.1);
    margin: 10px 0;
}

.prediction-container {
    background: rgba(52,152,219,0.1);
    padding: 20px;
    border-radius: 12px;
    border-left: 5px solid #3498db;
    margin: 15px 0;
}
"""

# --- Function to create confidence bars HTML ---
def create_confidence_bars(confidence_dict):
    html_content = "<div class='confidence-container'>"
    for class_name, confidence in confidence_dict.items():
        percentage = confidence * 100
        # Color coding based on confidence
        if percentage > 70:
            color = "#28a745"  # Green for high confidence
        elif percentage > 40:
            color = "#ffc107"  # Yellow for medium confidence
        else:
            color = "#dc3545"  # Red for low confidence

        html_content += f"""
            <div style='margin: 12px 0;'>
                <div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
                    <span style='font-weight: bold; color: {color};'>{class_name}</span>
                    <span style='font-weight: bold; color: {color};'>{percentage:.1f}%</span>
                </div>
                <div class='confidence-bar' style='width: {percentage}%; background: {color};'></div>
            </div>
        """
    html_content += "</div>"
    return html_content

# --- Enhanced Prediction Function with Dark Mode Interface ---
def enhanced_classify_and_explain(img):
    if img is None:
        return None, "No image provided", 0, ""

    gradcam_img, confidence_dict = classify_and_explain(img)

    # Get predicted class and confidence
    pred_class = max(confidence_dict, key=confidence_dict.get)
    confidence = confidence_dict[pred_class]

    # Create confidence bars HTML
    confidence_bars_html = create_confidence_bars(confidence_dict)

    return gradcam_img, pred_class, confidence, confidence_bars_html

# --- Enhanced Gradio Interface ---
with gr.Blocks(css=css, title="Wound Classification") as demo:
    gr.HTML("""
        <div class="main-header">
            <h1>Wound Classification</h1>
        </div>
    """)

    with gr.Row():
        with gr.Column(scale=1):
            gr.HTML("<div class='section-title'>Input Image</div>")
            input_image = gr.Image(
                label="Upload wound image",
                type="pil",
                height=350,
                container=True
            )

        with gr.Column(scale=1):
            gr.HTML("<div class='section-title'>Analysis Results</div>")

            # Prediction results
            prediction_output = gr.Textbox(
                label="Predicted Wound Type",
                interactive=False,
                container=True
            )

            confidence_output = gr.Number(
                label="Confidence Score",
                interactive=False,
                container=True
            )

            # Confidence bars for all classes
            confidence_bars = gr.HTML(
                label="Confidence Scores by Class",
                container=True
            )

    with gr.Row():
        with gr.Column():
            gr.HTML("<div class='section-title'>Model Focus Visualization</div>")
            cam_output = gr.Image(
                label="Grad-CAM Heatmap - Shows which areas the model focused on",
                height=350,
                container=True
            )

    # Event handlers
    input_image.change(
        fn=enhanced_classify_and_explain,
        inputs=[input_image],
        outputs=[cam_output, prediction_output, confidence_output, confidence_bars]
    )

# --- Launch the enhanced interface ---
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=True,
        show_error=True
    )