Spaces:
Sleeping
Sleeping
| from typing import ClassVar | |
| import numpy as np | |
| class JointConfig: | |
| # Standard joint names used in LeRobot training data | |
| LEROBOT_JOINT_NAMES: ClassVar = [ | |
| "shoulder_pan_joint", | |
| "shoulder_lift_joint", | |
| "elbow_joint", | |
| "wrist_1_joint", | |
| "wrist_2_joint", | |
| "wrist_3_joint", | |
| ] | |
| # Our custom joint names (more intuitive for users) | |
| CUSTOM_JOINT_NAMES: ClassVar = [ | |
| "base_rotation", | |
| "shoulder_tilt", | |
| "elbow_bend", | |
| "wrist_rotate", | |
| "wrist_tilt", | |
| "wrist_twist", | |
| ] | |
| # Mapping from our custom names to LeRobot standard names | |
| CUSTOM_TO_LEROBOT_NAMES: ClassVar = { | |
| "base_rotation": "shoulder_pan_joint", | |
| "shoulder_tilt": "shoulder_lift_joint", | |
| "elbow_bend": "elbow_joint", | |
| "wrist_rotate": "wrist_1_joint", | |
| "wrist_tilt": "wrist_2_joint", | |
| "wrist_twist": "wrist_3_joint", | |
| } | |
| # Reverse mapping for convenience | |
| LEROBOT_TO_CUSTOM_NAMES: ClassVar = { | |
| v: k for k, v in CUSTOM_TO_LEROBOT_NAMES.items() | |
| } | |
| # Joint limits in normalized values (-100 to +100 for most joints, 0 to 100 for gripper) | |
| JOINT_LIMITS: ClassVar = { | |
| "base_rotation": (-100.0, 100.0), | |
| "shoulder_tilt": (-100.0, 100.0), | |
| "elbow_bend": (-100.0, 100.0), | |
| "wrist_rotate": (-100.0, 100.0), | |
| "wrist_tilt": (-100.0, 100.0), | |
| "wrist_twist": (-100.0, 100.0), | |
| } | |
| def get_joint_index(cls, joint_name: str) -> int | None: | |
| """ | |
| Get the index of a joint in the standard joint order. | |
| Args: | |
| joint_name: Name of the joint (can be custom or LeRobot name) | |
| Returns: | |
| Index of the joint, or None if not found | |
| """ | |
| # Try custom names first | |
| if joint_name in cls.CUSTOM_JOINT_NAMES: | |
| return cls.CUSTOM_JOINT_NAMES.index(joint_name) | |
| # Try LeRobot names | |
| if joint_name in cls.LEROBOT_JOINT_NAMES: | |
| return cls.LEROBOT_JOINT_NAMES.index(joint_name) | |
| # Try case-insensitive matching | |
| joint_name_lower = joint_name.lower() | |
| for i, name in enumerate(cls.CUSTOM_JOINT_NAMES): | |
| if name.lower() == joint_name_lower: | |
| return i | |
| for i, name in enumerate(cls.LEROBOT_JOINT_NAMES): | |
| if name.lower() == joint_name_lower: | |
| return i | |
| return None | |
| def parse_joint_data(cls, joints_data, policy_type: str = "act") -> list[float]: | |
| """ | |
| Parse joint data from Arena message into standard order. | |
| Expected format: dict with joint names as keys and normalized values. | |
| All values are already normalized from the training data pipeline. | |
| Args: | |
| joints_data: Joint data from Arena message | |
| policy_type: Type of policy (for logging purposes) | |
| Returns: | |
| List of 6 normalized joint values in LeRobot standard order | |
| """ | |
| try: | |
| # Handle different possible data formats | |
| if hasattr(joints_data, "data"): | |
| joint_dict = joints_data.data | |
| else: | |
| joint_dict = joints_data | |
| if not isinstance(joint_dict, dict): | |
| return [0.0] * 6 | |
| # Extract joint values in LeRobot standard order | |
| joint_values = [] | |
| for lerobot_name in cls.LEROBOT_JOINT_NAMES: | |
| value = None | |
| # Try LeRobot name directly | |
| if lerobot_name in joint_dict: | |
| value = float(joint_dict[lerobot_name]) | |
| else: | |
| # Try custom name | |
| custom_name = cls.LEROBOT_TO_CUSTOM_NAMES.get(lerobot_name) | |
| if custom_name and custom_name in joint_dict: | |
| value = float(joint_dict[custom_name]) | |
| else: | |
| # Try various common formats | |
| for key in [ | |
| lerobot_name, | |
| f"joint_{lerobot_name}", | |
| lerobot_name.upper(), | |
| custom_name, | |
| f"joint_{custom_name}" if custom_name else None, | |
| ]: | |
| if key and key in joint_dict: | |
| value = float(joint_dict[key]) | |
| break | |
| joint_values.append(value if value is not None else 0.0) | |
| return joint_values | |
| except Exception: | |
| # Return zeros if parsing fails | |
| return [0.0] * 6 | |
| def create_joint_commands(cls, action_values: np.ndarray) -> list[dict]: | |
| """ | |
| Create joint command dictionaries from action values. | |
| Args: | |
| action_values: Array of 6 joint values in LeRobot standard order | |
| Returns: | |
| List of joint command dictionaries with custom names | |
| """ | |
| commands = [] | |
| for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES): | |
| if i < len(action_values): | |
| commands.append({"name": custom_name, "value": float(action_values[i])}) | |
| return commands | |
| def validate_joint_values(cls, joint_values: np.ndarray) -> np.ndarray: | |
| """ | |
| Validate and clamp joint values to their limits. | |
| Args: | |
| joint_values: Array of joint values | |
| Returns: | |
| Clamped joint values | |
| """ | |
| if len(joint_values) != 6: | |
| # Pad or truncate to 6 values | |
| padded = np.zeros(6, dtype=np.float32) | |
| n = min(len(joint_values), 6) | |
| padded[:n] = joint_values[:n] | |
| joint_values = padded | |
| # Clamp to limits | |
| for i, custom_name in enumerate(cls.CUSTOM_JOINT_NAMES): | |
| min_val, max_val = cls.JOINT_LIMITS[custom_name] | |
| joint_values[i] = np.clip(joint_values[i], min_val, max_val) | |
| return joint_values | |