rizavelioglu commited on
Commit
a766367
·
verified ·
1 Parent(s): b9da44c

add new model & timer

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -7,6 +7,7 @@ import torchvision.transforms.v2 as transforms
7
  from torchvision.io import read_image
8
  from typing import Dict
9
  import os
 
10
  from huggingface_hub import login
11
 
12
 
@@ -46,7 +47,7 @@ class VAETester:
46
  endpoints = {
47
  "sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
48
  "sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
49
- "FLUX.1-schnell": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
50
  }
51
  return endpoints[base_name]
52
 
@@ -57,8 +58,8 @@ class VAETester:
57
  "sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
58
  "sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
59
  "stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
60
- "FLUX.1-schnell": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae").to(self.device),
61
- "FLUX.1-dev": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
62
  }
63
  # Define the desired order of models
64
  order = [
@@ -68,9 +69,9 @@ class VAETester:
68
  "sdxl-vae",
69
  #"sdxl-vae (remote)",
70
  "stable-diffusion-3-medium",
71
- "FLUX.1-schnell",
72
- #"FLUX.1-schnell (remote)",
73
- "FLUX.1-dev",
74
  ]
75
 
76
  # Construct the vae_models dictionary in the specified order
@@ -95,6 +96,9 @@ class VAETester:
95
  img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
96
  original_base = self.base_transform(img).cpu()
97
 
 
 
 
98
  if model_config["type"] == "local":
99
  vae = model_config["vae"]
100
  with torch.no_grad():
@@ -112,6 +116,10 @@ class VAETester:
112
  return_type="pt",
113
  partial_postprocess=False,
114
  )
 
 
 
 
115
  decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
116
  reconstructed = decoded_transformed.clip(0, 1)
117
  diff = (original_base - reconstructed).abs()
@@ -119,14 +127,13 @@ class VAETester:
119
  diff_image = transforms.ToPILImage()(bw_diff)
120
  recon_image = transforms.ToPILImage()(reconstructed)
121
  diff_score = bw_diff.sum().item()
122
- return diff_image, recon_image, diff_score
123
 
124
  def process_all_models(self, img: torch.Tensor, tolerance: float):
125
  """Process image through all configured VAEs"""
126
  results = {}
127
  for name, model_config in self.vae_models.items():
128
- diff_img, recon_img, score = self.process_image(img, model_config, tolerance)
129
- results[name] = (diff_img, recon_img, score)
130
  return results
131
 
132
  @spaces.GPU(duration=15)
@@ -142,10 +149,10 @@ def test_all_vaes(image_path: str, tolerance: float, img_size: int):
142
  scores = []
143
 
144
  for name in tester.vae_models.keys():
145
- diff_img, recon_img, score = results[name]
146
  diff_images.append((diff_img, name))
147
  recon_images.append((recon_img, name))
148
- scores.append(f"{name:<25}: {score:,.0f}")
149
 
150
  return diff_images, recon_images, "\n".join(scores)
151
  except Exception as e:
@@ -157,13 +164,13 @@ examples = [f"examples/{img_filename}" for img_filename in sorted(os.listdir("ex
157
  with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
158
  gr.Markdown("# VAE Comparison Tool")
159
  gr.Markdown("""
160
- Upload an image or select an example to compare how different VAEs reconstruct it. Now includes remote VAEs via Hugging Face's remote decoding feature!
161
  1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
162
- 2. Each VAE (local or remote) encodes the image into a latent space and decodes it back.
163
  3. Outputs include:
164
  - **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
165
  - **Reconstructed Images**: Outputs from each VAE.
166
- - **Sum of Differences**: Total pixels exceeding tolerance (lower is better).
167
  Adjust tolerance to change sensitivity.
168
  """)
169
 
@@ -185,7 +192,7 @@ with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family
185
  with gr.Row():
186
  diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
187
  recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
188
- scores_output = gr.Textbox(label="Sum of differences (lower is better)", lines=9, elem_classes="monospace-text")
189
 
190
  if examples:
191
  with gr.Row():
 
7
  from torchvision.io import read_image
8
  from typing import Dict
9
  import os
10
+ import time
11
  from huggingface_hub import login
12
 
13
 
 
47
  endpoints = {
48
  "sd-vae-ft-mse": "https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud",
49
  "sdxl-vae": "https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud",
50
+ "FLUX.1": "https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
51
  }
52
  return endpoints[base_name]
53
 
 
58
  "sd-vae-ft-mse": AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(self.device),
59
  "sdxl-vae": AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(self.device),
60
  "stable-diffusion-3-medium": AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="vae").to(self.device),
61
+ "FLUX.1": AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae").to(self.device),
62
+ "CogView4-6B": AutoencoderKL.from_pretrained("THUDM/CogView4-6B", subfolder="vae").to(self.device),
63
  }
64
  # Define the desired order of models
65
  order = [
 
69
  "sdxl-vae",
70
  #"sdxl-vae (remote)",
71
  "stable-diffusion-3-medium",
72
+ "FLUX.1",
73
+ #"FLUX.1 (remote)",
74
+ "CogView4-6B",
75
  ]
76
 
77
  # Construct the vae_models dictionary in the specified order
 
96
  img_transformed = self.input_transform(img).to(self.device).unsqueeze(0)
97
  original_base = self.base_transform(img).cpu()
98
 
99
+ # Start timer
100
+ start_time = time.time()
101
+
102
  if model_config["type"] == "local":
103
  vae = model_config["vae"]
104
  with torch.no_grad():
 
116
  return_type="pt",
117
  partial_postprocess=False,
118
  )
119
+
120
+ # End timer
121
+ processing_time = time.time() - start_time
122
+
123
  decoded_transformed = self.output_transform(decoded.squeeze(0)).cpu()
124
  reconstructed = decoded_transformed.clip(0, 1)
125
  diff = (original_base - reconstructed).abs()
 
127
  diff_image = transforms.ToPILImage()(bw_diff)
128
  recon_image = transforms.ToPILImage()(reconstructed)
129
  diff_score = bw_diff.sum().item()
130
+ return diff_image, recon_image, diff_score, processing_time
131
 
132
  def process_all_models(self, img: torch.Tensor, tolerance: float):
133
  """Process image through all configured VAEs"""
134
  results = {}
135
  for name, model_config in self.vae_models.items():
136
+ results[name] = self.process_image(img, model_config, tolerance)
 
137
  return results
138
 
139
  @spaces.GPU(duration=15)
 
149
  scores = []
150
 
151
  for name in tester.vae_models.keys():
152
+ diff_img, recon_img, score, proc_time = results[name]
153
  diff_images.append((diff_img, name))
154
  recon_images.append((recon_img, name))
155
+ scores.append(f"{name:<25}: {score:7,.0f} | {proc_time:.4f}s")
156
 
157
  return diff_images, recon_images, "\n".join(scores)
158
  except Exception as e:
 
164
  with gr.Blocks(title="VAE Performance Tester", css=".monospace-text {font-family: 'Courier New', Courier, monospace;}") as demo:
165
  gr.Markdown("# VAE Comparison Tool")
166
  gr.Markdown("""
167
+ Upload an image or select an example to compare how different VAEs reconstruct it.
168
  1. The image is padded to a square and resized to the selected size (512 or 1024 pixels).
169
+ 2. Each VAE encodes the image into a latent space and decodes it back.
170
  3. Outputs include:
171
  - **Difference Maps**: Where reconstruction differs from the original (white = difference > tolerance).
172
  - **Reconstructed Images**: Outputs from each VAE.
173
+ - **Sum of Differences and Time**: Total pixels exceeding tolerance (lower is better) and processing time in seconds.
174
  Adjust tolerance to change sensitivity.
175
  """)
176
 
 
192
  with gr.Row():
193
  diff_gallery = gr.Gallery(label="Difference Maps", columns=4, height=512)
194
  recon_gallery = gr.Gallery(label="Reconstructed Images", columns=4, height=512)
195
+ scores_output = gr.Textbox(label="Sum of differences (lower is better) | Processing time (lower is faster)", lines=9, elem_classes="monospace-text")
196
 
197
  if examples:
198
  with gr.Row():