• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

feihoo87 / waveforms / 6924783199

20 Nov 2023 02:30AM UTC coverage: 43.548%. First build
6924783199

push

github

feihoo87
fix WaveVStack.__getstate___

3 of 13 new or added lines in 2 files covered. (23.08%)

7523 of 17275 relevant lines covered (43.55%)

3.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.33
/waveforms/waveform.py
1
import pickle
9✔
2
from bisect import bisect_left
9✔
3
from fractions import Fraction
9✔
4
from itertools import chain, product
9✔
5
from math import comb
9✔
6

7
import numpy as np
9✔
8
import scipy.special as special
9✔
9
from numpy import e, inf, pi
9✔
10
from scipy.signal import sosfilt
9✔
11

12
NDIGITS = 15
9✔
13

14
_zero = ((), ())
9✔
15

16

17
def _const(c):
9✔
18
    if c == 0:
9✔
19
        return _zero
9✔
20
    return (((), ()), ), (c, )
9✔
21

22

23
_one = _const(1.0)
9✔
24
_half = _const(1 / 2)
9✔
25
_two = _const(2.0)
9✔
26
_pi = _const(pi)
9✔
27
_two_pi = _const(2 * pi)
9✔
28
_half_pi = _const(pi / 2)
9✔
29

30

31
def _is_const(x):
9✔
32
    return x == _zero or x[0] == (((), ()), )
9✔
33

34

35
def _basic_wave(Type, *args, shift=0):
9✔
36
    return ((((Type, *args, shift), ), (1, )), ), (1.0, )
9✔
37

38

39
def _insert_type_value_pair(t_list, v_list, t, v, lo, hi):
9✔
40
    i = bisect_left(t_list, t, lo, hi)
9✔
41
    if i < hi and t_list[i] == t:
9✔
42
        v += v_list[i]
9✔
43
        if v == 0:
9✔
44
            t_list.pop(i)
9✔
45
            v_list.pop(i)
9✔
46
            return i, hi - 1
9✔
47
        else:
48
            v_list[i] = v
9✔
49
            return i, hi
9✔
50
    else:
51
        t_list.insert(i, t)
9✔
52
        v_list.insert(i, v)
9✔
53
        return i, hi + 1
9✔
54

55

56
def _mul(x, y):
9✔
57
    t_list, v_list = [], []
9✔
58
    xt_list, xv_list = x
9✔
59
    yt_list, yv_list = y
9✔
60
    lo, hi = 0, 0
9✔
61
    for (t1, t2), (v1, v2) in zip(product(xt_list, yt_list),
9✔
62
                                  product(xv_list, yv_list)):
63
        if v1 * v2 == 0:
9✔
64
            continue
×
65
        t = _add(t1, t2)
9✔
66
        lo, hi = _insert_type_value_pair(t_list, v_list, t, v1 * v2, lo, hi)
9✔
67
    return tuple(t_list), tuple(v_list)
9✔
68

69

70
def _add(x, y):
9✔
71
    # x, y = (x, y) if len(x[0]) >= len(y[0]) else (y, x)
72
    t_list, v_list = list(x[0]), list(x[1])
9✔
73
    lo, hi = 0, len(t_list)
9✔
74
    for t, v in zip(*y):
9✔
75
        lo, hi = _insert_type_value_pair(t_list, v_list, t, v, lo, hi)
9✔
76
    return tuple(t_list), tuple(v_list)
9✔
77

78

79
def _shift(x, time):
9✔
80
    if _is_const(x):
9✔
81
        return x
9✔
82

83
    t_list = []
9✔
84

85
    for pre_mtlist, nlist in x[0]:
9✔
86
        mtlist = []
9✔
87
        for Type, *args, shift in pre_mtlist:
9✔
88
            mtlist.append((Type, *args, shift + time))
9✔
89
        t_list.append((tuple(mtlist), nlist))
9✔
90
    return tuple(t_list), x[1]
9✔
91

92

93
def _pow(x, n):
9✔
94
    if x == _zero:
9✔
95
        return _zero
×
96
    if n == 0:
9✔
97
        return _one
×
98
    if _is_const(x):
9✔
99
        return _const(x[1][0]**n)
×
100

101
    if len(x[0]) == 1:
9✔
102
        t_list, v_list = [], []
×
103
        for (mtlist, pre_nlist), v in zip(*x):
×
104
            nlist = []
×
105
            for m in pre_nlist:
×
106
                nlist.append(n * m)
×
107
            t_list.append((mtlist, tuple(nlist)))
×
108
            v_list.append(v**n)
×
109
        return tuple(t_list), tuple(v_list)
×
110
    else:
111
        assert isinstance(n, int) and n > 0
9✔
112
        ret = _one
9✔
113
        for i in range(n):
9✔
114
            ret = _mul(ret, x)
9✔
115
        return ret
9✔
116

117

118
def _cos_power_n(x, n):
9✔
119
    _, w, shift = x
9✔
120
    ret = _zero
9✔
121
    for k in range(0, n // 2 + 1):
9✔
122
        if n == 2 * k:
9✔
123
            a = _const(comb(n, k) / 2**n)
9✔
124
            ret = _add(ret, a)
9✔
125
        else:
126
            expr = (((((COS, (n - 2 * k) * w, shift), ), (1, )), ),
9✔
127
                    (comb(n, k) / 2**(n - 1), ))
128
            ret = _add(ret, expr)
9✔
129
    return ret
9✔
130

131

132
def _trigMul_t(x, y, v):
9✔
133
    """cos(a)cos(b) = 0.5*cos(a+b)+0.5*cos(a-b)"""
134
    _, w1, t1 = x
9✔
135
    _, w2, t2 = y
9✔
136
    if w2 > w1:
9✔
137
        t1, t2 = t2, t1
9✔
138
        w1, w2 = w2, w1
9✔
139
    exp1 = (COS, w1 + w2, (w1 * t1 + w2 * t2) / (w1 + w2))
9✔
140
    if w1 == w2:
9✔
141
        c = v * np.cos(w1 * t1 - w2 * t2) / 2
9✔
142
        if c == 0:
9✔
143
            return (((exp1, ), (1, )), ), (0.5 * v, )
×
144
        else:
145
            return (((), ()), ((exp1, ), (1, ))), (c, 0.5 * v)
9✔
146
    else:
147
        exp2 = (COS, w1 - w2, (w1 * t1 - w2 * t2) / (w1 - w2))
9✔
148
        if exp2[1] > exp1[1]:
9✔
149
            exp2, exp1 = exp1, exp2
×
150
        return (((exp2, ), (1, )), ((exp1, ), (1, ))), (0.5 * v, 0.5 * v)
9✔
151

152

153
def _trigMul(x, y):
9✔
154
    if _is_const(x) or _is_const(y):
9✔
155
        return _mul(x, y)
9✔
156
    ret = _zero
9✔
157
    for (t1, t2), (v1, v2) in zip(product(x[0], y[0]), product(x[1], y[1])):
9✔
158
        v = v1 * v2
9✔
159
        tmp = _one
9✔
160
        trig = []
9✔
161
        for mt, n in zip(chain(t1[0], t2[0]), chain(t1[1], t2[1])):
9✔
162
            if mt[0] == COS:
9✔
163
                trig.append(mt)
9✔
164
            else:
165
                tmp = _mul(tmp, ((((mt, ), (n, )), ), (1, )))
×
166
        if len(trig) == 1:
9✔
167
            x = ((((trig[0], ), (1, )), ), (v, ))
9✔
168
            expr = _mul(tmp, x)
9✔
169
        elif len(trig) == 2:
9✔
170
            expr = _trigMul_t(trig[0], trig[1], v)
9✔
171
            expr = _mul(tmp, expr)
9✔
172
        else:
173
            expr = _mul(tmp, _const(v))
×
174
        ret = _add(ret, expr)
9✔
175
    return ret
9✔
176

177

178
def _exp_trig_Reduce(mtlist, v):
9✔
179
    trig = _one
9✔
180
    alpha = 0
9✔
181
    shift = 0
9✔
182
    ml, nl = [], []
9✔
183
    for mt, n in zip(*mtlist):
9✔
184
        if mt[0] == COS:
9✔
185
            trig = _trigMul(trig, _cos_power_n(mt, n))
9✔
186
        elif mt[0] == EXP:
9✔
187
            x = alpha * shift + n * mt[1] * mt[-1]
9✔
188
            alpha += n * mt[1]
9✔
189
            if alpha == 0:
9✔
190
                shift = 0
×
191
            else:
192
                shift = x / alpha
9✔
193
        elif mt[0] == GAUSSIAN and n != 1:
9✔
194
            ml.append((mt[0], mt[1] / np.sqrt(n), mt[2]))
×
195
            nl.append(1)
×
196
        else:
197
            ml.append(mt)
9✔
198
            nl.append(n)
9✔
199
    ret = (((tuple(ml), tuple(nl)), ), (v, ))
9✔
200

201
    if alpha != 0:
9✔
202
        ret = _mul(ret, _basic_wave(EXP, alpha, shift=shift))
9✔
203

204
    return _mul(ret, trig)
9✔
205

206

207
def _get_freq(t):
9✔
208
    t2 = [[], []]
9✔
209
    freq, shift = 0, 0
9✔
210
    for mt, n in zip(*t):
9✔
211
        if mt[0] == COS:
9✔
212
            if freq != 0:
9✔
213
                raise ValueError("run _exp_trig_Reduce first")
×
214
            freq = mt[1]
9✔
215
            shift = mt[-1]
9✔
216
        else:
217
            t2[0].append(mt)
9✔
218
            t2[1].append(n)
9✔
219
    t2 = (tuple(t2[0]), tuple(t2[1]))
9✔
220
    return freq, shift, t2
9✔
221

222

223
def _simplify(expr, eps):
9✔
224
    d = {}
9✔
225
    for t, v in zip(*expr):
9✔
226
        for t, v in zip(*_exp_trig_Reduce(t, v)):
9✔
227
            freq, shift, t = _get_freq(t)
9✔
228
            v_r, v_i, shift_r, shift_i = v.real, v.imag, shift, shift
9✔
229
            if (t, freq) in d:
9✔
230
                v0_r, shift0_r, v0_i, shift0_i = d[(t, freq)]
9✔
231
                if freq == 0:
9✔
232
                    v_r, v_i = v.real + v0_r, v.imag + v0_i
×
233
                else:
234
                    a = v0_r * np.cos(freq * shift0_r) + v_r * np.cos(
9✔
235
                        freq * shift_r)
236
                    b = v0_r * np.sin(freq * shift0_r) + v_r * np.sin(
9✔
237
                        freq * shift_r)
238
                    shift_r = np.arctan2(b, a) / freq
9✔
239
                    v_r = np.sqrt(a**2 + b**2)
9✔
240

241
                    a = v0_i * np.cos(freq * shift0_i) + v_i * np.cos(
9✔
242
                        freq * shift_i)
243
                    b = v0_i * np.sin(freq * shift0_i) + v_i * np.sin(
9✔
244
                        freq * shift_i)
245
                    shift_i = np.arctan2(b, a) / freq
9✔
246
                    v_i = np.sqrt(a**2 + b**2)
9✔
247
            d[(t, freq)] = v_r, shift_r, v_i, shift_i
9✔
248
    ret = _zero
9✔
249
    for (t, freq), (v_r, shift_r, v_i, shift_i) in d.items():
9✔
250
        if freq == 0 and abs(v) >= eps:
9✔
251
            if v_i == 0:
9✔
252
                ret = _add(ret, ((t, ), (v_r, )))
9✔
253
            else:
254
                ret = _add(ret, ((t, ), (v_r + 1j * v_i, )))
×
255
        else:
256
            if abs(v_i) < eps and abs(v_r) < eps:
9✔
257
                continue
×
258
            elif abs(v_i) < eps and abs(v_r) >= eps:
9✔
259
                expr = (((((COS, freq, shift_r), ), (1, )), ), (v_r, ))
9✔
260
            elif abs(v_i) >= eps and abs(v_r) < eps:
9✔
261
                expr = (((((COS, freq, shift_i), ), (1, )), ), (v_i * 1j, ))
×
262
            elif abs(v_i) >= eps and abs(v_r) >= eps:
9✔
263
                expr = (((((COS, freq, shift_r), ), (1, )),
9✔
264
                         (((COS, freq, shift_i), ), (1, ))), (v_r, v_i * 1j))
265
            else:
266
                pass  # Never reach here
×
267

268
            expr = _mul(((t, ), (1, )), expr)
9✔
269
            ret = _add(ret, expr)
9✔
270
    return ret
9✔
271

272

273
def _filter(expr, low, high, eps):
9✔
274
    expr = _simplify(expr, eps)
9✔
275
    ret = _zero
9✔
276
    for t, v in zip(*expr):
9✔
277
        for i, (mt, n) in enumerate(zip(*t)):
9✔
278
            if mt[0] == COS:
9✔
279
                if low <= mt[1] < high:
9✔
280
                    ret = _add(ret, ((t, ), (v, )))
9✔
281
                break
9✔
282
            elif mt[0] == SINC and n == 1:
9✔
283
                pass
×
284
            elif mt[0] == GAUSSIAN and n == 1:
9✔
285
                pass
×
286
        else:
287
            if low <= 0:
×
288
                ret = _add(ret, ((t, ), (v, )))
×
289
    return ret
9✔
290

291

292
def _apply(function_lib, func_id, x, shift, *args):
9✔
293
    return function_lib[func_id](x - shift, *args)
9✔
294

295

296
def _calc(wav, x, function_lib):
9✔
297
    lru_cache = {}
9✔
298

299
    def _calc_m(t, x):
9✔
300
        ret = 1
9✔
301
        for mt, n in zip(*t):
9✔
302
            if mt not in lru_cache:
9✔
303
                func_id, *args, shift = mt
9✔
304
                lru_cache[mt] = _apply(function_lib, func_id, x, shift, *args)
9✔
305
            if n == 1:
9✔
306
                ret = ret * lru_cache[mt]
9✔
307
            else:
308
                ret = ret * lru_cache[mt]**n
9✔
309
        return ret
9✔
310

311
    ret = 0
9✔
312
    for t, v in zip(*wav):
9✔
313
        ret = ret + v * _calc_m(t, x)
9✔
314
    return ret
9✔
315

316

317
def _test_spec_num(num, spec):
9✔
318
    x = Fraction(num / spec).limit_denominator(1000000000)
×
319
    if x.denominator <= 24:
×
320
        return True, x, 1
×
321
    x = Fraction(spec * num).limit_denominator(1000000000)
×
322
    if x.denominator <= 24:
×
323
        return True, x, -1
×
324
    return False, x, 0
×
325

326

327
def _spec_num_latex(num):
9✔
328
    for spec, spec_latex in [(1, ''), (np.sqrt(2), '\\sqrt{2}'),
×
329
                             (np.sqrt(3), '\\sqrt{3}'),
330
                             (np.sqrt(5), '\\sqrt{5}'),
331
                             (np.log(2), '\\log{2}'), (np.log(3), '\\log{3}'),
332
                             (np.log(5), '\\log{5}'), (np.e, 'e'),
333
                             (np.pi, '\\pi'), (np.pi**2, '\\pi^2'),
334
                             (np.sqrt(np.pi), '\\sqrt{\\pi}')]:
335
        flag, x, sign = _test_spec_num(num, spec)
×
336
        if flag:
×
337
            if sign < 0:
×
338
                spec_latex = f"\\frac{{{1}}}{{{spec_latex}}}"
×
339
            if x.denominator == 1:
×
340
                if x.numerator == 1:
×
341
                    return f"{spec_latex}"
×
342
                else:
343
                    return f"{x.numerator:g}{spec_latex}"
×
344
            else:
345
                if x.numerator < 0:
×
346
                    return f"-\\frac{{{-x.numerator}}}{{{x.denominator}}}{spec_latex}"
×
347
                else:
348
                    return f"\\frac{{{x.numerator}}}{{{x.denominator}}}{spec_latex}"
×
349
    return f"{num:g}"
×
350

351

352
def _num_latex(num):
9✔
353
    if num == -np.inf:
×
354
        return r"-\infty"
×
355
    elif num == np.inf:
×
356
        return r"\infty"
×
357
    if num.imag > 0:
×
358
        return f"\\left({_num_latex(num.real)}+{_num_latex(num.imag)}j\\right)"
×
359
    elif num.imag < 0:
×
360
        return f"\\left({_num_latex(num.real)}-{_num_latex(-num.imag)}j\\right)"
×
361
    s = _spec_num_latex(num.real)
×
362
    if s == '' and round(num.real) == 1:
×
363
        return '1'
×
364
    if "e" in s:
×
365
        a, n = s.split("e")
×
366
        n = float(n)
×
367
        s = f"{a} \\times 10^{{{n:g}}}"
×
368
    return s
×
369

370

371
def _fun_latex(fun):
9✔
372
    funID, *args, shift = fun
×
373
    if _baseFunc_latex[funID] is None:
×
374
        shift = _num_latex(shift)
×
375
        if shift == "0":
×
376
            shift = ""
×
377
        elif shift[0] != '-':
×
378
            shift = "+" + shift
×
379
        return r"\mathrm{Func}" + f"{funID}(t{shift}, ...)"
×
380
    return _baseFunc_latex[funID](shift, *args)
×
381

382

383
def _wav_latex(wav):
9✔
384
    from waveforms.waveform import _is_const, _zero
×
385

386
    if wav == _zero:
×
387
        return "0"
×
388
    elif _is_const(wav):
×
389
        return f"{wav[1][0]}"
×
390

391
    sum_expr = []
×
392
    for mul, amp in zip(*wav):
×
393
        if mul == ((), ()):
×
394
            sum_expr.append(_num_latex(amp))
×
395
            continue
×
396
        mul_expr = []
×
397
        amp = _num_latex(amp)
×
398
        if amp != "1":
×
399
            mul_expr.append(amp)
×
400
        for fun, n in zip(*mul):
×
401
            fun_expr = _fun_latex(fun)
×
402
            if n != 1:
×
403
                mul_expr.append(fun_expr + "^{" + f"{n}" + "}")
×
404
            else:
405
                mul_expr.append(fun_expr)
×
406
        sum_expr.append(''.join(mul_expr))
×
407

408
    ret = sum_expr[0]
×
409
    for expr in sum_expr[1:]:
×
410
        if expr[0] == '-':
×
411
            ret += expr
×
412
        else:
413
            ret += "+" + expr
×
414
    return ret
×
415

416

417
class Waveform:
9✔
418
    __slots__ = ('bounds', 'seq', 'max', 'min', 'start', 'stop', 'sample_rate',
9✔
419
                 'filters', 'label')
420

421
    def __init__(self, bounds=(+inf, ), seq=(_zero, ), min=-inf, max=inf):
9✔
422
        self.bounds = bounds
9✔
423
        self.seq = seq
9✔
424
        self.max = max
9✔
425
        self.min = min
9✔
426
        self.start = None
9✔
427
        self.stop = None
9✔
428
        self.sample_rate = None
9✔
429
        self.filters = None
9✔
430
        self.label = None
9✔
431

432
    def _begin(self):
9✔
433
        for i, s in enumerate(self.seq):
×
434
            if s is not _zero:
×
435
                if i == 0:
×
436
                    return -inf
×
437
                return self.bounds[i - 1]
×
438
        return inf
×
439

440
    def _end(self):
9✔
441
        N = len(self.bounds)
×
442
        for i, s in enumerate(self.seq[::-1]):
×
443
            if s is not _zero:
×
444
                if i == 0:
×
445
                    return inf
×
446
                return self.bounds[N - i - 1]
×
447
        return -inf
×
448

449
    @property
9✔
450
    def begin(self):
9✔
451
        if self.start is None:
×
452
            return self._begin()
×
453
        else:
454
            return max(self.start, self._begin())
×
455

456
    @property
9✔
457
    def end(self):
9✔
458
        if self.stop is None:
×
459
            return self._end()
×
460
        else:
461
            return min(self.stop, self._end())
×
462

463
    def sample(self,
9✔
464
               sample_rate=None,
465
               out=None,
466
               chunk_size=None,
467
               function_lib=None,
468
               filters=None):
469
        if sample_rate is None:
9✔
470
            sample_rate = self.sample_rate
9✔
471
        if self.start is None or self.stop is None or sample_rate is None:
9✔
472
            raise ValueError(
×
473
                f'Waveform is not initialized. {self.start=}, {self.stop=}, {sample_rate=}'
474
            )
475
        if filters is None:
9✔
476
            filters = self.filters
9✔
477
        if chunk_size is None:
9✔
478
            x = np.arange(self.start, self.stop, 1 / sample_rate)
9✔
479
            sig = self.__call__(x, out=out, function_lib=function_lib)
9✔
480
            if filters is not None:
9✔
481
                sos, initial = filters
9✔
482
                if initial:
9✔
483
                    sig = sosfilt(sos, sig - initial) + initial
×
484
                else:
485
                    sig = sosfilt(sos, sig)
9✔
486
            return sig
9✔
487
        else:
488
            return self._sample_iter(sample_rate, chunk_size, out,
×
489
                                     function_lib, filters)
490

491
    def _sample_iter(self, sample_rate, chunk_size, out, function_lib,
9✔
492
                     filters):
493
        start = self.start
×
494
        start_n = 0
×
495
        if filters is not None:
×
496
            sos, initial = filters
×
497
            # zi = sosfilt_zi(sos)
498
            zi = np.zeros((sos.shape[0], 2))
×
499
        length = chunk_size / sample_rate
×
500
        while start < self.stop:
×
501
            if start + length > self.stop:
×
502
                length = self.stop - start
×
503
                stop = self.stop
×
504
                size = round((stop - start) * sample_rate)
×
505
            else:
506
                stop = start + length
×
507
                size = chunk_size
×
508
            x = np.linspace(start, stop, size, endpoint=False)
×
509
            if out is not None:
×
510
                if filters is None:
×
511
                    yield self.__call__(x,
×
512
                                        out=out[start_n:],
513
                                        function_lib=function_lib)
514
                else:
515
                    if initial:
×
516
                        sig -= initial
×
517
                    sig, zi = sosfilt(sos,
×
518
                                      self.__call__(x,
519
                                                    function_lib=function_lib),
520
                                      zi=zi)
521
                    if initial:
×
522
                        sig += initial
×
523
                    out[start_n:start_n + size] = sig
×
524
                    yield sig
×
525
            else:
526
                if filters is None:
×
527
                    yield self.__call__(x, function_lib=function_lib)
×
528
                else:
529
                    if initial:
×
530
                        sig -= initial
×
531
                    sig, zi = sosfilt(sos,
×
532
                                      self.__call__(x,
533
                                                    function_lib=function_lib),
534
                                      zi=zi)
535
                    if initial:
×
536
                        sig += initial
×
537
                    yield sig
×
538
            start = stop
×
539
            start_n += chunk_size
×
540

541
    @staticmethod
9✔
542
    def _tolist(bounds, seq, ret=None):
9✔
543
        if ret is None:
9✔
544
            ret = []
×
545
        ret.append(len(bounds))
9✔
546
        for seq, b in zip(seq, bounds):
9✔
547
            ret.append(b)
9✔
548
            tlist, amplist = seq
9✔
549
            ret.append(len(amplist))
9✔
550
            for t, amp in zip(tlist, amplist):
9✔
551
                ret.append(amp)
9✔
552
                mtlist, nlist = t
9✔
553
                ret.append(len(nlist))
9✔
554
                for fun, n in zip(mtlist, nlist):
9✔
555
                    ret.append(n)
9✔
556
                    ret.append(len(fun))
9✔
557
                    ret.extend(fun)
9✔
558
        return ret
9✔
559

560
    @staticmethod
9✔
561
    def _fromlist(l, pos=0):
9✔
562

563
        def _read(l, pos, size):
9✔
564
            try:
9✔
565
                return tuple(l[pos:pos + size]), pos + size
9✔
566
            except:
×
567
                raise ValueError('Invalid waveform format')
×
568

569
        (nseg, ), pos = _read(l, pos, 1)
9✔
570
        bounds = []
9✔
571
        seq = []
9✔
572
        for _ in range(nseg):
9✔
573
            (b, nsum), pos = _read(l, pos, 2)
9✔
574
            bounds.append(b)
9✔
575
            amp = []
9✔
576
            t = []
9✔
577
            for _ in range(nsum):
9✔
578
                (a, nmul), pos = _read(l, pos, 2)
9✔
579
                amp.append(a)
9✔
580
                nlst = []
9✔
581
                mt = []
9✔
582
                for _ in range(nmul):
9✔
583
                    (n, nfun), pos = _read(l, pos, 2)
9✔
584
                    nlst.append(n)
9✔
585
                    fun, pos = _read(l, pos, nfun)
9✔
586
                    mt.append(fun)
9✔
587
                t.append((tuple(mt), tuple(nlst)))
9✔
588
            seq.append((tuple(t), tuple(amp)))
9✔
589

590
        return tuple(bounds), tuple(seq), pos
9✔
591

592
    def tolist(self):
9✔
593
        l = [self.max, self.min, self.start, self.stop, self.sample_rate]
9✔
594
        if self.filters is None:
9✔
595
            l.append(None)
9✔
596
        else:
597
            sos, initial = self.filters
9✔
598
            sos = list(sos.reshape(-1))
9✔
599
            l.append(len(sos))
9✔
600
            l.extend(sos)
9✔
601
            l.append(initial)
9✔
602

603
        return self._tolist(self.bounds, self.seq, l)
9✔
604

605
    @classmethod
9✔
606
    def fromlist(cls, l):
9✔
607
        w = cls()
9✔
608
        pos = 6
9✔
609
        (w.max, w.min, w.start, w.stop, w.sample_rate, sos_size) = l[:pos]
9✔
610
        if sos_size is not None:
9✔
611
            sos = np.array(l[pos:pos + sos_size]).reshape(-1, 6)
9✔
612
            pos += sos_size
9✔
613
            initial = l[pos]
9✔
614
            pos += 1
9✔
615
            w.filters = sos, initial
9✔
616

617
        w.bounds, w.seq, pos = cls._fromlist(l, pos)
9✔
618
        return w
9✔
619

620
    def totree(self):
9✔
621
        if self.filters is None:
9✔
622
            header = (self.max, self.min, self.start, self.stop,
9✔
623
                      self.sample_rate, None)
624
        else:
625
            header = (self.max, self.min, self.start, self.stop,
9✔
626
                      self.sample_rate, self.filters)
627
        body = []
9✔
628

629
        for seq, b in zip(self.seq, self.bounds):
9✔
630
            tlist, amplist = seq
9✔
631
            new_seq = []
9✔
632
            for t, amp in zip(tlist, amplist):
9✔
633
                mtlist, nlist = t
9✔
634
                new_t = []
9✔
635
                for fun, n in zip(mtlist, nlist):
9✔
636
                    new_t.append((n, fun))
9✔
637
                new_seq.append((amp, tuple(new_t)))
9✔
638
            body.append((b, tuple(new_seq)))
9✔
639
        return header, tuple(body)
9✔
640

641
    @staticmethod
9✔
642
    def fromtree(tree):
9✔
643
        w = Waveform()
9✔
644
        header, body = tree
9✔
645

646
        (w.max, w.min, w.start, w.stop, w.sample_rate, w.filters) = header
9✔
647
        bounds = []
9✔
648
        seqs = []
9✔
649
        for b, seq in body:
9✔
650
            bounds.append(b)
9✔
651
            amp_list = []
9✔
652
            t_list = []
9✔
653
            for amp, t in seq:
9✔
654
                amp_list.append(amp)
9✔
655
                n_list = []
9✔
656
                mt_list = []
9✔
657
                for n, mt in t:
9✔
658
                    n_list.append(n)
9✔
659
                    mt_list.append(mt)
9✔
660
                t_list.append((tuple(mt_list), tuple(n_list)))
9✔
661
            seqs.append((tuple(t_list), tuple(amp_list)))
9✔
662
        w.bounds = tuple(bounds)
9✔
663
        w.seq = tuple(seqs)
9✔
664
        return w
9✔
665

666
    def simplify(self, eps=1e-15):
9✔
667
        seq = [_simplify(self.seq[0], eps)]
9✔
668
        bounds = [self.bounds[0]]
9✔
669
        for expr, b in zip(self.seq[1:], self.bounds[1:]):
9✔
670
            expr = _simplify(expr, eps)
9✔
671
            if expr == seq[-1]:
9✔
672
                seq.pop()
×
673
                bounds.pop()
×
674
            seq.append(expr)
9✔
675
            bounds.append(b)
9✔
676
        return Waveform(tuple(bounds), tuple(seq))
9✔
677

678
    def filter(self, low=0, high=inf, eps=1e-15):
9✔
679
        seq = []
9✔
680
        for expr in self.seq:
9✔
681
            seq.append(_filter(expr, low, high, eps))
9✔
682
        return Waveform(self.bounds, tuple(seq))
9✔
683

684
    def _comb(self, other, oper):
9✔
685
        bounds = []
9✔
686
        seq = []
9✔
687
        i1, i2 = 0, 0
9✔
688
        h1, h2 = len(self.bounds), len(other.bounds)
9✔
689
        while i1 < h1 or i2 < h2:
9✔
690
            s = oper(self.seq[i1], other.seq[i2])
9✔
691
            b = min(self.bounds[i1], other.bounds[i2])
9✔
692
            if seq and s == seq[-1]:
9✔
693
                bounds[-1] = b
9✔
694
            else:
695
                bounds.append(b)
9✔
696
                seq.append(s)
9✔
697
            if b == self.bounds[i1]:
9✔
698
                i1 += 1
9✔
699
            if b == other.bounds[i2]:
9✔
700
                i2 += 1
9✔
701
        return Waveform(tuple(bounds), tuple(seq))
9✔
702

703
    def __pow__(self, n):
9✔
704
        return Waveform(self.bounds, tuple(_pow(w, n) for w in self.seq))
9✔
705

706
    def __add__(self, other):
9✔
707
        if isinstance(other, Waveform):
9✔
708
            return self._comb(other, _add)
9✔
709
        else:
710
            return self + const(other)
9✔
711

712
    def __radd__(self, v):
9✔
713
        return const(v) + self
×
714

715
    def __ior__(self, other):
9✔
716
        return self | other
×
717

718
    def __or__(self, other):
9✔
719
        if isinstance(other, (int, float, complex)):
×
720
            other = const(other)
×
721
        w = self.marker + other.marker
×
722

723
        def _or(a, b):
×
724
            if a != _zero or b != _zero:
×
725
                return _one
×
726
            else:
727
                return _zero
×
728

729
        return self._comb(other, _or)
×
730

731
    def __iand__(self, other):
9✔
732
        return self & other
×
733

734
    def __and__(self, other):
9✔
735
        if isinstance(other, (int, float, complex)):
×
736
            other = const(other)
×
737
        w = self.marker + other.marker
×
738

739
        def _and(a, b):
×
740
            if a != _zero and b != _zero:
×
741
                return _one
×
742
            else:
743
                return _zero
×
744

745
        return self._comb(other, _and)
×
746

747
    @property
9✔
748
    def marker(self):
9✔
749
        w = self.simplify()
×
750
        return Waveform(w.bounds,
×
751
                        tuple(_zero if s == _zero else _one for s in w.seq))
752

753
    def mask(self, edge=0):
9✔
754
        w = self.marker
×
755
        in_wave = w.seq[0] == _zero
×
756
        bounds = []
×
757
        seq = []
×
758

759
        if w.seq[0] == _zero:
×
760
            in_wave = False
×
761
            b = w.bounds[0] - edge
×
762
            bounds.append(b)
×
763
            seq.append(_zero)
×
764

765
        for b, s in zip(w.bounds[1:], w.seq[1:]):
×
766
            if not in_wave and s != _zero:
×
767
                in_wave = True
×
768
                bounds.append(b + edge)
×
769
                seq.append(_one)
×
770
            elif in_wave and s == _zero:
×
771
                in_wave = False
×
772
                b = b - edge
×
773
                if b > bounds[-1]:
×
774
                    bounds.append(b)
×
775
                    seq.append(_zero)
×
776
                else:
777
                    bounds.pop()
×
778
                    bounds.append(b)
×
779
        return Waveform(tuple(bounds), tuple(seq))
×
780

781
    def __mul__(self, other):
9✔
782
        if isinstance(other, Waveform):
9✔
783
            return self._comb(other, _mul)
9✔
784
        else:
785
            return self * const(other)
9✔
786

787
    def __rmul__(self, v):
9✔
788
        return const(v) * self
9✔
789

790
    def __truediv__(self, other):
9✔
791
        if isinstance(other, Waveform):
9✔
792
            raise TypeError('division by waveform')
×
793
        else:
794
            return self * const(1 / other)
9✔
795

796
    def __neg__(self):
9✔
797
        return -1 * self
9✔
798

799
    def __sub__(self, other):
9✔
800
        return self + (-other)
9✔
801

802
    def __rsub__(self, v):
9✔
803
        return v + (-self)
×
804

805
    def __rshift__(self, time):
9✔
806
        return Waveform(
9✔
807
            tuple(round(bound + time, NDIGITS) for bound in self.bounds),
808
            tuple(_shift(expr, time) for expr in self.seq))
809

810
    def __lshift__(self, time):
9✔
811
        return self >> (-time)
9✔
812

813
    @staticmethod
9✔
814
    def _calc_parts(bounds, seq, x, function_lib, min=-inf, max=inf):
9✔
815
        range_list = np.searchsorted(x, bounds)
9✔
816
        parts = []
9✔
817
        start, stop = 0, 0
9✔
818
        dtype = float
9✔
819
        for i, stop in enumerate(range_list):
9✔
820
            if start < stop and seq[i] != _zero:
9✔
821
                part = np.clip(_calc(seq[i], x[start:stop], function_lib), min,
9✔
822
                               max)
823
                if (isinstance(part, complex) or isinstance(part, np.ndarray)
9✔
824
                        and isinstance(part[0], complex)):
825
                    dtype = complex
9✔
826
                parts.append((start, stop, part))
9✔
827
            start = stop
9✔
828
        return parts, dtype
9✔
829

830
    @staticmethod
9✔
831
    def _merge_parts(
9✔
832
        parts: list[tuple[int, int, np.ndarray | int | float | complex]],
833
        out: list[tuple[int, int, np.ndarray | int | float | complex]]
834
    ) -> list[tuple[int, int, np.ndarray | int | float | complex]]:
835
        # TODO: merge parts
836
        raise NotImplementedError
×
837

838
    @staticmethod
9✔
839
    def _fill_parts(parts, out):
9✔
840
        for start, stop, part in parts:
9✔
841
            out[start:stop] += part
9✔
842

843
    def __call__(self,
9✔
844
                 x,
845
                 frag=False,
846
                 out=None,
847
                 accumulate=False,
848
                 function_lib=None):
849
        if function_lib is None:
9✔
850
            function_lib = _baseFunc
9✔
851
        if isinstance(x, (int, float, complex)):
9✔
852
            return self.__call__(np.array([x]), function_lib=function_lib)[0]
×
853
        parts, dtype = self._calc_parts(self.bounds, self.seq, x, function_lib,
9✔
854
                                        self.min, self.max)
855
        if not frag:
9✔
856
            if out is None:
9✔
857
                out = np.zeros_like(x, dtype=dtype)
9✔
858
            elif not accumulate:
×
859
                out *= 0
×
860
            self._fill_parts(parts, out)
9✔
861
        else:
862
            if out is None:
×
863
                return parts
×
864
            else:
865
                if not accumulate:
×
866
                    out.clear()
×
867
                    out.extend(parts)
×
868
                else:
869
                    self._merge_parts(parts, out)
×
870
        return out
9✔
871

872
    def __hash__(self):
9✔
873
        return hash((self.max, self.min, self.start, self.stop,
×
874
                     self.sample_rate, self.bounds, self.seq))
875

876
    def __eq__(self, o: object) -> bool:
9✔
877
        if isinstance(o, (int, float, complex)):
9✔
878
            return self == const(o)
×
879
        elif isinstance(o, Waveform):
9✔
880
            a = self.simplify()
9✔
881
            b = o.simplify()
9✔
882
            return a.seq == b.seq and a.bounds == b.bounds and (
9✔
883
                a.max, a.min, a.start, a.stop) == (b.max, b.min, b.start,
884
                                                   b.stop)
885
        else:
886
            return False
×
887

888
    def _repr_latex_(self):
9✔
889
        parts = []
×
890
        start = -np.inf
×
891
        for end, wav in zip(self.bounds, self.seq):
×
892
            e_str = _wav_latex(wav)
×
893
            start_str = _num_latex(start)
×
894
            end_str = _num_latex(end)
×
895
            parts.append(e_str + r",~~&t\in" + f"({start_str},{end_str}" +
×
896
                         (']' if end < np.inf else ')'))
897
            start = end
×
898
        if len(parts) == 1:
×
899
            expr = ''.join(['f(t)=', *parts[0].split('&')])
×
900
        else:
901
            expr = '\n'.join([
×
902
                r"f(t)=\begin{cases}", (r"\\" + '\n').join(parts),
903
                r"\end{cases}"
904
            ])
905
        return "$$\n{}\n$$".format(expr)
×
906

907
    def _play(self, time_unit, volume=1.0):
9✔
908
        import pyaudio
×
909

910
        CHUNK = 1024
×
911
        RATE = 48000
×
912

913
        dynamic_volume = 1.0
×
914
        amp = 2**15 * 0.999 * volume * dynamic_volume
×
915

916
        p = pyaudio.PyAudio()
×
917
        try:
×
918
            stream = p.open(format=pyaudio.paInt16,
×
919
                            channels=1,
920
                            rate=RATE,
921
                            output=True)
922
            try:
×
923
                for data in self.sample(sample_rate=RATE / time_unit,
×
924
                                        chunk_size=CHUNK):
925
                    lim = np.abs(data).max()
×
926
                    if lim > 0 and dynamic_volume > 1.0 / lim:
×
927
                        dynamic_volume = 1.0 / lim
×
928
                        amp = 2**15 * 0.99 * volume * dynamic_volume
×
929
                    data = (amp * data).astype(np.int16)
×
930
                    stream.write(bytes(data.data))
×
931
            finally:
932
                stream.stop_stream()
×
933
                stream.close()
×
934
        finally:
935
            p.terminate()
×
936

937
    def play(self, time_unit=1, volume=1.0):
9✔
938
        import multiprocessing as mp
×
939
        p = mp.Process(target=self._play,
×
940
                       args=(time_unit, volume),
941
                       daemon=True)
942
        p.start()
×
943

944

945
class WaveVStack(Waveform):
9✔
946

947
    def __init__(self, wlist: list[Waveform] = []):
9✔
948
        self.wlist = [(w.bounds, w.seq) for w in wlist]
9✔
949
        self.start = None
9✔
950
        self.stop = None
9✔
951
        self.sample_rate = None
9✔
952
        self.offset = 0
9✔
953
        self.shift = 0
9✔
954
        self.filters = None
9✔
955
        self.label = None
9✔
956
        self.function_lib = None
9✔
957

958
    def __call__(self, x, frag=False, out=None, function_lib=None):
9✔
959
        assert frag is False, 'WaveVStack does not support frag mode'
9✔
960
        out = np.full_like(x, self.offset, dtype=complex)
9✔
961
        if self.shift != 0:
9✔
962
            x = x - self.shift
9✔
963
        if function_lib is None:
9✔
964
            if self.function_lib is None:
9✔
965
                function_lib = _baseFunc
9✔
966
            else:
967
                function_lib = self.function_lib
×
968
        for bounds, seq in self.wlist:
9✔
969
            parts, dtype = self._calc_parts(bounds, seq, x, function_lib)
9✔
970
            self._fill_parts(parts, out)
9✔
971
        return out.real
9✔
972

973
    def tolist(self):
9✔
974
        l = [
9✔
975
            self.start,
976
            self.stop,
977
            self.offset,
978
            self.shift,
979
            self.sample_rate,
980
        ]
981
        if self.filters is None:
9✔
982
            l.append(None)
9✔
983
        else:
984
            sos, initial = self.filters
9✔
985
            sos = list(sos.reshape(-1))
9✔
986
            l.append(len(sos))
9✔
987
            l.extend(sos)
9✔
988
            l.append(initial)
9✔
989
        l.append(len(self.wlist))
9✔
990
        for bounds, seq in self.wlist:
9✔
991
            self._tolist(bounds, seq, l)
9✔
992
        return l
9✔
993

994
    @classmethod
9✔
995
    def fromlist(cls, l):
9✔
996
        w = cls()
9✔
997
        pos = 6
9✔
998
        w.start, w.stop, w.offset, w.shift, w.sample_rate, sos_size = l[:pos]
9✔
999
        if sos_size is not None:
9✔
1000
            sos = np.array(l[pos:pos + sos_size]).reshape(-1, 6)
9✔
1001
            pos += sos_size
9✔
1002
            initial = l[pos]
9✔
1003
            pos += 1
9✔
1004
            w.filters = sos, initial
9✔
1005
        n = l[pos]
9✔
1006
        pos += 1
9✔
1007
        for _ in range(n):
9✔
1008
            bounds, seq, pos = cls._fromlist(l, pos)
9✔
1009
            w.wlist.append((bounds, seq))
9✔
1010
        return w
9✔
1011

1012
    def simplify(self, eps=1e-15):
9✔
1013
        wav = wave_sum(self.wlist)
9✔
1014
        if self.offset != 0:
9✔
1015
            wav += self.offset
×
1016
        if self.shift != 0:
9✔
1017
            wav >>= self.shift
×
1018
        wav = wav.simplify(eps)
9✔
1019
        wav.start = self.start
9✔
1020
        wav.stop = self.stop
9✔
1021
        wav.sample_rate = self.sample_rate
9✔
1022
        return wav
9✔
1023

1024
    @staticmethod
9✔
1025
    def _rshift(wlist, time):
9✔
1026
        if time == 0:
×
1027
            return wlist
×
1028
        return [(tuple(round(bound + time, NDIGITS) for bound in bounds),
×
1029
                 tuple(_shift(expr, time) for expr in seq))
1030
                for bounds, seq in wlist]
1031

1032
    def __rshift__(self, time):
9✔
1033
        ret = WaveVStack()
9✔
1034
        ret.wlist = self.wlist
9✔
1035
        ret.sample_rate = self.sample_rate
9✔
1036
        ret.start = self.start
9✔
1037
        ret.stop = self.stop
9✔
1038
        ret.shift = self.shift + time
9✔
1039
        ret.offset = self.offset
9✔
1040
        return ret
9✔
1041

1042
    def __add__(self, other):
9✔
1043
        ret = WaveVStack()
9✔
1044
        ret.wlist.extend(self.wlist)
9✔
1045
        if isinstance(other, WaveVStack):
9✔
1046
            if other.shift != self.shift:
×
1047
                ret.wlist = self._rshift(ret.wlist, self.shift)
×
1048
                ret.wlist.extend(self._rshift(other.wlist, other.shift))
×
1049
            else:
1050
                ret.wlist.extend(other.wlist)
×
1051
            ret.offset = self.offset + other.offset
×
1052
        elif isinstance(other, Waveform):
9✔
1053
            other <<= self.shift
9✔
1054
            ret.wlist.append((other.bounds, other.seq))
9✔
1055
        else:
1056
            # ret.wlist.append(((+inf, ), (_const(1.0 * other), )))
1057
            ret.offset += other
9✔
1058
        return ret
9✔
1059

1060
    def __radd__(self, v):
9✔
1061
        return self + v
×
1062

1063
    def __mul__(self, other):
9✔
1064
        if isinstance(other, Waveform):
9✔
1065
            other = other.simplify() << self.shift
9✔
1066
            ret = WaveVStack([Waveform(*w) * other for w in self.wlist])
9✔
1067
            if self.offset != 0:
9✔
1068
                w = other * self.offset
×
1069
                ret.wlist.append((w.bounds, w.seq))
×
1070
            return ret
9✔
1071
        else:
1072
            ret = WaveVStack([Waveform(*w) * other for w in self.wlist])
×
1073
            ret.offset = self.offset * other
×
1074
            return ret
×
1075

1076
    def __rmul__(self, v):
9✔
1077
        return self * v
×
1078

1079
    def __eq__(self, other):
9✔
1080
        if self.wlist:
×
1081
            return False
×
1082
        else:
1083
            return zero() == other
×
1084

1085
    def _repr_latex_(self):
9✔
1086
        return r"\sum_{i=1}^{" + f"{len(self.wlist)}" + r"}" + r"f_i(t)"
×
1087

1088
    def __getstate__(self) -> tuple:
9✔
NEW
1089
        function_lib = self.function_lib
×
NEW
1090
        if function_lib:
×
1091
            try:
×
1092
                import dill
×
NEW
1093
                function_lib = dill.dumps(function_lib)
×
1094
            except:
×
NEW
1095
                function_lib = None
×
NEW
1096
        return (self.wlist, self.start, self.stop, self.sample_rate,
×
1097
                self.offset, self.shift, self.filters, self.label,
1098
                function_lib)
1099

1100
    def __setstate__(self, state: tuple) -> None:
9✔
NEW
1101
        (self.wlist, self.start, self.stop, self.sample_rate, self.offset,
×
1102
         self.shift, self.filters, self.label, function_lib) = state
NEW
1103
        if function_lib:
×
1104
            try:
×
1105
                import dill
×
NEW
1106
                function_lib = dill.loads(function_lib)
×
1107
            except:
×
NEW
1108
                function_lib = None
×
NEW
1109
        self.function_lib = function_lib
×
1110

1111

1112
def wave_sum(waves):
9✔
1113
    if not waves:
9✔
1114
        return zero()
9✔
1115

1116
    bounds, seq = waves[0]
9✔
1117
    if not waves[1:]:
9✔
1118
        return Waveform(bounds, seq)
9✔
1119
    bounds, seq = list(bounds), list(seq)
9✔
1120

1121
    for bounds_, seq_ in waves[1:]:
9✔
1122
        if len(bounds_) == 1:
9✔
1123
            for i, s in enumerate(seq):
9✔
1124
                seq[i] = _add(s, seq_[0])
9✔
1125
        elif len(bounds) == 1:
9✔
1126
            bounds = list(bounds_)
9✔
1127
            seq = [_add(seq[0], s) for s in seq_]
9✔
1128
        else:
1129
            lo = 0
9✔
1130
            for b, s in zip(bounds_, seq_):
9✔
1131
                i = bisect_left(bounds, b, lo=lo)
9✔
1132
                if bounds[i] > b:
9✔
1133
                    bounds.insert(i, b)
9✔
1134
                    if i == 0:
9✔
1135
                        seq.insert(i, s)
×
1136
                    else:
1137
                        seq.insert(i, _add(s, seq[i]))
9✔
1138
                    up = i - 1
9✔
1139
                else:
1140
                    up = i
9✔
1141
                for j in range(lo + 1, up + 1):
9✔
1142
                    seq[j] = _add(seq[j], s)
9✔
1143
                lo = i
9✔
1144

1145
    i = 0
9✔
1146
    while i < len(bounds) - 1:
9✔
1147
        if seq[i] == seq[i + 1]:
9✔
1148
            del seq[i + 1]
×
1149
            del bounds[i + 1]
×
1150
        else:
1151
            i += 1
9✔
1152

1153
    return Waveform(tuple(bounds), tuple(seq))
9✔
1154

1155

1156
def play(data, rate=48000):
9✔
1157
    import io
×
1158

1159
    import pyaudio
×
1160

1161
    CHUNK = 1024
×
1162

1163
    max_amp = np.max(np.abs(data))
×
1164

1165
    if max_amp > 1:
×
1166
        data /= max_amp
×
1167

1168
    data = np.array(2**15 * 0.999 * data, dtype=np.int16)
×
1169
    buff = io.BytesIO(data.data)
×
1170
    p = pyaudio.PyAudio()
×
1171

1172
    try:
×
1173
        stream = p.open(format=pyaudio.paInt16,
×
1174
                        channels=1,
1175
                        rate=rate,
1176
                        output=True)
1177
        try:
×
1178
            while True:
×
1179
                data = buff.read(CHUNK)
×
1180
                if data:
×
1181
                    stream.write(data)
×
1182
                else:
1183
                    break
×
1184
        finally:
1185
            stream.stop_stream()
×
1186
            stream.close()
×
1187
    finally:
1188
        p.terminate()
×
1189

1190

1191
_zero_waveform = Waveform()
9✔
1192
_one_waveform = Waveform(seq=(_one, ))
9✔
1193

1194

1195
def zero():
9✔
1196
    return _zero_waveform
9✔
1197

1198

1199
def one():
9✔
1200
    return _one_waveform
×
1201

1202

1203
def const(c):
9✔
1204
    return Waveform(seq=(_const(1.0 * c), ))
9✔
1205

1206

1207
__TypeIndex = 1
9✔
1208
_baseFunc = {}
9✔
1209
_derivativeBaseFunc = {}
9✔
1210
_baseFunc_latex = {}
9✔
1211

1212

1213
def registerBaseFunc(func, latex=None):
9✔
1214
    global __TypeIndex
1215
    Type = __TypeIndex
9✔
1216
    __TypeIndex += 1
9✔
1217

1218
    _baseFunc[Type] = func
9✔
1219
    _baseFunc_latex[Type] = latex
9✔
1220

1221
    return Type
9✔
1222

1223

1224
def packBaseFunc():
9✔
1225
    return pickle.dumps(_baseFunc)
×
1226

1227

1228
def updateBaseFunc(buf):
9✔
1229
    _baseFunc.update(pickle.loads(buf))
×
1230

1231

1232
def registerDerivative(Type, dFunc):
9✔
1233
    _derivativeBaseFunc[Type] = dFunc
9✔
1234

1235

1236
# register base function
1237
def _format_LINEAR(shift, *args):
9✔
1238
    if shift != 0:
×
1239
        shift = _num_latex(-shift)
×
1240
        if shift[0] == '-':
×
1241
            return f"(t{shift})"
×
1242
        else:
1243
            return f"(t+{shift})"
×
1244
    else:
1245
        return 't'
×
1246

1247

1248
def _format_GAUSSIAN(shift, *args):
9✔
1249
    sigma = _num_latex(args[0] / np.sqrt(2))
×
1250
    shift = _num_latex(-shift)
×
1251
    if shift != '0':
×
1252
        if shift[0] != '-':
×
1253
            shift = '+' + shift
×
1254
        if sigma == '1':
×
1255
            return ('\\exp\\left[-\\frac{\\left(t' + shift +
×
1256
                    '\\right)^2}{2}\\right]')
1257
        else:
1258
            return ('\\exp\\left[-\\frac{1}{2}\\left(\\frac{t' + shift + '}{' +
×
1259
                    sigma + '}\\right)^2\\right]')
1260
    else:
1261
        if sigma == '1':
×
1262
            return ('\\exp\\left(-\\frac{t^2}{2}\\right)')
×
1263
        else:
1264
            return ('\\exp\\left[-\\frac{1}{2}\\left(\\frac{t}{' + sigma +
×
1265
                    '}\\right)^2\\right]')
1266

1267

1268
def _format_SINC(shift, *args):
9✔
1269
    shift = _num_latex(-shift)
×
1270
    bw = _num_latex(args[0])
×
1271
    if shift != '0':
×
1272
        if shift[0] != '-':
×
1273
            shift = '+' + shift
×
1274
        if bw == '1':
×
1275
            return '\\mathrm{sinc}(t' + shift + ')'
×
1276
        else:
1277
            return '\\mathrm{sinc}[' + bw + '(t' + shift + ')]'
×
1278
    else:
1279
        if bw == '1':
×
1280
            return '\\mathrm{sinc}(t)'
×
1281
        else:
1282
            return '\\mathrm{sinc}(' + bw + 't)'
×
1283

1284

1285
def _format_COSINE(shift, *args):
9✔
1286
    freq = args[0] / 2 / np.pi
×
1287
    phase = -shift * freq
×
1288
    freq = _num_latex(freq)
×
1289
    if freq == '1':
×
1290
        freq = ''
×
1291
    phase = _num_latex(phase)
×
1292
    if phase == '0':
×
1293
        phase = ''
×
1294
    elif phase[0] != '-':
×
1295
        phase = '+' + phase
×
1296
    if phase != '':
×
1297
        return f'\\cos\\left[2\\pi\\left({freq}t{phase}\\right)\\right]'
×
1298
    elif freq != '':
×
1299
        return f'\\cos\\left(2\\pi\\times {freq}t\\right)'
×
1300
    else:
1301
        return '\\cos\\left(2\\pi t\\right)'
×
1302

1303

1304
def _format_ERF(shift, *args):
9✔
1305
    if shift > 0:
×
1306
        return '\\mathrm{erf}(\\frac{t-' + f"{_num_latex(shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1307
    elif shift < 0:
×
1308
        return '\\mathrm{erf}(\\frac{t+' + f"{_num_latex(-shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1309
    else:
1310
        return '\\mathrm{erf}(\\frac{t}{' + f'{args[0]:g}' + '})'
×
1311

1312

1313
def _format_COSH(shift, *args):
9✔
1314
    if shift > 0:
×
1315
        return '\\cosh(\\frac{t-' + f"{_num_latex(shift)}" + '}{' + f'{1/args[0]:g}' + '})'
×
1316
    elif shift < 0:
×
1317
        return '\\cosh(\\frac{t+' + f"{_num_latex(-shift)}" + '}{' + f'{1/args[0]:g}' + '})'
×
1318
    else:
1319
        return '\\cosh(\\frac{t}{' + f'{1/args[0]:g}' + '})'
×
1320

1321

1322
def _format_SINH(shift, *args):
9✔
1323
    if shift > 0:
×
1324
        return '\\sinh(\\frac{t-' + f"{_num_latex(shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1325
    elif shift < 0:
×
1326
        return '\\sinh(\\frac{t+' + f"{_num_latex(-shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1327
    else:
1328
        return '\\sinh(\\frac{t}{' + f'{args[0]:g}' + '})'
×
1329

1330

1331
def _format_EXP(shift, *args):
9✔
1332
    if _num_latex(shift) and shift > 0:
×
1333
        return '\\exp\\left(-' + f'{args[0]:g}' + '\\left(t-' + f"{_num_latex(shift)}" + '\\right)\\right)'
×
1334
    elif _num_latex(-shift) and shift < 0:
×
1335
        return '\\exp\\left(-' + f'{args[0]:g}' + '\\left(t+' + f"{_num_latex(-shift)}" + '\\right)\\right)'
×
1336
    else:
1337
        return '\\exp\\left(-' + f'{args[0]:g}' + 't\\right)'
×
1338

1339

1340
LINEAR = registerBaseFunc(lambda t: t, _format_LINEAR)
9✔
1341
GAUSSIAN = registerBaseFunc(lambda t, std_sq2: np.exp(-(t / std_sq2)**2),
9✔
1342
                            _format_GAUSSIAN)
1343
ERF = registerBaseFunc(lambda t, std_sq2: special.erf(t / std_sq2),
9✔
1344
                       _format_ERF)
1345
COS = registerBaseFunc(lambda t, w: np.cos(w * t), _format_COSINE)
9✔
1346
SINC = registerBaseFunc(lambda t, bw: np.sinc(bw * t), _format_SINC)
9✔
1347
EXP = registerBaseFunc(lambda t, alpha: np.exp(alpha * t), _format_EXP)
9✔
1348
INTERP = registerBaseFunc(lambda t, start, stop, points: np.interp(
9✔
1349
    t, np.linspace(start, stop, len(points)), points))
1350
LINEARCHIRP = registerBaseFunc(lambda t, f0, f1, T, phi0: np.sin(
9✔
1351
    phi0 + 2 * np.pi * ((f1 - f0) / (2 * T) * t**2 + f0 * t)))
1352
EXPONENTIALCHIRP = registerBaseFunc(lambda t, f0, alpha, phi0: np.sin(
9✔
1353
    phi0 + 2 * pi * f0 * (np.exp(alpha * t) - 1) / alpha))
1354
HYPERBOLICCHIRP = registerBaseFunc(lambda t, f0, k, phi0: np.sin(
9✔
1355
    phi0 + 2 * np.pi * f0 / k * np.log(1 + k * t)))
1356
COSH = registerBaseFunc(lambda t, w: np.cosh(w * t), _format_COSH)
9✔
1357
SINH = registerBaseFunc(lambda t, w: np.sinh(w * t), _format_SINH)
9✔
1358

1359

1360
def _drag(t: np.ndarray, t0: float, freq: float, width: float, delta: float,
9✔
1361
          block_freq: float, phase: float):
1362

1363
    o = np.pi / width
×
1364
    Omega_x = np.sin(o * (t - t0))**2
×
1365
    wt = 2 * np.pi * (freq + delta) * t - (2 * np.pi * delta * t0 + phase)
×
1366

1367
    if block_freq is None or block_freq - delta == 0:
×
1368
        return Omega_x * np.cos(wt)
×
1369

1370
    b = 1 / np.pi / 2 / (block_freq - delta)
×
1371
    Omega_y = -b * o * np.sin(2 * o * (t - t0))
×
1372

1373
    return Omega_x * np.cos(wt) + Omega_y * np.sin(wt)
×
1374

1375

1376
def _format_DRAG(shift, *args):
9✔
1377
    return f"DRAG(...)"
×
1378

1379

1380
DRAG = registerBaseFunc(_drag, _format_DRAG)
9✔
1381

1382
# register derivative
1383
registerDerivative(LINEAR, lambda shift, *args: _one)
9✔
1384

1385
registerDerivative(
9✔
1386
    GAUSSIAN, lambda shift, *args: (((((LINEAR, shift),
1387
                                       (GAUSSIAN, *args, shift)), (1, 1)), ),
1388
                                    (-2 / args[0]**2, )))
1389

1390
registerDerivative(
9✔
1391
    ERF, lambda shift, *args: (((((GAUSSIAN, *args, shift), ), (1, )), ),
1392
                               (2 / args[0] / np.sqrt(pi), )))
1393

1394
registerDerivative(
9✔
1395
    COS, lambda shift, *args: (((((COS, args[0], shift - pi / args[0] / 2), ),
1396
                                 (1, )), ), (args[0], )))
1397

1398
registerDerivative(
9✔
1399
    SINC, lambda shift, *args:
1400
    (((((LINEAR, shift), (COS, *args, shift)), (-1, 1)),
1401
      (((LINEAR, shift), (COS, args[0], args[1] - pi / 2, shift)), (-2, 1))),
1402
     (1, -1 / args[0])))
1403

1404
registerDerivative(
9✔
1405
    EXP, lambda shift, *args: (((((EXP, *args, shift), ), (1, )), ),
1406
                               (args[0], )))
1407

1408
registerDerivative(
9✔
1409
    INTERP, lambda shift, start, stop, points:
1410
    (((((INTERP, start, stop, tuple(np.gradient(np.asarray(points))), shift),
1411
        ), (1, )), ), ((len(points) - 1) / (stop - start), )))
1412

1413
registerDerivative(
9✔
1414
    COSH, lambda shift, *args: (((((SINH, *args, shift), ), (1, )), ),
1415
                                (args[0], )))
1416

1417
registerDerivative(
9✔
1418
    SINH, lambda shift, *args: (((((COSH, *args, shift), ), (1, )), ),
1419
                                (args[0], )))
1420

1421

1422
def _d_LINEARCHIRP(shift, f0, f1, T, phi0):
9✔
1423
    tlist = (
×
1424
        (((LINEARCHIRP, f0, f1, T, phi0 + pi / 2, shift), ), (1, )),
1425
        (((LINEAR, shift), (LINEARCHIRP, f0, f1, T, phi0 + pi / 2, shift)),
1426
         (1, 1)),
1427
    )
1428
    alist = (2 * pi * f0, 2 * pi * (f1 - f0) / T)
×
1429

1430
    if f0 == 0:
×
1431
        return tlist[1:], alist[1:]
×
1432
    else:
1433
        return tlist, alist
×
1434

1435

1436
registerDerivative(LINEARCHIRP, _d_LINEARCHIRP)
9✔
1437
registerDerivative(
9✔
1438
    EXPONENTIALCHIRP, lambda shift, f0, alpha, phi0:
1439
    (((((EXP, alpha, shift),
1440
        (EXPONENTIALCHIRP, f0, alpha, phi0 + pi / 2, shift)), (1, 1)), ),
1441
     (2 * pi * f0, )))
1442
registerDerivative(
9✔
1443
    HYPERBOLICCHIRP, lambda shift, f0, k, phi0:
1444
    (((((LINEAR, shift - 1 / k),
1445
        (HYPERBOLICCHIRP, f0, k, phi0 + pi / 2, shift)), (-1, 1)), ),
1446
     (2 * pi * f0, )))
1447

1448

1449
def _D_base(m):
9✔
1450
    Type, *args, shift = m
×
1451
    return _derivativeBaseFunc[Type](shift, *args)
×
1452

1453

1454
def _D(x):
9✔
1455
    if _is_const(x):
×
1456
        return _zero
×
1457
    t_list, v_list = x
×
1458
    if len(v_list) == 1:
×
1459
        (m_list, n_list), v = t_list[0], v_list[0]
×
1460
        if len(m_list) == 1:
×
1461
            m, n = m_list[0], n_list[0]
×
1462
            if n == 1:
×
1463
                return _mul(_D_base(m), _const(v))
×
1464
            else:
1465
                return _mul(((((m, ), (n - 1, )), ), (n * v, )),
×
1466
                            _D(((((m, ), (1, )), ), (1, ))))
1467
        else:
1468
            a = (((m_list[:1], n_list[:1]), ), (v, ))
×
1469
            b = (((m_list[1:], n_list[1:]), ), (1, ))
×
1470
            return _add(_mul(a, _D(b)), _mul(_D(a), b))
×
1471
    else:
1472
        return _add(_D((t_list[:1], v_list[:1])), _D((t_list[1:], v_list[1:])))
×
1473

1474

1475
def D(wav):
9✔
1476
    """derivative
1477
    """
1478
    return Waveform(bounds=wav.bounds, seq=tuple(_D(x) for x in wav.seq))
×
1479

1480

1481
def convolve(a, b):
9✔
1482
    pass
×
1483

1484

1485
def sign():
9✔
1486
    return Waveform(bounds=(0, +inf), seq=(_const(-1), _one))
×
1487

1488

1489
def step(edge, type='erf'):
9✔
1490
    """
1491
    type: "erf", "cos", "linear"
1492
    """
1493
    if edge == 0:
9✔
1494
        return Waveform(bounds=(0, +inf), seq=(_zero, _one))
9✔
1495
    if type == 'cos':
9✔
1496
        rise = _add(_half,
×
1497
                    _mul(_half, _basic_wave(COS, pi / edge, shift=0.5 * edge)))
1498
        return Waveform(bounds=(round(-edge / 2,
×
1499
                                      NDIGITS), round(edge / 2,
1500
                                                      NDIGITS), +inf),
1501
                        seq=(_zero, rise, _one))
1502
    elif type == 'linear':
9✔
1503
        rise = _add(_half, _mul(_const(1 / edge), _basic_wave(LINEAR)))
9✔
1504
        return Waveform(bounds=(round(-edge / 2,
9✔
1505
                                      NDIGITS), round(edge / 2,
1506
                                                      NDIGITS), +inf),
1507
                        seq=(_zero, rise, _one))
1508
    else:
1509
        std_sq2 = edge / 5
9✔
1510
        # rise = _add(_half, _mul(_half, _basic_wave(ERF, std_sq2)))
1511
        rise = ((((), ()), (((ERF, std_sq2, 0), ), (1, ))), (0.5, 0.5))
9✔
1512
        return Waveform(bounds=(-round(edge, NDIGITS), round(edge,
9✔
1513
                                                             NDIGITS), +inf),
1514
                        seq=(_zero, rise, _one))
1515

1516

1517
def square(width, edge=0, type='erf'):
9✔
1518
    if width <= 0:
9✔
1519
        return zero()
×
1520
    if edge == 0:
9✔
1521
        return Waveform(bounds=(round(-0.5 * width,
9✔
1522
                                      NDIGITS), round(0.5 * width,
1523
                                                      NDIGITS), +inf),
1524
                        seq=(_zero, _one, _zero))
1525
    else:
1526
        return ((step(edge, type=type) << width / 2) -
9✔
1527
                (step(edge, type=type) >> width / 2))
1528

1529

1530
def gaussian(width, plateau=0.0):
9✔
1531
    if width <= 0 and plateau <= 0.0:
9✔
1532
        return zero()
×
1533
    # width is two times FWHM
1534
    # std_sq2 = width / (4 * np.sqrt(np.log(2)))
1535
    std_sq2 = width / 3.3302184446307908
9✔
1536
    # std is set to give total pulse area same as a square
1537
    # std_sq2 = width/np.sqrt(np.pi)
1538
    if round(0.5 * plateau, NDIGITS) <= 0.0:
9✔
1539
        return Waveform(bounds=(round(-0.75 * width,
9✔
1540
                                      NDIGITS), round(0.75 * width,
1541
                                                      NDIGITS), +inf),
1542
                        seq=(_zero, _basic_wave(GAUSSIAN, std_sq2), _zero))
1543
    else:
1544
        return Waveform(bounds=(round(-0.75 * width - 0.5 * plateau,
×
1545
                                      NDIGITS), round(-0.5 * plateau, NDIGITS),
1546
                                round(0.5 * plateau, NDIGITS),
1547
                                round(0.75 * width + 0.5 * plateau,
1548
                                      NDIGITS), +inf),
1549
                        seq=(_zero,
1550
                             _basic_wave(GAUSSIAN,
1551
                                         std_sq2,
1552
                                         shift=-0.5 * plateau), _one,
1553
                             _basic_wave(GAUSSIAN,
1554
                                         std_sq2,
1555
                                         shift=0.5 * plateau), _zero))
1556

1557

1558
def cos(w, phi=0):
9✔
1559
    if w == 0:
9✔
1560
        return const(np.cos(phi))
×
1561
    if w < 0:
9✔
1562
        phi = -phi
9✔
1563
        w = -w
9✔
1564
    return Waveform(seq=(_basic_wave(COS, w, shift=-phi / w), ))
9✔
1565

1566

1567
def sin(w, phi=0):
9✔
1568
    if w == 0:
9✔
1569
        return const(np.sin(phi))
×
1570
    if w < 0:
9✔
1571
        phi = -phi + pi
9✔
1572
        w = -w
9✔
1573
    return Waveform(seq=(_basic_wave(COS, w, shift=(pi / 2 - phi) / w), ))
9✔
1574

1575

1576
def exp(alpha):
9✔
1577
    if isinstance(alpha, complex):
9✔
1578
        if alpha.real == 0:
9✔
1579
            return cos(alpha.imag) + 1j * sin(alpha.imag)
×
1580
        else:
1581
            return exp(alpha.real) * (cos(alpha.imag) + 1j * sin(alpha.imag))
9✔
1582
    else:
1583
        return Waveform(seq=(_basic_wave(EXP, alpha), ))
9✔
1584

1585

1586
def sinc(bw):
9✔
1587
    if bw <= 0:
×
1588
        return zero()
×
1589
    width = 100 / bw
×
1590
    return Waveform(bounds=(round(-0.5 * width,
×
1591
                                  NDIGITS), round(0.5 * width, NDIGITS), +inf),
1592
                    seq=(_zero, _basic_wave(SINC, bw), _zero))
1593

1594

1595
def cosPulse(width, plateau=0.0):
9✔
1596
    # cos = _basic_wave(COS, 2*np.pi/width)
1597
    # pulse = _mul(_add(cos, _one), _half)
1598
    if round(0.5 * plateau, NDIGITS) > 0:
9✔
1599
        return square(plateau + 0.5 * width, edge=0.5 * width, type='cos')
×
1600
    if width <= 0:
9✔
1601
        return zero()
×
1602
    pulse = ((((), ()), (((COS, 6.283185307179586 / width, 0), ), (1, ))),
9✔
1603
             (0.5, 0.5))
1604
    return Waveform(bounds=(round(-0.5 * width,
9✔
1605
                                  NDIGITS), round(0.5 * width, NDIGITS), +inf),
1606
                    seq=(_zero, pulse, _zero))
1607

1608

1609
def hanning(width, plateau=0.0):
9✔
1610
    return cosPulse(width, plateau=plateau)
×
1611

1612

1613
def cosh(w):
9✔
1614
    return Waveform(seq=(_basic_wave(COSH, w), ))
×
1615

1616

1617
def sinh(w):
9✔
1618
    return Waveform(seq=(_basic_wave(SINH, w), ))
×
1619

1620

1621
def coshPulse(width, eps=1.0, plateau=0.0):
9✔
1622
    """Cosine hyperbolic pulse with the following im
1623

1624
    pulse edge shape:
1625
            cosh(eps / 2) - cosh(eps * t / T)
1626
    f(t) = -----------------------------------
1627
                  cosh(eps / 2) - 1
1628
    where T is the pulse width and eps is the pulse edge steepness.
1629
    The pulse is defined for t in [-T/2, T/2].
1630

1631
    In case of plateau > 0, the pulse is defined as:
1632
           | f(t + plateau/2)   if t in [-T/2 - plateau/2, - plateau/2]
1633
    g(t) = | 1                  if t in [-plateau/2, plateau/2]
1634
           | f(t - plateau/2)   if t in [plateau/2, T/2 + plateau/2]
1635

1636
    Parameters
1637
    ----------
1638
    width : float
1639
        Pulse width.
1640
    eps : float
1641
        Pulse edge steepness.
1642
    plateau : float
1643
        Pulse plateau.
1644
    """
1645
    if width <= 0 and plateau <= 0:
×
1646
        return zero()
×
1647
    w = eps / width
×
1648
    A = np.cosh(eps / 2)
×
1649

1650
    if plateau == 0.0 or round(-0.5 * plateau, NDIGITS) == round(
×
1651
            0.5 * plateau, NDIGITS):
1652
        pulse = ((((), ()), (((COSH, w, 0), ), (1, ))), (A / (A - 1),
×
1653
                                                         -1 / (A - 1)))
1654
        return Waveform(bounds=(round(-0.5 * width,
×
1655
                                      NDIGITS), round(0.5 * width,
1656
                                                      NDIGITS), +inf),
1657
                        seq=(_zero, pulse, _zero))
1658
    else:
1659
        raising = ((((), ()), (((COSH, w, -0.5 * plateau), ), (1, ))),
×
1660
                   (A / (A - 1), -1 / (A - 1)))
1661
        falling = ((((), ()), (((COSH, w, 0.5 * plateau), ), (1, ))),
×
1662
                   (A / (A - 1), -1 / (A - 1)))
1663
        return Waveform(bounds=(round(-0.5 * width - 0.5 * plateau,
×
1664
                                      NDIGITS), round(-0.5 * plateau, NDIGITS),
1665
                                round(0.5 * plateau, NDIGITS),
1666
                                round(0.5 * width + 0.5 * plateau,
1667
                                      NDIGITS), +inf),
1668
                        seq=(_zero, raising, _one, falling, _zero))
1669

1670

1671
def general_cosine(duration, *arg):
9✔
1672
    wav = zero()
×
1673
    arg = np.asarray(arg)
×
1674
    arg /= arg[::2].sum()
×
1675
    for i, a in enumerate(arg, start=1):
×
1676
        wav += a / 2 * (1 - (-1)**i * cos(i * 2 * pi / duration))
×
1677
    return wav * square(duration)
×
1678

1679

1680
def slepian(duration, *arg):
9✔
1681
    wav = zero()
×
1682
    arg = np.asarray(arg)
×
1683
    arg /= arg[::2].sum()
×
1684
    for i, a in enumerate(arg, start=1):
×
1685
        wav += a / 2 * (1 - (-1)**i * cos(i * 2 * pi / duration))
×
1686
    return wav * square(duration)
×
1687

1688

1689
def _poly(*a):
9✔
1690
    """
1691
    a[0] + a[1] * t + a[2] * t**2 + ...
1692
    """
1693
    t = []
9✔
1694
    amp = []
9✔
1695
    if a[0] != 0:
9✔
1696
        t.append(((), ()))
9✔
1697
        amp.append(a[0])
9✔
1698
    for n, a_ in enumerate(a[1:], start=1):
9✔
1699
        if a_ != 0:
9✔
1700
            t.append((((LINEAR, 0), ), (n, )))
9✔
1701
            amp.append(a_)
9✔
1702
    return tuple(t), tuple(a)
9✔
1703

1704

1705
def poly(a):
9✔
1706
    """
1707
    a[0] + a[1] * t + a[2] * t**2 + ...
1708
    """
1709
    return Waveform(seq=(_poly(*a), ))
9✔
1710

1711

1712
def t():
9✔
1713
    return Waveform(seq=((((LINEAR, 0), ), (1, )), (1, )))
×
1714

1715

1716
def drag(freq, width, plateau=0, delta=0, block_freq=None, phase=0, t0=0):
9✔
1717
    phase += pi * delta * (width + plateau)
×
1718
    if plateau <= 0:
×
1719
        return Waveform(seq=(_zero,
×
1720
                             _basic_wave(DRAG, t0, freq, width, delta,
1721
                                         block_freq, phase), _zero),
1722
                        bounds=(round(t0, NDIGITS), round(t0 + width,
1723
                                                          NDIGITS), +inf))
1724
    elif width <= 0:
×
1725
        w = 2 * pi * (freq + delta)
×
1726
        return Waveform(
×
1727
            seq=(_zero,
1728
                 _basic_wave(COS, w,
1729
                             shift=(phase + 2 * pi * delta * t0) / w), _zero),
1730
            bounds=(round(t0, NDIGITS), round(t0 + plateau, NDIGITS), +inf))
1731
    else:
1732
        w = 2 * pi * (freq + delta)
×
1733
        return Waveform(
×
1734
            seq=(_zero,
1735
                 _basic_wave(DRAG, t0, freq, width, delta, block_freq, phase),
1736
                 _basic_wave(COS, w, shift=(phase + 2 * pi * delta * t0) / w),
1737
                 _basic_wave(DRAG, t0 + plateau, freq, width, delta,
1738
                             block_freq,
1739
                             phase - 2 * pi * delta * plateau), _zero),
1740
            bounds=(round(t0, NDIGITS), round(t0 + width / 2, NDIGITS),
1741
                    round(t0 + width / 2 + plateau,
1742
                          NDIGITS), round(t0 + width + plateau,
1743
                                          NDIGITS), +inf))
1744

1745

1746
def chirp(f0, f1, T, phi0=0, type='linear'):
9✔
1747
    """
1748
    A chirp is a signal in which the frequency increases (up-chirp)
1749
    or decreases (down-chirp) with time. In some sources, the term
1750
    chirp is used interchangeably with sweep signal.
1751

1752
    type: "linear", "exponential", "hyperbolic"
1753
    """
1754
    if f0 == f1:
9✔
1755
        return sin(f0, phi0)
×
1756
    if T <= 0:
9✔
1757
        raise ValueError('T must be positive')
×
1758

1759
    if type == 'linear':
9✔
1760
        # f(t) = f1 * (t/T) + f0 * (1 - t/T)
1761
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
9✔
1762
                        seq=(_zero, _basic_wave(LINEARCHIRP, f0, f1, T,
1763
                                                phi0), _zero))
1764
    elif type in ['exp', 'exponential', 'geometric']:
9✔
1765
        # f(t) = f0 * (f1/f0) ** (t/T)
1766
        if f0 == 0:
9✔
1767
            raise ValueError('f0 must be non-zero')
×
1768
        alpha = np.log(f1 / f0) / T
9✔
1769
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
9✔
1770
                        seq=(_zero,
1771
                             _basic_wave(EXPONENTIALCHIRP, f0, alpha,
1772
                                         phi0), _zero))
1773
    elif type in ['hyperbolic', 'hyp']:
9✔
1774
        # f(t) = f0 * f1 / (f0 * (t/T) + f1 * (1-t/T))
1775
        if f0 * f1 == 0:
9✔
1776
            return const(np.sin(phi0))
×
1777
        k = (f0 - f1) / (f1 * T)
9✔
1778
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
9✔
1779
                        seq=(_zero, _basic_wave(HYPERBOLICCHIRP, f0, k,
1780
                                                phi0), _zero))
1781
    else:
1782
        raise ValueError(f'unknown type {type}')
×
1783

1784

1785
def interp(x, y):
9✔
1786
    seq, bounds = [_zero], [x[0]]
×
1787
    for x1, x2, y1, y2 in zip(x[:-1], x[1:], y[:-1], y[1:]):
×
1788
        if x2 == x1:
×
1789
            continue
×
1790
        seq.append(
×
1791
            _add(
1792
                _mul(_const((y2 - y1) / (x2 - x1)),
1793
                     _basic_wave(LINEAR, shift=x1)), _const(y1)))
1794
        bounds.append(x2)
×
1795
    bounds.append(inf)
×
1796
    seq.append(_zero)
×
1797
    return Waveform(seq=tuple(seq),
×
1798
                    bounds=tuple(round(b, NDIGITS)
1799
                                 for b in bounds)).simplify()
1800

1801

1802
def cut(wav, start=None, stop=None, head=None, tail=None, min=None, max=None):
9✔
1803
    offset = 0
×
1804
    if start is not None and head is not None:
×
1805
        offset = head - wav(np.array([1.0 * start]))[0]
×
1806
    elif stop is not None and tail is not None:
×
1807
        offset = tail - wav(np.array([1.0 * stop]))[0]
×
1808
    wav = wav + offset
×
1809

1810
    if start is not None:
×
1811
        wav = wav * (step(0) >> start)
×
1812
    if stop is not None:
×
1813
        wav = wav * ((1 - step(0)) >> stop)
×
1814
    if min is not None:
×
1815
        wav.min = min
×
1816
    if max is not None:
×
1817
        wav.max = max
×
1818
    return wav
×
1819

1820

1821
def function(fun, *args, start=None, stop=None):
9✔
1822
    TYPEID = registerBaseFunc(fun)
×
1823
    seq = (_basic_wave(TYPEID, *args), )
×
1824
    wav = Waveform(seq=seq)
×
1825
    if start is not None:
×
1826
        wav = wav * (step(0) >> start)
×
1827
    if stop is not None:
×
1828
        wav = wav * ((1 - step(0)) >> stop)
×
1829
    return wav
×
1830

1831

1832
def samplingPoints(start, stop, points):
9✔
1833
    return Waveform(bounds=(round(start, NDIGITS), round(stop, NDIGITS), inf),
×
1834
                    seq=(_zero, _basic_wave(INTERP, start, stop,
1835
                                            tuple(points)), _zero))
1836

1837

1838
def mixing(I,
9✔
1839
           Q=None,
1840
           *,
1841
           phase=0.0,
1842
           freq=0.0,
1843
           ratioIQ=1.0,
1844
           phaseDiff=0.0,
1845
           block_freq=None,
1846
           DRAGScaling=None):
1847
    """SSB or envelope mixing
1848
    """
1849
    if Q is None:
9✔
1850
        I = I
9✔
1851
        Q = zero()
9✔
1852

1853
    w = 2 * pi * freq
9✔
1854
    if freq != 0.0:
9✔
1855
        # SSB mixing
1856
        Iout = I * cos(w, -phase) + Q * sin(w, -phase)
9✔
1857
        Qout = -I * sin(w, -phase + phaseDiff) + Q * cos(w, -phase + phaseDiff)
9✔
1858
    else:
1859
        # envelope mixing
1860
        Iout = I * np.cos(-phase) + Q * np.sin(-phase)
9✔
1861
        Qout = -I * np.sin(-phase) + Q * np.cos(-phase)
9✔
1862

1863
    # apply DRAG
1864
    if block_freq is not None and block_freq != freq:
9✔
1865
        a = block_freq / (block_freq - freq)
×
1866
        b = 1 / (block_freq - freq)
×
1867
        I = a * Iout + b / (2 * pi) * D(Qout)
×
1868
        Q = a * Qout - b / (2 * pi) * D(Iout)
×
1869
        Iout, Qout = I, Q
×
1870
    elif DRAGScaling is not None and DRAGScaling != 0:
9✔
1871
        # 2 * pi * scaling * (freq - block_freq) = 1
1872
        I = (1 - w * DRAGScaling) * Iout - DRAGScaling * D(Qout)
×
1873
        Q = (1 - w * DRAGScaling) * Qout + DRAGScaling * D(Iout)
×
1874
        Iout, Qout = I, Q
×
1875

1876
    Qout = ratioIQ * Qout
9✔
1877

1878
    return Iout, Qout
9✔
1879

1880

1881
__all__ = [
9✔
1882
    'D', 'Waveform', 'chirp', 'const', 'cos', 'cosh', 'coshPulse', 'cosPulse',
1883
    'cut', 'drag', 'exp', 'function', 'gaussian', 'general_cosine', 'hanning',
1884
    'interp', 'mixing', 'one', 'poly', 'registerBaseFunc',
1885
    'registerDerivative', 'samplingPoints', 'sign', 'sin', 'sinc', 'sinh',
1886
    'square', 'step', 't', 'zero'
1887
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc