• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

harvardnlp / namedtensor / 235

pending completion
235

Pull #59

travis-ci

web-flow
.
Pull Request #59: Rewrite of test suite

183 of 183 new or added lines in 4 files covered. (100.0%)

249 of 1190 relevant lines covered (20.92%)

0.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

2.28
/namedtensor/test_core.py
1
from . import assert_match, ntorch, NamedTensor
1✔
2
import numpy as np
1✔
3
import torch
1✔
4
from collections import OrderedDict
1✔
5
import pytest
1✔
6
import torch.nn.functional as F
1✔
7
from hypothesis import given, example
1✔
8
from hypothesis.strategies import text, composite, sets, sampled_from, data, lists, permutations, integers, floats, booleans
×
9
from hypothesis.extra.numpy import arrays, array_shapes
×
10

11

12
# Setup Hypothesis helpers
13
def named_tensor(dtype=np.float, shape=array_shapes(2, 5, max_side=5)):
×
14
    @composite
×
15
    def name(draw, array):
16
        array = draw(array)
×
17
        names = draw(lists(text(min_size=1, alphabet="abc"),
×
18
                           max_size=len(array.shape),
19
                           min_size=len(array.shape),
20
                           unique=True))
21
        return ntorch.tensor(array, names=names)
×
22
    return name(arrays(dtype, shape, elements=floats(allow_nan=False,
×
23
                                                     allow_infinity=False)))
24

25
def dim(tensor):
×
26
    return sampled_from(list(tensor.shape.keys()))
×
27

28
def dims(tensor, max_size=5):
×
29
    return lists(dim(tensor), unique=True, min_size=2, max_size=max_size)
×
30

31
def name(tensor):
×
32
    return text(alphabet="abc", min_size=1).filter(lambda y: y not in tensor.shape)
×
33

34
def names(tensor, max_size=5):
×
35
    return lists(name(tensor), unique=True, min_size=2, max_size=max_size)
×
36

37
def broadcast_named_tensor(x, dtype=np.float):
×
38
    @composite
×
39
    def fill(draw):
40
        ds = draw(dims(x, max_size=2))
×
41
        ns = draw(names(x, max_size=2))
×
42
        perm = draw(permutations(range(len(ns) + len(ds))))
×
43
        def reorder(ls):
×
44
            return [ls[perm[i]] for i in range(len(ls))]
×
45
        sizes = draw(lists(integers(min_value=1, max_value=4),
×
46
                           min_size=len(ns), max_size=len(ns)))
47
        shape = reorder([x.shape[d] for d in ds] + sizes)
×
48
        np = draw(arrays(dtype, shape=shape))
×
49

50
        return ntorch.tensor(np, names=reorder(ds + ns))
×
51
    return fill()
×
52

53

54
def mask_named_tensor(x, dtype=np.uint8):
×
55
    @composite
×
56
    def fill(draw):
57
        ds = draw(dims(x, max_size=2))
×
58
        perm = draw(permutations(range(len(ds))))
×
59
        def reorder(ls):
×
60
            return [ls[perm[i]] for i in range(len(ls))]
×
61
        shape = reorder([x.shape[d] for d in ds])
×
62
        np = draw(arrays(dtype, shape, integers(min_value=0, max_value=1)))
×
63

64
        return ntorch.tensor(np, names=reorder(ds)).byte()
×
65
    return fill()
×
66

67

68
@pytest.mark.xfail
×
69
def test_unique_names():
70
    base = torch.zeros([10, 2])
×
71
    assert ntorch.tensor(base, ("alpha", "beta", "alpha"))
×
72

73
def test_names():
×
74
    base = torch.zeros([10, 2, 50])
×
75
    assert ntorch.tensor(base, ("alpha", "beta", "gamma"))
×
76

77
@pytest.mark.xfail
×
78
def test_bad_names():
79
    base = torch.zeros([10, 2])
×
80
    assert ntorch.tensor(base, ("elements_dim", "input_dims"))
×
81

82

83
@given(data(), named_tensor())
×
84
def test_stack_basic(data, x):
85
    s = data.draw(dims(x))
×
86
    n = data.draw(name(x))
×
87
    x = x.stack(list(s), n)
×
88
    assert n in x.dims
×
89
    assert not (x.shape.keys() & s)
×
90

91
@given(data(), named_tensor())
×
92
def test_rename(data, x):
93
    s = data.draw(dim(x))
×
94
    n = data.draw(name(x))
×
95
    x = x.rename(s, n)
×
96
    assert n in x.dims
×
97
    assert s not in x.dims
×
98

99

100
@given(data(), named_tensor())
×
101
def test_split(data, x):
102
    s = data.draw(dim(x))
×
103
    ns = list(data.draw(names(x)))
×
104
    x2 = x.split(s, ns, **{n:1 for n in ns[:-1]})
×
105
    assert len(set(ns) & set(x2.dims)) == len(ns)
×
106
    assert s not in x2.dims
×
107
    assert torch.prod(torch.tensor([x2.shape[n] for n in ns])) == x.shape[s]
×
108

109
@given(data(), named_tensor())
×
110
def test_reduce(data, x):
111
    ns = data.draw(dims(x))
×
112
    method = data.draw(sampled_from(sorted(x._reduce)))
×
113

114
    if method not in ["logsumexp"]:
×
115
        y = getattr(x, method)()
×
116
        # assert y.values == getattr(x.values, method)()
117

118
    x2 = getattr(x, method)(tuple(ns))
×
119
    assert set(x2.dims) | set(ns) == set(x.dims)
×
120

121

122
@given(data(), named_tensor())
×
123
def test_binary_op(data, x):
124
    y = data.draw(broadcast_named_tensor(x))
×
125
    method = data.draw(sampled_from(sorted(x._binop)))
×
126
    x2 = getattr(x, method)(y)
×
127
    assert set(x2.dims) == set(x.dims) | set(y.dims)
×
128
    x3 = getattr(y, method)(x)
×
129
    assert set(x3.dims) == set(x.dims) | set(y.dims)
×
130

131
@given(data(), named_tensor())
×
132
def test_noshift(data, x):
133
    method = data.draw(sampled_from(sorted(x._noshift)).filter(lambda a: a not in {"cuda"}))
×
134
    x2 = getattr(x, method)()
×
135
    assert set(x2.dims) == set(x.dims)
×
136

137
@given(data(), named_tensor())
×
138
def test_apply(data, x):
139
    method = data.draw(sampled_from(sorted(x._noshift_dim | x._noshift_nn_dim)))
×
140
    s = data.draw(dim(x))
×
141
    x2 = getattr(x, method)(s)
×
142
    assert x.shape == x2.shape
×
143

144
def test_apply2():
×
145
    base = torch.zeros([10, 2, 50])
×
146
    ntensor = ntorch.tensor(base, ("alpha", "beta", "gamma"))
×
147
    ntensor = ntensor.op(F.softmax, dim="alpha")
×
148
    assert (ntorch.abs(ntensor.sum("alpha") - 1.0) < 1e-5).all()
×
149

150
@given(named_tensor())
×
151
def test_sum(x):
152
    s = x.sum()
×
153
    print(x.shape)
×
154
    assert s.values == x.values.sum()
×
155

156

157
def test_fill():
×
158
    base = torch.zeros([10, 2, 50])
×
159
    ntensor = ntorch.tensor(base, ("alpha", "beta", "gamma"))
×
160
    ntensor.fill_(20)
×
161
    assert (ntensor == 20).all()
×
162

163
@given(data(), named_tensor())
×
164
def test_mask(data, x):
165
    mask = data.draw(mask_named_tensor(x))
×
166
    x2 = x.masked_select(mask, "c")
×
167
    x2 = x[mask]
×
168

169
@pytest.mark.xfail
×
170
@given(data(), named_tensor())
×
171
def test_maskfail():
172
    mask = data.draw(broadcast_named_tensor(x))
×
173
    x2 = x.masked_select(mask, "c")
×
174
    x2 = x[mask]
×
175

176

177
@given(data(), named_tensor(), floats(allow_nan=False, allow_infinity=False))
×
178
def test_all_scalar_ops(data, x, y):
179
    x = x + y
×
180
    x = x - y
×
181
    x = x * y
×
182
    x = x / y
×
183

184
    x = y + x
×
185
    x = y - x
×
186
    x = y * x
×
187

188
    x = -x
×
189

190
def test_gather():
×
191
    t = torch.Tensor([[1, 2], [3, 4]])
×
192
    base = torch.gather(t, 1, torch.LongTensor([[0, 0], [1, 0]]))
×
193

194
    t = ntorch.tensor(torch.Tensor([[1, 2], [3, 4]]), ("a", "b"))
×
195
    index = ntorch.tensor(torch.LongTensor([[0, 0], [1, 0]]), ("a", "c"))
×
196
    ntensor = ntorch.gather(t, "b", index, "c")
×
197
    assert (ntensor.values == base).all()
×
198
    assert ntensor.shape == OrderedDict([("a", 2), ("c", 2)])
×
199

200
    x = ntorch.tensor(torch.rand(2, 5), ("c", "b"))
×
201
    y = ntorch.tensor(torch.rand(3, 5), ("a", "b"))
×
202
    y.scatter_(
×
203
        "a",
204
        ntorch.tensor(
205
            torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), ("c", "b")
206
        ),
207
        x,
208
        "c",
209
    )
210
    assert y.shape == OrderedDict([("a", 3), ("b", 5)])
×
211

212

213
def test_cat():
×
214
    x = ntorch.zeros(20, 10, names=("a", "b"))
×
215
    y = ntorch.ones(30, 20, names=("b", "a"))
×
216
    assert ntorch.cat([x, y], dim="b").shape == OrderedDict(
×
217
        [("a", 20), ("b", 40)]
218
    )
219

220

221
def test_stack():
×
222
    tensor_a = ntorch.tensor(
×
223
        torch.Tensor([[1, 2], [3, 4], [5, 6]]), ("dim1", "dim2")
224
    )
225
    tensor_b = ntorch.tensor(
×
226
        torch.Tensor([[1, 2, 3], [4, 5, 6]]), ("dim2", "dim1")
227
    )
228
    tensor_c = ntorch.stack([tensor_a, tensor_b], "dim3")
×
229
    assert tensor_c.shape == OrderedDict(
×
230
        [("dim3", 2), ("dim1", 3), ("dim2", 2)]
231
    )
232

233

234
def test_unbind():
×
235
    base = torch.zeros([10, 2, 50])
×
236
    ntensor = ntorch.tensor(base, ("alpha", "beta", "gamma"))
×
237
    out = ntensor.unbind("beta")
×
238
    assert len(out) == 2
×
239
    assert out[0].shape == OrderedDict([("alpha", 10), ("gamma", 50)])
×
240

241
    base = torch.zeros([10])
×
242
    ntensor = ntorch.tensor(base, ("alpha",))
×
243
    ntensor.fill_(20)
×
244
    c = ntensor.unbind("alpha")
×
245
    assert len(c) == 10
×
246
    assert c[0].item() == 20
×
247

248

249
@pytest.mark.xfail
×
250
def test_fail():
251
    for base1, base2 in zip(
×
252
        make_tensors([10, 2, 50]), make_tensors([10, 20, 2])
253
    ):
254
        ntensor1 = NamedTensor(base1, ("alpha", "beta", "gamma"))
×
255
        ntensor2 = NamedTensor(base2, ("alpha", "beat", "gamma"))
×
256
        assert_match(ntensor1, ntensor2)
×
257

258

259
def test_multiple():
×
260
    base1 = torch.rand([10, 2, 50])
×
261
    base2 = torch.rand([10, 20, 2])
×
262
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
×
263
    ntensor2 = ntorch.tensor(base2, ("alpha", "delta", "beta"))
×
264
    assert_match(ntensor1, ntensor2)
×
265

266
    # Try applying a projected bin op
267
    base3 = torch.mul(base1.view([10, 1, 2, 50]), base2.view([10, 20, 2, 1]))
×
268
    ntensor3 = ntensor1.mul(ntensor2).transpose(
×
269
        "alpha", "delta", "beta", "gamma"
270
    )
271

272
    assert base3.shape == ntensor3.vshape
×
273
    assert (base3 == ntensor3.values).all()
×
274

275

276
def test_contract():
×
277
    base1 = torch.randn(10, 2, 50)
×
278
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
×
279
    base2 = torch.randn(10, 20, 2)
×
280
    ntensor2 = ntorch.tensor(base2, ("alpha", "delta", "beta"))
×
281
    assert_match(ntensor1, ntensor2)
×
282

283
    base3 = torch.einsum("abg,adb->a", (base1, base2))
×
284

285
    ntensor3 = ntorch.dot(("beta", "gamma", "delta"), ntensor1, ntensor2)
×
286
    assert ntensor3.shape == OrderedDict([("alpha", 10)])
×
287
    assert ntensor3.vshape == base3.shape
×
288
    assert (np.abs(ntensor3._tensor - base3) < 1e-5).all()
×
289

290
    # ntensora = ntensor.reduce("alpha", "mean")
291
    # assert ntensora.named_shape == OrderedDict([("beta", 2),
292
    #                                        ("gamma", 50)])
293

294
    # ntensorb = ntensor.reduce("alpha gamma", "mean")
295
    # assert ntensorb.named_shape == OrderedDict([("beta", 2)])
296

297

298
# def test_lift():
299
#     def test_function(tensor):
300
#         return np.sum(tensor, dim=1)
301

302
#     base = np.random.randn(10, 70, 50)
303
#     ntensor = NamedTensor(base, 'batch alpha beta')
304

305
#     lifted = lift(test_function, ["alpha beta"], "beta")
306

307

308
#     ntensor2 = lifted(ntensor)
309
#     assert ntensor2.named_shape == OrderedDict([("batch", 10),
310
#                                             ("beta", 2)])
311

312

313
def test_unbind2():
×
314
    base1 = torch.randn(10, 2, 50)
×
315
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
×
316
    a, b = ntensor1.unbind("beta")
×
317
    assert a.shape == OrderedDict([("alpha", 10), ("gamma", 50)])
×
318

319

320
# def test_access():
321
#     base1 = torch.randn(10, 2, 50)
322

323
#     ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
324

325
#     assert (ntensor1.access("gamma")[45] == base1[:, :, 45]).all()
326
#     assert (ntensor1.get("gamma", 1)._tensor == base1[:, :, 1]).all()
327

328
#     assert (ntensor1.access("gamma beta")[45, 1] == base1[:, 1, 45]).all()
329

330

331
def test_takes():
×
332
    base1 = torch.randn(10, 2, 50)
×
333

334
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
×
335
    indices = torch.ones(30).long()
×
336
    ntensor2 = ntorch.tensor(indices, ("indices",))
×
337

338
    selected = ntensor1.index_select("beta", ntensor2)
×
339
    assert (selected._tensor == base1.index_select(1, indices)).all()
×
340
    assert selected.shape == OrderedDict(
×
341
        [("alpha", 10), ("indices", 30), ("gamma", 50)]
342
    )
343

344

345
def test_narrow():
×
346
    base1 = torch.randn(10, 2, 50)
×
347

348
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
×
349
    narrowed = ntensor1.narrow("gamma", 0, 25)
×
350
    assert narrowed.shape == OrderedDict(
×
351
        [("alpha", 10), ("beta", 2), ("gamma", 25)]
352
    )
353

354

355
# def test_ops():
356
#     base1 = ntorch.randn(dict(alpha=10, beta=2, gamma=50))
357
#     base2 = ntorch.log(base1)
358
#     base2 = ntorch.exp(base1)
359

360

361
@pytest.mark.xfail
×
362
def test_mask2():
363
    base1 = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
×
364
    base2 = base1.mask_to("alpha")
×
365
    print(base2._schema._masked)
×
366
    base2 = base2.softmax("alpha")
×
367

368

369
def test_unmask():
×
370
    base1 = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
×
371
    base2 = base1.mask_to("alpha")
×
372
    base2 = base2.mask_to("")
×
373
    base2 = base2.softmax("alpha")
×
374

375

376
# def test_division():
377
#     base1 = NamedTensor(torch.ones(3, 4), ("short", "long"))
378
#     expected = NamedTensor(torch.ones(3) / 4, ("short",))
379
#     assert_match(base1 / base1.sum("long"), expected)
380

381

382
# def test_scalarmult():
383
#     base1 = NamedTensor(torch.ones(3, 4), ("short", "long"))
384
#     rmul = 3 * base1
385
#     lmul = base1 * 3
386
#     assert_match(rmul, lmul)
387

388

389
# def test_subtraction():
390
#     base1 = ntorch.ones(3, 4, names=("short", "long"))
391
#     base2 = ntorch.ones(3, 4, names=("short", "long"))
392
#     expect = ntorch.zeros(3, 4, names=("short", "long"))
393
#     assert_match(base1 - base2, expect)
394

395

396
# def test_rightsubtraction():
397
#     base1 = ntorch.ones(3, 4, names=("short", "long"))
398
#     expect = ntorch.zeros(3, 4, names=("short", "long"))
399
#     assert_match(1 - base1, expect)
400

401

402
# def test_rightaddition():
403
#     base1 = ntorch.ones(3, 4, names=("short", "long"))
404
#     expect = NamedTensor(2 * torch.ones(3, 4), names=("short", "long"))
405
#     assert_match(1 + base1, expect)
406

407

408
# def test_neg():
409
#     base = ntorch.ones(3, names=("short",))
410
#     expected = NamedTensor(-1 * torch.ones(3), ("short",))
411
#     assert_match(-base, expected)
412

413

414
def test_nonzero():
×
415

416
    # only zeros
417
    x = ntorch.zeros(10, names=("alpha",))
×
418
    y = x.nonzero()
×
419
    assert x.shape == OrderedDict([("alpha", 10)])
×
420
    assert y.shape == OrderedDict([("elements", 0), ("inputdims", 1)])
×
421

422
    # `names` length must be 2
423
    y = x.nonzero(names=("a", "b"))
×
424
    assert y.shape == OrderedDict([("a", 0), ("b", 1)])
×
425

426
    # 1d tensor
427
    x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
×
428
    y = x.nonzero()
×
429
    assert 3 == y.size("elements")
×
430
    assert x.shape == OrderedDict([("dim", 5)])
×
431
    assert y.shape == OrderedDict([("elements", 3), ("inputdims", 1)])
×
432

433
    # `names` length must be 2
434
    y = x.nonzero(names=("a", "b"))
×
435
    assert 3 == y.size("a")
×
436
    assert y.shape == OrderedDict([("a", 3), ("b", 1)])
×
437

438
    # 2d tensor
439
    x = ntorch.tensor(
×
440
        [
441
            [0.6, 0.0, 0.0, 0.0],
442
            [0.0, 0.4, 0.0, 0.0],
443
            [0.0, 0.0, 1.2, 0.0],
444
            [2.0, 0.0, 0.0, -0.4],
445
        ],
446
        names=("alpha", "beta"),
447
    )
448
    y = x.nonzero()
×
449
    assert 5 == y.size("elements")
×
450
    assert x.shape == OrderedDict([("alpha", 4), ("beta", 4)])
×
451
    assert y.shape == OrderedDict([("elements", 5), ("inputdims", 2)])
×
452

453
    # `names` length must be 2
454
    y = x.nonzero(names=("a", "b"))
×
455
    assert 5 == y.size("a")
×
456
    assert y.shape == OrderedDict([("a", 5), ("b", 2)])
×
457

458

459
@pytest.mark.xfail
×
460
def test_nonzero_names():
461

462
    # `names` length must be 2
463
    x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
×
464
    y = x.nonzero(names=("a",))
×
465
    assert 2 == len(y.shape)
×
466

467
    # `names` length must be 2
468
    x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
×
469
    y = x.nonzero(names=("a", "b", "c"))
×
470
    assert 2 == len(y.shape)
×
471

472

473
# def test_log_softmax():
474
#     base = (
475
#         ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
476
#         .float()
477
#         .log_softmax("dim")
478
#     )
479
#     y = F.log_softmax(torch.tensor([0, 1, 2, 0, 5]).float(), dim=0)
480
#     expected = ntorch.tensor(y, names=("dim",))
481
#     assert_match(base, expected)
482

483

484
def test_indexing():
×
485
    base = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
×
486

487
    base1 = base[{"alpha": 2}]
×
488
    assert base1.shape == OrderedDict([("beta", 2), ("gamma", 50)])
×
489

490
    base1 = base[{"beta": 0}]
×
491
    assert base1.shape == OrderedDict([("alpha", 10), ("gamma", 50)])
×
492

493
    base1 = base[{"alpha": slice(2, 5)}]
×
494
    assert base1.shape == OrderedDict(
×
495
        [("alpha", 3), ("beta", 2), ("gamma", 50)]
496
    )
497

498
@given(data(), named_tensor())
×
499
def test_indexing(data, x):
500
    d = data.draw(dim(x))
×
501
    i = data.draw(integers(min_value=0, max_value=x.shape[d]))
×
502
    x2 = x[{d: i}]
×
503
    assert x2.dims == x.dims - set([d])
×
504

505
    ds = data.draw(dims(x))
×
506
    index = {}
×
507
    for d in ds:
×
508
        i = data.draw(integers(min_value=0, max_value=x.shape[d]))
×
509
        index[d] = i
×
510
    x2 = x[index]
×
511
    assert x2.dims == x.dims - set(ds)
×
512

513
    ds = data.draw(dims(x))
×
514
    index = {}
×
515
    for d in ds:
×
516
        i = data.draw(integers(min_value=0, max_value=x.shape[d]-1))
×
517
        j = data.draw(integers(min_value=i+1, max_value=x.shape[d]))
×
518
        index[d] = slice(i, j)
×
519
    x2 = x[index]
×
520
    assert x2.dims == x.dims - set(ds)
×
521

522

523
def test_index_set():
×
524

525
    base = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
×
526
    new = ntorch.randn(2, 50, names=("beta", "gamma"))
×
527
    base[{"alpha": 2}] = new
×
528
    new = ntorch.randn(3, 2, 50, names=("alpha", "beta", "gamma"))
×
529
    base[{"alpha": slice(0, 3)}] = new
×
530

531

532
def test_tensor_mask():
×
533
    base = ntorch.zeros(10, 2, 50, names=("alpha", "beta", "gamma"))
×
534
    base[{"alpha": slice(2, 5), "gamma": slice(4, 6)}] = 1
×
535

536
    mask = base > 0.5
×
537
    base1 = base[mask]
×
538
    assert base1.shape == OrderedDict([("on", 12)])
×
539
    base[mask] = 6
×
540
    print(base[{"alpha": 2, "gamma": 4, "beta": 0}])
×
541
    assert base[{"alpha": 2, "gamma": 4, "beta": 0}].values == 6
×
542

543

544
def test_index_tensor():
×
545
    base = ntorch.zeros(10, 2, 50, names=("alpha", "beta", "gamma"))
×
546
    indices = ntorch.tensor([1, 2, 3, 4], names=("indices"))
×
547
    base1 = base[{"gamma": indices}]
×
548
    assert base1.shape == OrderedDict(
×
549
        [("alpha", 10), ("beta", 2), ("indices", 4)]
550
    )
551

552
    indices = ntorch.tensor([1, 2, 3, 4], names=("indices"))
×
553
    base1 = base[{"alpha": 1, "gamma": indices}]
×
554
    assert base1.shape == OrderedDict([("beta", 2), ("indices", 4)])
×
555

556
    indices = ntorch.tensor([[1, 2, 3], [1, 2, 3]], names=("d", "indices"))
×
557
    base1 = base[{"gamma": indices}]
×
558
    assert base1.shape == OrderedDict(
×
559
        [("alpha", 10), ("beta", 2), ("d", 2), ("indices", 3)]
560
    )
561

562

563
def test_setindex_tensor():
×
564
    base = ntorch.zeros(10, 2, 50, names=("alpha", "beta", "gamma")).float()
×
565
    indices = ntorch.tensor([1, 2, 3, 4], names=("indices")).long()
×
566
    vals = ntorch.tensor([52, 23.0, 42.9, 4.2], names=("indices")).float()
×
567
    b = base[{"alpha": 1, "beta": 1}]
×
568
    b[{"gamma": indices}] = vals
×
569
    assert base[{"alpha": 1, "beta": 1, "gamma": 1}].values == 52
×
570

571
    base[{"gamma": indices}] = 2
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc