Skip to content

Base

The base class, its state and config for all MjxPlaygroundEnv environments used in Mujorax.

mujorax.MjxPlaygroundEnv

Bases: JaxEnv[Box, Box, MjxPlaygroundState, MjxPlaygroundConfig]

Base wrapper that exposes a mujoco_playground environment via Envrax's JaxEnv API.

Subclasses set _PLAYGROUND_NAME to a name accepted by mujoco_playground.registry.load. Override _reward, _done, or _info to customise per-env behaviour.

Parameters:

Name Type Description Default
config MjxPlaygroundConfig

Static configuration. Defaults to MjxPlaygroundConfig().

required
Source code in mujorax/envs/_base.py
Python
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
class MjxPlaygroundEnv(JaxEnv[Box, Box, MjxPlaygroundState, MjxPlaygroundConfig]):
    """
    Base wrapper that exposes a `mujoco_playground` environment via
    Envrax's `JaxEnv` API.

    Subclasses set `_PLAYGROUND_NAME` to a name accepted by
    `mujoco_playground.registry.load`. Override `_reward`, `_done`, or
    `_info` to customise per-env behaviour.

    Parameters
    ----------
    config : MjxPlaygroundConfig (optional)
        Static configuration. Defaults to `MjxPlaygroundConfig()`.
    """

    _PLAYGROUND_NAME: str = ""

    def __init__(self, config: MjxPlaygroundConfig | None = None) -> None:
        if not self._PLAYGROUND_NAME:
            raise ValueError(f"{type(self).__name__} must set `_PLAYGROUND_NAME`.")

        super().__init__(config)
        self._env = mujoco_playground.registry.load(
            self._PLAYGROUND_NAME,
            config_overrides=self._resolve_overrides(),
        )
        _ = self.observation_space  # raises NotImplementedError for dict obs

    @property
    def xml_path(self) -> Path:
        """
        Path to the MJCF XML file backing this Playground environment.

        Used by composite render scenes that need to compose multiple
        copies of the environment's MJCF.

        Returns
        -------
        xml_path : Path
            Absolute path to the env's MJCF file.

        Raises
        ------
        attr_missing : AttributeError
            If the underlying Playground module does not expose `_XML_PATH`.
        """
        module = importlib.import_module(type(self._env).__module__)
        if not hasattr(module, "_XML_PATH"):
            raise AttributeError(
                f"Could not locate XML path for {type(self._env).__name__}; "
                f"module {module.__name__!r} has no `_XML_PATH` attribute."
            )

        return Path(str(module._XML_PATH))

    def _resolve_overrides(self) -> Dict[str, Any] | None:
        """
        Build the override dict passed to `mujoco_playground.registry.load`.

        Returns
        -------
        overrides : Dict[str, Any] | None
            Resolved overrides, or `None` when empty.
        """
        overrides = dict(self.config.config_overrides or {})
        overrides.setdefault("impl", self.config.impl)
        return overrides or None

    def _extract_obs(self, pg_state: mjx_env.State) -> jax.Array:
        """
        Extract the observation array from a Playground state.

        Dict observations are rejected at construction time; this method
        narrows Playground's `Observation` union to a single array and
        guards against the dict case slipping through at runtime.

        Parameters
        ----------
        pg_state : mjx_env.State
            Playground state

        Returns
        -------
        obs : jax.Array
            Observation array

        Raises
        ------
        error : TypeError
            If `pg_state.obs` is not a single array.
        """
        obs = pg_state.obs
        if not isinstance(obs, jax.Array):
            raise TypeError(
                f"Expected `pg_state.obs` to be a `jax.Array`, got "
                f"{type(obs).__name__}. Dict observations are not supported "
                "in this release."
            )

        return obs

    @property
    def observation_space(self) -> Box:
        """Returns the observation space."""
        size = self._env.observation_size

        if not isinstance(size, int):
            raise NotImplementedError(
                f"{type(self).__name__} produces dict-shaped observations "
                f"({size}); not supported in this release."
            )

        return Box(
            low=-jnp.inf,
            high=jnp.inf,
            shape=(size,),
            dtype=jnp.float32,
        )

    @property
    def action_space(self) -> Box:
        """Returns the action space."""
        return Box(
            low=-1.0,
            high=1.0,
            shape=(self._env.action_size,),
            dtype=jnp.float32,
        )

    def reset(self, rng: chex.PRNGKey) -> Tuple[jax.Array, MjxPlaygroundState]:
        """
        Set the environment to a starting state.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : jax.Array
            Initial observation
        state : MjxPlaygroundState
            Initial environment state with `rng` embedded
        """
        rng, init_rng = jax.random.split(rng)
        pg_state = self._env.reset(init_rng)

        state = MjxPlaygroundState(
            rng=rng,
            step=jnp.int32(0),
            done=pg_state.done.astype(jnp.bool_),
            pg_state=pg_state,
        )

        return self._extract_obs(pg_state), state

    def step(
        self,
        state: MjxPlaygroundState,
        action: jax.Array,
    ) -> Tuple[jax.Array, MjxPlaygroundState, jax.Array, jax.Array, Dict[str, Any]]:
        """
        Take an action through the environment.

        Parameters
        ----------
        state : MjxPlaygroundState
            Current environment state
        action : jax.Array
            Action to take in the environment

        Returns
        -------
        obs : jax.Array
            Observation after the step
        new_state : MjxPlaygroundState
            Updated environment state
        reward : jax.Array
            Scalar reward
        done : jax.Array
            bool scalar — `True` when the episode has ended
        info : Dict[str, Any]
            Auxiliary diagnostic information
        """
        new_pg = self._env.step(state.pg_state, action)  # type: ignore
        new_step = state.step + jnp.int32(1)

        reward = self._reward(state, action, new_pg)
        done = self._done(state, new_pg, new_step)
        rng, _ = jax.random.split(state.rng)

        new_state = state.__replace__(
            rng=rng,
            step=new_step,
            done=done,
            pg_state=new_pg,
        )
        info = self._info(state, new_pg, new_step)

        return self._extract_obs(new_pg), new_state, reward, done, info

    def render(
        self,
        state: MjxPlaygroundState,
        height: int = 240,
        width: int = 320,
    ) -> np.ndarray:
        """
        Render the environment state as an RGB frame.

        Parameters
        ----------
        state : MjxPlaygroundState
            Current environment state to render
        height : int, default 240
            Output frame height in pixels
        width : int, default 320
            Output frame width in pixels

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(height, width, 3)`
        """
        frames = self._env.render([state.pg_state], height=height, width=width)
        return np.asarray(frames[0], dtype=np.uint8)

    def _reward(
        self,
        state: MjxPlaygroundState,
        action: jax.Array,
        new_pg: mjx_env.State,
    ) -> jax.Array:
        """
        Compute the reward for the most recent step.

        Defaults to Playground's own reward. Override to add shaping.

        Parameters
        ----------
        state : MjxPlaygroundState
            State before the step
        action : jax.Array
            Action just taken
        new_pg : mjx_env.State
            Playground state after the step

        Returns
        -------
        reward : jax.Array
            Scalar reward
        """
        return new_pg.reward

    def _done(
        self,
        state: MjxPlaygroundState,
        new_pg: mjx_env.State,
        new_step: jax.Array,
    ) -> jax.Array:
        """
        Compute the termination flag for the most recent step.

        Defaults to `new_pg.done OR new_step >= max_steps`.

        Parameters
        ----------
        state : MjxPlaygroundState
            State before the step
        new_pg : mjx_env.State
            Playground state after the step
        new_step : jax.Array
            Episode timestep after the step

        Returns
        -------
        done : jax.Array
            bool scalar — `True` when the episode has ended
        """
        return jnp.logical_or(
            new_pg.done.astype(jnp.bool_),
            new_step >= self.config.max_steps,
        )

    def _info(
        self,
        state: MjxPlaygroundState,
        new_pg: mjx_env.State,
        new_step: jax.Array,
    ) -> Dict[str, Any]:
        """
        Build the info dict returned from `step`.

        Parameters
        ----------
        state : MjxPlaygroundState
            State before the step
        new_pg : mjx_env.State
            Playground state after the step
        new_step : jax.Array
            Episode timestep after the step

        Returns
        -------
        info : Dict[str, Any]
            Auxiliary diagnostic information
        """
        return {
            "current_step": new_step,
            "metrics": new_pg.metrics,
            **new_pg.info,
        }

xml_path property

Path to the MJCF XML file backing this Playground environment.

Used by composite render scenes that need to compose multiple copies of the environment's MJCF.

Returns:

Name Type Description
xml_path Path

Absolute path to the env's MJCF file.

Raises:

Name Type Description
attr_missing AttributeError

If the underlying Playground module does not expose _XML_PATH.

observation_space property

Returns the observation space.

action_space property

Returns the action space.

reset(rng)

Set the environment to a starting state.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Initial observation

state MjxPlaygroundState

Initial environment state with rng embedded

Source code in mujorax/envs/_base.py
Python
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
def reset(self, rng: chex.PRNGKey) -> Tuple[jax.Array, MjxPlaygroundState]:
    """
    Set the environment to a starting state.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs : jax.Array
        Initial observation
    state : MjxPlaygroundState
        Initial environment state with `rng` embedded
    """
    rng, init_rng = jax.random.split(rng)
    pg_state = self._env.reset(init_rng)

    state = MjxPlaygroundState(
        rng=rng,
        step=jnp.int32(0),
        done=pg_state.done.astype(jnp.bool_),
        pg_state=pg_state,
    )

    return self._extract_obs(pg_state), state

step(state, action)

Take an action through the environment.

Parameters:

Name Type Description Default
state MjxPlaygroundState

Current environment state

required
action Array

Action to take in the environment

required

Returns:

Name Type Description
obs Array

Observation after the step

new_state MjxPlaygroundState

Updated environment state

reward Array

Scalar reward

done Array

bool scalar — True when the episode has ended

info Dict[str, Any]

Auxiliary diagnostic information

Source code in mujorax/envs/_base.py
Python
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def step(
    self,
    state: MjxPlaygroundState,
    action: jax.Array,
) -> Tuple[jax.Array, MjxPlaygroundState, jax.Array, jax.Array, Dict[str, Any]]:
    """
    Take an action through the environment.

    Parameters
    ----------
    state : MjxPlaygroundState
        Current environment state
    action : jax.Array
        Action to take in the environment

    Returns
    -------
    obs : jax.Array
        Observation after the step
    new_state : MjxPlaygroundState
        Updated environment state
    reward : jax.Array
        Scalar reward
    done : jax.Array
        bool scalar — `True` when the episode has ended
    info : Dict[str, Any]
        Auxiliary diagnostic information
    """
    new_pg = self._env.step(state.pg_state, action)  # type: ignore
    new_step = state.step + jnp.int32(1)

    reward = self._reward(state, action, new_pg)
    done = self._done(state, new_pg, new_step)
    rng, _ = jax.random.split(state.rng)

    new_state = state.__replace__(
        rng=rng,
        step=new_step,
        done=done,
        pg_state=new_pg,
    )
    info = self._info(state, new_pg, new_step)

    return self._extract_obs(new_pg), new_state, reward, done, info

render(state, height=240, width=320)

Render the environment state as an RGB frame.

Parameters:

Name Type Description Default
state MjxPlaygroundState

Current environment state to render

required
height int

Output frame height in pixels

240
width int

Output frame width in pixels

320

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (height, width, 3)

Source code in mujorax/envs/_base.py
Python
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
def render(
    self,
    state: MjxPlaygroundState,
    height: int = 240,
    width: int = 320,
) -> np.ndarray:
    """
    Render the environment state as an RGB frame.

    Parameters
    ----------
    state : MjxPlaygroundState
        Current environment state to render
    height : int, default 240
        Output frame height in pixels
    width : int, default 320
        Output frame width in pixels

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(height, width, 3)`
    """
    frames = self._env.render([state.pg_state], height=height, width=width)
    return np.asarray(frames[0], dtype=np.uint8)

mujorax.MjxPlaygroundState

Bases: EnvState

Environment state for a wrapped MuJoCo Playground environment.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required
step Array

Current timestep within the episode

required
done Array

bool scalar — episode termination flag

required
pg_state State

Full Playground environment state

required
Source code in mujorax/envs/_base.py
Python
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
@chex.dataclass
class MjxPlaygroundState(EnvState):
    """
    Environment state for a wrapped MuJoCo Playground environment.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key
    step : jax.Array
        Current timestep within the episode
    done : jax.Array
        bool scalar — episode termination flag
    pg_state : mjx_env.State
        Full Playground environment state
    """

    pg_state: mjx_env.State

mujorax.MjxPlaygroundConfig

Bases: EnvConfig

Static configuration for a wrapped MuJoCo Playground environment.

Parameters:

Name Type Description Default
max_steps int

Maximum number of steps per episode. Default is 1000.

required
impl Literal['jax', 'warp']

MJX backend to use. When jax, uses pure JAX. When warp uses NVIDIA Warp FFI. Default is jax

required
config_overrides Dict[str, Any]

Flat overrides forwarded to mujoco_playground.registry.load. Use dotted keys for nested fields (e.g. "reward_config.scale").

required
Source code in mujorax/envs/_base.py
Python
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@chex.dataclass
class MjxPlaygroundConfig(EnvConfig):
    """
    Static configuration for a wrapped MuJoCo Playground environment.

    Parameters
    ----------
    max_steps : int (optional)
        Maximum number of steps per episode. Default is `1000`.
    impl : Literal["jax", "warp"] (optional)
        MJX backend to use. When `jax`, uses pure JAX. When `warp` uses NVIDIA Warp FFI. Default is `jax`
    config_overrides : Dict[str, Any] (optional)
        Flat overrides forwarded to `mujoco_playground.registry.load`.
        Use dotted keys for nested fields (e.g. `"reward_config.scale"`).
    """

    impl: Literal["jax", "warp"] = "jax"
    config_overrides: Dict[str, Any] = field(default_factory=dict)