Source code for mujoco_toolbox.builder

import os
import xml.etree.ElementTree as StdET  # For element creation
from io import BytesIO
from pathlib import Path
from typing import Union
from xml.etree.ElementTree import (  # Import for type annotations
    Element,
    ElementTree,
)

import defusedxml.ElementTree as ET  # For secure parsing


[docs] class Builder: """Builder class for loading, merging, and saving MuJoCo models from XML strings or file paths.""" def __init__(self, *args: str) -> None: """Initialize with XML strings or file paths. Args: *args: XML strings or file paths. Raises: TypeError: If any input is not a string. ValueError: If no input is provided. """ for arg in args: if not isinstance(arg, str): msg = "Input must be an XML string or a file path" raise TypeError(msg) if len(args) == 0: msg = "Input is required to initialize the Builder" raise ValueError(msg) if len(args) > 1: # Merge multiple Builder instances self.tree, self.root = self._load_model(args[0]) for additional_arg in args[1:]: additional_builder = Builder(additional_arg) self.__add__(additional_builder) else: xml_input = args[0] self.tree, self.root = self._load_model(xml_input) def _create_element_tree(self, root: Element) -> ElementTree: """Create an ElementTree from a root element in a defusedxml-safe way. Args: root: Root element for the tree. Returns: ElementTree: A new ElementTree with the given root. """ # Convert to string and parse back to create a tree xml_string = StdET.tostring(root) return ET.parse(BytesIO(xml_string)) def _load_model(self, xml_input: str) -> tuple[ElementTree, Element]: """Load a MuJoCo model from a file or XML string. Args: xml_input: XML string or file path. Returns: tuple: ElementTree and root Element. Raises: TypeError: If input is not a string. ValueError: If XML is invalid. FileNotFoundError: If file path doesn't exist. """ if not isinstance(xml_input, str): msg = "Input must be either an XML string or a file path" raise TypeError(msg) from None # Check if input is an XML string if xml_input.strip().startswith("<"): # Try parsing as XML string try: root = ET.fromstring(xml_input) # Create tree using our safe method tree = self._create_element_tree(root) # If it's a URDF (robot tag), wrap it in mujoco if root.tag == "robot": return self._wrap_urdf_in_mujoco(tree) # If root is not mujoco, wrap it if root.tag != "mujoco": mujoco_elem = StdET.Element("mujoco") mujoco_elem.append(root) # Create tree using our safe method tree = self._create_element_tree(mujoco_elem) return tree, tree.getroot() except ET.ParseError: msg = "Invalid XML string provided" raise ValueError(msg) from None else: # Treat input as file path if not Path(xml_input).exists(): msg = f"File not found: {xml_input}" raise FileNotFoundError(msg) from None # Handle URDF files if xml_input.lower().endswith(".urdf"): return self._wrap_urdf_in_mujoco(ET.parse(xml_input)) # Parse regular XML file try: tree = ET.parse(xml_input) root = tree.getroot() # If root is not mujoco, wrap it if root.tag != "mujoco": mujoco_elem = StdET.Element("mujoco") mujoco_elem.append(root) # Create tree using our safe method tree = self._create_element_tree(mujoco_elem) return tree, tree.getroot() except ET.ParseError: msg = f"Invalid XML file: {xml_input}" raise ValueError(msg) from None def _wrap_urdf_in_mujoco( self, urdf_input: str | ElementTree, ) -> tuple[ElementTree, Element]: """Wrap a URDF inside <mujoco> tags for MuJoCo compatibility. Args: urdf_input: File path, ElementTree, or XML string. Returns: tuple: ElementTree and root Element. Raises: TypeError: If input is not a string or ElementTree. ValueError: If URDF is invalid. """ # Handle string input (could be a file path or XML string) if isinstance(urdf_input, str): # Check if it's an XML string if urdf_input.strip().startswith("<"): try: root = ET.fromstring(urdf_input) # Create tree using our safe method tree = self._create_element_tree(root) except ET.ParseError: msg = "Invalid URDF XML string provided" raise ValueError(msg) from None else: # It's a file path try: tree = ET.parse(urdf_input) except Exception: msg = f"Could not parse URDF file: {urdf_input}" raise ValueError(msg) from None # Handle ElementTree input using duck typing elif hasattr(urdf_input, "getroot") and callable(urdf_input.getroot): # Handle both standard ElementTree and defusedxml ElementTree tree = urdf_input # type: ignore else: msg = "URDF input must be a string or ElementTree-like object with getroot() method" raise TypeError(msg) root = tree.getroot() # If root is already present in a mujoco element if root.tag == "mujoco": return tree, root # Create mujoco element and append the URDF root mujoco_elem = StdET.Element("mujoco") mujoco_elem.append(root) # Create a new tree with the <mujoco> root using our safe method tree = self._create_element_tree(mujoco_elem) return tree, tree.getroot() def _merge_tags(self, tag_name: str, root_1: Element, root_2: Element) -> None: """Merge a specific tag (e.g., worldbody, asset) from two models. Args: tag_name: Tag name to merge. root_1: First model root element. root_2: Second model root element. """ section_1 = root_1.find(tag_name) section_2 = root_2.find(tag_name) # If target section doesn't exist in model 1 but exists in model 2, create it if section_1 is None and section_2 is not None: section_1 = StdET.SubElement(root_1, tag_name) # Only proceed if both sections exist now if section_1 is not None and section_2 is not None: # Merge the contents by appending all elements from model 2 to model 1 for element in list(section_2): # Use list() to avoid modification section_1.append(element) def __add__(self, other: Union["Builder", str]) -> "Builder": """Implement the + operator for merging two models. Args: other: Another Builder object or XML string. Returns: Builder: Self with merged content. Raises: TypeError: If other is not a Builder or string. """ model_root: Element | None = None if isinstance(other, Builder): # Merge the current model with another Builder instance model_root = other.root elif isinstance(other, str): # If the input is a string, load as XML _, model_root = self._load_model(other) else: msg = "Addition only supported between Builder objects or with XML strings/files" raise TypeError(msg) # List of sections to merge sections_to_merge = [ "asset", "worldbody", "camera", "light", "contact", "equality", "sensor", "actuator", "default", "tendon", "include", ] # Merge all relevant sections for section in sections_to_merge: self._merge_tags(section, self.root, model_root) # Return self to allow for method chaining return self def __radd__(self, other: Union["Builder", int]) -> "Builder": """Implement the reverse + operator for merging two models. Args: other: Another Builder object or 0 (for sum()). Returns: Builder: Self with merged content. """ if other == 0: # Handle sum() starting value return self return self.__add__(other)
[docs] def save(self, file_path: str) -> str: """Save the merged model to a file. Args: file_path: Path to save the model. Returns: str: Absolute path to the saved file. Raises: ValueError: If no model is loaded. """ if self.tree is not None: # Format the XML with proper indentation before saving self._indent_xml(self.root) self.tree.write(file_path, encoding="utf-8", xml_declaration=True) else: msg = "No model loaded. Cannot save." raise ValueError(msg) return os.path.abspath(file_path)
def _indent_xml(self, elem: Element, level: int = 0) -> None: """Add proper indentation to make the XML file more readable. Args: elem: XML element to indent. level: Indentation level. """ i = "\n" + level * " " if len(elem): if not elem.text or not elem.text.strip(): elem.text = i + " " if not elem.tail or not elem.tail.strip(): elem.tail = i for sub_elem in elem: self._indent_xml(sub_elem, level + 1) if not elem[-1].tail or not elem[-1].tail.strip(): elem[-1].tail = i elif level and (not elem.tail or not elem.tail.strip()): elem.tail = i def __str__(self) -> str: """Return the XML string of the model. Returns: str: XML string representation. Raises: ValueError: If no model is loaded. """ if self.tree is not None: # Format the XML with proper indentation self._indent_xml(self.root) # Use standard ElementTree's tostring with proper encoding return StdET.tostring(self.root, encoding="unicode", method="xml") msg = "No model loaded. Cannot generate string." raise ValueError(msg) def __repr__(self) -> str: """Return the XML string of the model. Returns: str: XML string representation. """ return self.__str__() def __len__(self) -> int: """Return the number of elements in the model. Returns: int: Number of elements in the model. """ return len(self.root) if self.root is not None else 0