• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

feihoo87 / waveforms / 7087368175

04 Dec 2023 01:33PM UTC coverage: 43.134% (-0.5%) from 43.641%
7087368175

push

github

feihoo87
fix workflow

7413 of 17186 relevant lines covered (43.13%)

2.58 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.1
/waveforms/waveform.py
1
import pickle
6✔
2
from bisect import bisect_left
6✔
3
from fractions import Fraction
6✔
4
from itertools import chain, product
6✔
5
from math import comb
6✔
6

7
import numpy as np
6✔
8
import scipy.special as special
6✔
9
from numpy import e, inf, pi
6✔
10
from scipy.signal import sosfilt
6✔
11

12
from ._waveform import (_D, NDIGITS, _add, _baseFunc, _baseFunc_latex,
6✔
13
                        _basic_wave, _calc_parts, _const, _half, _is_const,
14
                        _merge_waveform, _mul, _one, _pow, _shift, _zero,
15
                        registerBaseFunc, registerBaseFuncLatex,
16
                        registerDerivative, wave_sum)
17

18

19
def _cos_power_n(x, n):
6✔
20
    _, w, shift = x
6✔
21
    ret = _zero
6✔
22
    for k in range(0, n // 2 + 1):
6✔
23
        if n == 2 * k:
6✔
24
            a = _const(comb(n, k) / 2**n)
6✔
25
            ret = _add(ret, a)
6✔
26
        else:
27
            expr = (((((COS, (n - 2 * k) * w, shift), ), (1, )), ),
6✔
28
                    (comb(n, k) / 2**(n - 1), ))
29
            ret = _add(ret, expr)
6✔
30
    return ret
6✔
31

32

33
def _trigMul_t(x, y, v):
6✔
34
    """cos(a)cos(b) = 0.5*cos(a+b)+0.5*cos(a-b)"""
35
    _, w1, t1 = x
6✔
36
    _, w2, t2 = y
6✔
37
    if w2 > w1:
6✔
38
        t1, t2 = t2, t1
6✔
39
        w1, w2 = w2, w1
6✔
40
    exp1 = (COS, w1 + w2, (w1 * t1 + w2 * t2) / (w1 + w2))
6✔
41
    if w1 == w2:
6✔
42
        c = v * np.cos(w1 * t1 - w2 * t2) / 2
6✔
43
        if c == 0:
6✔
44
            return (((exp1, ), (1, )), ), (0.5 * v, )
×
45
        else:
46
            return (((), ()), ((exp1, ), (1, ))), (c, 0.5 * v)
6✔
47
    else:
48
        exp2 = (COS, w1 - w2, (w1 * t1 - w2 * t2) / (w1 - w2))
6✔
49
        if exp2[1] > exp1[1]:
6✔
50
            exp2, exp1 = exp1, exp2
×
51
        return (((exp2, ), (1, )), ((exp1, ), (1, ))), (0.5 * v, 0.5 * v)
6✔
52

53

54
def _trigMul(x, y):
6✔
55
    if _is_const(x) or _is_const(y):
6✔
56
        return _mul(x, y)
6✔
57
    ret = _zero
6✔
58
    for (t1, t2), (v1, v2) in zip(product(x[0], y[0]), product(x[1], y[1])):
6✔
59
        v = v1 * v2
6✔
60
        tmp = _one
6✔
61
        trig = []
6✔
62
        for mt, n in zip(chain(t1[0], t2[0]), chain(t1[1], t2[1])):
6✔
63
            if mt[0] == COS:
6✔
64
                trig.append(mt)
6✔
65
            else:
66
                tmp = _mul(tmp, ((((mt, ), (n, )), ), (1, )))
×
67
        if len(trig) == 1:
6✔
68
            x = ((((trig[0], ), (1, )), ), (v, ))
6✔
69
            expr = _mul(tmp, x)
6✔
70
        elif len(trig) == 2:
6✔
71
            expr = _trigMul_t(trig[0], trig[1], v)
6✔
72
            expr = _mul(tmp, expr)
6✔
73
        else:
74
            expr = _mul(tmp, _const(v))
×
75
        ret = _add(ret, expr)
6✔
76
    return ret
6✔
77

78

79
def _exp_trig_Reduce(mtlist, v):
6✔
80
    trig = _one
6✔
81
    alpha = 0
6✔
82
    shift = 0
6✔
83
    ml, nl = [], []
6✔
84
    for mt, n in zip(*mtlist):
6✔
85
        if mt[0] == COS:
6✔
86
            trig = _trigMul(trig, _cos_power_n(mt, n))
6✔
87
        elif mt[0] == EXP:
6✔
88
            x = alpha * shift + n * mt[1] * mt[-1]
6✔
89
            alpha += n * mt[1]
6✔
90
            if alpha == 0:
6✔
91
                shift = 0
×
92
            else:
93
                shift = x / alpha
6✔
94
        elif mt[0] == GAUSSIAN and n != 1:
6✔
95
            ml.append((mt[0], mt[1] / np.sqrt(n), mt[2]))
×
96
            nl.append(1)
×
97
        else:
98
            ml.append(mt)
6✔
99
            nl.append(n)
6✔
100
    ret = (((tuple(ml), tuple(nl)), ), (v, ))
6✔
101

102
    if alpha != 0:
6✔
103
        ret = _mul(ret, _basic_wave(EXP, alpha, shift=shift))
6✔
104

105
    return _mul(ret, trig)
6✔
106

107

108
def _get_freq(t):
6✔
109
    t2 = [[], []]
6✔
110
    freq, shift = 0, 0
6✔
111
    for mt, n in zip(*t):
6✔
112
        if mt[0] == COS:
6✔
113
            if freq != 0:
6✔
114
                raise ValueError("run _exp_trig_Reduce first")
×
115
            freq = mt[1]
6✔
116
            shift = mt[-1]
6✔
117
        else:
118
            t2[0].append(mt)
6✔
119
            t2[1].append(n)
6✔
120
    t2 = (tuple(t2[0]), tuple(t2[1]))
6✔
121
    return freq, shift, t2
6✔
122

123

124
def _simplify(expr, eps):
6✔
125
    d = {}
6✔
126
    for t, v in zip(*expr):
6✔
127
        for t, v in zip(*_exp_trig_Reduce(t, v)):
6✔
128
            freq, shift, t = _get_freq(t)
6✔
129
            v_r, v_i, shift_r, shift_i = v.real, v.imag, shift, shift
6✔
130
            if (t, freq) in d:
6✔
131
                v0_r, shift0_r, v0_i, shift0_i = d[(t, freq)]
6✔
132
                if freq == 0:
6✔
133
                    v_r, v_i = v.real + v0_r, v.imag + v0_i
×
134
                else:
135
                    a = v0_r * np.cos(freq * shift0_r) + v_r * np.cos(
6✔
136
                        freq * shift_r)
137
                    b = v0_r * np.sin(freq * shift0_r) + v_r * np.sin(
6✔
138
                        freq * shift_r)
139
                    shift_r = np.arctan2(b, a) / freq
6✔
140
                    v_r = np.sqrt(a**2 + b**2)
6✔
141

142
                    a = v0_i * np.cos(freq * shift0_i) + v_i * np.cos(
6✔
143
                        freq * shift_i)
144
                    b = v0_i * np.sin(freq * shift0_i) + v_i * np.sin(
6✔
145
                        freq * shift_i)
146
                    shift_i = np.arctan2(b, a) / freq
6✔
147
                    v_i = np.sqrt(a**2 + b**2)
6✔
148
            d[(t, freq)] = v_r, shift_r, v_i, shift_i
6✔
149
    ret = _zero
6✔
150
    for (t, freq), (v_r, shift_r, v_i, shift_i) in d.items():
6✔
151
        if freq == 0 and abs(v) >= eps:
6✔
152
            if v_i == 0:
6✔
153
                ret = _add(ret, ((t, ), (v_r, )))
6✔
154
            else:
155
                ret = _add(ret, ((t, ), (v_r + 1j * v_i, )))
×
156
        else:
157
            if abs(v_i) < eps and abs(v_r) < eps:
6✔
158
                continue
×
159
            elif abs(v_i) < eps and abs(v_r) >= eps:
6✔
160
                expr = (((((COS, freq, shift_r), ), (1, )), ), (v_r, ))
6✔
161
            elif abs(v_i) >= eps and abs(v_r) < eps:
6✔
162
                expr = (((((COS, freq, shift_i), ), (1, )), ), (v_i * 1j, ))
×
163
            elif abs(v_i) >= eps and abs(v_r) >= eps:
6✔
164
                expr = (((((COS, freq, shift_r), ), (1, )),
6✔
165
                         (((COS, freq, shift_i), ), (1, ))), (v_r, v_i * 1j))
166
            else:
167
                pass  # Never reach here
×
168

169
            expr = _mul(((t, ), (1, )), expr)
6✔
170
            ret = _add(ret, expr)
6✔
171
    return ret
6✔
172

173

174
def _filter(expr, low, high, eps):
6✔
175
    expr = _simplify(expr, eps)
6✔
176
    ret = _zero
6✔
177
    for t, v in zip(*expr):
6✔
178
        for i, (mt, n) in enumerate(zip(*t)):
6✔
179
            if mt[0] == COS:
6✔
180
                if low <= mt[1] < high:
6✔
181
                    ret = _add(ret, ((t, ), (v, )))
6✔
182
                break
6✔
183
            elif mt[0] == SINC and n == 1:
6✔
184
                pass
×
185
            elif mt[0] == GAUSSIAN and n == 1:
6✔
186
                pass
×
187
        else:
188
            if low <= 0:
×
189
                ret = _add(ret, ((t, ), (v, )))
×
190
    return ret
6✔
191

192

193
def _test_spec_num(num, spec):
6✔
194
    x = Fraction(num / spec).limit_denominator(1000000000)
×
195
    if x.denominator <= 24:
×
196
        return True, x, 1
×
197
    x = Fraction(spec * num).limit_denominator(1000000000)
×
198
    if x.denominator <= 24:
×
199
        return True, x, -1
×
200
    return False, x, 0
×
201

202

203
def _spec_num_latex(num):
6✔
204
    for spec, spec_latex in [(1, ''), (np.sqrt(2), '\\sqrt{2}'),
×
205
                             (np.sqrt(3), '\\sqrt{3}'),
206
                             (np.sqrt(5), '\\sqrt{5}'),
207
                             (np.log(2), '\\log{2}'), (np.log(3), '\\log{3}'),
208
                             (np.log(5), '\\log{5}'), (np.e, 'e'),
209
                             (np.pi, '\\pi'), (np.pi**2, '\\pi^2'),
210
                             (np.sqrt(np.pi), '\\sqrt{\\pi}')]:
211
        flag, x, sign = _test_spec_num(num, spec)
×
212
        if flag:
×
213
            if sign < 0:
×
214
                spec_latex = f"\\frac{{{1}}}{{{spec_latex}}}"
×
215
            if x.denominator == 1:
×
216
                if x.numerator == 1:
×
217
                    return f"{spec_latex}"
×
218
                else:
219
                    return f"{x.numerator:g}{spec_latex}"
×
220
            else:
221
                if x.numerator < 0:
×
222
                    return f"-\\frac{{{-x.numerator}}}{{{x.denominator}}}{spec_latex}"
×
223
                else:
224
                    return f"\\frac{{{x.numerator}}}{{{x.denominator}}}{spec_latex}"
×
225
    return f"{num:g}"
×
226

227

228
def _num_latex(num):
6✔
229
    if num == -np.inf:
×
230
        return r"-\infty"
×
231
    elif num == np.inf:
×
232
        return r"\infty"
×
233
    if num.imag > 0:
×
234
        return f"\\left({_num_latex(num.real)}+{_num_latex(num.imag)}j\\right)"
×
235
    elif num.imag < 0:
×
236
        return f"\\left({_num_latex(num.real)}-{_num_latex(-num.imag)}j\\right)"
×
237
    s = _spec_num_latex(num.real)
×
238
    if s == '' and round(num.real) == 1:
×
239
        return '1'
×
240
    if "e" in s:
×
241
        a, n = s.split("e")
×
242
        n = float(n)
×
243
        s = f"{a} \\times 10^{{{n:g}}}"
×
244
    return s
×
245

246

247
def _fun_latex(fun):
6✔
248
    funID, *args, shift = fun
×
249
    if _baseFunc_latex[funID] is None:
×
250
        shift = _num_latex(shift)
×
251
        if shift == "0":
×
252
            shift = ""
×
253
        elif shift[0] != '-':
×
254
            shift = "+" + shift
×
255
        return r"\mathrm{Func}" + f"{funID}(t{shift}, ...)"
×
256
    return _baseFunc_latex[funID](shift, *args)
×
257

258

259
def _wav_latex(wav):
6✔
260
    from waveforms.waveform import _is_const, _zero
×
261

262
    if wav == _zero:
×
263
        return "0"
×
264
    elif _is_const(wav):
×
265
        return f"{wav[1][0]}"
×
266

267
    sum_expr = []
×
268
    for mul, amp in zip(*wav):
×
269
        if mul == ((), ()):
×
270
            sum_expr.append(_num_latex(amp))
×
271
            continue
×
272
        mul_expr = []
×
273
        amp = _num_latex(amp)
×
274
        if amp != "1":
×
275
            mul_expr.append(amp)
×
276
        for fun, n in zip(*mul):
×
277
            fun_expr = _fun_latex(fun)
×
278
            if n != 1:
×
279
                mul_expr.append(fun_expr + "^{" + f"{n}" + "}")
×
280
            else:
281
                mul_expr.append(fun_expr)
×
282
        sum_expr.append(''.join(mul_expr))
×
283

284
    ret = sum_expr[0]
×
285
    for expr in sum_expr[1:]:
×
286
        if expr[0] == '-':
×
287
            ret += expr
×
288
        else:
289
            ret += "+" + expr
×
290
    return ret
×
291

292

293
class Waveform:
6✔
294
    __slots__ = ('bounds', 'seq', 'max', 'min', 'start', 'stop', 'sample_rate',
6✔
295
                 'filters', 'label')
296

297
    def __init__(self, bounds=(+inf, ), seq=(_zero, ), min=-inf, max=inf):
6✔
298
        self.bounds = bounds
6✔
299
        self.seq = seq
6✔
300
        self.max = max
6✔
301
        self.min = min
6✔
302
        self.start = None
6✔
303
        self.stop = None
6✔
304
        self.sample_rate = None
6✔
305
        self.filters = None
6✔
306
        self.label = None
6✔
307

308
    def _begin(self):
6✔
309
        for i, s in enumerate(self.seq):
×
310
            if s is not _zero:
×
311
                if i == 0:
×
312
                    return -inf
×
313
                return self.bounds[i - 1]
×
314
        return inf
×
315

316
    def _end(self):
6✔
317
        N = len(self.bounds)
×
318
        for i, s in enumerate(self.seq[::-1]):
×
319
            if s is not _zero:
×
320
                if i == 0:
×
321
                    return inf
×
322
                return self.bounds[N - i - 1]
×
323
        return -inf
×
324

325
    @property
6✔
326
    def begin(self):
6✔
327
        if self.start is None:
×
328
            return self._begin()
×
329
        else:
330
            return max(self.start, self._begin())
×
331

332
    @property
6✔
333
    def end(self):
6✔
334
        if self.stop is None:
×
335
            return self._end()
×
336
        else:
337
            return min(self.stop, self._end())
×
338

339
    def sample(self,
6✔
340
               sample_rate=None,
341
               out=None,
342
               chunk_size=None,
343
               function_lib=None,
344
               filters=None):
345
        if sample_rate is None:
6✔
346
            sample_rate = self.sample_rate
6✔
347
        if self.start is None or self.stop is None or sample_rate is None:
6✔
348
            raise ValueError(
×
349
                f'Waveform is not initialized. {self.start=}, {self.stop=}, {sample_rate=}'
350
            )
351
        if filters is None:
6✔
352
            filters = self.filters
6✔
353
        if chunk_size is None:
6✔
354
            x = np.arange(self.start, self.stop, 1 / sample_rate)
6✔
355
            sig = self.__call__(x, out=out, function_lib=function_lib)
6✔
356
            if filters is not None:
6✔
357
                sos, initial = filters
6✔
358
                if initial:
6✔
359
                    sig = sosfilt(sos, sig - initial) + initial
×
360
                else:
361
                    sig = sosfilt(sos, sig)
6✔
362
            return sig
6✔
363
        else:
364
            return self._sample_iter(sample_rate, chunk_size, out,
×
365
                                     function_lib, filters)
366

367
    def _sample_iter(self, sample_rate, chunk_size, out, function_lib,
6✔
368
                     filters):
369
        start = self.start
×
370
        start_n = 0
×
371
        if filters is not None:
×
372
            sos, initial = filters
×
373
            # zi = sosfilt_zi(sos)
374
            zi = np.zeros((sos.shape[0], 2))
×
375
        length = chunk_size / sample_rate
×
376
        while start < self.stop:
×
377
            if start + length > self.stop:
×
378
                length = self.stop - start
×
379
                stop = self.stop
×
380
                size = round((stop - start) * sample_rate)
×
381
            else:
382
                stop = start + length
×
383
                size = chunk_size
×
384
            x = np.linspace(start, stop, size, endpoint=False)
×
385
            if out is not None:
×
386
                if filters is None:
×
387
                    yield self.__call__(x,
×
388
                                        out=out[start_n:],
389
                                        function_lib=function_lib)
390
                else:
391
                    if initial:
×
392
                        sig -= initial
×
393
                    sig, zi = sosfilt(sos,
×
394
                                      self.__call__(x,
395
                                                    function_lib=function_lib),
396
                                      zi=zi)
397
                    if initial:
×
398
                        sig += initial
×
399
                    out[start_n:start_n + size] = sig
×
400
                    yield sig
×
401
            else:
402
                if filters is None:
×
403
                    yield self.__call__(x, function_lib=function_lib)
×
404
                else:
405
                    if initial:
×
406
                        sig -= initial
×
407
                    sig, zi = sosfilt(sos,
×
408
                                      self.__call__(x,
409
                                                    function_lib=function_lib),
410
                                      zi=zi)
411
                    if initial:
×
412
                        sig += initial
×
413
                    yield sig
×
414
            start = stop
×
415
            start_n += chunk_size
×
416

417
    @staticmethod
6✔
418
    def _tolist(bounds, seq, ret=None):
6✔
419
        if ret is None:
6✔
420
            ret = []
×
421
        ret.append(len(bounds))
6✔
422
        for seq, b in zip(seq, bounds):
6✔
423
            ret.append(b)
6✔
424
            tlist, amplist = seq
6✔
425
            ret.append(len(amplist))
6✔
426
            for t, amp in zip(tlist, amplist):
6✔
427
                ret.append(amp)
6✔
428
                mtlist, nlist = t
6✔
429
                ret.append(len(nlist))
6✔
430
                for fun, n in zip(mtlist, nlist):
6✔
431
                    ret.append(n)
6✔
432
                    ret.append(len(fun))
6✔
433
                    ret.extend(fun)
6✔
434
        return ret
6✔
435

436
    @staticmethod
6✔
437
    def _fromlist(l, pos=0):
6✔
438

439
        def _read(l, pos, size):
6✔
440
            try:
6✔
441
                return tuple(l[pos:pos + size]), pos + size
6✔
442
            except:
×
443
                raise ValueError('Invalid waveform format')
×
444

445
        (nseg, ), pos = _read(l, pos, 1)
6✔
446
        bounds = []
6✔
447
        seq = []
6✔
448
        for _ in range(nseg):
6✔
449
            (b, nsum), pos = _read(l, pos, 2)
6✔
450
            bounds.append(b)
6✔
451
            amp = []
6✔
452
            t = []
6✔
453
            for _ in range(nsum):
6✔
454
                (a, nmul), pos = _read(l, pos, 2)
6✔
455
                amp.append(a)
6✔
456
                nlst = []
6✔
457
                mt = []
6✔
458
                for _ in range(nmul):
6✔
459
                    (n, nfun), pos = _read(l, pos, 2)
6✔
460
                    nlst.append(n)
6✔
461
                    fun, pos = _read(l, pos, nfun)
6✔
462
                    mt.append(fun)
6✔
463
                t.append((tuple(mt), tuple(nlst)))
6✔
464
            seq.append((tuple(t), tuple(amp)))
6✔
465

466
        return tuple(bounds), tuple(seq), pos
6✔
467

468
    def tolist(self):
6✔
469
        l = [self.max, self.min, self.start, self.stop, self.sample_rate]
6✔
470
        if self.filters is None:
6✔
471
            l.append(None)
6✔
472
        else:
473
            sos, initial = self.filters
6✔
474
            sos = list(sos.reshape(-1))
6✔
475
            l.append(len(sos))
6✔
476
            l.extend(sos)
6✔
477
            l.append(initial)
6✔
478

479
        return self._tolist(self.bounds, self.seq, l)
6✔
480

481
    @classmethod
6✔
482
    def fromlist(cls, l):
6✔
483
        w = cls()
6✔
484
        pos = 6
6✔
485
        (w.max, w.min, w.start, w.stop, w.sample_rate, sos_size) = l[:pos]
6✔
486
        if sos_size is not None:
6✔
487
            sos = np.array(l[pos:pos + sos_size]).reshape(-1, 6)
6✔
488
            pos += sos_size
6✔
489
            initial = l[pos]
6✔
490
            pos += 1
6✔
491
            w.filters = sos, initial
6✔
492

493
        w.bounds, w.seq, pos = cls._fromlist(l, pos)
6✔
494
        return w
6✔
495

496
    def totree(self):
6✔
497
        if self.filters is None:
6✔
498
            header = (self.max, self.min, self.start, self.stop,
6✔
499
                      self.sample_rate, None)
500
        else:
501
            header = (self.max, self.min, self.start, self.stop,
6✔
502
                      self.sample_rate, self.filters)
503
        body = []
6✔
504

505
        for seq, b in zip(self.seq, self.bounds):
6✔
506
            tlist, amplist = seq
6✔
507
            new_seq = []
6✔
508
            for t, amp in zip(tlist, amplist):
6✔
509
                mtlist, nlist = t
6✔
510
                new_t = []
6✔
511
                for fun, n in zip(mtlist, nlist):
6✔
512
                    new_t.append((n, fun))
6✔
513
                new_seq.append((amp, tuple(new_t)))
6✔
514
            body.append((b, tuple(new_seq)))
6✔
515
        return header, tuple(body)
6✔
516

517
    @staticmethod
6✔
518
    def fromtree(tree):
6✔
519
        w = Waveform()
6✔
520
        header, body = tree
6✔
521

522
        (w.max, w.min, w.start, w.stop, w.sample_rate, w.filters) = header
6✔
523
        bounds = []
6✔
524
        seqs = []
6✔
525
        for b, seq in body:
6✔
526
            bounds.append(b)
6✔
527
            amp_list = []
6✔
528
            t_list = []
6✔
529
            for amp, t in seq:
6✔
530
                amp_list.append(amp)
6✔
531
                n_list = []
6✔
532
                mt_list = []
6✔
533
                for n, mt in t:
6✔
534
                    n_list.append(n)
6✔
535
                    mt_list.append(mt)
6✔
536
                t_list.append((tuple(mt_list), tuple(n_list)))
6✔
537
            seqs.append((tuple(t_list), tuple(amp_list)))
6✔
538
        w.bounds = tuple(bounds)
6✔
539
        w.seq = tuple(seqs)
6✔
540
        return w
6✔
541

542
    def simplify(self, eps=1e-15):
6✔
543
        seq = [_simplify(self.seq[0], eps)]
6✔
544
        bounds = [self.bounds[0]]
6✔
545
        for expr, b in zip(self.seq[1:], self.bounds[1:]):
6✔
546
            expr = _simplify(expr, eps)
6✔
547
            if expr == seq[-1]:
6✔
548
                seq.pop()
×
549
                bounds.pop()
×
550
            seq.append(expr)
6✔
551
            bounds.append(b)
6✔
552
        return Waveform(tuple(bounds), tuple(seq))
6✔
553

554
    def filter(self, low=0, high=inf, eps=1e-15):
6✔
555
        seq = []
6✔
556
        for expr in self.seq:
6✔
557
            seq.append(_filter(expr, low, high, eps))
6✔
558
        return Waveform(self.bounds, tuple(seq))
6✔
559

560
    def _comb(self, other, oper):
6✔
561
        return Waveform(*_merge_waveform(self.bounds, self.seq, other.bounds,
6✔
562
                                         other.seq, oper))
563

564
    def __pow__(self, n):
6✔
565
        return Waveform(self.bounds, tuple(_pow(w, n) for w in self.seq))
6✔
566

567
    def __add__(self, other):
6✔
568
        if isinstance(other, Waveform):
6✔
569
            return self._comb(other, _add)
6✔
570
        else:
571
            return self + const(other)
6✔
572

573
    def __radd__(self, v):
6✔
574
        return const(v) + self
×
575

576
    def __ior__(self, other):
6✔
577
        return self | other
×
578

579
    def __or__(self, other):
6✔
580
        if isinstance(other, (int, float, complex)):
×
581
            other = const(other)
×
582
        w = self.marker + other.marker
×
583

584
        def _or(a, b):
×
585
            if a != _zero or b != _zero:
×
586
                return _one
×
587
            else:
588
                return _zero
×
589

590
        return self._comb(other, _or)
×
591

592
    def __iand__(self, other):
6✔
593
        return self & other
×
594

595
    def __and__(self, other):
6✔
596
        if isinstance(other, (int, float, complex)):
×
597
            other = const(other)
×
598
        w = self.marker + other.marker
×
599

600
        def _and(a, b):
×
601
            if a != _zero and b != _zero:
×
602
                return _one
×
603
            else:
604
                return _zero
×
605

606
        return self._comb(other, _and)
×
607

608
    @property
6✔
609
    def marker(self):
6✔
610
        w = self.simplify()
×
611
        return Waveform(w.bounds,
×
612
                        tuple(_zero if s == _zero else _one for s in w.seq))
613

614
    def mask(self, edge=0):
6✔
615
        w = self.marker
×
616
        in_wave = w.seq[0] == _zero
×
617
        bounds = []
×
618
        seq = []
×
619

620
        if w.seq[0] == _zero:
×
621
            in_wave = False
×
622
            b = w.bounds[0] - edge
×
623
            bounds.append(b)
×
624
            seq.append(_zero)
×
625

626
        for b, s in zip(w.bounds[1:], w.seq[1:]):
×
627
            if not in_wave and s != _zero:
×
628
                in_wave = True
×
629
                bounds.append(b + edge)
×
630
                seq.append(_one)
×
631
            elif in_wave and s == _zero:
×
632
                in_wave = False
×
633
                b = b - edge
×
634
                if b > bounds[-1]:
×
635
                    bounds.append(b)
×
636
                    seq.append(_zero)
×
637
                else:
638
                    bounds.pop()
×
639
                    bounds.append(b)
×
640
        return Waveform(tuple(bounds), tuple(seq))
×
641

642
    def __mul__(self, other):
6✔
643
        if isinstance(other, Waveform):
6✔
644
            return self._comb(other, _mul)
6✔
645
        else:
646
            return self * const(other)
6✔
647

648
    def __rmul__(self, v):
6✔
649
        return const(v) * self
6✔
650

651
    def __truediv__(self, other):
6✔
652
        if isinstance(other, Waveform):
6✔
653
            raise TypeError('division by waveform')
×
654
        else:
655
            return self * const(1 / other)
6✔
656

657
    def __neg__(self):
6✔
658
        return -1 * self
6✔
659

660
    def __sub__(self, other):
6✔
661
        return self + (-other)
6✔
662

663
    def __rsub__(self, v):
6✔
664
        return v + (-self)
×
665

666
    def __rshift__(self, time):
6✔
667
        return Waveform(
6✔
668
            tuple(round(bound + time, NDIGITS) for bound in self.bounds),
669
            tuple(_shift(expr, time) for expr in self.seq))
670

671
    def __lshift__(self, time):
6✔
672
        return self >> (-time)
6✔
673

674
    @staticmethod
6✔
675
    def _merge_parts(
6✔
676
        parts: list[tuple[int, int, np.ndarray | int | float | complex]],
677
        out: list[tuple[int, int, np.ndarray | int | float | complex]]
678
    ) -> list[tuple[int, int, np.ndarray | int | float | complex]]:
679
        # TODO: merge parts
680
        raise NotImplementedError
×
681

682
    @staticmethod
6✔
683
    def _fill_parts(parts, out):
6✔
684
        for start, stop, part in parts:
6✔
685
            out[start:stop] += part
6✔
686

687
    def __call__(self,
6✔
688
                 x,
689
                 frag=False,
690
                 out=None,
691
                 accumulate=False,
692
                 function_lib=None):
693
        if function_lib is None:
6✔
694
            function_lib = _baseFunc
6✔
695
        if isinstance(x, (int, float, complex)):
6✔
696
            return self.__call__(np.array([x]), function_lib=function_lib)[0]
×
697
        parts, dtype = _calc_parts(self.bounds, self.seq, x, function_lib,
6✔
698
                                   self.min, self.max)
699
        if not frag:
6✔
700
            if out is None:
6✔
701
                out = np.zeros_like(x, dtype=dtype)
6✔
702
            elif not accumulate:
×
703
                out *= 0
×
704
            self._fill_parts(parts, out)
6✔
705
        else:
706
            if out is None:
×
707
                return parts
×
708
            else:
709
                if not accumulate:
×
710
                    out.clear()
×
711
                    out.extend(parts)
×
712
                else:
713
                    self._merge_parts(parts, out)
×
714
        return out
6✔
715

716
    def __hash__(self):
6✔
717
        return hash((self.max, self.min, self.start, self.stop,
×
718
                     self.sample_rate, self.bounds, self.seq))
719

720
    def __eq__(self, o: object) -> bool:
6✔
721
        if isinstance(o, (int, float, complex)):
6✔
722
            return self == const(o)
×
723
        elif isinstance(o, Waveform):
6✔
724
            a = self.simplify()
6✔
725
            b = o.simplify()
6✔
726
            return a.seq == b.seq and a.bounds == b.bounds and (
6✔
727
                a.max, a.min, a.start, a.stop) == (b.max, b.min, b.start,
728
                                                   b.stop)
729
        else:
730
            return False
×
731

732
    def _repr_latex_(self):
6✔
733
        parts = []
×
734
        start = -np.inf
×
735
        for end, wav in zip(self.bounds, self.seq):
×
736
            e_str = _wav_latex(wav)
×
737
            start_str = _num_latex(start)
×
738
            end_str = _num_latex(end)
×
739
            parts.append(e_str + r",~~&t\in" + f"({start_str},{end_str}" +
×
740
                         (']' if end < np.inf else ')'))
741
            start = end
×
742
        if len(parts) == 1:
×
743
            expr = ''.join(['f(t)=', *parts[0].split('&')])
×
744
        else:
745
            expr = '\n'.join([
×
746
                r"f(t)=\begin{cases}", (r"\\" + '\n').join(parts),
747
                r"\end{cases}"
748
            ])
749
        return "$$\n{}\n$$".format(expr)
×
750

751
    def _play(self, time_unit, volume=1.0):
6✔
752
        import pyaudio
×
753

754
        CHUNK = 1024
×
755
        RATE = 48000
×
756

757
        dynamic_volume = 1.0
×
758
        amp = 2**15 * 0.999 * volume * dynamic_volume
×
759

760
        p = pyaudio.PyAudio()
×
761
        try:
×
762
            stream = p.open(format=pyaudio.paInt16,
×
763
                            channels=1,
764
                            rate=RATE,
765
                            output=True)
766
            try:
×
767
                for data in self.sample(sample_rate=RATE / time_unit,
×
768
                                        chunk_size=CHUNK):
769
                    lim = np.abs(data).max()
×
770
                    if lim > 0 and dynamic_volume > 1.0 / lim:
×
771
                        dynamic_volume = 1.0 / lim
×
772
                        amp = 2**15 * 0.99 * volume * dynamic_volume
×
773
                    data = (amp * data).astype(np.int16)
×
774
                    stream.write(bytes(data.data))
×
775
            finally:
776
                stream.stop_stream()
×
777
                stream.close()
×
778
        finally:
779
            p.terminate()
×
780

781
    def play(self, time_unit=1, volume=1.0):
6✔
782
        import multiprocessing as mp
×
783
        p = mp.Process(target=self._play,
×
784
                       args=(time_unit, volume),
785
                       daemon=True)
786
        p.start()
×
787

788

789
class WaveVStack(Waveform):
6✔
790

791
    def __init__(self, wlist: list[Waveform] = []):
6✔
792
        self.wlist = [(w.bounds, w.seq) for w in wlist]
6✔
793
        self.start = None
6✔
794
        self.stop = None
6✔
795
        self.sample_rate = None
6✔
796
        self.offset = 0
6✔
797
        self.shift = 0
6✔
798
        self.filters = None
6✔
799
        self.label = None
6✔
800
        self.function_lib = None
6✔
801

802
    def __call__(self, x, frag=False, out=None, function_lib=None):
6✔
803
        assert frag is False, 'WaveVStack does not support frag mode'
6✔
804
        out = np.full_like(x, self.offset, dtype=complex)
6✔
805
        if self.shift != 0:
6✔
806
            x = x - self.shift
6✔
807
        if function_lib is None:
6✔
808
            if self.function_lib is None:
6✔
809
                function_lib = _baseFunc
6✔
810
            else:
811
                function_lib = self.function_lib
×
812
        for bounds, seq in self.wlist:
6✔
813
            parts, dtype = _calc_parts(bounds, seq, x, function_lib)
6✔
814
            self._fill_parts(parts, out)
6✔
815
        return out.real
6✔
816

817
    def tolist(self):
6✔
818
        l = [
6✔
819
            self.start,
820
            self.stop,
821
            self.offset,
822
            self.shift,
823
            self.sample_rate,
824
        ]
825
        if self.filters is None:
6✔
826
            l.append(None)
6✔
827
        else:
828
            sos, initial = self.filters
6✔
829
            sos = list(sos.reshape(-1))
6✔
830
            l.append(len(sos))
6✔
831
            l.extend(sos)
6✔
832
            l.append(initial)
6✔
833
        l.append(len(self.wlist))
6✔
834
        for bounds, seq in self.wlist:
6✔
835
            self._tolist(bounds, seq, l)
6✔
836
        return l
6✔
837

838
    @classmethod
6✔
839
    def fromlist(cls, l):
6✔
840
        w = cls()
6✔
841
        pos = 6
6✔
842
        w.start, w.stop, w.offset, w.shift, w.sample_rate, sos_size = l[:pos]
6✔
843
        if sos_size is not None:
6✔
844
            sos = np.array(l[pos:pos + sos_size]).reshape(-1, 6)
6✔
845
            pos += sos_size
6✔
846
            initial = l[pos]
6✔
847
            pos += 1
6✔
848
            w.filters = sos, initial
6✔
849
        n = l[pos]
6✔
850
        pos += 1
6✔
851
        for _ in range(n):
6✔
852
            bounds, seq, pos = cls._fromlist(l, pos)
6✔
853
            w.wlist.append((bounds, seq))
6✔
854
        return w
6✔
855

856
    def simplify(self, eps=1e-15):
6✔
857
        if not self.wlist:
6✔
858
            return zero()
6✔
859
        bounds, seq = wave_sum(self.wlist)
6✔
860
        wav = Waveform(bounds=bounds, seq=seq)
6✔
861
        if self.offset != 0:
6✔
862
            wav += self.offset
×
863
        if self.shift != 0:
6✔
864
            wav >>= self.shift
×
865
        wav = wav.simplify(eps)
6✔
866
        wav.start = self.start
6✔
867
        wav.stop = self.stop
6✔
868
        wav.sample_rate = self.sample_rate
6✔
869
        return wav
6✔
870

871
    @staticmethod
6✔
872
    def _rshift(wlist, time):
6✔
873
        if time == 0:
×
874
            return wlist
×
875
        return [(tuple(round(bound + time, NDIGITS) for bound in bounds),
×
876
                 tuple(_shift(expr, time) for expr in seq))
877
                for bounds, seq in wlist]
878

879
    def __rshift__(self, time):
6✔
880
        ret = WaveVStack()
6✔
881
        ret.wlist = self.wlist
6✔
882
        ret.sample_rate = self.sample_rate
6✔
883
        ret.start = self.start
6✔
884
        ret.stop = self.stop
6✔
885
        ret.shift = self.shift + time
6✔
886
        ret.offset = self.offset
6✔
887
        return ret
6✔
888

889
    def __add__(self, other):
6✔
890
        ret = WaveVStack()
6✔
891
        ret.wlist.extend(self.wlist)
6✔
892
        if isinstance(other, WaveVStack):
6✔
893
            if other.shift != self.shift:
×
894
                ret.wlist = self._rshift(ret.wlist, self.shift)
×
895
                ret.wlist.extend(self._rshift(other.wlist, other.shift))
×
896
            else:
897
                ret.wlist.extend(other.wlist)
×
898
            ret.offset = self.offset + other.offset
×
899
        elif isinstance(other, Waveform):
6✔
900
            other <<= self.shift
6✔
901
            ret.wlist.append((other.bounds, other.seq))
6✔
902
        else:
903
            # ret.wlist.append(((+inf, ), (_const(1.0 * other), )))
904
            ret.offset += other
6✔
905
        return ret
6✔
906

907
    def __radd__(self, v):
6✔
908
        return self + v
×
909

910
    def __mul__(self, other):
6✔
911
        if isinstance(other, Waveform):
6✔
912
            other = other.simplify() << self.shift
6✔
913
            ret = WaveVStack([Waveform(*w) * other for w in self.wlist])
6✔
914
            if self.offset != 0:
6✔
915
                w = other * self.offset
×
916
                ret.wlist.append((w.bounds, w.seq))
×
917
            return ret
6✔
918
        else:
919
            ret = WaveVStack([Waveform(*w) * other for w in self.wlist])
×
920
            ret.offset = self.offset * other
×
921
            return ret
×
922

923
    def __rmul__(self, v):
6✔
924
        return self * v
×
925

926
    def __eq__(self, other):
6✔
927
        if self.wlist:
×
928
            return False
×
929
        else:
930
            return zero() == other
×
931

932
    def _repr_latex_(self):
6✔
933
        return r"\sum_{i=1}^{" + f"{len(self.wlist)}" + r"}" + r"f_i(t)"
×
934

935
    def __getstate__(self) -> tuple:
6✔
936
        function_lib = self.function_lib
6✔
937
        if function_lib:
6✔
938
            try:
×
939
                import dill
×
940
                function_lib = dill.dumps(function_lib)
×
941
            except:
×
942
                function_lib = None
×
943
        return (self.wlist, self.start, self.stop, self.sample_rate,
6✔
944
                self.offset, self.shift, self.filters, self.label,
945
                function_lib)
946

947
    def __setstate__(self, state: tuple) -> None:
6✔
948
        (self.wlist, self.start, self.stop, self.sample_rate, self.offset,
6✔
949
         self.shift, self.filters, self.label, function_lib) = state
950
        if function_lib:
6✔
951
            try:
×
952
                import dill
×
953
                function_lib = dill.loads(function_lib)
×
954
            except:
×
955
                function_lib = None
×
956
        self.function_lib = function_lib
6✔
957

958

959
def play(data, rate=48000):
6✔
960
    import io
×
961

962
    import pyaudio
×
963

964
    CHUNK = 1024
×
965

966
    max_amp = np.max(np.abs(data))
×
967

968
    if max_amp > 1:
×
969
        data /= max_amp
×
970

971
    data = np.array(2**15 * 0.999 * data, dtype=np.int16)
×
972
    buff = io.BytesIO(data.data)
×
973
    p = pyaudio.PyAudio()
×
974

975
    try:
×
976
        stream = p.open(format=pyaudio.paInt16,
×
977
                        channels=1,
978
                        rate=rate,
979
                        output=True)
980
        try:
×
981
            while True:
×
982
                data = buff.read(CHUNK)
×
983
                if data:
×
984
                    stream.write(data)
×
985
                else:
986
                    break
×
987
        finally:
988
            stream.stop_stream()
×
989
            stream.close()
×
990
    finally:
991
        p.terminate()
×
992

993

994
_zero_waveform = Waveform()
6✔
995
_one_waveform = Waveform(seq=(_one, ))
6✔
996

997

998
def zero():
6✔
999
    return _zero_waveform
6✔
1000

1001

1002
def one():
6✔
1003
    return _one_waveform
×
1004

1005

1006
def const(c):
6✔
1007
    return Waveform(seq=(_const(1.0 * c), ))
6✔
1008

1009

1010
# register base function
1011
def _format_LINEAR(shift, *args):
6✔
1012
    if shift != 0:
×
1013
        shift = _num_latex(-shift)
×
1014
        if shift[0] == '-':
×
1015
            return f"(t{shift})"
×
1016
        else:
1017
            return f"(t+{shift})"
×
1018
    else:
1019
        return 't'
×
1020

1021

1022
def _format_GAUSSIAN(shift, *args):
6✔
1023
    sigma = _num_latex(args[0] / np.sqrt(2))
×
1024
    shift = _num_latex(-shift)
×
1025
    if shift != '0':
×
1026
        if shift[0] != '-':
×
1027
            shift = '+' + shift
×
1028
        if sigma == '1':
×
1029
            return ('\\exp\\left[-\\frac{\\left(t' + shift +
×
1030
                    '\\right)^2}{2}\\right]')
1031
        else:
1032
            return ('\\exp\\left[-\\frac{1}{2}\\left(\\frac{t' + shift + '}{' +
×
1033
                    sigma + '}\\right)^2\\right]')
1034
    else:
1035
        if sigma == '1':
×
1036
            return ('\\exp\\left(-\\frac{t^2}{2}\\right)')
×
1037
        else:
1038
            return ('\\exp\\left[-\\frac{1}{2}\\left(\\frac{t}{' + sigma +
×
1039
                    '}\\right)^2\\right]')
1040

1041

1042
def _format_SINC(shift, *args):
6✔
1043
    shift = _num_latex(-shift)
×
1044
    bw = _num_latex(args[0])
×
1045
    if shift != '0':
×
1046
        if shift[0] != '-':
×
1047
            shift = '+' + shift
×
1048
        if bw == '1':
×
1049
            return '\\mathrm{sinc}(t' + shift + ')'
×
1050
        else:
1051
            return '\\mathrm{sinc}[' + bw + '(t' + shift + ')]'
×
1052
    else:
1053
        if bw == '1':
×
1054
            return '\\mathrm{sinc}(t)'
×
1055
        else:
1056
            return '\\mathrm{sinc}(' + bw + 't)'
×
1057

1058

1059
def _format_COSINE(shift, *args):
6✔
1060
    freq = args[0] / 2 / np.pi
×
1061
    phase = -shift * freq
×
1062
    freq = _num_latex(freq)
×
1063
    if freq == '1':
×
1064
        freq = ''
×
1065
    phase = _num_latex(phase)
×
1066
    if phase == '0':
×
1067
        phase = ''
×
1068
    elif phase[0] != '-':
×
1069
        phase = '+' + phase
×
1070
    if phase != '':
×
1071
        return f'\\cos\\left[2\\pi\\left({freq}t{phase}\\right)\\right]'
×
1072
    elif freq != '':
×
1073
        return f'\\cos\\left(2\\pi\\times {freq}t\\right)'
×
1074
    else:
1075
        return '\\cos\\left(2\\pi t\\right)'
×
1076

1077

1078
def _format_ERF(shift, *args):
6✔
1079
    if shift > 0:
×
1080
        return '\\mathrm{erf}(\\frac{t-' + f"{_num_latex(shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1081
    elif shift < 0:
×
1082
        return '\\mathrm{erf}(\\frac{t+' + f"{_num_latex(-shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1083
    else:
1084
        return '\\mathrm{erf}(\\frac{t}{' + f'{args[0]:g}' + '})'
×
1085

1086

1087
def _format_COSH(shift, *args):
6✔
1088
    if shift > 0:
×
1089
        return '\\cosh(\\frac{t-' + f"{_num_latex(shift)}" + '}{' + f'{1/args[0]:g}' + '})'
×
1090
    elif shift < 0:
×
1091
        return '\\cosh(\\frac{t+' + f"{_num_latex(-shift)}" + '}{' + f'{1/args[0]:g}' + '})'
×
1092
    else:
1093
        return '\\cosh(\\frac{t}{' + f'{1/args[0]:g}' + '})'
×
1094

1095

1096
def _format_SINH(shift, *args):
6✔
1097
    if shift > 0:
×
1098
        return '\\sinh(\\frac{t-' + f"{_num_latex(shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1099
    elif shift < 0:
×
1100
        return '\\sinh(\\frac{t+' + f"{_num_latex(-shift)}" + '}{' + f'{args[0]:g}' + '})'
×
1101
    else:
1102
        return '\\sinh(\\frac{t}{' + f'{args[0]:g}' + '})'
×
1103

1104

1105
def _format_EXP(shift, *args):
6✔
1106
    if _num_latex(shift) and shift > 0:
×
1107
        return '\\exp\\left(-' + f'{args[0]:g}' + '\\left(t-' + f"{_num_latex(shift)}" + '\\right)\\right)'
×
1108
    elif _num_latex(-shift) and shift < 0:
×
1109
        return '\\exp\\left(-' + f'{args[0]:g}' + '\\left(t+' + f"{_num_latex(-shift)}" + '\\right)\\right)'
×
1110
    else:
1111
        return '\\exp\\left(-' + f'{args[0]:g}' + 't\\right)'
×
1112

1113

1114
LINEAR = registerBaseFunc(lambda t: t)
6✔
1115
GAUSSIAN = registerBaseFunc(lambda t, std_sq2: np.exp(-(t / std_sq2)**2))
6✔
1116
ERF = registerBaseFunc(lambda t, std_sq2: special.erf(t / std_sq2))
6✔
1117
COS = registerBaseFunc(lambda t, w: np.cos(w * t))
6✔
1118
SINC = registerBaseFunc(lambda t, bw: np.sinc(bw * t))
6✔
1119
EXP = registerBaseFunc(lambda t, alpha: np.exp(alpha * t))
6✔
1120
INTERP = registerBaseFunc(lambda t, start, stop, points: np.interp(
6✔
1121
    t, np.linspace(start, stop, len(points)), points))
1122
LINEARCHIRP = registerBaseFunc(lambda t, f0, f1, T, phi0: np.sin(
6✔
1123
    phi0 + 2 * np.pi * ((f1 - f0) / (2 * T) * t**2 + f0 * t)))
1124
EXPONENTIALCHIRP = registerBaseFunc(lambda t, f0, alpha, phi0: np.sin(
6✔
1125
    phi0 + 2 * pi * f0 * (np.exp(alpha * t) - 1) / alpha))
1126
HYPERBOLICCHIRP = registerBaseFunc(lambda t, f0, k, phi0: np.sin(
6✔
1127
    phi0 + 2 * np.pi * f0 / k * np.log(1 + k * t)))
1128
COSH = registerBaseFunc(lambda t, w: np.cosh(w * t))
6✔
1129
SINH = registerBaseFunc(lambda t, w: np.sinh(w * t))
6✔
1130

1131
registerBaseFuncLatex(LINEAR, _format_LINEAR)
6✔
1132
registerBaseFuncLatex(GAUSSIAN, _format_GAUSSIAN)
6✔
1133
registerBaseFuncLatex(ERF, _format_ERF)
6✔
1134
registerBaseFuncLatex(COS, _format_COSINE)
6✔
1135
registerBaseFuncLatex(SINC, _format_SINC)
6✔
1136
registerBaseFuncLatex(EXP, _format_EXP)
6✔
1137
registerBaseFuncLatex(COSH, _format_COSH)
6✔
1138
registerBaseFuncLatex(SINH, _format_SINH)
6✔
1139

1140

1141
def _drag(t: np.ndarray, t0: float, freq: float, width: float, delta: float,
6✔
1142
          block_freq: float, phase: float):
1143

1144
    o = np.pi / width
6✔
1145
    Omega_x = np.sin(o * (t - t0))**2
6✔
1146
    wt = 2 * np.pi * (freq + delta) * t - (2 * np.pi * delta * t0 + phase)
6✔
1147

1148
    if block_freq is None or block_freq - delta == 0:
6✔
1149
        return Omega_x * np.cos(wt)
×
1150

1151
    b = 1 / np.pi / 2 / (block_freq - delta)
6✔
1152
    Omega_y = -b * o * np.sin(2 * o * (t - t0))
6✔
1153

1154
    return Omega_x * np.cos(wt) + Omega_y * np.sin(wt)
6✔
1155

1156

1157
def _format_DRAG(shift, *args):
6✔
1158
    return f"DRAG(...)"
×
1159

1160

1161
DRAG = registerBaseFunc(_drag)
6✔
1162
registerBaseFuncLatex(DRAG, _format_DRAG)
6✔
1163

1164
# register derivative
1165
registerDerivative(LINEAR, lambda shift, *args: _one)
6✔
1166

1167
registerDerivative(
6✔
1168
    GAUSSIAN, lambda shift, *args: (((((LINEAR, shift),
1169
                                       (GAUSSIAN, *args, shift)), (1, 1)), ),
1170
                                    (-2 / args[0]**2, )))
1171

1172
registerDerivative(
6✔
1173
    ERF, lambda shift, *args: (((((GAUSSIAN, *args, shift), ), (1, )), ),
1174
                               (2 / args[0] / np.sqrt(pi), )))
1175

1176
registerDerivative(
6✔
1177
    COS, lambda shift, *args: (((((COS, args[0], shift - pi / args[0] / 2), ),
1178
                                 (1, )), ), (args[0], )))
1179

1180
registerDerivative(
6✔
1181
    SINC, lambda shift, *args:
1182
    (((((LINEAR, shift), (COS, *args, shift)), (-1, 1)),
1183
      (((LINEAR, shift), (COS, args[0], args[1] - pi / 2, shift)), (-2, 1))),
1184
     (1, -1 / args[0])))
1185

1186
registerDerivative(
6✔
1187
    EXP, lambda shift, *args: (((((EXP, *args, shift), ), (1, )), ),
1188
                               (args[0], )))
1189

1190
registerDerivative(
6✔
1191
    INTERP, lambda shift, start, stop, points:
1192
    (((((INTERP, start, stop, tuple(np.gradient(np.asarray(points))), shift),
1193
        ), (1, )), ), ((len(points) - 1) / (stop - start), )))
1194

1195
registerDerivative(
6✔
1196
    COSH, lambda shift, *args: (((((SINH, *args, shift), ), (1, )), ),
1197
                                (args[0], )))
1198

1199
registerDerivative(
6✔
1200
    SINH, lambda shift, *args: (((((COSH, *args, shift), ), (1, )), ),
1201
                                (args[0], )))
1202

1203

1204
def _d_LINEARCHIRP(shift, f0, f1, T, phi0):
6✔
1205
    tlist = (
×
1206
        (((LINEARCHIRP, f0, f1, T, phi0 + pi / 2, shift), ), (1, )),
1207
        (((LINEAR, shift), (LINEARCHIRP, f0, f1, T, phi0 + pi / 2, shift)),
1208
         (1, 1)),
1209
    )
1210
    alist = (2 * pi * f0, 2 * pi * (f1 - f0) / T)
×
1211

1212
    if f0 == 0:
×
1213
        return tlist[1:], alist[1:]
×
1214
    else:
1215
        return tlist, alist
×
1216

1217

1218
registerDerivative(LINEARCHIRP, _d_LINEARCHIRP)
6✔
1219
registerDerivative(
6✔
1220
    EXPONENTIALCHIRP, lambda shift, f0, alpha, phi0:
1221
    (((((EXP, alpha, shift),
1222
        (EXPONENTIALCHIRP, f0, alpha, phi0 + pi / 2, shift)), (1, 1)), ),
1223
     (2 * pi * f0, )))
1224
registerDerivative(
6✔
1225
    HYPERBOLICCHIRP, lambda shift, f0, k, phi0:
1226
    (((((LINEAR, shift - 1 / k),
1227
        (HYPERBOLICCHIRP, f0, k, phi0 + pi / 2, shift)), (-1, 1)), ),
1228
     (2 * pi * f0, )))
1229

1230

1231
def D(wav):
6✔
1232
    """derivative
1233
    """
1234
    return Waveform(bounds=wav.bounds, seq=tuple(_D(x) for x in wav.seq))
×
1235

1236

1237
def convolve(a, b):
6✔
1238
    pass
×
1239

1240

1241
def sign():
6✔
1242
    return Waveform(bounds=(0, +inf), seq=(_const(-1), _one))
×
1243

1244

1245
def step(edge, type='erf'):
6✔
1246
    """
1247
    type: "erf", "cos", "linear"
1248
    """
1249
    if edge == 0:
6✔
1250
        return Waveform(bounds=(0, +inf), seq=(_zero, _one))
6✔
1251
    if type == 'cos':
6✔
1252
        rise = _add(_half,
×
1253
                    _mul(_half, _basic_wave(COS, pi / edge, shift=0.5 * edge)))
1254
        return Waveform(bounds=(round(-edge / 2,
×
1255
                                      NDIGITS), round(edge / 2,
1256
                                                      NDIGITS), +inf),
1257
                        seq=(_zero, rise, _one))
1258
    elif type == 'linear':
6✔
1259
        rise = _add(_half, _mul(_const(1 / edge), _basic_wave(LINEAR)))
6✔
1260
        return Waveform(bounds=(round(-edge / 2,
6✔
1261
                                      NDIGITS), round(edge / 2,
1262
                                                      NDIGITS), +inf),
1263
                        seq=(_zero, rise, _one))
1264
    else:
1265
        std_sq2 = edge / 5
6✔
1266
        # rise = _add(_half, _mul(_half, _basic_wave(ERF, std_sq2)))
1267
        rise = ((((), ()), (((ERF, std_sq2, 0), ), (1, ))), (0.5, 0.5))
6✔
1268
        return Waveform(bounds=(-round(edge, NDIGITS), round(edge,
6✔
1269
                                                             NDIGITS), +inf),
1270
                        seq=(_zero, rise, _one))
1271

1272

1273
def square(width, edge=0, type='erf'):
6✔
1274
    if width <= 0:
6✔
1275
        return zero()
×
1276
    if edge == 0:
6✔
1277
        return Waveform(bounds=(round(-0.5 * width,
6✔
1278
                                      NDIGITS), round(0.5 * width,
1279
                                                      NDIGITS), +inf),
1280
                        seq=(_zero, _one, _zero))
1281
    else:
1282
        return ((step(edge, type=type) << width / 2) -
6✔
1283
                (step(edge, type=type) >> width / 2))
1284

1285

1286
def gaussian(width, plateau=0.0):
6✔
1287
    if width <= 0 and plateau <= 0.0:
6✔
1288
        return zero()
×
1289
    # width is two times FWHM
1290
    # std_sq2 = width / (4 * np.sqrt(np.log(2)))
1291
    std_sq2 = width / 3.3302184446307908
6✔
1292
    # std is set to give total pulse area same as a square
1293
    # std_sq2 = width/np.sqrt(np.pi)
1294
    if round(0.5 * plateau, NDIGITS) <= 0.0:
6✔
1295
        return Waveform(bounds=(round(-0.75 * width,
6✔
1296
                                      NDIGITS), round(0.75 * width,
1297
                                                      NDIGITS), +inf),
1298
                        seq=(_zero, _basic_wave(GAUSSIAN, std_sq2), _zero))
1299
    else:
1300
        return Waveform(bounds=(round(-0.75 * width - 0.5 * plateau,
×
1301
                                      NDIGITS), round(-0.5 * plateau, NDIGITS),
1302
                                round(0.5 * plateau, NDIGITS),
1303
                                round(0.75 * width + 0.5 * plateau,
1304
                                      NDIGITS), +inf),
1305
                        seq=(_zero,
1306
                             _basic_wave(GAUSSIAN,
1307
                                         std_sq2,
1308
                                         shift=-0.5 * plateau), _one,
1309
                             _basic_wave(GAUSSIAN,
1310
                                         std_sq2,
1311
                                         shift=0.5 * plateau), _zero))
1312

1313

1314
def cos(w, phi=0):
6✔
1315
    if w == 0:
6✔
1316
        return const(np.cos(phi))
×
1317
    if w < 0:
6✔
1318
        phi = -phi
6✔
1319
        w = -w
6✔
1320
    return Waveform(seq=(_basic_wave(COS, w, shift=-phi / w), ))
6✔
1321

1322

1323
def sin(w, phi=0):
6✔
1324
    if w == 0:
6✔
1325
        return const(np.sin(phi))
×
1326
    if w < 0:
6✔
1327
        phi = -phi + pi
6✔
1328
        w = -w
6✔
1329
    return Waveform(seq=(_basic_wave(COS, w, shift=(pi / 2 - phi) / w), ))
6✔
1330

1331

1332
def exp(alpha):
6✔
1333
    if isinstance(alpha, complex):
6✔
1334
        if alpha.real == 0:
6✔
1335
            return cos(alpha.imag) + 1j * sin(alpha.imag)
×
1336
        else:
1337
            return exp(alpha.real) * (cos(alpha.imag) + 1j * sin(alpha.imag))
6✔
1338
    else:
1339
        return Waveform(seq=(_basic_wave(EXP, alpha), ))
6✔
1340

1341

1342
def sinc(bw):
6✔
1343
    if bw <= 0:
×
1344
        return zero()
×
1345
    width = 100 / bw
×
1346
    return Waveform(bounds=(round(-0.5 * width,
×
1347
                                  NDIGITS), round(0.5 * width, NDIGITS), +inf),
1348
                    seq=(_zero, _basic_wave(SINC, bw), _zero))
1349

1350

1351
def cosPulse(width, plateau=0.0):
6✔
1352
    # cos = _basic_wave(COS, 2*np.pi/width)
1353
    # pulse = _mul(_add(cos, _one), _half)
1354
    if round(0.5 * plateau, NDIGITS) > 0:
6✔
1355
        return square(plateau + 0.5 * width, edge=0.5 * width, type='cos')
×
1356
    if width <= 0:
6✔
1357
        return zero()
×
1358
    pulse = ((((), ()), (((COS, 6.283185307179586 / width, 0), ), (1, ))),
6✔
1359
             (0.5, 0.5))
1360
    return Waveform(bounds=(round(-0.5 * width,
6✔
1361
                                  NDIGITS), round(0.5 * width, NDIGITS), +inf),
1362
                    seq=(_zero, pulse, _zero))
1363

1364

1365
def hanning(width, plateau=0.0):
6✔
1366
    return cosPulse(width, plateau=plateau)
×
1367

1368

1369
def cosh(w):
6✔
1370
    return Waveform(seq=(_basic_wave(COSH, w), ))
×
1371

1372

1373
def sinh(w):
6✔
1374
    return Waveform(seq=(_basic_wave(SINH, w), ))
×
1375

1376

1377
def coshPulse(width, eps=1.0, plateau=0.0):
6✔
1378
    """Cosine hyperbolic pulse with the following im
1379

1380
    pulse edge shape:
1381
            cosh(eps / 2) - cosh(eps * t / T)
1382
    f(t) = -----------------------------------
1383
                  cosh(eps / 2) - 1
1384
    where T is the pulse width and eps is the pulse edge steepness.
1385
    The pulse is defined for t in [-T/2, T/2].
1386

1387
    In case of plateau > 0, the pulse is defined as:
1388
           | f(t + plateau/2)   if t in [-T/2 - plateau/2, - plateau/2]
1389
    g(t) = | 1                  if t in [-plateau/2, plateau/2]
1390
           | f(t - plateau/2)   if t in [plateau/2, T/2 + plateau/2]
1391

1392
    Parameters
1393
    ----------
1394
    width : float
1395
        Pulse width.
1396
    eps : float
1397
        Pulse edge steepness.
1398
    plateau : float
1399
        Pulse plateau.
1400
    """
1401
    if width <= 0 and plateau <= 0:
×
1402
        return zero()
×
1403
    w = eps / width
×
1404
    A = np.cosh(eps / 2)
×
1405

1406
    if plateau == 0.0 or round(-0.5 * plateau, NDIGITS) == round(
×
1407
            0.5 * plateau, NDIGITS):
1408
        pulse = ((((), ()), (((COSH, w, 0), ), (1, ))), (A / (A - 1),
×
1409
                                                         -1 / (A - 1)))
1410
        return Waveform(bounds=(round(-0.5 * width,
×
1411
                                      NDIGITS), round(0.5 * width,
1412
                                                      NDIGITS), +inf),
1413
                        seq=(_zero, pulse, _zero))
1414
    else:
1415
        raising = ((((), ()), (((COSH, w, -0.5 * plateau), ), (1, ))),
×
1416
                   (A / (A - 1), -1 / (A - 1)))
1417
        falling = ((((), ()), (((COSH, w, 0.5 * plateau), ), (1, ))),
×
1418
                   (A / (A - 1), -1 / (A - 1)))
1419
        return Waveform(bounds=(round(-0.5 * width - 0.5 * plateau,
×
1420
                                      NDIGITS), round(-0.5 * plateau, NDIGITS),
1421
                                round(0.5 * plateau, NDIGITS),
1422
                                round(0.5 * width + 0.5 * plateau,
1423
                                      NDIGITS), +inf),
1424
                        seq=(_zero, raising, _one, falling, _zero))
1425

1426

1427
def general_cosine(duration, *arg):
6✔
1428
    wav = zero()
×
1429
    arg = np.asarray(arg)
×
1430
    arg /= arg[::2].sum()
×
1431
    for i, a in enumerate(arg, start=1):
×
1432
        wav += a / 2 * (1 - (-1)**i * cos(i * 2 * pi / duration))
×
1433
    return wav * square(duration)
×
1434

1435

1436
def slepian(duration, *arg):
6✔
1437
    wav = zero()
×
1438
    arg = np.asarray(arg)
×
1439
    arg /= arg[::2].sum()
×
1440
    for i, a in enumerate(arg, start=1):
×
1441
        wav += a / 2 * (1 - (-1)**i * cos(i * 2 * pi / duration))
×
1442
    return wav * square(duration)
×
1443

1444

1445
def _poly(*a):
6✔
1446
    """
1447
    a[0] + a[1] * t + a[2] * t**2 + ...
1448
    """
1449
    t = []
6✔
1450
    amp = []
6✔
1451
    if a[0] != 0:
6✔
1452
        t.append(((), ()))
6✔
1453
        amp.append(a[0])
6✔
1454
    for n, a_ in enumerate(a[1:], start=1):
6✔
1455
        if a_ != 0:
6✔
1456
            t.append((((LINEAR, 0), ), (n, )))
6✔
1457
            amp.append(a_)
6✔
1458
    return tuple(t), tuple(a)
6✔
1459

1460

1461
def poly(a):
6✔
1462
    """
1463
    a[0] + a[1] * t + a[2] * t**2 + ...
1464
    """
1465
    return Waveform(seq=(_poly(*a), ))
6✔
1466

1467

1468
def t():
6✔
1469
    return Waveform(seq=((((LINEAR, 0), ), (1, )), (1, )))
×
1470

1471

1472
def drag(freq, width, plateau=0, delta=0, block_freq=None, phase=0, t0=0):
6✔
1473
    phase += pi * delta * (width + plateau)
6✔
1474
    if plateau <= 0:
6✔
1475
        return Waveform(seq=(_zero,
6✔
1476
                             _basic_wave(DRAG, t0, freq, width, delta,
1477
                                         block_freq, phase), _zero),
1478
                        bounds=(round(t0, NDIGITS), round(t0 + width,
1479
                                                          NDIGITS), +inf))
1480
    elif width <= 0:
×
1481
        w = 2 * pi * (freq + delta)
×
1482
        return Waveform(
×
1483
            seq=(_zero,
1484
                 _basic_wave(COS, w,
1485
                             shift=(phase + 2 * pi * delta * t0) / w), _zero),
1486
            bounds=(round(t0, NDIGITS), round(t0 + plateau, NDIGITS), +inf))
1487
    else:
1488
        w = 2 * pi * (freq + delta)
×
1489
        return Waveform(
×
1490
            seq=(_zero,
1491
                 _basic_wave(DRAG, t0, freq, width, delta, block_freq, phase),
1492
                 _basic_wave(COS, w, shift=(phase + 2 * pi * delta * t0) / w),
1493
                 _basic_wave(DRAG, t0 + plateau, freq, width, delta,
1494
                             block_freq,
1495
                             phase - 2 * pi * delta * plateau), _zero),
1496
            bounds=(round(t0, NDIGITS), round(t0 + width / 2, NDIGITS),
1497
                    round(t0 + width / 2 + plateau,
1498
                          NDIGITS), round(t0 + width + plateau,
1499
                                          NDIGITS), +inf))
1500

1501

1502
def chirp(f0, f1, T, phi0=0, type='linear'):
6✔
1503
    """
1504
    A chirp is a signal in which the frequency increases (up-chirp)
1505
    or decreases (down-chirp) with time. In some sources, the term
1506
    chirp is used interchangeably with sweep signal.
1507

1508
    type: "linear", "exponential", "hyperbolic"
1509
    """
1510
    if f0 == f1:
6✔
1511
        return sin(f0, phi0)
×
1512
    if T <= 0:
6✔
1513
        raise ValueError('T must be positive')
×
1514

1515
    if type == 'linear':
6✔
1516
        # f(t) = f1 * (t/T) + f0 * (1 - t/T)
1517
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
6✔
1518
                        seq=(_zero, _basic_wave(LINEARCHIRP, f0, f1, T,
1519
                                                phi0), _zero))
1520
    elif type in ['exp', 'exponential', 'geometric']:
6✔
1521
        # f(t) = f0 * (f1/f0) ** (t/T)
1522
        if f0 == 0:
6✔
1523
            raise ValueError('f0 must be non-zero')
×
1524
        alpha = np.log(f1 / f0) / T
6✔
1525
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
6✔
1526
                        seq=(_zero,
1527
                             _basic_wave(EXPONENTIALCHIRP, f0, alpha,
1528
                                         phi0), _zero))
1529
    elif type in ['hyperbolic', 'hyp']:
6✔
1530
        # f(t) = f0 * f1 / (f0 * (t/T) + f1 * (1-t/T))
1531
        if f0 * f1 == 0:
6✔
1532
            return const(np.sin(phi0))
×
1533
        k = (f0 - f1) / (f1 * T)
6✔
1534
        return Waveform(bounds=(0, round(T, NDIGITS), +inf),
6✔
1535
                        seq=(_zero, _basic_wave(HYPERBOLICCHIRP, f0, k,
1536
                                                phi0), _zero))
1537
    else:
1538
        raise ValueError(f'unknown type {type}')
×
1539

1540

1541
def interp(x, y):
6✔
1542
    seq, bounds = [_zero], [x[0]]
×
1543
    for x1, x2, y1, y2 in zip(x[:-1], x[1:], y[:-1], y[1:]):
×
1544
        if x2 == x1:
×
1545
            continue
×
1546
        seq.append(
×
1547
            _add(
1548
                _mul(_const((y2 - y1) / (x2 - x1)),
1549
                     _basic_wave(LINEAR, shift=x1)), _const(y1)))
1550
        bounds.append(x2)
×
1551
    bounds.append(inf)
×
1552
    seq.append(_zero)
×
1553
    return Waveform(seq=tuple(seq),
×
1554
                    bounds=tuple(round(b, NDIGITS)
1555
                                 for b in bounds)).simplify()
1556

1557

1558
def cut(wav, start=None, stop=None, head=None, tail=None, min=None, max=None):
6✔
1559
    offset = 0
×
1560
    if start is not None and head is not None:
×
1561
        offset = head - wav(np.array([1.0 * start]))[0]
×
1562
    elif stop is not None and tail is not None:
×
1563
        offset = tail - wav(np.array([1.0 * stop]))[0]
×
1564
    wav = wav + offset
×
1565

1566
    if start is not None:
×
1567
        wav = wav * (step(0) >> start)
×
1568
    if stop is not None:
×
1569
        wav = wav * ((1 - step(0)) >> stop)
×
1570
    if min is not None:
×
1571
        wav.min = min
×
1572
    if max is not None:
×
1573
        wav.max = max
×
1574
    return wav
×
1575

1576

1577
def function(fun, *args, start=None, stop=None):
6✔
1578
    TYPEID = registerBaseFunc(fun)
×
1579
    seq = (_basic_wave(TYPEID, *args), )
×
1580
    wav = Waveform(seq=seq)
×
1581
    if start is not None:
×
1582
        wav = wav * (step(0) >> start)
×
1583
    if stop is not None:
×
1584
        wav = wav * ((1 - step(0)) >> stop)
×
1585
    return wav
×
1586

1587

1588
def samplingPoints(start, stop, points):
6✔
1589
    return Waveform(bounds=(round(start, NDIGITS), round(stop, NDIGITS), inf),
×
1590
                    seq=(_zero, _basic_wave(INTERP, start, stop,
1591
                                            tuple(points)), _zero))
1592

1593

1594
def mixing(I,
6✔
1595
           Q=None,
1596
           *,
1597
           phase=0.0,
1598
           freq=0.0,
1599
           ratioIQ=1.0,
1600
           phaseDiff=0.0,
1601
           block_freq=None,
1602
           DRAGScaling=None):
1603
    """SSB or envelope mixing
1604
    """
1605
    if Q is None:
6✔
1606
        I = I
6✔
1607
        Q = zero()
6✔
1608

1609
    w = 2 * pi * freq
6✔
1610
    if freq != 0.0:
6✔
1611
        # SSB mixing
1612
        Iout = I * cos(w, -phase) + Q * sin(w, -phase)
6✔
1613
        Qout = -I * sin(w, -phase + phaseDiff) + Q * cos(w, -phase + phaseDiff)
6✔
1614
    else:
1615
        # envelope mixing
1616
        Iout = I * np.cos(-phase) + Q * np.sin(-phase)
6✔
1617
        Qout = -I * np.sin(-phase) + Q * np.cos(-phase)
6✔
1618

1619
    # apply DRAG
1620
    if block_freq is not None and block_freq != freq:
6✔
1621
        a = block_freq / (block_freq - freq)
×
1622
        b = 1 / (block_freq - freq)
×
1623
        I = a * Iout + b / (2 * pi) * D(Qout)
×
1624
        Q = a * Qout - b / (2 * pi) * D(Iout)
×
1625
        Iout, Qout = I, Q
×
1626
    elif DRAGScaling is not None and DRAGScaling != 0:
6✔
1627
        # 2 * pi * scaling * (freq - block_freq) = 1
1628
        I = (1 - w * DRAGScaling) * Iout - DRAGScaling * D(Qout)
×
1629
        Q = (1 - w * DRAGScaling) * Qout + DRAGScaling * D(Iout)
×
1630
        Iout, Qout = I, Q
×
1631

1632
    Qout = ratioIQ * Qout
6✔
1633

1634
    return Iout, Qout
6✔
1635

1636

1637
__all__ = [
6✔
1638
    'D', 'Waveform', 'chirp', 'const', 'cos', 'cosh', 'coshPulse', 'cosPulse',
1639
    'cut', 'drag', 'exp', 'function', 'gaussian', 'general_cosine', 'hanning',
1640
    'interp', 'mixing', 'one', 'poly', 'registerBaseFunc',
1641
    'registerDerivative', 'samplingPoints', 'sign', 'sin', 'sinc', 'sinh',
1642
    'square', 'step', 't', 'zero'
1643
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc