• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / trax / 3029

pending completion
3029

push

travis-ci

Copybara-Service
[TRAX] Backends are enums instead of strings.

PiperOrigin-RevId: 323567051

49 of 49 new or added lines in 14 files covered. (100.0%)

2252 of 9833 relevant lines covered (22.9%)

0.23 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/trax/tf_numpy/numpy_impl/math_ops.py
1
# coding=utf-8
2
# Copyright 2020 The Trax Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Mathematical operations."""
×
17
from __future__ import absolute_import
×
18
from __future__ import division
×
19
from __future__ import print_function
×
20

21
import sys
×
22

23
import numpy as np
×
24
import six
×
25

26
import tensorflow.compat.v2 as tf
×
27

28
from trax.tf_numpy.numpy_impl import array_ops
×
29
from trax.tf_numpy.numpy_impl import arrays
×
30
from trax.tf_numpy.numpy_impl import dtypes
×
31
from trax.tf_numpy.numpy_impl import utils
×
32

33

34
@utils.np_doc_only(np.dot)
×
35
def dot(a, b):  # pylint: disable=missing-docstring
36
  def f(a, b):  # pylint: disable=missing-docstring
×
37
    return utils.cond(
×
38
        utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0),
39
        lambda: a * b,
40
        lambda: utils.cond(  # pylint: disable=g-long-lambda
41
            tf.rank(b) == 1,
42
            lambda: tf.tensordot(a, b, axes=[[-1], [-1]]),
43
            lambda: tf.tensordot(a, b, axes=[[-1], [-2]])))
44
  return _bin_op(f, a, b)
×
45

46

47
# TODO(wangpeng): Make element-wise ops `ufunc`s
48
def _bin_op(tf_fun, a, b, promote=True):
×
49
  if promote:
×
50
    a, b = array_ops._promote_dtype(a, b)  # pylint: disable=protected-access
×
51
  else:
52
    a = array_ops.array(a)
×
53
    b = array_ops.array(b)
×
54
  return utils.tensor_to_ndarray(tf_fun(a.data, b.data))
×
55

56

57
@utils.np_doc(np.add)
×
58
def add(x1, x2):
59
  def add_or_or(x1, x2):
×
60
    if x1.dtype == tf.bool:
×
61
      assert x2.dtype == tf.bool
×
62
      return tf.logical_or(x1, x2)
×
63
    return tf.add(x1, x2)
×
64
  return _bin_op(add_or_or, x1, x2)
×
65

66

67
@utils.np_doc(np.subtract)
×
68
def subtract(x1, x2):
69
  return _bin_op(tf.subtract, x1, x2)
×
70

71

72
@utils.np_doc(np.multiply)
×
73
def multiply(x1, x2):
74
  def mul_or_and(x1, x2):
×
75
    if x1.dtype == tf.bool:
×
76
      assert x2.dtype == tf.bool
×
77
      return tf.logical_and(x1, x2)
×
78
    return tf.multiply(x1, x2)
×
79
  return _bin_op(mul_or_and, x1, x2)
×
80

81

82
@utils.np_doc(np.true_divide)
×
83
def true_divide(x1, x2):
84
  def _avoid_float64(x1, x2):
×
85
    if x1.dtype == x2.dtype and x1.dtype in (tf.int32, tf.int64):
×
86
      x1 = tf.cast(x1, dtype=tf.float32)
×
87
      x2 = tf.cast(x2, dtype=tf.float32)
×
88
    return x1, x2
×
89

90
  def f(x1, x2):
×
91
    if x1.dtype == tf.bool:
×
92
      assert x2.dtype == tf.bool
×
93
      float_ = dtypes.default_float_type()
×
94
      x1 = tf.cast(x1, float_)
×
95
      x2 = tf.cast(x2, float_)
×
96
    if not dtypes.is_allow_float64():
×
97
      # tf.math.truediv in Python3 produces float64 when both inputs are int32
98
      # or int64. We want to avoid that when is_allow_float64() is False.
99
      x1, x2 = _avoid_float64(x1, x2)
×
100
    return tf.math.truediv(x1, x2)
×
101
  return _bin_op(f, x1, x2)
×
102

103

104
divide = true_divide
×
105

106

107
@utils.np_doc(np.floor_divide)
×
108
def floor_divide(x1, x2):
109
  def f(x1, x2):
×
110
    if x1.dtype == tf.bool:
×
111
      assert x2.dtype == tf.bool
×
112
      x1 = tf.cast(x1, tf.int8)
×
113
      x2 = tf.cast(x2, tf.int8)
×
114
    return tf.math.floordiv(x1, x2)
×
115
  return _bin_op(f, x1, x2)
×
116

117

118
@utils.np_doc(np.mod)
×
119
def mod(x1, x2):
120
  def f(x1, x2):
×
121
    if x1.dtype == tf.bool:
×
122
      assert x2.dtype == tf.bool
×
123
      x1 = tf.cast(x1, tf.int8)
×
124
      x2 = tf.cast(x2, tf.int8)
×
125
    return tf.math.mod(x1, x2)
×
126
  return _bin_op(f, x1, x2)
×
127

128

129
remainder = mod
×
130

131

132
@utils.np_doc(np.divmod)
×
133
def divmod(x1, x2):
134
  return floor_divide(x1, x2), mod(x1, x2)
×
135

136

137
@utils.np_doc(np.maximum)
×
138
def maximum(x1, x2):
139
  def max_or_or(x1, x2):
×
140
    if x1.dtype == tf.bool:
×
141
      assert x2.dtype == tf.bool
×
142
      return tf.logical_or(x1, x2)
×
143
    return tf.math.maximum(x1, x2)
×
144
  return _bin_op(max_or_or, x1, x2)
×
145

146

147
@utils.np_doc(np.minimum)
×
148
def minimum(x1, x2):
149
  def min_or_and(x1, x2):
×
150
    if x1.dtype == tf.bool:
×
151
      assert x2.dtype == tf.bool
×
152
      return tf.logical_and(x1, x2)
×
153
    return tf.math.minimum(x1, x2)
×
154
  return _bin_op(min_or_and, x1, x2)
×
155

156

157
@utils.np_doc(np.clip)
×
158
def clip(a, a_min, a_max):  # pylint: disable=missing-docstring
159
  if a_min is None and a_max is None:
×
160
    raise ValueError('Not more than one of `a_min` and `a_max` may be `None`.')
×
161
  if a_min is None:
×
162
    return minimum(a, a_max)
×
163
  elif a_max is None:
×
164
    return maximum(a, a_min)
×
165
  else:
166
    a, a_min, a_max = array_ops._promote_dtype(a, a_min, a_max)  # pylint: disable=protected-access
×
167
    return utils.tensor_to_ndarray(
×
168
        tf.clip_by_value(*utils.tf_broadcast(a.data, a_min.data, a_max.data)))
169

170

171
@utils.np_doc(np.matmul)
×
172
def matmul(x1, x2):  # pylint: disable=missing-docstring
173
  def f(x1, x2):
×
174
    try:
×
175
      return utils.cond(tf.rank(x2) == 1,
×
176
                        lambda: tf.tensordot(x1, x2, axes=1),
177
                        lambda: utils.cond(tf.rank(x1) == 1,  # pylint: disable=g-long-lambda
178
                                           lambda: tf.tensordot(  # pylint: disable=g-long-lambda
179
                                               x1, x2, axes=[[0], [-2]]),
180
                                           lambda: tf.matmul(x1, x2)))
181
    except tf.errors.InvalidArgumentError as err:
×
182
      six.reraise(ValueError, ValueError(str(err)), sys.exc_info()[2])
×
183
  return _bin_op(f, x1, x2)
×
184

185

186
@utils.np_doc(np.tensordot)
×
187
def tensordot(a, b, axes=2):
×
188
  return _bin_op(lambda a, b: tf.tensordot(a, b, axes=axes), a, b)
×
189

190

191
@utils.np_doc_only(np.inner)
×
192
def inner(a, b):
193
  def f(a, b):
×
194
    return utils.cond(utils.logical_or(tf.rank(a) == 0, tf.rank(b) == 0),
×
195
                      lambda: a * b,
196
                      lambda: tf.tensordot(a, b, axes=[[-1], [-1]]))
197
  return _bin_op(f, a, b)
×
198

199

200
@utils.np_doc(np.cross)
×
201
def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):  # pylint: disable=missing-docstring
×
202
  def f(a, b):  # pylint: disable=missing-docstring
×
203
    # We can't assign to captured variable `axisa`, so make a new variable
204
    axis_a = axisa
×
205
    axis_b = axisb
×
206
    axis_c = axisc
×
207
    if axis is not None:
×
208
      axis_a = axis
×
209
      axis_b = axis
×
210
      axis_c = axis
×
211
    if axis_a < 0:
×
212
      axis_a = utils.add(axis_a, tf.rank(a))
×
213
    if axis_b < 0:
×
214
      axis_b = utils.add(axis_b, tf.rank(b))
×
215
    def maybe_move_axis_to_last(a, axis):
×
216
      def move_axis_to_last(a, axis):
×
217
        return tf.transpose(
×
218
            a, tf.concat(
219
                [tf.range(axis), tf.range(axis + 1, tf.rank(a)), [axis]],
220
                axis=0))
221
      return utils.cond(
×
222
          axis == utils.subtract(tf.rank(a), 1),
223
          lambda: a,
224
          lambda: move_axis_to_last(a, axis))
225
    a = maybe_move_axis_to_last(a, axis_a)
×
226
    b = maybe_move_axis_to_last(b, axis_b)
×
227
    a_dim = utils.getitem(tf.shape(a), -1)
×
228
    b_dim = utils.getitem(tf.shape(b), -1)
×
229
    def maybe_pad_0(a, size_of_last_dim):
×
230
      def pad_0(a):
×
231
        return tf.pad(a, tf.concat([tf.zeros([tf.rank(a) - 1, 2], tf.int32),
×
232
                                    tf.constant([[0, 1]], tf.int32)], axis=0))
233
      return utils.cond(size_of_last_dim == 2,
×
234
                        lambda: pad_0(a),
235
                        lambda: a)
236
    a = maybe_pad_0(a, a_dim)
×
237
    b = maybe_pad_0(b, b_dim)
×
238
    c = tf.linalg.cross(*utils.tf_broadcast(a, b))
×
239
    if axis_c < 0:
×
240
      axis_c = utils.add(axis_c, tf.rank(c))
×
241
    def move_last_to_axis(a, axis):
×
242
      r = tf.rank(a)
×
243
      return tf.transpose(
×
244
          a, tf.concat(
245
              [tf.range(axis), [r - 1], tf.range(axis, r - 1)], axis=0))
246
    c = utils.cond(
×
247
        (a_dim == 2) & (b_dim == 2),
248
        lambda: c[..., 2],
249
        lambda: utils.cond(  # pylint: disable=g-long-lambda
250
            axis_c == utils.subtract(tf.rank(c), 1),
251
            lambda: c,
252
            lambda: move_last_to_axis(c, axis_c)))
253
    return c
×
254
  return _bin_op(f, a, b)
×
255

256

257
@utils.np_doc(np.power)
×
258
def power(x1, x2):
259
  return _bin_op(tf.math.pow, x1, x2)
×
260

261

262
@utils.np_doc(np.float_power)
×
263
def float_power(x1, x2):
264
  return power(x1, x2)
×
265

266

267
@utils.np_doc(np.arctan2)
×
268
def arctan2(x1, x2):
269
  return _bin_op(tf.math.atan2, x1, x2)
×
270

271

272
@utils.np_doc(np.nextafter)
×
273
def nextafter(x1, x2):
274
  return _bin_op(tf.math.nextafter, x1, x2)
×
275

276

277
@utils.np_doc(np.heaviside)
×
278
def heaviside(x1, x2):
279
  def f(x1, x2):
×
280
    return tf.where(x1 < 0, tf.constant(0, dtype=x2.dtype),
×
281
                    tf.where(x1 > 0, tf.constant(1, dtype=x2.dtype), x2))
282
  y = _bin_op(f, x1, x2)
×
283
  if not np.issubdtype(y.dtype, np.inexact):
×
284
    y = y.astype(dtypes.default_float_type())
×
285
  return y
×
286

287

288
@utils.np_doc(np.hypot)
×
289
def hypot(x1, x2):
290
  return sqrt(square(x1) + square(x2))
×
291

292

293
@utils.np_doc(np.kron)
×
294
def kron(a, b):
295
  # pylint: disable=protected-access,g-complex-comprehension
296
  a, b = array_ops._promote_dtype(a, b)
×
297
  ndim = max(a.ndim, b.ndim)
×
298
  if a.ndim < ndim:
×
299
    a = array_ops.reshape(a, array_ops._pad_left_to(ndim, a.shape))
×
300
  if b.ndim < ndim:
×
301
    b = array_ops.reshape(b, array_ops._pad_left_to(ndim, b.shape))
×
302
  a_reshaped = array_ops.reshape(a, [i for d in a.shape for i in (d, 1)])
×
303
  b_reshaped = array_ops.reshape(b, [i for d in b.shape for i in (1, d)])
×
304
  out_shape = tuple(np.multiply(a.shape, b.shape))
×
305
  return array_ops.reshape(a_reshaped * b_reshaped, out_shape)
×
306

307

308
@utils.np_doc(np.outer)
×
309
def outer(a, b):
310
  def f(a, b):
×
311
    return tf.reshape(a, [-1, 1]) * tf.reshape(b, [-1])
×
312
  return _bin_op(f, a, b)
×
313

314

315
# This can also be implemented via tf.reduce_logsumexp
316
@utils.np_doc(np.logaddexp)
×
317
def logaddexp(x1, x2):
318
  amax = maximum(x1, x2)
×
319
  delta = x1 - x2
×
320
  return array_ops.where(
×
321
      isnan(delta),
322
      x1 + x2,  # NaNs or infinities of the same sign.
323
      amax + log1p(exp(-abs(delta))))
324

325

326
@utils.np_doc(np.logaddexp2)
×
327
def logaddexp2(x1, x2):
328
  amax = maximum(x1, x2)
×
329
  delta = x1 - x2
×
330
  return array_ops.where(
×
331
      isnan(delta),
332
      x1 + x2,  # NaNs or infinities of the same sign.
333
      amax + log1p(exp2(-abs(delta))) / np.log(2))
334

335

336
@utils.np_doc(np.polyval)
×
337
def polyval(p, x):
338
  def f(p, x):
×
339
    if p.shape.rank == 0:
×
340
      p = tf.reshape(p, [1])
×
341
    p = tf.unstack(p)
×
342
    # TODO(wangpeng): Make tf version take a tensor for p instead of a list.
343
    y = tf.math.polyval(p, x)
×
344
    # If the polynomial is 0-order, numpy requires the result to be broadcast to
345
    # `x`'s shape.
346
    if len(p) == 1:
×
347
      y = tf.broadcast_to(y, x.shape)
×
348
    return y
×
349
  return _bin_op(f, p, x)
×
350

351

352
@utils.np_doc(np.isclose)
×
353
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):  # pylint: disable=missing-docstring
×
354
  def f(a, b):  # pylint: disable=missing-docstring
×
355
    dtype = a.dtype
×
356
    if np.issubdtype(dtype.as_numpy_dtype, np.inexact):
×
357
      rtol_ = tf.convert_to_tensor(rtol, dtype.real_dtype)
×
358
      atol_ = tf.convert_to_tensor(atol, dtype.real_dtype)
×
359
      result = (tf.math.abs(a - b) <= atol_ + rtol_ * tf.math.abs(b))
×
360
      if equal_nan:
×
361
        result = result | (tf.math.is_nan(a) & tf.math.is_nan(b))
×
362
      return result
×
363
    else:
364
      return a == b
×
365
  return _bin_op(f, a, b)
×
366

367

368
@utils.np_doc(np.allclose)
×
369
def allclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
×
370
  return array_ops.all(isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan))
×
371

372

373
def _tf_gcd(x1, x2):
×
374
  def _gcd_cond_fn(x1, x2):
×
375
    return tf.reduce_any(x2 != 0)
×
376
  def _gcd_body_fn(x1, x2):
×
377
    # tf.math.mod will raise an error when any element of x2 is 0. To avoid
378
    # that, we change those zeros to ones. Their values don't matter because
379
    # they won't be used.
380
    x2_safe = tf.where(x2 != 0, x2, tf.constant(1, x2.dtype))
×
381
    x1, x2 = (tf.where(x2 != 0, x2, x1),
×
382
              tf.where(x2 != 0, tf.math.mod(x1, x2_safe),
383
                       tf.constant(0, x2.dtype)))
384
    return (tf.where(x1 < x2, x2, x1), tf.where(x1 < x2, x1, x2))
×
385
  if (not np.issubdtype(x1.dtype.as_numpy_dtype, np.integer) or
×
386
      not np.issubdtype(x2.dtype.as_numpy_dtype, np.integer)):
387
    raise ValueError("Arguments to gcd must be integers.")
×
388
  shape = tf.broadcast_static_shape(x1.shape, x2.shape)
×
389
  x1 = tf.broadcast_to(x1, shape)
×
390
  x2 = tf.broadcast_to(x2, shape)
×
391
  gcd, _ = tf.while_loop(_gcd_cond_fn, _gcd_body_fn,
×
392
                         (tf.math.abs(x1), tf.math.abs(x2)))
393
  return gcd
×
394

395

396
@utils.np_doc(np.gcd)
×
397
def gcd(x1, x2):
398
  return _bin_op(_tf_gcd, x1, x2)
×
399

400

401
@utils.np_doc(np.lcm)
×
402
def lcm(x1, x2):
403
  def f(x1, x2):
×
404
    d = _tf_gcd(x1, x2)
×
405
    # Same as the `x2_safe` trick above
406
    d_safe = tf.where(d == 0, tf.constant(1, d.dtype), d)
×
407
    return tf.where(d == 0, tf.constant(0, d.dtype),
×
408
                    tf.math.abs(x1 * x2) // d_safe)
409
  return _bin_op(f, x1, x2)
×
410

411

412
def _bitwise_binary_op(tf_fn, x1, x2):
×
413
  def f(x1, x2):
×
414
    is_bool = (x1.dtype == tf.bool)
×
415
    if is_bool:
×
416
      assert x2.dtype == tf.bool
×
417
      x1 = tf.cast(x1, tf.int8)
×
418
      x2 = tf.cast(x2, tf.int8)
×
419
    r = tf_fn(x1, x2)
×
420
    if is_bool:
×
421
      r = tf.cast(r, tf.bool)
×
422
    return r
×
423
  return _bin_op(f, x1, x2)
×
424

425

426
@utils.np_doc(np.bitwise_and)
×
427
def bitwise_and(x1, x2):
428
  return _bitwise_binary_op(tf.bitwise.bitwise_and, x1, x2)
×
429

430

431
@utils.np_doc(np.bitwise_or)
×
432
def bitwise_or(x1, x2):
433
  return _bitwise_binary_op(tf.bitwise.bitwise_or, x1, x2)
×
434

435

436
@utils.np_doc(np.bitwise_xor)
×
437
def bitwise_xor(x1, x2):
438
  return _bitwise_binary_op(tf.bitwise.bitwise_xor, x1, x2)
×
439

440

441
@utils.np_doc(np.bitwise_not)
×
442
def bitwise_not(x):
443
  def f(x):
×
444
    if x.dtype == tf.bool:
×
445
      return tf.logical_not(x)
×
446
    return tf.bitwise.invert(x)
×
447
  return _scalar(f, x)
×
448

449

450
def _scalar(tf_fn, x, promote_to_float=False):
×
451
  """Computes the tf_fn(x) for each element in `x`.
452

453
  Args:
454
    tf_fn: function that takes a single Tensor argument.
455
    x: array_like. Could be an ndarray, a Tensor or any object that can
456
      be converted to a Tensor using `tf.convert_to_tensor`.
457
    promote_to_float: whether to cast the argument to a float dtype
458
      (`dtypes.default_float_type`) if it is not already.
459

460
  Returns:
461
    An ndarray with the same shape as `x`. The default output dtype is
462
    determined by `dtypes.default_float_type`, unless x is an ndarray with a
463
    floating point type, in which case the output type is same as x.dtype.
464
  """
465
  x = array_ops.asarray(x)
×
466
  if promote_to_float and not np.issubdtype(x.dtype, np.inexact):
×
467
    x = x.astype(dtypes.default_float_type())
×
468
  return utils.tensor_to_ndarray(tf_fn(x.data))
×
469

470

471
@utils.np_doc(np.log)
×
472
def log(x):
473
  return _scalar(tf.math.log, x, True)
×
474

475

476
@utils.np_doc(np.exp)
×
477
def exp(x):
478
  return _scalar(tf.exp, x, True)
×
479

480

481
@utils.np_doc(np.sqrt)
×
482
def sqrt(x):
483
  return _scalar(tf.sqrt, x, True)
×
484

485

486
@utils.np_doc(np.abs)
×
487
def abs(x):
488
  return _scalar(tf.math.abs, x)
×
489

490

491
@utils.np_doc(np.absolute)
×
492
def absolute(x):
493
  return abs(x)
×
494

495

496
@utils.np_doc(np.fabs)
×
497
def fabs(x):
498
  return abs(x)
×
499

500

501
@utils.np_doc(np.ceil)
×
502
def ceil(x):
503
  return _scalar(tf.math.ceil, x, True)
×
504

505

506
@utils.np_doc(np.floor)
×
507
def floor(x):
508
  return _scalar(tf.math.floor, x, True)
×
509

510

511
@utils.np_doc(np.conj)
×
512
def conj(x):
513
  return _scalar(tf.math.conj, x)
×
514

515

516
@utils.np_doc(np.negative)
×
517
def negative(x):
518
  return _scalar(tf.math.negative, x)
×
519

520

521
@utils.np_doc(np.reciprocal)
×
522
def reciprocal(x):
523
  return _scalar(tf.math.reciprocal, x)
×
524

525

526
@utils.np_doc(np.signbit)
×
527
def signbit(x):
528
  def f(x):
×
529
    if x.dtype == tf.bool:
×
530
      return tf.fill(x.shape, False)
×
531
    return x < 0
×
532
  return _scalar(f, x)
×
533

534

535
@utils.np_doc(np.sin)
×
536
def sin(x):
537
  return _scalar(tf.math.sin, x, True)
×
538

539

540
@utils.np_doc(np.cos)
×
541
def cos(x):
542
  return _scalar(tf.math.cos, x, True)
×
543

544

545
@utils.np_doc(np.tan)
×
546
def tan(x):
547
  return _scalar(tf.math.tan, x, True)
×
548

549

550
@utils.np_doc(np.sinh)
×
551
def sinh(x):
552
  return _scalar(tf.math.sinh, x, True)
×
553

554

555
@utils.np_doc(np.cosh)
×
556
def cosh(x):
557
  return _scalar(tf.math.cosh, x, True)
×
558

559

560
@utils.np_doc(np.tanh)
×
561
def tanh(x):
562
  return _scalar(tf.math.tanh, x, True)
×
563

564

565
@utils.np_doc(np.arcsin)
×
566
def arcsin(x):
567
  return _scalar(tf.math.asin, x, True)
×
568

569

570
@utils.np_doc(np.arccos)
×
571
def arccos(x):
572
  return _scalar(tf.math.acos, x, True)
×
573

574

575
@utils.np_doc(np.arctan)
×
576
def arctan(x):
577
  return _scalar(tf.math.atan, x, True)
×
578

579

580
@utils.np_doc(np.arcsinh)
×
581
def arcsinh(x):
582
  return _scalar(tf.math.asinh, x, True)
×
583

584

585
@utils.np_doc(np.arccosh)
×
586
def arccosh(x):
587
  return _scalar(tf.math.acosh, x, True)
×
588

589

590
@utils.np_doc(np.arctanh)
×
591
def arctanh(x):
592
  return _scalar(tf.math.atanh, x, True)
×
593

594

595
@utils.np_doc(np.deg2rad)
×
596
def deg2rad(x):
597
  def f(x):
×
598
    return x * (np.pi / 180.0)
×
599
  return _scalar(f, x, True)
×
600

601

602
@utils.np_doc(np.rad2deg)
×
603
def rad2deg(x):
604
  return x * (180.0 / np.pi)
×
605

606

607
_tf_float_types = [tf.bfloat16, tf.float16, tf.float32, tf.float64]
×
608

609

610
@utils.np_doc(np.angle)
×
611
def angle(z, deg=False):
×
612
  def f(x):
×
613
    if x.dtype in _tf_float_types:
×
614
      # Workaround for b/147515503
615
      return tf.where(x < 0, np.pi, 0)
×
616
    else:
617
      return tf.math.angle(x)
×
618
  y = _scalar(f, z, True)
×
619
  if deg:
×
620
    y = rad2deg(y)
×
621
  return y
×
622

623

624
@utils.np_doc(np.cbrt)
×
625
def cbrt(x):
626
  def f(x):
×
627
    # __pow__ can't handle negative base, so we use `abs` here.
628
    rt = tf.math.abs(x) ** (1.0 / 3)
×
629
    return tf.where(x < 0, -rt, rt)
×
630
  return _scalar(f, x, True)
×
631

632

633
@utils.np_doc(np.conjugate)
×
634
def conjugate(x):
635
  return _scalar(tf.math.conj, x)
×
636

637

638
@utils.np_doc(np.exp2)
×
639
def exp2(x):
640
  def f(x):
×
641
    return 2 ** x
×
642
  return _scalar(f, x, True)
×
643

644

645
@utils.np_doc(np.expm1)
×
646
def expm1(x):
647
  return _scalar(tf.math.expm1, x, True)
×
648

649

650
@utils.np_doc(np.fix)
×
651
def fix(x):
652
  def f(x):
×
653
    return tf.where(x < 0, tf.math.ceil(x), tf.math.floor(x))
×
654
  return _scalar(f, x, True)
×
655

656

657
@utils.np_doc(np.iscomplex)
×
658
def iscomplex(x):
659
  return array_ops.imag(x) != 0
×
660

661

662
@utils.np_doc(np.isreal)
×
663
def isreal(x):
664
  return array_ops.imag(x) == 0
×
665

666

667
@utils.np_doc(np.iscomplexobj)
×
668
def iscomplexobj(x):
669
  x = array_ops.array(x)
×
670
  return np.issubdtype(x.dtype, np.complexfloating)
×
671

672

673
@utils.np_doc(np.isrealobj)
×
674
def isrealobj(x):
675
  return not iscomplexobj(x)
×
676

677

678
@utils.np_doc(np.isnan)
×
679
def isnan(x):
680
  return _scalar(tf.math.is_nan, x, True)
×
681

682

683
def _make_nan_reduction(onp_reduction, reduction, init_val):
×
684
  """Helper to generate nan* functions."""
685
  @utils.np_doc(onp_reduction)
×
686
  def nan_reduction(a, axis=None, dtype=None, keepdims=False):
×
687
    a = array_ops.array(a)
×
688
    v = array_ops.array(init_val, dtype=a.dtype)
×
689
    return reduction(
×
690
        array_ops.where(isnan(a), v, a),
691
        axis=axis,
692
        dtype=dtype,
693
        keepdims=keepdims)
694
  return nan_reduction
×
695

696

697
nansum = _make_nan_reduction(np.nansum, array_ops.sum, 0)
×
698
nanprod = _make_nan_reduction(np.nanprod, array_ops.prod, 1)
×
699

700

701
@utils.np_doc(np.nanmean)
×
702
def nanmean(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=missing-docstring
×
703
  a = array_ops.array(a)
×
704
  if np.issubdtype(a.dtype, np.bool_) or np.issubdtype(a.dtype, np.integer):
×
705
    return array_ops.mean(a, axis=axis, dtype=dtype, keepdims=keepdims)
×
706
  nan_mask = logical_not(isnan(a))
×
707
  if dtype is None:
×
708
    dtype = a.dtype
×
709
  normalizer = array_ops.sum(
×
710
      nan_mask, axis=axis, dtype=dtype, keepdims=keepdims)
711
  return nansum(a, axis=axis, dtype=dtype, keepdims=keepdims) / normalizer
×
712

713

714
@utils.np_doc(np.isfinite)
×
715
def isfinite(x):
716
  return _scalar(tf.math.is_finite, x, True)
×
717

718

719
@utils.np_doc(np.isinf)
×
720
def isinf(x):
721
  return _scalar(tf.math.is_inf, x, True)
×
722

723

724
@utils.np_doc(np.isneginf)
×
725
def isneginf(x):
726
  return x == array_ops.full_like(x, -np.inf)
×
727

728

729
@utils.np_doc(np.isposinf)
×
730
def isposinf(x):
731
  return x == array_ops.full_like(x, np.inf)
×
732

733

734
@utils.np_doc(np.log2)
×
735
def log2(x):
736
  return log(x) / np.log(2)
×
737

738

739
@utils.np_doc(np.log10)
×
740
def log10(x):
741
  return log(x) / np.log(10)
×
742

743

744
@utils.np_doc(np.log1p)
×
745
def log1p(x):
746
  return _scalar(tf.math.log1p, x, True)
×
747

748

749
@utils.np_doc(np.positive)
×
750
def positive(x):
751
  return _scalar(lambda x: x, x)
×
752

753

754
@utils.np_doc(np.sinc)
×
755
def sinc(x):
756
  def f(x):
×
757
    pi_x = x * np.pi
×
758
    return tf.where(x == 0, tf.ones_like(x), tf.math.sin(pi_x) / pi_x)
×
759
  return _scalar(f, x, True)
×
760

761

762
@utils.np_doc(np.square)
×
763
def square(x):
764
  return _scalar(tf.math.square, x)
×
765

766

767
@utils.np_doc(np.diff)
×
768
def diff(a, n=1, axis=-1):
×
769
  def f(a):
×
770
    nd = a.shape.rank
×
771
    if (axis + nd if axis < 0 else axis) >= nd:
×
772
      raise ValueError("axis %s is out of bounds for array of dimension %s" %
×
773
                       (axis, nd))
774
    if n < 0:
×
775
      raise ValueError("order must be non-negative but got %s" % n)
×
776
    slice1 = [slice(None)] * nd
×
777
    slice2 = [slice(None)] * nd
×
778
    slice1[axis] = slice(1, None)
×
779
    slice2[axis] = slice(None, -1)
×
780
    slice1 = tuple(slice1)
×
781
    slice2 = tuple(slice2)
×
782
    op = tf.not_equal if a.dtype == tf.bool else tf.subtract
×
783
    for _ in range(n):
×
784
      a = op(a[slice1], a[slice2])
×
785
    return a
×
786
  return _scalar(f, a)
×
787

788

789
def _flip_args(f):
×
790
  def _f(a, b):
×
791
    return f(b, a)
×
792
  return _f
×
793

794

795
setattr(arrays.ndarray, '__abs__', absolute)
×
796
setattr(arrays.ndarray, '__floordiv__', floor_divide)
×
797
setattr(arrays.ndarray, '__rfloordiv__', _flip_args(floor_divide))
×
798
setattr(arrays.ndarray, '__mod__', mod)
×
799
setattr(arrays.ndarray, '__rmod__', _flip_args(mod))
×
800
setattr(arrays.ndarray, '__add__', add)
×
801
setattr(arrays.ndarray, '__radd__', _flip_args(add))
×
802
setattr(arrays.ndarray, '__sub__', subtract)
×
803
setattr(arrays.ndarray, '__rsub__', _flip_args(subtract))
×
804
setattr(arrays.ndarray, '__mul__', multiply)
×
805
setattr(arrays.ndarray, '__rmul__', _flip_args(multiply))
×
806
setattr(arrays.ndarray, '__pow__', power)
×
807
setattr(arrays.ndarray, '__rpow__', _flip_args(power))
×
808
setattr(arrays.ndarray, '__truediv__', true_divide)
×
809
setattr(arrays.ndarray, '__rtruediv__', _flip_args(true_divide))
×
810

811

812
def _comparison(tf_fun, x1, x2, cast_bool_to_int=False):
×
813
  dtype = utils.result_type(x1, x2)
×
814
  # Cast x1 and x2 to the result_type if needed.
815
  x1 = array_ops.array(x1, dtype=dtype)
×
816
  x2 = array_ops.array(x2, dtype=dtype)
×
817
  x1 = x1.data
×
818
  x2 = x2.data
×
819
  if cast_bool_to_int and x1.dtype == tf.bool:
×
820
    x1 = tf.cast(x1, tf.int32)
×
821
    x2 = tf.cast(x2, tf.int32)
×
822
  return utils.tensor_to_ndarray(tf_fun(x1, x2))
×
823

824

825
@utils.np_doc(np.equal)
×
826
def equal(x1, x2):
827
  return _comparison(tf.equal, x1, x2)
×
828

829

830
@utils.np_doc(np.not_equal)
×
831
def not_equal(x1, x2):
832
  return _comparison(tf.not_equal, x1, x2)
×
833

834

835
@utils.np_doc(np.greater)
×
836
def greater(x1, x2):
837
  return _comparison(tf.greater, x1, x2, True)
×
838

839

840
@utils.np_doc(np.greater_equal)
×
841
def greater_equal(x1, x2):
842
  return _comparison(tf.greater_equal, x1, x2, True)
×
843

844

845
@utils.np_doc(np.less)
×
846
def less(x1, x2):
847
  return _comparison(tf.less, x1, x2, True)
×
848

849

850
@utils.np_doc(np.less_equal)
×
851
def less_equal(x1, x2):
852
  return _comparison(tf.less_equal, x1, x2, True)
×
853

854

855
@utils.np_doc(np.array_equal)
×
856
def array_equal(a1, a2):
857
  def f(a1, a2):
×
858
    if a1.shape != a2.shape:
×
859
      return tf.constant(False)
×
860
    return tf.reduce_all(tf.equal(a1, a2))
×
861
  return _comparison(f, a1, a2)
×
862

863

864
def _logical_binary_op(tf_fun, x1, x2):
×
865
  x1 = array_ops.array(x1, dtype=np.bool_)
×
866
  x2 = array_ops.array(x2, dtype=np.bool_)
×
867
  return utils.tensor_to_ndarray(tf_fun(x1.data, x2.data))
×
868

869

870
@utils.np_doc(np.logical_and)
×
871
def logical_and(x1, x2):
872
  return _logical_binary_op(tf.logical_and, x1, x2)
×
873

874

875
@utils.np_doc(np.logical_or)
×
876
def logical_or(x1, x2):
877
  return _logical_binary_op(tf.logical_or, x1, x2)
×
878

879

880
@utils.np_doc(np.logical_xor)
×
881
def logical_xor(x1, x2):
882
  return _logical_binary_op(tf.math.logical_xor, x1, x2)
×
883

884

885
@utils.np_doc(np.logical_not)
×
886
def logical_not(x):
887
  x = array_ops.array(x, dtype=np.bool_)
×
888
  return utils.tensor_to_ndarray(tf.logical_not(x.data))
×
889

890
setattr(arrays.ndarray, '__invert__', logical_not)
×
891
setattr(arrays.ndarray, '__lt__', less)
×
892
setattr(arrays.ndarray, '__le__', less_equal)
×
893
setattr(arrays.ndarray, '__gt__', greater)
×
894
setattr(arrays.ndarray, '__ge__', greater_equal)
×
895
setattr(arrays.ndarray, '__eq__', equal)
×
896
setattr(arrays.ndarray, '__ne__', not_equal)
×
897

898

899
@utils.np_doc(np.linspace)
×
900
def linspace(  # pylint: disable=missing-docstring
×
901
    start, stop, num=50, endpoint=True, retstep=False, dtype=float, axis=0):
902
  if dtype:
×
903
    dtype = utils.result_type(dtype)
×
904
  start = array_ops.array(start, dtype=dtype).data
×
905
  stop = array_ops.array(stop, dtype=dtype).data
×
906
  if num < 0:
×
907
    raise ValueError('Number of samples {} must be non-negative.'.format(num))
×
908
  step = tf.convert_to_tensor(np.nan)
×
909
  if endpoint:
×
910
    result = tf.linspace(start, stop, num, axis=axis)
×
911
    if num > 1:
×
912
      step = (stop - start) / (num - 1)
×
913
  else:
914
    # tf.linspace does not support endpoint=False so we manually handle it
915
    # here.
916
    if num > 1:
×
917
      step = ((stop - start) / num)
×
918
      new_stop = tf.cast(stop, step.dtype) - step
×
919
      start = tf.cast(start, new_stop.dtype)
×
920
      result = tf.linspace(start, new_stop, num, axis=axis)
×
921
    else:
922
      result = tf.linspace(start, stop, num, axis=axis)
×
923
  if dtype:
×
924
    result = tf.cast(result, dtype)
×
925
  if retstep:
×
926
    return arrays.tensor_to_ndarray(result), arrays.tensor_to_ndarray(step)
×
927
  else:
928
    return arrays.tensor_to_ndarray(result)
×
929

930

931
@utils.np_doc(np.logspace)
×
932
def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None, axis=0):
×
933
  dtype = utils.result_type(start, stop, dtype)
×
934
  result = linspace(
×
935
      start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis).data
936
  result = tf.pow(tf.cast(base, result.dtype), result)
×
937
  if dtype:
×
938
    result = tf.cast(result, dtype)
×
939
  return arrays.tensor_to_ndarray(result)
×
940

941

942
@utils.np_doc(np.ptp)
×
943
def ptp(a, axis=None, keepdims=None):
×
944
  return (array_ops.amax(a, axis=axis, keepdims=keepdims) -
×
945
          array_ops.amin(a, axis=axis, keepdims=keepdims))
946

947

948
@utils.np_doc_only(np.concatenate)
×
949
def concatenate(arys, axis=0):
×
950
  if not isinstance(arys, (list, tuple)):
×
951
    arys = [arys]
×
952
  if not arys:
×
953
    raise ValueError('Need at least one array to concatenate.')
×
954
  dtype = utils.result_type(*arys)
×
955
  arys = [array_ops.array(array, dtype=dtype).data for array in arys]
×
956
  return arrays.tensor_to_ndarray(tf.concat(arys, axis))
×
957

958

959
@utils.np_doc_only(np.tile)
×
960
def tile(a, reps):
961
  a = array_ops.array(a).data
×
962
  reps = array_ops.array(reps, dtype=tf.int32).reshape([-1]).data
×
963

964
  a_rank = tf.rank(a)
×
965
  reps_size = tf.size(reps)
×
966
  reps = tf.pad(
×
967
      reps, [[tf.math.maximum(a_rank - reps_size, 0), 0]],
968
      constant_values=1)
969
  a_shape = tf.pad(
×
970
      tf.shape(a), [[tf.math.maximum(reps_size - a_rank, 0), 0]],
971
      constant_values=1)
972
  a = tf.reshape(a, a_shape)
×
973

974
  return arrays.tensor_to_ndarray(tf.tile(a, reps))
×
975

976

977
@utils.np_doc(np.count_nonzero)
×
978
def count_nonzero(a, axis=None):
×
979
  return arrays.tensor_to_ndarray(
×
980
      tf.math.count_nonzero(array_ops.array(a).data, axis))
981

982

983
@utils.np_doc(np.argsort)
×
984
def argsort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-docstring
×
985
  # TODO(nareshmodi): make string tensors also work.
986
  if kind not in ('quicksort', 'stable'):
×
987
    raise ValueError("Only 'quicksort' and 'stable' arguments are supported.")
×
988
  if order is not None:
×
989
    raise ValueError("'order' argument to sort is not supported.")
×
990
  stable = (kind == 'stable')
×
991

992
  a = array_ops.array(a).data
×
993

994
  def _argsort(a, axis, stable):
×
995
    if axis is None:
×
996
      a = tf.reshape(a, [-1])
×
997
      axis = 0
×
998

999
    return tf.argsort(a, axis, stable=stable)
×
1000

1001
  tf_ans = tf.cond(
×
1002
      tf.rank(a) == 0, lambda: tf.constant([0]),
1003
      lambda: _argsort(a, axis, stable))
1004

1005
  return array_ops.array(tf_ans, dtype=np.intp)
×
1006

1007

1008
@utils.np_doc(np.sort)
×
1009
def sort(a, axis=-1, kind='quicksort', order=None):  # pylint: disable=missing-docstring
×
1010
  if kind != 'quicksort':
×
1011
    raise ValueError("Only 'quicksort' is supported.")
×
1012
  if order is not None:
×
1013
    raise ValueError("'order' argument to sort is not supported.")
×
1014

1015
  a = array_ops.array(a)
×
1016

1017
  if axis is None:
×
1018
    result_t = tf.sort(tf.reshape(a.data, [-1]), 0)
×
1019
    return utils.tensor_to_ndarray(result_t)
×
1020
  else:
1021
    return utils.tensor_to_ndarray(tf.sort(a.data, axis))
×
1022

1023

1024
def _argminmax(fn, a, axis=None):
×
1025
  a = array_ops.array(a)
×
1026
  if axis is None:
×
1027
    # When axis is None numpy flattens the array.
1028
    a_t = tf.reshape(a.data, [-1])
×
1029
  else:
1030
    a_t = array_ops.atleast_1d(a).data
×
1031
  return utils.tensor_to_ndarray(fn(input=a_t, axis=axis))
×
1032

1033

1034
@utils.np_doc(np.argmax)
×
1035
def argmax(a, axis=None):
×
1036
  return _argminmax(tf.argmax, a, axis)
×
1037

1038

1039
@utils.np_doc(np.argmin)
×
1040
def argmin(a, axis=None):
×
1041
  return _argminmax(tf.argmin, a, axis)
×
1042

1043

1044
@utils.np_doc(np.append)
×
1045
def append(arr, values, axis=None):
×
1046
  if axis is None:
×
1047
    return concatenate([array_ops.ravel(arr), array_ops.ravel(values)], 0)
×
1048
  else:
1049
    return concatenate([arr, values], axis=axis)
×
1050

1051

1052
@utils.np_doc(np.average)
×
1053
def average(a, axis=None, weights=None, returned=False):  # pylint: disable=missing-docstring
×
1054
  if axis is not None and not isinstance(axis, six.integer_types):
×
1055
    # TODO(wangpeng): Support tuple of ints as `axis`
1056
    raise ValueError('`axis` must be an integer. Tuple of ints is not '
×
1057
                     'supported yet. Got type: %s' % type(axis))
1058
  a = array_ops.array(a)
×
1059
  if weights is None:  # Treat all weights as 1
×
1060
    if not np.issubdtype(a.dtype, np.inexact):
×
1061
      a = a.astype(utils.result_type(a.dtype, dtypes.default_float_type()))
×
1062
    avg = tf.reduce_mean(a.data, axis=axis)
×
1063
    if returned:
×
1064
      if axis is None:
×
1065
        weights_sum = tf.size(a.data)
×
1066
      else:
1067
        weights_sum = tf.shape(a.data)[axis]
×
1068
      weights_sum = tf.cast(weights_sum, a.data.dtype)
×
1069
  else:
1070
    if np.issubdtype(a.dtype, np.inexact):
×
1071
      out_dtype = utils.result_type(a.dtype, weights)
×
1072
    else:
1073
      out_dtype = utils.result_type(a.dtype, weights,
×
1074
                                    dtypes.default_float_type())
1075
    a = array_ops.array(a, out_dtype).data
×
1076
    weights = array_ops.array(weights, out_dtype).data
×
1077

1078
    def rank_equal_case():
×
1079
      tf.debugging.Assert(tf.reduce_all(tf.shape(a) == tf.shape(weights)),
×
1080
                          [tf.shape(a), tf.shape(weights)])
1081
      weights_sum = tf.reduce_sum(weights, axis=axis)
×
1082
      avg = tf.reduce_sum(a * weights, axis=axis) / weights_sum
×
1083
      return avg, weights_sum
×
1084
    if axis is None:
×
1085
      avg, weights_sum = rank_equal_case()
×
1086
    else:
1087
      def rank_not_equal_case():
×
1088
        tf.debugging.Assert(tf.rank(weights) == 1, [tf.rank(weights)])
×
1089
        weights_sum = tf.reduce_sum(weights)
×
1090
        axes = tf.convert_to_tensor([[axis], [0]])
×
1091
        avg = tf.tensordot(a, weights, axes) / weights_sum
×
1092
        return avg, weights_sum
×
1093
      # We condition on rank rather than shape equality, because if we do the
1094
      # latter, when the shapes are partially unknown but the ranks are known
1095
      # and different, utils.cond will run shape checking on the true branch,
1096
      # which will raise a shape-checking error.
1097
      avg, weights_sum = utils.cond(tf.rank(a) == tf.rank(weights),
×
1098
                                    rank_equal_case, rank_not_equal_case)
1099

1100
  avg = array_ops.array(avg)
×
1101
  if returned:
×
1102
    weights_sum = array_ops.broadcast_to(weights_sum, tf.shape(avg.data))
×
1103
    return avg, weights_sum
×
1104
  return avg
×
1105

1106

1107
@utils.np_doc(np.trace)
×
1108
def trace(a, offset=0, axis1=0, axis2=1, dtype=None):  # pylint: disable=missing-docstring
×
1109
  if dtype:
×
1110
    dtype = utils.result_type(dtype)
×
1111
  a = array_ops.asarray(a, dtype).data
×
1112

1113
  if offset == 0:
×
1114
    a_shape = a.shape
×
1115
    if a_shape.rank is not None:
×
1116
      rank = len(a_shape)
×
1117
      if (axis1 == -2 or axis1 == rank - 2) and (axis2 == -1 or
×
1118
                                                 axis2 == rank - 1):
1119
        return utils.tensor_to_ndarray(tf.linalg.trace(a))
×
1120

1121
  a = array_ops.diagonal(a, offset, axis1, axis2)
×
1122
  return array_ops.sum(a, -1, dtype)
×
1123

1124

1125
@utils.np_doc(np.meshgrid)
×
1126
def meshgrid(*xi, **kwargs):
1127
  """This currently requires copy=True and sparse=False."""
1128
  sparse = kwargs.get('sparse', False)
×
1129
  if sparse:
×
1130
    raise ValueError('tf.numpy doesnt support returning sparse arrays yet')
×
1131

1132
  copy = kwargs.get('copy', True)
×
1133
  if not copy:
×
1134
    raise ValueError('tf.numpy only supports copy=True')
×
1135

1136
  indexing = kwargs.get('indexing', 'xy')
×
1137

1138
  xi = [array_ops.asarray(arg).data for arg in xi]
×
1139
  kwargs = {'indexing': indexing}
×
1140

1141
  outputs = tf.meshgrid(*xi, **kwargs)
×
1142
  outputs = [utils.tensor_to_ndarray(output) for output in outputs]
×
1143

1144
  return outputs
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc