• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

qutech / qupulse / 21392850829

27 Jan 2026 10:03AM UTC coverage: 88.721%. First build
21392850829

Pull #933

github

web-flow
Merge 1c6516f74 into 2d86c016d
Pull Request #933: Refactor context management for program builders

454 of 485 new or added lines in 26 files covered. (93.61%)

19170 of 21607 relevant lines covered (88.72%)

5.32 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.97
/qupulse/program/linspace.py
1
# SPDX-FileCopyrightText: 2014-2024 Quantum Technology Group and Chair of Software Engineering, RWTH Aachen University
2
#
3
# SPDX-License-Identifier: LGPL-3.0-or-later
4

5
import contextlib
6✔
6
import dataclasses
6✔
7
import warnings
6✔
8

9
import numpy as np
6✔
10
import math
6✔
11
import copy
6✔
12

13
from dataclasses import dataclass
6✔
14
from abc import ABC, abstractmethod
6✔
15
from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, \
6✔
16
    Union, Dict, List, Set, ClassVar, Callable, Any, AbstractSet
17
from collections import OrderedDict
6✔
18

19
from qupulse import ChannelID, MeasurementWindow
6✔
20
from qupulse.parameter_scope import Scope, MappedScope, FrozenDict
6✔
21
from qupulse.program.protocol import (ProgramBuilder, Waveform, BaseProgramBuilder, Program, )
6✔
22
from qupulse.program.values import RepetitionCount, HardwareTime, HardwareVoltage, DynamicLinearValue, TimeType
6✔
23
from qupulse.program.volatile import VolatileRepetitionCount, InefficientVolatility
6✔
24
from qupulse.program.waveforms import TransformingWaveform
6✔
25

26
# this resolution is used to unify increments
27
# the increments themselves remain floats
28
DEFAULT_INCREMENT_RESOLUTION: float = 1e-9
6✔
29

30

31
@dataclass(frozen=True)
6✔
32
class DepKey:
6✔
33
    """The key that identifies how a certain set command depends on iteration indices. The factors are rounded with a
34
    given resolution to be independent on rounding errors.
35

36
    These objects allow backends which support it to track multiple amplitudes at once.
37
    """
38
    factors: Tuple[int, ...]
6✔
39

40
    @classmethod
6✔
41
    def from_voltages(cls, voltages: Sequence[float], resolution: float):
6✔
42
        # remove trailing zeros
43
        while voltages and voltages[-1] == 0:
6✔
44
            voltages = voltages[:-1]
6✔
45
        return cls(tuple(int(round(voltage / resolution)) for voltage in voltages))
6✔
46

47

48
@dataclass
6✔
49
class LinSpaceNode(ABC):
6✔
50
    """AST node for a program that supports linear spacing of set points as well as nested sequencing and repetitions"""
51

52
    @abstractmethod
6✔
53
    def dependencies(self) -> Mapping[int, set[Tuple[float, ...]]]:
6✔
54
        """Returns a mapping from channel indices to the iteration indices dependencies that those channels have inside
55
        this node.
56

57
        Returns:
58
             Mapping from channel indices to the iteration indices dependencies
59
        """
60
        raise NotImplementedError
×
61

62
    def reversed(self, offset: int, lengths: list):
6✔
63
        """Get the time reversed version of this linspace node. Since this is a non-local operation the arguments give
64
        the context.
65

66
        Args:
67
            offset:  Active iterations that are not reserved
68
            lengths: Lengths of the currently active iterations that have to be reversed
69

70
        Returns:
71
            Time reversed version.
72
        """
73
        raise NotImplementedError
×
74

75

76
@dataclass
6✔
77
class LinSpaceHold(LinSpaceNode):
6✔
78
    """Hold voltages for a given time. The voltages and the time may depend on the iteration index."""
79

80
    bases: Tuple[float, ...]
6✔
81
    factors: Tuple[Optional[Tuple[float, ...]], ...]
6✔
82

83
    duration_base: TimeType
6✔
84
    duration_factors: Optional[Tuple[TimeType, ...]]
6✔
85

86
    def dependencies(self) -> Mapping[int, set]:
6✔
87
        return {idx: {factors}
6✔
88
                for idx, factors in enumerate(self.factors)
89
                if factors}
90

91
    def reversed(self, offset: int, lengths: list):
6✔
92
        if not lengths:
6✔
93
            return self
×
94
        # If the iteration length is `n`, the starting point is shifted by `n - 1`
95
        steps = [length - 1 for length in lengths]
6✔
96
        bases = []
6✔
97
        factors = []
6✔
98
        for ch_base, ch_factors in zip(self.bases, self.factors):
6✔
99
            if ch_factors is None or len(ch_factors) <= offset:
6✔
100
                bases.append(ch_base)
×
101
                factors.append(ch_factors)
×
102
            else:
103
                ch_reverse_base = ch_base + sum(step * factor
6✔
104
                                                for factor, step in zip(ch_factors[offset:], steps))
105
                reversed_factors = ch_factors[:offset] + tuple(-f for f in ch_factors[offset:])
6✔
106
                bases.append(ch_reverse_base)
6✔
107
                factors.append(reversed_factors)
6✔
108

109
        if self.duration_factors is None or len(self.duration_factors) <= offset:
6✔
110
            duration_factors = self.duration_factors
6✔
111
            duration_base = self.duration_base
6✔
112
        else:
113
            duration_base = self.duration_base + sum((step * factor
×
114
                                                      for factor, step in zip(self.duration_factors[offset:], steps)), TimeType(0))
115
            duration_factors = self.duration_factors[:offset] + tuple(-f for f in self.duration_factors[offset:])
×
116
        return LinSpaceHold(tuple(bases), tuple(factors), duration_base=duration_base, duration_factors=duration_factors)
6✔
117

118

119
@dataclass
6✔
120
class LinSpaceArbitraryWaveform(LinSpaceNode):
6✔
121
    """This is just a wrapper to pipe arbitrary waveforms through the system."""
122
    waveform: Waveform
6✔
123
    channels: Tuple[ChannelID, ...]
6✔
124

125
    def dependencies(self) -> Mapping[int, set[Tuple[float, ...]]]:
6✔
126
        return {}
×
127

128
    def reversed(self, offset: int, lengths: list):
6✔
129
        return LinSpaceArbitraryWaveform(
6✔
130
            waveform=self.waveform.reversed(),
131
            channels=self.channels,
132
        )
133

134

135
@dataclass
6✔
136
class LinSpaceRepeat(LinSpaceNode):
6✔
137
    """Repeat the body count times."""
138
    body: Tuple[LinSpaceNode, ...]
6✔
139
    count: int
6✔
140

141
    def dependencies(self):
6✔
142
        dependencies = {}
6✔
143
        for node in self.body:
6✔
144
            for idx, deps in node.dependencies().items():
6✔
145
                dependencies.setdefault(idx, set()).update(deps)
6✔
146
        return dependencies
6✔
147

148
    def reversed(self, offset: int, counts: list):
6✔
149
        return LinSpaceRepeat(tuple(node.reversed(offset, counts) for node in reversed(self.body)), self.count)
×
150

151

152
@dataclass
6✔
153
class LinSpaceIter(LinSpaceNode):
6✔
154
    """Iteration in linear space are restricted to range 0 to length.
155

156
    Offsets and spacing are stored in the hold node."""
157
    body: Tuple[LinSpaceNode, ...]
6✔
158
    length: int
6✔
159

160
    def dependencies(self):
6✔
161
        dependencies = {}
6✔
162
        for node in self.body:
6✔
163
            for idx, deps in node.dependencies().items():
6✔
164
                # remove the last elemt in index because this iteration sets it -> no external dependency
165
                shortened = {dep[:-1] for dep in deps}
6✔
166
                if shortened != {()}:
6✔
167
                    dependencies.setdefault(idx, set()).update(shortened)
6✔
168
        return dependencies
6✔
169

170
    def reversed(self, offset: int, lengths: list):
6✔
171
        lengths.append(self.length)
6✔
172
        reversed_iter = LinSpaceIter(tuple(node.reversed(offset, lengths) for node in reversed(self.body)), self.length)
6✔
173
        lengths.pop()
6✔
174
        return reversed_iter
6✔
175

176

177
class LinSpaceBuilder(BaseProgramBuilder):
6✔
178
    """This program builder supports efficient translation of pulse templates that use symbolic linearly
179
    spaced voltages and durations.
180

181
    The channel identifiers are reduced to their index in the given channel tuple.
182

183
    Arbitrary waveforms are not implemented yet
184
    """
185

186
    def __init__(self, channels: Tuple[ChannelID, ...]):
6✔
187
        super().__init__()
6✔
188
        self._name_to_idx = {name: idx for idx, name in enumerate(channels)}
6✔
189
        self._idx_to_name = channels
6✔
190

191
        self._stack = [[]]
6✔
192
        self._ranges = []
6✔
193

194
    def _root(self):
6✔
195
        return self._stack[0]
6✔
196

197
    def _get_rng(self, idx_name: str) -> range:
6✔
198
        return self._get_ranges()[idx_name]
×
199

200
    def inner_scope(self, scope: Scope) -> Scope:
6✔
201
        """This function is necessary to inject program builder specific parameter implementations into the build
202
        process."""
203
        if self._ranges:
×
204
            name, _ = self._ranges[-1]
×
205
            return scope.overwrite({name: DynamicLinearValue(base=0, factors={name: 1})})
×
206
        else:
207
            return scope
×
208

209
    def _get_ranges(self):
6✔
210
        return dict(self._ranges)
6✔
211

212
    def _transformed_hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, HardwareVoltage]):
6✔
213
        voltages = sorted((self._name_to_idx[ch_name], value) for ch_name, value in voltages.items())
6✔
214
        voltages = [value for _, value in voltages]
6✔
215

216
        ranges = self._get_ranges()
6✔
217
        factors = []
6✔
218
        bases = []
6✔
219
        for value in voltages:
6✔
220
            if isinstance(value, float):
6✔
221
                bases.append(value)
6✔
222
                factors.append(None)
6✔
223
                continue
6✔
224
            offsets = value.factors
6✔
225
            base = value.base
6✔
226
            incs = []
6✔
227
            for rng_name, rng in ranges.items():
6✔
228
                start = 0.
6✔
229
                step = 0.
6✔
230
                offset = offsets.get(rng_name, None)
6✔
231
                if offset:
6✔
232
                    start += rng.start * offset
6✔
233
                    step += rng.step * offset
6✔
234
                base += start
6✔
235
                incs.append(step)
6✔
236
            factors.append(tuple(incs))
6✔
237
            bases.append(base)
6✔
238

239
        if isinstance(duration, DynamicLinearValue):
6✔
240
            duration_factors = duration.factors
×
241
            duration_base = duration.base
×
242
        else:
243
            duration_base = duration
6✔
244
            duration_factors = None
6✔
245

246
        set_cmd = LinSpaceHold(bases=tuple(bases),
6✔
247
                               factors=tuple(factors),
248
                               duration_base=duration_base,
249
                               duration_factors=duration_factors)
250

251
        self._stack[-1].append(set_cmd)
6✔
252

253
    def _transformed_play_arbitrary_waveform(self, waveform: Waveform):
6✔
254
        return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform, self._idx_to_name))
6✔
255

256
    def measure(self, measurements: Optional[Sequence[MeasurementWindow]]):
6✔
257
        """Ignores measurements"""
258
        pass
6✔
259

260
    def with_repetition(self, repetition_count: RepetitionCount,
6✔
261
                        measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']:
262
        if repetition_count == 0:
6✔
263
            return
×
264
        if isinstance(repetition_count, VolatileRepetitionCount):
6✔
265
            warnings.warn(f"{type(self).__name__} does not support volatile repetition counts.",
×
266
                          category=InefficientVolatility)
267

268
        self._stack.append([])
6✔
269
        yield self
6✔
270
        blocks = self._stack.pop()
6✔
271
        if blocks:
6✔
272
            self._stack[-1].append(LinSpaceRepeat(body=tuple(blocks), count=repetition_count))
6✔
273

274
    @contextlib.contextmanager
6✔
275
    def with_sequence(self,
6✔
276
                      measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']:
277
        yield self
6✔
278

279
    def new_subprogram(self, global_transformation: 'Transformation' = None) -> ContextManager['ProgramBuilder']:
6✔
280
        raise NotImplementedError('Not implemented yet (postponed)')
6✔
281

282
    def with_iteration(self, index_name: str, rng: range,
6✔
283
                       measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']:
284
        if len(rng) == 0:
6✔
285
            return
×
286
        self._stack.append([])
6✔
287
        self._ranges.append((index_name, rng))
6✔
288
        scope = self.build_context.scope.overwrite({index_name: DynamicLinearValue(base=0, factors={index_name: 1})})
6✔
289
        with self._with_patched_context(scope=scope):
6✔
290
            yield self
6✔
291
        cmds = self._stack.pop()
6✔
292
        self._ranges.pop()
6✔
293
        if cmds:
6✔
294
            self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng)))
6✔
295

296
    @contextlib.contextmanager
6✔
297
    def time_reversed(self) -> Iterable['LinSpaceBuilder']:
6✔
298
        self._stack.append([])
6✔
299
        yield self
6✔
300
        inner = self._stack.pop()
6✔
301
        offset = len(self._ranges)
6✔
302
        self._stack[-1].extend(node.reversed(offset, []) for node in reversed(inner))
6✔
303

304
    def to_program(self) -> Optional['LinSpaceProgram']:
6✔
305
        if root := self._root():
6✔
306
            return LinSpaceProgram(
6✔
307
                root=tuple(root),
308
                defined_channels=self._idx_to_name,
309
            )
310
        else:
NEW
311
            return None
×
312

313

314
@dataclass
6✔
315
class LoopLabel:
6✔
316
    idx: int
6✔
317
    count: int
6✔
318

319

320
@dataclass
6✔
321
class Increment:
6✔
322
    channel: int
6✔
323
    value: float
6✔
324
    dependency_key: DepKey
6✔
325

326

327
@dataclass
6✔
328
class Set:
6✔
329
    channel: int
6✔
330
    value: float
6✔
331
    key: DepKey = dataclasses.field(default_factory=lambda: DepKey(()))
6✔
332

333

334
@dataclass
6✔
335
class Wait:
6✔
336
    duration: TimeType
6✔
337

338

339
@dataclass
6✔
340
class LoopJmp:
6✔
341
    idx: int
6✔
342

343

344
@dataclass
6✔
345
class Play:
6✔
346
    waveform: Waveform
6✔
347
    channels: Tuple[ChannelID]
6✔
348

349

350
Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play]
6✔
351

352

353
@dataclass(frozen=True)
6✔
354
class DepState:
6✔
355
    base: float
6✔
356
    iterations: Tuple[int, ...]
6✔
357

358
    def required_increment_from(self, previous: 'DepState', factors: Sequence[float]) -> float:
6✔
359
        """Calculate the required increment from the previous state to the current given the factors that determine
360
        the voltage dependency of each index.
361

362
        By convention there are only two possible values for each iteration index integer in self: 0 or the last index
363
        The three possible increments for each iteration are none, regular and jump to next line.
364
        
365
        The previous dependency state can have a different iteration length if the trailing factors now or during the
366
        last iteration are zero.
367

368
        Args:
369
            previous: The previous state to calculate the required increment from. It has to belong to the same DepKey.
370
            factors: The number of factors has to be the same as the current number of iterations.
371

372
        Returns:
373
            The increment
374
        """
375
        assert len(self.iterations) == len(factors)
6✔
376

377
        increment = self.base - previous.base
6✔
378
        for old, new, factor in zip(previous.iterations, self.iterations, factors):
6✔
379
            # By convention there are only two possible values for each integer here: 0 or the last index
380
            # The three possible increments are none, regular and jump to next line
381

382
            if old == new:
6✔
383
                # we are still in the same iteration of this sweep
384
                pass
6✔
385

386
            elif old < new:
6✔
387
                assert old == 0
6✔
388
                # regular iteration, although the new value will probably be > 1, the resulting increment will be
389
                # applied multiple times so only one factor is needed.
390
                increment += factor
6✔
391

392
            else:
393
                assert new == 0
6✔
394
                # we need to jump back. The old value gives us the number of increments to reverse
395
                increment -= factor * old
6✔
396
        return increment
6✔
397

398

399
@dataclass
6✔
400
class _TranslationState:
6✔
401
    """This is the state of a translation of a LinSpace program to a command sequence."""
402

403
    label_num: int = dataclasses.field(default=0)
6✔
404
    commands: List[Command] = dataclasses.field(default_factory=list)
6✔
405
    iterations: List[int] = dataclasses.field(default_factory=list)
6✔
406
    active_dep: Dict[int, DepKey] = dataclasses.field(default_factory=dict)
6✔
407
    dep_states: Dict[int, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict)
6✔
408
    plain_voltage: Dict[int, float] = dataclasses.field(default_factory=dict)
6✔
409
    resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_INCREMENT_RESOLUTION)
6✔
410

411
    def new_loop(self, count: int):
6✔
412
        label = LoopLabel(self.label_num, count)
6✔
413
        jmp = LoopJmp(self.label_num)
6✔
414
        self.label_num += 1
6✔
415
        return label, jmp
6✔
416

417
    def get_dependency_state(self, dependencies: Mapping[int, set]):
6✔
418
        return {
6✔
419
            self.dep_states.get(ch, {}).get(DepKey.from_voltages(dep, self.resolution), None)
420
            for ch, deps in dependencies.items()
421
            for dep in deps
422
        }
423

424
    def set_voltage(self, channel: int, value: float):
6✔
425
        key = DepKey(())
6✔
426
        if self.active_dep.get(channel, None) != key or self.plain_voltage.get(channel, None) != value:
6✔
427
            self.commands.append(Set(channel, value, key))
6✔
428
            self.active_dep[channel] = key
6✔
429
            self.plain_voltage[channel] = value
6✔
430

431
    def _add_repetition_node(self, node: LinSpaceRepeat):
6✔
432
        pre_dep_state = self.get_dependency_state(node.dependencies())
6✔
433
        label, jmp = self.new_loop(node.count)
6✔
434
        initial_position = len(self.commands)
6✔
435
        self.commands.append(label)
6✔
436
        self.add_node(node.body)
6✔
437
        post_dep_state = self.get_dependency_state(node.dependencies())
6✔
438
        if pre_dep_state != post_dep_state:
6✔
439
            # hackedy
440
            self.commands.pop(initial_position)
6✔
441
            self.commands.append(label)
6✔
442
            label.count -= 1
6✔
443
            self.add_node(node.body)
6✔
444
        self.commands.append(jmp)
6✔
445

446
    def _add_iteration_node(self, node: LinSpaceIter):
6✔
447
        self.iterations.append(0)
6✔
448
        self.add_node(node.body)
6✔
449

450
        if node.length > 1:
6✔
451
            self.iterations[-1] = node.length - 1
6✔
452
            label, jmp = self.new_loop(node.length - 1)
6✔
453
            self.commands.append(label)
6✔
454
            self.add_node(node.body)
6✔
455
            self.commands.append(jmp)
6✔
456
        self.iterations.pop()
6✔
457

458
    def _set_indexed_voltage(self, channel: int, base: float, factors: Sequence[float]):
6✔
459
        dep_key = DepKey.from_voltages(voltages=factors, resolution=self.resolution)
6✔
460
        new_dep_state = DepState(
6✔
461
            base,
462
            iterations=tuple(self.iterations)
463
        )
464

465
        current_dep_state = self.dep_states.setdefault(channel, {}).get(dep_key, None)
6✔
466
        if current_dep_state is None:
6✔
467
            assert all(it == 0 for it in self.iterations)
6✔
468
            self.commands.append(Set(channel, base, dep_key))
6✔
469
            self.active_dep[channel] = dep_key
6✔
470

471
        else:
472
            inc = new_dep_state.required_increment_from(previous=current_dep_state, factors=factors)
6✔
473

474
            # we insert all inc here (also inc == 0) because it signals to activate this amplitude register
475
            if inc or self.active_dep.get(channel, None) != dep_key:
6✔
476
                self.commands.append(Increment(channel, inc, dep_key))
6✔
477
            self.active_dep[channel] = dep_key
6✔
478
        self.dep_states[channel][dep_key] = new_dep_state
6✔
479

480
    def _add_hold_node(self, node: LinSpaceHold):
6✔
481
        if node.duration_factors:
6✔
482
            raise NotImplementedError("TODO")
×
483

484
        for ch, (base, factors) in enumerate(zip(node.bases, node.factors)):
6✔
485
            if factors is None:
6✔
486
                self.set_voltage(ch, base)
6✔
487
                continue
6✔
488

489
            else:
490
                self._set_indexed_voltage(ch, base, factors)
6✔
491

492
        self.commands.append(Wait(node.duration_base))
6✔
493

494
    def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]):
6✔
495
        """Translate a (sequence of) linspace node(s) to commands and add it to the internal command list."""
496
        if isinstance(node, Sequence):
6✔
497
            for lin_node in node:
6✔
498
                self.add_node(lin_node)
6✔
499

500
        elif isinstance(node, LinSpaceRepeat):
6✔
501
            self._add_repetition_node(node)
6✔
502

503
        elif isinstance(node, LinSpaceIter):
6✔
504
            self._add_iteration_node(node)
6✔
505

506
        elif isinstance(node, LinSpaceHold):
6✔
507
            self._add_hold_node(node)
6✔
508

509
        elif isinstance(node, LinSpaceArbitraryWaveform):
6✔
510
            self.commands.append(Play(node.waveform, node.channels))
6✔
511

512
        else:
513
            raise TypeError("The node type is not handled", type(node), node)
×
514

515

516
def to_increment_commands(linspace_nodes: 'LinSpaceProgram') -> List[Command]:
6✔
517
    """translate the given linspace node tree to a minimal sequence of set and increment commands as well as loops."""
518
    state = _TranslationState()
6✔
519
    state.add_node(linspace_nodes.root)
6✔
520
    return state.commands
6✔
521

522

523
def transform_linspace_commands(command_list: List[Command],
6✔
524
                                channel_transformations: Mapping[ChannelID, 'ChannelTransformation'],
525
                                ) -> List[Command]:
526
    # all commands = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play]
527
    trafos_by_channel_idx = list(channel_transformations.values())
×
528

529
    for command in command_list:
×
530
        if isinstance(command, (LoopLabel, LoopJmp, Play, Wait)):
×
531
            # play is handled by transforming the sampled waveform
532
            continue
×
533
        elif isinstance(command, Increment):
×
534
            ch_trafo = trafos_by_channel_idx[command.channel]
×
535
            if ch_trafo.voltage_transformation:
×
536
                raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command")
×
537
            command.value /= ch_trafo.amplitude
×
538
        elif isinstance(command, Set):
×
539
            ch_trafo = trafos_by_channel_idx[command.channel]
×
540
            if ch_trafo.voltage_transformation:
×
541
                command.value = float(ch_trafo.voltage_transformation(command.value))
×
542
            command.value -= ch_trafo.offset
×
543
            command.value /= ch_trafo.amplitude
×
544
        else:        
545
            raise NotImplementedError(command)
×
546
    
547
    return command_list
×
548

549

550
def _get_waveforms_dict(transformed_commands: Sequence[Command]) -> Mapping[Waveform,Any]:
6✔
551
    return OrderedDict((command.waveform, None)
×
552
        for command in transformed_commands if isinstance(command,Play))
553

554

555
@dataclass
6✔
556
class LinSpaceProgram(Program):
6✔
557
    root: Tuple[LinSpaceNode, ...]
6✔
558
    defined_channels: Tuple[ChannelID, ...]
6✔
559

560
    @property
6✔
561
    def duration(self) -> TimeType:
6✔
NEW
562
        raise NotImplementedError("TODO")
×
563

564
    def get_defined_channels(self) -> AbstractSet[ChannelID]:
6✔
NEW
565
        return set(self.defined_channels)
×
566
    
567
    def dependencies(self):
6✔
568
        dependencies = {}
×
NEW
569
        for node in self.root:
×
570
            for idx, deps in node.dependencies().items():
×
571
                dependencies.setdefault(idx, set()).update(deps)
×
572
        return dependencies
×
573
    
574
    def get_waveforms_dict(self,
6✔
575
                           channels: Sequence[ChannelID], #!!! this argument currently does not do anything.
576
                           channel_transformations: Mapping[ChannelID,'ChannelTransformation'],
577
                           ) -> Mapping[Waveform,Any]:
578
        commands = to_increment_commands(self)
×
579
        commands_transformed = transform_linspace_commands(commands,channel_transformations)
×
580
        return _get_waveforms_dict(commands_transformed)
×
581
        
582

583
class LinSpaceVM:
6✔
584
    def __init__(self, channels: int,
6✔
585
                 sample_resolution: TimeType = TimeType.from_fraction(1, 2)):
586
        self.current_values = [np.nan] * channels
6✔
587
        self.sample_resolution = sample_resolution
6✔
588
        self.time = TimeType(0)
6✔
589
        self.registers = tuple({} for _ in range(channels))
6✔
590

591
        self.history: List[Tuple[TimeType, List[float]]] = []
6✔
592

593
        self.commands = None
6✔
594
        self.label_targets = None
6✔
595
        self.label_counts = None
6✔
596
        self.current_command = None
6✔
597

598
    def _play_arbitrary(self, play: Play):
6✔
599
        """Play an arbitrary waveform.
600

601
        This implementation samples the waveform with self.sample_resolution. We reinterpret this as a sequence of
602
        Set and Hold commands.
603

604
        Args:
605
            play: The waveform to play
606
        """
607
        start_time = copy.copy(self.time)
6✔
608

609
        # we do arbitrary time resolution sampling in a single batch
610
        dt = self.sample_resolution
6✔
611
        total_duration = play.waveform.duration
6✔
612
        # we ceil, because we need to cover the complete duration. The last sample can have a shorter duration
613
        n_samples = math.ceil(total_duration / dt)
6✔
614
        exact_times = [dt * n for n in range(n_samples)]
6✔
615
        sample_times = np.array(exact_times, dtype=np.float64)
6✔
616
        samples = []
6✔
617
        for ch in play.channels:
6✔
618
            samples.append(play.waveform.get_sampled(channel=ch, sample_times=sample_times))
6✔
619

620
        end_time = self.time + total_duration
6✔
621
        for values in zip(*samples):
6✔
622
            # This explicitness is not efficient but desired
623
            # "set" the voltages
624
            self.current_values[:] = values
6✔
625

626
            # "wait" for sample time or time until end
627
            hold_duration = min(dt, end_time - self.time)
6✔
628
            self.history.append((self.time, self.current_values.copy()))
6✔
629
            self.time += hold_duration
6✔
630

631
        assert self.time == start_time + total_duration
6✔
632

633
    def change_state(self, cmd: Union[Set, Increment, Wait, Play]):
6✔
634
        if isinstance(cmd, Play):
6✔
635
            self._play_arbitrary(cmd)
6✔
636

637
        elif isinstance(cmd, Wait):
6✔
638
            if self.history and self.history[-1][1] == self.current_values:
6✔
639
                # do not create noop entries
640
                pass
6✔
641
            else:
642
                self.history.append(
6✔
643
                    (self.time, self.current_values.copy())
644
                )
645
            self.time += cmd.duration
6✔
646
        elif isinstance(cmd, Set):
6✔
647
            self.current_values[cmd.channel] = cmd.value
6✔
648
            self.registers[cmd.channel][cmd.key] = cmd.value
6✔
649
        elif isinstance(cmd, Increment):
6✔
650
            value = self.registers[cmd.channel][cmd.dependency_key]
6✔
651
            value += cmd.value
6✔
652
            self.registers[cmd.channel][cmd.dependency_key] = value
6✔
653
            self.current_values[cmd.channel] = value
6✔
654
        else:
655
            raise NotImplementedError(cmd)
×
656

657
    def set_commands(self, commands: Sequence[Command]):
6✔
658
        self.commands = []
6✔
659
        self.label_targets = {}
6✔
660
        self.label_counts = {}
6✔
661
        self.current_command = None
6✔
662

663
        for cmd in commands:
6✔
664
            self.commands.append(cmd)
6✔
665
            if isinstance(cmd, LoopLabel):
6✔
666
                # a loop label signifies a reset count followed by the actual label that targets the following command
667
                assert cmd.idx not in self.label_targets
6✔
668
                self.label_targets[cmd.idx] = len(self.commands)
6✔
669

670
        self.current_command = 0
6✔
671

672
    def step(self):
6✔
673
        cmd = self.commands[self.current_command]
6✔
674
        if isinstance(cmd, LoopJmp):
6✔
675
            if self.label_counts[cmd.idx] > 0:
6✔
676
                self.label_counts[cmd.idx] -= 1
6✔
677
                self.current_command = self.label_targets[cmd.idx]
6✔
678
            else:
679
                # ignore jump
680
                self.current_command += 1
6✔
681
        elif isinstance(cmd, LoopLabel):
6✔
682
            self.label_counts[cmd.idx] = cmd.count - 1
6✔
683
            self.current_command += 1
6✔
684
        else:
685
            self.change_state(cmd)
6✔
686
            self.current_command += 1
6✔
687

688
    def run(self):
6✔
689
        while self.current_command < len(self.commands):
6✔
690
            self.step()
6✔
691

692

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc