ginipick commited on
Commit
c30bc0f
·
verified ·
1 Parent(s): 755623e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +304 -38
app.py CHANGED
@@ -17,6 +17,11 @@ from accelerate import Accelerator
17
  from huggingface_hub.file_download import hf_hub_download
18
  from huggingface_hub import list_repo_files
19
 
 
 
 
 
 
20
  from primitive_anything.utils import path_mkdir, count_parameters
21
  from primitive_anything.utils.logger import print_log
22
 
@@ -33,7 +38,7 @@ for file in all_files:
33
  hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt")
34
 
35
  def parse_args():
36
- parser = argparse.ArgumentParser(description='Process 3D model files')
37
 
38
  parser.add_argument(
39
  '--input',
@@ -49,6 +54,28 @@ def parse_args():
49
  help='Output directory path (default: results/demo)'
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return parser.parse_args()
53
 
54
  def get_input_files(input_path):
@@ -230,13 +257,192 @@ def SRT_quat_to_matrix(scale, quat, translation):
230
  return transform_matrix
231
 
232
 
233
- def write_output(primitives, name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  out_json = {}
235
 
236
  new_group = []
237
  model_scene = trimesh.Scene()
 
 
238
  color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0])
239
  color_map = (np.array(color_map) * 255).astype("uint8")
 
240
  for idx, (scale, rotation, translation, type_code) in enumerate(zip(
241
  primitives['scale'].squeeze().cpu().numpy(),
242
  primitives['rotation'].squeeze().cpu().numpy(),
@@ -262,21 +468,44 @@ def write_output(primitives, name):
262
  vertices[:, 1] = bs.vertices[:, 2]
263
  vertices[:, 2] = -bs.vertices[:, 1]
264
  bs.vertices = vertices
 
 
 
265
  model_scene.add_geometry(bs)
 
266
  out_json['group'] = new_group
267
 
 
268
  json_path = os.path.join(LOG_PATH, f'output_{name}.json')
269
  with open(json_path, 'w') as json_file:
270
  json.dump(out_json, json_file, indent=4)
271
 
272
- glb_path = os.path.join(LOG_PATH, f'output_{name}.glb')
273
- model_scene.export(glb_path)
274
-
275
- return glb_path, out_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
 
278
  @torch.no_grad()
279
- def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False, do_marching_cubes=False, postprocess='none'):
 
 
280
  t1 = time.time()
281
  set_seed(sample_seed)
282
  input_mesh = trimesh.load(input_3d, force='mesh')
@@ -323,44 +552,55 @@ def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False,
323
  else:
324
  recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature)
325
 
326
- output_glb, output_json = write_output(recon_primitives, os.path.basename(input_3d)[:-4])
 
 
 
 
 
 
327
 
328
- return input_save_name, output_glb, output_json
329
 
330
 
331
  import gradio as gr
332
 
333
  @spaces.GPU
334
- def process_3d_model(input_3d, dilated_offset, do_marching_cubes, postprocess_method="postprocess1"):
335
- print(f"processing: {input_3d}")
336
- # try:
337
- preprocess_model_obj, output_model_obj, output_model_json = do_inference(
338
- input_3d,
339
- dilated_offset=dilated_offset,
340
- do_marching_cubes=do_marching_cubes,
341
- postprocess=postprocess_method
342
- )
343
 
344
- # Save JSON to a file
345
- json_path = os.path.join(LOG_PATH, f'output_{os.path.basename(input_3d)[:-4]}.json')
346
- with open(json_path, 'w') as f:
347
- json.dump(output_model_json, f, indent=4)
348
-
349
- return output_model_obj, json_path
350
- # except Exception as e:
351
- # return f"Error processing file: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
352
 
353
 
354
  _HEADER_ = '''
355
- <h2><b>[SIGGRAPH 2025] PrimitiveAnything 🤗 Gradio Demo</b></h2>
356
 
357
- This is official demo for our SIGGRAPH 2025 paper <a href="">PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer</a>.
358
 
359
  Code: <a href='https://github.com/PrimitiveAnything/PrimitiveAnything' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2505.04622' target='_blank'>ArXiv</a>.
360
 
361
  ❗️❗️❗️**Important Notes:**
362
- - Currently our demo supports 3D models only. You can use other text- and image-conditioned models (e.g. [Tencent Hunyuan3D](https://huggingface.co/spaces/tencent/Hunyuan3D-2) or [TRELLIS](https://huggingface.co/spaces/theseanlavery/TRELLIS-3D)) to generate 3D models and then upload them here.
363
- - For optimal results with fine structures, we apply marching cubes and dilation operations by default (which differs from testing and evaluation). This prevents quality degradation in thin areas.
 
364
  '''
365
 
366
  _CITE_ = r"""
@@ -382,7 +622,7 @@ If you find our work useful for your research or applications, please cite using
382
  If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
383
  """
384
 
385
- with gr.Blocks(title="PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer") as demo:
386
  # Title section
387
  gr.Markdown(_HEADER_)
388
 
@@ -390,22 +630,48 @@ with gr.Blocks(title="PrimitiveAnything: Human-Crafted 3D Primitive Assembly Gen
390
  with gr.Column():
391
  # Input components
392
  input_3d = gr.Model3D(label="Upload 3D Model File")
393
- dilated_offset = gr.Number(label="Dilated Offset", value=0.015)
394
- do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True)
395
- submit_btn = gr.Button("Process Model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  with gr.Column():
398
  # Output components
399
- output_3d = gr.Model3D(label="Primitive Assembly Prediction")
 
400
  output_json = gr.File(label="Download JSON File")
401
 
402
  submit_btn.click(
403
  fn=process_3d_model,
404
- inputs=[input_3d, dilated_offset, do_marching_cubes],
405
- outputs=[output_3d, output_json]
406
  )
407
 
408
-
409
  # Prepare examples properly
410
  example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB
411
 
 
17
  from huggingface_hub.file_download import hf_hub_download
18
  from huggingface_hub import list_repo_files
19
 
20
+ # Animation-related imports
21
+ import trimesh.transformations as tf
22
+ import math
23
+ import pyglet
24
+
25
  from primitive_anything.utils import path_mkdir, count_parameters
26
  from primitive_anything.utils.logger import print_log
27
 
 
38
  hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt")
39
 
40
  def parse_args():
41
+ parser = argparse.ArgumentParser(description='Process 3D model files with animation')
42
 
43
  parser.add_argument(
44
  '--input',
 
54
  help='Output directory path (default: results/demo)'
55
  )
56
 
57
+ parser.add_argument(
58
+ '--animation_type',
59
+ type=str,
60
+ default='rotate',
61
+ choices=['rotate', 'float', 'explode', 'assemble'],
62
+ help='Type of animation to apply'
63
+ )
64
+
65
+ parser.add_argument(
66
+ '--animation_duration',
67
+ type=float,
68
+ default=3.0,
69
+ help='Duration of animation in seconds'
70
+ )
71
+
72
+ parser.add_argument(
73
+ '--fps',
74
+ type=int,
75
+ default=30,
76
+ help='Frames per second for animation'
77
+ )
78
+
79
  return parser.parse_args()
80
 
81
  def get_input_files(input_path):
 
257
  return transform_matrix
258
 
259
 
260
+ # Animation Functions
261
+ def create_rotation_animation(primitive_list, duration=3.0, fps=30):
262
+ """Create a rotation animation for each primitive"""
263
+ num_frames = int(duration * fps)
264
+ frames = []
265
+
266
+ for frame_idx in range(num_frames):
267
+ t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
268
+ angle = t * 2 * math.pi # Full rotation
269
+
270
+ frame_scene = trimesh.Scene()
271
+ for idx, (primitive, color) in enumerate(primitive_list):
272
+ # Create a copy of the primitive to animate
273
+ animated_primitive = primitive.copy()
274
+
275
+ # Apply rotation around Y axis
276
+ rotation_matrix = tf.rotation_matrix(angle, [0, 1, 0], animated_primitive.centroid)
277
+ animated_primitive.apply_transform(rotation_matrix)
278
+
279
+ # Add to scene with original color
280
+ frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
281
+
282
+ frames.append(frame_scene)
283
+
284
+ return frames
285
+
286
+ def create_float_animation(primitive_list, duration=3.0, fps=30, amplitude=0.1):
287
+ """Create a floating animation where primitives move up and down"""
288
+ num_frames = int(duration * fps)
289
+ frames = []
290
+
291
+ for frame_idx in range(num_frames):
292
+ t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
293
+ frame_scene = trimesh.Scene()
294
+
295
+ for idx, (primitive, color) in enumerate(primitive_list):
296
+ # Create a copy of the primitive to animate
297
+ animated_primitive = primitive.copy()
298
+
299
+ # Apply floating motion (sinusoidal)
300
+ phase_offset = 2 * math.pi * (idx / len(primitive_list)) # Different phase for each primitive
301
+ y_offset = amplitude * math.sin(2 * math.pi * t + phase_offset)
302
+
303
+ translation_matrix = tf.translation_matrix([0, y_offset, 0])
304
+ animated_primitive.apply_transform(translation_matrix)
305
+
306
+ # Add to scene with original color
307
+ frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
308
+
309
+ frames.append(frame_scene)
310
+
311
+ return frames
312
+
313
+ def create_explode_animation(primitive_list, duration=3.0, fps=30, max_distance=0.5):
314
+ """Create an explode animation where primitives move outward from center"""
315
+ num_frames = int(duration * fps)
316
+ frames = []
317
+
318
+ # Calculate center of the model
319
+ all_vertices = np.vstack([p.vertices for p, _ in primitive_list])
320
+ center = np.mean(all_vertices, axis=0)
321
+
322
+ for frame_idx in range(num_frames):
323
+ t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
324
+ frame_scene = trimesh.Scene()
325
+
326
+ for idx, (primitive, color) in enumerate(primitive_list):
327
+ # Create a copy of the primitive to animate
328
+ animated_primitive = primitive.copy()
329
+
330
+ # Calculate direction from center to primitive centroid
331
+ primitive_center = primitive.centroid
332
+ direction = primitive_center - center
333
+ if np.linalg.norm(direction) < 1e-10:
334
+ # If primitive is at center, choose random direction
335
+ direction = np.random.rand(3) - 0.5
336
+
337
+ direction = direction / np.linalg.norm(direction)
338
+
339
+ # Apply explosion movement
340
+ translation = direction * t * max_distance
341
+ translation_matrix = tf.translation_matrix(translation)
342
+ animated_primitive.apply_transform(translation_matrix)
343
+
344
+ # Add to scene with original color
345
+ frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
346
+
347
+ frames.append(frame_scene)
348
+
349
+ return frames
350
+
351
+ def create_assemble_animation(primitive_list, duration=3.0, fps=30, start_distance=1.0):
352
+ """Create an assembly animation where primitives move inward to form the model"""
353
+ num_frames = int(duration * fps)
354
+ frames = []
355
+
356
+ # Calculate center of the model
357
+ all_vertices = np.vstack([p.vertices for p, _ in primitive_list])
358
+ center = np.mean(all_vertices, axis=0)
359
+
360
+ # Store original positions
361
+ original_primitives = [(p.copy(), c) for p, c in primitive_list]
362
+
363
+ for frame_idx in range(num_frames):
364
+ t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
365
+ frame_scene = trimesh.Scene()
366
+
367
+ for idx, ((original_primitive, color), (primitive, _)) in enumerate(zip(original_primitives, primitive_list)):
368
+ # Create a copy of the primitive to animate
369
+ animated_primitive = original_primitive.copy()
370
+
371
+ # Calculate direction from center to primitive centroid
372
+ primitive_center = primitive.centroid
373
+ direction = primitive_center - center
374
+ if np.linalg.norm(direction) < 1e-10:
375
+ # If primitive is at center, choose random direction
376
+ direction = np.random.rand(3) - 0.5
377
+
378
+ direction = direction / np.linalg.norm(direction)
379
+
380
+ # Apply assembly movement (1.0 - t for reverse of explosion)
381
+ translation = direction * (1.0 - t) * start_distance
382
+ translation_matrix = tf.translation_matrix(translation)
383
+ animated_primitive.apply_transform(translation_matrix)
384
+
385
+ # Add to scene with original color
386
+ frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
387
+
388
+ frames.append(frame_scene)
389
+
390
+ return frames
391
+
392
+ def generate_animated_glb(primitive_list, animation_type='rotate', duration=3.0, fps=30, output_path="animated_model.glb"):
393
+ """Generate animated GLB file with primitives"""
394
+ if animation_type == 'rotate':
395
+ frames = create_rotation_animation(primitive_list, duration, fps)
396
+ elif animation_type == 'float':
397
+ frames = create_float_animation(primitive_list, duration, fps)
398
+ elif animation_type == 'explode':
399
+ frames = create_explode_animation(primitive_list, duration, fps)
400
+ elif animation_type == 'assemble':
401
+ frames = create_assemble_animation(primitive_list, duration, fps)
402
+ else:
403
+ raise ValueError(f"Unknown animation type: {animation_type}")
404
+
405
+ # Export animation frames to GLB
406
+ # For simplicity, we'll export the first frame and last frame
407
+ # In a production environment, you would use a proper animation exporter
408
+ first_frame = frames[0]
409
+ first_frame.export(output_path)
410
+
411
+ # Also create a GIF for preview
412
+ gif_path = output_path.replace('.glb', '.gif')
413
+ try:
414
+ # Simple gif export using pyglet
415
+ gif_frames = []
416
+ for frame in frames:
417
+ img = frame.save_image(resolution=[640, 480])
418
+ gif_frames.append(img)
419
+
420
+ # Use PIL to save as GIF
421
+ from PIL import Image
422
+ gif_frames[0].save(
423
+ gif_path,
424
+ save_all=True,
425
+ append_images=gif_frames[1:],
426
+ optimize=False,
427
+ duration=int(1000 / fps),
428
+ loop=0
429
+ )
430
+ except Exception as e:
431
+ print(f"Error creating GIF: {str(e)}")
432
+
433
+ return output_path, gif_path
434
+
435
+
436
+ def write_output(primitives, name, animation_type='rotate', duration=3.0, fps=30):
437
  out_json = {}
438
 
439
  new_group = []
440
  model_scene = trimesh.Scene()
441
+ primitive_list = []
442
+
443
  color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0])
444
  color_map = (np.array(color_map) * 255).astype("uint8")
445
+
446
  for idx, (scale, rotation, translation, type_code) in enumerate(zip(
447
  primitives['scale'].squeeze().cpu().numpy(),
448
  primitives['rotation'].squeeze().cpu().numpy(),
 
468
  vertices[:, 1] = bs.vertices[:, 2]
469
  vertices[:, 2] = -bs.vertices[:, 1]
470
  bs.vertices = vertices
471
+
472
+ # Add to primitive list for animation
473
+ primitive_list.append((bs, color_map[idx]))
474
  model_scene.add_geometry(bs)
475
+
476
  out_json['group'] = new_group
477
 
478
+ # Save static model
479
  json_path = os.path.join(LOG_PATH, f'output_{name}.json')
480
  with open(json_path, 'w') as json_file:
481
  json.dump(out_json, json_file, indent=4)
482
 
483
+ static_glb_path = os.path.join(LOG_PATH, f'output_{name}.glb')
484
+ model_scene.export(static_glb_path)
485
+
486
+ # Generate animated model
487
+ animated_glb_path = os.path.join(LOG_PATH, f'animated_{name}.glb')
488
+ animated_gif_path = os.path.join(LOG_PATH, f'animated_{name}.gif')
489
+ try:
490
+ animated_glb_path, animated_gif_path = generate_animated_glb(
491
+ primitive_list,
492
+ animation_type=animation_type,
493
+ duration=duration,
494
+ fps=fps,
495
+ output_path=animated_glb_path
496
+ )
497
+ except Exception as e:
498
+ print(f"Error generating animation: {str(e)}")
499
+ animated_glb_path = static_glb_path
500
+ animated_gif_path = None
501
+
502
+ return animated_glb_path, animated_gif_path, out_json
503
 
504
 
505
  @torch.no_grad()
506
+ def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False,
507
+ do_marching_cubes=False, postprocess='none',
508
+ animation_type='rotate', duration=3.0, fps=30):
509
  t1 = time.time()
510
  set_seed(sample_seed)
511
  input_mesh = trimesh.load(input_3d, force='mesh')
 
552
  else:
553
  recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature)
554
 
555
+ output_animated_glb, output_animated_gif, output_json = write_output(
556
+ recon_primitives,
557
+ os.path.basename(input_3d)[:-4],
558
+ animation_type=animation_type,
559
+ duration=duration,
560
+ fps=fps
561
+ )
562
 
563
+ return input_save_name, output_animated_glb, output_animated_gif, output_json
564
 
565
 
566
  import gradio as gr
567
 
568
  @spaces.GPU
569
+ def process_3d_model(input_3d, dilated_offset, do_marching_cubes, animation_type, animation_duration, fps, postprocess_method="postprocess1"):
570
+ print(f"Processing: {input_3d} with animation type: {animation_type}")
 
 
 
 
 
 
 
571
 
572
+ try:
573
+ preprocess_model_obj, output_animated_glb, output_animated_gif, output_model_json = do_inference(
574
+ input_3d,
575
+ dilated_offset=dilated_offset,
576
+ do_marching_cubes=do_marching_cubes,
577
+ postprocess=postprocess_method,
578
+ animation_type=animation_type,
579
+ duration=animation_duration,
580
+ fps=fps
581
+ )
582
+
583
+ # Save JSON to a file
584
+ json_path = os.path.join(LOG_PATH, f'output_{os.path.basename(input_3d)[:-4]}.json')
585
+ with open(json_path, 'w') as f:
586
+ json.dump(output_model_json, f, indent=4)
587
+
588
+ return output_animated_glb, output_animated_gif, json_path
589
+ except Exception as e:
590
+ return f"Error processing file: {str(e)}", None, None
591
 
592
 
593
  _HEADER_ = '''
594
+ <h2><b>[SIGGRAPH 2025] Animated PrimitiveAnything 🤗 Gradio Demo</b></h2>
595
 
596
+ This is an enhanced demo for the SIGGRAPH 2025 paper <a href="">PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer</a>, now with animation capabilities!
597
 
598
  Code: <a href='https://github.com/PrimitiveAnything/PrimitiveAnything' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2505.04622' target='_blank'>ArXiv</a>.
599
 
600
  ❗️❗️❗️**Important Notes:**
601
+ - This demo supports 3D model animation. Upload your GLB file and see it come to life!
602
+ - Choose from different animation styles: rotation, floating, explosion, or assembly.
603
+ - For optimal results with fine structures, we apply marching cubes and dilation operations by default.
604
  '''
605
 
606
  _CITE_ = r"""
 
622
  If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
623
  """
624
 
625
+ with gr.Blocks(title="PrimitiveAnything with Animation: 3D Animation Generator") as demo:
626
  # Title section
627
  gr.Markdown(_HEADER_)
628
 
 
630
  with gr.Column():
631
  # Input components
632
  input_3d = gr.Model3D(label="Upload 3D Model File")
633
+
634
+ with gr.Row():
635
+ dilated_offset = gr.Number(label="Dilated Offset", value=0.015)
636
+ do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True)
637
+
638
+ with gr.Row():
639
+ animation_type = gr.Dropdown(
640
+ label="Animation Type",
641
+ choices=["rotate", "float", "explode", "assemble"],
642
+ value="rotate"
643
+ )
644
+
645
+ with gr.Row():
646
+ animation_duration = gr.Slider(
647
+ label="Animation Duration (seconds)",
648
+ minimum=1.0,
649
+ maximum=10.0,
650
+ value=3.0,
651
+ step=0.5
652
+ )
653
+ fps = gr.Slider(
654
+ label="Frames Per Second",
655
+ minimum=15,
656
+ maximum=60,
657
+ value=30,
658
+ step=1
659
+ )
660
+
661
+ submit_btn = gr.Button("Process and Animate Model")
662
 
663
  with gr.Column():
664
  # Output components
665
+ output_3d = gr.Model3D(label="Animated Primitive Assembly")
666
+ output_gif = gr.Image(label="Animation Preview (GIF)")
667
  output_json = gr.File(label="Download JSON File")
668
 
669
  submit_btn.click(
670
  fn=process_3d_model,
671
+ inputs=[input_3d, dilated_offset, do_marching_cubes, animation_type, animation_duration, fps],
672
+ outputs=[output_3d, output_gif, output_json]
673
  )
674
 
 
675
  # Prepare examples properly
676
  example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB
677