• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 9032957364

10 May 2024 01:25PM UTC coverage: 48.538% (+1.7%) from 46.79%
9032957364

Pull #643

github

74d2ec
web-flow
Merge aebf2c9af into ec2d8e4fe
Pull Request #643: 8.19.0

377 of 1074 new or added lines in 38 files covered. (35.1%)

977 existing lines in 19 files now uncovered.

3253 of 6702 relevant lines covered (48.54%)

0.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.77
/iblrig/hardware.py
1
"""
2
This modules contains hardware classes used to interact with modules.
3
"""
4

5
import logging
2✔
6
import os
2✔
7
import re
2✔
8
import shutil
2✔
9
import struct
2✔
10
import subprocess
2✔
11
import threading
2✔
12
import time
2✔
13
from collections.abc import Callable
2✔
14
from enum import IntEnum
2✔
15
from pathlib import Path
2✔
16
from typing import Annotated, Literal
2✔
17

18
import numpy as np
2✔
19
import serial
2✔
20
import sounddevice as sd
2✔
21
from annotated_types import Ge, Le
2✔
22
from pydantic import validate_call
2✔
23
from serial.tools import list_ports
2✔
24

25
from iblrig.tools import static_vars
2✔
26
from iblutil.util import Bunch
2✔
27
from pybpod_rotaryencoder_module.module import RotaryEncoder
2✔
28
from pybpod_rotaryencoder_module.module_api import RotaryEncoderModule
2✔
29
from pybpodapi.bpod.bpod_io import BpodIO
2✔
30
from pybpodapi.bpod_modules.bpod_module import BpodModule
2✔
31
from pybpodapi.state_machine import StateMachine
2✔
32

33
SOFTCODE = IntEnum('SOFTCODE', ['STOP_SOUND', 'PLAY_TONE', 'PLAY_NOISE', 'TRIGGER_CAMERA'])
2✔
34

35
# some annotated types
36
Uint8 = Annotated[int, Ge(0), Le(255)]
2✔
37
ActionIdx = Annotated[int, Ge(1), Le(255)]
2✔
38

39
log = logging.getLogger(__name__)
2✔
40

41

42
class Bpod(BpodIO):
2✔
43
    can_control_led = True
2✔
44
    softcodes: dict[int, Callable] | None = None
2✔
45
    _instances = {}
2✔
46
    _lock = threading.Lock()
2✔
47
    _is_initialized = False
2✔
48

49
    def __new__(cls, *args, **kwargs):
2✔
50
        serial_port = args[0] if len(args) > 0 else ''
2✔
51
        serial_port = kwargs.get('serial_port', serial_port)
2✔
52
        with cls._lock:
2✔
53
            instance = Bpod._instances.get(serial_port, None)
2✔
54
            if instance:
2✔
55
                return instance
2✔
56
            instance = super().__new__(cls)
2✔
57
            Bpod._instances[serial_port] = instance
2✔
58
            return instance
2✔
59

60
    def __init__(self, *args, skip_initialization: bool = False, **kwargs):
2✔
61
        # skip initialization if it has already been performed before
62
        # IMPORTANT: only use this for non-critical tasks (e.g., flushing valve from GUI)
63
        if skip_initialization and self._is_initialized:
2✔
64
            return
×
65

66
        # try to instantiate once for nothing
67
        try:
2✔
68
            super().__init__(*args, **kwargs)
2✔
UNCOV
69
        except Exception:
×
UNCOV
70
            log.warning("Couldn't instantiate BPOD, retrying once...")
×
UNCOV
71
            time.sleep(1)
×
UNCOV
72
            try:
×
UNCOV
73
                super().__init__(*args, **kwargs)
×
UNCOV
74
            except (serial.serialutil.SerialException, UnicodeDecodeError) as e:
×
UNCOV
75
                log.error(e)
×
UNCOV
76
                raise serial.serialutil.SerialException(
×
77
                    'The communication with Bpod is established but the Bpod is not responsive. '
78
                    'This is usually indicated by the device with a green light. '
79
                    'Please unplug the Bpod USB cable from the computer and plug it back in to start the task. '
80
                ) from e
81
        self.serial_messages = {}
2✔
82
        self.actions = Bunch({})
2✔
83
        self.can_control_led = self.set_status_led(True)
2✔
84
        self._is_initialized = True
2✔
85

86
    def close(self) -> None:
2✔
87
        super().close()
2✔
88
        self._is_initialized = False
2✔
89

90
    def __del__(self):
2✔
91
        with self._lock:
2✔
92
            if self.serial_port in Bpod._instances:
2✔
93
                Bpod._instances.pop(self.serial_port)
2✔
94

95
    @property
2✔
96
    def is_connected(self):
2✔
97
        return self.modules is not None
2✔
98

99
    @property
2✔
100
    def rotary_encoder(self):
2✔
101
        return self.get_module('rotary_encoder')
×
102

103
    @property
2✔
104
    def sound_card(self):
2✔
UNCOV
105
        return self.get_module('sound_card')
×
106

107
    @property
2✔
108
    def ambient_module(self):
2✔
109
        return self.get_module('^AmbientModule')
2✔
110

111
    def get_module(self, module_name: str) -> BpodModule | None:
2✔
112
        """Get module by name
113

114
        Parameters
115
        ----------
116
        module_name : str
117
            Regular Expression for matching a module name
118

119
        Returns
120
        -------
121
        BpodModule | None
122
            First matching module or None
123
        """
124
        if self.modules is None:
2✔
125
            return None
2✔
UNCOV
126
        if module_name in ['re', 'rotary_encoder']:
×
127
            module_name = r'^RotaryEncoder'
×
UNCOV
128
        elif module_name in ['sc', 'sound_card']:
×
UNCOV
129
            module_name = r'^SoundCard'
×
UNCOV
130
        modules = [x for x in self.modules if re.match(module_name, x.name)]
×
UNCOV
131
        if len(modules) > 1:
×
UNCOV
132
            log.critical(f'Found several Bpod modules matching `{module_name}`. Using first match: `{modules[0].name}`')
×
UNCOV
133
        if len(modules) > 0:
×
134
            return modules[0]
×
135

136
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
137
    def _define_message(self, module: BpodModule | int, message: list[Uint8]) -> ActionIdx:
2✔
138
        """Define a serial message to be sent to a Bpod module as an output action within a state
139

140
        Parameters
141
        ----------
142
        module : BpodModule | int
143
            The targeted module, defined as a BpodModule instance or the module's port index
144
        message : list[int]
145
            The message to be sent - a list of up to three 8-bit integers
146

147
        Returns
148
        -------
149
        int
150
            The index of the serial message (1-255)
151

152
        Raises
153
        ------
154
        TypeError
155
            If module is not an instance of BpodModule or int
156

157
        Examples
158
        --------
159
        >>> id_msg_bonsai_show_stim = self._define_message(self.rotary_encoder, [ord("#"), 2])
160
        will then be used as such in StateMachine:
161
        >>> output_actions=[("Serial1", id_msg_bonsai_show_stim)]
162
        """
NEW
163
        if isinstance(module, BpodModule):
×
UNCOV
164
            module = module.serial_port
×
UNCOV
165
        message_id = len(self.serial_messages) + 1
×
UNCOV
166
        self.load_serial_message(module, message_id, message)
×
UNCOV
167
        self.serial_messages.update({message_id: {'target_module': module, 'message': message}})
×
UNCOV
168
        return message_id
×
169

170
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
171
    def define_xonar_sounds_actions(self):
2✔
UNCOV
172
        self.actions.update(
×
173
            {
174
                'play_tone': ('SoftCode', SOFTCODE.PLAY_TONE),
175
                'play_noise': ('SoftCode', SOFTCODE.PLAY_NOISE),
176
                'stop_sound': ('SoftCode', SOFTCODE.STOP_SOUND),
177
            }
178
        )
179

180
    def define_harp_sounds_actions(self, module: BpodModule, go_tone_index: int = 2, noise_index: int = 3) -> None:
2✔
NEW
181
        module_port = f"Serial{module.serial_port if module is not None else ''}"
×
UNCOV
182
        self.actions.update(
×
183
            {
184
                'play_tone': (module_port, self._define_message(module, [ord('P'), go_tone_index])),
185
                'play_noise': (module_port, self._define_message(module, [ord('P'), noise_index])),
186
                'stop_sound': (module_port, ord('X')),
187
            }
188
        )
189

190
    def define_rotary_encoder_actions(self, module: BpodModule | None = None) -> None:
2✔
NEW
191
        if module is None:
×
NEW
192
            module = self.rotary_encoder
×
NEW
193
        module_port = f"Serial{module.serial_port if module is not None else ''}"
×
UNCOV
194
        self.actions.update(
×
195
            {
196
                'rotary_encoder_reset': (
197
                    module_port,
198
                    self._define_message(module, [RotaryEncoder.COM_SETZEROPOS, RotaryEncoder.COM_ENABLE_ALLTHRESHOLDS]),
199
                ),
200
                'bonsai_hide_stim': (module_port, self._define_message(module, [ord('#'), 1])),
201
                'bonsai_show_stim': (module_port, self._define_message(module, [ord('#'), 8])),
202
                'bonsai_closed_loop': (module_port, self._define_message(module, [ord('#'), 3])),
203
                'bonsai_freeze_stim': (module_port, self._define_message(module, [ord('#'), 4])),
204
                'bonsai_show_center': (module_port, self._define_message(module, [ord('#'), 5])),
205
            }
206
        )
207

208
    def get_ambient_sensor_reading(self):
2✔
209
        if self.ambient_module is None:
2✔
210
            return {
2✔
211
                'Temperature_C': np.NaN,
212
                'AirPressure_mb': np.NaN,
213
                'RelativeHumidity': np.NaN,
214
            }
NEW
215
        self.ambient_module.start_module_relay()
×
NEW
216
        self.bpod_modules.module_write(self.ambient_module, 'R')
×
NEW
217
        reply = self.bpod_modules.module_read(self.ambient_module, 12)
×
NEW
218
        self.ambient_module.stop_module_relay()
×
219

220
        return {
×
221
            'Temperature_C': np.frombuffer(bytes(reply[:4]), np.float32)[0],
222
            'AirPressure_mb': np.frombuffer(bytes(reply[4:8]), np.float32)[0] / 100,
223
            'RelativeHumidity': np.frombuffer(bytes(reply[8:]), np.float32)[0],
224
        }
225

226
    def flush(self):
2✔
227
        """
228
        Flushes valve 1
229
        :return:
230
        """
231
        self.toggle_valve()
×
232

233
    def toggle_valve(self, duration=None):
2✔
234
        """
235
        Flushes valve 1 for duration (seconds)
236
        :return:
237
        """
UNCOV
238
        if duration is None:
×
UNCOV
239
            self.open_valve(open=True, valve_number=1)
×
UNCOV
240
            input('Press ENTER when done.')
×
UNCOV
241
            self.open_valve(open=False, valve_number=1)
×
242
        else:
UNCOV
243
            self.pulse_valve(open_time_s=duration)
×
244

245
    def open_valve(self, open: bool, valve_number: int = 1):
2✔
246
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_number, open)
×
247

248
    def pulse_valve(self, open_time_s: float, valve: str = 'Valve1'):
2✔
249
        sma = StateMachine(self)
×
250
        sma.add_state(
×
251
            state_name='flush', state_timer=open_time_s, state_change_conditions={'Tup': 'exit'}, output_actions=[(valve, 255)]
252
        )
253
        self.send_state_machine(sma)
×
UNCOV
254
        self.run_state_machine(sma)
×
255

256
    def pulse_valve_repeatedly(
2✔
257
        self, repetitions: int, open_time_s: float, close_time_s: float = 0.2, valve: str = 'Valve1'
258
    ) -> int:
UNCOV
259
        counter = 0
×
260

261
        def softcode_handler(_):
×
262
            nonlocal counter
UNCOV
263
            counter += 1
×
264

UNCOV
265
        original_softcode_handler = self.softcode_handler_function
×
UNCOV
266
        self.softcode_handler_function = softcode_handler
×
267

UNCOV
268
        sma = StateMachine(self)
×
UNCOV
269
        sma.set_global_timer(timer_id=1, timer_duration=(open_time_s + close_time_s) * repetitions)
×
NEW
270
        sma.add_state(state_name='start_timer', state_change_conditions={'Tup': 'open'}, output_actions=[('GlobalTimerTrig', 1)])
×
UNCOV
271
        sma.add_state(
×
272
            state_name='open',
273
            state_timer=open_time_s,
274
            state_change_conditions={'Tup': 'close'},
275
            output_actions=[(valve, 255), ('SoftCode', 1)],
276
        )
UNCOV
277
        sma.add_state(
×
278
            state_name='close', state_timer=close_time_s, state_change_conditions={'Tup': 'open', 'GlobalTimer1_End': 'exit'}
279
        )
UNCOV
280
        self.send_state_machine(sma)
×
UNCOV
281
        self.run_state_machine(sma)
×
UNCOV
282
        self.softcode_handler_function = original_softcode_handler
×
UNCOV
283
        return counter
×
284

285
    @static_vars(supported=True)
2✔
286
    def set_status_led(self, state: bool) -> bool:
2✔
287
        if self.can_control_led and self._arcom is not None:
2✔
UNCOV
288
            try:
×
UNCOV
289
                log.debug(f'{"en" if state else "dis"}abling Bpod Status LED')
×
UNCOV
290
                command = struct.pack('cB', b':', state)
×
UNCOV
291
                self._arcom.serial_object.write(command)
×
UNCOV
292
                if self._arcom.read_uint8() == 1:
×
UNCOV
293
                    return True
×
UNCOV
294
            except serial.SerialException:
×
UNCOV
295
                pass
×
UNCOV
296
            self._arcom.serial_object.reset_input_buffer()
×
UNCOV
297
            self._arcom.serial_object.reset_output_buffer()
×
UNCOV
298
            log.warning('Bpod device does not support control of the status LED. Please update firmware.')
×
299
        return False
2✔
300

301
    def valve(self, valve_id: int, state: bool):
2✔
UNCOV
302
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_id, state)
×
303

304
    @validate_call
2✔
305
    def register_softcodes(self, softcode_dict: dict[int, Callable]) -> None:
2✔
306
        """
307
        Register softcodes to be used in the state machine
308

309
        Parameters
310
        ----------
311
        softcode_dict : dict[int, Callable]
312
            dictionary of int keys with callables as values
313
        """
314
        self.softcodes = softcode_dict
2✔
315
        self.softcode_handler_function = lambda code: softcode_dict[code]()
2✔
316

317

318
class MyRotaryEncoder:
2✔
319
    def __init__(self, all_thresholds, gain, com, connect=False):
2✔
320
        self.RE_PORT = com
2✔
321
        self.WHEEL_PERIM = 31 * 2 * np.pi  # = 194,778744523
2✔
322
        self.deg_mm = 360 / self.WHEEL_PERIM
2✔
323
        self.mm_deg = self.WHEEL_PERIM / 360
2✔
324
        self.factor = 1 / (self.mm_deg * gain)
2✔
325
        self.SET_THRESHOLDS = [x * self.factor for x in all_thresholds]
2✔
326
        self.ENABLE_THRESHOLDS = [(x != 0) for x in self.SET_THRESHOLDS]
2✔
327
        # ENABLE_THRESHOLDS needs 8 bools even if only 2 thresholds are set
328
        while len(self.ENABLE_THRESHOLDS) < 8:
2✔
329
            self.ENABLE_THRESHOLDS.append(False)
2✔
330

331
        # Names of the RE events generated by Bpod
332
        self.ENCODER_EVENTS = [f'RotaryEncoder1_{x}' for x in list(range(1, len(all_thresholds) + 1))]
2✔
333
        # Dict mapping threshold crossings with name ov RE event
334
        self.THRESHOLD_EVENTS = dict(zip(all_thresholds, self.ENCODER_EVENTS, strict=False))
2✔
335
        if connect:
2✔
UNCOV
336
            self.connect()
×
337

338
    def connect(self):
2✔
UNCOV
339
        if self.RE_PORT == 'COM#':
×
UNCOV
340
            return
×
UNCOV
341
        m = RotaryEncoderModule(self.RE_PORT)
×
UNCOV
342
        m.set_zero_position()  # Not necessarily needed
×
UNCOV
343
        m.set_thresholds(self.SET_THRESHOLDS)
×
UNCOV
344
        m.enable_thresholds(self.ENABLE_THRESHOLDS)
×
UNCOV
345
        m.enable_evt_transmission()
×
UNCOV
346
        m.close()
×
347

348

349
def sound_device_factory(output: Literal['xonar', 'harp', 'hifi', 'sysdefault'] = 'sysdefault', samplerate: int | None = None):
2✔
350
    """
351
    Will import, configure, and return sounddevice module to play sounds using onboard sound card.
352
    Parameters
353
    ----------
354
    output
355
        defaults to "sysdefault", should be 'xonar' or 'harp'
356
    samplerate
357
        audio sample rate, defaults to 44100
358
    """
359
    match output:
2✔
360
        case 'xonar':
2✔
UNCOV
361
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
362
            devices = sd.query_devices()
×
UNCOV
363
            sd.default.device = next((i for i, d in enumerate(devices) if 'XONAR SOUND CARD(64)' in d['name']), None)
×
UNCOV
364
            sd.default.latency = 'low'
×
UNCOV
365
            sd.default.channels = 2
×
UNCOV
366
            channels = 'L+TTL'
×
UNCOV
367
            sd.default.samplerate = samplerate
×
368
        case 'harp':
2✔
UNCOV
369
            samplerate = samplerate if samplerate is not None else 96000
×
UNCOV
370
            sd.default.samplerate = samplerate
×
UNCOV
371
            sd.default.channels = 2
×
UNCOV
372
            channels = 'stereo'
×
373
        case 'hifi':
2✔
UNCOV
374
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
375
            channels = 'stereo'
×
376
        case 'sysdefault':
2✔
377
            samplerate = samplerate if samplerate is not None else 44100
2✔
378
            sd.default.latency = 'low'
2✔
379
            sd.default.channels = 2
2✔
380
            sd.default.samplerate = samplerate
2✔
381
            channels = 'stereo'
2✔
UNCOV
382
        case _:
×
UNCOV
383
            raise ValueError()
×
384
    return sd, samplerate, channels
2✔
385

386

387
def restart_com_port(regexp: str) -> bool:
2✔
388
    """
389
    Restart the communication port(s) matching the specified regular expression.
390

391
    Parameters
392
    ----------
393
    regexp : str
394
        A regular expression used to match the communication port(s).
395

396
    Returns
397
    -------
398
    bool
399
        Returns True if all matched ports are successfully restarted, False otherwise.
400

401
    Raises
402
    ------
403
    NotImplementedError
404
        If the operating system is not Windows.
405

406
    FileNotFoundError
407
        If the required 'pnputil.exe' executable is not found.
408

409
    Examples
410
    --------
411
    >>> restart_com_port("COM3")  # Restart the communication port with serial number 'COM3'
412
    True
413

414
    >>> restart_com_port("COM[1-3]")  # Restart communication ports with serial numbers 'COM1', 'COM2', 'COM3'
415
    True
416
    """
UNCOV
417
    if not os.name == 'nt':
×
UNCOV
418
        raise NotImplementedError('Only implemented for Windows OS.')
×
UNCOV
419
    if not (file_pnputil := Path(shutil.which('pnputil'))).exists():
×
UNCOV
420
        raise FileNotFoundError('Could not find pnputil.exe')
×
UNCOV
421
    result = []
×
UNCOV
422
    for port in list_ports.grep(regexp):
×
UNCOV
423
        pnputil_output = subprocess.check_output([file_pnputil, '/enum-devices', '/connected', '/class', 'ports'])
×
UNCOV
424
        instance_id = re.search(rf'(\S*{port.serial_number}\S*)', pnputil_output.decode())
×
UNCOV
425
        if instance_id is None:
×
UNCOV
426
            continue
×
UNCOV
427
        result.append(
×
428
            subprocess.check_call(
429
                [file_pnputil, '/restart-device', f'"{instance_id.group}"'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
430
            )
431
            == 0
432
        )
UNCOV
433
    return all(result)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc