• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / trax / 61

pending completion
61

Pull #531

travis-ci

web-flow
Merge 3eea35a3e into 0c21b50f5
Pull Request #531: Support incomplete shapes in swapaxes

8 of 8 new or added lines in 1 file covered. (100.0%)

2681 of 10743 relevant lines covered (24.96%)

0.25 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

25.94
/trax/tf_numpy/numpy/array_methods.py
1
# coding=utf-8
2
# Copyright 2020 The Trax Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Common array methods."""
1✔
17
from __future__ import absolute_import
1✔
18
from __future__ import division
1✔
19
from __future__ import print_function
1✔
20

21
import math
1✔
22
import numpy as np
1✔
23
import six
1✔
24
import tensorflow.compat.v2 as tf
1✔
25

26
from trax.tf_numpy.numpy import array_creation
1✔
27
from trax.tf_numpy.numpy import arrays
1✔
28
from trax.tf_numpy.numpy import dtypes
1✔
29
from trax.tf_numpy.numpy import utils
1✔
30

31

32
def all(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
1✔
33
  """Whether all array elements or those along an axis evaluate to true.
34

35
  Casts the array to bool type if it is not already and uses `tf.reduce_all` to
36
  compute the result.
37

38
  Args:
39
    a: array_like. Could be an ndarray, a Tensor or any object that can
40
      be converted to a Tensor using `tf.convert_to_tensor`.
41
    axis: Optional. Could be an int or a tuple of integers. If not specified,
42
      the reduction is performed over all array indices.
43
    keepdims: If true, retains reduced dimensions with length 1.
44

45
  Returns:
46
    An ndarray. Note that unlike NumPy this does not return a scalar bool if
47
    `axis` is None.
48
  """
49
  a = array_creation.asarray(a, dtype=bool)
×
50
  return utils.tensor_to_ndarray(
×
51
      tf.reduce_all(input_tensor=a.data, axis=axis, keepdims=keepdims))
52

53

54
def any(a, axis=None, keepdims=None):  # pylint: disable=redefined-builtin
1✔
55
  """Whether any element in the entire array or in an axis evaluates to true.
56

57
  Casts the array to bool type if it is not already and uses `tf.reduce_any` to
58
  compute the result.
59

60
  Args:
61
    a: array_like. Could be an ndarray, a Tensor or any object that can
62
      be converted to a Tensor using `tf.convert_to_tensor`.
63
    axis: Optional. Could be an int or a tuple of integers. If not specified,
64
      the reduction is performed over all array indices.
65
    keepdims: If true, retains reduced dimensions with length 1.
66

67
  Returns:
68
    An ndarray. Note that unlike NumPy this does not return a scalar bool if
69
    `axis` is None.
70
  """
71
  a = array_creation.asarray(a, dtype=bool)
×
72
  return utils.tensor_to_ndarray(
×
73
      tf.reduce_any(input_tensor=a.data, axis=axis, keepdims=keepdims))
74

75

76
def compress(condition, a, axis=None):
1✔
77
  """Compresses `a` by selecting values along `axis` with `condition` true.
78

79
  Uses `tf.boolean_mask`.
80

81
  Args:
82
    condition: 1-d array of bools. If `condition` is shorter than the array
83
      axis (or the flattened array if axis is None), it is padded with False.
84
    a: array_like. Could be an ndarray, a Tensor or any object that can
85
      be converted to a Tensor using `tf.convert_to_tensor`.
86
    axis: Optional. Axis along which to select elements. If None, `condition` is
87
      applied on flattened array.
88

89
  Returns:
90
    An ndarray.
91

92
  Raises:
93
    ValueError: if `condition` is not of rank 1.
94
  """
95
  condition = array_creation.asarray(condition, dtype=bool)
×
96
  a = array_creation.asarray(a)
×
97

98
  if condition.ndim != 1:
×
99
    raise ValueError('condition must be a 1-d array.')
×
100

101
  # `np.compress` treats scalars as 1-d arrays.
102
  if a.ndim == 0:
×
103
    a = ravel(a)
×
104

105
  if axis is None:
×
106
    a = ravel(a)
×
107
    axis = 0
×
108

109
  if axis < 0:
×
110
    axis += a.ndim
×
111

112
  assert axis >= 0 and axis < a.ndim
×
113

114
  # `tf.boolean_mask` requires the first dimensions of array and condition to
115
  # match. `np.compress` pads condition with False when it is shorter.
116
  condition_t = condition.data
×
117
  a_t = a.data
×
118
  if condition.shape[0] < a.shape[axis]:
×
119
    padding = tf.fill([a.shape[axis] - condition.shape[0]], False)
×
120
    condition_t = tf.concat([condition_t, padding], axis=0)
×
121
  return utils.tensor_to_ndarray(tf.boolean_mask(tensor=a_t, mask=condition_t,
×
122
                                                 axis=axis))
123

124

125
def copy(a):
1✔
126
  """Returns a copy of the array."""
127
  return array_creation.array(a, copy=True)
×
128

129

130
def cumprod(a, axis=None, dtype=None):
1✔
131
  """Returns cumulative product of `a` along an axis or the flattened array.
132

133
  Uses `tf.cumprod`.
134

135
  Args:
136
    a: array_like. Could be an ndarray, a Tensor or any object that can
137
      be converted to a Tensor using `tf.convert_to_tensor`.
138
    axis: Optional. Axis along which to compute products. If None, operation is
139
      performed on the flattened array.
140
    dtype: Optional. The type of the output array. If None, defaults to the
141
      dtype of `a` unless `a` is an integer type with precision less than `int`
142
      in which case the output type is `int.`
143

144
  Returns:
145
    An ndarray with the same number of elements as `a`. If `axis` is None, the
146
    output is a 1-d array, else it has the same shape as `a`.
147
  """
148
  a = array_creation.asarray(a, dtype=dtype)
×
149

150
  if dtype is None and tf.as_dtype(a.dtype).is_integer:
×
151
    # If a is an integer type and its precision is less than that of `int`,
152
    # the output type will be `int`.
153
    output_type = np.promote_types(a.dtype, int)
×
154
    if output_type != a.dtype:
×
155
      a = array_creation.asarray(a, dtype=output_type)
×
156

157
  # If axis is None, the input is flattened.
158
  if axis is None:
×
159
    a = ravel(a)
×
160
    axis = 0
×
161
  if axis < 0:
×
162
    axis += a.ndim
×
163
  assert axis >= 0 and axis < a.ndim
×
164
  return utils.tensor_to_ndarray(tf.math.cumprod(a.data, axis))
×
165

166

167
def cumsum(a, axis=None, dtype=None):
1✔
168
  """Returns cumulative sum of `a` along an axis or the flattened array.
169

170
  Uses `tf.cumsum`.
171

172
  Args:
173
    a: array_like. Could be an ndarray, a Tensor or any object that can
174
      be converted to a Tensor using `tf.convert_to_tensor`.
175
    axis: Optional. Axis along which to compute sums. If None, operation is
176
      performed on the flattened array.
177
    dtype: Optional. The type of the output array. If None, defaults to the
178
      dtype of `a` unless `a` is an integer type with precision less than `int`
179
      in which case the output type is `int.`
180

181
  Returns:
182
    An ndarray with the same number of elements as `a`. If `axis` is None, the
183
    output is a 1-d array, else it has the same shape as `a`.
184
  """
185
  a = array_creation.asarray(a, dtype=dtype)
×
186

187
  if dtype is None and tf.as_dtype(a.dtype).is_integer:
×
188
    # If a is an integer type and its precision is less than that of `int`,
189
    # the output type will be `int`.
190
    output_type = np.promote_types(a.dtype, int)
×
191
    if output_type != a.dtype:
×
192
      a = array_creation.asarray(a, dtype=output_type)
×
193

194
  # If axis is None, the input is flattened.
195
  if axis is None:
×
196
    a = ravel(a)
×
197
    axis = 0
×
198
  if axis < 0:
×
199
    axis += a.ndim
×
200
  assert axis >= 0 and axis < a.ndim
×
201
  return utils.tensor_to_ndarray(tf.cumsum(a.data, axis))
×
202

203

204
def imag(a):
1✔
205
  """Returns imaginary parts of all elements in `a`.
206

207
  Uses `tf.imag`.
208

209
  Args:
210
    a: array_like. Could be an ndarray, a Tensor or any object that can
211
      be converted to a Tensor using `tf.convert_to_tensor`.
212

213
  Returns:
214
    An ndarray with the same shape as `a`.
215
  """
216
  a = array_creation.asarray(a)
×
217
  # TODO(srbs): np.imag returns a scalar if a is a scalar, whereas we always
218
  # return an ndarray.
219
  return utils.tensor_to_ndarray(tf.math.imag(a.data))
×
220

221

222
_TO_INT64 = 0
1✔
223
_TO_FLOAT = 1
1✔
224

225

226
def _reduce(tf_fn, a, axis=None, dtype=None, keepdims=None,
1✔
227
            promote_int=_TO_INT64, tf_bool_fn=None, preserve_bool=False):
228
  """A general reduction function.
229

230
  Args:
231
    tf_fn: the TF reduction function.
232
    a: the array to be reduced.
233
    axis: (optional) the axis along which to do the reduction. If None, all
234
      dimensions are reduced.
235
    dtype: (optional) the dtype of the result.
236
    keepdims: (optional) whether to keep the reduced dimension(s).
237
    promote_int: how to promote integer and bool inputs. There are three
238
      choices: (1) _TO_INT64: always promote them to int64 or uint64; (2)
239
      _TO_FLOAT: always promote them to a float type (determined by
240
      dtypes.default_float_type); (3) None: don't promote.
241
    tf_bool_fn: (optional) the TF reduction function for bool inputs. It
242
      will only be used if `dtype` is explicitly set to `np.bool_` or if `a`'s
243
      dtype is `np.bool_` and `preserve_bool` is True.
244
    preserve_bool: a flag to control whether to use `tf_bool_fn` if `a`'s dtype
245
      is `np.bool_` (some reductions such as np.sum convert bools to
246
      integers, while others such as np.max preserve bools.
247

248
  Returns:
249
    An ndarray.
250
  """
251
  if dtype:
×
252
    dtype = utils.result_type(dtype)
×
253
  if keepdims is None:
×
254
    keepdims = False
×
255
  a = array_creation.asarray(a, dtype=dtype)
×
256
  if ((dtype == np.bool_ or preserve_bool and a.dtype == np.bool_)
×
257
      and tf_bool_fn is not None):
258
    return utils.tensor_to_ndarray(
×
259
        tf_bool_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
260
  if dtype is None:
×
261
    dtype = a.dtype
×
262
    if np.issubdtype(dtype, np.integer) or dtype == np.bool_:
×
263
      if promote_int == _TO_INT64:
×
264
        # If a is an integer/bool type and whose bit width is less than 64,
265
        # numpy up-casts it to 64-bit.
266
        if dtype == np.bool_:
×
267
          is_signed = True
×
268
          width = 8  # We can use any number here that is less than 64
×
269
        else:
270
          is_signed = np.issubdtype(dtype, np.signedinteger)
×
271
          width = np.iinfo(dtype).bits
×
272
        if width < 64:
×
273
          if is_signed:
×
274
            dtype = np.int64
×
275
          else:
276
            dtype = np.uint64
×
277
          a = a.astype(dtype)
×
278
      elif promote_int == _TO_FLOAT:
×
279
        a = a.astype(dtypes.default_float_type())
×
280

281
  return utils.tensor_to_ndarray(
×
282
      tf_fn(input_tensor=a.data, axis=axis, keepdims=keepdims))
283

284

285
@utils.np_doc(np.sum)
1✔
286
def sum(a, axis=None, dtype=None, keepdims=None):  # pylint: disable=redefined-builtin
1✔
287
  return _reduce(tf.reduce_sum, a, axis=axis, dtype=dtype, keepdims=keepdims,
×
288
                 tf_bool_fn=tf.reduce_any)
289

290

291
@utils.np_doc(np.prod)
1✔
292
def prod(a, axis=None, dtype=None, keepdims=None):
1✔
293
  return _reduce(tf.reduce_prod, a, axis=axis, dtype=dtype, keepdims=keepdims,
×
294
                 tf_bool_fn=tf.reduce_all)
295

296

297
@utils.np_doc(np.mean)
1✔
298
def mean(a, axis=None, dtype=None, keepdims=None):
1✔
299
  return _reduce(tf.math.reduce_mean, a, axis=axis, dtype=dtype,
×
300
                 keepdims=keepdims, promote_int=_TO_FLOAT)
301

302

303
@utils.np_doc(np.amax)
1✔
304
def amax(a, axis=None, keepdims=None):
1✔
305
  return _reduce(tf.reduce_max, a, axis=axis, dtype=None, keepdims=keepdims,
×
306
                 promote_int=None, tf_bool_fn=tf.reduce_any, preserve_bool=True)
307

308

309
@utils.np_doc(np.amin)
1✔
310
def amin(a, axis=None, keepdims=None):
1✔
311
  return _reduce(tf.reduce_min, a, axis=axis, dtype=None, keepdims=keepdims,
×
312
                 promote_int=None, tf_bool_fn=tf.reduce_all, preserve_bool=True)
313

314

315
@utils.np_doc(np.var)
1✔
316
def var(a, axis=None, keepdims=None):
1✔
317
  return _reduce(tf.math.reduce_variance, a, axis=axis, dtype=None,
×
318
                 keepdims=keepdims, promote_int=_TO_FLOAT)
319

320

321
@utils.np_doc(np.std)
1✔
322
def std(a, axis=None, keepdims=None):
1✔
323
  return _reduce(tf.math.reduce_std, a, axis=axis, dtype=None,
×
324
                 keepdims=keepdims, promote_int=_TO_FLOAT)
325

326

327
def ravel(a):
1✔
328
  """Flattens `a` into a 1-d array.
329

330
  If `a` is already a 1-d ndarray it is returned as is.
331

332
  Uses `tf.reshape`.
333

334
  Args:
335
    a: array_like. Could be an ndarray, a Tensor or any object that can
336
      be converted to a Tensor using `tf.convert_to_tensor`.
337

338
  Returns:
339
    A 1-d ndarray.
340
  """
341
  a = array_creation.asarray(a)
×
342
  if a.ndim == 1:
×
343
    return a
×
344
  return utils.tensor_to_ndarray(tf.reshape(a.data, [-1]))
×
345

346

347
def real(val):
1✔
348
  """Returns real parts of all elements in `a`.
349

350
  Uses `tf.real`.
351

352
  Args:
353
    val: array_like. Could be an ndarray, a Tensor or any object that can
354
      be converted to a Tensor using `tf.convert_to_tensor`.
355

356
  Returns:
357
    An ndarray with the same shape as `a`.
358
  """
359
  val = array_creation.asarray(val)
×
360
  # TODO(srbs): np.real returns a scalar if val is a scalar, whereas we always
361
  # return an ndarray.
362
  return utils.tensor_to_ndarray(tf.math.real(val.data))
×
363

364

365
@utils.np_doc(np.repeat)
1✔
366
def repeat(a, repeats, axis=None):
1✔
367
  a = array_creation.asarray(a).data
×
368
  repeats = array_creation.asarray(repeats).data
×
369
  return utils.tensor_to_ndarray(tf.repeat(a, repeats, axis))
×
370

371

372
@utils.np_doc(np.around)
1✔
373
def around(a, decimals=0):
1✔
374
  a = array_creation.asarray(a)
×
375
  factor = math.pow(10, decimals)
×
376
  factor = tf.cast(factor, a.dtype)
×
377
  a_t = tf.multiply(a.data, factor)
×
378
  a_t = tf.round(a_t)
×
379
  a_t = tf.math.divide(a_t, factor)
×
380
  return utils.tensor_to_ndarray(a_t)
×
381

382

383
round_ = around
1✔
384
setattr(arrays.ndarray, '__round__', around)
1✔
385

386

387
def reshape(a, newshape):
1✔
388
  """Reshapes an array.
389

390
  Args:
391
    a: array_like. Could be an ndarray, a Tensor or any object that can
392
      be converted to a Tensor using `tf.convert_to_tensor`.
393
    newshape: 0-d or 1-d array_like.
394

395
  Returns:
396
    An ndarray with the contents and dtype of `a` and shape `newshape`.
397
  """
398
  a = array_creation.asarray(a)
×
399
  if isinstance(newshape, arrays.ndarray):
×
400
    newshape = newshape.data
×
401
  return utils.tensor_to_ndarray(tf.reshape(a.data, newshape))
×
402

403

404
def expand_dims(a, axis):
1✔
405
  """Expand the shape of an array.
406

407
  Args:
408
    a: array_like. Could be an ndarray, a Tensor or any object that can
409
      be converted to a Tensor using `tf.convert_to_tensor`.
410
    axis: int. axis on which to expand the shape.
411

412
  Returns:
413
    An ndarray with the contents and dtype of `a` and shape expanded on axis.
414
  """
415
  a = array_creation.asarray(a)
×
416
  return utils.tensor_to_ndarray(tf.expand_dims(a.data, axis=axis))
×
417

418

419
def squeeze(a, axis=None):
1✔
420
  """Removes single-element axes from the array.
421

422
  Args:
423
    a: array_like. Could be an ndarray, a Tensor or any object that can
424
      be converted to a Tensor using `tf.convert_to_tensor`.
425
    axis: scalar or list/tuple of ints.
426

427
  TODO(srbs): tf.squeeze throws error when axis is a Tensor eager execution
428
  is enabled. So we cannot allow axis to be array_like here. Fix.
429

430
  Returns:
431
    An ndarray.
432
  """
433
  a = array_creation.asarray(a)
×
434
  return utils.tensor_to_ndarray(tf.squeeze(a, axis))
×
435

436

437
def transpose(a, axes=None):
1✔
438
  """Permutes dimensions of the array.
439

440
  Args:
441
    a: array_like. Could be an ndarray, a Tensor or any object that can
442
      be converted to a Tensor using `tf.convert_to_tensor`.
443
    axes: array_like. A list of ints with length rank(a) or None specifying the
444
      order of permutation. The i'th dimension of the output array corresponds
445
      to axes[i]'th dimension of the `a`. If None, the axes are reversed.
446

447
  Returns:
448
    An ndarray.
449
  """
450
  a = array_creation.asarray(a)
×
451
  if axes is not None:
×
452
    axes = array_creation.asarray(axes)
×
453
  return utils.tensor_to_ndarray(tf.transpose(a=a.data, perm=axes))
×
454

455

456
@utils.np_doc(np.swapaxes)
1✔
457
def swapaxes(a, axis1, axis2):  # pylint: disable=missing-docstring
458
  a = array_creation.asarray(a)
×
459

460
  a_rank = tf.rank(a)
×
461
  if axis1 < 0:
×
462
    axis1 += a_rank
×
463
  if axis2 < 0:
×
464
    axis2 += a_rank
×
465

466
  perm = tf.range(a_rank)
×
467
  perm = tf.tensor_scatter_nd_update(perm, [[axis1], [axis2]], [axis2, axis1])
×
468
  a = tf.transpose(a, perm)
×
469

470
  return utils.tensor_to_ndarray(a)
×
471

472

473
def _setitem(arr, index, value):
1✔
474
  """Sets the `value` at `index` in the array `arr`.
475

476
  This works by replacing the slice at `index` in the tensor with `value`.
477
  Since tensors are immutable, this builds a new tensor using the `tf.concat`
478
  op. Currently, only 0-d and 1-d indices are supported.
479

480
  Note that this may break gradients e.g.
481

482
  a = tf_np.array([1, 2, 3])
483
  old_a_t = a.data
484

485
  with tf.GradientTape(persistent=True) as g:
486
    g.watch(a.data)
487
    b = a * 2
488
    a[0] = 5
489
  g.gradient(b.data, [a.data])  # [None]
490
  g.gradient(b.data, [old_a_t])  # [[2., 2., 2.]]
491

492
  Here `d_b / d_a` is `[None]` since a.data no longer points to the same
493
  tensor.
494

495
  Args:
496
    arr: array_like.
497
    index: scalar or 1-d integer array.
498
    value: value to set at index.
499

500
  Returns:
501
    ndarray
502

503
  Raises:
504
    ValueError: if `index` is not a scalar or 1-d array.
505
  """
506
  # TODO(srbs): Figure out a solution to the gradient problem.
507
  arr = array_creation.asarray(arr)
×
508
  index = array_creation.asarray(index)
×
509
  if index.ndim == 0:
×
510
    index = ravel(index)
×
511
  elif index.ndim > 1:
×
512
    raise ValueError('index must be a scalar or a 1-d array.')
×
513
  value = array_creation.asarray(value, dtype=arr.dtype)
×
514
  if arr.shape[len(index):] != value.shape:
×
515
    value = array_creation.full(arr.shape[len(index):], value)
×
516
  prefix_t = arr.data[:index.data[0]]
×
517
  postfix_t = arr.data[index.data[0] + 1:]
×
518
  if len(index) == 1:
×
519
    arr._data = tf.concat(  # pylint: disable=protected-access
×
520
        [prefix_t, tf.expand_dims(value.data, 0), postfix_t], 0)
521
  else:
522
    subarray = arr[index.data[0]]
×
523
    _setitem(subarray, index[1:], value)
×
524
    arr._data = tf.concat(  # pylint: disable=protected-access
×
525
        [prefix_t, tf.expand_dims(subarray.data, 0), postfix_t], 0)
526

527

528
setattr(arrays.ndarray, 'transpose', transpose)
1✔
529
setattr(arrays.ndarray, 'reshape', reshape)
1✔
530
setattr(arrays.ndarray, '__setitem__', _setitem)
1✔
531

532

533
def pad(array, pad_width, mode, constant_values=0):
1✔
534
  """Pads an array.
535

536
  Args:
537
    array: array_like of rank N. Input array.
538
    pad_width: {sequence, array_like, int}.
539
      Number of values padded to the edges of each axis.
540
      ((before_1, after_1), ... (before_N, after_N)) unique pad widths
541
      for each axis.
542
      ((before, after),) yields same before and after pad for each axis.
543
      (pad,) or int is a shortcut for before = after = pad width for all
544
      axes.
545
    mode: string. One of the following string values:
546
      'constant'
547
          Pads with a constant value.
548
      'reflect'
549
          Pads with the reflection of the vector mirrored on
550
          the first and last values of the vector along each
551
          axis.
552
      'symmetric'
553
          Pads with the reflection of the vector mirrored
554
          along the edge of the array.
555
      **NOTE**: The supported list of `mode` does not match that of numpy's.
556
    constant_values: scalar with same dtype as `array`.
557
      Used in 'constant' mode as the pad value.  Default is 0.
558

559

560
  Returns:
561
    An ndarray padded array of rank equal to `array` with shape increased
562
    according to `pad_width`.
563

564
  Raises:
565
    ValueError if `mode` is not supported.
566
  """
567
  if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'):
×
568
    raise ValueError('Unsupported padding mode: ' + mode)
×
569
  mode = mode.upper()
×
570
  array = array_creation.asarray(array)
×
571
  pad_width = array_creation.asarray(pad_width, dtype=tf.int32)
×
572
  return utils.tensor_to_ndarray(tf.pad(
×
573
      tensor=array.data, paddings=pad_width.data, mode=mode,
574
      constant_values=constant_values))
575

576

577
def take(a, indices, axis=None):
1✔
578
  """Take elements from an array along an axis.
579

580
  See https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html for
581
  description.
582

583
  Args:
584
    a: array_like. The source array.
585
    indices: array_like. The indices of the values to extract.
586
    axis: int, optional. The axis over which to select values. By default, the
587
      flattened input array is used.
588

589
  Returns:
590
    A ndarray. The returned array has the same type as `a`.
591
  """
592
  a = array_creation.asarray(a)
×
593
  indices = array_creation.asarray(indices)
×
594
  a = a.data
×
595
  if axis is None:
×
596
    a = tf.reshape(a, [-1])
×
597
    axis = 0
×
598
  return utils.tensor_to_ndarray(tf.gather(a, indices.data, axis=axis))
×
599

600

601
def where(condition, x, y):
1✔
602
  """Return an array with elements from `x` or `y`, depending on condition.
603

604
  Args:
605
    condition: array_like, bool. Where True, yield `x`, otherwise yield `y`.
606
    x: see below.
607
    y: array_like, optional. Values from which to choose. `x`, `y` and
608
      `condition` need to be broadcastable to some shape.
609

610
  Returns:
611
    An array.
612
  """
613
  condition = array_creation.asarray(condition, dtype=np.bool_)
×
614
  x, y = array_creation._promote_dtype(x, y)
×
615
  return utils.tensor_to_ndarray(tf.where(condition.data, x.data, y.data))
×
616

617

618
def shape(a):
1✔
619
  """Return the shape of an array.
620

621
  Args:
622
    a: array_like. Input array.
623

624
  Returns:
625
    Tuple of ints.
626
  """
627
  a = array_creation.asarray(a)
×
628
  return a.shape
×
629

630

631
def ndim(a):
1✔
632
  a = array_creation.asarray(a)
×
633
  return a.ndim
×
634

635

636
def isscalar(a):
1✔
637
  return ndim(a) == 0
×
638

639

640
def _boundaries_to_sizes(a, boundaries, axis):
1✔
641
  """Converting boundaries of splits to sizes of splits.
642

643
  Args:
644
    a: the array to be split.
645
    boundaries: the boundaries, as in np.split.
646
    axis: the axis along which to split.
647

648
  Returns:
649
    A list of sizes of the splits, as in tf.split.
650
  """
651
  if axis >= len(a.shape):
×
652
    raise ValueError('axis %s is out of bound for shape %s' % (axis, a.shape))
×
653
  total_size = a.shape[axis]
×
654
  sizes = []
×
655
  sizes_sum = 0
×
656
  prev = 0
×
657
  for i, b in enumerate(boundaries):
×
658
    size = b - prev
×
659
    if size < 0:
×
660
      raise ValueError('The %s-th boundary %s is smaller than the previous '
×
661
                       'boundary %s' % (i, b, prev))
662
    size = min(size, max(0, total_size - sizes_sum))
×
663
    sizes.append(size)
×
664
    sizes_sum += size
×
665
    prev = b
×
666
  sizes.append(max(0, total_size - sizes_sum))
×
667
  return sizes
×
668

669

670
def split(a, indices_or_sections, axis=0):
1✔
671
  """Split an array into multiple sub-arrays.
672

673
  See https://docs.scipy.org/doc/numpy/reference/generated/numpy.split.html for
674
  reference.
675

676
  Args:
677
    a: the array to be splitted.
678
    indices_or_sections: int or 1-D array, representing the number of even
679
      splits or the boundaries between splits.
680
    axis: the axis along which to split.
681

682
  Returns:
683
    A list of sub-arrays.
684
  """
685
  a = array_creation.asarray(a)
×
686
  if not isinstance(indices_or_sections, six.integer_types):
×
687
    indices_or_sections = _boundaries_to_sizes(a, indices_or_sections, axis)
×
688
  result = tf.split(a.data, indices_or_sections, axis=axis)
×
689
  return [utils.tensor_to_ndarray(a) for a in result]
×
690

691

692
@utils.np_doc(np.broadcast_to)
1✔
693
def broadcast_to(array, shape):  # pylint: disable=redefined-outer-name
694
  return array_creation.full(shape, array)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc