derekl35 HF Staff commited on
Commit
6165359
·
verified ·
1 Parent(s): 8546d90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -130
app.py CHANGED
@@ -8,8 +8,47 @@ from pathlib import Path
8
  from PIL import Image
9
  import os
10
  import time
 
 
11
  import spaces
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  print(f"Using device: {DEVICE}")
15
 
@@ -18,12 +57,10 @@ DEFAULT_WIDTH = 1024
18
  DEFAULT_GUIDANCE_SCALE = 3.5
19
  DEFAULT_NUM_INFERENCE_STEPS = 15
20
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
21
- GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
22
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
23
 
24
  CACHED_PIPES = {}
25
  def load_bf16_pipeline():
26
- """Loads the original FLUX.1-dev pipeline in BF16 precision."""
27
  print("Loading BF16 pipeline...")
28
  MODEL_ID = "black-forest-labs/FLUX.1-dev"
29
  if MODEL_ID in CACHED_PIPES:
@@ -36,7 +73,6 @@ def load_bf16_pipeline():
36
  token=HF_TOKEN
37
  )
38
  pipe.to(DEVICE)
39
- # pipe.enable_model_cpu_offload()
40
  end_time = time.time()
41
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
42
  print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
@@ -44,10 +80,9 @@ def load_bf16_pipeline():
44
  return pipe
45
  except Exception as e:
46
  print(f"Error loading BF16 pipeline: {e}")
47
- raise # Re-raise exception to be caught in generate_images
48
 
49
  def load_bnb_8bit_pipeline():
50
- """Loads the FLUX.1-dev pipeline with 8-bit quantized components."""
51
  print("Loading 8-bit BNB pipeline...")
52
  MODEL_ID = "derekl35/FLUX.1-dev-bnb-8bit"
53
  if MODEL_ID in CACHED_PIPES:
@@ -70,7 +105,6 @@ def load_bnb_8bit_pipeline():
70
  raise
71
 
72
  def load_bnb_4bit_pipeline():
73
- """Loads the FLUX.1-dev pipeline with 4-bit quantized components."""
74
  print("Loading 4-bit BNB pipeline...")
75
  MODEL_ID = "derekl35/FLUX.1-dev-nf4"
76
  if MODEL_ID in CACHED_PIPES:
@@ -89,20 +123,17 @@ def load_bnb_4bit_pipeline():
89
  CACHED_PIPES[MODEL_ID] = pipe
90
  return pipe
91
  except Exception as e:
92
- print(f"4-bit BNB pipeline: {e}")
93
  raise
94
 
95
  @spaces.GPU(duration=240)
96
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
97
- """Loads original and selected quantized model, generates one image each, shuffles results."""
98
  if not prompt:
99
- return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
100
 
101
  if not quantization_choice:
102
- # Return updates for all outputs to clear them or show warning
103
- return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None)
104
 
105
- # Determine which quantized model to load
106
  if quantization_choice == "8-bit":
107
  quantized_load_func = load_bnb_8bit_pipeline
108
  quantized_label = "Quantized (8-bit)"
@@ -110,12 +141,11 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
110
  quantized_load_func = load_bnb_4bit_pipeline
111
  quantized_label = "Quantized (4-bit)"
112
  else:
113
- # Should not happen with Radio choices, but good practice
114
- return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None)
115
 
116
  model_configs = [
117
  ("Original", load_bf16_pipeline),
118
- (quantized_label, quantized_load_func), # Use the specific label here
119
  ]
120
 
121
  results = []
@@ -128,8 +158,6 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
128
  "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH,
129
  }
130
 
131
- current_pipe = None # Keep track of the current pipe for cleanup
132
-
133
  seed = random.getrandbits(64)
134
  print(f"Using seed: {seed}")
135
 
@@ -147,7 +175,6 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
147
  gen_start_time = time.time()
148
  image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images
149
  image = image_list[0]
150
- # image.save(f"{load_start_time}.png")
151
  gen_end_time = time.time()
152
  results.append({"label": label, "image": image})
153
  print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---")
@@ -156,64 +183,42 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
156
 
157
  except Exception as e:
158
  print(f"Error during {label} model processing: {e}")
159
- # Return error state to Gradio - update all outputs
160
- return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
161
 
162
- # No finally block needed here, cleanup happens before next load or after loop
163
 
164
  if len(results) != len(model_configs):
165
- print("Generation did not complete for all models.")
166
- # Update all outputs
167
- return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None)
168
 
169
- # Shuffle the results for display
170
  shuffled_results = results.copy()
171
  random.shuffle(shuffled_results)
172
-
173
- # Create the gallery data: [(image, caption), (image, caption)]
174
  shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)]
175
-
176
- # Create the mapping: display_index -> correct_label (e.g., {0: 'Original', 1: 'Quantized (8-bit)'})
177
  correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
178
  print("Correct mapping (hidden):", correct_mapping)
179
 
180
- guess_radio_update = gr.update(choices=["Image 1", "Image 2"], value=None, interactive=True)
181
 
182
- # Return shuffled images, the correct mapping state, status message, and update the guess radio
183
- return shuffled_data_for_gallery, correct_mapping, gr.update(value="Generation complete! Make your guess.", interactive=False), guess_radio_update
184
 
185
-
186
- # --- Guess Verification Function ---
187
  def check_guess(user_guess, correct_mapping_state):
188
- """Compares the user's guess with the correct mapping stored in the state."""
189
-
190
  if not isinstance(correct_mapping_state, dict) or not correct_mapping_state:
191
  return "Please generate images first (state is empty or invalid)."
192
-
193
  if user_guess is None:
194
  return "Please select which image you think is quantized."
195
 
196
- # Find which display index (0 or 1) corresponds to the quantized image
197
  quantized_image_index = -1
198
  quantized_label_actual = ""
199
  for index, label in correct_mapping_state.items():
200
- if "Quantized" in label: # Check if the label indicates quantization
201
  quantized_image_index = index
202
- quantized_label_actual = label # Store the full label e.g. "Quantized (8-bit)"
203
  break
204
-
205
  if quantized_image_index == -1:
206
- # This shouldn't happen if generation was successful
207
  return "Error: Could not find the quantized image in the mapping data."
208
 
209
- # Determine what the user *should* have selected based on the index
210
- correct_guess_label = f"Image {quantized_image_index + 1}" # "Image 1" or "Image 2"
211
-
212
  if user_guess == correct_guess_label:
213
  feedback = f"Correct! {correct_guess_label} used the {quantized_label_actual} model."
214
  else:
215
  feedback = f"Incorrect. The quantized image ({quantized_label_actual}) was {correct_guess_label}."
216
-
217
  return feedback
218
 
219
  EXAMPLE_DIR = Path(__file__).parent / "examples"
@@ -221,12 +226,14 @@ EXAMPLES = [
221
  {
222
  "prompt": "A photorealistic portrait of an astronaut on Mars",
223
  "files": ["astronauts_seed_6456306350371904162.png", "astronauts_bnb_8bit.png"],
224
- "quantized_idx": 1, # which of the two files is the quantized result
 
225
  },
226
  {
227
  "prompt": "Water-color painting of a cat wearing sunglasses",
228
  "files": ["watercolor_cat_bnb_8bit.png", "watercolor_cat_seed_14269059182221286790.png"],
229
  "quantized_idx": 0,
 
230
  },
231
  # {
232
  # "prompt": "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8-K",
@@ -236,97 +243,304 @@ EXAMPLES = [
236
  ]
237
 
238
  def load_example(idx):
239
- """Return [(PIL.Image, caption)...], mapping dict, and feedback string"""
240
  ex = EXAMPLES[idx]
241
  imgs = [Image.open(EXAMPLE_DIR / f) for f in ex["files"]]
242
  gallery_items = [(img, f"Image {i+1}") for i, img in enumerate(imgs)]
243
- mapping = {i: ("Quantized" if i == ex["quantized_idx"] else "Original")
244
  for i in range(2)}
245
  return gallery_items, mapping, f"{ex['prompt']}"
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
248
  gr.Markdown("# FLUX Model Quantization Challenge")
249
- gr.Markdown(
250
- "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). "
251
- "Enter a prompt, choose the quantization method, and generate two images. "
252
- "The images will be shuffled, can you spot which one was quantized?"
253
- )
254
-
255
- gr.Markdown("### Examples")
256
- ex_selector = gr.Radio(
257
- choices=[f"Example {i+1}" for i in range(len(EXAMPLES))],
258
- label="Choose an example prompt",
259
- interactive=True,
260
- )
261
- gr.Markdown("### …or create your own comparison")
262
- with gr.Row():
263
- prompt_input = gr.Textbox(label="Enter Prompt", scale=3)
264
- quantization_choice_radio = gr.Radio(
265
- choices=["8-bit", "4-bit"],
266
- label="Select Quantization",
267
- value="8-bit", # Default choice
268
- scale=1
269
- )
270
- generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
271
-
272
- output_gallery = gr.Gallery(
273
- label="Generated Images",
274
- columns=2,
275
- height=512,
276
- object_fit="contain",
277
- allow_preview=True,
278
- show_label=True,
279
- )
280
-
281
- gr.Markdown("### Which image used the selected quantization method?")
282
- with gr.Row():
283
- image1_btn = gr.Button("Image 1")
284
- image2_btn = gr.Button("Image 2")
285
-
286
- feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1)
287
-
288
- # Hidden state to store the correct mapping after shuffling
289
- # e.g., {0: 'Original', 1: 'Quantized (8-bit)'} or {0: 'Quantized (4-bit)', 1: 'Original'}
290
- correct_mapping_state = gr.State({})
291
-
292
- def _load_example(sel):
293
- idx = int(sel.split()[-1]) - 1
294
- return load_example(idx)
295
-
296
- ex_selector.change(
297
- fn=_load_example,
298
- inputs=ex_selector,
299
- outputs=[output_gallery, correct_mapping_state, prompt_input],
300
- ).then(
301
- lambda: (gr.update(interactive=True), gr.update(interactive=True)),
302
- outputs=[image1_btn, image2_btn],
303
- )
304
-
305
- generate_button.click(
306
- fn=generate_images,
307
- inputs=[prompt_input, quantization_choice_radio],
308
- outputs=[output_gallery, correct_mapping_state] #, feedback_box],
309
- ).then(
310
- lambda: (gr.update(interactive=True),
311
- gr.update(interactive=True),
312
- ""), # clear feedback
313
- outputs=[image1_btn, image2_btn, feedback_box],
314
- )
315
-
316
- def choose(choice_string, mapping):
317
- feedback = check_guess(choice_string, mapping)
318
- return feedback, gr.update(interactive=False), gr.update(interactive=False)
319
-
320
- image1_btn.click(
321
- fn=lambda mapping: choose("Image 1", mapping),
322
- inputs=[correct_mapping_state],
323
- outputs=[feedback_box, image1_btn, image2_btn],
324
- )
325
- image2_btn.click(
326
- fn=lambda mapping: choose("Image 2", mapping),
327
- inputs=[correct_mapping_state],
328
- outputs=[feedback_box, image1_btn, image2_btn],
329
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  if __name__ == "__main__":
332
  demo.launch(share=True)
 
8
  from PIL import Image
9
  import os
10
  import time
11
+ import json
12
+ from fasteners import InterProcessLock
13
  import spaces
14
 
15
+ AGG_FILE = Path(__file__).parent / "agg_stats.json"
16
+ LOCK_FILE = AGG_FILE.with_suffix(".lock")
17
+
18
+ def _load_agg_stats() -> dict:
19
+ if AGG_FILE.exists():
20
+ with open(AGG_FILE, "r") as f:
21
+ try:
22
+ return json.load(f)
23
+ except json.JSONDecodeError:
24
+ print(f"Warning: {AGG_FILE} is corrupted. Starting with empty stats.")
25
+ return {"8-bit": {"attempts": 0, "correct": 0}, "4-bit": {"attempts": 0, "correct": 0}}
26
+ return {"8-bit": {"attempts": 0, "correct": 0},
27
+ "4-bit": {"attempts": 0, "correct": 0}}
28
+
29
+ def _save_agg_stats(stats: dict) -> None:
30
+ with InterProcessLock(str(LOCK_FILE)):
31
+ with open(AGG_FILE, "w") as f:
32
+ json.dump(stats, f, indent=2)
33
+
34
+ USER_STATS_FILE = Path(__file__).parent / "user_stats.json"
35
+ USER_STATS_LOCK_FILE = USER_STATS_FILE.with_suffix(".lock")
36
+
37
+ def _load_user_stats() -> dict:
38
+ if USER_STATS_FILE.exists():
39
+ with open(USER_STATS_FILE, "r") as f:
40
+ try:
41
+ return json.load(f)
42
+ except json.JSONDecodeError:
43
+ print(f"Warning: {USER_STATS_FILE} is corrupted. Starting with empty user stats.")
44
+ return {}
45
+ return {}
46
+
47
+ def _save_user_stats(stats: dict) -> None:
48
+ with InterProcessLock(str(USER_STATS_LOCK_FILE)):
49
+ with open(USER_STATS_FILE, "w") as f:
50
+ json.dump(stats, f, indent=2)
51
+
52
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
53
  print(f"Using device: {DEVICE}")
54
 
 
57
  DEFAULT_GUIDANCE_SCALE = 3.5
58
  DEFAULT_NUM_INFERENCE_STEPS = 15
59
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
 
60
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
61
 
62
  CACHED_PIPES = {}
63
  def load_bf16_pipeline():
 
64
  print("Loading BF16 pipeline...")
65
  MODEL_ID = "black-forest-labs/FLUX.1-dev"
66
  if MODEL_ID in CACHED_PIPES:
 
73
  token=HF_TOKEN
74
  )
75
  pipe.to(DEVICE)
 
76
  end_time = time.time()
77
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
78
  print(f"BF16 Pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB")
 
80
  return pipe
81
  except Exception as e:
82
  print(f"Error loading BF16 pipeline: {e}")
83
+ raise
84
 
85
  def load_bnb_8bit_pipeline():
 
86
  print("Loading 8-bit BNB pipeline...")
87
  MODEL_ID = "derekl35/FLUX.1-dev-bnb-8bit"
88
  if MODEL_ID in CACHED_PIPES:
 
105
  raise
106
 
107
  def load_bnb_4bit_pipeline():
 
108
  print("Loading 4-bit BNB pipeline...")
109
  MODEL_ID = "derekl35/FLUX.1-dev-nf4"
110
  if MODEL_ID in CACHED_PIPES:
 
123
  CACHED_PIPES[MODEL_ID] = pipe
124
  return pipe
125
  except Exception as e:
126
+ print(f"Error loading 4-bit BNB pipeline: {e}")
127
  raise
128
 
129
  @spaces.GPU(duration=240)
130
  def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
 
131
  if not prompt:
132
+ return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
133
 
134
  if not quantization_choice:
135
+ return None, {}, gr.update(value="Please select a quantization method.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
 
136
 
 
137
  if quantization_choice == "8-bit":
138
  quantized_load_func = load_bnb_8bit_pipeline
139
  quantized_label = "Quantized (8-bit)"
 
141
  quantized_load_func = load_bnb_4bit_pipeline
142
  quantized_label = "Quantized (4-bit)"
143
  else:
144
+ return None, {}, gr.update(value="Invalid quantization choice.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
 
145
 
146
  model_configs = [
147
  ("Original", load_bf16_pipeline),
148
+ (quantized_label, quantized_load_func),
149
  ]
150
 
151
  results = []
 
158
  "max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH,
159
  }
160
 
 
 
161
  seed = random.getrandbits(64)
162
  print(f"Using seed: {seed}")
163
 
 
175
  gen_start_time = time.time()
176
  image_list = current_pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images
177
  image = image_list[0]
 
178
  gen_end_time = time.time()
179
  results.append({"label": label, "image": image})
180
  print(f"--- Finished Generation with {label} Model in {gen_end_time - gen_start_time:.2f} seconds ---")
 
183
 
184
  except Exception as e:
185
  print(f"Error during {label} model processing: {e}")
186
+ return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
 
187
 
 
188
 
189
  if len(results) != len(model_configs):
190
+ return None, {}, gr.update(value="Failed to generate images for all model types.", interactive=False), gr.update(choices=[], value=None), gr.update(interactive=True), gr.update(interactive=True)
 
 
191
 
 
192
  shuffled_results = results.copy()
193
  random.shuffle(shuffled_results)
 
 
194
  shuffled_data_for_gallery = [(res["image"], f"Image {i+1}") for i, res in enumerate(shuffled_results)]
 
 
195
  correct_mapping = {i: res["label"] for i, res in enumerate(shuffled_results)}
196
  print("Correct mapping (hidden):", correct_mapping)
197
 
198
+ return shuffled_data_for_gallery, correct_mapping, "Generation complete! Make your guess.", None, gr.update(interactive=True), gr.update(interactive=True)
199
 
 
 
200
 
 
 
201
  def check_guess(user_guess, correct_mapping_state):
 
 
202
  if not isinstance(correct_mapping_state, dict) or not correct_mapping_state:
203
  return "Please generate images first (state is empty or invalid)."
 
204
  if user_guess is None:
205
  return "Please select which image you think is quantized."
206
 
 
207
  quantized_image_index = -1
208
  quantized_label_actual = ""
209
  for index, label in correct_mapping_state.items():
210
+ if "Quantized" in label:
211
  quantized_image_index = index
212
+ quantized_label_actual = label
213
  break
 
214
  if quantized_image_index == -1:
 
215
  return "Error: Could not find the quantized image in the mapping data."
216
 
217
+ correct_guess_label = f"Image {quantized_image_index + 1}"
 
 
218
  if user_guess == correct_guess_label:
219
  feedback = f"Correct! {correct_guess_label} used the {quantized_label_actual} model."
220
  else:
221
  feedback = f"Incorrect. The quantized image ({quantized_label_actual}) was {correct_guess_label}."
 
222
  return feedback
223
 
224
  EXAMPLE_DIR = Path(__file__).parent / "examples"
 
226
  {
227
  "prompt": "A photorealistic portrait of an astronaut on Mars",
228
  "files": ["astronauts_seed_6456306350371904162.png", "astronauts_bnb_8bit.png"],
229
+ "quantized_idx": 1,
230
+ "quant_method": "bnb 8-bit",
231
  },
232
  {
233
  "prompt": "Water-color painting of a cat wearing sunglasses",
234
  "files": ["watercolor_cat_bnb_8bit.png", "watercolor_cat_seed_14269059182221286790.png"],
235
  "quantized_idx": 0,
236
+ "quant_method": "bnb 8-bit",
237
  },
238
  # {
239
  # "prompt": "Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8-K",
 
243
  ]
244
 
245
  def load_example(idx):
 
246
  ex = EXAMPLES[idx]
247
  imgs = [Image.open(EXAMPLE_DIR / f) for f in ex["files"]]
248
  gallery_items = [(img, f"Image {i+1}") for i, img in enumerate(imgs)]
249
+ mapping = {i: (f"Quantized {ex['quant_method']}" if i == ex["quantized_idx"] else "Original")
250
  for i in range(2)}
251
  return gallery_items, mapping, f"{ex['prompt']}"
252
 
253
+ def _accuracy_string(correct: int, attempts: int) -> tuple[str, float]:
254
+ if attempts:
255
+ pct = 100 * correct / attempts
256
+ return f"{pct:.1f}%", pct
257
+ return "N/A", -1.0
258
+
259
+ def _add_medals(user_rows):
260
+ MEDALS = {0: "🥇 ", 1: "🥈 ", 2: "🥉 "}
261
+ return [
262
+ [MEDALS.get(i, "") + row[0], *row[1:]]
263
+ for i, row in enumerate(user_rows)
264
+ ]
265
+
266
+ def update_leaderboards_data():
267
+ agg = _load_agg_stats()
268
+ quant_rows = []
269
+ for method, stats in agg.items():
270
+ acc_str, acc_val = _accuracy_string(stats["correct"], stats["attempts"])
271
+ quant_rows.append([
272
+ method,
273
+ stats["correct"],
274
+ stats["attempts"],
275
+ acc_str
276
+ ])
277
+ quant_rows.sort(key=lambda r: r[1]/r[2] if r[2] != 0 else 1e9)
278
+
279
+ user_stats = _load_user_stats()
280
+ user_rows = []
281
+ for user, st in user_stats.items():
282
+ acc_str, acc_val = _accuracy_string(st["total_correct"], st["total_attempts"])
283
+ user_rows.append([user, st["total_correct"], st["total_attempts"], acc_str])
284
+ user_rows.sort(key=lambda r: (-float(r[3].rstrip('%')) if r[3] != "N/A" else float('-inf'), -r[2]))
285
+ user_rows = _add_medals(user_rows)
286
+
287
+ return quant_rows, user_rows
288
+
289
  with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as demo:
290
  gr.Markdown("# FLUX Model Quantization Challenge")
291
+ with gr.Tabs():
292
+ with gr.TabItem("Challenge"):
293
+ gr.Markdown(
294
+ "Compare the original FLUX.1-dev (BF16) model against a quantized version (4-bit or 8-bit). "
295
+ "Enter a prompt, choose the quantization method, and generate two images. "
296
+ "The images will be shuffled, can you spot which one was quantized?"
297
+ )
298
+
299
+ gr.Markdown("### Examples")
300
+ ex_selector = gr.Radio(
301
+ choices=[f"Example {i+1}" for i in range(len(EXAMPLES))],
302
+ label="Choose an example prompt",
303
+ interactive=True,
304
+ )
305
+ gr.Markdown("### …or create your own comparison")
306
+ with gr.Row():
307
+ prompt_input = gr.Textbox(label="Enter Prompt", scale=3)
308
+ quantization_choice_radio = gr.Radio(
309
+ choices=["8-bit", "4-bit"],
310
+ label="Select Quantization",
311
+ value="8-bit",
312
+ scale=1
313
+ )
314
+ generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
315
+
316
+ output_gallery = gr.Gallery(
317
+ label="Generated Images",
318
+ columns=2,
319
+ height=606,
320
+ object_fit="contain",
321
+ allow_preview=True,
322
+ show_label=True,
323
+ )
324
+
325
+ gr.Markdown("### Which image used the selected quantization method?")
326
+ with gr.Row():
327
+ image1_btn = gr.Button("Image 1")
328
+ image2_btn = gr.Button("Image 2")
329
+
330
+ feedback_box = gr.Textbox(label="Feedback", interactive=False, lines=1)
331
+
332
+ with gr.Row():
333
+ session_score_box = gr.Textbox(label="Your accuracy this session", interactive=False)
334
+
335
+ with gr.Row(equal_height=False):
336
+ username_input = gr.Textbox(
337
+ label="Enter Your Name for Leaderboard",
338
+ placeholder="YourName",
339
+ visible=False,
340
+ interactive=True,
341
+ scale=2
342
+ )
343
+ add_score_button = gr.Button(
344
+ "Add My Score to Leaderboard",
345
+ visible=False,
346
+ variant="secondary",
347
+ scale=1
348
+ )
349
+ add_score_feedback = gr.Textbox(
350
+ label="Leaderboard Update",
351
+ visible=False,
352
+ interactive=False,
353
+ lines=1
354
+ )
355
+
356
+ correct_mapping_state = gr.State({})
357
+ session_stats_state = gr.State(
358
+ {"8-bit": {"attempts": 0, "correct": 0},
359
+ "4-bit": {"attempts": 0, "correct": 0}}
360
+ )
361
+ is_example_state = gr.State(False)
362
+ has_added_score_state = gr.State(False)
363
+
364
+ def _load_example(sel):
365
+ idx = int(sel.split()[-1]) - 1
366
+ gallery_items, mapping, prompt = load_example(idx)
367
+ quant_data, user_data = update_leaderboards_data()
368
+ return gallery_items, mapping, prompt, True, quant_data, user_data
369
+
370
+ ex_selector.change(
371
+ fn=_load_example,
372
+ inputs=ex_selector,
373
+ outputs=[output_gallery, correct_mapping_state, prompt_input, is_example_state, quant_df, user_df],
374
+ ).then(
375
+ lambda: (gr.update(interactive=True), gr.update(interactive=True)),
376
+ outputs=[image1_btn, image2_btn],
377
+ )
378
+
379
+ generate_button.click(
380
+ fn=generate_images,
381
+ inputs=[prompt_input, quantization_choice_radio],
382
+ outputs=[output_gallery, correct_mapping_state]
383
+ ).then(
384
+ lambda: (False, # for is_example_state
385
+ False, # for has_added_score_state
386
+ gr.update(visible=False, value="", interactive=True), # username_input reset
387
+ gr.update(visible=False), # add_score_button reset
388
+ gr.update(visible=False, value="")), # add_score_feedback reset
389
+ outputs=[is_example_state,
390
+ has_added_score_state,
391
+ username_input,
392
+ add_score_button,
393
+ add_score_feedback]
394
+ ).then(
395
+ lambda: (gr.update(interactive=True),
396
+ gr.update(interactive=True),
397
+ ""),
398
+ outputs=[image1_btn, image2_btn, feedback_box],
399
+ )
400
+
401
+ def choose(choice_string, mapping, session_stats, is_example, has_added_score_curr):
402
+ feedback = check_guess(choice_string, mapping)
403
+
404
+ quant_label = next(label for label in mapping.values() if "Quantized" in label)
405
+ quant_key = "8-bit" if "8-bit" in quant_label else "4-bit"
406
+
407
+ got_it_right = "Correct!" in feedback
408
+
409
+ sess = session_stats.copy()
410
+ if not is_example and not has_added_score_curr:
411
+ sess[quant_key]["attempts"] += 1
412
+ if got_it_right:
413
+ sess[quant_key]["correct"] += 1
414
+ session_stats = sess
415
+
416
+ AGG_STATS = _load_agg_stats()
417
+ AGG_STATS[quant_key]["attempts"] += 1
418
+ if got_it_right:
419
+ AGG_STATS[quant_key]["correct"] += 1
420
+ _save_agg_stats(AGG_STATS)
421
+
422
+ def _fmt(d):
423
+ a, c = d["attempts"], d["correct"]
424
+ pct = 100 * c / a if a else 0
425
+ return f"{c} / {a} ({pct:.1f}%)"
426
+
427
+ session_msg = ", ".join(
428
+ f"{k}: {_fmt(v)}" for k, v in sess.items()
429
+ )
430
+ current_agg_stats = _load_agg_stats()
431
+ global_msg = ", ".join(
432
+ f"{k}: {_fmt(v)}" for k, v in current_agg_stats.items()
433
+ )
434
+
435
+ username_input_update = gr.update(visible=False, interactive=True)
436
+ add_score_button_update = gr.update(visible=False)
437
+ # Keep existing feedback if score was already added and feedback is visible
438
+ current_feedback_text = add_score_feedback.value if hasattr(add_score_feedback, 'value') and add_score_feedback.value else ""
439
+ add_score_feedback_update = gr.update(visible=has_added_score_curr, value=current_feedback_text)
440
+
441
+ session_total_attempts = sum(stats["attempts"] for stats in sess.values())
442
+
443
+ if not is_example and not has_added_score_curr:
444
+ if session_total_attempts >= 1 : # Show button if more than 1 attempt
445
+ username_input_update = gr.update(visible=True, interactive=True)
446
+ add_score_button_update = gr.update(visible=True, interactive=True)
447
+ add_score_feedback_update = gr.update(visible=False, value="")
448
+ else: # Less than 1 attempts, keep hidden
449
+ username_input_update = gr.update(visible=False, value=username_input.value if hasattr(username_input, 'value') else "")
450
+ add_score_button_update = gr.update(visible=False)
451
+ add_score_feedback_update = gr.update(visible=False, value="")
452
+ elif has_added_score_curr:
453
+ username_input_update = gr.update(visible=True, interactive=False, value=username_input.value if hasattr(username_input, 'value') else "")
454
+ add_score_button_update = gr.update(visible=True, interactive=False)
455
+ add_score_feedback_update = gr.update(visible=True)
456
+
457
+ # disable the buttons so the user can't vote twice
458
+ quant_data, user_data = update_leaderboards_data() # Get updated leaderboard data
459
+ return (feedback,
460
+ gr.update(interactive=False),
461
+ gr.update(interactive=False),
462
+ session_msg,
463
+ session_stats,
464
+ quant_data,
465
+ user_data,
466
+ username_input_update,
467
+ add_score_button_update,
468
+ add_score_feedback_update)
469
+
470
+
471
+ image1_btn.click(
472
+ fn=lambda mapping, sess, is_ex, has_added: choose("Image 1", mapping, sess, is_ex, has_added),
473
+ inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state],
474
+ outputs=[feedback_box, image1_btn, image2_btn,
475
+ session_score_box, session_stats_state,
476
+ quant_df, user_df,
477
+ username_input, add_score_button, add_score_feedback],
478
+ )
479
+ image2_btn.click(
480
+ fn=lambda mapping, sess, is_ex, has_added: choose("Image 2", mapping, sess, is_ex, has_added),
481
+ inputs=[correct_mapping_state, session_stats_state, is_example_state, has_added_score_state],
482
+ outputs=[feedback_box, image1_btn, image2_btn,
483
+ session_score_box, session_stats_state,
484
+ quant_df, user_df,
485
+ username_input, add_score_button, add_score_feedback],
486
+ )
487
+
488
+ def handle_add_score_to_leaderboard(username_str, current_session_stats_dict):
489
+ if not username_str or not username_str.strip():
490
+ return ("Username is required.", # Feedback for add_score_feedback
491
+ gr.update(interactive=True), # username_input
492
+ gr.update(interactive=True), # add_score_button
493
+ False, # has_added_score_state
494
+ None, None) # quant_df, user_df
495
+
496
+ user_stats = _load_user_stats()
497
+ user_key = username_str.strip()
498
+
499
+ session_total_correct = sum(stats["correct"] for stats in current_session_stats_dict.values())
500
+ session_total_attempts = sum(stats["attempts"] for stats in current_session_stats_dict.values())
501
+
502
+ if session_total_attempts == 0:
503
+ return ("No attempts made in this session to add to leaderboard.",
504
+ gr.update(interactive=True),
505
+ gr.update(interactive=True),
506
+ False, None, None)
507
+
508
+ if user_key in user_stats:
509
+ user_stats[user_key]["total_correct"] += session_total_correct
510
+ user_stats[user_key]["total_attempts"] += session_total_attempts
511
+ else:
512
+ user_stats[user_key] = {
513
+ "total_correct": session_total_correct,
514
+ "total_attempts": session_total_attempts
515
+ }
516
+ _save_user_stats(user_stats)
517
+
518
+ new_quant_data, new_user_data = update_leaderboards_data()
519
+ feedback_msg = f"Score for '{user_key}' submitted to leaderboard!"
520
+ return (feedback_msg, # To add_score_feedback
521
+ gr.update(interactive=False), # username_input
522
+ gr.update(interactive=False), # add_score_button
523
+ True, # has_added_score_state (set to true)
524
+ new_quant_data, # To quant_df
525
+ new_user_data) # To user_df
526
+
527
+ add_score_button.click(
528
+ fn=handle_add_score_to_leaderboard,
529
+ inputs=[username_input, session_stats_state],
530
+ outputs=[add_score_feedback, username_input, add_score_button, has_added_score_state, quant_df, user_df]
531
+ )
532
+ with gr.TabItem("Leaderboard"):
533
+ gr.Markdown("## Quantization Method Leaderboard *(Lower % ⇒ harder to detect)*")
534
+ quant_df = gr.DataFrame(
535
+ headers=["Method", "Correct Guesses", "Total Attempts", "Detectability %"],
536
+ interactive=False, col_count=(4, "fixed")
537
+ )
538
+ gr.Markdown("## User Leaderboard *(Higher % ⇒ better spotter)*")
539
+ user_df = gr.DataFrame(
540
+ headers=["User", "Correct Guesses", "Total Attempts", "Accuracy %"],
541
+ interactive=False, col_count=(4, "fixed")
542
+ )
543
+ demo.load(update_leaderboards_data, outputs=[quant_df, user_df])
544
 
545
  if __name__ == "__main__":
546
  demo.launch(share=True)