Skip to content

Base

The base class, its state and config for all MjxPlaygroundEnv environments used in Mujorax.

mujorax.MjxPlaygroundEnv

Bases: JaxEnv[Box, Box, MjxPlaygroundState, MjxPlaygroundConfig]

Base wrapper that exposes a mujoco_playground environment via Envrax's JaxEnv API.

Subclasses set _PLAYGROUND_NAME to a name accepted by mujoco_playground.registry.load. Override _reward, _done, or _info to customise per-env behaviour.

Parameters:

Name Type Description Default
config MjxPlaygroundConfig

Static configuration. Defaults to MjxPlaygroundConfig().

required
Source code in mujorax/envs/_base.py
Python
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
class MjxPlaygroundEnv(JaxEnv[Box, Box, MjxPlaygroundState, MjxPlaygroundConfig]):
    """
    Base wrapper that exposes a `mujoco_playground` environment via
    Envrax's `JaxEnv` API.

    Subclasses set `_PLAYGROUND_NAME` to a name accepted by
    `mujoco_playground.registry.load`. Override `_reward`, `_done`, or
    `_info` to customise per-env behaviour.

    Parameters
    ----------
    config : MjxPlaygroundConfig (optional)
        Static configuration. Defaults to `MjxPlaygroundConfig()`.
    """

    _PLAYGROUND_NAME: str = ""

    def __init__(self, config: MjxPlaygroundConfig | None = None) -> None:
        if not self._PLAYGROUND_NAME:
            raise ValueError(f"{type(self).__name__} must set `_PLAYGROUND_NAME`.")

        super().__init__(config)
        self._env = mujoco_playground.registry.load(
            self._PLAYGROUND_NAME,
            config_overrides=self._resolve_overrides(),
        )
        _ = self.observation_space  # raises NotImplementedError for dict obs

    def _resolve_overrides(self) -> Dict[str, Any] | None:
        """
        Build the override dict passed to `mujoco_playground.registry.load`.

        Auto-sets `impl="jax"` on CPU-only systems unless the user has
        already pinned `impl` via `config.config_overrides`. Playground
        defaults to `impl="warp"` which requires a CUDA backend.

        Returns
        -------
        overrides : Dict[str, Any] | None
            Resolved overrides, or `None` when empty.
        """
        overrides = dict(self.config.config_overrides or {})
        if "impl" not in overrides and not _has_cuda():
            overrides["impl"] = "jax"

        return overrides or None

    def _extract_obs(self, pg_state: mjx_env.State) -> chex.Array:
        """
        Extract the observation array from a Playground state.

        Dict observations are rejected at construction time; this method
        narrows Playground's `Observation` union to a single array and
        guards against the dict case slipping through at runtime.

        Parameters
        ----------
        pg_state : mjx_env.State
            Playground state

        Returns
        -------
        obs : chex.Array
            Observation array

        Raises
        ------
        error : TypeError
            If `pg_state.obs` is not a single array.
        """
        obs = pg_state.obs
        if not isinstance(obs, jax.Array):
            raise TypeError(
                f"Expected `pg_state.obs` to be a `jax.Array`, got "
                f"{type(obs).__name__}. Dict observations are not supported "
                "in this release."
            )

        return obs

    @property
    def observation_space(self) -> Box:
        """Returns the observation space."""
        size = self._env.observation_size

        if not isinstance(size, int):
            raise NotImplementedError(
                f"{type(self).__name__} produces dict-shaped observations "
                f"({size}); not supported in this release."
            )

        return Box(
            low=-jnp.inf,
            high=jnp.inf,
            shape=(size,),
            dtype=jnp.float32,
        )

    @property
    def action_space(self) -> Box:
        """Returns the action space."""
        return Box(
            low=-1.0,
            high=1.0,
            shape=(self._env.action_size,),
            dtype=jnp.float32,
        )

    def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, MjxPlaygroundState]:
        """
        Set the environment to a starting state.

        Parameters
        ----------
        rng : chex.PRNGKey
            JAX PRNG key

        Returns
        -------
        obs : chex.Array
            Initial observation
        state : MjxPlaygroundState
            Initial environment state with `rng` embedded
        """
        rng, init_rng = jax.random.split(rng)
        pg_state = self._env.reset(init_rng)

        state = MjxPlaygroundState(
            rng=rng,
            step=jnp.int32(0),
            done=pg_state.done.astype(jnp.bool_),
            pg_state=pg_state,
        )

        return self._extract_obs(pg_state), state

    def step(
        self,
        state: MjxPlaygroundState,
        action: chex.Array,
    ) -> Tuple[chex.Array, MjxPlaygroundState, chex.Array, chex.Array, Dict[str, Any]]:
        """
        Take an action through the environment.

        Parameters
        ----------
        state : MjxPlaygroundState
            Current environment state
        action : chex.Array
            Action to take in the environment

        Returns
        -------
        obs : chex.Array
            Observation after the step
        new_state : MjxPlaygroundState
            Updated environment state
        reward : chex.Array
            Scalar reward
        done : chex.Array
            bool scalar — `True` when the episode has ended
        info : Dict[str, Any]
            Auxiliary diagnostic information
        """
        new_pg = self._env.step(state.pg_state, action)  # type: ignore
        new_step = state.step + jnp.int32(1)

        reward = self._reward(state, action, new_pg)
        done = self._done(state, new_pg, new_step)
        rng, _ = jax.random.split(state.rng)

        new_state = state.__replace__(
            rng=rng,
            step=new_step,
            done=done,
            pg_state=new_pg,
        )
        info = self._info(state, new_pg, new_step)

        return self._extract_obs(new_pg), new_state, reward, done, info

    def render(
        self,
        state: MjxPlaygroundState,
        height: int = 240,
        width: int = 320,
    ) -> np.ndarray:
        """
        Render the environment state as an RGB frame.

        Parameters
        ----------
        state : MjxPlaygroundState
            Current environment state to render
        height : int, default 240
            Output frame height in pixels
        width : int, default 320
            Output frame width in pixels

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(height, width, 3)`
        """
        frames = self._env.render([state.pg_state], height=height, width=width)
        return np.asarray(frames[0], dtype=np.uint8)

    def _reward(
        self,
        state: MjxPlaygroundState,
        action: chex.Array,
        new_pg: mjx_env.State,
    ) -> chex.Array:
        """
        Compute the reward for the most recent step.

        Defaults to Playground's own reward. Override to add shaping.

        Parameters
        ----------
        state : MjxPlaygroundState
            State before the step
        action : chex.Array
            Action just taken
        new_pg : mjx_env.State
            Playground state after the step

        Returns
        -------
        reward : chex.Array
            Scalar reward
        """
        return new_pg.reward

    def _done(
        self,
        state: MjxPlaygroundState,
        new_pg: mjx_env.State,
        new_step: chex.Array,
    ) -> chex.Array:
        """
        Compute the termination flag for the most recent step.

        Defaults to `new_pg.done OR new_step >= max_steps`.

        Parameters
        ----------
        state : MjxPlaygroundState
            State before the step
        new_pg : mjx_env.State
            Playground state after the step
        new_step : chex.Array
            Episode timestep after the step

        Returns
        -------
        done : chex.Array
            bool scalar — `True` when the episode has ended
        """
        return jnp.logical_or(
            new_pg.done.astype(jnp.bool_),
            new_step >= self.config.max_steps,
        )

    def _info(
        self,
        state: MjxPlaygroundState,
        new_pg: mjx_env.State,
        new_step: chex.Array,
    ) -> Dict[str, Any]:
        """
        Build the info dict returned from `step`.

        Parameters
        ----------
        state : MjxPlaygroundState
            State before the step
        new_pg : mjx_env.State
            Playground state after the step
        new_step : chex.Array
            Episode timestep after the step

        Returns
        -------
        info : Dict[str, Any]
            Auxiliary diagnostic information
        """
        return {
            "current_step": new_step,
            "metrics": new_pg.metrics,
            **new_pg.info,
        }

observation_space property

Returns the observation space.

action_space property

Returns the action space.

reset(rng)

Set the environment to a starting state.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required

Returns:

Name Type Description
obs Array

Initial observation

state MjxPlaygroundState

Initial environment state with rng embedded

Source code in mujorax/envs/_base.py
Python
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def reset(self, rng: chex.PRNGKey) -> Tuple[chex.Array, MjxPlaygroundState]:
    """
    Set the environment to a starting state.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key

    Returns
    -------
    obs : chex.Array
        Initial observation
    state : MjxPlaygroundState
        Initial environment state with `rng` embedded
    """
    rng, init_rng = jax.random.split(rng)
    pg_state = self._env.reset(init_rng)

    state = MjxPlaygroundState(
        rng=rng,
        step=jnp.int32(0),
        done=pg_state.done.astype(jnp.bool_),
        pg_state=pg_state,
    )

    return self._extract_obs(pg_state), state

step(state, action)

Take an action through the environment.

Parameters:

Name Type Description Default
state MjxPlaygroundState

Current environment state

required
action Array

Action to take in the environment

required

Returns:

Name Type Description
obs Array

Observation after the step

new_state MjxPlaygroundState

Updated environment state

reward Array

Scalar reward

done Array

bool scalar — True when the episode has ended

info Dict[str, Any]

Auxiliary diagnostic information

Source code in mujorax/envs/_base.py
Python
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
def step(
    self,
    state: MjxPlaygroundState,
    action: chex.Array,
) -> Tuple[chex.Array, MjxPlaygroundState, chex.Array, chex.Array, Dict[str, Any]]:
    """
    Take an action through the environment.

    Parameters
    ----------
    state : MjxPlaygroundState
        Current environment state
    action : chex.Array
        Action to take in the environment

    Returns
    -------
    obs : chex.Array
        Observation after the step
    new_state : MjxPlaygroundState
        Updated environment state
    reward : chex.Array
        Scalar reward
    done : chex.Array
        bool scalar — `True` when the episode has ended
    info : Dict[str, Any]
        Auxiliary diagnostic information
    """
    new_pg = self._env.step(state.pg_state, action)  # type: ignore
    new_step = state.step + jnp.int32(1)

    reward = self._reward(state, action, new_pg)
    done = self._done(state, new_pg, new_step)
    rng, _ = jax.random.split(state.rng)

    new_state = state.__replace__(
        rng=rng,
        step=new_step,
        done=done,
        pg_state=new_pg,
    )
    info = self._info(state, new_pg, new_step)

    return self._extract_obs(new_pg), new_state, reward, done, info

render(state, height=240, width=320)

Render the environment state as an RGB frame.

Parameters:

Name Type Description Default
state MjxPlaygroundState

Current environment state to render

required
height int

Output frame height in pixels

240
width int

Output frame width in pixels

320

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (height, width, 3)

Source code in mujorax/envs/_base.py
Python
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def render(
    self,
    state: MjxPlaygroundState,
    height: int = 240,
    width: int = 320,
) -> np.ndarray:
    """
    Render the environment state as an RGB frame.

    Parameters
    ----------
    state : MjxPlaygroundState
        Current environment state to render
    height : int, default 240
        Output frame height in pixels
    width : int, default 320
        Output frame width in pixels

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(height, width, 3)`
    """
    frames = self._env.render([state.pg_state], height=height, width=width)
    return np.asarray(frames[0], dtype=np.uint8)

mujorax.MjxPlaygroundState

Bases: EnvState

Environment state for a wrapped MuJoCo Playground environment.

Parameters:

Name Type Description Default
rng PRNGKey

JAX PRNG key

required
step Array

Current timestep within the episode

required
done Array

bool scalar — episode termination flag

required
pg_state State

Full Playground environment state

required
Source code in mujorax/envs/_base.py
Python
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@chex.dataclass
class MjxPlaygroundState(EnvState):
    """
    Environment state for a wrapped MuJoCo Playground environment.

    Parameters
    ----------
    rng : chex.PRNGKey
        JAX PRNG key
    step : chex.Array
        Current timestep within the episode
    done : chex.Array
        bool scalar — episode termination flag
    pg_state : mjx_env.State
        Full Playground environment state
    """

    pg_state: mjx_env.State

mujorax.MjxPlaygroundConfig

Bases: EnvConfig

Static configuration for a wrapped MuJoCo Playground environment.

Parameters:

Name Type Description Default
max_steps int

Maximum number of steps per episode. Default is 1000.

required
config_overrides Dict[str, Any]

Flat overrides forwarded to mujoco_playground.registry.load. Use dotted keys for nested fields (e.g. "reward_config.scale").

required
Source code in mujorax/envs/_base.py
Python
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@chex.dataclass
class MjxPlaygroundConfig(EnvConfig):
    """
    Static configuration for a wrapped MuJoCo Playground environment.

    Parameters
    ----------
    max_steps : int
        Maximum number of steps per episode. Default is 1000.
    config_overrides : Dict[str, Any]
        Flat overrides forwarded to `mujoco_playground.registry.load`.
        Use dotted keys for nested fields (e.g. `"reward_config.scale"`).
    """

    config_overrides: Dict[str, Any] = field(default_factory=dict)