• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

feihoo87 / waveforms / 6534953321

16 Oct 2023 02:19PM UTC coverage: 35.674% (-22.7%) from 58.421%
6534953321

push

github

feihoo87
fix Coveralls

5913 of 16575 relevant lines covered (35.67%)

3.21 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

54.24
/waveforms/waveform.py
1
import pickle
9✔
2
from bisect import bisect_left
9✔
3
from itertools import chain, product
9✔
4
from math import comb
9✔
5

6
import numpy as np
9✔
7
import scipy.special as special
9✔
8
from numpy import e, inf, pi
9✔
9

10
NDIGITS = 15
9✔
11

12
_zero = ((), ())
9✔
13

14

15
def _const(c):
9✔
16
    if c == 0:
9✔
17
        return _zero
9✔
18
    return (((), ()), ), (c, )
9✔
19

20

21
_one = _const(1.0)
9✔
22
_half = _const(1 / 2)
9✔
23
_two = _const(2.0)
9✔
24
_pi = _const(pi)
9✔
25
_two_pi = _const(2 * pi)
9✔
26
_half_pi = _const(pi / 2)
9✔
27

28

29
def _is_const(x):
9✔
30
    return x == _zero or x[0] == (((), ()), )
9✔
31

32

33
def _basic_wave(Type, *args, shift=0):
9✔
34
    return ((((Type, *args, shift), ), (1, )), ), (1.0, )
9✔
35

36

37
def _insert_type_value_pair(t_list, v_list, t, v, lo, hi):
9✔
38
    i = bisect_left(t_list, t, lo, hi)
9✔
39
    if i < hi and t_list[i] == t:
9✔
40
        v += v_list[i]
9✔
41
        if v == 0:
9✔
42
            t_list.pop(i)
9✔
43
            v_list.pop(i)
9✔
44
            return i, hi - 1
9✔
45
        else:
46
            v_list[i] = v
9✔
47
            return i, hi
9✔
48
    else:
49
        t_list.insert(i, t)
9✔
50
        v_list.insert(i, v)
9✔
51
        return i, hi + 1
9✔
52

53

54
def _mul(x, y):
9✔
55
    t_list, v_list = [], []
9✔
56
    xt_list, xv_list = x
9✔
57
    yt_list, yv_list = y
9✔
58
    lo, hi = 0, 0
9✔
59
    for (t1, t2), (v1, v2) in zip(product(xt_list, yt_list),
9✔
60
                                  product(xv_list, yv_list)):
61
        if v1 * v2 == 0:
9✔
62
            continue
×
63
        t = _add(t1, t2)
9✔
64
        lo, hi = _insert_type_value_pair(t_list, v_list, t, v1 * v2, lo, hi)
9✔
65
    return tuple(t_list), tuple(v_list)
9✔
66

67

68
def _add(x, y):
9✔
69
    #x, y = (x, y) if len(x[0]) >= len(y[0]) else (y, x)
70
    t_list, v_list = list(x[0]), list(x[1])
9✔
71
    lo, hi = 0, len(t_list)
9✔
72
    for t, v in zip(*y):
9✔
73
        lo, hi = _insert_type_value_pair(t_list, v_list, t, v, lo, hi)
9✔
74
    return tuple(t_list), tuple(v_list)
9✔
75

76

77
def _shift(x, time):
9✔
78
    if _is_const(x):
9✔
79
        return x
9✔
80

81
    t_list = []
9✔
82

83
    for pre_mtlist, nlist in x[0]:
9✔
84
        mtlist = []
9✔
85
        for Type, *args, shift in pre_mtlist:
9✔
86
            mtlist.append((Type, *args, shift + time))
9✔
87
        t_list.append((tuple(mtlist), nlist))
9✔
88
    return tuple(t_list), x[1]
9✔
89

90

91
def _pow(x, n):
9✔
92
    if x == _zero:
9✔
93
        return _zero
×
94
    if n == 0:
9✔
95
        return _one
×
96
    if _is_const(x):
9✔
97
        return _const(x[1][0]**n)
×
98

99
    if len(x[0]) == 1:
9✔
100
        t_list, v_list = [], []
×
101
        for (mtlist, pre_nlist), v in zip(*x):
×
102
            nlist = []
×
103
            for m in pre_nlist:
×
104
                nlist.append(n * m)
×
105
            t_list.append((mtlist, tuple(nlist)))
×
106
            v_list.append(v**n)
×
107
        return tuple(t_list), tuple(v_list)
×
108
    else:
109
        assert isinstance(n, int) and n > 0
9✔
110
        ret = _one
9✔
111
        for i in range(n):
9✔
112
            ret = _mul(ret, x)
9✔
113
        return ret
9✔
114

115

116
def _cos_power_n(x, n):
9✔
117
    _, w, shift = x
9✔
118
    ret = _zero
9✔
119
    for k in range(0, n // 2 + 1):
9✔
120
        if n == 2 * k:
9✔
121
            a = _const(comb(n, k) / 2**n)
9✔
122
            ret = _add(ret, a)
9✔
123
        else:
124
            expr = (((((COS, (n - 2 * k) * w, shift), ), (1, )), ),
9✔
125
                    (comb(n, k) / 2**(n - 1), ))
126
            ret = _add(ret, expr)
9✔
127
    return ret
9✔
128

129

130
def _trigMul_t(x, y, v):
9✔
131
    """cos(a)cos(b) = 0.5*cos(a+b)+0.5*cos(a-b)"""
132
    _, w1, t1 = x
9✔
133
    _, w2, t2 = y
9✔
134
    if w2 > w1:
9✔
135
        t1, t2 = t2, t1
9✔
136
        w1, w2 = w2, w1
9✔
137
    exp1 = (COS, w1 + w2, (w1 * t1 + w2 * t2) / (w1 + w2))
9✔
138
    if w1 == w2:
9✔
139
        c = v * np.cos(w1 * t1 - w2 * t2) / 2
9✔
140
        if c == 0:
9✔
141
            return (((exp1, ), (1, )), ), (0.5 * v, )
×
142
        else:
143
            return (((), ()), ((exp1, ), (1, ))), (c, 0.5 * v)
9✔
144
    else:
145
        exp2 = (COS, w1 - w2, (w1 * t1 - w2 * t2) / (w1 - w2))
9✔
146
        if exp2[1] > exp1[1]:
9✔
147
            exp2, exp1 = exp1, exp2
×
148
        return (((exp2, ), (1, )), ((exp1, ), (1, ))), (0.5 * v, 0.5 * v)
9✔
149

150

151
def _trigMul(x, y):
9✔
152
    if _is_const(x) or _is_const(y):
9✔
153
        return _mul(x, y)
9✔
154
    ret = _zero
9✔
155
    for (t1, t2), (v1, v2) in zip(product(x[0], y[0]), product(x[1], y[1])):
9✔
156
        v = v1 * v2
9✔
157
        tmp = _one
9✔
158
        trig = []
9✔
159
        for mt, n in zip(chain(t1[0], t2[0]), chain(t1[1], t2[1])):
9✔
160
            if mt[0] == COS:
9✔
161
                trig.append(mt)
9✔
162
            else:
163
                tmp = _mul(tmp, ((((mt, ), (n, )), ), (1, )))
×
164
        if len(trig) == 1:
9✔
165
            x = ((((trig[0], ), (1, )), ), (v, ))
9✔
166
            expr = _mul(tmp, x)
9✔
167
        elif len(trig) == 2:
9✔
168
            expr = _trigMul_t(trig[0], trig[1], v)
9✔
169
            expr = _mul(tmp, expr)
9✔
170
        else:
171
            expr = _mul(tmp, _const(v))
×
172
        ret = _add(ret, expr)
9✔
173
    return ret
9✔
174

175

176
def _exp_trig_Reduce(mtlist, v):
9✔
177
    trig = _one
9✔
178
    alpha = 0
9✔
179
    shift = 0
9✔
180
    ml, nl = [], []
9✔
181
    for mt, n in zip(*mtlist):
9✔
182
        if mt[0] == COS:
9✔
183
            trig = _trigMul(trig, _cos_power_n(mt, n))
9✔
184
        elif mt[0] == EXP:
9✔
185
            x = alpha * shift + n * mt[1] * mt[-1]
9✔
186
            alpha += n * mt[1]
9✔
187
            if alpha == 0:
9✔
188
                shift = 0
×
189
            else:
190
                shift = x / alpha
9✔
191
        elif mt[0] == GAUSSIAN and n != 1:
9✔
192
            ml.append((mt[0], mt[1] / np.sqrt(n), mt[2]))
×
193
            nl.append(1)
×
194
        else:
195
            ml.append(mt)
9✔
196
            nl.append(n)
9✔
197
    ret = (((tuple(ml), tuple(nl)), ), (v, ))
9✔
198

199
    if alpha != 0:
9✔
200
        ret = _mul(ret, _basic_wave(EXP, alpha, shift=shift))
9✔
201

202
    return _mul(ret, trig)
9✔
203

204

205
def _simplify(expr):
9✔
206
    ret = _zero
9✔
207
    for t, v in zip(*expr):
9✔
208
        y = _exp_trig_Reduce(t, v)
9✔
209
        ret = _add(ret, y)
9✔
210
    return ret
9✔
211

212

213
def _filter(expr, low, high):
9✔
214
    expr = _simplify(expr)
9✔
215
    ret = _zero
9✔
216
    for t, v in zip(*expr):
9✔
217
        for i, (mt, n) in enumerate(zip(*t)):
9✔
218
            if mt[0] == COS:
9✔
219
                if low <= mt[1] < high:
9✔
220
                    ret = _add(ret, ((t, ), (v, )))
9✔
221
                break
9✔
222
            elif mt[0] == SINC and n == 1:
9✔
223
                pass
×
224
            elif mt[0] == GAUSSIAN and n == 1:
9✔
225
                pass
×
226
        else:
227
            if low <= 0:
×
228
                ret = _add(ret, ((t, ), (v, )))
×
229
    return ret
9✔
230

231

232
def _apply(function_lib, func_id, x, shift, *args):
9✔
233
    return function_lib[func_id](x - shift, *args)
9✔
234

235

236
def _calc(wav, x, function_lib):
9✔
237
    lru_cache = {}
9✔
238

239
    def _calc_m(t, x):
9✔
240
        ret = 1
9✔
241
        for mt, n in zip(*t):
9✔
242
            if mt not in lru_cache:
9✔
243
                func_id, *args, shift = mt
9✔
244
                lru_cache[mt] = _apply(function_lib, func_id, x, shift, *args)
9✔
245
            if n == 1:
9✔
246
                ret = ret * lru_cache[mt]
9✔
247
            else:
248
                ret = ret * lru_cache[mt]**n
9✔
249
        return ret
9✔
250

251
    ret = 0
9✔
252
    for t, v in zip(*wav):
9✔
253
        ret = ret + v * _calc_m(t, x)
9✔
254
    return ret
9✔
255

256

257
def _num_latex(num):
9✔
258
    if num == -np.inf:
×
259
        return r"-\infty"
×
260
    elif num == np.inf:
×
261
        return r"\infty"
×
262
    s = f"{num:g}"
×
263
    if "e" in s:
×
264
        a, n = s.split("e")
×
265
        n = float(n)
×
266
        s = f"{a} \\times 10^{{{n:g}}}"
×
267
    return s
×
268

269

270
def _fun_latex(fun):
9✔
271
    funID, *args, shift = fun
×
272
    if _baseFunc_latex[funID] is None:
×
273
        shift = _num_latex(shift)
×
274
        if shift == "0":
×
275
            shift = ""
×
276
        elif shift[0] != '-':
×
277
            shift = "+" + shift
×
278
        return r"\mathrm{Func}" + f"{funID}(t{shift}, ...)"
×
279
    return _baseFunc_latex[funID](shift, *args)
×
280

281

282
def _wav_latex(wav):
9✔
283
    from waveforms.waveform import _is_const, _zero
×
284

285
    if wav == _zero:
×
286
        return "0"
×
287
    elif _is_const(wav):
×
288
        return f"{wav[1][0]}"
×
289

290
    sum_expr = []
×
291
    for mul, amp in zip(*wav):
×
292
        if mul == ((), ()):
×
293
            sum_expr.append(_num_latex(amp))
×
294
            continue
×
295
        mul_expr = []
×
296
        amp = _num_latex(amp)
×
297
        if amp != "1":
×
298
            mul_expr.append(amp)
×
299
        for fun, n in zip(*mul):
×
300
            fun_expr = _fun_latex(fun)
×
301
            if n != 1:
×
302
                mul_expr.append(fun_expr + "^{" + f"{n}" + "}")
×
303
            else:
304
                mul_expr.append(fun_expr)
×
305
        sum_expr.append(''.join(mul_expr))
×
306

307
    ret = sum_expr[0]
×
308
    for expr in sum_expr[1:]:
×
309
        if expr[0] == '-':
×
310
            ret += expr
×
311
        else:
312
            ret += "+" + expr
×
313
    return ret
×
314

315

316
class Waveform:
9✔
317
    __slots__ = ('bounds', 'seq', 'max', 'min', 'start', 'stop', 'sample_rate')
9✔
318

319
    def __init__(self, bounds=(+inf, ), seq=(_zero, ), min=-inf, max=inf):
9✔
320
        self.bounds = bounds
9✔
321
        self.seq = seq
9✔
322
        self.max = max
9✔
323
        self.min = min
9✔
324
        self.start = None
9✔
325
        self.stop = None
9✔
326
        self.sample_rate = None
9✔
327

328
    def _head(self):
9✔
329
        for i, s in enumerate(self.seq):
×
330
            if s is not _zero:
×
331
                if i == 0:
×
332
                    return -inf
×
333
                return self.bounds[i - 1]
×
334
        return inf
×
335

336
    def _tail(self):
9✔
337
        N = len(self.bounds)
×
338
        for i, s in enumerate(self.seq[::-1]):
×
339
            if s is not _zero:
×
340
                if i == 0:
×
341
                    return inf
×
342
                return self.bounds[N - i - 1]
×
343
        return -inf
×
344

345
    @property
9✔
346
    def head(self):
9✔
347
        if self.start is None:
×
348
            return self._head()
×
349
        else:
350
            return max(self.start, self._head())
×
351

352
    @property
9✔
353
    def tail(self):
9✔
354
        if self.stop is None:
×
355
            return self._tail()
×
356
        else:
357
            return min(self.stop, self._tail())
×
358

359
    def sample(self,
9✔
360
               sample_rate=None,
361
               out=None,
362
               chunk_size=None,
363
               function_lib=None):
364
        if sample_rate is None:
9✔
365
            sample_rate = self.sample_rate
9✔
366
        if self.start is None or self.stop is None or sample_rate is None:
9✔
367
            raise ValueError('Waveform is not initialized')
×
368
        if chunk_size is None:
9✔
369
            x = np.arange(self.start, self.stop, 1 / sample_rate)
9✔
370
            return self.__call__(x, out=out, function_lib=function_lib)
9✔
371
        else:
372
            return self._sample_iter(sample_rate, chunk_size, out,
×
373
                                     function_lib)
374

375
    def _sample_iter(self, sample_rate, chunk_size, out, function_lib):
9✔
376
        start = self.start
×
377
        start_n = 0
×
378
        length = chunk_size / sample_rate
×
379
        while start < self.stop:
×
380
            if start + length > self.stop:
×
381
                length = self.stop - start
×
382
                stop = self.stop
×
383
                size = round((stop - start) * sample_rate)
×
384
            else:
385
                stop = start + length
×
386
                size = chunk_size
×
387
            x = np.linspace(start, stop, size, endpoint=False)
×
388
            if out is not None:
×
389
                yield self.__call__(x,
×
390
                                    out=out[start_n:],
391
                                    function_lib=function_lib)
392
            else:
393
                yield self.__call__(x, function_lib=function_lib)
×
394
            start = stop
×
395
            start_n += chunk_size
×
396

397
    def tolist(self):
9✔
398
        ret = [self.max, self.min, self.start, self.stop, self.sample_rate]
9✔
399

400
        ret.append(len(self.bounds))
9✔
401
        for seq, b in zip(self.seq, self.bounds):
9✔
402
            ret.append(b)
9✔
403
            tlist, amplist = seq
9✔
404
            ret.append(len(amplist))
9✔
405
            for t, amp in zip(tlist, amplist):
9✔
406
                ret.append(amp)
9✔
407
                mtlist, nlist = t
9✔
408
                ret.append(len(nlist))
9✔
409
                for fun, n in zip(mtlist, nlist):
9✔
410
                    ret.append(n)
9✔
411
                    ret.append(len(fun))
9✔
412
                    ret.extend(fun)
9✔
413

414
        return ret
9✔
415

416
    @staticmethod
9✔
417
    def fromlist(l, return_pointer=False):
9✔
418

419
        def _read(l, pos, size):
9✔
420
            try:
9✔
421
                return tuple(l[pos:pos + size]), pos + size
9✔
422
            except:
×
423
                raise ValueError('Invalid waveform format')
×
424

425
        w = Waveform()
9✔
426
        (w.max, w.min, w.start, w.stop, w.sample_rate,
9✔
427
         nseg), pos = _read(l, 0, 6)
428
        bounds = []
9✔
429
        seq = []
9✔
430
        for _ in range(nseg):
9✔
431
            (b, nsum), pos = _read(l, pos, 2)
9✔
432
            bounds.append(b)
9✔
433
            amp = []
9✔
434
            t = []
9✔
435
            for _ in range(nsum):
9✔
436
                (a, nmul), pos = _read(l, pos, 2)
9✔
437
                amp.append(a)
9✔
438
                nlst = []
9✔
439
                mt = []
9✔
440
                for _ in range(nmul):
9✔
441
                    (n, nfun), pos = _read(l, pos, 2)
9✔
442
                    nlst.append(n)
9✔
443
                    fun, pos = _read(l, pos, nfun)
9✔
444
                    mt.append(fun)
9✔
445
                t.append((tuple(mt), tuple(nlst)))
9✔
446
            seq.append((tuple(t), tuple(amp)))
9✔
447
        w.seq = tuple(seq)
9✔
448
        w.bounds = tuple(bounds)
9✔
449
        if return_pointer:
9✔
450
            return w, pos
×
451
        return w
9✔
452

453
    def totree(self):
9✔
454
        header = (self.max, self.min, self.start, self.stop, self.sample_rate)
9✔
455
        body = []
9✔
456

457
        for seq, b in zip(self.seq, self.bounds):
9✔
458
            tlist, amplist = seq
9✔
459
            new_seq = []
9✔
460
            for t, amp in zip(tlist, amplist):
9✔
461
                mtlist, nlist = t
9✔
462
                new_t = []
9✔
463
                for fun, n in zip(mtlist, nlist):
9✔
464
                    new_t.append((n, fun))
9✔
465
                new_seq.append((amp, tuple(new_t)))
9✔
466
            body.append((b, tuple(new_seq)))
9✔
467
        return header, tuple(body)
9✔
468

469
    @staticmethod
9✔
470
    def fromtree(tree):
9✔
471
        w = Waveform()
9✔
472
        header, body = tree
9✔
473

474
        (w.max, w.min, w.start, w.stop, w.sample_rate) = header
9✔
475
        bounds = []
9✔
476
        seqs = []
9✔
477
        for b, seq in body:
9✔
478
            bounds.append(b)
9✔
479
            amp_list = []
9✔
480
            t_list = []
9✔
481
            for amp, t in seq:
9✔
482
                amp_list.append(amp)
9✔
483
                n_list = []
9✔
484
                mt_list = []
9✔
485
                for n, mt in t:
9✔
486
                    n_list.append(n)
9✔
487
                    mt_list.append(mt)
9✔
488
                t_list.append((tuple(mt_list), tuple(n_list)))
9✔
489
            seqs.append((tuple(t_list), tuple(amp_list)))
9✔
490
        w.bounds = tuple(bounds)
9✔
491
        w.seq = tuple(seqs)
9✔
492
        return w
9✔
493

494
    def simplify(self):
9✔
495
        seq = [_simplify(self.seq[0])]
9✔
496
        bounds = [self.bounds[0]]
9✔
497
        for expr, b in zip(self.seq[1:], self.bounds[1:]):
9✔
498
            expr = _simplify(expr)
9✔
499
            if expr == seq[-1]:
9✔
500
                seq.pop()
×
501
                bounds.pop()
×
502
            seq.append(expr)
9✔
503
            bounds.append(b)
9✔
504
        return Waveform(tuple(bounds), tuple(seq))
9✔
505

506
    def filter(self, low=0, high=inf):
9✔
507
        seq = []
9✔
508
        for expr in self.seq:
9✔
509
            seq.append(_filter(expr, low, high))
9✔
510
        return Waveform(self.bounds, tuple(seq))
9✔
511

512
    def _comb(self, other, oper):
9✔
513
        bounds = []
9✔
514
        seq = []
9✔
515
        i1, i2 = 0, 0
9✔
516
        h1, h2 = len(self.bounds), len(other.bounds)
9✔
517
        while i1 < h1 or i2 < h2:
9✔
518
            s = oper(self.seq[i1], other.seq[i2])
9✔
519
            b = min(self.bounds[i1], other.bounds[i2])
9✔
520
            if seq and s == seq[-1]:
9✔
521
                bounds[-1] = b
9✔
522
            else:
523
                bounds.append(b)
9✔
524
                seq.append(s)
9✔
525
            if b == self.bounds[i1]:
9✔
526
                i1 += 1
9✔
527
            if b == other.bounds[i2]:
9✔
528
                i2 += 1
9✔
529
        return Waveform(tuple(bounds), tuple(seq))
9✔
530

531
    def __pow__(self, n):
9✔
532
        return Waveform(self.bounds, tuple(_pow(w, n) for w in self.seq))
9✔
533

534
    def __add__(self, other):
9✔
535
        if isinstance(other, Waveform):
9✔
536
            return self._comb(other, _add)
9✔
537
        else:
538
            return self + const(other)
×
539

540
    def __radd__(self, v):
9✔
541
        return const(v) + self
×
542

543
    def append(self, other):
9✔
544
        if not isinstance(other, Waveform):
×
545
            raise TypeError('connect Waveform by other type')
×
546
        if len(self.bounds) == 1:
×
547
            self.bounds = other.bounds
×
548
            self.seq = self.seq + other.seq[1:]
×
549
            return
×
550

551
        assert self.bounds[-2] <= other.bounds[
×
552
            0], f"connect waveforms with overlaped domain {self.bounds}, {other.bounds}"
553
        if self.bounds[-2] < other.bounds[0]:
×
554
            self.bounds = self.bounds[:-1] + other.bounds
×
555
            self.seq = self.seq + other.seq[1:]
×
556
        else:
557
            self.bounds = self.bounds[:-2] + other.bounds
×
558
            self.seq = self.seq[:-1] + other.seq[1:]
×
559

560
    def __ior__(self, other):
9✔
561
        return self | other
×
562

563
    def __or__(self, other):
9✔
564
        if isinstance(other, (int, float, complex)):
×
565
            other = const(other)
×
566
        w = self.marker + other.marker
×
567

568
        def _or(a, b):
×
569
            if a != _zero or b != _zero:
×
570
                return _one
×
571
            else:
572
                return _zero
×
573

574
        return self._comb(other, _or)
×
575

576
    def __iand__(self, other):
9✔
577
        return self & other
×
578

579
    def __and__(self, other):
9✔
580
        if isinstance(other, (int, float, complex)):
×
581
            other = const(other)
×
582
        w = self.marker + other.marker
×
583

584
        def _and(a, b):
×
585
            if a != _zero and b != _zero:
×
586
                return _one
×
587
            else:
588
                return _zero
×
589

590
        return self._comb(other, _and)
×
591

592
    @property
9✔
593
    def marker(self):
9✔
594
        w = self.simplify()
×
595
        return Waveform(w.bounds,
×
596
                        tuple(_zero if s == _zero else _one for s in w.seq))
597

598
    def mask(self, edge=0):
9✔
599
        w = self.marker
×
600
        in_wave = w.seq[0] == _zero
×
601
        bounds = []
×
602
        seq = []
×
603

604
        if w.seq[0] == _zero:
×
605
            in_wave = False
×
606
            b = w.bounds[0] - edge
×
607
            bounds.append(b)
×
608
            seq.append(_zero)
×
609

610
        for b, s in zip(w.bounds[1:], w.seq[1:]):
×
611
            if not in_wave and s != _zero:
×
612
                in_wave = True
×
613
                bounds.append(b + edge)
×
614
                seq.append(_one)
×
615
            elif in_wave and s == _zero:
×
616
                in_wave = False
×
617
                b = b - edge
×
618
                if b > bounds[-1]:
×
619
                    bounds.append(b)
×
620
                    seq.append(_zero)
×
621
                else:
622
                    bounds.pop()
×
623
                    bounds.append(b)
×
624
        return Waveform(tuple(bounds), tuple(seq))
×
625

626
    def __mul__(self, other):
9✔
627
        if isinstance(other, Waveform):
9✔
628
            return self._comb(other, _mul)
9✔
629
        else:
630
            return self * const(other)
9✔
631

632
    def __rmul__(self, v):
9✔
633
        return const(v) * self
9✔
634

635
    def __truediv__(self, other):
9✔
636
        if isinstance(other, Waveform):
9✔
637
            raise TypeError('division by waveform')
×
638
        else:
639
            return self * const(1 / other)
9✔
640

641
    def __neg__(self):
9✔
642
        return -1 * self
9✔
643

644
    def __sub__(self, other):
9✔
645
        return self + (-other)
9✔
646

647
    def __rsub__(self, v):
9✔
648
        return v + (-self)
×
649

650
    def __rshift__(self, time):
9✔
651
        return Waveform(
9✔
652
            tuple(round(bound + time, NDIGITS) for bound in self.bounds),
653
            tuple(_shift(expr, time) for expr in self.seq))
654

655
    def __lshift__(self, time):
9✔
656
        return self >> (-time)
9✔
657

658
    def _calc_parts(self, x, function_lib):
9✔
659
        range_list = np.searchsorted(x, self.bounds)
9✔
660
        parts = []
9✔
661
        start, stop = 0, 0
9✔
662
        dtype = float
9✔
663
        for i, stop in enumerate(range_list):
9✔
664
            if start < stop and self.seq[i] != _zero:
9✔
665
                part = np.clip(_calc(self.seq[i], x[start:stop], function_lib),
9✔
666
                               self.min, self.max)
667
                if (isinstance(part, complex) or isinstance(part, np.ndarray)
9✔
668
                        and isinstance(part[0], complex)):
669
                    dtype = complex
9✔
670
                parts.append((start, stop, part))
9✔
671
            start = stop
9✔
672
        return parts, dtype
9✔
673

674
    def _merge_parts(self, parts, out):
9✔
675
        lo = 0
×
676
        for start, stop, part in parts:
×
677
            i = bisect_left(out, (start, stop), lo, key=lambda x: x[0])
×
678
            j = bisect_left(out, (start, stop), i, key=lambda x: x[1])
×
679
            # assert start <= out[i][0]
680
            # assert stop <= out[j][1]
681
            if i == j:
×
682
                start2, stop2, part2 = out[j]
×
683
                if stop < start2:
×
684
                    out.insert(i, (start, stop, part))
×
685
                elif stop == start2:
×
686
                    if isinstance(part, np.ndarray) and isinstance(
×
687
                            part2, np.ndarray):
688
                        out[j] = (start, stop2, np.hstack([part, part2]))
×
689
                    elif isinstance(
×
690
                            part, (int, float, complex)) and isinstance(
691
                                part2,
692
                                (int, float, complex)) and part == part2:
693
                        out[i] = (start, stop2, part)
×
694
                    else:
695
                        out.insert(i, (start, stop, part))
×
696
                else:
697
                    match (isinstance(part, np.ndarray),
×
698
                           isinstance(part2, np.ndarray)):
699
                        case (True, True):
×
700
                            out[j] = (start, stop2,
×
701
                                      np.hstack([
702
                                          part[:start2 - start],
703
                                          part[start2 - start:stop - start] +
704
                                          part2[:stop - start2],
705
                                          part2[stop - start2:]
706
                                      ]))
707
                        case (True, False):
×
708
                            out[j] = (stop, stop2, part2)
×
709
                            part[start2 - start:] += part2
×
710
                            out.insert(j, (start, stop, part))
×
711
                        case (False, True):
×
712
                            part2[:stop - start2] += part
×
713
                            out.insert(j, (start, start2, part))
×
714
                        case (False, False):
×
715
                            out[j] = (stop, stop2, part2)
×
716
                            out.insert(j, (start2, stop, part + part2))
×
717
                            out.insert(j, (start, start2, part))
×
718
            else:
719
                for k in range(i, j):
×
720
                    if isinstance(part, (int, float, complex)):
×
721
                        out[k] = (out[k][0], out[k][1], out[k][2] + part)
×
722
                    else:
723
                        out[k] = (out[k][0], out[k][1],
×
724
                                  out[k][2] + part[k - i])
725

726
                out[i:j] = [(out[i][0], start, out[i][2]), (start, stop, part),
×
727
                            (stop, out[j][1], out[j][2])]
728
            lo = j
×
729

730
    def _fill_parts(self, parts, out):
9✔
731
        for start, stop, part in parts:
9✔
732
            out[start:stop] += part
9✔
733

734
    def __call__(self,
9✔
735
                 x,
736
                 frag=False,
737
                 out=None,
738
                 accumulate=False,
739
                 function_lib=None):
740
        if function_lib is None:
9✔
741
            function_lib = _baseFunc
9✔
742
        if isinstance(x, (int, float, complex)):
9✔
743
            return self.__call__(np.array([x]), function_lib=function_lib)[0]
×
744
        parts, dtype = self._calc_parts(x, function_lib)
9✔
745
        if not frag:
9✔
746
            if out is None:
9✔
747
                out = np.zeros_like(x, dtype=dtype)
9✔
748
            elif not accumulate:
×
749
                out *= 0
×
750
            self._fill_parts(parts, out)
9✔
751
        else:
752
            if out is None:
×
753
                return parts
×
754
            else:
755
                if not accumulate:
×
756
                    out.clear()
×
757
                    out.extend(parts)
×
758
                else:
759
                    self._merge_parts(parts, out)
×
760
        return out
9✔
761

762
    def __hash__(self):
9✔
763
        return hash((self.max, self.min, self.start, self.stop,
×
764
                     self.sample_rate, self.bounds, self.seq))
765

766
    def __eq__(self, o: object) -> bool:
9✔
767
        if isinstance(o, (int, float, complex)):
9✔
768
            return self == const(o)
×
769
        elif isinstance(o, Waveform):
9✔
770
            a = self.simplify()
9✔
771
            b = o.simplify()
9✔
772
            return a.seq == b.seq and a.bounds == b.bounds and (
9✔
773
                a.max, a.min, a.start, a.stop,
774
                a.sample_rate) == (b.max, b.min, b.start, b.stop,
775
                                   b.sample_rate)
776
        else:
777
            return False
×
778

779
    def _repr_latex_(self):
9✔
780
        parts = []
×
781
        start = -np.inf
×
782
        for end, wav in zip(self.bounds, self.seq):
×
783
            e_str = _wav_latex(wav)
×
784
            start_str = _num_latex(start)
×
785
            end_str = _num_latex(end)
×
786
            parts.append(e_str + r",~~&t\in" + f"({start_str},{end_str}" +
×
787
                         (']' if end < np.inf else ')'))
788
            start = end
×
789
        if len(parts) == 1:
×
790
            expr = ''.join(['f(t)=', *parts[0].split('&')])
×
791
        else:
792
            expr = '\n'.join([
×
793
                r"f(t)=\begin{cases}", (r"\\" + '\n').join(parts),
794
                r"\end{cases}"
795
            ])
796
        return "$$\n{}\n$$".format(expr)
×
797

798
    def _play(self, time_unit, volume=1.0):
9✔
799
        import pyaudio
×
800

801
        CHUNK = 1024
×
802
        RATE = 48000
×
803

804
        dynamic_volume = 1.0
×
805
        amp = 2**15 * 0.999 * volume * dynamic_volume
×
806

807
        p = pyaudio.PyAudio()
×
808
        try:
×
809
            stream = p.open(format=pyaudio.paInt16,
×
810
                            channels=1,
811
                            rate=RATE,
812
                            output=True)
813
            try:
×
814
                for data in self.sample(sample_rate=RATE / time_unit,
×
815
                                        chunk_size=CHUNK):
816
                    lim = np.abs(data).max()
×
817
                    if lim > 0 and dynamic_volume > 1.0 / lim:
×
818
                        dynamic_volume = 1.0 / lim
×
819
                        amp = 2**15 * 0.99 * volume * dynamic_volume
×
820
                    data = (amp * data).astype(np.int16)
×
821
                    stream.write(bytes(data.data))
×
822
            finally:
823
                stream.stop_stream()
×
824
                stream.close()
×
825
        finally:
826
            p.terminate()
×
827

828
    def play(self, time_unit=1, volume=1.0):
9✔
829
        import multiprocessing as mp
×
830
        p = mp.Process(target=self._play,
×
831
                       args=(time_unit, volume),
832
                       daemon=True)
833
        p.start()
×
834

835

836
class WaveVStack(Waveform):
9✔
837

838
    def __init__(self, wlist: list[Waveform] = []):
9✔
839
        self.wlist = wlist
9✔
840
        self.start = None
9✔
841
        self.stop = None
9✔
842
        self.sample_rate = None
9✔
843

844
    def __call__(self, x, frag=False, out=None, function_lib=None):
9✔
845
        assert frag is False, 'WaveVStack does not support frag mode'
×
846
        out = np.zeros_like(x, dtype=complex)
×
847
        for w in self.wlist:
×
848
            w(x, False, out, accumulate=True, function_lib=function_lib)
×
849
        return out.real
×
850

851
    def tolist(self):
9✔
852
        ret = [self.start, self.stop, self.sample_rate, len(self.wlist)]
×
853
        for w in self.wlist:
×
854
            ret.extend(w.tolist())
×
855
        return ret
×
856

857
    @staticmethod
9✔
858
    def fromlist(l):
9✔
859
        w = WaveVStack()
×
860
        w.start, w.stop, w.sample_rate, n = l[:4]
×
861
        l = l[4:]
×
862
        for _ in range(n):
×
863
            wav, pos = Waveform.fromlist(l, True)
×
864
            w.wlist.append(wav)
×
865
            l = l[pos:]
×
866
        return w
×
867

868
    def simplify(self):
9✔
869
        wav = wave_sum(*self.wlist)
9✔
870
        wav.start = self.start
9✔
871
        wav.stop = self.stop
9✔
872
        wav.sample_rate = self.sample_rate
9✔
873
        return wav
9✔
874

875
    def __rshift__(self, time):
9✔
876
        return WaveVStack([w >> time for w in self.wlist])
×
877

878
    def __add__(self, other):
9✔
879
        if isinstance(other, WaveVStack):
×
880
            return WaveVStack(self.wlist + other.wlist)
×
881
        elif isinstance(other, Waveform):
×
882
            return WaveVStack(self.wlist + [other])
×
883
        else:
884
            return WaveVStack(self.wlist + [const(other)])
×
885

886
    def __radd__(self, v):
9✔
887
        return self + v
×
888

889
    def __mul__(self, other):
9✔
890
        if isinstance(other, Waveform):
×
891
            other = other.simplify()
×
892
            return WaveVStack([w * other for w in self.wlist])
×
893
        else:
894
            return WaveVStack([w * other for w in self.wlist])
×
895

896
    def __rmul__(self, v):
9✔
897
        return self * v
×
898

899
    def _repr_latex_(self):
9✔
900
        return r"\sum_{i=1}^{" + f"{len(self.wlist)}" + r"}" + r"f_i(t)"
×
901

902

903
def wave_sum(*waves):
9✔
904
    if not waves:
9✔
905
        return Waveform()
×
906

907
    bounds = list(waves[0].bounds)
9✔
908
    seq = list(waves[0].seq)
9✔
909

910
    for wave in waves[1:]:
9✔
911
        lo = 0
9✔
912
        for b, s in zip(wave.bounds, wave.seq):
9✔
913
            i = bisect_left(bounds, b, lo)
9✔
914
            if bounds[i] != b:
9✔
915
                bounds.insert(i, b)
9✔
916
                seq.insert(i, seq[i])
9✔
917
            for j in range(lo + 1, i + 1):
9✔
918
                seq[j] = _add(seq[j], s)
9✔
919
            lo = i
9✔
920

921
    return Waveform(tuple(bounds), tuple(seq))
9✔
922

923

924
def play(data, rate=48000):
9✔
925
    import io
×
926

927
    import pyaudio
×
928

929
    CHUNK = 1024
×
930

931
    max_amp = np.max(np.abs(data))
×
932

933
    if max_amp > 1:
×
934
        data /= max_amp
×
935

936
    data = np.array(2**15 * 0.999 * data, dtype=np.int16)
×
937
    buff = io.BytesIO(data.data)
×
938
    p = pyaudio.PyAudio()
×
939

940
    try:
×
941
        stream = p.open(format=pyaudio.paInt16,
×
942
                        channels=1,
943
                        rate=rate,
944
                        output=True)
945
        try:
×
946
            while True:
×
947
                data = buff.read(CHUNK)
×
948
                if data:
×
949
                    stream.write(data)
×
950
                else:
951
                    break
×
952
        finally:
953
            stream.stop_stream()
×
954
            stream.close()
×
955
    finally:
956
        p.terminate()
×
957

958

959
_zero_waveform = Waveform()
9✔
960
_one_waveform = Waveform(seq=(_one, ))
9✔
961

962

963
def zero():
9✔
964
    return _zero_waveform
9✔
965

966

967
def one():
9✔
968
    return _one_waveform
×
969

970

971
def const(c):
9✔
972
    return Waveform(seq=(_const(1.0 * c), ))
9✔
973

974

975
__TypeIndex = 1
9✔
976
_baseFunc = {}
9✔
977
_derivativeBaseFunc = {}
9✔
978
_baseFunc_latex = {}
9✔
979

980

981
def registerBaseFunc(func, latex=None):
9✔
982
    global __TypeIndex
983
    Type = __TypeIndex
9✔
984
    __TypeIndex += 1
9✔
985

986
    _baseFunc[Type] = func
9✔
987
    _baseFunc_latex[Type] = latex
9✔
988

989
    return Type
9✔
990

991

992
def packBaseFunc():
9✔
993
    return pickle.dumps(_baseFunc)
×
994

995

996
def updateBaseFunc(buf):
9✔
997
    _baseFunc.update(pickle.loads(buf))
×
998

999

1000
def registerDerivative(Type, dFunc):
9✔
1001
    _derivativeBaseFunc[Type] = dFunc
9✔
1002

1003

1004
# register base function
1005
def _format_LINEAR(shift, *args):
9✔
1006
    if shift != 0:
×
1007
        shift = _num_latex(-shift)
×
1008
        if shift[0] == '-':
×
1009
            return f"(t{shift})"
×
1010
        else:
1011
            return f"(t+{shift})"
×
1012
    else:
1013
        return 't'
×
1014

1015

1016
def _format_GAUSSIAN(shift, *args):
9✔
1017
    sigma = _num_latex(args[0] / np.sqrt(2))
×
1018
    shift = _num_latex(-shift)
×
1019
    if shift != '0':
×
1020
        if shift[0] != '-':
×
1021
            shift = '+' + shift
×
1022
        if sigma == '1':
×
1023
            return ('\\exp\\left[-\\frac{\\left(t' + shift +
×
1024
                    '\\right)^2}{2}\\right]')
1025
        else:
1026
            return ('\\exp\\left[-\\frac{1}{2}\\left(\\frac{t' + shift + '}{' +
×
1027
                    sigma + '}\\right)^2\\right]')
1028
    else:
1029
        if sigma == '1':
×
1030
            return ('\\exp\\left(-\\frac{t^2}{2}\\right)')
×
1031
        else:
1032
            return ('\\exp\\left[-\\frac{1}{2}\\left(\\frac{t}{' + sigma +
×
1033
                    '}\\right)^2\\right]')
1034

1035

1036
def _format_SINC(shift, *args):
9✔
1037
    shift = _num_latex(-shift)
×
1038
    bw = _num_latex(args[0])
×
1039
    if shift != '0':
×
1040
        if shift[0] != '-':
×
1041
            shift = '+' + shift
×
1042
        if bw == '1':
×
1043
            return '\\mathrm{sinc}(t' + shift + ')'
×
1044
        else:
1045
            return '\\mathrm{sinc}[' + bw + '(t' + shift + ')]'
×
1046
    else:
1047
        if bw == '1':
×
1048
            return '\\mathrm{sinc}(t)'
×
1049
        else:
1050
            return '\\mathrm{sinc}(' + bw + 't)'
×
1051

1052

1053
def _format_COSINE(shift, *args):
9✔
1054
    freq = args[0] / 2 / np.pi
×
1055
    phase = -shift * freq
×
1056
    freq = _num_latex(freq)
×
1057
    if freq == '1':
×
1058
        freq = ''
×
1059
    phase = _num_latex(phase)
×
1060
    if phase == '0':
×
1061
        phase = ''
×
1062
    elif phase[0] != '-':
×
1063
        phase = '+' + phase
×
1064
    if phase != '':
×
1065
        return f'\\cos\\left[2\\pi({freq}t{phase})\\right]'
×
1066
    elif freq != '':
×
1067
        return f'\\cos\\left(2\\pi\\times {freq}t\\right)'
×
1068
    else:
1069
        return '\\cos(2\\pi t)'
×
1070

1071

1072
def _format_ERF(shift, *args):
9✔
1073
    if shift != 0:
×
1074
        return '\\mathrm{erf}(\\frac{t-' + f"{shift:g}" + '}{' + f'{args[0]:g}' + '})'
×
1075
    else:
1076
        return '\\mathrm{erf}(\\frac{t}{' + f'{args[0]:g}' + '})'
×
1077

1078

1079
def _format_COSH(shift, *args):
9✔
1080
    if shift != 0:
×
1081
        return '\\cosh(\\frac{t-' + f"{shift:g}" + '}{' + f'{1/args[0]:g}' + '})'
×
1082
    else:
1083
        return '\\cosh(\\frac{t}{' + f'{1/args[0]:g}' + '})'
×
1084

1085

1086
def _format_SINH(shift, *args):
9✔
1087
    if shift != 0:
×
1088
        return '\\sinh(\\frac{t-' + f"{shift:g}" + '}{' + f'{args[0]:g}' + '})'
×
1089
    else:
1090
        return '\\sinh(\\frac{t}{' + f'{args[0]:g}' + '})'
×
1091

1092

1093
def _format_EXP(shift, *args):
9✔
1094
    if shift != 0:
×
1095
        return '\\exp(-' + f'{args[0]:g}' + '(t-' + f"{shift:g}" + '))'
×
1096
    else:
1097
        return '\\exp(-' + f'{args[0]:g}' + 't)'
×
1098

1099

1100
LINEAR = registerBaseFunc(lambda t: t, _format_LINEAR)
9✔
1101
GAUSSIAN = registerBaseFunc(lambda t, std_sq2: np.exp(-(t / std_sq2)**2),
9✔
1102
                            _format_GAUSSIAN)
1103
ERF = registerBaseFunc(lambda t, std_sq2: special.erf(t / std_sq2),
9✔
1104
                       _format_ERF)
1105
COS = registerBaseFunc(lambda t, w: np.cos(w * t), _format_COSINE)
9✔
1106
SINC = registerBaseFunc(lambda t, bw: np.sinc(bw * t), _format_SINC)
9✔
1107
EXP = registerBaseFunc(lambda t, alpha: np.exp(alpha * t), _format_EXP)
9✔
1108
INTERP = registerBaseFunc(lambda t, start, stop, points: np.interp(
9✔
1109
    t, np.linspace(start, stop, len(points)), points))
1110
LINEARCHIRP = registerBaseFunc(lambda t, f0, f1, T, phi0: np.sin(
9✔
1111
    phi0 + 2 * np.pi * ((f1 - f0) / (2 * T) * t**2 + f0 * t)))
1112
EXPONENTIALCHIRP = registerBaseFunc(lambda t, f0, alpha, phi0: np.sin(
9✔
1113
    phi0 + 2 * pi * f0 * (np.exp(alpha * t) - 1) / alpha))
1114
HYPERBOLICCHIRP = registerBaseFunc(lambda t, f0, k, phi0: np.sin(
9✔
1115
    phi0 + 2 * np.pi * f0 / k * np.log(1 + k * t)))
1116
COSH = registerBaseFunc(lambda t, w: np.cosh(w * t), _format_COSH)
9✔
1117
SINH = registerBaseFunc(lambda t, w: np.sinh(w * t), _format_SINH)
9✔
1118

1119

1120
def _drag(t: np.ndarray, t0: float, freq: float, width: float, delta: float,
9✔
1121
          block_freq: float, phase: float):
1122

1123
    o = np.pi / width
×
1124
    Omega_x = np.sin(o * (t - t0))**2
×
1125
    wt = 2 * np.pi * (freq + delta) * t - (2 * np.pi * delta * t0 + phase)
×
1126

1127
    if block_freq is None or block_freq - delta == 0:
×
1128
        return Omega_x * np.cos(wt)
×
1129

1130
    b = 1 / np.pi / 2 / (block_freq - delta)
×
1131
    Omega_y = -b * o * np.sin(2 * o * (t - t0))
×
1132

1133
    return Omega_x * np.cos(wt) + Omega_y * np.sin(wt)
×
1134

1135

1136
def _format_DRAG(shift, *args):
9✔
1137
    return f"DRAG(...)"
×
1138

1139

1140
DRAG = registerBaseFunc(_drag, _format_DRAG)
9✔
1141

1142
# register derivative
1143
registerDerivative(LINEAR, lambda shift, *args: _one)
9✔
1144

1145
registerDerivative(
9✔
1146
    GAUSSIAN, lambda shift, *args: (((((LINEAR, shift),
1147
                                       (GAUSSIAN, *args, shift)), (1, 1)), ),
1148
                                    (-2 / args[0]**2, )))
1149

1150
registerDerivative(
9✔
1151
    ERF, lambda shift, *args: (((((GAUSSIAN, *args, shift), ), (1, )), ),
1152
                               (2 / args[0] / np.sqrt(pi), )))
1153

1154
registerDerivative(
9✔
1155
    COS, lambda shift, *args: (((((COS, args[0], shift - pi / args[0] / 2), ),
1156
                                 (1, )), ), (args[0], )))
1157

1158
registerDerivative(
9✔
1159
    SINC, lambda shift, *args:
1160
    (((((LINEAR, shift), (COS, *args, shift)), (-1, 1)),
1161
      (((LINEAR, shift), (COS, args[0], args[1] - pi / 2, shift)), (-2, 1))),
1162
     (1, -1 / args[0])))
1163

1164
registerDerivative(
9✔
1165
    EXP, lambda shift, *args: (((((EXP, *args, shift), ), (1, )), ),
1166
                               (args[0], )))
1167

1168
registerDerivative(
9✔
1169
    INTERP, lambda shift, start, stop, points:
1170
    (((((INTERP, start, stop, tuple(np.gradient(np.asarray(points))), shift),
1171
        ), (1, )), ), ((len(points) - 1) / (stop - start), )))
1172

1173
registerDerivative(
9✔
1174
    COSH, lambda shift, *args: (((((SINH, *args, shift), ), (1, )), ),
1175
                                (args[0], )))
1176

1177
registerDerivative(
9✔
1178
    SINH, lambda shift, *args: (((((COSH, *args, shift), ), (1, )), ),
1179
                                (args[0], )))
1180

1181

1182
def _d_LINEARCHIRP(shift, f0, f1, T, phi0):
9✔
1183
    tlist = (
×
1184
        (((LINEARCHIRP, f0, f1, T, phi0 + pi / 2, shift), ), (1, )),
1185
        (((LINEAR, shift), (LINEARCHIRP, f0, f1, T, phi0 + pi / 2, shift)),
1186
         (1, 1)),
1187
    )
1188
    alist = (2 * pi * f0, 2 * pi * (f1 - f0) / T)
×
1189

1190
    if f0 == 0:
×
1191
        return tlist[1:], alist[1:]
×
1192
    else:
1193
        return tlist, alist
×
1194

1195

1196
registerDerivative(LINEARCHIRP, _d_LINEARCHIRP)
9✔
1197
registerDerivative(
9✔
1198
    EXPONENTIALCHIRP, lambda shift, f0, alpha, phi0:
1199
    (((((EXP, alpha, shift),
1200
        (EXPONENTIALCHIRP, f0, alpha, phi0 + pi / 2, shift)), (1, 1)), ),
1201
     (2 * pi * f0, )))
1202
registerDerivative(
9✔
1203
    HYPERBOLICCHIRP, lambda shift, f0, k, phi0:
1204
    (((((LINEAR, shift - 1 / k),
1205
        (HYPERBOLICCHIRP, f0, k, phi0 + pi / 2, shift)), (-1, 1)), ),
1206
     (2 * pi * f0, )))
1207

1208

1209
def _D_base(m):
9✔
1210
    Type, *args, shift = m
×
1211
    return _derivativeBaseFunc[Type](shift, *args)
×
1212

1213

1214
def _D(x):
9✔
1215
    if _is_const(x):
×
1216
        return _zero
×
1217
    t_list, v_list = x
×
1218
    if len(v_list) == 1:
×
1219
        (m_list, n_list), v = t_list[0], v_list[0]
×
1220
        if len(m_list) == 1:
×
1221
            m, n = m_list[0], n_list[0]
×
1222
            if n == 1:
×
1223
                return _mul(_D_base(m), _const(v))
×
1224
            else:
1225
                return _mul(((((m, ), (n - 1, )), ), (n * v, )),
×
1226
                            _D(((((m, ), (1, )), ), (1, ))))
1227
        else:
1228
            a = (((m_list[:1], n_list[:1]), ), (v, ))
×
1229
            b = (((m_list[1:], n_list[1:]), ), (1, ))
×
1230
            return _add(_mul(a, _D(b)), _mul(_D(a), b))
×
1231
    else:
1232
        return _add(_D((t_list[:1], v_list[:1])), _D((t_list[1:], v_list[1:])))
×
1233

1234

1235
def D(wav):
9✔
1236
    """derivative
1237
    """
1238
    return Waveform(bounds=wav.bounds, seq=tuple(_D(x) for x in wav.seq))
×
1239

1240

1241
def convolve(a, b):
9✔
1242
    pass
×
1243

1244

1245
def sign():
9✔
1246
    return Waveform(bounds=(0, +inf), seq=(_const(-1), _one))
×
1247

1248

1249
def step(edge, type='erf'):
9✔
1250
    """
1251
    type: "erf", "cos", "linear"
1252
    """
1253
    if edge == 0:
9✔
1254
        return Waveform(bounds=(0, +inf), seq=(_zero, _one))
×
1255
    if type == 'cos':
9✔
1256
        rise = _add(_half,
×
1257
                    _mul(_half, _basic_wave(COS, pi / edge, shift=0.5 * edge)))
1258
        return Waveform(bounds=(round(-edge / 2,
×
1259
                                      NDIGITS), round(edge / 2,
1260
                                                      NDIGITS), +inf),
1261
                        seq=(_zero, rise, _one))
1262
    elif type == 'linear':
9✔
1263
        rise = _add(_half, _mul(_const(1 / edge), _basic_wave(LINEAR)))
9✔
1264
        return Waveform(bounds=(round(-edge / 2,
9✔
1265
                                      NDIGITS), round(edge / 2,
1266
                                                      NDIGITS), +inf),
1267
                        seq=(_zero, rise, _one))
1268
    else:
1269
        std_sq2 = edge / 5
9✔
1270
        # rise = _add(_half, _mul(_half, _basic_wave(ERF, std_sq2)))
1271
        rise = ((((), ()), (((ERF, std_sq2, 0), ), (1, ))), (0.5, 0.5))
9✔
1272
        return Waveform(bounds=(-round(edge, NDIGITS), round(edge,
9✔
1273
                                                             NDIGITS), +inf),
1274
                        seq=(_zero, rise, _one))
1275

1276

1277
def square(width, edge=0, type='erf'):
9✔
1278
    if width <= 0:
9✔
1279
        return zero()
×
1280
    if edge == 0:
9✔
1281
        return Waveform(bounds=(round(-0.5 * width,
9✔
1282
                                      NDIGITS), round(0.5 * width,
1283
                                                      NDIGITS), +inf),
1284
                        seq=(_zero, _one, _zero))
1285
    else:
1286
        return ((step(edge, type=type) << width / 2) -
9✔
1287
                (step(edge, type=type) >> width / 2))
1288

1289

1290
def gaussian(width, plateau=0.0):
9✔
1291
    if width <= 0 and plateau <= 0.0:
9✔
1292
        return zero()
×
1293
    # width is two times FWHM
1294
    # std_sq2 = width / (4 * np.sqrt(np.log(2)))
1295
    std_sq2 = width / 3.3302184446307908
9✔
1296
    # std is set to give total pulse area same as a square
1297
    #std_sq2 = width/np.sqrt(np.pi)
1298
    if round(0.5 * plateau, NDIGITS) <= 0.0:
9✔
1299
        return Waveform(bounds=(round(-0.75 * width,
9✔
1300
                                      NDIGITS), round(0.75 * width,
1301
                                                      NDIGITS), +inf),
1302
                        seq=(_zero, _basic_wave(GAUSSIAN, std_sq2), _zero))
1303
    else:
1304
        return Waveform(bounds=(round(-0.75 * width - 0.5 * plateau,
×
1305
                                      NDIGITS), round(-0.5 * plateau, NDIGITS),
1306
                                round(0.5 * plateau, NDIGITS),
1307
                                round(0.75 * width + 0.5 * plateau,
1308
                                      NDIGITS), +inf),
1309
                        seq=(_zero,
1310
                             _basic_wave(GAUSSIAN,
1311
                                         std_sq2,
1312
                                         shift=-0.5 * plateau), _one,
1313
                             _basic_wave(GAUSSIAN,
1314
                                         std_sq2,
1315
                                         shift=0.5 * plateau), _zero))
1316

1317

1318
def cos(w, phi=0):
9✔
1319
    if w == 0:
9✔
1320
        return const(np.cos(phi))
×
1321
    if w < 0:
9✔
1322
        phi = -phi
9✔
1323
        w = -w
9✔
1324
    return Waveform(seq=(_basic_wave(COS, w, shift=-phi / w), ))
9✔
1325

1326

1327
def sin(w, phi=0):
9✔
1328
    if w == 0:
9✔
1329
        return const(np.sin(phi))
×
1330
    if w < 0:
9✔
1331
        phi = -phi + pi
9✔
1332
        w = -w
9✔
1333
    return Waveform(seq=(_basic_wave(COS, w, shift=(pi / 2 - phi) / w), ))
9✔
1334

1335

1336
def exp(alpha):
9✔
1337
    if isinstance(alpha, complex):
9✔
1338
        if alpha.real == 0:
9✔
1339
            return cos(alpha.imag) + 1j * sin(alpha.imag)
×
1340
        else:
1341
            return exp(alpha.real) * (cos(alpha.imag) + 1j * sin(alpha.imag))
9✔
1342
    else:
1343
        return Waveform(seq=(_basic_wave(EXP, alpha), ))
9✔
1344

1345

1346
def sinc(bw):
9✔
1347
    if bw <= 0:
×
1348
        return zero()
×
1349
    width = 100 / bw
×
1350
    return Waveform(bounds=(round(-0.5 * width,
×
1351
                                  NDIGITS), round(0.5 * width, NDIGITS), +inf),
1352
                    seq=(_zero, _basic_wave(SINC, bw), _zero))
1353

1354

1355
def cosPulse(width, plateau=0.0):
9✔
1356
    # cos = _basic_wave(COS, 2*np.pi/width)
1357
    # pulse = _mul(_add(cos, _one), _half)
1358
    if round(0.5 * plateau, NDIGITS) > 0:
9✔
1359
        return square(plateau + 0.5 * width, edge=0.5 * width, type='cos')
×
1360
    if width <= 0:
9✔
1361
        return zero()
×
1362
    pulse = ((((), ()), (((COS, 6.283185307179586 / width, 0), ), (1, ))),
9✔
1363
             (0.5, 0.5))
1364
    return Waveform(bounds=(round(-0.5 * width,
9✔
1365
                                  NDIGITS), round(0.5 * width, NDIGITS), +inf),
1366
                    seq=(_zero, pulse, _zero))
1367

1368

1369
def hanning(width, plateau=0.0):
9✔
1370
    return cosPulse(width, plateau=plateau)
×
1371

1372

1373
def cosh(w):
9✔
1374
    return Waveform(seq=(_basic_wave(COSH, w), ))
×
1375

1376

1377
def sinh(w):
9✔
1378
    return Waveform(seq=(_basic_wave(SINH, w), ))
×
1379

1380

1381
def coshPulse(width, eps=1.0, plateau=0.0):
9✔
1382
    """Cosine hyperbolic pulse with the following im
1383

1384
    pulse edge shape:
1385
            cosh(eps / 2) - cosh(eps * t / T)
1386
    f(t) = -----------------------------------
1387
                  cosh(eps / 2) - 1
1388
    where T is the pulse width and eps is the pulse edge steepness.
1389
    The pulse is defined for t in [-T/2, T/2].
1390

1391
    In case of plateau > 0, the pulse is defined as:
1392
           | f(t + plateau/2)   if t in [-T/2 - plateau/2, - plateau/2]
1393
    g(t) = | 1                  if t in [-plateau/2, plateau/2]
1394
           | f(t - plateau/2)   if t in [plateau/2, T/2 + plateau/2]
1395

1396
    Parameters
1397
    ----------
1398
    width : float
1399
        Pulse width.
1400
    eps : float
1401
        Pulse edge steepness.
1402
    plateau : float
1403
        Pulse plateau.
1404
    """
1405
    if width <= 0 and plateau <= 0:
×
1406
        return zero()
×
1407
    w = eps / width
×
1408
    A = np.cosh(eps / 2)
×
1409

1410
    if plateau == 0.0 or round(-0.5 * plateau, NDIGITS) == round(
×
1411
            0.5 * plateau, NDIGITS):
1412
        pulse = ((((), ()), (((COSH, w, 0), ), (1, ))), (A / (A - 1),
×
1413
                                                         -1 / (A - 1)))
1414
        return Waveform(bounds=(round(-0.5 * width,
×
1415
                                      NDIGITS), round(0.5 * width,
1416
                                                      NDIGITS), +inf),
1417
                        seq=(_zero, pulse, _zero))
1418
    else:
1419
        raising = ((((), ()), (((COSH, w, -0.5 * plateau), ), (1, ))),
×
1420
                   (A / (A - 1), -1 / (A - 1)))
1421
        falling = ((((), ()), (((COSH, w, 0.5 * plateau), ), (1, ))),
×
1422
                   (A / (A - 1), -1 / (A - 1)))
1423
        return Waveform(bounds=(round(-0.5 * width - 0.5 * plateau,
×
1424
                                      NDIGITS), round(-0.5 * plateau, NDIGITS),
1425
                                round(0.5 * plateau, NDIGITS),
1426
                                round(0.5 * width + 0.5 * plateau,
1427
                                      NDIGITS), +inf),
1428
                        seq=(_zero, raising, _one, falling, _zero))
1429

1430

1431
def general_cosine(duration, *arg):
9✔
1432
    wav = zero()
×
1433
    arg = np.asarray(arg)
×
1434
    arg /= arg[::2].sum()
×
1435
    for i, a in enumerate(arg, start=1):
×
1436
        wav += a / 2 * (1 - (-1)**i * cos(i * 2 * pi / duration))
×
1437
    return wav * square(duration)
×
1438

1439

1440
def slepian(duration, *arg):
9✔
1441
    wav = zero()
×
1442
    arg = np.asarray(arg)
×
1443
    arg /= arg[::2].sum()
×
1444
    for i, a in enumerate(arg, start=1):
×
1445
        wav += a / 2 * (1 - (-1)**i * cos(i * 2 * pi / duration))
×
1446
    return wav * square(duration)
×
1447

1448

1449
def _poly(*a):
9✔
1450
    """
1451
    a[0] + a[1] * t + a[2] * t**2 + ...
1452
    """
1453
    t = []
9✔
1454
    amp = []
9✔
1455
    if a[0] != 0:
9✔
1456
        t.append(((), ()))
9✔
1457
        amp.append(a[0])
9✔
1458
    for n, a_ in enumerate(a[1:], start=1):
9✔
1459
        if a_ != 0:
9✔
1460
            t.append((((LINEAR, 0), ), (n, )))
9✔
1461
            amp.append(a_)
9✔
1462
    return tuple(t), tuple(a)
9✔
1463

1464

1465
def poly(a):
9✔
1466
    """
1467
    a[0] + a[1] * t + a[2] * t**2 + ...
1468
    """
1469
    return Waveform(seq=(_poly(*a), ))
9✔
1470

1471

1472
def t():
9✔
1473
    return Waveform(seq=((((LINEAR, 0), ), (1, )), (1, )))
×
1474

1475

1476
def drag(freq, width, plateau=0, delta=0, block_freq=None, phase=0, t0=0):
9✔
1477
    phase += pi * delta * (width + plateau)
×
1478
    if plateau <= 0:
×
1479
        return Waveform(seq=(_zero,
×
1480
                             _basic_wave(DRAG, t0, freq, width, delta,
1481
                                         block_freq, phase), _zero),
1482
                        bounds=(round(t0, NDIGITS), round(t0 + width,
1483
                                                          NDIGITS), +inf))
1484
    elif width <= 0:
×
1485
        w = 2 * pi * (freq + delta)
×
1486
        return Waveform(
×
1487
            seq=(_zero,
1488
                 _basic_wave(COS, w,
1489
                             shift=(phase + 2 * pi * delta * t0) / w), _zero),
1490
            bounds=(round(t0, NDIGITS), round(t0 + plateau, NDIGITS), +inf))
1491
    else:
1492
        w = 2 * pi * (freq + delta)
×
1493
        return Waveform(
×
1494
            seq=(_zero,
1495
                 _basic_wave(DRAG, t0, freq, width, delta, block_freq, phase),
1496
                 _basic_wave(COS, w, shift=(phase + 2 * pi * delta * t0) / w),
1497
                 _basic_wave(DRAG, t0 + plateau, freq, width, delta,
1498
                             block_freq,
1499
                             phase - 2 * pi * delta * plateau), _zero),
1500
            bounds=(round(t0, NDIGITS), round(t0 + width / 2, NDIGITS),
1501
                    round(t0 + width / 2 + plateau,
1502
                          NDIGITS), round(t0 + width + plateau,
1503
                                          NDIGITS), +inf))
1504

1505

1506
def chirp(f0, f1, T, phi0=0, type='linear'):
9✔
1507
    """
1508
    A chirp is a signal in which the frequency increases (up-chirp)
1509
    or decreases (down-chirp) with time. In some sources, the term
1510
    chirp is used interchangeably with sweep signal.
1511

1512
    type: "linear", "exponential", "hyperbolic"
1513
    """
1514
    if f0 == f1:
9✔
1515
        return sin(f0, phi0)
×
1516
    if T <= 0:
9✔
1517
        raise ValueError('T must be positive')
×
1518

1519
    if type == 'linear':
9✔
1520
        # f(t) = f1 * (t/T) + f0 * (1 - t/T)
1521
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
9✔
1522
                        seq=(_zero, _basic_wave(LINEARCHIRP, f0, f1, T,
1523
                                                phi0), _zero))
1524
    elif type in ['exp', 'exponential', 'geometric']:
9✔
1525
        # f(t) = f0 * (f1/f0) ** (t/T)
1526
        if f0 == 0:
9✔
1527
            raise ValueError('f0 must be non-zero')
×
1528
        alpha = np.log(f1 / f0) / T
9✔
1529
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
9✔
1530
                        seq=(_zero,
1531
                             _basic_wave(EXPONENTIALCHIRP, f0, alpha,
1532
                                         phi0), _zero))
1533
    elif type in ['hyperbolic', 'hyp']:
9✔
1534
        # f(t) = f0 * f1 / (f0 * (t/T) + f1 * (1-t/T))
1535
        if f0 * f1 == 0:
9✔
1536
            return const(np.sin(phi0))
×
1537
        k = (f0 - f1) / (f1 * T)
9✔
1538
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
9✔
1539
                        seq=(_zero, _basic_wave(HYPERBOLICCHIRP, f0, k,
1540
                                                phi0), _zero))
1541
    else:
1542
        raise ValueError(f'unknown type {type}')
×
1543

1544

1545
def interp(x, y):
9✔
1546
    seq, bounds = [_zero], [x[0]]
×
1547
    for x1, x2, y1, y2 in zip(x[:-1], x[1:], y[:-1], y[1:]):
×
1548
        if x2 == x1:
×
1549
            continue
×
1550
        seq.append(
×
1551
            _add(
1552
                _mul(_const((y2 - y1) / (x2 - x1)),
1553
                     _basic_wave(LINEAR, shift=x1)), _const(y1)))
1554
        bounds.append(x2)
×
1555
    bounds.append(inf)
×
1556
    seq.append(_zero)
×
1557
    return Waveform(seq=tuple(seq),
×
1558
                    bounds=tuple(round(b, NDIGITS)
1559
                                 for b in bounds)).simplify()
1560

1561

1562
def cut(wav, start=None, stop=None, head=None, tail=None, min=None, max=None):
9✔
1563
    offset = 0
×
1564
    if start is not None and head is not None:
×
1565
        offset = head - wav(np.array([1.0 * start]))[0]
×
1566
    elif stop is not None and tail is not None:
×
1567
        offset = tail - wav(np.array([1.0 * stop]))[0]
×
1568
    wav = wav + offset
×
1569

1570
    if start is not None:
×
1571
        wav = wav * (step(0) >> start)
×
1572
    if stop is not None:
×
1573
        wav = wav * ((1 - step(0)) >> stop)
×
1574
    if min is not None:
×
1575
        wav.min = min
×
1576
    if max is not None:
×
1577
        wav.max = max
×
1578
    return wav
×
1579

1580

1581
def function(fun, *args, start=None, stop=None):
9✔
1582
    TYPEID = registerBaseFunc(fun)
×
1583
    seq = (_basic_wave(TYPEID, *args), )
×
1584
    wav = Waveform(seq=seq)
×
1585
    if start is not None:
×
1586
        wav = wav * (step(0) >> start)
×
1587
    if stop is not None:
×
1588
        wav = wav * ((1 - step(0)) >> stop)
×
1589
    return wav
×
1590

1591

1592
def samplingPoints(start, stop, points):
9✔
1593
    return Waveform(bounds=(round(start, NDIGITS), round(stop, NDIGITS), inf),
×
1594
                    seq=(_zero, _basic_wave(INTERP, start, stop,
1595
                                            tuple(points)), _zero))
1596

1597

1598
def mixing(I,
9✔
1599
           Q=None,
1600
           *,
1601
           phase=0.0,
1602
           freq=0.0,
1603
           ratioIQ=1.0,
1604
           phaseDiff=0.0,
1605
           block_freq=None,
1606
           DRAGScaling=None):
1607
    """SSB or envelope mixing
1608
    """
1609
    if Q is None:
9✔
1610
        I = I
9✔
1611
        Q = zero()
9✔
1612

1613
    w = 2 * pi * freq
9✔
1614
    if freq != 0.0:
9✔
1615
        # SSB mixing
1616
        Iout = I * cos(w, -phase) + Q * sin(w, -phase)
9✔
1617
        Qout = -I * sin(w, -phase + phaseDiff) + Q * cos(w, -phase + phaseDiff)
9✔
1618
    else:
1619
        # envelope mixing
1620
        Iout = I * np.cos(-phase) + Q * np.sin(-phase)
9✔
1621
        Qout = -I * np.sin(-phase) + Q * np.cos(-phase)
9✔
1622

1623
    # apply DRAG
1624
    if block_freq is not None and block_freq != freq:
9✔
1625
        a = block_freq / (block_freq - freq)
×
1626
        b = 1 / (block_freq - freq)
×
1627
        I = a * Iout + b / (2 * pi) * D(Qout)
×
1628
        Q = a * Qout - b / (2 * pi) * D(Iout)
×
1629
        Iout, Qout = I, Q
×
1630
    elif DRAGScaling is not None and DRAGScaling != 0:
9✔
1631
        # 2 * pi * scaling * (freq - block_freq) = 1
1632
        I = (1 - w * DRAGScaling) * Iout - DRAGScaling * D(Qout)
×
1633
        Q = (1 - w * DRAGScaling) * Qout + DRAGScaling * D(Iout)
×
1634
        Iout, Qout = I, Q
×
1635

1636
    Qout = ratioIQ * Qout
9✔
1637

1638
    return Iout, Qout
9✔
1639

1640

1641
__all__ = [
9✔
1642
    'D', 'Waveform', 'chirp', 'const', 'cos', 'cosh', 'coshPulse', 'cosPulse',
1643
    'cut', 'drag', 'exp', 'function', 'gaussian', 'general_cosine', 'hanning',
1644
    'interp', 'mixing', 'one', 'poly', 'registerBaseFunc',
1645
    'registerDerivative', 'samplingPoints', 'sign', 'sin', 'sinc', 'sinh',
1646
    'square', 'step', 't', 'zero'
1647
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc