• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / bpod-core / 16678188595

01 Aug 2025 02:53PM UTC coverage: 76.463%. First build
16678188595

Pull #35

github

web-flow
Merge d8e3a2d98 into 2e1e7643f
Pull Request #35: Uv

3 of 6 new or added lines in 2 files covered. (50.0%)

627 of 820 relevant lines covered (76.46%)

9.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.36
/bpod_core/bpod.py
1
"""Module for interfacing with the Bpod Finite State Machine."""
2

3
import logging
12✔
4
import re
12✔
5
import struct
12✔
6
import weakref
12✔
7
from abc import ABC, abstractmethod
12✔
8
from collections.abc import Callable
12✔
9
from dataclasses import dataclass, field
12✔
10
from threading import Event as TreadingEvent
12✔
11
from threading import Thread
12✔
12
from types import TracebackType
12✔
13
from typing import NamedTuple
12✔
14

15
import numpy as np
12✔
16
from numpy.typing import NDArray
12✔
17
from pydantic import validate_call
12✔
18
from serial import SerialException
12✔
19
from serial.tools.list_ports import comports
12✔
20
from typing_extensions import Self
12✔
21

22
from bpod_core import __version__ as bpod_core_version
12✔
23
from bpod_core.com import ExtendedSerial
12✔
24
from bpod_core.fsm import StateMachine
12✔
25
from bpod_core.misc import suggest_similar
12✔
26

27
PROJECT_NAME = 'bpod-core'
12✔
28
VENDOR_IDS_BPOD = [0x16C0]  # vendor IDs of supported Bpod devices
12✔
29
MIN_BPOD_FW_VERSION = (23, 0)  # minimum supported firmware version (major, minor)
12✔
30
MIN_BPOD_HW_VERSION = 3  # minimum supported hardware version
12✔
31
MAX_BPOD_HW_VERSION = 4  # maximum supported hardware version
12✔
32
CHANNEL_TYPES_INPUT = {
12✔
33
    b'U': 'Serial',
34
    b'X': 'SoftCode',
35
    b'Z': 'SoftCodeApp',
36
    b'F': 'Flex',
37
    b'D': 'Digital',
38
    b'B': 'BNC',
39
    b'W': 'Wire',
40
    b'P': 'Port',
41
}
42
CHANNEL_TYPES_OUTPUT = CHANNEL_TYPES_INPUT.copy()
12✔
43
CHANNEL_TYPES_OUTPUT.update({b'V': 'Valve', b'P': 'PWM'})
12✔
44
N_SERIAL_EVENTS_DEFAULT = 15
12✔
45
VALID_OPERATORS = ['exit', '>exit', '>back']
12✔
46

47
logger = logging.getLogger(__name__)
12✔
48

49

50
class VersionInfo(NamedTuple):
12✔
51
    """Represents the Bpod's on-board hardware configuration."""
52

53
    firmware: tuple[int, int]
12✔
54
    """Firmware version (major, minor)"""
12✔
55
    machine: int
12✔
56
    """Machine type (numerical)"""
12✔
57
    pcb: int | None
12✔
58
    """PCB revision, if applicable"""
12✔
59

60

61
class HardwareConfiguration(NamedTuple):
12✔
62
    """Represents the Bpod's on-board hardware configuration."""
63

64
    max_states: int
12✔
65
    """Maximum number of supported states in a single state machine description."""
12✔
66
    cycle_period: int
12✔
67
    """Period of the state machine's refresh cycle during a trial in microseconds."""
12✔
68
    max_serial_events: int
12✔
69
    """Maximum number of behavior events allocatable among connected modules."""
12✔
70
    max_bytes_per_serial_message: int
12✔
71
    """Maximum number of bytes allowed per serial message."""
12✔
72
    n_global_timers: int
12✔
73
    """Number of global timers supported."""
12✔
74
    n_global_counters: int
12✔
75
    """Number of global counters supported."""
12✔
76
    n_conditions: int
12✔
77
    """Number of condition-events supported."""
12✔
78
    n_inputs: int
12✔
79
    """Number of input channels."""
12✔
80
    input_description: bytes
12✔
81
    """Array indicating the state machine's onboard input channel types."""
12✔
82
    n_outputs: int
12✔
83
    """Number of channels in the state machine's output channel description array."""
12✔
84
    output_description: bytes
12✔
85
    """Array indicating the state machine's onboard output channel types."""
12✔
86
    cycle_frequency: int
12✔
87
    """Frequency of the state machine's refresh cycle during a trial in Hertz."""
12✔
88
    n_modules: int
12✔
89
    """Number of modules supported by the state machine."""
12✔
90

91

92
class BpodError(Exception):
12✔
93
    """
94
    Exception class for Bpod-related errors.
95

96
    This exception is raised when an error specific to the Bpod device or its
97
    operations occurs.
98
    """
99

100

101
class FSMThread(Thread):
12✔
102
    """A thread for managing the execution of a finite state machine on the Bpod."""
103

104
    _struct_start = struct.Struct('<Q')
12✔
105
    _struct_cycles = struct.Struct('<I')
12✔
106
    _struct_exit = struct.Struct('<IQ')
12✔
107

108
    def __init__(  # noqa: PLR0913
12✔
109
        self,
110
        serial: ExtendedSerial,
111
        fsm_index: int,
112
        confirm_fsm: bool,
113
        cycle_period: int,
114
        softcode_handler: Callable,
115
        state_transitions: NDArray[np.uint8],
116
        use_back_op: bool,
117
    ) -> None:
118
        """
119
        Initialize the FSMThread.
120

121
        Parameters
122
        ----------
123
        serial : ExtendedSerial
124
            The serial connection to the Bpod device.
125
        fsm_index : int
126
            The index of the FSM being managed.
127
        confirm_fsm : bool
128
            Whether to confirm the FSM with the Bpod device.
129
        cycle_period : int
130
            The cycle period of the Bpod device in microseconds.
131
        softcode_handler : Callable
132
            A handler function for processing softcodes.
133
        state_transitions : np.ndarray
134
            The state transition matrix
135
        use_back_op : bool
136
            Whether the state machine makes use of the `>back` operator
137
        """
138
        super().__init__()
×
139
        self.daemon = True
×
140
        self.serial = serial
×
141
        self._stop_event = TreadingEvent()
×
142
        self._index = fsm_index
×
143
        self._confirm_fsm = confirm_fsm
×
144
        self._cycle_period = cycle_period
×
145
        self._softcode_handler = softcode_handler
×
146
        self._state_transitions = state_transitions
×
147
        self._use_back_op = use_back_op
×
148

149
    def stop(self):
12✔
150
        self._stop_event.set()
×
151

152
    def run(self) -> None:
12✔
153
        """Execute the FSMThread."""
154
        # assign members to local variables to avoid repeated attribute lookups
155
        serial = self.serial
×
156
        index = self._index
×
157
        cycle_period = self._cycle_period
×
158
        struct_cycles = self._struct_cycles
×
159
        softcode_handler = self._softcode_handler
×
160
        state_transitions = self._state_transitions
×
161
        previous_state = np.uint8(0)
×
162
        current_state = np.uint8(0)
×
163
        target_exit = np.uint8(state_transitions.shape[0])
×
164
        target_back = np.uint8(255)
×
165
        use_back_op = self._use_back_op
×
166

167
        # create buffers for repeated serial reads
168
        opcode_buf = bytearray(2)  # buffer for opcodes
×
169
        event_data_buf = bytearray(259)  # max 255 events + 4 bytes for n_cycles
×
170

171
        # should we use debug logging?
172
        debug = logger.isEnabledFor(logging.DEBUG)
×
173

174
        # confirm the state machine
175
        if self._confirm_fsm:
×
176
            if serial.read(1) != b'\x01':
×
177
                raise RuntimeError('State machine #%d was not confirmed by Bpod', index)
×
178
            if debug:
×
179
                logger.debug('State machine #%d confirmed by Bpod', index)
×
180

181
        # read the start time of the state machine (uInt64)
182
        t0 = self._struct_start.unpack(serial.read(8))[0]
×
183
        if debug:
×
184
            logger.debug('%d µs: Starting state machine #%d', t0, index)
×
185
            logger.debug('%d µs: State %d', t0, current_state)
×
186
        # TODO: handle start of state machine
187
        # TODO: handle start of state
188

189
        # enter the reading loop
190
        while not self._stop_event.is_set():
×
191
            # read the next two opcodes
192
            serial.readinto(opcode_buf)
×
193
            opcode, param = opcode_buf
×
194

195
            if opcode == 1:  # handle events
×
196
                # read `param` event bytes + 4 bytes for n_cycles (uInt32)
197
                event_data_view = memoryview(event_data_buf)[: param + 4]
×
198
                serial.readinto(event_data_view)
×
199

200
                # unpack the number of cycles, calculate the event's timestamp
201
                n_cycles = struct_cycles.unpack_from(event_data_view, param)[0]
×
202
                micros = t0 + n_cycles * cycle_period
×
203

204
                # handle each event
205
                events = event_data_view[:param]
×
206
                for event in events:
×
207
                    if debug:
×
208
                        logger.debug('%d µs: Event %d', micros, event)
×
209
                    # TODO: handle event
210

211
                # handle state transitions / exit event
212
                for event in events:
×
213
                    if event == 255:  # exit event
×
214
                        self.stop()
×
215
                        break
×
216
                    target_state = state_transitions[current_state][event]
×
217
                    if target_state == current_state:  # no transition
×
218
                        continue
×
219
                    if target_state == target_exit:  # virtual exit state
×
220
                        # TODO: handle end of state
221
                        break
×
222
                    if target_state == target_back and use_back_op:  # back
×
223
                        target_state = previous_state
×
224
                    # TODO: handle end of state
225
                    previous_state = current_state
×
226
                    current_state = target_state
×
227
                    # TODO: handle start of state
228
                    if debug:
×
229
                        logger.debug('%d µs: State %d', micros, current_state)
×
230
                    break  # only handle the first state transition
×
231

232
            elif opcode == 2:  # handle softcodes
×
233
                if debug:
×
234
                    logger.debug('Softcode %d', param)
×
235
                softcode_handler(param)
×
236

237
            else:
238
                raise RuntimeError(f'Unknown opcode: {opcode}')
×
239

240
        # exit state machine
241
        # read 12 bytes: cycles (uInt32) and micros (uInt64)
242
        cycles, micros = self._struct_exit.unpack(serial.read(12))
×
243
        if debug:
×
244
            logger.debug(
×
245
                '%d µs: Ending state machine #%d (%d cycles)', micros, index, cycles
246
            )
247
        # TODO: handle end of state machine
248

249

250
class Bpod:
12✔
251
    """Bpod class for interfacing with the Bpod Finite State Machine."""
252

253
    _version: VersionInfo
12✔
254
    _hardware: HardwareConfiguration
12✔
255
    _fsm_thread: FSMThread | None = None
12✔
256
    _next_fsm_index: int = -1
12✔
257
    _serial_buffer = bytearray()  # buffer for TrialReader thread
12✔
258
    serial0: ExtendedSerial
12✔
259
    """Primary serial device for communication with the Bpod."""
12✔
260
    serial1: ExtendedSerial | None = None
12✔
261
    """Secondary serial device for communication with the Bpod."""
12✔
262
    serial2: ExtendedSerial | None = None
12✔
263
    """Tertiary serial device for communication with the Bpod - used by Bpod 2+ only."""
12✔
264
    inputs: NamedTuple
12✔
265
    """Available input channels."""
12✔
266
    outputs: NamedTuple
12✔
267
    """Available output channels."""
12✔
268
    modules: NamedTuple
12✔
269
    """Available modules."""
12✔
270
    event_names: list[str]
12✔
271
    """List of event names."""
12✔
272
    output_actions: list[str]
12✔
273
    """List of output actions."""
12✔
274

275
    @validate_call
12✔
276
    def __init__(
12✔
277
        self, port: str | None = None, serial_number: str | None = None
278
    ) -> None:
279
        weakref.finalize(self, self.close)
12✔
280
        logger.info('bpod_core %s', bpod_core_version)
12✔
281

282
        # initialize members
283
        self.event_names = []
12✔
284
        self.output_actions = []
12✔
285
        self._waiting_for_confirmation = False
12✔
286
        self._state_transitions: NDArray[np.uint8] = np.empty((0, 255), dtype=np.uint8)
12✔
287
        self._use_back_op = False
12✔
288

289
        # identify Bpod by port or serial number
290
        port, self._serial_number = self._identify_bpod(port, serial_number)
12✔
291

292
        # open primary serial port
293
        self.serial0 = ExtendedSerial()
12✔
294
        self.serial0.port = port
12✔
295
        self.open()
12✔
296

297
        # get firmware version and machine type; enforce version requirements
298
        self._get_version_info()
12✔
299

300
        # get the Bpod's onboard hardware configuration
301
        self._get_hardware_configuration()
12✔
302

303
        # configure input and output channels
304
        self._configure_io()
12✔
305

306
        # detect additional serial ports
307
        self._detect_additional_serial_ports()
12✔
308

309
        # update modules
310
        self.update_modules()
12✔
311

312
        # log hardware information
313
        if logger.isEnabledFor(logging.INFO):
12✔
314
            machine = {3: 'r2.0-2.5', 4: '2+ r1.0'}.get(self.version.machine, 'unknown')
×
315
            logger.info(
×
316
                'Connected to Bpod Finite State Machine %s on %s', machine, self.port
317
            )
318
            logger.info(
×
319
                'Firmware Version %d.%d, Serial Number %s, PCB Revision %d',
320
                *self.version.firmware,
321
                self._serial_number,
322
                self.version.pcb,
323
            )
324

325
    def __enter__(self) -> Self:
12✔
326
        """Enter context."""
327
        return self
×
328

329
    def __exit__(
12✔
330
        self,
331
        exc_type: type[BaseException] | None,
332
        exc_val: BaseException | None,
333
        exc_tb: TracebackType | None,
334
    ) -> None:
335
        """Exit context and close connection."""
336
        self.close()
×
337

338
    def _sends_discovery_byte(
12✔
339
        self,
340
        port: str,
341
        byte: bytes = b'\xde',
342
        timeout: float = 0.11,
343
        trigger: bytes | None = None,
344
    ) -> bool:
345
        r"""Check if the device on the given port sends a discovery byte.
346

347
        Parameters
348
        ----------
349
        port : str
350
            The name of the serial port to check (e.g., '/dev/ttyUSB0' or 'COM3').
351
        byte : bytes, optional
352
            The discovery byte to expect from the device. Defaults to b'\\xde'.
353
        timeout : float, optional
354
            Timeout period (in seconds) for the serial read operation. Defaults to 0.11.
355
        trigger : bytes, optional
356
            An optional command to send on serial0 before reading from the given device.
357

358
        Returns
359
        -------
360
        bool
361
            Whether the given device responded with the expected discovery byte or not.
362
        """
363
        try:
12✔
364
            with ExtendedSerial(port, timeout=timeout) as ser:
12✔
365
                if trigger is not None and getattr(self, 'serial0', None) is not None:
12✔
366
                    self.serial0.write(trigger)
×
367
                return ser.read(1) == byte
12✔
368
        except SerialException:
12✔
369
            return False
12✔
370

371
    def _identify_bpod(
12✔
372
        self,
373
        port: str | None = None,
374
        serial_number: str | None = None,
375
    ) -> tuple[str, str | None]:
376
        """
377
        Try to identify a supported Bpod based on port or serial number.
378

379
        If neither port nor serial number are provided, this function will attempt to
380
        detect a supported Bpod automatically.
381

382
        Parameters
383
        ----------
384
        port : str | None, optional
385
            The port of the device.
386
        serial_number : str | None, optional
387
            The serial number of the device.
388

389
        Returns
390
        -------
391
        str
392
            the port of the device
393
        str | None
394
            the serial number of the device
395

396
        Raises
397
        ------
398
        BpodError
399
            If no Bpod is found or the indicated device is not supported.
400
        """
401
        # If no port or serial number provided, try to automagically find an idle Bpod
402
        if port is None and serial_number is None:
12✔
403
            try:
12✔
404
                port_info = next(
12✔
405
                    p
406
                    for p in comports()
407
                    if getattr(p, 'vid', None) in VENDOR_IDS_BPOD
408
                    and self._sends_discovery_byte(p.device)
409
                )
410
            except StopIteration as e:
12✔
411
                raise BpodError('No available Bpod found') from e
12✔
412
            return port_info.device, port_info.serial_number
12✔
413

414
        # If a serial number was provided, try to match it with a serial device
415
        if serial_number is not None:
12✔
416
            try:
12✔
417
                port_info = next(
12✔
418
                    p
419
                    for p in comports()
420
                    if p.serial_number == serial_number
421
                    and self._sends_discovery_byte(p.device)
422
                )
423
            except (StopIteration, AttributeError) as e:
12✔
424
                raise BpodError(f'No device with serial number {serial_number}') from e
12✔
425

426
        # Else, assure that the provided port exists and the device could be a Bpod
427
        else:
428
            try:
12✔
429
                port_info = next(p for p in comports() if p.device == port)
12✔
430
            except (StopIteration, AttributeError) as e:
12✔
431
                raise BpodError(f'Port not found: {port}') from e
12✔
432

433
        if port_info.vid not in VENDOR_IDS_BPOD:
12✔
434
            raise BpodError('Device is not a supported Bpod')
12✔
435
        return port_info.device, port_info.serial_number
12✔
436

437
    def _get_version_info(self) -> None:
12✔
438
        """
439
        Retrieve firmware version and machine type information from the Bpod.
440

441
        This method queries the Bpod to obtain its firmware version, machine type, and
442
        PCB revision. It also validates that the hardware and firmware versions meet
443
        the minimum requirements. If the versions are not supported, an Exception is
444
        raised.
445

446
        Raises
447
        ------
448
        BpodError
449
            If the hardware version or firmware version is not supported.
450
        """
451
        logger.debug('Retrieving version information')
12✔
452
        v_major, machine_type = self.serial0.query_struct(b'F', '<2H')
12✔
453
        v_minor = self.serial0.query_struct(b'f', '<H')[0] if v_major > 22 else 0
12✔
454
        v_firmware = (v_major, v_minor)
12✔
455
        if not (MIN_BPOD_HW_VERSION <= machine_type <= MAX_BPOD_HW_VERSION):
12✔
456
            raise BpodError(
12✔
457
                f'The hardware version of the Bpod on {self.port} is not supported.',
458
            )
459
        if v_firmware < MIN_BPOD_FW_VERSION:
12✔
460
            raise BpodError(
12✔
461
                f'The Bpod on {self.port} uses firmware v{v_major}.{v_minor} '
462
                f'which is not supported. Please update the device to firmware '
463
                f'v{MIN_BPOD_FW_VERSION[0]}.{MIN_BPOD_FW_VERSION[1]} or later.',
464
            )
465
        pcv_rev = self.serial0.query_struct(b'v', '<B')[0] if v_major > 22 else None
12✔
466
        self._version = VersionInfo(v_firmware, machine_type, pcv_rev)
12✔
467

468
    def _get_hardware_configuration(self) -> None:
12✔
469
        """Retrieve the Bpod's onboard hardware configuration."""
470
        logger.debug('Retrieving onboard hardware configuration')
12✔
471

472
        # retrieve hardware configuration from Bpod
473
        if self.version.firmware > (22, 0):
12✔
474
            hardware_conf = list(self.serial0.query_struct(b'H', '<2H6B'))
12✔
475
        else:
476
            hardware_conf = list(self.serial0.query_struct(b'H', '<2H5B'))
12✔
477
            hardware_conf.insert(-4, 3)  # max bytes per serial msg always = 3
12✔
478
        hardware_conf.extend(self.serial0.read_struct(f'<{hardware_conf[-1]}s1B'))
12✔
479
        hardware_conf.extend(self.serial0.read_struct(f'<{hardware_conf[-1]}s'))
12✔
480

481
        # compute additional fields
482
        cycle_frequency = 1000000 // hardware_conf[1]  # cycle_period is at index 1
12✔
483
        n_modules = hardware_conf[-3].count(b'U')  # input_description is third to last
12✔
484
        hardware_conf.extend([cycle_frequency, n_modules])
12✔
485

486
        # create NamedTuple for hardware configuration
487
        self._hardware = HardwareConfiguration(*hardware_conf)
12✔
488

489
    def _configure_io(self) -> None:
12✔
490
        """Configure the input and output channels of the Bpod."""
491
        logger.debug('Configuring I/O')
12✔
492
        for description, channel_class, channel_names in (
12✔
493
            (self._hardware.input_description, Input, CHANNEL_TYPES_INPUT),
494
            (self._hardware.output_description, Output, CHANNEL_TYPES_OUTPUT),
495
        ):
496
            n_channels = len(description)
12✔
497
            io_class = f'{channel_class.__name__.lower()}s'
12✔
498
            channels = []
12✔
499
            types = []
12✔
500

501
            # loop over the description array and create channels
502
            for idx, io_key in enumerate(struct.unpack(f'<{n_channels}c', description)):
12✔
503
                if io_key not in channel_names:
12✔
504
                    raise RuntimeError(f'Unknown {io_class[:-1]} type: {io_key}')
×
505
                n = description[:idx].count(io_key) + 1
12✔
506
                name = f'{channel_names[io_key]}{n}'
12✔
507
                channels.append(channel_class(self, name, io_key, idx))
12✔
508
                types.append((name, channel_class))
12✔
509

510
            # store channels to NamedTuple and set the latter as a class attribute
511
            named_tuple = NamedTuple(io_class, types)._make(channels)
12✔
512
            setattr(self, io_class, named_tuple)
12✔
513

514
        # set the enabled state of the input channels
515
        self._set_enable_inputs()
12✔
516

517
    def _detect_additional_serial_ports(self) -> None:
12✔
518
        """Detect additional USB-serial ports."""
519
        logger.debug('Detecting additional USB-serial ports')
×
520

521
        # First, assemble a list of candidate ports
522
        candidate_ports = [
×
523
            p.device
524
            for p in comports()
525
            if p.serial_number == self._serial_number and p.device != self.port
526
        ]
527

528
        # Exclude those devices from the list that are already sending a discovery byte
529
        # NB: this should not be necessary, as we already filter for devices with
530
        #     identical USB serial number.
531
        # for port in candidate_ports:
532
        #     if self._sends_discovery_byte(port):
533
        #         candidate_ports.remove(port)
534

535
        # Find secondary USB-serial port
536
        if self._version.firmware >= (23, 0):
×
537
            for port in candidate_ports:
×
538
                if self._sends_discovery_byte(port, bytes([222]), trigger=b'{'):
×
539
                    self.serial1 = ExtendedSerial()
×
540
                    self.serial1.port = port
×
541
                    candidate_ports.remove(port)
×
542
                    logger.debug('Detected secondary USB-serial port: %s', port)
×
543
                    break
×
544
            if self.serial1 is None:
×
545
                raise BpodError('Could not detect secondary serial port')
×
546

547
        # State Machine 2+ uses a third USB-serial port for FlexIO
548
        if self.version.machine == 4:
×
549
            for port in candidate_ports:
×
550
                if self._sends_discovery_byte(port, bytes([223]), trigger=b'}'):
×
551
                    self.serial2 = ExtendedSerial()
×
552
                    self.serial2.port = port
×
553
                    logger.debug('Detected tertiary USB-serial port: %s', port)
×
554
                    break
×
555
            if self.serial2 is None:
×
556
                raise BpodError('Could not detect tertiary serial port')
×
557

558
    def _handshake(self) -> None:
12✔
559
        """
560
        Perform a handshake with the Bpod.
561

562
        Raises
563
        ------
564
        BpodException
565
            If the handshake fails.
566
        """
567
        try:
12✔
568
            self.serial0.timeout = 0.2
12✔
569
            if not self.serial0.verify(b'6', b'5'):
12✔
570
                raise BpodError(f'Handshake with device on {self.port} failed')
12✔
571
            self.serial0.timeout = None
12✔
572
        except SerialException as e:
12✔
573
            raise BpodError(f'Handshake with device on {self.port} failed') from e
12✔
574
        finally:
575
            self.serial0.reset_input_buffer()
12✔
576
        logger.debug('Handshake with Bpod on %s successful', self.port)
12✔
577

578
    def _test_psram(self) -> bool:
12✔
579
        """
580
        Test the Bpod's PSRAM.
581

582
        Returns
583
        -------
584
        bool
585
            True if the PSRAM test passed, False otherwise.
586
        """
587
        return self.serial0.verify(b'_')
×
588

589
    def _set_enable_inputs(self) -> bool:
12✔
590
        logger.debug('Updating enabled state of input channels')
12✔
591
        enable = [i.enabled for i in self.inputs]
12✔
592
        self.serial0.write_struct(f'<c{self._hardware.n_inputs}?', b'E', *enable)
12✔
593
        return self.serial0.read(1) == b'\x01'
12✔
594

595
    def reset_session_clock(self) -> bool:
12✔
596
        logger.debug('Resetting session clock')
12✔
597
        return self.serial0.verify(b'*')
12✔
598

599
    def _disable_all_module_relays(self) -> None:
12✔
600
        for module in self.modules:
12✔
601
            module.set_relay(False)
12✔
602

603
    def _compile_event_names(self) -> None:
12✔
604
        """Compile the list of event names supported by the Bpod hardware."""
605
        n_serial_events = sum([len(m.event_names) for m in self.modules])
12✔
606
        n_softcodes = self._hardware.max_serial_events - n_serial_events
12✔
607
        n_usb = self._hardware.input_description.count(b'X')
12✔
608
        n_usb_ext = self._hardware.input_description.count(b'Z')
12✔
609
        n_softcodes_per_usb = n_softcodes // (n_usb + n_usb_ext)
12✔
610
        n_app_softcodes = n_usb_ext * n_softcodes_per_usb
12✔
611
        self.event_names = []
12✔
612

613
        # Compile actions for output channels
614
        counters = dict.fromkeys(CHANNEL_TYPES_INPUT, 0)
12✔
615
        for io_key in [bytes([x]) for x in self._hardware.input_description]:
12✔
616
            name = CHANNEL_TYPES_INPUT[io_key]
12✔
617
            if io_key == b'U':  # Serial
12✔
618
                names = self.modules[counters[io_key]].event_names
12✔
619
            elif io_key == b'X':  # SoftCode
12✔
620
                names = (f'{name}{i + 1}' for i in range(n_softcodes_per_usb))
12✔
621
            elif io_key == b'Z':  # SoftCodeApp
12✔
622
                names = (f'{name}{i + 1}' for i in range(n_app_softcodes))
12✔
623
            elif io_key == b'F':  # Flex
12✔
624
                names = (f'{name}{counters[io_key] + 1}_{i + 1}' for i in range(2))
12✔
625
            elif io_key in b'PBW':  # Port, BNC, Wire
12✔
626
                names = (f'{name}{counters[io_key] + 1}_{s}' for s in ('High', 'Low'))
12✔
627
            else:
628
                continue
×
629
            self.event_names.extend(names)
12✔
630
            counters[io_key] += 1
12✔
631

632
        # Add global timers, global counters, conditions and 'Tup'
633
        for event_name, n in [
12✔
634
            ('GlobalTimer{}_Start', self._hardware.n_global_timers),
635
            ('GlobalTimer{}_End', self._hardware.n_global_timers),
636
            ('GlobalCounter{}_End', self._hardware.n_global_counters),
637
            ('Condition{}', self._hardware.n_conditions),
638
        ]:
639
            self.event_names.extend(event_name.format(i + 1) for i in range(n))
12✔
640
        self.event_names.append('Tup')
12✔
641

642
    def _compile_output_actions(self) -> None:
12✔
643
        """Compile the list of output actions supported by the Bpod hardware."""
644
        self.output_actions = []
12✔
645

646
        # Compile actions for output channels
647
        counters = dict.fromkeys(CHANNEL_TYPES_OUTPUT, 0)
12✔
648
        for io_key in [bytes([x]) for x in self._hardware.output_description]:
12✔
649
            if io_key == b'U':  # Serial
12✔
650
                name = self.modules[counters[io_key]].name
12✔
651
            elif io_key in b'XZ':  # SoftCode, SoftCodeApp
12✔
652
                name = CHANNEL_TYPES_OUTPUT[io_key]
12✔
653
            elif io_key in b'FVPBW':  # Flex, Valve, PWM, BNC, Wire
12✔
654
                name = f'{CHANNEL_TYPES_OUTPUT[io_key]}{counters[io_key] + 1}'
12✔
655
            else:
656
                continue
×
657
            self.output_actions.append(name)
12✔
658
            counters[io_key] += 1
12✔
659

660
        # Add output actions for global timers, global counters and analog thresholds
661
        self.output_actions.extend(
12✔
662
            ['GlobalTimerTrig', 'GlobalTimerCancel', 'GlobalCounterReset'],
663
        )
664
        if self.version.machine == 4:
12✔
665
            self.output_actions.extend(['AnalogThreshEnable', 'AnalogThreshDisable'])
12✔
666

667
    @property
12✔
668
    def port(self) -> str | None:
12✔
669
        """The port of the Bpod's primary serial device."""
670
        return self.serial0.port
12✔
671

672
    @property
12✔
673
    def version(self) -> VersionInfo:
12✔
674
        """Version information of the Bpod's firmware and hardware."""
675
        return self._version
12✔
676

677
    def open(self) -> None:
12✔
678
        """
679
        Open the connection to the Bpod.
680

681
        Raises
682
        ------
683
        SerialException
684
            If the port could not be opened.
685
        BpodException
686
            If the handshake fails.
687
        """
688
        if self.serial0.is_open:
12✔
689
            return
×
690
        self.serial0.open()
12✔
691
        self._handshake()
12✔
692

693
    def close(self) -> None:
12✔
694
        """Close the connection to the Bpod."""
695
        self.stop_state_machine()
×
696
        if hasattr(self, 'serial0') and self.serial0.is_open:
×
697
            logger.debug('Closing connection to Bpod on %s', self.port)
×
698
            self.serial0.write(b'Z')
×
699
            self.serial0.close()
×
700

701
    def set_status_led(self, enabled: bool) -> bool:
12✔
702
        """
703
        Enable or disable the Bpod's status LED.
704

705
        Parameters
706
        ----------
707
        enabled : bool
708
            True to enable the status LED, False to disable.
709

710
        Returns
711
        -------
712
        bool
713
            True if the operation was successful, False otherwise.
714
        """
715
        self.serial0.write_struct('<c?', b':', enabled)
×
716
        return self.serial0.verify(b'')
×
717

718
    def update_modules(self) -> None:
12✔
719
        """Update the list of connected modules and their configurations."""
720
        # self._disable_all_module_relays()
721
        self.serial0.write(b'M')
12✔
722
        modules = []
12✔
723
        for idx in range(self._hardware.n_modules):
12✔
724
            # check connection state
725
            if not (is_connected := self.serial0.read_struct('<?')[0]):
12✔
726
                module_name = f'{CHANNEL_TYPES_INPUT[b"U"]}{idx + 1}'
12✔
727
                modules.append(Module(_bpod=self, index=idx, name=module_name))
12✔
728
                continue
12✔
729

730
            # read further information if module is connected
731
            n_events = N_SERIAL_EVENTS_DEFAULT
×
732
            firmware_version, n_chars = self.serial0.read_struct('<IB')
×
733
            base_name, more_info = self.serial0.read_struct(f'<{n_chars}s?')
×
734
            base_name = base_name.decode('UTF8')
×
735
            custom_event_names = []
×
736
            while more_info:
×
737
                match self.serial0.read(1):
×
738
                    case b'#':
×
739
                        n_events = self.serial0.read_struct('<B')[0]
×
740
                    case b'E':
×
741
                        n_event_names = self.serial0.read_struct('<B')[0]
×
742
                        for _ in range(n_event_names):
×
743
                            n_chars = self.serial0.read_struct('<B')[0]
×
744
                            event_name = self.serial0.read_struct(f'<{n_chars}s')[0]
×
745
                            custom_event_names.append(event_name.decode('UTF8'))
×
746
                more_info = self.serial0.read_struct('<?')[0]
×
747

748
            # create module name with trailing index
749
            matches = [re.match(rf'^{base_name}(\d$)', m.name) for m in modules]
×
750
            index = max([int(m.group(1)) for m in matches if m is not None] + [0])
×
751
            module_name = f'{base_name}{index + 1}'
×
752
            logger.debug('Detected %s on module port %d', module_name, idx + 1)
×
753

754
            # create instance of Module
755
            modules.append(
×
756
                Module(
757
                    _bpod=self,
758
                    index=idx,
759
                    name=module_name,
760
                    is_connected=is_connected,
761
                    firmware_version=firmware_version,
762
                    n_events=n_events,
763
                    _custom_event_names=custom_event_names,
764
                ),
765
            )
766

767
        # create NamedTuple and store as class attribute
768
        self.modules = NamedTuple('modules', [(m.name, Module) for m in modules])._make(
12✔
769
            modules,
770
        )
771

772
        # update event names and output actions
773
        self._compile_event_names()
12✔
774
        self._compile_output_actions()
12✔
775

776
    def validate_state_machine(self, state_machine: StateMachine) -> None:
12✔
777
        """
778
        Validate the provided state machine for compatibility with the hardware.
779

780
        Parameters
781
        ----------
782
        state_machine : StateMachine
783
            The state machine to validate.
784

785
        Raises
786
        ------
787
        ValueError
788
            If the state machine is invalid or not compatible with the hardware.
789
        """
790
        self.send_state_machine(state_machine, validate_only=True)
×
791

792
    def send_state_machine(
12✔
793
        self,
794
        state_machine: StateMachine,
795
        *,
796
        run_asap: bool = False,
797
        validate_only: bool = False,
798
    ) -> None:
799
        """
800
        Send a state machine to the Bpod.
801

802
        This method compiles the provided state machine into a byte array format
803
        compatible with the Bpod and sends it to the device. It also validates the
804
        state machine for compatibility with the hardware before sending.
805

806
        Parameters
807
        ----------
808
        state_machine : StateMachine
809
            The state machine to be sent to the Bpod device.
810
        run_asap : bool, optional
811
            If True, the state machine will run immediately after the current one has
812
            finished. Default is False.
813
        validate_only : bool, optional
814
            If True, the state machine is only validated and not sent to the device.
815
            Default is False.
816

817
        Raises
818
        ------
819
        ValueError
820
            If the state machine is invalid or exceeds hardware limitations.
821
        """
822
        # Disable all active module relays
823
        if not validate_only:
12✔
824
            self._disable_all_module_relays()
12✔
825

826
        # Ensure that the state machine has at least one state
827
        if (n_states := len(state_machine.states)) == 0:
12✔
828
            raise ValueError('State machine needs to have at least one state')
×
829

830
        # Check if '>back' operator is being used
831
        targets_used = {
12✔
832
            target
833
            for state in state_machine.states.values()
834
            for target in state.state_change_conditions.values()
835
        }
836
        self._use_back_op = '>back' in targets_used
12✔
837

838
        # Validate the number of states, global timers, global counters and conditions.
839
        n_global_timers = max(state_machine.global_timers.keys(), default=-1) + 1
12✔
840
        n_global_counters = max(state_machine.global_counters.keys(), default=-1) + 1
12✔
841
        n_conditions = max(state_machine.conditions.keys(), default=-1) + 1
12✔
842
        for name, value, maximum_value in (
12✔
843
            ('states', n_states, self._hardware.max_states - 1 - self._use_back_op),
844
            ('global timers', n_global_timers, self._hardware.n_global_timers),
845
            ('global counters', n_global_counters, self._hardware.n_global_counters),
846
            ('conditions', n_conditions, self._hardware.n_conditions),
847
        ):
848
            if value > maximum_value:
12✔
849
                raise ValueError(
×
850
                    'Too many %s in state machine - hardware supports a maximum '
851
                    'number of %d %s',
852
                    name,
853
                    maximum_value,
854
                    name,
855
                )
856

857
        # Validate states
858
        valid_targets = list(state_machine.states.keys()) + VALID_OPERATORS
12✔
859
        for state_name, state in state_machine.states.items():
12✔
860
            for condition_name, target in state.state_change_conditions.items():
12✔
861
                if target not in valid_targets:
12✔
862
                    target_type = 'operator' if target[0] == '>' else 'target state'
×
863
                    raise ValueError(
×
864
                        f"Invalid {target_type} '{target}' for state change condition "
865
                        f"'{condition_name}' in state '{state_name}'"
866
                        + suggest_similar(target, valid_targets),
867
                    )
868
                if condition_name not in self.event_names:
12✔
869
                    raise ValueError(
×
870
                        f"Invalid state change condition '{condition_name}' in state "
871
                        f"'{state_name}'"
872
                        + suggest_similar(condition_name, self.event_names),
873
                    )
874
            actions = set(state.output_actions.keys())
12✔
875
            if invalid_actions := actions.difference(self.output_actions):
12✔
876
                invalid_action = invalid_actions.pop()
×
877
                raise ValueError(
×
878
                    f"Invalid output action '{invalid_action}' in state '{state_name}'"
879
                    + suggest_similar(invalid_action, self.output_actions),
880
                )
881

882
        # Compile list of physical channels
883
        # TODO: this is ugly
884
        physical_output_channels = [m.name for m in self.modules] + [
12✔
885
            o.name for o in self.outputs if o.io_type != b'U'
886
        ]
887
        physical_input_channels = [m.name for m in self.modules] + [
12✔
888
            o.name for o in self.inputs if o.io_type != b'U'
889
        ]
890

891
        # Validate global timers
892
        for timer_id, timer in state_machine.global_timers.items():
12✔
893
            if timer.channel not in (*physical_output_channels, None):
12✔
894
                raise ValueError(
×
895
                    f"Invalid channel '{timer.channel}' for global timer {timer_id}"
896
                    + suggest_similar(timer.channel or '', physical_output_channels),
897
                )
898

899
        # TODO: validate global timer onset triggers
900
        # TODO: validate global counters
901
        # TODO: validate conditions
902
        # TODO: Check that sync channel is not used as state output
903

904
        # return here if we're only validating the state machine
905
        if validate_only:
12✔
906
            return
×
907

908
        # compile dicts of indices to resolve strings to integers
909
        target_indices = {
12✔
910
            k: v for v, k in enumerate([*state_machine.states.keys(), 'exit'])
911
        }
912
        target_indices.update({'exit': n_states, '>exit': n_states})
12✔
913
        target_indices.update({'>back': 255} if self._use_back_op else {})
12✔
914
        event_indices = {k: v for v, k in enumerate(self.event_names)}
12✔
915
        action_indices = {k: v for v, k in enumerate(self.output_actions)}
12✔
916

917
        # Initialize bytearray. This will be appended to in the following sections.
918
        byte_array = bytearray(
12✔
919
            (n_states, n_global_timers, n_global_counters, n_conditions),
920
        )
921

922
        # Compile target indices for state timers and append to bytearray
923
        # Target indices default to the respective state's index unless 'Tup' is used
924
        for state_idx, state in enumerate(state_machine.states.values()):
12✔
925
            for event, target in state.state_change_conditions.items():
12✔
926
                if event == 'Tup':
12✔
927
                    byte_array.append(target_indices[target])
12✔
928
                    break
12✔
929
            else:
930
                byte_array.append(state_idx)
12✔
931

932
        # Helper function for appending events and their target indices to bytearray
933
        def append_events(event0: str, event1: str) -> None:
12✔
934
            idx0 = event_indices[event0]
12✔
935
            idx1 = event_indices[event1]
12✔
936
            for state in state_machine.states.values():
12✔
937
                counter_idx = len(byte_array)
12✔
938
                byte_array.append(0)
12✔
939
                for event, target in state.state_change_conditions.items():
12✔
940
                    if idx0 <= (key_idx := event_indices[event]) < idx1:
12✔
941
                        byte_array[counter_idx] += 1
12✔
942
                        byte_array.extend((key_idx - idx0, target_indices[target]))
12✔
943

944
        # Append input events to bytearray (i.e., events on physical input channels)
945
        append_events(self.event_names[0], 'GlobalTimer1_Start')
12✔
946

947
        # Append output actions and their values to bytearray
948
        # TODO: this could be more efficient?
949
        i1 = action_indices['GlobalTimerTrig']
12✔
950
        tmp_list: list[int] = []
12✔
951
        for state in state_machine.states.values():
12✔
952
            counter_pos = len(tmp_list)
12✔
953
            tmp_list.append(0)
12✔
954
            for invalid_action, value in state.output_actions.items():
12✔
955
                if (key_idx := action_indices[invalid_action]) < i1:
12✔
956
                    tmp_list[counter_pos] += 1
12✔
957
                    tmp_list.extend((key_idx, value))
12✔
958
        format_string = 'H' if self.version.machine == 4 else 'B'
12✔
959
        byte_array.extend(
12✔
960
            struct.pack(
961
                f'<{len(tmp_list)}{format_string}',
962
                *tmp_list,
963
            ),
964
        )
965

966
        # state transition matrix
967
        n_states = len(state_machine.states)
12✔
968
        self._state_transitions = np.arange(n_states, dtype=np.uint8)[
12✔
969
            :,
970
            np.newaxis,
971
        ] * np.ones((1, 255), dtype=np.uint8)
972
        for state_idx, state in enumerate(state_machine.states.values()):
12✔
973
            for event, target in state.state_change_conditions.items():
12✔
974
                target_idx = target_indices[target]
12✔
975
                self._state_transitions[state_idx][event_indices[event]] = target_idx
12✔
976

977
        # Append remaining events
978
        append_events('GlobalTimer1_Start', 'GlobalTimer1_End')  # global timer start
12✔
979
        append_events('GlobalTimer1_End', 'GlobalCounter1_End')  # global timer end
12✔
980
        append_events('GlobalCounter1_End', 'Condition1')  # global counter end
12✔
981
        append_events('Condition1', 'Tup')  # conditions
12✔
982

983
        # Compile indices for global timers channels
984
        timer_channel_indices = {k: v for v, k in enumerate(physical_output_channels)}
12✔
985
        timer_channel_indices[None] = 254
12✔
986

987
        # Helper function for packing a collection of integers into byte_array
988
        def pack_values(values: list[int], format_str: str) -> None:
12✔
989
            byte_array.extend(struct.pack(f'<{len(values)}{format_str}', *values))
12✔
990

991
        # Append values for global timer channels to byte_array
992
        idx0 = len(byte_array)
12✔
993
        byte_array.extend(b'\xfe' * n_global_timers)  # default: 254
12✔
994
        for timer_id, global_timer in state_machine.global_timers.items():
12✔
995
            byte_array[idx0 + timer_id] = timer_channel_indices[global_timer.channel]
12✔
996

997
        # Append values for global timers value_on and value_off to bytearray
998
        # Bpod 2+ uses 16-bit values for value_on and value_off
999
        format_string = 'H' if self.version.machine == 4 else 'B'
12✔
1000
        for field_name in ('value_on', 'value_off'):
12✔
1001
            pack_values(
12✔
1002
                [
1003
                    getattr(state_machine.global_timers.get(idx), field_name, 0)
1004
                    for idx in range(n_global_timers)
1005
                ],
1006
                format_string,
1007
            )
1008

1009
        # Append values for global timers loop and send_events to bytearray
1010
        for field_name, default in (('loop', b'\x00'), ('send_events', b'\x01')):
12✔
1011
            idx0 = len(byte_array)
12✔
1012
            byte_array.extend(default * n_global_timers)  # default: 254
12✔
1013
            for timer_id, global_timer in state_machine.global_timers.items():
12✔
1014
                byte_array[idx0 + timer_id] = getattr(global_timer, field_name)
12✔
1015

1016
        # Append global counter events to bytearray
1017
        idx0 = len(byte_array)
12✔
1018
        byte_array.extend((254,) * n_global_counters)
12✔
1019
        for counter_id, global_counter in state_machine.global_counters.items():
12✔
1020
            byte_array[idx0 + counter_id] = event_indices[global_counter.event]
12✔
1021

1022
        # Compile indices for condition channels
1023
        global_timers = [
12✔
1024
            f'GlobalTimer{i + 1}' for i in range(self._hardware.n_global_timers)
1025
        ]
1026
        condition_channel_indices = {
12✔
1027
            k: v for v, k in enumerate(physical_input_channels + global_timers)
1028
        }
1029

1030
        # Append values for conditions to bytearray
1031
        idx0 = len(byte_array)
12✔
1032
        byte_array.extend((0,) * n_conditions * 2)
12✔
1033
        for condition_id, condition in state_machine.conditions.items():
12✔
1034
            offset = idx0 + condition_id
12✔
1035
            byte_array[offset : offset + 2 * n_conditions : n_conditions] = (
12✔
1036
                condition_channel_indices[condition.channel],
1037
                condition.value,
1038
            )
1039

1040
        # Append global counter resets
1041
        if self.version.firmware < (23, 0):
12✔
1042
            byte_array.extend(
×
1043
                s.output_actions.get('GlobalCounterReset', 0)
1044
                for s in state_machine.states.values()
1045
            )
1046
        else:
1047
            counter_idx = len(byte_array)
12✔
1048
            byte_array.append(0)
12✔
1049
            for state_idx, state in enumerate(state_machine.states.values()):
12✔
1050
                if (value := state.output_actions.get('GlobalCounterReset', 0)) > 0:
12✔
1051
                    byte_array[counter_idx] += 1
12✔
1052
                    byte_array.extend([state_idx, value])
12✔
1053

1054
        # Enable / disable analog thresholds
1055
        # TODO: this is just a placeholder for now
1056
        if self.version.machine == 4:
12✔
1057
            byte_array.extend([0, 0])
12✔
1058

1059
        # The format of the next values depends on the number of global timers
1060
        if self._hardware.n_global_timers > 16:
12✔
1061
            format_string = 'I'  # uint32
×
1062
        elif self._hardware.n_global_timers > 8:
12✔
1063
            format_string = 'H'  # uint16
12✔
1064
        else:
1065
            format_string = 'B'  # uint8
×
1066

1067
        # Pack global timer triggers and cancels into bytearray
1068
        for key in ('GlobalTimerTrig', 'GlobalTimerCancel'):
12✔
1069
            pack_values(
12✔
1070
                [s.output_actions.get(key, 0) for s in state_machine.states.values()],
1071
                format_string,
1072
            )
1073

1074
        # Pack global timer onset triggers into bytearray
1075
        pack_values(
12✔
1076
            [
1077
                getattr(state_machine.global_timers.get(idx, {}), 'onset_trigger', 0)
1078
                for idx in range(n_global_timers)
1079
            ],
1080
            format_string,
1081
        )
1082

1083
        # Pack state timers
1084
        pack_values(
12✔
1085
            [
1086
                round(s.timer * self._hardware.cycle_frequency)
1087
                for s in state_machine.states.values()
1088
            ],
1089
            'I',  # uint32
1090
        )
1091

1092
        # Pack global timer durations, onset delays and loop intervals
1093
        for key in ('duration', 'onset_delay', 'loop_interval'):
12✔
1094
            pack_values(
12✔
1095
                [
1096
                    round(
1097
                        getattr(state_machine.global_timers.get(idx, {}), key, 0)
1098
                        * self._hardware.cycle_frequency,
1099
                    )
1100
                    for idx in range(n_global_timers)
1101
                ],
1102
                'I',  # uint32
1103
            )
1104

1105
        # Pack global counter thresholds
1106
        pack_values(
12✔
1107
            [
1108
                getattr(state_machine.global_counters.get(idx, {}), 'threshold', 0)
1109
                for idx in range(n_global_counters)
1110
            ],
1111
            'I',  # uint32
1112
        )
1113

1114
        # Append additional opcodes
1115
        # TODO: why?
1116
        if self.version.firmware > (22, 0):
12✔
1117
            byte_array.append(0)
12✔
1118

1119
        # Send to state machine
1120
        self._next_fsm_index += 1
12✔
1121
        logger.debug('Sending state machine #%d to Bpod', self._next_fsm_index)
12✔
1122
        n_bytes = len(byte_array)
12✔
1123
        self.serial0.write_struct(
12✔
1124
            f'<c2?H{n_bytes}s', b'C', run_asap, self._use_back_op, n_bytes, byte_array
1125
        )
1126
        self._waiting_for_confirmation = True
12✔
1127

1128
        if run_asap:
12✔
1129
            self._run_state_machine(blocking=False)
×
1130

1131
    @property
12✔
1132
    def is_running(self) -> bool:
12✔
1133
        """Check if the Bpod is currently running a state machine."""
1134
        return self._fsm_thread is not None and self._fsm_thread.is_alive()
×
1135

1136
    def wait(self) -> None:
12✔
1137
        """
1138
        Wait for the currently running state machine to finish.
1139

1140
        This method blocks until the state machine has finished executing.
1141
        If no state machine is currently running, it raises a RuntimeError.
1142
        """
1143
        if self.is_running:
×
NEW
1144
            self._fsm_thread.join()  # type: ignore[union-attr]
×
1145

1146
    def run_state_machine(self, *, blocking: bool = True) -> None:
12✔
1147
        """Temporary run method for debugging purposes."""
1148
        if self.is_running:
×
1149
            raise RuntimeError('A state machine is already running')
×
1150
        self.serial0.write(b'R')
×
1151
        self._run_state_machine(blocking=blocking)
×
1152

1153
    def _run_state_machine(self, *, blocking: bool) -> None:
12✔
1154
        # Wait for an already running state machine to finish
1155
        self.wait()
×
1156

1157
        # Start a new FSM thread
1158
        self._fsm_thread = FSMThread(
×
1159
            self.serial0,
1160
            self._next_fsm_index,
1161
            self._waiting_for_confirmation,
1162
            self._hardware.cycle_period,
1163
            self._softcode_handler,
1164
            self._state_transitions,
1165
            self._use_back_op,
1166
        )
1167
        self._fsm_thread.start()
×
1168
        self._waiting_for_confirmation = False
×
1169

1170
        # Wait for the FSM thread to finish
1171
        if blocking:
×
1172
            self._fsm_thread.join()
×
1173

1174
    def stop_state_machine(self) -> None:
12✔
1175
        """Stop the currently running state machine."""
1176
        if not self.is_running:
×
1177
            return
×
1178
        logger.debug('Stopping state machine')
×
1179
        self.serial0.write(b'X')
×
1180
        if self._fsm_thread is not None:
×
1181
            self._fsm_thread.join()
×
1182

1183
    @staticmethod
12✔
1184
    def _softcode_handler(softcode: int) -> None:
12✔
1185
        pass
×
1186

1187

1188
class Channel(ABC):
12✔
1189
    """Abstract base class representing a channel on the Bpod device."""
1190

1191
    @abstractmethod
12✔
1192
    def __init__(self, bpod: Bpod, name: str, io_key: bytes, index: int) -> None:
12✔
1193
        """
1194
        Abstract base class representing a channel on the Bpod device.
1195

1196
        Parameters
1197
        ----------
1198
        bpod : Bpod
1199
            The Bpod instance associated with the channel.
1200
        name : str
1201
            The name of the channel.
1202
        io_key : bytes
1203
            The I/O type of the channel (e.g., b'B', b'V', b'P').
1204
        index : int
1205
            The index of the channel.
1206
        """
1207
        self.name = name
12✔
1208
        self.io_type = io_key
12✔
1209
        self.index = index
12✔
1210
        self._serial0 = bpod.serial0
12✔
1211

1212
    def __repr__(self) -> str:
12✔
1213
        return self.__class__.__name__ + '()'
×
1214

1215

1216
class Input(Channel):
12✔
1217
    """Input channel class representing a digital input channel."""
1218

1219
    def __init__(self, bpod: Bpod, name: str, io_key: bytes, index: int) -> None:
12✔
1220
        """
1221
        Input channel class representing a digital input channel.
1222

1223
        Parameters
1224
        ----------
1225
        bpod : Bpod
1226
            The Bpod instance associated with the channel.
1227
        name : str
1228
            The name of the channel.
1229
        io_key : bytes
1230
            The I/O type of the channel (e.g., b'B', b'V', b'P').
1231
        index : int
1232
            The index of the channel.
1233
        """
1234
        super().__init__(bpod, name, io_key, index)
12✔
1235
        self._set_enable_inputs = bpod._set_enable_inputs
12✔
1236
        self._enabled = io_key in (b'PBWF')  # Enable Port, BNC, Wire and FlexIO inputs
12✔
1237

1238
    def read(self) -> bool:
12✔
1239
        """
1240
        Read the state of the input channel.
1241

1242
        Returns
1243
        -------
1244
        bool
1245
            True if the input channel is active, False otherwise.
1246
        """
1247
        return self._serial0.verify([b'I', self.index])
×
1248

1249
    def override(self, state: bool) -> None:
12✔
1250
        """
1251
        Override the state of the input channel.
1252

1253
        Parameters
1254
        ----------
1255
        state : bool
1256
            The state to set for the input channel.
1257
        """
1258
        self._serial0.write_struct('<cB', b'V', state)
×
1259

1260
    def enable(self, enabled: bool) -> bool:
12✔
1261
        """
1262
        Enable or disable the input channel.
1263

1264
        Parameters
1265
        ----------
1266
        enabled : bool
1267
            True to enable the input channel, False to disable.
1268

1269
        Returns
1270
        -------
1271
        bool
1272
            True if the operation was success, False otherwise.
1273
        """
1274
        if self.io_type not in b'FDBWVP':
×
1275
            logger.warning(
×
1276
                '%sabling input `%s` has no effect',
1277
                'En' if enabled else 'Dis',
1278
                self.name,
1279
            )
1280
        self._enabled = enabled
×
1281
        return self._set_enable_inputs()
×
1282

1283
    @property
12✔
1284
    def enabled(self) -> bool:
12✔
1285
        """
1286
        Check if the input channel is enabled.
1287

1288
        Returns
1289
        -------
1290
        bool
1291
            True if the input channel is enabled, False otherwise.
1292
        """
1293
        return self._enabled
12✔
1294

1295
    @enabled.setter
12✔
1296
    def enabled(self, enabled: bool) -> None:
12✔
1297
        """
1298
        Enable or disable the input channel.
1299

1300
        Parameters
1301
        ----------
1302
        enabled : bool
1303
            True to enable the input channel, False to disable.
1304
        """
1305
        self.enable(enabled)
×
1306

1307

1308
class Output(Channel):
12✔
1309
    """Output channel class representing a digital output channel."""
1310

1311
    def __init__(self, bpod: Bpod, name: str, io_key: bytes, index: int) -> None:
12✔
1312
        """
1313
        Output channel class representing a digital output channel.
1314

1315
        Parameters
1316
        ----------
1317
        bpod : Bpod
1318
            The Bpod instance associated with the channel.
1319
        name : str
1320
            The name of the channel.
1321
        io_key : bytes
1322
            The I/O type of the channel (e.g., b'B', b'V', b'P').
1323
        index : int
1324
            The index of the channel.
1325
        """
1326
        super().__init__(bpod, name, io_key, index)
12✔
1327

1328
    def override(self, state: bool | int) -> None:
12✔
1329
        """
1330
        Override the state of the output channel.
1331

1332
        Parameters
1333
        ----------
1334
        state : bool or int
1335
            The state to set for the output channel. For binary I/O types, provide a
1336
            bool. For pulse width modulation (PWM) I/O types, provide an int (0-255).
1337
        """
1338
        if isinstance(state, int) and self.io_type in (b'D', b'B', b'W'):
×
1339
            state = state > 0
×
1340
        self._serial0.write_struct('<c2B', b'O', self.index, state)
×
1341

1342

1343
@dataclass
12✔
1344
class Module:
12✔
1345
    """Represents a Bpod module with its configuration and event names."""
1346

1347
    _bpod: Bpod
12✔
1348
    """A reference to the Bpod."""
12✔
1349

1350
    index: int
12✔
1351
    """The index of the module."""
12✔
1352

1353
    name: str
12✔
1354
    """The name of the module."""
12✔
1355

1356
    is_connected: bool = False
12✔
1357
    """Whether the module is connected."""
12✔
1358

1359
    firmware_version: int | None = None
12✔
1360
    """The firmware version of the module."""
12✔
1361

1362
    n_events: int = N_SERIAL_EVENTS_DEFAULT
12✔
1363
    """The number of events assigned to the module."""
12✔
1364

1365
    _custom_event_names: list[str] = field(default_factory=list)
12✔
1366
    """A list of custom event names."""
12✔
1367

1368
    def __post_init__(self) -> None:
12✔
1369
        self._relay_enabled = False
12✔
1370
        self._define_event_names()
12✔
1371

1372
    def _define_event_names(self) -> None:
12✔
1373
        """Define the module's event names."""
1374
        self.event_names = []
12✔
1375
        for idx in range(self.n_events):
12✔
1376
            if len(self._custom_event_names) > idx:
12✔
1377
                self.event_names.append(f'{self.name}_{self._custom_event_names[idx]}')
×
1378
            else:
1379
                self.event_names.append(f'{self.name}_{idx + 1}')
12✔
1380

1381
    @validate_call
12✔
1382
    def set_relay(self, enable: bool) -> None:
12✔
1383
        """
1384
        Enable or disable the serial relay for the module.
1385

1386
        Parameters
1387
        ----------
1388
        enable : bool
1389
            True to enable the relay, False to disable it.
1390
        """
1391
        if enable == self._relay_enabled:
12✔
1392
            return
12✔
1393
        if enable is True:
×
1394
            self._bpod._disable_all_module_relays()
×
1395
        logger.info(
×
1396
            '%sabling relay for module %s', {'En' if enable else 'Dis'}, self.name
1397
        )
1398
        self._bpod.serial0.write_struct('<cB?', b'J', self.index, enable)
×
1399
        self._relay_enabled = enable
×
1400

1401
    @property
12✔
1402
    def relay(self) -> bool:
12✔
1403
        """The current state of the serial relay."""
1404
        return self._relay_enabled
×
1405

1406
    @relay.setter
12✔
1407
    def relay(self, state: bool) -> None:
12✔
1408
        """The current state of the serial relay."""
1409
        self.set_relay(state)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc