• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 15073834064

16 May 2025 05:16PM UTC coverage: 49.414% (+2.6%) from 46.79%
15073834064

Pull #750

github

c98309
web-flow
Merge 8e475a77c into e481532ae
Pull Request #750: Online plots

538 of 720 new or added lines in 3 files covered. (74.72%)

1000 existing lines in 20 files now uncovered.

4677 of 9465 relevant lines covered (49.41%)

0.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.5
/iblrig/base_choice_world.py
1
"""Extends the base_tasks modules by providing task logic around the Choice World protocol."""
2

3
import abc
1✔
4
import logging
1✔
5
import math
1✔
6
import random
1✔
7
import subprocess
1✔
8
import time
1✔
9
from pathlib import Path
1✔
10
from re import split as re_split
1✔
11
from string import ascii_letters
1✔
12
from typing import Annotated, Any
1✔
13

14
import numpy as np
1✔
15
import pandas as pd
1✔
16
from annotated_types import Interval, IsNan
1✔
17
from pydantic import NonNegativeFloat, NonNegativeInt
1✔
18

19
import iblrig.base_tasks
1✔
20
from iblrig import choiceworld, misc
1✔
21
from iblrig.hardware import DTYPE_AMBIENT_SENSOR_BIN, SOFTCODE
1✔
22
from iblrig.pydantic_definitions import TrialDataModel
1✔
23
from iblutil.io import binary, jsonable
1✔
24
from iblutil.util import Bunch
1✔
25
from pybpodapi.com.messaging.trial import Trial
1✔
26
from pybpodapi.protocol import StateMachine
1✔
27

28
log = logging.getLogger(__name__)
1✔
29

30
NTRIALS_INIT = 2000
1✔
31
NBLOCKS_INIT = 100
1✔
32

33
# TODO: task parameters should be verified through a pydantic model
34
#
35
# Probability = Annotated[float, Field(ge=0.0, le=1.0)]
36
#
37
# class ChoiceWorldParams(BaseModel):
38
#     AUTOMATIC_CALIBRATION: bool = True
39
#     ADAPTIVE_REWARD: bool = False
40
#     BONSAI_EDITOR: bool = False
41
#     CALIBRATION_VALUE: float = 0.067
42
#     CONTRAST_SET: list[Probability] = Field([1.0, 0.25, 0.125, 0.0625, 0.0], min_length=1)
43
#     CONTRAST_SET_PROBABILITY_TYPE: Literal['uniform', 'skew_zero'] = 'uniform'
44
#     GO_TONE_AMPLITUDE: float = 0.0272
45
#     GO_TONE_DURATION: float = 0.11
46
#     GO_TONE_IDX: int = Field(2, ge=0)
47
#     GO_TONE_FREQUENCY: float = Field(5000, gt=0)
48
#     FEEDBACK_CORRECT_DELAY_SECS: float = 1
49
#     FEEDBACK_ERROR_DELAY_SECS: float = 2
50
#     FEEDBACK_NOGO_DELAY_SECS: float = 2
51
#     INTERACTIVE_DELAY: float = 0.0
52
#     ITI_DELAY_SECS: float = 0.5
53
#     NTRIALS: int = Field(2000, gt=0)
54
#     PROBABILITY_LEFT: Probability = 0.5
55
#     QUIESCENCE_THRESHOLDS: list[float] = Field(default=[-2, 2], min_length=2, max_length=2)
56
#     QUIESCENT_PERIOD: float = 0.2
57
#     RECORD_AMBIENT_SENSOR_DATA: bool = True
58
#     RECORD_SOUND: bool = True
59
#     RESPONSE_WINDOW: float = 60
60
#     REWARD_AMOUNT_UL: float = 1.5
61
#     REWARD_TYPE: str = 'Water 10% Sucrose'
62
#     STIM_ANGLE: float = 0.0
63
#     STIM_FREQ: float = 0.1
64
#     STIM_GAIN: float = 4.0  # wheel to stimulus relationship (degrees visual angle per mm of wheel displacement)
65
#     STIM_POSITIONS: list[float] = [-35, 35]
66
#     STIM_SIGMA: float = 7.0
67
#     STIM_TRANSLATION_Z: Literal[7, 8] = 7  # 7 for ephys, 8 otherwise. -p:Stim.TranslationZ-{STIM_TRANSLATION_Z} bonsai param
68
#     STIM_REVERSE: bool = False
69
#     SYNC_SQUARE_X: float = 1.33
70
#     SYNC_SQUARE_Y: float = -1.03
71
#     USE_AUTOMATIC_STOPPING_CRITERIONS: bool = True
72
#     VISUAL_STIMULUS: str = 'GaborIBLTask / Gabor2D.bonsai'  # null / passiveChoiceWorld_passive.bonsai
73
#     WHITE_NOISE_AMPLITUDE: float = 0.05
74
#     WHITE_NOISE_DURATION: float = 0.5
75
#     WHITE_NOISE_IDX: int = 3
76

77

78
class ChoiceWorldTrialData(TrialDataModel):
1✔
79
    """Pydantic Model for Trial Data."""
80

81
    contrast: Annotated[float, Interval(ge=0.0, le=1.0)]
1✔
82
    stim_probability_left: Annotated[float, Interval(ge=0.0, le=1.0)]
1✔
83
    position: float
1✔
84
    quiescent_period: NonNegativeFloat
1✔
85
    reward_amount: NonNegativeFloat
1✔
86
    reward_valve_time: NonNegativeFloat
1✔
87
    stim_angle: Annotated[float, Interval(ge=-180.0, le=180.0)]
1✔
88
    stim_freq: NonNegativeFloat
1✔
89
    stim_gain: float
1✔
90
    stim_phase: Annotated[float, Interval(ge=0.0, le=2 * math.pi)]
1✔
91
    stim_reverse: bool
1✔
92
    stim_sigma: float
1✔
93
    trial_num: NonNegativeInt
1✔
94
    pause_duration: NonNegativeFloat = 0.0
1✔
95

96
    # The following variables are only used in ActiveChoiceWorld
97
    # We keep them here with fixed default values for sake of compatibility
98
    #
99
    # TODO: Yes, this should probably be done differently.
100
    response_side: Annotated[int, Interval(ge=0, le=0)] = 0
1✔
101
    response_time: IsNan[float] = np.nan
1✔
102
    trial_correct: Annotated[bool, Interval(ge=0, le=0)] = False
1✔
103

104

105
class ChoiceWorldSession(
1✔
106
    iblrig.base_tasks.BonsaiRecordingMixin,
107
    iblrig.base_tasks.BonsaiVisualStimulusMixin,
108
    iblrig.base_tasks.BpodMixin,
109
    iblrig.base_tasks.Frame2TTLMixin,
110
    iblrig.base_tasks.RotaryEncoderMixin,
111
    iblrig.base_tasks.SoundMixin,
112
    iblrig.base_tasks.ValveMixin,
113
    iblrig.base_tasks.NetworkSession,
114
):
115
    # task_params = ChoiceWorldParams()
116
    base_parameters_file = Path(__file__).parent.joinpath('base_choice_world_params.yaml')
1✔
117
    TrialDataModel = ChoiceWorldTrialData
1✔
118

119
    def __init__(self, *args, delay_mins: float = 0, **kwargs):
1✔
120
        super().__init__(**kwargs)
1✔
121

122
        # session delay is handled in seconds internally
123
        self.task_params['SESSION_DELAY_START'] = delay_mins * 60.0
1✔
124

125
        # init behaviour data
126
        self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[0]]
1✔
127
        self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[1]]
1✔
128

129
        # init counter variables
130
        self.trial_num = -1
1✔
131
        self.block_num = -1
1✔
132
        self.block_trial_num = -1
1✔
133

134
        # init the tables, there are 2 of them: a trials table and a ambient sensor data table
135
        self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT)
1✔
136
        self.ambient_sensor_table = pd.DataFrame(
1✔
137
            np.nan, index=range(NTRIALS_INIT), columns=['Temperature_C', 'AirPressure_mb', 'RelativeHumidity'], dtype=np.float32
138
        )
139
        self.ambient_sensor_table.rename_axis('Trial', inplace=True)
1✔
140

141
    @staticmethod
1✔
142
    def extra_parser():
1✔
143
        """:return: argparse.parser()"""
144
        parser = super(ChoiceWorldSession, ChoiceWorldSession).extra_parser()
1✔
145
        parser.add_argument(
1✔
146
            '--delay_mins',
147
            dest='delay_mins',
148
            default=0,
149
            type=float,
150
            required=False,
151
            help='initial delay before starting the first trial (default: 0 min)',
152
        )
153
        return parser
1✔
154

155
    def start_hardware(self):
1✔
156
        """
157
        In this step we explicitly run the start methods of the various mixins.
158
        The super class start method is overloaded because we need to start the different hardware pieces in order
159
        """
160
        if not self.is_mock:
1✔
UNCOV
161
            self.start_mixin_frame2ttl()
×
UNCOV
162
            self.start_mixin_bpod()
×
UNCOV
163
            self.start_mixin_valve()
×
UNCOV
164
            self.start_mixin_sound()
×
UNCOV
165
            self.start_mixin_rotary_encoder()
×
UNCOV
166
            self.start_mixin_bonsai_cameras()
×
UNCOV
167
            self.start_mixin_bonsai_microphone()
×
UNCOV
168
            self.start_mixin_bonsai_visual_stimulus()
×
UNCOV
169
            self.bpod.register_softcodes(self.softcode_dictionary())
×
170

171
    def _run(self):
1✔
172
        """Run the task with the actual state machine."""
173
        time_last_trial_end = time.time()
1✔
174
        for i in range(self.task_params.NTRIALS):  # Main loop
1✔
175
            # t_overhead = time.time()
176
            self.next_trial()
1✔
177
            log.info(f'Starting trial: {i}')
1✔
178
            # =============================================================================
179
            #     Start state machine definition
180
            # =============================================================================
181
            sma = self.get_state_machine_trial(i)
1✔
182
            log.debug('Sending state machine to bpod')
1✔
183
            # Send state machine description to Bpod device
184
            self.bpod.send_state_machine(sma)
1✔
185
            # t_overhead = time.time() - t_overhead
186
            # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
187
            # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
188
            dead_time = self.task_params.get('DEAD_TIME', 0.5)
1✔
189
            dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
1✔
190
            # wait to achieve the desired ITI duration
191
            if dt > 0:
1✔
UNCOV
192
                time.sleep(dt)
×
193
            # Run state machine
194
            log.debug('running state machine')
1✔
195
            self.bpod.run_state_machine(sma)  # Locks until state machine 'exit' is reached
1✔
196
            time_last_trial_end = time.time()
1✔
197
            # handle pause event
198
            flag_pause = self.paths.SESSION_FOLDER.joinpath('.pause')
1✔
199
            flag_stop = self.paths.SESSION_FOLDER.joinpath('.stop')
1✔
200
            if flag_pause.exists() and i < (self.task_params.NTRIALS - 1):
1✔
201
                log.info(f'Pausing session inbetween trials {i} and {i + 1}')
1✔
202
                while flag_pause.exists() and not flag_stop.exists():
1✔
203
                    time.sleep(1)
1✔
204
                self.trials_table.at[self.trial_num, 'pause_duration'] = time.time() - time_last_trial_end
1✔
205
                if not flag_stop.exists():
1✔
206
                    log.info('Resuming session')
1✔
207

208
            # save trial and update log
209
            self.trial_completed(self.bpod.session.current_trial.export())
1✔
210
            self.show_trial_log()
1✔
211

212
            # handle stop event
213
            if flag_stop.exists():
1✔
214
                log.info('Stopping session after trial %d', i)
1✔
215
                flag_stop.unlink()
1✔
216
                break
1✔
217

218
    def mock(self, file_jsonable_fixture=None):
1✔
219
        """
220
        Instantiate a state machine and Bpod object to simulate a task's run.
221

222
        This is useful to test or display the state machine flow.
223
        """
224
        super().mock()
1✔
225

226
        if file_jsonable_fixture is not None:
1✔
227
            task_data = jsonable.read(file_jsonable_fixture)
1✔
228
            # pop-out the bpod data from the table
229
            bpod_data = []
1✔
230
            for td in task_data:
1✔
231
                bpod_data.append(td.pop('behavior_data'))
1✔
232

233
            class MockTrial(Trial):
1✔
234
                def export(self):
1✔
235
                    return np.random.choice(bpod_data)
1✔
236
        else:
237

238
            class MockTrial(Trial):
1✔
239
                def export(self):
1✔
UNCOV
240
                    return {}
×
241

242
        self.bpod.session.trials = [MockTrial()]
1✔
243
        self.bpod.send_state_machine = lambda k: None
1✔
244
        self.bpod.run_state_machine = lambda k: time.sleep(1.2)
1✔
245

246
        daction = ('dummy', 'action')
1✔
247
        self.sound = Bunch({'GO_TONE': daction, 'WHITE_NOISE': daction})
1✔
248

249
        self.bpod.actions.update(
1✔
250
            {
251
                'play_tone': daction,
252
                'play_noise': daction,
253
                'stop_sound': daction,
254
                'rotary_encoder_reset': daction,
255
                'bonsai_hide_stim': daction,
256
                'bonsai_show_stim': daction,
257
                'bonsai_closed_loop': daction,
258
                'bonsai_freeze_stim': daction,
259
                'bonsai_show_center': daction,
260
            }
261
        )
262

263
    def get_graphviz_task(self, output_file=None, view=True):
1✔
264
        """
265
        Get the state machine's states diagram in Digraph format.
266

267
        :param output_file:
268
        :return:
269
        """
270
        import graphviz
1✔
271

272
        self.next_trial()
1✔
273
        sma = self.get_state_machine_trial(0)
1✔
274
        if sma is None:
1✔
275
            return
1✔
276
        states_indices = {i: k for i, k in enumerate(sma.state_names)}
1✔
277
        states_indices.update({(i + 10000): k for i, k in enumerate(sma.undeclared)})
1✔
278
        states_letters = {k: ascii_letters[i] for i, k in enumerate(sma.state_names)}
1✔
279
        dot = graphviz.Digraph(comment='The Great IBL Task')
1✔
280
        edges = []
1✔
281

282
        for i in range(len(sma.state_names)):
1✔
283
            letter = states_letters[sma.state_names[i]]
1✔
284
            dot.node(letter, sma.state_names[i])
1✔
285
            if ~np.isnan(sma.state_timer_matrix[i]):
1✔
286
                out_state = states_indices[sma.state_timer_matrix[i]]
1✔
287
                edges.append(f'{letter}{states_letters[out_state]}')
1✔
288
            for input in sma.input_matrix[i]:
1✔
289
                if input[0] == 0:
1✔
290
                    edges.append(f'{letter}{states_letters[states_indices[input[1]]]}')
1✔
291
        dot.edges(edges)
1✔
292
        if output_file is not None:
1✔
293
            try:
1✔
294
                dot.render(output_file, view=view)
1✔
295
            except graphviz.exceptions.ExecutableNotFound:
1✔
296
                log.info('Graphviz system executable not found, cannot render the graph')
1✔
297
        return dot
1✔
298

299
    def _instantiate_state_machine(self, *args, **kwargs):
1✔
300
        return StateMachine(self.bpod)
1✔
301

302
    def get_state_machine_trial(self, i):
1✔
303
        # we define the trial number here for subclasses that may need it
304
        sma = self._instantiate_state_machine(trial_number=i)
1✔
305

306
        if i == 0:  # First trial exception start camera
1✔
307
            session_delay_start = self.task_params.get('SESSION_DELAY_START', 0)
1✔
308
            log.info('First trial initializing, will move to next trial only if:')
1✔
309
            log.info('1. camera is detected')
1✔
310
            log.info(f'2. {session_delay_start} sec have elapsed')
1✔
311
            sma.add_state(
1✔
312
                state_name='trial_start',
313
                state_timer=0,
314
                state_change_conditions={'Port1In': 'delay_initiation'},
315
                output_actions=[('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
316
            )  # start camera
317
            sma.add_state(
1✔
318
                state_name='delay_initiation',
319
                state_timer=session_delay_start,
320
                output_actions=[],
321
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
322
            )
323
        else:
324
            sma.add_state(
1✔
325
                state_name='trial_start',
326
                state_timer=0,  # ~100µs hardware irreducible delay
327
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
328
                output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
329
            )  # stop all sounds
330

331
        # Reset the rotary encoder by sending the following opcodes via the modules serial interface
332
        # - 'Z' (ASCII 90): Set current rotary encoder position to zero
333
        # - 'E' (ASCII 69): Enable all position thresholds (that may have been disabled by a threshold-crossing)
334
        # cf. https://sanworks.github.io/Bpod_Wiki/serial-interfaces/rotary-encoder-module-serial-interface/
335
        sma.add_state(
1✔
336
            state_name='reset_rotary_encoder',
337
            state_timer=0,
338
            output_actions=[self.bpod.actions.rotary_encoder_reset],
339
            state_change_conditions={'Tup': 'quiescent_period'},
340
        )
341

342
        # Quiescent Period. If the wheel is moved past one of the thresholds: Reset the rotary encoder and start over.
343
        # Continue with the stimulation once the quiescent period has passed without triggering movement thresholds.
344
        sma.add_state(
1✔
345
            state_name='quiescent_period',
346
            state_timer=self.quiescent_period,
347
            output_actions=[],
348
            state_change_conditions={
349
                'Tup': 'stim_on',
350
                self.movement_left: 'reset_rotary_encoder',
351
                self.movement_right: 'reset_rotary_encoder',
352
            },
353
        )
354

355
        # Show the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
356
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
357
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
358
        sma.add_state(
1✔
359
            state_name='stim_on',
360
            state_timer=0.1,
361
            output_actions=[self.bpod.actions.bonsai_show_stim],
362
            state_change_conditions={'BNC1High': 'interactive_delay', 'BNC1Low': 'interactive_delay', 'Tup': 'interactive_delay'},
363
        )
364

365
        # Defined delay between visual and auditory cue
366
        sma.add_state(
1✔
367
            state_name='interactive_delay',
368
            state_timer=self.task_params.INTERACTIVE_DELAY,
369
            output_actions=[],
370
            state_change_conditions={'Tup': 'play_tone'},
371
        )
372

373
        # Play tone. Move to next state if sound is detected. Use the state-timer as a backup to prevent a stall.
374
        sma.add_state(
1✔
375
            state_name='play_tone',
376
            state_timer=0.1,
377
            output_actions=[self.bpod.actions.play_tone],
378
            state_change_conditions={'Tup': 'reset2_rotary_encoder', 'BNC2High': 'reset2_rotary_encoder'},
379
        )
380

381
        # Reset rotary encoder (see above). Move on after brief delay (to avoid a race conditions in the bonsai flow).
382
        sma.add_state(
1✔
383
            state_name='reset2_rotary_encoder',
384
            state_timer=0.05,
385
            output_actions=[self.bpod.actions.rotary_encoder_reset],
386
            state_change_conditions={'Tup': 'closed_loop'},
387
        )
388

389
        # Start the closed loop state in which the animal controls the position of the visual stimulus by means of the
390
        # rotary encoder. The three possible outcomes are:
391
        # 1) wheel has NOT been moved past a threshold: continue with no-go condition
392
        # 2) wheel has been moved in WRONG direction: continue with error condition
393
        # 3) wheel has been moved in CORRECT direction: continue with reward condition
394

395
        sma.add_state(
1✔
396
            state_name='closed_loop',
397
            state_timer=self.task_params.RESPONSE_WINDOW,
398
            output_actions=[self.bpod.actions.bonsai_closed_loop],
399
            state_change_conditions={'Tup': 'no_go', self.event_error: 'freeze_error', self.event_reward: 'freeze_reward'},
400
        )
401

402
        # No-go: hide the visual stimulus and play white noise. Go to exit_state after FEEDBACK_NOGO_DELAY_SECS.
403
        sma.add_state(
1✔
404
            state_name='no_go',
405
            state_timer=self.feedback_nogo_delay,
406
            output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
407
            state_change_conditions={'Tup': 'exit_state'},
408
        )
409

410
        # Error: Freeze the stimulus and play white noise.
411
        # Continue to hide_stim/exit_state once FEEDBACK_ERROR_DELAY_SECS have passed.
412
        sma.add_state(
1✔
413
            state_name='freeze_error',
414
            state_timer=0,
415
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
416
            state_change_conditions={'Tup': 'error'},
417
        )
418
        sma.add_state(
1✔
419
            state_name='error',
420
            state_timer=self.feedback_error_delay,
421
            output_actions=[self.bpod.actions.play_noise],
422
            state_change_conditions={'Tup': 'hide_stim'},
423
        )
424

425
        # Reward: open the valve for a defined duration (and set BNC1 to high), freeze stimulus in center of screen.
426
        # Continue to hide_stim/exit_state once FEEDBACK_CORRECT_DELAY_SECS have passed.
427
        sma.add_state(
1✔
428
            state_name='freeze_reward',
429
            state_timer=0,
430
            output_actions=[self.bpod.actions.bonsai_show_center],
431
            state_change_conditions={'Tup': 'reward'},
432
        )
433
        sma.add_state(
1✔
434
            state_name='reward',
435
            state_timer=self.reward_time,
436
            output_actions=[('Valve1', 255), ('BNC1', 255)],
437
            state_change_conditions={'Tup': 'correct'},
438
        )
439
        sma.add_state(
1✔
440
            state_name='correct',
441
            state_timer=self.feedback_correct_delay - self.reward_time,
442
            output_actions=[],
443
            state_change_conditions={'Tup': 'hide_stim'},
444
        )
445

446
        # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
447
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
448
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
449
        sma.add_state(
1✔
450
            state_name='hide_stim',
451
            state_timer=0.1,
452
            output_actions=[self.bpod.actions.bonsai_hide_stim],
453
            state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
454
        )
455

456
        # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event.
457
        sma.add_state(
1✔
458
            state_name='exit_state',
459
            state_timer=self.task_params.ITI_DELAY_SECS,
460
            output_actions=[('BNC1', 255)],
461
            state_change_conditions={'Tup': 'exit'},
462
        )
463

464
        return sma
1✔
465

466
    @abc.abstractmethod
1✔
467
    def next_trial(self):
1✔
UNCOV
468
        pass
×
469

470
    @property
1✔
471
    def default_reward_amount(self):
1✔
472
        return self.task_params.REWARD_AMOUNT_UL
1✔
473

474
    def draw_next_trial_info(self, pleft=0.5, **kwargs):
1✔
475
        """Draw next trial variables.
476

477
        calls :meth:`send_trial_info_to_bonsai`.
478
        This is called by the `next_trial` method before updating the Bpod state machine.
479
        """
480
        assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported'
1✔
481
        contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE)
1✔
482
        position = int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft]))
1✔
483
        quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(
1✔
484
            scale=0.35, min_value=0.2, max_value=0.5
485
        )
486
        self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period
1✔
487
        self.trials_table.at[self.trial_num, 'contrast'] = contrast
1✔
488
        self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi)
1✔
489
        self.trials_table.at[self.trial_num, 'stim_sigma'] = self.task_params.STIM_SIGMA
1✔
490
        self.trials_table.at[self.trial_num, 'stim_angle'] = self.task_params.STIM_ANGLE
1✔
491
        self.trials_table.at[self.trial_num, 'stim_gain'] = self.stimulus_gain
1✔
492
        self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ
1✔
493
        self.trials_table.at[self.trial_num, 'stim_reverse'] = self.task_params.STIM_REVERSE
1✔
494
        self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num
1✔
495
        self.trials_table.at[self.trial_num, 'position'] = position
1✔
496
        self.trials_table.at[self.trial_num, 'reward_amount'] = self.default_reward_amount
1✔
497
        self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft
1✔
498

499
        # use the kwargs dict to override computed values
500
        for key, value in kwargs.items():
1✔
501
            if key == 'index':
1✔
UNCOV
502
                pass
×
503
            self.trials_table.at[self.trial_num, key] = value
1✔
504

505
        self.send_trial_info_to_bonsai()
1✔
506

507
    def trial_completed(self, bpod_data: dict[str, Any]) -> None:
1✔
508
        # if the reward state has not been triggered, null the reward
509
        if np.isnan(bpod_data['States timestamps']['reward'][0][0]):
1✔
510
            self.trials_table.at[self.trial_num, 'reward_amount'] = 0
1✔
511
        self.trials_table.at[self.trial_num, 'reward_valve_time'] = self.reward_time
1✔
512
        # update cumulative reward value
513
        self.session_info.TOTAL_WATER_DELIVERED += self.trials_table.at[self.trial_num, 'reward_amount']
1✔
514
        self.session_info.NTRIALS += 1
1✔
515
        # SAVE TRIAL DATA
516
        self.save_trial_data_to_json(bpod_data)
1✔
517

518
        # save ambient data
519
        if self.hardware_settings.device_bpod.USE_AMBIENT_MODULE:
1✔
520
            self.ambient_sensor_table.iloc[self.trial_num] = (sensor_reading := self.bpod.get_ambient_sensor_reading())
1✔
521
            with self.paths['AMBIENT_FILE_PATH'].open('ab') as f:
1✔
522
                binary.write_array(f, [self.trial_num, *sensor_reading], DTYPE_AMBIENT_SENSOR_BIN)
1✔
523

524
        # this is a flag for the online plots. If online plots were in pyqt5, there is a file watcher functionality
525
        Path(self.paths['DATA_FILE_PATH']).parent.joinpath('new_trial.flag').touch()
1✔
526
        self.paths.SESSION_FOLDER.joinpath('transfer_me.flag').touch()
1✔
527
        self.check_sync_pulses(bpod_data=bpod_data)
1✔
528

529
    def check_sync_pulses(self, bpod_data):
1✔
530
        # todo move this in the post trial when we have a task flow
531
        if not self.bpod.is_connected:
1✔
532
            return
1✔
UNCOV
533
        events = bpod_data['Events timestamps']
×
UNCOV
534
        if not misc.get_port_events(events, name='BNC1'):
×
UNCOV
535
            log.warning("NO FRAME2TTL PULSES RECEIVED ON BPOD'S TTL INPUT 1")
×
UNCOV
536
        if not misc.get_port_events(events, name='BNC2'):
×
UNCOV
537
            log.warning("NO SOUND SYNC PULSES RECEIVED ON BPOD'S TTL INPUT 2")
×
UNCOV
538
        if not misc.get_port_events(events, name='Port1'):
×
UNCOV
539
            log.warning("NO CAMERA SYNC PULSES RECEIVED ON BPOD'S BEHAVIOR PORT 1")
×
540

541
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
1✔
542
        """
543
        Log the details of the current trial.
544

545
        This method retrieves information about the current trial from the
546
        trials table and logs it. It can also incorporate additional information
547
        provided through the `extra_info` parameter.
548

549
        Parameters
550
        ----------
551
        extra_info : dict[str, Any], optional
552
            A dictionary containing additional information to include in the
553
            log.
554

555
        log_level : int, optional
556
            The logging level to use when logging the trial information.
557
            Default is logging.INFO.
558

559
        Notes
560
        -----
561
        When overloading, make sure to call the super class and pass additional
562
        log items by means of the extra_info parameter. See the implementation
563
        of :py:meth:`~iblrig.base_choice_world.ActiveChoiceWorldSession.show_trial_log` in
564
        :mod:`~iblrig.base_choice_world.ActiveChoiceWorldSession` for reference.
565
        """
566
        # construct base info dict
567
        trial_info = self.trials_table.iloc[self.trial_num]
1✔
568
        info_dict = {
1✔
569
            'Stim. Position': trial_info.position,
570
            'Stim. Contrast': trial_info.contrast,
571
            'Stim. Phase': f'{trial_info.stim_phase:.2f}',
572
            'Stim. p Left': trial_info.stim_probability_left,
573
            'Water delivered': f'{self.session_info.TOTAL_WATER_DELIVERED:.1f} µl',
574
            'Time from Start': self.time_elapsed,
575
            'Temperature': f'{self.ambient_sensor_table.loc[self.trial_num, "Temperature_C"]:.1f} °C',
576
            'Air Pressure': f'{self.ambient_sensor_table.loc[self.trial_num, "AirPressure_mb"]:.1f} mb',
577
            'Rel. Humidity': f'{self.ambient_sensor_table.loc[self.trial_num, "RelativeHumidity"]:.1f} %',
578
        }
579

580
        # update info dict with extra_info dict
581
        if isinstance(extra_info, dict):
1✔
582
            info_dict.update(extra_info)
1✔
583

584
        # log info dict
585
        log.log(log_level, f'Outcome of Trial #{trial_info.trial_num}:')
1✔
586
        max_key_length = max(len(key) for key in info_dict)
1✔
587
        for key, value in info_dict.items():
1✔
588
            spaces = (max_key_length - len(key)) * ' '
1✔
589
            log.log(log_level, f'- {key}: {spaces}{str(value)}')
1✔
590

591
    @property
1✔
592
    def iti_reward(self):
1✔
593
        """
594
        Returns the ITI time that needs to be set in order to achieve the desired ITI,
595
        by subtracting the time it takes to give a reward from the desired ITI.
596
        """
UNCOV
597
        return self.task_params.ITI_CORRECT - self.calibration.get('REWARD_VALVE_TIME', None)
×
598

599
    """
1✔
600
    Those are the properties that are used in the state machine code
601
    """
602

603
    @property
1✔
604
    def reward_time(self):
1✔
605
        return self.compute_reward_time(amount_ul=self.trials_table.at[self.trial_num, 'reward_amount'])
1✔
606

607
    @property
1✔
608
    def quiescent_period(self):
1✔
609
        return self.trials_table.at[self.trial_num, 'quiescent_period']
1✔
610

611
    @property
1✔
612
    def feedback_correct_delay(self):
1✔
613
        return self.task_params['FEEDBACK_CORRECT_DELAY_SECS']
1✔
614

615
    @property
1✔
616
    def feedback_error_delay(self):
1✔
617
        return self.task_params['FEEDBACK_ERROR_DELAY_SECS']
1✔
618

619
    @property
1✔
620
    def feedback_nogo_delay(self):
1✔
621
        return self.task_params['FEEDBACK_NOGO_DELAY_SECS']
1✔
622

623
    @property
1✔
624
    def position(self):
1✔
625
        return self.trials_table.at[self.trial_num, 'position']
1✔
626

627
    @property
1✔
628
    def event_error(self):
1✔
629
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(-1 if self.task_params.STIM_REVERSE else 1) * self.position]
1✔
630

631
    @property
1✔
632
    def event_reward(self):
1✔
633
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(1 if self.task_params.STIM_REVERSE else -1) * self.position]
1✔
634

635

636
class HabituationChoiceWorldTrialData(ChoiceWorldTrialData):
1✔
637
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
638

639
    delay_to_stim_center: NonNegativeFloat
1✔
640

641

642
class HabituationChoiceWorldSession(ChoiceWorldSession):
1✔
643
    protocol_name = '_iblrig_tasks_habituationChoiceWorld'
1✔
644
    TrialDataModel = HabituationChoiceWorldTrialData
1✔
645

646
    def next_trial(self):
1✔
647
        self.trial_num += 1
1✔
648
        self.draw_next_trial_info()
1✔
649

650
    def draw_next_trial_info(self, *args, **kwargs):
1✔
651
        # update trial table fields specific to habituation choice world
652
        self.trials_table.at[self.trial_num, 'delay_to_stim_center'] = np.random.normal(self.task_params.DELAY_TO_STIM_CENTER, 2)
1✔
653
        super().draw_next_trial_info(*args, **kwargs)
1✔
654

655
    def get_state_machine_trial(self, i):
1✔
656
        sma = StateMachine(self.bpod)
1✔
657

658
        if i == 0:  # First trial exception start camera
1✔
659
            log.info('Waiting for camera pulses...')
1✔
660
            sma.add_state(
1✔
661
                state_name='iti',
662
                state_timer=3600,
663
                state_change_conditions={'Port1In': 'stim_on'},
664
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
665
            )  # start camera
666
        else:
667
            # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
668
            # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
669
            # the offset of this state is trial start!
UNCOV
670
            sma.add_state(
×
671
                state_name='iti',
672
                state_timer=1,  # Stim off for 1 sec
673
                state_change_conditions={'Tup': 'stim_on'},
674
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
675
            )
676
        # This stim_on state is considered the actual trial start
677
        sma.add_state(
1✔
678
            state_name='stim_on',
679
            state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
680
            state_change_conditions={'Tup': 'stim_center'},
681
            output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
682
        )
683

684
        sma.add_state(
1✔
685
            state_name='stim_center',
686
            state_timer=0.5,
687
            state_change_conditions={'Tup': 'reward'},
688
            output_actions=[self.bpod.actions.bonsai_show_center],
689
        )
690

691
        sma.add_state(
1✔
692
            state_name='reward',
693
            state_timer=self.reward_time,  # the length of time to leave reward valve open, i.e. reward size
694
            state_change_conditions={'Tup': 'post_reward'},
695
            output_actions=[('Valve1', 255), ('BNC1', 255)],
696
        )
697
        # This state defines the period after reward where Bpod TTL is LOW.
698
        # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
699
        # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.
700
        sma.add_state(
1✔
701
            state_name='post_reward',
702
            state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
703
            state_change_conditions={'Tup': 'exit'},
704
            output_actions=[],
705
        )
706
        return sma
1✔
707

708

709
class ActiveChoiceWorldTrialData(ChoiceWorldTrialData):
1✔
710
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
711

712
    response_side: Annotated[int, Interval(ge=-1, le=1)]
1✔
713
    response_time: NonNegativeFloat
1✔
714
    trial_correct: bool
1✔
715

716

717
class ActiveChoiceWorldSession(ChoiceWorldSession):
1✔
718
    """
719
    The ActiveChoiceWorldSession is a base class for protocols where the mouse is actively making decisions
720
    by turning the wheel. It has the following characteristics
721

722
    -   it is trial based
723
    -   it is decision based
724
    -   left and right simulus are equiprobable: there is no biased block
725
    -   a trial can either be correct / error / no_go depending on the side of the stimulus and the response
726
    -   it has a quantifiable performance by computing the proportion of correct trials of passive stimulations protocols or
727
        habituation protocols.
728

729
    The TrainingChoiceWorld, BiasedChoiceWorld are all subclasses of this class
730
    """
731

732
    TrialDataModel = ActiveChoiceWorldTrialData
1✔
733
    plot_subprocess: subprocess.Popen | None = None
1✔
734

735
    def __init__(self, **kwargs):
1✔
736
        super().__init__(**kwargs)
1✔
737
        self.trials_table['stim_probability_left'] = np.zeros(NTRIALS_INIT, dtype=np.float64)
1✔
738

739
    def _run(self):
1✔
740
        # starts online plotting
741
        if self.interactive:
1✔
NEW
742
            log.info('Starting subprocess: online plots')
×
NEW
743
            self.plot_subprocess = subprocess.Popen(
×
744
                ['view_session', str(self.paths['SESSION_RAW_DATA_FOLDER'])],
745
                stdout=subprocess.DEVNULL,
746
                stderr=subprocess.STDOUT,
747
            )
748
        super()._run()
1✔
749

750
    def __del__(self):
1✔
751
        if isinstance(self.plot_subprocess, subprocess.Popen) and self.plot_subprocess.poll() is None:
1✔
NEW
752
            log.info('Terminating subprocess: online plots')
×
NEW
753
            self.plot_subprocess.terminate()
×
NEW
754
            try:
×
NEW
755
                self.plot_subprocess.wait(timeout=5)
×
NEW
756
            except subprocess.TimeoutExpired:
×
NEW
757
                log.warning('Process did not terminate within 5 seconds - killing it.')
×
NEW
758
                self.plot_subprocess.kill()
×
759

760
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
1✔
761
        # construct info dict
762
        trial_info = self.trials_table.iloc[self.trial_num]
1✔
763
        info_dict = {
1✔
764
            'Response Time': f'{trial_info.response_time:.2f} s',
765
            'Trial Correct': trial_info.trial_correct,
766
            'N Trials Correct': self.session_info.NTRIALS_CORRECT,
767
            'N Trials Error': self.trial_num - self.session_info.NTRIALS_CORRECT + 1,
768
        }
769

770
        # update info dict with extra_info dict
771
        if isinstance(extra_info, dict):
1✔
772
            info_dict.update(extra_info)
1✔
773

774
        # call parent method
775
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
1✔
776

777
    def trial_completed(self, bpod_data: dict) -> None:
1✔
778
        """
779
        Update the trials table with information about the behaviour coming from the bpod.
780

781
        Constraints on the state machine data:
782

783
        - mandatory states: ['correct', 'error', 'no_go', 'reward']
784
        - optional states : ['omit_correct', 'omit_error', 'omit_no_go']
785

786
        Parameters
787
        ----------
788
        bpod_data : dict
789
            The Bpod data as returned by pybpod
790

791
        Raises
792
        ------
793
        AssertionError
794
            If the position is zero or if the number of detected outcomes is not exactly one.
795
        """
796
        # Get the response time from the behaviour data.
797
        # It is defined as the time passing between the start of `stim_on` and the end of `closed_loop`.
798
        state_times = bpod_data['States timestamps']
1✔
799
        response_time = state_times['closed_loop'][0][1] - state_times['stim_on'][0][0]
1✔
800
        self.trials_table.at[self.trial_num, 'response_time'] = response_time
1✔
801

802
        try:
1✔
803
            # Get the stimulus position
804
            position = self.trials_table.at[self.trial_num, 'position']
1✔
805
            assert position != 0, 'the stimulus position should not be 0'
1✔
806

807
            # Get the trial's outcome, i.e., the states that have a matching name and a valid time-stamp
808
            # Assert that we have exactly one outcome
809
            outcome_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
1✔
810
            outcomes = [name for name, times in state_times.items() if name in outcome_names and ~np.isnan(times[0][0])]
1✔
811
            if (n_outcomes := len(outcomes)) != 1:
1✔
UNCOV
812
                trial_states = 'Trial states: ' + ', '.join(k for k, v in state_times.items() if ~np.isnan(v[0][0]))
×
UNCOV
813
                assert n_outcomes != 0, f'No outcome detected for trial {self.trial_num}.\n{trial_states}'
×
UNCOV
814
                assert n_outcomes == 1, f'{n_outcomes} outcomes detected for trial {self.trial_num}.\n{trial_states}'
×
815
            outcome = outcomes[0]
1✔
816

UNCOV
817
        except AssertionError as e:
×
818
            # write bpod_data to disk, log exception then raise
UNCOV
819
            self.save_trial_data_to_json(bpod_data, validate=False)
×
UNCOV
820
            for line in re_split(r'\n', e.args[0]):
×
UNCOV
821
                log.error(line)
×
UNCOV
822
            raise e
×
823

824
        # record the trial's outcome in the trials_table
825
        self.trials_table.at[self.trial_num, 'trial_correct'] = 'correct' in outcome
1✔
826
        if 'correct' in outcome:
1✔
827
            self.session_info.NTRIALS_CORRECT += 1
1✔
828
            self.trials_table.at[self.trial_num, 'response_side'] = -np.sign(position)
1✔
829
        elif 'error' in outcome:
1✔
830
            self.trials_table.at[self.trial_num, 'response_side'] = np.sign(position)
1✔
831
        elif 'no_go' in outcome:
1✔
832
            self.trials_table.at[self.trial_num, 'response_side'] = 0
1✔
833

834
        super().trial_completed(bpod_data)
1✔
835

836

837
class BiasedChoiceWorldTrialData(ActiveChoiceWorldTrialData):
1✔
838
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
839

840
    block_num: NonNegativeInt = 0
1✔
841
    block_trial_num: NonNegativeInt = 0
1✔
842

843

844
class BiasedChoiceWorldSession(ActiveChoiceWorldSession):
1✔
845
    """
846
    Biased choice world session is the instantiation of ActiveChoiceWorld where the notion of biased
847
    blocks is introduced.
848
    """
849

850
    base_parameters_file = Path(__file__).parent.joinpath('base_biased_choice_world_params.yaml')
1✔
851
    protocol_name = '_iblrig_tasks_biasedChoiceWorld'
1✔
852
    TrialDataModel = BiasedChoiceWorldTrialData
1✔
853

854
    def __init__(self, **kwargs):
1✔
855
        super().__init__(**kwargs)
1✔
856
        self.blocks_table = pd.DataFrame(
1✔
857
            {'probability_left': np.zeros(NBLOCKS_INIT) * np.nan, 'block_length': np.zeros(NBLOCKS_INIT, dtype=np.int16) * -1}
858
        )
859

860
    def new_block(self):
1✔
861
        """
862
        If block_init_5050
863
            First block has 50/50 probability of leftward stim
864
            is 90 trials long
865
        """
866
        self.block_num += 1  # the block number is zero based
1✔
867
        self.block_trial_num = 0
1✔
868

869
        # handles the block length logic
870
        if self.task_params.BLOCK_INIT_5050 and self.block_num == 0:
1✔
871
            block_len = 90
1✔
872
        else:
873
            block_len = int(
1✔
874
                misc.truncated_exponential(
875
                    scale=self.task_params.BLOCK_LEN_FACTOR,
876
                    min_value=self.task_params.BLOCK_LEN_MIN,
877
                    max_value=self.task_params.BLOCK_LEN_MAX,
878
                )
879
            )
880
        if self.block_num == 0:
1✔
881
            pleft = 0.5 if self.task_params.BLOCK_INIT_5050 else np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
1✔
882
        elif self.block_num == 1 and self.task_params.BLOCK_INIT_5050:
1✔
883
            pleft = np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
1✔
884
        else:
885
            # this switches the probability of leftward stim for the next block
886
            pleft = round(abs(1 - self.blocks_table.loc[self.block_num - 1, 'probability_left']), 1)
1✔
887
        self.blocks_table.at[self.block_num, 'block_length'] = block_len
1✔
888
        self.blocks_table.at[self.block_num, 'probability_left'] = pleft
1✔
889

890
    def next_trial(self):
1✔
891
        self.trial_num += 1
1✔
892
        # if necessary update the block number
893
        self.block_trial_num += 1
1✔
894
        if self.block_num < 0 or self.block_trial_num > (self.blocks_table.loc[self.block_num, 'block_length'] - 1):
1✔
895
            self.new_block()
1✔
896
        # get and store probability left
897
        pleft = self.blocks_table.loc[self.block_num, 'probability_left']
1✔
898
        # update trial table fields specific to biased choice world task
899
        self.trials_table.at[self.trial_num, 'block_num'] = self.block_num
1✔
900
        self.trials_table.at[self.trial_num, 'block_trial_num'] = self.block_trial_num
1✔
901
        # save and send trial info to bonsai
902
        self.draw_next_trial_info(pleft=pleft)
1✔
903

904
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
1✔
905
        # construct info dict
906
        trial_info = self.trials_table.iloc[self.trial_num]
1✔
907
        info_dict = {
1✔
908
            'Block Number': trial_info.block_num,
909
            'Block Length': self.blocks_table.loc[self.block_num, 'block_length'],
910
            'N Trials in Block': trial_info.block_trial_num,
911
        }
912

913
        # update info dict with extra_info dict
914
        if isinstance(extra_info, dict):
1✔
UNCOV
915
            info_dict.update(extra_info)
×
916

917
        # call parent method
918
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
1✔
919

920

921
class TrainingChoiceWorldTrialData(ActiveChoiceWorldTrialData):
1✔
922
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ActiveChoiceWorldTrialData`."""
923

924
    training_phase: NonNegativeInt
1✔
925
    debias_trial: bool
1✔
926

927

928
class TrainingChoiceWorldSession(ActiveChoiceWorldSession):
1✔
929
    """
930
    The TrainingChoiceWorldSession corresponds to the first training protocol of the choice world task.
931
    This protocol has a complicated adaptation of the number of contrasts (embodied by the training_phase
932
    property) and the reward amount, embodied by the adaptive_reward property.
933
    """
934

935
    protocol_name = '_iblrig_tasks_trainingChoiceWorld'
1✔
936
    TrialDataModel = TrainingChoiceWorldTrialData
1✔
937

938
    def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, **kwargs):
1✔
939
        super().__init__(**kwargs)
1✔
940
        inferred_training_phase, inferred_adaptive_reward, inferred_adaptive_gain = self.get_subject_training_info()
1✔
941
        if training_phase == -1:
1✔
942
            log.critical(f'Got training phase: {inferred_training_phase}')
1✔
943
            self.training_phase = inferred_training_phase
1✔
944
        else:
945
            log.critical(f'Training phase manually set to: {training_phase}')
1✔
946
            self.training_phase = training_phase
1✔
947
        if adaptive_reward == -1:
1✔
948
            log.critical(f'Got Adaptive reward {inferred_adaptive_reward} uL')
1✔
949
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = inferred_adaptive_reward
1✔
950
        else:
951
            log.critical(f'Adaptive reward manually set to {adaptive_reward} uL')
1✔
952
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = adaptive_reward
1✔
953
        if adaptive_gain is None:
1✔
954
            log.critical(f'Got Adaptive gain {inferred_adaptive_gain} degrees/mm')
1✔
955
            self.session_info['ADAPTIVE_GAIN_VALUE'] = inferred_adaptive_gain
1✔
956
        else:
957
            log.critical(f'Adaptive gain manually set to {adaptive_gain} degrees/mm')
1✔
958
            self.session_info['ADAPTIVE_GAIN_VALUE'] = adaptive_gain
1✔
959
        self.var = {'training_phase_trial_counts': np.zeros(6), 'last_10_responses_sides': np.zeros(10)}
1✔
960

961
    @property
1✔
962
    def default_reward_amount(self):
1✔
963
        return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL)
1✔
964

965
    @property
1✔
966
    def stimulus_gain(self) -> float:
1✔
967
        return self.session_info.get('ADAPTIVE_GAIN_VALUE')
1✔
968

969
    def get_subject_training_info(self):
1✔
970
        """
971
        Get the previous session's according to this session parameters and deduce the
972
        training level, adaptive reward amount and adaptive gain value.
973

974
        Returns
975
        -------
976
        training_info: dict
977
            Dictionary with keys: training_phase, adaptive_reward, adaptive_gain
978
        """
979
        training_info, _ = choiceworld.get_subject_training_info(
1✔
980
            subject_name=self.session_info.SUBJECT_NAME,
981
            task_name=self.protocol_name,
982
            stim_gain=self.task_params.AG_INIT_VALUE,
983
            stim_gain_on_error=self.task_params.STIM_GAIN,
984
            default_reward=self.task_params.REWARD_AMOUNT_UL,
985
            local_path=self.iblrig_settings['iblrig_local_data_path'],
986
            remote_path=self.iblrig_settings['iblrig_remote_data_path'],
987
            lab=self.iblrig_settings['ALYX_LAB'],
988
            iblrig_settings=self.iblrig_settings,
989
        )
990
        return training_info['training_phase'], training_info['adaptive_reward'], training_info['adaptive_gain']
1✔
991

992
    def check_training_phase(self) -> bool:
1✔
993
        """Check if the mouse is ready to move to the next training phase."""
994
        move_on = False
1✔
995
        if self.training_phase == 0:  # each of the -1, -.5, .5, 1 contrast should be above 80% perf to switch
1✔
996
            performance = choiceworld.compute_performance(self.trials_table)
1✔
997
            passing = performance[np.abs(performance.index) >= 0.5]['last_50_perf']
1✔
998
            if np.all(passing > 0.8) and passing.size == 4:
1✔
999
                move_on = True
1✔
1000
        elif self.training_phase == 1:  # each of the -.25, .25 should be above 80% perf to switch
1✔
1001
            performance = choiceworld.compute_performance(self.trials_table)
1✔
1002
            passing = performance[np.abs(performance.index) == 0.25]['last_50_perf']
1✔
1003
            if np.all(passing > 0.8) and passing.size == 2:
1✔
1004
                move_on = True
1✔
1005
        elif 5 > self.training_phase >= 2:  # for the next phases, always switch after 200 trials
1✔
1006
            if self.var['training_phase_trial_counts'][self.training_phase] >= 200:
1✔
1007
                move_on = True
1✔
1008
        if move_on:
1✔
1009
            self.training_phase = np.minimum(5, self.training_phase + 1)
1✔
1010
            log.warning(f'Moving on to training phase {self.training_phase}, {self.trial_num}')
1✔
1011
        return move_on
1✔
1012

1013
    def next_trial(self):
1✔
1014
        # update counters
1015
        self.trial_num += 1
1✔
1016
        self.var['training_phase_trial_counts'][self.training_phase] += 1
1✔
1017

1018
        # check if the subject graduates to a new training phase
1019
        self.check_training_phase()
1✔
1020

1021
        # draw the next trial
1022
        signed_contrast = choiceworld.draw_training_contrast(self.training_phase)
1✔
1023
        position = self.task_params.STIM_POSITIONS[int(np.sign(signed_contrast) == 1)]
1✔
1024
        contrast = np.abs(signed_contrast)
1✔
1025

1026
        # debiasing: if the previous trial was incorrect, not a no-go and easy
1027
        if self.task_params.DEBIAS and self.trial_num >= 1 and self.training_phase < 5:
1✔
1028
            last_contrast = self.trials_table.loc[self.trial_num - 1, 'contrast']
1✔
1029
            do_debias_trial = (
1✔
1030
                (self.trials_table.loc[self.trial_num - 1, 'trial_correct'] != 1)
1031
                and (self.trials_table.loc[self.trial_num - 1, 'response_side'] != 0)
1032
                and last_contrast >= 0.5
1033
            )
1034
            self.trials_table.at[self.trial_num, 'debias_trial'] = do_debias_trial
1✔
1035
            if do_debias_trial:
1✔
1036
                # indices of trials that had a response
1037
                iresponse = np.logical_and(self.trials_table['response_side'].notna(), self.trials_table['response_side'] != 0)
1✔
1038
                iresponse = iresponse.index[iresponse]
1✔
1039

1040
                # takes the average of right responses over last 10 response trials
1041
                average_right = (self.trials_table['response_side'][iresponse[-np.minimum(10, iresponse.size) :]] == 1).mean()
1✔
1042

1043
                # the probability of the next stimulus being on the left is a draw from a normal distribution centered
1044
                # on the average right with sigma 0.5 - if it is less than 0.5 the next stimulus will be on the left.
1045
                position = self.task_params.STIM_POSITIONS[int(np.random.normal(average_right, 0.5) >= 0.5)]
1✔
1046

1047
                # contrast is the last contrast
1048
                contrast = last_contrast
1✔
1049
        else:
1050
            self.trials_table.at[self.trial_num, 'debias_trial'] = False
1✔
1051

1052
        # save and send trial info to bonsai
1053
        self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT, position=position, contrast=contrast)
1✔
1054
        self.trials_table.at[self.trial_num, 'training_phase'] = self.training_phase
1✔
1055

1056
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
1✔
1057
        # construct info dict
UNCOV
1058
        info_dict = {
×
1059
            'Contrast Set': np.unique(np.abs(choiceworld.contrasts_set(self.training_phase))),
1060
            'Training Phase': self.training_phase,
1061
            'Debias Trial': self.trials_table.at[self.trial_num, 'debias_trial'],
1062
        }
1063

1064
        # update info dict with extra_info dict
UNCOV
1065
        if isinstance(extra_info, dict):
×
UNCOV
1066
            info_dict.update(extra_info)
×
1067

1068
        # call parent method
UNCOV
1069
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc