ginipick commited on
Commit
ac14d96
ยท
verified ยท
1 Parent(s): c30bc0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +327 -532
app.py CHANGED
@@ -2,43 +2,22 @@ import os
2
  import time
3
  import glob
4
  import json
5
- import yaml
6
- import torch
7
  import trimesh
8
  import argparse
9
- import mesh2sdf.core
10
- import numpy as np
11
- import skimage.measure
12
- import seaborn as sns
13
  from scipy.spatial.transform import Rotation
14
- from mesh_to_sdf import get_surface_point_cloud
15
- from accelerate.utils import set_seed
16
- from accelerate import Accelerator
17
- from huggingface_hub.file_download import hf_hub_download
18
- from huggingface_hub import list_repo_files
19
-
20
- # Animation-related imports
21
- import trimesh.transformations as tf
22
  import math
23
- import pyglet
24
-
25
- from primitive_anything.utils import path_mkdir, count_parameters
26
- from primitive_anything.utils.logger import print_log
27
 
28
  os.environ['PYOPENGL_PLATFORM'] = 'egl'
29
 
30
- import spaces
31
-
32
- repo_id = "hyz317/PrimitiveAnything"
33
- all_files = list_repo_files(repo_id, revision="main")
34
- for file in all_files:
35
- if os.path.exists(file):
36
- continue
37
- hf_hub_download(repo_id, file, local_dir="./ckpt")
38
- hf_hub_download("Maikou/Michelangelo", "checkpoints/aligned_shape_latents/shapevae-256.ckpt", local_dir="./ckpt")
39
 
40
  def parse_args():
41
- parser = argparse.ArgumentParser(description='Process 3D model files with animation')
42
 
43
  parser.add_argument(
44
  '--input',
@@ -58,7 +37,7 @@ def parse_args():
58
  '--animation_type',
59
  type=str,
60
  default='rotate',
61
- choices=['rotate', 'float', 'explode', 'assemble'],
62
  help='Type of animation to apply'
63
  )
64
 
@@ -93,595 +72,411 @@ os.makedirs(LOG_PATH, exist_ok=True)
93
 
94
  print(f"Output directory: {LOG_PATH}")
95
 
96
- CODE_SHAPE = {
97
- 0: 'SM_GR_BS_CubeBevel_001.ply',
98
- 1: 'SM_GR_BS_SphereSharp_001.ply',
99
- 2: 'SM_GR_BS_CylinderSharp_001.ply',
100
- }
101
-
102
- shapename_map = {
103
- 'SM_GR_BS_CubeBevel_001.ply': 1101002001034001,
104
- 'SM_GR_BS_SphereSharp_001.ply': 1101002001034010,
105
- 'SM_GR_BS_CylinderSharp_001.ply': 1101002001034002,
106
- }
107
-
108
- #### config
109
- bs_dir = 'data/basic_shapes_norm'
110
- config_path = './configs/infer.yml'
111
- AR_checkpoint_path = './ckpt/mesh-transformer.ckpt.60.pt'
112
- temperature= 0.0
113
- #### init model
114
- mesh_bs = {}
115
- for bs_path in glob.glob(os.path.join(bs_dir, '*.ply')):
116
- bs_name = os.path.basename(bs_path)
117
- bs = trimesh.load(bs_path)
118
- bs.visual.uv = np.clip(bs.visual.uv, 0, 1)
119
- bs.visual = bs.visual.to_color()
120
- mesh_bs[bs_name] = bs
121
-
122
- def create_model(cfg_model):
123
- kwargs = cfg_model
124
- name = kwargs.pop('name')
125
- model = get_model(name)(**kwargs)
126
- print_log("Model '{}' init: nb_params={:,}, kwargs={}".format(name, count_parameters(model), kwargs))
127
- return model
128
-
129
- from primitive_anything.primitive_transformer import PrimitiveTransformerDiscrete
130
- def get_model(name):
131
- return {
132
- 'discrete': PrimitiveTransformerDiscrete,
133
- }[name]
134
-
135
- with open(config_path, mode='r') as fp:
136
- AR_train_cfg = yaml.load(fp, Loader=yaml.FullLoader)
137
-
138
- AR_checkpoint = torch.load(AR_checkpoint_path)
139
-
140
- transformer = create_model(AR_train_cfg['model'])
141
- transformer.load_state_dict(AR_checkpoint)
142
-
143
- device = torch.device('cuda')
144
- accelerator = Accelerator(
145
- mixed_precision='fp16',
146
- )
147
- transformer = accelerator.prepare(transformer)
148
- transformer.eval()
149
- transformer.bs_pc = transformer.bs_pc.cuda()
150
- transformer.rotation_matrix_align_coord = transformer.rotation_matrix_align_coord.cuda()
151
- print('model loaded to device')
152
-
153
-
154
- def sample_surface_points(mesh, number_of_points=500000, surface_point_method='scan', sign_method='normal',
155
- scan_count=100, scan_resolution=400, sample_point_count=10000000, return_gradients=False,
156
- return_surface_pc_normals=False, normalized=False):
157
- sample_start = time.time()
158
- if surface_point_method == 'sample' and sign_method == 'depth':
159
- print("Incompatible methods for sampling points and determining sign, using sign_method='normal' instead.")
160
- sign_method = 'normal'
161
-
162
- surface_start = time.time()
163
- bound_radius = 1 if normalized else None
164
- surface_point_cloud = get_surface_point_cloud(mesh, surface_point_method, bound_radius, scan_count, scan_resolution,
165
- sample_point_count,
166
- calculate_normals=sign_method == 'normal' or return_gradients)
167
-
168
- surface_end = time.time()
169
- print('surface point cloud time cost :', surface_end - surface_start)
170
-
171
- normal_start = time.time()
172
- if return_surface_pc_normals:
173
- rng = np.random.default_rng()
174
- assert surface_point_cloud.points.shape[0] == surface_point_cloud.normals.shape[0]
175
- indices = rng.choice(surface_point_cloud.points.shape[0], number_of_points, replace=True)
176
- points = surface_point_cloud.points[indices]
177
- normals = surface_point_cloud.normals[indices]
178
- surface_points = np.concatenate([points, normals], axis=-1)
179
- else:
180
- surface_points = surface_point_cloud.get_random_surface_points(number_of_points, use_scans=True)
181
- normal_end = time.time()
182
- print('normal time cost :', normal_end - normal_start)
183
- sample_end = time.time()
184
- print('sample surface point time cost :', sample_end - sample_start)
185
- return surface_points
186
-
187
-
188
- def normalize_vertices(vertices, scale=0.9):
189
- bbmin, bbmax = vertices.min(0), vertices.max(0)
190
- center = (bbmin + bbmax) * 0.5
191
- scale = 2.0 * scale / (bbmax - bbmin).max()
192
- vertices = (vertices - center) * scale
193
- return vertices, center, scale
194
-
195
-
196
- def export_to_watertight(normalized_mesh, octree_depth: int = 7):
197
- """
198
- Convert the non-watertight mesh to watertight.
199
-
200
- Args:
201
- input_path (str): normalized path
202
- octree_depth (int):
203
-
204
- Returns:
205
- mesh(trimesh.Trimesh): watertight mesh
206
-
207
- """
208
- size = 2 ** octree_depth
209
- level = 2 / size
210
-
211
- scaled_vertices, to_orig_center, to_orig_scale = normalize_vertices(normalized_mesh.vertices)
212
- sdf = mesh2sdf.core.compute(scaled_vertices, normalized_mesh.faces, size=size)
213
- vertices, faces, normals, _ = skimage.measure.marching_cubes(np.abs(sdf), level)
214
-
215
- # watertight mesh
216
- vertices = vertices / size * 2 - 1 # -1 to 1
217
- vertices = vertices / to_orig_scale + to_orig_center
218
- mesh = trimesh.Trimesh(vertices, faces, normals=normals)
219
-
220
- return mesh
221
-
222
-
223
- def process_mesh_to_surface_pc(mesh_list, marching_cubes=False, dilated_offset=0.0, sample_num=10000):
224
- # mesh_list : list of trimesh
225
- pc_normal_list = []
226
- return_mesh_list = []
227
- for mesh in mesh_list:
228
- if marching_cubes:
229
- mesh = export_to_watertight(mesh)
230
- print("MC over!")
231
- if dilated_offset > 0:
232
- new_vertices = mesh.vertices + mesh.vertex_normals * dilated_offset
233
- mesh.vertices = new_vertices
234
- print("dilate over!")
235
-
236
- mesh.merge_vertices()
237
- mesh.update_faces(mesh.unique_faces())
238
- mesh.fix_normals()
239
-
240
- return_mesh_list.append(mesh)
241
-
242
- pc_normal = np.asarray(sample_surface_points(mesh, sample_num, return_surface_pc_normals=True))
243
- pc_normal_list.append(pc_normal)
244
- print("process mesh success")
245
- return pc_normal_list, return_mesh_list
246
-
247
-
248
- #### utils
249
- def euler_to_quat(euler):
250
- return Rotation.from_euler('XYZ', euler, degrees=True).as_quat()
251
-
252
- def SRT_quat_to_matrix(scale, quat, translation):
253
- rotation_matrix = Rotation.from_quat(quat).as_matrix()
254
- transform_matrix = np.eye(4)
255
- transform_matrix[:3, :3] = rotation_matrix * scale
256
- transform_matrix[:3, 3] = translation
257
- return transform_matrix
258
-
259
 
260
- # Animation Functions
261
- def create_rotation_animation(primitive_list, duration=3.0, fps=30):
262
- """Create a rotation animation for each primitive"""
263
  num_frames = int(duration * fps)
264
  frames = []
265
 
 
 
 
266
  for frame_idx in range(num_frames):
267
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
268
  angle = t * 2 * math.pi # Full rotation
269
 
270
- frame_scene = trimesh.Scene()
271
- for idx, (primitive, color) in enumerate(primitive_list):
272
- # Create a copy of the primitive to animate
273
- animated_primitive = primitive.copy()
274
-
275
- # Apply rotation around Y axis
276
- rotation_matrix = tf.rotation_matrix(angle, [0, 1, 0], animated_primitive.centroid)
277
- animated_primitive.apply_transform(rotation_matrix)
278
-
279
- # Add to scene with original color
280
- frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
281
 
282
- frames.append(frame_scene)
 
 
 
 
 
283
 
284
  return frames
285
 
286
- def create_float_animation(primitive_list, duration=3.0, fps=30, amplitude=0.1):
287
- """Create a floating animation where primitives move up and down"""
288
  num_frames = int(duration * fps)
289
  frames = []
290
 
 
 
 
291
  for frame_idx in range(num_frames):
292
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
293
- frame_scene = trimesh.Scene()
294
 
295
- for idx, (primitive, color) in enumerate(primitive_list):
296
- # Create a copy of the primitive to animate
297
- animated_primitive = primitive.copy()
298
-
299
- # Apply floating motion (sinusoidal)
300
- phase_offset = 2 * math.pi * (idx / len(primitive_list)) # Different phase for each primitive
301
- y_offset = amplitude * math.sin(2 * math.pi * t + phase_offset)
302
-
303
- translation_matrix = tf.translation_matrix([0, y_offset, 0])
304
- animated_primitive.apply_transform(translation_matrix)
305
-
306
- # Add to scene with original color
307
- frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
308
 
309
- frames.append(frame_scene)
 
 
 
 
 
 
310
 
311
  return frames
312
 
313
- def create_explode_animation(primitive_list, duration=3.0, fps=30, max_distance=0.5):
314
- """Create an explode animation where primitives move outward from center"""
315
  num_frames = int(duration * fps)
316
  frames = []
317
 
318
- # Calculate center of the model
319
- all_vertices = np.vstack([p.vertices for p, _ in primitive_list])
320
- center = np.mean(all_vertices, axis=0)
 
 
 
 
 
 
 
 
 
321
 
322
  for frame_idx in range(num_frames):
323
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
324
- frame_scene = trimesh.Scene()
325
 
326
- for idx, (primitive, color) in enumerate(primitive_list):
327
- # Create a copy of the primitive to animate
328
- animated_primitive = primitive.copy()
329
 
330
- # Calculate direction from center to primitive centroid
331
- primitive_center = primitive.centroid
332
- direction = primitive_center - center
333
- if np.linalg.norm(direction) < 1e-10:
334
- # If primitive is at center, choose random direction
335
- direction = np.random.rand(3) - 0.5
 
 
 
 
 
 
336
 
337
- direction = direction / np.linalg.norm(direction)
 
 
 
 
 
 
338
 
339
- # Apply explosion movement
340
- translation = direction * t * max_distance
341
- translation_matrix = tf.translation_matrix(translation)
342
- animated_primitive.apply_transform(translation_matrix)
 
 
343
 
344
- # Add to scene with original color
345
- frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
 
 
 
 
 
 
 
 
346
 
347
- frames.append(frame_scene)
 
348
 
349
  return frames
350
 
351
- def create_assemble_animation(primitive_list, duration=3.0, fps=30, start_distance=1.0):
352
- """Create an assembly animation where primitives move inward to form the model"""
 
 
 
 
 
 
353
  num_frames = int(duration * fps)
354
  frames = []
355
 
356
- # Calculate center of the model
357
- all_vertices = np.vstack([p.vertices for p, _ in primitive_list])
358
- center = np.mean(all_vertices, axis=0)
359
-
360
- # Store original positions
361
- original_primitives = [(p.copy(), c) for p, c in primitive_list]
362
 
363
  for frame_idx in range(num_frames):
364
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
365
- frame_scene = trimesh.Scene()
366
 
367
- for idx, ((original_primitive, color), (primitive, _)) in enumerate(zip(original_primitives, primitive_list)):
368
- # Create a copy of the primitive to animate
369
- animated_primitive = original_primitive.copy()
370
-
371
- # Calculate direction from center to primitive centroid
372
- primitive_center = primitive.centroid
373
- direction = primitive_center - center
374
- if np.linalg.norm(direction) < 1e-10:
375
- # If primitive is at center, choose random direction
376
- direction = np.random.rand(3) - 0.5
377
-
378
- direction = direction / np.linalg.norm(direction)
379
-
380
- # Apply assembly movement (1.0 - t for reverse of explosion)
381
- translation = direction * (1.0 - t) * start_distance
382
- translation_matrix = tf.translation_matrix(translation)
383
- animated_primitive.apply_transform(translation_matrix)
384
-
385
- # Add to scene with original color
386
- frame_scene.add_geometry(animated_primitive, node_name=f'primitive_{idx}')
387
 
388
- frames.append(frame_scene)
 
389
 
390
  return frames
391
 
392
- def generate_animated_glb(primitive_list, animation_type='rotate', duration=3.0, fps=30, output_path="animated_model.glb"):
393
- """Generate animated GLB file with primitives"""
394
- if animation_type == 'rotate':
395
- frames = create_rotation_animation(primitive_list, duration, fps)
396
- elif animation_type == 'float':
397
- frames = create_float_animation(primitive_list, duration, fps)
398
- elif animation_type == 'explode':
399
- frames = create_explode_animation(primitive_list, duration, fps)
400
- elif animation_type == 'assemble':
401
- frames = create_assemble_animation(primitive_list, duration, fps)
402
- else:
403
- raise ValueError(f"Unknown animation type: {animation_type}")
404
 
405
- # Export animation frames to GLB
406
- # For simplicity, we'll export the first frame and last frame
407
- # In a production environment, you would use a proper animation exporter
408
- first_frame = frames[0]
409
- first_frame.export(output_path)
410
 
411
- # Also create a GIF for preview
412
- gif_path = output_path.replace('.glb', '.gif')
413
- try:
414
- # Simple gif export using pyglet
415
- gif_frames = []
416
- for frame in frames:
417
- img = frame.save_image(resolution=[640, 480])
418
- gif_frames.append(img)
419
 
420
- # Use PIL to save as GIF
421
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  gif_frames[0].save(
423
- gif_path,
424
  save_all=True,
425
  append_images=gif_frames[1:],
426
  optimize=False,
427
  duration=int(1000 / fps),
428
  loop=0
429
  )
 
 
 
 
 
 
 
 
 
430
  except Exception as e:
431
- print(f"Error creating GIF: {str(e)}")
 
432
 
433
- return output_path, gif_path
434
-
435
-
436
- def write_output(primitives, name, animation_type='rotate', duration=3.0, fps=30):
437
- out_json = {}
438
-
439
- new_group = []
440
- model_scene = trimesh.Scene()
441
- primitive_list = []
442
-
443
- color_map = sns.color_palette("hls", primitives['type_code'].squeeze().shape[0])
444
- color_map = (np.array(color_map) * 255).astype("uint8")
445
-
446
- for idx, (scale, rotation, translation, type_code) in enumerate(zip(
447
- primitives['scale'].squeeze().cpu().numpy(),
448
- primitives['rotation'].squeeze().cpu().numpy(),
449
- primitives['translation'].squeeze().cpu().numpy(),
450
- primitives['type_code'].squeeze().cpu().numpy()
451
- )):
452
- if type_code == -1:
453
- break
454
- bs_name = CODE_SHAPE[type_code]
455
- new_block = {}
456
- new_block['type_id'] = shapename_map[bs_name]
457
- new_block['data'] = {}
458
- new_block['data']['location'] = translation.tolist()
459
- new_block['data']['rotation'] = euler_to_quat(rotation).tolist()
460
- new_block['data']['scale'] = scale.tolist()
461
- new_group.append(new_block)
462
-
463
- trans = SRT_quat_to_matrix(scale, euler_to_quat(rotation), translation)
464
- bs = mesh_bs[bs_name].copy().apply_transform(trans)
465
- new_vertex_colors = np.repeat(color_map[idx:idx+1], bs.visual.vertex_colors.shape[0], axis=0)
466
- bs.visual.vertex_colors[:, :3] = new_vertex_colors
467
- vertices = bs.vertices.copy()
468
- vertices[:, 1] = bs.vertices[:, 2]
469
- vertices[:, 2] = -bs.vertices[:, 1]
470
- bs.vertices = vertices
471
-
472
- # Add to primitive list for animation
473
- primitive_list.append((bs, color_map[idx]))
474
- model_scene.add_geometry(bs)
475
 
476
- out_json['group'] = new_group
477
-
478
- # Save static model
479
- json_path = os.path.join(LOG_PATH, f'output_{name}.json')
480
- with open(json_path, 'w') as json_file:
481
- json.dump(out_json, json_file, indent=4)
482
-
483
- static_glb_path = os.path.join(LOG_PATH, f'output_{name}.glb')
484
- model_scene.export(static_glb_path)
 
 
 
 
485
 
486
- # Generate animated model
487
- animated_glb_path = os.path.join(LOG_PATH, f'animated_{name}.glb')
488
- animated_gif_path = os.path.join(LOG_PATH, f'animated_{name}.gif')
489
  try:
490
- animated_glb_path, animated_gif_path = generate_animated_glb(
491
- primitive_list,
492
- animation_type=animation_type,
493
- duration=duration,
494
- fps=fps,
495
- output_path=animated_glb_path
496
- )
497
  except Exception as e:
498
- print(f"Error generating animation: {str(e)}")
499
- animated_glb_path = static_glb_path
500
  animated_gif_path = None
 
 
501
 
502
- return animated_glb_path, animated_gif_path, out_json
503
-
504
-
505
- @torch.no_grad()
506
- def do_inference(input_3d, dilated_offset=0.0, sample_seed=0, do_sampling=False,
507
- do_marching_cubes=False, postprocess='none',
508
- animation_type='rotate', duration=3.0, fps=30):
509
- t1 = time.time()
510
- set_seed(sample_seed)
511
- input_mesh = trimesh.load(input_3d, force='mesh')
512
-
513
- # scale mesh
514
- vertices = input_mesh.vertices
515
- bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
516
- vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
517
- vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
518
- input_mesh.vertices = vertices
519
-
520
- pc_list, mesh_list = process_mesh_to_surface_pc(
521
- [input_mesh],
522
- marching_cubes=do_marching_cubes,
523
- dilated_offset=dilated_offset
524
- )
525
- pc_normal = pc_list[0] # 10000, 6
526
- mesh = mesh_list[0]
527
-
528
- pc_coor = pc_normal[:, :3]
529
- normals = pc_normal[:, 3:]
530
-
531
- if dilated_offset > 0:
532
- # scale mesh and pc
533
- vertices = mesh.vertices
534
- bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
535
- vertices = vertices - (bounds[0] + bounds[1])[None, :] / 2
536
- vertices = vertices / (bounds[1] - bounds[0]).max() * 1.6
537
- mesh.vertices = vertices
538
- pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
539
- pc_coor = pc_coor / (bounds[1] - bounds[0]).max() * 1.6
540
-
541
- input_save_name = os.path.join(LOG_PATH, f'processed_{os.path.basename(input_3d)}')
542
- mesh.export(input_save_name)
543
-
544
- assert (np.linalg.norm(normals, axis=-1) > 0.99).all(), 'normals should be unit vectors, something wrong'
545
- normalized_pc_normal = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
546
-
547
- input_pc = torch.tensor(normalized_pc_normal, dtype=torch.float16, device=device)[None]
548
-
549
- with accelerator.autocast():
550
- if postprocess == 'postprocess1':
551
- recon_primitives, mask = transformer.generate_w_recon_loss(pc=input_pc, temperature=temperature, single_directional=True)
552
- else:
553
- recon_primitives, mask = transformer.generate(pc=input_pc, temperature=temperature)
554
-
555
- output_animated_glb, output_animated_gif, output_json = write_output(
556
- recon_primitives,
557
- os.path.basename(input_3d)[:-4],
558
- animation_type=animation_type,
559
- duration=duration,
560
- fps=fps
561
- )
562
-
563
- return input_save_name, output_animated_glb, output_animated_gif, output_json
564
-
565
-
566
- import gradio as gr
567
-
568
- @spaces.GPU
569
- def process_3d_model(input_3d, dilated_offset, do_marching_cubes, animation_type, animation_duration, fps, postprocess_method="postprocess1"):
570
  print(f"Processing: {input_3d} with animation type: {animation_type}")
571
 
572
  try:
573
- preprocess_model_obj, output_animated_glb, output_animated_gif, output_model_json = do_inference(
 
574
  input_3d,
575
- dilated_offset=dilated_offset,
576
- do_marching_cubes=do_marching_cubes,
577
- postprocess=postprocess_method,
578
  animation_type=animation_type,
579
  duration=animation_duration,
580
  fps=fps
581
  )
582
 
583
- # Save JSON to a file
584
- json_path = os.path.join(LOG_PATH, f'output_{os.path.basename(input_3d)[:-4]}.json')
 
 
 
 
 
 
 
 
 
 
 
585
  with open(json_path, 'w') as f:
586
- json.dump(output_model_json, f, indent=4)
587
 
588
- return output_animated_glb, output_animated_gif, json_path
589
  except Exception as e:
590
- return f"Error processing file: {str(e)}", None, None
591
-
 
592
 
593
  _HEADER_ = '''
594
- <h2><b>[SIGGRAPH 2025] Animated PrimitiveAnything ๐Ÿค— Gradio Demo</b></h2>
595
 
596
- This is an enhanced demo for the SIGGRAPH 2025 paper <a href="">PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer</a>, now with animation capabilities!
597
 
598
- Code: <a href='https://github.com/PrimitiveAnything/PrimitiveAnything' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2505.04622' target='_blank'>ArXiv</a>.
599
-
600
- โ—๏ธโ—๏ธโ—๏ธ**Important Notes:**
601
- - This demo supports 3D model animation. Upload your GLB file and see it come to life!
602
- - Choose from different animation styles: rotation, floating, explosion, or assembly.
603
- - For optimal results with fine structures, we apply marching cubes and dilation operations by default.
604
  '''
605
 
606
- _CITE_ = r"""
607
- If PrimitiveAnything is helpful, please help to โญ the <a href='https://github.com/PrimitiveAnything/PrimitiveAnything' target='_blank'>GitHub Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/PrimitiveAnything/PrimitiveAnything?style=social)](https://github.com/PrimitiveAnything/PrimitiveAnything)
608
- ---
609
- ๐Ÿ“ **Citation**
610
- If you find our work useful for your research or applications, please cite using this bibtex:
611
- ```bibtex
612
- @misc{ye2025primitiveanything,
613
- title={PrimitiveAnything: Human-Crafted 3D Primitive Assembly Generation with Auto-Regressive Transformer},
614
- author={Jingwen Ye and Yuze He and Yanning Zhou and Yiqin Zhu and Kaiwen Xiao and Yong-Jin Liu and Wei Yang and Xiao Han},
615
- year={2025},
616
- eprint={2505.04622},
617
- archivePrefix={arXiv},
618
- primaryClass={cs.GR}
619
- }
620
- ```
621
- ๐Ÿ“ง **Contact**
622
- If you have any questions, feel free to open a discussion or contact us at <b>[email protected]</b>.
623
  """
624
 
625
- with gr.Blocks(title="PrimitiveAnything with Animation: 3D Animation Generator") as demo:
626
- # Title section
627
- gr.Markdown(_HEADER_)
628
-
629
- with gr.Row():
630
- with gr.Column():
631
- # Input components
632
- input_3d = gr.Model3D(label="Upload 3D Model File")
633
-
634
- with gr.Row():
635
- dilated_offset = gr.Number(label="Dilated Offset", value=0.015)
636
- do_marching_cubes = gr.Checkbox(label="Perform Marching Cubes", value=True)
637
-
638
- with gr.Row():
639
- animation_type = gr.Dropdown(
640
- label="Animation Type",
641
- choices=["rotate", "float", "explode", "assemble"],
642
- value="rotate"
643
- )
644
-
645
- with gr.Row():
646
- animation_duration = gr.Slider(
647
- label="Animation Duration (seconds)",
648
- minimum=1.0,
649
- maximum=10.0,
650
- value=3.0,
651
- step=0.5
652
- )
653
- fps = gr.Slider(
654
- label="Frames Per Second",
655
- minimum=15,
656
- maximum=60,
657
- value=30,
658
- step=1
659
- )
660
-
661
- submit_btn = gr.Button("Process and Animate Model")
662
-
663
- with gr.Column():
664
- # Output components
665
- output_3d = gr.Model3D(label="Animated Primitive Assembly")
666
- output_gif = gr.Image(label="Animation Preview (GIF)")
667
- output_json = gr.File(label="Download JSON File")
668
-
669
- submit_btn.click(
670
- fn=process_3d_model,
671
- inputs=[input_3d, dilated_offset, do_marching_cubes, animation_type, animation_duration, fps],
672
- outputs=[output_3d, output_gif, output_json]
673
- )
 
 
 
 
 
 
 
 
 
 
 
 
674
 
675
- # Prepare examples properly
676
- example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ] # Note: wrapped in list and filtered for GLB
677
-
678
- example = gr.Examples(
679
- examples=example_files,
680
- inputs=[input_3d], # Only include the Model3D input
681
- examples_per_page=14,
682
- )
683
-
684
- gr.Markdown(_CITE_)
685
 
 
686
  if __name__ == "__main__":
687
- demo.launch(ssr_mode=False)
 
 
2
  import time
3
  import glob
4
  import json
5
+ import numpy as np
 
6
  import trimesh
7
  import argparse
 
 
 
 
8
  from scipy.spatial.transform import Rotation
9
+ import PIL.Image
10
+ from PIL import Image
 
 
 
 
 
 
11
  import math
12
+ import trimesh.transformations as tf
13
+ from trimesh.exchange.gltf import export_glb
 
 
14
 
15
  os.environ['PYOPENGL_PLATFORM'] = 'egl'
16
 
17
+ import gradio as gr
 
 
 
 
 
 
 
 
18
 
19
  def parse_args():
20
+ parser = argparse.ArgumentParser(description='Create animations for 3D models')
21
 
22
  parser.add_argument(
23
  '--input',
 
37
  '--animation_type',
38
  type=str,
39
  default='rotate',
40
+ choices=['rotate', 'float', 'explode', 'assemble', 'pulse', 'swing'],
41
  help='Type of animation to apply'
42
  )
43
 
 
72
 
73
  print(f"Output directory: {LOG_PATH}")
74
 
75
+ def normalize_mesh(mesh):
76
+ """Normalize mesh to fit in a unit cube centered at origin"""
77
+ vertices = mesh.vertices
78
+ bounds = np.array([vertices.min(axis=0), vertices.max(axis=0)])
79
+ center = (bounds[0] + bounds[1]) / 2
80
+ scale = 1.0 / (bounds[1] - bounds[0]).max()
81
+
82
+ # Create a copy to avoid modifying the original
83
+ normalized_mesh = mesh.copy()
84
+ normalized_mesh.vertices = (vertices - center) * scale
85
+
86
+ return normalized_mesh, center, scale
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def create_rotation_animation(mesh, duration=3.0, fps=30):
89
+ """Create a rotation animation around the Y axis"""
 
90
  num_frames = int(duration * fps)
91
  frames = []
92
 
93
+ # Normalize the mesh for consistent animation
94
+ mesh, original_center, original_scale = normalize_mesh(mesh)
95
+
96
  for frame_idx in range(num_frames):
97
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
98
  angle = t * 2 * math.pi # Full rotation
99
 
100
+ # Create a copy of the mesh to animate
101
+ animated_mesh = mesh.copy()
 
 
 
 
 
 
 
 
 
102
 
103
+ # Apply rotation around Y axis
104
+ rotation_matrix = tf.rotation_matrix(angle, [0, 1, 0])
105
+ animated_mesh.apply_transform(rotation_matrix)
106
+
107
+ # Add to frames
108
+ frames.append(animated_mesh)
109
 
110
  return frames
111
 
112
+ def create_float_animation(mesh, duration=3.0, fps=30, amplitude=0.2):
113
+ """Create a floating animation where the mesh moves up and down"""
114
  num_frames = int(duration * fps)
115
  frames = []
116
 
117
+ # Normalize the mesh for consistent animation
118
+ mesh, original_center, original_scale = normalize_mesh(mesh)
119
+
120
  for frame_idx in range(num_frames):
121
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
 
122
 
123
+ # Create a copy of the mesh to animate
124
+ animated_mesh = mesh.copy()
 
 
 
 
 
 
 
 
 
 
 
125
 
126
+ # Apply floating motion (sinusoidal)
127
+ y_offset = amplitude * math.sin(2 * math.pi * t)
128
+ translation_matrix = tf.translation_matrix([0, y_offset, 0])
129
+ animated_mesh.apply_transform(translation_matrix)
130
+
131
+ # Add to frames
132
+ frames.append(animated_mesh)
133
 
134
  return frames
135
 
136
+ def create_explode_animation(mesh, duration=3.0, fps=30):
137
+ """Create an explode animation where parts of the mesh move outward"""
138
  num_frames = int(duration * fps)
139
  frames = []
140
 
141
+ # Normalize the mesh for consistent animation
142
+ mesh, original_center, original_scale = normalize_mesh(mesh)
143
+
144
+ # Split the mesh into components
145
+ # If the mesh can't be split, we'll just move vertices outward
146
+ try:
147
+ components = mesh.split(only_watertight=False)
148
+ if len(components) <= 1:
149
+ raise ValueError("Mesh cannot be split into components")
150
+ except:
151
+ # If splitting fails, work with the original mesh
152
+ components = None
153
 
154
  for frame_idx in range(num_frames):
155
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
 
156
 
157
+ if components:
158
+ # Create a scene to hold all components
159
+ scene = trimesh.Scene()
160
 
161
+ # Move each component outward from center
162
+ for component in components:
163
+ # Create a copy of the component
164
+ animated_component = component.copy()
165
+
166
+ # Calculate direction from center to component centroid
167
+ direction = animated_component.centroid
168
+ if np.linalg.norm(direction) < 1e-10:
169
+ # If component is at center, choose random direction
170
+ direction = np.random.rand(3) - 0.5
171
+
172
+ direction = direction / np.linalg.norm(direction)
173
 
174
+ # Apply explosion movement
175
+ translation = direction * t * 0.5 # Scale factor for explosion
176
+ translation_matrix = tf.translation_matrix(translation)
177
+ animated_component.apply_transform(translation_matrix)
178
+
179
+ # Add to scene
180
+ scene.add_geometry(animated_component)
181
 
182
+ # Convert scene to mesh (approximation)
183
+ animated_mesh = trimesh.util.concatenate(scene.dump())
184
+ else:
185
+ # Work with vertices directly if components approach failed
186
+ animated_mesh = mesh.copy()
187
+ vertices = animated_mesh.vertices.copy()
188
 
189
+ # Calculate directions from center (0,0,0) to each vertex
190
+ directions = vertices.copy()
191
+ norms = np.linalg.norm(directions, axis=1, keepdims=True)
192
+ mask = norms > 1e-10
193
+ directions[mask] = directions[mask] / norms[mask]
194
+ directions[~mask] = np.random.rand(np.sum(~mask), 3) - 0.5
195
+
196
+ # Apply explosion factor
197
+ vertices += directions * t * 0.3
198
+ animated_mesh.vertices = vertices
199
 
200
+ # Add to frames
201
+ frames.append(animated_mesh)
202
 
203
  return frames
204
 
205
+ def create_assemble_animation(mesh, duration=3.0, fps=30):
206
+ """Create an assembly animation (reverse of explode)"""
207
+ # Get explode animation and reverse it
208
+ explode_frames = create_explode_animation(mesh, duration, fps)
209
+ return list(reversed(explode_frames))
210
+
211
+ def create_pulse_animation(mesh, duration=3.0, fps=30, min_scale=0.8, max_scale=1.2):
212
+ """Create a pulsing animation where the mesh scales up and down"""
213
  num_frames = int(duration * fps)
214
  frames = []
215
 
216
+ # Normalize the mesh for consistent animation
217
+ mesh, original_center, original_scale = normalize_mesh(mesh)
 
 
 
 
218
 
219
  for frame_idx in range(num_frames):
220
  t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
 
221
 
222
+ # Create a copy of the mesh to animate
223
+ animated_mesh = mesh.copy()
224
+
225
+ # Apply pulsing motion (sinusoidal scale)
226
+ scale_factor = min_scale + (max_scale - min_scale) * (0.5 + 0.5 * math.sin(2 * math.pi * t))
227
+ scale_matrix = tf.scale_matrix(scale_factor)
228
+ animated_mesh.apply_transform(scale_matrix)
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ # Add to frames
231
+ frames.append(animated_mesh)
232
 
233
  return frames
234
 
235
+ def create_swing_animation(mesh, duration=3.0, fps=30, max_angle=math.pi/6):
236
+ """Create a swinging animation where the mesh rotates back and forth"""
237
+ num_frames = int(duration * fps)
238
+ frames = []
 
 
 
 
 
 
 
 
239
 
240
+ # Normalize the mesh for consistent animation
241
+ mesh, original_center, original_scale = normalize_mesh(mesh)
 
 
 
242
 
243
+ for frame_idx in range(num_frames):
244
+ t = frame_idx / (num_frames - 1) # Normalized time [0, 1]
245
+
246
+ # Create a copy of the mesh to animate
247
+ animated_mesh = mesh.copy()
 
 
 
248
 
249
+ # Apply swinging motion (sinusoidal rotation)
250
+ angle = max_angle * math.sin(2 * math.pi * t)
251
+ rotation_matrix = tf.rotation_matrix(angle, [0, 1, 0])
252
+ animated_mesh.apply_transform(rotation_matrix)
253
+
254
+ # Add to frames
255
+ frames.append(animated_mesh)
256
+
257
+ return frames
258
+
259
+ def generate_gif_from_frames(frames, output_path, fps=30, resolution=(640, 480), background_color=(255, 255, 255, 255)):
260
+ """Generate a GIF from animation frames"""
261
+ gif_frames = []
262
+
263
+ for frame in frames:
264
+ # Create a scene with the frame
265
+ scene = trimesh.Scene(frame)
266
+
267
+ # Set camera and rendering parameters
268
+ try:
269
+ # Try to get a good view of the object
270
+ scene.camera_transform = scene.camera_transform
271
+ except:
272
+ # If that fails, use a default camera position
273
+ scene.camera_transform = tf.translation_matrix([0, 0, 2])
274
+
275
+ # Render the frame
276
+ try:
277
+ img = scene.save_image(resolution=resolution, background=background_color)
278
+ gif_frames.append(Image.open(img))
279
+ except Exception as e:
280
+ print(f"Error rendering frame: {str(e)}")
281
+ # Create a blank image if rendering fails
282
+ gif_frames.append(Image.new('RGB', resolution, (255, 255, 255)))
283
+
284
+ # Save as GIF
285
+ if gif_frames:
286
  gif_frames[0].save(
287
+ output_path,
288
  save_all=True,
289
  append_images=gif_frames[1:],
290
  optimize=False,
291
  duration=int(1000 / fps),
292
  loop=0
293
  )
294
+ return output_path
295
+ else:
296
+ return None
297
+
298
+ def create_animation_mesh(input_mesh_path, animation_type='rotate', duration=3.0, fps=30):
299
+ """Create animation from input mesh based on animation type"""
300
+ # Load the mesh
301
+ try:
302
+ mesh = trimesh.load(input_mesh_path)
303
  except Exception as e:
304
+ print(f"Error loading mesh: {str(e)}")
305
+ return None, None
306
 
307
+ # Generate animation frames based on animation type
308
+ if animation_type == 'rotate':
309
+ frames = create_rotation_animation(mesh, duration, fps)
310
+ elif animation_type == 'float':
311
+ frames = create_float_animation(mesh, duration, fps)
312
+ elif animation_type == 'explode':
313
+ frames = create_explode_animation(mesh, duration, fps)
314
+ elif animation_type == 'assemble':
315
+ frames = create_assemble_animation(mesh, duration, fps)
316
+ elif animation_type == 'pulse':
317
+ frames = create_pulse_animation(mesh, duration, fps)
318
+ elif animation_type == 'swing':
319
+ frames = create_swing_animation(mesh, duration, fps)
320
+ else:
321
+ print(f"Unknown animation type: {animation_type}")
322
+ return None, None
323
+
324
+ base_filename = os.path.basename(input_mesh_path).rsplit('.', 1)[0]
325
+
326
+ # Save animated mesh as GLB
327
+ try:
328
+ animated_glb_path = os.path.join(LOG_PATH, f'animated_{base_filename}.glb')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
+ # For GLB output, we'll use the first frame for now
331
+ # In a production environment, you'd want to use proper animation keyframes
332
+ if frames and len(frames) > 0:
333
+ # First frame for static GLB
334
+ first_frame = frames[0]
335
+ # Export as GLB
336
+ scene = trimesh.Scene(first_frame)
337
+ scene.export(animated_glb_path)
338
+ else:
339
+ return None, None
340
+ except Exception as e:
341
+ print(f"Error exporting GLB: {str(e)}")
342
+ animated_glb_path = None
343
 
344
+ # Create GIF for preview
 
 
345
  try:
346
+ animated_gif_path = os.path.join(LOG_PATH, f'animated_{base_filename}.gif')
347
+ generate_gif_from_frames(frames, animated_gif_path, fps)
 
 
 
 
 
348
  except Exception as e:
349
+ print(f"Error creating GIF: {str(e)}")
 
350
  animated_gif_path = None
351
+
352
+ return animated_glb_path, animated_gif_path
353
 
354
+ def process_3d_model(input_3d, animation_type, animation_duration, fps):
355
+ """Process a 3D model and apply animation"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  print(f"Processing: {input_3d} with animation type: {animation_type}")
357
 
358
  try:
359
+ # Create animation
360
+ animated_glb_path, animated_gif_path = create_animation_mesh(
361
  input_3d,
 
 
 
362
  animation_type=animation_type,
363
  duration=animation_duration,
364
  fps=fps
365
  )
366
 
367
+ if not animated_glb_path or not animated_gif_path:
368
+ return "Error creating animation", None
369
+
370
+ # Create a simple JSON metadata file
371
+ metadata = {
372
+ "animation_type": animation_type,
373
+ "duration": animation_duration,
374
+ "fps": fps,
375
+ "original_model": os.path.basename(input_3d),
376
+ "created_at": time.strftime("%Y-%m-%d %H:%M:%S")
377
+ }
378
+
379
+ json_path = os.path.join(LOG_PATH, f'metadata_{os.path.basename(input_3d).rsplit(".", 1)[0]}.json')
380
  with open(json_path, 'w') as f:
381
+ json.dump(metadata, f, indent=4)
382
 
383
+ return animated_glb_path, animated_gif_path, json_path
384
  except Exception as e:
385
+ error_msg = f"Error processing file: {str(e)}"
386
+ print(error_msg)
387
+ return error_msg, None, None
388
 
389
  _HEADER_ = '''
390
+ <h2><b>GLB ์• ๋‹ˆ๋ฉ”์ด์…˜ ์ƒ์„ฑ๊ธฐ - 3D ๋ชจ๋ธ ์›€์ง์ž„ ํšจ๊ณผ</b></h2>
391
 
392
+ ์ด ๋ฐ๋ชจ๋ฅผ ํ†ตํ•ด ์ •์ ์ธ 3D ๋ชจ๋ธ(GLB ํŒŒ์ผ)์— ๋‹ค์–‘ํ•œ ์• ๋‹ˆ๋ฉ”์ด์…˜ ํšจ๊ณผ๋ฅผ ์ ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
393
 
394
+ โ—๏ธโ—๏ธโ—๏ธ**์ค‘์š”์‚ฌํ•ญ:**
395
+ - ์ด ๋ฐ๋ชจ๋Š” ์—…๋กœ๋“œ๋œ GLB ํŒŒ์ผ์— ์• ๋‹ˆ๋ฉ”์ด์…˜์„ ์ ์šฉํ•ฉ๋‹ˆ๋‹ค.
396
+ - ๋‹ค์–‘ํ•œ ์• ๋‹ˆ๋ฉ”์ด์…˜ ์Šคํƒ€์ผ ์ค‘์—์„œ ์„ ํƒํ•˜์„ธ์š”: ํšŒ์ „, ๋ถ€์œ , ํญ๋ฐœ, ์กฐ๋ฆฝ, ํŽ„์Šค, ์Šค์œ™.
397
+ - ๊ฒฐ๊ณผ๋Š” ์• ๋‹ˆ๋ฉ”์ด์…˜๋œ GLB ํŒŒ์ผ๊ณผ ๋ฏธ๋ฆฌ๋ณด๊ธฐ์šฉ GIF ํŒŒ์ผ๋กœ ์ œ๊ณต๋ฉ๋‹ˆ๋‹ค.
 
 
398
  '''
399
 
400
+ _INFO_ = r"""
401
+ ### ์• ๋‹ˆ๋ฉ”์ด์…˜ ์œ ํ˜• ์„ค๋ช…
402
+ - **ํšŒ์ „(rotate)**: ๋ชจ๋ธ์ด Y์ถ•์„ ์ค‘์‹ฌ์œผ๋กœ ํšŒ์ „ํ•ฉ๋‹ˆ๋‹ค.
403
+ - **๋ถ€์œ (float)**: ๋ชจ๋ธ์ด ์œ„์•„๋ž˜๋กœ ๋ถ€๋“œ๋Ÿฝ๊ฒŒ ๋– ๋‹ค๋‹™๋‹ˆ๋‹ค.
404
+ - **ํญ๋ฐœ(explode)**: ๋ชจ๋ธ์˜ ๊ฐ ๋ถ€๋ถ„์ด ์ค‘์‹ฌ์—์„œ ๋ฐ”๊นฅ์ชฝ์œผ๋กœ ํผ์ ธ๋‚˜๊ฐ‘๋‹ˆ๋‹ค.
405
+ - **์กฐ๋ฆฝ(assemble)**: ํญ๋ฐœ ์• ๋‹ˆ๋ฉ”์ด์…˜์˜ ๋ฐ˜๋Œ€ - ๋ถ€ํ’ˆ๋“ค์ด ํ•จ๊ป˜ ๋ชจ์ž…๋‹ˆ๋‹ค.
406
+ - **ํŽ„์Šค(pulse)**: ๋ชจ๋ธ์ด ํฌ๊ธฐ๊ฐ€ ์ปค์กŒ๋‹ค ์ž‘์•„์กŒ๋‹ค๋ฅผ ๋ฐ˜๋ณตํ•ฉ๋‹ˆ๋‹ค.
407
+ - **์Šค์œ™(swing)**: ๋ชจ๋ธ์ด ์ขŒ์šฐ๋กœ ๋ถ€๋“œ๋Ÿฝ๊ฒŒ ํ”๋“ค๋ฆฝ๋‹ˆ๋‹ค.
408
+
409
+ ### ํŒ
410
+ - ์• ๋‹ˆ๋ฉ”์ด์…˜ ๊ธธ์ด์™€ FPS๋ฅผ ์กฐ์ ˆํ•˜์—ฌ ์›€์ง์ž„์˜ ์†๋„์™€ ๋ถ€๋“œ๋Ÿฌ์›€์„ ์กฐ์ ˆํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
411
+ - ๋ณต์žกํ•œ ๋ชจ๋ธ์€ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ๋” ์˜ค๋ž˜ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
412
+ - GIF ๋ฏธ๋ฆฌ๋ณด๊ธฐ๋Š” ๋น ๋ฅธ ์ฐธ์กฐ์šฉ์ด๋ฉฐ, ๊ณ ํ’ˆ์งˆ ๊ฒฐ๊ณผ๋ฅผ ์œ„ํ•ด์„œ๋Š” ์• ๋‹ˆ๋ฉ”์ด์…˜๋œ GLB ํŒŒ์ผ์„ ๋‹ค์šด๋กœ๋“œํ•˜์„ธ์š”.
 
 
 
 
413
  """
414
 
415
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
416
+ def create_gradio_interface():
417
+ with gr.Blocks(title="GLB ์• ๋‹ˆ๋ฉ”์ด์…˜ ์ƒ์„ฑ๊ธฐ") as demo:
418
+ # ์ œ๋ชฉ ์„น์…˜
419
+ gr.Markdown(_HEADER_)
420
+
421
+ with gr.Row():
422
+ with gr.Column():
423
+ # ์ž…๋ ฅ ์ปดํฌ๋„ŒํŠธ
424
+ input_3d = gr.Model3D(label="3D ๋ชจ๋ธ ํŒŒ์ผ ์—…๋กœ๋“œ (GLB ํฌ๋งท)")
425
+
426
+ with gr.Row():
427
+ animation_type = gr.Dropdown(
428
+ label="์• ๋‹ˆ๋ฉ”์ด์…˜ ์œ ํ˜•",
429
+ choices=["rotate", "float", "explode", "assemble", "pulse", "swing"],
430
+ value="rotate"
431
+ )
432
+
433
+ with gr.Row():
434
+ animation_duration = gr.Slider(
435
+ label="์• ๋‹ˆ๋ฉ”์ด์…˜ ๊ธธ์ด (์ดˆ)",
436
+ minimum=1.0,
437
+ maximum=10.0,
438
+ value=3.0,
439
+ step=0.5
440
+ )
441
+ fps = gr.Slider(
442
+ label="์ดˆ๋‹น ํ”„๋ ˆ์ž„ ์ˆ˜",
443
+ minimum=15,
444
+ maximum=60,
445
+ value=30,
446
+ step=1
447
+ )
448
+
449
+ submit_btn = gr.Button("๋ชจ๋ธ ์ฒ˜๋ฆฌ ๋ฐ ์• ๋‹ˆ๋ฉ”์ด์…˜ ์ƒ์„ฑ")
450
+
451
+ with gr.Column():
452
+ # ์ถœ๋ ฅ ์ปดํฌ๋„ŒํŠธ
453
+ output_3d = gr.Model3D(label="์• ๋‹ˆ๋ฉ”์ด์…˜ ์ ์šฉ๋œ 3D ๋ชจ๋ธ")
454
+ output_gif = gr.Image(label="์• ๋‹ˆ๋ฉ”์ด์…˜ ๋ฏธ๋ฆฌ๋ณด๊ธฐ (GIF)")
455
+ output_json = gr.File(label="๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ํŒŒ์ผ ๋‹ค์šด๋กœ๋“œ")
456
+
457
+ # ์• ๋‹ˆ๋ฉ”์ด์…˜ ์œ ํ˜• ์„ค๋ช…
458
+ gr.Markdown(_INFO_)
459
+
460
+ # ๋ฒ„ํŠผ ๋™์ž‘ ์„ค์ •
461
+ submit_btn.click(
462
+ fn=process_3d_model,
463
+ inputs=[input_3d, animation_type, animation_duration, fps],
464
+ outputs=[output_3d, output_gif, output_json]
465
+ )
466
+
467
+ # ์˜ˆ์ œ ์ค€๋น„
468
+ example_files = [ [f] for f in glob.glob('./data/demo_glb/*.glb') ]
469
+
470
+ if example_files:
471
+ example = gr.Examples(
472
+ examples=example_files,
473
+ inputs=[input_3d],
474
+ examples_per_page=10,
475
+ )
476
 
477
+ return demo
 
 
 
 
 
 
 
 
 
478
 
479
+ # ๋ฉ”์ธ ์‹คํ–‰ ๋ถ€๋ถ„
480
  if __name__ == "__main__":
481
+ demo = create_gradio_interface()
482
+ demo.launch(share=True)