Skip to content

Getting Started

To get started, setup a Python 3.13+ environment and install the package.

Project Setup

If you don't already have a Python project, spin one up with your tool of choice:

1
2
uv init --python 3.13 my-project
cd my-project
1
2
3
mkdir my-project && cd my-project
python3.13 -m venv .venv
source .venv/bin/activate
1
2
3
mkdir my-project && cd my-project
py -3.13 -m venv .venv
.venv\Scripts\activate
1
2
poetry new --python ">=3.13" my-project
cd my-project

Install Package

Then, install the package:

1
uv add mujorax
1
pip install mujorax
1
poetry add mujorax

If you're new, or want a refresher, head on over to the tutorials or try out the example below!

Example Usage

A simple cartpole rollout:

Python
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
import jax
import mujorax  # registers the suite at import
import envrax

# Init the environment
env = envrax.make("mjx/cartpole_balance-v0")

# Set its initial state
rng = jax.random.PRNGKey(0)
obs, state = env.reset(rng)

# Iterate through 1000 timesteps
for _ in range(1000):
    rng, action_rng = jax.random.split(rng)
    action = env.action_space.sample(action_rng)
    obs, state, reward, done, info = env.step(state, action)

    # If episode has ended, reset to start a new one
    if done:
        rng, reset_rng = jax.random.split(rng)
        obs, state = env.reset(reset_rng)

This code should work "as is".

Make Parallel Copies of It

Python
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
import jax
import jax.numpy as jnp
import envrax
import mujorax

vec_env = envrax.make_vec("mjx/cartpole_balance-v0", n_envs=512)
obs, state = vec_env.reset(jax.random.PRNGKey(0))   # obs: float32[512, 5]

actions = jnp.zeros((512, 1), dtype=jnp.float32)
obs, state, rewards, dones, infos = vec_env.step(state, actions)
# rewards: float32[512]
# dones:   bool[512]

This code should work "as is".

Combine Heterogeneous Environments

Python
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
import jax
import envrax
import mujorax

# Roll out across two different envs at once
multi_env = envrax.make_multi([
    "mjx/cartpole_balance-v0",
    "mjx/cheetah_run-v0",
])
obs_list, state_list = multi_env.reset(jax.random.PRNGKey(0))  # one entry per env

For vectorised parallel copies of each, use make_multi_vec:

Python
1
2
3
4
5
6
multi_vec_env = envrax.make_multi_vec(
    ["mjx/cartpole_balance-v0", "mjx/cheetah_run-v0"],
    n_envs=64,
)
obs_list, state_list = multi_vec_env.reset(jax.random.PRNGKey(0))
# each entry shaped (64, *single_obs_shape)

This code should work "as is".

Next Steps