• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 11407201950

18 Oct 2024 04:12PM UTC coverage: 47.898% (+1.1%) from 46.79%
11407201950

Pull #730

github

86ab26
web-flow
Merge 9801a3e94 into 0f4a57326
Pull Request #730: 8.24.4

47 of 68 new or added lines in 8 files covered. (69.12%)

1013 existing lines in 22 files now uncovered.

4170 of 8706 relevant lines covered (47.9%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.59
/iblrig/hardware.py
1
"""Hardware classes used to interact with modules."""
2

3
import logging
2✔
4
import os
2✔
5
import re
2✔
6
import shutil
2✔
7
import struct
2✔
8
import subprocess
2✔
9
import threading
2✔
10
import time
2✔
11
from collections.abc import Callable
2✔
12
from enum import IntEnum
2✔
13
from pathlib import Path
2✔
14
from typing import Annotated, Literal
2✔
15

16
import numpy as np
2✔
17
import serial
2✔
18
import sounddevice as sd
2✔
19
from annotated_types import Ge, Le
2✔
20
from pydantic import PositiveFloat, PositiveInt, validate_call
2✔
21
from serial.serialutil import SerialException
2✔
22
from serial.tools import list_ports
2✔
23

24
from iblrig.pydantic_definitions import HardwareSettingsRotaryEncoder
2✔
25
from iblrig.tools import static_vars
2✔
26
from iblutil.util import Bunch
2✔
27
from pybpod_rotaryencoder_module.module import RotaryEncoder as PybpodRotaryEncoder
2✔
28
from pybpod_rotaryencoder_module.module_api import RotaryEncoderModule as PybpodRotaryEncoderModule
2✔
29
from pybpodapi.bpod.bpod_io import BpodIO
2✔
30
from pybpodapi.bpod_modules.bpod_module import BpodModule
2✔
31
from pybpodapi.state_machine import StateMachine
2✔
32

33
SOFTCODE = IntEnum('SOFTCODE', ['STOP_SOUND', 'PLAY_TONE', 'PLAY_NOISE', 'TRIGGER_CAMERA'])
2✔
34

35
# some annotated types
36
Uint8 = Annotated[int, Ge(0), Le(255)]
2✔
37
ActionIdx = Annotated[int, Ge(1), Le(255)]
2✔
38

39
log = logging.getLogger(__name__)
2✔
40

41

42
class Bpod(BpodIO):
2✔
43
    can_control_led = True
2✔
44
    softcodes: dict[int, Callable] | None = None
2✔
45
    _instances = {}
2✔
46
    _lock = threading.RLock()
2✔
47
    _is_initialized = False
2✔
48

49
    def __new__(cls, *args, **kwargs):
2✔
50
        serial_port = args[0] if len(args) > 0 else ''
2✔
51
        serial_port = kwargs.get('serial_port', serial_port)
2✔
52
        with cls._lock:
2✔
53
            instance = Bpod._instances.get(serial_port, None)
2✔
54
            if instance:
2✔
55
                return instance
2✔
56
            instance = super().__new__(cls)
2✔
57
            Bpod._instances[serial_port] = instance
2✔
58
            return instance
2✔
59

60
    def __init__(self, *args, skip_initialization: bool = False, **kwargs):
2✔
61
        # skip initialization if it has already been performed before
62
        # IMPORTANT: only use this for non-critical tasks (e.g., flushing valve from GUI)
63
        if skip_initialization and self._is_initialized:
2✔
UNCOV
64
            return
×
65

66
        # try to instantiate once for nothing
67
        try:
2✔
68
            super().__init__(*args, **kwargs)
2✔
UNCOV
69
        except Exception:
×
UNCOV
70
            log.warning("Couldn't instantiate BPOD, retrying once...")
×
UNCOV
71
            time.sleep(1)
×
UNCOV
72
            try:
×
UNCOV
73
                super().__init__(*args, **kwargs)
×
UNCOV
74
            except (serial.serialutil.SerialException, UnicodeDecodeError) as e:
×
UNCOV
75
                log.error(e)
×
UNCOV
76
                raise serial.serialutil.SerialException(
×
77
                    'The communication with Bpod is established but the Bpod is not responsive. '
78
                    'This is usually indicated by the device with a green light. '
79
                    'Please unplug the Bpod USB cable from the computer and plug it back in to start the task. '
80
                ) from e
81
        self.serial_messages = {}
2✔
82
        self.actions = Bunch({})
2✔
83
        self.can_control_led = self.set_status_led(True)
2✔
84
        self._is_initialized = True
2✔
85

86
    def close(self) -> None:
2✔
87
        super().close()
2✔
88
        self._is_initialized = False
2✔
89

90
    def __del__(self):
2✔
91
        with self._lock:
2✔
92
            if self.serial_port in Bpod._instances:
2✔
93
                Bpod._instances.pop(self.serial_port)
2✔
94

95
    @property
2✔
96
    def is_connected(self):
2✔
97
        return self.modules is not None
2✔
98

99
    @property
2✔
100
    def rotary_encoder(self):
2✔
UNCOV
101
        return self.get_module('rotary_encoder')
×
102

103
    @property
2✔
104
    def sound_card(self):
2✔
UNCOV
105
        return self.get_module('sound_card')
×
106

107
    @property
2✔
108
    def ambient_module(self):
2✔
109
        return self.get_module('^AmbientModule')
2✔
110

111
    def get_module(self, module_name: str) -> BpodModule | None:
2✔
112
        """Get module by name.
113

114
        Parameters
115
        ----------
116
        module_name : str
117
            Regular Expression for matching a module name
118

119
        Returns
120
        -------
121
        BpodModule | None
122
            First matching module or None
123
        """
124
        if self.modules is None:
2✔
125
            return None
2✔
126
        if module_name in ['re', 'rotary_encoder']:
×
UNCOV
127
            module_name = r'^RotaryEncoder'
×
UNCOV
128
        elif module_name in ['sc', 'sound_card']:
×
UNCOV
129
            module_name = r'^SoundCard'
×
UNCOV
130
        modules = [x for x in self.modules if re.match(module_name, x.name)]
×
UNCOV
131
        if len(modules) > 1:
×
UNCOV
132
            log.critical(f'Found several Bpod modules matching `{module_name}`. Using first match: `{modules[0].name}`')
×
UNCOV
133
        if len(modules) > 0:
×
UNCOV
134
            return modules[0]
×
135

136
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
137
    def _define_message(self, module: BpodModule | int, message: list[Uint8]) -> ActionIdx:
2✔
138
        """Define a serial message to be sent to a Bpod module as an output action within a state.
139

140
        Parameters
141
        ----------
142
        module : BpodModule | int
143
            The targeted module, defined as a BpodModule instance or the module's port index
144
        message : list[int]
145
            The message to be sent - a list of up to three 8-bit integers
146

147
        Returns
148
        -------
149
        int
150
            The index of the serial message (1-255)
151

152
        Raises
153
        ------
154
        TypeError
155
            If module is not an instance of BpodModule or int
156

157
        Examples
158
        --------
159
        >>> id_msg_bonsai_show_stim = self._define_message(self.rotary_encoder, [ord("#"), 2])
160
        will then be used as such in StateMachine:
161
        >>> output_actions=[("Serial1", id_msg_bonsai_show_stim)]
162
        """
UNCOV
163
        if isinstance(module, BpodModule):
×
UNCOV
164
            module = module.serial_port
×
UNCOV
165
        message_id = len(self.serial_messages) + 1
×
UNCOV
166
        self.load_serial_message(module, message_id, message)
×
167
        self.serial_messages.update({message_id: {'target_module': module, 'message': message}})
×
UNCOV
168
        return message_id
×
169

170
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
171
    def define_xonar_sounds_actions(self):
2✔
UNCOV
172
        self.actions.update(
×
173
            {
174
                'play_tone': ('SoftCode', SOFTCODE.PLAY_TONE),
175
                'play_noise': ('SoftCode', SOFTCODE.PLAY_NOISE),
176
                'stop_sound': ('SoftCode', SOFTCODE.STOP_SOUND),
177
            }
178
        )
179

180
    def define_harp_sounds_actions(self, module: BpodModule, go_tone_index: int = 2, noise_index: int = 3) -> None:
2✔
UNCOV
181
        module_port = f"Serial{module.serial_port if module is not None else ''}"
×
182
        self.actions.update(
×
183
            {
184
                'play_tone': (module_port, self._define_message(module, [ord('P'), go_tone_index])),
185
                'play_noise': (module_port, self._define_message(module, [ord('P'), noise_index])),
186
                'stop_sound': (module_port, ord('X')),
187
            }
188
        )
189

190
    def define_rotary_encoder_actions(self, module: BpodModule | None = None) -> None:
2✔
191
        if module is None:
×
UNCOV
192
            module = self.rotary_encoder
×
UNCOV
193
        module_port = f"Serial{module.serial_port if module is not None else ''}"
×
UNCOV
194
        self.actions.update(
×
195
            {
196
                'rotary_encoder_reset': (
197
                    module_port,
198
                    self._define_message(
199
                        module, [PybpodRotaryEncoder.COM_SETZEROPOS, PybpodRotaryEncoder.COM_ENABLE_ALLTHRESHOLDS]
200
                    ),
201
                ),
202
                'bonsai_hide_stim': (module_port, self._define_message(module, [ord('#'), 1])),
203
                'bonsai_show_stim': (module_port, self._define_message(module, [ord('#'), 8])),
204
                'bonsai_closed_loop': (module_port, self._define_message(module, [ord('#'), 3])),
205
                'bonsai_freeze_stim': (module_port, self._define_message(module, [ord('#'), 4])),
206
                'bonsai_show_center': (module_port, self._define_message(module, [ord('#'), 5])),
207
            }
208
        )
209

210
    def get_ambient_sensor_reading(self):
2✔
211
        if self.ambient_module is None:
2✔
212
            return {
2✔
213
                'Temperature_C': np.nan,
214
                'AirPressure_mb': np.nan,
215
                'RelativeHumidity': np.nan,
216
            }
UNCOV
217
        self.ambient_module.start_module_relay()
×
UNCOV
218
        self.bpod_modules.module_write(self.ambient_module, 'R')
×
219
        reply = self.bpod_modules.module_read(self.ambient_module, 12)
×
220
        self.ambient_module.stop_module_relay()
×
221

UNCOV
222
        return {
×
223
            'Temperature_C': np.frombuffer(bytes(reply[:4]), np.float32)[0],
224
            'AirPressure_mb': np.frombuffer(bytes(reply[4:8]), np.float32)[0] / 100,
225
            'RelativeHumidity': np.frombuffer(bytes(reply[8:]), np.float32)[0],
226
        }
227

228
    def flush(self):
2✔
229
        """Flushes valve 1."""
230
        self.toggle_valve()
×
231

232
    def toggle_valve(self, duration: int | None = None):
2✔
233
        """
234
        Flush valve 1 for specified duration.
235

236
        Parameters
237
        ----------
238
        duration : int, optional
239
            Duration of valve opening in seconds.
240
        """
UNCOV
241
        if duration is None:
×
UNCOV
242
            self.open_valve(open=True, valve_number=1)
×
UNCOV
243
            input('Press ENTER when done.')
×
244
            self.open_valve(open=False, valve_number=1)
×
245
        else:
246
            self.pulse_valve(open_time_s=duration)
×
247

248
    def open_valve(self, open: bool, valve_number: int = 1):
2✔
249
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_number, open)
×
250

251
    def pulse_valve(self, open_time_s: float, valve: str = 'Valve1'):
2✔
252
        sma = StateMachine(self)
×
253
        sma.add_state(
×
254
            state_name='flush', state_timer=open_time_s, state_change_conditions={'Tup': 'exit'}, output_actions=[(valve, 255)]
255
        )
UNCOV
256
        self.send_state_machine(sma)
×
UNCOV
257
        self.run_state_machine(sma)
×
258

259
    @validate_call()
2✔
260
    def pulse_valve_repeatedly(
2✔
261
        self, repetitions: PositiveInt, open_time_s: PositiveFloat, close_time_s: PositiveFloat = 0.2, valve: str = 'Valve1'
262
    ) -> int:
263
        counter = 0
×
264

UNCOV
265
        def softcode_handler(softcode: int):
×
266
            nonlocal counter, repetitions
UNCOV
267
            if softcode == 1:
×
UNCOV
268
                counter += 1
×
UNCOV
269
            elif softcode == 2 and counter >= repetitions:
×
UNCOV
270
                self.stop_trial()
×
271

UNCOV
272
        original_softcode_handler = self.softcode_handler_function
×
UNCOV
273
        self.softcode_handler_function = softcode_handler
×
274

UNCOV
275
        sma = StateMachine(self)
×
UNCOV
276
        sma.add_state(
×
277
            state_name='open',
278
            state_timer=open_time_s,
279
            state_change_conditions={'Tup': 'close'},
280
            output_actions=[(valve, 255), ('SoftCode', 1)],
281
        )
UNCOV
282
        sma.add_state(
×
283
            state_name='close',
284
            state_timer=close_time_s,
285
            state_change_conditions={'Tup': 'open'},
286
            output_actions=[('SoftCode', 2)],
287
        )
UNCOV
288
        self.send_state_machine(sma)
×
UNCOV
289
        self.run_state_machine(sma)
×
290

UNCOV
291
        self.softcode_handler_function = original_softcode_handler
×
UNCOV
292
        return counter
×
293

294
    @static_vars(supported=True)
2✔
295
    def set_status_led(self, state: bool) -> bool:
2✔
296
        if self.can_control_led and self._arcom is not None:
2✔
UNCOV
297
            try:
×
UNCOV
298
                log.debug(f'{"en" if state else "dis"}abling Bpod Status LED')
×
UNCOV
299
                command = struct.pack('cB', b':', state)
×
UNCOV
300
                self._arcom.serial_object.write(command)
×
UNCOV
301
                if self._arcom.read_uint8() == 1:
×
UNCOV
302
                    return True
×
UNCOV
303
            except serial.SerialException:
×
UNCOV
304
                pass
×
UNCOV
305
            self._arcom.serial_object.reset_input_buffer()
×
UNCOV
306
            self._arcom.serial_object.reset_output_buffer()
×
UNCOV
307
            log.warning('Bpod device does not support control of the status LED. Please update firmware.')
×
308
        return False
2✔
309

310
    def valve(self, valve_id: int, state: bool):
2✔
UNCOV
311
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_id, state)
×
312

313
    @validate_call
2✔
314
    def register_softcodes(self, softcode_dict: dict[int, Callable]) -> None:
2✔
315
        """
316
        Register softcodes to be used in the state machine.
317

318
        Parameters
319
        ----------
320
        softcode_dict : dict[int, Callable]
321
            dictionary of int keys with callables as values
322
        """
323
        self.softcodes = softcode_dict
2✔
324
        self.softcode_handler_function = lambda code: softcode_dict[code]()
2✔
325

326

327
class RotaryEncoderModule(PybpodRotaryEncoderModule):
2✔
328
    _name = 'Rotary Encoder Module'
2✔
329

330
    ENCODER_EVENTS = list()
2✔
331
    THRESHOLD_EVENTS = dict()
2✔
332

333
    def __init__(self, settings: HardwareSettingsRotaryEncoder, thresholds_deg: list[float], gain: float):
2✔
334
        super().__init__()
2✔
335
        self.settings = settings
2✔
336

337
        self._wheel_degree_per_mm = 360.0 / (self.settings.WHEEL_DIAMETER_MM * np.pi)
2✔
338
        self.thresholds_deg = thresholds_deg
2✔
339
        self.gain = gain
2✔
340
        self.ENCODER_EVENTS = [f'RotaryEncoder1_{x + 1}' for x in range(len(thresholds_deg))]
2✔
341
        self.THRESHOLD_EVENTS = dict(zip(thresholds_deg, self.ENCODER_EVENTS, strict=False))
2✔
342

343
    def open(self, _=None):
2✔
344
        if self.settings.COM_ROTARY_ENCODER is None:
2✔
345
            raise ValueError(
2✔
346
                'The value for device_rotary_encoder:COM_ROTARY_ENCODER in settings/hardware_settings.yaml is null. '
347
                'Please provide a valid port name.'
348
            )
NEW
349
        try:
×
NEW
350
            super().open(self.settings.COM_ROTARY_ENCODER)
×
NEW
351
        except SerialException as e:
×
NEW
352
            raise SerialException(
×
353
                f'The {self._name} on port {self.settings.COM_ROTARY_ENCODER} is already in use. This is '
354
                f'usually due to a Bonsai process running on the computer. Make sure all Bonsai windows are closed '
355
                f'prior to running the task.'
356
            ) from e
NEW
357
        except Exception as e:
×
NEW
358
            raise Exception(f'The {self._name} on port {self.settings.COM_ROTARY_ENCODER} did not return the handshake.') from e
×
NEW
359
        log.debug(f'Successfully opened serial connection to {self._name} on port {self.settings.COM_ROTARY_ENCODER}')
×
360

361
    def write_parameters(self):
2✔
NEW
362
        scaled_thresholds_deg = [x / self.gain * self._wheel_degree_per_mm for x in self.thresholds_deg]
×
NEW
363
        enabled_thresholds = [(x < len(scaled_thresholds_deg)) for x in range(8)]
×
364

NEW
365
        log.info(
×
366
            f'Thresholds for {self._name} scaled to {", ".join([f"{x:0.2f}" for x in scaled_thresholds_deg])} '
367
            f'using gain of {self.gain:0.1f} deg/mm and wheel diameter of {self.settings.WHEEL_DIAMETER_MM:0.1f} mm.'
368
        )
NEW
369
        self.set_zero_position()
×
NEW
370
        self.set_thresholds(scaled_thresholds_deg)
×
NEW
371
        self.enable_thresholds(enabled_thresholds)
×
NEW
372
        self.enable_evt_transmission()
×
373

374
    def close(self):
2✔
375
        if hasattr(self, 'arcom'):
2✔
NEW
376
            log.debug(f'Closing serial connection to {self._name} on port {self.settings.COM_ROTARY_ENCODER}')
×
NEW
377
            super().close()
×
378

379
    def __del__(self):
2✔
380
        self.close()
2✔
381

382

383
def sound_device_factory(output: Literal['xonar', 'harp', 'hifi', 'sysdefault'] = 'sysdefault', samplerate: int | None = None):
2✔
384
    """
385
    Will import, configure, and return sounddevice module to play sounds using onboard sound card.
386

387
    Parameters
388
    ----------
389
    output
390
        defaults to "sysdefault", should be 'xonar' or 'harp'
391
    samplerate
392
        audio sample rate, defaults to 44100
393
    """
394
    match output:
2✔
395
        case 'xonar':
2✔
UNCOV
396
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
397
            devices = sd.query_devices()
×
UNCOV
398
            sd.default.device = next((i for i, d in enumerate(devices) if 'XONAR SOUND CARD(64)' in d['name']), None)
×
UNCOV
399
            sd.default.latency = 'low'
×
UNCOV
400
            sd.default.channels = 2
×
UNCOV
401
            channels = 'L+TTL'
×
UNCOV
402
            sd.default.samplerate = samplerate
×
403
        case 'harp':
2✔
UNCOV
404
            samplerate = samplerate if samplerate is not None else 96000
×
UNCOV
405
            sd.default.samplerate = samplerate
×
UNCOV
406
            sd.default.channels = 2
×
UNCOV
407
            channels = 'stereo'
×
408
        case 'hifi':
2✔
UNCOV
409
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
410
            channels = 'stereo'
×
411
        case 'sysdefault':
2✔
412
            samplerate = samplerate if samplerate is not None else 44100
2✔
413
            sd.default.latency = 'low'
2✔
414
            sd.default.channels = 2
2✔
415
            sd.default.samplerate = samplerate
2✔
416
            channels = 'stereo'
2✔
UNCOV
417
        case _:
×
UNCOV
418
            raise ValueError()
×
419
    return sd, samplerate, channels
2✔
420

421

422
def restart_com_port(regexp: str) -> bool:
2✔
423
    """
424
    Restart the communication port(s) matching the specified regular expression.
425

426
    Parameters
427
    ----------
428
    regexp : str
429
        A regular expression used to match the communication port(s).
430

431
    Returns
432
    -------
433
    bool
434
        Returns True if all matched ports are successfully restarted, False otherwise.
435

436
    Raises
437
    ------
438
    NotImplementedError
439
        If the operating system is not Windows.
440

441
    FileNotFoundError
442
        If the required 'pnputil.exe' executable is not found.
443

444
    Examples
445
    --------
446
    >>> restart_com_port("COM3")  # Restart the communication port with serial number 'COM3'
447
    True
448

449
    >>> restart_com_port("COM[1-3]")  # Restart communication ports with serial numbers 'COM1', 'COM2', 'COM3'
450
    True
451
    """
UNCOV
452
    if not os.name == 'nt':
×
UNCOV
453
        raise NotImplementedError('Only implemented for Windows OS.')
×
UNCOV
454
    if not (file_pnputil := Path(shutil.which('pnputil'))).exists():
×
UNCOV
455
        raise FileNotFoundError('Could not find pnputil.exe')
×
UNCOV
456
    result = []
×
UNCOV
457
    for port in list_ports.grep(regexp):
×
UNCOV
458
        pnputil_output = subprocess.check_output([file_pnputil, '/enum-devices', '/connected', '/class', 'ports'])
×
UNCOV
459
        instance_id = re.search(rf'(\S*{port.serial_number}\S*)', pnputil_output.decode())
×
UNCOV
460
        if instance_id is None:
×
UNCOV
461
            continue
×
UNCOV
462
        result.append(
×
463
            subprocess.check_call(
464
                [file_pnputil, '/restart-device', f'"{instance_id.group}"'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
465
            )
466
            == 0
467
        )
UNCOV
468
    return all(result)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc