Coverage for trimesh/scene/transforms.py: 93%
305 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-24 04:40 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-24 04:40 +0000
1import collections
2import itertools
3from copy import deepcopy
5import numpy as np
7from .. import caching, util
8from ..caching import hash_fast
9from ..transformations import fix_rigid, quaternion_matrix, rotation_matrix
10from ..typed import ArrayLike, Hashable, NDArray, Sequence
12# we compare to identity a lot
13_identity = np.eye(4)
14_identity.flags["WRITEABLE"] = False
17class SceneGraph:
18 """
19 Hold data about positions and instances of geometry
20 in a scene. This includes a forest (i.e. multi-root tree)
21 of transforms and information on which node is the base
22 frame, and which geometries are affiliated with which
23 nodes.
24 """
26 def __init__(self, base_frame="world", repair_rigid=1e-5):
27 """
28 Create a scene graph, holding homogeneous transformation
29 matrices and instance information about geometry.
31 Parameters
32 -----------
33 base_frame : any
34 The root node transforms will be positioned from.
35 repair_rigid : None or float
36 If a float will attempt to repair rotation matrices
37 where `M @ M.T` differs from an identity matrix by
38 more than floating point zero but less than this value.
39 This can happen in a deep tree with a lot of matrix
40 multiplies.
41 """
42 # a graph structure, subclass of networkx DiGraph
43 self.transforms = EnforcedForest()
44 # hashable, the base or root frame
45 self.base_frame = base_frame
46 # if passed as a float try to repair rigid transforms
47 # that have accumulated floating point error
48 self.repair_rigid = repair_rigid
49 # cache transformation matrices keyed with tuples
50 self._cache = caching.Cache(self.__hash__)
52 def update(self, frame_to, frame_from=None, **kwargs):
53 """
54 Update a transform in the tree.
56 Parameters
57 ------------
58 frame_from : hashable object
59 Usually a string (eg 'world').
60 If left as None it will be set to self.base_frame
61 frame_to : hashable object
62 Usually a string (eg 'mesh_0')
63 matrix : (4,4) float
64 Homogeneous transformation matrix
65 quaternion : (4,) float
66 Quaternion ordered [w, x, y, z]
67 axis : (3,) float
68 Axis of rotation
69 angle : float
70 Angle of rotation, in radians
71 translation : (3,) float
72 Distance to translate
73 geometry : hashable
74 Geometry object name, e.g. 'mesh_0'
75 metadata: dictionary
76 Optional metadata attached to the new frame
77 (exports to glTF node 'extras').
78 """
79 # if no frame specified, use base frame
80 if frame_from is None:
81 frame_from = self.base_frame
83 # pass through
84 attr = {k: v for k, v in kwargs.items() if k in {"geometry", "metadata"}}
85 # convert various kwargs to a single matrix
86 attr["matrix"] = kwargs_to_matrix(**kwargs)
88 # add the edges for the transforms
89 # wi ll return if it changed anything
90 self.transforms.add_edge(frame_from, frame_to, **attr)
92 # set the node attribute with the geometry information
93 if "geometry" in kwargs:
94 self.transforms.node_data[frame_to]["geometry"] = kwargs["geometry"]
96 def get(
97 self, frame_to: Hashable, frame_from: Hashable | None = None
98 ) -> tuple[NDArray[np.float64], Hashable | None]:
99 """
100 Get the transform from one frame to another.
102 Parameters
103 ------------
104 frame_to : hashable
105 Node name, usually a string (eg 'mesh_0')
106 frame_from : hashable
107 Node name, usually a string (eg 'world').
108 If None it will be set to self.base_frame
110 Returns
111 ----------
112 transform : (4, 4) float
113 Homogeneous transformation matrix
114 geometry
115 The name of the geometry if it exists
117 Raises
118 -----------
119 ValueError
120 If the frames aren't connected.
121 """
123 # use base frame if not specified
124 if frame_from is None:
125 frame_from = self.base_frame
127 # look up transform to see if we have it already
128 key = (frame_from, frame_to)
129 if key in self._cache:
130 return self._cache[key]
132 # get the geometry at the final node if any
133 geometry = self.transforms.node_data[frame_to].get("geometry")
135 # get a local reference to edge data
136 data = self.transforms.edge_data
138 if frame_from == frame_to:
139 # if we're going from ourself return identity
140 matrix = _identity
141 elif key in data:
142 # if the path is just an edge return early
143 matrix = data[key]["matrix"]
144 else:
145 # we have a 3+ node path
146 # get the path from the forest always going from
147 # parent -> child -> child
148 path = self.transforms.shortest_path(frame_from, frame_to)
149 # the path should always start with `frame_from`
150 assert path[0] == frame_from
151 # and end with the `frame_to` node
152 assert path[-1] == frame_to
154 # loop through pairs of the path
155 matrices = []
156 for u, v in itertools.pairwise(path):
157 forward = data.get((u, v))
158 if forward is not None:
159 if "matrix" in forward:
160 # append the matrix from u to v
161 matrices.append(forward["matrix"])
162 continue
163 # since forwards didn't exist backward must
164 # exist otherwise this is a disconnected path
165 # and we should raise an error anyway
166 backward = data[(v, u)]
167 if "matrix" in backward:
168 # append the inverted backwards matrix
169 matrices.append(np.linalg.inv(backward["matrix"]))
170 # filter out any identity matrices
171 matrices = [m for m in matrices if np.abs(m - _identity).max() > 1e-8]
172 if len(matrices) == 0:
173 matrix = _identity
174 elif len(matrices) == 1:
175 matrix = matrices[0]
176 else:
177 # multiply matrices into single transform
178 matrix = util.multi_dot(matrices)
180 # if instructed to repair rigid transforms do it here
181 if self.repair_rigid is not None:
182 matrix = fix_rigid(matrix, max_deviance=self.repair_rigid)
184 # matrix being edited in-place leads to subtle bugs
185 matrix.flags["WRITEABLE"] = False
187 # store the result
188 self._cache[key] = (matrix, geometry)
190 return matrix, geometry
192 def __hash__(self):
193 return self.transforms.__hash__()
195 def copy(self):
196 """
197 Return a copy of the current TransformForest.
199 Returns
200 ------------
201 copied : TransformForest
202 Copy of current object.
203 """
204 # create a copy without transferring cache
205 copied = SceneGraph()
206 copied.base_frame = deepcopy(self.base_frame)
207 copied.transforms = deepcopy(self.transforms)
208 return copied
210 def to_flattened(self):
211 """
212 Export the current transform graph with all
213 transforms baked into world->instance.
215 Returns
216 ---------
217 flat : dict
218 Keyed {node : {transform, geometry}
219 """
220 flat = {}
221 base_frame = self.base_frame
222 for node in self.nodes:
223 if node == base_frame:
224 continue
225 # get the matrix and geometry name
226 matrix, geometry = self.get(frame_to=node, frame_from=base_frame)
227 # store matrix as list rather than numpy array
228 flat[node] = {"transform": matrix.tolist(), "geometry": geometry}
230 return flat
232 def to_gltf(self, scene, mesh_index=None):
233 """
234 Export a transforms as the 'nodes' section of the
235 GLTF header dict.
237 Parameters
238 ------------
239 scene : trimesh.Scene
240 Scene with geometry.
241 mesh_index : dict or None
242 Mapping { key in scene.geometry : int }
244 Returns
245 --------
246 gltf : dict
247 With 'nodes' referencing a list of dicts
248 """
250 if mesh_index is None:
251 # geometry is an OrderedDict
252 # map mesh name to index: {geometry key : index}
253 mesh_index = {name: i for i, name in enumerate(scene.geometry.keys())}
255 # get graph information into local scope before loop
256 graph = self.transforms
257 # get the stored node data
258 node_data = graph.node_data
259 edge_data = graph.edge_data
260 base_frame = self.base_frame
262 # list of dict, in gltf format
263 # start with base frame as first node index
264 result = [{"name": base_frame}]
265 # {node name : node index in gltf}
266 lookup = {base_frame: 0}
268 # collect the nodes in order
269 for node in node_data.keys():
270 if node == base_frame:
271 continue
272 # assign the index to the node-name lookup
273 lookup[node] = len(result)
274 # populate a result at the correct index
275 result.append({"name": node})
277 # get generated properties outside of loop
278 # does the scene have a defined camera to export
279 has_camera = scene.has_camera
280 children = graph.children
282 extensions_used = set()
284 # then iterate through to collect data
285 for info in result:
286 # name of the scene node
287 node = info["name"]
289 # get the original node names for children
290 childs = children.get(node, [])
291 if len(childs) > 0:
292 info["children"] = [lookup[k] for k in childs]
294 # if we have a mesh store by index
295 if "geometry" in node_data[node]:
296 mesh_key = node_data[node]["geometry"]
297 if mesh_key in mesh_index:
298 info["mesh"] = mesh_index[mesh_key]
299 # check to see if we have camera node
300 if has_camera and node == scene.camera.name:
301 info["camera"] = 0
303 if node != base_frame:
304 parent = graph.parents[node]
305 node_edge = edge_data[(parent, node)]
307 # get the matrix from this edge
308 matrix = node_edge["matrix"]
309 # only include if it's not an identify matrix
310 if not util.allclose(matrix, _identity):
311 info["matrix"] = matrix.T.reshape(-1).tolist()
313 # if an extra was stored on this edge
314 extras = node_edge.get("metadata")
315 if extras:
316 extras = extras.copy()
318 # if extensionss were stored on this edge
319 extensions = extras.pop("gltf_extensions", None)
320 if isinstance(extensions, dict):
321 info["extensions"] = extensions
322 extensions_used = extensions_used.union(set(extensions.keys()))
324 # convert any numpy arrays to lists
325 extras.update(
326 {k: v.tolist() for k, v in extras.items() if hasattr(v, "tolist")}
327 )
328 info["extras"] = extras
330 gltf = {"nodes": result}
331 if len(extensions_used) > 0:
332 gltf["extensionsUsed"] = list(extensions_used)
333 return gltf
335 def to_edgelist(self):
336 """
337 Export the current transforms as a list of
338 edge tuples, with each tuple having the format:
339 (node_a, node_b, {metadata})
341 Returns
342 ---------
343 edgelist : (n,) list
344 Of edge tuples
345 """
346 # save local reference to node_data
347 nodes = self.transforms.node_data
348 # save cleaned edges
349 export = []
350 # loop through (node, node, edge attributes)
351 for edge, attr in self.transforms.edge_data.items():
352 # node indexes from edge
353 a, b = edge
354 # geometry is a node property but save it to the
355 # edge so we don't need two dictionaries
356 b_attr = nodes[b]
357 # make sure we're not stomping on original
358 attr_new = attr.copy()
359 # apply node geometry to edge attributes
360 if "geometry" in b_attr:
361 attr_new["geometry"] = b_attr["geometry"]
362 # convert any numpy arrays to regular lists
363 attr_new.update(
364 {k: v.tolist() for k, v in attr_new.items() if hasattr(v, "tolist")}
365 )
366 export.append([a, b, attr_new])
367 return export
369 def from_edgelist(self, edges, strict=True):
370 """
371 Load transform data from an edge list into the current
372 scene graph.
374 Parameters
375 -------------
376 edgelist : (n,) tuples
377 Keyed (node_a, node_b, {key: value})
378 strict : bool
379 If True raise a ValueError when a
380 malformed edge is passed in a tuple.
381 """
383 # loop through each edge
384 for edge in edges:
385 # edge contains attributes
386 if len(edge) == 3:
387 self.update(edge[1], edge[0], **edge[2])
388 # edge just contains nodes
389 elif len(edge) == 2:
390 self.update(edge[1], edge[0])
391 # edge is broken
392 elif strict:
393 raise ValueError("edge incorrect shape: %s", str(edge))
395 def to_networkx(self):
396 """
397 Return a `networkx` copy of this graph.
399 Returns
400 ----------
401 graph : networkx.DiGraph
402 Directed graph.
403 """
404 import networkx
406 return networkx.from_edgelist(self.to_edgelist(), create_using=networkx.DiGraph)
408 def show(self, **kwargs):
409 """
410 Plot the scene graph using `networkx.draw_networkx`
411 which uses matplotlib to display the graph.
413 Parameters
414 -----------
415 kwargs : dict
416 Passed to `networkx.draw_networkx`
417 """
418 import matplotlib.pyplot as plt # noqa
419 import networkx
421 # default kwargs will only be set if not
422 # passed explicitly to the show command
423 defaults = {"with_labels": True}
424 kwargs.update(**{k: v for k, v in defaults.items() if k not in kwargs})
425 networkx.draw_networkx(G=self.to_networkx(), **kwargs)
427 plt.show()
429 def load(self, edgelist):
430 """
431 Load transform data from an edge list into the current
432 scene graph.
434 Parameters
435 -------------
436 edgelist : (n,) tuples
437 Structured (node_a, node_b, {key: value})
438 """
439 self.from_edgelist(edgelist, strict=True)
441 @caching.cache_decorator
442 def nodes(self):
443 """
444 A list of every node in the graph.
446 Returns
447 -------------
448 nodes : (n,) array
449 All node names.
450 """
451 return self.transforms.nodes
453 @caching.cache_decorator
454 def nodes_geometry(self):
455 """
456 The nodes in the scene graph with geometry attached.
458 Returns
459 ------------
460 nodes_geometry : (m,) array
461 Node names which have geometry associated
462 """
463 return [n for n, attr in self.transforms.node_data.items() if "geometry" in attr]
465 @caching.cache_decorator
466 def geometry_nodes(self):
467 """
468 Which nodes have this geometry? Inverse
469 of `nodes_geometry`.
471 Returns
472 ------------
473 geometry_nodes : dict
474 Keyed {geometry_name : node name}
475 """
476 res = collections.defaultdict(list)
477 for node, attr in self.transforms.node_data.items():
478 if "geometry" in attr:
479 res[attr["geometry"]].append(node)
480 return res
482 def remove_geometries(self, geometries: str | set | Sequence):
483 """
484 Remove the reference for specified geometries
485 from nodes without deleting the node.
487 Parameters
488 ------------
489 geometries : list or str
490 Name of scene.geometry to dereference.
491 """
492 # make sure we have a set of geometries to remove
493 if isinstance(geometries, str):
494 geometries = [geometries]
495 geometries = set(geometries)
497 # remove the geometry reference from the node without deleting nodes
498 # this lets us keep our cached paths, and will not screw up children
499 for attrib in self.transforms.node_data.values():
500 if "geometry" in attrib and attrib["geometry"] in geometries:
501 attrib.pop("geometry")
503 # it would be safer to just run _cache.clear
504 # but the only property using the geometry should be
505 # nodes_geometry: if this becomes not true change this to clear!
506 self._cache.cache.pop("nodes_geometry", None)
507 self.transforms._hash = None
509 def __contains__(self, key: Hashable) -> bool:
510 return key in self.transforms.node_data
512 def __getitem__(self, key: Hashable) -> tuple[NDArray[np.float64], Hashable | None]:
513 return self.get(key)
515 def __setitem__(self, key: Hashable, value: ArrayLike):
516 value = np.asanyarray(value, dtype=np.float64)
517 if value.shape != (4, 4):
518 raise ValueError("Matrix must be specified!")
519 return self.update(key, matrix=value)
521 def clear(self):
522 self.transforms = EnforcedForest()
523 self._cache.clear()
526class EnforcedForest:
527 """
528 A simple forest graph data structure: every node
529 is allowed to have exactly one parent. This makes
530 traversal and implementation much simpler than a
531 full graph data type; by storing only one parent
532 reference, it enforces the structure for "free."
533 """
535 def __init__(self):
536 # since every node can have only one parent
537 # this data structure transparently enforces
538 # the forest data structure without checks
539 # a dict {child : parent}
540 self.parents = {}
542 # store data for a particular edge keyed by tuple
543 # {(u, v) : data }
544 self.edge_data = collections.defaultdict(dict)
545 # {u: data}
546 self.node_data = collections.defaultdict(dict)
548 # if multiple calls are made for the same path
549 # but the connectivity hasn't changed return cached
550 self._cache = {}
552 def add_edge(self, u, v, **kwargs):
553 """
554 Add an edge to the forest cleanly.
556 Parameters
557 -----------
558 u : any
559 Hashable node key.
560 v : any
561 Hashable node key.
562 kwargs : dict
563 Stored as (u, v) edge data.
565 Returns
566 --------
567 changed : bool
568 Return if this operation changed anything.
569 """
570 self._hash = None
572 # topology has changed so clear cache
573 if (u, v) not in self.edge_data:
574 self._cache = {}
575 else:
576 # check to see if matrix and geometry are identical
577 edge = self.edge_data[(u, v)]
578 if util.allclose(
579 kwargs.get("matrix", _identity), edge.get("matrix", _identity), 1e-8
580 ) and (edge.get("geometry") == kwargs.get("geometry")):
581 return False
583 # store a parent reference for traversal
584 self.parents[v] = u
585 # store kwargs for edge data keyed with tuple
586 self.edge_data[(u, v)] = kwargs
587 # set empty node data
588 self.node_data[u].update({})
589 if "geometry" in kwargs:
590 self.node_data[v].update({"geometry": kwargs["geometry"]})
591 else:
592 self.node_data[v].update({})
594 return True
596 def remove_node(self, u):
597 """
598 Remove a node from the forest.
600 Parameters
601 -----------
602 u : any
603 Hashable node key.
605 Returns
606 --------
607 changed : bool
608 Return if this operation changed anything.
609 """
610 # check if node is part of forest
611 if u not in self.node_data:
612 return False
614 # topology will change so clear cache
615 self._cache = {}
616 self._hash = None
618 # delete all children's references and parent reference
619 children = [child for (child, parent) in self.parents.items() if parent == u]
620 for c in children:
621 del self.parents[c]
622 if u in self.parents:
623 del self.parents[u]
625 # delete edge data
626 edges = [(a, b) for (a, b) in self.edge_data if a == u or b == u]
627 for e in edges:
628 del self.edge_data[e]
630 # delete node data
631 del self.node_data[u]
633 return True
635 def shortest_path(self, u, v):
636 """
637 Find the shortest path between `u` and `v`, returning
638 a path where the first element is always `u` and the
639 last element is always `v`, disregarding edge direction.
641 Parameters
642 -----------
643 u : any
644 Hashable node key.
645 v : any
646 Hashable node key.
648 Returns
649 -----------
650 path : (n,)
651 Path between `u` and `v`
652 """
653 # see if we've already computed this path
654 if u == v:
655 # the path between itself is an edge case
656 return []
657 elif (u, v) in self._cache:
658 # return the same path for either direction
659 return self._cache[(u, v)]
660 elif (v, u) in self._cache:
661 return self._cache[(v, u)][::-1]
663 # local reference to parent dict for performance
664 parents = self.parents
665 # store both forward and backwards traversal
666 forward = [u]
667 backward = [v]
669 # cap iteration to number of total nodes
670 for _ in range(len(parents) + 1):
671 # store the parent both forwards and backwards
672 f = parents.get(forward[-1])
673 b = parents.get(backward[-1])
674 forward.append(f)
675 backward.append(b)
677 if f == v:
678 self._cache[(u, v)] = forward
679 return forward
680 elif b == u:
681 # return reversed path
682 backward = backward[::-1]
683 self._cache[(u, v)] = backward
684 return backward
685 elif (b in forward) or (f is None and b is None):
686 # we have a either a common node between both
687 # traversal directions or we have consumed the whole
688 # tree in both directions so try to find the common node
689 common = set(backward).intersection(forward).difference({None})
690 if len(common) == 0:
691 raise ValueError(f"No path from {u}->{v}!")
692 elif len(common) > 1:
693 # get the first occurring common element in "forward"
694 link = next(f for f in forward if f in common)
695 assert link in common
696 else:
697 # take the only common element
698 link = next(iter(common))
700 # combine the forward and backwards traversals
701 a = forward[: forward.index(link) + 1]
702 b = backward[: backward.index(link)]
703 path = a + b[::-1]
705 # verify we didn't screw up the order
706 assert path[0] == u
707 assert path[-1] == v
709 self._cache[(u, v)] = path
711 return path
713 raise ValueError("Iteration limit exceeded!")
715 @property
716 def nodes(self):
717 """
718 Get a set of every node.
720 Returns
721 -----------
722 nodes : set
723 Every node currently stored.
724 """
725 return self.node_data.keys()
727 @property
728 def children(self):
729 """
730 Get the children of each node.
732 Returns
733 ----------
734 children : dict
735 Keyed {node : [child, child, ...]}
736 """
737 if "children" in self._cache:
738 return self._cache["children"]
739 child = collections.defaultdict(list)
740 # append children to parent references
741 # skip self-references to avoid a node loop
742 [child[v].append(u) for u, v in self.parents.items() if u != v]
744 # cache and return as a vanilla dict
745 self._cache["children"] = dict(child)
746 return self._cache["children"]
748 def successors(self, node):
749 """
750 Get all nodes that are successors to specified node,
751 including the specified node.
753 Parameters
754 -------------
755 node : any
756 Hashable key for a node.
758 Returns
759 ------------
760 successors : set
761 Nodes that succeed specified node.
762 """
763 # get mapping of {parent : child}
764 children = self.children
765 # if node doesn't exist return early
766 if node not in children:
767 return {node}
769 # children we need to collect
770 queue = [node]
771 # start collecting values with children of source
772 collected = set(queue)
774 # cap maximum iterations
775 for _ in range(len(self.node_data) + 1):
776 if len(queue) == 0:
777 # no more nodes to visit so we're done
778 return collected
779 # add the children of this node to be processed
780 childs = children.get(queue.pop())
781 if childs is not None:
782 queue.extend(childs)
783 collected.update(childs)
784 return collected
786 def __hash__(self):
787 """
788 Actually hash all of the data, but use a "dirty" mechanism
789 in functions that modify the data, which MUST
790 # all invalidate the hash by setting `self._hash = None`
792 This was optimized a bit, and is evaluating on an
793 older laptop on a scene with 77 nodes and 76 edges
794 10,000 times in 0.7s which seems fast enough.
795 """
796 # see if there is an available hash value
797 # if you are seeing cache bugs this is the thing
798 # to try eliminating because it is very likely that
799 # someone somewhere is modifying the data without
800 # setting `self._hash = None`
801 hashed = getattr(self, "_hash", None)
802 if hashed is not None:
803 return hashed
805 hashed = hash_fast(
806 (
807 "".join(
808 str(hash(k)) + v.get("geometry", "")
809 for k, v in self.edge_data.items()
810 )
811 + "".join(
812 str(k) + v.get("geometry", "") for k, v in self.node_data.items()
813 )
814 ).encode("utf-8")
815 + b"".join(
816 v["matrix"].tobytes() for v in self.edge_data.values() if "matrix" in v
817 )
818 )
819 self._hash = hashed
820 return hashed
823def kwargs_to_matrix(
824 matrix=None, quaternion=None, translation=None, axis=None, angle=None, **kwargs
825):
826 """
827 Take multiple keyword arguments and parse them
828 into a homogeneous transformation matrix.
830 Returns
831 ---------
832 matrix : (4, 4) float
833 Homogeneous transformation matrix.
834 """
835 if matrix is not None:
836 # a matrix takes immediate precedence over other options
837 return np.array(matrix, dtype=np.float64)
838 elif quaternion is not None:
839 matrix = quaternion_matrix(quaternion)
840 elif axis is not None and angle is not None:
841 matrix = rotation_matrix(angle, axis)
842 else:
843 matrix = np.eye(4)
845 if translation is not None:
846 # translation can be used in conjunction with any
847 # of the methods specifying transforms
848 matrix[:3, 3] += translation
850 return matrix