Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files- init_mujoco.py +4 -0
- samples/locomotion.ipynb +1758 -0
- samples/manipulation.ipynb +650 -0
- samples/tutorial.ipynb +2258 -0
init_mujoco.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import The Playground (this will clone Mujoco assets)
|
2 |
+
|
3 |
+
from mujoco_playground import wrapper
|
4 |
+
from mujoco_playground import registry
|
samples/locomotion.ipynb
ADDED
@@ -0,0 +1,1758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "MpkYHwCqk7W-"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"\n",
|
10 |
+
"\n",
|
11 |
+
"\n",
|
12 |
+
"\n",
|
13 |
+
"\n",
|
14 |
+
"\n"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"metadata": {
|
20 |
+
"id": "xBSdkbmGN2K-"
|
21 |
+
},
|
22 |
+
"source": [
|
23 |
+
"### Copyright notice"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {
|
29 |
+
"id": "_UbO9uhtBSX5"
|
30 |
+
},
|
31 |
+
"source": [
|
32 |
+
"> <p><small><small>Copyright 2025 DeepMind Technologies Limited.</small></p>\n",
|
33 |
+
"> <p><small><small>Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href=\"http://www.apache.org/licenses/LICENSE-2.0\">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>\n",
|
34 |
+
"> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"metadata": {
|
40 |
+
"id": "dNIJkb_FM2Ux"
|
41 |
+
},
|
42 |
+
"source": [
|
43 |
+
"# Locomotion in The Playground! <a href=\"https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/locomotion.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/></a>\n",
|
44 |
+
"\n",
|
45 |
+
"In this notebook, we'll walk through a few locomotion environments available in MuJoCo Playground.\n",
|
46 |
+
"\n",
|
47 |
+
"**A Hugging Face Space with GPU acceleration is required.**\n"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": 1,
|
53 |
+
"metadata": {
|
54 |
+
"cellView": "form",
|
55 |
+
"id": "Xqo7pyX-n72M"
|
56 |
+
},
|
57 |
+
"outputs": [
|
58 |
+
{
|
59 |
+
"name": "stdout",
|
60 |
+
"output_type": "stream",
|
61 |
+
"text": [
|
62 |
+
"Collecting jax[cuda12]\n",
|
63 |
+
" Downloading jax-0.6.2-py3-none-any.whl.metadata (13 kB)\n",
|
64 |
+
"Collecting jaxlib<=0.6.2,>=0.6.2 (from jax[cuda12])\n",
|
65 |
+
" Downloading jaxlib-0.6.2-cp313-cp313-manylinux2014_x86_64.whl.metadata (1.3 kB)\n",
|
66 |
+
"Collecting ml_dtypes>=0.5.0 (from jax[cuda12])\n",
|
67 |
+
" Downloading ml_dtypes-0.5.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n",
|
68 |
+
"Collecting numpy>=1.26 (from jax[cuda12])\n",
|
69 |
+
" Downloading numpy-2.3.1-cp313-cp313-manylinux_2_28_x86_64.whl.metadata (62 kB)\n",
|
70 |
+
"Collecting opt_einsum (from jax[cuda12])\n",
|
71 |
+
" Downloading opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)\n",
|
72 |
+
"Collecting scipy>=1.12 (from jax[cuda12])\n",
|
73 |
+
" Downloading scipy-1.16.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (61 kB)\n",
|
74 |
+
"Collecting jax-cuda12-plugin<=0.6.2,>=0.6.2 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
75 |
+
" Downloading jax_cuda12_plugin-0.6.2-cp313-cp313-manylinux2014_x86_64.whl.metadata (1.7 kB)\n",
|
76 |
+
"Collecting jax-cuda12-pjrt==0.6.2 (from jax-cuda12-plugin<=0.6.2,>=0.6.2->jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
77 |
+
" Downloading jax_cuda12_pjrt-0.6.2-py3-none-manylinux2014_x86_64.whl.metadata (579 bytes)\n",
|
78 |
+
"Collecting nvidia-cublas-cu12>=12.1.3.1 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
79 |
+
" Downloading nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_x86_64.whl.metadata (1.7 kB)\n",
|
80 |
+
"Collecting nvidia-cuda-cupti-cu12>=12.1.105 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
81 |
+
" Downloading nvidia_cuda_cupti_cu12-12.9.79-py3-none-manylinux_2_25_x86_64.whl.metadata (1.8 kB)\n",
|
82 |
+
"Collecting nvidia-cuda-nvcc-cu12>=12.6.85 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
83 |
+
" Downloading nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)\n",
|
84 |
+
"Collecting nvidia-cuda-runtime-cu12>=12.1.105 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
85 |
+
" Downloading nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)\n",
|
86 |
+
"Collecting nvidia-cudnn-cu12<10.0,>=9.8 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
87 |
+
" Downloading nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl.metadata (1.8 kB)\n",
|
88 |
+
"Collecting nvidia-cufft-cu12>=11.0.2.54 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
89 |
+
" Downloading nvidia_cufft_cu12-11.4.1.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)\n",
|
90 |
+
"Collecting nvidia-cusolver-cu12>=11.4.5.107 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
91 |
+
" Downloading nvidia_cusolver_cu12-11.7.5.82-py3-none-manylinux_2_27_x86_64.whl.metadata (1.9 kB)\n",
|
92 |
+
"Collecting nvidia-cusparse-cu12>=12.1.0.106 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
93 |
+
" Downloading nvidia_cusparse_cu12-12.5.10.65-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.8 kB)\n",
|
94 |
+
"Collecting nvidia-nccl-cu12>=2.18.1 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
95 |
+
" Downloading nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.0 kB)\n",
|
96 |
+
"Collecting nvidia-nvjitlink-cu12>=12.1.105 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
97 |
+
" Downloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)\n",
|
98 |
+
"Collecting nvidia-cuda-nvrtc-cu12>=12.1.55 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
99 |
+
" Downloading nvidia_cuda_nvrtc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)\n",
|
100 |
+
"Collecting nvidia-nvshmem-cu12>=3.2.5 (from jax-cuda12-plugin[with-cuda]<=0.6.2,>=0.6.2; extra == \"cuda12\"->jax[cuda12])\n",
|
101 |
+
" Downloading nvidia_nvshmem_cu12-3.3.9-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (2.1 kB)\n",
|
102 |
+
"Downloading jax-0.6.2-py3-none-any.whl (2.7 MB)\n",
|
103 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m161.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
104 |
+
"\u001b[?25hDownloading jax_cuda12_plugin-0.6.2-cp313-cp313-manylinux2014_x86_64.whl (15.9 MB)\n",
|
105 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m15.9/15.9 MB\u001b[0m \u001b[31m167.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
106 |
+
"\u001b[?25hDownloading jax_cuda12_pjrt-0.6.2-py3-none-manylinux2014_x86_64.whl (125.3 MB)\n",
|
107 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m125.3/125.3 MB\u001b[0m \u001b[31m246.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
108 |
+
"\u001b[?25hDownloading jaxlib-0.6.2-cp313-cp313-manylinux2014_x86_64.whl (89.9 MB)\n",
|
109 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.9/89.9 MB\u001b[0m \u001b[31m221.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
110 |
+
"\u001b[?25hDownloading nvidia_cudnn_cu12-9.10.2.21-py3-none-manylinux_2_27_x86_64.whl (706.8 MB)\n",
|
111 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m706.8/706.8 MB\u001b[0m \u001b[31m71.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
112 |
+
"\u001b[?25hDownloading ml_dtypes-0.5.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.7 MB)\n",
|
113 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.7/4.7 MB\u001b[0m \u001b[31m284.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
114 |
+
"\u001b[?25hDownloading numpy-2.3.1-cp313-cp313-manylinux_2_28_x86_64.whl (16.6 MB)\n",
|
115 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m16.6/16.6 MB\u001b[0m \u001b[31m213.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
116 |
+
"\u001b[?25hDownloading nvidia_cublas_cu12-12.9.1.4-py3-none-manylinux_2_27_x86_64.whl (581.2 MB)\n",
|
117 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m581.2/581.2 MB\u001b[0m \u001b[31m87.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
118 |
+
"\u001b[?25hDownloading nvidia_cuda_cupti_cu12-12.9.79-py3-none-manylinux_2_25_x86_64.whl (10.8 MB)\n",
|
119 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.8/10.8 MB\u001b[0m \u001b[31m259.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
120 |
+
"\u001b[?25hDownloading nvidia_cuda_nvcc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (40.5 MB)\n",
|
121 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━���━━━━━━━━━━━━━━\u001b[0m \u001b[32m40.5/40.5 MB\u001b[0m \u001b[31m236.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
122 |
+
"\u001b[?25hDownloading nvidia_cuda_nvrtc_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (89.6 MB)\n",
|
123 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.6/89.6 MB\u001b[0m \u001b[31m226.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0ma \u001b[36m0:00:01\u001b[0m\n",
|
124 |
+
"\u001b[?25hDownloading nvidia_cuda_runtime_cu12-12.9.79-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.5 MB)\n",
|
125 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.5/3.5 MB\u001b[0m \u001b[31m220.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
126 |
+
"\u001b[?25hDownloading nvidia_cufft_cu12-11.4.1.4-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (200.9 MB)\n",
|
127 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m200.9/200.9 MB\u001b[0m \u001b[31m199.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
128 |
+
"\u001b[?25hDownloading nvidia_cusolver_cu12-11.7.5.82-py3-none-manylinux_2_27_x86_64.whl (338.1 MB)\n",
|
129 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m338.1/338.1 MB\u001b[0m \u001b[31m155.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
130 |
+
"\u001b[?25hDownloading nvidia_cusparse_cu12-12.5.10.65-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (366.5 MB)\n",
|
131 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m366.5/366.5 MB\u001b[0m \u001b[31m144.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
132 |
+
"\u001b[?25hDownloading nvidia_nccl_cu12-2.27.5-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (322.3 MB)\n",
|
133 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m322.3/322.3 MB\u001b[0m \u001b[31m169.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
134 |
+
"\u001b[?25hDownloading nvidia_nvjitlink_cu12-12.9.86-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl (39.7 MB)\n",
|
135 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m39.7/39.7 MB\u001b[0m \u001b[31m267.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
136 |
+
"\u001b[?25hDownloading nvidia_nvshmem_cu12-3.3.9-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (124.6 MB)\n",
|
137 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.6/124.6 MB\u001b[0m \u001b[31m198.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
138 |
+
"\u001b[?25hDownloading scipy-1.16.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (35.1 MB)\n",
|
139 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m35.1/35.1 MB\u001b[0m \u001b[31m203.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
140 |
+
"\u001b[?25hDownloading opt_einsum-3.4.0-py3-none-any.whl (71 kB)\n",
|
141 |
+
"Installing collected packages: jax-cuda12-pjrt, opt_einsum, nvidia-nvshmem-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-nvcc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, numpy, jax-cuda12-plugin, scipy, nvidia-cusparse-cu12, nvidia-cufft-cu12, nvidia-cudnn-cu12, ml_dtypes, nvidia-cusolver-cu12, jaxlib, jax\n",
|
142 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20/20\u001b[0m [jax]32m19/20\u001b[0m [jax]ib]cusolver-cu12]2]2]\n",
|
143 |
+
"\u001b[1A\u001b[2KSuccessfully installed jax-0.6.2 jax-cuda12-pjrt-0.6.2 jax-cuda12-plugin-0.6.2 jaxlib-0.6.2 ml_dtypes-0.5.1 numpy-2.3.1 nvidia-cublas-cu12-12.9.1.4 nvidia-cuda-cupti-cu12-12.9.79 nvidia-cuda-nvcc-cu12-12.9.86 nvidia-cuda-nvrtc-cu12-12.9.86 nvidia-cuda-runtime-cu12-12.9.79 nvidia-cudnn-cu12-9.10.2.21 nvidia-cufft-cu12-11.4.1.4 nvidia-cusolver-cu12-11.7.5.82 nvidia-cusparse-cu12-12.5.10.65 nvidia-nccl-cu12-2.27.5 nvidia-nvjitlink-cu12-12.9.86 nvidia-nvshmem-cu12-3.3.9 opt_einsum-3.4.0 scipy-1.16.0\n",
|
144 |
+
"Collecting mujoco\n",
|
145 |
+
" Downloading mujoco-3.3.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)\n",
|
146 |
+
"Collecting absl-py (from mujoco)\n",
|
147 |
+
" Downloading absl_py-2.3.0-py3-none-any.whl.metadata (2.4 kB)\n",
|
148 |
+
"Collecting etils[epath] (from mujoco)\n",
|
149 |
+
" Downloading etils-1.12.2-py3-none-any.whl.metadata (6.5 kB)\n",
|
150 |
+
"Collecting glfw (from mujoco)\n",
|
151 |
+
" Downloading glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl.metadata (5.4 kB)\n",
|
152 |
+
"Requirement already satisfied: numpy in /home/user/miniconda/lib/python3.13/site-packages (from mujoco) (2.3.1)\n",
|
153 |
+
"Collecting pyopengl (from mujoco)\n",
|
154 |
+
" Downloading PyOpenGL-3.1.9-py3-none-any.whl.metadata (3.3 kB)\n",
|
155 |
+
"Collecting fsspec (from etils[epath]->mujoco)\n",
|
156 |
+
" Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)\n",
|
157 |
+
"Collecting importlib_resources (from etils[epath]->mujoco)\n",
|
158 |
+
" Downloading importlib_resources-6.5.2-py3-none-any.whl.metadata (3.9 kB)\n",
|
159 |
+
"Requirement already satisfied: typing_extensions in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco) (4.12.2)\n",
|
160 |
+
"Collecting zipp (from etils[epath]->mujoco)\n",
|
161 |
+
" Downloading zipp-3.23.0-py3-none-any.whl.metadata (3.6 kB)\n",
|
162 |
+
"Downloading mujoco-3.3.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.6 MB)\n",
|
163 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.6/6.6 MB\u001b[0m \u001b[31m178.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
164 |
+
"\u001b[?25hDownloading absl_py-2.3.0-py3-none-any.whl (135 kB)\n",
|
165 |
+
"Downloading etils-1.12.2-py3-none-any.whl (167 kB)\n",
|
166 |
+
"Downloading fsspec-2025.5.1-py3-none-any.whl (199 kB)\n",
|
167 |
+
"Downloading glfw-2.9.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl (243 kB)\n",
|
168 |
+
"Downloading importlib_resources-6.5.2-py3-none-any.whl (37 kB)\n",
|
169 |
+
"Downloading PyOpenGL-3.1.9-py3-none-any.whl (3.2 MB)\n",
|
170 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.2/3.2 MB\u001b[0m \u001b[31m228.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
171 |
+
"\u001b[?25hDownloading zipp-3.23.0-py3-none-any.whl (10 kB)\n",
|
172 |
+
"Installing collected packages: pyopengl, glfw, zipp, importlib_resources, fsspec, etils, absl-py, mujoco\n",
|
173 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8/8\u001b[0m [mujoco]2m7/8\u001b[0m [mujoco]]\n",
|
174 |
+
"\u001b[1A\u001b[2KSuccessfully installed absl-py-2.3.0 etils-1.12.2 fsspec-2025.5.1 glfw-2.9.0 importlib_resources-6.5.2 mujoco-3.3.3 pyopengl-3.1.9 zipp-3.23.0\n",
|
175 |
+
"Collecting mujoco_mjx\n",
|
176 |
+
" Downloading mujoco_mjx-3.3.3-py3-none-any.whl.metadata (3.4 kB)\n",
|
177 |
+
"Requirement already satisfied: absl-py in /home/user/miniconda/lib/python3.13/site-packages (from mujoco_mjx) (2.3.0)\n",
|
178 |
+
"Requirement already satisfied: etils[epath] in /home/user/miniconda/lib/python3.13/site-packages (from mujoco_mjx) (1.12.2)\n",
|
179 |
+
"Requirement already satisfied: jax in /home/user/miniconda/lib/python3.13/site-packages (from mujoco_mjx) (0.6.2)\n",
|
180 |
+
"Requirement already satisfied: jaxlib in /home/user/miniconda/lib/python3.13/site-packages (from mujoco_mjx) (0.6.2)\n",
|
181 |
+
"Requirement already satisfied: mujoco>=3.3.3.dev0 in /home/user/miniconda/lib/python3.13/site-packages (from mujoco_mjx) (3.3.3)\n",
|
182 |
+
"Requirement already satisfied: scipy in /home/user/miniconda/lib/python3.13/site-packages (from mujoco_mjx) (1.16.0)\n",
|
183 |
+
"Collecting trimesh (from mujoco_mjx)\n",
|
184 |
+
" Downloading trimesh-4.6.13-py3-none-any.whl.metadata (18 kB)\n",
|
185 |
+
"Requirement already satisfied: glfw in /home/user/miniconda/lib/python3.13/site-packages (from mujoco>=3.3.3.dev0->mujoco_mjx) (2.9.0)\n",
|
186 |
+
"Requirement already satisfied: numpy in /home/user/miniconda/lib/python3.13/site-packages (from mujoco>=3.3.3.dev0->mujoco_mjx) (2.3.1)\n",
|
187 |
+
"Requirement already satisfied: pyopengl in /home/user/miniconda/lib/python3.13/site-packages (from mujoco>=3.3.3.dev0->mujoco_mjx) (3.1.9)\n",
|
188 |
+
"Requirement already satisfied: fsspec in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco_mjx) (2025.5.1)\n",
|
189 |
+
"Requirement already satisfied: importlib_resources in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco_mjx) (6.5.2)\n",
|
190 |
+
"Requirement already satisfied: typing_extensions in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco_mjx) (4.12.2)\n",
|
191 |
+
"Requirement already satisfied: zipp in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco_mjx) (3.23.0)\n",
|
192 |
+
"Requirement already satisfied: ml_dtypes>=0.5.0 in /home/user/miniconda/lib/python3.13/site-packages (from jax->mujoco_mjx) (0.5.1)\n",
|
193 |
+
"Requirement already satisfied: opt_einsum in /home/user/miniconda/lib/python3.13/site-packages (from jax->mujoco_mjx) (3.4.0)\n",
|
194 |
+
"Downloading mujoco_mjx-3.3.3-py3-none-any.whl (6.7 MB)\n",
|
195 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.7/6.7 MB\u001b[0m \u001b[31m138.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
196 |
+
"\u001b[?25hDownloading trimesh-4.6.13-py3-none-any.whl (712 kB)\n",
|
197 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m712.4/712.4 kB\u001b[0m \u001b[31m124.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
198 |
+
"\u001b[?25hInstalling collected packages: trimesh, mujoco_mjx\n",
|
199 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2/2\u001b[0m [mujoco_mjx]2\u001b[0m [mujoco_mjx]\n",
|
200 |
+
"\u001b[1A\u001b[2KSuccessfully installed mujoco_mjx-3.3.3 trimesh-4.6.13\n",
|
201 |
+
"Collecting brax\n",
|
202 |
+
" Downloading brax-0.12.4-py3-none-any.whl.metadata (20 kB)\n",
|
203 |
+
"Requirement already satisfied: absl-py in /home/user/miniconda/lib/python3.13/site-packages (from brax) (2.3.0)\n",
|
204 |
+
"Requirement already satisfied: etils in /home/user/miniconda/lib/python3.13/site-packages (from brax) (1.12.2)\n",
|
205 |
+
"Collecting flask (from brax)\n",
|
206 |
+
" Downloading flask-3.1.1-py3-none-any.whl.metadata (3.0 kB)\n",
|
207 |
+
"Collecting flask-cors (from brax)\n",
|
208 |
+
" Downloading flask_cors-6.0.1-py3-none-any.whl.metadata (5.3 kB)\n",
|
209 |
+
"Collecting flax (from brax)\n",
|
210 |
+
" Downloading flax-0.10.6-py3-none-any.whl.metadata (11 kB)\n",
|
211 |
+
"Requirement already satisfied: jax>=0.4.6 in /home/user/miniconda/lib/python3.13/site-packages (from brax) (0.6.2)\n",
|
212 |
+
"Requirement already satisfied: jaxlib>=0.4.6 in /home/user/miniconda/lib/python3.13/site-packages (from brax) (0.6.2)\n",
|
213 |
+
"Collecting jaxopt (from brax)\n",
|
214 |
+
" Downloading jaxopt-0.8.5-py3-none-any.whl.metadata (3.3 kB)\n",
|
215 |
+
"Requirement already satisfied: jinja2 in /home/user/miniconda/lib/python3.13/site-packages (from brax) (3.1.6)\n",
|
216 |
+
"Collecting ml-collections (from brax)\n",
|
217 |
+
" Downloading ml_collections-1.1.0-py3-none-any.whl.metadata (22 kB)\n",
|
218 |
+
"Requirement already satisfied: mujoco in /home/user/miniconda/lib/python3.13/site-packages (from brax) (3.3.3)\n",
|
219 |
+
"Requirement already satisfied: mujoco-mjx in /home/user/miniconda/lib/python3.13/site-packages (from brax) (3.3.3)\n",
|
220 |
+
"Requirement already satisfied: numpy in /home/user/miniconda/lib/python3.13/site-packages (from brax) (2.3.1)\n",
|
221 |
+
"Collecting optax (from brax)\n",
|
222 |
+
" Downloading optax-0.2.5-py3-none-any.whl.metadata (7.5 kB)\n",
|
223 |
+
"Collecting orbax-checkpoint (from brax)\n",
|
224 |
+
" Downloading orbax_checkpoint-0.11.17-py3-none-any.whl.metadata (2.2 kB)\n",
|
225 |
+
"Collecting pillow (from brax)\n",
|
226 |
+
" Downloading pillow-11.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (9.0 kB)\n",
|
227 |
+
"Requirement already satisfied: scipy in /home/user/miniconda/lib/python3.13/site-packages (from brax) (1.16.0)\n",
|
228 |
+
"Collecting tensorboardx (from brax)\n",
|
229 |
+
" Downloading tensorboardx-2.6.4-py3-none-any.whl.metadata (6.2 kB)\n",
|
230 |
+
"Requirement already satisfied: trimesh in /home/user/miniconda/lib/python3.13/site-packages (from brax) (4.6.13)\n",
|
231 |
+
"Requirement already satisfied: typing-extensions in /home/user/miniconda/lib/python3.13/site-packages (from brax) (4.12.2)\n",
|
232 |
+
"Requirement already satisfied: ml_dtypes>=0.5.0 in /home/user/miniconda/lib/python3.13/site-packages (from jax>=0.4.6->brax) (0.5.1)\n",
|
233 |
+
"Requirement already satisfied: opt_einsum in /home/user/miniconda/lib/python3.13/site-packages (from jax>=0.4.6->brax) (3.4.0)\n",
|
234 |
+
"Collecting blinker>=1.9.0 (from flask->brax)\n",
|
235 |
+
" Downloading blinker-1.9.0-py3-none-any.whl.metadata (1.6 kB)\n",
|
236 |
+
"Collecting click>=8.1.3 (from flask->brax)\n",
|
237 |
+
" Downloading click-8.2.1-py3-none-any.whl.metadata (2.5 kB)\n",
|
238 |
+
"Collecting itsdangerous>=2.2.0 (from flask->brax)\n",
|
239 |
+
" Downloading itsdangerous-2.2.0-py3-none-any.whl.metadata (1.9 kB)\n",
|
240 |
+
"Requirement already satisfied: markupsafe>=2.1.1 in /home/user/miniconda/lib/python3.13/site-packages (from flask->brax) (3.0.2)\n",
|
241 |
+
"Collecting werkzeug>=3.1.0 (from flask->brax)\n",
|
242 |
+
" Downloading werkzeug-3.1.3-py3-none-any.whl.metadata (3.7 kB)\n",
|
243 |
+
"Collecting msgpack (from flax->brax)\n",
|
244 |
+
" Downloading msgpack-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.4 kB)\n",
|
245 |
+
"Collecting tensorstore (from flax->brax)\n",
|
246 |
+
" Downloading tensorstore-0.1.75-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (21 kB)\n",
|
247 |
+
"Requirement already satisfied: rich>=11.1 in /home/user/miniconda/lib/python3.13/site-packages (from flax->brax) (13.9.4)\n",
|
248 |
+
"Requirement already satisfied: PyYAML>=5.4.1 in /home/user/miniconda/lib/python3.13/site-packages (from flax->brax) (6.0.2)\n",
|
249 |
+
"Collecting treescope>=0.1.7 (from flax->brax)\n",
|
250 |
+
" Downloading treescope-0.1.9-py3-none-any.whl.metadata (6.6 kB)\n",
|
251 |
+
"Requirement already satisfied: markdown-it-py>=2.2.0 in /home/user/miniconda/lib/python3.13/site-packages (from rich>=11.1->flax->brax) (2.2.0)\n",
|
252 |
+
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/user/miniconda/lib/python3.13/site-packages (from rich>=11.1->flax->brax) (2.19.1)\n",
|
253 |
+
"Requirement already satisfied: mdurl~=0.1 in /home/user/miniconda/lib/python3.13/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax->brax) (0.1.0)\n",
|
254 |
+
"Requirement already satisfied: glfw in /home/user/miniconda/lib/python3.13/site-packages (from mujoco->brax) (2.9.0)\n",
|
255 |
+
"Requirement already satisfied: pyopengl in /home/user/miniconda/lib/python3.13/site-packages (from mujoco->brax) (3.1.9)\n",
|
256 |
+
"Requirement already satisfied: fsspec in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco->brax) (2025.5.1)\n",
|
257 |
+
"Requirement already satisfied: importlib_resources in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco->brax) (6.5.2)\n",
|
258 |
+
"Requirement already satisfied: zipp in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco->brax) (3.23.0)\n",
|
259 |
+
"Collecting chex>=0.1.87 (from optax->brax)\n",
|
260 |
+
" Downloading chex-0.1.89-py3-none-any.whl.metadata (17 kB)\n",
|
261 |
+
"Requirement already satisfied: setuptools in /home/user/miniconda/lib/python3.13/site-packages (from chex>=0.1.87->optax->brax) (78.1.1)\n",
|
262 |
+
"Collecting toolz>=0.9.0 (from chex>=0.1.87->optax->brax)\n",
|
263 |
+
" Downloading toolz-1.0.0-py3-none-any.whl.metadata (5.1 kB)\n",
|
264 |
+
"Requirement already satisfied: nest_asyncio in /home/user/miniconda/lib/python3.13/site-packages (from orbax-checkpoint->brax) (1.6.0)\n",
|
265 |
+
"Collecting protobuf (from orbax-checkpoint->brax)\n",
|
266 |
+
" Downloading protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl.metadata (593 bytes)\n",
|
267 |
+
"Collecting humanize (from orbax-checkpoint->brax)\n",
|
268 |
+
" Downloading humanize-4.12.3-py3-none-any.whl.metadata (7.8 kB)\n",
|
269 |
+
"Collecting simplejson>=3.16.0 (from orbax-checkpoint->brax)\n",
|
270 |
+
" Downloading simplejson-3.20.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.3 kB)\n",
|
271 |
+
"Requirement already satisfied: packaging in /home/user/miniconda/lib/python3.13/site-packages (from tensorboardx->brax) (24.2)\n",
|
272 |
+
"Downloading brax-0.12.4-py3-none-any.whl (341 kB)\n",
|
273 |
+
"Downloading flask-3.1.1-py3-none-any.whl (103 kB)\n",
|
274 |
+
"Downloading blinker-1.9.0-py3-none-any.whl (8.5 kB)\n",
|
275 |
+
"Downloading click-8.2.1-py3-none-any.whl (102 kB)\n",
|
276 |
+
"Downloading itsdangerous-2.2.0-py3-none-any.whl (16 kB)\n",
|
277 |
+
"Downloading werkzeug-3.1.3-py3-none-any.whl (224 kB)\n",
|
278 |
+
"Downloading flask_cors-6.0.1-py3-none-any.whl (13 kB)\n",
|
279 |
+
"Downloading flax-0.10.6-py3-none-any.whl (447 kB)\n",
|
280 |
+
"Downloading treescope-0.1.9-py3-none-any.whl (182 kB)\n",
|
281 |
+
"Downloading jaxopt-0.8.5-py3-none-any.whl (172 kB)\n",
|
282 |
+
"Downloading ml_collections-1.1.0-py3-none-any.whl (76 kB)\n",
|
283 |
+
"Downloading msgpack-1.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (423 kB)\n",
|
284 |
+
"Downloading optax-0.2.5-py3-none-any.whl (354 kB)\n",
|
285 |
+
"Downloading chex-0.1.89-py3-none-any.whl (99 kB)\n",
|
286 |
+
"Downloading toolz-1.0.0-py3-none-any.whl (56 kB)\n",
|
287 |
+
"Downloading orbax_checkpoint-0.11.17-py3-none-any.whl (479 kB)\n",
|
288 |
+
"Downloading simplejson-3.20.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (152 kB)\n",
|
289 |
+
"Downloading tensorstore-0.1.75-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.8 MB)\n",
|
290 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m18.8/18.8 MB\u001b[0m \u001b[31m25.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m00:01\u001b[0m00:01\u001b[0m\n",
|
291 |
+
"\u001b[?25hDownloading humanize-4.12.3-py3-none-any.whl (128 kB)\n",
|
292 |
+
"Downloading pillow-11.3.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.6 MB)\n",
|
293 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.6/6.6 MB\u001b[0m \u001b[31m40.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
294 |
+
"\u001b[?25hDownloading protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl (321 kB)\n",
|
295 |
+
"Downloading tensorboardx-2.6.4-py3-none-any.whl (87 kB)\n",
|
296 |
+
"Installing collected packages: werkzeug, treescope, toolz, simplejson, protobuf, pillow, msgpack, ml-collections, itsdangerous, humanize, click, blinker, tensorstore, tensorboardx, flask, flask-cors, orbax-checkpoint, jaxopt, chex, optax, flax, brax\n",
|
297 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m22/22\u001b[0m [brax]2m21/22\u001b[0m [brax]]]heckpoint]\n",
|
298 |
+
"\u001b[1A\u001b[2KSuccessfully installed blinker-1.9.0 brax-0.12.4 chex-0.1.89 click-8.2.1 flask-3.1.1 flask-cors-6.0.1 flax-0.10.6 humanize-4.12.3 itsdangerous-2.2.0 jaxopt-0.8.5 ml-collections-1.1.0 msgpack-1.1.1 optax-0.2.5 orbax-checkpoint-0.11.17 pillow-11.3.0 protobuf-6.31.1 simplejson-3.20.1 tensorboardx-2.6.4 tensorstore-0.1.75 toolz-1.0.0 treescope-0.1.9 werkzeug-3.1.3\n",
|
299 |
+
"Collecting mediapy\n",
|
300 |
+
" Downloading mediapy-1.2.4-py3-none-any.whl.metadata (4.8 kB)\n",
|
301 |
+
"Requirement already satisfied: ipython in /home/user/miniconda/lib/python3.13/site-packages (from mediapy) (9.4.0)\n",
|
302 |
+
"Collecting matplotlib (from mediapy)\n",
|
303 |
+
" Downloading matplotlib-3.10.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n",
|
304 |
+
"Requirement already satisfied: numpy in /home/user/miniconda/lib/python3.13/site-packages (from mediapy) (2.3.1)\n",
|
305 |
+
"Requirement already satisfied: Pillow in /home/user/miniconda/lib/python3.13/site-packages (from mediapy) (11.3.0)\n",
|
306 |
+
"Requirement already satisfied: decorator in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (5.2.1)\n",
|
307 |
+
"Requirement already satisfied: ipython-pygments-lexers in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (1.1.1)\n",
|
308 |
+
"Requirement already satisfied: jedi>=0.16 in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (0.19.2)\n",
|
309 |
+
"Requirement already satisfied: matplotlib-inline in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (0.1.7)\n",
|
310 |
+
"Requirement already satisfied: pexpect>4.3 in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (4.9.0)\n",
|
311 |
+
"Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (3.0.51)\n",
|
312 |
+
"Requirement already satisfied: pygments>=2.4.0 in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (2.19.1)\n",
|
313 |
+
"Requirement already satisfied: stack_data in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (0.6.3)\n",
|
314 |
+
"Requirement already satisfied: traitlets>=5.13.0 in /home/user/miniconda/lib/python3.13/site-packages (from ipython->mediapy) (5.14.3)\n",
|
315 |
+
"Requirement already satisfied: wcwidth in /home/user/miniconda/lib/python3.13/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython->mediapy) (0.2.13)\n",
|
316 |
+
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /home/user/miniconda/lib/python3.13/site-packages (from jedi>=0.16->ipython->mediapy) (0.8.4)\n",
|
317 |
+
"Requirement already satisfied: ptyprocess>=0.5 in /home/user/miniconda/lib/python3.13/site-packages (from pexpect>4.3->ipython->mediapy) (0.7.0)\n",
|
318 |
+
"Collecting contourpy>=1.0.1 (from matplotlib->mediapy)\n",
|
319 |
+
" Downloading contourpy-1.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)\n",
|
320 |
+
"Collecting cycler>=0.10 (from matplotlib->mediapy)\n",
|
321 |
+
" Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)\n",
|
322 |
+
"Collecting fonttools>=4.22.0 (from matplotlib->mediapy)\n",
|
323 |
+
" Downloading fonttools-4.58.4-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (106 kB)\n",
|
324 |
+
"Collecting kiwisolver>=1.3.1 (from matplotlib->mediapy)\n",
|
325 |
+
" Downloading kiwisolver-1.4.8-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.2 kB)\n",
|
326 |
+
"Requirement already satisfied: packaging>=20.0 in /home/user/miniconda/lib/python3.13/site-packages (from matplotlib->mediapy) (24.2)\n",
|
327 |
+
"Collecting pyparsing>=2.3.1 (from matplotlib->mediapy)\n",
|
328 |
+
" Downloading pyparsing-3.2.3-py3-none-any.whl.metadata (5.0 kB)\n",
|
329 |
+
"Requirement already satisfied: python-dateutil>=2.7 in /home/user/miniconda/lib/python3.13/site-packages (from matplotlib->mediapy) (2.9.0.post0)\n",
|
330 |
+
"Requirement already satisfied: six>=1.5 in /home/user/miniconda/lib/python3.13/site-packages (from python-dateutil>=2.7->matplotlib->mediapy) (1.17.0)\n",
|
331 |
+
"Requirement already satisfied: executing>=1.2.0 in /home/user/miniconda/lib/python3.13/site-packages (from stack_data->ipython->mediapy) (2.2.0)\n",
|
332 |
+
"Requirement already satisfied: asttokens>=2.1.0 in /home/user/miniconda/lib/python3.13/site-packages (from stack_data->ipython->mediapy) (3.0.0)\n",
|
333 |
+
"Requirement already satisfied: pure-eval in /home/user/miniconda/lib/python3.13/site-packages (from stack_data->ipython->mediapy) (0.2.3)\n",
|
334 |
+
"Downloading mediapy-1.2.4-py3-none-any.whl (26 kB)\n",
|
335 |
+
"Downloading matplotlib-3.10.3-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)\n",
|
336 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.6/8.6 MB\u001b[0m \u001b[31m49.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
337 |
+
"\u001b[?25hDownloading contourpy-1.3.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (322 kB)\n",
|
338 |
+
"Downloading cycler-0.12.1-py3-none-any.whl (8.3 kB)\n",
|
339 |
+
"Downloading fonttools-4.58.4-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl (4.9 MB)\n",
|
340 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m34.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
341 |
+
"\u001b[?25hDownloading kiwisolver-1.4.8-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.5 MB)\n",
|
342 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.5/1.5 MB\u001b[0m \u001b[31m13.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
343 |
+
"\u001b[?25hDownloading pyparsing-3.2.3-py3-none-any.whl (111 kB)\n",
|
344 |
+
"Installing collected packages: pyparsing, kiwisolver, fonttools, cycler, contourpy, matplotlib, mediapy\n",
|
345 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7/7\u001b[0m [mediapy]m6/7\u001b[0m [mediapy]ib]\n",
|
346 |
+
"\u001b[1A\u001b[2KSuccessfully installed contourpy-1.3.2 cycler-0.12.1 fonttools-4.58.4 kiwisolver-1.4.8 matplotlib-3.10.3 mediapy-1.2.4 pyparsing-3.2.3\n"
|
347 |
+
]
|
348 |
+
}
|
349 |
+
],
|
350 |
+
"source": [
|
351 |
+
"#@title Install pre-requisites\n",
|
352 |
+
"!pip install \"jax[cuda12]\"\n",
|
353 |
+
"!pip install mujoco\n",
|
354 |
+
"!pip install mujoco_mjx\n",
|
355 |
+
"!pip install brax\n",
|
356 |
+
"!pip install mediapy"
|
357 |
+
]
|
358 |
+
},
|
359 |
+
{
|
360 |
+
"cell_type": "code",
|
361 |
+
"execution_count": 2,
|
362 |
+
"metadata": {
|
363 |
+
"cellView": "form",
|
364 |
+
"id": "IbZxYDxzoz5R"
|
365 |
+
},
|
366 |
+
"outputs": [
|
367 |
+
{
|
368 |
+
"name": "stdout",
|
369 |
+
"output_type": "stream",
|
370 |
+
"text": [
|
371 |
+
"Tue Jul 1 12:12:35 2025 \n",
|
372 |
+
"+-----------------------------------------------------------------------------------------+\n",
|
373 |
+
"| NVIDIA-SMI 570.158.01 Driver Version: 570.158.01 CUDA Version: 12.8 |\n",
|
374 |
+
"|-----------------------------------------+------------------------+----------------------+\n",
|
375 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
376 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
377 |
+
"| | | MIG M. |\n",
|
378 |
+
"|=========================================+========================+======================|\n",
|
379 |
+
"| 0 NVIDIA L40S On | 00000000:30:00.0 Off | 0 |\n",
|
380 |
+
"| N/A 38C P8 36W / 350W | 0MiB / 46068MiB | 0% Default |\n",
|
381 |
+
"| | | N/A |\n",
|
382 |
+
"+-----------------------------------------+------------------------+----------------------+\n",
|
383 |
+
" \n",
|
384 |
+
"+-----------------------------------------------------------------------------------------+\n",
|
385 |
+
"| Processes: |\n",
|
386 |
+
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
387 |
+
"| ID ID Usage |\n",
|
388 |
+
"|=========================================================================================|\n",
|
389 |
+
"| No running processes found |\n",
|
390 |
+
"+-----------------------------------------------------------------------------------------+\n",
|
391 |
+
"Setting environment variable to use GPU rendering:\n",
|
392 |
+
"env: MUJOCO_GL=egl\n",
|
393 |
+
"Checking that the installation succeeded:\n",
|
394 |
+
"Installation successful.\n"
|
395 |
+
]
|
396 |
+
}
|
397 |
+
],
|
398 |
+
"source": [
|
399 |
+
"# @title Check if MuJoCo installation was successful\n",
|
400 |
+
"\n",
|
401 |
+
"import distutils.util\n",
|
402 |
+
"import os\n",
|
403 |
+
"import subprocess\n",
|
404 |
+
"\n",
|
405 |
+
"if subprocess.run('nvidia-smi').returncode:\n",
|
406 |
+
" raise RuntimeError(\n",
|
407 |
+
" 'Cannot communicate with GPU. '\n",
|
408 |
+
" 'Make sure you are using a GPU Colab runtime. '\n",
|
409 |
+
" 'Go to the Runtime menu and select Choose runtime type.'\n",
|
410 |
+
" )\n",
|
411 |
+
"\n",
|
412 |
+
"# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
|
413 |
+
"# This is usually installed as part of an Nvidia driver package, but the Colab\n",
|
414 |
+
"# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
|
415 |
+
"# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
|
416 |
+
"NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
|
417 |
+
"#if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
|
418 |
+
"# with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
|
419 |
+
"# f.write(\"\"\"{\n",
|
420 |
+
"# \"file_format_version\" : \"1.0.0\",\n",
|
421 |
+
"# \"ICD\" : {\n",
|
422 |
+
"# \"library_path\" : \"libEGL_nvidia.so.0\"\n",
|
423 |
+
"# }\n",
|
424 |
+
"#}\n",
|
425 |
+
"#\"\"\")\n",
|
426 |
+
"\n",
|
427 |
+
"# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
|
428 |
+
"print('Setting environment variable to use GPU rendering:')\n",
|
429 |
+
"%env MUJOCO_GL=egl\n",
|
430 |
+
"\n",
|
431 |
+
"try:\n",
|
432 |
+
" print('Checking that the installation succeeded:')\n",
|
433 |
+
" import mujoco\n",
|
434 |
+
"\n",
|
435 |
+
" mujoco.MjModel.from_xml_string('<mujoco/>')\n",
|
436 |
+
"except Exception as e:\n",
|
437 |
+
" raise e from RuntimeError(\n",
|
438 |
+
" 'Something went wrong during installation. Check the shell output above '\n",
|
439 |
+
" 'for more information.\\n'\n",
|
440 |
+
" 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
|
441 |
+
" 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
|
442 |
+
" )\n",
|
443 |
+
"\n",
|
444 |
+
"print('Installation successful.')\n",
|
445 |
+
"\n",
|
446 |
+
"# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n",
|
447 |
+
"xla_flags = os.environ.get('XLA_FLAGS', '')\n",
|
448 |
+
"xla_flags += ' --xla_gpu_triton_gemm_any=True'\n",
|
449 |
+
"os.environ['XLA_FLAGS'] = xla_flags"
|
450 |
+
]
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"cell_type": "code",
|
454 |
+
"execution_count": 3,
|
455 |
+
"metadata": {
|
456 |
+
"cellView": "form",
|
457 |
+
"id": "T5f4w3Kq2X14"
|
458 |
+
},
|
459 |
+
"outputs": [],
|
460 |
+
"source": [
|
461 |
+
"# @title Import packages for plotting and creating graphics\n",
|
462 |
+
"import json\n",
|
463 |
+
"import itertools\n",
|
464 |
+
"import time\n",
|
465 |
+
"from typing import Callable, List, NamedTuple, Optional, Union\n",
|
466 |
+
"import numpy as np\n",
|
467 |
+
"\n",
|
468 |
+
"# Graphics and plotting.\n",
|
469 |
+
"#print(\"Installing mediapy:\")\n",
|
470 |
+
"#!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
|
471 |
+
"#!pip install -q mediapy\n",
|
472 |
+
"import mediapy as media\n",
|
473 |
+
"import matplotlib.pyplot as plt\n",
|
474 |
+
"\n",
|
475 |
+
"# More legible printing from numpy.\n",
|
476 |
+
"np.set_printoptions(precision=3, suppress=True, linewidth=100)"
|
477 |
+
]
|
478 |
+
},
|
479 |
+
{
|
480 |
+
"cell_type": "code",
|
481 |
+
"execution_count": 4,
|
482 |
+
"metadata": {
|
483 |
+
"cellView": "form",
|
484 |
+
"id": "ObF1UXrkb0Nd"
|
485 |
+
},
|
486 |
+
"outputs": [],
|
487 |
+
"source": [
|
488 |
+
"# @title Import MuJoCo, MJX, and Brax\n",
|
489 |
+
"from datetime import datetime\n",
|
490 |
+
"import functools\n",
|
491 |
+
"import os\n",
|
492 |
+
"from typing import Any, Dict, Sequence, Tuple, Union\n",
|
493 |
+
"from brax import base\n",
|
494 |
+
"from brax import envs\n",
|
495 |
+
"from brax import math\n",
|
496 |
+
"from brax.base import Base, Motion, Transform\n",
|
497 |
+
"from brax.base import State as PipelineState\n",
|
498 |
+
"from brax.envs.base import Env, PipelineEnv, State\n",
|
499 |
+
"from brax.io import html, mjcf, model\n",
|
500 |
+
"from brax.mjx.base import State as MjxState\n",
|
501 |
+
"from brax.training.agents.ppo import networks as ppo_networks\n",
|
502 |
+
"from brax.training.agents.ppo import train as ppo\n",
|
503 |
+
"from brax.training.agents.sac import networks as sac_networks\n",
|
504 |
+
"from brax.training.agents.sac import train as sac\n",
|
505 |
+
"from etils import epath\n",
|
506 |
+
"from flax import struct\n",
|
507 |
+
"from flax.training import orbax_utils\n",
|
508 |
+
"from IPython.display import HTML, clear_output\n",
|
509 |
+
"import jax\n",
|
510 |
+
"from jax import numpy as jp\n",
|
511 |
+
"from matplotlib import pyplot as plt\n",
|
512 |
+
"import mediapy as media\n",
|
513 |
+
"from ml_collections import config_dict\n",
|
514 |
+
"import mujoco\n",
|
515 |
+
"from mujoco import mjx\n",
|
516 |
+
"import numpy as np\n",
|
517 |
+
"from orbax import checkpoint as ocp"
|
518 |
+
]
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"cell_type": "code",
|
522 |
+
"execution_count": 6,
|
523 |
+
"metadata": {
|
524 |
+
"cellView": "form",
|
525 |
+
"id": "UoTLSx4cFRdy"
|
526 |
+
},
|
527 |
+
"outputs": [
|
528 |
+
{
|
529 |
+
"name": "stdout",
|
530 |
+
"output_type": "stream",
|
531 |
+
"text": [
|
532 |
+
"Collecting playground\n",
|
533 |
+
" Downloading playground-0.0.5-py3-none-any.whl.metadata (8.7 kB)\n",
|
534 |
+
"Requirement already satisfied: brax>=0.12.1 in /home/user/miniconda/lib/python3.13/site-packages (from playground) (0.12.4)\n",
|
535 |
+
"Requirement already satisfied: etils in /home/user/miniconda/lib/python3.13/site-packages (from playground) (1.12.2)\n",
|
536 |
+
"Requirement already satisfied: flax in /home/user/miniconda/lib/python3.13/site-packages (from playground) (0.10.6)\n",
|
537 |
+
"Requirement already satisfied: jax in /home/user/miniconda/lib/python3.13/site-packages (from playground) (0.6.2)\n",
|
538 |
+
"Collecting lxml (from playground)\n",
|
539 |
+
" Downloading lxml-6.0.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)\n",
|
540 |
+
"Requirement already satisfied: ml-collections in /home/user/miniconda/lib/python3.13/site-packages (from playground) (1.1.0)\n",
|
541 |
+
"Requirement already satisfied: mujoco-mjx>=3.2.7 in /home/user/miniconda/lib/python3.13/site-packages (from playground) (3.3.3)\n",
|
542 |
+
"Requirement already satisfied: mujoco>=3.2.7 in /home/user/miniconda/lib/python3.13/site-packages (from playground) (3.3.3)\n",
|
543 |
+
"Requirement already satisfied: tqdm in /home/user/miniconda/lib/python3.13/site-packages (from playground) (4.67.1)\n",
|
544 |
+
"Requirement already satisfied: absl-py in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (2.3.0)\n",
|
545 |
+
"Requirement already satisfied: flask in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (3.1.1)\n",
|
546 |
+
"Requirement already satisfied: flask-cors in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (6.0.1)\n",
|
547 |
+
"Requirement already satisfied: jaxlib>=0.4.6 in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (0.6.2)\n",
|
548 |
+
"Requirement already satisfied: jaxopt in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (0.8.5)\n",
|
549 |
+
"Requirement already satisfied: jinja2 in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (3.1.6)\n",
|
550 |
+
"Requirement already satisfied: numpy in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (2.3.1)\n",
|
551 |
+
"Requirement already satisfied: optax in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (0.2.5)\n",
|
552 |
+
"Requirement already satisfied: orbax-checkpoint in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (0.11.17)\n",
|
553 |
+
"Requirement already satisfied: pillow in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (11.3.0)\n",
|
554 |
+
"Requirement already satisfied: scipy in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (1.16.0)\n",
|
555 |
+
"Requirement already satisfied: tensorboardx in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (2.6.4)\n",
|
556 |
+
"Requirement already satisfied: trimesh in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (4.6.13)\n",
|
557 |
+
"Requirement already satisfied: typing-extensions in /home/user/miniconda/lib/python3.13/site-packages (from brax>=0.12.1->playground) (4.12.2)\n",
|
558 |
+
"Requirement already satisfied: ml_dtypes>=0.5.0 in /home/user/miniconda/lib/python3.13/site-packages (from jax->playground) (0.5.1)\n",
|
559 |
+
"Requirement already satisfied: opt_einsum in /home/user/miniconda/lib/python3.13/site-packages (from jax->playground) (3.4.0)\n",
|
560 |
+
"Requirement already satisfied: glfw in /home/user/miniconda/lib/python3.13/site-packages (from mujoco>=3.2.7->playground) (2.9.0)\n",
|
561 |
+
"Requirement already satisfied: pyopengl in /home/user/miniconda/lib/python3.13/site-packages (from mujoco>=3.2.7->playground) (3.1.9)\n",
|
562 |
+
"Requirement already satisfied: fsspec in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco>=3.2.7->playground) (2025.5.1)\n",
|
563 |
+
"Requirement already satisfied: importlib_resources in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco>=3.2.7->playground) (6.5.2)\n",
|
564 |
+
"Requirement already satisfied: zipp in /home/user/miniconda/lib/python3.13/site-packages (from etils[epath]->mujoco>=3.2.7->playground) (3.23.0)\n",
|
565 |
+
"Requirement already satisfied: blinker>=1.9.0 in /home/user/miniconda/lib/python3.13/site-packages (from flask->brax>=0.12.1->playground) (1.9.0)\n",
|
566 |
+
"Requirement already satisfied: click>=8.1.3 in /home/user/miniconda/lib/python3.13/site-packages (from flask->brax>=0.12.1->playground) (8.2.1)\n",
|
567 |
+
"Requirement already satisfied: itsdangerous>=2.2.0 in /home/user/miniconda/lib/python3.13/site-packages (from flask->brax>=0.12.1->playground) (2.2.0)\n",
|
568 |
+
"Requirement already satisfied: markupsafe>=2.1.1 in /home/user/miniconda/lib/python3.13/site-packages (from flask->brax>=0.12.1->playground) (3.0.2)\n",
|
569 |
+
"Requirement already satisfied: werkzeug>=3.1.0 in /home/user/miniconda/lib/python3.13/site-packages (from flask->brax>=0.12.1->playground) (3.1.3)\n",
|
570 |
+
"Requirement already satisfied: msgpack in /home/user/miniconda/lib/python3.13/site-packages (from flax->playground) (1.1.1)\n",
|
571 |
+
"Requirement already satisfied: tensorstore in /home/user/miniconda/lib/python3.13/site-packages (from flax->playground) (0.1.75)\n",
|
572 |
+
"Requirement already satisfied: rich>=11.1 in /home/user/miniconda/lib/python3.13/site-packages (from flax->playground) (13.9.4)\n",
|
573 |
+
"Requirement already satisfied: PyYAML>=5.4.1 in /home/user/miniconda/lib/python3.13/site-packages (from flax->playground) (6.0.2)\n",
|
574 |
+
"Requirement already satisfied: treescope>=0.1.7 in /home/user/miniconda/lib/python3.13/site-packages (from flax->playground) (0.1.9)\n",
|
575 |
+
"Requirement already satisfied: markdown-it-py>=2.2.0 in /home/user/miniconda/lib/python3.13/site-packages (from rich>=11.1->flax->playground) (2.2.0)\n",
|
576 |
+
"Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/user/miniconda/lib/python3.13/site-packages (from rich>=11.1->flax->playground) (2.19.1)\n",
|
577 |
+
"Requirement already satisfied: mdurl~=0.1 in /home/user/miniconda/lib/python3.13/site-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax->playground) (0.1.0)\n",
|
578 |
+
"Requirement already satisfied: chex>=0.1.87 in /home/user/miniconda/lib/python3.13/site-packages (from optax->brax>=0.12.1->playground) (0.1.89)\n",
|
579 |
+
"Requirement already satisfied: setuptools in /home/user/miniconda/lib/python3.13/site-packages (from chex>=0.1.87->optax->brax>=0.12.1->playground) (78.1.1)\n",
|
580 |
+
"Requirement already satisfied: toolz>=0.9.0 in /home/user/miniconda/lib/python3.13/site-packages (from chex>=0.1.87->optax->brax>=0.12.1->playground) (1.0.0)\n",
|
581 |
+
"Requirement already satisfied: nest_asyncio in /home/user/miniconda/lib/python3.13/site-packages (from orbax-checkpoint->brax>=0.12.1->playground) (1.6.0)\n",
|
582 |
+
"Requirement already satisfied: protobuf in /home/user/miniconda/lib/python3.13/site-packages (from orbax-checkpoint->brax>=0.12.1->playground) (6.31.1)\n",
|
583 |
+
"Requirement already satisfied: humanize in /home/user/miniconda/lib/python3.13/site-packages (from orbax-checkpoint->brax>=0.12.1->playground) (4.12.3)\n",
|
584 |
+
"Requirement already satisfied: simplejson>=3.16.0 in /home/user/miniconda/lib/python3.13/site-packages (from orbax-checkpoint->brax>=0.12.1->playground) (3.20.1)\n",
|
585 |
+
"Requirement already satisfied: packaging in /home/user/miniconda/lib/python3.13/site-packages (from tensorboardx->brax>=0.12.1->playground) (24.2)\n",
|
586 |
+
"Downloading playground-0.0.5-py3-none-any.whl (7.4 MB)\n",
|
587 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m88.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
588 |
+
"\u001b[?25hDownloading lxml-6.0.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (5.2 MB)\n",
|
589 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.2/5.2 MB\u001b[0m \u001b[31m163.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
|
590 |
+
"\u001b[?25hInstalling collected packages: lxml, playground\n",
|
591 |
+
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2/2\u001b[0m [playground]2\u001b[0m [playground]\n",
|
592 |
+
"\u001b[1A\u001b[2KSuccessfully installed lxml-6.0.0 playground-0.0.5\n"
|
593 |
+
]
|
594 |
+
}
|
595 |
+
],
|
596 |
+
"source": [
|
597 |
+
"#@title Install MuJoCo Playground\n",
|
598 |
+
"!pip install playground"
|
599 |
+
]
|
600 |
+
},
|
601 |
+
{
|
602 |
+
"cell_type": "code",
|
603 |
+
"execution_count": 7,
|
604 |
+
"metadata": {
|
605 |
+
"cellView": "form",
|
606 |
+
"id": "gYm2h7m8w3Nv"
|
607 |
+
},
|
608 |
+
"outputs": [
|
609 |
+
{
|
610 |
+
"name": "stdout",
|
611 |
+
"output_type": "stream",
|
612 |
+
"text": [
|
613 |
+
"mujoco_menagerie not found. Downloading...\n"
|
614 |
+
]
|
615 |
+
},
|
616 |
+
{
|
617 |
+
"name": "stderr",
|
618 |
+
"output_type": "stream",
|
619 |
+
"text": [
|
620 |
+
"Cloning mujoco_menagerie: ██████████| 100/100 [00:13<00:00]\n"
|
621 |
+
]
|
622 |
+
},
|
623 |
+
{
|
624 |
+
"name": "stdout",
|
625 |
+
"output_type": "stream",
|
626 |
+
"text": [
|
627 |
+
"Checking out commit 14ceccf557cc47240202f2354d684eca58ff8de4\n",
|
628 |
+
"Successfully downloaded mujoco_menagerie\n"
|
629 |
+
]
|
630 |
+
}
|
631 |
+
],
|
632 |
+
"source": [
|
633 |
+
"#@title Import The Playground\n",
|
634 |
+
"\n",
|
635 |
+
"from mujoco_playground import wrapper\n",
|
636 |
+
"from mujoco_playground import registry"
|
637 |
+
]
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"cell_type": "markdown",
|
641 |
+
"metadata": {
|
642 |
+
"id": "LcibXbyKt4FI"
|
643 |
+
},
|
644 |
+
"source": [
|
645 |
+
"# Locomotion\n",
|
646 |
+
"\n",
|
647 |
+
"MuJoCo Playground contains a host of quadrupedal and bipedal environments (all listed below after running the command)."
|
648 |
+
]
|
649 |
+
},
|
650 |
+
{
|
651 |
+
"cell_type": "code",
|
652 |
+
"execution_count": 8,
|
653 |
+
"metadata": {
|
654 |
+
"id": "ox0Gze9Ct5AM"
|
655 |
+
},
|
656 |
+
"outputs": [
|
657 |
+
{
|
658 |
+
"data": {
|
659 |
+
"text/plain": [
|
660 |
+
"('ApolloJoystickFlatTerrain',\n",
|
661 |
+
" 'BarkourJoystick',\n",
|
662 |
+
" 'BerkeleyHumanoidJoystickFlatTerrain',\n",
|
663 |
+
" 'BerkeleyHumanoidJoystickRoughTerrain',\n",
|
664 |
+
" 'G1JoystickFlatTerrain',\n",
|
665 |
+
" 'G1JoystickRoughTerrain',\n",
|
666 |
+
" 'Go1JoystickFlatTerrain',\n",
|
667 |
+
" 'Go1JoystickRoughTerrain',\n",
|
668 |
+
" 'Go1Getup',\n",
|
669 |
+
" 'Go1Handstand',\n",
|
670 |
+
" 'Go1Footstand',\n",
|
671 |
+
" 'H1InplaceGaitTracking',\n",
|
672 |
+
" 'H1JoystickGaitTracking',\n",
|
673 |
+
" 'Op3Joystick',\n",
|
674 |
+
" 'SpotFlatTerrainJoystick',\n",
|
675 |
+
" 'SpotGetup',\n",
|
676 |
+
" 'SpotJoystickGaitTracking',\n",
|
677 |
+
" 'T1JoystickFlatTerrain',\n",
|
678 |
+
" 'T1JoystickRoughTerrain')"
|
679 |
+
]
|
680 |
+
},
|
681 |
+
"execution_count": 8,
|
682 |
+
"metadata": {},
|
683 |
+
"output_type": "execute_result"
|
684 |
+
}
|
685 |
+
],
|
686 |
+
"source": [
|
687 |
+
"registry.locomotion.ALL_ENVS"
|
688 |
+
]
|
689 |
+
},
|
690 |
+
{
|
691 |
+
"cell_type": "markdown",
|
692 |
+
"metadata": {
|
693 |
+
"id": "_R01tjWfI-i6"
|
694 |
+
},
|
695 |
+
"source": [
|
696 |
+
"# Quadrupedal\n",
|
697 |
+
"\n",
|
698 |
+
"Let's jump right into quadrupedal locomotion! While we have environments available for the Google Barkour and Boston Dynamics Spot robots, the Unitree Go1 environment contains the most trainable policies that were transferred onto the real robot. We'll go right ahead and show a few policies using the Unitree Go1!\n",
|
699 |
+
"\n",
|
700 |
+
"First, let's train a joystick policy, which tracks linear and yaw velocity commands."
|
701 |
+
]
|
702 |
+
},
|
703 |
+
{
|
704 |
+
"cell_type": "code",
|
705 |
+
"execution_count": 9,
|
706 |
+
"metadata": {
|
707 |
+
"id": "kPJeoQeEJBSA"
|
708 |
+
},
|
709 |
+
"outputs": [
|
710 |
+
{
|
711 |
+
"name": "stderr",
|
712 |
+
"output_type": "stream",
|
713 |
+
"text": [
|
714 |
+
"WARNING:2025-07-01 11:54:03,909:jax._src.xla_bridge:794: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n",
|
715 |
+
"WARNING:jax._src.xla_bridge:An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
|
716 |
+
]
|
717 |
+
}
|
718 |
+
],
|
719 |
+
"source": [
|
720 |
+
"env_name = 'Go1JoystickFlatTerrain'\n",
|
721 |
+
"env = registry.load(env_name)\n",
|
722 |
+
"env_cfg = registry.get_default_config(env_name)"
|
723 |
+
]
|
724 |
+
},
|
725 |
+
{
|
726 |
+
"cell_type": "code",
|
727 |
+
"execution_count": null,
|
728 |
+
"metadata": {
|
729 |
+
"id": "6n9UT9N1wR5K"
|
730 |
+
},
|
731 |
+
"outputs": [],
|
732 |
+
"source": [
|
733 |
+
"env_cfg"
|
734 |
+
]
|
735 |
+
},
|
736 |
+
{
|
737 |
+
"cell_type": "markdown",
|
738 |
+
"metadata": {
|
739 |
+
"id": "Thm7nZueM4cz"
|
740 |
+
},
|
741 |
+
"source": [
|
742 |
+
"## Joystick\n",
|
743 |
+
"\n",
|
744 |
+
"Let's train the joystick policy and visualize rollouts:"
|
745 |
+
]
|
746 |
+
},
|
747 |
+
{
|
748 |
+
"cell_type": "code",
|
749 |
+
"execution_count": null,
|
750 |
+
"metadata": {
|
751 |
+
"id": "B9T_UVZYLDdM"
|
752 |
+
},
|
753 |
+
"outputs": [],
|
754 |
+
"source": [
|
755 |
+
"from mujoco_playground.config import locomotion_params\n",
|
756 |
+
"ppo_params = locomotion_params.brax_ppo_config(env_name)\n",
|
757 |
+
"ppo_params"
|
758 |
+
]
|
759 |
+
},
|
760 |
+
{
|
761 |
+
"cell_type": "markdown",
|
762 |
+
"metadata": {
|
763 |
+
"id": "Aefr2OS01D9g"
|
764 |
+
},
|
765 |
+
"source": [
|
766 |
+
"Domain randomization was used to make the policy robust to sim-to-real transfer. Certain environments in the Playground have domain randomization functions implemented. They're available in the registry and can be passed directly to brax RL algorithms. The [domain randomization](https://github.com/google-deepmind/mujoco_playground/blob/main/mujoco_playground/_src/locomotion/go1/randomize.py) function randomizes over friction, armature, center of mass of the torso, and link masses, amongst other simulation parameters."
|
767 |
+
]
|
768 |
+
},
|
769 |
+
{
|
770 |
+
"cell_type": "code",
|
771 |
+
"execution_count": null,
|
772 |
+
"metadata": {
|
773 |
+
"id": "UVA4Bn681DZT"
|
774 |
+
},
|
775 |
+
"outputs": [],
|
776 |
+
"source": [
|
777 |
+
"registry.get_domain_randomizer(env_name)"
|
778 |
+
]
|
779 |
+
},
|
780 |
+
{
|
781 |
+
"cell_type": "markdown",
|
782 |
+
"metadata": {
|
783 |
+
"id": "vBEEQyY6M5OC"
|
784 |
+
},
|
785 |
+
"source": [
|
786 |
+
"### Train\n",
|
787 |
+
"\n",
|
788 |
+
"The policy takes 7 minutes to train on an RTX 4090."
|
789 |
+
]
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"cell_type": "code",
|
793 |
+
"execution_count": null,
|
794 |
+
"metadata": {
|
795 |
+
"id": "XKFzyP7wM5OD"
|
796 |
+
},
|
797 |
+
"outputs": [],
|
798 |
+
"source": [
|
799 |
+
"x_data, y_data, y_dataerr = [], [], []\n",
|
800 |
+
"times = [datetime.now()]\n",
|
801 |
+
"\n",
|
802 |
+
"\n",
|
803 |
+
"def progress(num_steps, metrics):\n",
|
804 |
+
" clear_output(wait=True)\n",
|
805 |
+
"\n",
|
806 |
+
" times.append(datetime.now())\n",
|
807 |
+
" x_data.append(num_steps)\n",
|
808 |
+
" y_data.append(metrics[\"eval/episode_reward\"])\n",
|
809 |
+
" y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
|
810 |
+
"\n",
|
811 |
+
" plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
|
812 |
+
" plt.xlabel(\"# environment steps\")\n",
|
813 |
+
" plt.ylabel(\"reward per episode\")\n",
|
814 |
+
" plt.title(f\"y={y_data[-1]:.3f}\")\n",
|
815 |
+
" plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
|
816 |
+
"\n",
|
817 |
+
" display(plt.gcf())\n",
|
818 |
+
"\n",
|
819 |
+
"randomizer = registry.get_domain_randomizer(env_name)\n",
|
820 |
+
"ppo_training_params = dict(ppo_params)\n",
|
821 |
+
"network_factory = ppo_networks.make_ppo_networks\n",
|
822 |
+
"if \"network_factory\" in ppo_params:\n",
|
823 |
+
" del ppo_training_params[\"network_factory\"]\n",
|
824 |
+
" network_factory = functools.partial(\n",
|
825 |
+
" ppo_networks.make_ppo_networks,\n",
|
826 |
+
" **ppo_params.network_factory\n",
|
827 |
+
" )\n",
|
828 |
+
"\n",
|
829 |
+
"train_fn = functools.partial(\n",
|
830 |
+
" ppo.train, **dict(ppo_training_params),\n",
|
831 |
+
" network_factory=network_factory,\n",
|
832 |
+
" randomization_fn=randomizer,\n",
|
833 |
+
" progress_fn=progress\n",
|
834 |
+
")"
|
835 |
+
]
|
836 |
+
},
|
837 |
+
{
|
838 |
+
"cell_type": "code",
|
839 |
+
"execution_count": null,
|
840 |
+
"metadata": {
|
841 |
+
"id": "FGrlulWbM5OD"
|
842 |
+
},
|
843 |
+
"outputs": [],
|
844 |
+
"source": [
|
845 |
+
"make_inference_fn, params, metrics = train_fn(\n",
|
846 |
+
" environment=env,\n",
|
847 |
+
" eval_env=registry.load(env_name, config=env_cfg),\n",
|
848 |
+
" wrap_env_fn=wrapper.wrap_for_brax_training,\n",
|
849 |
+
")\n",
|
850 |
+
"print(f\"time to jit: {times[1] - times[0]}\")\n",
|
851 |
+
"print(f\"time to train: {times[-1] - times[1]}\")"
|
852 |
+
]
|
853 |
+
},
|
854 |
+
{
|
855 |
+
"cell_type": "markdown",
|
856 |
+
"metadata": {
|
857 |
+
"id": "AUxSNhq3UqmC"
|
858 |
+
},
|
859 |
+
"source": [
|
860 |
+
"Let's rollout and render the resulting policy!"
|
861 |
+
]
|
862 |
+
},
|
863 |
+
{
|
864 |
+
"cell_type": "code",
|
865 |
+
"execution_count": null,
|
866 |
+
"metadata": {
|
867 |
+
"id": "RBM89g5A2Yoi"
|
868 |
+
},
|
869 |
+
"outputs": [],
|
870 |
+
"source": [
|
871 |
+
"# Enable perturbation in the eval env.\n",
|
872 |
+
"env_cfg = registry.get_default_config(env_name)\n",
|
873 |
+
"env_cfg.pert_config.enable = True\n",
|
874 |
+
"env_cfg.pert_config.velocity_kick = [3.0, 6.0]\n",
|
875 |
+
"env_cfg.pert_config.kick_wait_times = [5.0, 15.0]\n",
|
876 |
+
"env_cfg.command_config.a = [1.5, 0.8, 2*jp.pi]\n",
|
877 |
+
"eval_env = registry.load(env_name, config=env_cfg)\n",
|
878 |
+
"velocity_kick_range = [0.0, 0.0] # Disable velocity kick.\n",
|
879 |
+
"kick_duration_range = [0.05, 0.2]\n",
|
880 |
+
"\n",
|
881 |
+
"jit_reset = jax.jit(eval_env.reset)\n",
|
882 |
+
"jit_step = jax.jit(eval_env.step)\n",
|
883 |
+
"jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
|
884 |
+
]
|
885 |
+
},
|
886 |
+
{
|
887 |
+
"cell_type": "code",
|
888 |
+
"execution_count": null,
|
889 |
+
"metadata": {
|
890 |
+
"cellView": "form",
|
891 |
+
"id": "C_1CY9xDoUKw"
|
892 |
+
},
|
893 |
+
"outputs": [],
|
894 |
+
"source": [
|
895 |
+
"#@title Rollout and Render\n",
|
896 |
+
"from mujoco_playground._src.gait import draw_joystick_command\n",
|
897 |
+
"\n",
|
898 |
+
"x_vel = 0.0 #@param {type: \"number\"}\n",
|
899 |
+
"y_vel = 0.0 #@param {type: \"number\"}\n",
|
900 |
+
"yaw_vel = 3.14 #@param {type: \"number\"}\n",
|
901 |
+
"\n",
|
902 |
+
"\n",
|
903 |
+
"def sample_pert(rng):\n",
|
904 |
+
" rng, key1, key2 = jax.random.split(rng, 3)\n",
|
905 |
+
" pert_mag = jax.random.uniform(\n",
|
906 |
+
" key1, minval=velocity_kick_range[0], maxval=velocity_kick_range[1]\n",
|
907 |
+
" )\n",
|
908 |
+
" duration_seconds = jax.random.uniform(\n",
|
909 |
+
" key2, minval=kick_duration_range[0], maxval=kick_duration_range[1]\n",
|
910 |
+
" )\n",
|
911 |
+
" duration_steps = jp.round(duration_seconds / eval_env.dt).astype(jp.int32)\n",
|
912 |
+
" state.info[\"pert_mag\"] = pert_mag\n",
|
913 |
+
" state.info[\"pert_duration\"] = duration_steps\n",
|
914 |
+
" state.info[\"pert_duration_seconds\"] = duration_seconds\n",
|
915 |
+
" return rng\n",
|
916 |
+
"\n",
|
917 |
+
"\n",
|
918 |
+
"rng = jax.random.PRNGKey(0)\n",
|
919 |
+
"rollout = []\n",
|
920 |
+
"modify_scene_fns = []\n",
|
921 |
+
"\n",
|
922 |
+
"swing_peak = []\n",
|
923 |
+
"rewards = []\n",
|
924 |
+
"linvel = []\n",
|
925 |
+
"angvel = []\n",
|
926 |
+
"track = []\n",
|
927 |
+
"foot_vel = []\n",
|
928 |
+
"rews = []\n",
|
929 |
+
"contact = []\n",
|
930 |
+
"command = jp.array([x_vel, y_vel, yaw_vel])\n",
|
931 |
+
"\n",
|
932 |
+
"state = jit_reset(rng)\n",
|
933 |
+
"if state.info[\"steps_since_last_pert\"] < state.info[\"steps_until_next_pert\"]:\n",
|
934 |
+
" rng = sample_pert(rng)\n",
|
935 |
+
"state.info[\"command\"] = command\n",
|
936 |
+
"for i in range(env_cfg.episode_length):\n",
|
937 |
+
" if state.info[\"steps_since_last_pert\"] < state.info[\"steps_until_next_pert\"]:\n",
|
938 |
+
" rng = sample_pert(rng)\n",
|
939 |
+
" act_rng, rng = jax.random.split(rng)\n",
|
940 |
+
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
|
941 |
+
" state = jit_step(state, ctrl)\n",
|
942 |
+
" state.info[\"command\"] = command\n",
|
943 |
+
" rews.append(\n",
|
944 |
+
" {k: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
|
945 |
+
" )\n",
|
946 |
+
" rollout.append(state)\n",
|
947 |
+
" swing_peak.append(state.info[\"swing_peak\"])\n",
|
948 |
+
" rewards.append(\n",
|
949 |
+
" {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
|
950 |
+
" )\n",
|
951 |
+
" linvel.append(env.get_global_linvel(state.data))\n",
|
952 |
+
" angvel.append(env.get_gyro(state.data))\n",
|
953 |
+
" track.append(\n",
|
954 |
+
" env._reward_tracking_lin_vel(\n",
|
955 |
+
" state.info[\"command\"], env.get_local_linvel(state.data)\n",
|
956 |
+
" )\n",
|
957 |
+
" )\n",
|
958 |
+
"\n",
|
959 |
+
" feet_vel = state.data.sensordata[env._foot_linvel_sensor_adr]\n",
|
960 |
+
" vel_xy = feet_vel[..., :2]\n",
|
961 |
+
" vel_norm = jp.sqrt(jp.linalg.norm(vel_xy, axis=-1))\n",
|
962 |
+
" foot_vel.append(vel_norm)\n",
|
963 |
+
"\n",
|
964 |
+
" contact.append(state.info[\"last_contact\"])\n",
|
965 |
+
"\n",
|
966 |
+
" xyz = np.array(state.data.xpos[env._torso_body_id])\n",
|
967 |
+
" xyz += np.array([0, 0, 0.2])\n",
|
968 |
+
" x_axis = state.data.xmat[env._torso_body_id, 0]\n",
|
969 |
+
" yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
|
970 |
+
" modify_scene_fns.append(\n",
|
971 |
+
" functools.partial(\n",
|
972 |
+
" draw_joystick_command,\n",
|
973 |
+
" cmd=state.info[\"command\"],\n",
|
974 |
+
" xyz=xyz,\n",
|
975 |
+
" theta=yaw,\n",
|
976 |
+
" scl=abs(state.info[\"command\"][0])\n",
|
977 |
+
" / env_cfg.command_config.a[0],\n",
|
978 |
+
" )\n",
|
979 |
+
" )\n",
|
980 |
+
"\n",
|
981 |
+
"\n",
|
982 |
+
"render_every = 2\n",
|
983 |
+
"fps = 1.0 / eval_env.dt / render_every\n",
|
984 |
+
"traj = rollout[::render_every]\n",
|
985 |
+
"mod_fns = modify_scene_fns[::render_every]\n",
|
986 |
+
"\n",
|
987 |
+
"scene_option = mujoco.MjvOption()\n",
|
988 |
+
"scene_option.geomgroup[2] = True\n",
|
989 |
+
"scene_option.geomgroup[3] = False\n",
|
990 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
|
991 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
|
992 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True\n",
|
993 |
+
"\n",
|
994 |
+
"frames = eval_env.render(\n",
|
995 |
+
" traj,\n",
|
996 |
+
" camera=\"track\",\n",
|
997 |
+
" scene_option=scene_option,\n",
|
998 |
+
" width=640,\n",
|
999 |
+
" height=480,\n",
|
1000 |
+
" modify_scene_fns=mod_fns,\n",
|
1001 |
+
")\n",
|
1002 |
+
"media.show_video(frames, fps=fps, loop=False)"
|
1003 |
+
]
|
1004 |
+
},
|
1005 |
+
{
|
1006 |
+
"cell_type": "markdown",
|
1007 |
+
"metadata": {
|
1008 |
+
"id": "1QHdoJ2r30En"
|
1009 |
+
},
|
1010 |
+
"source": [
|
1011 |
+
"Let's visualize the feet positions and the positional drift compared to the commanded linear and angular velocity. This is useful for debugging how well the policy follows the commands!"
|
1012 |
+
]
|
1013 |
+
},
|
1014 |
+
{
|
1015 |
+
"cell_type": "code",
|
1016 |
+
"execution_count": null,
|
1017 |
+
"metadata": {
|
1018 |
+
"cellView": "form",
|
1019 |
+
"id": "gyyynm3ozEet"
|
1020 |
+
},
|
1021 |
+
"outputs": [],
|
1022 |
+
"source": [
|
1023 |
+
"#@title Plot each foot in a 2x2 grid.\n",
|
1024 |
+
"\n",
|
1025 |
+
"swing_peak = jp.array(swing_peak)\n",
|
1026 |
+
"names = [\"FR\", \"FL\", \"RR\", \"RL\"]\n",
|
1027 |
+
"colors = [\"r\", \"g\", \"b\", \"y\"]\n",
|
1028 |
+
"fig, axs = plt.subplots(2, 2)\n",
|
1029 |
+
"for i, ax in enumerate(axs.flat):\n",
|
1030 |
+
" ax.plot(swing_peak[:, i], color=colors[i])\n",
|
1031 |
+
" ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])\n",
|
1032 |
+
" ax.axhline(env_cfg.reward_config.max_foot_height, color=\"k\", linestyle=\"--\")\n",
|
1033 |
+
" ax.set_title(names[i])\n",
|
1034 |
+
" ax.set_xlabel(\"time\")\n",
|
1035 |
+
" ax.set_ylabel(\"height\")\n",
|
1036 |
+
"plt.tight_layout()\n",
|
1037 |
+
"plt.show()\n",
|
1038 |
+
"\n",
|
1039 |
+
"linvel_x = jp.array(linvel)[:, 0]\n",
|
1040 |
+
"linvel_y = jp.array(linvel)[:, 1]\n",
|
1041 |
+
"angvel_yaw = jp.array(angvel)[:, 2]\n",
|
1042 |
+
"\n",
|
1043 |
+
"# Plot whether velocity is within the command range.\n",
|
1044 |
+
"linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode=\"same\")\n",
|
1045 |
+
"linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode=\"same\")\n",
|
1046 |
+
"angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode=\"same\")\n",
|
1047 |
+
"\n",
|
1048 |
+
"fig, axes = plt.subplots(3, 1, figsize=(10, 10))\n",
|
1049 |
+
"axes[0].plot(linvel_x)\n",
|
1050 |
+
"axes[1].plot(linvel_y)\n",
|
1051 |
+
"axes[2].plot(angvel_yaw)\n",
|
1052 |
+
"\n",
|
1053 |
+
"axes[0].set_ylim(\n",
|
1054 |
+
" -env_cfg.command_config.a[0], env_cfg.command_config.a[0]\n",
|
1055 |
+
")\n",
|
1056 |
+
"axes[1].set_ylim(\n",
|
1057 |
+
" -env_cfg.command_config.a[1], env_cfg.command_config.a[1]\n",
|
1058 |
+
")\n",
|
1059 |
+
"axes[2].set_ylim(\n",
|
1060 |
+
" -env_cfg.command_config.a[2], env_cfg.command_config.a[2]\n",
|
1061 |
+
")\n",
|
1062 |
+
"\n",
|
1063 |
+
"for i, ax in enumerate(axes):\n",
|
1064 |
+
" ax.axhline(state.info[\"command\"][i], color=\"red\", linestyle=\"--\")\n",
|
1065 |
+
"\n",
|
1066 |
+
"labels = [\"dx\", \"dy\", \"dyaw\"]\n",
|
1067 |
+
"for i, ax in enumerate(axes):\n",
|
1068 |
+
" ax.set_ylabel(labels[i])"
|
1069 |
+
]
|
1070 |
+
},
|
1071 |
+
{
|
1072 |
+
"cell_type": "markdown",
|
1073 |
+
"metadata": {
|
1074 |
+
"id": "t1QAHuYBQBbl"
|
1075 |
+
},
|
1076 |
+
"source": [
|
1077 |
+
"Now let's visualize what it looks like to slowly increase linear velocity commands."
|
1078 |
+
]
|
1079 |
+
},
|
1080 |
+
{
|
1081 |
+
"cell_type": "code",
|
1082 |
+
"execution_count": null,
|
1083 |
+
"metadata": {
|
1084 |
+
"cellView": "form",
|
1085 |
+
"id": "Q0EuQiVlzh5u"
|
1086 |
+
},
|
1087 |
+
"outputs": [],
|
1088 |
+
"source": [
|
1089 |
+
"#@title Slowly increase linvel commands\n",
|
1090 |
+
"\n",
|
1091 |
+
"rng = jax.random.PRNGKey(0)\n",
|
1092 |
+
"rollout = []\n",
|
1093 |
+
"modify_scene_fns = []\n",
|
1094 |
+
"swing_peak = []\n",
|
1095 |
+
"linvel = []\n",
|
1096 |
+
"angvel = []\n",
|
1097 |
+
"\n",
|
1098 |
+
"x = -0.25\n",
|
1099 |
+
"command = jp.array([x, 0, 0])\n",
|
1100 |
+
"state = jit_reset(rng)\n",
|
1101 |
+
"for i in range(1_400):\n",
|
1102 |
+
" # Increase the forward velocity by 0.25 m/s every 200 steps.\n",
|
1103 |
+
" if i % 200 == 0:\n",
|
1104 |
+
" x += 0.25\n",
|
1105 |
+
" print(f\"Setting x to {x}\")\n",
|
1106 |
+
" command = jp.array([x, 0, 0])\n",
|
1107 |
+
" state.info[\"command\"] = command\n",
|
1108 |
+
" if state.info[\"steps_since_last_pert\"] < state.info[\"steps_until_next_pert\"]:\n",
|
1109 |
+
" rng = sample_pert(rng)\n",
|
1110 |
+
" act_rng, rng = jax.random.split(rng)\n",
|
1111 |
+
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
|
1112 |
+
" state = jit_step(state, ctrl)\n",
|
1113 |
+
" rollout.append(state)\n",
|
1114 |
+
" swing_peak.append(state.info[\"swing_peak\"])\n",
|
1115 |
+
" linvel.append(env.get_global_linvel(state.data))\n",
|
1116 |
+
" angvel.append(env.get_gyro(state.data))\n",
|
1117 |
+
" xyz = np.array(state.data.xpos[env._torso_body_id])\n",
|
1118 |
+
" xyz += np.array([0, 0, 0.2])\n",
|
1119 |
+
" x_axis = state.data.xmat[env._torso_body_id, 0]\n",
|
1120 |
+
" yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
|
1121 |
+
" modify_scene_fns.append(\n",
|
1122 |
+
" functools.partial(\n",
|
1123 |
+
" draw_joystick_command,\n",
|
1124 |
+
" cmd=command,\n",
|
1125 |
+
" xyz=xyz,\n",
|
1126 |
+
" theta=yaw,\n",
|
1127 |
+
" scl=abs(command[0]) / env_cfg.command_config.a[0],\n",
|
1128 |
+
" )\n",
|
1129 |
+
" )\n",
|
1130 |
+
"\n",
|
1131 |
+
"\n",
|
1132 |
+
"# Plot each foot in a 2x2 grid.\n",
|
1133 |
+
"swing_peak = jp.array(swing_peak)\n",
|
1134 |
+
"names = [\"FR\", \"FL\", \"RR\", \"RL\"]\n",
|
1135 |
+
"colors = [\"r\", \"g\", \"b\", \"y\"]\n",
|
1136 |
+
"fig, axs = plt.subplots(2, 2)\n",
|
1137 |
+
"for i, ax in enumerate(axs.flat):\n",
|
1138 |
+
" ax.plot(swing_peak[:, i], color=colors[i])\n",
|
1139 |
+
" ax.set_ylim([0, env_cfg.reward_config.max_foot_height * 1.25])\n",
|
1140 |
+
" ax.axhline(env_cfg.reward_config.max_foot_height, color=\"k\", linestyle=\"--\")\n",
|
1141 |
+
" ax.set_title(names[i])\n",
|
1142 |
+
" ax.set_xlabel(\"time\")\n",
|
1143 |
+
" ax.set_ylabel(\"height\")\n",
|
1144 |
+
"plt.tight_layout()\n",
|
1145 |
+
"plt.show()\n",
|
1146 |
+
"\n",
|
1147 |
+
"linvel_x = jp.array(linvel)[:, 0]\n",
|
1148 |
+
"linvel_y = jp.array(linvel)[:, 1]\n",
|
1149 |
+
"angvel_yaw = jp.array(angvel)[:, 2]\n",
|
1150 |
+
"\n",
|
1151 |
+
"# Plot whether velocity is within the command range.\n",
|
1152 |
+
"linvel_x = jp.convolve(linvel_x, jp.ones(10) / 10, mode=\"same\")\n",
|
1153 |
+
"linvel_y = jp.convolve(linvel_y, jp.ones(10) / 10, mode=\"same\")\n",
|
1154 |
+
"angvel_yaw = jp.convolve(angvel_yaw, jp.ones(10) / 10, mode=\"same\")\n",
|
1155 |
+
"\n",
|
1156 |
+
"fig, axes = plt.subplots(3, 1, figsize=(10, 10))\n",
|
1157 |
+
"axes[0].plot(linvel_x)\n",
|
1158 |
+
"axes[1].plot(linvel_y)\n",
|
1159 |
+
"axes[2].plot(angvel_yaw)\n",
|
1160 |
+
"\n",
|
1161 |
+
"axes[0].set_ylim(\n",
|
1162 |
+
" -env_cfg.command_config.a[0], env_cfg.command_config.a[0]\n",
|
1163 |
+
")\n",
|
1164 |
+
"axes[1].set_ylim(\n",
|
1165 |
+
" -env_cfg.command_config.a[1], env_cfg.command_config.a[1]\n",
|
1166 |
+
")\n",
|
1167 |
+
"axes[2].set_ylim(\n",
|
1168 |
+
" -env_cfg.command_config.a[2], env_cfg.command_config.a[2]\n",
|
1169 |
+
")\n",
|
1170 |
+
"\n",
|
1171 |
+
"for i, ax in enumerate(axes):\n",
|
1172 |
+
" ax.axhline(state.info[\"command\"][i], color=\"red\", linestyle=\"--\")\n",
|
1173 |
+
"\n",
|
1174 |
+
"labels = [\"dx\", \"dy\", \"dyaw\"]\n",
|
1175 |
+
"for i, ax in enumerate(axes):\n",
|
1176 |
+
" ax.set_ylabel(labels[i])\n",
|
1177 |
+
"\n",
|
1178 |
+
"\n",
|
1179 |
+
"render_every = 2\n",
|
1180 |
+
"fps = 1.0 / eval_env.dt / render_every\n",
|
1181 |
+
"print(f\"fps: {fps}\")\n",
|
1182 |
+
"\n",
|
1183 |
+
"traj = rollout[::render_every]\n",
|
1184 |
+
"mod_fns = modify_scene_fns[::render_every]\n",
|
1185 |
+
"assert len(traj) == len(mod_fns)\n",
|
1186 |
+
"\n",
|
1187 |
+
"scene_option = mujoco.MjvOption()\n",
|
1188 |
+
"scene_option.geomgroup[2] = True\n",
|
1189 |
+
"scene_option.geomgroup[3] = False\n",
|
1190 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
|
1191 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = True\n",
|
1192 |
+
"\n",
|
1193 |
+
"frames = eval_env.render(\n",
|
1194 |
+
" traj,\n",
|
1195 |
+
" camera=\"track\",\n",
|
1196 |
+
" height=480,\n",
|
1197 |
+
" width=640,\n",
|
1198 |
+
" modify_scene_fns=mod_fns,\n",
|
1199 |
+
" scene_option=scene_option,\n",
|
1200 |
+
")\n",
|
1201 |
+
"media.show_video(frames, fps=fps, loop=False)"
|
1202 |
+
]
|
1203 |
+
},
|
1204 |
+
{
|
1205 |
+
"cell_type": "markdown",
|
1206 |
+
"metadata": {
|
1207 |
+
"id": "0RHZvXgmzrEJ"
|
1208 |
+
},
|
1209 |
+
"source": [
|
1210 |
+
"## Handstand\n",
|
1211 |
+
"\n",
|
1212 |
+
"Additional policies are available for the Unitree Go1 such as fall-recovery, handstand, and footstand policies. We'll use the handstand policy as an opportunity to demonstrate finetuning policies from prior checkpoints. This will allow us to quickly iterate on training curriculums by modifying the enviornment config between runs.\n",
|
1213 |
+
"\n",
|
1214 |
+
"For the Go1 handstand policy, we'll first train with the default configuration, and then add an energy penalty to make the policy smoother and more likely to transfer onto the robot."
|
1215 |
+
]
|
1216 |
+
},
|
1217 |
+
{
|
1218 |
+
"cell_type": "code",
|
1219 |
+
"execution_count": null,
|
1220 |
+
"metadata": {
|
1221 |
+
"id": "RYriZOAxzEk_"
|
1222 |
+
},
|
1223 |
+
"outputs": [],
|
1224 |
+
"source": [
|
1225 |
+
"from mujoco_playground.config import locomotion_params\n",
|
1226 |
+
"\n",
|
1227 |
+
"env_name = 'Go1Handstand'\n",
|
1228 |
+
"env = registry.load(env_name)\n",
|
1229 |
+
"env_cfg = registry.get_default_config(env_name)\n",
|
1230 |
+
"ppo_params = locomotion_params.brax_ppo_config(env_name)"
|
1231 |
+
]
|
1232 |
+
},
|
1233 |
+
{
|
1234 |
+
"cell_type": "markdown",
|
1235 |
+
"metadata": {
|
1236 |
+
"id": "3nB5ugbdS5kk"
|
1237 |
+
},
|
1238 |
+
"source": [
|
1239 |
+
"Let's create a checkpoint directory and then train a policy with checkpointing."
|
1240 |
+
]
|
1241 |
+
},
|
1242 |
+
{
|
1243 |
+
"cell_type": "code",
|
1244 |
+
"execution_count": null,
|
1245 |
+
"metadata": {
|
1246 |
+
"id": "EyEDpHisS7eO"
|
1247 |
+
},
|
1248 |
+
"outputs": [],
|
1249 |
+
"source": [
|
1250 |
+
"ckpt_path = epath.Path(\"checkpoints\").resolve() / env_name\n",
|
1251 |
+
"ckpt_path.mkdir(parents=True, exist_ok=True)\n",
|
1252 |
+
"print(f\"{ckpt_path}\")\n",
|
1253 |
+
"\n",
|
1254 |
+
"with open(ckpt_path / \"config.json\", \"w\") as fp:\n",
|
1255 |
+
" json.dump(env_cfg.to_dict(), fp, indent=4)"
|
1256 |
+
]
|
1257 |
+
},
|
1258 |
+
{
|
1259 |
+
"cell_type": "code",
|
1260 |
+
"execution_count": null,
|
1261 |
+
"metadata": {
|
1262 |
+
"id": "lCRUYofXSNGT"
|
1263 |
+
},
|
1264 |
+
"outputs": [],
|
1265 |
+
"source": [
|
1266 |
+
"#@title Training fn definition\n",
|
1267 |
+
"x_data, y_data, y_dataerr = [], [], []\n",
|
1268 |
+
"times = [datetime.now()]\n",
|
1269 |
+
"\n",
|
1270 |
+
"\n",
|
1271 |
+
"def policy_params_fn(current_step, make_policy, params):\n",
|
1272 |
+
" del make_policy # Unused.\n",
|
1273 |
+
" orbax_checkpointer = ocp.PyTreeCheckpointer()\n",
|
1274 |
+
" save_args = orbax_utils.save_args_from_target(params)\n",
|
1275 |
+
" path = ckpt_path / f\"{current_step}\"\n",
|
1276 |
+
" orbax_checkpointer.save(path, params, force=True, save_args=save_args)\n",
|
1277 |
+
"\n",
|
1278 |
+
"\n",
|
1279 |
+
"def progress(num_steps, metrics):\n",
|
1280 |
+
" clear_output(wait=True)\n",
|
1281 |
+
"\n",
|
1282 |
+
" times.append(datetime.now())\n",
|
1283 |
+
" x_data.append(num_steps)\n",
|
1284 |
+
" y_data.append(metrics[\"eval/episode_reward\"])\n",
|
1285 |
+
" y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
|
1286 |
+
"\n",
|
1287 |
+
" plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
|
1288 |
+
" plt.xlabel(\"# environment steps\")\n",
|
1289 |
+
" plt.ylabel(\"reward per episode\")\n",
|
1290 |
+
" plt.title(f\"y={y_data[-1]:.3f}\")\n",
|
1291 |
+
" plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
|
1292 |
+
"\n",
|
1293 |
+
" display(plt.gcf())\n",
|
1294 |
+
"\n",
|
1295 |
+
"randomizer = registry.get_domain_randomizer(env_name)\n",
|
1296 |
+
"ppo_training_params = dict(ppo_params)\n",
|
1297 |
+
"network_factory = ppo_networks.make_ppo_networks\n",
|
1298 |
+
"if \"network_factory\" in ppo_params:\n",
|
1299 |
+
" del ppo_training_params[\"network_factory\"]\n",
|
1300 |
+
" network_factory = functools.partial(\n",
|
1301 |
+
" ppo_networks.make_ppo_networks,\n",
|
1302 |
+
" **ppo_params.network_factory\n",
|
1303 |
+
" )\n",
|
1304 |
+
"\n",
|
1305 |
+
"train_fn = functools.partial(\n",
|
1306 |
+
" ppo.train, **dict(ppo_training_params),\n",
|
1307 |
+
" network_factory=network_factory,\n",
|
1308 |
+
" randomization_fn=randomizer,\n",
|
1309 |
+
" progress_fn=progress,\n",
|
1310 |
+
" policy_params_fn=policy_params_fn,\n",
|
1311 |
+
")"
|
1312 |
+
]
|
1313 |
+
},
|
1314 |
+
{
|
1315 |
+
"cell_type": "markdown",
|
1316 |
+
"metadata": {
|
1317 |
+
"id": "A1oK80x1anPp"
|
1318 |
+
},
|
1319 |
+
"source": [
|
1320 |
+
"The initial policy takes 8 minutes to train on an RTX 4090."
|
1321 |
+
]
|
1322 |
+
},
|
1323 |
+
{
|
1324 |
+
"cell_type": "code",
|
1325 |
+
"execution_count": null,
|
1326 |
+
"metadata": {
|
1327 |
+
"id": "MY6P3abhSNGU"
|
1328 |
+
},
|
1329 |
+
"outputs": [],
|
1330 |
+
"source": [
|
1331 |
+
"make_inference_fn, params, metrics = train_fn(\n",
|
1332 |
+
" environment=registry.load(env_name, config=env_cfg),\n",
|
1333 |
+
" eval_env=registry.load(env_name, config=env_cfg),\n",
|
1334 |
+
" wrap_env_fn=wrapper.wrap_for_brax_training,\n",
|
1335 |
+
")\n",
|
1336 |
+
"print(f\"time to jit: {times[1] - times[0]}\")\n",
|
1337 |
+
"print(f\"time to train: {times[-1] - times[1]}\")"
|
1338 |
+
]
|
1339 |
+
},
|
1340 |
+
{
|
1341 |
+
"cell_type": "markdown",
|
1342 |
+
"metadata": {
|
1343 |
+
"id": "4s6PkZ4GWV4Z"
|
1344 |
+
},
|
1345 |
+
"source": [
|
1346 |
+
"Let's visualize the current policy."
|
1347 |
+
]
|
1348 |
+
},
|
1349 |
+
{
|
1350 |
+
"cell_type": "code",
|
1351 |
+
"execution_count": null,
|
1352 |
+
"metadata": {
|
1353 |
+
"cellView": "form",
|
1354 |
+
"id": "WiWOtc_6WbcX"
|
1355 |
+
},
|
1356 |
+
"outputs": [],
|
1357 |
+
"source": [
|
1358 |
+
"#@title Rollout and Render\n",
|
1359 |
+
"inference_fn = make_inference_fn(params, deterministic=True)\n",
|
1360 |
+
"jit_inference_fn = jax.jit(inference_fn)\n",
|
1361 |
+
"\n",
|
1362 |
+
"eval_env = registry.load(env_name, config=env_cfg)\n",
|
1363 |
+
"jit_reset = jax.jit(eval_env.reset)\n",
|
1364 |
+
"jit_step = jax.jit(eval_env.step)\n",
|
1365 |
+
"\n",
|
1366 |
+
"rng = jax.random.PRNGKey(12345)\n",
|
1367 |
+
"rollout = []\n",
|
1368 |
+
"rewards = []\n",
|
1369 |
+
"torso_height = []\n",
|
1370 |
+
"actions = []\n",
|
1371 |
+
"torques = []\n",
|
1372 |
+
"power = []\n",
|
1373 |
+
"qfrc_constraint = []\n",
|
1374 |
+
"qvels = []\n",
|
1375 |
+
"power1 = []\n",
|
1376 |
+
"power2 = []\n",
|
1377 |
+
"for _ in range(10):\n",
|
1378 |
+
" rng, reset_rng = jax.random.split(rng)\n",
|
1379 |
+
" state = jit_reset(reset_rng)\n",
|
1380 |
+
" for i in range(env_cfg.episode_length // 2):\n",
|
1381 |
+
" act_rng, rng = jax.random.split(rng)\n",
|
1382 |
+
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
|
1383 |
+
" actions.append(ctrl)\n",
|
1384 |
+
" state = jit_step(state, ctrl)\n",
|
1385 |
+
" rollout.append(state)\n",
|
1386 |
+
" rewards.append(\n",
|
1387 |
+
" {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
|
1388 |
+
" )\n",
|
1389 |
+
" torso_height.append(state.data.qpos[2])\n",
|
1390 |
+
" torques.append(state.data.actuator_force)\n",
|
1391 |
+
" qvel = state.data.qvel[6:]\n",
|
1392 |
+
" power.append(jp.sum(jp.abs(qvel * state.data.actuator_force)))\n",
|
1393 |
+
" qfrc_constraint.append(jp.linalg.norm(state.data.qfrc_constraint[6:]))\n",
|
1394 |
+
" qvels.append(jp.max(jp.abs(qvel)))\n",
|
1395 |
+
" frc = state.data.actuator_force\n",
|
1396 |
+
" qvel = state.data.qvel[6:]\n",
|
1397 |
+
" power1.append(jp.sum(frc * qvel))\n",
|
1398 |
+
" power2.append(jp.sum(jp.abs(frc * qvel)))\n",
|
1399 |
+
"\n",
|
1400 |
+
"\n",
|
1401 |
+
"render_every = 2\n",
|
1402 |
+
"fps = 1.0 / eval_env.dt / render_every\n",
|
1403 |
+
"traj = rollout[::render_every]\n",
|
1404 |
+
"\n",
|
1405 |
+
"scene_option = mujoco.MjvOption()\n",
|
1406 |
+
"scene_option.geomgroup[2] = True\n",
|
1407 |
+
"scene_option.geomgroup[3] = False\n",
|
1408 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
|
1409 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False\n",
|
1410 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
|
1411 |
+
"\n",
|
1412 |
+
"frames = eval_env.render(\n",
|
1413 |
+
" traj, camera=\"side\", scene_option=scene_option, height=480, width=640\n",
|
1414 |
+
")\n",
|
1415 |
+
"media.show_video(frames, fps=fps, loop=False)\n",
|
1416 |
+
"\n",
|
1417 |
+
"power = jp.array(power1)\n",
|
1418 |
+
"print(f\"Max power: {jp.max(power)}\")"
|
1419 |
+
]
|
1420 |
+
},
|
1421 |
+
{
|
1422 |
+
"cell_type": "markdown",
|
1423 |
+
"metadata": {
|
1424 |
+
"id": "v5p0Z3PPSRik"
|
1425 |
+
},
|
1426 |
+
"source": [
|
1427 |
+
"Notice that the above policy looks jittery and unlikely to transfer on the robot. The max power output is also quite high.\n",
|
1428 |
+
"\n",
|
1429 |
+
"The sim-to-real deployment of the handstand policy was trained using a curriculum on the `energy_termination_threshold`, `energy` and `dof_acc`, which are config values that penalize high torques and high power output. Let's finetune the above policy with a decreased `energy_termination_threshold`, as well as non-zero values for `energy` and `dof_acc` rewards to get a smoother policy."
|
1430 |
+
]
|
1431 |
+
},
|
1432 |
+
{
|
1433 |
+
"cell_type": "markdown",
|
1434 |
+
"metadata": {
|
1435 |
+
"id": "hrjoVL-_WN-r"
|
1436 |
+
},
|
1437 |
+
"source": [
|
1438 |
+
"### Finetune the previous checkpoint"
|
1439 |
+
]
|
1440 |
+
},
|
1441 |
+
{
|
1442 |
+
"cell_type": "code",
|
1443 |
+
"execution_count": null,
|
1444 |
+
"metadata": {
|
1445 |
+
"id": "jTxAySRSSu96"
|
1446 |
+
},
|
1447 |
+
"outputs": [],
|
1448 |
+
"source": [
|
1449 |
+
"env_cfg = registry.get_default_config(env_name)\n",
|
1450 |
+
"env_cfg.energy_termination_threshold = 400 # lower energy termination threshold\n",
|
1451 |
+
"env_cfg.reward_config.energy = -0.003 # non-zero negative `energy` reward\n",
|
1452 |
+
"env_cfg.reward_config.dof_acc = -2.5e-7 # non-zero negative `dof_acc` reward\n",
|
1453 |
+
"\n",
|
1454 |
+
"FINETUNE_PATH = epath.Path(ckpt_path)\n",
|
1455 |
+
"latest_ckpts = list(FINETUNE_PATH.glob(\"*\"))\n",
|
1456 |
+
"latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]\n",
|
1457 |
+
"latest_ckpts.sort(key=lambda x: int(x.name))\n",
|
1458 |
+
"latest_ckpt = latest_ckpts[-1]\n",
|
1459 |
+
"restore_checkpoint_path = latest_ckpt"
|
1460 |
+
]
|
1461 |
+
},
|
1462 |
+
{
|
1463 |
+
"cell_type": "code",
|
1464 |
+
"execution_count": null,
|
1465 |
+
"metadata": {
|
1466 |
+
"id": "_M5IqOR6z4bV"
|
1467 |
+
},
|
1468 |
+
"outputs": [],
|
1469 |
+
"source": [
|
1470 |
+
"x_data, y_data, y_dataerr = [], [], []\n",
|
1471 |
+
"times = [datetime.now()]\n",
|
1472 |
+
"\n",
|
1473 |
+
"make_inference_fn, params, metrics = train_fn(\n",
|
1474 |
+
" environment=registry.load(env_name, config=env_cfg),\n",
|
1475 |
+
" eval_env=registry.load(env_name, config=env_cfg),\n",
|
1476 |
+
" wrap_env_fn=wrapper.wrap_for_brax_training,\n",
|
1477 |
+
" restore_checkpoint_path=restore_checkpoint_path, # restore from the checkpoint!\n",
|
1478 |
+
" seed=1,\n",
|
1479 |
+
")\n",
|
1480 |
+
"print(f\"time to jit: {times[1] - times[0]}\")\n",
|
1481 |
+
"print(f\"time to train: {times[-1] - times[1]}\")"
|
1482 |
+
]
|
1483 |
+
},
|
1484 |
+
{
|
1485 |
+
"cell_type": "code",
|
1486 |
+
"execution_count": null,
|
1487 |
+
"metadata": {
|
1488 |
+
"cellView": "form",
|
1489 |
+
"id": "tzG8eY2lz4dk"
|
1490 |
+
},
|
1491 |
+
"outputs": [],
|
1492 |
+
"source": [
|
1493 |
+
"#@title Rollout and Render Finetune Policy\n",
|
1494 |
+
"inference_fn = make_inference_fn(params, deterministic=True)\n",
|
1495 |
+
"jit_inference_fn = jax.jit(inference_fn)\n",
|
1496 |
+
"\n",
|
1497 |
+
"eval_env = registry.load(env_name, config=env_cfg)\n",
|
1498 |
+
"jit_reset = jax.jit(eval_env.reset)\n",
|
1499 |
+
"jit_step = jax.jit(eval_env.step)\n",
|
1500 |
+
"\n",
|
1501 |
+
"rng = jax.random.PRNGKey(12345)\n",
|
1502 |
+
"rollout = []\n",
|
1503 |
+
"rewards = []\n",
|
1504 |
+
"torso_height = []\n",
|
1505 |
+
"actions = []\n",
|
1506 |
+
"torques = []\n",
|
1507 |
+
"power = []\n",
|
1508 |
+
"qfrc_constraint = []\n",
|
1509 |
+
"qvels = []\n",
|
1510 |
+
"power1 = []\n",
|
1511 |
+
"power2 = []\n",
|
1512 |
+
"for _ in range(10):\n",
|
1513 |
+
" rng, reset_rng = jax.random.split(rng)\n",
|
1514 |
+
" state = jit_reset(reset_rng)\n",
|
1515 |
+
" for i in range(env_cfg.episode_length // 2):\n",
|
1516 |
+
" act_rng, rng = jax.random.split(rng)\n",
|
1517 |
+
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
|
1518 |
+
" actions.append(ctrl)\n",
|
1519 |
+
" state = jit_step(state, ctrl)\n",
|
1520 |
+
" rollout.append(state)\n",
|
1521 |
+
" rewards.append(\n",
|
1522 |
+
" {k[7:]: v for k, v in state.metrics.items() if k.startswith(\"reward/\")}\n",
|
1523 |
+
" )\n",
|
1524 |
+
" torso_height.append(state.data.qpos[2])\n",
|
1525 |
+
" torques.append(state.data.actuator_force)\n",
|
1526 |
+
" qvel = state.data.qvel[6:]\n",
|
1527 |
+
" power.append(jp.sum(jp.abs(qvel * state.data.actuator_force)))\n",
|
1528 |
+
" qfrc_constraint.append(jp.linalg.norm(state.data.qfrc_constraint[6:]))\n",
|
1529 |
+
" qvels.append(jp.max(jp.abs(qvel)))\n",
|
1530 |
+
" frc = state.data.actuator_force\n",
|
1531 |
+
" qvel = state.data.qvel[6:]\n",
|
1532 |
+
" power1.append(jp.sum(frc * qvel))\n",
|
1533 |
+
" power2.append(jp.sum(jp.abs(frc * qvel)))\n",
|
1534 |
+
"\n",
|
1535 |
+
"\n",
|
1536 |
+
"render_every = 2\n",
|
1537 |
+
"fps = 1.0 / eval_env.dt / render_every\n",
|
1538 |
+
"traj = rollout[::render_every]\n",
|
1539 |
+
"\n",
|
1540 |
+
"scene_option = mujoco.MjvOption()\n",
|
1541 |
+
"scene_option.geomgroup[2] = True\n",
|
1542 |
+
"scene_option.geomgroup[3] = False\n",
|
1543 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
|
1544 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = False\n",
|
1545 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
|
1546 |
+
"\n",
|
1547 |
+
"frames = eval_env.render(\n",
|
1548 |
+
" traj, camera=\"side\", scene_option=scene_option, height=480, width=640\n",
|
1549 |
+
")\n",
|
1550 |
+
"media.show_video(frames, fps=fps, loop=False)\n",
|
1551 |
+
"\n",
|
1552 |
+
"power = jp.array(power1)\n",
|
1553 |
+
"print(f\"Max power: {jp.max(power)}\")"
|
1554 |
+
]
|
1555 |
+
},
|
1556 |
+
{
|
1557 |
+
"cell_type": "markdown",
|
1558 |
+
"metadata": {
|
1559 |
+
"id": "yCyibqGMiAca"
|
1560 |
+
},
|
1561 |
+
"source": [
|
1562 |
+
"The final policy should exhibit smoother behavior and have less power output! Feel free to finetune the policy some more using different reward terms to get the best behavior."
|
1563 |
+
]
|
1564 |
+
},
|
1565 |
+
{
|
1566 |
+
"cell_type": "markdown",
|
1567 |
+
"metadata": {
|
1568 |
+
"id": "26o77FfWXvVp"
|
1569 |
+
},
|
1570 |
+
"source": [
|
1571 |
+
"# Bipedal\n",
|
1572 |
+
"\n",
|
1573 |
+
"MuJoCo Playground also comes with a host of bipedal environments, such as the Berkely Humanoid and the Unitree G1/H1. Let's demonstrate a joystick policy on the Berkeley Humanoid. The initial policy takes 17 minutes to train on an RTX 4090."
|
1574 |
+
]
|
1575 |
+
},
|
1576 |
+
{
|
1577 |
+
"cell_type": "code",
|
1578 |
+
"execution_count": null,
|
1579 |
+
"metadata": {
|
1580 |
+
"id": "ESNd18FUanPt"
|
1581 |
+
},
|
1582 |
+
"outputs": [],
|
1583 |
+
"source": [
|
1584 |
+
"env_name = 'BerkeleyHumanoidJoystickFlatTerrain'\n",
|
1585 |
+
"env = registry.load(env_name)\n",
|
1586 |
+
"env_cfg = registry.get_default_config(env_name)\n",
|
1587 |
+
"ppo_params = locomotion_params.brax_ppo_config(env_name)"
|
1588 |
+
]
|
1589 |
+
},
|
1590 |
+
{
|
1591 |
+
"cell_type": "code",
|
1592 |
+
"execution_count": null,
|
1593 |
+
"metadata": {
|
1594 |
+
"id": "nibLoRu8anPt"
|
1595 |
+
},
|
1596 |
+
"outputs": [],
|
1597 |
+
"source": [
|
1598 |
+
"x_data, y_data, y_dataerr = [], [], []\n",
|
1599 |
+
"times = [datetime.now()]\n",
|
1600 |
+
"\n",
|
1601 |
+
"randomizer = registry.get_domain_randomizer(env_name)\n",
|
1602 |
+
"ppo_training_params = dict(ppo_params)\n",
|
1603 |
+
"network_factory = ppo_networks.make_ppo_networks\n",
|
1604 |
+
"if \"network_factory\" in ppo_params:\n",
|
1605 |
+
" del ppo_training_params[\"network_factory\"]\n",
|
1606 |
+
" network_factory = functools.partial(\n",
|
1607 |
+
" ppo_networks.make_ppo_networks,\n",
|
1608 |
+
" **ppo_params.network_factory\n",
|
1609 |
+
" )\n",
|
1610 |
+
"\n",
|
1611 |
+
"train_fn = functools.partial(\n",
|
1612 |
+
" ppo.train, **dict(ppo_training_params),\n",
|
1613 |
+
" network_factory=network_factory,\n",
|
1614 |
+
" randomization_fn=randomizer,\n",
|
1615 |
+
" progress_fn=progress\n",
|
1616 |
+
")"
|
1617 |
+
]
|
1618 |
+
},
|
1619 |
+
{
|
1620 |
+
"cell_type": "code",
|
1621 |
+
"execution_count": null,
|
1622 |
+
"metadata": {
|
1623 |
+
"id": "16dqomv0anPt"
|
1624 |
+
},
|
1625 |
+
"outputs": [],
|
1626 |
+
"source": [
|
1627 |
+
"make_inference_fn, params, metrics = train_fn(\n",
|
1628 |
+
" environment=env,\n",
|
1629 |
+
" eval_env=registry.load(env_name, config=env_cfg),\n",
|
1630 |
+
" wrap_env_fn=wrapper.wrap_for_brax_training,\n",
|
1631 |
+
")\n",
|
1632 |
+
"print(f\"time to jit: {times[1] - times[0]}\")\n",
|
1633 |
+
"print(f\"time to train: {times[-1] - times[1]}\")"
|
1634 |
+
]
|
1635 |
+
},
|
1636 |
+
{
|
1637 |
+
"cell_type": "code",
|
1638 |
+
"execution_count": null,
|
1639 |
+
"metadata": {
|
1640 |
+
"cellView": "form",
|
1641 |
+
"id": "sBHDF-JFanPt"
|
1642 |
+
},
|
1643 |
+
"outputs": [],
|
1644 |
+
"source": [
|
1645 |
+
"#@title Rollout and Render\n",
|
1646 |
+
"from mujoco_playground._src.gait import draw_joystick_command\n",
|
1647 |
+
"\n",
|
1648 |
+
"env = registry.load(env_name)\n",
|
1649 |
+
"eval_env = registry.load(env_name)\n",
|
1650 |
+
"jit_reset = jax.jit(eval_env.reset)\n",
|
1651 |
+
"jit_step = jax.jit(eval_env.step)\n",
|
1652 |
+
"jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))\n",
|
1653 |
+
"\n",
|
1654 |
+
"rng = jax.random.PRNGKey(1)\n",
|
1655 |
+
"\n",
|
1656 |
+
"rollout = []\n",
|
1657 |
+
"modify_scene_fns = []\n",
|
1658 |
+
"\n",
|
1659 |
+
"x_vel = 1.0 #@param {type: \"number\"}\n",
|
1660 |
+
"y_vel = 0.0 #@param {type: \"number\"}\n",
|
1661 |
+
"yaw_vel = 0.0 #@param {type: \"number\"}\n",
|
1662 |
+
"command = jp.array([x_vel, y_vel, yaw_vel])\n",
|
1663 |
+
"\n",
|
1664 |
+
"phase_dt = 2 * jp.pi * eval_env.dt * 1.5\n",
|
1665 |
+
"phase = jp.array([0, jp.pi])\n",
|
1666 |
+
"\n",
|
1667 |
+
"for j in range(1):\n",
|
1668 |
+
" print(f\"episode {j}\")\n",
|
1669 |
+
" state = jit_reset(rng)\n",
|
1670 |
+
" state.info[\"phase_dt\"] = phase_dt\n",
|
1671 |
+
" state.info[\"phase\"] = phase\n",
|
1672 |
+
" for i in range(env_cfg.episode_length):\n",
|
1673 |
+
" act_rng, rng = jax.random.split(rng)\n",
|
1674 |
+
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
|
1675 |
+
" state = jit_step(state, ctrl)\n",
|
1676 |
+
" if state.done:\n",
|
1677 |
+
" break\n",
|
1678 |
+
" state.info[\"command\"] = command\n",
|
1679 |
+
" rollout.append(state)\n",
|
1680 |
+
"\n",
|
1681 |
+
" xyz = np.array(state.data.xpos[eval_env.mj_model.body(\"torso\").id])\n",
|
1682 |
+
" xyz += np.array([0, 0.0, 0])\n",
|
1683 |
+
" x_axis = state.data.xmat[eval_env._torso_body_id, 0]\n",
|
1684 |
+
" yaw = -np.arctan2(x_axis[1], x_axis[0])\n",
|
1685 |
+
" modify_scene_fns.append(\n",
|
1686 |
+
" functools.partial(\n",
|
1687 |
+
" draw_joystick_command,\n",
|
1688 |
+
" cmd=state.info[\"command\"],\n",
|
1689 |
+
" xyz=xyz,\n",
|
1690 |
+
" theta=yaw,\n",
|
1691 |
+
" scl=np.linalg.norm(state.info[\"command\"]),\n",
|
1692 |
+
" )\n",
|
1693 |
+
" )\n",
|
1694 |
+
"\n",
|
1695 |
+
"render_every = 1\n",
|
1696 |
+
"fps = 1.0 / eval_env.dt / render_every\n",
|
1697 |
+
"print(f\"fps: {fps}\")\n",
|
1698 |
+
"traj = rollout[::render_every]\n",
|
1699 |
+
"mod_fns = modify_scene_fns[::render_every]\n",
|
1700 |
+
"\n",
|
1701 |
+
"scene_option = mujoco.MjvOption()\n",
|
1702 |
+
"scene_option.geomgroup[2] = True\n",
|
1703 |
+
"scene_option.geomgroup[3] = False\n",
|
1704 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
|
1705 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = False\n",
|
1706 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_PERTFORCE] = False\n",
|
1707 |
+
"\n",
|
1708 |
+
"frames = eval_env.render(\n",
|
1709 |
+
" traj,\n",
|
1710 |
+
" camera=\"track\",\n",
|
1711 |
+
" scene_option=scene_option,\n",
|
1712 |
+
" width=640*2,\n",
|
1713 |
+
" height=480,\n",
|
1714 |
+
" modify_scene_fns=mod_fns,\n",
|
1715 |
+
")\n",
|
1716 |
+
"media.show_video(frames, fps=fps, loop=False)"
|
1717 |
+
]
|
1718 |
+
},
|
1719 |
+
{
|
1720 |
+
"cell_type": "markdown",
|
1721 |
+
"metadata": {
|
1722 |
+
"id": "CBtrAqns35sI"
|
1723 |
+
},
|
1724 |
+
"source": [
|
1725 |
+
"🙌 Hasta la vista!"
|
1726 |
+
]
|
1727 |
+
}
|
1728 |
+
],
|
1729 |
+
"metadata": {
|
1730 |
+
"accelerator": "GPU",
|
1731 |
+
"colab": {
|
1732 |
+
"gpuType": "A100",
|
1733 |
+
"machine_shape": "hm",
|
1734 |
+
"private_outputs": true,
|
1735 |
+
"provenance": [],
|
1736 |
+
"toc_visible": true
|
1737 |
+
},
|
1738 |
+
"kernelspec": {
|
1739 |
+
"display_name": "Python 3 (ipykernel)",
|
1740 |
+
"language": "python",
|
1741 |
+
"name": "python3"
|
1742 |
+
},
|
1743 |
+
"language_info": {
|
1744 |
+
"codemirror_mode": {
|
1745 |
+
"name": "ipython",
|
1746 |
+
"version": 3
|
1747 |
+
},
|
1748 |
+
"file_extension": ".py",
|
1749 |
+
"mimetype": "text/x-python",
|
1750 |
+
"name": "python",
|
1751 |
+
"nbconvert_exporter": "python",
|
1752 |
+
"pygments_lexer": "ipython3",
|
1753 |
+
"version": "3.13.5"
|
1754 |
+
}
|
1755 |
+
},
|
1756 |
+
"nbformat": 4,
|
1757 |
+
"nbformat_minor": 4
|
1758 |
+
}
|
samples/manipulation.ipynb
ADDED
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "MpkYHwCqk7W-"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"\n",
|
10 |
+
"\n",
|
11 |
+
"\n",
|
12 |
+
"\n",
|
13 |
+
"\n",
|
14 |
+
"\n"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "markdown",
|
19 |
+
"metadata": {
|
20 |
+
"id": "xBSdkbmGN2K-"
|
21 |
+
},
|
22 |
+
"source": [
|
23 |
+
"### Copyright notice"
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"cell_type": "markdown",
|
28 |
+
"metadata": {
|
29 |
+
"id": "_UbO9uhtBSX5"
|
30 |
+
},
|
31 |
+
"source": [
|
32 |
+
"> <p><small><small>Copyright 2025 DeepMind Technologies Limited.</small></p>\n",
|
33 |
+
"> <p><small><small>Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at <a href=\"http://www.apache.org/licenses/LICENSE-2.0\">http://www.apache.org/licenses/LICENSE-2.0</a>.</small></small></p>\n",
|
34 |
+
"> <p><small><small>Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.</small></small></p>"
|
35 |
+
]
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "markdown",
|
39 |
+
"metadata": {
|
40 |
+
"id": "dNIJkb_FM2Ux"
|
41 |
+
},
|
42 |
+
"source": [
|
43 |
+
"# Manipulation in The Playground! <a href=\"https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/manipulation.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/></a>\n",
|
44 |
+
"\n",
|
45 |
+
"In this notebook, we'll walk through a couple manipulation environments available in MuJoCo Playground.\n",
|
46 |
+
"\n",
|
47 |
+
"**A Colab runtime with GPU acceleration is required.** If you're using a CPU-only runtime, you can switch using the menu \"Runtime > Change runtime type\".\n"
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"cell_type": "code",
|
52 |
+
"execution_count": null,
|
53 |
+
"metadata": {
|
54 |
+
"id": "Xqo7pyX-n72M",
|
55 |
+
"cellView": "form"
|
56 |
+
},
|
57 |
+
"outputs": [],
|
58 |
+
"source": [
|
59 |
+
"#@title Install pre-requisites\n",
|
60 |
+
"!pip install mujoco\n",
|
61 |
+
"!pip install mujoco_mjx\n",
|
62 |
+
"!pip install brax"
|
63 |
+
]
|
64 |
+
},
|
65 |
+
{
|
66 |
+
"cell_type": "code",
|
67 |
+
"execution_count": null,
|
68 |
+
"metadata": {
|
69 |
+
"cellView": "form",
|
70 |
+
"id": "IbZxYDxzoz5R"
|
71 |
+
},
|
72 |
+
"outputs": [],
|
73 |
+
"source": [
|
74 |
+
"# @title Check if MuJoCo installation was successful\n",
|
75 |
+
"\n",
|
76 |
+
"import distutils.util\n",
|
77 |
+
"import os\n",
|
78 |
+
"import subprocess\n",
|
79 |
+
"\n",
|
80 |
+
"if subprocess.run('nvidia-smi').returncode:\n",
|
81 |
+
" raise RuntimeError(\n",
|
82 |
+
" 'Cannot communicate with GPU. '\n",
|
83 |
+
" 'Make sure you are using a GPU Colab runtime. '\n",
|
84 |
+
" 'Go to the Runtime menu and select Choose runtime type.'\n",
|
85 |
+
" )\n",
|
86 |
+
"\n",
|
87 |
+
"# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
|
88 |
+
"# This is usually installed as part of an Nvidia driver package, but the Colab\n",
|
89 |
+
"# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
|
90 |
+
"# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
|
91 |
+
"NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
|
92 |
+
"if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
|
93 |
+
" with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
|
94 |
+
" f.write(\"\"\"{\n",
|
95 |
+
" \"file_format_version\" : \"1.0.0\",\n",
|
96 |
+
" \"ICD\" : {\n",
|
97 |
+
" \"library_path\" : \"libEGL_nvidia.so.0\"\n",
|
98 |
+
" }\n",
|
99 |
+
"}\n",
|
100 |
+
"\"\"\")\n",
|
101 |
+
"\n",
|
102 |
+
"# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
|
103 |
+
"print('Setting environment variable to use GPU rendering:')\n",
|
104 |
+
"%env MUJOCO_GL=egl\n",
|
105 |
+
"\n",
|
106 |
+
"try:\n",
|
107 |
+
" print('Checking that the installation succeeded:')\n",
|
108 |
+
" import mujoco\n",
|
109 |
+
"\n",
|
110 |
+
" mujoco.MjModel.from_xml_string('<mujoco/>')\n",
|
111 |
+
"except Exception as e:\n",
|
112 |
+
" raise e from RuntimeError(\n",
|
113 |
+
" 'Something went wrong during installation. Check the shell output above '\n",
|
114 |
+
" 'for more information.\\n'\n",
|
115 |
+
" 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
|
116 |
+
" 'by going to the Runtime menu and selecting \"Choose runtime type\".'\n",
|
117 |
+
" )\n",
|
118 |
+
"\n",
|
119 |
+
"print('Installation successful.')\n",
|
120 |
+
"\n",
|
121 |
+
"# Tell XLA to use Triton GEMM, this improves steps/sec by ~30% on some GPUs\n",
|
122 |
+
"xla_flags = os.environ.get('XLA_FLAGS', '')\n",
|
123 |
+
"xla_flags += ' --xla_gpu_triton_gemm_any=True'\n",
|
124 |
+
"os.environ['XLA_FLAGS'] = xla_flags"
|
125 |
+
]
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"cell_type": "code",
|
129 |
+
"execution_count": null,
|
130 |
+
"metadata": {
|
131 |
+
"id": "T5f4w3Kq2X14",
|
132 |
+
"cellView": "form"
|
133 |
+
},
|
134 |
+
"outputs": [],
|
135 |
+
"source": [
|
136 |
+
"# @title Import packages for plotting and creating graphics\n",
|
137 |
+
"import json\n",
|
138 |
+
"import itertools\n",
|
139 |
+
"import time\n",
|
140 |
+
"from typing import Callable, List, NamedTuple, Optional, Union\n",
|
141 |
+
"import numpy as np\n",
|
142 |
+
"\n",
|
143 |
+
"# Graphics and plotting.\n",
|
144 |
+
"print(\"Installing mediapy:\")\n",
|
145 |
+
"!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
|
146 |
+
"!pip install -q mediapy\n",
|
147 |
+
"import mediapy as media\n",
|
148 |
+
"import matplotlib.pyplot as plt\n",
|
149 |
+
"\n",
|
150 |
+
"# More legible printing from numpy.\n",
|
151 |
+
"np.set_printoptions(precision=3, suppress=True, linewidth=100)"
|
152 |
+
]
|
153 |
+
},
|
154 |
+
{
|
155 |
+
"cell_type": "code",
|
156 |
+
"execution_count": null,
|
157 |
+
"metadata": {
|
158 |
+
"cellView": "form",
|
159 |
+
"id": "ObF1UXrkb0Nd"
|
160 |
+
},
|
161 |
+
"outputs": [],
|
162 |
+
"source": [
|
163 |
+
"# @title Import MuJoCo, MJX, and Brax\n",
|
164 |
+
"from datetime import datetime\n",
|
165 |
+
"import functools\n",
|
166 |
+
"import os\n",
|
167 |
+
"from typing import Any, Dict, Sequence, Tuple, Union\n",
|
168 |
+
"from brax import base\n",
|
169 |
+
"from brax import envs\n",
|
170 |
+
"from brax import math\n",
|
171 |
+
"from brax.base import Base, Motion, Transform\n",
|
172 |
+
"from brax.base import State as PipelineState\n",
|
173 |
+
"from brax.envs.base import Env, PipelineEnv, State\n",
|
174 |
+
"from brax.io import html, mjcf, model\n",
|
175 |
+
"from brax.mjx.base import State as MjxState\n",
|
176 |
+
"from brax.training.agents.ppo import networks as ppo_networks\n",
|
177 |
+
"from brax.training.agents.ppo import train as ppo\n",
|
178 |
+
"from brax.training.agents.sac import networks as sac_networks\n",
|
179 |
+
"from brax.training.agents.sac import train as sac\n",
|
180 |
+
"from etils import epath\n",
|
181 |
+
"from flax import struct\n",
|
182 |
+
"from flax.training import orbax_utils\n",
|
183 |
+
"from IPython.display import HTML, clear_output\n",
|
184 |
+
"import jax\n",
|
185 |
+
"from jax import numpy as jp\n",
|
186 |
+
"from matplotlib import pyplot as plt\n",
|
187 |
+
"import mediapy as media\n",
|
188 |
+
"from ml_collections import config_dict\n",
|
189 |
+
"import mujoco\n",
|
190 |
+
"from mujoco import mjx\n",
|
191 |
+
"import numpy as np\n",
|
192 |
+
"from orbax import checkpoint as ocp"
|
193 |
+
]
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"source": [
|
198 |
+
"#@title Install MuJoCo Playground\n",
|
199 |
+
"!pip install playground"
|
200 |
+
],
|
201 |
+
"metadata": {
|
202 |
+
"cellView": "form",
|
203 |
+
"id": "UoTLSx4cFRdy"
|
204 |
+
},
|
205 |
+
"execution_count": null,
|
206 |
+
"outputs": []
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"source": [
|
211 |
+
"#@title Import The Playground\n",
|
212 |
+
"\n",
|
213 |
+
"from mujoco_playground import wrapper\n",
|
214 |
+
"from mujoco_playground import registry"
|
215 |
+
],
|
216 |
+
"metadata": {
|
217 |
+
"cellView": "form",
|
218 |
+
"id": "gYm2h7m8w3Nv"
|
219 |
+
},
|
220 |
+
"execution_count": null,
|
221 |
+
"outputs": []
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "markdown",
|
225 |
+
"source": [
|
226 |
+
"# Manipulation\n",
|
227 |
+
"\n",
|
228 |
+
"MuJoCo Playground contains several manipulation environments (all listed below after running the command)."
|
229 |
+
],
|
230 |
+
"metadata": {
|
231 |
+
"id": "LcibXbyKt4FI"
|
232 |
+
}
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"source": [
|
237 |
+
"registry.manipulation.ALL_ENVS"
|
238 |
+
],
|
239 |
+
"metadata": {
|
240 |
+
"id": "ox0Gze9Ct5AM"
|
241 |
+
},
|
242 |
+
"execution_count": null,
|
243 |
+
"outputs": []
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"cell_type": "markdown",
|
247 |
+
"source": [
|
248 |
+
"# Franka Emika Panda\n",
|
249 |
+
"\n",
|
250 |
+
"Let's start off with the simplest environment, simply picking up a cube with the Franka Emika Panda."
|
251 |
+
],
|
252 |
+
"metadata": {
|
253 |
+
"id": "_R01tjWfI-i6"
|
254 |
+
}
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"source": [
|
259 |
+
"env_name = 'PandaPickCubeOrientation'\n",
|
260 |
+
"env = registry.load(env_name)\n",
|
261 |
+
"env_cfg = registry.get_default_config(env_name)"
|
262 |
+
],
|
263 |
+
"metadata": {
|
264 |
+
"id": "kPJeoQeEJBSA"
|
265 |
+
},
|
266 |
+
"execution_count": null,
|
267 |
+
"outputs": []
|
268 |
+
},
|
269 |
+
{
|
270 |
+
"cell_type": "code",
|
271 |
+
"source": [
|
272 |
+
"env_cfg"
|
273 |
+
],
|
274 |
+
"metadata": {
|
275 |
+
"id": "6n9UT9N1wR5K"
|
276 |
+
},
|
277 |
+
"execution_count": null,
|
278 |
+
"outputs": []
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"cell_type": "markdown",
|
282 |
+
"source": [
|
283 |
+
"## Train Policy\n",
|
284 |
+
"\n",
|
285 |
+
"Let's train the pick cube policy and visualize rollouts. The policy takes roughly 3 minutes to train on an RTX 4090."
|
286 |
+
],
|
287 |
+
"metadata": {
|
288 |
+
"id": "Thm7nZueM4cz"
|
289 |
+
}
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"source": [
|
294 |
+
"from mujoco_playground.config import manipulation_params\n",
|
295 |
+
"ppo_params = manipulation_params.brax_ppo_config(env_name)\n",
|
296 |
+
"ppo_params"
|
297 |
+
],
|
298 |
+
"metadata": {
|
299 |
+
"id": "B9T_UVZYLDdM"
|
300 |
+
},
|
301 |
+
"execution_count": null,
|
302 |
+
"outputs": []
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "markdown",
|
306 |
+
"metadata": {
|
307 |
+
"id": "vBEEQyY6M5OC"
|
308 |
+
},
|
309 |
+
"source": [
|
310 |
+
"### PPO"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": null,
|
316 |
+
"metadata": {
|
317 |
+
"id": "XKFzyP7wM5OD"
|
318 |
+
},
|
319 |
+
"outputs": [],
|
320 |
+
"source": [
|
321 |
+
"x_data, y_data, y_dataerr = [], [], []\n",
|
322 |
+
"times = [datetime.now()]\n",
|
323 |
+
"\n",
|
324 |
+
"\n",
|
325 |
+
"def progress(num_steps, metrics):\n",
|
326 |
+
" clear_output(wait=True)\n",
|
327 |
+
"\n",
|
328 |
+
" times.append(datetime.now())\n",
|
329 |
+
" x_data.append(num_steps)\n",
|
330 |
+
" y_data.append(metrics[\"eval/episode_reward\"])\n",
|
331 |
+
" y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
|
332 |
+
"\n",
|
333 |
+
" plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
|
334 |
+
" plt.xlabel(\"# environment steps\")\n",
|
335 |
+
" plt.ylabel(\"reward per episode\")\n",
|
336 |
+
" plt.title(f\"y={y_data[-1]:.3f}\")\n",
|
337 |
+
" plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
|
338 |
+
"\n",
|
339 |
+
" display(plt.gcf())\n",
|
340 |
+
"\n",
|
341 |
+
"ppo_training_params = dict(ppo_params)\n",
|
342 |
+
"network_factory = ppo_networks.make_ppo_networks\n",
|
343 |
+
"if \"network_factory\" in ppo_params:\n",
|
344 |
+
" del ppo_training_params[\"network_factory\"]\n",
|
345 |
+
" network_factory = functools.partial(\n",
|
346 |
+
" ppo_networks.make_ppo_networks,\n",
|
347 |
+
" **ppo_params.network_factory\n",
|
348 |
+
" )\n",
|
349 |
+
"\n",
|
350 |
+
"train_fn = functools.partial(\n",
|
351 |
+
" ppo.train, **dict(ppo_training_params),\n",
|
352 |
+
" network_factory=network_factory,\n",
|
353 |
+
" progress_fn=progress,\n",
|
354 |
+
" seed=1\n",
|
355 |
+
")"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": null,
|
361 |
+
"metadata": {
|
362 |
+
"id": "FGrlulWbM5OD"
|
363 |
+
},
|
364 |
+
"outputs": [],
|
365 |
+
"source": [
|
366 |
+
"make_inference_fn, params, metrics = train_fn(\n",
|
367 |
+
" environment=env,\n",
|
368 |
+
" wrap_env_fn=wrapper.wrap_for_brax_training,\n",
|
369 |
+
")\n",
|
370 |
+
"print(f\"time to jit: {times[1] - times[0]}\")\n",
|
371 |
+
"print(f\"time to train: {times[-1] - times[1]}\")"
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "markdown",
|
376 |
+
"source": [
|
377 |
+
"## Visualize Rollouts"
|
378 |
+
],
|
379 |
+
"metadata": {
|
380 |
+
"id": "mHVmccs-oMSo"
|
381 |
+
}
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"cell_type": "code",
|
385 |
+
"execution_count": null,
|
386 |
+
"metadata": {
|
387 |
+
"id": "sG5a2FFXoUKw"
|
388 |
+
},
|
389 |
+
"outputs": [],
|
390 |
+
"source": [
|
391 |
+
"jit_reset = jax.jit(env.reset)\n",
|
392 |
+
"jit_step = jax.jit(env.step)\n",
|
393 |
+
"jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
|
394 |
+
]
|
395 |
+
},
|
396 |
+
{
|
397 |
+
"cell_type": "code",
|
398 |
+
"execution_count": null,
|
399 |
+
"metadata": {
|
400 |
+
"id": "C_1CY9xDoUKw"
|
401 |
+
},
|
402 |
+
"outputs": [],
|
403 |
+
"source": [
|
404 |
+
"rng = jax.random.PRNGKey(42)\n",
|
405 |
+
"rollout = []\n",
|
406 |
+
"n_episodes = 1\n",
|
407 |
+
"\n",
|
408 |
+
"for _ in range(n_episodes):\n",
|
409 |
+
" state = jit_reset(rng)\n",
|
410 |
+
" rollout.append(state)\n",
|
411 |
+
" for i in range(env_cfg.episode_length):\n",
|
412 |
+
" act_rng, rng = jax.random.split(rng)\n",
|
413 |
+
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
|
414 |
+
" state = jit_step(state, ctrl)\n",
|
415 |
+
" rollout.append(state)\n",
|
416 |
+
"\n",
|
417 |
+
"render_every = 1\n",
|
418 |
+
"frames = env.render(rollout[::render_every])\n",
|
419 |
+
"rewards = [s.reward for s in rollout]\n",
|
420 |
+
"media.show_video(frames, fps=1.0 / env.dt / render_every)"
|
421 |
+
]
|
422 |
+
},
|
423 |
+
{
|
424 |
+
"cell_type": "markdown",
|
425 |
+
"source": [
|
426 |
+
"While the above policy is very simple, the work was extended using the Madrona batch renderer, and policies were transferred on a real robot. We encourage folks to check out the Madrona-MJX tutorial notebooks ([part 1](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_1.ipynb) and [part 2](https://colab.research.google.com/github/google-deepmind/mujoco_playground/blob/main/learning/notebooks/training_vision_2.ipynb))!"
|
427 |
+
],
|
428 |
+
"metadata": {
|
429 |
+
"id": "v5r4FwivlnoG"
|
430 |
+
}
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"cell_type": "markdown",
|
434 |
+
"source": [
|
435 |
+
"# Dexterous Manipulation\n",
|
436 |
+
"\n",
|
437 |
+
"Let's now train a policy that was transferred onto a real Leap Hand robot with the `LeapCubeReorient` environment! The environment contains a cube placed in the center of the hand, and the goal is to re-orient the cube in SO(3)."
|
438 |
+
],
|
439 |
+
"metadata": {
|
440 |
+
"id": "YVQsrEE3mMj8"
|
441 |
+
}
|
442 |
+
},
|
443 |
+
{
|
444 |
+
"cell_type": "code",
|
445 |
+
"source": [
|
446 |
+
"env_name = 'LeapCubeReorient'\n",
|
447 |
+
"env = registry.load(env_name)\n",
|
448 |
+
"env_cfg = registry.get_default_config(env_name)"
|
449 |
+
],
|
450 |
+
"metadata": {
|
451 |
+
"id": "oPaTdWqVmuPt"
|
452 |
+
},
|
453 |
+
"execution_count": null,
|
454 |
+
"outputs": []
|
455 |
+
},
|
456 |
+
{
|
457 |
+
"cell_type": "code",
|
458 |
+
"source": [
|
459 |
+
"env_cfg"
|
460 |
+
],
|
461 |
+
"metadata": {
|
462 |
+
"id": "c0OII08RmuPt"
|
463 |
+
},
|
464 |
+
"execution_count": null,
|
465 |
+
"outputs": []
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"cell_type": "markdown",
|
469 |
+
"source": [
|
470 |
+
"## Train Policy\n",
|
471 |
+
"\n",
|
472 |
+
"Let's train an initial policy and visualize the rollouts. Notice that the PPO parameters contain `policy_obs_key` and `value_obs_key` fields, which allow us to train brax PPO with [asymmetric](https://arxiv.org/abs/1710.06542) observations for the actor and the critic. While the actor recieves proprioceptive state similar in nature to the real-world camera tracking sensors, the critic network recieves privileged state only available in the simulator. This enables more sample efficient learning, and we are able to train an initial policy in 33 minutes on a single RTX 4090.\n",
|
473 |
+
"\n",
|
474 |
+
"Depending on the GPU device and topology, training can be brought down to 10-20 minutes as shown in the MuJoCo Playground technical report."
|
475 |
+
],
|
476 |
+
"metadata": {
|
477 |
+
"id": "3g335ZYFmuPt"
|
478 |
+
}
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"cell_type": "code",
|
482 |
+
"source": [
|
483 |
+
"from mujoco_playground.config import manipulation_params\n",
|
484 |
+
"ppo_params = manipulation_params.brax_ppo_config(env_name)\n",
|
485 |
+
"ppo_params"
|
486 |
+
],
|
487 |
+
"metadata": {
|
488 |
+
"id": "cc1Ka4hYmuPt"
|
489 |
+
},
|
490 |
+
"execution_count": null,
|
491 |
+
"outputs": []
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"cell_type": "markdown",
|
495 |
+
"metadata": {
|
496 |
+
"id": "Ulr1ih6PmuPu"
|
497 |
+
},
|
498 |
+
"source": [
|
499 |
+
"### PPO"
|
500 |
+
]
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"cell_type": "code",
|
504 |
+
"execution_count": null,
|
505 |
+
"metadata": {
|
506 |
+
"id": "gzwRjUGLmuPu"
|
507 |
+
},
|
508 |
+
"outputs": [],
|
509 |
+
"source": [
|
510 |
+
"x_data, y_data, y_dataerr = [], [], []\n",
|
511 |
+
"times = [datetime.now()]\n",
|
512 |
+
"\n",
|
513 |
+
"\n",
|
514 |
+
"def progress(num_steps, metrics):\n",
|
515 |
+
" clear_output(wait=True)\n",
|
516 |
+
"\n",
|
517 |
+
" times.append(datetime.now())\n",
|
518 |
+
" x_data.append(num_steps)\n",
|
519 |
+
" y_data.append(metrics[\"eval/episode_reward\"])\n",
|
520 |
+
" y_dataerr.append(metrics[\"eval/episode_reward_std\"])\n",
|
521 |
+
"\n",
|
522 |
+
" plt.xlim([0, ppo_params[\"num_timesteps\"] * 1.25])\n",
|
523 |
+
" plt.xlabel(\"# environment steps\")\n",
|
524 |
+
" plt.ylabel(\"reward per episode\")\n",
|
525 |
+
" plt.title(f\"y={y_data[-1]:.3f}\")\n",
|
526 |
+
" plt.errorbar(x_data, y_data, yerr=y_dataerr, color=\"blue\")\n",
|
527 |
+
"\n",
|
528 |
+
" display(plt.gcf())\n",
|
529 |
+
"\n",
|
530 |
+
"ppo_training_params = dict(ppo_params)\n",
|
531 |
+
"network_factory = ppo_networks.make_ppo_networks\n",
|
532 |
+
"if \"network_factory\" in ppo_params:\n",
|
533 |
+
" del ppo_training_params[\"network_factory\"]\n",
|
534 |
+
" network_factory = functools.partial(\n",
|
535 |
+
" ppo_networks.make_ppo_networks,\n",
|
536 |
+
" **ppo_params.network_factory\n",
|
537 |
+
" )\n",
|
538 |
+
"\n",
|
539 |
+
"train_fn = functools.partial(\n",
|
540 |
+
" ppo.train, **dict(ppo_training_params),\n",
|
541 |
+
" network_factory=network_factory,\n",
|
542 |
+
" progress_fn=progress,\n",
|
543 |
+
" seed=1\n",
|
544 |
+
")"
|
545 |
+
]
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"cell_type": "code",
|
549 |
+
"execution_count": null,
|
550 |
+
"metadata": {
|
551 |
+
"id": "YmortADGmuPu"
|
552 |
+
},
|
553 |
+
"outputs": [],
|
554 |
+
"source": [
|
555 |
+
"make_inference_fn, params, metrics = train_fn(\n",
|
556 |
+
" environment=env,\n",
|
557 |
+
" wrap_env_fn=wrapper.wrap_for_brax_training,\n",
|
558 |
+
")\n",
|
559 |
+
"print(f\"time to jit: {times[1] - times[0]}\")\n",
|
560 |
+
"print(f\"time to train: {times[-1] - times[1]}\")"
|
561 |
+
]
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"cell_type": "markdown",
|
565 |
+
"source": [
|
566 |
+
"## Visualize Rollouts"
|
567 |
+
],
|
568 |
+
"metadata": {
|
569 |
+
"id": "xIB_677emuPu"
|
570 |
+
}
|
571 |
+
},
|
572 |
+
{
|
573 |
+
"cell_type": "code",
|
574 |
+
"execution_count": null,
|
575 |
+
"metadata": {
|
576 |
+
"id": "xBgGvZpTmuPu"
|
577 |
+
},
|
578 |
+
"outputs": [],
|
579 |
+
"source": [
|
580 |
+
"jit_reset = jax.jit(env.reset)\n",
|
581 |
+
"jit_step = jax.jit(env.step)\n",
|
582 |
+
"jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
{
|
586 |
+
"cell_type": "code",
|
587 |
+
"execution_count": null,
|
588 |
+
"metadata": {
|
589 |
+
"id": "Ksj6_9PwmuPu"
|
590 |
+
},
|
591 |
+
"outputs": [],
|
592 |
+
"source": [
|
593 |
+
"rng = jax.random.PRNGKey(42)\n",
|
594 |
+
"rollout = []\n",
|
595 |
+
"n_episodes = 1\n",
|
596 |
+
"\n",
|
597 |
+
"for _ in range(n_episodes):\n",
|
598 |
+
" state = jit_reset(rng)\n",
|
599 |
+
" rollout.append(state)\n",
|
600 |
+
" for i in range(env_cfg.episode_length):\n",
|
601 |
+
" act_rng, rng = jax.random.split(rng)\n",
|
602 |
+
" ctrl, _ = jit_inference_fn(state.obs, act_rng)\n",
|
603 |
+
" state = jit_step(state, ctrl)\n",
|
604 |
+
" rollout.append(state)\n",
|
605 |
+
"\n",
|
606 |
+
"render_every = 1\n",
|
607 |
+
"frames = env.render(rollout[::render_every])\n",
|
608 |
+
"rewards = [s.reward for s in rollout]\n",
|
609 |
+
"media.show_video(frames, fps=1.0 / env.dt / render_every)"
|
610 |
+
]
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"cell_type": "markdown",
|
614 |
+
"source": [
|
615 |
+
"The above policy solves the task, but may look a little bit jittery. To get robust sim-to-real transfer, we retrained from previous checkpoints using a curriculum on the maximum torque to facilitate exploration early on in the curriculum, and to produce smoother actions for the final policy. More details can be found in the MuJoCo Playground technical report!"
|
616 |
+
],
|
617 |
+
"metadata": {
|
618 |
+
"id": "dWIVTxq5nhs5"
|
619 |
+
}
|
620 |
+
},
|
621 |
+
{
|
622 |
+
"cell_type": "markdown",
|
623 |
+
"metadata": {
|
624 |
+
"id": "CBtrAqns35sI"
|
625 |
+
},
|
626 |
+
"source": [
|
627 |
+
"🙌 Thanks for stopping by The Playground!"
|
628 |
+
]
|
629 |
+
}
|
630 |
+
],
|
631 |
+
"metadata": {
|
632 |
+
"colab": {
|
633 |
+
"private_outputs": true,
|
634 |
+
"toc_visible": true,
|
635 |
+
"provenance": [],
|
636 |
+
"machine_shape": "hm",
|
637 |
+
"gpuType": "A100"
|
638 |
+
},
|
639 |
+
"kernelspec": {
|
640 |
+
"display_name": "Python 3",
|
641 |
+
"name": "python3"
|
642 |
+
},
|
643 |
+
"language_info": {
|
644 |
+
"name": "python"
|
645 |
+
},
|
646 |
+
"accelerator": "GPU"
|
647 |
+
},
|
648 |
+
"nbformat": 4,
|
649 |
+
"nbformat_minor": 0
|
650 |
+
}
|
samples/tutorial.ipynb
ADDED
@@ -0,0 +1,2258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {
|
6 |
+
"id": "MpkYHwCqk7W-"
|
7 |
+
},
|
8 |
+
"source": [
|
9 |
+
"\n",
|
10 |
+
"\n",
|
11 |
+
"# <h1><center>Tutorial <a href=\"https://colab.research.google.com/github/google-deepmind/mujoco/blob/main/python/tutorial.ipynb\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" width=\"140\" align=\"center\"/></a></center></h1>\n",
|
12 |
+
"\n",
|
13 |
+
"This notebook provides an introductory tutorial for [**MuJoCo** physics](https://github.com/google-deepmind/mujoco#readme), using the native Python bindings.\n",
|
14 |
+
"\n",
|
15 |
+
"<!-- Copyright 2021 DeepMind Technologies Limited\n",
|
16 |
+
"\n",
|
17 |
+
" Licensed under the Apache License, Version 2.0 (the \"License\");\n",
|
18 |
+
" you may not use this file except in compliance with the License.\n",
|
19 |
+
" You may obtain a copy of the License at\n",
|
20 |
+
"\n",
|
21 |
+
" http://www.apache.org/licenses/LICENSE-2.0\n",
|
22 |
+
"\n",
|
23 |
+
" Unless required by applicable law or agreed to in writing, software\n",
|
24 |
+
" distributed under the License is distributed on an \"AS IS\" BASIS,\n",
|
25 |
+
" WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
|
26 |
+
" See the License for the specific language governing permissions and\n",
|
27 |
+
" limitations under the License.\n",
|
28 |
+
"-->"
|
29 |
+
]
|
30 |
+
},
|
31 |
+
{
|
32 |
+
"cell_type": "markdown",
|
33 |
+
"metadata": {
|
34 |
+
"id": "YvyGCsgSCxHQ"
|
35 |
+
},
|
36 |
+
"source": [
|
37 |
+
"# All imports"
|
38 |
+
]
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"cell_type": "code",
|
42 |
+
"execution_count": null,
|
43 |
+
"metadata": {
|
44 |
+
"id": "Xqo7pyX-n72M"
|
45 |
+
},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"!pip install mujoco\n",
|
49 |
+
"\n",
|
50 |
+
"# Set up GPU rendering.\n",
|
51 |
+
"from google.colab import files\n",
|
52 |
+
"import distutils.util\n",
|
53 |
+
"import os\n",
|
54 |
+
"import subprocess\n",
|
55 |
+
"if subprocess.run('nvidia-smi').returncode:\n",
|
56 |
+
" raise RuntimeError(\n",
|
57 |
+
" 'Cannot communicate with GPU. '\n",
|
58 |
+
" 'Make sure you are using a GPU Colab runtime. '\n",
|
59 |
+
" 'Go to the Runtime menu and select Choose runtime type.')\n",
|
60 |
+
"\n",
|
61 |
+
"# Add an ICD config so that glvnd can pick up the Nvidia EGL driver.\n",
|
62 |
+
"# This is usually installed as part of an Nvidia driver package, but the Colab\n",
|
63 |
+
"# kernel doesn't install its driver via APT, and as a result the ICD is missing.\n",
|
64 |
+
"# (https://github.com/NVIDIA/libglvnd/blob/master/src/EGL/icd_enumeration.md)\n",
|
65 |
+
"NVIDIA_ICD_CONFIG_PATH = '/usr/share/glvnd/egl_vendor.d/10_nvidia.json'\n",
|
66 |
+
"if not os.path.exists(NVIDIA_ICD_CONFIG_PATH):\n",
|
67 |
+
" with open(NVIDIA_ICD_CONFIG_PATH, 'w') as f:\n",
|
68 |
+
" f.write(\"\"\"{\n",
|
69 |
+
" \"file_format_version\" : \"1.0.0\",\n",
|
70 |
+
" \"ICD\" : {\n",
|
71 |
+
" \"library_path\" : \"libEGL_nvidia.so.0\"\n",
|
72 |
+
" }\n",
|
73 |
+
"}\n",
|
74 |
+
"\"\"\")\n",
|
75 |
+
"\n",
|
76 |
+
"# Configure MuJoCo to use the EGL rendering backend (requires GPU)\n",
|
77 |
+
"print('Setting environment variable to use GPU rendering:')\n",
|
78 |
+
"%env MUJOCO_GL=egl\n",
|
79 |
+
"\n",
|
80 |
+
"# Check if installation was succesful.\n",
|
81 |
+
"try:\n",
|
82 |
+
" print('Checking that the installation succeeded:')\n",
|
83 |
+
" import mujoco\n",
|
84 |
+
" mujoco.MjModel.from_xml_string('<mujoco/>')\n",
|
85 |
+
"except Exception as e:\n",
|
86 |
+
" raise e from RuntimeError(\n",
|
87 |
+
" 'Something went wrong during installation. Check the shell output above '\n",
|
88 |
+
" 'for more information.\\n'\n",
|
89 |
+
" 'If using a hosted Colab runtime, make sure you enable GPU acceleration '\n",
|
90 |
+
" 'by going to the Runtime menu and selecting \"Choose runtime type\".')\n",
|
91 |
+
"\n",
|
92 |
+
"print('Installation successful.')\n",
|
93 |
+
"\n",
|
94 |
+
"# Other imports and helper functions\n",
|
95 |
+
"import time\n",
|
96 |
+
"import itertools\n",
|
97 |
+
"import numpy as np\n",
|
98 |
+
"\n",
|
99 |
+
"# Graphics and plotting.\n",
|
100 |
+
"print('Installing mediapy:')\n",
|
101 |
+
"!command -v ffmpeg >/dev/null || (apt update && apt install -y ffmpeg)\n",
|
102 |
+
"!pip install -q mediapy\n",
|
103 |
+
"import mediapy as media\n",
|
104 |
+
"import matplotlib.pyplot as plt\n",
|
105 |
+
"\n",
|
106 |
+
"# More legible printing from numpy.\n",
|
107 |
+
"np.set_printoptions(precision=3, suppress=True, linewidth=100)\n",
|
108 |
+
"\n",
|
109 |
+
"from IPython.display import clear_output\n",
|
110 |
+
"clear_output()\n"
|
111 |
+
]
|
112 |
+
},
|
113 |
+
{
|
114 |
+
"cell_type": "markdown",
|
115 |
+
"metadata": {
|
116 |
+
"id": "t0CF6Gvkt_Cw"
|
117 |
+
},
|
118 |
+
"source": [
|
119 |
+
"# MuJoCo basics\n",
|
120 |
+
"\n",
|
121 |
+
"We begin by defining and loading a simple model:"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": null,
|
127 |
+
"metadata": {
|
128 |
+
"id": "3KJVqak6xdJa"
|
129 |
+
},
|
130 |
+
"outputs": [],
|
131 |
+
"source": [
|
132 |
+
"xml = \"\"\"\n",
|
133 |
+
"<mujoco>\n",
|
134 |
+
" <worldbody>\n",
|
135 |
+
" <geom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/>\n",
|
136 |
+
" <geom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/>\n",
|
137 |
+
" </worldbody>\n",
|
138 |
+
"</mujoco>\n",
|
139 |
+
"\"\"\"\n",
|
140 |
+
"model = mujoco.MjModel.from_xml_string(xml)"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "markdown",
|
145 |
+
"metadata": {
|
146 |
+
"id": "slhf39lGxvDI"
|
147 |
+
},
|
148 |
+
"source": [
|
149 |
+
"The `xml` string is written in MuJoCo's [MJCF](http://www.mujoco.org/book/modeling.html), which is an [XML](https://en.wikipedia.org/wiki/XML#Key_terminology)-based modeling language.\n",
|
150 |
+
" - The only required element is `<mujoco>`. The smallest valid MJCF model is `<mujoco/>` which is a completely empty model.\n",
|
151 |
+
" - All physical elements live inside the `<worldbody>` which is always the top-level body and constitutes the global origin in Cartesian coordinates.\n",
|
152 |
+
" - We define two geoms in the world named `red_box` and `green_sphere`.\n",
|
153 |
+
" - **Question:** The `red_box` has no position, the `green_sphere` has no type, why is that?\n",
|
154 |
+
" - **Answer:** MJCF attributes have *default values*. The default position is `0 0 0`, the default geom type is `sphere`. The MJCF language is described in the documentation's [XML Reference chapter](https://mujoco.readthedocs.io/en/latest/XMLreference.html).\n",
|
155 |
+
"\n",
|
156 |
+
"The `from_xml_string()` method invokes the model compiler, which creates a binary `mjModel` instance."
|
157 |
+
]
|
158 |
+
},
|
159 |
+
{
|
160 |
+
"cell_type": "markdown",
|
161 |
+
"metadata": {
|
162 |
+
"id": "gf9h_wi9weet"
|
163 |
+
},
|
164 |
+
"source": [
|
165 |
+
"## mjModel\n",
|
166 |
+
"\n",
|
167 |
+
"MuJoCo's `mjModel`, contains the *model description*, i.e., all quantities which *do not change over time*. The complete description of `mjModel` can be found at the end of the header file [`mjmodel.h`](https://github.com/google-deepmind/mujoco/blob/main/include/mujoco/mjmodel.h). Note that the header files contain short, useful inline comments, describing each field.\n",
|
168 |
+
"\n",
|
169 |
+
"Examples of quantities that can be found in `mjModel` are `ngeom`, the number of geoms in the scene and `geom_rgba`, their respective colors:"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": null,
|
175 |
+
"metadata": {
|
176 |
+
"id": "F40Pe6DY3Q0g"
|
177 |
+
},
|
178 |
+
"outputs": [],
|
179 |
+
"source": [
|
180 |
+
"model.ngeom"
|
181 |
+
]
|
182 |
+
},
|
183 |
+
{
|
184 |
+
"cell_type": "code",
|
185 |
+
"execution_count": null,
|
186 |
+
"metadata": {
|
187 |
+
"id": "MOIJG9pzx8cA"
|
188 |
+
},
|
189 |
+
"outputs": [],
|
190 |
+
"source": [
|
191 |
+
"model.geom_rgba"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "markdown",
|
196 |
+
"metadata": {
|
197 |
+
"id": "bzcLjdY23Kvp"
|
198 |
+
},
|
199 |
+
"source": [
|
200 |
+
"## Named access\n",
|
201 |
+
"\n",
|
202 |
+
"The MuJoCo Python bindings provide convenient [accessors](https://mujoco.readthedocs.io/en/latest/python.html#named-access) using names. Calling the `model.geom()` accessor without a name string generates a convenient error that tells us what the valid names are."
|
203 |
+
]
|
204 |
+
},
|
205 |
+
{
|
206 |
+
"cell_type": "code",
|
207 |
+
"execution_count": null,
|
208 |
+
"metadata": {
|
209 |
+
"id": "9AuTwbLFyJxQ"
|
210 |
+
},
|
211 |
+
"outputs": [],
|
212 |
+
"source": [
|
213 |
+
"try:\n",
|
214 |
+
" model.geom()\n",
|
215 |
+
"except KeyError as e:\n",
|
216 |
+
" print(e)"
|
217 |
+
]
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "markdown",
|
221 |
+
"metadata": {
|
222 |
+
"id": "qkfLK3h2zrqr"
|
223 |
+
},
|
224 |
+
"source": [
|
225 |
+
"Calling the named accessor without specifying a property will tell us what all the valid properties are:"
|
226 |
+
]
|
227 |
+
},
|
228 |
+
{
|
229 |
+
"cell_type": "code",
|
230 |
+
"execution_count": null,
|
231 |
+
"metadata": {
|
232 |
+
"id": "9X95TlWnyEEw"
|
233 |
+
},
|
234 |
+
"outputs": [],
|
235 |
+
"source": [
|
236 |
+
"model.geom('green_sphere')"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "markdown",
|
241 |
+
"metadata": {
|
242 |
+
"id": "mS9qDLevKsJq"
|
243 |
+
},
|
244 |
+
"source": [
|
245 |
+
"Let's read the `green_sphere`'s rgba values:"
|
246 |
+
]
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"cell_type": "code",
|
250 |
+
"execution_count": null,
|
251 |
+
"metadata": {
|
252 |
+
"id": "xsBlJAV7zpHb"
|
253 |
+
},
|
254 |
+
"outputs": [],
|
255 |
+
"source": [
|
256 |
+
"model.geom('green_sphere').rgba"
|
257 |
+
]
|
258 |
+
},
|
259 |
+
{
|
260 |
+
"cell_type": "markdown",
|
261 |
+
"metadata": {
|
262 |
+
"id": "8a8hswjjKyIa"
|
263 |
+
},
|
264 |
+
"source": [
|
265 |
+
"This functionality is a convenience shortcut for MuJoCo's [`mj_name2id`](https://mujoco.readthedocs.io/en/latest/APIreference.html?highlight=mj_name2id#mj-name2id) function:"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": null,
|
271 |
+
"metadata": {
|
272 |
+
"id": "Ng92hNUoKnVq"
|
273 |
+
},
|
274 |
+
"outputs": [],
|
275 |
+
"source": [
|
276 |
+
"id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_GEOM, 'green_sphere')\n",
|
277 |
+
"model.geom_rgba[id, :]"
|
278 |
+
]
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"cell_type": "markdown",
|
282 |
+
"metadata": {
|
283 |
+
"id": "5WL_SaJPLl3r"
|
284 |
+
},
|
285 |
+
"source": [
|
286 |
+
"Similarly, the read-only `id` and `name` properties can be used to convert from id to name and back:"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "code",
|
291 |
+
"execution_count": null,
|
292 |
+
"metadata": {
|
293 |
+
"id": "2CbGSmRZeE5p"
|
294 |
+
},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"print('id of \"green_sphere\": ', model.geom('green_sphere').id)\n",
|
298 |
+
"print('name of geom 1: ', model.geom(1).name)\n",
|
299 |
+
"print('name of body 0: ', model.body(0).name)"
|
300 |
+
]
|
301 |
+
},
|
302 |
+
{
|
303 |
+
"cell_type": "markdown",
|
304 |
+
"metadata": {
|
305 |
+
"id": "3RIizubaL_du"
|
306 |
+
},
|
307 |
+
"source": [
|
308 |
+
"Note that the 0th body is always the `world`. It cannot be renamed.\n",
|
309 |
+
"\n",
|
310 |
+
"The `id` and `name` attributes are useful in Python comprehensions:"
|
311 |
+
]
|
312 |
+
},
|
313 |
+
{
|
314 |
+
"cell_type": "code",
|
315 |
+
"execution_count": null,
|
316 |
+
"metadata": {
|
317 |
+
"id": "m3MtIE5F1K7s"
|
318 |
+
},
|
319 |
+
"outputs": [],
|
320 |
+
"source": [
|
321 |
+
"[model.geom(i).name for i in range(model.ngeom)]"
|
322 |
+
]
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "markdown",
|
326 |
+
"metadata": {
|
327 |
+
"id": "t5hY0fyXFLcf"
|
328 |
+
},
|
329 |
+
"source": [
|
330 |
+
"## `mjData`\n",
|
331 |
+
"`mjData` contains the *state* and quantities that depend on it. The state is made up of time, [generalized](https://en.wikipedia.org/wiki/Generalized_coordinates) positions and generalized velocities. These are respectively `data.time`, `data.qpos` and `data.qvel`. In order to make a new `mjData`, all we need is our `mjModel`"
|
332 |
+
]
|
333 |
+
},
|
334 |
+
{
|
335 |
+
"cell_type": "code",
|
336 |
+
"execution_count": null,
|
337 |
+
"metadata": {
|
338 |
+
"id": "FV2Hy6m948nr"
|
339 |
+
},
|
340 |
+
"outputs": [],
|
341 |
+
"source": [
|
342 |
+
"data = mujoco.MjData(model)"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "markdown",
|
347 |
+
"metadata": {
|
348 |
+
"id": "-KmNuvlJ46u0"
|
349 |
+
},
|
350 |
+
"source": [
|
351 |
+
"`mjData` also contains *functions of the state*, for example the Cartesian positions of objects in the world frame. The (x, y, z) positions of our two geoms are in `data.geom_xpos`:"
|
352 |
+
]
|
353 |
+
},
|
354 |
+
{
|
355 |
+
"cell_type": "code",
|
356 |
+
"execution_count": null,
|
357 |
+
"metadata": {
|
358 |
+
"id": "CPwDcAQ0-uUE"
|
359 |
+
},
|
360 |
+
"outputs": [],
|
361 |
+
"source": [
|
362 |
+
"print(data.geom_xpos)"
|
363 |
+
]
|
364 |
+
},
|
365 |
+
{
|
366 |
+
"cell_type": "markdown",
|
367 |
+
"metadata": {
|
368 |
+
"id": "Sjst5xGXX3sr"
|
369 |
+
},
|
370 |
+
"source": [
|
371 |
+
"Wait, why are both of our geoms at the origin? Didn't we offset the green sphere? The answer is that derived quantities in `mjData` need to be explicitly propagated (see [below](#scrollTo=QY1gpms1HXeN)). In our case, the minimal required function is [`mj_kinematics`](https://mujoco.readthedocs.io/en/latest/APIreference.html#mj-kinematics), which computes global Cartesian poses for all objects (excluding cameras and lights)."
|
372 |
+
]
|
373 |
+
},
|
374 |
+
{
|
375 |
+
"cell_type": "code",
|
376 |
+
"execution_count": null,
|
377 |
+
"metadata": {
|
378 |
+
"id": "tfe0YeZRYNTr"
|
379 |
+
},
|
380 |
+
"outputs": [],
|
381 |
+
"source": [
|
382 |
+
"mujoco.mj_kinematics(model, data)\n",
|
383 |
+
"print('raw access:\\n', data.geom_xpos)\n",
|
384 |
+
"\n",
|
385 |
+
"# MjData also supports named access:\n",
|
386 |
+
"print('\\nnamed access:\\n', data.geom('green_sphere').xpos)"
|
387 |
+
]
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "markdown",
|
391 |
+
"metadata": {
|
392 |
+
"id": "eU7uWNsTwmcZ"
|
393 |
+
},
|
394 |
+
"source": [
|
395 |
+
"# Basic rendering, simulation, and animation\n",
|
396 |
+
"\n",
|
397 |
+
"In order to render we'll need to instantiate a `Renderer` object and call its `render` method.\n",
|
398 |
+
"\n",
|
399 |
+
"We'll also reload our model to make the colab's sections independent."
|
400 |
+
]
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"cell_type": "code",
|
404 |
+
"execution_count": null,
|
405 |
+
"metadata": {
|
406 |
+
"id": "xK3c0-UDxMrN"
|
407 |
+
},
|
408 |
+
"outputs": [],
|
409 |
+
"source": [
|
410 |
+
"xml = \"\"\"\n",
|
411 |
+
"<mujoco>\n",
|
412 |
+
" <worldbody>\n",
|
413 |
+
" <geom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/>\n",
|
414 |
+
" <geom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/>\n",
|
415 |
+
" </worldbody>\n",
|
416 |
+
"</mujoco>\n",
|
417 |
+
"\"\"\"\n",
|
418 |
+
"# Make model and data\n",
|
419 |
+
"model = mujoco.MjModel.from_xml_string(xml)\n",
|
420 |
+
"data = mujoco.MjData(model)\n",
|
421 |
+
"\n",
|
422 |
+
"# Make renderer, render and show the pixels\n",
|
423 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
424 |
+
" media.show_image(renderer.render())"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"cell_type": "markdown",
|
429 |
+
"metadata": {
|
430 |
+
"id": "ZkFSHeYGxlT5"
|
431 |
+
},
|
432 |
+
"source": [
|
433 |
+
"Hmmm, why the black pixels?\n",
|
434 |
+
"\n",
|
435 |
+
"**Answer:** For the same reason as above, we first need to propagate the values in `mjData`. This time we'll call [`mj_forward`](https://mujoco.readthedocs.io/en/latest/APIreference/APIfunctions.html#mj-forward), which invokes the entire pipeline up to the computation of accelerations i.e., it computes $\\dot x = f(x)$, where $x$ is the state. This function does more than we actually need, but unless we care about saving computation time, it's good practice to call `mj_forward` since then we know we are not missing anything.\n",
|
436 |
+
"\n",
|
437 |
+
"We also need to update the `mjvScene` which is an object held by the renderer describing the visual scene. We'll later see that the scene can include visual objects which are not part of the physical model."
|
438 |
+
]
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"cell_type": "code",
|
442 |
+
"execution_count": null,
|
443 |
+
"metadata": {
|
444 |
+
"id": "pvh47r97huS4"
|
445 |
+
},
|
446 |
+
"outputs": [],
|
447 |
+
"source": [
|
448 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
449 |
+
" mujoco.mj_forward(model, data)\n",
|
450 |
+
" renderer.update_scene(data)\n",
|
451 |
+
"\n",
|
452 |
+
" media.show_image(renderer.render())"
|
453 |
+
]
|
454 |
+
},
|
455 |
+
{
|
456 |
+
"cell_type": "markdown",
|
457 |
+
"metadata": {
|
458 |
+
"id": "6oDW1dOUifw6"
|
459 |
+
},
|
460 |
+
"source": [
|
461 |
+
"This worked, but this image is a bit dark. Let's add a light and re-render."
|
462 |
+
]
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"cell_type": "code",
|
466 |
+
"execution_count": null,
|
467 |
+
"metadata": {
|
468 |
+
"id": "iqzJj2NIr_2V"
|
469 |
+
},
|
470 |
+
"outputs": [],
|
471 |
+
"source": [
|
472 |
+
"xml = \"\"\"\n",
|
473 |
+
"<mujoco>\n",
|
474 |
+
" <worldbody>\n",
|
475 |
+
" <light name=\"top\" pos=\"0 0 1\"/>\n",
|
476 |
+
" <geom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/>\n",
|
477 |
+
" <geom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/>\n",
|
478 |
+
" </worldbody>\n",
|
479 |
+
"</mujoco>\n",
|
480 |
+
"\"\"\"\n",
|
481 |
+
"model = mujoco.MjModel.from_xml_string(xml)\n",
|
482 |
+
"data = mujoco.MjData(model)\n",
|
483 |
+
"\n",
|
484 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
485 |
+
" mujoco.mj_forward(model, data)\n",
|
486 |
+
" renderer.update_scene(data)\n",
|
487 |
+
"\n",
|
488 |
+
" media.show_image(renderer.render())"
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "markdown",
|
493 |
+
"metadata": {
|
494 |
+
"id": "HS4K38Eirww9"
|
495 |
+
},
|
496 |
+
"source": [
|
497 |
+
"Much better!\n",
|
498 |
+
"\n",
|
499 |
+
"Note that all values in the `mjModel` instance are writable. While it's generally not recommended to do this but rather to change the values in the XML, because it's easy to make an invalid model, some values are safe to write into, for example colors:"
|
500 |
+
]
|
501 |
+
},
|
502 |
+
{
|
503 |
+
"cell_type": "code",
|
504 |
+
"execution_count": null,
|
505 |
+
"metadata": {
|
506 |
+
"id": "GBNcQVYJrt2h"
|
507 |
+
},
|
508 |
+
"outputs": [],
|
509 |
+
"source": [
|
510 |
+
"# Run this cell multiple times for different colors\n",
|
511 |
+
"model.geom('red_box').rgba[:3] = np.random.rand(3)\n",
|
512 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
513 |
+
" renderer.update_scene(data)\n",
|
514 |
+
"\n",
|
515 |
+
" media.show_image(renderer.render())"
|
516 |
+
]
|
517 |
+
},
|
518 |
+
{
|
519 |
+
"cell_type": "markdown",
|
520 |
+
"metadata": {
|
521 |
+
"id": "-P95E-QHizQq"
|
522 |
+
},
|
523 |
+
"source": [
|
524 |
+
"# Simulation\n",
|
525 |
+
"\n",
|
526 |
+
"Now let's simulate and make a video. We'll use MuJoCo's main high level function `mj_step`, which steps the state $x_{t+h} = f(x_t)$.\n",
|
527 |
+
"\n",
|
528 |
+
"Note that in the code block below we are *not* rendering after each call to `mj_step`. This is because the default timestep is 2ms, and we want a 60fps video, not 500fps."
|
529 |
+
]
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"cell_type": "code",
|
533 |
+
"execution_count": null,
|
534 |
+
"metadata": {
|
535 |
+
"id": "NdVnHOYisiKl"
|
536 |
+
},
|
537 |
+
"outputs": [],
|
538 |
+
"source": [
|
539 |
+
"duration = 3.8 # (seconds)\n",
|
540 |
+
"framerate = 60 # (Hz)\n",
|
541 |
+
"\n",
|
542 |
+
"# Simulate and display video.\n",
|
543 |
+
"frames = []\n",
|
544 |
+
"mujoco.mj_resetData(model, data) # Reset state and time.\n",
|
545 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
546 |
+
" while data.time < duration:\n",
|
547 |
+
" mujoco.mj_step(model, data)\n",
|
548 |
+
" if len(frames) < data.time * framerate:\n",
|
549 |
+
" renderer.update_scene(data)\n",
|
550 |
+
" pixels = renderer.render()\n",
|
551 |
+
" frames.append(pixels)\n",
|
552 |
+
"\n",
|
553 |
+
"media.show_video(frames, fps=framerate)"
|
554 |
+
]
|
555 |
+
},
|
556 |
+
{
|
557 |
+
"cell_type": "markdown",
|
558 |
+
"metadata": {
|
559 |
+
"id": "tYN4sL9RnsCU"
|
560 |
+
},
|
561 |
+
"source": [
|
562 |
+
"Hmmm, the video is playing, but nothing is moving, why is that?\n",
|
563 |
+
"\n",
|
564 |
+
"This is because this model has no [degrees of freedom](https://www.google.com/url?sa=D&q=https%3A%2F%2Fen.wikipedia.org%2Fwiki%2FDegrees_of_freedom_(mechanics)) (DoFs). The things that move (and which have inertia) are called *bodies*. We add DoFs by adding *joints* to bodies, specifying how they can move with respect to their parents. Let's make a new body that contains our geoms, add a hinge joint and re-render, while visualizing the joint axis using the visualization option object `MjvOption`."
|
565 |
+
]
|
566 |
+
},
|
567 |
+
{
|
568 |
+
"cell_type": "code",
|
569 |
+
"execution_count": null,
|
570 |
+
"metadata": {
|
571 |
+
"id": "LbWf84VYst5m"
|
572 |
+
},
|
573 |
+
"outputs": [],
|
574 |
+
"source": [
|
575 |
+
"xml = \"\"\"\n",
|
576 |
+
"<mujoco>\n",
|
577 |
+
" <worldbody>\n",
|
578 |
+
" <light name=\"top\" pos=\"0 0 1\"/>\n",
|
579 |
+
" <body name=\"box_and_sphere\" euler=\"0 0 -30\">\n",
|
580 |
+
" <joint name=\"swing\" type=\"hinge\" axis=\"1 -1 0\" pos=\"-.2 -.2 -.2\"/>\n",
|
581 |
+
" <geom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/>\n",
|
582 |
+
" <geom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/>\n",
|
583 |
+
" </body>\n",
|
584 |
+
" </worldbody>\n",
|
585 |
+
"</mujoco>\n",
|
586 |
+
"\"\"\"\n",
|
587 |
+
"model = mujoco.MjModel.from_xml_string(xml)\n",
|
588 |
+
"data = mujoco.MjData(model)\n",
|
589 |
+
"\n",
|
590 |
+
"# enable joint visualization option:\n",
|
591 |
+
"scene_option = mujoco.MjvOption()\n",
|
592 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = True\n",
|
593 |
+
"\n",
|
594 |
+
"duration = 3.8 # (seconds)\n",
|
595 |
+
"framerate = 60 # (Hz)\n",
|
596 |
+
"\n",
|
597 |
+
"# Simulate and display video.\n",
|
598 |
+
"frames = []\n",
|
599 |
+
"mujoco.mj_resetData(model, data)\n",
|
600 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
601 |
+
" while data.time < duration:\n",
|
602 |
+
" mujoco.mj_step(model, data)\n",
|
603 |
+
" if len(frames) < data.time * framerate:\n",
|
604 |
+
" renderer.update_scene(data, scene_option=scene_option)\n",
|
605 |
+
" pixels = renderer.render()\n",
|
606 |
+
" frames.append(pixels)\n",
|
607 |
+
"\n",
|
608 |
+
"media.show_video(frames, fps=framerate)"
|
609 |
+
]
|
610 |
+
},
|
611 |
+
{
|
612 |
+
"cell_type": "markdown",
|
613 |
+
"metadata": {
|
614 |
+
"id": "Ymv-tvWCpl6V"
|
615 |
+
},
|
616 |
+
"source": [
|
617 |
+
"Note that we rotated the `box_and_sphere` body by 30° around the Z (vertical) axis, with the directive `euler=\"0 0 -30\"`. This was made to emphasize that the poses of elements in the [kinematic tree](https://www.google.com/url?sa=D&q=https%3A%2F%2Fen.wikipedia.org%2Fwiki%2FKinematic_chain) are always with respect to their *parent body*, so our two geoms were also rotated by this transformation.\n",
|
618 |
+
"\n",
|
619 |
+
"Physics options live in `mjModel.opt`, for example the timestep:"
|
620 |
+
]
|
621 |
+
},
|
622 |
+
{
|
623 |
+
"cell_type": "code",
|
624 |
+
"execution_count": null,
|
625 |
+
"metadata": {
|
626 |
+
"id": "5yvAJokcpyX_"
|
627 |
+
},
|
628 |
+
"outputs": [],
|
629 |
+
"source": [
|
630 |
+
"model.opt.timestep"
|
631 |
+
]
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"cell_type": "markdown",
|
635 |
+
"metadata": {
|
636 |
+
"id": "SdkwLeGUp9B2"
|
637 |
+
},
|
638 |
+
"source": [
|
639 |
+
"Let's flip gravity and re-render:"
|
640 |
+
]
|
641 |
+
},
|
642 |
+
{
|
643 |
+
"cell_type": "code",
|
644 |
+
"execution_count": null,
|
645 |
+
"metadata": {
|
646 |
+
"id": "ocjPQG8Dp2F-"
|
647 |
+
},
|
648 |
+
"outputs": [],
|
649 |
+
"source": [
|
650 |
+
"print('default gravity', model.opt.gravity)\n",
|
651 |
+
"model.opt.gravity = (0, 0, 10)\n",
|
652 |
+
"print('flipped gravity', model.opt.gravity)\n",
|
653 |
+
"\n",
|
654 |
+
"# Simulate and display video.\n",
|
655 |
+
"frames = []\n",
|
656 |
+
"mujoco.mj_resetData(model, data)\n",
|
657 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
658 |
+
" while data.time < duration:\n",
|
659 |
+
" mujoco.mj_step(model, data)\n",
|
660 |
+
" if len(frames) < data.time * framerate:\n",
|
661 |
+
" renderer.update_scene(data, scene_option=scene_option)\n",
|
662 |
+
" pixels = renderer.render()\n",
|
663 |
+
" frames.append(pixels)\n",
|
664 |
+
"\n",
|
665 |
+
"media.show_video(frames, fps=60)"
|
666 |
+
]
|
667 |
+
},
|
668 |
+
{
|
669 |
+
"cell_type": "markdown",
|
670 |
+
"metadata": {
|
671 |
+
"id": "FsxDDgXBqg_J"
|
672 |
+
},
|
673 |
+
"source": [
|
674 |
+
"We could also have done this in XML using the top-level `<option>` element:\n",
|
675 |
+
"```xml\n",
|
676 |
+
"<mujoco>\n",
|
677 |
+
" <option gravity=\"0 0 10\"/>\n",
|
678 |
+
" ...\n",
|
679 |
+
"</mujoco>\n",
|
680 |
+
"```"
|
681 |
+
]
|
682 |
+
},
|
683 |
+
{
|
684 |
+
"cell_type": "markdown",
|
685 |
+
"metadata": {
|
686 |
+
"id": "QY1gpms1HXeN"
|
687 |
+
},
|
688 |
+
"source": [
|
689 |
+
"### Understanding Degrees of Freedom\n",
|
690 |
+
"\n",
|
691 |
+
"In the real world, all rigid objects have 6 degrees-of-freedom: 3 translations and 3 rotations. Real-world joints act as constraints, removing relative degrees-of-freedom from bodies connected by joints. Some physics simulation software use this representation which is known as the \"Cartesian\" or \"subtractive\" representation, but it is inefficient. MuJoCo uses a representation known as the \"Lagrangian\", \"generalized\" or \"additive\" representation, whereby objects have no degrees of freedom unless explicitly added using joints.\n",
|
692 |
+
"\n",
|
693 |
+
"Our model, which has a single hinge joint, has one degree of freedom, and the entire state is defined by this joint's angle and angular velocity. These are the system's generalized position and velocity."
|
694 |
+
]
|
695 |
+
},
|
696 |
+
{
|
697 |
+
"cell_type": "code",
|
698 |
+
"execution_count": null,
|
699 |
+
"metadata": {
|
700 |
+
"id": "wEdfGEfSKAOC"
|
701 |
+
},
|
702 |
+
"outputs": [],
|
703 |
+
"source": [
|
704 |
+
"print('Total number of DoFs in the model:', model.nv)\n",
|
705 |
+
"print('Generalized positions:', data.qpos)\n",
|
706 |
+
"print('Generalized velocities:', data.qvel)"
|
707 |
+
]
|
708 |
+
},
|
709 |
+
{
|
710 |
+
"cell_type": "markdown",
|
711 |
+
"metadata": {
|
712 |
+
"id": "Z8E-P5xONUSn"
|
713 |
+
},
|
714 |
+
"source": [
|
715 |
+
"MuJoCo's use of generalized coordinates is the reason that calling a function (e.g. [`mj_forward`](https://mujoco.readthedocs.io/en/latest/APIreference.html#mj-forward)) is required before rendering or reading the global poses of objects – Cartesian positions are *derived* from the generalized positions and need to be explicitly computed."
|
716 |
+
]
|
717 |
+
},
|
718 |
+
{
|
719 |
+
"cell_type": "markdown",
|
720 |
+
"metadata": {
|
721 |
+
"id": "SHppAOjvSupc"
|
722 |
+
},
|
723 |
+
"source": [
|
724 |
+
"# Example: Simulating free bodies with the self-inverting \"tippe-top\"\n",
|
725 |
+
"\n",
|
726 |
+
"A free body is a body with a [free joint](https://www.google.com/url?sa=D&q=https%3A%2F%2Fmujoco.readthedocs.io%2Fen%2Flatest%2FXMLreference.html%3Fhighlight%3Dfreejoint%23body-freejoint) having 6 DoFs, i.e., 3 translations and 3 rotations. We could give our `box_and_sphere` body a free joint and watch it fall, but let's look at something more interesting. A \"tippe top\" is a spinning toy which flips itself ([video](https://www.youtube.com/watch?v=kbYpVrdcszQ), [Wikipedia](https://en.wikipedia.org/wiki/Tippe_top)). We model it as follows:"
|
727 |
+
]
|
728 |
+
},
|
729 |
+
{
|
730 |
+
"cell_type": "code",
|
731 |
+
"execution_count": null,
|
732 |
+
"metadata": {
|
733 |
+
"id": "xasXQpVMjIwA"
|
734 |
+
},
|
735 |
+
"outputs": [],
|
736 |
+
"source": [
|
737 |
+
"tippe_top = \"\"\"\n",
|
738 |
+
"<mujoco model=\"tippe top\">\n",
|
739 |
+
" <option integrator=\"RK4\"/>\n",
|
740 |
+
"\n",
|
741 |
+
" <asset>\n",
|
742 |
+
" <texture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".1 .2 .3\"\n",
|
743 |
+
" rgb2=\".2 .3 .4\" width=\"300\" height=\"300\"/>\n",
|
744 |
+
" <material name=\"grid\" texture=\"grid\" texrepeat=\"8 8\" reflectance=\".2\"/>\n",
|
745 |
+
" </asset>\n",
|
746 |
+
"\n",
|
747 |
+
" <worldbody>\n",
|
748 |
+
" <geom size=\".2 .2 .01\" type=\"plane\" material=\"grid\"/>\n",
|
749 |
+
" <light pos=\"0 0 .6\"/>\n",
|
750 |
+
" <camera name=\"closeup\" pos=\"0 -.1 .07\" xyaxes=\"1 0 0 0 1 2\"/>\n",
|
751 |
+
" <body name=\"top\" pos=\"0 0 .02\">\n",
|
752 |
+
" <freejoint/>\n",
|
753 |
+
" <geom name=\"ball\" type=\"sphere\" size=\".02\" />\n",
|
754 |
+
" <geom name=\"stem\" type=\"cylinder\" pos=\"0 0 .02\" size=\"0.004 .008\"/>\n",
|
755 |
+
" <geom name=\"ballast\" type=\"box\" size=\".023 .023 0.005\" pos=\"0 0 -.015\"\n",
|
756 |
+
" contype=\"0\" conaffinity=\"0\" group=\"3\"/>\n",
|
757 |
+
" </body>\n",
|
758 |
+
" </worldbody>\n",
|
759 |
+
"\n",
|
760 |
+
" <keyframe>\n",
|
761 |
+
" <key name=\"spinning\" qpos=\"0 0 0.02 1 0 0 0\" qvel=\"0 0 0 0 1 200\" />\n",
|
762 |
+
" </keyframe>\n",
|
763 |
+
"</mujoco>\n",
|
764 |
+
"\"\"\"\n",
|
765 |
+
"model = mujoco.MjModel.from_xml_string(tippe_top)\n",
|
766 |
+
"data = mujoco.MjData(model)\n",
|
767 |
+
"\n",
|
768 |
+
"mujoco.mj_forward(model, data)\n",
|
769 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
770 |
+
" renderer.update_scene(data, camera=\"closeup\")\n",
|
771 |
+
"\n",
|
772 |
+
" media.show_image(renderer.render())"
|
773 |
+
]
|
774 |
+
},
|
775 |
+
{
|
776 |
+
"cell_type": "markdown",
|
777 |
+
"metadata": {
|
778 |
+
"id": "bvHlr6maJYIG"
|
779 |
+
},
|
780 |
+
"source": [
|
781 |
+
"Note several new features of this model definition:\n",
|
782 |
+
"1. A 6-DoF free joint is added with the `<freejoint/>` clause.\n",
|
783 |
+
"2. We use the `<option/>` clause to set the integrator to the 4th order [Runge Kutta](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods). Runge-Kutta has a higher rate of convergence than the default Euler integrator, which in many cases increases the accuracy at a given timestep size.\n",
|
784 |
+
"3. We define the floor's grid material inside the `<asset/>` clause and reference it in the `\"floor\"` geom.\n",
|
785 |
+
"4. We use an invisible and non-colliding box geom called `ballast` to move the top's center-of-mass lower. Having a low center of mass is (counter-intuitively) required for the flipping behavior to occur.\n",
|
786 |
+
"5. We save our initial spinning state as a *keyframe*. It has a high rotational velocity around the Z-axis, but is not perfectly oriented with the world, which introduces the symmetry-breaking required for the flipping.\n",
|
787 |
+
"6. We define a `<camera>` in our model, and then render from it using the `camera` argument to `update_scene()`.\n",
|
788 |
+
"Let us examine the state:\n"
|
789 |
+
]
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"cell_type": "code",
|
793 |
+
"execution_count": null,
|
794 |
+
"metadata": {
|
795 |
+
"id": "o4S9nYhHOKmb"
|
796 |
+
},
|
797 |
+
"outputs": [],
|
798 |
+
"source": [
|
799 |
+
"print('positions', data.qpos)\n",
|
800 |
+
"print('velocities', data.qvel)"
|
801 |
+
]
|
802 |
+
},
|
803 |
+
{
|
804 |
+
"cell_type": "markdown",
|
805 |
+
"metadata": {
|
806 |
+
"id": "71UgzBAqWdtZ"
|
807 |
+
},
|
808 |
+
"source": [
|
809 |
+
"The velocities are easy to interpret, 6 zeros, one for each DoF. What about the length 7 positions? We can see the initial 2cm height of the body; the subsequent four numbers are the 3D orientation, defined by a *unit quaternion*. 3D orientations are represented with **4** numbers while angular velocities are **3** numbers. For more information see the Wikipedia article on [quaternions and spatial rotation](https://en.wikipedia.org/wiki/Quaternions_and_spatial_rotation).\n",
|
810 |
+
"\n",
|
811 |
+
"Let's make a video:"
|
812 |
+
]
|
813 |
+
},
|
814 |
+
{
|
815 |
+
"cell_type": "code",
|
816 |
+
"execution_count": null,
|
817 |
+
"metadata": {
|
818 |
+
"id": "5P4HkhKNGQvs"
|
819 |
+
},
|
820 |
+
"outputs": [],
|
821 |
+
"source": [
|
822 |
+
"duration = 7 # (seconds)\n",
|
823 |
+
"framerate = 60 # (Hz)\n",
|
824 |
+
"\n",
|
825 |
+
"# Simulate and display video.\n",
|
826 |
+
"frames = []\n",
|
827 |
+
"mujoco.mj_resetDataKeyframe(model, data, 0) # Reset the state to keyframe 0\n",
|
828 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
829 |
+
" while data.time < duration:\n",
|
830 |
+
" mujoco.mj_step(model, data)\n",
|
831 |
+
" if len(frames) < data.time * framerate:\n",
|
832 |
+
" renderer.update_scene(data, \"closeup\")\n",
|
833 |
+
" pixels = renderer.render()\n",
|
834 |
+
" frames.append(pixels)\n",
|
835 |
+
"\n",
|
836 |
+
"media.show_video(frames, fps=framerate)"
|
837 |
+
]
|
838 |
+
},
|
839 |
+
{
|
840 |
+
"cell_type": "markdown",
|
841 |
+
"metadata": {
|
842 |
+
"id": "rRuFKD2ubPgu"
|
843 |
+
},
|
844 |
+
"source": [
|
845 |
+
"### Measuring values from `mjData`\n",
|
846 |
+
"As mentioned above, the `mjData` structure contains the dynamic variables and intermediate results produced by the simulation which are *expected to change* on each timestep. Below we simulate for 2000 timesteps and plot the angular velocity of the top and height of the stem as a function of time."
|
847 |
+
]
|
848 |
+
},
|
849 |
+
{
|
850 |
+
"cell_type": "code",
|
851 |
+
"execution_count": null,
|
852 |
+
"metadata": {
|
853 |
+
"id": "1XXB6asJoZ2N"
|
854 |
+
},
|
855 |
+
"outputs": [],
|
856 |
+
"source": [
|
857 |
+
"timevals = []\n",
|
858 |
+
"angular_velocity = []\n",
|
859 |
+
"stem_height = []\n",
|
860 |
+
"\n",
|
861 |
+
"# Simulate and save data\n",
|
862 |
+
"mujoco.mj_resetDataKeyframe(model, data, 0)\n",
|
863 |
+
"while data.time < duration:\n",
|
864 |
+
" mujoco.mj_step(model, data)\n",
|
865 |
+
" timevals.append(data.time)\n",
|
866 |
+
" angular_velocity.append(data.qvel[3:6].copy())\n",
|
867 |
+
" stem_height.append(data.geom_xpos[2,2]);\n",
|
868 |
+
"\n",
|
869 |
+
"dpi = 120\n",
|
870 |
+
"width = 600\n",
|
871 |
+
"height = 800\n",
|
872 |
+
"figsize = (width / dpi, height / dpi)\n",
|
873 |
+
"_, ax = plt.subplots(2, 1, figsize=figsize, dpi=dpi, sharex=True)\n",
|
874 |
+
"\n",
|
875 |
+
"ax[0].plot(timevals, angular_velocity)\n",
|
876 |
+
"ax[0].set_title('angular velocity')\n",
|
877 |
+
"ax[0].set_ylabel('radians / second')\n",
|
878 |
+
"\n",
|
879 |
+
"ax[1].plot(timevals, stem_height)\n",
|
880 |
+
"ax[1].set_xlabel('time (seconds)')\n",
|
881 |
+
"ax[1].set_ylabel('meters')\n",
|
882 |
+
"_ = ax[1].set_title('stem height')"
|
883 |
+
]
|
884 |
+
},
|
885 |
+
{
|
886 |
+
"cell_type": "markdown",
|
887 |
+
"metadata": {
|
888 |
+
"id": "u_zN8vATwcGy"
|
889 |
+
},
|
890 |
+
"source": [
|
891 |
+
"# Example: A chaotic pendulum"
|
892 |
+
]
|
893 |
+
},
|
894 |
+
{
|
895 |
+
"cell_type": "markdown",
|
896 |
+
"metadata": {
|
897 |
+
"id": "g1MKUEL_eSCM"
|
898 |
+
},
|
899 |
+
"source": [
|
900 |
+
"Below is a model of a chaotic pendulum, similar to [this one](https://www.exploratorium.edu/exhibits/chaotic-pendulum) in the San Francisco Exploratorium."
|
901 |
+
]
|
902 |
+
},
|
903 |
+
{
|
904 |
+
"cell_type": "code",
|
905 |
+
"execution_count": null,
|
906 |
+
"metadata": {
|
907 |
+
"id": "3jHYTV-bwfrS"
|
908 |
+
},
|
909 |
+
"outputs": [],
|
910 |
+
"source": [
|
911 |
+
"chaotic_pendulum = \"\"\"\n",
|
912 |
+
"<mujoco>\n",
|
913 |
+
" <option timestep=\".001\">\n",
|
914 |
+
" <flag energy=\"enable\" contact=\"disable\"/>\n",
|
915 |
+
" </option>\n",
|
916 |
+
"\n",
|
917 |
+
" <default>\n",
|
918 |
+
" <joint type=\"hinge\" axis=\"0 -1 0\"/>\n",
|
919 |
+
" <geom type=\"capsule\" size=\".02\"/>\n",
|
920 |
+
" </default>\n",
|
921 |
+
"\n",
|
922 |
+
" <worldbody>\n",
|
923 |
+
" <light pos=\"0 -.4 1\"/>\n",
|
924 |
+
" <camera name=\"fixed\" pos=\"0 -1 0\" xyaxes=\"1 0 0 0 0 1\"/>\n",
|
925 |
+
" <body name=\"0\" pos=\"0 0 .2\">\n",
|
926 |
+
" <joint name=\"root\"/>\n",
|
927 |
+
" <geom fromto=\"-.2 0 0 .2 0 0\" rgba=\"1 1 0 1\"/>\n",
|
928 |
+
" <geom fromto=\"0 0 0 0 0 -.25\" rgba=\"1 1 0 1\"/>\n",
|
929 |
+
" <body name=\"1\" pos=\"-.2 0 0\">\n",
|
930 |
+
" <joint/>\n",
|
931 |
+
" <geom fromto=\"0 0 0 0 0 -.2\" rgba=\"1 0 0 1\"/>\n",
|
932 |
+
" </body>\n",
|
933 |
+
" <body name=\"2\" pos=\".2 0 0\">\n",
|
934 |
+
" <joint/>\n",
|
935 |
+
" <geom fromto=\"0 0 0 0 0 -.2\" rgba=\"0 1 0 1\"/>\n",
|
936 |
+
" </body>\n",
|
937 |
+
" <body name=\"3\" pos=\"0 0 -.25\">\n",
|
938 |
+
" <joint/>\n",
|
939 |
+
" <geom fromto=\"0 0 0 0 0 -.2\" rgba=\"0 0 1 1\"/>\n",
|
940 |
+
" </body>\n",
|
941 |
+
" </body>\n",
|
942 |
+
" </worldbody>\n",
|
943 |
+
"</mujoco>\n",
|
944 |
+
"\"\"\"\n",
|
945 |
+
"model = mujoco.MjModel.from_xml_string(chaotic_pendulum)\n",
|
946 |
+
"data = mujoco.MjData(model)\n",
|
947 |
+
"height = 480\n",
|
948 |
+
"width = 640\n",
|
949 |
+
"\n",
|
950 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
951 |
+
" mujoco.mj_forward(model, data)\n",
|
952 |
+
" renderer.update_scene(data, camera=\"fixed\")\n",
|
953 |
+
"\n",
|
954 |
+
" media.show_image(renderer.render())"
|
955 |
+
]
|
956 |
+
},
|
957 |
+
{
|
958 |
+
"cell_type": "markdown",
|
959 |
+
"metadata": {
|
960 |
+
"id": "EKZrTBSS5f49"
|
961 |
+
},
|
962 |
+
"source": [
|
963 |
+
"## Timing\n",
|
964 |
+
"Let's see a video of it in action while we time the components:"
|
965 |
+
]
|
966 |
+
},
|
967 |
+
{
|
968 |
+
"cell_type": "code",
|
969 |
+
"execution_count": null,
|
970 |
+
"metadata": {
|
971 |
+
"id": "-kNWvE9dNwYW"
|
972 |
+
},
|
973 |
+
"outputs": [],
|
974 |
+
"source": [
|
975 |
+
"# setup\n",
|
976 |
+
"n_seconds = 6\n",
|
977 |
+
"framerate = 30 # Hz\n",
|
978 |
+
"n_frames = int(n_seconds * framerate)\n",
|
979 |
+
"frames = []\n",
|
980 |
+
"height = 240\n",
|
981 |
+
"width = 320\n",
|
982 |
+
"\n",
|
983 |
+
"# set initial state\n",
|
984 |
+
"mujoco.mj_resetData(model, data)\n",
|
985 |
+
"data.joint('root').qvel = 10\n",
|
986 |
+
"\n",
|
987 |
+
"# simulate and record frames\n",
|
988 |
+
"frame = 0\n",
|
989 |
+
"sim_time = 0\n",
|
990 |
+
"render_time = 0\n",
|
991 |
+
"n_steps = 0\n",
|
992 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
993 |
+
" for i in range(n_frames):\n",
|
994 |
+
" while data.time * framerate < i:\n",
|
995 |
+
" tic = time.time()\n",
|
996 |
+
" mujoco.mj_step(model, data)\n",
|
997 |
+
" sim_time += time.time() - tic\n",
|
998 |
+
" n_steps += 1\n",
|
999 |
+
" tic = time.time()\n",
|
1000 |
+
" renderer.update_scene(data, \"fixed\")\n",
|
1001 |
+
" frame = renderer.render()\n",
|
1002 |
+
" render_time += time.time() - tic\n",
|
1003 |
+
" frames.append(frame)\n",
|
1004 |
+
"\n",
|
1005 |
+
"# print timing and play video\n",
|
1006 |
+
"step_time = 1e6*sim_time/n_steps\n",
|
1007 |
+
"step_fps = n_steps/sim_time\n",
|
1008 |
+
"print(f'simulation: {step_time:5.3g} μs/step ({step_fps:5.0f}Hz)')\n",
|
1009 |
+
"frame_time = 1e6*render_time/n_frames\n",
|
1010 |
+
"frame_fps = n_frames/render_time\n",
|
1011 |
+
"print(f'rendering: {frame_time:5.3g} μs/frame ({frame_fps:5.0f}Hz)')\n",
|
1012 |
+
"print('\\n')\n",
|
1013 |
+
"\n",
|
1014 |
+
"# show video\n",
|
1015 |
+
"media.show_video(frames, fps=framerate)"
|
1016 |
+
]
|
1017 |
+
},
|
1018 |
+
{
|
1019 |
+
"cell_type": "markdown",
|
1020 |
+
"metadata": {
|
1021 |
+
"id": "Iqi_m8HT-X5k"
|
1022 |
+
},
|
1023 |
+
"source": [
|
1024 |
+
"Note that rendering is **much** slower than the simulated physics.\n",
|
1025 |
+
"\n",
|
1026 |
+
"## Chaos\n",
|
1027 |
+
"This is a [chaotic](https://en.wikipedia.org/wiki/Chaos_theory) system (small pertubations in initial conditions accumulate quickly):"
|
1028 |
+
]
|
1029 |
+
},
|
1030 |
+
{
|
1031 |
+
"cell_type": "code",
|
1032 |
+
"execution_count": null,
|
1033 |
+
"metadata": {
|
1034 |
+
"id": "Pa_19EfvOzzg"
|
1035 |
+
},
|
1036 |
+
"outputs": [],
|
1037 |
+
"source": [
|
1038 |
+
"PERTURBATION = 1e-7\n",
|
1039 |
+
"SIM_DURATION = 10 # seconds\n",
|
1040 |
+
"NUM_REPEATS = 8\n",
|
1041 |
+
"\n",
|
1042 |
+
"# preallocate\n",
|
1043 |
+
"n_steps = int(SIM_DURATION / model.opt.timestep)\n",
|
1044 |
+
"sim_time = np.zeros(n_steps)\n",
|
1045 |
+
"angle = np.zeros(n_steps)\n",
|
1046 |
+
"energy = np.zeros(n_steps)\n",
|
1047 |
+
"\n",
|
1048 |
+
"# prepare plotting axes\n",
|
1049 |
+
"_, ax = plt.subplots(2, 1, figsize=(8, 6), sharex=True)\n",
|
1050 |
+
"\n",
|
1051 |
+
"# simulate NUM_REPEATS times with slightly different initial conditions\n",
|
1052 |
+
"for _ in range(NUM_REPEATS):\n",
|
1053 |
+
" # initialize\n",
|
1054 |
+
" mujoco.mj_resetData(model, data)\n",
|
1055 |
+
" data.qvel[0] = 10 # root joint velocity\n",
|
1056 |
+
" # perturb initial velocities\n",
|
1057 |
+
" data.qvel[:] += PERTURBATION * np.random.randn(model.nv)\n",
|
1058 |
+
"\n",
|
1059 |
+
" # simulate\n",
|
1060 |
+
" for i in range(n_steps):\n",
|
1061 |
+
" mujoco.mj_step(model, data)\n",
|
1062 |
+
" sim_time[i] = data.time\n",
|
1063 |
+
" angle[i] = data.joint('root').qpos\n",
|
1064 |
+
" energy[i] = data.energy[0] + data.energy[1]\n",
|
1065 |
+
"\n",
|
1066 |
+
" # plot\n",
|
1067 |
+
" ax[0].plot(sim_time, angle)\n",
|
1068 |
+
" ax[1].plot(sim_time, energy)\n",
|
1069 |
+
"\n",
|
1070 |
+
"# finalize plot\n",
|
1071 |
+
"ax[0].set_title('root angle')\n",
|
1072 |
+
"ax[0].set_ylabel('radian')\n",
|
1073 |
+
"ax[1].set_title('total energy')\n",
|
1074 |
+
"ax[1].set_ylabel('Joule')\n",
|
1075 |
+
"ax[1].set_xlabel('second')\n",
|
1076 |
+
"plt.tight_layout()"
|
1077 |
+
]
|
1078 |
+
},
|
1079 |
+
{
|
1080 |
+
"cell_type": "markdown",
|
1081 |
+
"metadata": {
|
1082 |
+
"id": "daSIA_ewFGxV"
|
1083 |
+
},
|
1084 |
+
"source": [
|
1085 |
+
"## Timestep and accuracy\n",
|
1086 |
+
"**Question:** Why is the energy varying at all? There is no friction or damping, this system should conserve energy.\n",
|
1087 |
+
"\n",
|
1088 |
+
"**Answer:** Because of the discretization of time.\n",
|
1089 |
+
"\n",
|
1090 |
+
"If we decrease the timestep we'll get better accuracy and better energy conservation:"
|
1091 |
+
]
|
1092 |
+
},
|
1093 |
+
{
|
1094 |
+
"cell_type": "code",
|
1095 |
+
"execution_count": null,
|
1096 |
+
"metadata": {
|
1097 |
+
"id": "4z-7KN_fFme-"
|
1098 |
+
},
|
1099 |
+
"outputs": [],
|
1100 |
+
"source": [
|
1101 |
+
"SIM_DURATION = 10 # (seconds)\n",
|
1102 |
+
"TIMESTEPS = np.power(10, np.linspace(-2, -4, 5))\n",
|
1103 |
+
"\n",
|
1104 |
+
"# prepare plotting axes\n",
|
1105 |
+
"_, ax = plt.subplots(1, 1)\n",
|
1106 |
+
"\n",
|
1107 |
+
"for dt in TIMESTEPS:\n",
|
1108 |
+
" # set timestep, print\n",
|
1109 |
+
" model.opt.timestep = dt\n",
|
1110 |
+
"\n",
|
1111 |
+
" # allocate\n",
|
1112 |
+
" n_steps = int(SIM_DURATION / model.opt.timestep)\n",
|
1113 |
+
" sim_time = np.zeros(n_steps)\n",
|
1114 |
+
" energy = np.zeros(n_steps)\n",
|
1115 |
+
"\n",
|
1116 |
+
" # initialize\n",
|
1117 |
+
" mujoco.mj_resetData(model, data)\n",
|
1118 |
+
" data.qvel[0] = 9 # root joint velocity\n",
|
1119 |
+
"\n",
|
1120 |
+
" # simulate\n",
|
1121 |
+
" print('{} steps at dt = {:2.2g}ms'.format(n_steps, 1000*dt))\n",
|
1122 |
+
" for i in range(n_steps):\n",
|
1123 |
+
" mujoco.mj_step(model, data)\n",
|
1124 |
+
" sim_time[i] = data.time\n",
|
1125 |
+
" energy[i] = data.energy[0] + data.energy[1]\n",
|
1126 |
+
"\n",
|
1127 |
+
" # plot\n",
|
1128 |
+
" ax.plot(sim_time, energy, label='timestep = {:2.2g}ms'.format(1000*dt))\n",
|
1129 |
+
"\n",
|
1130 |
+
"# finalize plot\n",
|
1131 |
+
"ax.set_title('energy')\n",
|
1132 |
+
"ax.set_ylabel('Joule')\n",
|
1133 |
+
"ax.set_xlabel('second')\n",
|
1134 |
+
"ax.legend(frameon=True);\n",
|
1135 |
+
"plt.tight_layout()"
|
1136 |
+
]
|
1137 |
+
},
|
1138 |
+
{
|
1139 |
+
"cell_type": "markdown",
|
1140 |
+
"metadata": {
|
1141 |
+
"id": "jsVkUm7QKb9I"
|
1142 |
+
},
|
1143 |
+
"source": [
|
1144 |
+
"## Timestep and divergence\n",
|
1145 |
+
"When we increase the time step, the simulation quickly diverges:"
|
1146 |
+
]
|
1147 |
+
},
|
1148 |
+
{
|
1149 |
+
"cell_type": "code",
|
1150 |
+
"execution_count": null,
|
1151 |
+
"metadata": {
|
1152 |
+
"id": "FbdUA4zDPbDP"
|
1153 |
+
},
|
1154 |
+
"outputs": [],
|
1155 |
+
"source": [
|
1156 |
+
"SIM_DURATION = 10 # (seconds)\n",
|
1157 |
+
"TIMESTEPS = np.power(10, np.linspace(-2, -1.5, 7))\n",
|
1158 |
+
"\n",
|
1159 |
+
"# get plotting axes\n",
|
1160 |
+
"ax = plt.gca()\n",
|
1161 |
+
"\n",
|
1162 |
+
"for dt in TIMESTEPS:\n",
|
1163 |
+
" # set timestep\n",
|
1164 |
+
" model.opt.timestep = dt\n",
|
1165 |
+
"\n",
|
1166 |
+
" # allocate\n",
|
1167 |
+
" n_steps = int(SIM_DURATION / model.opt.timestep)\n",
|
1168 |
+
" sim_time = np.zeros(n_steps)\n",
|
1169 |
+
" energy = np.zeros(n_steps) * np.nan\n",
|
1170 |
+
" speed = np.zeros(n_steps) * np.nan\n",
|
1171 |
+
"\n",
|
1172 |
+
" # initialize\n",
|
1173 |
+
" mujoco.mj_resetData(model, data)\n",
|
1174 |
+
" data.qvel[0] = 11 # set root joint velocity\n",
|
1175 |
+
"\n",
|
1176 |
+
" # simulate\n",
|
1177 |
+
" print('simulating {} steps at dt = {:2.2g}ms'.format(n_steps, 1000*dt))\n",
|
1178 |
+
" for i in range(n_steps):\n",
|
1179 |
+
" mujoco.mj_step(model, data)\n",
|
1180 |
+
" if data.warning.number.any():\n",
|
1181 |
+
" warning_index = np.nonzero(data.warning.number)[0][0]\n",
|
1182 |
+
" warning = mujoco.mjtWarning(warning_index).name\n",
|
1183 |
+
" print(f'stopped due to divergence ({warning}) at timestep {i}.\\n')\n",
|
1184 |
+
" break\n",
|
1185 |
+
" sim_time[i] = data.time\n",
|
1186 |
+
" energy[i] = sum(abs(data.qvel))\n",
|
1187 |
+
" speed[i] = np.linalg.norm(data.qvel)\n",
|
1188 |
+
"\n",
|
1189 |
+
" # plot\n",
|
1190 |
+
" ax.plot(sim_time, energy, label='timestep = {:2.2g}ms'.format(1000*dt))\n",
|
1191 |
+
" ax.set_yscale('log')\n",
|
1192 |
+
"\n",
|
1193 |
+
"# finalize plot\n",
|
1194 |
+
"ax.set_ybound(1, 1e3)\n",
|
1195 |
+
"ax.set_title('energy')\n",
|
1196 |
+
"ax.set_ylabel('Joule')\n",
|
1197 |
+
"ax.set_xlabel('second')\n",
|
1198 |
+
"ax.legend(frameon=True, loc='lower right');\n",
|
1199 |
+
"plt.tight_layout()"
|
1200 |
+
]
|
1201 |
+
},
|
1202 |
+
{
|
1203 |
+
"cell_type": "markdown",
|
1204 |
+
"metadata": {
|
1205 |
+
"id": "FITYfGyy3XPL"
|
1206 |
+
},
|
1207 |
+
"source": [
|
1208 |
+
"# Contacts\n",
|
1209 |
+
"\n",
|
1210 |
+
"Let's go back to our box and sphere example and give it a free joint:"
|
1211 |
+
]
|
1212 |
+
},
|
1213 |
+
{
|
1214 |
+
"cell_type": "code",
|
1215 |
+
"execution_count": null,
|
1216 |
+
"metadata": {
|
1217 |
+
"id": "2n1VNVv_FkbB"
|
1218 |
+
},
|
1219 |
+
"outputs": [],
|
1220 |
+
"source": [
|
1221 |
+
"free_body_MJCF = \"\"\"\n",
|
1222 |
+
"<mujoco>\n",
|
1223 |
+
" <asset>\n",
|
1224 |
+
" <texture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".1 .2 .3\"\n",
|
1225 |
+
" rgb2=\".2 .3 .4\" width=\"300\" height=\"300\" mark=\"edge\" markrgb=\".2 .3 .4\"/>\n",
|
1226 |
+
" <material name=\"grid\" texture=\"grid\" texrepeat=\"2 2\" texuniform=\"true\"\n",
|
1227 |
+
" reflectance=\".2\"/>\n",
|
1228 |
+
" </asset>\n",
|
1229 |
+
"\n",
|
1230 |
+
" <worldbody>\n",
|
1231 |
+
" <light pos=\"0 0 1\" mode=\"trackcom\"/>\n",
|
1232 |
+
" <geom name=\"ground\" type=\"plane\" pos=\"0 0 -.5\" size=\"2 2 .1\" material=\"grid\" solimp=\".99 .99 .01\" solref=\".001 1\"/>\n",
|
1233 |
+
" <body name=\"box_and_sphere\" pos=\"0 0 0\">\n",
|
1234 |
+
" <freejoint/>\n",
|
1235 |
+
" <geom name=\"red_box\" type=\"box\" size=\".1 .1 .1\" rgba=\"1 0 0 1\" solimp=\".99 .99 .01\" solref=\".001 1\"/>\n",
|
1236 |
+
" <geom name=\"green_sphere\" size=\".06\" pos=\".1 .1 .1\" rgba=\"0 1 0 1\"/>\n",
|
1237 |
+
" <camera name=\"fixed\" pos=\"0 -.6 .3\" xyaxes=\"1 0 0 0 1 2\"/>\n",
|
1238 |
+
" <camera name=\"track\" pos=\"0 -.6 .3\" xyaxes=\"1 0 0 0 1 2\" mode=\"track\"/>\n",
|
1239 |
+
" </body>\n",
|
1240 |
+
" </worldbody>\n",
|
1241 |
+
"</mujoco>\n",
|
1242 |
+
"\"\"\"\n",
|
1243 |
+
"model = mujoco.MjModel.from_xml_string(free_body_MJCF)\n",
|
1244 |
+
"data = mujoco.MjData(model)\n",
|
1245 |
+
"height = 400\n",
|
1246 |
+
"width = 600\n",
|
1247 |
+
"\n",
|
1248 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
1249 |
+
" mujoco.mj_forward(model, data)\n",
|
1250 |
+
" renderer.update_scene(data, \"fixed\")\n",
|
1251 |
+
"\n",
|
1252 |
+
" media.show_image(renderer.render())"
|
1253 |
+
]
|
1254 |
+
},
|
1255 |
+
{
|
1256 |
+
"cell_type": "markdown",
|
1257 |
+
"metadata": {
|
1258 |
+
"id": "Z2amdQCn8REu"
|
1259 |
+
},
|
1260 |
+
"source": [
|
1261 |
+
"Let render this body rolling on the floor, in slow-motion, while visualizing contact points and forces:"
|
1262 |
+
]
|
1263 |
+
},
|
1264 |
+
{
|
1265 |
+
"cell_type": "code",
|
1266 |
+
"execution_count": null,
|
1267 |
+
"metadata": {
|
1268 |
+
"id": "HlRhFs_d3WLP"
|
1269 |
+
},
|
1270 |
+
"outputs": [],
|
1271 |
+
"source": [
|
1272 |
+
"n_frames = 200\n",
|
1273 |
+
"height = 240\n",
|
1274 |
+
"width = 320\n",
|
1275 |
+
"frames = []\n",
|
1276 |
+
"\n",
|
1277 |
+
"# visualize contact frames and forces, make body transparent\n",
|
1278 |
+
"options = mujoco.MjvOption()\n",
|
1279 |
+
"mujoco.mjv_defaultOption(options)\n",
|
1280 |
+
"options.flags[mujoco.mjtVisFlag.mjVIS_CONTACTPOINT] = True\n",
|
1281 |
+
"options.flags[mujoco.mjtVisFlag.mjVIS_CONTACTFORCE] = True\n",
|
1282 |
+
"options.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True\n",
|
1283 |
+
"\n",
|
1284 |
+
"# tweak scales of contact visualization elements\n",
|
1285 |
+
"model.vis.scale.contactwidth = 0.1\n",
|
1286 |
+
"model.vis.scale.contactheight = 0.03\n",
|
1287 |
+
"model.vis.scale.forcewidth = 0.05\n",
|
1288 |
+
"model.vis.map.force = 0.3\n",
|
1289 |
+
"\n",
|
1290 |
+
"# random initial rotational velocity:\n",
|
1291 |
+
"mujoco.mj_resetData(model, data)\n",
|
1292 |
+
"data.qvel[3:6] = 5*np.random.randn(3)\n",
|
1293 |
+
"\n",
|
1294 |
+
"# Simulate and display video.\n",
|
1295 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
1296 |
+
" for i in range(n_frames):\n",
|
1297 |
+
" while data.time < i/120.0: #1/4x real time\n",
|
1298 |
+
" mujoco.mj_step(model, data)\n",
|
1299 |
+
" renderer.update_scene(data, \"track\", options)\n",
|
1300 |
+
" frame = renderer.render()\n",
|
1301 |
+
" frames.append(frame)\n",
|
1302 |
+
"\n",
|
1303 |
+
"media.show_video(frames, fps=30)"
|
1304 |
+
]
|
1305 |
+
},
|
1306 |
+
{
|
1307 |
+
"cell_type": "markdown",
|
1308 |
+
"metadata": {
|
1309 |
+
"id": "_181TbtVSMBl"
|
1310 |
+
},
|
1311 |
+
"source": [
|
1312 |
+
"## Analysis of contact forces\n",
|
1313 |
+
"\n",
|
1314 |
+
"Let's rerun the above simulation (with a different random initial condition) and\n",
|
1315 |
+
"plot some values related to the contacts"
|
1316 |
+
]
|
1317 |
+
},
|
1318 |
+
{
|
1319 |
+
"cell_type": "code",
|
1320 |
+
"execution_count": null,
|
1321 |
+
"metadata": {
|
1322 |
+
"id": "BMqyWeHki8Eg"
|
1323 |
+
},
|
1324 |
+
"outputs": [],
|
1325 |
+
"source": [
|
1326 |
+
"n_steps = 499\n",
|
1327 |
+
"\n",
|
1328 |
+
"# allocate\n",
|
1329 |
+
"sim_time = np.zeros(n_steps)\n",
|
1330 |
+
"ncon = np.zeros(n_steps)\n",
|
1331 |
+
"force = np.zeros((n_steps,3))\n",
|
1332 |
+
"velocity = np.zeros((n_steps, model.nv))\n",
|
1333 |
+
"penetration = np.zeros(n_steps)\n",
|
1334 |
+
"acceleration = np.zeros((n_steps, model.nv))\n",
|
1335 |
+
"forcetorque = np.zeros(6)\n",
|
1336 |
+
"\n",
|
1337 |
+
"# random initial rotational velocity:\n",
|
1338 |
+
"mujoco.mj_resetData(model, data)\n",
|
1339 |
+
"data.qvel[3:6] = 2*np.random.randn(3)\n",
|
1340 |
+
"\n",
|
1341 |
+
"# simulate and save data\n",
|
1342 |
+
"for i in range(n_steps):\n",
|
1343 |
+
" mujoco.mj_step(model, data)\n",
|
1344 |
+
" sim_time[i] = data.time\n",
|
1345 |
+
" ncon[i] = data.ncon\n",
|
1346 |
+
" velocity[i] = data.qvel[:]\n",
|
1347 |
+
" acceleration[i] = data.qacc[:]\n",
|
1348 |
+
" # iterate over active contacts, save force and distance\n",
|
1349 |
+
" for j,c in enumerate(data.contact):\n",
|
1350 |
+
" mujoco.mj_contactForce(model, data, j, forcetorque)\n",
|
1351 |
+
" force[i] += forcetorque[0:3]\n",
|
1352 |
+
" penetration[i] = min(penetration[i], c.dist)\n",
|
1353 |
+
" # we could also do\n",
|
1354 |
+
" # force[i] += data.qfrc_constraint[0:3]\n",
|
1355 |
+
" # do you see why?\n",
|
1356 |
+
"\n",
|
1357 |
+
"# plot\n",
|
1358 |
+
"_, ax = plt.subplots(3, 2, sharex=True, figsize=(10, 10))\n",
|
1359 |
+
"\n",
|
1360 |
+
"lines = ax[0,0].plot(sim_time, force)\n",
|
1361 |
+
"ax[0,0].set_title('contact force')\n",
|
1362 |
+
"ax[0,0].set_ylabel('Newton')\n",
|
1363 |
+
"ax[0,0].legend(lines, ('normal z', 'friction x', 'friction y'));\n",
|
1364 |
+
"\n",
|
1365 |
+
"ax[1,0].plot(sim_time, acceleration)\n",
|
1366 |
+
"ax[1,0].set_title('acceleration')\n",
|
1367 |
+
"ax[1,0].set_ylabel('(meter,radian)/s/s')\n",
|
1368 |
+
"ax[1,0].legend(['ax', 'ay', 'az', 'αx', 'αy', 'αz'])\n",
|
1369 |
+
"\n",
|
1370 |
+
"ax[2,0].plot(sim_time, velocity)\n",
|
1371 |
+
"ax[2,0].set_title('velocity')\n",
|
1372 |
+
"ax[2,0].set_ylabel('(meter,radian)/s')\n",
|
1373 |
+
"ax[2,0].set_xlabel('second')\n",
|
1374 |
+
"ax[2,0].legend(['vx', 'vy', 'vz', 'ωx', 'ωy', 'ωz'])\n",
|
1375 |
+
"\n",
|
1376 |
+
"ax[0,1].plot(sim_time, ncon)\n",
|
1377 |
+
"ax[0,1].set_title('number of contacts')\n",
|
1378 |
+
"ax[0,1].set_yticks(range(6))\n",
|
1379 |
+
"\n",
|
1380 |
+
"ax[1,1].plot(sim_time, force[:,0])\n",
|
1381 |
+
"ax[1,1].set_yscale('log')\n",
|
1382 |
+
"ax[1,1].set_title('normal (z) force - log scale')\n",
|
1383 |
+
"ax[1,1].set_ylabel('Newton')\n",
|
1384 |
+
"z_gravity = -model.opt.gravity[2]\n",
|
1385 |
+
"mg = model.body(\"box_and_sphere\").mass[0] * z_gravity\n",
|
1386 |
+
"mg_line = ax[1,1].plot(sim_time, np.ones(n_steps)*mg, label='m*g', linewidth=1)\n",
|
1387 |
+
"ax[1,1].legend()\n",
|
1388 |
+
"\n",
|
1389 |
+
"ax[2,1].plot(sim_time, 1000*penetration)\n",
|
1390 |
+
"ax[2,1].set_title('penetration depth')\n",
|
1391 |
+
"ax[2,1].set_ylabel('millimeter')\n",
|
1392 |
+
"ax[2,1].set_xlabel('second')\n",
|
1393 |
+
"\n",
|
1394 |
+
"plt.tight_layout()"
|
1395 |
+
]
|
1396 |
+
},
|
1397 |
+
{
|
1398 |
+
"cell_type": "markdown",
|
1399 |
+
"metadata": {
|
1400 |
+
"id": "zV5PkYzFXu42"
|
1401 |
+
},
|
1402 |
+
"source": [
|
1403 |
+
"## Friction\n",
|
1404 |
+
"\n",
|
1405 |
+
"Let's see the effect of changing friction values"
|
1406 |
+
]
|
1407 |
+
},
|
1408 |
+
{
|
1409 |
+
"cell_type": "code",
|
1410 |
+
"execution_count": null,
|
1411 |
+
"metadata": {
|
1412 |
+
"id": "2R_gKoYyXwda"
|
1413 |
+
},
|
1414 |
+
"outputs": [],
|
1415 |
+
"source": [
|
1416 |
+
"MJCF = \"\"\"\n",
|
1417 |
+
"<mujoco>\n",
|
1418 |
+
" <asset>\n",
|
1419 |
+
" <texture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".1 .2 .3\"\n",
|
1420 |
+
" rgb2=\".2 .3 .4\" width=\"300\" height=\"300\" mark=\"none\"/>\n",
|
1421 |
+
" <material name=\"grid\" texture=\"grid\" texrepeat=\"6 6\"\n",
|
1422 |
+
" texuniform=\"true\" reflectance=\".2\"/>\n",
|
1423 |
+
" <material name=\"wall\" rgba='.5 .5 .5 1'/>\n",
|
1424 |
+
" </asset>\n",
|
1425 |
+
"\n",
|
1426 |
+
" <default>\n",
|
1427 |
+
" <geom type=\"box\" size=\".05 .05 .05\" />\n",
|
1428 |
+
" <joint type=\"free\"/>\n",
|
1429 |
+
" </default>\n",
|
1430 |
+
"\n",
|
1431 |
+
" <worldbody>\n",
|
1432 |
+
" <light name=\"light\" pos=\"-.2 0 1\"/>\n",
|
1433 |
+
" <geom name=\"ground\" type=\"plane\" size=\".5 .5 10\" material=\"grid\"\n",
|
1434 |
+
" zaxis=\"-.3 0 1\" friction=\".1\"/>\n",
|
1435 |
+
" <camera name=\"y\" pos=\"-.1 -.6 .3\" xyaxes=\"1 0 0 0 1 2\"/>\n",
|
1436 |
+
" <body pos=\"0 0 .1\">\n",
|
1437 |
+
" <joint/>\n",
|
1438 |
+
" <geom/>\n",
|
1439 |
+
" </body>\n",
|
1440 |
+
" <body pos=\"0 .2 .1\">\n",
|
1441 |
+
" <joint/>\n",
|
1442 |
+
" <geom friction=\".33\"/>\n",
|
1443 |
+
" </body>\n",
|
1444 |
+
" </worldbody>\n",
|
1445 |
+
"\n",
|
1446 |
+
"</mujoco>\n",
|
1447 |
+
"\"\"\"\n",
|
1448 |
+
"n_frames = 60\n",
|
1449 |
+
"height = 300\n",
|
1450 |
+
"width = 300\n",
|
1451 |
+
"frames = []\n",
|
1452 |
+
"\n",
|
1453 |
+
"# load\n",
|
1454 |
+
"model = mujoco.MjModel.from_xml_string(MJCF)\n",
|
1455 |
+
"data = mujoco.MjData(model)\n",
|
1456 |
+
"\n",
|
1457 |
+
"# Simulate and display video.\n",
|
1458 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
1459 |
+
" mujoco.mj_resetData(model, data)\n",
|
1460 |
+
" for i in range(n_frames):\n",
|
1461 |
+
" while data.time < i/30.0:\n",
|
1462 |
+
" mujoco.mj_step(model, data)\n",
|
1463 |
+
" renderer.update_scene(data, \"y\")\n",
|
1464 |
+
" frame = renderer.render()\n",
|
1465 |
+
" frames.append(frame)\n",
|
1466 |
+
"\n",
|
1467 |
+
"media.show_video(frames, fps=30)"
|
1468 |
+
]
|
1469 |
+
},
|
1470 |
+
{
|
1471 |
+
"cell_type": "markdown",
|
1472 |
+
"metadata": {
|
1473 |
+
"id": "ArmmaPqGP6W7"
|
1474 |
+
},
|
1475 |
+
"source": [
|
1476 |
+
"# Tendons, actuators and sensors"
|
1477 |
+
]
|
1478 |
+
},
|
1479 |
+
{
|
1480 |
+
"cell_type": "code",
|
1481 |
+
"execution_count": null,
|
1482 |
+
"metadata": {
|
1483 |
+
"id": "VJz84c97c8Df"
|
1484 |
+
},
|
1485 |
+
"outputs": [],
|
1486 |
+
"source": [
|
1487 |
+
"MJCF = \"\"\"\n",
|
1488 |
+
"<mujoco>\n",
|
1489 |
+
" <asset>\n",
|
1490 |
+
" <texture name=\"grid\" type=\"2d\" builtin=\"checker\" rgb1=\".1 .2 .3\"\n",
|
1491 |
+
" rgb2=\".2 .3 .4\" width=\"300\" height=\"300\" mark=\"none\"/>\n",
|
1492 |
+
" <material name=\"grid\" texture=\"grid\" texrepeat=\"1 1\"\n",
|
1493 |
+
" texuniform=\"true\" reflectance=\".2\"/>\n",
|
1494 |
+
" </asset>\n",
|
1495 |
+
"\n",
|
1496 |
+
" <worldbody>\n",
|
1497 |
+
" <light name=\"light\" pos=\"0 0 1\"/>\n",
|
1498 |
+
" <geom name=\"floor\" type=\"plane\" pos=\"0 0 -.5\" size=\"2 2 .1\" material=\"grid\"/>\n",
|
1499 |
+
" <site name=\"anchor\" pos=\"0 0 .3\" size=\".01\"/>\n",
|
1500 |
+
" <camera name=\"fixed\" pos=\"0 -1.3 .5\" xyaxes=\"1 0 0 0 1 2\"/>\n",
|
1501 |
+
"\n",
|
1502 |
+
" <geom name=\"pole\" type=\"cylinder\" fromto=\".3 0 -.5 .3 0 -.1\" size=\".04\"/>\n",
|
1503 |
+
" <body name=\"bat\" pos=\".3 0 -.1\">\n",
|
1504 |
+
" <joint name=\"swing\" type=\"hinge\" damping=\"1\" axis=\"0 0 1\"/>\n",
|
1505 |
+
" <geom name=\"bat\" type=\"capsule\" fromto=\"0 0 .04 0 -.3 .04\"\n",
|
1506 |
+
" size=\".04\" rgba=\"0 0 1 1\"/>\n",
|
1507 |
+
" </body>\n",
|
1508 |
+
"\n",
|
1509 |
+
" <body name=\"box_and_sphere\" pos=\"0 0 0\">\n",
|
1510 |
+
" <joint name=\"free\" type=\"free\"/>\n",
|
1511 |
+
" <geom name=\"red_box\" type=\"box\" size=\".1 .1 .1\" rgba=\"1 0 0 1\"/>\n",
|
1512 |
+
" <geom name=\"green_sphere\" size=\".06\" pos=\".1 .1 .1\" rgba=\"0 1 0 1\"/>\n",
|
1513 |
+
" <site name=\"hook\" pos=\"-.1 -.1 -.1\" size=\".01\"/>\n",
|
1514 |
+
" <site name=\"IMU\"/>\n",
|
1515 |
+
" </body>\n",
|
1516 |
+
" </worldbody>\n",
|
1517 |
+
"\n",
|
1518 |
+
" <tendon>\n",
|
1519 |
+
" <spatial name=\"wire\" limited=\"true\" range=\"0 0.35\" width=\"0.003\">\n",
|
1520 |
+
" <site site=\"anchor\"/>\n",
|
1521 |
+
" <site site=\"hook\"/>\n",
|
1522 |
+
" </spatial>\n",
|
1523 |
+
" </tendon>\n",
|
1524 |
+
"\n",
|
1525 |
+
" <actuator>\n",
|
1526 |
+
" <motor name=\"my_motor\" joint=\"swing\" gear=\"1\"/>\n",
|
1527 |
+
" </actuator>\n",
|
1528 |
+
"\n",
|
1529 |
+
" <sensor>\n",
|
1530 |
+
" <accelerometer name=\"accelerometer\" site=\"IMU\"/>\n",
|
1531 |
+
" </sensor>\n",
|
1532 |
+
"</mujoco>\n",
|
1533 |
+
"\"\"\"\n",
|
1534 |
+
"model = mujoco.MjModel.from_xml_string(MJCF)\n",
|
1535 |
+
"data = mujoco.MjData(model)\n",
|
1536 |
+
"height = 480\n",
|
1537 |
+
"width = 480\n",
|
1538 |
+
"\n",
|
1539 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
1540 |
+
" mujoco.mj_forward(model, data)\n",
|
1541 |
+
" renderer.update_scene(data, \"fixed\")\n",
|
1542 |
+
"\n",
|
1543 |
+
" media.show_image(renderer.render())"
|
1544 |
+
]
|
1545 |
+
},
|
1546 |
+
{
|
1547 |
+
"cell_type": "markdown",
|
1548 |
+
"metadata": {
|
1549 |
+
"id": "u8z2vrOr_RVD"
|
1550 |
+
},
|
1551 |
+
"source": [
|
1552 |
+
"actuated bat and passive \"piñata\":"
|
1553 |
+
]
|
1554 |
+
},
|
1555 |
+
{
|
1556 |
+
"cell_type": "code",
|
1557 |
+
"execution_count": null,
|
1558 |
+
"metadata": {
|
1559 |
+
"id": "z-zoBCuBv2Xi"
|
1560 |
+
},
|
1561 |
+
"outputs": [],
|
1562 |
+
"source": [
|
1563 |
+
"n_frames = 180\n",
|
1564 |
+
"height = 240\n",
|
1565 |
+
"width = 320\n",
|
1566 |
+
"frames = []\n",
|
1567 |
+
"fps = 60.0\n",
|
1568 |
+
"times = []\n",
|
1569 |
+
"sensordata = []\n",
|
1570 |
+
"\n",
|
1571 |
+
"# constant actuator signal\n",
|
1572 |
+
"mujoco.mj_resetData(model, data)\n",
|
1573 |
+
"data.ctrl = 20\n",
|
1574 |
+
"\n",
|
1575 |
+
"# Simulate and display video.\n",
|
1576 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
1577 |
+
" for i in range(n_frames):\n",
|
1578 |
+
" while data.time < i/fps:\n",
|
1579 |
+
" mujoco.mj_step(model, data)\n",
|
1580 |
+
" times.append(data.time)\n",
|
1581 |
+
" sensordata.append(data.sensor('accelerometer').data.copy())\n",
|
1582 |
+
" renderer.update_scene(data, \"fixed\")\n",
|
1583 |
+
" frame = renderer.render()\n",
|
1584 |
+
" frames.append(frame)\n",
|
1585 |
+
"\n",
|
1586 |
+
"media.show_video(frames, fps=fps)"
|
1587 |
+
]
|
1588 |
+
},
|
1589 |
+
{
|
1590 |
+
"cell_type": "markdown",
|
1591 |
+
"metadata": {
|
1592 |
+
"id": "gwHMy_iRA7Jh"
|
1593 |
+
},
|
1594 |
+
"source": [
|
1595 |
+
"Let's plot the values measured by our accelerometer sensor:"
|
1596 |
+
]
|
1597 |
+
},
|
1598 |
+
{
|
1599 |
+
"cell_type": "code",
|
1600 |
+
"execution_count": null,
|
1601 |
+
"metadata": {
|
1602 |
+
"id": "uy4wSEMAAJgn"
|
1603 |
+
},
|
1604 |
+
"outputs": [],
|
1605 |
+
"source": [
|
1606 |
+
"ax = plt.gca()\n",
|
1607 |
+
"\n",
|
1608 |
+
"ax.plot(np.asarray(times), np.asarray(sensordata), label=[f\"axis {v}\" for v in ['x', 'y', 'z']])\n",
|
1609 |
+
"\n",
|
1610 |
+
"# finalize plot\n",
|
1611 |
+
"ax.set_title('Accelerometer values')\n",
|
1612 |
+
"ax.set_ylabel('meter/second^2')\n",
|
1613 |
+
"ax.set_xlabel('second')\n",
|
1614 |
+
"ax.legend(frameon=True, loc='lower right')\n",
|
1615 |
+
"plt.tight_layout()"
|
1616 |
+
]
|
1617 |
+
},
|
1618 |
+
{
|
1619 |
+
"cell_type": "markdown",
|
1620 |
+
"metadata": {
|
1621 |
+
"id": "0YKSTtJ_BQ7x"
|
1622 |
+
},
|
1623 |
+
"source": [
|
1624 |
+
"Note how the moments when the body is hit by the bat are clearly visible in the accelerometer measurements."
|
1625 |
+
]
|
1626 |
+
},
|
1627 |
+
{
|
1628 |
+
"cell_type": "markdown",
|
1629 |
+
"metadata": {
|
1630 |
+
"id": "1kOs1wTc7uCZ"
|
1631 |
+
},
|
1632 |
+
"source": [
|
1633 |
+
"# Advanced rendering\n",
|
1634 |
+
"\n",
|
1635 |
+
"Like joint visualization, additional rendering options are exposed as parameters to the `render` method.\n",
|
1636 |
+
"\n",
|
1637 |
+
"Let's bring back our first model:"
|
1638 |
+
]
|
1639 |
+
},
|
1640 |
+
{
|
1641 |
+
"cell_type": "code",
|
1642 |
+
"execution_count": null,
|
1643 |
+
"metadata": {
|
1644 |
+
"id": "mTDgsk2xcgwH"
|
1645 |
+
},
|
1646 |
+
"outputs": [],
|
1647 |
+
"source": [
|
1648 |
+
"xml = \"\"\"\n",
|
1649 |
+
"<mujoco>\n",
|
1650 |
+
" <worldbody>\n",
|
1651 |
+
" <light name=\"top\" pos=\"0 0 1\"/>\n",
|
1652 |
+
" <body name=\"box_and_sphere\" euler=\"0 0 -30\">\n",
|
1653 |
+
" <joint name=\"swing\" type=\"hinge\" axis=\"1 -1 0\" pos=\"-.2 -.2 -.2\"/>\n",
|
1654 |
+
" <geom name=\"red_box\" type=\"box\" size=\".2 .2 .2\" rgba=\"1 0 0 1\"/>\n",
|
1655 |
+
" <geom name=\"green_sphere\" pos=\".2 .2 .2\" size=\".1\" rgba=\"0 1 0 1\"/>\n",
|
1656 |
+
" </body>\n",
|
1657 |
+
" </worldbody>\n",
|
1658 |
+
"</mujoco>\n",
|
1659 |
+
"\"\"\"\n",
|
1660 |
+
"model = mujoco.MjModel.from_xml_string(xml)\n",
|
1661 |
+
"data = mujoco.MjData(model)\n",
|
1662 |
+
"\n",
|
1663 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
1664 |
+
" mujoco.mj_forward(model, data)\n",
|
1665 |
+
" renderer.update_scene(data)\n",
|
1666 |
+
" media.show_image(renderer.render())"
|
1667 |
+
]
|
1668 |
+
},
|
1669 |
+
{
|
1670 |
+
"cell_type": "code",
|
1671 |
+
"execution_count": null,
|
1672 |
+
"metadata": {
|
1673 |
+
"id": "VePXamL_6XUc"
|
1674 |
+
},
|
1675 |
+
"outputs": [],
|
1676 |
+
"source": [
|
1677 |
+
"#@title Enable transparency and frame visualization {vertical-output: true}\n",
|
1678 |
+
"\n",
|
1679 |
+
"scene_option.frame = mujoco.mjtFrame.mjFRAME_GEOM\n",
|
1680 |
+
"scene_option.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True\n",
|
1681 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
1682 |
+
" renderer.update_scene(data, scene_option=scene_option)\n",
|
1683 |
+
" frame = renderer.render()\n",
|
1684 |
+
" media.show_image(frame)"
|
1685 |
+
]
|
1686 |
+
},
|
1687 |
+
{
|
1688 |
+
"cell_type": "code",
|
1689 |
+
"execution_count": null,
|
1690 |
+
"metadata": {
|
1691 |
+
"id": "PVcpcvww9lZ8"
|
1692 |
+
},
|
1693 |
+
"outputs": [],
|
1694 |
+
"source": [
|
1695 |
+
"#@title Depth rendering {vertical-output: true}\n",
|
1696 |
+
"\n",
|
1697 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
1698 |
+
" # update renderer to render depth\n",
|
1699 |
+
" renderer.enable_depth_rendering()\n",
|
1700 |
+
"\n",
|
1701 |
+
" # reset the scene\n",
|
1702 |
+
" renderer.update_scene(data)\n",
|
1703 |
+
"\n",
|
1704 |
+
" # depth is a float array, in meters.\n",
|
1705 |
+
" depth = renderer.render()\n",
|
1706 |
+
"\n",
|
1707 |
+
" # Shift nearest values to the origin.\n",
|
1708 |
+
" depth -= depth.min()\n",
|
1709 |
+
" # Scale by 2 mean distances of near rays.\n",
|
1710 |
+
" depth /= 2*depth[depth <= 1].mean()\n",
|
1711 |
+
" # Scale to [0, 255]\n",
|
1712 |
+
" pixels = 255*np.clip(depth, 0, 1)\n",
|
1713 |
+
"\n",
|
1714 |
+
" media.show_image(pixels.astype(np.uint8))"
|
1715 |
+
]
|
1716 |
+
},
|
1717 |
+
{
|
1718 |
+
"cell_type": "code",
|
1719 |
+
"execution_count": null,
|
1720 |
+
"metadata": {
|
1721 |
+
"id": "PNwiIrgpx7T8"
|
1722 |
+
},
|
1723 |
+
"outputs": [],
|
1724 |
+
"source": [
|
1725 |
+
"#@title Segmentation rendering {vertical-output: true}\n",
|
1726 |
+
"\n",
|
1727 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
1728 |
+
" renderer.disable_depth_rendering()\n",
|
1729 |
+
"\n",
|
1730 |
+
" # update renderer to render segmentation\n",
|
1731 |
+
" renderer.enable_segmentation_rendering()\n",
|
1732 |
+
"\n",
|
1733 |
+
" # reset the scene\n",
|
1734 |
+
" renderer.update_scene(data)\n",
|
1735 |
+
"\n",
|
1736 |
+
" seg = renderer.render()\n",
|
1737 |
+
"\n",
|
1738 |
+
" # Display the contents of the first channel, which contains object\n",
|
1739 |
+
" # IDs. The second channel, seg[:, :, 1], contains object types.\n",
|
1740 |
+
" geom_ids = seg[:, :, 0]\n",
|
1741 |
+
" # Infinity is mapped to -1\n",
|
1742 |
+
" geom_ids = geom_ids.astype(np.float64) + 1\n",
|
1743 |
+
" # Scale to [0, 1]\n",
|
1744 |
+
" geom_ids = geom_ids / geom_ids.max()\n",
|
1745 |
+
" pixels = 255*geom_ids\n",
|
1746 |
+
" media.show_image(pixels.astype(np.uint8))"
|
1747 |
+
]
|
1748 |
+
},
|
1749 |
+
{
|
1750 |
+
"cell_type": "markdown",
|
1751 |
+
"metadata": {
|
1752 |
+
"id": "wo72mo0mGIXr"
|
1753 |
+
},
|
1754 |
+
"source": [
|
1755 |
+
"## The camera matrix\n",
|
1756 |
+
"\n",
|
1757 |
+
"For a description of the camera matrix see the article [Camera matrix](https://en.wikipedia.org/wiki/Camera_matrix) on Wikipedia."
|
1758 |
+
]
|
1759 |
+
},
|
1760 |
+
{
|
1761 |
+
"cell_type": "code",
|
1762 |
+
"execution_count": null,
|
1763 |
+
"metadata": {
|
1764 |
+
"id": "sDYwClpxaxab"
|
1765 |
+
},
|
1766 |
+
"outputs": [],
|
1767 |
+
"source": [
|
1768 |
+
"def compute_camera_matrix(renderer, data):\n",
|
1769 |
+
" \"\"\"Returns the 3x4 camera matrix.\"\"\"\n",
|
1770 |
+
" # If the camera is a 'free' camera, we get its position and orientation\n",
|
1771 |
+
" # from the scene data structure. It is a stereo camera, so we average over\n",
|
1772 |
+
" # the left and right channels. Note: we call `self.update()` in order to\n",
|
1773 |
+
" # ensure that the contents of `scene.camera` are correct.\n",
|
1774 |
+
" renderer.update_scene(data)\n",
|
1775 |
+
" pos = np.mean([camera.pos for camera in renderer.scene.camera], axis=0)\n",
|
1776 |
+
" z = -np.mean([camera.forward for camera in renderer.scene.camera], axis=0)\n",
|
1777 |
+
" y = np.mean([camera.up for camera in renderer.scene.camera], axis=0)\n",
|
1778 |
+
" rot = np.vstack((np.cross(y, z), y, z))\n",
|
1779 |
+
" fov = model.vis.global_.fovy\n",
|
1780 |
+
"\n",
|
1781 |
+
" # Translation matrix (4x4).\n",
|
1782 |
+
" translation = np.eye(4)\n",
|
1783 |
+
" translation[0:3, 3] = -pos\n",
|
1784 |
+
"\n",
|
1785 |
+
" # Rotation matrix (4x4).\n",
|
1786 |
+
" rotation = np.eye(4)\n",
|
1787 |
+
" rotation[0:3, 0:3] = rot\n",
|
1788 |
+
"\n",
|
1789 |
+
" # Focal transformation matrix (3x4).\n",
|
1790 |
+
" focal_scaling = (1./np.tan(np.deg2rad(fov)/2)) * renderer.height / 2.0\n",
|
1791 |
+
" focal = np.diag([-focal_scaling, focal_scaling, 1.0, 0])[0:3, :]\n",
|
1792 |
+
"\n",
|
1793 |
+
" # Image matrix (3x3).\n",
|
1794 |
+
" image = np.eye(3)\n",
|
1795 |
+
" image[0, 2] = (renderer.width - 1) / 2.0\n",
|
1796 |
+
" image[1, 2] = (renderer.height - 1) / 2.0\n",
|
1797 |
+
" return image @ focal @ rotation @ translation"
|
1798 |
+
]
|
1799 |
+
},
|
1800 |
+
{
|
1801 |
+
"cell_type": "code",
|
1802 |
+
"execution_count": null,
|
1803 |
+
"metadata": {
|
1804 |
+
"id": "My0N4_7PDJ_q"
|
1805 |
+
},
|
1806 |
+
"outputs": [],
|
1807 |
+
"source": [
|
1808 |
+
"#@title Project from world to camera coordinates {vertical-output: true}\n",
|
1809 |
+
"\n",
|
1810 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
1811 |
+
" renderer.disable_segmentation_rendering()\n",
|
1812 |
+
" # reset the scene\n",
|
1813 |
+
" renderer.update_scene(data)\n",
|
1814 |
+
"\n",
|
1815 |
+
" # Get the world coordinates of the box corners\n",
|
1816 |
+
" box_pos = data.geom_xpos[model.geom('red_box').id]\n",
|
1817 |
+
" box_mat = data.geom_xmat[model.geom('red_box').id].reshape(3, 3)\n",
|
1818 |
+
" box_size = model.geom_size[model.geom('red_box').id]\n",
|
1819 |
+
" offsets = np.array([-1, 1]) * box_size[:, None]\n",
|
1820 |
+
" xyz_local = np.stack(list(itertools.product(*offsets))).T\n",
|
1821 |
+
" xyz_global = box_pos[:, None] + box_mat @ xyz_local\n",
|
1822 |
+
"\n",
|
1823 |
+
" # Camera matrices multiply homogenous [x, y, z, 1] vectors.\n",
|
1824 |
+
" corners_homogeneous = np.ones((4, xyz_global.shape[1]), dtype=float)\n",
|
1825 |
+
" corners_homogeneous[:3, :] = xyz_global\n",
|
1826 |
+
"\n",
|
1827 |
+
" # Get the camera matrix.\n",
|
1828 |
+
" m = compute_camera_matrix(renderer, data)\n",
|
1829 |
+
"\n",
|
1830 |
+
" # Project world coordinates into pixel space. See:\n",
|
1831 |
+
" # https://en.wikipedia.org/wiki/3D_projection#Mathematical_formula\n",
|
1832 |
+
" xs, ys, s = m @ corners_homogeneous\n",
|
1833 |
+
" # x and y are in the pixel coordinate system.\n",
|
1834 |
+
" x = xs / s\n",
|
1835 |
+
" y = ys / s\n",
|
1836 |
+
"\n",
|
1837 |
+
" # Render the camera view and overlay the projected corner coordinates.\n",
|
1838 |
+
" pixels = renderer.render()\n",
|
1839 |
+
" fig, ax = plt.subplots(1, 1)\n",
|
1840 |
+
" ax.imshow(pixels)\n",
|
1841 |
+
" ax.plot(x, y, '+', c='w')\n",
|
1842 |
+
" ax.set_axis_off()"
|
1843 |
+
]
|
1844 |
+
},
|
1845 |
+
{
|
1846 |
+
"cell_type": "markdown",
|
1847 |
+
"metadata": {
|
1848 |
+
"id": "AGm5-e0sHEAF"
|
1849 |
+
},
|
1850 |
+
"source": [
|
1851 |
+
"## Modifying the scene\n",
|
1852 |
+
"\n",
|
1853 |
+
"Let's add some arbitrary geometry to the `mjvScene`."
|
1854 |
+
]
|
1855 |
+
},
|
1856 |
+
{
|
1857 |
+
"cell_type": "code",
|
1858 |
+
"execution_count": null,
|
1859 |
+
"metadata": {
|
1860 |
+
"id": "Z6NDYJ8IOVt7"
|
1861 |
+
},
|
1862 |
+
"outputs": [],
|
1863 |
+
"source": [
|
1864 |
+
"def get_geom_speed(model, data, geom_name):\n",
|
1865 |
+
" \"\"\"Returns the speed of a geom.\"\"\"\n",
|
1866 |
+
" geom_vel = np.zeros(6)\n",
|
1867 |
+
" geom_type = mujoco.mjtObj.mjOBJ_GEOM\n",
|
1868 |
+
" geom_id = data.geom(geom_name).id\n",
|
1869 |
+
" mujoco.mj_objectVelocity(model, data, geom_type, geom_id, geom_vel, 0)\n",
|
1870 |
+
" return np.linalg.norm(geom_vel)\n",
|
1871 |
+
"\n",
|
1872 |
+
"def add_visual_capsule(scene, point1, point2, radius, rgba):\n",
|
1873 |
+
" \"\"\"Adds one capsule to an mjvScene.\"\"\"\n",
|
1874 |
+
" if scene.ngeom >= scene.maxgeom:\n",
|
1875 |
+
" return\n",
|
1876 |
+
" scene.ngeom += 1 # increment ngeom\n",
|
1877 |
+
" # initialise a new capsule, add it to the scene using mjv_connector\n",
|
1878 |
+
" mujoco.mjv_initGeom(scene.geoms[scene.ngeom-1],\n",
|
1879 |
+
" mujoco.mjtGeom.mjGEOM_CAPSULE, np.zeros(3),\n",
|
1880 |
+
" np.zeros(3), np.zeros(9), rgba.astype(np.float32))\n",
|
1881 |
+
" mujoco.mjv_connector(scene.geoms[scene.ngeom-1],\n",
|
1882 |
+
" mujoco.mjtGeom.mjGEOM_CAPSULE, radius,\n",
|
1883 |
+
" point1, point2)\n",
|
1884 |
+
"\n",
|
1885 |
+
" # traces of time, position and speed\n",
|
1886 |
+
"times = []\n",
|
1887 |
+
"positions = []\n",
|
1888 |
+
"speeds = []\n",
|
1889 |
+
"offset = model.jnt_axis[0]/16 # offset along the joint axis\n",
|
1890 |
+
"\n",
|
1891 |
+
"def modify_scene(scn):\n",
|
1892 |
+
" \"\"\"Draw position trace, speed modifies width and colors.\"\"\"\n",
|
1893 |
+
" if len(positions) > 1:\n",
|
1894 |
+
" for i in range(len(positions)-1):\n",
|
1895 |
+
" rgba=np.array((np.clip(speeds[i]/10, 0, 1),\n",
|
1896 |
+
" np.clip(1-speeds[i]/10, 0, 1),\n",
|
1897 |
+
" .5, 1.))\n",
|
1898 |
+
" radius=.003*(1+speeds[i])\n",
|
1899 |
+
" point1 = positions[i] + offset*times[i]\n",
|
1900 |
+
" point2 = positions[i+1] + offset*times[i+1]\n",
|
1901 |
+
" add_visual_capsule(scn, point1, point2, radius, rgba)\n",
|
1902 |
+
"\n",
|
1903 |
+
"duration = 6 # (seconds)\n",
|
1904 |
+
"framerate = 30 # (Hz)\n",
|
1905 |
+
"\n",
|
1906 |
+
"# Simulate and display video.\n",
|
1907 |
+
"frames = []\n",
|
1908 |
+
"\n",
|
1909 |
+
"# Reset state and time.\n",
|
1910 |
+
"mujoco.mj_resetData(model, data)\n",
|
1911 |
+
"mujoco.mj_forward(model, data)\n",
|
1912 |
+
"\n",
|
1913 |
+
"with mujoco.Renderer(model) as renderer:\n",
|
1914 |
+
" while data.time < duration:\n",
|
1915 |
+
" # append data to the traces\n",
|
1916 |
+
" positions.append(data.geom_xpos[data.geom(\"green_sphere\").id].copy())\n",
|
1917 |
+
" times.append(data.time)\n",
|
1918 |
+
" speeds.append(get_geom_speed(model, data, \"green_sphere\"))\n",
|
1919 |
+
" mujoco.mj_step(model, data)\n",
|
1920 |
+
" if len(frames) < data.time * framerate:\n",
|
1921 |
+
" renderer.update_scene(data)\n",
|
1922 |
+
" modify_scene(renderer.scene)\n",
|
1923 |
+
" pixels = renderer.render()\n",
|
1924 |
+
" frames.append(pixels)\n",
|
1925 |
+
"\n",
|
1926 |
+
"media.show_video(frames, fps=framerate)"
|
1927 |
+
]
|
1928 |
+
},
|
1929 |
+
{
|
1930 |
+
"cell_type": "markdown",
|
1931 |
+
"metadata": {
|
1932 |
+
"id": "p6wHrjRg1EGF"
|
1933 |
+
},
|
1934 |
+
"source": [
|
1935 |
+
"## Multiple frames in the same scene\n",
|
1936 |
+
"\n",
|
1937 |
+
"Sometimes one would like to draw the same geometry multiple times, for example when a model is tracking states from motion-capture, it's nice to have the data\n",
|
1938 |
+
"visualized next to the model. Unlike `mjv_updateScene` (called by the `Renderer`'s `update_scene` method) which clears the scene at every call, `mjv_addGeoms` will add visual geoms to an existing scene:"
|
1939 |
+
]
|
1940 |
+
},
|
1941 |
+
{
|
1942 |
+
"cell_type": "code",
|
1943 |
+
"execution_count": null,
|
1944 |
+
"metadata": {
|
1945 |
+
"id": "T4b_8n6t1ASk"
|
1946 |
+
},
|
1947 |
+
"outputs": [],
|
1948 |
+
"source": [
|
1949 |
+
"# Get MuJoCo's standard humanoid model.\n",
|
1950 |
+
"print('Getting MuJoCo humanoid XML description from GitHub:')\n",
|
1951 |
+
"!git clone https://github.com/google-deepmind/mujoco\n",
|
1952 |
+
"with open('mujoco/model/humanoid/humanoid.xml', 'r') as f:\n",
|
1953 |
+
" xml = f.read()\n",
|
1954 |
+
"\n",
|
1955 |
+
"# Load the model, make two MjData's.\n",
|
1956 |
+
"model = mujoco.MjModel.from_xml_string(xml)\n",
|
1957 |
+
"data = mujoco.MjData(model)\n",
|
1958 |
+
"data2 = mujoco.MjData(model)\n",
|
1959 |
+
"\n",
|
1960 |
+
"# Episode parameters.\n",
|
1961 |
+
"duration = 3 # (seconds)\n",
|
1962 |
+
"framerate = 60 # (Hz)\n",
|
1963 |
+
"data.qpos[0:2] = [-.5, -.5] # Initial x-y position (m)\n",
|
1964 |
+
"data.qvel[2] = 4 # Initial vertical velocity (m/s)\n",
|
1965 |
+
"ctrl_phase = 2 * np.pi * np.random.rand(model.nu) # Control phase\n",
|
1966 |
+
"ctrl_freq = 1 # Control frequency\n",
|
1967 |
+
"\n",
|
1968 |
+
"# Visual options for the \"ghost\" model.\n",
|
1969 |
+
"vopt2 = mujoco.MjvOption()\n",
|
1970 |
+
"vopt2.flags[mujoco.mjtVisFlag.mjVIS_TRANSPARENT] = True # Transparent.\n",
|
1971 |
+
"pert = mujoco.MjvPerturb() # Empty MjvPerturb object\n",
|
1972 |
+
"# We only want dynamic objects (the humanoid). Static objects (the floor)\n",
|
1973 |
+
"# should not be re-drawn. The mjtCatBit flag lets us do that, though we could\n",
|
1974 |
+
"# equivalently use mjtVisFlag.mjVIS_STATIC\n",
|
1975 |
+
"catmask = mujoco.mjtCatBit.mjCAT_DYNAMIC\n",
|
1976 |
+
"\n",
|
1977 |
+
"# Simulate and render.\n",
|
1978 |
+
"frames = []\n",
|
1979 |
+
"with mujoco.Renderer(model, 480, 640) as renderer:\n",
|
1980 |
+
" while data.time < duration:\n",
|
1981 |
+
" # Sinusoidal control signal.\n",
|
1982 |
+
" data.ctrl = np.sin(ctrl_phase + 2 * np.pi * data.time * ctrl_freq)\n",
|
1983 |
+
" mujoco.mj_step(model, data)\n",
|
1984 |
+
" if len(frames) < data.time * framerate:\n",
|
1985 |
+
" # This draws the regular humanoid from `data`.\n",
|
1986 |
+
" renderer.update_scene(data)\n",
|
1987 |
+
"\n",
|
1988 |
+
" # Copy qpos to data2, move the humanoid sideways, call mj_forward.\n",
|
1989 |
+
" data2.qpos = data.qpos\n",
|
1990 |
+
" data2.qpos[0] += 1.5\n",
|
1991 |
+
" data2.qpos[1] += 1\n",
|
1992 |
+
" mujoco.mj_forward(model, data2)\n",
|
1993 |
+
"\n",
|
1994 |
+
" # Call mjv_addGeoms to add the ghost humanoid to the scene.\n",
|
1995 |
+
" mujoco.mjv_addGeoms(model, data2, vopt2, pert, catmask, renderer.scene)\n",
|
1996 |
+
"\n",
|
1997 |
+
" # Render and add the frame.\n",
|
1998 |
+
" pixels = renderer.render()\n",
|
1999 |
+
" frames.append(pixels)\n",
|
2000 |
+
"\n",
|
2001 |
+
"# Render video at half real-time.\n",
|
2002 |
+
"media.show_video(frames, fps=framerate/2)"
|
2003 |
+
]
|
2004 |
+
},
|
2005 |
+
{
|
2006 |
+
"cell_type": "markdown",
|
2007 |
+
"metadata": {
|
2008 |
+
"id": "Zzzugf-qPExb"
|
2009 |
+
},
|
2010 |
+
"source": [
|
2011 |
+
"## Camera control\n",
|
2012 |
+
"\n",
|
2013 |
+
"Cameras can be controlled dynamically in order to achieve cinematic effects. Run the three cells below to see the difference between rendering from a static and moving camera.\n",
|
2014 |
+
"\n",
|
2015 |
+
"The camera-control code smoothly transitions between two trajectories, one orbiting a fixed point, the other tracking a moving object. Parameter values in the code were obtained by iterating quickly on low-res videos."
|
2016 |
+
]
|
2017 |
+
},
|
2018 |
+
{
|
2019 |
+
"cell_type": "code",
|
2020 |
+
"execution_count": null,
|
2021 |
+
"metadata": {
|
2022 |
+
"cellView": "form",
|
2023 |
+
"id": "-SW-K9WuPGrp"
|
2024 |
+
},
|
2025 |
+
"outputs": [],
|
2026 |
+
"source": [
|
2027 |
+
"#@title Load the \"dominos\" model\n",
|
2028 |
+
"\n",
|
2029 |
+
"dominos_xml = \"\"\"\n",
|
2030 |
+
"<mujoco>\n",
|
2031 |
+
" <asset>\n",
|
2032 |
+
" <texture type=\"skybox\" builtin=\"gradient\" rgb1=\".3 .5 .7\" rgb2=\"0 0 0\" width=\"32\" height=\"512\"/>\n",
|
2033 |
+
" <texture name=\"grid\" type=\"2d\" builtin=\"checker\" width=\"512\" height=\"512\" rgb1=\".1 .2 .3\" rgb2=\".2 .3 .4\"/>\n",
|
2034 |
+
" <material name=\"grid\" texture=\"grid\" texrepeat=\"2 2\" texuniform=\"true\" reflectance=\".2\"/>\n",
|
2035 |
+
" </asset>\n",
|
2036 |
+
"\n",
|
2037 |
+
" <statistic meansize=\".01\"/>\n",
|
2038 |
+
"\n",
|
2039 |
+
" <visual>\n",
|
2040 |
+
" <global offheight=\"2160\" offwidth=\"3840\"/>\n",
|
2041 |
+
" <quality offsamples=\"8\"/>\n",
|
2042 |
+
" </visual>\n",
|
2043 |
+
"\n",
|
2044 |
+
" <default>\n",
|
2045 |
+
" <geom type=\"box\" solref=\".005 1\"/>\n",
|
2046 |
+
" <default class=\"static\">\n",
|
2047 |
+
" <geom rgba=\".3 .5 .7 1\"/>\n",
|
2048 |
+
" </default>\n",
|
2049 |
+
" </default>\n",
|
2050 |
+
"\n",
|
2051 |
+
" <option timestep=\"5e-4\"/>\n",
|
2052 |
+
"\n",
|
2053 |
+
" <worldbody>\n",
|
2054 |
+
" <light pos=\".3 -.3 .8\" mode=\"trackcom\" diffuse=\"1 1 1\" specular=\".3 .3 .3\"/>\n",
|
2055 |
+
" <light pos=\"0 -.3 .4\" mode=\"targetbodycom\" target=\"box\" diffuse=\".8 .8 .8\" specular=\".3 .3 .3\"/>\n",
|
2056 |
+
" <geom name=\"floor\" type=\"plane\" size=\"3 3 .01\" pos=\"-0.025 -0.295 0\" material=\"grid\"/>\n",
|
2057 |
+
" <geom name=\"ramp\" pos=\".25 -.45 -.03\" size=\".04 .1 .07\" euler=\"-30 0 0\" class=\"static\"/>\n",
|
2058 |
+
" <camera name=\"top\" pos=\"-0.37 -0.78 0.49\" xyaxes=\"0.78 -0.63 0 0.27 0.33 0.9\"/>\n",
|
2059 |
+
"\n",
|
2060 |
+
" <body name=\"ball\" pos=\".25 -.45 .1\">\n",
|
2061 |
+
" <freejoint name=\"ball\"/>\n",
|
2062 |
+
" <geom name=\"ball\" type=\"sphere\" size=\".02\" rgba=\".65 .81 .55 1\"/>\n",
|
2063 |
+
" </body>\n",
|
2064 |
+
"\n",
|
2065 |
+
" <body pos=\".26 -.3 .03\" euler=\"0 0 -90.0\">\n",
|
2066 |
+
" <freejoint/>\n",
|
2067 |
+
" <geom size=\".0015 .015 .03\" rgba=\"1 .5 .5 1\"/>\n",
|
2068 |
+
" </body>\n",
|
2069 |
+
"\n",
|
2070 |
+
" <body pos=\".26 -.27 .04\" euler=\"0 0 -81.0\">\n",
|
2071 |
+
" <freejoint/>\n",
|
2072 |
+
" <geom size=\".002 .02 .04\" rgba=\"1 1 .5 1\"/>\n",
|
2073 |
+
" </body>\n",
|
2074 |
+
"\n",
|
2075 |
+
" <body pos=\".24 -.21 .06\" euler=\"0 0 -63.0\">\n",
|
2076 |
+
" <freejoint/>\n",
|
2077 |
+
" <geom size=\".003 .03 .06\" rgba=\".5 1 .5 1\"/>\n",
|
2078 |
+
" </body>\n",
|
2079 |
+
"\n",
|
2080 |
+
" <body pos=\".2 -.16 .08\" euler=\"0 0 -45.0\">\n",
|
2081 |
+
" <freejoint/>\n",
|
2082 |
+
" <geom size=\".004 .04 .08\" rgba=\".5 1 1 1\"/>\n",
|
2083 |
+
" </body>\n",
|
2084 |
+
"\n",
|
2085 |
+
" <body pos=\".15 -.12 .1\" euler=\"0 0 -27.0\">\n",
|
2086 |
+
" <freejoint/>\n",
|
2087 |
+
" <geom size=\".005 .05 .1\" rgba=\".5 .5 1 1\"/>\n",
|
2088 |
+
" </body>\n",
|
2089 |
+
"\n",
|
2090 |
+
" <body pos=\".09 -.1 .12\" euler=\"0 0 -9.0\">\n",
|
2091 |
+
" <freejoint/>\n",
|
2092 |
+
" <geom size=\".006 .06 .12\" rgba=\"1 .5 1 1\"/>\n",
|
2093 |
+
" </body>\n",
|
2094 |
+
"\n",
|
2095 |
+
" <body name=\"seasaw_wrapper\" pos=\"-.23 -.1 0\" euler=\"0 0 30\">\n",
|
2096 |
+
" <geom size=\".01 .01 .015\" pos=\"0 .05 .015\" class=\"static\"/>\n",
|
2097 |
+
" <geom size=\".01 .01 .015\" pos=\"0 -.05 .015\" class=\"static\"/>\n",
|
2098 |
+
" <geom type=\"cylinder\" size=\".01 .0175\" pos=\"-.09 0 .0175\" class=\"static\"/>\n",
|
2099 |
+
" <body name=\"seasaw\" pos=\"0 0 .03\">\n",
|
2100 |
+
" <joint axis=\"0 1 0\"/>\n",
|
2101 |
+
" <geom type=\"cylinder\" size=\".005 .039\" zaxis=\"0 1 0\" rgba=\".84 .15 .33 1\"/>\n",
|
2102 |
+
" <geom size=\".1 .02 .005\" pos=\"0 0 .01\" rgba=\".84 .15 .33 1\"/>\n",
|
2103 |
+
" </body>\n",
|
2104 |
+
" </body>\n",
|
2105 |
+
"\n",
|
2106 |
+
" <body name=\"box\" pos=\"-.3 -.14 .05501\" euler=\"0 0 -30\">\n",
|
2107 |
+
" <freejoint name=\"box\"/>\n",
|
2108 |
+
" <geom name=\"box\" size=\".01 .01 .01\" rgba=\".0 .7 .79 1\"/>\n",
|
2109 |
+
" </body>\n",
|
2110 |
+
" </worldbody>\n",
|
2111 |
+
"</mujoco>\n",
|
2112 |
+
"\"\"\"\n",
|
2113 |
+
"model = mujoco.MjModel.from_xml_string(dominos_xml)\n",
|
2114 |
+
"data = mujoco.MjData(model)\n"
|
2115 |
+
]
|
2116 |
+
},
|
2117 |
+
{
|
2118 |
+
"cell_type": "code",
|
2119 |
+
"execution_count": null,
|
2120 |
+
"metadata": {
|
2121 |
+
"cellView": "form",
|
2122 |
+
"id": "a2WruafiPhPk"
|
2123 |
+
},
|
2124 |
+
"outputs": [],
|
2125 |
+
"source": [
|
2126 |
+
"#@title Render from fixed camera\n",
|
2127 |
+
"\n",
|
2128 |
+
"duration = 2.5 # (seconds)\n",
|
2129 |
+
"framerate = 60 # (Hz)\n",
|
2130 |
+
"height = 1024\n",
|
2131 |
+
"width = 1440\n",
|
2132 |
+
"\n",
|
2133 |
+
"# Simulate and display video.\n",
|
2134 |
+
"frames = []\n",
|
2135 |
+
"mujoco.mj_resetData(model, data) # Reset state and time.\n",
|
2136 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
2137 |
+
" while data.time < duration:\n",
|
2138 |
+
" mujoco.mj_step(model, data)\n",
|
2139 |
+
" if len(frames) < data.time * framerate:\n",
|
2140 |
+
" renderer.update_scene(data, camera='top')\n",
|
2141 |
+
" pixels = renderer.render()\n",
|
2142 |
+
" frames.append(pixels)\n",
|
2143 |
+
"\n",
|
2144 |
+
"media.show_video(frames, fps=framerate)"
|
2145 |
+
]
|
2146 |
+
},
|
2147 |
+
{
|
2148 |
+
"cell_type": "code",
|
2149 |
+
"execution_count": null,
|
2150 |
+
"metadata": {
|
2151 |
+
"cellView": "form",
|
2152 |
+
"id": "Kie3y-27bQ3J"
|
2153 |
+
},
|
2154 |
+
"outputs": [],
|
2155 |
+
"source": [
|
2156 |
+
"#@title Render from moving camera\n",
|
2157 |
+
"\n",
|
2158 |
+
"duration = 3 # (seconds)\n",
|
2159 |
+
"height = 1024\n",
|
2160 |
+
"width = 1440\n",
|
2161 |
+
"\n",
|
2162 |
+
"# find time when box is thrown (speed > 2cm/s)\n",
|
2163 |
+
"throw_time = 0.0\n",
|
2164 |
+
"mujoco.mj_resetData(model, data)\n",
|
2165 |
+
"while data.time < duration and not throw_time:\n",
|
2166 |
+
" mujoco.mj_step(model, data)\n",
|
2167 |
+
" box_speed = np.linalg.norm(data.joint('box').qvel[:3])\n",
|
2168 |
+
" if box_speed > 0.02:\n",
|
2169 |
+
" throw_time = data.time\n",
|
2170 |
+
"assert throw_time > 0\n",
|
2171 |
+
"\n",
|
2172 |
+
"def mix(time, t0=0.0, width=1.0):\n",
|
2173 |
+
" \"\"\"Sigmoidal mixing function.\"\"\"\n",
|
2174 |
+
" t = (time - t0) / width\n",
|
2175 |
+
" s = 1 / (1 + np.exp(-t))\n",
|
2176 |
+
" return 1 - s, s\n",
|
2177 |
+
"\n",
|
2178 |
+
"def unit_cos(t):\n",
|
2179 |
+
" \"\"\"Unit cosine sigmoid from (0,0) to (1,1).\"\"\"\n",
|
2180 |
+
" return 0.5 - np.cos(np.pi*np.clip(t, 0, 1))/2\n",
|
2181 |
+
"\n",
|
2182 |
+
"def orbit_motion(t):\n",
|
2183 |
+
" \"\"\"Return orbit trajectory.\"\"\"\n",
|
2184 |
+
" distance = 0.9\n",
|
2185 |
+
" azimuth = 140 + 100 * unit_cos(t)\n",
|
2186 |
+
" elevation = -30\n",
|
2187 |
+
" lookat = data.geom('floor').xpos.copy()\n",
|
2188 |
+
" return distance, azimuth, elevation, lookat\n",
|
2189 |
+
"\n",
|
2190 |
+
"def track_motion():\n",
|
2191 |
+
" \"\"\"Return box-track trajectory.\"\"\"\n",
|
2192 |
+
" distance = 0.08\n",
|
2193 |
+
" azimuth = 280\n",
|
2194 |
+
" elevation = -10\n",
|
2195 |
+
" lookat = data.geom('box').xpos.copy()\n",
|
2196 |
+
" return distance, azimuth, elevation, lookat\n",
|
2197 |
+
"\n",
|
2198 |
+
"def cam_motion():\n",
|
2199 |
+
" \"\"\"Return sigmoidally-mixed {orbit, box-track} trajectory.\"\"\"\n",
|
2200 |
+
" d0, a0, e0, l0 = orbit_motion(data.time / throw_time)\n",
|
2201 |
+
" d1, a1, e1, l1 = track_motion()\n",
|
2202 |
+
" mix_time = 0.3\n",
|
2203 |
+
" w0, w1 = mix(data.time, throw_time, mix_time)\n",
|
2204 |
+
" return w0*d0+w1*d1, w0*a0+w1*a1, w0*e0+w1*e1, w0*l0+w1*l1\n",
|
2205 |
+
"\n",
|
2206 |
+
"# Make a camera.\n",
|
2207 |
+
"cam = mujoco.MjvCamera()\n",
|
2208 |
+
"mujoco.mjv_defaultCamera(cam)\n",
|
2209 |
+
"\n",
|
2210 |
+
"# Simulate and display video.\n",
|
2211 |
+
"framerate = 60 # (Hz)\n",
|
2212 |
+
"slowdown = 4 # 4x slow-down\n",
|
2213 |
+
"mujoco.mj_resetData(model, data)\n",
|
2214 |
+
"frames = []\n",
|
2215 |
+
"with mujoco.Renderer(model, height, width) as renderer:\n",
|
2216 |
+
" while data.time < duration:\n",
|
2217 |
+
" mujoco.mj_step(model, data)\n",
|
2218 |
+
" if len(frames) < data.time * framerate * slowdown:\n",
|
2219 |
+
" cam.distance, cam.azimuth, cam.elevation, cam.lookat = cam_motion()\n",
|
2220 |
+
" renderer.update_scene(data, cam)\n",
|
2221 |
+
" pixels = renderer.render()\n",
|
2222 |
+
" frames.append(pixels)\n",
|
2223 |
+
"\n",
|
2224 |
+
"media.show_video(frames, fps=framerate)"
|
2225 |
+
]
|
2226 |
+
}
|
2227 |
+
],
|
2228 |
+
"metadata": {
|
2229 |
+
"accelerator": "GPU",
|
2230 |
+
"colab": {
|
2231 |
+
"collapsed_sections": [
|
2232 |
+
"YvyGCsgSCxHQ"
|
2233 |
+
],
|
2234 |
+
"gpuClass": "premium",
|
2235 |
+
"private_outputs": true
|
2236 |
+
},
|
2237 |
+
"gpuClass": "premium",
|
2238 |
+
"kernelspec": {
|
2239 |
+
"display_name": "Python 3 (ipykernel)",
|
2240 |
+
"language": "python",
|
2241 |
+
"name": "python3"
|
2242 |
+
},
|
2243 |
+
"language_info": {
|
2244 |
+
"codemirror_mode": {
|
2245 |
+
"name": "ipython",
|
2246 |
+
"version": 3
|
2247 |
+
},
|
2248 |
+
"file_extension": ".py",
|
2249 |
+
"mimetype": "text/x-python",
|
2250 |
+
"name": "python",
|
2251 |
+
"nbconvert_exporter": "python",
|
2252 |
+
"pygments_lexer": "ipython3",
|
2253 |
+
"version": "3.9.5"
|
2254 |
+
}
|
2255 |
+
},
|
2256 |
+
"nbformat": 4,
|
2257 |
+
"nbformat_minor": 4
|
2258 |
+
}
|