• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymanopt / pymanopt / 14702850242

28 Apr 2025 07:43AM UTC coverage: 84.632% (-0.3%) from 84.932%
14702850242

Pull #296

github

web-flow
Merge 56dc45acc into 38296893c
Pull Request #296: Incorporate feedback on backend rewrite

36 of 60 new or added lines in 8 files covered. (60.0%)

2 existing lines in 2 files now uncovered.

3519 of 4158 relevant lines covered (84.63%)

3.38 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.23
/src/pymanopt/backends/tensorflow_backend.py
1
from numbers import Number
4✔
2
from typing import Any, Callable, Literal, Optional, Union
4✔
3

4
import numpy as np
4✔
5
import scipy
4✔
6
import tensorflow as tf
4✔
7

8
from pymanopt.backends.backend import Backend, DTypePrecision, TupleOrList
4✔
9
from pymanopt.tools import (
4✔
10
    bisect_sequence,
11
    unpack_singleton_sequence_return_value,
12
)
13

14

15
# This allows to use multiple features present in numpy and other backends:
16
# - tranpose of matrices with x.T
17
# - type promotion between floats, ints and complex
18
# - correct broadcasting for tensors with different ndims (in particular for
19
#   matrix vector multiplication)
20
# for more details see documentation:
21
# https://www.tensorflow.org/api_docs/python/tf/experimental/numpy/experimental_enable_numpy_behavior
22
tf.experimental.numpy.experimental_enable_numpy_behavior(prefer_float32=True)
4✔
23

24

25
def elementary_math_function(
4✔
26
    f: Callable[["TensorflowBackend", tf.Tensor], tf.Tensor],
27
) -> Callable[
28
    ["TensorflowBackend", Union[tf.Tensor, Number]], Union[tf.Tensor, Number]
29
]:
30
    def inner(
4✔
31
        self: "TensorflowBackend", x: Union[tf.Tensor, Number]
32
    ) -> Union[tf.Tensor, Number]:
33
        if isinstance(x, tf.Tensor):
4✔
34
            return f(self, x)
4✔
35
        else:
36
            return f(self, self.array(x)).numpy().item()
4✔
37

38
    inner.__doc__ = f.__doc__
4✔
39
    inner.__name__ = f.__name__
4✔
40
    return inner
4✔
41

42

43
class TensorflowBackend(Backend):
4✔
44
    ##########################################################################
45
    # Common attributes, properties and methods
46
    ##########################################################################
47
    array_t = tf.Tensor  # type: ignore
4✔
48
    _dtype: tf.DType
4✔
49

50
    def __init__(self, dtype=tf.float64):
4✔
51
        if dtype not in {
4✔
52
            tf.float32,
53
            tf.float64,
54
            tf.complex64,
55
            tf.complex128,
56
        }:
NEW
57
            raise ValueError(f"dtype {dtype} is not supported")
×
58
        self._dtype = dtype
4✔
59

60
    @property
4✔
61
    def dtype(self) -> tf.DType:
4✔
62
        return self._dtype
4✔
63

64
    @property
4✔
65
    def dtype_precision(self) -> DTypePrecision:
4✔
66
        return (
4✔
67
            DTypePrecision.SINGLE
68
            if (self.dtype == tf.float32 or self.dtype == tf.complex64)
69
            else DTypePrecision.DOUBLE
70
        )
71

72
    @property
4✔
73
    def is_dtype_real(self):
4✔
74
        return self.dtype in {tf.float32, tf.float64}
4✔
75

76
    @staticmethod
4✔
77
    def DEFAULT_REAL_DTYPE():
4✔
78
        return tf.float64
×
79

80
    @staticmethod
4✔
81
    def DEFAULT_COMPLEX_DTYPE():
4✔
82
        return tf.complex128
4✔
83

84
    def __repr__(self):
4✔
85
        return f"TensorflowBackend(dtype={self.dtype})"
4✔
86

87
    def to_real_backend(self) -> "TensorflowBackend":
4✔
88
        if self.is_dtype_real:
4✔
89
            return self
4✔
90
        if self.dtype == tf.complex64:
4✔
91
            return TensorflowBackend(dtype=tf.float32)
×
92
        elif self.dtype == tf.complex128:
4✔
93
            return TensorflowBackend(dtype=tf.float64)
4✔
94
        else:
95
            raise ValueError(f"dtype {self.dtype} is not supported")
×
96

97
    def to_complex_backend(self) -> "TensorflowBackend":
4✔
98
        if not self.is_dtype_real:
4✔
99
            return self
4✔
100
        if self.dtype == tf.float32:
4✔
101
            return TensorflowBackend(dtype=tf.complex64)
4✔
102
        elif self.dtype == tf.float64:
4✔
103
            return TensorflowBackend(dtype=tf.complex128)
4✔
104
        else:
105
            raise ValueError(f"dtype {self.dtype} is not supported")
×
106

107
    def _complex_to_real_dtype(self, complex_dtype: tf.DType) -> tf.DType:
4✔
108
        if complex_dtype == tf.complex64:
4✔
109
            return tf.float32
4✔
110
        elif complex_dtype == tf.complex128:
4✔
111
            return tf.float64
4✔
112
        else:
113
            raise ValueError(f"Provided dtype {complex_dtype} is not complex.")
×
114

115
    ##############################################################################
116
    # Autodiff methods
117
    ##############################################################################
118

119
    def _sanitize_gradient(self, tensor, grad):
4✔
120
        if grad is None:
4✔
121
            return tf.zeros_like(tensor)
4✔
122
        return grad
4✔
123

124
    def _sanitize_gradients(self, tensors, grads):
4✔
125
        return list(map(self._sanitize_gradient, tensors, grads))
4✔
126

127
    def generate_gradient_operator(self, function, num_arguments):
4✔
128
        def gradient(*args):
4✔
129
            with tf.GradientTape() as tape:
4✔
130
                for arg in args:
4✔
131
                    tape.watch(arg)
4✔
132
                gradients = tape.gradient(function(*args), args)
4✔
133
                return self._sanitize_gradients(args, gradients)
4✔
134

135
        if num_arguments == 1:
4✔
136
            return unpack_singleton_sequence_return_value(gradient)
4✔
137
        return gradient
4✔
138

139
    def generate_hessian_operator(self, function, num_arguments):
4✔
140
        def hessian_vector_product(*args):
4✔
141
            arguments, vectors = bisect_sequence(args)
4✔
142
            with (
4✔
143
                tf.GradientTape() as tape,
144
                tf.autodiff.ForwardAccumulator(
145
                    arguments, vectors
146
                ) as accumulator,
147
            ):
148
                for argument in arguments:
4✔
149
                    tape.watch(argument)
4✔
150
                gradients = tape.gradient(function(*arguments), arguments)
4✔
151
            return self._sanitize_gradients(
4✔
152
                arguments, accumulator.jvp(gradients)
153
            )
154

155
        if num_arguments == 1:
4✔
156
            return unpack_singleton_sequence_return_value(
4✔
157
                hessian_vector_product
158
            )
159
        return hessian_vector_product
4✔
160

161
    ##############################################################################
162
    # Numerics functions
163
    ##############################################################################
164

165
    @elementary_math_function
4✔
166
    def abs(self, array: tf.Tensor) -> tf.Tensor:
4✔
167
        return tf.abs(array)
4✔
168

169
    def all(self, array: tf.Tensor) -> bool:
4✔
170
        return tf.reduce_all(tf.constant(array, dtype=tf.bool)).numpy().item()
4✔
171

172
    def allclose(
4✔
173
        self,
174
        array_a: tf.Tensor,
175
        array_b: tf.Tensor,
176
        rtol: float = 1e-5,
177
        atol: float = 1e-8,
178
    ) -> bool:
179
        return tf.reduce_all(
4✔
180
            tf.abs(array_a - array_b) <= (atol + rtol * tf.abs(array_b))
181
        )
182

183
    def any(self, array: tf.Tensor) -> bool:
4✔
184
        return tf.reduce_any(tf.constant(array, dtype=tf.bool)).numpy().item()
4✔
185

186
    def arange(
4✔
187
        self,
188
        start: int,
189
        stop: Optional[int] = None,
190
        step: Optional[int] = None,
191
    ) -> tf.Tensor:
192
        if stop is None:
×
193
            return tf.range(start)
×
194
        if step is None:
×
195
            return tf.range(start, stop)
×
196
        return tf.range(start, stop, step)
×
197

198
    @elementary_math_function
4✔
199
    def arccos(self, array: tf.Tensor) -> tf.Tensor:
4✔
200
        return tf.math.acos(array)
4✔
201

202
    @elementary_math_function
4✔
203
    def arccosh(self, array: tf.Tensor) -> tf.Tensor:
4✔
204
        return tf.math.acosh(array)
4✔
205

206
    @elementary_math_function
4✔
207
    def arctan(self, array: tf.Tensor) -> tf.Tensor:
4✔
208
        return tf.math.atan(array)
4✔
209

210
    @elementary_math_function
4✔
211
    def arctanh(self, array: tf.Tensor) -> tf.Tensor:
4✔
212
        return tf.math.atanh(array)
4✔
213

214
    def argmin(self, array: tf.Tensor):
4✔
215
        return tf.argmin(array)
×
216

217
    def argsort(self, array: tf.Tensor):
4✔
218
        return tf.argsort(array)
×
219

220
    def array(self, array: Any) -> tf.Tensor:  # type: ignore
4✔
221
        if isinstance(array, tf.Tensor):
4✔
222
            if self.is_dtype_real and self.iscomplexobj(array):
4✔
223
                array = tf.math.real(array)
4✔
224
            array = tf.cast(array, dtype=self.dtype)
4✔
225
        return tf.convert_to_tensor(array, dtype=self.dtype)
4✔
226

227
    def assert_allclose(
4✔
228
        self,
229
        array_a: tf.Tensor,
230
        array_b: tf.Tensor,
231
        rtol: float = 1e-6,
232
        atol: float = 1e-6,
233
    ) -> None:
234
        def max_abs(x):
4✔
UNCOV
235
            return tf.math.reduce_max(tf.abs(x))
×
236

237
        if not self.allclose(array_a, array_b, rtol, atol):
4✔
NEW
238
            raise ValueError(
×
239
                "Arrays are not almost equal.\n"
240
                f"Max absolute difference: {max_abs(array_a - array_b)}"
241
                f" (atol={atol})\n"
242
                "Max relative difference: "
243
                f"{max_abs(array_a - array_b) / max_abs(array_b)}"
244
                f" (rtol={rtol})"
245
            )
246

247
    def assert_equal(
4✔
248
        self,
249
        array_a: tf.Tensor,
250
        array_b: tf.Tensor,
251
    ) -> None:
NEW
252
        if not tf.reduce_all(tf.equal(array_a, array_b)):
×
NEW
253
            raise ValueError(f"Arrays are not equal: {array_a} vs {array_b}")
×
254

255
    def concatenate(
4✔
256
        self, arrays: TupleOrList[tf.Tensor], axis: int = 0
257
    ) -> tf.Tensor:
258
        return tf.concat(arrays, axis)
4✔
259

260
    @elementary_math_function
4✔
261
    def conjugate(self, array: tf.Tensor) -> tf.Tensor:
4✔
262
        return tf.math.conj(array)
4✔
263

264
    @elementary_math_function
4✔
265
    def cos(self, array: tf.Tensor) -> tf.Tensor:
4✔
266
        return tf.math.cos(array)
4✔
267

268
    def diag(self, array: tf.Tensor) -> tf.Tensor:
4✔
269
        return tf.linalg.diag(array)
4✔
270

271
    def diagonal(self, array: tf.Tensor, axis1: int, axis2: int) -> tf.Tensor:
4✔
272
        # TODO: check correctness
273
        return tf.linalg.diag_part(array)
×
274

275
    def eps(self) -> float:
4✔
276
        return tf.experimental.numpy.finfo(self.dtype).eps
4✔
277

278
    @elementary_math_function
4✔
279
    def exp(self, array: tf.Tensor) -> tf.Tensor:
4✔
280
        return tf.math.exp(array)
4✔
281

282
    def expand_dims(self, array: tf.Tensor, axis: int) -> tf.Tensor:
4✔
283
        return tf.expand_dims(array, axis)
4✔
284

285
    def eye(self, size: int) -> tf.Tensor:
4✔
286
        return tf.eye(size, dtype=self.dtype)
4✔
287

288
    def hstack(self, arrays: TupleOrList[tf.Tensor]) -> tf.Tensor:
4✔
289
        return tf.concat(arrays, axis=1)
4✔
290

291
    def imag(self, array: tf.Tensor) -> tf.Tensor:
4✔
292
        return tf.math.imag(array)
4✔
293

294
    def iscomplexobj(self, array: tf.Tensor) -> bool:
4✔
295
        return tf.experimental.numpy.iscomplexobj(array)
4✔
296

297
    @elementary_math_function
4✔
298
    def isnan(self, array: tf.Tensor) -> tf.Tensor:
4✔
299
        return tf.math.is_nan(array)
×
300

301
    def isrealobj(self, array: tf.Tensor) -> bool:
4✔
302
        return tf.experimental.numpy.isrealobj(array)
4✔
303

304
    def linalg_cholesky(self, array: tf.Tensor) -> tf.Tensor:
4✔
305
        return tf.linalg.cholesky(array)
4✔
306

307
    def linalg_det(self, array: tf.Tensor) -> tf.Tensor:
4✔
308
        return tf.linalg.det(array)
4✔
309

310
    def linalg_eigh(self, array: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
4✔
311
        w, u = tf.linalg.eigh(array)
×
312
        return tf.math.real(w), u
×
313

314
    def linalg_eigvalsh(
4✔
315
        self, array_x: tf.Tensor, array_y: Optional[tf.Tensor] = None
316
    ) -> tf.Tensor:
317
        if array_y is None:
4✔
318
            return tf.math.real(tf.linalg.eigvalsh(array_x))
4✔
319
        else:
320
            return self.array(
×
321
                np.vectorize(
322
                    scipy.linalg.eigvalsh, signature="(m,m),(m,m)->(m)"
323
                )(array_x.numpy(), array_y.numpy())
324
            )
325

326
    def linalg_expm(
4✔
327
        self, array: tf.Tensor, symmetric: bool = False
328
    ) -> tf.Tensor:
329
        if not symmetric:
4✔
330
            return tf.linalg.expm(array)
4✔
331

332
        w, v = tf.linalg.eigh(array)
×
333
        w = tf.expand_dims(tf.exp(w), axis=-1)
×
334
        expmA = v @ (w * tf.linalg.adjoint(v))
×
335
        if array.dtype in {tf.float32, tf.float64}:
×
336
            return tf.math.real(expmA)
×
337
        return expmA
×
338

339
    def linalg_inv(self, array: tf.Tensor) -> tf.Tensor:
4✔
340
        return tf.linalg.inv(array)
4✔
341

342
    def linalg_logm(
4✔
343
        self, array: tf.Tensor, positive_definite: bool = False
344
    ) -> tf.Tensor:
345
        if not positive_definite:
4✔
346
            return self.array(
4✔
347
                tf.linalg.logm(self.to_complex_backend().array(array))
348
            )
349

350
        w, v = tf.linalg.eigh(array)
4✔
351
        w = tf.expand_dims(tf.math.log(w), axis=-1)
4✔
352
        logmA = v @ (w * tf.linalg.adjoint(v))
4✔
353
        if array.dtype in {tf.float32, tf.float64}:
4✔
354
            return tf.math.real(logmA)
4✔
355
        return logmA
4✔
356

357
    def linalg_matrix_rank(self, array: tf.Tensor) -> int:
4✔
358
        return tf.linalg.matrix_rank(array).numpy().item()
4✔
359

360
    def linalg_norm(
4✔
361
        self,
362
        array: tf.Tensor,
363
        ord: Union[int, Literal["fro"], None] = None,
364
        axis: Union[int, TupleOrList[int], None] = None,
365
        keepdims: bool = False,
366
    ) -> tf.Tensor:
367
        if ord == "fro" or ord is None:
4✔
368
            ord = "euclidean"  # type: ignore
4✔
369
        return tf.math.real(
4✔
370
            tf.norm(array, ord=ord, axis=axis, keepdims=keepdims)
371
        )
372

373
    def linalg_qr(self, array: tf.Tensor) -> tf.Tensor:
4✔
374
        q, r = tf.linalg.qr(array)
4✔
375
        # Compute signs or unit-modulus phase of entries of diagonal of r.
376
        s = tf.identity(tf.linalg.diag_part(r))
4✔
377
        s = tf.where(tf.equal(s, 0.0), tf.ones_like(s), s)
4✔
378
        s = s / tf.cast(tf.abs(s), dtype=self.dtype)
4✔
379
        s = tf.expand_dims(s, axis=-1)
4✔
380
        # normalize q and r to have either 1 or unit-modulus on the diagonal of r
381
        q = q * self.transpose(s)
4✔
382
        r = r * self.conjugate(s)
4✔
383
        return q, r
4✔
384

385
    def linalg_solve(
4✔
386
        self, array_a: tf.Tensor, array_b: tf.Tensor
387
    ) -> tf.Tensor:
388
        return tf.linalg.solve(array_a, array_b)
4✔
389
        # if array_b.ndim < array_a.ndim:
390
        #     array_b = tf.expand_dims(array_b, -1)
391
        # sol = tf.linalg.solve(array_a, array_b)
392
        # return sol[..., 0] if array_b.ndim < array_a.ndim else sol
393

394
    def linalg_solve_continuous_lyapunov(
4✔
395
        self, array_a: tf.Tensor, array_q: tf.Tensor
396
    ) -> tf.Tensor:
397
        return self.array(
4✔
398
            scipy.linalg.solve_continuous_lyapunov(
399
                array_a.numpy(), array_q.numpy()
400
            )
401
        )
402

403
    def linalg_svd(
4✔
404
        self, array: tf.Tensor, full_matrices: bool = True
405
    ) -> tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
406
        s, u, v = tf.linalg.svd(array, full_matrices=full_matrices)
4✔
407
        return u, s, self.conjugate_transpose(v)
4✔
408

409
    def linalg_svdvals(self, array: tf.Tensor) -> tf.Tensor:
4✔
410
        return tf.linalg.svd(array, compute_uv=False)
4✔
411

412
    @elementary_math_function
4✔
413
    def log(self, array: tf.Tensor) -> tf.Tensor:
4✔
414
        return tf.math.log(array)
×
415

416
    @elementary_math_function
4✔
417
    def log10(self, array: tf.Tensor) -> tf.Tensor:
4✔
418
        return tf.math.log(array) / tf.math.log(tf.constant(10.0))
×
419

420
    def logspace(self, start: float, stop: float, num: int) -> tf.Tensor:
4✔
421
        return tf.experimental.numpy.logspace(
×
422
            start, stop, num, dtype=self.dtype
423
        )
424

425
    def matvec(self, A: tf.Tensor, x: tf.Tensor) -> tf.Tensor:
4✔
NEW
426
        if A.ndim < 2:
×
NEW
427
            raise ValueError("Tensor must have at least have 2 dimension")
×
428
        if x.ndim == A.ndim - 1:
×
429
            return self.squeeze(A @ tf.expand_dims(x, -1))
×
430
        return tf.matmul(A, x)
×
431

432
    def matmul(self, A: tf.Tensor, B: tf.Tensor) -> tf.Tensor:
4✔
433
        return tf.matmul(A, B)
×
434

435
    def multieye(self, k: int, n: int) -> tf.Tensor:
4✔
436
        return tf.eye(n, batch_shape=[k], dtype=self.dtype)
4✔
437

438
    def ndim(self, array: tf.Tensor) -> int:
4✔
439
        return array.shape.rank
4✔
440

441
    def ones(self, shape: TupleOrList[int]) -> tf.Tensor:
4✔
442
        return tf.ones(shape, dtype=self.dtype)
4✔
443

444
    def ones_bool(self, shape: TupleOrList[int]) -> tf.Tensor:
4✔
445
        return tf.ones(shape, dtype=tf.bool)
×
446

447
    def prod(self, array: tf.Tensor) -> float:
4✔
448
        return tf.reduce_prod(array).numpy().item()
×
449

450
    def random_normal(
4✔
451
        self,
452
        loc: float = 0.0,
453
        scale: float = 1.0,
454
        size: Union[int, TupleOrList[int]] = 1,
455
    ) -> tf.Tensor:
456
        # pre-process the size
457
        if isinstance(size, int):
4✔
458
            new_size = (size,)
4✔
459
        elif size is None:
4✔
460
            new_size = (1,)
×
461
        else:
462
            new_size = size
4✔
463
        new_size = tf.constant(new_size)
4✔
464
        # sample
465
        if self.is_dtype_real:
4✔
466
            samples = tf.random.normal(
4✔
467
                shape=new_size, mean=loc, stddev=scale, dtype=self.dtype
468
            )
469
        else:
470
            real_dtype = self._complex_to_real_dtype(self.dtype)
4✔
471
            samples = tf.cast(
4✔
472
                tf.random.normal(shape=new_size, mean=loc, dtype=real_dtype),
473
                self.dtype,
474
            ) + 1j * tf.cast(
475
                tf.random.normal(shape=new_size, mean=loc, dtype=real_dtype),
476
                self.dtype,
477
            )
478
        # post-process
479
        return samples.numpy().item() if size is None else samples
4✔
480

481
    def random_uniform(self, size: Optional[int] = None) -> tf.Tensor:
4✔
482
        # pre-process the size
483
        if isinstance(size, int):
4✔
484
            new_size = (size,)
4✔
485
        elif size is None:
4✔
486
            new_size = (1,)
4✔
487
        else:
488
            new_size = size
4✔
489
        new_size = tf.constant(new_size)
4✔
490
        # sample
491
        if self.is_dtype_real:
4✔
492
            samples = tf.random.uniform(shape=new_size, dtype=self.dtype)
4✔
493
        else:
494
            real_dtype = self._complex_to_real_dtype(self.dtype)
×
495
            samples = tf.cast(
×
496
                tf.random.uniform(shape=new_size, dtype=real_dtype), self.dtype
497
            ) + 1j * tf.cast(
498
                tf.random.uniform(shape=new_size, dtype=real_dtype), self.dtype
499
            )
500
        # post-process
501
        return samples.numpy().item() if size is None else samples
4✔
502

503
    @elementary_math_function
4✔
504
    def real(self, array: tf.Tensor) -> tf.Tensor:
4✔
505
        return tf.math.real(array)
4✔
506

507
    def reshape(
4✔
508
        self, array: tf.Tensor, newshape: TupleOrList[int]
509
    ) -> tf.Tensor:
510
        return tf.reshape(array, newshape)
4✔
511

512
    @elementary_math_function
4✔
513
    def sin(self, array: tf.Tensor) -> tf.Tensor:
4✔
514
        return tf.math.sin(array)
4✔
515

516
    @elementary_math_function
4✔
517
    def sinc(self, array: tf.Tensor) -> tf.Tensor:
4✔
518
        return tf.experimental.numpy.sinc(array)
4✔
519

520
    def sort(self, array: tf.Tensor, descending: bool = False) -> tf.Tensor:
4✔
521
        return tf.sort(
4✔
522
            array, direction="DESCENDING" if descending else "ASCENDING"
523
        )
524

525
    @elementary_math_function
4✔
526
    def sqrt(self, array: tf.Tensor) -> tf.Tensor:
4✔
527
        return tf.math.sqrt(array)
4✔
528

529
    def squeeze(self, array: tf.Tensor) -> tf.Tensor:
4✔
530
        return tf.squeeze(array)
4✔
531

532
    def stack(
4✔
533
        self, arrays: TupleOrList[tf.Tensor], axis: int = 0
534
    ) -> tf.Tensor:
535
        return tf.stack(arrays, axis=axis)
4✔
536

537
    def sum(
4✔
538
        self,
539
        array: tf.Tensor,
540
        axis: Union[int, TupleOrList[int], None] = None,
541
        keepdims: bool = False,
542
    ) -> tf.Tensor:
543
        return tf.reduce_sum(array, axis=axis, keepdims=keepdims)
4✔
544

545
    @elementary_math_function
4✔
546
    def tan(self, array: tf.Tensor) -> tf.Tensor:
4✔
547
        return tf.math.tan(array)
×
548

549
    @elementary_math_function
4✔
550
    def tanh(self, array: tf.Tensor) -> tf.Tensor:
4✔
551
        return tf.math.tanh(array)
4✔
552

553
    def tensordot(
4✔
554
        self, a: tf.Tensor, b: tf.Tensor, axes: int = 2
555
    ) -> tf.Tensor:
556
        return tf.tensordot(a, b, axes=axes)
4✔
557

558
    def tile(
4✔
559
        self, array: tf.Tensor, reps: Union[int, TupleOrList[int]]
560
    ) -> tf.Tensor:
561
        return tf.tile(array, reps)
×
562

563
    def trace(self, array: tf.Tensor) -> tf.Tensor:
4✔
564
        return (
4✔
565
            tf.linalg.trace(array).numpy().item()
566
            if array.ndim == 2
567
            else tf.linalg.trace(array)
568
        )
569

570
    def transpose(self, array: tf.Tensor) -> tf.Tensor:
4✔
571
        perm = list(range(self.ndim(array)))
4✔
572
        perm[-1], perm[-2] = perm[-2], perm[-1]
4✔
573
        return tf.transpose(array, perm)
4✔
574

575
    def triu(self, array: tf.Tensor, k: int = 0) -> tf.Tensor:
4✔
576
        return tf.experimental.numpy.triu(array, k)
×
577

578
    def vstack(self, arrays: TupleOrList[tf.Tensor]) -> tf.Tensor:
4✔
579
        return tf.concat(arrays, axis=0)
4✔
580

581
    def where(
4✔
582
        self,
583
        condition: tf.Tensor,
584
        x: Optional[tf.Tensor] = None,
585
        y: Optional[tf.Tensor] = None,
586
    ) -> tf.Tensor:
587
        if x is None and y is None:
4✔
588
            return tf.where(condition)
×
589
        elif x is not None and y is not None:
4✔
590
            return tf.where(condition, x, y)
4✔
591
        else:
592
            raise ValueError(
×
593
                f"Both x and y have to be specified but are respectively {x} and {y}"
594
            )
595

596
    def zeros(self, shape: TupleOrList[int]) -> tf.Tensor:
4✔
597
        return tf.zeros(shape, dtype=self.dtype)
4✔
598

599
    def zeros_bool(self, shape: TupleOrList[int]) -> tf.Tensor:
4✔
600
        return tf.zeros(shape, tf.bool)
×
601

602
    def zeros_like(self, array: tf.Tensor) -> tf.Tensor:
4✔
603
        return tf.zeros_like(array, dtype=self.dtype)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc