Getting Started
To get started, setup a Python 3.13+ environment and install the package.
Project Setup
If you don't already have a Python project, spin one up with your tool of choice:
Install Package
Then, install the package:
If you're new, or want a refresher, head on over to the tutorials or try out the example below!
Example Usage
A simple cartpole rollout:
| Python |
|---|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 | import jax
import mujorax # registers the suite at import
import envrax
# Init the environment
env = envrax.make("mjx/cartpole_balance-v0")
# Set its initial state
rng = jax.random.PRNGKey(0)
obs, state = env.reset(rng)
# Iterate through 1000 timesteps
for _ in range(1000):
rng, action_rng = jax.random.split(rng)
action = env.action_space.sample(action_rng)
obs, state, reward, done, info = env.step(state, action)
# If episode has ended, reset to start a new one
if done:
rng, reset_rng = jax.random.split(rng)
obs, state = env.reset(reset_rng)
|
This code should work "as is".
Make Parallel Copies of It
| Python |
|---|
1
2
3
4
5
6
7
8
9
10
11
12 | import jax
import jax.numpy as jnp
import envrax
import mujorax
vec_env = envrax.make_vec("mjx/cartpole_balance-v0", n_envs=512)
obs, state = vec_env.reset(jax.random.PRNGKey(0)) # obs: float32[512, 5]
actions = jnp.zeros((512, 1), dtype=jnp.float32)
obs, state, rewards, dones, infos = vec_env.step(state, actions)
# rewards: float32[512]
# dones: bool[512]
|
This code should work "as is".
Combine Heterogeneous Environments
| Python |
|---|
| import jax
import envrax
import mujorax
# Roll out across two different envs at once
multi_env = envrax.make_multi([
"mjx/cartpole_balance-v0",
"mjx/cheetah_run-v0",
])
obs_list, state_list = multi_env.reset(jax.random.PRNGKey(0)) # one entry per env
|
For vectorised parallel copies of each, use make_multi_vec:
| Python |
|---|
| multi_vec_env = envrax.make_multi_vec(
["mjx/cartpole_balance-v0", "mjx/cheetah_run-v0"],
n_envs=64,
)
obs_list, state_list = multi_vec_env.reset(jax.random.PRNGKey(0))
# each entry shaped (64, *single_obs_shape)
|
This code should work "as is".
Next Steps