Source code for mujoco_toolbox.wrapper

"""Wrapper module for managing MuJoCo simulations.

This module provides a `Wrapper` class to handle MuJoCo simulations, including
loading models, running simulations, capturing data, and rendering frames.
"""

import os
import sys
import threading
import time
from collections import defaultdict
from collections.abc import Callable
from contextlib import nullcontext
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any, TypeAlias

import cv2
import defusedxml.ElementTree as ET
import mediapy as media
import mujoco
import mujoco.viewer
import numpy as np
import yaml
from IPython.display import clear_output
from tqdm.auto import tqdm

from .builder import Builder
from .loader import Loader

mjModel: TypeAlias = mujoco.MjModel  # pylint: disable=E1101  # noqa: N816
mjData: TypeAlias = mujoco.MjData  # pylint: disable=E1101  # noqa: N816

# pylint: disable=E1101
mujoco_object_types = [
    mujoco.mjtObj.mjOBJ_BODY,
    mujoco.mjtObj.mjOBJ_JOINT,
    mujoco.mjtObj.mjOBJ_GEOM,
    mujoco.mjtObj.mjOBJ_SITE,
    mujoco.mjtObj.mjOBJ_CAMERA,
    mujoco.mjtObj.mjOBJ_LIGHT,
    mujoco.mjtObj.mjOBJ_MESH,
    mujoco.mjtObj.mjOBJ_HFIELD,
    mujoco.mjtObj.mjOBJ_TEXTURE,
    mujoco.mjtObj.mjOBJ_MATERIAL,
    mujoco.mjtObj.mjOBJ_PAIR,
    mujoco.mjtObj.mjOBJ_EXCLUDE,
    mujoco.mjtObj.mjOBJ_EQUALITY,
    mujoco.mjtObj.mjOBJ_TENDON,
    mujoco.mjtObj.mjOBJ_ACTUATOR,
    mujoco.mjtObj.mjOBJ_SENSOR,
    mujoco.mjtObj.mjOBJ_NUMERIC,
    mujoco.mjtObj.mjOBJ_TEXT,
    mujoco.mjtObj.mjOBJ_KEY,
    mujoco.mjtObj.mjOBJ_PLUGIN,
]

PROGRESS_BAR_ENABLED = True


[docs] class Wrapper: """Wrapper class for managing MuJoCo simulations.""" def __new__(cls, *args: Any, **kwargs: Any) -> "Wrapper": if PROGRESS_BAR_ENABLED: os.system("clear || cls") # Clear the console clear_output(wait=True) return super().__new__(cls) # pylint: disable=E1101 def __init__( self, *xml_args: str | Builder, duration: int = 10, data_rate: int = 100, fps: int = 30, resolution: tuple[int, int] | None = None, initial_conditions: dict[str, list] | None = None, keyframe: int | None = None, controller: Callable[[mjModel, mjData, Any], None] | None = None, meshdir: str = "meshes/", **kwargs: Any, ) -> None: """Initialize the Wrapper class for managing MuJoCo simulations. Args: xml_args (str | Builder): One or more XML file paths, XML strings, or Builder objects defining the model. duration (int, optional): Duration of the simulation in seconds. Defaults to 10. data_rate (int, optional): Data capture rate in frames per second. Defaults to 100. fps (int, optional): Frames per second for rendering. Defaults to 30. resolution (tuple[int, int] | None, optional): Resolution of the simulation in pixels (width, height). If None, defaults to values from the XML or (400, 300). initial_conditions (dict[str, list] | None, optional): Initial conditions for the simulation. keyframe (int | None, optional): Keyframe index for resetting the simulation. controller (Callable[[mjModel, mjData, Any], None] | None, optional): Custom controller function for the simulation. meshdir (str, optional): Directory containing mesh files for URDF models. Defaults to "meshes/". **kwargs: Additional keyword arguments for model configuration. Raises: ValueError: If no XML arguments are provided. """ if not xml_args: msg = "At least one XML file, string, or Builder is required." raise ValueError(msg) # Separate Builders and strings builders = [arg for arg in xml_args if isinstance(arg, Builder)] strings = [arg for arg in xml_args if isinstance(arg, str)] # Start with the first builder if any if builders: builder = sum(builders[1:], builders[0]) # sum Builders if strings: builder += Builder(*strings) # merge str as Builder Instances else: builder = Builder(*strings) self._builder = builder # Load the model self._meshdir = meshdir loader = Loader(builder, meshdir) self.xml = loader.xml self._model = loader.model # Simulation Parameters self.duration = duration self.fps = fps self.data_rate = data_rate self.controller = controller self.resolution = resolution or self._extract_resolution() # Predefined simulation parameters but can be overridden # TODO(#8): @MGross21 Currently Causing Bugs when occluded from XML self.ts = kwargs.get("ts", self._model.opt.timestep) self.gravity = kwargs.get("gravity", self._model.opt.gravity) self._data = mujoco.MjData(self._model) self._keyframe = keyframe self.initial_conditions = initial_conditions or {} # **after data** # Initialize _frames and _captured_data attributes self._frames: list[np.ndarray] | None = None self._captured_data: _SimulationData | None = None self._initialize_names() def _initialize_names(self) -> None: """Populate body, joint, and actuator names.""" self.body_names = [ self._model.body(i).name for i in range(self._model.nbody) ] self.geom_names = [ self._model.geom(i).name for i in range(self._model.ngeom) ] self.joint_names = [ self._model.joint(i).name for i in range(self._model.njnt) ] self.actuator_names = [ self._model.actuator(i).name for i in range(self._model.nu) ] def _extract_resolution(self) -> tuple[int, int]: """Extract resolution from the XML or return default values.""" try: root = ET.fromstring(self.xml) global_tag = root.find("visual/global") if global_tag is not None: offwidth = int(global_tag.get("offwidth", 400)) offheight = int(global_tag.get("offheight", 300)) return (offwidth, offheight) except (ET.ParseError, ValueError, TypeError): pass return (400, 300)
[docs] def reload(self: "Wrapper") -> "Wrapper": """Reload the model and data objects. Returns: Wrapper: Self for method chaining. """ # Use the Loader to handle model reloading loader = Loader(self.xml, meshdir=self._meshdir) self._model = loader.model self._data = mujoco.MjData(self._model) self._initialize_names() # Reinitialize names and apply ic's # Apply initial conditions for key, value in getattr(self, "init_conditions", {}).items(): if hasattr(self._data, key): setattr(self._data, key, value) return self
def __str__(self) -> str: # noqa: D105 return self._model.__str__() def __repr__(self) -> str: # noqa: D105 MAX_LINE_ITEMS = 5 # noqa: N806 # pylint: disable=C0103 # Limit the number of items displayed in the string representation body_names = self.body_names[:MAX_LINE_ITEMS] + ( ["..."] if len(self.body_names) > MAX_LINE_ITEMS else [] ) joint_names = self.joint_names[:MAX_LINE_ITEMS] + ( ["..."] if len(self.joint_names) > MAX_LINE_ITEMS else [] ) actuator_names = self.actuator_names[:MAX_LINE_ITEMS] + ( ["..."] if len(self.actuator_names) > MAX_LINE_ITEMS else [] ) # Format the string representation return ( f"{self.__class__.__name__}(\n" f" Duration: {self.duration}s " f"[fps={self.fps}, ts={self.ts:.0e}]\n" f" Gravity: {self.gravity},\n" f" Resolution: {self.resolution[0]}W x {self.resolution[1]}H\n" f" Bodies ({self.model.nbody}): {', '.join(body_names)}\n" f" Joints ({self.model.njnt}): {', '.join(joint_names)}\n" f" Actuators ({self.model.nu}): {', '.join(actuator_names)}\n" f" Controller: " f"{self.controller.__name__ if self.controller else None}\n" f")" ) def __enter__(self) -> "Wrapper": # noqa: D105 return self def __exit__(self: "Wrapper", *args, **kwargs) -> None: # noqa: D105 mujoco.set_mjcb_control(None) for thread in threading.enumerate(): if thread is not threading.main_thread(): thread.join() @property def model(self) -> mjModel: """Read-only property to access the MjModel object.""" return self._model @property def data(self) -> mjData: """Read-only property to access the MjData single-step object. Use `captured_data` to access the entire simulation data. """ return self._data @property def keyframe(self) -> int | None: """Keyframe index for the simulation.""" return self._keyframe @keyframe.setter def keyframe(self, value: int | None) -> None: if value is not None and not isinstance(value, int): msg = "Keyframe must be an integer." raise ValueError(msg) self._keyframe = value @property def captured_data(self) -> dict[str, np.ndarray]: """Read-only property to access the entire captured simulation data.""" if self._captured_data is None: msg = "No simulation data captured yet." raise ValueError(msg) return self._captured_data.unwrap() @captured_data.deleter def captured_data(self) -> None: self._captured_data = None @property def frames(self) -> list[np.ndarray]: """Read-only property to access the captured frames.""" if not hasattr(self, "_frames") or self._frames is None: msg = ( "No frames captured yet. " "Run the simulation with render=True to capture frames." ) raise AttributeError(msg) from None return self._frames @frames.deleter def frames(self) -> None: # Safely delete frames if they exist if hasattr(self, "_frames") and isinstance(self._frames, list): self._frames.clear() self._frames = None import gc gc.collect() @property def duration(self) -> float: """Duration of the simulation in seconds.""" return self._duration @duration.setter def duration(self, value: float) -> None: if value < 0: msg = "Duration must be greater than zero." raise ValueError(msg) self._duration = value @property def fps(self) -> float: """Frames per second.""" return self._fps @fps.setter def fps(self, value: float) -> None: if value < 0: msg = "FPS must be greater than zero." raise ValueError(msg) self._fps = value @property def resolution(self) -> tuple[int, int]: """Resolution of the simulation in pixels (w,h).""" return ( self._model.vis.global_.offwidth, self._model.vis.global_.offheight, ) @resolution.setter def resolution(self, values: tuple[int, int]) -> None: if not (isinstance(values, tuple) and len(values) == 2): msg = "Resolution must be a tuple of width and height." raise ValueError(msg) if any(v < 1 for v in values): msg = "Resolution must be at least 1x1 pixels." raise ValueError(msg) # Update the resolution using the correct mujoco attributes self._model.vis.global_.offwidth = int(values[0]) self._model.vis.global_.offheight = int(values[1]) @property def initial_conditions(self) -> dict[str, list]: """Initial conditions for the simulation.""" return self.init_conditions @initial_conditions.setter def initial_conditions(self, values: dict[str, list]) -> None: if not isinstance(values, dict): msg = "Initial conditions must be a dictionary." raise TypeError(msg) # Cache data object and attribute names data = self._data valid_attrs = _SimulationData.get_public_keys(data) # Find any invalid keys invalid_keys = [key for key in values if key not in valid_attrs] if invalid_keys: msg = ( f"Invalid initial condition attributes: {', '.join(invalid_keys)}.\n" f"Valid attributes include: {', '.join(valid_attrs)}" ) raise ValueError( msg, ) # Save and apply self.init_conditions = values for k, v in values.items(): setattr(data, k, v) @property def controller(self) -> Callable[[mjModel, mjData, Any], None] | None: """Controller Function.""" return self._controller @controller.setter def controller(self, func: Callable[[mjModel, mjData, Any], None]) -> None: if func is not None and not callable(func): msg = "Controller must be a callable function." raise ValueError(msg) self._controller = func @property def ts(self) -> float: """Timestep of the simulation in seconds.""" return self._model.opt.timestep @ts.setter def ts(self, value: int) -> None: if value <= 0: msg = "Timestep must be greater than 0." raise ValueError(msg) self._model.opt.timestep = value @property def data_rate(self) -> int: """Data rate of the simulation in frames per second.""" return self._dr @data_rate.setter def data_rate(self, value: int) -> None: if not isinstance(value, int): msg = "Data rate must be an integer." raise ValueError(msg) if value <= 0: msg = "Data rate must be greater than 0." raise ValueError(msg) max_rate = int(self._duration / self.ts) if value > max_rate: msg = f"{value} exceeds the maximum data rate of {max_rate}." raise ValueError(msg) self._dr = value @property def gravity(self) -> np.ndarray: """Gravity vector of the simulation.""" return self._model.opt.gravity # pylint: disable=E1101 @gravity.setter def gravity(self, values: list | tuple | np.ndarray) -> None: if ( not isinstance(values, (list, tuple, np.ndarray)) or len(values) != 3 ): msg = "Gravity must be a 3D vector." raise ValueError(msg) self._model.opt.gravity = np.array(values)
[docs] def run( self, *, render: bool = False, camera: str | None = None, interactive: bool = False, show_menu: bool = True, # TODO@MGross21: Implement this with launch multi_thread: bool = False, ) -> "Wrapper": """Run the simulation with optional rendering. Args: render (bool): If True, renders the simulation. camera (str): The camera view to render from, defaults to None. data_rate (int): How often to capture data, expressed as frames per second. interactive (bool): If True, opens an interactive viewer window. show_menu (bool): Shows the menu in the interactive viewer. `Interactive` must be True. multi_thread (bool): If True, runs the simulation in multi-threaded mode. Returns: self: The current Wrapper object for method chaining. """ # TODO: Integrate interactive mujoco.viewer into this method # Eventually rename this to run() and point to sub-methods if interactive: msg = "Interactive mode (w/ menu option) is not yet implemented." raise NotImplementedError(msg) if multi_thread: msg = "Multi-threading is not yet implemented." raise NotImplementedError(msg) try: mujoco.mj_resetData(self._model, self._data) if self._controller is not None: mujoco.set_mjcb_control(self._controller) if self._keyframe is not None: mujoco.mj_resetDataKeyframe( self._model, self._data, self._keyframe, ) sim_data = _SimulationData() # Cache frequently used functions and objects for performance mj_step1 = mujoco.mj_step1 mj_step2 = mujoco.mj_step2 m, d = self._model, self._data # dur = self._duration # Simulation Timing total_steps = int(self._duration / self.ts) # capture_rate = self.data_rate * self.ts capture_interval = max(1, int(1.0 / (self._dr * self.ts))) # PEMDAS :) # RENDERING PREPARATIONS if render: w, h = self.resolution render_interval = max(1, int(1.0 / (self._fps * self.ts))) max_frames = int(self._duration * self._fps) frames = np.zeros((max_frames, h, w, 3), dtype=np.uint8) frame_count = 0 if multi_thread: cpu_count() # TODO: Implement multi-threading # if interactive: # gui = threading.Thread(target=self._window, kwargs={"show_menu": show_menu}) # noqa: ERA001 # gui.start() # Mujoco Renderer from . import MAX_GEOM_SCALAR # pylint: disable=E0405 max_geom = m.ngeom * MAX_GEOM_SCALAR _Renderer = ( # noqa: N806, pylint: disable=C0103 mujoco.Renderer(m, h, w, max_geom) if render else nullcontext() ) _ProgressBar = ( tqdm(total=total_steps, desc="Simulation", leave=False) if PROGRESS_BAR_ENABLED else nullcontext() ) with _Renderer as renderer, _ProgressBar as pbar: for step in range(total_steps): mj_step1(m, d) # Capture data at the specified rate if step % capture_interval == 0: sim_data.capture(d) if render and renderer and step % render_interval == 0 and frame_count < max_frames: renderer.update_scene(d, camera if camera else -1) frames[frame_count] = renderer.render() frame_count += 1 # Increment frame count after capturing the frame mj_step2(m, d) pbar.update(1) if PROGRESS_BAR_ENABLED else None except Exception as e: msg = "An error occurred while running the simulation." raise RuntimeError(msg) from e finally: mujoco.set_mjcb_control(None) self._captured_data = sim_data self._frames = frames[:frame_count] if render else None # if interactive: # gui.join() return self
def _window(self, show_menu: bool = True) -> None: # noqa: FBT001, FBT002 """Open a window to display the simulation in real time.""" try: m = self._model d = self._data def key_callback(key: int) -> bool: return key in (27, ord("q")) # 27 = ESC key, 'q' to quit # NOTE: launch_passive may blocking _Viewer = mujoco.viewer.launch_passive( # noqa: N806, pylint: disable=C0103 m, d, show_left_ui=show_menu, show_right_ui=show_menu, key_callback=key_callback, ) with _Viewer as viewer: viewer.sync() start_time = time.time() try: while viewer.is_running(): current_time = time.time() dt = current_time - start_time mujoco.mj_step(m, d) # Advance simulation by one step viewer.sync() # Sync the viewer start_time = current_time # Reset reference time time.sleep(max(0, 1.0 / self.fps - dt)) # self.fps except KeyboardInterrupt: viewer.close() except Exception as e: msg = "An error occurred while running the simulation." raise RuntimeError(msg) from e finally: mujoco.set_mjcb_control(None)
[docs] def launch(self, show_menu: bool = True) -> None: # noqa: FBT001, FBT002 """Open a window to display the simulation in real time.""" # Run the window in a separate thread gui = threading.Thread( target=self._window, kwargs={"show_menu": show_menu}, ) gui.start()
def _get_index( # noqa: C901 self, frame_idx: int | tuple[int, int] | None = None, time_idx: float | tuple[float, float] | None = None, ) -> list[Any]: """Validate and extract frames based on frame or time indices. Args: frame_idx (int or tuple, optional): Single frame index or (start, stop) frame indices. time_idx (float or tuple, optional): Single time or (start, end) times in seconds. Returns: List of frames Raises: ValueError: For invalid input parameters. """ if self._frames is None or len(self._frames) == 0: msg = "No frames captured to render." raise ValueError(msg) # Validate input parameters if frame_idx is not None and time_idx is not None: msg = "Can only specify either frame_idx or time_idx, not both." raise ValueError(msg) # If both are None, use all frames if frame_idx is None and time_idx is None: return self._frames # Handle time index conversion if time_idx is not None: if isinstance(time_idx, (int, float)): frame_idx = self.t2f(time_idx) elif isinstance(time_idx, tuple): frame_idx = (self.t2f(time_idx[0]), self.t2f(time_idx[1])) else: msg = "time_idx must be a number or a tuple of numbers." raise ValueError(msg) # Convert single index to tuple range if isinstance(frame_idx, (int, float)): frame_idx = (frame_idx, frame_idx + 1) # Validate frame indices if frame_idx is None: msg = "frame_idx cannot be None when unpacking." raise ValueError(msg) start, stop = frame_idx max_frames = len(self._frames) if start < 0: msg = f"Start index must be non-negative. Got {start}." raise ValueError(msg) if stop > max_frames: msg = ( f"Stop index must not exceed total frames ({max_frames}). " f"Got {stop}." ) raise ValueError(msg) if start >= stop: msg = ( f"Start index ({start}) must be less than stop index ({stop})." ) raise ValueError(msg) # Select subset of frames return self._frames[start:stop]
[docs] def show( self, title: str | None = None, codec: str = "gif", frame_idx: int | tuple[int, int] | None = None, time_idx: float | tuple[float, float] | None = None, ) -> None: """Render specific frame(s) as a video or GIF in a window. Args: title (str, optional): Title for the rendered media. codec (str, optional): Video codec/format. Defaults to "gif". frame_idx (int or tuple, optional): Single frame index or (start, stop) frame indices. time_idx (float or tuple, optional): Single time or (start, end) times in seconds. Raises: ValueError: If no frames are captured or invalid input parameters. """ if not hasattr(self, "_frames") or self._frames is None or self._frames.size == 0: msg = "No frames captured to render. Re-run the simulation with render=True." raise ValueError(msg) try: # Extract frames subset_frames = self._get_index( frame_idx=frame_idx, time_idx=time_idx, ) def is_jupyter() -> bool: try: from IPython import get_ipython return "ipykernel" in sys.modules or "IPKernelApp" in get_ipython().config except Exception: return False if is_jupyter(): # Show the video media.show_video( subset_frames, fps=1 if len(subset_frames) == 1 else self._fps, width=self.resolution[0], height=self.resolution[1], codec=codec, title=title, ) else: for frame in subset_frames: cv2.imshow("Video", frame) if cv2.waitKey(int(1000 / self._fps)) & 0xFF == ord("q"): break cv2.waitKey(1) cv2.destroyAllWindows() except Exception as e: msg = "Error while showing video subset." raise Exception(msg) from e # noqa: TRY002
[docs] def save( self, title: str = "render", codec: str = "gif", frame_idx: int | tuple[int, int] | None = None, time_idx: float | tuple[float, float] | None = None, ) -> str: """Save specific frame(s) as a video or GIF to a file. Args: title (str, optional): Filename for the saved media. codec (str, optional): Video codec/format. Defaults to "gif". frame_idx (int or tuple, optional): Single frame index or (start, stop) frame indices. time_idx (float or tuple, optional): Single time or (start, end) times in seconds. Returns: str: Absolute path to the saved file. Raises: ValueError: If no frames are captured or invalid input parameters. """ if not hasattr(self, "_frames") or self._frames is None or self._frames.size == 0: msg = "No frames captured to render. Re-run the simulation with render=True." raise ValueError(msg) try: # Extract frames subset_frames = self._get_index( frame_idx=frame_idx, time_idx=time_idx, ) # Ensure the title ends with the correct codec extension title_path = Path(title) if title_path.suffix != f".{codec}": title_path = title_path.with_suffix(f".{codec}") # Save the video media.write_video( str(title_path), subset_frames, fps=1 if len(subset_frames) == 1 else self._fps, codec=codec, ) return str(title_path.resolve()) except Exception as e: msg = "Error while saving video subset." raise Exception(msg) from e # noqa: TRY002
[docs] def t2f(self, t: float) -> int: """Convert time to frame index.""" return min( int(t * self._fps), int(self._duration * self._fps) - 1, ) # Subtract 1 to convert to 0-based index
[docs] def f2t(self, frame: int) -> float: """Convert frame index to time.""" return frame / self._fps
[docs] def body_data( self, body_name: str, data_name: str | None = None, ) -> dict[str, np.ndarray] | np.ndarray: """Get the data for a specific body in the simulation. Args: body_name (str): The name of the body to retrieve data for. data_name (str): The name of the data to retrieve. Returns: dict[str, np.ndarray] | np.ndarray: The data for the specified body. """ if body_name not in self.body_names: msg = f"Body '{body_name}' not found in the model." raise ValueError(msg) body_id = self._model.body(body_name).id if self._captured_data is None: msg = "No simulation data captured yet." raise ValueError(msg) unwrapped_data = self._captured_data.unwrap() if data_name is None: return unwrapped_data.get(body_id, np.array([])) if data_name not in unwrapped_data: msg = f"Data '{data_name}' not found for body '{body_name}'." raise ValueError(msg) if isinstance(unwrapped_data[body_id], dict): return unwrapped_data[body_id].get(data_name, None) msg = f"Data for body_id '{body_id}' is not a dictionary." raise ValueError(msg)
[docs] def name2id(self, name: str) -> int: """Get the name of a body given its index. Args: name (str): The name of the body. Returns: int: The index of the body. """ for obj_type in mujoco_object_types: try: obj_id = mujoco.mj_name2id(self._model, obj_type, name) if obj_id >= 0: return obj_id except (mujoco.FatalError, mujoco.UnexpectedError, Exception): # noqa: PERF203, S112 continue msg = f"Object with name '{name}' not found." raise ValueError(msg)
[docs] def id2name(self, id: int) -> str: """Get the name of a body given its ID. Args: id (int): The ID of the body. Returns: str: The name of the body. """ # BUG: Fix this to work properly msg = "This method is not implemented yet." raise NotImplementedError(msg) for obj_type in mujoco_object_types: try: obj_name = mujoco.mj_id2name(self._model, obj_type, id) if obj_name is None: continue return obj_name except (mujoco.FatalError, mujoco.UnexpectedError, Exception): # noqa: S112 continue msg = f"ID '{id}' not found." raise ValueError(msg)
[docs] def save_yaml(self, name: str = "Model") -> None: """Save simulation data to a YAML file. Args: name (str): The filename for the YAML file. Returns: None """ if not name.endswith(".yml"): name += ".yml" try: # Convert simData's NumPy arrays or lists to a YAML-friendly format serialized_data = { k: (v.tolist() if isinstance(v, np.ndarray) else v) for k, v in self.captured_data.items() } with Path(name).open("w", encoding="utf-8") as f: yaml.dump(serialized_data, f, default_flow_style=False) except Exception as e: msg = f"Failed to save simulation data to '{name}'" raise ValueError(msg) from e
class _SimulationData: """A class to store and manage simulation data.""" __slots__ = ["_d"] def __init__(self) -> None: self._d: dict[str, list] = defaultdict(list) def capture(self, mj_data) -> None: """Capture data from MjData, storing specified or all public attributes.""" from . import CAPTURE_PARAMETERS keys = ( self.get_public_keys(mj_data) if CAPTURE_PARAMETERS == "all" else CAPTURE_PARAMETERS ) for key in keys: value = getattr(mj_data, key, None) if value is None: continue if isinstance(value, np.ndarray): self._d[key].append(value.copy()) elif np.isscalar(value): self._d[key].append(value) elif hasattr(value, "copy") and callable(value.copy): self._d[key].append(value.copy()) else: self._d[key].append(value) def unwrap(self) -> dict[str, np.ndarray]: """Unwrap simulation data into a structured format with NumPy arrays. Returns: dict[str, np.ndarray]: Unwrapped data for each key. """ unwrapped_data = {} for key, value_list in self._d.items(): if not value_list: unwrapped_data[key] = np.array([]) continue first = value_list[0] try: if isinstance(first, np.ndarray): shape = first.shape if all(v.shape == shape for v in value_list): unwrapped_data[key] = np.stack(value_list) else: unwrapped_data[key] = value_list # Inconsistent shapes else: unwrapped_data[key] = np.array(value_list) except (ValueError, TypeError): unwrapped_data[key] = value_list # Fallback return unwrapped_data @property def shape(self) -> dict[str, tuple]: """Return the shape of the captured data per key.""" if not self._d: return {} shapes: dict[str, tuple[Any, ...]] = {} for key, value_list in self._d.items(): if not value_list: shapes[key] = () continue first_value = value_list[0] if isinstance(first_value, np.ndarray): shapes[key] = (len(value_list), *first_value.shape) elif isinstance(first_value, list) and all(isinstance(v, list) for v in value_list): shapes[key] = (len(value_list), len(first_value)) else: shapes[key] = (len(value_list),) return shapes def clear(self) -> None: """Clear all captured data.""" self._d.clear() def keys(self) -> set[str]: """Return a set of all captured data keys.""" return set(self._d.keys()) def items(self) -> dict[str, list]: """Return raw captured data as a dict of lists.""" return dict(self._d) def __len__(self) -> int: """Return the number of captured steps (based on first key).""" if not self._d: return 0 first_key = next(iter(self._d)) return len(self._d[first_key]) def __str__(self) -> str: return f"{self.__class__.__name__}({len(self)} Step(s) Captured)" def __repr__(self) -> str: return self.__str__() def __del__(self) -> None: """Safely clean up resources during object deletion.""" if hasattr(self, "_d"): self._d.clear() @staticmethod def get_public_keys(obj: object) -> set[str]: """Get all public (non-callable) attributes of an object.""" return { name for name in dir(obj) if not name.startswith("_") and not callable(getattr(obj, name)) }