Spaces:
Running
on
Zero
Running
on
Zero
File size: 34,178 Bytes
e7b9fb6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 |
#!/usr/bin/env python3
"""
MagicArticulate API - Enhanced Version
支持用户上传的3D模型文件和多用户结果管理
"""
import os
import sys
import uuid
import json
import time
import shutil
import logging
import tempfile
import traceback
from pathlib import Path
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple
import torch
import trimesh
import numpy as np
from tqdm import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed, DistributedDataParallelKwargs
# 添加父目录到路径以正确导入模块
parent_dir = str(Path(__file__).parent.parent)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
print(f"🔍 ARTICULATE_API DEBUG: Current working directory: {os.getcwd()}")
print(f"🔍 ARTICULATE_API DEBUG: Script file path: {__file__}")
print(f"🔍 ARTICULATE_API DEBUG: Parent directory: {parent_dir}")
print(f"🔍 ARTICULATE_API DEBUG: sys.path includes:")
for i, path in enumerate(sys.path[:10]): # 只显示前10个避免太长
print(f" {i}: {path}")
# 检查目录结构
utils_path = os.path.join(parent_dir, 'utils')
skeleton_path = os.path.join(parent_dir, 'skeleton_models')
print(f"🔍 ARTICULATE_API DEBUG: utils path exists: {os.path.exists(utils_path)}")
print(f"🔍 ARTICULATE_API DEBUG: skeleton_models path exists: {os.path.exists(skeleton_path)}")
if os.path.exists(utils_path):
print(f"🔍 ARTICULATE_API DEBUG: utils contents: {os.listdir(utils_path)}")
from skeleton_models.skeletongen import SkeletonGPT
from utils.mesh_to_pc import MeshProcessor
from utils.save_utils import (
save_mesh, pred_joints_and_bones, save_skeleton_to_txt,
save_args, remove_duplicate_joints, save_skeleton_obj,
render_mesh_with_skeleton
)
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class ModelValidator:
"""3D模型验证和修复类"""
SUPPORTED_FORMATS = {'.obj', '.glb', '.gltf', '.ply', '.stl', '.fbx', '.dae'}
MAX_VERTICES = 100000 # 最大顶点数
MIN_VERTICES = 100 # 最小顶点数
MAX_FILE_SIZE_MB = 100 # 最大文件大小
@staticmethod
def validate_file(file_path: str) -> Tuple[bool, str, Dict[str, Any]]:
"""
验证3D模型文件
Returns:
(is_valid, error_message, model_info)
"""
try:
# 检查文件是否存在
if not os.path.exists(file_path):
return False, "文件不存在", {}
# 检查文件大小
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
if file_size_mb > ModelValidator.MAX_FILE_SIZE_MB:
return False, f"文件过大: {file_size_mb:.1f}MB > {ModelValidator.MAX_FILE_SIZE_MB}MB", {}
# 检查文件格式
file_ext = Path(file_path).suffix.lower()
if file_ext not in ModelValidator.SUPPORTED_FORMATS:
return False, f"不支持的文件格式: {file_ext}", {}
# 尝试加载模型
mesh = trimesh.load(file_path, force='mesh')
# 检查是否为有效网格
if not hasattr(mesh, 'vertices') or not hasattr(mesh, 'faces'):
return False, "文件不包含有效的网格数据", {}
# 检查顶点数量
vertex_count = len(mesh.vertices)
if vertex_count < ModelValidator.MIN_VERTICES:
return False, f"顶点数量过少: {vertex_count} < {ModelValidator.MIN_VERTICES}", {}
if vertex_count > ModelValidator.MAX_VERTICES:
return False, f"顶点数量过多: {vertex_count} > {ModelValidator.MAX_VERTICES}", {}
# 收集模型信息
model_info = {
'file_name': os.path.basename(file_path),
'file_size_mb': file_size_mb,
'format': file_ext,
'vertex_count': vertex_count,
'face_count': len(mesh.faces) if hasattr(mesh, 'faces') else 0,
'bounds': mesh.bounds.tolist() if hasattr(mesh, 'bounds') else None,
'is_watertight': mesh.is_watertight if hasattr(mesh, 'is_watertight') else False,
'volume': float(mesh.volume) if hasattr(mesh, 'volume') else 0.0,
'area': float(mesh.area) if hasattr(mesh, 'area') else 0.0,
}
return True, "验证通过", model_info
except Exception as e:
return False, f"模型验证失败: {str(e)}", {}
@staticmethod
def auto_repair_mesh(mesh: trimesh.Trimesh) -> Tuple[trimesh.Trimesh, List[str]]:
"""
自动修复网格问题
Returns:
(repaired_mesh, repair_log)
"""
repair_log = []
try:
# 移除重复顶点
if mesh.is_volume:
original_vertices = len(mesh.vertices)
mesh.merge_vertices()
if len(mesh.vertices) < original_vertices:
repair_log.append(f"移除了 {original_vertices - len(mesh.vertices)} 个重复顶点")
# 修复法向量
if not hasattr(mesh, 'vertex_normals') or mesh.vertex_normals is None:
mesh.fix_normals()
repair_log.append("修复了顶点法向量")
# 移除退化面
original_faces = len(mesh.faces)
mesh.remove_degenerate_faces()
if len(mesh.faces) < original_faces:
repair_log.append(f"移除了 {original_faces - len(mesh.faces)} 个退化面")
# 填充孔洞(如果需要)
if not mesh.is_watertight and hasattr(mesh, 'fill_holes'):
try:
mesh.fill_holes()
repair_log.append("填充了网格孔洞")
except:
repair_log.append("尝试填充孔洞失败,但继续处理")
return mesh, repair_log
except Exception as e:
logger.warning(f"网格修复过程中出现错误: {str(e)}")
return mesh, repair_log + [f"修复过程出错: {str(e)}"]
class ModelPreprocessor:
"""模型预处理类"""
@staticmethod
def convert_format(input_path: str, output_format: str = '.obj') -> str:
"""
转换模型格式
Args:
input_path: 输入文件路径
output_format: 输出格式 (默认为.obj)
Returns:
输出文件路径
"""
try:
mesh = trimesh.load(input_path, force='mesh')
# 生成输出路径
base_name = os.path.splitext(os.path.basename(input_path))[0]
output_path = os.path.join(
os.path.dirname(input_path),
f"{base_name}_converted{output_format}"
)
# 导出为指定格式
mesh.export(output_path)
logger.info(f"格式转换完成: {input_path} -> {output_path}")
return output_path
except Exception as e:
logger.error(f"格式转换失败: {str(e)}")
raise
@staticmethod
def simplify_mesh(mesh: trimesh.Trimesh, target_faces: int = 5000) -> trimesh.Trimesh:
"""
简化网格
Args:
mesh: 输入网格
target_faces: 目标面数
Returns:
简化后的网格
"""
try:
if len(mesh.faces) <= target_faces:
return mesh
# 使用quadric decimation进行简化
simplified = mesh.simplify_quadratic_decimation(target_faces)
logger.info(f"网格简化: {len(mesh.faces)} -> {len(simplified.faces)} 面")
return simplified
except Exception as e:
logger.warning(f"网格简化失败: {str(e)}, 使用原始网格")
return mesh
@staticmethod
def normalize_mesh(mesh: trimesh.Trimesh, scale_factor: float = 0.95) -> Tuple[trimesh.Trimesh, Dict[str, Any]]:
"""
标准化网格到标准坐标空间
Args:
mesh: 输入网格
scale_factor: 缩放因子
Returns:
(normalized_mesh, transform_info)
"""
try:
# 计算边界框
bounds = mesh.bounds
center = (bounds[0] + bounds[1]) / 2
size = bounds[1] - bounds[0]
max_size = size.max()
# 计算变换参数
scale = (2.0 * scale_factor) / max_size
translation = -center
# 应用变换
vertices = mesh.vertices.copy()
vertices = (vertices + translation) * scale
# 创建新网格
normalized_mesh = trimesh.Trimesh(vertices=vertices, faces=mesh.faces)
# 记录变换信息
transform_info = {
'original_center': center.tolist(),
'original_size': size.tolist(),
'scale': float(scale),
'translation': translation.tolist()
}
logger.info(f"网格标准化完成: scale={scale:.4f}")
return normalized_mesh, transform_info
except Exception as e:
logger.error(f"网格标准化失败: {str(e)}")
raise
class UserSessionManager:
"""用户会话管理类"""
def __init__(self, base_dir: str = "user_sessions"):
self.base_dir = Path(base_dir)
self.base_dir.mkdir(exist_ok=True)
# 元数据文件
self.metadata_file = self.base_dir / "sessions_metadata.json"
self.load_metadata()
def load_metadata(self):
"""加载会话元数据"""
if self.metadata_file.exists():
with open(self.metadata_file, 'r', encoding='utf-8') as f:
self.sessions = json.load(f)
else:
self.sessions = {}
def save_metadata(self):
"""保存会话元数据"""
with open(self.metadata_file, 'w', encoding='utf-8') as f:
json.dump(self.sessions, f, indent=2, ensure_ascii=False)
def create_session(self, user_id: Optional[str] = None) -> str:
"""
创建新的用户会话
Args:
user_id: 用户ID(可选)
Returns:
session_id
"""
session_id = str(uuid.uuid4())
session_dir = self.base_dir / session_id
session_dir.mkdir(exist_ok=True)
# 创建子目录
(session_dir / "uploads").mkdir(exist_ok=True)
(session_dir / "outputs").mkdir(exist_ok=True)
(session_dir / "temp").mkdir(exist_ok=True)
# 记录会话信息
self.sessions[session_id] = {
'user_id': user_id,
'created_at': datetime.now().isoformat(),
'status': 'active',
'processed_models': [],
'last_activity': datetime.now().isoformat()
}
self.save_metadata()
logger.info(f"创建新会话: {session_id}")
return session_id
def get_session_dir(self, session_id: str) -> Path:
"""获取会话目录"""
session_dir = self.base_dir / session_id
if not session_dir.exists():
raise ValueError(f"会话不存在: {session_id}")
return session_dir
def update_activity(self, session_id: str):
"""更新会话活动时间"""
if session_id in self.sessions:
self.sessions[session_id]['last_activity'] = datetime.now().isoformat()
self.save_metadata()
def add_processed_model(self, session_id: str, model_info: Dict[str, Any]):
"""添加已处理模型记录"""
if session_id in self.sessions:
self.sessions[session_id]['processed_models'].append(model_info)
self.update_activity(session_id)
def cleanup_old_sessions(self, max_age_days: int = 7):
"""清理旧会话"""
cutoff_time = datetime.now().timestamp() - (max_age_days * 24 * 3600)
sessions_to_remove = []
for session_id, session_info in self.sessions.items():
last_activity = datetime.fromisoformat(session_info['last_activity'])
if last_activity.timestamp() < cutoff_time:
sessions_to_remove.append(session_id)
for session_id in sessions_to_remove:
try:
session_dir = self.base_dir / session_id
if session_dir.exists():
shutil.rmtree(session_dir)
del self.sessions[session_id]
logger.info(f"清理旧会话: {session_id}")
except Exception as e:
logger.error(f"清理会话失败 {session_id}: {str(e)}")
if sessions_to_remove:
self.save_metadata()
class MagicArticulateAPI:
"""MagicArticulate API主类"""
def __init__(self,
model_weights_path: Optional[str] = None,
device: str = "auto",
session_base_dir: str = "user_sessions"):
self.device = self._setup_device(device)
self.model = None
self.accelerator = None
self.model_weights_path = model_weights_path
# 初始化会话管理器
self.session_manager = UserSessionManager(session_base_dir)
# 默认处理参数 - 匹配原始demo.py设置
self.default_args = {
'input_pc_num': 8192,
'num_beams': 1,
'n_discrete_size': 128,
'n_max_bones': 100,
'pad_id': -1,
'precision': 'fp16',
'batchsize_per_gpu': 1,
'apply_marching_cubes': False,
'octree_depth': 7,
'hier_order': False, # 匹配demo.py默认值
'save_render': False,
'llm': 'facebook/opt-350m' # 匹配demo.py默认值
}
self.initialized = False
logger.info("MagicArticulate API 初始化完成")
def _setup_device(self, device: str) -> torch.device:
"""设置计算设备"""
if device == "auto":
if torch.cuda.is_available():
device = "cuda"
logger.info(f"使用GPU: {torch.cuda.get_device_name()}")
else:
device = "cpu"
logger.info("使用CPU")
return torch.device(device)
def initialize_model(self) -> bool:
"""初始化模型"""
try:
if self.initialized:
return True
logger.info("正在初始化MagicArticulate模型...")
# 设置加速器
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
self.accelerator = Accelerator(
kwargs_handlers=[kwargs],
mixed_precision=self.default_args['precision'],
)
# 创建模型
args = self._create_args_object()
self.model = SkeletonGPT(args)
if self.device.type == "cuda":
self.model = self.model.cuda()
# 加载预训练权重
if self.model_weights_path and os.path.exists(self.model_weights_path):
logger.info(f"加载模型权重: {self.model_weights_path}")
pkg = torch.load(self.model_weights_path, map_location=self.device)
self.model.load_state_dict(pkg["model"])
else:
error_msg = "预训练权重必须提供!当前使用随机初始化,结果将不准确。"
logger.error(error_msg)
# 不抛出错误,但给出强烈警告
logger.error("⚠️ WARNING: 没有预训练权重,生成的骨骼结构将不准确!")
self.model.eval()
set_seed(0)
# 准备模型
if self.accelerator:
self.model = self.accelerator.prepare(self.model)
self.initialized = True
logger.info("✅ 模型初始化成功")
return True
except Exception as e:
logger.error(f"❌ 模型初始化失败: {str(e)}")
logger.error(traceback.format_exc())
return False
def process_uploaded_model(self,
file_path: str,
session_id: Optional[str] = None,
user_prompt: str = "",
processing_options: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
处理用户上传的3D模型
Args:
file_path: 模型文件路径
session_id: 会话ID(可选)
user_prompt: 用户提示词
processing_options: 处理选项
Returns:
处理结果字典
"""
start_time = time.time()
try:
# 创建会话(如果未提供)
if not session_id:
session_id = self.session_manager.create_session()
logger.info(f"开始处理模型: {file_path}, 会话: {session_id}")
# 步骤1: 验证模型文件
is_valid, error_msg, model_info = ModelValidator.validate_file(file_path)
if not is_valid:
return self._create_error_result(error_msg, session_id, start_time)
logger.info(f"模型验证通过: {model_info}")
# 步骤2: 复制文件到会话目录
session_dir = self.session_manager.get_session_dir(session_id)
uploaded_file = session_dir / "uploads" / os.path.basename(file_path)
shutil.copy2(file_path, uploaded_file)
# 步骤3: 预处理模型
processed_mesh, preprocessing_log = self._preprocess_model(
str(uploaded_file),
processing_options or {}
)
# 步骤4: 生成骨骼
if not self.initialized:
if not self.initialize_model():
return self._create_error_result("模型初始化失败", session_id, start_time)
skeleton_result = self._generate_skeleton(
processed_mesh,
model_info['file_name'],
user_prompt
)
# 步骤5: 保存结果
output_files = self._save_results(
skeleton_result,
processed_mesh,
model_info,
session_dir,
user_prompt
)
# 步骤6: 记录处理结果
processing_time = time.time() - start_time
result = {
'success': True,
'session_id': session_id,
'processing_time': processing_time,
'model_info': model_info,
'preprocessing_log': preprocessing_log,
'skeleton_data': skeleton_result,
'output_files': output_files,
'user_prompt': user_prompt,
'timestamp': datetime.now().isoformat()
}
# 更新会话记录
self.session_manager.add_processed_model(session_id, {
'file_name': model_info['file_name'],
'processing_time': processing_time,
'timestamp': datetime.now().isoformat(),
'success': True
})
logger.info(f"✅ 模型处理完成,耗时: {processing_time:.2f}秒")
return result
except Exception as e:
processing_time = time.time() - start_time
error_msg = f"处理过程中发生错误: {str(e)}"
logger.error(f"❌ {error_msg}")
logger.error(traceback.format_exc())
return self._create_error_result(error_msg, session_id, start_time)
def _preprocess_model(self, file_path: str, options: Dict[str, Any]) -> Tuple[trimesh.Trimesh, List[str]]:
"""预处理模型"""
preprocessing_log = []
try:
# 加载模型
mesh = trimesh.load(file_path, force='mesh')
preprocessing_log.append(f"加载模型: {len(mesh.vertices)} 顶点, {len(mesh.faces)} 面")
# 自动修复
if options.get('auto_repair', True):
mesh, repair_log = ModelValidator.auto_repair_mesh(mesh)
preprocessing_log.extend(repair_log)
# 简化网格(如果需要)
target_faces = options.get('target_faces', 10000)
if len(mesh.faces) > target_faces:
mesh = ModelPreprocessor.simplify_mesh(mesh, target_faces)
preprocessing_log.append(f"简化网格到 {len(mesh.faces)} 面")
# 标准化网格
mesh, transform_info = ModelPreprocessor.normalize_mesh(mesh)
preprocessing_log.append(f"标准化网格: scale={transform_info['scale']:.4f}")
return mesh, preprocessing_log
except Exception as e:
error_msg = f"预处理失败: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _generate_skeleton(self, mesh: trimesh.Trimesh, file_name: str, user_prompt: str) -> Dict[str, Any]:
"""生成骨骼结构"""
try:
# 转换网格为点云
points_per_mesh = self.default_args['input_pc_num']
apply_marching_cubes = self.default_args['apply_marching_cubes']
octree_depth = self.default_args['octree_depth']
point_clouds = MeshProcessor.convert_meshes_to_point_clouds(
[mesh],
points_per_mesh,
apply_marching_cubes,
octree_depth
)
pc_data = point_clouds[0]
# 按照原始demo进行标准化处理
pc_coor = pc_data[:, :3]
normals = pc_data[:, 3:]
bounds = np.array([pc_coor.min(axis=0), pc_coor.max(axis=0)])
# 存储变换信息以便后续去标准化
trans = (bounds[0] + bounds[1])[None, :] / 2
scale = ((bounds[1] - bounds[0]).max() + 1e-5)
# 标准化坐标 - 与原始demo完全一致
pc_coor = pc_coor - (bounds[0] + bounds[1])[None, :] / 2
pc_coor = pc_coor / np.abs(pc_coor).max() * 0.9995
# 组合坐标和法向量
pc_coor = pc_coor.astype(np.float32)
normals = normals.astype(np.float32)
pc_normal_data = np.concatenate([pc_coor, normals], axis=-1, dtype=np.float16)
# 准备输入数据
pc_normal = torch.from_numpy(pc_normal_data).unsqueeze(0)
if self.device.type == "cuda":
pc_normal = pc_normal.cuda()
# 获取mesh的变换信息
mesh_bounds = np.array([mesh.vertices.min(axis=0), mesh.vertices.max(axis=0)])
mesh_trans = (mesh_bounds[0] + mesh_bounds[1])[None, :] / 2
mesh_scale = ((mesh_bounds[1] - mesh_bounds[0]).max() + 1e-5)
batch_data = {
'pc_normal': pc_normal,
'file_name': [os.path.splitext(file_name)[0]],
'trans': torch.from_numpy(mesh_trans).unsqueeze(0),
'scale': torch.tensor(mesh_scale).unsqueeze(0),
'vertices': torch.from_numpy(mesh.vertices).unsqueeze(0),
'faces': torch.from_numpy(mesh.faces).unsqueeze(0)
}
# 生成骨骼
with torch.no_grad():
if self.accelerator:
with self.accelerator.autocast():
pred_bone_coords = self.model.generate(batch_data)
else:
pred_bone_coords = self.model.generate(batch_data)
# 处理输出 - 完全按照原始demo的流程
trans = batch_data['trans'][0].cpu().numpy()
scale = batch_data['scale'][0].cpu().numpy()
vertices = batch_data['vertices'][0].cpu().numpy()
faces = batch_data['faces'][0].cpu().numpy()
skeleton = pred_bone_coords[0].cpu().numpy().squeeze()
pred_joints, pred_bones = pred_joints_and_bones(skeleton)
# 去重处理
if self.default_args['hier_order']:
pred_joints, pred_bones, pred_root_index = remove_duplicate_joints(
pred_joints, pred_bones, root_index=pred_bones[0][0]
)
else:
pred_joints, pred_bones = remove_duplicate_joints(pred_joints, pred_bones)
pred_root_index = 0
# 重要:去标准化骨骼关节到原始模型坐标系
pred_joints_denorm = pred_joints * scale + trans
return {
'joints': pred_joints_denorm.tolist(), # 使用去标准化后的关节
'joints_normalized': pred_joints.tolist(), # 保留标准化的关节用于可视化
'bones': pred_bones,
'root_index': pred_root_index,
'joint_count': len(pred_joints),
'bone_count': len(pred_bones),
'raw_skeleton': skeleton.tolist(),
'user_prompt': user_prompt,
'transform_info': {
'trans': trans.tolist(),
'scale': float(scale)
}
}
except Exception as e:
error_msg = f"骨骼生成失败: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _save_results(self,
skeleton_result: Dict[str, Any],
mesh: trimesh.Trimesh,
model_info: Dict[str, Any],
session_dir: Path,
user_prompt: str) -> Dict[str, str]:
"""保存处理结果"""
try:
output_dir = session_dir / "outputs"
base_name = os.path.splitext(model_info['file_name'])[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_files = {}
# 移除JSON格式输出以避免序列化问题
# 保存骨骼OBJ - 使用去标准化后的关节
obj_file = output_dir / f"{base_name}_{timestamp}_skeleton.obj"
save_skeleton_obj(
np.array(skeleton_result['joints']),
skeleton_result['bones'],
str(obj_file),
skeleton_result.get('root_index', 0),
use_cone=self.default_args['hier_order']
)
output_files['skeleton_obj'] = str(obj_file)
# 保存骨骼TXT
txt_file = output_dir / f"{base_name}_{timestamp}_rig.txt"
save_skeleton_to_txt(
np.array(skeleton_result['joints']),
skeleton_result['bones'],
skeleton_result.get('root_index', 0),
self.default_args['hier_order'],
mesh.vertices,
str(txt_file)
)
output_files['skeleton_txt'] = str(txt_file)
# 保存处理后的网格
mesh_file = output_dir / f"{base_name}_{timestamp}_processed.obj"
mesh.export(str(mesh_file))
output_files['processed_mesh'] = str(mesh_file)
# 保存处理报告(文本格式)
report_file = output_dir / f"{base_name}_{timestamp}_report.txt"
report_content = f"""MagicArticulate Processing Report
=====================================
File: {model_info['file_name']}
Processing Time: {datetime.now().isoformat()}
User Prompt: {user_prompt}
Model Information:
- Vertices: {model_info.get('vertex_count', 'N/A')}
- Faces: {model_info.get('face_count', 'N/A')}
- File Size: {model_info.get('file_size_mb', 'N/A')} MB
- Format: {model_info.get('format', 'N/A')}
Skeleton Results:
- Joints: {skeleton_result.get('joint_count', 'N/A')}
- Bones: {skeleton_result.get('bone_count', 'N/A')}
- Root Index: {skeleton_result.get('root_index', 'N/A')}
Generated Files:
- Skeleton OBJ: {base_name}_{timestamp}_skeleton.obj
- Skeleton TXT: {base_name}_{timestamp}_rig.txt
- Processed Mesh: {base_name}_{timestamp}_processed.obj
"""
with open(report_file, 'w', encoding='utf-8') as f:
f.write(report_content)
output_files['report'] = str(report_file)
logger.info(f"结果保存完成: {len(output_files)} 个文件")
return output_files
except Exception as e:
error_msg = f"保存结果失败: {str(e)}"
logger.error(error_msg)
raise RuntimeError(error_msg)
def _create_error_result(self, error_message: str, session_id: str, start_time: float) -> Dict[str, Any]:
"""创建错误结果"""
processing_time = time.time() - start_time
return {
'success': False,
'session_id': session_id,
'error': error_message,
'processing_time': processing_time,
'timestamp': datetime.now().isoformat()
}
def _make_json_serializable(self, obj):
"""将对象转换为JSON可序列化格式"""
if isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, dict):
return {key: self._make_json_serializable(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [self._make_json_serializable(item) for item in obj]
else:
return obj
def _create_args_object(self):
"""创建参数对象"""
class Args:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
return Args(**self.default_args)
def get_session_info(self, session_id: str) -> Dict[str, Any]:
"""获取会话信息"""
if session_id not in self.session_manager.sessions:
raise ValueError(f"会话不存在: {session_id}")
return self.session_manager.sessions[session_id].copy()
def list_user_sessions(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""列出用户会话"""
sessions = []
for session_id, session_info in self.session_manager.sessions.items():
if user_id is None or session_info.get('user_id') == user_id:
sessions.append({
'session_id': session_id,
**session_info
})
return sorted(sessions, key=lambda x: x['created_at'], reverse=True)
def cleanup_sessions(self, max_age_days: int = 7):
"""清理旧会话"""
self.session_manager.cleanup_old_sessions(max_age_days)
# 简化的使用接口
def process_model_file(file_path: str,
user_prompt: str = "",
model_weights_path: Optional[str] = None,
output_dir: Optional[str] = None) -> Dict[str, Any]:
"""
简化的模型处理接口
Args:
file_path: 模型文件路径
user_prompt: 用户提示词
model_weights_path: 模型权重路径
output_dir: 输出目录
Returns:
处理结果
"""
api = MagicArticulateAPI(
model_weights_path=model_weights_path,
session_base_dir=output_dir or "temp_sessions"
)
result = api.process_uploaded_model(
file_path=file_path,
user_prompt=user_prompt
)
return result
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="MagicArticulate API 测试")
parser.add_argument("--input", required=True, help="输入模型文件路径")
parser.add_argument("--prompt", default="", help="用户提示词")
parser.add_argument("--weights", help="模型权重路径")
parser.add_argument("--output", default="api_outputs", help="输出目录")
args = parser.parse_args()
# 测试API
result = process_model_file(
file_path=args.input,
user_prompt=args.prompt,
model_weights_path=args.weights,
output_dir=args.output
)
if result['success']:
print("✅ 处理成功!")
print(f"会话ID: {result['session_id']}")
print(f"处理时间: {result['processing_time']:.2f}秒")
print(f"关节数量: {result['skeleton_data']['joint_count']}")
print(f"骨骼数量: {result['skeleton_data']['bone_count']}")
print(f"输出文件: {len(result['output_files'])} 个")
else:
print("❌ 处理失败!")
print(f"错误: {result['error']}") |