Spaces:
Sleeping
Sleeping
Add application files
Browse files
app.py
CHANGED
@@ -19,11 +19,15 @@ from scipy.special import kl_div as scipy_kl_div
|
|
19 |
from skimage.metrics import structural_similarity as ssim
|
20 |
import warnings
|
21 |
|
22 |
-
|
|
|
|
|
23 |
PATH = os.path.join("datasets", TASK, "real")
|
24 |
images = []
|
25 |
perceptual_model = None
|
26 |
|
|
|
|
|
27 |
try:
|
28 |
vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
|
29 |
vgg.trainable = False
|
@@ -31,7 +35,9 @@ try:
|
|
31 |
except Exception as e:
|
32 |
perceptual_model = None
|
33 |
|
|
|
34 |
def safe_normalize_heatmap(heatmap):
|
|
|
35 |
if heatmap is None or heatmap.size == 0:
|
36 |
return np.zeros((64, 64), dtype=np.uint8)
|
37 |
heatmap = heatmap.astype(np.float32)
|
@@ -48,7 +54,10 @@ def safe_normalize_heatmap(heatmap):
|
|
48 |
normalized_heatmap = np.clip(normalized_heatmap, 0, 255)
|
49 |
return np.uint8(normalized_heatmap)
|
50 |
|
|
|
|
|
51 |
def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
|
|
52 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape:
|
53 |
return None
|
54 |
try:
|
@@ -92,6 +101,7 @@ def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=
|
|
92 |
return img_dict
|
93 |
|
94 |
def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
|
|
95 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
96 |
try:
|
97 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
@@ -121,6 +131,7 @@ def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False)
|
|
121 |
return img_dict
|
122 |
|
123 |
def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
|
|
124 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
125 |
try:
|
126 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
@@ -150,6 +161,7 @@ def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False
|
|
150 |
return img_dict
|
151 |
|
152 |
def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False):
|
|
|
153 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
154 |
try:
|
155 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB)
|
@@ -177,6 +189,7 @@ def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False):
|
|
177 |
return img_dict
|
178 |
|
179 |
def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
|
|
180 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
181 |
try:
|
182 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
@@ -212,6 +225,7 @@ def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_
|
|
212 |
return img_dict
|
213 |
|
214 |
def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
|
|
215 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
216 |
try:
|
217 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
@@ -241,6 +255,7 @@ def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False)
|
|
241 |
return img_dict
|
242 |
|
243 |
def perceptual_loss(img_real, img_fake, model, block_size=4):
|
|
|
244 |
if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape:
|
245 |
return None
|
246 |
original_height, original_width, _ = img_real.shape
|
@@ -264,7 +279,10 @@ def perceptual_loss(img_real, img_fake, model, block_size=4):
|
|
264 |
heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
|
265 |
return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}}
|
266 |
|
|
|
|
|
267 |
def gather_images(task):
|
|
|
268 |
global TASK, PATH, images
|
269 |
new_path = os.path.join("datasets", task, "real")
|
270 |
if TASK != task or not images:
|
@@ -311,6 +329,7 @@ def gather_images(task):
|
|
311 |
return placeholder, placeholder, error_msg
|
312 |
|
313 |
def run_comparison(real, fake, measurement, block_size_val):
|
|
|
314 |
placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8)
|
315 |
if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray):
|
316 |
return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images."
|
@@ -353,8 +372,25 @@ def run_comparison(real, fake, measurement, block_size_val):
|
|
353 |
return heatmap_rgb, status_msg
|
354 |
|
355 |
def clear_uploads(msg):
|
|
|
356 |
return None, None, msg
|
357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange")
|
359 |
with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo:
|
360 |
gr.Markdown("# GAN vs Ground Truth Image Comparison")
|
@@ -368,7 +404,7 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
|
|
368 |
with gr.Tabs():
|
369 |
with gr.TabItem("Load from Dataset"):
|
370 |
task_dropdown = gr.Dropdown(
|
371 |
-
["
|
372 |
info="Select the dataset task.",
|
373 |
label="Dataset Task"
|
374 |
)
|
@@ -387,7 +423,7 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
|
|
387 |
with gr.Row():
|
388 |
measurement_dropdown = gr.Dropdown(
|
389 |
["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"],
|
390 |
-
value="
|
391 |
info="Select the comparison metric.",
|
392 |
label="Comparison Metric",
|
393 |
scale=2
|
@@ -404,6 +440,15 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
|
|
404 |
gr.Markdown("### 3. See Result")
|
405 |
heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False)
|
406 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
sample_button.click(
|
408 |
fn=gather_images,
|
409 |
inputs=[task_dropdown],
|
@@ -434,7 +479,7 @@ with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !importan
|
|
434 |
outputs=[real_img_display, fake_img_display, status_message]
|
435 |
)
|
436 |
|
437 |
-
|
438 |
if __name__ == "__main__":
|
439 |
print("-------------------------------------------------------------")
|
440 |
print("Verifying VGG19 model status...")
|
|
|
19 |
from skimage.metrics import structural_similarity as ssim
|
20 |
import warnings
|
21 |
|
22 |
+
# --- Configuration ---
|
23 |
+
# Set the default task.
|
24 |
+
TASK = "facades"
|
25 |
PATH = os.path.join("datasets", TASK, "real")
|
26 |
images = []
|
27 |
perceptual_model = None
|
28 |
|
29 |
+
# --- Model Loading ---
|
30 |
+
# Attempt to load the VGG19 model for the perceptual loss metric.
|
31 |
try:
|
32 |
vgg = VGG19(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
|
33 |
vgg.trainable = False
|
|
|
35 |
except Exception as e:
|
36 |
perceptual_model = None
|
37 |
|
38 |
+
# --- Utility Functions ---
|
39 |
def safe_normalize_heatmap(heatmap):
|
40 |
+
"""Safely normalizes a heatmap to a 0-255 range for visualization."""
|
41 |
if heatmap is None or heatmap.size == 0:
|
42 |
return np.zeros((64, 64), dtype=np.uint8)
|
43 |
heatmap = heatmap.astype(np.float32)
|
|
|
54 |
normalized_heatmap = np.clip(normalized_heatmap, 0, 255)
|
55 |
return np.uint8(normalized_heatmap)
|
56 |
|
57 |
+
# --- Image Comparison Metrics ---
|
58 |
+
|
59 |
def KL_divergence(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
60 |
+
"""Calculates Kullback-Leibler Divergence between two images."""
|
61 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape:
|
62 |
return None
|
63 |
try:
|
|
|
101 |
return img_dict
|
102 |
|
103 |
def L1_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
104 |
+
"""Calculates L1 (Mean Absolute Error) loss between two images."""
|
105 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
106 |
try:
|
107 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
|
|
131 |
return img_dict
|
132 |
|
133 |
def MSE_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
134 |
+
"""Calculates MSE (Mean Squared Error) loss between two images."""
|
135 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
136 |
try:
|
137 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
|
|
161 |
return img_dict
|
162 |
|
163 |
def SSIM_loss(img_real, img_fake, block_size=7, sum_channels=False):
|
164 |
+
"""Calculates SSIM (Structural Similarity Index) loss between two images."""
|
165 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
166 |
try:
|
167 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB)
|
|
|
189 |
return img_dict
|
190 |
|
191 |
def cosine_similarity_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
192 |
+
"""Calculates Cosine Similarity loss between two images."""
|
193 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
194 |
try:
|
195 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
|
|
225 |
return img_dict
|
226 |
|
227 |
def TV_loss(img_real, img_fake, epsilon=1e-10, block_size=4, sum_channels=False):
|
228 |
+
"""Calculates Total Variation (TV) loss between two images."""
|
229 |
if img_real is None or img_fake is None or img_real.shape != img_fake.shape: return None
|
230 |
try:
|
231 |
img_real_rgb = cv2.cvtColor(img_real, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
|
|
|
255 |
return img_dict
|
256 |
|
257 |
def perceptual_loss(img_real, img_fake, model, block_size=4):
|
258 |
+
"""Calculates Perceptual loss using a pre-trained VGG19 model."""
|
259 |
if img_real is None or img_fake is None or model is None or img_real.shape != img_fake.shape:
|
260 |
return None
|
261 |
original_height, original_width, _ = img_real.shape
|
|
|
279 |
heatmap_original_size = cv2.resize(heatmap_features, (original_width, original_height), interpolation=cv2.INTER_LINEAR)
|
280 |
return {"SUM": {"SUM": total_loss, "HEATMAP": heatmap_original_size.astype(np.float32)}}
|
281 |
|
282 |
+
# --- Gradio Core Functions ---
|
283 |
+
|
284 |
def gather_images(task):
|
285 |
+
"""Loads a random pair of real and fake images from the selected dataset."""
|
286 |
global TASK, PATH, images
|
287 |
new_path = os.path.join("datasets", task, "real")
|
288 |
if TASK != task or not images:
|
|
|
329 |
return placeholder, placeholder, error_msg
|
330 |
|
331 |
def run_comparison(real, fake, measurement, block_size_val):
|
332 |
+
"""Runs the selected comparison metric and generates a heatmap."""
|
333 |
placeholder_heatmap = np.zeros((64, 64, 3), dtype=np.uint8)
|
334 |
if real is None or fake is None or not isinstance(real, np.ndarray) or not isinstance(fake, np.ndarray):
|
335 |
return placeholder_heatmap, "Error: Input image(s) missing or invalid. Please load or upload a pair of images."
|
|
|
372 |
return heatmap_rgb, status_msg
|
373 |
|
374 |
def clear_uploads(msg):
|
375 |
+
"""Clears the image displays and updates the status message."""
|
376 |
return None, None, msg
|
377 |
|
378 |
+
def load_and_compare_initial(task):
|
379 |
+
"""Gathers initial images and runs a comparison on them at startup."""
|
380 |
+
# Step 1: Get the initial images
|
381 |
+
real_img, fake_img, gather_status = gather_images(task)
|
382 |
+
|
383 |
+
# Step 2: Run the default comparison
|
384 |
+
# We use the default values from the UI definition
|
385 |
+
default_measurement = "Cosine Similarity"
|
386 |
+
default_block_size = 8
|
387 |
+
heatmap, compare_status = run_comparison(real_img, fake_img, default_measurement, default_block_size)
|
388 |
+
|
389 |
+
# Step 3: Combine status messages and return all initial values
|
390 |
+
final_status = f"{gather_status}\n{compare_status}"
|
391 |
+
return real_img, fake_img, heatmap, final_status
|
392 |
+
|
393 |
+
# --- Gradio UI Definition ---
|
394 |
theme = gr.themes.Soft(primary_hue="blue", secondary_hue="orange")
|
395 |
with gr.Blocks(theme=theme, css=".gradio-container { max-width: 1400px !important; margin: auto; }") as demo:
|
396 |
gr.Markdown("# GAN vs Ground Truth Image Comparison")
|
|
|
404 |
with gr.Tabs():
|
405 |
with gr.TabItem("Load from Dataset"):
|
406 |
task_dropdown = gr.Dropdown(
|
407 |
+
["facades"], value=TASK,
|
408 |
info="Select the dataset task.",
|
409 |
label="Dataset Task"
|
410 |
)
|
|
|
423 |
with gr.Row():
|
424 |
measurement_dropdown = gr.Dropdown(
|
425 |
["Kullback-Leibler Divergence", "L1-Loss", "MSE", "SSIM", "Cosine Similarity", "TV", "Perceptual"],
|
426 |
+
value="Cosine Similarity",
|
427 |
info="Select the comparison metric.",
|
428 |
label="Comparison Metric",
|
429 |
scale=2
|
|
|
440 |
gr.Markdown("### 3. See Result")
|
441 |
heatmap_display = gr.Image(type="numpy", label="Comparison Heatmap (Difference)", height=350, interactive=False)
|
442 |
|
443 |
+
# --- Event Listeners ---
|
444 |
+
|
445 |
+
# Load initial sample and run comparison when the app starts
|
446 |
+
demo.load(
|
447 |
+
fn=load_and_compare_initial,
|
448 |
+
inputs=[task_dropdown],
|
449 |
+
outputs=[real_img_display, fake_img_display, heatmap_display, status_message]
|
450 |
+
)
|
451 |
+
|
452 |
sample_button.click(
|
453 |
fn=gather_images,
|
454 |
inputs=[task_dropdown],
|
|
|
479 |
outputs=[real_img_display, fake_img_display, status_message]
|
480 |
)
|
481 |
|
482 |
+
# --- Application Entry Point ---
|
483 |
if __name__ == "__main__":
|
484 |
print("-------------------------------------------------------------")
|
485 |
print("Verifying VGG19 model status...")
|