lionelgarnier commited on
Commit
07db937
·
1 Parent(s): 862266f
Files changed (1) hide show
  1. app.py +211 -211
app.py CHANGED
@@ -33,176 +33,176 @@ def end_session(req: gr.Request):
33
  shutil.rmtree(user_dir)
34
 
35
 
36
- def preprocess_image(image: Image.Image) -> Image.Image:
37
- """
38
- Preprocess the input image.
39
-
40
- Args:
41
- image (Image.Image): The input image.
42
-
43
- Returns:
44
- Image.Image: The preprocessed image.
45
- """
46
- processed_image = pipeline.preprocess_image(image)
47
- return processed_image
48
-
49
-
50
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
51
- return {
52
- 'gaussian': {
53
- **gs.init_params,
54
- '_xyz': gs._xyz.cpu().numpy(),
55
- '_features_dc': gs._features_dc.cpu().numpy(),
56
- '_scaling': gs._scaling.cpu().numpy(),
57
- '_rotation': gs._rotation.cpu().numpy(),
58
- '_opacity': gs._opacity.cpu().numpy(),
59
- },
60
- 'mesh': {
61
- 'vertices': mesh.vertices.cpu().numpy(),
62
- 'faces': mesh.faces.cpu().numpy(),
63
- },
64
- }
65
 
66
 
67
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
68
- gs = Gaussian(
69
- aabb=state['gaussian']['aabb'],
70
- sh_degree=state['gaussian']['sh_degree'],
71
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
72
- scaling_bias=state['gaussian']['scaling_bias'],
73
- opacity_bias=state['gaussian']['opacity_bias'],
74
- scaling_activation=state['gaussian']['scaling_activation'],
75
- )
76
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
77
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
78
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
79
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
80
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
81
 
82
- mesh = edict(
83
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
84
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
85
- )
86
 
87
- return gs, mesh
88
-
89
-
90
- def get_seed(randomize_seed: bool, seed: int) -> int:
91
- """
92
- Get the random seed.
93
- """
94
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
95
-
96
-
97
- @spaces.GPU
98
- def image_to_3d(
99
- image: Image.Image,
100
- seed: int,
101
- ss_guidance_strength: float,
102
- ss_sampling_steps: int,
103
- slat_guidance_strength: float,
104
- slat_sampling_steps: int,
105
- req: gr.Request,
106
- ) -> Tuple[dict, str]:
107
- """
108
- Convert an image to a 3D model.
109
-
110
- Args:
111
- image (Image.Image): The input image.
112
- seed (int): The random seed.
113
- ss_guidance_strength (float): The guidance strength for sparse structure generation.
114
- ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
115
- slat_guidance_strength (float): The guidance strength for structured latent generation.
116
- slat_sampling_steps (int): The number of sampling steps for structured latent generation.
117
-
118
- Returns:
119
- dict: The information of the generated 3D model.
120
- str: The path to the video of the 3D model.
121
- """
122
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
123
- outputs = pipeline.run(
124
- image,
125
- seed=seed,
126
- formats=["gaussian", "mesh"],
127
- preprocess_image=False,
128
- sparse_structure_sampler_params={
129
- "steps": ss_sampling_steps,
130
- "cfg_strength": ss_guidance_strength,
131
- },
132
- slat_sampler_params={
133
- "steps": slat_sampling_steps,
134
- "cfg_strength": slat_guidance_strength,
135
- },
136
- )
137
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
138
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
139
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
140
- video_path = os.path.join(user_dir, 'sample.mp4')
141
- imageio.mimsave(video_path, video, fps=15)
142
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
143
- torch.cuda.empty_cache()
144
- return state, video_path
145
-
146
-
147
- @spaces.GPU(duration=90)
148
- def extract_glb(
149
- state: dict,
150
- mesh_simplify: float,
151
- texture_size: int,
152
- req: gr.Request,
153
- ) -> Tuple[str, str]:
154
- """
155
- Extract a GLB file from the 3D model.
156
-
157
- Args:
158
- state (dict): The state of the generated 3D model.
159
- mesh_simplify (float): The mesh simplification factor.
160
- texture_size (int): The texture resolution.
161
-
162
- Returns:
163
- str: The path to the extracted GLB file.
164
- """
165
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
166
- gs, mesh = unpack_state(state)
167
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
168
- glb_path = os.path.join(user_dir, 'sample.glb')
169
- glb.export(glb_path)
170
- torch.cuda.empty_cache()
171
- return glb_path, glb_path
172
-
173
-
174
- @spaces.GPU
175
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
176
- """
177
- Extract a Gaussian file from the 3D model.
178
-
179
- Args:
180
- state (dict): The state of the generated 3D model.
181
-
182
- Returns:
183
- str: The path to the extracted Gaussian file.
184
- """
185
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
186
- gs, _ = unpack_state(state)
187
- gaussian_path = os.path.join(user_dir, 'sample.ply')
188
- gs.save_ply(gaussian_path)
189
- torch.cuda.empty_cache()
190
- return gaussian_path, gaussian_path
191
-
192
-
193
- def split_image(image: Image.Image) -> List[Image.Image]:
194
- """
195
- Split an image into multiple views.
196
- """
197
- image = np.array(image)
198
- alpha = image[..., 3]
199
- alpha = np.any(alpha>0, axis=0)
200
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
201
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
202
- images = []
203
- for s, e in zip(start_pos, end_pos):
204
- images.append(Image.fromarray(image[:, s:e+1]))
205
- return [preprocess_image(image) for image in images]
206
 
207
 
208
  with gr.Blocks(delete_cache=(600, 600)) as demo:
@@ -257,52 +257,52 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
257
  demo.unload(end_session)
258
 
259
 
260
- image_prompt.upload(
261
- preprocess_image,
262
- inputs=[image_prompt],
263
- outputs=[image_prompt],
264
- )
265
-
266
- generate_btn.click(
267
- get_seed,
268
- inputs=[randomize_seed, seed],
269
- outputs=[seed],
270
- ).then(
271
- image_to_3d,
272
- inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
273
- outputs=[output_buf, video_output],
274
- ).then(
275
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
276
- outputs=[extract_glb_btn, extract_gs_btn],
277
- )
278
-
279
- video_output.clear(
280
- lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
281
- outputs=[extract_glb_btn, extract_gs_btn],
282
- )
283
-
284
- extract_glb_btn.click(
285
- extract_glb,
286
- inputs=[output_buf, mesh_simplify, texture_size],
287
- outputs=[model_output, download_glb],
288
- ).then(
289
- lambda: gr.Button(interactive=True),
290
- outputs=[download_glb],
291
- )
292
 
293
- extract_gs_btn.click(
294
- extract_gaussian,
295
- inputs=[output_buf],
296
- outputs=[model_output, download_gs],
297
- ).then(
298
- lambda: gr.Button(interactive=True),
299
- outputs=[download_gs],
300
- )
301
-
302
- model_output.clear(
303
- lambda: gr.Button(interactive=False),
304
- outputs=[download_glb],
305
- )
306
 
307
 
308
  # Launch the Gradio app
 
33
  shutil.rmtree(user_dir)
34
 
35
 
36
+ # def preprocess_image(image: Image.Image) -> Image.Image:
37
+ # """
38
+ # Preprocess the input image.
39
+
40
+ # Args:
41
+ # image (Image.Image): The input image.
42
+
43
+ # Returns:
44
+ # Image.Image: The preprocessed image.
45
+ # """
46
+ # processed_image = pipeline.preprocess_image(image)
47
+ # return processed_image
48
+
49
+
50
+ # def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
51
+ # return {
52
+ # 'gaussian': {
53
+ # **gs.init_params,
54
+ # '_xyz': gs._xyz.cpu().numpy(),
55
+ # '_features_dc': gs._features_dc.cpu().numpy(),
56
+ # '_scaling': gs._scaling.cpu().numpy(),
57
+ # '_rotation': gs._rotation.cpu().numpy(),
58
+ # '_opacity': gs._opacity.cpu().numpy(),
59
+ # },
60
+ # 'mesh': {
61
+ # 'vertices': mesh.vertices.cpu().numpy(),
62
+ # 'faces': mesh.faces.cpu().numpy(),
63
+ # },
64
+ # }
65
 
66
 
67
+ # def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
68
+ # gs = Gaussian(
69
+ # aabb=state['gaussian']['aabb'],
70
+ # sh_degree=state['gaussian']['sh_degree'],
71
+ # mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
72
+ # scaling_bias=state['gaussian']['scaling_bias'],
73
+ # opacity_bias=state['gaussian']['opacity_bias'],
74
+ # scaling_activation=state['gaussian']['scaling_activation'],
75
+ # )
76
+ # gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
77
+ # gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
78
+ # gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
79
+ # gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
80
+ # gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
81
 
82
+ # mesh = edict(
83
+ # vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
84
+ # faces=torch.tensor(state['mesh']['faces'], device='cuda'),
85
+ # )
86
 
87
+ # return gs, mesh
88
+
89
+
90
+ # def get_seed(randomize_seed: bool, seed: int) -> int:
91
+ # """
92
+ # Get the random seed.
93
+ # """
94
+ # return np.random.randint(0, MAX_SEED) if randomize_seed else seed
95
+
96
+
97
+ # @spaces.GPU
98
+ # def image_to_3d(
99
+ # image: Image.Image,
100
+ # seed: int,
101
+ # ss_guidance_strength: float,
102
+ # ss_sampling_steps: int,
103
+ # slat_guidance_strength: float,
104
+ # slat_sampling_steps: int,
105
+ # req: gr.Request,
106
+ # ) -> Tuple[dict, str]:
107
+ # """
108
+ # Convert an image to a 3D model.
109
+
110
+ # Args:
111
+ # image (Image.Image): The input image.
112
+ # seed (int): The random seed.
113
+ # ss_guidance_strength (float): The guidance strength for sparse structure generation.
114
+ # ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
115
+ # slat_guidance_strength (float): The guidance strength for structured latent generation.
116
+ # slat_sampling_steps (int): The number of sampling steps for structured latent generation.
117
+
118
+ # Returns:
119
+ # dict: The information of the generated 3D model.
120
+ # str: The path to the video of the 3D model.
121
+ # """
122
+ # user_dir = os.path.join(TMP_DIR, str(req.session_hash))
123
+ # outputs = pipeline.run(
124
+ # image,
125
+ # seed=seed,
126
+ # formats=["gaussian", "mesh"],
127
+ # preprocess_image=False,
128
+ # sparse_structure_sampler_params={
129
+ # "steps": ss_sampling_steps,
130
+ # "cfg_strength": ss_guidance_strength,
131
+ # },
132
+ # slat_sampler_params={
133
+ # "steps": slat_sampling_steps,
134
+ # "cfg_strength": slat_guidance_strength,
135
+ # },
136
+ # )
137
+ # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
138
+ # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
139
+ # video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
140
+ # video_path = os.path.join(user_dir, 'sample.mp4')
141
+ # imageio.mimsave(video_path, video, fps=15)
142
+ # state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
143
+ # torch.cuda.empty_cache()
144
+ # return state, video_path
145
+
146
+
147
+ # @spaces.GPU(duration=90)
148
+ # def extract_glb(
149
+ # state: dict,
150
+ # mesh_simplify: float,
151
+ # texture_size: int,
152
+ # req: gr.Request,
153
+ # ) -> Tuple[str, str]:
154
+ # """
155
+ # Extract a GLB file from the 3D model.
156
+
157
+ # Args:
158
+ # state (dict): The state of the generated 3D model.
159
+ # mesh_simplify (float): The mesh simplification factor.
160
+ # texture_size (int): The texture resolution.
161
+
162
+ # Returns:
163
+ # str: The path to the extracted GLB file.
164
+ # """
165
+ # user_dir = os.path.join(TMP_DIR, str(req.session_hash))
166
+ # gs, mesh = unpack_state(state)
167
+ # glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
168
+ # glb_path = os.path.join(user_dir, 'sample.glb')
169
+ # glb.export(glb_path)
170
+ # torch.cuda.empty_cache()
171
+ # return glb_path, glb_path
172
+
173
+
174
+ # @spaces.GPU
175
+ # def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
176
+ # """
177
+ # Extract a Gaussian file from the 3D model.
178
+
179
+ # Args:
180
+ # state (dict): The state of the generated 3D model.
181
+
182
+ # Returns:
183
+ # str: The path to the extracted Gaussian file.
184
+ # """
185
+ # user_dir = os.path.join(TMP_DIR, str(req.session_hash))
186
+ # gs, _ = unpack_state(state)
187
+ # gaussian_path = os.path.join(user_dir, 'sample.ply')
188
+ # gs.save_ply(gaussian_path)
189
+ # torch.cuda.empty_cache()
190
+ # return gaussian_path, gaussian_path
191
+
192
+
193
+ # def split_image(image: Image.Image) -> List[Image.Image]:
194
+ # """
195
+ # Split an image into multiple views.
196
+ # """
197
+ # image = np.array(image)
198
+ # alpha = image[..., 3]
199
+ # alpha = np.any(alpha>0, axis=0)
200
+ # start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
201
+ # end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
202
+ # images = []
203
+ # for s, e in zip(start_pos, end_pos):
204
+ # images.append(Image.fromarray(image[:, s:e+1]))
205
+ # return [preprocess_image(image) for image in images]
206
 
207
 
208
  with gr.Blocks(delete_cache=(600, 600)) as demo:
 
257
  demo.unload(end_session)
258
 
259
 
260
+ # image_prompt.upload(
261
+ # preprocess_image,
262
+ # inputs=[image_prompt],
263
+ # outputs=[image_prompt],
264
+ # )
265
+
266
+ # generate_btn.click(
267
+ # get_seed,
268
+ # inputs=[randomize_seed, seed],
269
+ # outputs=[seed],
270
+ # ).then(
271
+ # image_to_3d,
272
+ # inputs=[image_prompt, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
273
+ # outputs=[output_buf, video_output],
274
+ # ).then(
275
+ # lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
276
+ # outputs=[extract_glb_btn, extract_gs_btn],
277
+ # )
278
+
279
+ # video_output.clear(
280
+ # lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
281
+ # outputs=[extract_glb_btn, extract_gs_btn],
282
+ # )
283
+
284
+ # extract_glb_btn.click(
285
+ # extract_glb,
286
+ # inputs=[output_buf, mesh_simplify, texture_size],
287
+ # outputs=[model_output, download_glb],
288
+ # ).then(
289
+ # lambda: gr.Button(interactive=True),
290
+ # outputs=[download_glb],
291
+ # )
292
 
293
+ # extract_gs_btn.click(
294
+ # extract_gaussian,
295
+ # inputs=[output_buf],
296
+ # outputs=[model_output, download_gs],
297
+ # ).then(
298
+ # lambda: gr.Button(interactive=True),
299
+ # outputs=[download_gs],
300
+ # )
301
+
302
+ # model_output.clear(
303
+ # lambda: gr.Button(interactive=False),
304
+ # outputs=[download_glb],
305
+ # )
306
 
307
 
308
  # Launch the Gradio app