Coverage for trimesh/voxel/encoding.py: 78%

612 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-24 04:40 +0000

1"""OO interfaces to encodings for ND arrays which caching.""" 

2 

3import abc 

4 

5import numpy as np 

6 

7from .. import caching 

8from ..util import ABC 

9from . import runlength 

10 

11try: 

12 from scipy import sparse as sp 

13except BaseException as E: 

14 from ..exceptions import ExceptionWrapper 

15 

16 sp = ExceptionWrapper(E) 

17 

18 

19def _empty_stripped(shape): 

20 num_dims = len(shape) 

21 encoding = DenseEncoding(np.zeros(shape=(0,) * num_dims, dtype=bool)) 

22 padding = np.zeros(shape=(num_dims, 2), dtype=int) 

23 padding[:, 1] = shape 

24 return encoding, padding 

25 

26 

27class Encoding(ABC): 

28 """ 

29 Base class for objects that implement a specific subset of of ndarray ops. 

30 

31 This presents a unified interface for various different ways of encoding 

32 conceptually dense arrays and to interoperate between them. 

33 

34 Example implementations are ND sparse arrays, run length encoded arrays 

35 and dense encodings (wrappers around np.ndarrays). 

36 """ 

37 

38 def __init__(self, data): 

39 self._data = data 

40 self._cache = caching.Cache(id_function=self._data.__hash__) 

41 

42 @property 

43 @abc.abstractmethod 

44 def dtype(self): 

45 pass 

46 

47 @property 

48 @abc.abstractmethod 

49 def shape(self): 

50 pass 

51 

52 @property 

53 @abc.abstractmethod 

54 def sum(self): 

55 pass 

56 

57 @property 

58 @abc.abstractmethod 

59 def size(self): 

60 pass 

61 

62 @property 

63 @abc.abstractmethod 

64 def sparse_indices(self): 

65 pass 

66 

67 @property 

68 @abc.abstractmethod 

69 def sparse_values(self): 

70 pass 

71 

72 @property 

73 @abc.abstractmethod 

74 def dense(self): 

75 pass 

76 

77 @abc.abstractmethod 

78 def gather_nd(self, indices): 

79 pass 

80 

81 @abc.abstractmethod 

82 def mask(self, mask): 

83 pass 

84 

85 @abc.abstractmethod 

86 def get_value(self, index): 

87 pass 

88 

89 @abc.abstractmethod 

90 def copy(self): 

91 pass 

92 

93 @property 

94 def is_empty(self): 

95 return self.sparse_indices[self.sparse_values != 0].size == 0 

96 

97 @caching.cache_decorator 

98 def stripped(self): 

99 """ 

100 Get encoding with all zeros stripped from the start and end 

101 of each axis. 

102 

103 Returns 

104 ------------ 

105 encoding: ? 

106 padding : (n, 2) int 

107 Padding at the start and end that was stripped 

108 """ 

109 if self.is_empty: 

110 return _empty_stripped(self.shape) 

111 dense = self.dense 

112 shape = dense.shape 

113 ndims = len(shape) 

114 padding = [] 

115 slices = [] 

116 for dim, size in enumerate(shape): 

117 axis = tuple(range(dim)) + tuple(range(dim + 1, ndims)) 

118 filled = np.any(dense, axis=axis) 

119 (indices,) = np.nonzero(filled) 

120 lower = indices.min() 

121 upper = indices.max() + 1 

122 padding.append([lower, size - upper]) 

123 slices.append(slice(lower, upper)) 

124 return DenseEncoding(dense[tuple(slices)]), np.array(padding, int) 

125 

126 def _flip(self, axes): 

127 return FlippedEncoding(self, axes) 

128 

129 def __hash__(self): 

130 """ 

131 Get the hash of the current transformation matrix. 

132 

133 Returns 

134 ------------ 

135 hash : str 

136 Hash of transformation matrix 

137 """ 

138 return self._data.__hash__() 

139 

140 @property 

141 def ndims(self): 

142 return len(self.shape) 

143 

144 def reshape(self, shape): 

145 return self.flat if len(shape) == 1 else ShapedEncoding(self, shape) 

146 

147 @property 

148 def flat(self): 

149 return FlattenedEncoding(self) 

150 

151 def flip(self, axis=0): 

152 return _flipped(self, axis) 

153 

154 @property 

155 def sparse_components(self): 

156 return self.sparse_indices, self.sparse_values 

157 

158 @property 

159 def data(self): 

160 return self._data 

161 

162 def run_length_data(self, dtype=np.int64): 

163 if self.ndims != 1: 

164 raise ValueError("`run_length_data` only valid for flat encodings") 

165 return runlength.dense_to_rle(self.dense, dtype=dtype) 

166 

167 def binary_run_length_data(self, dtype=np.int64): 

168 if self.ndims != 1: 

169 raise ValueError("`run_length_data` only valid for flat encodings") 

170 return runlength.dense_to_brle(self.dense, dtype=dtype) 

171 

172 def transpose(self, perm): 

173 return _transposed(self, perm) 

174 

175 def _transpose(self, perm): 

176 return TransposedEncoding(self, perm) 

177 

178 @property 

179 def mutable(self): 

180 return self._data.mutable 

181 

182 @mutable.setter 

183 def mutable(self, value): 

184 self._data.mutable = value 

185 

186 

187class DenseEncoding(Encoding): 

188 """Simple `Encoding` implementation based on a numpy ndarray.""" 

189 

190 def __init__(self, data): 

191 if not isinstance(data, caching.TrackedArray): 

192 if not isinstance(data, np.ndarray): 

193 raise ValueError("DenseEncoding data must be a numpy array") 

194 data = caching.tracked_array(data) 

195 super().__init__(data=data) 

196 

197 @property 

198 def dtype(self): 

199 return self._data.dtype 

200 

201 @property 

202 def shape(self): 

203 return self._data.shape 

204 

205 @caching.cache_decorator 

206 def sum(self): 

207 return self._data.sum() 

208 

209 @caching.cache_decorator 

210 def is_empty(self): 

211 return not np.any(self._data) 

212 

213 @property 

214 def size(self): 

215 return self._data.size 

216 

217 @property 

218 def sparse_components(self): 

219 indices = self.sparse_indices 

220 values = self.gather(indices) 

221 return indices, values 

222 

223 @caching.cache_decorator 

224 def sparse_indices(self): 

225 return np.column_stack(np.where(self._data)) 

226 

227 @caching.cache_decorator 

228 def sparse_values(self): 

229 return self.sparse_components[1] 

230 

231 def _flip(self, axes): 

232 dense = self.dense 

233 for a in axes: 

234 dense = np.flip(dense, a) 

235 return DenseEncoding(dense) 

236 

237 @property 

238 def dense(self): 

239 return self._data 

240 

241 def gather(self, indices): 

242 return self._data[indices] 

243 

244 def gather_nd(self, indices): 

245 return self._data[tuple(indices.T)] 

246 

247 def mask(self, mask): 

248 return self._data[mask if isinstance(mask, np.ndarray) else mask.dense] 

249 

250 def get_value(self, index): 

251 return self._data[tuple(index)] 

252 

253 def reshape(self, shape): 

254 return DenseEncoding(self._data.reshape(shape)) 

255 

256 def _transpose(self, perm): 

257 return DenseEncoding(self._data.transpose(perm)) 

258 

259 @property 

260 def flat(self): 

261 return DenseEncoding(self._data.reshape((-1,))) 

262 

263 def copy(self): 

264 return DenseEncoding(self._data.copy()) 

265 

266 

267class SparseEncoding(Encoding): 

268 """ 

269 `Encoding` implementation based on an ND sparse implementation. 

270 

271 Since the scipy.sparse implementations are for 2D arrays only, this 

272 implementation uses a single-column CSC matrix with index 

273 raveling/unraveling. 

274 """ 

275 

276 def __init__(self, indices, values, shape=None): 

277 """ 

278 Parameters 

279 ------------ 

280 indices: (m, n)-sized int array of indices 

281 values: (m, n)-sized dtype array of values at the specified indices 

282 shape: (n,) iterable of integers. If None, the maximum value of indices 

283 + 1 is used. 

284 """ 

285 data = caching.DataStore() 

286 super().__init__(data) 

287 data["indices"] = indices 

288 data["values"] = values 

289 indices = data["indices"] 

290 if len(indices.shape) != 2: 

291 raise ValueError(f"indices must be 2D, got shaped {indices.shape!s}") 

292 if data["values"].shape != (indices.shape[0],): 

293 raise ValueError( 

294 "values and indices shapes inconsistent: {} and {}".format( 

295 data["values"], data["indices"] 

296 ) 

297 ) 

298 if shape is None: 

299 self._shape = tuple(data["indices"].max(axis=0) + 1) 

300 else: 

301 self._shape = tuple(shape) 

302 if not np.all(indices < self._shape): 

303 raise ValueError("all indices must be less than shape") 

304 if not np.all(indices >= 0): 

305 raise ValueError("all indices must be non-negative") 

306 

307 @staticmethod 

308 def from_dense(dense_data): 

309 sparse_indices = np.where(dense_data) 

310 values = dense_data[sparse_indices] 

311 return SparseEncoding( 

312 np.stack(sparse_indices, axis=-1), values, shape=dense_data.shape 

313 ) 

314 

315 def copy(self): 

316 return SparseEncoding( 

317 indices=self.sparse_indices.copy(), 

318 values=self.sparse_values.copy(), 

319 shape=self.shape, 

320 ) 

321 

322 @property 

323 def sparse_indices(self): 

324 return self._data["indices"] 

325 

326 @property 

327 def sparse_values(self): 

328 return self._data["values"] 

329 

330 @property 

331 def dtype(self): 

332 return self.sparse_values.dtype 

333 

334 @caching.cache_decorator 

335 def sum(self): 

336 return self.sparse_values.sum() 

337 

338 @property 

339 def ndims(self): 

340 return self.sparse_indices.shape[-1] 

341 

342 @property 

343 def shape(self): 

344 return self._shape 

345 

346 @property 

347 def size(self): 

348 return np.prod(self.shape) 

349 

350 @property 

351 def sparse_components(self): 

352 return self.sparse_indices, self.sparse_values 

353 

354 @caching.cache_decorator 

355 def dense(self): 

356 sparse = self._csc 

357 # sparse.todense gives an `np.matrix` which cannot be reshaped 

358 dense = np.zeros(shape=sparse.shape, dtype=sparse.dtype) 

359 sparse.todense(out=dense) 

360 return np.reshape(dense, self.shape) 

361 

362 @caching.cache_decorator 

363 def _csc(self): 

364 values = self.sparse_values 

365 indices = self._flat_indices(self.sparse_indices) 

366 indptr = [0, len(indices)] 

367 return sp.csc_matrix((values, indices, indptr), shape=(self.size, 1)) 

368 

369 def _flat_indices(self, indices): 

370 assert indices.shape[1] == 3 and len(indices.shape) == 2 

371 return np.ravel_multi_index(indices.T, self.shape) 

372 

373 def _shaped_indices(self, flat_indices): 

374 return np.column_stack(np.unravel_index(flat_indices, self.shape)) 

375 

376 def gather_nd(self, indices): 

377 mat = self._csc[self._flat_indices(indices)].todense() 

378 # mat is a np matrix, which stays rank 2 after squeeze 

379 # np.asarray changes this to a standard rank 2 array. 

380 return np.asarray(mat).squeeze(axis=-1) 

381 

382 def mask(self, mask): 

383 i, _ = np.where(self._csc[mask.reshape((-1,))]) 

384 return self._shaped_indices(i) 

385 

386 def get_value(self, index): 

387 return self._gather_nd(np.expand_dims(index, axis=0))[0] 

388 

389 @caching.cache_decorator 

390 def stripped(self): 

391 """ 

392 Get encoding with all zeros stripped from the start/end of each axis. 

393 

394 Returns: 

395 encoding: SparseEncoding with same values but indices shifted down 

396 by padding[:, 0] 

397 padding: (n, 2) array of ints denoting padding at the start/end 

398 that was stripped 

399 """ 

400 if self.is_empty: 

401 return _empty_stripped(self.shape) 

402 indices = self.sparse_indices 

403 pad_left = np.min(indices, axis=0) 

404 pad_right = np.max(indices, axis=0) 

405 pad_right *= -1 

406 pad_right += self.shape 

407 padding = np.column_stack((pad_left, pad_right)) 

408 return SparseEncoding(indices - pad_left, self.sparse_values), padding 

409 

410 

411def SparseBinaryEncoding(indices, shape=None): 

412 """ 

413 Convenient factory constructor for SparseEncodings with values all ones. 

414 

415 Parameters 

416 ------------ 

417 indices: (m, n) sparse indices into conceptual rank-n array 

418 shape: length n iterable or None. If None, maximum of indices along first 

419 axis + 1 is used 

420 

421 Returns 

422 ------------ 

423 rank n bool `SparseEncoding` with True values at each index. 

424 """ 

425 return SparseEncoding(indices, np.ones(shape=(indices.shape[0],), dtype=bool), shape) 

426 

427 

428class RunLengthEncoding(Encoding): 

429 """1D run length encoding. 

430 

431 See `trimesh.voxel.runlength` documentation for implementation details. 

432 """ 

433 

434 def __init__(self, data, dtype=None): 

435 """ 

436 Parameters 

437 ------------ 

438 data: run length encoded data. 

439 dtype: dtype of encoded data. Each second value of data is cast will be 

440 cast to this dtype if provided. 

441 """ 

442 super().__init__(data=caching.tracked_array(data)) 

443 if dtype is None: 

444 dtype = self._data.dtype 

445 if len(self._data.shape) != 1: 

446 raise ValueError("data must be 1D numpy array") 

447 self._dtype = dtype 

448 

449 @caching.cache_decorator 

450 def is_empty(self): 

451 return not np.any(np.logical_and(self._data[::2], self._data[1::2])) 

452 

453 @property 

454 def ndims(self): 

455 return 1 

456 

457 @property 

458 def shape(self): 

459 return (self.size,) 

460 

461 @property 

462 def dtype(self): 

463 return self._dtype 

464 

465 def __hash__(self): 

466 """ 

467 Get the hash of the current transformation matrix. 

468 

469 Returns 

470 ------------ 

471 hash : str 

472 Hash of transformation matrix 

473 """ 

474 return self._data.__hash__() 

475 

476 @staticmethod 

477 def from_dense(dense_data, dtype=np.int64, encoding_dtype=np.int64): 

478 return RunLengthEncoding( 

479 runlength.dense_to_rle(dense_data, dtype=encoding_dtype), dtype=dtype 

480 ) 

481 

482 @staticmethod 

483 def from_rle(rle_data, dtype=None): 

484 if dtype != rle_data.dtype: 

485 rle_data = runlength.rle_to_rle(rle_data, dtype=dtype) 

486 return RunLengthEncoding(rle_data) 

487 

488 @staticmethod 

489 def from_brle(brle_data, dtype=None): 

490 return RunLengthEncoding(runlength.brle_to_rle(brle_data, dtype=dtype)) 

491 

492 @caching.cache_decorator 

493 def stripped(self): 

494 if self.is_empty: 

495 return _empty_stripped(self.shape) 

496 data, padding = runlength.rle_strip(self._data) 

497 if padding == (0, 0): 

498 encoding = self 

499 else: 

500 encoding = RunLengthEncoding(data, dtype=self._dtype) 

501 padding = np.expand_dims(padding, axis=0) 

502 return encoding, padding 

503 

504 @caching.cache_decorator 

505 def sum(self): 

506 return (self._data[::2] * self._data[1::2]).sum() 

507 

508 @caching.cache_decorator 

509 def size(self): 

510 return runlength.rle_length(self._data) 

511 

512 def _flip(self, axes): 

513 if axes != (0,): 

514 raise ValueError(f"encoding is 1D - cannot flip on axis {axes!s}") 

515 return RunLengthEncoding(runlength.rle_reverse(self._data)) 

516 

517 @caching.cache_decorator 

518 def sparse_components(self): 

519 return runlength.rle_to_sparse(self._data) 

520 

521 @caching.cache_decorator 

522 def sparse_indices(self): 

523 return self.sparse_components[0] 

524 

525 @caching.cache_decorator 

526 def sparse_values(self): 

527 return self.sparse_components[1] 

528 

529 @caching.cache_decorator 

530 def dense(self): 

531 return runlength.rle_to_dense(self._data, dtype=self._dtype) 

532 

533 def gather(self, indices): 

534 return runlength.rle_gather_1d(self._data, indices, dtype=self._dtype) 

535 

536 def gather_nd(self, indices): 

537 indices = np.squeeze(indices, axis=-1) 

538 return self.gather(indices) 

539 

540 def sorted_gather(self, ordered_indices): 

541 return np.array( 

542 tuple(runlength.sorted_rle_gather_1d(self._data, ordered_indices)), 

543 dtype=self._dtype, 

544 ) 

545 

546 def mask(self, mask): 

547 return np.array(tuple(runlength.rle_mask(self._data, mask)), dtype=self._dtype) 

548 

549 def get_value(self, index): 

550 for value in self.sorted_gather((index,)): 

551 return np.asanyarray(value, dtype=self._dtype) 

552 

553 def copy(self): 

554 return RunLengthEncoding(self._data.copy(), dtype=self.dtype) 

555 

556 def run_length_data(self, dtype=np.int64): 

557 return runlength.rle_to_rle(self._data, dtype=dtype) 

558 

559 def binary_run_length_data(self, dtype=np.int64): 

560 return runlength.rle_to_brle(self._data, dtype=dtype) 

561 

562 

563class BinaryRunLengthEncoding(RunLengthEncoding): 

564 """1D binary run length encoding. 

565 

566 See `trimesh.voxel.runlength` documentation for implementation details. 

567 """ 

568 

569 def __init__(self, data): 

570 """ 

571 Parameters 

572 ------------ 

573 data: binary run length encoded data. 

574 """ 

575 super().__init__(data=data, dtype=bool) 

576 

577 @caching.cache_decorator 

578 def is_empty(self): 

579 return not np.any(self._data[1::2]) 

580 

581 @staticmethod 

582 def from_dense(dense_data, encoding_dtype=np.int64): 

583 return BinaryRunLengthEncoding( 

584 runlength.dense_to_brle(dense_data, dtype=encoding_dtype) 

585 ) 

586 

587 @staticmethod 

588 def from_rle(rle_data, dtype=None): 

589 return BinaryRunLengthEncoding(runlength.rle_to_brle(rle_data, dtype=dtype)) 

590 

591 @staticmethod 

592 def from_brle(brle_data, dtype=None): 

593 if dtype != brle_data.dtype: 

594 brle_data = runlength.brle_to_brle(brle_data, dtype=dtype) 

595 return BinaryRunLengthEncoding(brle_data) 

596 

597 @caching.cache_decorator 

598 def stripped(self): 

599 if self.is_empty: 

600 return _empty_stripped(self.shape) 

601 data, padding = runlength.rle_strip(self._data) 

602 if padding == (0, 0): 

603 encoding = self 

604 else: 

605 encoding = BinaryRunLengthEncoding(data) 

606 padding = np.expand_dims(padding, axis=0) 

607 return encoding, padding 

608 

609 @caching.cache_decorator 

610 def sum(self): 

611 return self._data[1::2].sum() 

612 

613 @caching.cache_decorator 

614 def size(self): 

615 return runlength.brle_length(self._data) 

616 

617 def _flip(self, axes): 

618 if axes != (0,): 

619 raise ValueError(f"encoding is 1D - cannot flip on axis {axes!s}") 

620 return BinaryRunLengthEncoding(runlength.brle_reverse(self._data)) 

621 

622 @property 

623 def sparse_components(self): 

624 return self.sparse_indices, self.sparse_values 

625 

626 @caching.cache_decorator 

627 def sparse_values(self): 

628 return np.ones(shape=(self.sum,), dtype=bool) 

629 

630 @caching.cache_decorator 

631 def sparse_indices(self): 

632 return runlength.brle_to_sparse(self._data) 

633 

634 @caching.cache_decorator 

635 def dense(self): 

636 return runlength.brle_to_dense(self._data) 

637 

638 def gather(self, indices): 

639 return runlength.brle_gather_1d(self._data, indices) 

640 

641 def gather_nd(self, indices): 

642 indices = np.squeeze(indices) 

643 return self.gather(indices) 

644 

645 def sorted_gather(self, ordered_indices): 

646 gen = runlength.sorted_brle_gather_1d(self._data, ordered_indices) 

647 return np.array(tuple(gen), dtype=bool) 

648 

649 def mask(self, mask): 

650 gen = runlength.brle_mask(self._data, mask) 

651 return np.array(tuple(gen), dtype=bool) 

652 

653 def copy(self): 

654 return BinaryRunLengthEncoding(self._data.copy()) 

655 

656 def run_length_data(self, dtype=np.int64): 

657 return runlength.brle_to_rle(self._data, dtype=dtype) 

658 

659 def binary_run_length_data(self, dtype=np.int64): 

660 return runlength.brle_to_brle(self._data, dtype=dtype) 

661 

662 

663class LazyIndexMap(Encoding): 

664 """ 

665 Abstract class for implementing lazy index mapping operations. 

666 

667 Implementations include transpose, flatten/reshaping and flipping 

668 

669 Derived classes must implement: 

670 * _to_base_indices(indices) 

671 * _from_base_indices(base_indices) 

672 * shape 

673 * dense 

674 * mask(mask) 

675 """ 

676 

677 @abc.abstractmethod 

678 def _to_base_indices(self, indices): 

679 pass 

680 

681 @abc.abstractmethod 

682 def _from_base_indices(self, base_indices): 

683 pass 

684 

685 @property 

686 def is_empty(self): 

687 return self._data.is_empty 

688 

689 @property 

690 def dtype(self): 

691 return self._data.dtype 

692 

693 @property 

694 def sum(self): 

695 return self._data.sum 

696 

697 @property 

698 def size(self): 

699 return self._data.size 

700 

701 @property 

702 def sparse_indices(self): 

703 return self._from_base_indices(self._data.sparse_indices) 

704 

705 @property 

706 def sparse_values(self): 

707 return self._data.sparse_values 

708 

709 def gather_nd(self, indices): 

710 return self._data.gather_nd(self._to_base_indices(indices)) 

711 

712 def get_value(self, index): 

713 return self._data[tuple(self._to_base_indices(index))] 

714 

715 

716class FlattenedEncoding(LazyIndexMap): 

717 """ 

718 Lazily flattened encoding. 

719 

720 Dense equivalent is np.reshape(data, (-1,)) (np.flatten creates a copy). 

721 """ 

722 

723 def _to_base_indices(self, indices): 

724 return np.column_stack(np.unravel_index(indices, self._data.shape)) 

725 

726 def _from_base_indices(self, base_indices): 

727 return np.expand_dims( 

728 np.ravel_multi_index(base_indices.T, self._data.shape), axis=-1 

729 ) 

730 

731 @property 

732 def shape(self): 

733 return (self.size,) 

734 

735 @property 

736 def dense(self): 

737 return self._data.dense.reshape((-1,)) 

738 

739 def mask(self, mask): 

740 return self._data.mask(mask.reshape(self._data.shape)) 

741 

742 @property 

743 def flat(self): 

744 return self 

745 

746 def copy(self): 

747 return FlattenedEncoding(self._data.copy()) 

748 

749 

750class ShapedEncoding(LazyIndexMap): 

751 """ 

752 Lazily reshaped encoding. 

753 

754 Numpy equivalent is `np.reshape` 

755 """ 

756 

757 def __init__(self, encoding, shape): 

758 if isinstance(encoding, Encoding): 

759 if encoding.ndims != 1: 

760 encoding = encoding.flat 

761 else: 

762 raise ValueError("encoding must be an Encoding") 

763 super().__init__(data=encoding) 

764 self._shape = tuple(shape) 

765 nn = self._shape.count(-1) 

766 size = np.prod(self._shape) 

767 if nn == 1: 

768 size = np.abs(size) 

769 if self._data.size % size != 0: 

770 raise ValueError( 

771 "cannot reshape encoding of size %d into shape %s", 

772 self._data.size, 

773 str(self._shape), 

774 ) 

775 

776 rem = self._data.size // size 

777 self._shape = tuple(rem if s == -1 else s for s in self._shape) 

778 elif nn > 2: 

779 raise ValueError("shape cannot have more than one -1 value") 

780 elif np.prod(self._shape) != self._data.size: 

781 raise ValueError( 

782 "cannot reshape encoding of size %d into shape %s", 

783 self._data.size, 

784 str(self._shape), 

785 ) 

786 

787 def _from_base_indices(self, base_indices): 

788 return np.column_stack(np.unravel_index(base_indices, self.shape)) 

789 

790 def _to_base_indices(self, indices): 

791 return np.expand_dims(np.ravel_multi_index(indices.T, self.shape), axis=-1) 

792 

793 @property 

794 def flat(self): 

795 return self._data 

796 

797 @property 

798 def shape(self): 

799 return self._shape 

800 

801 @property 

802 def dense(self): 

803 return self._data.dense.reshape(self.shape) 

804 

805 def mask(self, mask): 

806 return self._data.mask(mask.flat) 

807 

808 def copy(self): 

809 return ShapedEncoding(encoding=self._data.copy(), shape=self.shape) 

810 

811 

812class TransposedEncoding(LazyIndexMap): 

813 """ 

814 Lazily transposed encoding 

815 

816 Dense equivalent is `np.transpose` 

817 """ 

818 

819 def __init__(self, base_encoding, perm): 

820 if not isinstance(base_encoding, Encoding): 

821 raise ValueError(f"base_encoding must be an Encoding, got {base_encoding!s}") 

822 if len(base_encoding.shape) != len(perm): 

823 raise ValueError( 

824 "base_encoding has %d ndims - cannot transpose with perm %s", 

825 base_encoding.ndims, 

826 str(perm), 

827 ) 

828 

829 super().__init__(base_encoding) 

830 perm = np.array(perm, dtype=np.int64) 

831 if not all(i in perm for i in range(base_encoding.ndims)): 

832 raise ValueError(f"perm {perm!s} is not a valid permutation") 

833 inv_perm = np.zeros_like(perm) 

834 inv_perm[perm] = np.arange(base_encoding.ndims) 

835 self._perm = perm 

836 self._inv_perm = inv_perm 

837 

838 def transpose(self, perm): 

839 return _transposed(self._data, [self._perm[p] for p in perm]) 

840 

841 def _transpose(self, perm): 

842 raise RuntimeError("Should not be here") 

843 

844 @property 

845 def perm(self): 

846 return self._perm 

847 

848 @property 

849 def shape(self): 

850 shape = self._data.shape 

851 return tuple(shape[p] for p in self._perm) 

852 

853 def _to_base_indices(self, indices): 

854 return np.take(indices, self._perm, axis=-1) 

855 

856 def _from_base_indices(self, base_indices): 

857 try: 

858 return np.take(base_indices, self._inv_perm, axis=-1) 

859 except TypeError: 

860 # windows sometimes tries to use wrong dtypes 

861 return np.take( 

862 base_indices.astype(np.int64), self._inv_perm.astype(np.int64), axis=-1 

863 ) 

864 

865 @property 

866 def dense(self): 

867 return self._data.dense.transpose(self._perm) 

868 

869 def gather(self, indices): 

870 return self._data.gather(self._base_indices(indices)) 

871 

872 def mask(self, mask): 

873 return self._data.mask(mask.transpose(self._inv_perm)).transpose(self._perm) 

874 

875 def get_value(self, index): 

876 return self._data[tuple(self._base_indices(index))] 

877 

878 @property 

879 def data(self): 

880 return self._data 

881 

882 def copy(self): 

883 return TransposedEncoding(base_encoding=self._data.copy(), perm=self._perm) 

884 

885 

886class FlippedEncoding(LazyIndexMap): 

887 """ 

888 Encoding with entries flipped along one or more axes. 

889 

890 Dense equivalent is `np.flip` 

891 """ 

892 

893 def __init__(self, encoding, axes): 

894 ndims = encoding.ndims 

895 if isinstance(axes, np.ndarray) and axes.size == 1: 

896 axes = (axes.item(),) 

897 elif isinstance(axes, int): 

898 axes = (axes,) 

899 axes = tuple(a + ndims if a < 0 else a for a in axes) 

900 self._axes = tuple(sorted(axes)) 

901 if len(set(self._axes)) != len(self._axes): 

902 raise ValueError(f"Axes cannot contain duplicates, got {self._axes!s}") 

903 super().__init__(encoding) 

904 if not all(0 <= a < self._data.ndims for a in axes): 

905 raise ValueError( 

906 "Invalid axes %s for %d-d encoding", str(axes), self._data.ndims 

907 ) 

908 

909 def _to_base_indices(self, indices): 

910 indices = indices.copy() 

911 shape = self.shape 

912 for a in self._axes: 

913 indices[:, a] *= -1 

914 indices[:, a] += shape 

915 return indices 

916 

917 def _from_base_indices(self, base_indices): 

918 return self._to_base_indices(base_indices) 

919 

920 @property 

921 def shape(self): 

922 return self._data.shape 

923 

924 @property 

925 def dense(self): 

926 dense = self._data.dense 

927 for a in self._axes: 

928 dense = np.flip(dense, a) 

929 return dense 

930 

931 def mask(self, mask): 

932 if not isinstance(mask, Encoding): 

933 mask = DenseEncoding(mask) 

934 mask = mask.flip(self._axes) 

935 return self._data.mask(mask).flip(self._axes) 

936 

937 def copy(self): 

938 return FlippedEncoding(self._data.copy(), self._axes) 

939 

940 def flip(self, axis=0): 

941 if isinstance(axis, np.ndarray): 

942 if axis.size == 1: 

943 axis = (axis.item(),) 

944 else: 

945 axis = tuple(axis) 

946 elif isinstance(axis, int): 

947 axes = (axis,) 

948 else: 

949 axes = tuple(axis) 

950 return _flipped(self, self._axes + axes) 

951 

952 def _flip(self, axes): 

953 raise RuntimeError("Should not be here") 

954 

955 

956def _flipped(encoding, axes): 

957 if not hasattr(axes, "__iter__"): 

958 axes = (axes,) 

959 unique_ax = set() 

960 ndims = encoding.ndims 

961 axes = tuple(a + ndims if a < 0 else a for a in axes) 

962 for a in axes: 

963 if a in unique_ax: 

964 unique_ax.remove(a) 

965 else: 

966 unique_ax.add(a) 

967 if len(unique_ax) == 0: 

968 return encoding 

969 else: 

970 return encoding._flip(tuple(sorted(unique_ax))) 

971 

972 

973def _transposed(encoding, perm): 

974 ndims = encoding.ndims 

975 perm = tuple(p + ndims if p < 0 else p for p in perm) 

976 if np.all(np.arange(ndims) == perm): 

977 return encoding 

978 else: 

979 return encoding._transpose(perm)