Using Environments¶
Mujorax wraps every MuJoCo Playground environment behind Envrax's JaxEnv API, so once you know how to use one, you know how to use them all!
This tutorial walks you through how to create Mujorax environments with the make() methods, how to use the environment's reset/step methods, and how the MjxPlaygroundState carries everything between calls.
Without further ado, let's jump right in!
Make Methods Overview¶
API Docs
As Mujorax is an Envrax environment suite, it doesn't provide its own factory functions. Instead, it leans on Envrax's existing make methods so that you can utilise its centralised registry with as many environment suites as you want!
The four Envrax make() methods are available as soon as you import mujorax:
| Factory | Returns | Use for |
|---|---|---|
envrax.make(name) |
JaxEnv |
A single environment |
envrax.make_vec(name, n_envs) |
VecEnv |
A batched environment for parallel rollouts |
envrax.make_multi(names) |
MultiEnv |
Heterogeneous environments rolled out in parallel |
envrax.make_multi_vec(names, n_envs) |
MultiVecEnv |
Heterogeneous batched environments |
Canonical IDs (name) follow the mjx/<name>-v0 format. See the environment catalogue for the full list.
Already familiar with Envrax?
The same wrappers, jit_compile, pre_warm, and cache_dir parameters apply to all Mujorax environments out of the box!
For a deeper dive on the factory functions themselves, see Envrax's Make Methods [] tutorial.
Single Environment¶
API Docs
To construct a single environment by its canonical ID, use the envrax.make() method:
| Python | |
|---|---|
1 2 3 4 5 | |
By default, the returned environment is a JIT-compiled wrapper around the desired environment (in this case, CartpoleBalanceEnv).
reset()¶
To initialise/reset the environment, we use the env.reset() method. This takes a JAX PRNG key and returns a new observation and environment state:
| Python | |
|---|---|
1 | |
obsis ajax.Arraymatchingenv.observation_space.shape.stateis anMjxPlaygroundStatethat carries everything the environment needs to advance — the PRNG key, the current step counter, the done flag, and the wrapped Playground state with full physics data.
step()¶
To step through the environment and generate a new state, we use the step() method. This takes the current state and an action, and returns a 5-tuple of information:
| Python | |
|---|---|
1 2 3 4 5 | |
obs— the next observation in the environmentstate— the newMjxPlaygroundStatewithstep += 1reward— the scalarjax.Arrayreward obtained for taking that action in the previous statedone— aboolscalar determining if the environment has completedinfo— adictof metadata containing diagnostic information
A Full Episode¶
To run a full episode using a while loop, we can use the following example:
| Python | |
|---|---|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | |
Vectorised Environments¶
You can create vectorised environments using the make_vec() method. This returns a VecEnv copy of the environment that runs n_envs of parallel copies simultaneously in a single vmap call:
| Python | |
|---|---|
1 2 3 4 5 6 7 | |
With this approach, obs, reward, done, and the state values all gain a leading n_envs batch dimension.
By default, VecEnv auto-resets each parallel env when its episode ends, so done flags fire once per boundary and state is replaced by a fresh reset on the next step.
Heterogeneous Environments¶
For multi-task training, evaluation suites, or meta-learning across environment types, use make_multi:
| Python | |
|---|---|
1 2 3 4 5 6 | |
For vectorised parallel copies of each, use make_multi_vec:
| Python | |
|---|---|
1 2 3 4 5 6 | |
Per-env config overrides
make_multi and make_multi_vec use each env's registered default config.
To override them, use Envrax's MultiEnv / MultiVecEnv classes manually:
| Python | |
|---|---|
1 2 3 4 5 6 7 | |
State and Config¶
Every Mujorax environment shares the same state and config types:
MjxPlaygroundState— holds therngkey,stepindex,doneflag, and the embedded MuJoCo Playgroundpg_statewith full physics data.MjxPlaygroundConfig— holds amax_stepsvalue (required by Envrax) plus an optionalconfig_overridesdictionary that gets forwarded to the underlying MuJoCo Playground for advanced customisation.
Anywhere you need to inspect the physics state (joint positions, contact forces, sensor readings), use the state.pg_state.data value. The full surface is documented in the API reference.
We'll explore tweaking MjxPlaygroundConfig in the next tutorial.
Recap¶
And those are the basics! To recap:
import mujoraxregisters all supported Mujorax environments under canonical IDs of the formmjx/<name>-v0.envrax.make(name)returns a single JIT-wrapped env;make_vec,make_multi, andmake_multi_vecfollow the same pattern with batching or multi-environment composition.- Every environment has
resetandstepmethods.reset()returns a 2-tuple(obs, state), and thestep()method returns a 5-tuple(obs, state, reward, done, info). - Mujorax's state class is called
MjxPlaygroundState, which embeds the upstream Playground state on thepg_stateproperty for full physics access to the underlying environment.
Next Steps¶
Next up, we'll explore how to tweak the environments config via the MjxPlaygroundConfig class!
-
Configuration
Explore how to tweak an environments config.