huaweilin commited on
Commit
7d3842f
·
1 Parent(s): faff65d
Files changed (2) hide show
  1. app.py +17 -10
  2. src/model_processing.py +12 -12
app.py CHANGED
@@ -78,23 +78,30 @@ placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
78
 
79
  @spaces.GPU
80
  def process_selected_models(uploaded_image, selected_models):
81
- results = []
 
 
82
  for model_name in model_name_mapping:
 
 
83
  if uploaded_image is None:
84
- results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (No input)"))
85
  elif model_name in selected_models:
86
  try:
87
  model, data_params = model_dict[model_name]
88
  pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
89
  output = model(pixel_values)[0]
90
  reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)
91
- results.append(gr.update(value=reconstructed_image, label=model_name_mapping[model_name]))
92
  except Exception as e:
93
  print(f"Error in model {model_name}: {e}")
94
- results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Error)"))
 
95
  else:
96
- results.append(gr.update(value=placeholder_image, label=f"{model_name_mapping[model_name]} (Not selected)"))
97
- return results
 
 
98
 
99
  with gr.Blocks() as demo:
100
  gr.Markdown("## VTBench")
@@ -140,15 +147,15 @@ with gr.Blocks() as demo:
140
  run_button = gr.Button("Start Processing")
141
 
142
  image_outputs = []
143
- model_items = list(model_name_mapping.items())
144
-
145
  n_columns = 5
146
- output_rows = [model_items[i:i+n_columns] for i in range(0, len(model_items), n_columns)]
147
 
148
  with gr.Column():
149
  for row in output_rows:
150
  with gr.Row():
151
- for model_name, display_name in row:
 
152
  out_img = gr.Image(
153
  label=f"{display_name} (Not run)",
154
  value=placeholder_image,
 
78
 
79
  @spaces.GPU
80
  def process_selected_models(uploaded_image, selected_models):
81
+ selected_results = []
82
+ placeholder_results = []
83
+
84
  for model_name in model_name_mapping:
85
+ label = model_name_mapping[model_name]
86
+
87
  if uploaded_image is None:
88
+ result = gr.update(value=placeholder_image, label=f"{label} (No input)")
89
  elif model_name in selected_models:
90
  try:
91
  model, data_params = model_dict[model_name]
92
  pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
93
  output = model(pixel_values)[0]
94
  reconstructed_image = tensor_to_pil(output[0].cpu(), **data_params)
95
+ result = gr.update(value=reconstructed_image, label=label)
96
  except Exception as e:
97
  print(f"Error in model {model_name}: {e}")
98
+ result = gr.update(value=placeholder_image, label=f"{label} (Error)")
99
+ selected_results.append(result)
100
  else:
101
+ result = gr.update(value=placeholder_image, label=f"{label} (Not selected)")
102
+ placeholder_results.append(result)
103
+
104
+ return selected_results + placeholder_results
105
 
106
  with gr.Blocks() as demo:
107
  gr.Markdown("## VTBench")
 
147
  run_button = gr.Button("Start Processing")
148
 
149
  image_outputs = []
150
+ model_names_ordered = list(model_name_mapping.keys())
 
151
  n_columns = 5
152
+ output_rows = [model_names_ordered[i:i+n_columns] for i in range(0, len(model_names_ordered), n_columns)]
153
 
154
  with gr.Column():
155
  for row in output_rows:
156
  with gr.Row():
157
+ for model_name in row:
158
+ display_name = model_name_mapping[model_name]
159
  out_img = gr.Image(
160
  label=f"{display_name} (Not run)",
161
  value=placeholder_image,
src/model_processing.py CHANGED
@@ -279,8 +279,8 @@ def get_model(model_path, model_name):
279
  data_params = {
280
  "target_image_size": (384, 384),
281
  "lock_ratio": True,
282
- "center_crop": False,
283
- "padding": True,
284
  }
285
 
286
  elif "var" in model_name.lower():
@@ -307,8 +307,8 @@ def get_model(model_path, model_name):
307
  (512, 512) if "512" in model_name.lower() else (256, 256)
308
  ),
309
  "lock_ratio": True,
310
- "center_crop": False,
311
- "padding": True,
312
  "standardize": False,
313
  }
314
 
@@ -349,8 +349,8 @@ def get_model(model_path, model_name):
349
  data_params = {
350
  "target_image_size": (1024, 1024),
351
  "lock_ratio": True,
352
- "center_crop": False,
353
- "padding": True,
354
  "standardize": False,
355
  }
356
 
@@ -367,8 +367,8 @@ def get_model(model_path, model_name):
367
  data_params = {
368
  "target_image_size": (1024, 1024),
369
  "lock_ratio": True,
370
- "center_crop": False,
371
- "padding": True,
372
  "standardize": True,
373
  }
374
 
@@ -385,8 +385,8 @@ def get_model(model_path, model_name):
385
  data_params = {
386
  "target_image_size": (1024, 1024),
387
  "lock_ratio": True,
388
- "center_crop": False,
389
- "padding": True,
390
  "standardize": True,
391
  }
392
 
@@ -396,8 +396,8 @@ def get_model(model_path, model_name):
396
  data_params = {
397
  "target_image_size": (1024, 1024),
398
  "lock_ratio": True,
399
- "center_crop": False,
400
- "padding": True,
401
  "standardize": False,
402
  }
403
  model = GPTImage(data_params)
 
279
  data_params = {
280
  "target_image_size": (384, 384),
281
  "lock_ratio": True,
282
+ "center_crop": True,
283
+ "padding": False,
284
  }
285
 
286
  elif "var" in model_name.lower():
 
307
  (512, 512) if "512" in model_name.lower() else (256, 256)
308
  ),
309
  "lock_ratio": True,
310
+ "center_crop": True,
311
+ "padding": False,
312
  "standardize": False,
313
  }
314
 
 
349
  data_params = {
350
  "target_image_size": (1024, 1024),
351
  "lock_ratio": True,
352
+ "center_crop": True,
353
+ "padding": False,
354
  "standardize": False,
355
  }
356
 
 
367
  data_params = {
368
  "target_image_size": (1024, 1024),
369
  "lock_ratio": True,
370
+ "center_crop": True,
371
+ "padding": False,
372
  "standardize": True,
373
  }
374
 
 
385
  data_params = {
386
  "target_image_size": (1024, 1024),
387
  "lock_ratio": True,
388
+ "center_crop": True,
389
+ "padding": False,
390
  "standardize": True,
391
  }
392
 
 
396
  data_params = {
397
  "target_image_size": (1024, 1024),
398
  "lock_ratio": True,
399
+ "center_crop": True,
400
+ "padding": False,
401
  "standardize": False,
402
  }
403
  model = GPTImage(data_params)