lionelgarnier commited on
Commit
008680f
·
1 Parent(s): bd71366

deactivate 3D

Browse files
Files changed (1) hide show
  1. app.py +185 -185
app.py CHANGED
@@ -163,7 +163,7 @@ def validate_dimensions(width, height):
163
  return True, None
164
 
165
  @spaces.GPU()
166
- def infer(prompt, seed=DEFAULT_SEED,
167
  randomize_seed=DEFAULT_RANDOMIZE_SEED,
168
  width=DEFAULT_WIDTH,
169
  height=DEFAULT_HEIGHT,
@@ -251,136 +251,136 @@ def preload_models():
251
  return success, status
252
 
253
 
254
- def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
255
- return {
256
- 'gaussian': {
257
- **gs.init_params,
258
- '_xyz': gs._xyz.cpu().numpy(),
259
- '_features_dc': gs._features_dc.cpu().numpy(),
260
- '_scaling': gs._scaling.cpu().numpy(),
261
- '_rotation': gs._rotation.cpu().numpy(),
262
- '_opacity': gs._opacity.cpu().numpy(),
263
- },
264
- 'mesh': {
265
- 'vertices': mesh.vertices.cpu().numpy(),
266
- 'faces': mesh.faces.cpu().numpy(),
267
- },
268
- }
269
 
270
 
271
- def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
272
- gs = Gaussian(
273
- aabb=state['gaussian']['aabb'],
274
- sh_degree=state['gaussian']['sh_degree'],
275
- mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
276
- scaling_bias=state['gaussian']['scaling_bias'],
277
- opacity_bias=state['gaussian']['opacity_bias'],
278
- scaling_activation=state['gaussian']['scaling_activation'],
279
- )
280
- gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
281
- gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
282
- gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
283
- gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
284
- gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
285
 
286
- mesh = edict(
287
- vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
288
- faces=torch.tensor(state['mesh']['faces'], device='cuda'),
289
- )
290
 
291
- return gs, mesh
292
-
293
-
294
- @spaces.GPU
295
- def image_to_3d(
296
- image: Image.Image,
297
- seed: int,
298
- ss_guidance_strength: float,
299
- ss_sampling_steps: int,
300
- slat_guidance_strength: float,
301
- slat_sampling_steps: int,
302
- ) -> Tuple[dict, str]:
303
- try:
304
- # Use a fixed temp directory instead of user-specific
305
- temp_dir = os.path.join(TMP_DIR, "temp_output")
306
- os.makedirs(temp_dir, exist_ok=True)
307
 
308
- # Get the pipeline using the getter function
309
- pipeline = get_trellis_pipeline()
310
- if pipeline is None:
311
- return None, "Trellis pipeline is unavailable."
312
 
313
- outputs = pipeline.run(
314
- image,
315
- seed=seed,
316
- formats=["gaussian", "mesh"],
317
- preprocess_image=False,
318
- sparse_structure_sampler_params={
319
- "steps": ss_sampling_steps,
320
- "cfg_strength": ss_guidance_strength,
321
- },
322
- slat_sampler_params={
323
- "steps": slat_sampling_steps,
324
- "cfg_strength": slat_guidance_strength,
325
- },
326
- )
327
-
328
- video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
329
- video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
330
- video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
331
- video_path = os.path.join(temp_dir, 'sample.mp4')
332
- imageio.mimsave(video_path, video, fps=15)
333
- state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
334
- torch.cuda.empty_cache()
335
- return state, video_path
336
- except Exception as e:
337
- print(f"Error in image_to_3d: {str(e)}")
338
- return None, f"Error generating 3D model: {str(e)}"
339
-
340
-
341
- @spaces.GPU(duration=90)
342
- def extract_glb(
343
- state: dict,
344
- mesh_simplify: float,
345
- texture_size: int,
346
- ) -> Tuple[str, str]:
347
- """
348
- Extract a GLB file from the 3D model.
349
-
350
- Args:
351
- state (dict): The state of the generated 3D model.
352
- mesh_simplify (float): The mesh simplification factor.
353
- texture_size (int): The texture resolution.
354
-
355
- Returns:
356
- str: The path to the extracted GLB file.
357
- """
358
- temp_dir = os.path.join(TMP_DIR, "temp_output")
359
- gs, mesh = unpack_state(state)
360
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
361
- glb_path = os.path.join(temp_dir, 'sample.glb')
362
- glb.export(glb_path)
363
- torch.cuda.empty_cache()
364
- return glb_path, glb_path
365
-
366
-
367
- @spaces.GPU
368
- def extract_gaussian(state: dict) -> Tuple[str, str]:
369
- """
370
- Extract a Gaussian file from the 3D model.
371
-
372
- Args:
373
- state (dict): The state of the generated 3D model.
374
-
375
- Returns:
376
- str: The path to the extracted Gaussian file.
377
- """
378
- temp_dir = os.path.join(TMP_DIR, "temp_output")
379
- gs, _ = unpack_state(state)
380
- gaussian_path = os.path.join(temp_dir, 'sample.ply')
381
- gs.save_ply(gaussian_path)
382
- torch.cuda.empty_cache()
383
- return gaussian_path, gaussian_path
384
 
385
 
386
  # Create a combined function that handles the whole pipeline from example to image
@@ -435,14 +435,14 @@ def create_interface():
435
  visual_button = gr.Button("Create visual with Flux")
436
 
437
  generated_image = gr.Image(show_label=False)
438
- gen3d_button = gr.Button("Create 3D visual with Trellis")
439
 
440
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
441
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
442
 
443
- with gr.Row():
444
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
445
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
446
 
447
  message_box = gr.Textbox(
448
  label="Status Messages",
@@ -487,28 +487,28 @@ def create_interface():
487
  value=DEFAULT_NUM_INFERENCE_STEPS,
488
  )
489
 
490
- with gr.Tab("3D Generation Settings"):
491
- trellis_seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
492
- trellis_randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
493
- gr.Markdown("Stage 1: Sparse Structure Generation")
494
- with gr.Row():
495
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
496
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
497
- gr.Markdown("Stage 2: Structured Latent Generation")
498
- with gr.Row():
499
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
500
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
501
-
502
- with gr.Tab("GLB Extraction Settings"):
503
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
504
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
505
 
506
- with gr.Row():
507
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
508
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
509
- gr.Markdown("""
510
- *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
511
- """)
512
 
513
  output_buf = gr.State()
514
 
@@ -531,44 +531,44 @@ def create_interface():
531
 
532
  gr.on(
533
  triggers=[visual_button.click],
534
- fn=infer,
535
  inputs=[refined_prompt, flux_seed, flux_randomize_seed, width, height, num_inference_steps],
536
  outputs=[generated_image, message_box]
537
  )
538
 
539
- gr.on(
540
- triggers=[gen3d_button.click],
541
- fn=image_to_3d,
542
- inputs=[generated_image, trellis_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
543
- outputs=[output_state, video_output],
544
- ).then(
545
- # Update button states after successful 3D generation
546
- lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True), "3D model generated successfully"),
547
- outputs=[extract_glb_btn, extract_gs_btn, message_box]
548
- )
549
 
550
- # Add handlers for GLB and Gaussian extraction
551
- gr.on(
552
- triggers=[extract_glb_btn.click],
553
- fn=extract_glb,
554
- inputs=[output_state, mesh_simplify, texture_size],
555
- outputs=[model_output, download_glb]
556
- ).then(
557
- lambda path: (gr.DownloadButton.update(interactive=True, value=path), "GLB extraction completed"),
558
- inputs=[model_output],
559
- outputs=[download_glb, message_box]
560
- )
561
-
562
- gr.on(
563
- triggers=[extract_gs_btn.click],
564
- fn=extract_gaussian,
565
- inputs=[output_state],
566
- outputs=[model_output, download_gs]
567
- ).then(
568
- lambda path: (gr.DownloadButton.update(interactive=True, value=path), "Gaussian extraction completed"),
569
- inputs=[model_output],
570
- outputs=[download_gs, message_box]
571
- )
572
 
573
  return demo
574
 
 
163
  return True, None
164
 
165
  @spaces.GPU()
166
+ def generate_image(prompt, seed=DEFAULT_SEED,
167
  randomize_seed=DEFAULT_RANDOMIZE_SEED,
168
  width=DEFAULT_WIDTH,
169
  height=DEFAULT_HEIGHT,
 
251
  return success, status
252
 
253
 
254
+ # def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
255
+ # return {
256
+ # 'gaussian': {
257
+ # **gs.init_params,
258
+ # '_xyz': gs._xyz.cpu().numpy(),
259
+ # '_features_dc': gs._features_dc.cpu().numpy(),
260
+ # '_scaling': gs._scaling.cpu().numpy(),
261
+ # '_rotation': gs._rotation.cpu().numpy(),
262
+ # '_opacity': gs._opacity.cpu().numpy(),
263
+ # },
264
+ # 'mesh': {
265
+ # 'vertices': mesh.vertices.cpu().numpy(),
266
+ # 'faces': mesh.faces.cpu().numpy(),
267
+ # },
268
+ # }
269
 
270
 
271
+ # def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
272
+ # gs = Gaussian(
273
+ # aabb=state['gaussian']['aabb'],
274
+ # sh_degree=state['gaussian']['sh_degree'],
275
+ # mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
276
+ # scaling_bias=state['gaussian']['scaling_bias'],
277
+ # opacity_bias=state['gaussian']['opacity_bias'],
278
+ # scaling_activation=state['gaussian']['scaling_activation'],
279
+ # )
280
+ # gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
281
+ # gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
282
+ # gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
283
+ # gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
284
+ # gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
285
 
286
+ # mesh = edict(
287
+ # vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
288
+ # faces=torch.tensor(state['mesh']['faces'], device='cuda'),
289
+ # )
290
 
291
+ # return gs, mesh
292
+
293
+
294
+ # @spaces.GPU
295
+ # def image_to_3d(
296
+ # image: Image.Image,
297
+ # seed: int,
298
+ # ss_guidance_strength: float,
299
+ # ss_sampling_steps: int,
300
+ # slat_guidance_strength: float,
301
+ # slat_sampling_steps: int,
302
+ # ) -> Tuple[dict, str]:
303
+ # try:
304
+ # # Use a fixed temp directory instead of user-specific
305
+ # temp_dir = os.path.join(TMP_DIR, "temp_output")
306
+ # os.makedirs(temp_dir, exist_ok=True)
307
 
308
+ # # Get the pipeline using the getter function
309
+ # pipeline = get_trellis_pipeline()
310
+ # if pipeline is None:
311
+ # return None, "Trellis pipeline is unavailable."
312
 
313
+ # outputs = pipeline.run(
314
+ # image,
315
+ # seed=seed,
316
+ # formats=["gaussian", "mesh"],
317
+ # preprocess_image=False,
318
+ # sparse_structure_sampler_params={
319
+ # "steps": ss_sampling_steps,
320
+ # "cfg_strength": ss_guidance_strength,
321
+ # },
322
+ # slat_sampler_params={
323
+ # "steps": slat_sampling_steps,
324
+ # "cfg_strength": slat_guidance_strength,
325
+ # },
326
+ # )
327
+
328
+ # video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
329
+ # video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
330
+ # video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
331
+ # video_path = os.path.join(temp_dir, 'sample.mp4')
332
+ # imageio.mimsave(video_path, video, fps=15)
333
+ # state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
334
+ # torch.cuda.empty_cache()
335
+ # return state, video_path
336
+ # except Exception as e:
337
+ # print(f"Error in image_to_3d: {str(e)}")
338
+ # return None, f"Error generating 3D model: {str(e)}"
339
+
340
+
341
+ # @spaces.GPU(duration=90)
342
+ # def extract_glb(
343
+ # state: dict,
344
+ # mesh_simplify: float,
345
+ # texture_size: int,
346
+ # ) -> Tuple[str, str]:
347
+ # """
348
+ # Extract a GLB file from the 3D model.
349
+
350
+ # Args:
351
+ # state (dict): The state of the generated 3D model.
352
+ # mesh_simplify (float): The mesh simplification factor.
353
+ # texture_size (int): The texture resolution.
354
+
355
+ # Returns:
356
+ # str: The path to the extracted GLB file.
357
+ # """
358
+ # temp_dir = os.path.join(TMP_DIR, "temp_output")
359
+ # gs, mesh = unpack_state(state)
360
+ # glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
361
+ # glb_path = os.path.join(temp_dir, 'sample.glb')
362
+ # glb.export(glb_path)
363
+ # torch.cuda.empty_cache()
364
+ # return glb_path, glb_path
365
+
366
+
367
+ # @spaces.GPU
368
+ # def extract_gaussian(state: dict) -> Tuple[str, str]:
369
+ # """
370
+ # Extract a Gaussian file from the 3D model.
371
+
372
+ # Args:
373
+ # state (dict): The state of the generated 3D model.
374
+
375
+ # Returns:
376
+ # str: The path to the extracted Gaussian file.
377
+ # """
378
+ # temp_dir = os.path.join(TMP_DIR, "temp_output")
379
+ # gs, _ = unpack_state(state)
380
+ # gaussian_path = os.path.join(temp_dir, 'sample.ply')
381
+ # gs.save_ply(gaussian_path)
382
+ # torch.cuda.empty_cache()
383
+ # return gaussian_path, gaussian_path
384
 
385
 
386
  # Create a combined function that handles the whole pipeline from example to image
 
435
  visual_button = gr.Button("Create visual with Flux")
436
 
437
  generated_image = gr.Image(show_label=False)
438
+ # gen3d_button = gr.Button("Create 3D visual with Trellis")
439
 
440
+ # video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
441
+ # model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
442
 
443
+ # with gr.Row():
444
+ # download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
445
+ # download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
446
 
447
  message_box = gr.Textbox(
448
  label="Status Messages",
 
487
  value=DEFAULT_NUM_INFERENCE_STEPS,
488
  )
489
 
490
+ # with gr.Tab("3D Generation Settings"):
491
+ # trellis_seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
492
+ # trellis_randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
493
+ # gr.Markdown("Stage 1: Sparse Structure Generation")
494
+ # with gr.Row():
495
+ # ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
496
+ # ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
497
+ # gr.Markdown("Stage 2: Structured Latent Generation")
498
+ # with gr.Row():
499
+ # slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
500
+ # slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
501
+
502
+ # with gr.Tab("GLB Extraction Settings"):
503
+ # mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
504
+ # texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
505
 
506
+ # with gr.Row():
507
+ # extract_glb_btn = gr.Button("Extract GLB", interactive=False)
508
+ # extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
509
+ # gr.Markdown("""
510
+ # *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
511
+ # """)
512
 
513
  output_buf = gr.State()
514
 
 
531
 
532
  gr.on(
533
  triggers=[visual_button.click],
534
+ fn=generate_image,
535
  inputs=[refined_prompt, flux_seed, flux_randomize_seed, width, height, num_inference_steps],
536
  outputs=[generated_image, message_box]
537
  )
538
 
539
+ # gr.on(
540
+ # triggers=[gen3d_button.click],
541
+ # fn=image_to_3d,
542
+ # inputs=[generated_image, trellis_seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
543
+ # outputs=[output_state, video_output],
544
+ # ).then(
545
+ # # Update button states after successful 3D generation
546
+ # lambda: (gr.Button.update(interactive=True), gr.Button.update(interactive=True), "3D model generated successfully"),
547
+ # outputs=[extract_glb_btn, extract_gs_btn, message_box]
548
+ # )
549
 
550
+ # # Add handlers for GLB and Gaussian extraction
551
+ # gr.on(
552
+ # triggers=[extract_glb_btn.click],
553
+ # fn=extract_glb,
554
+ # inputs=[output_state, mesh_simplify, texture_size],
555
+ # outputs=[model_output, download_glb]
556
+ # ).then(
557
+ # lambda path: (gr.DownloadButton.update(interactive=True, value=path), "GLB extraction completed"),
558
+ # inputs=[model_output],
559
+ # outputs=[download_glb, message_box]
560
+ # )
561
+
562
+ # gr.on(
563
+ # triggers=[extract_gs_btn.click],
564
+ # fn=extract_gaussian,
565
+ # inputs=[output_state],
566
+ # outputs=[model_output, download_gs]
567
+ # ).then(
568
+ # lambda path: (gr.DownloadButton.update(interactive=True, value=path), "Gaussian extraction completed"),
569
+ # inputs=[model_output],
570
+ # outputs=[download_gs, message_box]
571
+ # )
572
 
573
  return demo
574