{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "MpkYHwCqk7W-" }, "source": [ "![MuJoCo banner](https://raw.githubusercontent.com/google-deepmind/mujoco/main/banner.png)\n", "\n", "\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "xBSdkbmGN2K-" }, "source": [ "### Copyright notice" ] }, { "cell_type": "markdown", "metadata": { "id": "_UbO9uhtBSX5" }, "source": [ ">

Copyright 2025 DeepMind Technologies Limited.

\n", ">

Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.

\n", ">

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

" ] }, { "cell_type": "markdown", "metadata": { "id": "dNIJkb_FM2Ux" }, "source": [ "# Manipulation in The Playground! \n", "\n", "In this notebook, we'll walk through a couple manipulation environments available in MuJoCo Playground.\n", "\n", "**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime > Change runtime type\".\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Xqo7pyX-n72M", "cellView": "form" }, "outputs": [], "source": [ "#@title Install pre-requisites\n", "!pip install mujoco\n", "!pip install mujoco_mjx\n", "!pip install brax" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "IbZxYDxzoz5R" }, "outputs": [], "source": [ "# @title Check if MuJoCo installation was successful\n", "\n", "import distutils.util\n", "import os\n", "import subprocess\n", "\n", "if subprocess.run('nvidia-smi').returncode:\n", " raise RuntimeError(\n", " 'Cannot communicate with GPU. '\n", " 'Make sure you are using a GPU Colab runtime. '\n", " 'Go to the Runtime menu and select Choose runtime type.'\n", " )\n", "\n", "# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n", "# This is usually installed as part of an Nvidia driver package, but the Colab\n", "# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n", "# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n", "NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n", "if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n", " with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n", " f.write(\"\"\"{\n", " \"file_format_version\" : \"1.0.0\",\n", " \"ICD\" : {\n", " \"library_path\" : \"libEGL_nvidia.so.0\"\n", " }\n", "}\n", "\"\"\")\n", "\n", "# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n", "print('Setting environment variable to use GPU rendering:')\n", "%env MUJOCO_GL=egl\n", "\n", "try:\n", " print('Checking that the installation succeeded:')\n", " import mujoco\n", "\n", " mujoco.MjModel.from_xml_string('')\n", "except Exception as e:\n", " raise e from RuntimeError(\n", " 'Something went wrong during installation. Check the shell output above '\n", " 'for more information.\\n'\n", " 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n", " 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n", " )\n", "\n", "print('Installation successful.')\n", "\n", "# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n", "xla_flags = os.environ.get('XLA_FLAGS', '')\n", "xla_flags += ' --xla_gpu_triton_gemm_any=True'\n", "os.environ['XLA_FLAGS'] = xla_flags" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T5f4w3Kq2X14", "cellView": "form" }, "outputs": [], "source": [ "# @title Import packages for plotting and creating graphics\n", "import json\n", "import itertools\n", "import time\n", "from typing import Callable, List, NamedTuple, Optional, Union\n", "import numpy as np\n", "\n", "# Graphics and plotting.\n", "print(\"Installing mediapy:\")\n", "!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n", "!pip install -q mediapy\n", "import mediapy as media\n", "import matplotlib.pyplot as plt\n", "\n", "# More legible printing from numpy.\n", "np.set_printoptions(precision=3, suppress=True, linewidth=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "id": "ObF1UXrkb0Nd" }, "outputs": [], "source": [ "# @title Import MuJoCo, MJX, and Brax\n", "from datetime import datetime\n", "import functools\n", "import os\n", "from typing import Any, Dict, Sequence, Tuple, Union\n", "from brax import base\n", "from brax import envs\n", "from brax import math\n", "from brax.base import Base, Motion, Transform\n", "from brax.base import State as PipelineState\n", "from brax.envs.base import Env, PipelineEnv, State\n", "from brax.io import html, mjcf, model\n", "from brax.mjx.base import State as MjxState\n", "from brax.training.agents.ppo import networks as ppo_networks\n", "from brax.training.agents.ppo import train as ppo\n", "from brax.training.agents.sac import networks as sac_networks\n", "from brax.training.agents.sac import train as sac\n", "from etils import epath\n", "from flax import struct\n", "from flax.training import orbax_utils\n", "from IPython.display import HTML, clear_output\n", "import jax\n", "from jax import numpy as jp\n", "from matplotlib import pyplot as plt\n", "import mediapy as media\n", "from ml_collections import config_dict\n", "import mujoco\n", "from mujoco import mjx\n", "import numpy as np\n", "from orbax import checkpoint as ocp" ] }, { "cell_type": "code", "source": [ "#@title Install MuJoCo Playground\n", "!pip install playground" ], "metadata": { "cellView": "form", "id": "UoTLSx4cFRdy" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#@title Import The Playground\n", "\n", "from mujoco_playground import wrapper\n", "from mujoco_playground import registry" ], "metadata": { "cellView": "form", "id": "gYm2h7m8w3Nv" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Manipulation\n", "\n", "MuJoCo Playground contains several manipulation environments (all listed below after running the command)." ], "metadata": { "id": "LcibXbyKt4FI" } }, { "cell_type": "code", "source": [ "registry.manipulation.ALL_ENVS" ], "metadata": { "id": "ox0Gze9Ct5AM" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "# Franka Emika Panda\n", "\n", "Let's start off with the simplest environment, simply picking up a cube with the Franka Emika Panda." ], "metadata": { "id": "_R01tjWfI-i6" } }, { "cell_type": "code", "source": [ "env_name = 'PandaPickCubeOrientation'\n", "env = registry.load(env_name)\n", "env_cfg = registry.get_default_config(env_name)" ], "metadata": { "id": "kPJeoQeEJBSA" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "env_cfg" ], "metadata": { "id": "6n9UT9N1wR5K" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Train Policy\n", "\n", "Let's train the pick cube policy and visualize rollouts. The policy takes roughly 3 minutes to train on an RTX 4090." ], "metadata": { "id": "Thm7nZueM4cz" } }, { "cell_type": "code", "source": [ "from mujoco_playground.config import manipulation_params\n", "ppo_params = manipulation_params.brax_ppo_config(env_name)\n", "ppo_params" ], "metadata": { "id": "B9T_UVZYLDdM" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "vBEEQyY6M5OC" }, "source": [ "### PPO" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XKFzyP7wM5OD" }, "outputs": [], "source": [ "x_data, y_data, y_dataerr = [], [], []\n", "times = [datetime.now()]\n", "\n", "\n", "def progress(num_steps, metrics):\n", " clear_output(wait=True)\n", "\n", " times.append(datetime.now())\n", " x_data.append(num_steps)\n", " y_data.append(metrics[\"eval/episode_reward\"])\n", " y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n", "\n", " plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n", " plt.xlabel(\"# environment steps\")\n", " plt.ylabel(\"reward per episode\")\n", " plt.title(f\"y={y_data[-1]:.3f}\")\n", " plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n", "\n", " display(plt.gcf())\n", "\n", "ppo_training_params = dict(ppo_params)\n", "network_factory = ppo_networks.make_ppo_networks\n", "if \"network_factory\" in ppo_params:\n", " del ppo_training_params[\"network_factory\"]\n", " network_factory = functools.partial(\n", " ppo_networks.make_ppo_networks,\n", " **ppo_params.network_factory\n", " )\n", "\n", "train_fn = functools.partial(\n", " ppo.train, **dict(ppo_training_params),\n", " network_factory=network_factory,\n", " progress_fn=progress,\n", " seed=1\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FGrlulWbM5OD" }, "outputs": [], "source": [ "make_inference_fn, params, metrics = train_fn(\n", " environment=env,\n", " wrap_env_fn=wrapper.wrap_for_brax_training,\n", ")\n", "print(f\"time to jit: {times[1] - times[0]}\")\n", "print(f\"time to train: {times[-1] - times[1]}\")" ] }, { "cell_type": "markdown", "source": [ "## Visualize Rollouts" ], "metadata": { "id": "mHVmccs-oMSo" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sG5a2FFXoUKw" }, "outputs": [], "source": [ "jit_reset = jax.jit(env.reset)\n", "jit_step = jax.jit(env.step)\n", "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C_1CY9xDoUKw" }, "outputs": [], "source": [ "rng = jax.random.PRNGKey(42)\n", "rollout = []\n", "n_episodes = 1\n", "\n", "for _ in range(n_episodes):\n", " state = jit_reset(rng)\n", " rollout.append(state)\n", " for i in range(env_cfg.episode_length):\n", " act_rng, rng = jax.random.split(rng)\n", " ctrl, _ = jit_inference_fn(state.obs, act_rng)\n", " state = jit_step(state, ctrl)\n", " rollout.append(state)\n", "\n", "render_every = 1\n", "frames = env.render(rollout[::render_every])\n", "rewards = [s.reward for s in rollout]\n", "media.show_video(frames, fps=1.0 / env.dt / render_every)" ] }, { "cell_type": "markdown", "source": [ "While the above policy is very simple, the work was extended using the Madrona batch renderer, and policies were transferred on a real robot. We encourage folks to check out the Madrona-MJX tutorial notebooks ([part 1](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb) and [part 2](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb))!" ], "metadata": { "id": "v5r4FwivlnoG" } }, { "cell_type": "markdown", "source": [ "# Dexterous Manipulation\n", "\n", "Let's now train a policy that was transferred onto a real Leap Hand robot with the `LeapCubeReorient` environment! The environment contains a cube placed in the center of the hand, and the goal is to re-orient the cube in SO(3)." ], "metadata": { "id": "YVQsrEE3mMj8" } }, { "cell_type": "code", "source": [ "env_name = 'LeapCubeReorient'\n", "env = registry.load(env_name)\n", "env_cfg = registry.get_default_config(env_name)" ], "metadata": { "id": "oPaTdWqVmuPt" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "env_cfg" ], "metadata": { "id": "c0OII08RmuPt" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "## Train Policy\n", "\n", "Let's train an initial policy and visualize the rollouts. Notice that the PPO parameters contain `policy_obs_key` and `value_obs_key` fields, which allow us to train brax PPO with [asymmetric](https://arxiv.org/abs/1710.06542) observations for the actor and the critic. While the actor recieves proprioceptive state similar in nature to the real-world camera tracking sensors, the critic network recieves privileged state only available in the simulator. This enables more sample efficient learning, and we are able to train an initial policy in 33 minutes on a single RTX 4090.\n", "\n", "Depending on the GPU device and topology, training can be brought down to 10-20 minutes as shown in the MuJoCo Playground technical report." ], "metadata": { "id": "3g335ZYFmuPt" } }, { "cell_type": "code", "source": [ "from mujoco_playground.config import manipulation_params\n", "ppo_params = manipulation_params.brax_ppo_config(env_name)\n", "ppo_params" ], "metadata": { "id": "cc1Ka4hYmuPt" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "Ulr1ih6PmuPu" }, "source": [ "### PPO" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gzwRjUGLmuPu" }, "outputs": [], "source": [ "x_data, y_data, y_dataerr = [], [], []\n", "times = [datetime.now()]\n", "\n", "\n", "def progress(num_steps, metrics):\n", " clear_output(wait=True)\n", "\n", " times.append(datetime.now())\n", " x_data.append(num_steps)\n", " y_data.append(metrics[\"eval/episode_reward\"])\n", " y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n", "\n", " plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n", " plt.xlabel(\"# environment steps\")\n", " plt.ylabel(\"reward per episode\")\n", " plt.title(f\"y={y_data[-1]:.3f}\")\n", " plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n", "\n", " display(plt.gcf())\n", "\n", "ppo_training_params = dict(ppo_params)\n", "network_factory = ppo_networks.make_ppo_networks\n", "if \"network_factory\" in ppo_params:\n", " del ppo_training_params[\"network_factory\"]\n", " network_factory = functools.partial(\n", " ppo_networks.make_ppo_networks,\n", " **ppo_params.network_factory\n", " )\n", "\n", "train_fn = functools.partial(\n", " ppo.train, **dict(ppo_training_params),\n", " network_factory=network_factory,\n", " progress_fn=progress,\n", " seed=1\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "YmortADGmuPu" }, "outputs": [], "source": [ "make_inference_fn, params, metrics = train_fn(\n", " environment=env,\n", " wrap_env_fn=wrapper.wrap_for_brax_training,\n", ")\n", "print(f\"time to jit: {times[1] - times[0]}\")\n", "print(f\"time to train: {times[-1] - times[1]}\")" ] }, { "cell_type": "markdown", "source": [ "## Visualize Rollouts" ], "metadata": { "id": "xIB_677emuPu" } }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xBgGvZpTmuPu" }, "outputs": [], "source": [ "jit_reset = jax.jit(env.reset)\n", "jit_step = jax.jit(env.step)\n", "jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ksj6_9PwmuPu" }, "outputs": [], "source": [ "rng = jax.random.PRNGKey(42)\n", "rollout = []\n", "n_episodes = 1\n", "\n", "for _ in range(n_episodes):\n", " state = jit_reset(rng)\n", " rollout.append(state)\n", " for i in range(env_cfg.episode_length):\n", " act_rng, rng = jax.random.split(rng)\n", " ctrl, _ = jit_inference_fn(state.obs, act_rng)\n", " state = jit_step(state, ctrl)\n", " rollout.append(state)\n", "\n", "render_every = 1\n", "frames = env.render(rollout[::render_every])\n", "rewards = [s.reward for s in rollout]\n", "media.show_video(frames, fps=1.0 / env.dt / render_every)" ] }, { "cell_type": "markdown", "source": [ "The above policy solves the task, but may look a little bit jittery. To get robust sim-to-real transfer, we retrained from previous checkpoints using a curriculum on the maximum torque to facilitate exploration early on in the curriculum, and to produce smoother actions for the final policy. More details can be found in the MuJoCo Playground technical report!" ], "metadata": { "id": "dWIVTxq5nhs5" } }, { "cell_type": "markdown", "metadata": { "id": "CBtrAqns35sI" }, "source": [ "🙌 Thanks for stopping by The Playground!" ] } ], "metadata": { "colab": { "private_outputs": true, "toc_visible": true, "provenance": [], "machine_shape": "hm", "gpuType": "A100" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "nbformat": 4, "nbformat_minor": 0 }