Coverage for trimesh/voxel/encoding.py: 78%
612 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-24 04:40 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-24 04:40 +0000
1"""OO interfaces to encodings for ND arrays which caching."""
3import abc
5import numpy as np
7from .. import caching
8from ..util import ABC
9from . import runlength
11try:
12 from scipy import sparse as sp
13except BaseException as E:
14 from ..exceptions import ExceptionWrapper
16 sp = ExceptionWrapper(E)
19def _empty_stripped(shape):
20 num_dims = len(shape)
21 encoding = DenseEncoding(np.zeros(shape=(0,) * num_dims, dtype=bool))
22 padding = np.zeros(shape=(num_dims, 2), dtype=int)
23 padding[:, 1] = shape
24 return encoding, padding
27class Encoding(ABC):
28 """
29 Base class for objects that implement a specific subset of of ndarray ops.
31 This presents a unified interface for various different ways of encoding
32 conceptually dense arrays and to interoperate between them.
34 Example implementations are ND sparse arrays, run length encoded arrays
35 and dense encodings (wrappers around np.ndarrays).
36 """
38 def __init__(self, data):
39 self._data = data
40 self._cache = caching.Cache(id_function=self._data.__hash__)
42 @property
43 @abc.abstractmethod
44 def dtype(self):
45 pass
47 @property
48 @abc.abstractmethod
49 def shape(self):
50 pass
52 @property
53 @abc.abstractmethod
54 def sum(self):
55 pass
57 @property
58 @abc.abstractmethod
59 def size(self):
60 pass
62 @property
63 @abc.abstractmethod
64 def sparse_indices(self):
65 pass
67 @property
68 @abc.abstractmethod
69 def sparse_values(self):
70 pass
72 @property
73 @abc.abstractmethod
74 def dense(self):
75 pass
77 @abc.abstractmethod
78 def gather_nd(self, indices):
79 pass
81 @abc.abstractmethod
82 def mask(self, mask):
83 pass
85 @abc.abstractmethod
86 def get_value(self, index):
87 pass
89 @abc.abstractmethod
90 def copy(self):
91 pass
93 @property
94 def is_empty(self):
95 return self.sparse_indices[self.sparse_values != 0].size == 0
97 @caching.cache_decorator
98 def stripped(self):
99 """
100 Get encoding with all zeros stripped from the start and end
101 of each axis.
103 Returns
104 ------------
105 encoding: ?
106 padding : (n, 2) int
107 Padding at the start and end that was stripped
108 """
109 if self.is_empty:
110 return _empty_stripped(self.shape)
111 dense = self.dense
112 shape = dense.shape
113 ndims = len(shape)
114 padding = []
115 slices = []
116 for dim, size in enumerate(shape):
117 axis = tuple(range(dim)) + tuple(range(dim + 1, ndims))
118 filled = np.any(dense, axis=axis)
119 (indices,) = np.nonzero(filled)
120 lower = indices.min()
121 upper = indices.max() + 1
122 padding.append([lower, size - upper])
123 slices.append(slice(lower, upper))
124 return DenseEncoding(dense[tuple(slices)]), np.array(padding, int)
126 def _flip(self, axes):
127 return FlippedEncoding(self, axes)
129 def __hash__(self):
130 """
131 Get the hash of the current transformation matrix.
133 Returns
134 ------------
135 hash : str
136 Hash of transformation matrix
137 """
138 return self._data.__hash__()
140 @property
141 def ndims(self):
142 return len(self.shape)
144 def reshape(self, shape):
145 return self.flat if len(shape) == 1 else ShapedEncoding(self, shape)
147 @property
148 def flat(self):
149 return FlattenedEncoding(self)
151 def flip(self, axis=0):
152 return _flipped(self, axis)
154 @property
155 def sparse_components(self):
156 return self.sparse_indices, self.sparse_values
158 @property
159 def data(self):
160 return self._data
162 def run_length_data(self, dtype=np.int64):
163 if self.ndims != 1:
164 raise ValueError("`run_length_data` only valid for flat encodings")
165 return runlength.dense_to_rle(self.dense, dtype=dtype)
167 def binary_run_length_data(self, dtype=np.int64):
168 if self.ndims != 1:
169 raise ValueError("`run_length_data` only valid for flat encodings")
170 return runlength.dense_to_brle(self.dense, dtype=dtype)
172 def transpose(self, perm):
173 return _transposed(self, perm)
175 def _transpose(self, perm):
176 return TransposedEncoding(self, perm)
178 @property
179 def mutable(self):
180 return self._data.mutable
182 @mutable.setter
183 def mutable(self, value):
184 self._data.mutable = value
187class DenseEncoding(Encoding):
188 """Simple `Encoding` implementation based on a numpy ndarray."""
190 def __init__(self, data):
191 if not isinstance(data, caching.TrackedArray):
192 if not isinstance(data, np.ndarray):
193 raise ValueError("DenseEncoding data must be a numpy array")
194 data = caching.tracked_array(data)
195 super().__init__(data=data)
197 @property
198 def dtype(self):
199 return self._data.dtype
201 @property
202 def shape(self):
203 return self._data.shape
205 @caching.cache_decorator
206 def sum(self):
207 return self._data.sum()
209 @caching.cache_decorator
210 def is_empty(self):
211 return not np.any(self._data)
213 @property
214 def size(self):
215 return self._data.size
217 @property
218 def sparse_components(self):
219 indices = self.sparse_indices
220 values = self.gather(indices)
221 return indices, values
223 @caching.cache_decorator
224 def sparse_indices(self):
225 return np.column_stack(np.where(self._data))
227 @caching.cache_decorator
228 def sparse_values(self):
229 return self.sparse_components[1]
231 def _flip(self, axes):
232 dense = self.dense
233 for a in axes:
234 dense = np.flip(dense, a)
235 return DenseEncoding(dense)
237 @property
238 def dense(self):
239 return self._data
241 def gather(self, indices):
242 return self._data[indices]
244 def gather_nd(self, indices):
245 return self._data[tuple(indices.T)]
247 def mask(self, mask):
248 return self._data[mask if isinstance(mask, np.ndarray) else mask.dense]
250 def get_value(self, index):
251 return self._data[tuple(index)]
253 def reshape(self, shape):
254 return DenseEncoding(self._data.reshape(shape))
256 def _transpose(self, perm):
257 return DenseEncoding(self._data.transpose(perm))
259 @property
260 def flat(self):
261 return DenseEncoding(self._data.reshape((-1,)))
263 def copy(self):
264 return DenseEncoding(self._data.copy())
267class SparseEncoding(Encoding):
268 """
269 `Encoding` implementation based on an ND sparse implementation.
271 Since the scipy.sparse implementations are for 2D arrays only, this
272 implementation uses a single-column CSC matrix with index
273 raveling/unraveling.
274 """
276 def __init__(self, indices, values, shape=None):
277 """
278 Parameters
279 ------------
280 indices: (m, n)-sized int array of indices
281 values: (m, n)-sized dtype array of values at the specified indices
282 shape: (n,) iterable of integers. If None, the maximum value of indices
283 + 1 is used.
284 """
285 data = caching.DataStore()
286 super().__init__(data)
287 data["indices"] = indices
288 data["values"] = values
289 indices = data["indices"]
290 if len(indices.shape) != 2:
291 raise ValueError(f"indices must be 2D, got shaped {indices.shape!s}")
292 if data["values"].shape != (indices.shape[0],):
293 raise ValueError(
294 "values and indices shapes inconsistent: {} and {}".format(
295 data["values"], data["indices"]
296 )
297 )
298 if shape is None:
299 self._shape = tuple(data["indices"].max(axis=0) + 1)
300 else:
301 self._shape = tuple(shape)
302 if not np.all(indices < self._shape):
303 raise ValueError("all indices must be less than shape")
304 if not np.all(indices >= 0):
305 raise ValueError("all indices must be non-negative")
307 @staticmethod
308 def from_dense(dense_data):
309 sparse_indices = np.where(dense_data)
310 values = dense_data[sparse_indices]
311 return SparseEncoding(
312 np.stack(sparse_indices, axis=-1), values, shape=dense_data.shape
313 )
315 def copy(self):
316 return SparseEncoding(
317 indices=self.sparse_indices.copy(),
318 values=self.sparse_values.copy(),
319 shape=self.shape,
320 )
322 @property
323 def sparse_indices(self):
324 return self._data["indices"]
326 @property
327 def sparse_values(self):
328 return self._data["values"]
330 @property
331 def dtype(self):
332 return self.sparse_values.dtype
334 @caching.cache_decorator
335 def sum(self):
336 return self.sparse_values.sum()
338 @property
339 def ndims(self):
340 return self.sparse_indices.shape[-1]
342 @property
343 def shape(self):
344 return self._shape
346 @property
347 def size(self):
348 return np.prod(self.shape)
350 @property
351 def sparse_components(self):
352 return self.sparse_indices, self.sparse_values
354 @caching.cache_decorator
355 def dense(self):
356 sparse = self._csc
357 # sparse.todense gives an `np.matrix` which cannot be reshaped
358 dense = np.zeros(shape=sparse.shape, dtype=sparse.dtype)
359 sparse.todense(out=dense)
360 return np.reshape(dense, self.shape)
362 @caching.cache_decorator
363 def _csc(self):
364 values = self.sparse_values
365 indices = self._flat_indices(self.sparse_indices)
366 indptr = [0, len(indices)]
367 return sp.csc_matrix((values, indices, indptr), shape=(self.size, 1))
369 def _flat_indices(self, indices):
370 assert indices.shape[1] == 3 and len(indices.shape) == 2
371 return np.ravel_multi_index(indices.T, self.shape)
373 def _shaped_indices(self, flat_indices):
374 return np.column_stack(np.unravel_index(flat_indices, self.shape))
376 def gather_nd(self, indices):
377 mat = self._csc[self._flat_indices(indices)].todense()
378 # mat is a np matrix, which stays rank 2 after squeeze
379 # np.asarray changes this to a standard rank 2 array.
380 return np.asarray(mat).squeeze(axis=-1)
382 def mask(self, mask):
383 i, _ = np.where(self._csc[mask.reshape((-1,))])
384 return self._shaped_indices(i)
386 def get_value(self, index):
387 return self._gather_nd(np.expand_dims(index, axis=0))[0]
389 @caching.cache_decorator
390 def stripped(self):
391 """
392 Get encoding with all zeros stripped from the start/end of each axis.
394 Returns:
395 encoding: SparseEncoding with same values but indices shifted down
396 by padding[:, 0]
397 padding: (n, 2) array of ints denoting padding at the start/end
398 that was stripped
399 """
400 if self.is_empty:
401 return _empty_stripped(self.shape)
402 indices = self.sparse_indices
403 pad_left = np.min(indices, axis=0)
404 pad_right = np.max(indices, axis=0)
405 pad_right *= -1
406 pad_right += self.shape
407 padding = np.column_stack((pad_left, pad_right))
408 return SparseEncoding(indices - pad_left, self.sparse_values), padding
411def SparseBinaryEncoding(indices, shape=None):
412 """
413 Convenient factory constructor for SparseEncodings with values all ones.
415 Parameters
416 ------------
417 indices: (m, n) sparse indices into conceptual rank-n array
418 shape: length n iterable or None. If None, maximum of indices along first
419 axis + 1 is used
421 Returns
422 ------------
423 rank n bool `SparseEncoding` with True values at each index.
424 """
425 return SparseEncoding(indices, np.ones(shape=(indices.shape[0],), dtype=bool), shape)
428class RunLengthEncoding(Encoding):
429 """1D run length encoding.
431 See `trimesh.voxel.runlength` documentation for implementation details.
432 """
434 def __init__(self, data, dtype=None):
435 """
436 Parameters
437 ------------
438 data: run length encoded data.
439 dtype: dtype of encoded data. Each second value of data is cast will be
440 cast to this dtype if provided.
441 """
442 super().__init__(data=caching.tracked_array(data))
443 if dtype is None:
444 dtype = self._data.dtype
445 if len(self._data.shape) != 1:
446 raise ValueError("data must be 1D numpy array")
447 self._dtype = dtype
449 @caching.cache_decorator
450 def is_empty(self):
451 return not np.any(np.logical_and(self._data[::2], self._data[1::2]))
453 @property
454 def ndims(self):
455 return 1
457 @property
458 def shape(self):
459 return (self.size,)
461 @property
462 def dtype(self):
463 return self._dtype
465 def __hash__(self):
466 """
467 Get the hash of the current transformation matrix.
469 Returns
470 ------------
471 hash : str
472 Hash of transformation matrix
473 """
474 return self._data.__hash__()
476 @staticmethod
477 def from_dense(dense_data, dtype=np.int64, encoding_dtype=np.int64):
478 return RunLengthEncoding(
479 runlength.dense_to_rle(dense_data, dtype=encoding_dtype), dtype=dtype
480 )
482 @staticmethod
483 def from_rle(rle_data, dtype=None):
484 if dtype != rle_data.dtype:
485 rle_data = runlength.rle_to_rle(rle_data, dtype=dtype)
486 return RunLengthEncoding(rle_data)
488 @staticmethod
489 def from_brle(brle_data, dtype=None):
490 return RunLengthEncoding(runlength.brle_to_rle(brle_data, dtype=dtype))
492 @caching.cache_decorator
493 def stripped(self):
494 if self.is_empty:
495 return _empty_stripped(self.shape)
496 data, padding = runlength.rle_strip(self._data)
497 if padding == (0, 0):
498 encoding = self
499 else:
500 encoding = RunLengthEncoding(data, dtype=self._dtype)
501 padding = np.expand_dims(padding, axis=0)
502 return encoding, padding
504 @caching.cache_decorator
505 def sum(self):
506 return (self._data[::2] * self._data[1::2]).sum()
508 @caching.cache_decorator
509 def size(self):
510 return runlength.rle_length(self._data)
512 def _flip(self, axes):
513 if axes != (0,):
514 raise ValueError(f"encoding is 1D - cannot flip on axis {axes!s}")
515 return RunLengthEncoding(runlength.rle_reverse(self._data))
517 @caching.cache_decorator
518 def sparse_components(self):
519 return runlength.rle_to_sparse(self._data)
521 @caching.cache_decorator
522 def sparse_indices(self):
523 return self.sparse_components[0]
525 @caching.cache_decorator
526 def sparse_values(self):
527 return self.sparse_components[1]
529 @caching.cache_decorator
530 def dense(self):
531 return runlength.rle_to_dense(self._data, dtype=self._dtype)
533 def gather(self, indices):
534 return runlength.rle_gather_1d(self._data, indices, dtype=self._dtype)
536 def gather_nd(self, indices):
537 indices = np.squeeze(indices, axis=-1)
538 return self.gather(indices)
540 def sorted_gather(self, ordered_indices):
541 return np.array(
542 tuple(runlength.sorted_rle_gather_1d(self._data, ordered_indices)),
543 dtype=self._dtype,
544 )
546 def mask(self, mask):
547 return np.array(tuple(runlength.rle_mask(self._data, mask)), dtype=self._dtype)
549 def get_value(self, index):
550 for value in self.sorted_gather((index,)):
551 return np.asanyarray(value, dtype=self._dtype)
553 def copy(self):
554 return RunLengthEncoding(self._data.copy(), dtype=self.dtype)
556 def run_length_data(self, dtype=np.int64):
557 return runlength.rle_to_rle(self._data, dtype=dtype)
559 def binary_run_length_data(self, dtype=np.int64):
560 return runlength.rle_to_brle(self._data, dtype=dtype)
563class BinaryRunLengthEncoding(RunLengthEncoding):
564 """1D binary run length encoding.
566 See `trimesh.voxel.runlength` documentation for implementation details.
567 """
569 def __init__(self, data):
570 """
571 Parameters
572 ------------
573 data: binary run length encoded data.
574 """
575 super().__init__(data=data, dtype=bool)
577 @caching.cache_decorator
578 def is_empty(self):
579 return not np.any(self._data[1::2])
581 @staticmethod
582 def from_dense(dense_data, encoding_dtype=np.int64):
583 return BinaryRunLengthEncoding(
584 runlength.dense_to_brle(dense_data, dtype=encoding_dtype)
585 )
587 @staticmethod
588 def from_rle(rle_data, dtype=None):
589 return BinaryRunLengthEncoding(runlength.rle_to_brle(rle_data, dtype=dtype))
591 @staticmethod
592 def from_brle(brle_data, dtype=None):
593 if dtype != brle_data.dtype:
594 brle_data = runlength.brle_to_brle(brle_data, dtype=dtype)
595 return BinaryRunLengthEncoding(brle_data)
597 @caching.cache_decorator
598 def stripped(self):
599 if self.is_empty:
600 return _empty_stripped(self.shape)
601 data, padding = runlength.rle_strip(self._data)
602 if padding == (0, 0):
603 encoding = self
604 else:
605 encoding = BinaryRunLengthEncoding(data)
606 padding = np.expand_dims(padding, axis=0)
607 return encoding, padding
609 @caching.cache_decorator
610 def sum(self):
611 return self._data[1::2].sum()
613 @caching.cache_decorator
614 def size(self):
615 return runlength.brle_length(self._data)
617 def _flip(self, axes):
618 if axes != (0,):
619 raise ValueError(f"encoding is 1D - cannot flip on axis {axes!s}")
620 return BinaryRunLengthEncoding(runlength.brle_reverse(self._data))
622 @property
623 def sparse_components(self):
624 return self.sparse_indices, self.sparse_values
626 @caching.cache_decorator
627 def sparse_values(self):
628 return np.ones(shape=(self.sum,), dtype=bool)
630 @caching.cache_decorator
631 def sparse_indices(self):
632 return runlength.brle_to_sparse(self._data)
634 @caching.cache_decorator
635 def dense(self):
636 return runlength.brle_to_dense(self._data)
638 def gather(self, indices):
639 return runlength.brle_gather_1d(self._data, indices)
641 def gather_nd(self, indices):
642 indices = np.squeeze(indices)
643 return self.gather(indices)
645 def sorted_gather(self, ordered_indices):
646 gen = runlength.sorted_brle_gather_1d(self._data, ordered_indices)
647 return np.array(tuple(gen), dtype=bool)
649 def mask(self, mask):
650 gen = runlength.brle_mask(self._data, mask)
651 return np.array(tuple(gen), dtype=bool)
653 def copy(self):
654 return BinaryRunLengthEncoding(self._data.copy())
656 def run_length_data(self, dtype=np.int64):
657 return runlength.brle_to_rle(self._data, dtype=dtype)
659 def binary_run_length_data(self, dtype=np.int64):
660 return runlength.brle_to_brle(self._data, dtype=dtype)
663class LazyIndexMap(Encoding):
664 """
665 Abstract class for implementing lazy index mapping operations.
667 Implementations include transpose, flatten/reshaping and flipping
669 Derived classes must implement:
670 * _to_base_indices(indices)
671 * _from_base_indices(base_indices)
672 * shape
673 * dense
674 * mask(mask)
675 """
677 @abc.abstractmethod
678 def _to_base_indices(self, indices):
679 pass
681 @abc.abstractmethod
682 def _from_base_indices(self, base_indices):
683 pass
685 @property
686 def is_empty(self):
687 return self._data.is_empty
689 @property
690 def dtype(self):
691 return self._data.dtype
693 @property
694 def sum(self):
695 return self._data.sum
697 @property
698 def size(self):
699 return self._data.size
701 @property
702 def sparse_indices(self):
703 return self._from_base_indices(self._data.sparse_indices)
705 @property
706 def sparse_values(self):
707 return self._data.sparse_values
709 def gather_nd(self, indices):
710 return self._data.gather_nd(self._to_base_indices(indices))
712 def get_value(self, index):
713 return self._data[tuple(self._to_base_indices(index))]
716class FlattenedEncoding(LazyIndexMap):
717 """
718 Lazily flattened encoding.
720 Dense equivalent is np.reshape(data, (-1,)) (np.flatten creates a copy).
721 """
723 def _to_base_indices(self, indices):
724 return np.column_stack(np.unravel_index(indices, self._data.shape))
726 def _from_base_indices(self, base_indices):
727 return np.expand_dims(
728 np.ravel_multi_index(base_indices.T, self._data.shape), axis=-1
729 )
731 @property
732 def shape(self):
733 return (self.size,)
735 @property
736 def dense(self):
737 return self._data.dense.reshape((-1,))
739 def mask(self, mask):
740 return self._data.mask(mask.reshape(self._data.shape))
742 @property
743 def flat(self):
744 return self
746 def copy(self):
747 return FlattenedEncoding(self._data.copy())
750class ShapedEncoding(LazyIndexMap):
751 """
752 Lazily reshaped encoding.
754 Numpy equivalent is `np.reshape`
755 """
757 def __init__(self, encoding, shape):
758 if isinstance(encoding, Encoding):
759 if encoding.ndims != 1:
760 encoding = encoding.flat
761 else:
762 raise ValueError("encoding must be an Encoding")
763 super().__init__(data=encoding)
764 self._shape = tuple(shape)
765 nn = self._shape.count(-1)
766 size = np.prod(self._shape)
767 if nn == 1:
768 size = np.abs(size)
769 if self._data.size % size != 0:
770 raise ValueError(
771 "cannot reshape encoding of size %d into shape %s",
772 self._data.size,
773 str(self._shape),
774 )
776 rem = self._data.size // size
777 self._shape = tuple(rem if s == -1 else s for s in self._shape)
778 elif nn > 2:
779 raise ValueError("shape cannot have more than one -1 value")
780 elif np.prod(self._shape) != self._data.size:
781 raise ValueError(
782 "cannot reshape encoding of size %d into shape %s",
783 self._data.size,
784 str(self._shape),
785 )
787 def _from_base_indices(self, base_indices):
788 return np.column_stack(np.unravel_index(base_indices, self.shape))
790 def _to_base_indices(self, indices):
791 return np.expand_dims(np.ravel_multi_index(indices.T, self.shape), axis=-1)
793 @property
794 def flat(self):
795 return self._data
797 @property
798 def shape(self):
799 return self._shape
801 @property
802 def dense(self):
803 return self._data.dense.reshape(self.shape)
805 def mask(self, mask):
806 return self._data.mask(mask.flat)
808 def copy(self):
809 return ShapedEncoding(encoding=self._data.copy(), shape=self.shape)
812class TransposedEncoding(LazyIndexMap):
813 """
814 Lazily transposed encoding
816 Dense equivalent is `np.transpose`
817 """
819 def __init__(self, base_encoding, perm):
820 if not isinstance(base_encoding, Encoding):
821 raise ValueError(f"base_encoding must be an Encoding, got {base_encoding!s}")
822 if len(base_encoding.shape) != len(perm):
823 raise ValueError(
824 "base_encoding has %d ndims - cannot transpose with perm %s",
825 base_encoding.ndims,
826 str(perm),
827 )
829 super().__init__(base_encoding)
830 perm = np.array(perm, dtype=np.int64)
831 if not all(i in perm for i in range(base_encoding.ndims)):
832 raise ValueError(f"perm {perm!s} is not a valid permutation")
833 inv_perm = np.zeros_like(perm)
834 inv_perm[perm] = np.arange(base_encoding.ndims)
835 self._perm = perm
836 self._inv_perm = inv_perm
838 def transpose(self, perm):
839 return _transposed(self._data, [self._perm[p] for p in perm])
841 def _transpose(self, perm):
842 raise RuntimeError("Should not be here")
844 @property
845 def perm(self):
846 return self._perm
848 @property
849 def shape(self):
850 shape = self._data.shape
851 return tuple(shape[p] for p in self._perm)
853 def _to_base_indices(self, indices):
854 return np.take(indices, self._perm, axis=-1)
856 def _from_base_indices(self, base_indices):
857 try:
858 return np.take(base_indices, self._inv_perm, axis=-1)
859 except TypeError:
860 # windows sometimes tries to use wrong dtypes
861 return np.take(
862 base_indices.astype(np.int64), self._inv_perm.astype(np.int64), axis=-1
863 )
865 @property
866 def dense(self):
867 return self._data.dense.transpose(self._perm)
869 def gather(self, indices):
870 return self._data.gather(self._base_indices(indices))
872 def mask(self, mask):
873 return self._data.mask(mask.transpose(self._inv_perm)).transpose(self._perm)
875 def get_value(self, index):
876 return self._data[tuple(self._base_indices(index))]
878 @property
879 def data(self):
880 return self._data
882 def copy(self):
883 return TransposedEncoding(base_encoding=self._data.copy(), perm=self._perm)
886class FlippedEncoding(LazyIndexMap):
887 """
888 Encoding with entries flipped along one or more axes.
890 Dense equivalent is `np.flip`
891 """
893 def __init__(self, encoding, axes):
894 ndims = encoding.ndims
895 if isinstance(axes, np.ndarray) and axes.size == 1:
896 axes = (axes.item(),)
897 elif isinstance(axes, int):
898 axes = (axes,)
899 axes = tuple(a + ndims if a < 0 else a for a in axes)
900 self._axes = tuple(sorted(axes))
901 if len(set(self._axes)) != len(self._axes):
902 raise ValueError(f"Axes cannot contain duplicates, got {self._axes!s}")
903 super().__init__(encoding)
904 if not all(0 <= a < self._data.ndims for a in axes):
905 raise ValueError(
906 "Invalid axes %s for %d-d encoding", str(axes), self._data.ndims
907 )
909 def _to_base_indices(self, indices):
910 indices = indices.copy()
911 shape = self.shape
912 for a in self._axes:
913 indices[:, a] *= -1
914 indices[:, a] += shape
915 return indices
917 def _from_base_indices(self, base_indices):
918 return self._to_base_indices(base_indices)
920 @property
921 def shape(self):
922 return self._data.shape
924 @property
925 def dense(self):
926 dense = self._data.dense
927 for a in self._axes:
928 dense = np.flip(dense, a)
929 return dense
931 def mask(self, mask):
932 if not isinstance(mask, Encoding):
933 mask = DenseEncoding(mask)
934 mask = mask.flip(self._axes)
935 return self._data.mask(mask).flip(self._axes)
937 def copy(self):
938 return FlippedEncoding(self._data.copy(), self._axes)
940 def flip(self, axis=0):
941 if isinstance(axis, np.ndarray):
942 if axis.size == 1:
943 axis = (axis.item(),)
944 else:
945 axis = tuple(axis)
946 elif isinstance(axis, int):
947 axes = (axis,)
948 else:
949 axes = tuple(axis)
950 return _flipped(self, self._axes + axes)
952 def _flip(self, axes):
953 raise RuntimeError("Should not be here")
956def _flipped(encoding, axes):
957 if not hasattr(axes, "__iter__"):
958 axes = (axes,)
959 unique_ax = set()
960 ndims = encoding.ndims
961 axes = tuple(a + ndims if a < 0 else a for a in axes)
962 for a in axes:
963 if a in unique_ax:
964 unique_ax.remove(a)
965 else:
966 unique_ax.add(a)
967 if len(unique_ax) == 0:
968 return encoding
969 else:
970 return encoding._flip(tuple(sorted(unique_ax)))
973def _transposed(encoding, perm):
974 ndims = encoding.ndims
975 perm = tuple(p + ndims if p < 0 else p for p in perm)
976 if np.all(np.arange(ndims) == perm):
977 return encoding
978 else:
979 return encoding._transpose(perm)