• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

harvardnlp / namedtensor / 248

pending completion
248

Pull #59

travis-ci

web-flow
lint
Pull Request #59: Rewrite of test suite

246 of 246 new or added lines in 5 files covered. (100.0%)

1069 of 1196 relevant lines covered (89.38%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.79
/namedtensor/test_core.py
1
from . import assert_match, ntorch
1✔
2
import torch
1✔
3
from collections import OrderedDict
1✔
4
import pytest
1✔
5
import torch.nn.functional as F
1✔
6
from hypothesis import given
1✔
7
from .strategies import (
1✔
8
    named_tensor,
9
    broadcast_named_tensor,
10
    mask_named_tensor,
11
    dim,
12
    dims,
13
    name,
14
    names,
15
)
16
from hypothesis.strategies import (
1✔
17
    sampled_from,
18
    lists,
19
    data,
20
    floats,
21
    integers,
22
    permutations,
23
)
24

25

26
## HYPOTHESIS Tests
27
@given(data(), named_tensor())
1✔
28
def test_stack_basic(data, x):
29
    s = data.draw(dims(x))
1✔
30
    n = data.draw(name(x))
1✔
31
    x = x.stack(list(s), n)
1✔
32
    assert n in x.dims
1✔
33
    assert not (x.shape.keys() & s)
1✔
34

35

36
@given(data(), named_tensor())
1✔
37
def test_rename(data, x):
38
    s = data.draw(dim(x))
1✔
39
    n = data.draw(name(x))
1✔
40
    x = x.rename(s, n)
1✔
41
    assert n in x.dims
1✔
42
    assert s not in x.dims
1✔
43

44

45
@given(data(), named_tensor())
1✔
46
def test_split(data, x):
47
    s = data.draw(dim(x))
1✔
48
    ns = list(data.draw(names(x)))
1✔
49
    x2 = x.split(s, ns, **{n: 1 for n in ns[:-1]})
1✔
50
    assert len(set(ns) & set(x2.dims)) == len(ns)
1✔
51
    assert s not in x2.dims
1✔
52
    assert torch.prod(torch.tensor([x2.shape[n] for n in ns])) == x.shape[s]
1✔
53

54

55
@given(data(), named_tensor())
1✔
56
def test_reduce(data, x):
57
    ns = data.draw(dims(x))
1✔
58
    method = data.draw(sampled_from(sorted(x._reduce)))
1✔
59

60
    if method not in ["logsumexp"]:
1✔
61
        y = getattr(x, method)()
1✔
62
        print(y)
1✔
63
        # assert y.values == getattr(x.values, method)()
64

65
    x2 = getattr(x, method)(tuple(ns))
1✔
66
    assert set(x2.dims) | set(ns) == set(x.dims)
1✔
67

68

69
@given(data(), named_tensor())
1✔
70
def test_binary_op(data, x):
71
    y = data.draw(broadcast_named_tensor(x))
1✔
72
    method = data.draw(sampled_from(sorted(x._binop)))
1✔
73
    x2 = getattr(x, method)(y)
1✔
74
    assert set(x2.dims) == set(x.dims) | set(y.dims)
1✔
75
    x3 = getattr(y, method)(x)
1✔
76
    assert set(x3.dims) == set(x.dims) | set(y.dims)
1✔
77

78

79
@given(data(), named_tensor())
1✔
80
def test_noshift(data, x):
81
    method = data.draw(
1✔
82
        sampled_from(sorted(x._noshift)).filter(lambda a: a not in {"cuda"})
83
    )
84
    x2 = getattr(x, method)()
1✔
85
    assert set(x2.dims) == set(x.dims)
1✔
86

87

88
@given(data(), named_tensor())
1✔
89
def test_apply(data, x):
90
    method = data.draw(
1✔
91
        sampled_from(sorted(x._noshift_dim | x._noshift_nn_dim))
92
    )
93
    s = data.draw(dim(x))
1✔
94
    x2 = getattr(x, method)(s)
1✔
95
    assert x.shape == x2.shape
1✔
96

97

98
@given(named_tensor())
1✔
99
def test_sum(x):
100
    s = x.sum()
1✔
101
    print(x.shape)
1✔
102
    assert s.values == x.values.sum()
1✔
103

104

105
@given(data(), named_tensor())
1✔
106
def test_mask(data, x):
107
    mask = data.draw(mask_named_tensor(x))
1✔
108
    x2 = x.masked_select(mask, "c")
1✔
109
    x2 = x[mask]
1✔
110
    print(x2)
1✔
111

112

113
@pytest.mark.xfail
1✔
114
@given(data(), named_tensor())
1✔
115
def test_maskfail(data, x):
116
    mask = data.draw(broadcast_named_tensor(x))
1✔
117
    x2 = x.masked_select(mask, "c")
1✔
118
    x2 = x[mask]
×
119
    print(x2)
×
120

121

122
@given(data(), named_tensor(), floats(allow_nan=False, allow_infinity=False))
1✔
123
def test_all_scalar_ops(data, x, y):
124
    x = x + y
1✔
125
    x = x - y
1✔
126
    x = x * y
1✔
127
    x = x / y
1✔
128

129
    x = y + x
1✔
130
    x = y - x
1✔
131
    x = y * x
1✔
132

133
    x = -x
1✔
134

135

136
@given(data(), named_tensor())
1✔
137
def test_indexing(data, x):
138
    d = data.draw(dim(x))
1✔
139
    i = data.draw(integers(min_value=0, max_value=x.shape[d] - 1))
1✔
140
    x2 = x[{d: i}]
1✔
141
    assert set(x2.dims) == set(x.dims) - set([d])
1✔
142

143
    ds = data.draw(dims(x))
1✔
144
    index = {}
1✔
145
    for d in ds:
1✔
146
        i = data.draw(integers(min_value=0, max_value=x.shape[d] - 1))
1✔
147
        index[d] = i
1✔
148
    x2 = x[index]
1✔
149
    assert set(x2.dims) == set(x.dims) - set(ds)
1✔
150

151
    ds = data.draw(dims(x))
1✔
152
    index = {}
1✔
153
    for d in ds:
1✔
154
        i = data.draw(integers(min_value=0, max_value=x.shape[d] - 1))
1✔
155
        j = data.draw(integers(min_value=i + 1, max_value=x.shape[d]))
1✔
156
        index[d] = slice(i, j)
1✔
157
    x2 = x[index]
1✔
158
    assert set(x2.dims) == set(x.dims)
1✔
159
    x[index] = 6
1✔
160

161

162
@given(data(), named_tensor())
1✔
163
def test_tensor_indexing(data, x):
164
    d = data.draw(dim(x))
1✔
165
    indices = data.draw(
1✔
166
        lists(integers(min_value=0, max_value=x.shape[d] - 1), unique=True)
167
    )
168
    n = data.draw(name(x))
1✔
169
    ind_vector = ntorch.tensor(indices, names=n).long()
1✔
170
    x2 = x[{d: ind_vector}]
1✔
171
    assert set(x2.dims) == (set(x.dims) | set([n])) - set([d])
1✔
172

173
    x[{d: ind_vector}] = 5
1✔
174
    assert (x[{d: ind_vector}] == 5).all()
1✔
175

176

177
@given(data(), named_tensor())
1✔
178
def test_tensor_mask(data, x):
179
    mask = data.draw(mask_named_tensor(x))
1✔
180
    x[mask] = 6
1✔
181
    x2 = x[mask]
1✔
182
    assert x2.dims == ("on",)
1✔
183

184

185
@given(data(), named_tensor())
1✔
186
def test_cat(data, x):
187
    perm = data.draw(permutations(x.dims))
1✔
188
    y = x.transpose(*perm)
1✔
189
    for s in set(x.dims) & set(y.dims):
1✔
190
        c = ntorch.cat([x, y], dim=s)
1✔
191
        c = ntorch.cat([x, c], dim=s)
1✔
192
        c = ntorch.cat([c, x, y], dim=s)
1✔
193
    print(c)
1✔
194

195

196
@given(data(), named_tensor())
1✔
197
def test_stack(data, x):
198
    perm = data.draw(permutations(x.dims))
1✔
199
    print(perm)
1✔
200
    y = x.transpose(*perm)
1✔
201
    n = data.draw(name(x))
1✔
202
    z = ntorch.stack([x, y], n)
1✔
203
    assert set(z.dims) == set(x.dims) | set([n])
1✔
204

205

206
@given(data(), named_tensor())
1✔
207
def test_dot(data, x):
208
    y = data.draw(broadcast_named_tensor(x))
1✔
209
    dsx = data.draw(dims(x))
1✔
210
    dsy = data.draw(dims(x))
1✔
211
    x.dot(dsx, y)
1✔
212
    x.dot(dsy, y)
1✔
213
    y.dot(dsx, x)
1✔
214
    y.dot(dsy, x)
1✔
215

216

217
## Old style tests
218

219

220
def test_apply2():
1✔
221
    base = torch.zeros([10, 2, 50])
1✔
222
    ntensor = ntorch.tensor(base, ("alpha", "beta", "gamma"))
1✔
223
    ntensor = ntensor.op(F.softmax, dim="alpha")
1✔
224
    assert (ntorch.abs(ntensor.sum("alpha") - 1.0) < 1e-5).all()
1✔
225

226

227
# def test_fill():
228
#     base = torch.zeros([10, 2, 50])
229
#     ntensor = ntorch.tensor(base, ("alpha", "beta", "gamma"))
230
#     ntensor.fill_(20)
231
#     assert (ntensor == 20).all()
232

233

234
def test_gather():
1✔
235
    t = torch.Tensor([[1, 2], [3, 4]])
1✔
236
    base = torch.gather(t, 1, torch.LongTensor([[0, 0], [1, 0]]))
1✔
237

238
    t = ntorch.tensor(torch.Tensor([[1, 2], [3, 4]]), ("a", "b"))
1✔
239
    index = ntorch.tensor(torch.LongTensor([[0, 0], [1, 0]]), ("a", "c"))
1✔
240
    ntensor = ntorch.gather(t, "b", index, "c")
1✔
241
    assert (ntensor.values == base).all()
1✔
242
    assert ntensor.shape == OrderedDict([("a", 2), ("c", 2)])
1✔
243

244
    x = ntorch.tensor(torch.rand(2, 5), ("c", "b"))
1✔
245
    y = ntorch.tensor(torch.rand(3, 5), ("a", "b"))
1✔
246
    y.scatter_(
1✔
247
        "a",
248
        ntorch.tensor(
249
            torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), ("c", "b")
250
        ),
251
        x,
252
        "c",
253
    )
254
    assert y.shape == OrderedDict([("a", 3), ("b", 5)])
1✔
255

256

257
def test_unbind():
1✔
258
    base = torch.zeros([10, 2, 50])
1✔
259
    ntensor = ntorch.tensor(base, ("alpha", "beta", "gamma"))
1✔
260
    out = ntensor.unbind("beta")
1✔
261
    assert len(out) == 2
1✔
262
    assert out[0].shape == OrderedDict([("alpha", 10), ("gamma", 50)])
1✔
263

264
    base = torch.zeros([10])
1✔
265
    ntensor = ntorch.tensor(base, ("alpha",))
1✔
266
    ntensor.fill_(20)
1✔
267
    c = ntensor.unbind("alpha")
1✔
268
    assert len(c) == 10
1✔
269
    assert c[0].item() == 20
1✔
270

271

272
# @pytest.mark.xfail
273
# def test_fail():
274
#     for base1, base2 in zip(
275
#         make_tensors([10, 2, 50]), make_tensors([10, 20, 2])
276
#     ):
277
#         ntensor1 = NamedTensor(base1, ("alpha", "beta", "gamma"))
278
#         ntensor2 = NamedTensor(base2, ("alpha", "beat", "gamma"))
279
#         assert_match(ntensor1, ntensor2)
280

281

282
def test_multiple():
1✔
283
    base1 = torch.rand([10, 2, 50])
1✔
284
    base2 = torch.rand([10, 20, 2])
1✔
285
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
1✔
286
    ntensor2 = ntorch.tensor(base2, ("alpha", "delta", "beta"))
1✔
287
    assert_match(ntensor1, ntensor2)
1✔
288

289
    # Try applying a projected bin op
290
    base3 = torch.mul(base1.view([10, 1, 2, 50]), base2.view([10, 20, 2, 1]))
1✔
291
    ntensor3 = ntensor1.mul(ntensor2).transpose(
1✔
292
        "alpha", "delta", "beta", "gamma"
293
    )
294

295
    assert base3.shape == ntensor3.vshape
1✔
296
    assert (base3 == ntensor3.values).all()
1✔
297

298

299
# def test_contract():
300
#     base1 = torch.randn(10, 2, 50)
301
#     ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
302
#     base2 = torch.randn(10, 20, 2)
303
#     ntensor2 = ntorch.tensor(base2, ("alpha", "delta", "beta"))
304
#     assert_match(ntensor1, ntensor2)
305

306
#     base3 = torch.einsum("abg,adb->a", (base1, base2))
307

308
#     ntensor3 = ntorch.dot(("beta", "gamma", "delta"), ntensor1, ntensor2)
309
#     assert ntensor3.shape == OrderedDict([("alpha", 10)])
310
#     assert ntensor3.vshape == base3.shape
311
#     assert (np.abs(ntensor3._tensor - base3) < 1e-5).all()
312

313
# ntensora = ntensor.reduce("alpha", "mean")
314
# assert ntensora.named_shape == OrderedDict([("beta", 2),
315
#                                        ("gamma", 50)])
316

317
# ntensorb = ntensor.reduce("alpha gamma", "mean")
318
# assert ntensorb.named_shape == OrderedDict([("beta", 2)])
319

320

321
# def test_lift():
322
#     def test_function(tensor):
323
#         return np.sum(tensor, dim=1)
324

325
#     base = np.random.randn(10, 70, 50)
326
#     ntensor = NamedTensor(base, 'batch alpha beta')
327

328
#     lifted = lift(test_function, ["alpha beta"], "beta")
329

330

331
#     ntensor2 = lifted(ntensor)
332
#     assert ntensor2.named_shape == OrderedDict([("batch", 10),
333
#                                             ("beta", 2)])
334

335

336
def test_unbind2():
1✔
337
    base1 = torch.randn(10, 2, 50)
1✔
338
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
1✔
339
    a, b = ntensor1.unbind("beta")
1✔
340
    assert a.shape == OrderedDict([("alpha", 10), ("gamma", 50)])
1✔
341

342

343
# def test_access():
344
#     base1 = torch.randn(10, 2, 50)
345

346
#     ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
347

348
#     assert (ntensor1.access("gamma")[45] == base1[:, :, 45]).all()
349
#     assert (ntensor1.get("gamma", 1)._tensor == base1[:, :, 1]).all()
350

351
#     assert (ntensor1.access("gamma beta")[45, 1] == base1[:, 1, 45]).all()
352

353

354
def test_takes():
1✔
355
    base1 = torch.randn(10, 2, 50)
1✔
356

357
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
1✔
358
    indices = torch.ones(30).long()
1✔
359
    ntensor2 = ntorch.tensor(indices, ("indices",))
1✔
360

361
    selected = ntensor1.index_select("beta", ntensor2)
1✔
362
    assert (selected._tensor == base1.index_select(1, indices)).all()
1✔
363
    assert selected.shape == OrderedDict(
1✔
364
        [("alpha", 10), ("indices", 30), ("gamma", 50)]
365
    )
366

367

368
def test_narrow():
1✔
369
    base1 = torch.randn(10, 2, 50)
1✔
370

371
    ntensor1 = ntorch.tensor(base1, ("alpha", "beta", "gamma"))
1✔
372
    narrowed = ntensor1.narrow("gamma", 0, 25)
1✔
373
    assert narrowed.shape == OrderedDict(
1✔
374
        [("alpha", 10), ("beta", 2), ("gamma", 25)]
375
    )
376

377

378
# def test_ops():
379
#     base1 = ntorch.randn(dict(alpha=10, beta=2, gamma=50))
380
#     base2 = ntorch.log(base1)
381
#     base2 = ntorch.exp(base1)
382

383

384
@pytest.mark.xfail
1✔
385
def test_mask2():
386
    base1 = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
1✔
387
    base2 = base1.mask_to("alpha")
1✔
388
    print(base2._schema._masked)
1✔
389
    base2 = base2.softmax("alpha")
1✔
390

391

392
def test_unmask():
1✔
393
    base1 = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
1✔
394
    base2 = base1.mask_to("alpha")
1✔
395
    base2 = base2.mask_to("")
1✔
396
    base2 = base2.softmax("alpha")
1✔
397

398

399
# def test_division():
400
#     base1 = NamedTensor(torch.ones(3, 4), ("short", "long"))
401
#     expected = NamedTensor(torch.ones(3) / 4, ("short",))
402
#     assert_match(base1 / base1.sum("long"), expected)
403

404

405
# def test_scalarmult():
406
#     base1 = NamedTensor(torch.ones(3, 4), ("short", "long"))
407
#     rmul = 3 * base1
408
#     lmul = base1 * 3
409
#     assert_match(rmul, lmul)
410

411

412
# def test_subtraction():
413
#     base1 = ntorch.ones(3, 4, names=("short", "long"))
414
#     base2 = ntorch.ones(3, 4, names=("short", "long"))
415
#     expect = ntorch.zeros(3, 4, names=("short", "long"))
416
#     assert_match(base1 - base2, expect)
417

418

419
# def test_rightsubtraction():
420
#     base1 = ntorch.ones(3, 4, names=("short", "long"))
421
#     expect = ntorch.zeros(3, 4, names=("short", "long"))
422
#     assert_match(1 - base1, expect)
423

424

425
# def test_rightaddition():
426
#     base1 = ntorch.ones(3, 4, names=("short", "long"))
427
#     expect = NamedTensor(2 * torch.ones(3, 4), names=("short", "long"))
428
#     assert_match(1 + base1, expect)
429

430

431
# def test_neg():
432
#     base = ntorch.ones(3, names=("short",))
433
#     expected = NamedTensor(-1 * torch.ones(3), ("short",))
434
#     assert_match(-base, expected)
435

436

437
def test_nonzero():
1✔
438

439
    # only zeros
440
    x = ntorch.zeros(10, names=("alpha",))
1✔
441
    y = x.nonzero()
1✔
442
    assert x.shape == OrderedDict([("alpha", 10)])
1✔
443
    assert y.shape == OrderedDict([("elements", 0), ("inputdims", 1)])
1✔
444

445
    # `names` length must be 2
446
    y = x.nonzero(names=("a", "b"))
1✔
447
    assert y.shape == OrderedDict([("a", 0), ("b", 1)])
1✔
448

449
    # 1d tensor
450
    x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
1✔
451
    y = x.nonzero()
1✔
452
    assert 3 == y.size("elements")
1✔
453
    assert x.shape == OrderedDict([("dim", 5)])
1✔
454
    assert y.shape == OrderedDict([("elements", 3), ("inputdims", 1)])
1✔
455

456
    # `names` length must be 2
457
    y = x.nonzero(names=("a", "b"))
1✔
458
    assert 3 == y.size("a")
1✔
459
    assert y.shape == OrderedDict([("a", 3), ("b", 1)])
1✔
460

461
    # 2d tensor
462
    x = ntorch.tensor(
1✔
463
        [
464
            [0.6, 0.0, 0.0, 0.0],
465
            [0.0, 0.4, 0.0, 0.0],
466
            [0.0, 0.0, 1.2, 0.0],
467
            [2.0, 0.0, 0.0, -0.4],
468
        ],
469
        names=("alpha", "beta"),
470
    )
471
    y = x.nonzero()
1✔
472
    assert 5 == y.size("elements")
1✔
473
    assert x.shape == OrderedDict([("alpha", 4), ("beta", 4)])
1✔
474
    assert y.shape == OrderedDict([("elements", 5), ("inputdims", 2)])
1✔
475

476
    # `names` length must be 2
477
    y = x.nonzero(names=("a", "b"))
1✔
478
    assert 5 == y.size("a")
1✔
479
    assert y.shape == OrderedDict([("a", 5), ("b", 2)])
1✔
480

481

482
@pytest.mark.xfail
1✔
483
def test_nonzero_names():
484

485
    # `names` length must be 2
486
    x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
1✔
487
    y = x.nonzero(names=("a",))
1✔
488
    assert 2 == len(y.shape)
×
489

490
    # `names` length must be 2
491
    x = ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
×
492
    y = x.nonzero(names=("a", "b", "c"))
×
493
    assert 2 == len(y.shape)
×
494

495

496
# def test_log_softmax():
497
#     base = (
498
#         ntorch.tensor([0, 1, 2, 0, 5], names=("dim",))
499
#         .float()
500
#         .log_softmax("dim")
501
#     )
502
#     y = F.log_softmax(torch.tensor([0, 1, 2, 0, 5]).float(), dim=0)
503
#     expected = ntorch.tensor(y, names=("dim",))
504
#     assert_match(base, expected)
505

506

507
def test_indexing_basic():
1✔
508
    base = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
1✔
509

510
    base1 = base[{"alpha": 2}]
1✔
511
    assert base1.shape == OrderedDict([("beta", 2), ("gamma", 50)])
1✔
512

513
    base1 = base[{"beta": 0}]
1✔
514
    assert base1.shape == OrderedDict([("alpha", 10), ("gamma", 50)])
1✔
515

516
    base1 = base[{"alpha": slice(2, 5)}]
1✔
517
    assert base1.shape == OrderedDict(
1✔
518
        [("alpha", 3), ("beta", 2), ("gamma", 50)]
519
    )
520

521

522
def test_index_set():
1✔
523
    base = ntorch.randn(10, 2, 50, names=("alpha", "beta", "gamma"))
1✔
524
    new = ntorch.randn(2, 50, names=("beta", "gamma"))
1✔
525
    base[{"alpha": 2}] = new
1✔
526
    new = ntorch.randn(3, 2, 50, names=("alpha", "beta", "gamma"))
1✔
527
    base[{"alpha": slice(0, 3)}] = new
1✔
528

529

530
def test_index_tensor():
1✔
531
    base = ntorch.zeros(10, 2, 50, names=("alpha", "beta", "gamma"))
1✔
532
    indices = ntorch.tensor([1, 2, 3, 4], names=("indices"))
1✔
533
    base1 = base[{"gamma": indices}]
1✔
534
    assert base1.shape == OrderedDict(
1✔
535
        [("alpha", 10), ("beta", 2), ("indices", 4)]
536
    )
537

538
    indices = ntorch.tensor([1, 2, 3, 4], names=("indices"))
1✔
539
    base1 = base[{"alpha": 1, "gamma": indices}]
1✔
540
    assert base1.shape == OrderedDict([("beta", 2), ("indices", 4)])
1✔
541

542
    indices = ntorch.tensor([[1, 2, 3], [1, 2, 3]], names=("d", "indices"))
1✔
543
    base1 = base[{"gamma": indices}]
1✔
544
    assert base1.shape == OrderedDict(
1✔
545
        [("alpha", 10), ("beta", 2), ("d", 2), ("indices", 3)]
546
    )
547

548

549
def test_setindex_tensor():
1✔
550
    base = ntorch.zeros(10, 2, 50, names=("alpha", "beta", "gamma")).float()
1✔
551
    indices = ntorch.tensor([1, 2, 3, 4], names=("indices")).long()
1✔
552
    vals = ntorch.tensor([52, 23.0, 42.9, 4.2], names=("indices")).float()
1✔
553
    b = base[{"alpha": 1, "beta": 1}]
1✔
554
    b[{"gamma": indices}] = vals
1✔
555
    assert base[{"alpha": 1, "beta": 1, "gamma": 1}].values == 52
1✔
556

557
    base[{"gamma": indices}] = 2
1✔
558

559

560
@pytest.mark.xfail
1✔
561
def test_unique_names():
562
    base = torch.zeros([10, 2])
1✔
563
    assert ntorch.tensor(base, ("alpha", "beta", "alpha"))
1✔
564

565

566
def test_names():
1✔
567
    base = torch.zeros([10, 2, 50])
1✔
568
    assert ntorch.tensor(base, ("alpha", "beta", "gamma"))
1✔
569

570

571
@pytest.mark.xfail
1✔
572
def test_bad_names():
573
    base = torch.zeros([10, 2])
1✔
574
    assert ntorch.tensor(base, ("elements_dim", "input_dims"))
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc