• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 10568073180

26 Aug 2024 10:13PM UTC coverage: 47.538% (+0.7%) from 46.79%
10568073180

Pull #711

github

eeff82
web-flow
Merge 599c9edfb into ad41db25f
Pull Request #711: 8.23.2

121 of 135 new or added lines in 8 files covered. (89.63%)

1025 existing lines in 22 files now uncovered.

4084 of 8591 relevant lines covered (47.54%)

0.95 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

52.74
/iblrig/hardware.py
1
"""Hardware classes used to interact with modules."""
2

3
import logging
2✔
4
import os
2✔
5
import re
2✔
6
import shutil
2✔
7
import struct
2✔
8
import subprocess
2✔
9
import threading
2✔
10
import time
2✔
11
from collections.abc import Callable
2✔
12
from enum import IntEnum
2✔
13
from pathlib import Path
2✔
14
from typing import Annotated, Literal
2✔
15

16
import numpy as np
2✔
17
import serial
2✔
18
import sounddevice as sd
2✔
19
from annotated_types import Ge, Le
2✔
20
from pydantic import PositiveFloat, PositiveInt, validate_call
2✔
21
from serial.tools import list_ports
2✔
22

23
from iblrig.tools import static_vars
2✔
24
from iblutil.util import Bunch
2✔
25
from pybpod_rotaryencoder_module.module import RotaryEncoder
2✔
26
from pybpod_rotaryencoder_module.module_api import RotaryEncoderModule
2✔
27
from pybpodapi.bpod.bpod_io import BpodIO
2✔
28
from pybpodapi.bpod_modules.bpod_module import BpodModule
2✔
29
from pybpodapi.state_machine import StateMachine
2✔
30

31
SOFTCODE = IntEnum('SOFTCODE', ['STOP_SOUND', 'PLAY_TONE', 'PLAY_NOISE', 'TRIGGER_CAMERA'])
2✔
32

33
# some annotated types
34
Uint8 = Annotated[int, Ge(0), Le(255)]
2✔
35
ActionIdx = Annotated[int, Ge(1), Le(255)]
2✔
36

37
log = logging.getLogger(__name__)
2✔
38

39

40
class Bpod(BpodIO):
2✔
41
    can_control_led = True
2✔
42
    softcodes: dict[int, Callable] | None = None
2✔
43
    _instances = {}
2✔
44
    _lock = threading.RLock()
2✔
45
    _is_initialized = False
2✔
46

47
    def __new__(cls, *args, **kwargs):
2✔
48
        serial_port = args[0] if len(args) > 0 else ''
2✔
49
        serial_port = kwargs.get('serial_port', serial_port)
2✔
50
        with cls._lock:
2✔
51
            instance = Bpod._instances.get(serial_port, None)
2✔
52
            if instance:
2✔
53
                return instance
2✔
54
            instance = super().__new__(cls)
2✔
55
            Bpod._instances[serial_port] = instance
2✔
56
            return instance
2✔
57

58
    def __init__(self, *args, skip_initialization: bool = False, **kwargs):
2✔
59
        # skip initialization if it has already been performed before
60
        # IMPORTANT: only use this for non-critical tasks (e.g., flushing valve from GUI)
61
        if skip_initialization and self._is_initialized:
2✔
UNCOV
62
            return
×
63

64
        # try to instantiate once for nothing
65
        try:
2✔
66
            super().__init__(*args, **kwargs)
2✔
UNCOV
67
        except Exception:
×
UNCOV
68
            log.warning("Couldn't instantiate BPOD, retrying once...")
×
UNCOV
69
            time.sleep(1)
×
UNCOV
70
            try:
×
UNCOV
71
                super().__init__(*args, **kwargs)
×
UNCOV
72
            except (serial.serialutil.SerialException, UnicodeDecodeError) as e:
×
UNCOV
73
                log.error(e)
×
UNCOV
74
                raise serial.serialutil.SerialException(
×
75
                    'The communication with Bpod is established but the Bpod is not responsive. '
76
                    'This is usually indicated by the device with a green light. '
77
                    'Please unplug the Bpod USB cable from the computer and plug it back in to start the task. '
78
                ) from e
79
        self.serial_messages = {}
2✔
80
        self.actions = Bunch({})
2✔
81
        self.can_control_led = self.set_status_led(True)
2✔
82
        self._is_initialized = True
2✔
83

84
    def close(self) -> None:
2✔
85
        super().close()
2✔
86
        self._is_initialized = False
2✔
87

88
    def __del__(self):
2✔
89
        with self._lock:
2✔
90
            if self.serial_port in Bpod._instances:
2✔
91
                Bpod._instances.pop(self.serial_port)
2✔
92

93
    @property
2✔
94
    def is_connected(self):
2✔
95
        return self.modules is not None
2✔
96

97
    @property
2✔
98
    def rotary_encoder(self):
2✔
UNCOV
99
        return self.get_module('rotary_encoder')
×
100

101
    @property
2✔
102
    def sound_card(self):
2✔
UNCOV
103
        return self.get_module('sound_card')
×
104

105
    @property
2✔
106
    def ambient_module(self):
2✔
107
        return self.get_module('^AmbientModule')
2✔
108

109
    def get_module(self, module_name: str) -> BpodModule | None:
2✔
110
        """Get module by name.
111

112
        Parameters
113
        ----------
114
        module_name : str
115
            Regular Expression for matching a module name
116

117
        Returns
118
        -------
119
        BpodModule | None
120
            First matching module or None
121
        """
122
        if self.modules is None:
2✔
123
            return None
2✔
124
        if module_name in ['re', 'rotary_encoder']:
×
UNCOV
125
            module_name = r'^RotaryEncoder'
×
UNCOV
126
        elif module_name in ['sc', 'sound_card']:
×
UNCOV
127
            module_name = r'^SoundCard'
×
UNCOV
128
        modules = [x for x in self.modules if re.match(module_name, x.name)]
×
UNCOV
129
        if len(modules) > 1:
×
UNCOV
130
            log.critical(f'Found several Bpod modules matching `{module_name}`. Using first match: `{modules[0].name}`')
×
UNCOV
131
        if len(modules) > 0:
×
UNCOV
132
            return modules[0]
×
133

134
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
135
    def _define_message(self, module: BpodModule | int, message: list[Uint8]) -> ActionIdx:
2✔
136
        """Define a serial message to be sent to a Bpod module as an output action within a state.
137

138
        Parameters
139
        ----------
140
        module : BpodModule | int
141
            The targeted module, defined as a BpodModule instance or the module's port index
142
        message : list[int]
143
            The message to be sent - a list of up to three 8-bit integers
144

145
        Returns
146
        -------
147
        int
148
            The index of the serial message (1-255)
149

150
        Raises
151
        ------
152
        TypeError
153
            If module is not an instance of BpodModule or int
154

155
        Examples
156
        --------
157
        >>> id_msg_bonsai_show_stim = self._define_message(self.rotary_encoder, [ord("#"), 2])
158
        will then be used as such in StateMachine:
159
        >>> output_actions=[("Serial1", id_msg_bonsai_show_stim)]
160
        """
UNCOV
161
        if isinstance(module, BpodModule):
×
UNCOV
162
            module = module.serial_port
×
UNCOV
163
        message_id = len(self.serial_messages) + 1
×
UNCOV
164
        self.load_serial_message(module, message_id, message)
×
165
        self.serial_messages.update({message_id: {'target_module': module, 'message': message}})
×
UNCOV
166
        return message_id
×
167

168
    @validate_call(config={'arbitrary_types_allowed': True})
2✔
169
    def define_xonar_sounds_actions(self):
2✔
UNCOV
170
        self.actions.update(
×
171
            {
172
                'play_tone': ('SoftCode', SOFTCODE.PLAY_TONE),
173
                'play_noise': ('SoftCode', SOFTCODE.PLAY_NOISE),
174
                'stop_sound': ('SoftCode', SOFTCODE.STOP_SOUND),
175
            }
176
        )
177

178
    def define_harp_sounds_actions(self, module: BpodModule, go_tone_index: int = 2, noise_index: int = 3) -> None:
2✔
UNCOV
179
        module_port = f"Serial{module.serial_port if module is not None else ''}"
×
180
        self.actions.update(
×
181
            {
182
                'play_tone': (module_port, self._define_message(module, [ord('P'), go_tone_index])),
183
                'play_noise': (module_port, self._define_message(module, [ord('P'), noise_index])),
184
                'stop_sound': (module_port, ord('X')),
185
            }
186
        )
187

188
    def define_rotary_encoder_actions(self, module: BpodModule | None = None) -> None:
2✔
189
        if module is None:
×
UNCOV
190
            module = self.rotary_encoder
×
UNCOV
191
        module_port = f"Serial{module.serial_port if module is not None else ''}"
×
UNCOV
192
        self.actions.update(
×
193
            {
194
                'rotary_encoder_reset': (
195
                    module_port,
196
                    self._define_message(module, [RotaryEncoder.COM_SETZEROPOS, RotaryEncoder.COM_ENABLE_ALLTHRESHOLDS]),
197
                ),
198
                'bonsai_hide_stim': (module_port, self._define_message(module, [ord('#'), 1])),
199
                'bonsai_show_stim': (module_port, self._define_message(module, [ord('#'), 8])),
200
                'bonsai_closed_loop': (module_port, self._define_message(module, [ord('#'), 3])),
201
                'bonsai_freeze_stim': (module_port, self._define_message(module, [ord('#'), 4])),
202
                'bonsai_show_center': (module_port, self._define_message(module, [ord('#'), 5])),
203
            }
204
        )
205

206
    def get_ambient_sensor_reading(self):
2✔
207
        if self.ambient_module is None:
2✔
208
            return {
2✔
209
                'Temperature_C': np.NaN,
210
                'AirPressure_mb': np.NaN,
211
                'RelativeHumidity': np.NaN,
212
            }
UNCOV
213
        self.ambient_module.start_module_relay()
×
UNCOV
214
        self.bpod_modules.module_write(self.ambient_module, 'R')
×
215
        reply = self.bpod_modules.module_read(self.ambient_module, 12)
×
216
        self.ambient_module.stop_module_relay()
×
217

UNCOV
218
        return {
×
219
            'Temperature_C': np.frombuffer(bytes(reply[:4]), np.float32)[0],
220
            'AirPressure_mb': np.frombuffer(bytes(reply[4:8]), np.float32)[0] / 100,
221
            'RelativeHumidity': np.frombuffer(bytes(reply[8:]), np.float32)[0],
222
        }
223

224
    def flush(self):
2✔
225
        """Flushes valve 1."""
226
        self.toggle_valve()
×
227

228
    def toggle_valve(self, duration: int | None = None):
2✔
229
        """
230
        Flush valve 1 for specified duration.
231

232
        Parameters
233
        ----------
234
        duration : int, optional
235
            Duration of valve opening in seconds.
236
        """
UNCOV
237
        if duration is None:
×
UNCOV
238
            self.open_valve(open=True, valve_number=1)
×
UNCOV
239
            input('Press ENTER when done.')
×
240
            self.open_valve(open=False, valve_number=1)
×
241
        else:
242
            self.pulse_valve(open_time_s=duration)
×
243

244
    def open_valve(self, open: bool, valve_number: int = 1):
2✔
245
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_number, open)
×
246

247
    def pulse_valve(self, open_time_s: float, valve: str = 'Valve1'):
2✔
248
        sma = StateMachine(self)
×
249
        sma.add_state(
×
250
            state_name='flush', state_timer=open_time_s, state_change_conditions={'Tup': 'exit'}, output_actions=[(valve, 255)]
251
        )
UNCOV
252
        self.send_state_machine(sma)
×
UNCOV
253
        self.run_state_machine(sma)
×
254

255
    @validate_call()
2✔
256
    def pulse_valve_repeatedly(
2✔
257
        self, repetitions: PositiveInt, open_time_s: PositiveFloat, close_time_s: PositiveFloat = 0.2, valve: str = 'Valve1'
258
    ) -> int:
259
        counter = 0
×
260

UNCOV
261
        def softcode_handler(softcode: int):
×
262
            nonlocal counter, repetitions
UNCOV
263
            if softcode == 1:
×
UNCOV
264
                counter += 1
×
UNCOV
265
            elif softcode == 2 and counter >= repetitions:
×
UNCOV
266
                self.stop_trial()
×
267

UNCOV
268
        original_softcode_handler = self.softcode_handler_function
×
UNCOV
269
        self.softcode_handler_function = softcode_handler
×
270

UNCOV
271
        sma = StateMachine(self)
×
UNCOV
272
        sma.add_state(
×
273
            state_name='open',
274
            state_timer=open_time_s,
275
            state_change_conditions={'Tup': 'close'},
276
            output_actions=[(valve, 255), ('SoftCode', 1)],
277
        )
UNCOV
278
        sma.add_state(
×
279
            state_name='close',
280
            state_timer=close_time_s,
281
            state_change_conditions={'Tup': 'open'},
282
            output_actions=[('SoftCode', 2)],
283
        )
UNCOV
284
        self.send_state_machine(sma)
×
UNCOV
285
        self.run_state_machine(sma)
×
286

UNCOV
287
        self.softcode_handler_function = original_softcode_handler
×
UNCOV
288
        return counter
×
289

290
    @static_vars(supported=True)
2✔
291
    def set_status_led(self, state: bool) -> bool:
2✔
292
        if self.can_control_led and self._arcom is not None:
2✔
UNCOV
293
            try:
×
UNCOV
294
                log.debug(f'{"en" if state else "dis"}abling Bpod Status LED')
×
UNCOV
295
                command = struct.pack('cB', b':', state)
×
UNCOV
296
                self._arcom.serial_object.write(command)
×
UNCOV
297
                if self._arcom.read_uint8() == 1:
×
UNCOV
298
                    return True
×
UNCOV
299
            except serial.SerialException:
×
UNCOV
300
                pass
×
UNCOV
301
            self._arcom.serial_object.reset_input_buffer()
×
UNCOV
302
            self._arcom.serial_object.reset_output_buffer()
×
UNCOV
303
            log.warning('Bpod device does not support control of the status LED. Please update firmware.')
×
304
        return False
2✔
305

306
    def valve(self, valve_id: int, state: bool):
2✔
UNCOV
307
        self.manual_override(self.ChannelTypes.OUTPUT, self.ChannelNames.VALVE, valve_id, state)
×
308

309
    @validate_call
2✔
310
    def register_softcodes(self, softcode_dict: dict[int, Callable]) -> None:
2✔
311
        """
312
        Register softcodes to be used in the state machine.
313

314
        Parameters
315
        ----------
316
        softcode_dict : dict[int, Callable]
317
            dictionary of int keys with callables as values
318
        """
319
        self.softcodes = softcode_dict
2✔
320
        self.softcode_handler_function = lambda code: softcode_dict[code]()
2✔
321

322

323
class MyRotaryEncoder:
2✔
324
    def __init__(self, all_thresholds, gain, com, connect=False):
2✔
325
        self.RE_PORT = com
2✔
326
        self.WHEEL_PERIM = 31 * 2 * np.pi  # = 194,778744523
2✔
327
        self.deg_mm = 360 / self.WHEEL_PERIM
2✔
328
        self.mm_deg = self.WHEEL_PERIM / 360
2✔
329
        self.factor = 1 / (self.mm_deg * gain)
2✔
330
        self.SET_THRESHOLDS = [x * self.factor for x in all_thresholds]
2✔
331
        self.ENABLE_THRESHOLDS = [(x != 0) for x in self.SET_THRESHOLDS]
2✔
332
        # ENABLE_THRESHOLDS needs 8 bools even if only 2 thresholds are set
333
        while len(self.ENABLE_THRESHOLDS) < 8:
2✔
334
            self.ENABLE_THRESHOLDS.append(False)
2✔
335

336
        # Names of the RE events generated by Bpod
337
        self.ENCODER_EVENTS = [f'RotaryEncoder1_{x}' for x in list(range(1, len(all_thresholds) + 1))]
2✔
338
        # Dict mapping threshold crossings with name ov RE event
339
        self.THRESHOLD_EVENTS = dict(zip(all_thresholds, self.ENCODER_EVENTS, strict=False))
2✔
340
        if connect:
2✔
UNCOV
341
            self.connect()
×
342

343
    def connect(self):
2✔
UNCOV
344
        if self.RE_PORT == 'COM#':
×
UNCOV
345
            return
×
UNCOV
346
        m = RotaryEncoderModule(self.RE_PORT)
×
UNCOV
347
        m.set_zero_position()  # Not necessarily needed
×
UNCOV
348
        m.set_thresholds(self.SET_THRESHOLDS)
×
UNCOV
349
        m.enable_thresholds(self.ENABLE_THRESHOLDS)
×
UNCOV
350
        m.enable_evt_transmission()
×
UNCOV
351
        m.close()
×
352

353

354
def sound_device_factory(output: Literal['xonar', 'harp', 'hifi', 'sysdefault'] = 'sysdefault', samplerate: int | None = None):
2✔
355
    """
356
    Will import, configure, and return sounddevice module to play sounds using onboard sound card.
357

358
    Parameters
359
    ----------
360
    output
361
        defaults to "sysdefault", should be 'xonar' or 'harp'
362
    samplerate
363
        audio sample rate, defaults to 44100
364
    """
365
    match output:
2✔
366
        case 'xonar':
2✔
UNCOV
367
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
368
            devices = sd.query_devices()
×
UNCOV
369
            sd.default.device = next((i for i, d in enumerate(devices) if 'XONAR SOUND CARD(64)' in d['name']), None)
×
UNCOV
370
            sd.default.latency = 'low'
×
UNCOV
371
            sd.default.channels = 2
×
UNCOV
372
            channels = 'L+TTL'
×
UNCOV
373
            sd.default.samplerate = samplerate
×
374
        case 'harp':
2✔
UNCOV
375
            samplerate = samplerate if samplerate is not None else 96000
×
UNCOV
376
            sd.default.samplerate = samplerate
×
UNCOV
377
            sd.default.channels = 2
×
UNCOV
378
            channels = 'stereo'
×
379
        case 'hifi':
2✔
UNCOV
380
            samplerate = samplerate if samplerate is not None else 192000
×
UNCOV
381
            channels = 'stereo'
×
382
        case 'sysdefault':
2✔
383
            samplerate = samplerate if samplerate is not None else 44100
2✔
384
            sd.default.latency = 'low'
2✔
385
            sd.default.channels = 2
2✔
386
            sd.default.samplerate = samplerate
2✔
387
            channels = 'stereo'
2✔
UNCOV
388
        case _:
×
UNCOV
389
            raise ValueError()
×
390
    return sd, samplerate, channels
2✔
391

392

393
def restart_com_port(regexp: str) -> bool:
2✔
394
    """
395
    Restart the communication port(s) matching the specified regular expression.
396

397
    Parameters
398
    ----------
399
    regexp : str
400
        A regular expression used to match the communication port(s).
401

402
    Returns
403
    -------
404
    bool
405
        Returns True if all matched ports are successfully restarted, False otherwise.
406

407
    Raises
408
    ------
409
    NotImplementedError
410
        If the operating system is not Windows.
411

412
    FileNotFoundError
413
        If the required 'pnputil.exe' executable is not found.
414

415
    Examples
416
    --------
417
    >>> restart_com_port("COM3")  # Restart the communication port with serial number 'COM3'
418
    True
419

420
    >>> restart_com_port("COM[1-3]")  # Restart communication ports with serial numbers 'COM1', 'COM2', 'COM3'
421
    True
422
    """
UNCOV
423
    if not os.name == 'nt':
×
UNCOV
424
        raise NotImplementedError('Only implemented for Windows OS.')
×
UNCOV
425
    if not (file_pnputil := Path(shutil.which('pnputil'))).exists():
×
UNCOV
426
        raise FileNotFoundError('Could not find pnputil.exe')
×
UNCOV
427
    result = []
×
UNCOV
428
    for port in list_ports.grep(regexp):
×
UNCOV
429
        pnputil_output = subprocess.check_output([file_pnputil, '/enum-devices', '/connected', '/class', 'ports'])
×
UNCOV
430
        instance_id = re.search(rf'(\S*{port.serial_number}\S*)', pnputil_output.decode())
×
UNCOV
431
        if instance_id is None:
×
UNCOV
432
            continue
×
UNCOV
433
        result.append(
×
434
            subprocess.check_call(
435
                [file_pnputil, '/restart-device', f'"{instance_id.group}"'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT
436
            )
437
            == 0
438
        )
UNCOV
439
    return all(result)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc