huaweilin commited on
Commit
b460782
·
1 Parent(s): 7d3842f
Files changed (1) hide show
  1. app.py +28 -25
app.py CHANGED
@@ -1,19 +1,5 @@
1
  import os
2
  import spaces
3
- import subprocess
4
- import sys
5
-
6
- # REQUIREMENTS_FILE = "requirements.txt"
7
- # if os.path.exists(REQUIREMENTS_FILE):
8
- # try:
9
- # print("Installing dependencies from requirements.txt...")
10
- # subprocess.check_call([sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_FILE])
11
- # print("Dependencies installed successfully.")
12
- # except subprocess.CalledProcessError as e:
13
- # print(f"Failed to install dependencies: {e}")
14
- # else:
15
- # print("requirements.txt not found.")
16
-
17
  import gradio as gr
18
  from src.data_processing import pil_to_tensor, tensor_to_pil
19
  from PIL import Image
@@ -63,6 +49,8 @@ model_name_mapping = {
63
  "bsqvit": "BSQ-VIT",
64
  }
65
 
 
 
66
  def load_model(model_name):
67
  model, data_params = get_model(MODEL_DIR, model_name)
68
  model = model.to(device)
@@ -77,16 +65,24 @@ model_dict = {
77
  placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
78
 
79
  @spaces.GPU
80
- def process_selected_models(uploaded_image, selected_models):
 
 
 
 
 
 
 
 
81
  selected_results = []
82
  placeholder_results = []
83
 
 
 
84
  for model_name in model_name_mapping:
85
  label = model_name_mapping[model_name]
86
 
87
- if uploaded_image is None:
88
- result = gr.update(value=placeholder_image, label=f"{label} (No input)")
89
- elif model_name in selected_models:
90
  try:
91
  model, data_params = model_dict[model_name]
92
  pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
@@ -101,11 +97,11 @@ def process_selected_models(uploaded_image, selected_models):
101
  result = gr.update(value=placeholder_image, label=f"{label} (Not selected)")
102
  placeholder_results.append(result)
103
 
104
- return selected_results + placeholder_results
 
105
 
106
  with gr.Blocks() as demo:
107
  gr.Markdown("## VTBench")
108
-
109
  gr.Markdown("---")
110
 
111
  image_input = gr.Image(
@@ -132,18 +128,23 @@ with gr.Blocks() as demo:
132
  def load_img():
133
  return Image.open(p)
134
  return load_img
135
-
136
  ex_img.select(fn=make_loader(), outputs=image_input)
137
 
138
  gr.Markdown("---")
139
-
140
  gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
 
 
 
 
141
  model_selector = gr.CheckboxGroup(
142
- choices=list(model_name_mapping.keys()),
143
  label="Select models to run",
144
- value=["SD3.5L", "chameleon", "janus_pro_1b"],
145
  interactive=True,
146
  )
 
 
147
  run_button = gr.Button("Start Processing")
148
 
149
  image_outputs = []
@@ -164,10 +165,12 @@ with gr.Blocks() as demo:
164
  )
165
  image_outputs.append(out_img)
166
 
 
167
  run_button.click(
168
  fn=process_selected_models,
169
  inputs=[image_input, model_selector],
170
- outputs=image_outputs
171
  )
172
 
173
  demo.launch()
 
 
1
  import os
2
  import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import gradio as gr
4
  from src.data_processing import pil_to_tensor, tensor_to_pil
5
  from PIL import Image
 
49
  "bsqvit": "BSQ-VIT",
50
  }
51
 
52
+ display_to_internal = {v: k for k, v in model_name_mapping.items()}
53
+
54
  def load_model(model_name):
55
  model, data_params = get_model(MODEL_DIR, model_name)
56
  model = model.to(device)
 
65
  placeholder_image = Image.new("RGBA", (512, 512), (0, 0, 0, 0))
66
 
67
  @spaces.GPU
68
+ def process_selected_models(uploaded_image, selected_display_names):
69
+ if uploaded_image is None:
70
+ return [gr.update(value="⚠️ Please upload an image before processing.", visible=True)] + \
71
+ [gr.update() for _ in model_name_mapping]
72
+
73
+ if not selected_display_names:
74
+ return [gr.update(value="⚠️ Please select at least one model.", visible=True)] + \
75
+ [gr.update() for _ in model_name_mapping]
76
+
77
  selected_results = []
78
  placeholder_results = []
79
 
80
+ selected_internal = [display_to_internal[d] for d in selected_display_names]
81
+
82
  for model_name in model_name_mapping:
83
  label = model_name_mapping[model_name]
84
 
85
+ if model_name in selected_internal:
 
 
86
  try:
87
  model, data_params = model_dict[model_name]
88
  pixel_values = pil_to_tensor(uploaded_image, **data_params).unsqueeze(0).to(device)
 
97
  result = gr.update(value=placeholder_image, label=f"{label} (Not selected)")
98
  placeholder_results.append(result)
99
 
100
+ return [gr.update(visible=False)] + selected_results + placeholder_results
101
+
102
 
103
  with gr.Blocks() as demo:
104
  gr.Markdown("## VTBench")
 
105
  gr.Markdown("---")
106
 
107
  image_input = gr.Image(
 
128
  def load_img():
129
  return Image.open(p)
130
  return load_img
131
+
132
  ex_img.select(fn=make_loader(), outputs=image_input)
133
 
134
  gr.Markdown("---")
 
135
  gr.Markdown("⚠️ **The more models you select, the longer the processing time will be.**")
136
+
137
+ display_names = list(model_name_mapping.values())
138
+ default_selected = ["SD3.5L", "Chameleon", "Janus Pro 1B/7B"]
139
+
140
  model_selector = gr.CheckboxGroup(
141
+ choices=display_names,
142
  label="Select models to run",
143
+ value=default_selected,
144
  interactive=True,
145
  )
146
+
147
+ status_output = gr.Markdown("", visible=False)
148
  run_button = gr.Button("Start Processing")
149
 
150
  image_outputs = []
 
165
  )
166
  image_outputs.append(out_img)
167
 
168
+
169
  run_button.click(
170
  fn=process_selected_models,
171
  inputs=[image_input, model_selector],
172
+ outputs=[status_output] + image_outputs
173
  )
174
 
175
  demo.launch()
176
+