• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 14196118657

01 Apr 2025 12:52PM UTC coverage: 47.634% (+0.8%) from 46.79%
14196118657

Pull #795

github

cfb5bd
web-flow
Merge 5ba5d5f25 into 58cf64236
Pull Request #795: fixes for habituation CW

11 of 12 new or added lines in 1 file covered. (91.67%)

1083 existing lines in 22 files now uncovered.

4288 of 9002 relevant lines covered (47.63%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.21
/iblrig/hardware.py
1
"""Hardware classes used to interact with modules."""
2

3
import logging
2✔
4
import os
2✔
5
import re
2✔
6
import shutil
2✔
7
import struct
2✔
8
import subprocess
2✔
9
import threading
2✔
10
import time
2✔
11
from collections.abc import Callable
2✔
12
from enum import IntEnum
2✔
13
from pathlib import Path
2✔
14
from typing import Annotated, Literal
2✔
15

16
import numpy as np
2✔
17
import serial
2✔
18
import sounddevice as sd
2✔
19
from annotated_types import Ge, Le
2✔
20
from pydantic import PositiveFloat, PositiveInt, validate_call
2✔
21
from serial.serialutil import SerialException
2✔
22
from serial.tools import list_ports
2✔
23

24
from iblrig.pydantic_definitions import HardwareSettingsRotaryEncoder
2✔
25
from iblutil.util import Bunch
2✔
26
from pybpod_rotaryencoder_module.module import RotaryEncoder as PybpodRotaryEncoder
2✔
27
from pybpod_rotaryencoder_module.module_api import RotaryEncoderModule as PybpodRotaryEncoderModule
2✔
28
from pybpodapi.bpod.bpod_io import BpodIO
2✔
29
from pybpodapi.bpod_modules.bpod_module import BpodModule
2✔
30
from pybpodapi.state_machine import StateMachine
2✔
31

32
SOFTCODE = IntEnum('SOFTCODE', ['STOP_SOUND', 'PLAY_TONE', 'PLAY_NOISE', 'TRIGGER_CAMERA'])
2✔
33

34
# some annotated types
35
Uint8 = Annotated[int, Ge(0), Le(255)]
2✔
36
ActionIdx = Annotated[int, Ge(1), Le(255)]
2✔
37

38
log = logging.getLogger(__name__)
2✔
39

40

41
class Bpod(BpodIO):
2✔
42
    can_control_led = True
2✔
43
    softcodes: dict[int, Callable] | None = None
2✔
44
    _instances = {}
2✔
45
    _lock = threading.RLock()
2✔
46
    _is_initialized = False
2✔
47

48
    def __new__(cls, *args, **kwargs):
2✔
49
        serial_port = args[0] if len(args) > 0 else ''
2✔
50
        serial_port = kwargs.get('serial_port', serial_port)
2✔
51
        with cls._lock:
2✔
52
            instance = Bpod._instances.get(serial_port, None)
2✔
53
            if instance:
2✔
54
                return instance
2✔
55
            instance = super().__new__(cls)
2✔
56
            Bpod._instances[serial_port] = instance
2✔
57
            return instance
2✔
58

59
    def __init__(self, *args, skip_initialization: bool = False, **kwargs):
2✔
60
        # skip initialization if it has already been performed before
61
        # IMPORTANT: only use this for non-critical tasks (e.g., flushing valve from GUI)
62
        if skip_initialization and self._is_initialized:
2✔
UNCOV
63
            return
×
64

65
        # try to instantiate once for nothing
66
        try:
2✔
67
            super().__init__(*args, **kwargs)
2✔
UNCOV
68
        except Exception:
×
UNCOV
69
            log.warning("Couldn't instantiate BPOD, retrying once...")
×
UNCOV
70
            time.sleep(1)
×
UNCOV
71
            try:
×
UNCOV
72
                super().__init__(*args, **kwargs)
×
UNCOV
73
            except (serial.serialutil.SerialException, UnicodeDecodeError) as e:
×
UNCOV
74
                log.error(e)
×
UNCOV
75
                raise serial.serialutil.SerialException(
×
76
                    'The communication with Bpod is established but the Bpod is not responsive. '
77
                    'This is usually indicated by the device with a green light. '
78
                    'Please unplug the Bpod USB cable from the computer and plug it back in to start the task. '
79
                ) from e
80
        self.serial_messages = {}
2✔
81
        self.actions = Bunch({})
2✔
82
        self.can_control_led = self.set_status_led(True)
2✔
83
        self._is_initialized = True
2✔
84

85
    def close(self) -> None:
2✔
86
        super().close()
2✔
87
        self._is_initialized = False
2✔
88

89
    def __del__(self):
2✔
90
        with self._lock:
2✔
91
            if self.serial_port in Bpod._instances:
2✔
92
                Bpod._instances.pop(self.serial_port)
2✔
93

94
    @property
2✔
95
    def is_connected(self):
2✔
96
        return self.modules is not None
2✔
97

98
    @property
2✔
99
    def rotary_encoder(self):
2✔
UNCOV
100
        return self.get_module('rotary_encoder')
×
101

102
    @property
2✔
103
    def sound_card(self):
2✔
UNCOV
104
        return self.get_module('sound_card')
×
105

106
    @property
2✔
107
    def ambient_module(self):
2✔
108
        return self.get_module('^AmbientModule')
2✔
109

110
    def get_module(self, module_name: str) -> BpodModule | None:
2✔
111
        """Get module by name.
112

113
        Parameters
114
        ----------
115
        module_name : str
116
            Regular Expression for matching a module name
117

118
        Returns
119
        -------
120
        BpodModule | None
121
            First matching module or None
122
        """
123
        if self.modules is None:
2✔
124
            return None
2✔
UNCOV
125
        if module_name in ['re', 'rotary_encoder']:
×
UNCOV
126
            module_name = r'^RotaryEncoder'
×
UNCOV
127
        elif module_name in ['sc', 'sound_card']:
×
UNCOV
128
            module_name = r'^SoundCard'
×
UNCOV
129
        modules = [x for x in self.modules if re.match(module_name, x.name)]
×
UNCOV
130
        if len(modules) > 1:
×
UNCOV
131
            log.critical(f'Found several Bpod modules matching `{module_name}`. Using first match: `{modules[0].name}`')
×
UNCOV
132
        if len(modules) > 0:
×
UNCOV
133
            return modules[0]
×
134

135
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
136
    def _define_message(self, module: BpodModule | int, message: list[Uint8]) -> ActionIdx:
2✔
137
        """Define a serial message to be sent to a Bpod module as an output action within a state.
138

139
        Parameters
140
        ----------
141
        module : BpodModule | int
142
            The targeted module, defined as a BpodModule instance or the module's port index
143
        message : list[int]
144
            The message to be sent - a list of up to three 8-bit integers
145

146
        Returns
147
        -------
148
        int
149
            The index of the serial message (1-255)
150

151
        Raises
152
        ------
153
        TypeError
154
            If module is not an instance of BpodModule or int
155

156
        Examples
157
        --------
158
        >>> id_msg_bonsai_show_stim = self._define_message(self.rotary_encoder, [ord("#"), 2])
159
        will then be used as such in StateMachine:
160
        >>> output_actions=[("Serial1", id_msg_bonsai_show_stim)]
161
        """
UNCOV
162
        if isinstance(module, BpodModule):
×
UNCOV
163
            module = module.serial_port
×
UNCOV
164
        message_id = len(self.serial_messages) + 1
×
165
        self.load_serial_message(module, message_id, message)
×
UNCOV
166
        self.serial_messages.update({message_id: {'target_module': module, 'message': message}})
×
UNCOV
167
        return message_id
×
168

169
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
170
    def define_xonar_sounds_actions(self):
2✔
UNCOV
171
        self.actions.update(
×
172
            {
173
                'play_tone': ('SoftCode', SOFTCODE.PLAY_TONE),
174
                'play_noise': ('SoftCode', SOFTCODE.PLAY_NOISE),
175
                'stop_sound': ('SoftCode', SOFTCODE.STOP_SOUND),
176
            }
177
        )
178

179
    def define_harp_sounds_actions(self, module: BpodModule, go_tone_index: int = 2, noise_index: int = 3) -> None:
2✔
180
        module_port = f'Serial{module.serial_port if module is not None else ""}'
×
181
        self.actions.update(
×
182
            {
183
                'play_tone': (module_port, self._define_message(module, [ord('P'), go_tone_index])),
184
                'play_noise': (module_port, self._define_message(module, [ord('P'), noise_index])),
185
                'stop_sound': (module_port, ord('X')),
186
            }
187
        )
188

189
    def define_rotary_encoder_actions(self, module: BpodModule | None = None) -> None:
2✔
UNCOV
190
        if module is None:
×
UNCOV
191
            module = self.rotary_encoder
×
UNCOV
192
        module_port = f'Serial{module.serial_port if module is not None else ""}'
×
UNCOV
193
        self.actions.update(
×
194
            {
195
                'rotary_encoder_reset': (
196
                    module_port,
197
                    self._define_message(
198
                        module, [PybpodRotaryEncoder.COM_SETZEROPOS, PybpodRotaryEncoder.COM_ENABLE_ALLTHRESHOLDS]
199
                    ),
200
                ),
201
                'bonsai_hide_stim': (module_port, self._define_message(module, [ord('#'), 1])),
202
                'bonsai_show_stim': (module_port, self._define_message(module, [ord('#'), 8])),
203
                'bonsai_closed_loop': (module_port, self._define_message(module, [ord('#'), 3])),
204
                'bonsai_freeze_stim': (module_port, self._define_message(module, [ord('#'), 4])),
205
                'bonsai_show_center': (module_port, self._define_message(module, [ord('#'), 5])),
206
            }
207
        )
208

209
    def get_ambient_sensor_reading(self):
2✔
210
        if self.ambient_module is None:
2✔
211
            return {
2✔
212
                'Temperature_C': np.nan,
213
                'AirPressure_mb': np.nan,
214
                'RelativeHumidity': np.nan,
215
            }
216
        self.ambient_module.start_module_relay()
×
UNCOV
217
        self.bpod_modules.module_write(self.ambient_module, 'R')
×
UNCOV
218
        reply = self.bpod_modules.module_read(self.ambient_module, 12)
×
219
        self.ambient_module.stop_module_relay()
×
220

221
        return {
×
222
            'Temperature_C': np.frombuffer(bytes(reply[:4]), np.float32)[0],
223
            'AirPressure_mb': np.frombuffer(bytes(reply[4:8]), np.float32)[0] / 100,
224
            'RelativeHumidity': np.frombuffer(bytes(reply[8:]), np.float32)[0],
225
        }
226

227
    def flush(self):
2✔
228
        """Flushes valve 1."""
UNCOV
229
        self.toggle_valve()
×
230

231
    def toggle_valve(self, duration: int | None = None):
2✔
232
        """
233
        Flush valve 1 for specified duration.
234

235
        Parameters
236
        ----------
237
        duration : int, optional
238
            Duration of valve opening in seconds.
239
        """
240
        if duration is None:
×
241
            self.open_valve(open=True, valve_number=1)
×
242
            input('Press ENTER when done.')
×
243
            self.open_valve(open=False, valve_number=1)
×
244
        else:
245
            self.pulse_valve(open_time_s=duration)
×
246

247
    def open_valve(self, open: bool, valve_number: int = 1):
2✔
248
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_number, open)
×
249

250
    def pulse_valve(self, open_time_s: float, valve: str = 'Valve1'):
2✔
251
        sma = StateMachine(self)
×
UNCOV
252
        sma.add_state(
×
253
            state_name='flush', state_timer=open_time_s, state_change_conditions={'Tup': 'exit'}, output_actions=[(valve, 255)]
254
        )
UNCOV
255
        self.send_state_machine(sma)
×
UNCOV
256
        self.run_state_machine(sma)
×
257

258
    @validate_call()
2✔
259
    def pulse_valve_repeatedly(
2✔
260
        self, repetitions: PositiveInt, open_time_s: PositiveFloat, close_time_s: PositiveFloat = 0.2, valve: str = 'Valve1'
261
    ) -> int:
UNCOV
262
        counter = 0
×
263

UNCOV
264
        def softcode_handler(softcode: int):
×
265
            nonlocal counter, repetitions
UNCOV
266
            if softcode == 1:
×
UNCOV
267
                counter += 1
×
UNCOV
268
            elif softcode == 2 and counter >= repetitions:
×
UNCOV
269
                self.stop_trial()
×
270

UNCOV
271
        original_softcode_handler = self.softcode_handler_function
×
UNCOV
272
        self.softcode_handler_function = softcode_handler
×
273

UNCOV
274
        sma = StateMachine(self)
×
UNCOV
275
        sma.add_state(
×
276
            state_name='open',
277
            state_timer=open_time_s,
278
            state_change_conditions={'Tup': 'close'},
279
            output_actions=[(valve, 255), ('SoftCode', 1)],
280
        )
UNCOV
281
        sma.add_state(
×
282
            state_name='close',
283
            state_timer=close_time_s,
284
            state_change_conditions={'Tup': 'open'},
285
            output_actions=[('SoftCode', 2)],
286
        )
UNCOV
287
        self.send_state_machine(sma)
×
UNCOV
288
        self.run_state_machine(sma)
×
289

UNCOV
290
        self.softcode_handler_function = original_softcode_handler
×
UNCOV
291
        return counter
×
292

293
    def set_status_led(self, state: bool) -> bool:
2✔
294
        if self.can_control_led and self._arcom is not None:
2✔
UNCOV
295
            try:
×
UNCOV
296
                log.debug(f'{"en" if state else "dis"}abling Bpod Status LED')
×
UNCOV
297
                command = struct.pack('cB', b':', state)
×
UNCOV
298
                self._arcom.serial_object.write(command)
×
UNCOV
299
                if self._arcom.read_uint8() == 1:
×
UNCOV
300
                    return True
×
UNCOV
301
            except (serial.SerialException, struct.error):
×
UNCOV
302
                pass
×
UNCOV
303
            self._arcom.serial_object.reset_input_buffer()
×
UNCOV
304
            self._arcom.serial_object.reset_output_buffer()
×
UNCOV
305
            log.warning('Bpod device does not support control of the status LED. Please update firmware.')
×
306
        return False
2✔
307

308
    def valve(self, valve_id: int, state: bool):
2✔
UNCOV
309
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_id, state)
×
310

311
    @validate_call
2✔
312
    def register_softcodes(self, softcode_dict: dict[int, Callable]) -> None:
2✔
313
        """
314
        Register softcodes to be used in the state machine.
315

316
        Parameters
317
        ----------
318
        softcode_dict : dict[int, Callable]
319
            dictionary of int keys with callables as values
320
        """
321
        self.softcodes = softcode_dict
2✔
322
        self.softcode_handler_function = lambda code: softcode_dict[code]()
2✔
323

324

325
class RotaryEncoderModule(PybpodRotaryEncoderModule):
2✔
326
    _name = 'Rotary Encoder Module'
2✔
327

328
    ENCODER_EVENTS = list()
2✔
329
    THRESHOLD_EVENTS = dict()
2✔
330

331
    def __init__(self, settings: HardwareSettingsRotaryEncoder, thresholds_deg: list[float], gain: float):
2✔
332
        super().__init__()
2✔
333
        self.settings = settings
2✔
334

335
        self._wheel_degree_per_mm = 360.0 / (self.settings.WHEEL_DIAMETER_MM * np.pi)
2✔
336
        self.thresholds_deg = thresholds_deg
2✔
337
        self.gain = gain
2✔
338
        self.ENCODER_EVENTS = [f'RotaryEncoder1_{x + 1}' for x in range(len(thresholds_deg))]
2✔
339
        self.THRESHOLD_EVENTS = dict(zip(thresholds_deg, self.ENCODER_EVENTS, strict=False))
2✔
340

341
    def open(self, _=None):
2✔
342
        if self.settings.COM_ROTARY_ENCODER is None:
2✔
343
            raise ValueError(
2✔
344
                'The value for device_rotary_encoder:COM_ROTARY_ENCODER in settings/hardware_settings.yaml is null. '
345
                'Please provide a valid port name.'
346
            )
UNCOV
347
        try:
×
UNCOV
348
            super().open(self.settings.COM_ROTARY_ENCODER)
×
UNCOV
349
        except SerialException as e:
×
UNCOV
350
            raise SerialException(
×
351
                f'The {self._name} on port {self.settings.COM_ROTARY_ENCODER} is already in use. This is '
352
                f'usually due to a Bonsai process running on the computer. Make sure all Bonsai windows are closed '
353
                f'prior to running the task.'
354
            ) from e
UNCOV
355
        except Exception as e:
×
UNCOV
356
            raise Exception(f'The {self._name} on port {self.settings.COM_ROTARY_ENCODER} did not return the handshake.') from e
×
UNCOV
357
        log.debug(f'Successfully opened serial connection to {self._name} on port {self.settings.COM_ROTARY_ENCODER}')
×
358

359
    def write_parameters(self):
2✔
UNCOV
360
        scaled_thresholds_deg = [x / self.gain * self._wheel_degree_per_mm for x in self.thresholds_deg]
×
UNCOV
361
        enabled_thresholds = [(x < len(scaled_thresholds_deg)) for x in range(8)]
×
362

UNCOV
363
        log.info(
×
364
            f'Thresholds for {self._name} scaled to {", ".join([f"{x:0.2f}" for x in scaled_thresholds_deg])} '
365
            f'using gain of {self.gain:0.1f} deg/mm and wheel diameter of {self.settings.WHEEL_DIAMETER_MM:0.1f} mm.'
366
        )
UNCOV
367
        self.set_zero_position()
×
UNCOV
368
        self.set_thresholds(scaled_thresholds_deg)
×
UNCOV
369
        self.enable_thresholds(enabled_thresholds)
×
UNCOV
370
        self.enable_evt_transmission()
×
371

372
    def close(self):
2✔
373
        if self.arcom is not None:
2✔
UNCOV
374
            log.debug(f'Closing serial connection to {self._name} on port {self.settings.COM_ROTARY_ENCODER}')
×
UNCOV
375
            super().close()
×
376

377
    def __del__(self):
2✔
378
        self.close()
2✔
379

380

381
def sound_device_factory(output: Literal['xonar', 'harp', 'hifi', 'sysdefault'] = 'sysdefault', samplerate: int | None = None):
2✔
382
    """
383
    Will import, configure, and return sounddevice module to play sounds using onboard sound card.
384

385
    Parameters
386
    ----------
387
    output
388
        defaults to "sysdefault", should be 'xonar' or 'harp'
389
    samplerate
390
        audio sample rate, defaults to 44100
391
    """
392
    match output:
2✔
393
        case 'xonar':
2✔
UNCOV
394
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
395
            devices = sd.query_devices()
×
UNCOV
396
            sd.default.device = next((i for i, d in enumerate(devices) if 'XONAR SOUND CARD(64)' in d['name']), None)
×
UNCOV
397
            sd.default.latency = 'low'
×
UNCOV
398
            sd.default.channels = 2
×
UNCOV
399
            channels = 'L+TTL'
×
UNCOV
400
            sd.default.samplerate = samplerate
×
401
        case 'harp':
2✔
UNCOV
402
            samplerate = samplerate if samplerate is not None else 96000
×
UNCOV
403
            sd.default.samplerate = samplerate
×
UNCOV
404
            sd.default.channels = 2
×
UNCOV
405
            channels = 'stereo'
×
406
        case 'hifi':
2✔
UNCOV
407
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
408
            channels = 'stereo'
×
409
        case 'sysdefault':
2✔
410
            samplerate = samplerate if samplerate is not None else 44100
2✔
411
            sd.default.latency = 'low'
2✔
412
            sd.default.channels = 2
2✔
413
            sd.default.samplerate = samplerate
2✔
414
            channels = 'stereo'
2✔
UNCOV
415
        case _:
×
UNCOV
416
            raise ValueError()
×
417
    return sd, samplerate, channels
2✔
418

419

420
def restart_com_port(regexp: str) -> bool:
2✔
421
    """
422
    Restart the communication port(s) matching the specified regular expression.
423

424
    Parameters
425
    ----------
426
    regexp : str
427
        A regular expression used to match the communication port(s).
428

429
    Returns
430
    -------
431
    bool
432
        Returns True if all matched ports are successfully restarted, False otherwise.
433

434
    Raises
435
    ------
436
    NotImplementedError
437
        If the operating system is not Windows.
438

439
    FileNotFoundError
440
        If the required 'pnputil.exe' executable is not found.
441

442
    Examples
443
    --------
444
    >>> restart_com_port("COM3")  # Restart the communication port with serial number 'COM3'
445
    True
446

447
    >>> restart_com_port("COM[1-3]")  # Restart communication ports with serial numbers 'COM1', 'COM2', 'COM3'
448
    True
449
    """
UNCOV
450
    if not os.name == 'nt':
×
UNCOV
451
        raise NotImplementedError('Only implemented for Windows OS.')
×
UNCOV
452
    if not (file_pnputil := Path(shutil.which('pnputil'))).exists():
×
UNCOV
453
        raise FileNotFoundError('Could not find pnputil.exe')
×
UNCOV
454
    result = []
×
UNCOV
455
    for port in list_ports.grep(regexp):
×
UNCOV
456
        pnputil_output = subprocess.check_output([file_pnputil, '/enum-devices', '/connected', '/class', 'ports'])
×
UNCOV
457
        instance_id = re.search(rf'(\S*{port.serial_number}\S*)', pnputil_output.decode())
×
UNCOV
458
        if instance_id is None:
×
UNCOV
459
            continue
×
UNCOV
460
        result.append(
×
461
            subprocess.check_call(
462
                [file_pnputil, '/restart-device', f'"{instance_id.group}"'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
463
            )
464
            == 0
465
        )
UNCOV
466
    return all(result)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc