danhtran2mind commited on
Commit
c8f6bca
·
verified ·
1 Parent(s): 34f4a67

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ import numpy as np
5
+ import tensorflow as tf
6
+ import requests
7
+ from skimage.color import lab2rgb
8
+
9
+ # Model paths and mapping
10
+ load_model_paths = [
11
+ "ckpts/autoencoder/autoencoder_colorization_model.h5",
12
+ "ckpts/unet/unet_colorization_model.keras",
13
+ "ckpts/unet/unet_colorization_model.keras"
14
+ ]
15
+
16
+ # Custom object needed by models
17
+ from models.auto_encoder_gray2color import SpatialAttention
18
+
19
+ # Model input size
20
+ WIDTH, HEIGHT = 512, 512
21
+
22
+ # Download models if they don't exist
23
+ def download_model(url, path):
24
+ os.makedirs(os.path.dirname(path), exist_ok=True)
25
+ print(f"Downloading model from {url}...")
26
+ with requests.get(url, stream=True) as r:
27
+ r.raise_for_status()
28
+ with open(path, "wb") as f:
29
+ for chunk in r.iter_content(chunk_size=8192):
30
+ f.write(chunk)
31
+ print("Download complete.")
32
+
33
+ # Helper to dynamically load a model
34
+ def load_model(model_path):
35
+ if not os.path.exists(model_path):
36
+ if "autoencoder" in model_path:
37
+ url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/autoencoder_colorization_model.h5"
38
+ elif "unet" in model_path:
39
+ url = "https://huggingface.co/danhtran2mind/autoencoder-grayscale2color-landscape/resolve/main/ckpts/unet_colorization_model.keras"
40
+ else:
41
+ raise ValueError("Unknown model path for downloading.")
42
+ download_model(url, model_path)
43
+ print(f"Loading model from {model_path}...")
44
+ return tf.keras.models.load_model(
45
+ model_path,
46
+ custom_objects={'SpatialAttention': SpatialAttention}
47
+ )
48
+
49
+ # Dictionary of loaded models
50
+ loaded_models = {
51
+ "Autoencoder": load_model(load_model_paths[0]),
52
+ "U-Net v1": load_model(load_model_paths[1]),
53
+ "U-Net v2": load_model(load_model_paths[2])
54
+ }
55
+
56
+ def process_image(input_img, model_type):
57
+ model = loaded_models[model_type]
58
+
59
+ # Store original input dimensions
60
+ original_width, original_height = input_img.size
61
+
62
+ # Convert PIL Image to grayscale and resize to model input size
63
+ img = input_img.convert("L") # Grayscale
64
+ img = img.resize((WIDTH, HEIGHT)) # Resize to match model input
65
+ img_array = tf.keras.preprocessing.image.img_to_array(img) / 255.0 # Normalize
66
+ img_array = img_array[None, ..., 0:1] # Add batch dim (B, H, W, C)
67
+
68
+ # Predict a*b* channels
69
+ output_array = model.predict(img_array)
70
+ print("Model Output Shape:", output_array.shape)
71
+
72
+ L_channel = img_array[0, :, :, 0] * 100.0
73
+ ab_channels = output_array[0] * 128.0 # Denormalize ab to [-128, 128]
74
+
75
+ # Combine into Lab image
76
+ lab_image = np.stack([L_channel, ab_channels[:, :, 0], ab_channels[:, :, 1]], axis=-1)
77
+
78
+ # Convert to RGB
79
+ rgb_array = lab2rgb(lab_image)
80
+ rgb_array = np.clip(rgb_array, 0, 1) * 255.0
81
+ rgb_image = Image.fromarray(rgb_array.astype(np.uint8), 'RGB')
82
+
83
+ # Resize back to original resolution
84
+ rgb_image = rgb_image.resize((original_width, original_height), Image.Resampling.LANCZOS)
85
+
86
+ return rgb_image
87
+
88
+ custom_css = """
89
+ body {background: linear-gradient(135deg, #f0f4f8 0%, #d9e2ec 100%) !important;}
90
+ .gradio-container {background: transparent !important;}
91
+ h1, .gr-title {color: #007bff !important; font-family: 'Segoe UI', sans-serif;}
92
+ .gr-description {color: #333333 !important; font-size: 1.1em;}
93
+ .gr-input, .gr-output {border-radius: 18px !important; box-shadow: 0 4px 24px rgba(0,0,0,0.1);}
94
+ .gr-button {background: linear-gradient(90deg, #007bff 0%, #00c4cc 100%) !important; color: #fff !important; border: none !important; border-radius: 12px !important;}
95
+ """
96
+
97
+ with gr.Blocks(theme="soft", css=custom_css) as demo:
98
+ gr.Markdown("<h1 style='text-align:center;'>🌄 Gray2Color Landscape Autoencoder</h1>")
99
+ gr.Markdown(
100
+ "<div style='font-size:1.15em;line-height:1.6em;text-align:center;'>"
101
+ "Transform grayscale landscape photos into vivid color using AI.<br>"
102
+ "Upload a grayscale image and select a model to begin!"
103
+ "</div>"
104
+ )
105
+ with gr.Row():
106
+ image_input = gr.Image(type="pil", label="Upload Grayscale Landscape", image_mode="L")
107
+ image_output = gr.Image(type="pil", label="Colorized Output")
108
+ model_selector = gr.Dropdown(
109
+ choices=["Autoencoder", "U-Net v1", "U-Net v2"],
110
+ label="Select Model",
111
+ value="Autoencoder"
112
+ )
113
+ run_button = gr.Button("🎨 Colorize")
114
+ run_button.click(fn=process_image, inputs=[image_input, model_selector], outputs=image_output)
115
+
116
+ gr.Examples(
117
+ examples=[
118
+ ["examples/example_input_1.jpg"],
119
+ ["examples/example_input_2.jpg"]
120
+ ],
121
+ inputs=[image_input],
122
+ outputs=image_output,
123
+ fn=lambda x: process_image(x, "Autoencoder"), # Default example model choice
124
+ cache_examples=True
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()