Spaces:
mashroo
/
Running on Zero

YoussefAnso commited on
Commit
bb9cdff
·
1 Parent(s): 732c53f

Refactor app.py by moving model initialization and argument parsing to the main execution block, enhancing clarity and organization. Correct typos in comments and variable names for improved readability. Update Gradio component configurations for better user experience.

Browse files
Files changed (1) hide show
  1. app.py +74 -95
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # Fixed version with proper error handling and compatibility
2
  import spaces
3
  import argparse
4
  import numpy as np
@@ -20,71 +19,10 @@ import argparse
20
  from model import CRM
21
  from inference import generate3d
22
 
23
- # Move model initialization into a function that will be called by workers
24
- def init_model():
25
- parser = argparse.ArgumentParser()
26
- parser.add_argument(
27
- "--stage1_config",
28
- type=str,
29
- default="configs/nf7_v3_SNR_rd_size_stroke.yaml",
30
- help="config for stage1",
31
- )
32
- parser.add_argument(
33
- "--stage2_config",
34
- type=str,
35
- default="configs/stage2-v2-snr.yaml",
36
- help="config for stage2",
37
- )
38
- parser.add_argument("--device", type=str, default="cuda")
39
- args = parser.parse_args(args=[]) # Fix: provide empty args list
40
-
41
- # Download model files
42
- crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
43
- specs = json.load(open("configs/specs_objaverse_total.json"))
44
- model = CRM(specs)
45
- model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
46
- model = model.to(args.device)
47
-
48
- # Load configs
49
- stage1_config = OmegaConf.load(args.stage1_config).config
50
- stage2_config = OmegaConf.load(args.stage2_config).config
51
- stage2_sampler_config = stage2_config.sampler
52
- stage1_sampler_config = stage1_config.sampler
53
-
54
- stage1_model_config = stage1_config.models
55
- stage2_model_config = stage2_config.models
56
-
57
- xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
58
- pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
59
- stage1_model_config.resume = pixel_path
60
- stage2_model_config.resume = xyz_path
61
-
62
- pipeline = TwoStagePipeline(
63
- stage1_model_config,
64
- stage2_model_config,
65
- stage1_sampler_config,
66
- stage2_sampler_config,
67
- device=args.device,
68
- dtype=torch.float32
69
- )
70
-
71
- return model, pipeline, args
72
-
73
- # Global variables to store model and pipeline
74
- model = None
75
  pipeline = None
76
- args = None
77
-
78
- @spaces.GPU
79
- def get_model():
80
- """Lazy initialization of model and pipeline"""
81
- global model, pipeline, args
82
- if model is None or pipeline is None:
83
- model, pipeline, args = init_model()
84
- return model, pipeline
85
-
86
  rembg_session = rembg.new_session()
87
 
 
88
  def expand_to_square(image, bg_color=(0, 0, 0, 0)):
89
  # expand image to 1:1
90
  width, height = image.size
@@ -97,10 +35,9 @@ def expand_to_square(image, bg_color=(0, 0, 0, 0)):
97
  return new_image
98
 
99
  def check_input_image(input_image):
100
- """Check if the input image is valid"""
101
  if input_image is None:
102
  raise gr.Error("No image uploaded!")
103
- return input_image
104
 
105
  def remove_background(
106
  image: PIL.Image.Image,
@@ -111,7 +48,7 @@ def remove_background(
111
  do_remove = True
112
  if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
113
  # explain why current do not rm bg
114
- print("alpha channel not empty, skip remove background, using alpha channel as mask")
115
  background = Image.new("RGBA", image.size, (0, 0, 0, 0))
116
  image = Image.alpha_composite(background, image)
117
  do_remove = False
@@ -121,7 +58,7 @@ def remove_background(
121
  return image
122
 
123
  def do_resize_content(original_image: Image, scale_rate):
124
- # resize image content while retaining the original image size
125
  if scale_rate != 1:
126
  # Calculate the new size after rescaling
127
  new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
@@ -140,11 +77,6 @@ def add_background(image, bg_color=(255, 255, 255)):
140
  background = Image.new("RGBA", image.size, bg_color)
141
  return Image.alpha_composite(background, image)
142
 
143
- def add_random_background(image, color):
144
- # Add a random background to the image
145
- width, height = image.size
146
- background = Image.new("RGBA", image.size, color)
147
- return Image.alpha_composite(background, image)
148
 
149
  def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
150
  """
@@ -172,7 +104,53 @@ def gen_image(input_image, seed, scale, step):
172
  np_xyzs = np.concatenate(stage2_images, 1)
173
 
174
  glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
175
- return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  _DESCRIPTION = '''
178
  * Our [official implementation](https://github.com/thu-ml/CRM) uses UV texture instead of vertex color. It has better texture than this online demo.
@@ -180,7 +158,6 @@ _DESCRIPTION = '''
180
  * If you find the output unsatisfying, try using different seeds:)
181
  '''
182
 
183
- # Create the Gradio interface
184
  with gr.Blocks() as demo:
185
  gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model")
186
  gr.Markdown(_DESCRIPTION)
@@ -198,11 +175,13 @@ with gr.Blocks() as demo:
198
  with gr.Column():
199
  with gr.Row():
200
  background_choice = gr.Radio([
201
- "Alpha as mask",
202
- "Auto Remove background"
203
- ], value="Auto Remove background",
204
- label="Background choice")
205
- back_ground_color = gr.ColorPicker(label="Background Color", value="#7F7F7F")
 
 
206
  foreground_ratio = gr.Slider(
207
  label="Foreground Ratio",
208
  minimum=0.5,
@@ -212,21 +191,21 @@ with gr.Blocks() as demo:
212
  )
213
 
214
  with gr.Column():
215
- seed = gr.Number(value=1234, label="Seed", precision=0)
216
- guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="Guidance scale")
217
- step = gr.Number(value=30, minimum=30, maximum=100, label="Sample steps", precision=0)
218
  text_button = gr.Button("Generate 3D shape")
219
- # if os.path.exists("examples") and os.listdir("examples"):
220
- # gr.Examples(
221
- # examples=[os.path.join("examples", i) for i in os.listdir("examples") if i.lower().endswith(('.png', '.jpg', '.jpeg'))],
222
- # inputs=[image_input],
223
- # examples_per_page=20,
224
- # )
225
  with gr.Column():
226
  image_output = gr.Image(interactive=False, label="Output RGB image")
227
- xyz_output = gr.Image(interactive=False, label="Output CCM image")
 
228
  output_model = gr.Model3D(
229
- label="Output GLB",
230
  interactive=False,
231
  )
232
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
@@ -239,19 +218,19 @@ with gr.Blocks() as demo:
239
  ]
240
  outputs = [
241
  image_output,
242
- xyz_output,
243
  output_model,
 
244
  ]
245
 
 
246
  text_button.click(fn=check_input_image, inputs=[image_input]).success(
247
  fn=preprocess_image,
248
- inputs=[image_input, background_choice, foreground_ratio, back_ground_color],
249
  outputs=[processed_image],
250
  ).success(
251
  fn=gen_image,
252
  inputs=inputs,
253
  outputs=outputs,
254
  )
255
-
256
- if __name__ == "__main__":
257
- demo.queue().launch()
 
 
1
  import spaces
2
  import argparse
3
  import numpy as np
 
19
  from model import CRM
20
  from inference import generate3d
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  pipeline = None
 
 
 
 
 
 
 
 
 
 
23
  rembg_session = rembg.new_session()
24
 
25
+
26
  def expand_to_square(image, bg_color=(0, 0, 0, 0)):
27
  # expand image to 1:1
28
  width, height = image.size
 
35
  return new_image
36
 
37
  def check_input_image(input_image):
 
38
  if input_image is None:
39
  raise gr.Error("No image uploaded!")
40
+
41
 
42
  def remove_background(
43
  image: PIL.Image.Image,
 
48
  do_remove = True
49
  if image.mode == "RGBA" and image.getextrema()[3][0] < 255:
50
  # explain why current do not rm bg
51
+ print("alhpa channl not enpty, skip remove background, using alpha channel as mask")
52
  background = Image.new("RGBA", image.size, (0, 0, 0, 0))
53
  image = Image.alpha_composite(background, image)
54
  do_remove = False
 
58
  return image
59
 
60
  def do_resize_content(original_image: Image, scale_rate):
61
+ # resize image content wile retain the original image size
62
  if scale_rate != 1:
63
  # Calculate the new size after rescaling
64
  new_size = tuple(int(dim * scale_rate) for dim in original_image.size)
 
77
  background = Image.new("RGBA", image.size, bg_color)
78
  return Image.alpha_composite(background, image)
79
 
 
 
 
 
 
80
 
81
  def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
82
  """
 
104
  np_xyzs = np.concatenate(stage2_images, 1)
105
 
106
  glb_path = generate3d(model, np_imgs, np_xyzs, args.device)
107
+ return Image.fromarray(np_imgs), Image.fromarray(np_xyzs), glb_path#, obj_path
108
+
109
+
110
+ parser = argparse.ArgumentParser()
111
+ parser.add_argument(
112
+ "--stage1_config",
113
+ type=str,
114
+ default="configs/nf7_v3_SNR_rd_size_stroke.yaml",
115
+ help="config for stage1",
116
+ )
117
+ parser.add_argument(
118
+ "--stage2_config",
119
+ type=str,
120
+ default="configs/stage2-v2-snr.yaml",
121
+ help="config for stage2",
122
+ )
123
+
124
+ parser.add_argument("--device", type=str, default="cuda")
125
+ args = parser.parse_args()
126
+
127
+ crm_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="CRM.pth")
128
+ specs = json.load(open("configs/specs_objaverse_total.json"))
129
+ model = CRM(specs)
130
+ model.load_state_dict(torch.load(crm_path, map_location="cpu"), strict=False)
131
+ model = model.to(args.device)
132
+
133
+ stage1_config = OmegaConf.load(args.stage1_config).config
134
+ stage2_config = OmegaConf.load(args.stage2_config).config
135
+ stage2_sampler_config = stage2_config.sampler
136
+ stage1_sampler_config = stage1_config.sampler
137
+
138
+ stage1_model_config = stage1_config.models
139
+ stage2_model_config = stage2_config.models
140
+
141
+ xyz_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="ccm-diffusion.pth")
142
+ pixel_path = hf_hub_download(repo_id="Zhengyi/CRM", filename="pixel-diffusion.pth")
143
+ stage1_model_config.resume = pixel_path
144
+ stage2_model_config.resume = xyz_path
145
+
146
+ pipeline = TwoStagePipeline(
147
+ stage1_model_config,
148
+ stage2_model_config,
149
+ stage1_sampler_config,
150
+ stage2_sampler_config,
151
+ device=args.device,
152
+ dtype=torch.float32
153
+ )
154
 
155
  _DESCRIPTION = '''
156
  * Our [official implementation](https://github.com/thu-ml/CRM) uses UV texture instead of vertex color. It has better texture than this online demo.
 
158
  * If you find the output unsatisfying, try using different seeds:)
159
  '''
160
 
 
161
  with gr.Blocks() as demo:
162
  gr.Markdown("# CRM: Single Image to 3D Textured Mesh with Convolutional Reconstruction Model")
163
  gr.Markdown(_DESCRIPTION)
 
175
  with gr.Column():
176
  with gr.Row():
177
  background_choice = gr.Radio([
178
+ "Alpha as mask",
179
+ "Auto Remove background"
180
+ ], value="Auto Remove background",
181
+ label="backgroud choice")
182
+ # do_remove_background = gr.Checkbox(label=, value=True)
183
+ # force_remove = gr.Checkbox(label=, value=False)
184
+ back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
185
  foreground_ratio = gr.Slider(
186
  label="Foreground Ratio",
187
  minimum=0.5,
 
191
  )
192
 
193
  with gr.Column():
194
+ seed = gr.Number(value=1234, label="seed", precision=0)
195
+ guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale")
196
+ step = gr.Number(value=30, minimum=30, maximum=100, label="sample steps", precision=0)
197
  text_button = gr.Button("Generate 3D shape")
198
+ gr.Examples(
199
+ examples=[os.path.join("examples", i) for i in os.listdir("examples")],
200
+ inputs=[image_input],
201
+ examples_per_page = 20,
202
+ )
 
203
  with gr.Column():
204
  image_output = gr.Image(interactive=False, label="Output RGB image")
205
+ xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
206
+
207
  output_model = gr.Model3D(
208
+ label="Output OBJ",
209
  interactive=False,
210
  )
211
  gr.Markdown("Note: Ensure that the input image is correctly pre-processed into a grey background, otherwise the results will be unpredictable.")
 
218
  ]
219
  outputs = [
220
  image_output,
221
+ xyz_ouput,
222
  output_model,
223
+ # output_obj,
224
  ]
225
 
226
+
227
  text_button.click(fn=check_input_image, inputs=[image_input]).success(
228
  fn=preprocess_image,
229
+ inputs=[image_input, background_choice, foreground_ratio, back_groud_color],
230
  outputs=[processed_image],
231
  ).success(
232
  fn=gen_image,
233
  inputs=inputs,
234
  outputs=outputs,
235
  )
236
+ demo.queue().launch()