"""Sim module for managing MuJoCo simulations.
This module provides a `Simulation` class to handle MuJoCo simulations, including
loading models, running simulations, capturing data, and rendering frames.
"""
from __future__ import annotations
import os
import sys
import threading
import time
import warnings
from collections import defaultdict
from contextlib import nullcontext
from multiprocessing import cpu_count
from pathlib import Path
from typing import TYPE_CHECKING, Any, Self, TypeAlias
import defusedxml.ElementTree as ET
import imageio.v3 as iio
import matplotlib.pyplot as plt
import mujoco
import mujoco.viewer
import numpy as np
import yaml
from IPython.display import HTML, clear_output
from matplotlib import animation
from tqdm.auto import tqdm
from .builder import Builder
from .loader import Loader
if TYPE_CHECKING:
from collections.abc import Callable
mjModel: TypeAlias = mujoco.MjModel # pylint: disable=E1101 # noqa: N816
mjData: TypeAlias = mujoco.MjData # pylint: disable=E1101 # noqa: N816
# pylint: disable=E1101
_MJ_OBJ_TYPES = [
mujoco.mjtObj.mjOBJ_BODY,
mujoco.mjtObj.mjOBJ_JOINT,
mujoco.mjtObj.mjOBJ_GEOM,
mujoco.mjtObj.mjOBJ_SITE,
mujoco.mjtObj.mjOBJ_CAMERA,
mujoco.mjtObj.mjOBJ_LIGHT,
mujoco.mjtObj.mjOBJ_MESH,
mujoco.mjtObj.mjOBJ_HFIELD,
mujoco.mjtObj.mjOBJ_TEXTURE,
mujoco.mjtObj.mjOBJ_MATERIAL,
mujoco.mjtObj.mjOBJ_PAIR,
mujoco.mjtObj.mjOBJ_EXCLUDE,
mujoco.mjtObj.mjOBJ_EQUALITY,
mujoco.mjtObj.mjOBJ_TENDON,
mujoco.mjtObj.mjOBJ_ACTUATOR,
mujoco.mjtObj.mjOBJ_SENSOR,
mujoco.mjtObj.mjOBJ_NUMERIC,
mujoco.mjtObj.mjOBJ_TEXT,
mujoco.mjtObj.mjOBJ_KEY,
mujoco.mjtObj.mjOBJ_PLUGIN,
]
[docs]
class Simulation:
"""Simulation class for managing MuJoCo simulations."""
def __new__(cls, *args: Any, **kwargs: Any) -> Self:
from . import PROGRESS_BAR_ENABLED # pylint: disable=E0401
if PROGRESS_BAR_ENABLED and kwargs.get("clear_screen", True):
os.system("clear || cls") # Clear the console
clear_output(wait=True)
return super().__new__(cls)
# pylint: disable=E1101
def __init__(
self,
*xml_args: str | Builder,
duration: int = 10,
data_rate: int = 100,
fps: int = 30,
resolution: tuple[int, int] | None = None,
initial_conditions: dict[str, list] | None = None,
keyframe: int | None = None,
controller: Callable[[mjModel, mjData, Any], None] | None = None,
meshdir: str = "meshes/",
**kwargs: Any,
) -> None:
"""Initialize the Simulation class for managing MuJoCo simulations.
Args:
xml_args (str | Builder): One or more XML file paths, XML strings,
or Builder objects defining the model.
duration (int, optional): Duration of the simulation in seconds. Defaults to 10.
data_rate (int, optional): Data capture rate in frames per second. Defaults to 100.
fps (int, optional): Frames per second for rendering. Defaults to 30.
resolution (tuple[int, int] | None, optional): Resolution of the simulation
in pixels (width, height). If None, defaults to values from the XML
or (400, 300).
initial_conditions (dict[str, list] | None, optional): Initial conditions
for the simulation.
keyframe (int | None, optional): Keyframe index for resetting the simulation.
controller (Callable[[mjModel, mjData, Any], None] | None, optional): Custom
controller function for the simulation.
meshdir (str, optional): Directory containing mesh files for URDF models. Defaults to "meshes/".
**kwargs: Additional keyword arguments for model configuration.
Raises:
ValueError: If no XML arguments are provided.
"""
if not xml_args:
msg = "At least one XML file, string, or Builder is required."
raise ValueError(msg)
self._builder = Builder.merge(xml_args, meshdir=meshdir)
self._meshdir = meshdir
self._loader = Loader(self._builder)
# Validate meshes after loading
self._loader.validate_meshes()
self.xml = self._loader.xml
self._model = self._loader.model
# Simulation Parameters
self.duration = duration
self.fps = fps
self.data_rate = data_rate
self.controller = controller
self.resolution = resolution or self._extract_resolution()
# Predefined simulation parameters but can be overridden
# TODO(#8): @MGross21 Currently Causing Bugs when occluded from XML
self.ts = kwargs.get("ts", self._model.opt.timestep)
self.gravity = kwargs.get("gravity", self._model.opt.gravity)
self._data = mujoco.MjData(self._model)
self._keyframe = keyframe
self.initial_conditions = initial_conditions or {} # **after data**
# Initialize _frames and _captured_data attributes
self._frames: list[np.ndarray] | None = None
self._captured_data: _SimulationData | None = None
self._initialize_names()
def _initialize_names(self) -> None:
"""Populate body, joint, and actuator names."""
self.body_names = [
self._model.body(i).name for i in range(self._model.nbody)
]
self.geom_names = [
self._model.geom(i).name for i in range(self._model.ngeom)
]
self.joint_names = [
self._model.joint(i).name for i in range(self._model.njnt)
]
self.actuator_names = [
self._model.actuator(i).name for i in range(self._model.nu)
]
def _extract_resolution(self) -> tuple[int, int]:
"""Extract resolution from the XML or return default values."""
try:
root = ET.fromstring(self.xml)
global_tag = root.find("visual/global")
if global_tag is not None:
offwidth = int(global_tag.get("offwidth", 400))
offheight = int(global_tag.get("offheight", 300))
return (offwidth, offheight)
except (ET.ParseError, ValueError, TypeError):
pass
return (400, 300)
[docs]
def reload(self: Simulation) -> Simulation:
"""Reload the model and data objects.
Returns:
Simulation: Self for method chaining.
"""
# Use the Loader to handle model reloading
loader = Loader(self.xml, meshdir=self._meshdir)
self._model = loader.model
self._data = mujoco.MjData(self._model)
self._initialize_names() # Reinitialize names and apply ic's
# Apply initial conditions
for key, value in getattr(self, "init_conditions", {}).items():
if hasattr(self._data, key):
setattr(self._data, key, value)
return self
def __str__(self) -> str: # noqa: D105
return self._model.__str__()
def __repr__(self) -> str: # noqa: D105
MAX_LINE_ITEMS = 5 # noqa: N806 # pylint: disable=C0103
# Limit the number of items displayed in the string representation
body_names = self.body_names[:MAX_LINE_ITEMS] + (
["..."] if len(self.body_names) > MAX_LINE_ITEMS else []
)
joint_names = self.joint_names[:MAX_LINE_ITEMS] + (
["..."] if len(self.joint_names) > MAX_LINE_ITEMS else []
)
actuator_names = self.actuator_names[:MAX_LINE_ITEMS] + (
["..."] if len(self.actuator_names) > MAX_LINE_ITEMS else []
)
# Format the string representation
return (
f"{self.__class__.__name__}(\n"
f" Duration: {self.duration}s "
f"[fps={self.fps}, ts={self.ts:.0e}]\n"
f" Gravity: {self.gravity},\n"
f" Resolution: {self.resolution[0]}W x {self.resolution[1]}H\n"
f" Bodies ({self.model.nbody}): {', '.join(body_names)}\n"
f" Joints ({self.model.njnt}): {', '.join(joint_names)}\n"
f" Actuators ({self.model.nu}): {', '.join(actuator_names)}\n"
f" Controller: "
f"{self.controller.__name__ if self.controller else None}\n"
f")"
)
def __enter__(self) -> Self: # noqa: D105
return self
def __exit__(self: Simulation, *args, **kwargs) -> None: # noqa: D105
mujoco.set_mjcb_control(None)
for thread in threading.enumerate():
if thread is not threading.main_thread():
thread.join()
@property
def model(self) -> mjModel:
"""Read-only property to access the MjModel object."""
return self._model
@property
def data(self) -> mjData:
"""Read-only property to access the MjData single-step object.
Use `captured_data` to access the entire simulation data.
"""
return self._data
@property
def keyframe(self) -> int | None:
"""Keyframe index for the simulation."""
return self._keyframe
@keyframe.setter
def keyframe(self, value: int | None) -> None:
if value is not None and not isinstance(value, int):
msg = "Keyframe must be an integer."
raise ValueError(msg)
if value is not None and (value < 0 or value > self._model.nkey):
msg = (
f"Keyframe must be between 0 and {self._model.nkey}."
f" Got {value}."
)
raise ValueError(msg)
self._keyframe = value
@property
def captured_data(self) -> dict[str, np.ndarray]:
"""Read-only property to access the entire captured simulation data."""
if self._captured_data is None:
msg = "No simulation data captured yet."
raise ValueError(msg)
return self._captured_data.unwrap()
@captured_data.deleter
def captured_data(self) -> None:
self._captured_data = None
@property
def frames(self) -> list[np.ndarray]:
"""Read-only property to access the captured frames."""
if not hasattr(self, "_frames") or self._frames is None:
msg = (
"No frames captured yet. "
"Run the simulation with render=True to capture frames."
)
raise AttributeError(msg) from None
return self._frames
@frames.deleter
def frames(self) -> None:
# Safely delete frames if they exist
if hasattr(self, "_frames") and isinstance(self._frames, list):
self._frames.clear()
self._frames = None
import gc
gc.collect()
@property
def duration(self) -> float:
"""Duration of the simulation in seconds."""
return self._duration
@duration.setter
def duration(self, value: float) -> None:
if value < 0:
msg = "Duration must be greater than zero."
raise ValueError(msg)
self._duration = value
@property
def fps(self) -> float:
"""Frames per second."""
return self._fps
@fps.setter
def fps(self, value: float) -> None:
if value < 0:
msg = "FPS must be greater than zero."
raise ValueError(msg)
self._fps = value
@property
def resolution(self) -> tuple[int, int]:
"""Resolution of the simulation in pixels (w,h)."""
return (
self._model.vis.global_.offwidth,
self._model.vis.global_.offheight,
)
@resolution.setter
def resolution(self, values: tuple[int, int]) -> None:
if not (isinstance(values, tuple) and len(values) == 2):
msg = "Resolution must be a tuple of width and height."
raise ValueError(msg)
if any(v < 1 for v in values):
msg = "Resolution must be at least 1x1 pixels."
raise ValueError(msg)
# Update the resolution using the correct mujoco attributes
self._model.vis.global_.offwidth = int(values[0])
self._model.vis.global_.offheight = int(values[1])
@property
def initial_conditions(self) -> dict[str, list]:
"""Initial conditions for the simulation."""
return self.init_conditions
@initial_conditions.setter
def initial_conditions(self, values: dict[str, list]) -> None:
if not isinstance(values, dict):
msg = "Initial conditions must be a dictionary."
raise TypeError(msg)
# Cache data object and attribute names
data = self._data
valid_attrs = _SimulationData.get_public_keys(data)
# Find any invalid keys
invalid_keys = [key for key in values if key not in valid_attrs]
if invalid_keys:
msg = (
f"Invalid initial condition attributes: {', '.join(invalid_keys)}.\n"
f"Valid attributes include: {', '.join(valid_attrs)}"
)
raise ValueError(
msg,
)
# Save and apply
self.init_conditions = values
for k, v in values.items():
setattr(data, k, v)
@property
def controller(self) -> Callable[[mjModel, mjData, Any], None] | None:
"""Controller Function."""
return self._controller
@controller.setter
def controller(self, func: Callable[[mjModel, mjData, Any], None]) -> None:
if func is not None and not callable(func):
msg = "Controller must be a callable function."
raise ValueError(msg)
self._controller = func
@property
def ts(self) -> float:
"""Timestep of the simulation in seconds."""
return self._model.opt.timestep
@ts.setter
def ts(self, value: int) -> None:
if value <= 0:
msg = "Timestep must be greater than 0."
raise ValueError(msg)
self._model.opt.timestep = value
@property
def data_rate(self) -> int:
"""Data rate of the simulation in frames per second."""
return self._dr
@data_rate.setter
def data_rate(self, value: int) -> None:
if not isinstance(value, int):
msg = "Data rate must be an integer."
raise ValueError(msg)
if value <= 0:
msg = "Data rate must be greater than 0."
raise ValueError(msg)
max_rate = int(self._duration / self.ts)
if value > max_rate:
msg = f"{value} exceeds the maximum data rate of {max_rate}."
raise ValueError(msg)
self._dr = value
@property
def gravity(self) -> np.ndarray:
"""Gravity vector of the simulation."""
return self._model.opt.gravity # pylint: disable=E1101
@gravity.setter
def gravity(self, values: list | tuple | np.ndarray) -> None:
if (
not isinstance(values, (list, tuple, np.ndarray))
or len(values) != 3
):
msg = "Gravity must be a 3D vector."
raise ValueError(msg)
self._model.opt.gravity = np.array(values)
[docs]
def run(
self,
*,
render: bool = False,
camera: str | None = None,
interactive: bool = False,
show_menu: bool = True, # TODO@MGross21: Implement this with launch
multi_thread: bool = False,
) -> Simulation:
"""Run the simulation with optional rendering.
Args:
render (bool): If True, renders the simulation.
camera (str): The camera view to render from, defaults to None.
data_rate (int): How often to capture data, expressed as frames
per second.
interactive (bool): If True, opens an interactive viewer window.
show_menu (bool): Shows the menu in the interactive viewer.
`Interactive` must be True.
multi_thread (bool): If True, runs the simulation in multi-threaded
mode.
Returns:
self: The current Simulation object for method chaining.
"""
# TODO: Integrate interactive mujoco.viewer into this method
# Eventually rename this to run() and point to sub-methods
if interactive:
msg = "Interactive mode (w/ menu option) is not yet implemented."
raise NotImplementedError(msg)
if multi_thread:
msg = "Multi-threading is not yet implemented."
raise NotImplementedError(msg)
try:
mujoco.mj_resetData(self._model, self._data)
if self._controller is not None:
mujoco.set_mjcb_control(self._controller)
if self._keyframe is not None:
mujoco.mj_resetDataKeyframe(
self._model, self._data, self._keyframe,
)
sim_data = _SimulationData()
# Cache frequently used functions and objects for performance
mj_step1 = mujoco.mj_step1
mj_step2 = mujoco.mj_step2
m, d = self._model, self._data
# dur = self._duration
# Simulation Timing
total_steps = int(self._duration / self.ts)
# capture_rate = self.data_rate * self.ts
capture_interval = max(1, int(1.0 / (self._dr * self.ts))) # PEMDAS :)
# RENDERING PREPARATIONS
if render:
w, h = self.resolution
render_interval = max(1, int(1.0 / (self._fps * self.ts)))
max_frames = int(self._duration * self._fps)
frames = np.zeros((max_frames, h, w, 3), dtype=np.uint8)
frame_count = 0
if multi_thread:
cpu_count()
# TODO: Implement multi-threading
# if interactive:
# gui = threading.Thread(target=self._window, kwargs={"show_menu": show_menu}) # noqa: ERA001
# gui.start()
# Mujoco Renderer
from . import ( # pylint: disable=E0405
MAX_GEOM_SCALAR,
PROGRESS_BAR_ENABLED,
)
max_geom = m.ngeom * MAX_GEOM_SCALAR
_Renderer = ( # noqa: N806, pylint: disable=C0103
mujoco.Renderer(m, h, w, max_geom) if render else nullcontext()
)
_ProgressBar = (
tqdm(
total=total_steps,
desc="Simulation",
unit=" steps",
leave=False,
)
if PROGRESS_BAR_ENABLED
else nullcontext()
)
with _Renderer as renderer, _ProgressBar as pbar:
for step in range(total_steps):
mj_step1(m, d)
# Capture data at the specified rate
if step % capture_interval == 0:
sim_data.capture(d)
if render and renderer and step % render_interval == 0 and frame_count < max_frames:
renderer.update_scene(d, camera if camera else -1)
frames[frame_count] = renderer.render() # no copy
frame_count += 1 # Increment frame count after capturing the frame
mj_step2(m, d)
pbar.update(1) if PROGRESS_BAR_ENABLED else None
except Exception as e:
msg = "An error occurred while running the simulation."
raise RuntimeError(msg) from e
finally:
mujoco.set_mjcb_control(None)
self._captured_data = sim_data
self._frames = frames[:frame_count] if render else None
# if interactive:
# gui.join()
return self
def _window(self, show_menu: bool = True) -> None: # noqa: FBT001, FBT002
"""Open a window to display the simulation in real time."""
try:
m = self._model
d = self._data
def key_callback(key: int) -> bool:
return key in (27, ord("q")) # 27 = ESC key, 'q' to quit
# NOTE: launch_passive may blocking
_Viewer = mujoco.viewer.launch_passive( # noqa: N806, pylint: disable=C0103
m,
d,
show_left_ui=show_menu,
show_right_ui=show_menu,
key_callback=key_callback,
)
with _Viewer as viewer:
viewer.sync()
start_time = time.time()
try:
while viewer.is_running():
current_time = time.time()
dt = current_time - start_time
mujoco.mj_step(m, d) # Advance simulation by one step
viewer.sync() # Sync the viewer
start_time = current_time # Reset reference time
time.sleep(max(0, 1.0 / self.fps - dt)) # self.fps
except KeyboardInterrupt:
viewer.close()
except Exception as e:
msg = "An error occurred while running the simulation."
raise RuntimeError(msg) from e
finally:
mujoco.set_mjcb_control(None)
[docs]
def launch(self, show_menu: bool = True) -> None: # noqa: FBT001, FBT002
"""Open a window to display the simulation in real time."""
# Run the window in a separate thread
gui = threading.Thread(
target=self._window,
kwargs={"show_menu": show_menu},
)
gui.start()
def _get_index( # noqa: C901
self,
frame_idx: int | tuple[int, int] | None = None,
time_idx: float | tuple[float, float] | None = None,
) -> list[Any]:
"""Validate and extract frames based on frame or time indices.
Args:
frame_idx (int or tuple, optional): Single frame index or
(start, stop) frame indices.
time_idx (float or tuple, optional): Single time or
(start, end) times in seconds.
Returns:
List of frames
Raises:
ValueError: For invalid input parameters.
"""
if self._frames is None or len(self._frames) == 0:
msg = "No frames captured to render."
raise ValueError(msg)
# Validate input parameters
if frame_idx is not None and time_idx is not None:
msg = "Can only specify either frame_idx or time_idx, not both."
raise ValueError(msg)
# If both are None, use all frames
if frame_idx is None and time_idx is None:
return self._frames
# Handle time index conversion
if time_idx is not None:
if isinstance(time_idx, (int, float)):
frame_idx = self.t2f(time_idx)
elif isinstance(time_idx, tuple):
frame_idx = (self.t2f(time_idx[0]), self.t2f(time_idx[1]))
else:
msg = "time_idx must be a number or a tuple of numbers."
raise ValueError(msg)
# Convert single index to tuple range
if isinstance(frame_idx, (int, float)):
frame_idx = (frame_idx, frame_idx + 1)
# Validate frame indices
if frame_idx is None:
msg = "frame_idx cannot be None when unpacking."
raise ValueError(msg)
start, stop = frame_idx
max_frames = len(self._frames)
if start < 0:
msg = f"Start index must be non-negative. Got {start}."
raise ValueError(msg)
if stop > max_frames:
msg = (
f"Stop index must not exceed total frames ({max_frames}). "
f"Got {stop}."
)
raise ValueError(msg)
if start >= stop:
msg = (
f"Start index ({start}) must be less than stop index ({stop})."
)
raise ValueError(msg)
# Select subset of frames
return self._frames[start:stop]
[docs]
def show(
self,
title: str | None = None,
*,
frame_idx: int | tuple[int, int] | None = None,
time_idx: float | tuple[float, float] | None = None,
) -> None:
"""Render specific frame(s) as a video or GIF in a window.
Args:
title (str, optional): Title for the rendered media.
frame_idx (int or tuple, optional): Single frame index or
(start, stop) frame indices.
time_idx (float or tuple, optional): Single time or
(start, end) times in seconds.
Raises:
ValueError: If no frames are captured or invalid input parameters.
"""
if not hasattr(self, "_frames") or self._frames is None or self._frames.size == 0:
msg = "No frames captured to render. Re-run the simulation with render=True."
raise ValueError(msg)
try:
# Extract frames
subset_frames = self._get_index(
frame_idx=frame_idx,
time_idx=time_idx,
)
def is_jupyter() -> bool:
try:
from IPython import get_ipython
return "ipykernel" in sys.modules or "IPKernelApp" in get_ipython().config
except Exception:
return False
# Set up the figure and image once
fig, ax = plt.subplots()
im = ax.imshow(np.zeros((self.resolution[1], self.resolution[0], 3), dtype=np.uint8), interpolation="nearest")
ax.set_axis_off()
ax.set_title(title)
plt.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
if is_jupyter():
ani = animation.FuncAnimation(
fig,
lambda frame: (im.set_data(frame), [im])[1],
frames=subset_frames,
interval=(1000 / self._fps),
blit=True,
)
plt.close(fig)
return HTML(ani.to_jshtml())
plt.ion()
delay = 1.0 / self._fps
for frame in subset_frames:
im.set_data(frame)
plt.pause(delay)
plt.ioff()
plt.close(fig)
except Exception as e:
msg = "Error while showing video subset."
raise Exception(msg) from e # noqa: TRY002
[docs]
def save(
self,
title: str = "output.gif",
*,
frame_idx: int | tuple[int, int] | None = None,
time_idx: float | tuple[float, float] | None = None,
) -> str:
"""Save specific frame(s) as a video or GIF to a file.
Args:
title (str, optional): Filename for the saved media. Filename
should end with the desired codec extension (e.g., .mp4, .gif)
frame_idx (int or tuple, optional): Single frame index or
(start, stop) frame indices.
time_idx (float or tuple, optional): Single time or
(start, end) times in seconds.
Returns:
str: Absolute path to the saved file.
Raises:
ValueError: If no frames are captured or invalid input parameters.
"""
if not hasattr(self, "_frames") or self._frames is None or len(self._frames) == 0:
msg = "No frames captured to render. Re-run the simulation with render=True."
raise ValueError(msg)
# Extract frames
subset_frames = self._get_index(
frame_idx=frame_idx,
time_idx=time_idx,
)
title_path = Path(title)
try:
# Save the video
iio.imwrite(
title,
subset_frames,
fps=self._fps if len(subset_frames) != 1 else 1,
)
return str(title_path.resolve())
except RuntimeError:
raise
except Exception as e:
msg = "Error while saving video subset."
raise Exception(msg) from e # noqa: TRY002
[docs]
def t2f(self, t: float) -> int:
"""Convert time to frame index."""
return min(
int(t * self._fps),
int(self._duration * self._fps) - 1,
) # Subtract 1 to convert to 0-based index
[docs]
def f2t(self, frame: int) -> float:
"""Convert frame index to time."""
return frame / self._fps
[docs]
def body_data(
self, body_name: str, data_name: str | None = None,
) -> dict[str, np.ndarray] | np.ndarray:
"""Get the data for a specific body in the simulation.
Args:
body_name (str): The name of the body to retrieve data for.
data_name (str): The name of the data to retrieve.
Returns:
dict[str, np.ndarray] | np.ndarray: The data for the specified body.
"""
if body_name not in self.body_names:
msg = f"Body '{body_name}' not found in the model."
raise ValueError(msg)
body_id = self._model.body(body_name).id
if self._captured_data is None:
msg = "No simulation data captured yet."
raise ValueError(msg)
unwrapped_data = self._captured_data.unwrap()
if data_name is None:
return unwrapped_data.get(body_id, np.array([]))
if data_name not in unwrapped_data:
msg = f"Data '{data_name}' not found for body '{body_name}'."
raise ValueError(msg)
if isinstance(unwrapped_data[body_id], dict):
return unwrapped_data[body_id].get(data_name, None)
msg = f"Data for body_id '{body_id}' is not a dictionary."
raise ValueError(msg)
[docs]
def name2id(self, name: str) -> int:
"""Get the name of a body given its index.
Args:
name (str): The name of the body.
Returns:
int: The index of the body.
"""
for obj_type in _MJ_OBJ_TYPES:
try:
obj_id = mujoco.mj_name2id(self._model, obj_type, name)
if obj_id >= 0:
return obj_id
except (mujoco.FatalError, mujoco.UnexpectedError, Exception): # noqa: S112
continue
msg = f"Object with name '{name}' not found."
raise ValueError(msg)
[docs]
def id2name(self, id: int) -> str:
"""Get the name of a body given its ID.
Args:
id (int): The ID of the body.
Returns:
str: The name of the body.
"""
# BUG: Fix this to work properly
msg = "This method is not implemented yet."
raise NotImplementedError(msg)
for obj_type in _MJ_OBJ_TYPES:
try:
obj_name = mujoco.mj_id2name(self._model, obj_type, id)
if obj_name is None:
continue
return obj_name
except (mujoco.FatalError, mujoco.UnexpectedError, Exception): # noqa: S112
continue
msg = f"ID '{id}' not found."
raise ValueError(msg)
[docs]
def to_yaml(self, name: str = "Model") -> None:
"""Save simulation data to a YAML file.
Args:
name (str): The filename for the YAML file.
Returns:
None
"""
if not name.endswith(".yml"):
name += ".yml"
try:
# Convert simData's NumPy arrays or lists to a YAML-friendly format
serialized_data = {
k: (v.tolist() if isinstance(v, np.ndarray) else v)
for k, v in self.captured_data.items()
}
with Path(name).open("w", encoding="utf-8") as f:
yaml.dump(serialized_data, f, default_flow_style=False)
except Exception as e:
msg = f"Failed to save simulation data to '{name}'"
raise ValueError(msg) from e
class _SimulationData:
"""A class to store and manage simulation data."""
__slots__ = ["_d"]
def __init__(self) -> None:
self._d: dict[str, list] = defaultdict(list)
def _is_capture_all(self, params) -> bool:
"""Check if all data is captured."""
if params is all:
return True
if isinstance(params, set):
return ("all" in map(str.lower, params))
if isinstance(params, str):
return params.lower() == "all"
return None
def capture(self, mj_data) -> None:
"""Capture data from MjData, storing specified or all public attributes."""
from . import CAPTURE_PARAMETERS
if (self._is_capture_all(CAPTURE_PARAMETERS)):
keys = self.get_public_keys(mj_data) # TODO: Fix this to be more efficient. Is cycling on every sim step.
else:
keys = CAPTURE_PARAMETERS
for key in keys:
value = getattr(mj_data, key, None)
if value is None:
continue
if isinstance(value, np.ndarray):
self._d[key].append(value.copy())
elif np.isscalar(value):
self._d[key].append(value)
elif hasattr(value, "copy") and callable(value.copy):
self._d[key].append(value.copy())
else:
self._d[key].append(value)
def unwrap(self) -> dict[str, np.ndarray]:
"""Unwrap simulation data into a structured format with NumPy arrays.
Returns:
dict[str, np.ndarray]: Unwrapped data for each key.
"""
unwrapped_data = {}
for key, value_list in self._d.items():
if not value_list:
unwrapped_data[key] = np.array([])
continue
first = value_list[0]
try:
if isinstance(first, np.ndarray):
shape = first.shape
if all(v.shape == shape for v in value_list):
unwrapped_data[key] = np.stack(value_list)
else:
unwrapped_data[key] = value_list # Inconsistent shapes
else:
unwrapped_data[key] = np.array(value_list)
except (ValueError, TypeError):
unwrapped_data[key] = value_list # Fallback
return unwrapped_data
@property
def shape(self) -> dict[str, tuple]:
"""Return the shape of the captured data per key."""
if not self._d:
return {}
shapes: dict[str, tuple[Any, ...]] = {}
for key, value_list in self._d.items():
if not value_list:
shapes[key] = ()
continue
first_value = value_list[0]
if isinstance(first_value, np.ndarray):
shapes[key] = (len(value_list), *first_value.shape)
elif isinstance(first_value, list) and all(isinstance(v, list) for v in value_list):
shapes[key] = (len(value_list), len(first_value))
else:
shapes[key] = (len(value_list),)
return shapes
def clear(self) -> None:
"""Clear all captured data."""
self._d.clear()
def keys(self) -> set[str]:
"""Return a set of all captured data keys."""
return set(self._d.keys())
def items(self) -> dict[str, list]:
"""Return raw captured data as a dict of lists."""
return dict(self._d)
def __len__(self) -> int:
"""Return the number of captured steps (based on first key)."""
if not self._d:
return 0
first_key = next(iter(self._d))
return len(self._d[first_key])
def __str__(self) -> str:
return f"{self.__class__.__name__}({len(self)} Step(s) Captured)"
def __repr__(self) -> str:
return self.__str__()
def __del__(self) -> None:
"""Safely clean up resources during object deletion."""
if hasattr(self, "_d"):
self._d.clear()
@staticmethod
def get_public_keys(obj: object) -> set[str]:
"""Get all public (non-callable) attributes of an object."""
return {
name
for name in dir(obj)
if not name.startswith("_") and not callable(getattr(obj, name))
}
[docs]
class Wrapper(Simulation):
def __init__(self, *args, **kwargs) -> None:
from . import __version__
if __version__ >= "1.0.0":
msg = "Wrapper was removed in v1.0.0. Use Simulation instead."
raise RuntimeError(
msg,
)
warnings.warn(
f"Wrapper is deprecated and will be removed in v1.0.0 (Current: {__version__}). Use Simulation instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)