• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

qutech / qupulse / 16969114023

14 Aug 2025 03:09PM UTC coverage: 88.801%. First build
16969114023

Pull #920

github

web-flow
Merge 60363e790 into 63c5b4da0
Pull Request #920: Linspace followup waveforms

44 of 45 new or added lines in 6 files covered. (97.78%)

19015 of 21413 relevant lines covered (88.8%)

5.33 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.33
/qupulse/program/waveforms.py
1
# SPDX-FileCopyrightText: 2014-2024 Quantum Technology Group and Chair of Software Engineering, RWTH Aachen University
2
#
3
# SPDX-License-Identifier: LGPL-3.0-or-later
4

5
"""This module contains all waveform classes
6

7
Classes:
8
    - Waveform: An instantiated pulse which can be sampled to a raw voltage value array.
9
"""
10

11
import itertools
6✔
12
import operator
6✔
13
import warnings
6✔
14
from abc import ABCMeta, abstractmethod
6✔
15
from numbers import Real
6✔
16
from typing import (
6✔
17
    AbstractSet, Any, FrozenSet, Iterable, Mapping, NamedTuple, Sequence, Set,
18
    Tuple, Union, cast, Optional, List, Hashable)
19
from weakref import WeakValueDictionary, ref
6✔
20

21
import numpy as np
6✔
22

23
from qupulse import ChannelID
6✔
24
from qupulse.utils.performance import is_monotonic
6✔
25
from qupulse.expressions import ExpressionScalar
6✔
26
from qupulse.pulses.interpolation import InterpolationStrategy
6✔
27
from qupulse.utils import checked_int_cast, isclose
6✔
28
from qupulse.utils.types import TimeType, time_from_float
6✔
29
from qupulse.program.transformation import Transformation
6✔
30
from qupulse.utils import pairwise
6✔
31

32

33
class ConstantFunctionPulseTemplateWarning(UserWarning):
6✔
34
    """  This warning indicates a constant waveform is constructed from a FunctionPulseTemplate """
35
    pass
6✔
36

37

38
__all__ = ["Waveform", "TableWaveform", "TableWaveformEntry", "FunctionWaveform", "SequenceWaveform",
6✔
39
           "MultiChannelWaveform", "RepetitionWaveform", "TransformingWaveform", "ArithmeticWaveform",
40
           "ConstantFunctionPulseTemplateWarning", "ConstantWaveform"]
41

42
PULSE_TO_WAVEFORM_ERROR = None  # error margin in pulse template to waveform conversion
6✔
43

44
#  these are private because there probably will be changes here
45
_ALLOCATION_FUNCTION = np.full_like  # pre_allocated = ALLOCATION_FUNCTION(sample_times, **ALLOCATION_FUNCTION_KWARGS)
6✔
46
_ALLOCATION_FUNCTION_KWARGS = dict(fill_value=np.nan, dtype=float)
6✔
47

48

49
def _to_time_type(duration: Real) -> TimeType:
6✔
50
    if isinstance(duration, TimeType):
6✔
51
        return duration
6✔
52
    else:
53
        return time_from_float(float(duration), absolute_error=PULSE_TO_WAVEFORM_ERROR)
6✔
54

55

56
class Waveform(metaclass=ABCMeta):
6✔
57
    """Represents an instantiated PulseTemplate which can be sampled to retrieve arrays of voltage
58
    values for the hardware."""
59

60
    __sampled_cache = WeakValueDictionary()
6✔
61

62
    __slots__ = ('_duration','_pow_2_divisor')
6✔
63

64
    def __init__(self, duration: TimeType, _pow_2_divisor: int = 0):
6✔
65
        self._duration = duration
6✔
66
        self._pow_2_divisor = _pow_2_divisor
6✔
67

68
    @property
6✔
69
    def duration(self) -> TimeType:
6✔
70
        """The duration of the waveform in time units."""
71
        return self._duration
6✔
72

73
    @abstractmethod
6✔
74
    def unsafe_sample(self,
6✔
75
                      channel: ChannelID,
76
                      sample_times: np.ndarray,
77
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
78
        """Sample the waveform at given sample times.
79

80
        The unsafe means that there are no sanity checks performed. The provided sample times are assumed to be
81
        monotonously increasing and lie in the range of [0, waveform.duration]
82

83
        Args:
84
            sample_times: Times at which this Waveform will be sampled.
85
            output_array: Has to be either None or an array of the same size and type as sample_times. If
86
                not None, the sampled values will be written here and this array will be returned
87
        Result:
88
            The sampled values of this Waveform at the provided sample times. Has the same number of
89
            elements as sample_times.
90
        """
91

92
    def get_sampled(self,
6✔
93
                    channel: ChannelID,
94
                    sample_times: np.ndarray,
95
                    output_array: Union[np.ndarray, None] = None) -> np.ndarray:
96
        """A wrapper to the unsafe_sample method which caches the result. This method enforces the constrains
97
        unsafe_sample expects and caches the result to save memory.
98

99
        Args:
100
            sample_times: Times at which this Waveform will be sampled.
101
            output_array: Has to be either None or an array of the same size and type as sample_times. If an array is
102
                given, the sampled values will be written into the given array and it will be returned. Otherwise, a new
103
                array will be created and cached to save memory.
104

105
        Result:
106
            The sampled values of this Waveform at the provided sample times. Is `output_array` if provided
107
        """
108
        if len(sample_times) == 0:
6✔
109
            if output_array is None:
6✔
110
                return np.zeros_like(sample_times, dtype=float)
6✔
111
            elif len(output_array) == len(sample_times):
6✔
112
                return output_array
6✔
113
            else:
114
                raise ValueError('Output array length and sample time length are different')
6✔
115

116
        if not is_monotonic(sample_times):
6✔
117
            raise ValueError('The sample times are not monotonously increasing')
6✔
118
        if sample_times[0] < 0 or sample_times[-1] > float(self.duration):
6✔
119
            raise ValueError(f'The sample times [{sample_times[0]}, ..., {sample_times[-1]}] are not in the range'
6✔
120
                             f' [0, duration={float(self.duration)}]')
121
        if channel not in self.defined_channels:
6✔
122
            raise KeyError('Channel not defined in this waveform: {}'.format(channel))
6✔
123

124
        constant_value = self.constant_value(channel)
6✔
125
        if constant_value is None:
6✔
126
            if output_array is None:
6✔
127
                # cache the result to save memory
128
                result = self.unsafe_sample(channel, sample_times)
6✔
129
                result.flags.writeable = False
6✔
130
                key = hash(bytes(result))
6✔
131
                if key not in self.__sampled_cache:
6✔
132
                    self.__sampled_cache[key] = result
6✔
133
                return self.__sampled_cache[key]
6✔
134
            else:
135
                if len(output_array) != len(sample_times):
6✔
136
                    raise ValueError('Output array length and sample time length are different')
6✔
137
                # use the user provided memory
138
                return self.unsafe_sample(channel=channel,
6✔
139
                                          sample_times=sample_times,
140
                                          output_array=output_array)
141
        else:
142
            if output_array is None:
6✔
143
                output_array = np.full_like(sample_times, fill_value=constant_value, dtype=float)
6✔
144
            else:
145
                output_array[:] = constant_value
6✔
146
            return output_array
6✔
147

148
    def __hash__(self):
6✔
149
        if self.__class__.__base__ is not Waveform:
6✔
150
            raise NotImplementedError("Waveforms __hash__ and __eq__ implementation requires direct inheritance")
×
151
        return hash(tuple(getattr(self, slot) for slot in self.__slots__)) ^ hash(self._duration)
6✔
152

153
    def __eq__(self, other):
6✔
154
        slots = self.__slots__
6✔
155
        if slots is getattr(other, '__slots__', None):
6✔
156
            return self._duration == other._duration and all(getattr(self, slot) == getattr(other, slot) for slot in slots)
6✔
157
        # The other class might be more lenient
158
        return NotImplemented
6✔
159

160
    @property
6✔
161
    @abstractmethod
6✔
162
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
163
        """The channels this waveform should played on. Use
164
            :func:`~qupulse.pulses.instructions.get_measurement_windows` to get a waveform for a subset of these."""
165

166
    @abstractmethod
6✔
167
    def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform':
6✔
168
        """Unsafe version of :func:`~qupulse.pulses.instructions.get_measurement_windows`."""
169

170
    def get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform':
6✔
171
        """Get a waveform that only describes the channels contained in `channels`.
172

173
        Args:
174
            channels: A channel set the return value should confine to.
175

176
        Raises:
177
            KeyError: If `channels` is not a subset of the waveform's defined channels.
178

179
        Returns:
180
            A waveform with waveform.defined_channels == `channels`
181
        """
182
        if not channels <= self.defined_channels:
6✔
183
            raise KeyError('Channels not defined on waveform: {}'.format(channels))
6✔
184
        if channels == self.defined_channels:
6✔
185
            return self
6✔
186
        return self.unsafe_get_subset_for_channels(channels=channels)
6✔
187

188
    def is_constant(self) -> bool:
6✔
189
        """Convenience function to check if all channels are constant. The result is equal to
190
        `all(waveform.constant_value(ch) is not None for ch in waveform.defined_channels)` but might be more performant.
191

192
        Returns:
193
            True if all channels have constant values.
194
        """
195
        return self.constant_value_dict() is not None
6✔
196

197
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
198
        result = {ch: self.constant_value(ch) for ch in self.defined_channels}
6✔
199
        if None in result.values():
6✔
200
            return None
6✔
201
        else:
202
            return result
6✔
203

204
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
205
        """Checks if the requested channel has a constant value and returns it if so.
206

207
        Guarantee that this assertion passes for every t in waveform duration:
208
        >>> assert waveform.constant_value(channel) is None or waveform.constant_value(t) = waveform.get_sampled(channel, t)
209

210
        Args:
211
            channel: The channel to check
212

213
        Returns:
214
            None if there is no guarantee that the channel is constant. The value otherwise.
215
        """
216
        return None
6✔
217

218
    def __neg__(self):
6✔
219
        return FunctorWaveform.from_functor(self, {ch: np.negative for ch in self.defined_channels})
6✔
220

221
    def __pos__(self):
6✔
222
        return self
6✔
223

224
    def _sort_key_for_channels(self) -> Sequence[Tuple[str, int]]:
6✔
225
        """Makes reproducible sorting by defined channels possible"""
226
        return sorted((ch, 0) if isinstance(ch, str) else ('', ch) for ch in self.defined_channels)
6✔
227

228
    def reversed(self) -> 'Waveform':
6✔
229
        """Returns a reversed version of this waveform."""
230
        # We don't check for constness here because const waveforms are supposed to override this method
231
        return ReversedWaveform(self)
6✔
232

233

234
class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', Real),
6✔
235
                                                           ('v', float),
236
                                                           ('interp', InterpolationStrategy)])):
237
    def __init__(self, t: float, v: float, interp: InterpolationStrategy):
6✔
238
        if not callable(interp):
6✔
239
            raise TypeError('{} is neither callable nor of type InterpolationStrategy'.format(interp))
6✔
240

241
    def __repr__(self):
6✔
242
        return f'{type(self).__name__}(t={self.t!r}, v={self.v!r}, interp={self.interp!r})'
6✔
243

244

245
class TableWaveform(Waveform):
6✔
246
    EntryInInit = Union[TableWaveformEntry, Tuple[float, float, InterpolationStrategy]]
6✔
247

248
    """Waveform obtained from instantiating a TablePulseTemplate."""
6✔
249

250
    __slots__ = ('_table', '_channel_id')
6✔
251

252
    def __init__(self,
6✔
253
                 channel: ChannelID,
254
                 waveform_table: Tuple[TableWaveformEntry, ...]) -> None:
255
        """Create a new TableWaveform instance.
256

257
        Args:
258
            waveform_table: A tuple of instantiated and validated table entries
259
        """
260
        if not isinstance(waveform_table, tuple):
6✔
261
            warnings.warn("Please use a tuple of TableWaveformEntry to construct TableWaveform directly",
×
262
                          category=DeprecationWarning)
263
            waveform_table = self._validate_input(waveform_table)
×
264

265
        super().__init__(duration=_to_time_type(waveform_table[-1].t))
6✔
266

267
        self._table = waveform_table
6✔
268
        self._channel_id = channel
6✔
269

270
    @staticmethod
6✔
271
    def _validate_input(input_waveform_table: Sequence[EntryInInit]) -> Union[Tuple[Real, Real],
6✔
272
                                                                              List[TableWaveformEntry]]:
273
        """ Checks that:
274
         - the time is increasing,
275
         - there are at least two entries
276

277
        Optimizations:
278
          - removes subsequent entries with same time or voltage values.
279
          - checks if the complete waveform is constant. Returns a (duration, value) tuple if this is the case
280

281
        Raises:
282
            ValueError:
283
              - there are less than two entries
284
              - the entries are not ordered in time
285
              - Any time is negative
286
              - The total length is zero
287

288
        Returns:
289
            A list of de-duplicated table entries
290
            OR
291
            A (duration, value) tuple if the waveform is constant
292
        """
293
        # we use an iterator here to avoid duplicate work and be maximally efficient for short tables
294
        # We never use StopIteration to abort iteration. It always signifies an error.
295
        input_iter = iter(input_waveform_table)
6✔
296
        try:
6✔
297
            first_t, first_v, first_interp = next(input_iter)
6✔
298
        except StopIteration:
6✔
299
            raise ValueError("Waveform table mut not be empty")
6✔
300

301
        if first_t != 0.0:
6✔
302
            raise ValueError('First time entry is not zero.')
6✔
303

304
        previous_t = 0.0
6✔
305
        previous_v = first_v
6✔
306
        output_waveform_table = [TableWaveformEntry(0.0, first_v, first_interp)]
6✔
307

308
        try:
6✔
309
            t, v, interp = next(input_iter)
6✔
310
        except StopIteration:
6✔
311
            raise ValueError("Waveform table has less than two entries.")
6✔
312
        if t < 0:
6✔
313
            raise ValueError('Negative time values are not allowed.')
6✔
314

315
        # constant_v is None <=> the waveform is constant until up to the current entry
316
        constant_v = interp.constant_value((previous_t, previous_v), (t, v))
6✔
317

318
        for next_t, next_v, next_interp in input_iter:
6✔
319
            if next_t < t:
6✔
320
                if next_t < 0:
6✔
321
                    raise ValueError('Negative time values are not allowed.')
×
322
                else:
323
                    raise ValueError('Times are not increasing.')
6✔
324

325
            if constant_v is not None and interp.constant_value((t, v), (next_t, next_v)) != constant_v:
6✔
326
                constant_v = None
6✔
327

328
            if (previous_t != t or t != next_t) and (previous_v != v or v != next_v):
6✔
329
                # the time and the value differ both either from the next or the previous
330
                # otherwise we skip the entry
331
                previous_t = t
6✔
332
                previous_v = v
6✔
333
                output_waveform_table.append(TableWaveformEntry(t, v, interp))
6✔
334

335
            t, v, interp = next_t, next_v, next_interp
6✔
336

337
        # Until now, we only checked that the time does not decrease. We require an increase because duration == 0
338
        # waveforms are ill-formed. t is now the time of the last entry.
339
        if t == 0:
6✔
340
            raise ValueError('Last time entry is zero.')
6✔
341

342
        if constant_v is not None:
6✔
343
            # the waveform is constant
344
            return t, constant_v
6✔
345
        else:
346
            # we must still add the last entry to the table
347
            output_waveform_table.append(TableWaveformEntry(t, v, interp))
6✔
348
            return output_waveform_table
6✔
349

350
    def is_constant(self) -> bool:
6✔
351
        # only correct if `from_table` is used
352
        return False
6✔
353

354
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
355
        # only correct if `from_table` is used
356
        return None
6✔
357

358
    @classmethod
6✔
359
    def from_table(cls, channel: ChannelID, table: Sequence[EntryInInit]) -> Union['TableWaveform', 'ConstantWaveform']:
6✔
360
        table = cls._validate_input(table)
6✔
361
        if isinstance(table, tuple):
6✔
362
            duration, amplitude = table
6✔
363
            return ConstantWaveform(duration=duration, amplitude=amplitude, channel=channel)
6✔
364
        else:
365
            return TableWaveform(channel, tuple(table))
6✔
366

367
    @property
6✔
368
    def compare_key(self) -> Any:
6✔
369
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
×
370
                      DeprecationWarning, stacklevel=2)
371
        return self._channel_id, self._table
×
372

373
    def unsafe_sample(self,
6✔
374
                      channel: ChannelID,
375
                      sample_times: np.ndarray,
376
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
377
        if output_array is None:
6✔
378
            output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS)
6✔
379

380
        if PULSE_TO_WAVEFORM_ERROR:
6✔
381
            # we need to replace the last entry's t with self.duration
382
            *entries, last = self._table
×
383
            entries.append(TableWaveformEntry(float(self.duration), last.v, last.interp))
×
384
        else:
385
            entries = self._table
6✔
386

387
        for entry1, entry2 in pairwise(entries):
6✔
388
            indices = slice(sample_times.searchsorted(entry1.t, 'left'),
6✔
389
                            sample_times.searchsorted(entry2.t, 'right'))
390
            output_array[indices] = \
6✔
391
                entry2.interp((float(entry1.t), entry1.v),
392
                              (float(entry2.t), entry2.v),
393
                              sample_times[indices])
394
        return output_array
6✔
395

396
    @property
6✔
397
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
398
        return {self._channel_id}
6✔
399

400
    def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform':
6✔
401
        return self
6✔
402

403
    def __repr__(self):
6✔
404
        return f'{type(self).__name__}(channel={self._channel_id!r}, waveform_table={self._table!r})'
6✔
405

406

407
class ConstantWaveform(Waveform):
6✔
408

409
    # TODO: remove
410
    _is_constant_waveform = True
6✔
411

412
    __slots__ = ('_amplitude', '_channel')
6✔
413

414
    def __init__(self, duration: Real, amplitude: Any, channel: ChannelID):
6✔
415
        """ Create a qupulse waveform corresponding to a ConstantPulseTemplate """
416
        super().__init__(duration=_to_time_type(duration))
6✔
417
        if hasattr(amplitude, 'shape'):
6✔
418
            amplitude = amplitude[()]
6✔
419
            hash(amplitude)
6✔
420
        self._amplitude = amplitude
6✔
421
        self._channel = channel
6✔
422

423
    @classmethod
6✔
424
    def from_mapping(cls, duration: Real, constant_values: Mapping[ChannelID, float]) -> Union['ConstantWaveform',
6✔
425
                                                                                               'MultiChannelWaveform']:
426
        """Construct a ConstantWaveform or a MultiChannelWaveform of ConstantWaveforms with given duration and values"""
427
        assert constant_values
6✔
428
        duration = _to_time_type(duration)
6✔
429
        if len(constant_values) == 1:
6✔
430
            (channel, amplitude), = constant_values.items()
6✔
431
            return cls(duration, amplitude=amplitude, channel=channel)
6✔
432
        else:
433
            return MultiChannelWaveform([cls(duration, amplitude=amplitude, channel=channel)
6✔
434
                                         for channel, amplitude in constant_values.items()])
435

436
    def is_constant(self) -> bool:
6✔
437
        return True
6✔
438

439
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
440
        assert channel == self._channel
6✔
441
        return self._amplitude
6✔
442

443
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
444
        return {self._channel: self._amplitude}
6✔
445

446
    @property
6✔
447
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
448
        """The channels this waveform should played on. Use
449
            :func:`~qupulse.pulses.instructions.get_measurement_windows` to get a waveform for a subset of these."""
450

451
        return {self._channel}
6✔
452

453
    @property
6✔
454
    def compare_key(self) -> Tuple[Any, ...]:
6✔
455
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
×
456
                      DeprecationWarning, stacklevel=2)
457
        return self._duration, self._amplitude, self._channel
×
458

459
    def unsafe_sample(self,
6✔
460
                      channel: ChannelID,
461
                      sample_times: np.ndarray,
462
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
463
        if output_array is None:
6✔
464
            return np.full_like(sample_times, fill_value=self._amplitude, dtype=float)
6✔
465
        else:
466
            output_array[:] = self._amplitude
6✔
467
            return output_array
6✔
468

469
    def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform:
6✔
470
        """Unsafe version of :func:`~qupulse.pulses.instructions.get_measurement_windows`."""
471
        return self
6✔
472

473
    def __repr__(self):
6✔
474
        return f"{type(self).__name__}(duration={self.duration!r}, "\
6✔
475
               f"amplitude={self._amplitude!r}, channel={self._channel!r})"
476

477
    def reversed(self) -> 'Waveform':
6✔
478
        return self
6✔
479

480

481
class FunctionWaveform(Waveform):
6✔
482
    """Waveform obtained from instantiating a FunctionPulseTemplate."""
483

484
    __slots__ = ('_expression', '_channel_id')
6✔
485

486
    def __init__(self, expression: ExpressionScalar,
6✔
487
                 duration: float,
488
                 channel: ChannelID) -> None:
489
        """Creates a new FunctionWaveform instance.
490

491
        Args:
492
            expression: The function represented by this FunctionWaveform
493
                as a mathematical expression where 't' denotes the time variable. It must not have other variables
494
            duration: The duration of the waveform
495
            measurement_windows: A list of measurement windows
496
            channel: The channel this waveform is played on
497
        """
498

499
        if set(expression.variables) - set('t'):
6✔
500
            raise ValueError('FunctionWaveforms may not depend on anything but "t"')
6✔
501
        elif not expression.variables:
6✔
502
            warnings.warn("Constant FunctionWaveform is not recommended as the constant propagation will be suboptimal",
6✔
503
                          category=ConstantFunctionPulseTemplateWarning)
504
        super().__init__(duration=_to_time_type(duration))
6✔
505
        self._expression = expression
6✔
506
        self._channel_id = channel
6✔
507

508
    @classmethod
6✔
509
    def from_expression(cls, expression: ExpressionScalar, duration: float, channel: ChannelID) -> Union['FunctionWaveform', ConstantWaveform]:
6✔
510
        if expression.variables:
6✔
511
            return cls(expression, duration, channel)
6✔
512
        else:
513
            return ConstantWaveform(amplitude=expression.evaluate_numeric(), duration=duration, channel=channel)
6✔
514

515
    def is_constant(self) -> bool:
6✔
516
        # only correct if `from_expression` is used
517
        return False
6✔
518

519
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
520
        # only correct if `from_expression` is used
521
        return None
6✔
522

523
    @property
6✔
524
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
525
        return {self._channel_id}
6✔
526

527
    @property
6✔
528
    def compare_key(self) -> Any:
6✔
529
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
×
530
                      DeprecationWarning, stacklevel=2)
531
        return self._channel_id, self._expression, self._duration
×
532

533
    @property
6✔
534
    def duration(self) -> TimeType:
6✔
535
        return self._duration
6✔
536

537
    def unsafe_sample(self,
6✔
538
                      channel: ChannelID,
539
                      sample_times: np.ndarray,
540
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
541
        evaluated = self._expression.evaluate_numeric(t=sample_times)
6✔
542
        if output_array is None:
6✔
543
            if self._expression.variables:
6✔
544
                return evaluated.astype(float)
6✔
545
            else:
546
                return np.full_like(sample_times, fill_value=float(evaluated), dtype=float)
6✔
547
        else:
548
            output_array[:] = evaluated
6✔
549
            return output_array
6✔
550

551
    def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform:
6✔
552
        return self
6✔
553

554
    def __repr__(self):
6✔
555
        return f"{type(self).__name__}(duration={self.duration!r}, "\
6✔
556
               f"expression={self._expression!r}, channel={self._channel_id!r})"
557

558

559
class SequenceWaveform(Waveform):
6✔
560
    """This class allows putting multiple PulseTemplate together in one waveform on the hardware."""
561

562
    __slots__ = ('_sequenced_waveforms', )
6✔
563

564
    def __init__(self, sub_waveforms: Iterable[Waveform]):
6✔
565
        """Use Waveform.from_sequence for optimal construction
566

567
        :param subwaveforms: All waveforms must have the same defined channels
568
        """
569
        if not sub_waveforms:
6✔
570
            raise ValueError(
6✔
571
                "SequenceWaveform cannot be constructed without channel waveforms."
572
            )
573

574
        # do not fail on iterators although we do not allow them as an argument
575
        sequenced_waveforms = tuple(sub_waveforms)
6✔
576

577
        super().__init__(duration=sum(waveform.duration for waveform in sequenced_waveforms))
6✔
578
        self._sequenced_waveforms = sequenced_waveforms
6✔
579

580
        defined_channels = self._sequenced_waveforms[0].defined_channels
6✔
581
        if not all(waveform.defined_channels == defined_channels
6✔
582
                   for waveform in itertools.islice(self._sequenced_waveforms, 1, None)):
583
            for waveform in self._sequenced_waveforms[1:]:
6✔
584
                 if not waveform.defined_channels == self.defined_channels:
6✔
585
                     print(f"SequenceWaveform: defined channels {self.defined_channels} do not match {waveform.defined_channels} ")
6✔
586
            raise ValueError(
6✔
587
                "SequenceWaveform cannot be constructed from waveforms of different"
588
                "defined channels."
589
            )
590

591
    @classmethod
6✔
592
    def from_sequence(cls, waveforms: Sequence['Waveform']) -> 'Waveform':
6✔
593
        """Returns a waveform the represents the given sequence of waveforms. Applies some optimizations."""
594
        assert waveforms, "Sequence must not be empty"
6✔
595
        if len(waveforms) == 1:
6✔
596
            return waveforms[0]
6✔
597

598
        flattened = []
6✔
599
        constant_values = waveforms[0].constant_value_dict()
6✔
600
        for wf in waveforms:
6✔
601
            if constant_values and constant_values != wf.constant_value_dict():
6✔
602
                constant_values = None
6✔
603
            if isinstance(wf, cls):
6✔
604
                flattened.extend(wf.sequenced_waveforms)
6✔
605
            else:
606
                flattened.append(wf)
6✔
607
        if constant_values is None:
6✔
608
            return cls(sub_waveforms=flattened)
6✔
609
        else:
610
            duration = sum(wf.duration for wf in flattened)
6✔
611
            return ConstantWaveform.from_mapping(duration, constant_values)
6✔
612

613
    def is_constant(self) -> bool:
6✔
614
        # only correct if from_sequence is used for construction
615
        return False
6✔
616

617
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
618
        # only correct if from_sequence is used for construction
619
        return None
6✔
620

621
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
622
        v = None
6✔
623
        for wf in self._sequenced_waveforms:
6✔
624
            wf_cv = wf.constant_value(channel)
6✔
625
            if wf_cv is None:
6✔
626
                return None
6✔
627
            elif wf_cv == v:
6✔
628
                continue
×
629
            elif v is None:
6✔
630
                v = wf_cv
6✔
631
            else:
632
                assert v != wf_cv
6✔
633
                return None
6✔
634
        return v
×
635

636
    @property
6✔
637
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
638
        return self._sequenced_waveforms[0].defined_channels
6✔
639

640
    def unsafe_sample(self,
6✔
641
                      channel: ChannelID,
642
                      sample_times: np.ndarray,
643
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
644
        if output_array is None:
6✔
645
            output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS)
6✔
646
        time = 0
6✔
647
        for subwaveform in self._sequenced_waveforms:
6✔
648
            # before you change anything here, make sure to understand the difference between basic and advanced
649
            # indexing in numpy and their copy/reference behaviour
650
            end = time + subwaveform.duration
6✔
651

652
            indices = slice(*sample_times.searchsorted((float(time), float(end)), 'left'))
6✔
653
            subwaveform.unsafe_sample(channel=channel,
6✔
654
                                      sample_times=sample_times[indices]-np.float64(time),
655
                                      output_array=output_array[indices])
656
            time = end
6✔
657
        return output_array
6✔
658

659
    @property
6✔
660
    def compare_key(self) -> Tuple[Waveform]:
6✔
661
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
×
662
                      DeprecationWarning, stacklevel=2)
663
        return self._sequenced_waveforms
×
664

665
    @property
6✔
666
    def duration(self) -> TimeType:
6✔
667
        return self._duration
6✔
668

669
    def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform':
6✔
670
        return SequenceWaveform.from_sequence([
6✔
671
            sub_waveform.unsafe_get_subset_for_channels(channels & sub_waveform.defined_channels)
672
            for sub_waveform in self._sequenced_waveforms if sub_waveform.defined_channels & channels])
673

674
    @property
6✔
675
    def sequenced_waveforms(self) -> Sequence[Waveform]:
6✔
676
        return self._sequenced_waveforms
6✔
677

678
    def __repr__(self):
6✔
679
        return f"{type(self).__name__}({self._sequenced_waveforms})"
6✔
680

681

682
class MultiChannelWaveform(Waveform):
6✔
683
    """A MultiChannelWaveform is a Waveform object that allows combining arbitrary Waveform objects
684
    to into a single waveform defined for several channels.
685

686
    The number of channels used by the MultiChannelWaveform object is the sum of the channels used
687
    by the Waveform objects it consists of.
688

689
    MultiChannelWaveform allows an arbitrary mapping of channels defined by the Waveforms it
690
    consists of and the channels it defines. For example, if the MultiChannelWaveform consists
691
    of a two Waveform objects A and B which define two channels each, then the channels of the
692
    MultiChannelWaveform may be 0: A.1, 1: B.0, 2: B.1, 3: A.0 where A.0 means channel 0 of Waveform
693
    object A.
694

695
    The following constraints must hold:
696
     - The durations of all Waveform objects must be equal.
697
     - The channel mapping must be sane, i.e., no channel of the MultiChannelWaveform must be
698
        assigned more than one channel of any Waveform object it consists of
699
    """
700

701
    __slots__ = ('_sub_waveforms', '_defined_channels')
6✔
702

703
    def __init__(self, sub_waveforms: List[Waveform]) -> None:
6✔
704
        """Create a new MultiChannelWaveform instance.
705
        Use `MultiChannelWaveform.from_parallel` for optimal construction.
706

707
        Requires a list of subwaveforms in the form (Waveform, List(int)) where the list defines
708
        the channel mapping, i.e., a value y at index x in the list means that channel x of the
709
        subwaveform will be mapped to channel y of this MultiChannelWaveform object.
710

711
        Args:
712
            sub_waveforms: The list of sub waveforms of this
713
                MultiChannelWaveform. List might get sorted!
714
        Raises:
715
            ValueError, if a channel mapping is out of bounds of the channels defined by this
716
                MultiChannelWaveform
717
            ValueError, if several subwaveform channels are assigned to a single channel of this
718
                MultiChannelWaveform
719
            ValueError, if subwaveforms have inconsistent durations
720
        """
721

722
        if not sub_waveforms:
6✔
723
            raise ValueError(
6✔
724
                "MultiChannelWaveform cannot be constructed without channel waveforms."
725
            )
726

727
        # sort the waveforms with their defined channels to make compare key reproducible
728
        if not isinstance(sub_waveforms, list):
6✔
729
            sub_waveforms = list(sub_waveforms)
6✔
730
        sub_waveforms.sort(key=lambda wf: wf._sort_key_for_channels())
6✔
731

732
        super().__init__(duration=sub_waveforms[0].duration)
6✔
733
        self._sub_waveforms = tuple(sub_waveforms)
6✔
734

735
        defined_channels = set()
6✔
736
        for waveform in self._sub_waveforms:
6✔
737
            if waveform.defined_channels & defined_channels:
6✔
738
                raise ValueError('Channel may not be defined in multiple waveforms',
6✔
739
                                 waveform.defined_channels & defined_channels)
740
            defined_channels |= waveform.defined_channels
6✔
741
        self._defined_channels = frozenset(defined_channels)
6✔
742

743
        if not all(isclose(waveform.duration, self.duration) for waveform in self._sub_waveforms[1:]):
6✔
744
            # meaningful error message:
745
            durations = {}
6✔
746

747
            for waveform in self._sub_waveforms:
6✔
748
                for duration, channels in durations.items():
6✔
749
                    if isclose(waveform.duration, duration):
6✔
750
                        channels.update(waveform.defined_channels)
6✔
751
                        break
6✔
752
                else:
753
                    durations[waveform.duration] = set(waveform.defined_channels)
6✔
754

755
            raise ValueError(
6✔
756
                "MultiChannelWaveform cannot be constructed from channel waveforms of different durations.",
757
                durations
758
            )
759

760
    @staticmethod
6✔
761
    def from_parallel(waveforms: Sequence[Waveform]) -> Waveform:
6✔
762
        assert waveforms, "ARgument must not be empty"
6✔
763
        if len(waveforms) == 1:
6✔
764
            return waveforms[0]
6✔
765

766
        # we do not look at constant values here because there is no benefit. We would need to construct a new
767
        # MultiChannelWaveform anyways
768

769
        # avoid unnecessary multi channel nesting
770
        flattened = []
6✔
771
        for waveform in waveforms:
6✔
772
            if isinstance(waveform, MultiChannelWaveform):
6✔
773
                flattened.extend(waveform._sub_waveforms)
6✔
774
            else:
775
                flattened.append(waveform)
6✔
776

777
        return MultiChannelWaveform(flattened)
6✔
778

779
    def is_constant(self) -> bool:
6✔
780
        return all(wf.is_constant() for wf in self._sub_waveforms)
6✔
781

782
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
783
        return self[channel].constant_value(channel)
6✔
784

785
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
786
        d = {}
6✔
787
        for wf in self._sub_waveforms:
6✔
788
            wf_d = wf.constant_value_dict()
6✔
789
            if wf_d is None:
6✔
790
                return None
6✔
791
            else:
792
                d.update(wf_d)
6✔
793
        return d
6✔
794

795
    @property
6✔
796
    def duration(self) -> TimeType:
6✔
797
        return self._sub_waveforms[0].duration
6✔
798

799
    def __getitem__(self, key: ChannelID) -> Waveform:
6✔
800
        for waveform in self._sub_waveforms:
6✔
801
            if key in waveform.defined_channels:
6✔
802
                return waveform
6✔
803
        raise KeyError('Unknown channel ID: {}'.format(key), key)
6✔
804

805
    @property
6✔
806
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
807
        return self._defined_channels
6✔
808

809
    @property
6✔
810
    def compare_key(self) -> Any:
6✔
811
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
×
812
                      DeprecationWarning, stacklevel=2)
813
        return self._sub_waveforms
×
814

815
    def unsafe_sample(self,
6✔
816
                      channel: ChannelID,
817
                      sample_times: np.ndarray,
818
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
819
        return self[channel].unsafe_sample(channel, sample_times, output_array)
6✔
820

821
    def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform':
6✔
822
        relevant_sub_waveforms = [swf for swf in self._sub_waveforms if swf.defined_channels & channels]
6✔
823
        if len(relevant_sub_waveforms) == 1:
6✔
824
            return relevant_sub_waveforms[0].get_subset_for_channels(channels)
6✔
825
        elif len(relevant_sub_waveforms) > 1:
6✔
826
            return MultiChannelWaveform.from_parallel(
6✔
827
                [sub_waveform.get_subset_for_channels(channels & sub_waveform.defined_channels)
828
                 for sub_waveform in relevant_sub_waveforms])
829
        else:
830
            raise KeyError('Unknown channels: {}'.format(channels))
6✔
831

832
    def __repr__(self):
6✔
833
        return f"{type(self).__name__}({self._sub_waveforms!r})"
6✔
834

835

836
class RepetitionWaveform(Waveform):
6✔
837
    """This class allows putting multiple PulseTemplate together in one waveform on the hardware."""
838

839
    __slots__ = ('_body', '_repetition_count')
6✔
840

841
    def __init__(self, body: Waveform, repetition_count: int):
6✔
842
        repetition_count = checked_int_cast(repetition_count)
6✔
843
        if repetition_count < 1 or not isinstance(repetition_count, int):
6✔
844
            raise ValueError('Repetition count must be an integer >0')
6✔
845

846
        super().__init__(duration=body.duration * repetition_count)
6✔
847
        self._body = body
6✔
848
        self._repetition_count = repetition_count
6✔
849

850
    @classmethod
6✔
851
    def from_repetition_count(cls, body: Waveform, repetition_count: int) -> Waveform:
6✔
852
        constant_values = body.constant_value_dict()
6✔
853
        if constant_values is None:
6✔
854
            return RepetitionWaveform(body, repetition_count)
6✔
855
        else:
856
            return ConstantWaveform.from_mapping(body.duration * repetition_count, constant_values)
6✔
857

858
    @property
6✔
859
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
860
        return self._body.defined_channels
6✔
861

862
    def unsafe_sample(self,
6✔
863
                      channel: ChannelID,
864
                      sample_times: np.ndarray,
865
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
866
        if output_array is None:
6✔
867
            output_array = _ALLOCATION_FUNCTION(sample_times, **_ALLOCATION_FUNCTION_KWARGS)
6✔
868
        body_duration = self._body.duration
6✔
869
        time = 0
6✔
870
        for _ in range(self._repetition_count):
6✔
871
            end = time + body_duration
6✔
872
            indices = slice(*sample_times.searchsorted((float(time), float(end)), 'left'))
6✔
873
            self._body.unsafe_sample(channel=channel,
6✔
874
                                     sample_times=sample_times[indices] - float(time),
875
                                     output_array=output_array[indices])
876
            time = end
6✔
877
        return output_array
6✔
878

879
    @property
6✔
880
    def compare_key(self) -> Tuple[Any, int]:
6✔
881
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
×
882
                      DeprecationWarning, stacklevel=2)
883
        return self._body.compare_key, self._repetition_count
×
884

885
    def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform:
6✔
886
        return RepetitionWaveform.from_repetition_count(
6✔
887
            body=self._body.unsafe_get_subset_for_channels(channels),
888
            repetition_count=self._repetition_count)
889

890
    def is_constant(self) -> bool:
6✔
891
        return self._body.is_constant()
6✔
892

893
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
894
        return self._body.constant_value(channel)
6✔
895

896
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
897
        return self._body.constant_value_dict()
6✔
898

899
    def __repr__(self):
6✔
900
        return f"{type(self).__name__}(body={self._body!r}, repetition_count={self._repetition_count!r})"
6✔
901

902

903
class TransformingWaveform(Waveform):
6✔
904
    __slots__ = ('_inner_waveform', '_transformation', '_cached_data', '_cached_times')
6✔
905

906
    def __init__(self, inner_waveform: Waveform, transformation: Transformation):
6✔
907
        """"""
908
        super(TransformingWaveform, self).__init__(duration=inner_waveform.duration)
6✔
909
        self._inner_waveform = inner_waveform
6✔
910
        self._transformation = transformation
6✔
911

912
        # cache data of inner channels based identified and invalidated by the sample times
913
        self._cached_data = None
6✔
914
        self._cached_times = lambda: None
6✔
915

916
    def __hash__(self):
6✔
917
        return hash((self._inner_waveform, self._transformation))
×
918

919
    def __eq__(self, other):
6✔
920
        if getattr(other, '__slots__', None) is self.__slots__:
6✔
921
            return self._inner_waveform == other._inner_waveform and self._transformation == other._transformation
6✔
922
        return NotImplemented
×
923

924
    @classmethod
6✔
925
    def from_transformation(cls, inner_waveform: Waveform, transformation: Transformation) -> Waveform:
6✔
926
        constant_values = inner_waveform.constant_value_dict()
6✔
927

928
        if constant_values is None or not transformation.is_constant_invariant():
6✔
929
            return cls(inner_waveform, transformation)
6✔
930

931
        transformed_constant_values = {key: value for key, value in transformation(0., constant_values).items()}
6✔
932
        return ConstantWaveform.from_mapping(inner_waveform.duration, transformed_constant_values)
6✔
933

934
    def is_constant(self) -> bool:
6✔
935
        # only true if `from_transformation` was used
936
        return False
6✔
937

938
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
939
        # only true if `from_transformation` was used
940
        return None
6✔
941

942
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
943
        if not self._transformation.is_constant_invariant():
6✔
944
            return None
6✔
945
        in_channels = self._transformation.get_input_channels({channel})
6✔
946
        in_values = {ch: self._inner_waveform.constant_value(ch) for ch in in_channels}
6✔
947
        if any(val is None for val in in_values.values()):
6✔
948
            return None
6✔
949
        else:
950
            return self._transformation(0., in_values)[channel]
6✔
951

952
    @property
6✔
953
    def inner_waveform(self) -> Waveform:
6✔
954
        return self._inner_waveform
6✔
955

956
    @property
6✔
957
    def transformation(self) -> Transformation:
6✔
958
        return self._transformation
6✔
959

960
    @property
6✔
961
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
962
        return self.transformation.get_output_channels(self.inner_waveform.defined_channels)
6✔
963

964
    @property
6✔
965
    def compare_key(self) -> Tuple[Waveform, Transformation]:
6✔
966
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
6✔
967
                      DeprecationWarning, stacklevel=2)
968
        return self.inner_waveform, self.transformation
6✔
969

970
    def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'SubsetWaveform':
6✔
971
        return SubsetWaveform(self, channel_subset=channels)
6✔
972

973
    def unsafe_sample(self,
6✔
974
                      channel: ChannelID,
975
                      sample_times: np.ndarray,
976
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
977
        if self._cached_times() is not sample_times:
6✔
978
            self._cached_data = dict()
6✔
979
            self._cached_times = ref(sample_times)
6✔
980

981
        if channel not in self._cached_data:
6✔
982

983
            inner_channels = self.transformation.get_input_channels({channel})
6✔
984

985
            inner_data = {inner_channel: self.inner_waveform.unsafe_sample(inner_channel, sample_times)
6✔
986
                          for inner_channel in inner_channels}
987

988
            outer_data = self.transformation(sample_times, inner_data)
6✔
989
            self._cached_data.update(outer_data)
6✔
990

991
        if output_array is None:
6✔
992
            output_array = self._cached_data[channel]
6✔
993
        else:
994
            output_array[:] = self._cached_data[channel]
6✔
995

996
        return output_array
6✔
997

998

999
class SubsetWaveform(Waveform):
6✔
1000
    __slots__ = ('_inner_waveform', '_channel_subset')
6✔
1001

1002
    def __init__(self, inner_waveform: Waveform, channel_subset: Set[ChannelID]):
6✔
1003
        super().__init__(duration=inner_waveform.duration)
6✔
1004
        self._inner_waveform = inner_waveform
6✔
1005
        self._channel_subset = frozenset(channel_subset)
6✔
1006

1007
    @property
6✔
1008
    def inner_waveform(self) -> Waveform:
6✔
1009
        return self._inner_waveform
6✔
1010

1011
    @property
6✔
1012
    def defined_channels(self) -> FrozenSet[ChannelID]:
6✔
1013
        return self._channel_subset
6✔
1014

1015
    @property
6✔
1016
    def compare_key(self) -> Tuple[frozenset, Waveform]:
6✔
1017
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
6✔
1018
                      DeprecationWarning, stacklevel=2)
1019
        return self.defined_channels, self.inner_waveform
6✔
1020

1021
    def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform:
6✔
1022
        return self.inner_waveform.get_subset_for_channels(channels)
6✔
1023

1024
    def unsafe_sample(self,
6✔
1025
                      channel: ChannelID,
1026
                      sample_times: np.ndarray,
1027
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
1028
        return self.inner_waveform.unsafe_sample(channel, sample_times, output_array)
6✔
1029

1030
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
1031
        d = self._inner_waveform.constant_value_dict()
×
1032
        if d is not None:
×
1033
            return {ch: d[ch] for ch in self._channel_subset}
×
1034

1035
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
1036
        if channel not in self._channel_subset:
6✔
1037
            raise KeyError(channel)
×
1038
        return self._inner_waveform.constant_value(channel)
6✔
1039

1040

1041
class ArithmeticWaveform(Waveform):
6✔
1042
    """Channels only present in one waveform have the operations neutral element on the other."""
1043

1044
    numpy_operator_map = {'+': np.add,
6✔
1045
                          '-': np.subtract}
1046
    operator_map = {'+': operator.add,
6✔
1047
                    '-': operator.sub}
1048

1049
    rhs_only_map = {'+': operator.pos,
6✔
1050
                    '-': operator.neg}
1051
    numpy_rhs_only_map = {'+': np.positive,
6✔
1052
                          '-': np.negative}
1053

1054
    __slots__ = ('_lhs', '_rhs', '_arithmetic_operator')
6✔
1055

1056
    def __init__(self,
6✔
1057
                 lhs: Waveform,
1058
                 arithmetic_operator: str,
1059
                 rhs: Waveform):
1060
        super().__init__(duration=lhs.duration)
6✔
1061
        self._lhs = lhs
6✔
1062
        self._rhs = rhs
6✔
1063
        self._arithmetic_operator = arithmetic_operator
6✔
1064

1065
        assert np.isclose(float(self._lhs.duration), float(self._rhs.duration))
6✔
1066
        assert arithmetic_operator in self.operator_map
6✔
1067

1068
    @classmethod
6✔
1069
    def from_operator(cls, lhs: Waveform, arithmetic_operator: str, rhs: Waveform):
6✔
1070
        # one could optimize rhs_cv to being only created if lhs_cv is not None but this makes the code harder to read
1071
        lhs_cv = lhs.constant_value_dict()
6✔
1072
        rhs_cv = rhs.constant_value_dict()
6✔
1073
        if lhs_cv is None or rhs_cv is None:
6✔
1074
            return cls(lhs, arithmetic_operator, rhs)
6✔
1075

1076
        else:
1077
            constant_values = dict(lhs_cv)
6✔
1078
            op = cls.operator_map[arithmetic_operator]
6✔
1079
            rhs_op = cls.rhs_only_map[arithmetic_operator]
6✔
1080

1081
            for ch, rhs_val in rhs_cv.items():
6✔
1082
                if ch in constant_values:
6✔
1083
                    constant_values[ch] = op(constant_values[ch], rhs_val)
6✔
1084
                else:
1085
                    constant_values[ch] = rhs_op(rhs_val)
6✔
1086

1087
            duration = lhs.duration
6✔
1088
            assert isclose(duration, rhs.duration)
6✔
1089

1090
            return ConstantWaveform.from_mapping(duration, constant_values)
6✔
1091

1092
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
1093
        if channel not in self._rhs.defined_channels:
6✔
1094
            return self._lhs.constant_value(channel)
6✔
1095
        rhs = self._rhs.constant_value(channel)
6✔
1096
        if rhs is None:
6✔
1097
            return None
6✔
1098

1099
        if channel in self._lhs.defined_channels:
6✔
1100
            lhs = self._lhs.constant_value(channel)
6✔
1101
            if lhs is None:
6✔
1102
                return None
6✔
1103

1104
            return self.operator_map[self._arithmetic_operator](lhs, rhs)
6✔
1105
        else:
1106
            return self.rhs_only_map[self._arithmetic_operator](rhs)
6✔
1107

1108
    def is_constant(self) -> bool:
6✔
1109
        # only correct if from_operator is used
1110
        return False
6✔
1111

1112
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
1113
        # only correct if from_operator is used
1114
        return None
6✔
1115

1116
    @property
6✔
1117
    def lhs(self) -> Waveform:
6✔
1118
        return self._lhs
6✔
1119

1120
    @property
6✔
1121
    def rhs(self) -> Waveform:
6✔
1122
        return self._rhs
6✔
1123

1124
    @property
6✔
1125
    def arithmetic_operator(self) -> str:
6✔
1126
        return self._arithmetic_operator
6✔
1127

1128
    @property
6✔
1129
    def duration(self) -> TimeType:
6✔
1130
        return self._lhs.duration
6✔
1131

1132
    @property
6✔
1133
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
1134
        return self._lhs.defined_channels | self._rhs.defined_channels
6✔
1135

1136
    def unsafe_sample(self,
6✔
1137
                      channel: ChannelID,
1138
                      sample_times: np.ndarray,
1139
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
1140
        if channel in self._lhs.defined_channels:
6✔
1141
            lhs = self._lhs.unsafe_sample(channel=channel, sample_times=sample_times, output_array=output_array)
6✔
1142
        else:
1143
            lhs = None
6✔
1144

1145
        if channel in self._rhs.defined_channels:
6✔
1146
            rhs = self._rhs.unsafe_sample(channel=channel, sample_times=sample_times,
6✔
1147
                                          output_array=None if lhs is not None else output_array)
1148
        else:
1149
            rhs = None
6✔
1150

1151
        if rhs is not None and lhs is not None:
6✔
1152
            arithmetic_operator = self.numpy_operator_map[self._arithmetic_operator]
6✔
1153
            if output_array is None:
6✔
1154
                output_array = lhs
6✔
1155
            return arithmetic_operator(lhs, rhs, out=output_array)
6✔
1156

1157
        else:
1158
            if lhs is None:
6✔
1159
                assert rhs is not None, "channel %r not in defined channels (internal bug)" % channel
6✔
1160
                return self.numpy_rhs_only_map[self._arithmetic_operator](rhs, out=output_array)
6✔
1161
            else:
1162
                return lhs
6✔
1163

1164
    def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform:
6✔
1165
        # TODO: optimization possible
1166
        return SubsetWaveform(self, channels)
6✔
1167

1168
    @property
6✔
1169
    def compare_key(self) -> Tuple[str, Waveform, Waveform]:
6✔
1170
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
6✔
1171
                      DeprecationWarning, stacklevel=2)
1172
        return self._arithmetic_operator, self._lhs, self._rhs
6✔
1173

1174

1175
class FunctorWaveform(Waveform):
6✔
1176
    # TODO: Use Protocol to enforce that it accepts second argument has the keyword out
1177
    Functor = callable
6✔
1178

1179
    __slots__ = ('_inner_waveform', '_functor')
6✔
1180

1181
    """Apply a channel wise functor that works inplace to all results. The functor must accept two arguments"""
6✔
1182
    def __init__(self, inner_waveform: Waveform, functor: Mapping[ChannelID, Functor]):
6✔
1183
        super(FunctorWaveform, self).__init__(duration=inner_waveform.duration)
6✔
1184
        self._inner_waveform = inner_waveform
6✔
1185
        self._functor = dict(functor.items())
6✔
1186

1187
        assert set(functor.keys()) == inner_waveform.defined_channels, ("There is no default identity mapping (yet)."
6✔
1188
                                                                        "File an issue on github if you need it.")
1189

1190
    @classmethod
6✔
1191
    def from_functor(cls, inner_waveform: Waveform, functor: Mapping[ChannelID, Functor]):
6✔
1192
        constant_values = inner_waveform.constant_value_dict()
6✔
1193
        if constant_values is None:
6✔
1194
            return FunctorWaveform(inner_waveform, functor)
6✔
1195

1196
        funced_constant_values = {ch: functor[ch](val) for ch, val in constant_values.items()}
6✔
1197
        return ConstantWaveform.from_mapping(inner_waveform.duration, funced_constant_values)
6✔
1198

1199
    def is_constant(self) -> bool:
6✔
1200
        # only correct if `from_functor` was used
1201
        return False
6✔
1202

1203
    def constant_value_dict(self) -> Optional[Mapping[ChannelID, float]]:
6✔
1204
        # only correct if `from_functor` was used
1205
        return None
6✔
1206

1207
    def constant_value(self, channel: ChannelID) -> Optional[float]:
6✔
1208
        inner = self._inner_waveform.constant_value(channel)
6✔
1209
        if inner is None:
6✔
1210
            return None
6✔
1211
        else:
1212
            return self._functor[channel](inner)
6✔
1213

1214
    @property
6✔
1215
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
1216
        return self._inner_waveform.defined_channels
6✔
1217

1218
    def unsafe_sample(self,
6✔
1219
                      channel: ChannelID,
1220
                      sample_times: np.ndarray,
1221
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
1222
        inner_output = self._inner_waveform.unsafe_sample(channel, sample_times, output_array)
6✔
1223
        return self._functor[channel](inner_output, out=inner_output)
6✔
1224

1225
    def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform:
6✔
1226
        return FunctorWaveform.from_functor(
6✔
1227
            self._inner_waveform.unsafe_get_subset_for_channels(channels),
1228
            {ch: self._functor[ch] for ch in channels})
1229

1230
    @property
6✔
1231
    def compare_key(self) -> Tuple[Waveform, FrozenSet]:
6✔
1232
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
6✔
1233
                      DeprecationWarning, stacklevel=2)
1234
        return self._inner_waveform, frozenset(self._functor.items())
6✔
1235

1236

1237
class ReversedWaveform(Waveform):
6✔
1238
    """Reverses the inner waveform in time."""
1239

1240
    __slots__ = ('_inner',)
6✔
1241

1242
    def __init__(self, inner: Waveform):
6✔
1243
        super().__init__(duration=inner.duration)
6✔
1244
        self._inner = inner
6✔
1245

1246
    @classmethod
6✔
1247
    def from_to_reverse(cls, inner: Waveform) -> Waveform:
6✔
1248
        if inner.constant_value_dict():
×
1249
            return inner
×
1250
        else:
1251
            return cls(inner)
×
1252

1253
    def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray,
6✔
1254
                      output_array: Union[np.ndarray, None] = None) -> np.ndarray:
1255
        inner_sample_times = (float(self.duration) - sample_times)[::-1]
6✔
1256
        if output_array is None:
6✔
1257
            return self._inner.unsafe_sample(channel, inner_sample_times, None)[::-1]
6✔
1258
        else:
1259
            inner_output_array = output_array[::-1]
6✔
1260
            inner_output_array = self._inner.unsafe_sample(channel, inner_sample_times, output_array=inner_output_array)
6✔
1261
            if id(inner_output_array.base) not in (id(output_array), id(output_array.base)):
6✔
1262
                # TODO: is there a guarantee by numpy we never end up here?
1263
                output_array[:] = inner_output_array[::-1]
×
1264
            return output_array
6✔
1265

1266
    @property
6✔
1267
    def defined_channels(self) -> AbstractSet[ChannelID]:
6✔
1268
        return self._inner.defined_channels
6✔
1269

1270
    def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'Waveform':
6✔
1271
        return ReversedWaveform.from_to_reverse(self._inner.unsafe_get_subset_for_channels(channels))
×
1272

1273
    @property
6✔
1274
    def compare_key(self) -> Hashable:
6✔
1275
        warnings.warn("Waveform.compare_key is deprecated since 0.11 and will be removed in 0.12",
6✔
1276
                      DeprecationWarning, stacklevel=2)
1277
        return self._inner.compare_key
6✔
1278

1279
    def reversed(self) -> 'Waveform':
6✔
1280
        return self._inner
×
1281

1282
    def __repr__(self):
6✔
NEW
1283
        return f"ReversedWaveform(inner={self._inner!r})"
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc