• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 15738036488

18 Jun 2025 04:10PM UTC coverage: 48.249% (+1.5%) from 46.79%
15738036488

Pull #815

github

9b495a
web-flow
Merge fd70c12e3 into 5c537cbb7
Pull Request #815: extended tests for photometry copier

23 of 32 new or added lines in 1 file covered. (71.88%)

1106 existing lines in 22 files now uncovered.

4408 of 9136 relevant lines covered (48.25%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.57
/iblrig/hardware.py
1
"""Hardware classes used to interact with modules."""
2

3
import logging
2✔
4
import os
2✔
5
import re
2✔
6
import shutil
2✔
7
import struct
2✔
8
import subprocess
2✔
9
import threading
2✔
10
import time
2✔
11
from collections.abc import Callable
2✔
12
from enum import IntEnum
2✔
13
from pathlib import Path
2✔
14
from typing import Annotated, Literal
2✔
15

16
import numpy as np
2✔
17
import serial
2✔
18
import sounddevice as sd
2✔
19
from annotated_types import Ge, Le
2✔
20
from pydantic import PositiveFloat, PositiveInt, validate_call
2✔
21
from serial.serialutil import SerialException
2✔
22
from serial.tools import list_ports
2✔
23

24
from iblrig.pydantic_definitions import HardwareSettingsRotaryEncoder
2✔
25
from iblutil.util import Bunch
2✔
26
from pybpod_rotaryencoder_module.module import RotaryEncoder as PybpodRotaryEncoder
2✔
27
from pybpod_rotaryencoder_module.module_api import RotaryEncoderModule as PybpodRotaryEncoderModule
2✔
28
from pybpodapi.bpod.bpod_io import BpodIO
2✔
29
from pybpodapi.bpod_modules.bpod_module import BpodModule
2✔
30
from pybpodapi.state_machine import StateMachine
2✔
31

32
SOFTCODE = IntEnum('SOFTCODE', ['STOP_SOUND', 'PLAY_TONE', 'PLAY_NOISE', 'TRIGGER_CAMERA'])
2✔
33
DTYPE_AMBIENT_SENSOR_RAW = np.dtype(
2✔
34
    [('Temperature_C', np.float32), ('AirPressure_mb', np.float32), ('RelativeHumidity', np.float32)]
35
)
36
DTYPE_AMBIENT_SENSOR_BIN = np.dtype([('Trial', np.uint16)] + DTYPE_AMBIENT_SENSOR_RAW.descr)
2✔
37

38
# some annotated types
39
Uint8 = Annotated[int, Ge(0), Le(255)]
2✔
40
ActionIdx = Annotated[int, Ge(1), Le(255)]
2✔
41

42
log = logging.getLogger(__name__)
2✔
43

44

45
class Bpod(BpodIO):
2✔
46
    can_control_led = True
2✔
47
    softcodes: dict[int, Callable] | None = None
2✔
48
    _instances = {}
2✔
49
    _lock = threading.RLock()
2✔
50
    _is_initialized = False
2✔
51

52
    def __new__(cls, *args, **kwargs):
2✔
53
        serial_port = args[0] if len(args) > 0 else ''
2✔
54
        serial_port = kwargs.get('serial_port', serial_port)
2✔
55
        with cls._lock:
2✔
56
            instance = Bpod._instances.get(serial_port, None)
2✔
57
            if instance:
2✔
58
                return instance
2✔
59
            instance = super().__new__(cls)
2✔
60
            Bpod._instances[serial_port] = instance
2✔
61
            return instance
2✔
62

63
    def __init__(self, *args, skip_initialization: bool = False, **kwargs):
2✔
64
        # skip initialization if it has already been performed before
65
        # IMPORTANT: only use this for non-critical tasks (e.g., flushing valve from GUI)
66
        if skip_initialization and self._is_initialized:
2✔
UNCOV
67
            return
×
68

69
        # try to instantiate once for nothing
70
        try:
2✔
71
            super().__init__(*args, **kwargs)
2✔
UNCOV
72
        except Exception:
×
UNCOV
73
            log.warning("Couldn't instantiate BPOD, retrying once...")
×
UNCOV
74
            time.sleep(1)
×
UNCOV
75
            try:
×
UNCOV
76
                super().__init__(*args, **kwargs)
×
UNCOV
77
            except (serial.serialutil.SerialException, UnicodeDecodeError) as e:
×
UNCOV
78
                log.error(e)
×
UNCOV
79
                raise serial.serialutil.SerialException(
×
80
                    'The communication with Bpod is established but the Bpod is not responsive. '
81
                    'This is usually indicated by the device with a green light. '
82
                    'Please unplug the Bpod USB cable from the computer and plug it back in to start the task. '
83
                ) from e
84
        self.serial_messages = {}
2✔
85
        self.actions = Bunch({})
2✔
86
        self.can_control_led = self.set_status_led(True)
2✔
87
        self._is_initialized = True
2✔
88

89
    def close(self) -> None:
2✔
90
        super().close()
2✔
91
        self._is_initialized = False
2✔
92

93
    def __del__(self):
2✔
94
        with self._lock:
2✔
95
            if self.serial_port in Bpod._instances:
2✔
96
                Bpod._instances.pop(self.serial_port)
2✔
97

98
    @property
2✔
99
    def is_connected(self):
2✔
100
        return self.modules is not None
2✔
101

102
    @property
2✔
103
    def rotary_encoder(self):
2✔
UNCOV
104
        return self.get_module('rotary_encoder')
×
105

106
    @property
2✔
107
    def sound_card(self):
2✔
UNCOV
108
        return self.get_module('sound_card')
×
109

110
    @property
2✔
111
    def ambient_module(self):
2✔
112
        return self.get_module('^AmbientModule')
2✔
113

114
    def get_module(self, module_name: str) -> BpodModule | None:
2✔
115
        """Get module by name.
116

117
        Parameters
118
        ----------
119
        module_name : str
120
            Regular Expression for matching a module name
121

122
        Returns
123
        -------
124
        BpodModule | None
125
            First matching module or None
126
        """
127
        if self.modules is None:
2✔
128
            return None
2✔
UNCOV
129
        if module_name in ['re', 'rotary_encoder']:
×
UNCOV
130
            module_name = r'^RotaryEncoder'
×
UNCOV
131
        elif module_name in ['sc', 'sound_card']:
×
UNCOV
132
            module_name = r'^SoundCard'
×
UNCOV
133
        modules = [x for x in self.modules if re.match(module_name, x.name)]
×
UNCOV
134
        if len(modules) > 1:
×
UNCOV
135
            log.critical(f'Found several Bpod modules matching `{module_name}`. Using first match: `{modules[0].name}`')
×
UNCOV
136
        if len(modules) > 0:
×
137
            return modules[0]
×
138

139
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
140
    def _define_message(self, module: BpodModule | int, message: list[Uint8]) -> ActionIdx:
2✔
141
        """Define a serial message to be sent to a Bpod module as an output action within a state.
142

143
        Parameters
144
        ----------
145
        module : BpodModule | int
146
            The targeted module, defined as a BpodModule instance or the module's port index
147
        message : list[int]
148
            The message to be sent - a list of up to three 8-bit integers
149

150
        Returns
151
        -------
152
        int
153
            The index of the serial message (1-255)
154

155
        Raises
156
        ------
157
        TypeError
158
            If module is not an instance of BpodModule or int
159

160
        Examples
161
        --------
162
        >>> id_msg_bonsai_show_stim = self._define_message(self.rotary_encoder, [ord("#"), 2])
163
        will then be used as such in StateMachine:
164
        >>> output_actions=[("Serial1", id_msg_bonsai_show_stim)]
165
        """
UNCOV
166
        if isinstance(module, BpodModule):
×
UNCOV
167
            module = module.serial_port
×
UNCOV
168
        message_id = len(self.serial_messages) + 1
×
UNCOV
169
        self.load_serial_message(module, message_id, message)
×
UNCOV
170
        self.serial_messages.update({message_id: {'target_module': module, 'message': message}})
×
UNCOV
171
        return message_id
×
172

173
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
174
    def define_xonar_sounds_actions(self):
2✔
UNCOV
175
        self.actions.update(
×
176
            {
177
                'play_tone': ('SoftCode', SOFTCODE.PLAY_TONE),
178
                'play_noise': ('SoftCode', SOFTCODE.PLAY_NOISE),
179
                'stop_sound': ('SoftCode', SOFTCODE.STOP_SOUND),
180
            }
181
        )
182

183
    def define_harp_sounds_actions(self, module: BpodModule, go_tone_index: int = 2, noise_index: int = 3) -> None:
2✔
184
        module_port = f'Serial{module.serial_port if module is not None else ""}'
×
185
        self.actions.update(
×
186
            {
187
                'play_tone': (module_port, self._define_message(module, [ord('P'), go_tone_index])),
188
                'play_noise': (module_port, self._define_message(module, [ord('P'), noise_index])),
189
                'stop_sound': (module_port, ord('X')),
190
            }
191
        )
192

193
    def define_rotary_encoder_actions(self, module: BpodModule | None = None) -> None:
2✔
UNCOV
194
        if module is None:
×
UNCOV
195
            module = self.rotary_encoder
×
UNCOV
196
        module_port = f'Serial{module.serial_port if module is not None else ""}'
×
UNCOV
197
        self.actions.update(
×
198
            {
199
                'rotary_encoder_reset': (
200
                    module_port,
201
                    self._define_message(
202
                        module, [PybpodRotaryEncoder.COM_SETZEROPOS, PybpodRotaryEncoder.COM_ENABLE_ALLTHRESHOLDS]
203
                    ),
204
                ),
205
                'bonsai_hide_stim': (module_port, self._define_message(module, [ord('#'), 1])),
206
                'bonsai_show_stim': (module_port, self._define_message(module, [ord('#'), 8])),
207
                'bonsai_closed_loop': (module_port, self._define_message(module, [ord('#'), 3])),
208
                'bonsai_freeze_stim': (module_port, self._define_message(module, [ord('#'), 4])),
209
                'bonsai_show_center': (module_port, self._define_message(module, [ord('#'), 5])),
210
                'bonsai_freeze_center': (module_port, self._define_message(module, [ord('#'), 9])),
211
            }
212
        )
213

214
    def get_ambient_sensor_reading(self) -> np.ndarray:
2✔
215
        """
216
        Retrieve ambient sensor readings.
217

218
        If the ambient sensor module is not available, returns an array filled with NaN values.
219
        Otherwise, retrieves the temperature, air pressure, and relative humidity readings.
220

221
        Returns
222
        -------
223
        np.ndarray
224
            A NumPy array containing the sensor readings in the following order:
225

226
            - [0] : Temperature in degrees Celsius
227
            - [1] : Air pressure in millibars
228
            - [2] : Relative humidity in percentage
229
        """
230
        if self.ambient_module is None:
2✔
231
            data = np.full(3, np.nan, np.float32)
2✔
232
        else:
UNCOV
233
            self.ambient_module.start_module_relay()
×
UNCOV
234
            self.bpod_modules.module_write(self.ambient_module, 'R')
×
UNCOV
235
            reply = self.bpod_modules.module_read(self.ambient_module, 12)
×
UNCOV
236
            self.ambient_module.stop_module_relay()
×
UNCOV
237
            data = np.frombuffer(bytes(reply), dtype=np.float32).copy()
×
UNCOV
238
            data[1] /= 100
×
239
        return data
2✔
240

241
    def flush(self):
2✔
242
        """Flushes valve 1."""
243
        self.toggle_valve()
×
244

245
    def toggle_valve(self, duration: int | None = None):
2✔
246
        """
247
        Flush valve 1 for specified duration.
248

249
        Parameters
250
        ----------
251
        duration : int, optional
252
            Duration of valve opening in seconds.
253
        """
UNCOV
254
        if duration is None:
×
UNCOV
255
            self.open_valve(state=True, valve_number=1)
×
UNCOV
256
            input('Press ENTER when done.')
×
UNCOV
257
            self.open_valve(state=False, valve_number=1)
×
258
        else:
259
            self.pulse_valve(open_time_s=duration)
×
260

261
    def open_valve(self, state: bool, valve_number: int = 1):
2✔
UNCOV
262
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_number, state)
×
263

264
    def pulse_valve(self, open_time_s: float, valve: str = 'Valve1'):
2✔
UNCOV
265
        sma = StateMachine(self)
×
UNCOV
266
        sma.add_state(
×
267
            state_name='flush', state_timer=open_time_s, state_change_conditions={'Tup': 'exit'}, output_actions=[(valve, 255)]
268
        )
UNCOV
269
        self.send_state_machine(sma)
×
UNCOV
270
        self.run_state_machine(sma)
×
271

272
    @validate_call()
2✔
273
    def pulse_valve_repeatedly(
2✔
274
        self, repetitions: PositiveInt, open_time_s: PositiveFloat, close_time_s: PositiveFloat = 0.2, valve: str = 'Valve1'
275
    ) -> int:
UNCOV
276
        counter = 0
×
277

UNCOV
278
        def softcode_handler(softcode: int):
×
279
            nonlocal counter, repetitions
UNCOV
280
            if softcode == 1:
×
UNCOV
281
                counter += 1
×
UNCOV
282
            elif softcode == 2 and counter >= repetitions:
×
UNCOV
283
                self.stop_trial()
×
284

UNCOV
285
        original_softcode_handler = self.softcode_handler_function
×
UNCOV
286
        self.softcode_handler_function = softcode_handler
×
287

UNCOV
288
        sma = StateMachine(self)
×
UNCOV
289
        sma.add_state(
×
290
            state_name='open',
291
            state_timer=open_time_s,
292
            state_change_conditions={'Tup': 'close'},
293
            output_actions=[(valve, 255), ('SoftCode', 1)],
294
        )
UNCOV
295
        sma.add_state(
×
296
            state_name='close',
297
            state_timer=close_time_s,
298
            state_change_conditions={'Tup': 'open'},
299
            output_actions=[('SoftCode', 2)],
300
        )
UNCOV
301
        self.send_state_machine(sma)
×
UNCOV
302
        self.run_state_machine(sma)
×
303

UNCOV
304
        self.softcode_handler_function = original_softcode_handler
×
UNCOV
305
        return counter
×
306

307
    def set_status_led(self, state: bool) -> bool:
2✔
308
        if self.can_control_led and self._arcom is not None:
2✔
UNCOV
309
            try:
×
UNCOV
310
                log.debug(f'{"en" if state else "dis"}abling Bpod Status LED')
×
UNCOV
311
                command = struct.pack('cB', b':', state)
×
UNCOV
312
                self._arcom.serial_object.write(command)
×
UNCOV
313
                if self._arcom.read_uint8() == 1:
×
UNCOV
314
                    return True
×
UNCOV
315
            except (serial.SerialException, struct.error):
×
UNCOV
316
                pass
×
UNCOV
317
            self._arcom.serial_object.reset_input_buffer()
×
UNCOV
318
            self._arcom.serial_object.reset_output_buffer()
×
UNCOV
319
            log.warning('Bpod device does not support control of the status LED. Please update firmware.')
×
320
        return False
2✔
321

322
    def valve(self, valve_id: int, state: bool):
2✔
UNCOV
323
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_id, state)
×
324

325
    @validate_call
2✔
326
    def register_softcodes(self, softcode_dict: dict[int, Callable]) -> None:
2✔
327
        """
328
        Register softcodes to be used in the state machine.
329

330
        Parameters
331
        ----------
332
        softcode_dict : dict[int, Callable]
333
            dictionary of int keys with callables as values
334
        """
335
        self.softcodes = softcode_dict
2✔
336
        self.softcode_handler_function = lambda code: softcode_dict[code]()
2✔
337

338

339
class RotaryEncoderModule(PybpodRotaryEncoderModule):
2✔
340
    _name = 'Rotary Encoder Module'
2✔
341

342
    ENCODER_EVENTS = list()
2✔
343
    THRESHOLD_EVENTS = dict()
2✔
344

345
    def __init__(self, settings: HardwareSettingsRotaryEncoder, thresholds_deg: list[float], gain: float):
2✔
346
        super().__init__()
2✔
347
        self.settings = settings
2✔
348

349
        self._wheel_degree_per_mm = 360.0 / (self.settings.WHEEL_DIAMETER_MM * np.pi)
2✔
350
        self.thresholds_deg = thresholds_deg
2✔
351
        self.gain = gain
2✔
352
        self.ENCODER_EVENTS = [f'RotaryEncoder1_{x + 1}' for x in range(len(thresholds_deg))]
2✔
353
        self.THRESHOLD_EVENTS = dict(zip(thresholds_deg, self.ENCODER_EVENTS, strict=False))
2✔
354

355
    def open(self, _=None):
2✔
356
        if self.settings.COM_ROTARY_ENCODER is None:
2✔
357
            raise ValueError(
2✔
358
                'The value for device_rotary_encoder:COM_ROTARY_ENCODER in settings/hardware_settings.yaml is null. '
359
                'Please provide a valid port name.'
360
            )
UNCOV
361
        try:
×
UNCOV
362
            super().open(self.settings.COM_ROTARY_ENCODER)
×
UNCOV
363
        except SerialException as e:
×
UNCOV
364
            raise SerialException(
×
365
                f'The {self._name} on port {self.settings.COM_ROTARY_ENCODER} is already in use. This is '
366
                f'usually due to a Bonsai process running on the computer. Make sure all Bonsai windows are closed '
367
                f'prior to running the task.'
368
            ) from e
UNCOV
369
        except Exception as e:
×
UNCOV
370
            raise Exception(f'The {self._name} on port {self.settings.COM_ROTARY_ENCODER} did not return the handshake.') from e
×
UNCOV
371
        log.debug(f'Successfully opened serial connection to {self._name} on port {self.settings.COM_ROTARY_ENCODER}')
×
372

373
    def write_parameters(self):
2✔
UNCOV
374
        scaled_thresholds_deg = [x / self.gain * self._wheel_degree_per_mm for x in self.thresholds_deg]
×
UNCOV
375
        enabled_thresholds = [(x < len(scaled_thresholds_deg)) for x in range(8)]
×
376

UNCOV
377
        log.info(
×
378
            f'Thresholds for {self._name} scaled to {", ".join([f"{x:0.2f}" for x in scaled_thresholds_deg])} '
379
            f'using gain of {self.gain:0.1f} deg/mm and wheel diameter of {self.settings.WHEEL_DIAMETER_MM:0.1f} mm.'
380
        )
UNCOV
381
        self.set_zero_position()
×
UNCOV
382
        self.set_thresholds(scaled_thresholds_deg)
×
UNCOV
383
        self.enable_thresholds(enabled_thresholds)
×
UNCOV
384
        self.enable_evt_transmission()
×
385

386
    def close(self):
2✔
387
        if getattr(self, 'arcom') is not None:  # noqa: B009
2✔
UNCOV
388
            log.debug(f'Closing serial connection to {self._name} on port {self.settings.COM_ROTARY_ENCODER}')
×
UNCOV
389
            super().close()
×
390

391
    def __del__(self):
2✔
392
        self.close()
2✔
393

394

395
def sound_device_factory(output: Literal['xonar', 'harp', 'hifi', 'sysdefault'] = 'sysdefault', samplerate: int | None = None):
2✔
396
    """
397
    Will import, configure, and return sounddevice module to play sounds using onboard sound card.
398

399
    Parameters
400
    ----------
401
    output
402
        defaults to "sysdefault", should be 'xonar' or 'harp'
403
    samplerate
404
        audio sample rate, defaults to 44100
405
    """
406
    match output:
2✔
407
        case 'xonar':
2✔
UNCOV
408
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
409
            devices = sd.query_devices()
×
UNCOV
410
            sd.default.device = next((i for i, d in enumerate(devices) if 'XONAR SOUND CARD(64)' in d['name']), None)
×
UNCOV
411
            sd.default.latency = 'low'
×
UNCOV
412
            sd.default.channels = 2
×
UNCOV
413
            channels = 'L+TTL'
×
UNCOV
414
            sd.default.samplerate = samplerate
×
415
        case 'harp':
2✔
UNCOV
416
            samplerate = samplerate if samplerate is not None else 96000
×
UNCOV
417
            sd.default.samplerate = samplerate
×
UNCOV
418
            sd.default.channels = 2
×
UNCOV
419
            channels = 'stereo'
×
420
        case 'hifi':
2✔
UNCOV
421
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
422
            channels = 'stereo'
×
423
        case 'sysdefault':
2✔
424
            samplerate = samplerate if samplerate is not None else 44100
2✔
425
            sd.default.latency = 'low'
2✔
426
            sd.default.channels = 2
2✔
427
            sd.default.samplerate = samplerate
2✔
428
            channels = 'stereo'
2✔
UNCOV
429
        case _:
×
UNCOV
430
            raise ValueError()
×
431
    return sd, samplerate, channels
2✔
432

433

434
def restart_com_port(regexp: str) -> bool:
2✔
435
    """
436
    Restart the communication port(s) matching the specified regular expression.
437

438
    Parameters
439
    ----------
440
    regexp : str
441
        A regular expression used to match the communication port(s).
442

443
    Returns
444
    -------
445
    bool
446
        Returns True if all matched ports are successfully restarted, False otherwise.
447

448
    Raises
449
    ------
450
    NotImplementedError
451
        If the operating system is not Windows.
452

453
    FileNotFoundError
454
        If the required 'pnputil.exe' executable is not found.
455

456
    Examples
457
    --------
458
    >>> restart_com_port("COM3")  # Restart the communication port with serial number 'COM3'
459
    True
460

461
    >>> restart_com_port("COM[1-3]")  # Restart communication ports with serial numbers 'COM1', 'COM2', 'COM3'
462
    True
463
    """
UNCOV
464
    if not os.name == 'nt':
×
UNCOV
465
        raise NotImplementedError('Only implemented for Windows OS.')
×
UNCOV
466
    if not (file_pnputil := Path(shutil.which('pnputil'))).exists():
×
UNCOV
467
        raise FileNotFoundError('Could not find pnputil.exe')
×
UNCOV
468
    result = []
×
UNCOV
469
    for port in list_ports.grep(regexp):
×
UNCOV
470
        pnputil_output = subprocess.check_output([file_pnputil, '/enum-devices', '/connected', '/class', 'ports'])
×
UNCOV
471
        instance_id = re.search(rf'(\S*{port.serial_number}\S*)', pnputil_output.decode())
×
UNCOV
472
        if instance_id is None:
×
UNCOV
473
            continue
×
UNCOV
474
        result.append(
×
475
            subprocess.check_call(
476
                [file_pnputil, '/restart-device', f'"{instance_id.group}"'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
477
            )
478
            == 0
479
        )
UNCOV
480
    return all(result)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc