Spaces:
Running
on
Zero
Running
on
Zero
add new model & timer
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import torchvision.transforms.v2 as transforms
|
|
7 |
from torchvision.io import read_image
|
8 |
from typing import Dict
|
9 |
import os
|
|
|
10 |
from huggingface_hub import login
|
11 |
|
12 |
|
@@ -46,7 +47,7 @@ class VAETester:
|
|
46 |
endpoints = {
|
47 |
"sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
|
48 |
"sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
|
49 |
-
"FLUX.1
|
50 |
}
|
51 |
return endpoints[base_name]
|
52 |
|
@@ -57,8 +58,8 @@ class VAETester:
|
|
57 |
"sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
|
58 |
"sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
|
59 |
"stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
|
60 |
-
"FLUX.1
|
61 |
-
"
|
62 |
}
|
63 |
# Define the desired order of models
|
64 |
order = [
|
@@ -68,9 +69,9 @@ class VAETester:
|
|
68 |
"sdxl-vae",
|
69 |
#"sdxl-vae (remote)",
|
70 |
"stable-diffusion-3-medium",
|
71 |
-
"FLUX.1
|
72 |
-
#"FLUX.1
|
73 |
-
"
|
74 |
]
|
75 |
|
76 |
# Construct the vae_models dictionary in the specified order
|
@@ -95,6 +96,9 @@ class VAETester:
|
|
95 |
img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
|
96 |
original_base = self.base_transform(img).cpu()
|
97 |
|
|
|
|
|
|
|
98 |
if model_config["type"] == "local":
|
99 |
vae = model_config["vae"]
|
100 |
with torch.no_grad():
|
@@ -112,6 +116,10 @@ class VAETester:
|
|
112 |
return_type="pt",
|
113 |
partial_postprocess=False,
|
114 |
)
|
|
|
|
|
|
|
|
|
115 |
decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
|
116 |
reconstructed = decoded_transformed.clip(0, 1)
|
117 |
diff = (original_base - reconstructed).abs()
|
@@ -119,14 +127,13 @@ class VAETester:
|
|
119 |
diff_image = transforms.ToPILImage()(bw_diff)
|
120 |
recon_image = transforms.ToPILImage()(reconstructed)
|
121 |
diff_score = bw_diff.sum().item()
|
122 |
-
return diff_image, recon_image, diff_score
|
123 |
|
124 |
def process_all_models(self, img: torch.Tensor, tolerance: float):
|
125 |
"""Process image through all configured VAEs"""
|
126 |
results = {}
|
127 |
for name, model_config in self.vae_models.items():
|
128 |
-
|
129 |
-
results[name] = (diff_img, recon_img, score)
|
130 |
return results
|
131 |
|
132 |
@spaces.GPU(duration=15)
|
@@ -142,10 +149,10 @@ def test_all_vaes(image_path: str, tolerance: float, img_size: int):
|
|
142 |
scores = []
|
143 |
|
144 |
for name in tester.vae_models.keys():
|
145 |
-
diff_img, recon_img, score = results[name]
|
146 |
diff_images.append((diff_img, name))
|
147 |
recon_images.append((recon_img, name))
|
148 |
-
scores.append(f"{name:<25}: {score
|
149 |
|
150 |
return diff_images, recon_images, "\n".join(scores)
|
151 |
except Exception as e:
|
@@ -157,13 +164,13 @@ examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("ex
|
|
157 |
with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
|
158 |
gr.Markdown("# VAE Comparison Tool")
|
159 |
gr.Markdown("""
|
160 |
-
Upload an image or select an example to compare how different VAEs reconstruct it.
|
161 |
1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
|
162 |
-
2. Each VAE
|
163 |
3. Outputs include:
|
164 |
- **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
|
165 |
- **Reconstructed Images**: Outputs from each VAE.
|
166 |
-
- **Sum of Differences**: Total pixels exceeding tolerance (lower is better).
|
167 |
Adjust tolerance to change sensitivity.
|
168 |
""")
|
169 |
|
@@ -185,7 +192,7 @@ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family
|
|
185 |
with gr.Row():
|
186 |
diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
|
187 |
recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
|
188 |
-
scores_output = gr.Textbox(label="Sum of differences (lower is better)", lines=9, elem_classes="monospace-text")
|
189 |
|
190 |
if examples:
|
191 |
with gr.Row():
|
|
|
7 |
from torchvision.io import read_image
|
8 |
from typing import Dict
|
9 |
import os
|
10 |
+
import time
|
11 |
from huggingface_hub import login
|
12 |
|
13 |
|
|
|
47 |
endpoints = {
|
48 |
"sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
|
49 |
"sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
|
50 |
+
"FLUX.1": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
|
51 |
}
|
52 |
return endpoints[base_name]
|
53 |
|
|
|
58 |
"sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
|
59 |
"sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
|
60 |
"stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
|
61 |
+
"FLUX.1": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
|
62 |
+
"CogView4-6B": AutoencoderKL.from_pretrained("THUDM/CogView4-6B", subfolder="vae").to(self.device),
|
63 |
}
|
64 |
# Define the desired order of models
|
65 |
order = [
|
|
|
69 |
"sdxl-vae",
|
70 |
#"sdxl-vae (remote)",
|
71 |
"stable-diffusion-3-medium",
|
72 |
+
"FLUX.1",
|
73 |
+
#"FLUX.1 (remote)",
|
74 |
+
"CogView4-6B",
|
75 |
]
|
76 |
|
77 |
# Construct the vae_models dictionary in the specified order
|
|
|
96 |
img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
|
97 |
original_base = self.base_transform(img).cpu()
|
98 |
|
99 |
+
# Start timer
|
100 |
+
start_time = time.time()
|
101 |
+
|
102 |
if model_config["type"] == "local":
|
103 |
vae = model_config["vae"]
|
104 |
with torch.no_grad():
|
|
|
116 |
return_type="pt",
|
117 |
partial_postprocess=False,
|
118 |
)
|
119 |
+
|
120 |
+
# End timer
|
121 |
+
processing_time = time.time() - start_time
|
122 |
+
|
123 |
decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
|
124 |
reconstructed = decoded_transformed.clip(0, 1)
|
125 |
diff = (original_base - reconstructed).abs()
|
|
|
127 |
diff_image = transforms.ToPILImage()(bw_diff)
|
128 |
recon_image = transforms.ToPILImage()(reconstructed)
|
129 |
diff_score = bw_diff.sum().item()
|
130 |
+
return diff_image, recon_image, diff_score, processing_time
|
131 |
|
132 |
def process_all_models(self, img: torch.Tensor, tolerance: float):
|
133 |
"""Process image through all configured VAEs"""
|
134 |
results = {}
|
135 |
for name, model_config in self.vae_models.items():
|
136 |
+
results[name] = self.process_image(img, model_config, tolerance)
|
|
|
137 |
return results
|
138 |
|
139 |
@spaces.GPU(duration=15)
|
|
|
149 |
scores = []
|
150 |
|
151 |
for name in tester.vae_models.keys():
|
152 |
+
diff_img, recon_img, score, proc_time = results[name]
|
153 |
diff_images.append((diff_img, name))
|
154 |
recon_images.append((recon_img, name))
|
155 |
+
scores.append(f"{name:<25}: {score:7,.0f} | {proc_time:.4f}s")
|
156 |
|
157 |
return diff_images, recon_images, "\n".join(scores)
|
158 |
except Exception as e:
|
|
|
164 |
with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
|
165 |
gr.Markdown("# VAE Comparison Tool")
|
166 |
gr.Markdown("""
|
167 |
+
Upload an image or select an example to compare how different VAEs reconstruct it.
|
168 |
1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
|
169 |
+
2. Each VAE encodes the image into a latent space and decodes it back.
|
170 |
3. Outputs include:
|
171 |
- **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
|
172 |
- **Reconstructed Images**: Outputs from each VAE.
|
173 |
+
- **Sum of Differences and Time**: Total pixels exceeding tolerance (lower is better) and processing time in seconds.
|
174 |
Adjust tolerance to change sensitivity.
|
175 |
""")
|
176 |
|
|
|
192 |
with gr.Row():
|
193 |
diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
|
194 |
recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
|
195 |
+
scores_output = gr.Textbox(label="Sum of differences (lower is better) | Processing time (lower is faster)", lines=9, elem_classes="monospace-text")
|
196 |
|
197 |
if examples:
|
198 |
with gr.Row():
|