• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 15738036488

18 Jun 2025 04:10PM UTC coverage: 48.249% (+1.5%) from 46.79%
15738036488

Pull #815

github

9b495a
web-flow
Merge fd70c12e3 into 5c537cbb7
Pull Request #815: extended tests for photometry copier

23 of 32 new or added lines in 1 file covered. (71.88%)

1106 existing lines in 22 files now uncovered.

4408 of 9136 relevant lines covered (48.25%)

0.96 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.04
/iblrig/base_choice_world.py
1
"""Extends the base_tasks modules by providing task logic around the Choice World protocol."""
2

3
import abc
2✔
4
import logging
2✔
5
import math
2✔
6
import random
2✔
7
import subprocess
2✔
8
import time
2✔
9
from pathlib import Path
2✔
10
from re import split as re_split
2✔
11
from string import ascii_letters
2✔
12
from typing import Annotated, Any
2✔
13

14
import numpy as np
2✔
15
import pandas as pd
2✔
16
from annotated_types import Interval, IsNan
2✔
17
from pydantic import NonNegativeFloat, NonNegativeInt
2✔
18

19
import iblrig.base_tasks
2✔
20
from iblrig import choiceworld, misc
2✔
21
from iblrig.hardware import DTYPE_AMBIENT_SENSOR_BIN, SOFTCODE
2✔
22
from iblrig.pydantic_definitions import TrialDataModel
2✔
23
from iblutil.io import binary, jsonable
2✔
24
from iblutil.util import Bunch
2✔
25
from pybpodapi.com.messaging.trial import Trial
2✔
26
from pybpodapi.protocol import StateMachine
2✔
27

28
log = logging.getLogger(__name__)
2✔
29

30
NTRIALS_INIT = 2000
2✔
31
NBLOCKS_INIT = 100
2✔
32

33
# TODO: task parameters should be verified through a pydantic model
34
#
35
# Probability = Annotated[float, Field(ge=0.0, le=1.0)]
36
#
37
# class ChoiceWorldParams(BaseModel):
38
#     AUTOMATIC_CALIBRATION: bool = True
39
#     ADAPTIVE_REWARD: bool = False
40
#     BONSAI_EDITOR: bool = False
41
#     CALIBRATION_VALUE: float = 0.067
42
#     CONTRAST_SET: list[Probability] = Field([1.0, 0.25, 0.125, 0.0625, 0.0], min_length=1)
43
#     CONTRAST_SET_PROBABILITY_TYPE: Literal['uniform', 'skew_zero'] = 'uniform'
44
#     GO_TONE_AMPLITUDE: float = 0.0272
45
#     GO_TONE_DURATION: float = 0.11
46
#     GO_TONE_IDX: int = Field(2, ge=0)
47
#     GO_TONE_FREQUENCY: float = Field(5000, gt=0)
48
#     FEEDBACK_CORRECT_DELAY_SECS: float = 1
49
#     FEEDBACK_ERROR_DELAY_SECS: float = 2
50
#     FEEDBACK_NOGO_DELAY_SECS: float = 2
51
#     INTERACTIVE_DELAY: float = 0.0
52
#     ITI_DELAY_SECS: float = 0.5
53
#     NTRIALS: int = Field(2000, gt=0)
54
#     PROBABILITY_LEFT: Probability = 0.5
55
#     QUIESCENCE_THRESHOLDS: list[float] = Field(default=[-2, 2], min_length=2, max_length=2)
56
#     QUIESCENT_PERIOD: float = 0.2
57
#     RECORD_AMBIENT_SENSOR_DATA: bool = True
58
#     RECORD_SOUND: bool = True
59
#     RESPONSE_WINDOW: float = 60
60
#     REWARD_AMOUNT_UL: float = 1.5
61
#     REWARD_TYPE: str = 'Water 10% Sucrose'
62
#     STIM_ANGLE: float = 0.0
63
#     STIM_FREQ: float = 0.1
64
#     STIM_GAIN: float = 4.0  # wheel to stimulus relationship (degrees visual angle per mm of wheel displacement)
65
#     STIM_POSITIONS: list[float] = [-35, 35]
66
#     STIM_SIGMA: float = 7.0
67
#     STIM_TRANSLATION_Z: Literal[7, 8] = 7  # 7 for ephys, 8 otherwise. -p:Stim.TranslationZ-{STIM_TRANSLATION_Z} bonsai param
68
#     STIM_REVERSE: bool = False
69
#     SYNC_SQUARE_X: float = 1.33
70
#     SYNC_SQUARE_Y: float = -1.03
71
#     USE_AUTOMATIC_STOPPING_CRITERIONS: bool = True
72
#     VISUAL_STIMULUS: str = 'GaborIBLTask / Gabor2D.bonsai'  # null / passiveChoiceWorld_passive.bonsai
73
#     WHITE_NOISE_AMPLITUDE: float = 0.05
74
#     WHITE_NOISE_DURATION: float = 0.5
75
#     WHITE_NOISE_IDX: int = 3
76

77

78
class ChoiceWorldTrialData(TrialDataModel):
2✔
79
    """Pydantic Model for Trial Data."""
80

81
    contrast: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
82
    stim_probability_left: Annotated[float, Interval(ge=0.0, le=1.0)]
2✔
83
    position: float
2✔
84
    quiescent_period: NonNegativeFloat
2✔
85
    reward_amount: NonNegativeFloat
2✔
86
    reward_valve_time: NonNegativeFloat
2✔
87
    stim_angle: Annotated[float, Interval(ge=-180.0, le=180.0)]
2✔
88
    stim_freq: NonNegativeFloat
2✔
89
    stim_gain: float
2✔
90
    stim_phase: Annotated[float, Interval(ge=0.0, le=2 * math.pi)]
2✔
91
    stim_reverse: bool
2✔
92
    stim_sigma: float
2✔
93
    trial_num: NonNegativeInt
2✔
94
    pause_duration: NonNegativeFloat = 0.0
2✔
95

96
    # The following variables are only used in ActiveChoiceWorld
97
    # We keep them here with fixed default values for sake of compatibility
98
    #
99
    # TODO: Yes, this should probably be done differently.
100
    response_side: Annotated[int, Interval(ge=0, le=0)] = 0
2✔
101
    response_time: IsNan[float] = np.nan
2✔
102
    trial_correct: Annotated[bool, Interval(ge=0, le=0)] = False
2✔
103

104

105
class ChoiceWorldSession(
2✔
106
    iblrig.base_tasks.BonsaiRecordingMixin,
107
    iblrig.base_tasks.BonsaiVisualStimulusMixin,
108
    iblrig.base_tasks.BpodMixin,
109
    iblrig.base_tasks.Frame2TTLMixin,
110
    iblrig.base_tasks.RotaryEncoderMixin,
111
    iblrig.base_tasks.SoundMixin,
112
    iblrig.base_tasks.ValveMixin,
113
    iblrig.base_tasks.NetworkSession,
114
):
115
    # task_params = ChoiceWorldParams()
116
    base_parameters_file = Path(__file__).parent.joinpath('base_choice_world_params.yaml')
2✔
117
    TrialDataModel = ChoiceWorldTrialData
2✔
118

119
    def __init__(self, *args, delay_mins: float = 0, **kwargs):
2✔
120
        super().__init__(**kwargs)
2✔
121

122
        # session delay is handled in seconds internally
123
        self.task_params['SESSION_DELAY_START'] = delay_mins * 60.0
2✔
124

125
        # init behaviour data
126
        self.movement_left = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[0]]
2✔
127
        self.movement_right = self.device_rotary_encoder.THRESHOLD_EVENTS[self.task_params.QUIESCENCE_THRESHOLDS[1]]
2✔
128

129
        # init counter variables
130
        self.trial_num = -1
2✔
131
        self.block_num = -1
2✔
132
        self.block_trial_num = -1
2✔
133

134
        # init the tables, there are 2 of them: a trials table and a ambient sensor data table
135
        self.trials_table = self.TrialDataModel.preallocate_dataframe(NTRIALS_INIT)
2✔
136
        self.ambient_sensor_table = pd.DataFrame(
2✔
137
            np.nan, index=range(NTRIALS_INIT), columns=['Temperature_C', 'AirPressure_mb', 'RelativeHumidity'], dtype=np.float32
138
        )
139
        self.ambient_sensor_table.rename_axis('Trial', inplace=True)
2✔
140

141
    @staticmethod
2✔
142
    def extra_parser():
2✔
143
        """:return: argparse.parser()"""
144
        parser = super(ChoiceWorldSession, ChoiceWorldSession).extra_parser()
2✔
145
        parser.add_argument(
2✔
146
            '--delay_mins',
147
            dest='delay_mins',
148
            default=0,
149
            type=float,
150
            required=False,
151
            help='initial delay before starting the first trial (default: 0 min)',
152
        )
153
        return parser
2✔
154

155
    def start_hardware(self):
2✔
156
        """
157
        In this step we explicitly run the start methods of the various mixins.
158
        The super class start method is overloaded because we need to start the different hardware pieces in order
159
        """
160
        if not self.is_mock:
2✔
UNCOV
161
            self.start_mixin_frame2ttl()
×
UNCOV
162
            self.start_mixin_bpod()
×
UNCOV
163
            self.start_mixin_valve()
×
UNCOV
164
            self.start_mixin_sound()
×
UNCOV
165
            self.start_mixin_rotary_encoder()
×
UNCOV
166
            self.start_mixin_bonsai_cameras()
×
UNCOV
167
            self.start_mixin_bonsai_microphone()
×
UNCOV
168
            self.start_mixin_bonsai_visual_stimulus()
×
UNCOV
169
            self.bpod.register_softcodes(self.softcode_dictionary())
×
170

171
    def _run(self):
2✔
172
        """Run the task with the actual state machine."""
173
        time_last_trial_end = time.time()
2✔
174
        for i in range(self.task_params.NTRIALS):  # Main loop
2✔
175
            # t_overhead = time.time()
176
            self.next_trial()
2✔
177
            log.info(f'Starting trial: {i}')
2✔
178
            # =============================================================================
179
            #     Start state machine definition
180
            # =============================================================================
181
            sma = self.get_state_machine_trial(i)
2✔
182
            log.debug('Sending state machine to bpod')
2✔
183
            # Send state machine description to Bpod device
184
            self.bpod.send_state_machine(sma)
2✔
185
            # t_overhead = time.time() - t_overhead
186
            # The ITI_DELAY_SECS defines the grey screen period within the state machine, where the
187
            # Bpod TTL is HIGH. The DEAD_TIME param defines the time between last trial and the next
188
            dead_time = self.task_params.get('DEAD_TIME', 0.5)
2✔
189
            dt = self.task_params.ITI_DELAY_SECS - dead_time - (time.time() - time_last_trial_end)
2✔
190
            # wait to achieve the desired ITI duration
191
            if dt > 0:
2✔
UNCOV
192
                time.sleep(dt)
×
193
            # Run state machine
194
            log.debug('running state machine')
2✔
195
            self.bpod.run_state_machine(sma)  # Locks until state machine 'exit' is reached
2✔
196
            time_last_trial_end = time.time()
2✔
197
            # handle pause event
198
            flag_pause = self.paths.SESSION_FOLDER.joinpath('.pause')
2✔
199
            flag_stop = self.paths.SESSION_FOLDER.joinpath('.stop')
2✔
200
            if flag_pause.exists() and i < (self.task_params.NTRIALS - 1):
2✔
201
                log.info(f'Pausing session inbetween trials {i} and {i + 1}')
2✔
202
                while flag_pause.exists() and not flag_stop.exists():
2✔
203
                    time.sleep(1)
2✔
204
                self.trials_table.at[self.trial_num, 'pause_duration'] = time.time() - time_last_trial_end
2✔
205
                if not flag_stop.exists():
2✔
206
                    log.info('Resuming session')
2✔
207

208
            # save trial and update log
209
            self.trial_completed(self.bpod.session.current_trial.export())
2✔
210
            self.show_trial_log()
2✔
211

212
            # handle stop event
213
            if flag_stop.exists():
2✔
214
                log.info('Stopping session after trial %d', i)
2✔
215
                flag_stop.unlink()
2✔
216
                break
2✔
217

218
    def mock(self, file_jsonable_fixture=None):
2✔
219
        """
220
        Instantiate a state machine and Bpod object to simulate a task's run.
221

222
        This is useful to test or display the state machine flow.
223
        """
224
        super().mock()
2✔
225

226
        if file_jsonable_fixture is not None:
2✔
227
            task_data = jsonable.read(file_jsonable_fixture)
2✔
228
            # pop-out the bpod data from the table
229
            bpod_data = []
2✔
230
            for td in task_data:
2✔
231
                bpod_data.append(td.pop('behavior_data'))
2✔
232

233
            class MockTrial(Trial):
2✔
234
                def export(self):
2✔
235
                    return np.random.choice(bpod_data)
2✔
236
        else:
237

238
            class MockTrial(Trial):
2✔
239
                def export(self):
2✔
UNCOV
240
                    return {}
×
241

242
        self.bpod.session.trials = [MockTrial()]
2✔
243
        self.bpod.send_state_machine = lambda k: None
2✔
244
        self.bpod.run_state_machine = lambda k: time.sleep(1.2)
2✔
245

246
        daction = ('dummy', 'action')
2✔
247
        self.sound = Bunch({'GO_TONE': daction, 'WHITE_NOISE': daction})
2✔
248

249
        self.bpod.actions.update(
2✔
250
            {
251
                'play_tone': daction,
252
                'play_noise': daction,
253
                'stop_sound': daction,
254
                'rotary_encoder_reset': daction,
255
                'bonsai_hide_stim': daction,
256
                'bonsai_show_stim': daction,
257
                'bonsai_closed_loop': daction,
258
                'bonsai_freeze_stim': daction,
259
                'bonsai_show_center': daction,
260
                'bonsai_freeze_center': daction,
261
            }
262
        )
263

264
    def get_graphviz_task(self, output_file=None, view=True):
2✔
265
        """
266
        Get the state machine's states diagram in Digraph format.
267

268
        :param output_file:
269
        :return:
270
        """
271
        import graphviz
2✔
272

273
        self.next_trial()
2✔
274
        sma = self.get_state_machine_trial(0)
2✔
275
        if sma is None:
2✔
276
            return
2✔
277
        states_indices = {i: k for i, k in enumerate(sma.state_names)}
2✔
278
        states_indices.update({(i + 10000): k for i, k in enumerate(sma.undeclared)})
2✔
279
        states_letters = {k: ascii_letters[i] for i, k in enumerate(sma.state_names)}
2✔
280
        dot = graphviz.Digraph(comment='The Great IBL Task')
2✔
281
        edges = []
2✔
282

283
        for i in range(len(sma.state_names)):
2✔
284
            letter = states_letters[sma.state_names[i]]
2✔
285
            dot.node(letter, sma.state_names[i])
2✔
286
            if ~np.isnan(sma.state_timer_matrix[i]):
2✔
287
                out_state = states_indices[sma.state_timer_matrix[i]]
2✔
288
                edges.append(f'{letter}{states_letters[out_state]}')
2✔
289
            for inputs in sma.input_matrix[i]:
2✔
290
                if inputs[0] == 0:
2✔
291
                    edges.append(f'{letter}{states_letters[states_indices[inputs[1]]]}')
2✔
292
        dot.edges(edges)
2✔
293
        if output_file is not None:
2✔
294
            try:
2✔
295
                dot.render(output_file, view=view)
2✔
296
            except graphviz.exceptions.ExecutableNotFound:
2✔
297
                log.info('Graphviz system executable not found, cannot render the graph')
2✔
298
        return dot
2✔
299

300
    def _instantiate_state_machine(self, *args, **kwargs):
2✔
301
        return StateMachine(self.bpod)
2✔
302

303
    def get_state_machine_trial(self, i):
2✔
304
        # we define the trial number here for subclasses that may need it
305
        sma = self._instantiate_state_machine(trial_number=i)
2✔
306

307
        if i == 0:  # First trial exception start camera
2✔
308
            session_delay_start = self.task_params.get('SESSION_DELAY_START', 0)
2✔
309
            log.info('First trial initializing, will move to next trial only if:')
2✔
310
            log.info('1. camera is detected')
2✔
311
            log.info(f'2. {session_delay_start} sec have elapsed')
2✔
312
            sma.add_state(
2✔
313
                state_name='trial_start',
314
                state_timer=0,
315
                state_change_conditions={'Port1In': 'delay_initiation'},
316
                output_actions=[('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
317
            )  # start camera
318
            sma.add_state(
2✔
319
                state_name='delay_initiation',
320
                state_timer=session_delay_start,
321
                output_actions=[],
322
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
323
            )
324
        else:
325
            sma.add_state(
2✔
326
                state_name='trial_start',
327
                state_timer=0,  # ~100µs hardware irreducible delay
328
                state_change_conditions={'Tup': 'reset_rotary_encoder'},
329
                output_actions=[self.bpod.actions.stop_sound, ('BNC1', 255)],
330
            )  # stop all sounds
331

332
        # Reset the rotary encoder by sending the following opcodes via the modules serial interface
333
        # - 'Z' (ASCII 90): Set current rotary encoder position to zero
334
        # - 'E' (ASCII 69): Enable all position thresholds (that may have been disabled by a threshold-crossing)
335
        # cf. https://sanworks.github.io/Bpod_Wiki/serial-interfaces/rotary-encoder-module-serial-interface/
336
        sma.add_state(
2✔
337
            state_name='reset_rotary_encoder',
338
            state_timer=0,
339
            output_actions=[self.bpod.actions.rotary_encoder_reset],
340
            state_change_conditions={'Tup': 'quiescent_period'},
341
        )
342

343
        # Quiescent Period. If the wheel is moved past one of the thresholds: Reset the rotary encoder and start over.
344
        # Continue with the stimulation once the quiescent period has passed without triggering movement thresholds.
345
        sma.add_state(
2✔
346
            state_name='quiescent_period',
347
            state_timer=self.quiescent_period,
348
            output_actions=[],
349
            state_change_conditions={
350
                'Tup': 'stim_on',
351
                self.movement_left: 'reset_rotary_encoder',
352
                self.movement_right: 'reset_rotary_encoder',
353
            },
354
        )
355

356
        # Show the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
357
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
358
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
359
        sma.add_state(
2✔
360
            state_name='stim_on',
361
            state_timer=0.1,
362
            output_actions=[self.bpod.actions.bonsai_show_stim],
363
            state_change_conditions={'BNC1High': 'interactive_delay', 'BNC1Low': 'interactive_delay', 'Tup': 'interactive_delay'},
364
        )
365

366
        # Defined delay between visual and auditory cue
367
        sma.add_state(
2✔
368
            state_name='interactive_delay',
369
            state_timer=self.task_params.INTERACTIVE_DELAY,
370
            output_actions=[],
371
            state_change_conditions={'Tup': 'play_tone'},
372
        )
373

374
        # Play tone. Move to next state if sound is detected. Use the state-timer as a backup to prevent a stall.
375
        sma.add_state(
2✔
376
            state_name='play_tone',
377
            state_timer=0.1,
378
            output_actions=[self.bpod.actions.play_tone],
379
            state_change_conditions={'Tup': 'reset2_rotary_encoder', 'BNC2High': 'reset2_rotary_encoder'},
380
        )
381

382
        # Reset rotary encoder (see above). Move on after brief delay (to avoid a race conditions in the bonsai flow).
383
        sma.add_state(
2✔
384
            state_name='reset2_rotary_encoder',
385
            state_timer=0.05,
386
            output_actions=[self.bpod.actions.rotary_encoder_reset],
387
            state_change_conditions={'Tup': 'closed_loop'},
388
        )
389

390
        # Start the closed loop state in which the animal controls the position of the visual stimulus by means of the
391
        # rotary encoder. The three possible outcomes are:
392
        # 1) wheel has NOT been moved past a threshold: continue with no-go condition
393
        # 2) wheel has been moved in WRONG direction: continue with error condition
394
        # 3) wheel has been moved in CORRECT direction: continue with reward condition
395

396
        sma.add_state(
2✔
397
            state_name='closed_loop',
398
            state_timer=self.task_params.RESPONSE_WINDOW,
399
            output_actions=[self.bpod.actions.bonsai_closed_loop],
400
            state_change_conditions={'Tup': 'no_go', self.event_error: 'freeze_error', self.event_reward: 'freeze_reward'},
401
        )
402

403
        # No-go: hide the visual stimulus and play white noise. Go to exit_state after FEEDBACK_NOGO_DELAY_SECS.
404
        sma.add_state(
2✔
405
            state_name='no_go',
406
            state_timer=self.feedback_nogo_delay,
407
            output_actions=[self.bpod.actions.bonsai_hide_stim, self.bpod.actions.play_noise],
408
            state_change_conditions={'Tup': 'exit_state'},
409
        )
410

411
        # Error: Freeze the stimulus and play white noise.
412
        # Continue to hide_stim/exit_state once FEEDBACK_ERROR_DELAY_SECS have passed.
413
        sma.add_state(
2✔
414
            state_name='freeze_error',
415
            state_timer=0,
416
            output_actions=[self.bpod.actions.bonsai_freeze_stim],
417
            state_change_conditions={'Tup': 'error'},
418
        )
419
        sma.add_state(
2✔
420
            state_name='error',
421
            state_timer=self.feedback_error_delay,
422
            output_actions=[self.bpod.actions.play_noise],
423
            state_change_conditions={'Tup': 'hide_stim'},
424
        )
425

426
        # Reward: open the valve for a defined duration (and set BNC1 to high), freeze stimulus in center of screen.
427
        # Continue to hide_stim/exit_state once FEEDBACK_CORRECT_DELAY_SECS have passed.
428
        sma.add_state(
2✔
429
            state_name='freeze_reward',
430
            state_timer=0,
431
            output_actions=[self.bpod.actions.bonsai_freeze_center],
432
            state_change_conditions={'Tup': 'reward'},
433
        )
434
        sma.add_state(
2✔
435
            state_name='reward',
436
            state_timer=self.reward_time,
437
            output_actions=[('Valve1', 255), ('BNC1', 255)],
438
            state_change_conditions={'Tup': 'correct'},
439
        )
440
        sma.add_state(
2✔
441
            state_name='correct',
442
            state_timer=self.feedback_correct_delay - self.reward_time,
443
            output_actions=[],
444
            state_change_conditions={'Tup': 'hide_stim'},
445
        )
446

447
        # Hide the visual stimulus. This is achieved by sending a time-stamped byte-message to Bonsai via the Rotary
448
        # Encoder Module's ongoing USB-stream. Move to the next state once the Frame2TTL has been triggered, i.e.,
449
        # when the stimulus has been rendered on screen. Use the state-timer as a backup to prevent a stall.
450
        sma.add_state(
2✔
451
            state_name='hide_stim',
452
            state_timer=0.1,
453
            output_actions=[self.bpod.actions.bonsai_hide_stim],
454
            state_change_conditions={'Tup': 'exit_state', 'BNC1High': 'exit_state', 'BNC1Low': 'exit_state'},
455
        )
456

457
        # Wait for ITI_DELAY_SECS before ending the trial. Raise BNC1 to mark this event.
458
        sma.add_state(
2✔
459
            state_name='exit_state',
460
            state_timer=self.task_params.ITI_DELAY_SECS,
461
            output_actions=[('BNC1', 255)],
462
            state_change_conditions={'Tup': 'exit'},
463
        )
464

465
        return sma
2✔
466

467
    @abc.abstractmethod
2✔
468
    def next_trial(self):
2✔
UNCOV
469
        pass
×
470

471
    @property
2✔
472
    def default_reward_amount(self):
2✔
473
        return self.task_params.REWARD_AMOUNT_UL
2✔
474

475
    def draw_next_trial_info(self, pleft=0.5, **kwargs):
2✔
476
        """Draw next trial variables.
477

478
        calls :meth:`send_trial_info_to_bonsai`.
479
        This is called by the `next_trial` method before updating the Bpod state machine.
480
        """
481
        assert len(self.task_params.STIM_POSITIONS) == 2, 'Only two positions are supported'
2✔
482
        contrast = misc.draw_contrast(self.task_params.CONTRAST_SET, self.task_params.CONTRAST_SET_PROBABILITY_TYPE)
2✔
483
        position = int(np.random.choice(self.task_params.STIM_POSITIONS, p=[pleft, 1 - pleft]))
2✔
484
        quiescent_period = self.task_params.QUIESCENT_PERIOD + misc.truncated_exponential(
2✔
485
            scale=0.35, min_value=0.2, max_value=0.5
486
        )
487
        self.trials_table.at[self.trial_num, 'quiescent_period'] = quiescent_period
2✔
488
        self.trials_table.at[self.trial_num, 'contrast'] = contrast
2✔
489
        self.trials_table.at[self.trial_num, 'stim_phase'] = random.uniform(0, 2 * math.pi)
2✔
490
        self.trials_table.at[self.trial_num, 'stim_sigma'] = self.task_params.STIM_SIGMA
2✔
491
        self.trials_table.at[self.trial_num, 'stim_angle'] = self.task_params.STIM_ANGLE
2✔
492
        self.trials_table.at[self.trial_num, 'stim_gain'] = self.stimulus_gain
2✔
493
        self.trials_table.at[self.trial_num, 'stim_freq'] = self.task_params.STIM_FREQ
2✔
494
        self.trials_table.at[self.trial_num, 'stim_reverse'] = self.task_params.STIM_REVERSE
2✔
495
        self.trials_table.at[self.trial_num, 'trial_num'] = self.trial_num
2✔
496
        self.trials_table.at[self.trial_num, 'position'] = position
2✔
497
        self.trials_table.at[self.trial_num, 'reward_amount'] = self.default_reward_amount
2✔
498
        self.trials_table.at[self.trial_num, 'stim_probability_left'] = pleft
2✔
499

500
        # use the kwargs dict to override computed values
501
        for key, value in kwargs.items():
2✔
502
            if key == 'index':
2✔
UNCOV
503
                pass
×
504
            self.trials_table.at[self.trial_num, key] = value
2✔
505

506
        self.send_trial_info_to_bonsai()
2✔
507

508
    def trial_completed(self, bpod_data: dict[str, Any]) -> None:
2✔
509
        # if the reward state has not been triggered, null the reward
510
        if np.isnan(bpod_data['States timestamps']['reward'][0][0]):
2✔
511
            self.trials_table.at[self.trial_num, 'reward_amount'] = 0
2✔
512
        self.trials_table.at[self.trial_num, 'reward_valve_time'] = self.reward_time
2✔
513
        # update cumulative reward value
514
        self.session_info.TOTAL_WATER_DELIVERED += self.trials_table.at[self.trial_num, 'reward_amount']
2✔
515
        self.session_info.NTRIALS += 1
2✔
516
        # SAVE TRIAL DATA
517
        self.save_trial_data_to_json(bpod_data)
2✔
518

519
        # save ambient data
520
        if self.hardware_settings.device_bpod.USE_AMBIENT_MODULE:
2✔
521
            self.ambient_sensor_table.iloc[self.trial_num] = (sensor_reading := self.bpod.get_ambient_sensor_reading())
2✔
522
            with self.paths['AMBIENT_FILE_PATH'].open('ab') as f:
2✔
523
                binary.write_array(f, [self.trial_num, *sensor_reading], DTYPE_AMBIENT_SENSOR_BIN)
2✔
524

525
        # this is a flag for the online plots. If online plots were in pyqt5, there is a file watcher functionality
526
        Path(self.paths['DATA_FILE_PATH']).parent.joinpath('new_trial.flag').touch()
2✔
527
        self.paths.SESSION_FOLDER.joinpath('transfer_me.flag').touch()
2✔
528
        self.check_sync_pulses(bpod_data=bpod_data)
2✔
529

530
    def check_sync_pulses(self, bpod_data):
2✔
531
        # todo move this in the post trial when we have a task flow
532
        if not self.bpod.is_connected:
2✔
533
            return
2✔
UNCOV
534
        events = bpod_data['Events timestamps']
×
UNCOV
535
        if not misc.get_port_events(events, name='BNC1'):
×
UNCOV
536
            log.warning("NO FRAME2TTL PULSES RECEIVED ON BPOD'S TTL INPUT 1")
×
UNCOV
537
        if not misc.get_port_events(events, name='BNC2'):
×
UNCOV
538
            log.warning("NO SOUND SYNC PULSES RECEIVED ON BPOD'S TTL INPUT 2")
×
UNCOV
539
        if not misc.get_port_events(events, name='Port1'):
×
UNCOV
540
            log.warning("NO CAMERA SYNC PULSES RECEIVED ON BPOD'S BEHAVIOR PORT 1")
×
541

542
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
543
        """
544
        Log the details of the current trial.
545

546
        This method retrieves information about the current trial from the
547
        trials table and logs it. It can also incorporate additional information
548
        provided through the `extra_info` parameter.
549

550
        Parameters
551
        ----------
552
        extra_info : dict[str, Any], optional
553
            A dictionary containing additional information to include in the
554
            log.
555

556
        log_level : int, optional
557
            The logging level to use when logging the trial information.
558
            Default is logging.INFO.
559

560
        Notes
561
        -----
562
        When overloading, make sure to call the super class and pass additional
563
        log items by means of the extra_info parameter. See the implementation
564
        of :py:meth:`~iblrig.base_choice_world.ActiveChoiceWorldSession.show_trial_log` in
565
        :mod:`~iblrig.base_choice_world.ActiveChoiceWorldSession` for reference.
566
        """
567
        # construct base info dict
568
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
569
        info_dict = {
2✔
570
            'Stim. Position': trial_info.position,
571
            'Stim. Contrast': trial_info.contrast,
572
            'Stim. Phase': f'{trial_info.stim_phase:.2f}',
573
            'Stim. p Left': trial_info.stim_probability_left,
574
            'Water delivered': f'{self.session_info.TOTAL_WATER_DELIVERED:.1f} µl',
575
            'Time from Start': self.time_elapsed,
576
            'Temperature': f'{self.ambient_sensor_table.loc[self.trial_num, "Temperature_C"]:.1f} °C',
577
            'Air Pressure': f'{self.ambient_sensor_table.loc[self.trial_num, "AirPressure_mb"]:.1f} mb',
578
            'Rel. Humidity': f'{self.ambient_sensor_table.loc[self.trial_num, "RelativeHumidity"]:.1f} %',
579
        }
580

581
        # update info dict with extra_info dict
582
        if isinstance(extra_info, dict):
2✔
583
            info_dict.update(extra_info)
2✔
584

585
        # log info dict
586
        log.log(log_level, f'Outcome of Trial #{trial_info.trial_num}:')
2✔
587
        max_key_length = max(len(key) for key in info_dict)
2✔
588
        for key, value in info_dict.items():
2✔
589
            spaces = (max_key_length - len(key)) * ' '
2✔
590
            log.log(log_level, f'- {key}: {spaces}{str(value)}')
2✔
591

592
    @property
2✔
593
    def iti_reward(self):
2✔
594
        """
595
        Returns the ITI time that needs to be set in order to achieve the desired ITI,
596
        by subtracting the time it takes to give a reward from the desired ITI.
597
        """
UNCOV
598
        return self.task_params.ITI_CORRECT - self.calibration.get('REWARD_VALVE_TIME', None)
×
599

600
    """
2✔
601
    Those are the properties that are used in the state machine code
602
    """
603

604
    @property
2✔
605
    def reward_time(self):
2✔
606
        return self.compute_reward_time(amount_ul=self.trials_table.at[self.trial_num, 'reward_amount'])
2✔
607

608
    @property
2✔
609
    def quiescent_period(self):
2✔
610
        return self.trials_table.at[self.trial_num, 'quiescent_period']
2✔
611

612
    @property
2✔
613
    def feedback_correct_delay(self):
2✔
614
        return self.task_params['FEEDBACK_CORRECT_DELAY_SECS']
2✔
615

616
    @property
2✔
617
    def feedback_error_delay(self):
2✔
618
        return self.task_params['FEEDBACK_ERROR_DELAY_SECS']
2✔
619

620
    @property
2✔
621
    def feedback_nogo_delay(self):
2✔
622
        return self.task_params['FEEDBACK_NOGO_DELAY_SECS']
2✔
623

624
    @property
2✔
625
    def position(self):
2✔
626
        return self.trials_table.at[self.trial_num, 'position']
2✔
627

628
    @property
2✔
629
    def event_error(self):
2✔
630
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(-1 if self.task_params.STIM_REVERSE else 1) * self.position]
2✔
631

632
    @property
2✔
633
    def event_reward(self):
2✔
634
        return self.device_rotary_encoder.THRESHOLD_EVENTS[(1 if self.task_params.STIM_REVERSE else -1) * self.position]
2✔
635

636

637
class HabituationChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
638
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
639

640
    delay_to_stim_center: NonNegativeFloat
2✔
641

642

643
class HabituationChoiceWorldSession(ChoiceWorldSession):
2✔
644
    protocol_name = '_iblrig_tasks_habituationChoiceWorld'
2✔
645
    TrialDataModel = HabituationChoiceWorldTrialData
2✔
646

647
    def next_trial(self):
2✔
648
        self.trial_num += 1
2✔
649
        self.draw_next_trial_info()
2✔
650

651
    def draw_next_trial_info(self, *args, **kwargs):
2✔
652
        # update trial table fields specific to habituation choice world
653
        self.trials_table.at[self.trial_num, 'delay_to_stim_center'] = np.random.normal(self.task_params.DELAY_TO_STIM_CENTER, 2)
2✔
654
        super().draw_next_trial_info(*args, **kwargs)
2✔
655

656
    def get_state_machine_trial(self, i):
2✔
657
        sma = StateMachine(self.bpod)
2✔
658

659
        if i == 0:  # First trial exception start camera
2✔
660
            log.info('Waiting for camera pulses...')
2✔
661
            sma.add_state(
2✔
662
                state_name='iti',
663
                state_timer=3600,
664
                state_change_conditions={'Port1In': 'stim_on'},
665
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('SoftCode', SOFTCODE.TRIGGER_CAMERA), ('BNC1', 255)],
666
            )  # start camera
667
        else:
668
            # NB: This state actually the inter-trial interval, i.e. the period of grey screen between stim off and stim on.
669
            # During this period the Bpod TTL is HIGH and there are no stimuli. The onset of this state is trial end;
670
            # the offset of this state is trial start!
UNCOV
671
            sma.add_state(
×
672
                state_name='iti',
673
                state_timer=1,  # Stim off for 1 sec
674
                state_change_conditions={'Tup': 'stim_on'},
675
                output_actions=[self.bpod.actions.bonsai_hide_stim, ('BNC1', 255)],
676
            )
677
        # This stim_on state is considered the actual trial start
678
        sma.add_state(
2✔
679
            state_name='stim_on',
680
            state_timer=self.trials_table.at[self.trial_num, 'delay_to_stim_center'],
681
            state_change_conditions={'Tup': 'stim_center'},
682
            output_actions=[self.bpod.actions.bonsai_show_stim, self.bpod.actions.play_tone],
683
        )
684

685
        sma.add_state(
2✔
686
            state_name='stim_center',
687
            state_timer=0.5,
688
            state_change_conditions={'Tup': 'reward'},
689
            output_actions=[self.bpod.actions.bonsai_show_center],
690
        )
691

692
        sma.add_state(
2✔
693
            state_name='reward',
694
            state_timer=self.reward_time,  # the length of time to leave reward valve open, i.e. reward size
695
            state_change_conditions={'Tup': 'post_reward'},
696
            output_actions=[('Valve1', 255), ('BNC1', 255)],
697
        )
698
        # This state defines the period after reward where Bpod TTL is LOW.
699
        # NB: The stimulus is on throughout this period. The stim off trigger occurs upon exit.
700
        # The stimulus thus remains in the screen centre for 0.5 + ITI_DELAY_SECS seconds.
701
        sma.add_state(
2✔
702
            state_name='post_reward',
703
            state_timer=self.task_params.ITI_DELAY_SECS - self.reward_time,
704
            state_change_conditions={'Tup': 'exit'},
705
            output_actions=[],
706
        )
707
        return sma
2✔
708

709

710
class ActiveChoiceWorldTrialData(ChoiceWorldTrialData):
2✔
711
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
712

713
    response_side: Annotated[int, Interval(ge=-1, le=1)]
2✔
714
    response_time: NonNegativeFloat
2✔
715
    trial_correct: bool
2✔
716

717

718
class ActiveChoiceWorldSession(ChoiceWorldSession):
2✔
719
    """
720
    The ActiveChoiceWorldSession is a base class for protocols where the mouse is actively making decisions
721
    by turning the wheel. It has the following characteristics
722

723
    -   it is trial based
724
    -   it is decision based
725
    -   left and right simulus are equiprobable: there is no biased block
726
    -   a trial can either be correct / error / no_go depending on the side of the stimulus and the response
727
    -   it has a quantifiable performance by computing the proportion of correct trials of passive stimulations protocols or
728
        habituation protocols.
729

730
    The TrainingChoiceWorld, BiasedChoiceWorld are all subclasses of this class
731
    """
732

733
    TrialDataModel = ActiveChoiceWorldTrialData
2✔
734

735
    def __init__(self, **kwargs):
2✔
736
        super().__init__(**kwargs)
2✔
737
        self.trials_table['stim_probability_left'] = np.zeros(NTRIALS_INIT, dtype=np.float64)
2✔
738

739
    def _run(self):
2✔
740
        # starts online plotting
741
        if self.interactive:
2✔
UNCOV
742
            subprocess.Popen(
×
743
                ['view_session', str(self.paths['DATA_FILE_PATH']), str(self.paths['SETTINGS_FILE_PATH'])],
744
                stdout=subprocess.DEVNULL,
745
                stderr=subprocess.STDOUT,
746
            )
747
        super()._run()
2✔
748

749
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
750
        # construct info dict
751
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
752
        info_dict = {
2✔
753
            'Response Time': f'{trial_info.response_time:.2f} s',
754
            'Trial Correct': trial_info.trial_correct,
755
            'N Trials Correct': self.session_info.NTRIALS_CORRECT,
756
            'N Trials Error': self.trial_num - self.session_info.NTRIALS_CORRECT + 1,
757
        }
758

759
        # update info dict with extra_info dict
760
        if isinstance(extra_info, dict):
2✔
761
            info_dict.update(extra_info)
2✔
762

763
        # call parent method
764
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
765

766
    def trial_completed(self, bpod_data: dict) -> None:
2✔
767
        """
768
        Update the trials table with information about the behaviour coming from the bpod.
769

770
        Constraints on the state machine data:
771

772
        - mandatory states: ['correct', 'error', 'no_go', 'reward']
773
        - optional states : ['omit_correct', 'omit_error', 'omit_no_go']
774

775
        Parameters
776
        ----------
777
        bpod_data : dict
778
            The Bpod data as returned by pybpod
779

780
        Raises
781
        ------
782
        AssertionError
783
            If the position is zero or if the number of detected outcomes is not exactly one.
784
        """
785
        # Get the response time from the behaviour data.
786
        # It is defined as the time passing between the start of `stim_on` and the end of `closed_loop`.
787
        state_times = bpod_data['States timestamps']
2✔
788
        response_time = state_times['closed_loop'][0][1] - state_times['stim_on'][0][0]
2✔
789
        self.trials_table.at[self.trial_num, 'response_time'] = response_time
2✔
790

791
        try:
2✔
792
            # Get the stimulus position
793
            position = self.trials_table.at[self.trial_num, 'position']
2✔
794
            assert position != 0, 'the stimulus position should not be 0'
2✔
795

796
            # Get the trial's outcome, i.e., the states that have a matching name and a valid time-stamp
797
            # Assert that we have exactly one outcome
798
            outcome_names = ['correct', 'error', 'no_go', 'omit_correct', 'omit_error', 'omit_no_go']
2✔
799
            outcomes = [name for name, times in state_times.items() if name in outcome_names and ~np.isnan(times[0][0])]
2✔
800
            if (n_outcomes := len(outcomes)) != 1:
2✔
UNCOV
801
                trial_states = 'Trial states: ' + ', '.join(k for k, v in state_times.items() if ~np.isnan(v[0][0]))
×
UNCOV
802
                assert n_outcomes != 0, f'No outcome detected for trial {self.trial_num}.\n{trial_states}'
×
UNCOV
803
                assert n_outcomes == 1, f'{n_outcomes} outcomes detected for trial {self.trial_num}.\n{trial_states}'
×
804
            outcome = outcomes[0]
2✔
805

UNCOV
806
        except AssertionError as e:
×
807
            # write bpod_data to disk, log exception then raise
UNCOV
808
            self.save_trial_data_to_json(bpod_data, validate=False)
×
UNCOV
809
            for line in re_split(r'\n', e.args[0]):
×
UNCOV
810
                log.error(line)
×
UNCOV
811
            raise e
×
812

813
        # record the trial's outcome in the trials_table
814
        self.trials_table.at[self.trial_num, 'trial_correct'] = 'correct' in outcome
2✔
815
        if 'correct' in outcome:
2✔
816
            self.session_info.NTRIALS_CORRECT += 1
2✔
817
            self.trials_table.at[self.trial_num, 'response_side'] = -np.sign(position)
2✔
818
        elif 'error' in outcome:
2✔
819
            self.trials_table.at[self.trial_num, 'response_side'] = np.sign(position)
2✔
820
        elif 'no_go' in outcome:
2✔
821
            self.trials_table.at[self.trial_num, 'response_side'] = 0
2✔
822

823
        super().trial_completed(bpod_data)
2✔
824

825

826
class BiasedChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
827
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ChoiceWorldTrialData`."""
828

829
    block_num: NonNegativeInt = 0
2✔
830
    block_trial_num: NonNegativeInt = 0
2✔
831

832

833
class BiasedChoiceWorldSession(ActiveChoiceWorldSession):
2✔
834
    """
835
    Biased choice world session is the instantiation of ActiveChoiceWorld where the notion of biased
836
    blocks is introduced.
837
    """
838

839
    base_parameters_file = Path(__file__).parent.joinpath('base_biased_choice_world_params.yaml')
2✔
840
    protocol_name = '_iblrig_tasks_biasedChoiceWorld'
2✔
841
    TrialDataModel = BiasedChoiceWorldTrialData
2✔
842

843
    def __init__(self, **kwargs):
2✔
844
        super().__init__(**kwargs)
2✔
845
        self.blocks_table = pd.DataFrame(
2✔
846
            {'probability_left': np.zeros(NBLOCKS_INIT) * np.nan, 'block_length': np.zeros(NBLOCKS_INIT, dtype=np.int16) * -1}
847
        )
848

849
    def new_block(self):
2✔
850
        """
851
        If block_init_5050
852
            First block has 50/50 probability of leftward stim
853
            is 90 trials long
854
        """
855
        self.block_num += 1  # the block number is zero based
2✔
856
        self.block_trial_num = 0
2✔
857

858
        # handles the block length logic
859
        if self.task_params.BLOCK_INIT_5050 and self.block_num == 0:
2✔
860
            block_len = 90
2✔
861
        else:
862
            block_len = int(
2✔
863
                misc.truncated_exponential(
864
                    scale=self.task_params.BLOCK_LEN_FACTOR,
865
                    min_value=self.task_params.BLOCK_LEN_MIN,
866
                    max_value=self.task_params.BLOCK_LEN_MAX,
867
                )
868
            )
869
        if self.block_num == 0:
2✔
870
            pleft = 0.5 if self.task_params.BLOCK_INIT_5050 else np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
871
        elif self.block_num == 1 and self.task_params.BLOCK_INIT_5050:
2✔
872
            pleft = np.random.choice(self.task_params.BLOCK_PROBABILITY_SET)
2✔
873
        else:
874
            # this switches the probability of leftward stim for the next block
875
            pleft = round(abs(1 - self.blocks_table.loc[self.block_num - 1, 'probability_left']), 1)
2✔
876
        self.blocks_table.at[self.block_num, 'block_length'] = block_len
2✔
877
        self.blocks_table.at[self.block_num, 'probability_left'] = pleft
2✔
878

879
    def next_trial(self):
2✔
880
        self.trial_num += 1
2✔
881
        # if necessary update the block number
882
        self.block_trial_num += 1
2✔
883
        if self.block_num < 0 or self.block_trial_num > (self.blocks_table.loc[self.block_num, 'block_length'] - 1):
2✔
884
            self.new_block()
2✔
885
        # get and store probability left
886
        pleft = self.blocks_table.loc[self.block_num, 'probability_left']
2✔
887
        # update trial table fields specific to biased choice world task
888
        self.trials_table.at[self.trial_num, 'block_num'] = self.block_num
2✔
889
        self.trials_table.at[self.trial_num, 'block_trial_num'] = self.block_trial_num
2✔
890
        # save and send trial info to bonsai
891
        self.draw_next_trial_info(pleft=pleft)
2✔
892

893
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
894
        # construct info dict
895
        trial_info = self.trials_table.iloc[self.trial_num]
2✔
896
        info_dict = {
2✔
897
            'Block Number': trial_info.block_num,
898
            'Block Length': self.blocks_table.loc[self.block_num, 'block_length'],
899
            'N Trials in Block': trial_info.block_trial_num,
900
        }
901

902
        # update info dict with extra_info dict
903
        if isinstance(extra_info, dict):
2✔
UNCOV
904
            info_dict.update(extra_info)
×
905

906
        # call parent method
907
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
2✔
908

909

910
class TrainingChoiceWorldTrialData(ActiveChoiceWorldTrialData):
2✔
911
    """Pydantic Model for Trial Data, extended from :class:`~.iblrig.base_choice_world.ActiveChoiceWorldTrialData`."""
912

913
    training_phase: NonNegativeInt
2✔
914
    debias_trial: bool
2✔
915

916

917
class TrainingChoiceWorldSession(ActiveChoiceWorldSession):
2✔
918
    """
919
    The TrainingChoiceWorldSession corresponds to the first training protocol of the choice world task.
920
    This protocol has a complicated adaptation of the number of contrasts (embodied by the training_phase
921
    property) and the reward amount, embodied by the adaptive_reward property.
922
    """
923

924
    protocol_name = '_iblrig_tasks_trainingChoiceWorld'
2✔
925
    TrialDataModel = TrainingChoiceWorldTrialData
2✔
926

927
    def __init__(self, training_phase=-1, adaptive_reward=-1.0, adaptive_gain=None, **kwargs):
2✔
928
        super().__init__(**kwargs)
2✔
929
        inferred_training_phase, inferred_adaptive_reward, inferred_adaptive_gain = self.get_subject_training_info()
2✔
930
        if training_phase == -1:
2✔
931
            log.critical(f'Got training phase: {inferred_training_phase}')
2✔
932
            self.training_phase = inferred_training_phase
2✔
933
        else:
934
            log.critical(f'Training phase manually set to: {training_phase}')
2✔
935
            self.training_phase = training_phase
2✔
936
        if adaptive_reward == -1:
2✔
937
            log.critical(f'Got Adaptive reward {inferred_adaptive_reward} uL')
2✔
938
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = inferred_adaptive_reward
2✔
939
        else:
940
            log.critical(f'Adaptive reward manually set to {adaptive_reward} uL')
2✔
941
            self.session_info['ADAPTIVE_REWARD_AMOUNT_UL'] = adaptive_reward
2✔
942
        if adaptive_gain is None:
2✔
943
            log.critical(f'Got Adaptive gain {inferred_adaptive_gain} degrees/mm')
2✔
944
            self.session_info['ADAPTIVE_GAIN_VALUE'] = inferred_adaptive_gain
2✔
945
        else:
946
            log.critical(f'Adaptive gain manually set to {adaptive_gain} degrees/mm')
2✔
947
            self.session_info['ADAPTIVE_GAIN_VALUE'] = adaptive_gain
2✔
948
        self.var = {'training_phase_trial_counts': np.zeros(6), 'last_10_responses_sides': np.zeros(10)}
2✔
949

950
    @property
2✔
951
    def default_reward_amount(self):
2✔
952
        return self.session_info.get('ADAPTIVE_REWARD_AMOUNT_UL', self.task_params.REWARD_AMOUNT_UL)
2✔
953

954
    @property
2✔
955
    def stimulus_gain(self) -> float:
2✔
956
        return self.session_info.get('ADAPTIVE_GAIN_VALUE')
2✔
957

958
    def get_subject_training_info(self):
2✔
959
        """
960
        Get the previous session's according to this session parameters and deduce the
961
        training level, adaptive reward amount and adaptive gain value.
962

963
        Returns
964
        -------
965
        training_info: dict
966
            Dictionary with keys: training_phase, adaptive_reward, adaptive_gain
967
        """
968
        training_info, _ = choiceworld.get_subject_training_info(
2✔
969
            subject_name=self.session_info.SUBJECT_NAME,
970
            task_name=self.protocol_name,
971
            stim_gain=self.task_params.AG_INIT_VALUE,
972
            stim_gain_on_error=self.task_params.STIM_GAIN,
973
            default_reward=self.task_params.REWARD_AMOUNT_UL,
974
            local_path=self.iblrig_settings['iblrig_local_data_path'],
975
            remote_path=self.iblrig_settings['iblrig_remote_data_path'],
976
            lab=self.iblrig_settings['ALYX_LAB'],
977
            iblrig_settings=self.iblrig_settings,
978
        )
979
        return training_info['training_phase'], training_info['adaptive_reward'], training_info['adaptive_gain']
2✔
980

981
    def check_training_phase(self) -> bool:
2✔
982
        """Check if the mouse is ready to move to the next training phase."""
983
        move_on = False
2✔
984
        if self.training_phase == 0:  # each of the -1, -.5, .5, 1 contrast should be above 80% perf to switch
2✔
985
            performance = choiceworld.compute_performance(self.trials_table)
2✔
986
            passing = performance[np.abs(performance.index) >= 0.5]['last_50_perf']
2✔
987
            if np.all(passing > 0.8) and passing.size == 4:
2✔
988
                move_on = True
2✔
989
        elif self.training_phase == 1:  # each of the -.25, .25 should be above 80% perf to switch
2✔
990
            performance = choiceworld.compute_performance(self.trials_table)
2✔
991
            passing = performance[np.abs(performance.index) == 0.25]['last_50_perf']
2✔
992
            if np.all(passing > 0.8) and passing.size == 2:
2✔
993
                move_on = True
2✔
994
        elif 5 > self.training_phase >= 2:  # for the next phases, always switch after 200 trials
2✔
995
            if self.var['training_phase_trial_counts'][self.training_phase] >= 200:
2✔
996
                move_on = True
2✔
997
        if move_on:
2✔
998
            self.training_phase = np.minimum(5, self.training_phase + 1)
2✔
999
            log.warning(f'Moving on to training phase {self.training_phase}, {self.trial_num}')
2✔
1000
        return move_on
2✔
1001

1002
    def next_trial(self):
2✔
1003
        # update counters
1004
        self.trial_num += 1
2✔
1005
        self.var['training_phase_trial_counts'][self.training_phase] += 1
2✔
1006

1007
        # check if the subject graduates to a new training phase
1008
        self.check_training_phase()
2✔
1009

1010
        # draw the next trial
1011
        signed_contrast = choiceworld.draw_training_contrast(self.training_phase)
2✔
1012
        position = self.task_params.STIM_POSITIONS[int(np.sign(signed_contrast) == 1)]
2✔
1013
        contrast = np.abs(signed_contrast)
2✔
1014

1015
        # debiasing: if the previous trial was incorrect, not a no-go and easy
1016
        if self.task_params.DEBIAS and self.trial_num >= 1 and self.training_phase < 5:
2✔
1017
            last_contrast = self.trials_table.loc[self.trial_num - 1, 'contrast']
2✔
1018
            do_debias_trial = (
2✔
1019
                (self.trials_table.loc[self.trial_num - 1, 'trial_correct'] != 1)
1020
                and (self.trials_table.loc[self.trial_num - 1, 'response_side'] != 0)
1021
                and last_contrast >= 0.5
1022
            )
1023
            self.trials_table.at[self.trial_num, 'debias_trial'] = do_debias_trial
2✔
1024
            if do_debias_trial:
2✔
1025
                # indices of trials that had a response
1026
                iresponse = np.logical_and(self.trials_table['response_side'].notna(), self.trials_table['response_side'] != 0)
2✔
1027
                iresponse = iresponse.index[iresponse]
2✔
1028

1029
                # takes the average of right responses over last 10 response trials
1030
                average_right = (self.trials_table['response_side'][iresponse[-np.minimum(10, iresponse.size) :]] == 1).mean()
2✔
1031

1032
                # the probability of the next stimulus being on the left is a draw from a normal distribution centered
1033
                # on the average right with sigma 0.5 - if it is less than 0.5 the next stimulus will be on the left.
1034
                position = self.task_params.STIM_POSITIONS[int(np.random.normal(average_right, 0.5) >= 0.5)]
2✔
1035

1036
                # contrast is the last contrast
1037
                contrast = last_contrast
2✔
1038
        else:
1039
            self.trials_table.at[self.trial_num, 'debias_trial'] = False
2✔
1040

1041
        # save and send trial info to bonsai
1042
        self.draw_next_trial_info(pleft=self.task_params.PROBABILITY_LEFT, position=position, contrast=contrast)
2✔
1043
        self.trials_table.at[self.trial_num, 'training_phase'] = self.training_phase
2✔
1044

1045
    def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO):
2✔
1046
        # construct info dict
UNCOV
1047
        info_dict = {
×
1048
            'Contrast Set': np.unique(np.abs(choiceworld.contrasts_set(self.training_phase))),
1049
            'Training Phase': self.training_phase,
1050
            'Debias Trial': self.trials_table.at[self.trial_num, 'debias_trial'],
1051
        }
1052

1053
        # update info dict with extra_info dict
UNCOV
1054
        if isinstance(extra_info, dict):
×
UNCOV
1055
            info_dict.update(extra_info)
×
1056

1057
        # call parent method
UNCOV
1058
        super().show_trial_log(extra_info=info_dict, log_level=log_level)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc