Skip to content

Rendering

Composite-scene renderer that visualises multiple copies of a Playground environment in a single image.

mujorax.render.stadium.StadiumRenderer

Composite-scene renderer that visualises n_slots copies of a Playground environment in a single image.

Builds a render-only MJCF by replicating the environment's source XML n_slots times, spaced along the X axis. Each render() call rasterises the composite mj_data, which is populated by update() (or update_batched()) from caller-supplied per-slot states.

No physics happens here — the renderer copies qpos / qvel into the composite mj_data and calls mj_forward to refresh derived fields before rasterising.

Parameters:

Name Type Description Default
env VecEnv | MjxPlaygroundEnv

Source of the MJCF template. When a VecEnv is supplied, its inner environment provides the MJCF and n_slots is inferred from VecEnv.n_slots. When an MjxPlaygroundEnv (optionally wrapper-wrapped) is supplied, n_slots must be given explicitly.

required
n_slots int

Number of agent slots in the rendered stadium. Required when env is not a VecEnv; redundant (and validated against env.n_slots) when it is.

required
spacing float

Distance (metres) between adjacent slot origins along the X axis. Default is 5.0.

required
height int

Render frame height in pixels. Default is 480.

required
width int

Render frame width in pixels. Default is 640.

required

Raises:

Name Type Description
n_slots_missing ValueError

If env is not a VecEnv and n_slots is not supplied.

n_slots_conflict ValueError

If env is a VecEnv and n_slots is supplied with a value that does not match env.n_slots.

wrong_env_type TypeError

If env (after unwrapping) is not an MjxPlaygroundEnv.

Source code in mujorax/render/stadium.py
Python
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class StadiumRenderer:
    """
    Composite-scene renderer that visualises `n_slots` copies of a
    Playground environment in a single image.

    Builds a render-only MJCF by replicating the environment's source
    XML `n_slots` times, spaced along the X axis. Each `render()` call
    rasterises the composite `mj_data`, which is populated by `update()`
    (or `update_batched()`) from caller-supplied per-slot states.

    No physics happens here — the renderer copies `qpos` / `qvel` into
    the composite `mj_data` and calls `mj_forward` to refresh derived
    fields before rasterising.

    Parameters
    ----------
    env : VecEnv | MjxPlaygroundEnv
        Source of the MJCF template. When a `VecEnv` is supplied, its
        inner environment provides the MJCF and `n_slots` is inferred
        from `VecEnv.n_slots`. When an `MjxPlaygroundEnv` (optionally
        wrapper-wrapped) is supplied, `n_slots` must be given explicitly.
    n_slots : int (optional)
        Number of agent slots in the rendered stadium. Required when
        `env` is not a `VecEnv`; redundant (and validated against
        `env.n_slots`) when it is.
    spacing : float (optional)
        Distance (metres) between adjacent slot origins along the X axis.
        Default is `5.0`.
    height : int (optional)
        Render frame height in pixels. Default is `480`.
    width : int (optional)
        Render frame width in pixels. Default is `640`.

    Raises
    ------
    n_slots_missing : ValueError
        If `env` is not a `VecEnv` and `n_slots` is not supplied.
    n_slots_conflict : ValueError
        If `env` is a `VecEnv` and `n_slots` is supplied with a value
        that does not match `env.n_slots`.
    wrong_env_type : TypeError
        If `env` (after unwrapping) is not an `MjxPlaygroundEnv`.
    """

    def __init__(
        self,
        env: VecEnv | MjxPlaygroundEnv,
        n_slots: int | None = None,
        spacing: float = 5.0,
        height: int = 480,
        width: int = 640,
    ) -> None:
        env, n_slots = self._resolve_env_and_slots(env, n_slots)

        if n_slots < 1:
            raise ValueError(f"n_slots must be >= 1, got {n_slots}.")

        self._env = env
        self._n_slots = n_slots
        self._spacing = spacing
        self._height = height
        self._width = width

        self._mj_model = self._build_composite(env.xml_path, n_slots, spacing)
        self._mj_data = mujoco.MjData(self._mj_model)
        self._qpos_slots, self._qvel_slots = self._build_slot_address_tables(
            self._mj_model, n_slots
        )

        self._renderer = mujoco.Renderer(self._mj_model, height=height, width=width)

    @staticmethod
    def _resolve_env_and_slots(
        env: VecEnv | MjxPlaygroundEnv, n_slots: int | None
    ) -> tuple[MjxPlaygroundEnv, int]:
        """
        Normalise the `(env, n_slots)` constructor inputs to a concrete pair.

        Unwraps `VecEnv` to its inner environment (inferring `n_slots`),
        walks any further wrapper layers via `unwrapped`, and asserts the
        final environment is an `MjxPlaygroundEnv`.

        Parameters
        ----------
        env : VecEnv | MjxPlaygroundEnv
            Raw constructor input.
        n_slots : int | None
            Raw constructor input; ignored when `env` is a `VecEnv`.

        Returns
        -------
        env : MjxPlaygroundEnv
            Unwrapped Playground environment that exposes `xml_path`.
        n_slots : int
            Slot count, inferred from the `VecEnv` when applicable.

        Raises
        ------
        n_slots_missing : ValueError
            If `env` is not a `VecEnv` and `n_slots` is not supplied.
        n_slots_conflict : ValueError
            If `env` is a `VecEnv` and `n_slots` is supplied with a value
            that does not match `env.n_slots`.
        wrong_env_type : TypeError
            If `env` (after unwrapping) is not an `MjxPlaygroundEnv`.
        """
        if isinstance(env, VecEnv):
            if n_slots is not None and n_slots != env.n_slots:
                raise ValueError(
                    f"n_slots={n_slots} conflicts with VecEnv.n_slots={env.n_slots}; "
                    "omit `n_slots` when passing a VecEnv."
                )
            n_slots = env.n_slots
            env = env.env  # type: ignore

        env = getattr(env, "unwrapped", env)

        if n_slots is None:
            raise ValueError("`n_slots` is required when `env` is not a VecEnv.")

        if not isinstance(env, MjxPlaygroundEnv):
            raise TypeError(
                f"`env` must resolve to an MjxPlaygroundEnv, got {type(env).__name__}."
            )

        return env, n_slots

    @property
    def n_slots(self) -> int:
        """Number of agent slots in the rendered stadium."""
        return self._n_slots

    @property
    def mj_model(self) -> mujoco.MjModel:
        """The composite scene's `mujoco.MjModel`."""
        return self._mj_model

    @property
    def mj_data(self) -> mujoco.MjData:
        """The composite scene's `mujoco.MjData`. Populated by `update*` calls."""
        return self._mj_data

    @staticmethod
    def _build_composite(
        xml_path: Path, n_slots: int, spacing: float
    ) -> mujoco.MjModel:
        """
        Compose `n_slots` replicas of the env's MJCF into one scene.

        Parameters
        ----------
        xml_path : Path
            Path to the template environment's MJCF XML file.
        n_slots : int
            Number of replicas to attach to the composite scene.
        spacing : float
            Distance (metres) between adjacent slot origins along the X axis.

        Returns
        -------
        mj_model : mujoco.MjModel
            Compiled composite scene with one floor plane plus `n_slots`
            attached copies of the source MJCF, each prefixed `a{i}_`.
        """
        base = mujoco.MjSpec.from_file(str(xml_path))
        stadium = mujoco.MjSpec()
        stadium.option.timestep = base.option.timestep
        stadium.worldbody.add_geom(
            type=mujoco.mjtGeom.mjGEOM_PLANE,
            size=[50.0, 50.0, 0.1],
            rgba=[0.5, 0.5, 0.55, 1.0],
            contype=0,
            conaffinity=0,
        )
        for i in range(n_slots):
            child = mujoco.MjSpec.from_file(str(xml_path))
            frame = stadium.worldbody.add_frame(pos=[i * spacing, 0.0, 0.0])
            stadium.attach(child, prefix=f"a{i}_", frame=frame)

        return stadium.compile()

    @staticmethod
    def _build_slot_address_tables(
        mj_model: mujoco.MjModel, n_slots: int
    ) -> tuple[list[list[int]], list[list[int]]]:
        """
        Build per-slot `qpos` / `qvel` index lists by walking joint names.

        Joints are matched to slots via their `a{i}_` prefix (set by
        `_build_composite`). Joint type determines the per-joint
        `qpos` / `qvel` widths.

        Parameters
        ----------
        mj_model : mujoco.MjModel
            Compiled composite scene.
        n_slots : int
            Number of slots whose addresses to extract.

        Returns
        -------
        qpos_slots : list[list[int]]
            For each slot, the list of `qpos` indices belonging to its joints.
        qvel_slots : list[list[int]]
            For each slot, the list of `qvel` indices belonging to its joints.
        """
        qpos_slots: list[list[int]] = []
        qvel_slots: list[list[int]] = []
        for i in range(n_slots):
            prefix = f"a{i}_"
            q_idx, v_idx = [], []
            for j in range(mj_model.njnt):
                jname = mujoco.mj_id2name(mj_model, mujoco.mjtObj.mjOBJ_JOINT, j)
                if jname is None or not jname.startswith(prefix):
                    continue
                qadr = mj_model.jnt_qposadr[j]
                vadr = mj_model.jnt_dofadr[j]
                jtype = int(mj_model.jnt_type[j])
                qsize = {0: 7, 1: 4, 2: 1, 3: 1}[jtype]
                vsize = {0: 6, 1: 3, 2: 1, 3: 1}[jtype]
                q_idx.extend(range(qadr, qadr + qsize))
                v_idx.extend(range(vadr, vadr + vsize))
            qpos_slots.append(q_idx)
            qvel_slots.append(v_idx)

        return qpos_slots, qvel_slots

    def update(self, states: List[MjxPlaygroundState]) -> None:
        """
        Populate the composite `mj_data` from a list of single-env states.

        Parameters
        ----------
        states : List[MjxPlaygroundState]
            One state per slot, in slot-index order. Length must equal `n_slots`.

        Raises
        ------
        length_mismatch : ValueError
            If `len(states) != n_slots`.
        """
        if len(states) != self._n_slots:
            raise ValueError(
                f"StadiumRenderer.update expected {self._n_slots} states, "
                f"got {len(states)}."
            )

        for i, state in enumerate(states):
            self._copy_slot_state(i, state)

        mujoco.mj_forward(self._mj_model, self._mj_data)

    def update_batched(self, batched_state: MjxPlaygroundState) -> None:
        """
        Populate the composite `mj_data` from a batched (vmapped) state
        whose leading dim equals `n_slots`.

        Convenience for `VecEnv`-style states without manually unstacking.

        Parameters
        ----------
        batched_state : MjxPlaygroundState
            Single state pytree with leading batch dimension of size `n_slots`.
        """
        for i in range(self._n_slots):
            slot = jax.tree.map(lambda x, i=i: x[i], batched_state)
            self._copy_slot_state(i, slot)

        mujoco.mj_forward(self._mj_model, self._mj_data)

    def render(self) -> np.ndarray:
        """
        Render the full composite scene.

        Returns
        -------
        frame : np.ndarray
            uint8 RGB array of shape `(height, width, 3)`.
        """
        self._renderer.update_scene(self._mj_data)
        return self._renderer.render()

    def _copy_slot_state(self, slot_idx: int, state: MjxPlaygroundState) -> None:
        """
        Copy one slot's `qpos` / `qvel` into the composite `mj_data`.

        Does not refresh derived fields — callers should invoke
        `mj_forward` once after copying all slots.

        Parameters
        ----------
        slot_idx : int
            Target slot index in `[0, n_slots)`.
        state : MjxPlaygroundState
            Single-env state whose `qpos` / `qvel` populate the slot.
        """
        qpos = np.asarray(state.pg_state.data.qpos)
        qvel = np.asarray(state.pg_state.data.qvel)
        self._mj_data.qpos[self._qpos_slots[slot_idx]] = qpos
        self._mj_data.qvel[self._qvel_slots[slot_idx]] = qvel

    def __repr__(self) -> str:
        return (
            f"StadiumRenderer<{self._env.name}, "
            f"n_slots={self._n_slots}, "
            f"size={self._width}x{self._height}>"
        )

n_slots property

Number of agent slots in the rendered stadium.

mj_model property

The composite scene's mujoco.MjModel.

mj_data property

The composite scene's mujoco.MjData. Populated by update* calls.

update(states)

Populate the composite mj_data from a list of single-env states.

Parameters:

Name Type Description Default
states List[MjxPlaygroundState]

One state per slot, in slot-index order. Length must equal n_slots.

required

Raises:

Name Type Description
length_mismatch ValueError

If len(states) != n_slots.

Source code in mujorax/render/stadium.py
Python
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def update(self, states: List[MjxPlaygroundState]) -> None:
    """
    Populate the composite `mj_data` from a list of single-env states.

    Parameters
    ----------
    states : List[MjxPlaygroundState]
        One state per slot, in slot-index order. Length must equal `n_slots`.

    Raises
    ------
    length_mismatch : ValueError
        If `len(states) != n_slots`.
    """
    if len(states) != self._n_slots:
        raise ValueError(
            f"StadiumRenderer.update expected {self._n_slots} states, "
            f"got {len(states)}."
        )

    for i, state in enumerate(states):
        self._copy_slot_state(i, state)

    mujoco.mj_forward(self._mj_model, self._mj_data)

update_batched(batched_state)

Populate the composite mj_data from a batched (vmapped) state whose leading dim equals n_slots.

Convenience for VecEnv-style states without manually unstacking.

Parameters:

Name Type Description Default
batched_state MjxPlaygroundState

Single state pytree with leading batch dimension of size n_slots.

required
Source code in mujorax/render/stadium.py
Python
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
def update_batched(self, batched_state: MjxPlaygroundState) -> None:
    """
    Populate the composite `mj_data` from a batched (vmapped) state
    whose leading dim equals `n_slots`.

    Convenience for `VecEnv`-style states without manually unstacking.

    Parameters
    ----------
    batched_state : MjxPlaygroundState
        Single state pytree with leading batch dimension of size `n_slots`.
    """
    for i in range(self._n_slots):
        slot = jax.tree.map(lambda x, i=i: x[i], batched_state)
        self._copy_slot_state(i, slot)

    mujoco.mj_forward(self._mj_model, self._mj_data)

render()

Render the full composite scene.

Returns:

Name Type Description
frame ndarray

uint8 RGB array of shape (height, width, 3).

Source code in mujorax/render/stadium.py
Python
282
283
284
285
286
287
288
289
290
291
292
def render(self) -> np.ndarray:
    """
    Render the full composite scene.

    Returns
    -------
    frame : np.ndarray
        uint8 RGB array of shape `(height, width, 3)`.
    """
    self._renderer.update_scene(self._mj_data)
    return self._renderer.render()