• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / iblrig / 15073834064

16 May 2025 05:16PM UTC coverage: 49.414% (+2.6%) from 46.79%
15073834064

Pull #750

github

c98309
web-flow
Merge 8e475a77c into e481532ae
Pull Request #750: Online plots

538 of 720 new or added lines in 3 files covered. (74.72%)

1000 existing lines in 20 files now uncovered.

4677 of 9465 relevant lines covered (49.41%)

0.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.63
/iblrig/gui/online_plots.py
1
import ctypes
1✔
2
import datetime
1✔
3
import json
1✔
4
import os
1✔
5
import sys
1✔
6
import time
1✔
7
from collections.abc import Iterable
1✔
8
from copy import copy
1✔
9
from dataclasses import dataclass
1✔
10
from pathlib import Path
1✔
11
from typing import Annotated, Any
1✔
12

13
import numpy as np
1✔
14
import pandas as pd
1✔
15
import pyqtgraph as pg
1✔
16
from pydantic import UUID4, AfterValidator, DirectoryPath, Field, FilePath, PlainSerializer, validate_call
1✔
17
from pydantic_settings import BaseSettings, CliPositionalArg
1✔
18
from qtpy.QtCore import (
1✔
19
    QCoreApplication,
20
    QFileSystemWatcher,
21
    QItemSelection,
22
    QModelIndex,
23
    QObject,
24
    QPoint,
25
    QRect,
26
    QRectF,
27
    QSettings,
28
    QSize,
29
    Qt,
30
    QThreadPool,
31
    Signal,
32
    Slot,
33
)
34
from qtpy.QtGui import QBrush, QColor, QFont, QGradient, QIcon, QLinearGradient, QPainter, QPixmap, QTransform
1✔
35
from qtpy.QtWidgets import (
1✔
36
    QApplication,
37
    QFileDialog,
38
    QFrame,
39
    QGraphicsRectItem,
40
    QGraphicsSceneHoverEvent,
41
    QGridLayout,
42
    QHeaderView,
43
    QLabel,
44
    QMainWindow,
45
    QSizePolicy,
46
    QStyledItemDelegate,
47
    QTableView,
48
    QVBoxLayout,
49
    QWidget,
50
)
51

52
from iblqt.core import DataFrameTableModel
1✔
53
from iblrig import __version__ as iblrig_version
1✔
54
from iblrig.choiceworld import get_subject_training_info
1✔
55
from iblrig.gui import resources_rc  # noqa: F401
1✔
56
from iblrig.gui.tools import Worker
1✔
57
from iblrig.misc import online_std
1✔
58
from iblrig.path_helper import get_local_and_remote_paths
1✔
59
from iblrig.raw_data_loaders import bpod_trial_data_to_dataframe, load_task_jsonable
1✔
60
from one.alf.spec import is_session_path
1✔
61
from one.api import ONE
1✔
62

63

64
def is_alf_path(value: Path) -> Path:
1✔
NEW
65
    if not is_session_path(value):
×
NEW
66
        raise ValueError('Field is not a session path')
×
NEW
67
    return value
×
68

69

70
SessionPath = Annotated[
1✔
71
    Path,
72
    AfterValidator(lambda x: is_alf_path(x)),
73
    PlainSerializer(lambda x: str(x), return_type=str),
74
]
75

76

77
@dataclass
1✔
78
class Colors:
1✔
79
    RED = '#eb5757'
1✔
80
    GREEN = '#57eb8b'
1✔
81
    YELLOW = '#ede34e'
1✔
82
    TRANSPARENT = 'transparent'
1✔
83

84

85
@dataclass
1✔
86
class EngagedCriterion:
1✔
87
    SECONDS = 45 * 60
1✔
88
    TRIAL_COUNT = 400
1✔
89

90

91
@dataclass
1✔
92
class DefaultSettings:
1✔
93
    CONTRAST_SET = np.array([0, 1 / 16, 1 / 8, 1 / 4, 1 / 2, 1])
1✔
94
    PROBABILITY_SET = np.array([0.2, 0.5, 0.8])
1✔
95

96

97
class PlotWidget(pg.PlotWidget):
1✔
98
    """PyQtGraph PlotWidget with tuned default settings."""
99

100
    def __init__(self, *args, **kwargs):
1✔
101
        super().__init__(*args, **kwargs)
1✔
102
        self.setBackground('white')
1✔
103
        self.plotItem.getViewBox().setBackgroundColor(pg.mkColor(250, 250, 250))
1✔
104
        self.plotItem.setMouseEnabled(x=False, y=False)
1✔
105
        self.plotItem.setMenuEnabled(False)
1✔
106
        self.plotItem.hideButtons()
1✔
107
        for axis in ('left', 'bottom'):
1✔
108
            self.plotItem.getAxis(axis).setTextPen('k')
1✔
109

110

111
class SingleBarChartWidget(PlotWidget):
1✔
112
    """A bar chart with a single column for use with PyQtGraph"""
113

114
    def __init__(self, *args, barColor: Any = 0.4, textColor: Any = 1.0, textFormat: str = '{:g}', **kwargs):
1✔
115
        super().__init__(*args, **kwargs)
1✔
116

117
        y_axis = self.plotItem.getAxis('left')
1✔
118
        y_axis.setWidth(40)
1✔
119
        y_axis.setGrid(128)
1✔
120

121
        x_axis = self.plotItem.getAxis('bottom')
1✔
122
        x_axis.setLabel(' ')
1✔
123
        x_axis.setTicks([[(1, ' ')], []])
1✔
124
        x_axis.setStyle(tickLength=0, tickAlpha=0)
1✔
125
        self.plotItem.setXRange(min=0, max=2, padding=0)
1✔
126

127
        gradient = QLinearGradient(0, 0, 0, 1)
1✔
128
        gradient.setCoordinateMode(QGradient.ObjectBoundingMode)
1✔
129
        gradient.setColorAt(0.9, pg.mkColor(barColor))
1✔
130
        gradient.setColorAt(0, pg.mkColor((255, 255, 255, 0)))
1✔
131
        self._barGraphItem = pg.BarGraphItem(x=1, width=2, height=0, pen=None, brush=QBrush(gradient))
1✔
132
        self.addItem(self._barGraphItem)
1✔
133

134
        self._textFormat = textFormat
1✔
135
        self._textItem = pg.TextItem('0', anchor=(0.5, 0), color=textColor)
1✔
136
        self._textItem.setX(1)
1✔
137
        self._textItem.setY(50)
1✔
138
        self.addItem(self._textItem)
1✔
139

140
    @Slot(float)
1✔
141
    def setValue(self, value: float):
1✔
142
        self._barGraphItem.setOpts(height=value)
1✔
143
        self._textItem.setText(self._textFormat.format(value))
1✔
144
        self._textItem.setY(value)
1✔
145

146

147
class FunctionWidget(PlotWidget):
1✔
148
    """A widget for psychometric and chronometric functions"""
149

150
    def __init__(self, *args, colors: pg.ColorMap, probabilities: Iterable[float], **kwargs):
1✔
151
        super().__init__(*args, **kwargs)
1✔
152
        self.plotItem.addItem(pg.InfiniteLine(0, 90, 'black'))
1✔
153
        for axis in ('left', 'bottom'):
1✔
154
            self.plotItem.getAxis(axis).setGrid(128)
1✔
155
            self.plotItem.getAxis(axis).setTextPen('k')
1✔
156
        self.plotItem.getAxis('bottom').setLabel('Signed Contrast')
1✔
157
        legend = pg.LegendItem(pen='lightgray', brush='w', offset=(45, 35), verSpacing=-5, labelTextColor='k')
1✔
158
        legend.setParentItem(self.plotItem.graphicsItem())
1✔
159
        legend.setZValue(1)
1✔
160
        self.plotDataItems = dict()
1✔
161
        self.upperCurves = dict()
1✔
162
        self.lowerCurves = dict()
1✔
163
        self.fillItems = dict()
1✔
164
        null_pen = pg.mkPen((0, 0, 0, 0))
1✔
165
        for idx, p in enumerate(probabilities):
1✔
166
            line_color = colors.getByIndex(idx)
1✔
167
            fill_color = copy(line_color)
1✔
168
            fill_color.setAlpha(32)
1✔
169
            self.upperCurves[p] = self.plotItem.plot(pen=null_pen)
1✔
170
            self.lowerCurves[p] = self.plotItem.plot(pen=null_pen)
1✔
171
            self.fillItems[p] = pg.FillBetweenItem(self.upperCurves[p], self.lowerCurves[p], brush=fill_color, pen=null_pen)
1✔
172
            self.addItem(self.fillItems[p])
1✔
173
            self.plotDataItems[p] = self.plotItem.plot(connect='all')
1✔
174
            self.plotDataItems[p].setData(x=[1, np.NAN], y=[np.NAN, 1])
1✔
175
            self.plotDataItems[p].setPen(pg.mkPen(color=line_color, width=2))
1✔
176
            self.plotDataItems[p].setSymbol('o')
1✔
177
            self.plotDataItems[p].setSymbolPen(line_color)
1✔
178
            self.plotDataItems[p].setSymbolBrush(line_color)
1✔
179
            self.plotDataItems[p].setSymbolSize(5)
1✔
180
            legend.addItem(self.plotDataItems[p], f'p = {p:0.1f}')
1✔
181

182

183
class TrialsTableModel(DataFrameTableModel):
1✔
184
    """A table model that displays status tips for entries in the trials table."""
185

186
    def data(self, index: QModelIndex, role: int = Qt.ItemDataRole.DisplayRole) -> Any | None:
1✔
187
        if index.isValid() and role == Qt.ItemDataRole.StatusTipRole:
1✔
NEW
188
            trial = index.siblingAtColumn(0).data()
×
NEW
189
            position = index.siblingAtColumn(1).data()
×
NEW
190
            contrast = index.siblingAtColumn(2).data() * 100
×
NEW
191
            debias = index.siblingAtColumn(3).data()
×
NEW
192
            outcome = index.siblingAtColumn(4).data()
×
NEW
193
            timing = index.siblingAtColumn(5).data()
×
NEW
194
            tip = (
×
195
                f'Trial {trial}: {contrast:g}% contrast / {abs(position):g}° {"right" if position > 0 else "left"} '
196
                f'{"/ debiasing " if debias else ""}/ {outcome}'
197
            )
NEW
198
            return tip + ('.' if outcome == 'no-go' else f' after {timing:0.2f} s.')
×
199
        if index.isValid() and index.column() == 0 and role == Qt.TextAlignmentRole:
1✔
200
            return Qt.AlignRight | Qt.AlignVCenter
1✔
201
        return super().data(index, role)
1✔
202

203

204
class TrialsTableView(QTableView):
1✔
205
    """A table view that shows a logarithmic x-grid in one column"""
206

207
    norm_min = 0.1
1✔
208
    norm_max = 102.0
1✔
209
    norm_div = np.log10(norm_max / norm_min)
1✔
210
    x_minor = [i / j for j in (10, 1, 0.1) for i in range(2, 10)]
1✔
211
    x_major = np.power(10.0, np.arange(-1, 3))
1✔
212
    color_minor = QColor(238, 238, 238)
1✔
213
    color_major = QColor(199, 199, 199)
1✔
214
    grid_col = 5
1✔
215

216
    def __init__(self, parent: QObject):
1✔
217
        super().__init__(parent)
1✔
218
        self.setMouseTracking(True)
1✔
219
        # self.setVerticalScrollMode(QAbstractItemView.ScrollPerPixel)
220
        self.verticalHeader().hide()
1✔
221
        self.horizontalHeader().hide()
1✔
222
        self.horizontalHeader().setDefaultAlignment(Qt.AlignLeft)
1✔
223
        self.horizontalHeader().setSectionResizeMode(QHeaderView.Fixed)
1✔
224
        self.horizontalHeader().setStretchLastSection(True)
1✔
225
        self.setStyleSheet(
1✔
226
            'QHeaderView::section { border: none; background-color: white; }'
227
            'QTableView::item:selected { color: black; selection-background-color: rgba(0, 0, 0, 6%); }'
228
            'QTableView { background-color: rgba(0, 0, 0, 3%); }'
229
        )
230
        self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
1✔
231
        self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff)
1✔
232
        self.stimulusDelegate = StimulusDelegate()
1✔
233
        self.responseTimeDelegate = ResponseTimeDelegate()
1✔
234
        self.setItemDelegateForColumn(1, self.stimulusDelegate)
1✔
235
        self.setItemDelegateForColumn(5, self.responseTimeDelegate)
1✔
236
        self.setShowGrid(False)
1✔
237
        self.setFrameShape(QTableView.NoFrame)
1✔
238
        self.setFocusPolicy(Qt.FocusPolicy.NoFocus)
1✔
239
        self.setSelectionMode(QTableView.SingleSelection)
1✔
240
        self.setSelectionBehavior(QTableView.SelectRows)
1✔
241

242
    def paintEvent(self, event):
1✔
NEW
243
        viewport_pos = self.columnViewportPosition(self.grid_col)
×
NEW
244
        col_width = self.columnWidth(self.grid_col)
×
NEW
245
        painter = QPainter(self.viewport())
×
NEW
246
        painter.setPen(self.color_minor)
×
NEW
247
        for x in self.x_minor:
×
NEW
248
            x_val = np.log10(x / self.norm_min) / self.norm_div
×
NEW
249
            line_x = viewport_pos + round(col_width * x_val)
×
NEW
250
            painter.drawLine(line_x, 0, line_x, self.height())
×
NEW
251
        painter.setPen(self.color_major)
×
NEW
252
        for x in self.x_major:
×
NEW
253
            x_val = np.log10(x / self.norm_min) / self.norm_div
×
NEW
254
            line_x = viewport_pos + round(col_width * x_val)
×
NEW
255
            painter.drawLine(line_x, 0, line_x, self.height())
×
NEW
256
        super().paintEvent(event)
×
257

258

259
class TrialsWidget(QWidget):
1✔
260
    trialSelected = Signal(int)
1✔
261

262
    def __init__(self, parent: QObject, model: TrialsTableModel):
1✔
263
        super().__init__(parent)
1✔
264
        self.model = model
1✔
265

266
        layout = QVBoxLayout(self)
1✔
267
        layout.setSpacing(5)
1✔
268
        layout.setContentsMargins(0, 7, 0, 36)
1✔
269
        self.setLayout(layout)
1✔
270

271
        self.titleLabel = QLabel('Trials History')
1✔
272
        self.titleLabel.setAlignment(Qt.AlignHCenter)
1✔
273
        font = self.titleLabel.font()
1✔
274
        font.setPointSize(11)
1✔
275
        self.titleLabel.setFont(font)
1✔
276
        layout.addWidget(self.titleLabel)
1✔
277

278
        self.table_view = TrialsTableView(self)
1✔
279
        self.table_view.setModel(self.model)
1✔
280
        self.table_view.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
1✔
281
        self.table_view.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
1✔
282
        self.table_view.setColumnHidden(2, True)
1✔
283
        self.table_view.setColumnHidden(3, True)
1✔
284
        self.table_view.setColumnHidden(4, True)
1✔
285
        self.table_view.selectionModel().selectionChanged.connect(self._onSelectionChange)
1✔
286
        layout.addWidget(self.table_view)
1✔
287
        layout.setStretch(1, 1)
1✔
288

289
    @Slot(QItemSelection, QItemSelection)
1✔
290
    def _onSelectionChange(self, selected: QItemSelection, _deselected: QItemSelection):
1✔
291
        self.trialSelected.emit(selected.indexes()[0].row())
1✔
292

293

294
class StimulusDelegate(QStyledItemDelegate):
1✔
295
    pen = QColor(0, 0, 0, 128)
1✔
296

297
    def paint(self, painter, option, index: QModelIndex):
1✔
NEW
298
        super().paint(painter, option, index)
×
NEW
299
        location = index.siblingAtColumn(1).data()
×
NEW
300
        contrast = index.siblingAtColumn(2).data()
×
NEW
301
        debias = index.siblingAtColumn(3).data()
×
302

NEW
303
        color = QColor()
×
NEW
304
        color.setHslF(0, 0, 1.0 - contrast)
×
305

NEW
306
        diameter = round(option.rect.height() * 0.8)
×
NEW
307
        spacing = (option.rect.height() - diameter) // 2
×
NEW
308
        x_pos = option.rect.left() + spacing if location < 0 else option.rect.right() - diameter - spacing
×
NEW
309
        y_pos = option.rect.top() + spacing
×
310

311
        # draw circle
NEW
312
        painter.save()
×
NEW
313
        painter.setRenderHint(QPainter.Antialiasing)
×
NEW
314
        painter.setBrush(color)
×
NEW
315
        painter.setPen(self.pen)
×
NEW
316
        painter.drawEllipse(x_pos, y_pos, diameter, diameter)
×
317

NEW
318
        if debias:
×
NEW
319
            rect = QRect(x_pos, y_pos, diameter, diameter)
×
NEW
320
            painter.setPen(QColor('white') if contrast > 0.5 else QColor('black'))
×
NEW
321
            painter.setFont(QFont(painter.font().family(), 9, -1, False))
×
NEW
322
            painter.drawText(rect, Qt.AlignHCenter | Qt.AlignVCenter, 'DB')
×
NEW
323
        painter.restore()
×
324

325
    def displayText(self, value, locale):
1✔
326
        return ''
1✔
327

328

329
class ResponseTimeDelegate(QStyledItemDelegate):
1✔
330
    norm_min = 0.1
1✔
331
    norm_max = 102.0
1✔
332
    norm_div = np.log(norm_max / norm_min)
1✔
333
    color_correct = QColor(0, 107, 90)
1✔
334
    color_error = QColor(219, 67, 37)
1✔
335
    color_nogo = QColor(192, 192, 192)
1✔
336
    color_text = QColor('white')
1✔
337
    color_gradient0 = QColor(255, 255, 255, 0)
1✔
338

339
    def paint(self, painter, option, index):
1✔
NEW
340
        super().paint(painter, option, index)
×
341

342
        # Get the float value from the model
NEW
343
        value = index.data()
×
NEW
344
        outcome = index.sibling(index.row(), 4).data()
×
345

346
        # Draw the progress bar
NEW
347
        painter.fillRect(option.rect, option.backgroundBrush)
×
NEW
348
        if outcome == 'no-go':
×
NEW
349
            return
×
NEW
350
        norm_value = np.log(value / self.norm_min) / self.norm_div
×
NEW
351
        filled_rect = QRectF(option.rect)
×
NEW
352
        filled_rect.setWidth(filled_rect.width() * norm_value)
×
NEW
353
        gradient = QLinearGradient(filled_rect.topLeft(), filled_rect.topRight())
×
NEW
354
        gradient.setColorAt(0, self.color_gradient0)
×
NEW
355
        gradient.setColorAt(1, self.color_correct if outcome == 'correct' else self.color_error)
×
NEW
356
        painter.setBrush(gradient)
×
NEW
357
        painter.setPen(Qt.NoPen)
×
NEW
358
        painter.drawRect(filled_rect)
×
359

NEW
360
        painter.setPen(self.color_text)
×
NEW
361
        value_text = f'{value:.2f}' if outcome != 'no-go' else 'N/A'
×
NEW
362
        filled_rect.adjust(0, 0, -5, 0)
×
NEW
363
        painter.drawText(filled_rect, Qt.AlignVCenter | Qt.AlignRight, value_text)
×
364

365
    def displayText(self, value, locale):
1✔
NEW
366
        return ''
×
367

368

369
class StateMeshItem(pg.PColorMeshItem):
1✔
370
    """
371
    A graphical item for displaying a color mesh that represents Bpod states.
372

373
    This class extends the PyQtGraph's `PColorMeshItem` to provide
374
    functionality for emitting signals when the mouse hovers over
375
    different states in the mesh.
376

377
    Attributes
378
    ----------
379
    stateIndex : Signal
380
        A signal that emits the index of the state currently hovered over.
381
    """
382

383
    stateIndex = Signal(int)
1✔
384

385
    def __init__(self, *args, **kwargs):
1✔
386
        super().__init__(*args, **kwargs)
1✔
387

388
    def hoverEvent(self, ev: QGraphicsSceneHoverEvent):
1✔
389
        """
390
        Handle hover events over the mesh item.
391

392
        This method emits the index of the state that the mouse is currently
393
        hovering over. If the mouse exits the item or hovers over an area
394
        that does not correspond to a state, it emits -1.
395

396
        Parameters
397
        ----------
398
        ev : QGraphicsSceneHoverEvent
399
            The event object containing information about the hover event.
400
        """
NEW
401
        if ev.exit:
×
402
            # If the mouse exits the item, emit -1 to indicate no state is hovered
NEW
403
            if not hasattr(ev, '_scenePos'):
×
NEW
404
                self.stateIndex.emit(-1)
×
405
            else:
NEW
406
                item = self.scene().itemAt(ev.scenePos(), QTransform())
×
NEW
407
                if not isinstance(item, QGraphicsRectItem):
×
NEW
408
                    self.stateIndex.emit(-1)
×
NEW
409
            return
×
410

NEW
411
        try:
×
412
            # Get the x-coordinate of the mouse position relative to the item
NEW
413
            x = self.mapFromParent(ev.pos()).x()
×
NEW
414
        except AttributeError:
×
NEW
415
            return
×
NEW
416
        try:
×
417
            # Find the index of the state corresponding to the x-coordinate
NEW
418
            i = self.z[:, np.where(self.x[0, :] <= x)[0][-1]][0]
×
NEW
419
        except IndexError:
×
NEW
420
            return
×
421

422
        # Emit the index of the hovered state
NEW
423
        self.stateIndex.emit(i)
×
424

425

426
class BpodWidget(pg.GraphicsLayoutWidget):
1✔
427
    """
428
    A widget for visualizing Bpod data in a graphical layout.
429

430
    This widget displays digital channels and Bpod states over time,
431
    allowing for the visualization of trial data.
432
    """
433

434
    data = pd.DataFrame()
1✔
435
    labels: dict[str, pg.LabelItem] = dict()
1✔
436
    plots: dict[str, pg.PlotDataItem] = dict()
1✔
437
    meshes: dict[str, StateMeshItem] = dict()
1✔
438
    viewBoxes: dict[str, pg.ViewBox] = dict()
1✔
439

440
    def __init__(
1✔
441
        self,
442
        *args,
443
        title: str | None = None,
444
        alpha: int = 64,
445
        channels: Iterable | None = None,
446
        showStatusTips: bool = True,
447
        **kwargs,
448
    ):
449
        """
450
        Initialize the BpodWidget.
451

452
        Parameters
453
        ----------
454
        *args : tuple
455
            Positional arguments to be passed to the parent class.
456
        title : str | None, optional
457
            The title of the widget (default is None).
458
        alpha : int, optional
459
            The alpha value used in color-coding the Bpod states. Default: 64.
460
        channels : Iterable, optional
461
            An iterable of channel names to be included in the plot.
462
            Defaults are `BNC1`, `BNC2`, and `Port1`.
463
        showStatusTips : bool, optional
464
            Show status tips when hovering the mouse over state regions. Default: True
465
        **kwargs : dict
466
            Keyword arguments to be passed to the parent class.
467
        """
468
        super().__init__(*args, **kwargs)
1✔
469

470
        # set rendering hint and layout options
471
        self.setRenderHints(QPainter.Antialiasing)
1✔
472
        self.setBackground('white')
1✔
473
        self.centralWidget.setSpacing(0)
1✔
474
        self.centralWidget.setContentsMargins(0, 0, 0, 0)
1✔
475

476
        # define colormap for Bpod states
477
        colormap = pg.colormap.get('glasbey_light', source='colorcet')
1✔
478
        colors = colormap.getLookupTable(0, 1, 256, alpha=True)
1✔
479
        colors[:, 3] = alpha
1✔
480
        self.colormap = pg.ColorMap(colormap.pos, colors)
1✔
481

482
        # add title
483
        if title is not None:
1✔
484
            self.centralWidget.nextRow()
1✔
485
            self.addLabel(title, size='11pt', col=1, color='k')
1✔
486

487
        # add digital channels
488
        for channel in channels or ('BNC1', 'BNC2', 'Port1'):
1✔
489
            self.addDigitalChannel(channel)
1✔
490

491
        # add x axis
492
        self.centralWidget.nextRow()
1✔
493
        a = pg.AxisItem(orientation='bottom', textPen='k', linkView=list(self.viewBoxes.values())[0], parent=self.centralWidget)
1✔
494
        a.setLabel(text='Time', units='s')
1✔
495
        a.enableAutoSIPrefix(True)
1✔
496
        self.centralWidget.addItem(a, col=1)
1✔
497

498
    def addDigitalChannel(self, channel: str, label: str | None = None):
1✔
499
        """
500
        Add a digital channel to the widget.
501

502
        Parameters
503
        ----------
504
        channel : str
505
            The name of the digital channel to add.
506
        label : str | None, optional
507
            The label for the channel (default is None, which uses the channel name).
508
        """
509
        label = channel if label is None else label
1✔
510
        self.centralWidget.nextRow()
1✔
511
        self.labels[channel] = self.addLabel(label, col=0, color='k')
1✔
512
        self.meshes[channel] = StateMeshItem(colorMap=self.colormap)
1✔
513
        self.meshes[channel].stateIndex.connect(self.showStateInfo)
1✔
514
        self.plots[channel] = pg.PlotDataItem(pen='k', stepMode='right')
1✔
515
        self.plots[channel].setSkipFiniteCheck(True)
1✔
516
        self.viewBoxes[channel] = self.addViewBox(col=1)
1✔
517
        self.viewBoxes[channel].addItem(self.meshes[channel])
1✔
518
        self.viewBoxes[channel].addItem(self.plots[channel])
1✔
519
        self.viewBoxes[channel].setMouseEnabled(x=True, y=False)
1✔
520
        self.viewBoxes[channel].setMenuEnabled(False)
1✔
521
        self.viewBoxes[channel].sigXRangeChanged.connect(self.updateXRange)
1✔
522

523
    def setData(self, data: pd.DataFrame):
1✔
524
        """
525
        Set the data for the widget and update the display.
526

527
        Parameters
528
        ----------
529
        data : pd.DataFrame
530
            The data to be displayed in the widget. The data needs to be organized according to the format returned by
531
            :py:func:`~iblrig.raw_data_loaders.bpod_trial_data_to_dataframe`.
532
        """
533
        self.data = data
1✔
534
        self.showTrial()
1✔
535

536
    @Slot(int)
1✔
537
    def showStateInfo(self, index: int):
1✔
538
        """
539
        Show information about the state in the status bar.
540

541
        Parameters
542
        ----------
543
        index : int
544
            The index of the state to display.
545
        """
NEW
546
        if index < 0:
×
NEW
547
            self.window().statusBar().clearMessage()
×
548
        else:
NEW
549
            self.window().statusBar().showMessage(f'State: {self.data.State.cat.categories[index]}')
×
550

551
    def showTrial(self):
1✔
552
        """
553
        Display the trial data in the widget.
554
        This method updates the limits and plots for each digital channel.
555
        """
556
        limits = self.data[self.data['Type'].isin(['TrialStart', 'TrialEnd'])]
1✔
557
        limits = limits.index.total_seconds()
1✔
558
        self.limits = {'xMin': 0, 'xMax': limits[1] - limits[0], 'minXRange': 0.001, 'yMin': -0.2, 'yMax': 1.2}
1✔
559

560
        state_t0 = self.data[self.data.Type == 'StateStart']
1✔
561
        state_t1 = self.data[self.data.Type == 'StateEnd']
1✔
562
        mesh_x = np.append(state_t0.index.total_seconds(), state_t1.index[-1].total_seconds()) - limits[0]
1✔
563
        mesh_x = np.tile(mesh_x, (2, 1))
1✔
564
        mesh_y = np.zeros(mesh_x.shape) - 0.2
1✔
565
        mesh_y[1, :] = 1.2
1✔
566
        mesh_z = state_t0.State.cat.codes.to_numpy()
1✔
567
        mesh_z = mesh_z[np.newaxis, :]
1✔
568

569
        for channel in self.plots:
1✔
570
            values = self.data.loc[self.data.Channel == channel, 'Value']
1✔
571
            plot_x = values.index.total_seconds().to_numpy() - limits[0]
1✔
572
            plot_y = values.to_numpy()
1✔
573

574
            # Extend the plots to both sides to include axes limits.
575
            if len(plot_x) > 0:
1✔
576
                plot_x = np.insert(plot_x, 0, 0)
1✔
577
                plot_x = np.append(plot_x, limits[1])
1✔
578
                plot_y = np.insert(plot_y, 0, not plot_y[0])
1✔
579
                plot_y = np.append(plot_y, plot_y[-1])
1✔
580

581
            self.plots[channel].setData(plot_x, plot_y)
1✔
582
            self.meshes[channel].setData(mesh_x, mesh_y, mesh_z)
1✔
583
            self.viewBoxes[channel].setLimits(**self.limits)
1✔
584

585
        list(self.viewBoxes.values())[0].setXRange(
1✔
586
            self.data.index[0].total_seconds(), self.data.index[-1].total_seconds(), padding=0
587
        )
588

589
    def updateXRange(self):
1✔
590
        sender = self.sender()
1✔
591
        x_range = sender.viewRange()[0]
1✔
592

593
        # Update the x-range for all other ViewBoxes
594
        for view_box in self.viewBoxes.values():
1✔
595
            if view_box is not sender:  # Avoid updating the sender
1✔
596
                view_box.setXRange(x_range[0], x_range[1], padding=0)
1✔
597

598

599
class OnlinePlotsModel(QObject):
1✔
600
    currentTrialChanged = Signal(int)
1✔
601
    titleChanged = Signal(str)
1✔
602
    titleColorChanged = Signal(str)
1✔
603
    sessionStringAvailable = Signal(str)
1✔
604
    tableModel = TrialsTableModel()
1✔
605
    sessionString = ''
1✔
606
    probability_set = DefaultSettings.PROBABILITY_SET
1✔
607
    contrast_set = DefaultSettings.CONTRAST_SET
1✔
608
    _trial_data = pd.DataFrame()
1✔
609
    _bpod_data: list[pd.DataFrame] = list()
1✔
610
    _jsonable_offset = 0
1✔
611
    _current_trial = 0
1✔
612

613
    @validate_call(config=dict(arbitrary_types_allowed=True))
1✔
614
    def __init__(self, session: FilePath | DirectoryPath | UUID4, parent: QObject | None = None):
1✔
615
        super().__init__(parent=parent)
1✔
616
        is_live = False
1✔
617

618
        # If session is a UUID ...
619
        if not isinstance(session, Path):
1✔
NEW
620
            one = ONE()
×
621

622
            # assert that session exists
NEW
623
            session_exists = len(one.alyx.rest('sessions', 'list', id=session)) > 0
×
NEW
624
            if not session_exists:
×
NEW
625
                raise ValueError(f'Could not find session with ID {session}')
×
626

627
            # load Task Data File
NEW
628
            datasets = one.list_datasets(session, filename='*taskData.raw.jsonable')
×
NEW
629
            if len(datasets) == 0:
×
NEW
630
                raise ValueError(f'Could not find Task Data File for session {session}')
×
NEW
631
            session = one.load_dataset(session, datasets[0], download_only=True)
×
632

633
            # load Task Settings File
NEW
634
            datasets = one.list_datasets(session, filename='*_iblrig_taskSettings.raw.json')
×
NEW
635
            if len(datasets) > 0:
×
NEW
636
                one.load_dataset(session, datasets[0], download_only=True)
×
637

638
        # If session is a directory ...
639
        if session.is_dir():
1✔
640
            if not session.name.startswith('raw_task_data'):
1✔
NEW
641
                raise ValueError(f'Not a Raw Data Directory: {session}')
×
642
            self.raw_data_folder = session
1✔
643
            self.jsonable_file = self.raw_data_folder.joinpath('_iblrig_taskData.raw.jsonable')
1✔
644
            self.settings_file = self.raw_data_folder.joinpath('_iblrig_taskSettings.raw.json')
1✔
645
            if not self.jsonable_file.exists():
1✔
NEW
646
                print('Waiting for data ...')
×
NEW
647
                while not self.jsonable_file.exists():
×
NEW
648
                    time.sleep(0.2)
×
649
            is_live = True
1✔
650

651
        # If session is a file ...
652
        elif session.is_file():
1✔
653
            if not session.name.endswith('.raw.jsonable'):
1✔
NEW
654
                raise ValueError(f'Not a Task Data File: {session}')
×
655
            self.jsonable_file = session
1✔
656
            self.raw_data_folder = session.parent
1✔
657
            self.settings_file = self.raw_data_folder.joinpath('_iblrig_taskSettings.raw.json')
1✔
658

659
        if self.settings_file.exists():
1✔
NEW
660
            with self.settings_file.open('r') as f:
×
NEW
661
                self.task_settings = json.load(f)
×
NEW
662
            self.probability_set = [self.task_settings.get('PROBABILITY_LEFT')] + self.task_settings.get(
×
663
                'BLOCK_PROBABILITY_SET', []
664
            )
NEW
665
            self.contrast_set = np.unique(np.abs(self.task_settings.get('CONTRAST_SET')))
×
666

667
        self.signed_contrasts = np.r_[-np.flipud(self.contrast_set[1:]), self.contrast_set]
1✔
668
        self.psychometrics = pd.DataFrame(
1✔
669
            columns=['count', 'response_time', 'choice', 'response_time_std', 'choice_std'],
670
            index=pd.MultiIndex.from_product([self.probability_set, self.signed_contrasts]),
671
        )
672
        self.psychometrics['count'] = 0
1✔
673
        self.reward_amount = 0
1✔
674
        self._t0 = 0
1✔
675
        self._n_trials = 0
1✔
676
        self._n_trials_correct = 0
1✔
677
        self._n_trials_engaged = 0
1✔
678
        self._seconds_elapsed = 0
1✔
679
        self.titleColor = ''
1✔
680

681
        # get session string in separate thread
682
        session_string_worker = Worker(self.getSessionString)
1✔
683
        QThreadPool.globalInstance().start(session_string_worker)
1✔
684

685
        # read the jsonable file and instantiate a QFileSystemWatcher
686
        self.readJsonable(self.jsonable_file)
1✔
687
        if is_live:
1✔
688
            self.jsonableWatcher = QFileSystemWatcher([str(self.jsonable_file)], parent=self)
1✔
689
            self.jsonableWatcher.fileChanged.connect(self.readJsonable)
1✔
690

691
    @Slot(str)
1✔
692
    def readJsonable(self, _: str) -> None:
1✔
693
        # load jsonable data
694
        trial_data, bpod_data = load_task_jsonable(self.jsonable_file, offset=self._jsonable_offset)
1✔
695
        self._jsonable_offset = self.jsonable_file.stat().st_size
1✔
696
        self._trial_data = pd.concat([self._trial_data, trial_data])
1✔
697
        if len(self._bpod_data) == 0:
1✔
698
            self._t0 = bpod_data[0]['Trial start timestamp']
1✔
699
        self._bpod_data.extend(bpod_data)
1✔
700

701
        # update data for trial history table
702
        table = self._trial_data[['trial_num', 'position', 'contrast']].copy()
1✔
703
        table.columns = ['Trial', 'Stimulus', 'Contrast']
1✔
704
        table['Debias'] = self._trial_data.get('debias_trial', False)
1✔
705
        table['Outcome'] = self._trial_data.apply(
1✔
706
            lambda row: 'no-go'
707
            if (row.get('response_side') == 0 or row.get('response_time') > 60)
708
            else ('correct' if row.get('trial_correct') else 'error'),
709
            axis=1,
710
        )
711
        table['Response Time / s'] = self._trial_data.apply(
1✔
712
            lambda row: np.NAN if row.get('response_side') == 0 else row.get('response_time'), axis=1
713
        )
714
        self.tableModel.setDataFrame(table)
1✔
715

716
        # update some counters
717
        self._n_trials += len(trial_data)
1✔
718
        if len(bpod_data) > 1:
1✔
719
            seconds_elapsed = np.array([trial['Trial end timestamp'] for trial in bpod_data]) - self._t0
1✔
720
            self._seconds_elapsed = seconds_elapsed[-1]
1✔
721
            self._n_trials_engaged += (seconds_elapsed <= EngagedCriterion.SECONDS).sum()
1✔
722
        else:
723
            self._seconds_elapsed = bpod_data[-1]['Trial end timestamp'] - self._t0
1✔
724
            self._n_trials_engaged += self._seconds_elapsed <= EngagedCriterion.SECONDS
1✔
725
        self._n_trials_correct += trial_data['trial_correct'].sum()
1✔
726
        self.reward_amount += trial_data['trial_correct'].sum()
1✔
727

728
        # update psychometrics
729
        trial_data['signed_contrast'] = np.sign(trial_data['position']) * trial_data['contrast']
1✔
730
        for _, row in trial_data.iterrows():
1✔
731
            if row.get('response_side') == 0:
1✔
NEW
732
                continue
×
733
            choice = row.position > 0 if row.trial_correct else row.position < 0
1✔
734
            indexer = (row.stim_probability_left, row.signed_contrast)
1✔
735
            if indexer not in self.psychometrics.index:
1✔
NEW
736
                self.psychometrics.loc[indexer, :] = np.nan
×
NEW
737
                self.psychometrics.loc[indexer, 'count'] = 0
×
738
            self.psychometrics.loc[indexer, 'count'] += 1
1✔
739
            self.psychometrics.loc[indexer, 'response_time'], self.psychometrics.loc[indexer, 'response_time_std'] = online_std(
1✔
740
                new_sample=row.response_time,
741
                new_count=self.psychometrics.loc[indexer, 'count'],
742
                old_mean=self.psychometrics.loc[indexer, 'response_time'],
743
                old_std=self.psychometrics.loc[indexer, 'response_time_std'],
744
            )
745
            self.psychometrics.loc[indexer, 'choice'], self.psychometrics.loc[indexer, 'choice_std'] = online_std(
1✔
746
                new_sample=float(choice),
747
                new_count=self.psychometrics.loc[indexer, 'count'],
748
                old_mean=self.psychometrics.loc[indexer, 'choice'],
749
                old_std=self.psychometrics.loc[indexer, 'choice_std'],
750
            )
751

752
        self.compute_end_session_criteria()
1✔
753
        self.setCurrentTrial(self._n_trials - 1)
1✔
754

755
    def compute_end_session_criteria(self):
1✔
756
        """Implement critera to change the color of the figure display, according to the specifications of the task."""
757
        # Within the first part of the session we don't apply response time criterion
758
        if self._seconds_elapsed < EngagedCriterion.SECONDS:
1✔
759
            color = Colors.TRANSPARENT
1✔
760

761
        # if the mouse has been training for more than 90 minutes subject training too long
762
        elif self._seconds_elapsed > (90 * 60):
1✔
NEW
763
            color = Colors.RED
×
764

765
        # the mouse fails to do more than 400 trials in the first 45 mins
766
        elif self._n_trials_engaged <= EngagedCriterion.TRIAL_COUNT:
1✔
NEW
767
            color = Colors.GREEN
×
768

769
        # the subject reaction time over the last 20 trials is more than 5 times greater than the overall reaction time
770
        elif (self._trial_data['response_time'].median() * 5) < self._trial_data['response_time'][20:].median():
1✔
NEW
771
            color = Colors.YELLOW
×
772

773
        # 90 > time > 45 min and subject's avg response time hasn't significantly decreased
774
        else:
775
            color = Colors.TRANSPARENT
1✔
776

777
        if self.titleColor != color:
1✔
778
            self.titleColor = color
1✔
779
            self.titleColorChanged.emit(color)
1✔
780

781
    def getSessionString(self) -> None:
1✔
NEW
782
        if not hasattr(self, 'task_settings'):
×
NEW
783
            return
×
NEW
784
        training_info, _ = get_subject_training_info(
×
785
            subject_name=self.task_settings.get('SUBJECT_NAME'),
786
            task_name=self.task_settings.get('PYBPOD_PROTOCOL'),
787
            lab=self.task_settings.get('ALYX_LAB'),
788
        )
NEW
789
        use_adaptive_reward = self.task_settings.get('ADAPTIVE_REWARD', False)
×
NEW
790
        reward_amount = training_info['adaptive_reward'] if use_adaptive_reward else self.task_settings.get('REWARD_AMOUNT_UL')
×
NEW
791
        self.sessionString = (
×
792
            f'Subject: {self.task_settings.get("SUBJECT_NAME")}  ·  '
793
            f'Weight: {self.task_settings.get("SUBJECT_WEIGHT")} g  ·  '
794
            f'Training Phase: {training_info["training_phase"]}  ·  '
795
            f'Stimulus Gain: {self.task_settings.get("STIM_GAIN")}  ·  '
796
            f'{"Adaptive " if use_adaptive_reward else ""}Reward Amount: {reward_amount} µl'
797
        )
NEW
798
        self.sessionStringAvailable.emit(self.sessionString)
×
799

800
    @Slot(int)
1✔
801
    def setCurrentTrial(self, value: int) -> None:
1✔
802
        if value != self._current_trial:
1✔
803
            self._current_trial = value
1✔
804
            self.currentTrialChanged.emit(value)
1✔
805
            self.titleChanged.emit(self.getTitle())
1✔
806

807
    def currentTrial(self) -> int:
1✔
NEW
808
        return self._current_trial
×
809

810
    def nTrials(self) -> int:
1✔
811
        return self._n_trials
1✔
812

813
    def timeElapsed(self) -> datetime.timedelta:
1✔
814
        if self._n_trials == 0:
1✔
NEW
815
            return datetime.timedelta(seconds=0)
×
816
        t1 = self._bpod_data[self._current_trial]['Trial end timestamp']
1✔
817
        return datetime.timedelta(seconds=t1 - self._t0)
1✔
818

819
    def percentCorrect(self) -> float:
1✔
820
        return self._n_trials_correct / (self._n_trials if self._n_trials > 0 else np.nan) * 100
1✔
821

822
    def bpod_data(self, trial: int) -> pd.DataFrame:
1✔
823
        return bpod_trial_data_to_dataframe(self._bpod_data[trial], trial)
1✔
824

825
    def getTitle(self) -> str:
1✔
826
        protocol = getattr(self, 'task_settings', dict()).get('PYBPOD_PROTOCOL', 'unknown task protocol')
1✔
827
        spacer = '  ·  '
1✔
828
        t_elapsed = str(self.timeElapsed()).split('.')[0]
1✔
829
        return f'{protocol}{spacer}Trial {self._current_trial}{spacer}Elapsed Time: {t_elapsed}'
1✔
830

831

832
class OnlinePlotsView(QMainWindow):
1✔
833
    colormap = pg.colormap.get('tab10', source='matplotlib')
1✔
834

835
    def __init__(self, session: FilePath | DirectoryPath | UUID4, parent: QObject | None = None):
1✔
836
        super().__init__(parent)
1✔
837
        pg.setConfigOptions(antialias=True)
1✔
838
        self.model = OnlinePlotsModel(session, self)
1✔
839

840
        self.statusBar().clearMessage()
1✔
841
        self.setWindowTitle('Online Plots')
1✔
842
        self.setMinimumSize(1024, 771)
1✔
843
        self.setWindowIcon(QIcon(QPixmap(':/images/iblrig_logo')))
1✔
844

845
        # the frame that contains all the plots
846
        frame = QFrame(self)
1✔
847
        frame.setFrameStyle(QFrame.StyledPanel)
1✔
848
        frame.setStyleSheet('background-color: rgb(255, 255, 255);')
1✔
849
        self.setCentralWidget(frame)
1✔
850

851
        # use a grid layout to organize the different widgets
852
        layout = QGridLayout(frame)
1✔
853
        frame.setLayout(layout)
1✔
854
        layout.setColumnStretch(0, 1)
1✔
855
        layout.setColumnStretch(1, 2)
1✔
856

857
        # titles are arranged in a sub-layout to allow changing the background color in unison
858
        self.titleFrame = QFrame(self)
1✔
859
        title_layout = QVBoxLayout(self.titleFrame)
1✔
860
        self.titleFrame.setLayout(title_layout)
1✔
861
        layout.addWidget(self.titleFrame, 0, 0, 1, 3)
1✔
862

863
        # main title
864
        self.title = QLabel(self.model.getTitle(), self)
1✔
865
        self.title.setAlignment(Qt.AlignHCenter)
1✔
866
        font = self.title.font()
1✔
867
        font.setPointSize(15)
1✔
868
        font.setBold(True)
1✔
869
        self.title.setFont(font)
1✔
870
        self.title.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum)
1✔
871
        self.setTitleBackground(self.model.titleColor)
1✔
872
        title_layout.addWidget(self.title)
1✔
873

874
        # sub title
875
        self.subtitle = QLabel(self)
1✔
876
        self.model.sessionStringAvailable.connect(self.subtitle.setText)
1✔
877
        self.subtitle.setText(self.model.sessionString)
1✔
878
        self.subtitle.setAlignment(Qt.AlignHCenter)
1✔
879
        font.setPointSize(10)
1✔
880
        self.subtitle.setFont(font)
1✔
881
        self.subtitle.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum)
1✔
882
        title_layout.addWidget(self.subtitle)
1✔
883

884
        # trial history
885
        self.trials = TrialsWidget(self, self.model.tableModel)
1✔
886
        self.trials.trialSelected.connect(self.model.setCurrentTrial)
1✔
887
        layout.addWidget(self.trials, 1, 0, 2, 1)
1✔
888

889
        # psychometric function
890
        self.psychometricWidget = FunctionWidget(parent=self, colors=self.colormap, probabilities=self.model.probability_set)
1✔
891
        self.psychometricWidget.plotItem.setTitle('Psychometric Function', color='k')
1✔
892
        self.psychometricWidget.plotItem.getAxis('left').setLabel('Rightward Choices (%)')
1✔
893
        self.psychometricWidget.plotItem.addItem(pg.InfiniteLine(0.5, 0, 'black'))
1✔
894
        self.psychometricWidget.plotItem.setYRange(0, 1, padding=0.05)
1✔
895
        self.psychometricWidget.plotItem.hoverEvent = self.mouseOverFunction
1✔
896
        layout.addWidget(self.psychometricWidget, 1, 1, 1, 1)
1✔
897

898
        # chronometric function
899
        self.chronometricWidget = FunctionWidget(parent=self, colors=self.colormap, probabilities=self.model.probability_set)
1✔
900
        self.chronometricWidget.plotItem.setTitle('Chronometric Function', color='k')
1✔
901
        self.chronometricWidget.plotItem.getAxis('left').setLabel('Response Time (s)')
1✔
902
        self.chronometricWidget.plotItem.setLogMode(x=False, y=True)
1✔
903
        self.chronometricWidget.plotItem.setXLink(self.psychometricWidget.plotItem)
1✔
904
        self.chronometricWidget.plotItem.setXRange(-1, 1, padding=0.025)
1✔
905
        self.chronometricWidget.plotItem.setYRange(-1, 2, padding=0.05)
1✔
906
        self.chronometricWidget.plotItem.hoverEvent = self.mouseOverFunction
1✔
907
        layout.addWidget(self.chronometricWidget, 2, 1, 1, 1)
1✔
908

909
        # performance chart
910
        self.performanceWidget = SingleBarChartWidget(parent=self, textFormat='{:0.1f} %')
1✔
911
        self.performanceWidget.setMinimumWidth(155)
1✔
912
        self.performanceWidget.plotItem.setTitle('Performance', color='k')
1✔
913
        self.performanceWidget.plotItem.getAxis('left').setLabel('Correct Choices (%)')
1✔
914
        self.performanceWidget.plotItem.setYRange(0, 105, padding=0)
1✔
915
        self.performanceWidget.plotItem.hoverEvent = self.mouseOverBarChart
1✔
916
        layout.addWidget(self.performanceWidget, 1, 2, 1, 1)
1✔
917

918
        # reward chart
919
        self.rewardWidget = SingleBarChartWidget(parent=self, barColor=(128, 128, 255), textFormat='{:0.1f} μl')
1✔
920
        self.rewardWidget.plotItem.setTitle('Reward Amount', color='k')
1✔
921
        self.rewardWidget.plotItem.getAxis('left').setLabel('Total Reward Volume (μl)')
1✔
922
        self.rewardWidget.plotItem.setYRange(0, 1050, padding=0)
1✔
923
        self.rewardWidget.plotItem.hoverEvent = self.mouseOverBarChart
1✔
924
        layout.addWidget(self.rewardWidget, 2, 2, 1, 1)
1✔
925

926
        # bpod data
927
        self.bpodWidget = BpodWidget(self, title='Bpod States and Input Channels')
1✔
928
        self.bpodWidget.setMinimumHeight(130)
1✔
929
        self.bpodWidget.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Minimum)
1✔
930
        layout.addWidget(self.bpodWidget, 3, 0, 1, 3)
1✔
931

932
        # connect signals / slots
933
        self.model.titleChanged.connect(self.setTitle)
1✔
934
        self.model.titleColorChanged.connect(self.setTitleBackground)
1✔
935
        self.model.currentTrialChanged.connect(self.updatePlots)
1✔
936
        self.updatePlots(self.model.nTrials() - 1)
1✔
937

938
        # manage settings
939
        self.settings = QSettings()
1✔
940
        self.move(self.settings.value('pos', self.pos(), QPoint))
1✔
941
        self.resize(self.settings.value('size', self.size(), QSize))
1✔
942

943
    @Slot(str)
1✔
944
    def setTitle(self, title: str):
1✔
945
        self.title.setText(title)
1✔
946

947
    @Slot(str)
1✔
948
    def setTitleBackground(self, color: str):
1✔
949
        """Set the background color of the title area to a gradient of the specified color."""
950
        self.titleFrame.setStyleSheet(
1✔
951
            f'QFrame {{ background-color: qlineargradient(x1: 0, x2: 1, '
952
            f'stop: 0 {color}, stop: 0.2 transparent, stop: 0.8 transparent, stop: 1 {color}); }}\n'
953
            f'QLabel {{ background-color: transparent; }}'
954
        )
955

956
    def mouseOverBarChart(self, event):
1✔
NEW
957
        statusbar = self.window().statusBar()
×
NEW
958
        if event.exit:
×
NEW
959
            statusbar.clearMessage()
×
NEW
960
        elif event.currentItem.vb.sceneBoundingRect().contains(event.scenePos()):
×
NEW
961
            if event.currentItem == self.performanceWidget.plotItem:
×
NEW
962
                statusbar.showMessage(f'Performance: {self.model.percentCorrect():0.1f}% correct choices')
×
963
            else:
NEW
964
                statusbar.showMessage(f'Total reward volume: {self.model.reward_amount:0.1f} μl')
×
965

966
    def mouseOverFunction(self, event):
1✔
NEW
967
        statusbar = self.window().statusBar()
×
NEW
968
        if event.exit:
×
NEW
969
            statusbar.clearMessage()
×
NEW
970
        elif event.currentItem.vb.sceneBoundingRect().contains(event.scenePos()):
×
NEW
971
            if event.currentItem == self.psychometricWidget.plotItem:
×
NEW
972
                statusbar.showMessage('Psychometric Function, SEM')
×
973
            else:
NEW
974
                statusbar.showMessage('Chronometric Function, SEM')
×
975

976
    @Slot(int)
1✔
977
    def updatePlots(self, trial: int):
1✔
978
        self.bpodWidget.setData(self.model.bpod_data(trial))
1✔
979
        self.trials.table_view.setCurrentIndex(self.model.tableModel.index(trial, 0))
1✔
980
        self.trials.table_view.scrollTo(self.model.tableModel.index(trial, 0))
1✔
981
        for p in self.model.probability_set:
1✔
982
            data = self.model.psychometrics.loc[p].dropna(axis=0).astype(float)
1✔
983
            x = data.index.to_numpy()
1✔
984
            y = data.choice.to_numpy()
1✔
985
            sqrt_n = np.sqrt(data['count'].to_numpy())
1✔
986
            e = data.choice_std.to_numpy() / sqrt_n
1✔
987
            self.psychometricWidget.upperCurves[p].setData(x=x, y=y + e)
1✔
988
            self.psychometricWidget.lowerCurves[p].setData(x=x, y=y - e)
1✔
989
            self.psychometricWidget.plotDataItems[p].setData(x=x, y=y)
1✔
990
            y = data.response_time.to_numpy()
1✔
991
            e = data.response_time_std.to_numpy() / sqrt_n
1✔
992
            self.chronometricWidget.upperCurves[p].setData(x=x, y=y + e)
1✔
993
            self.chronometricWidget.lowerCurves[p].setData(x=x, y=np.clip(y - e, np.finfo(float).tiny, None))
1✔
994
            self.chronometricWidget.plotDataItems[p].setData(x=x, y=y)
1✔
995
        self.performanceWidget.setValue(self.model.percentCorrect())
1✔
996
        self.rewardWidget.setValue(self.model.reward_amount)
1✔
997
        self.update()
1✔
998

999
    def keyPressEvent(self, event) -> None:
1✔
1000
        """Navigate trials using directional keys."""
NEW
1001
        match event.key():
×
NEW
1002
            case Qt.Key.Key_Up:
×
NEW
1003
                if self.model.currentTrial() > 0:
×
NEW
1004
                    self.model.setCurrentTrial(self.model.currentTrial() - 1)
×
NEW
1005
            case Qt.Key.Key_Down:
×
NEW
1006
                if self.model.currentTrial() < (self.model.nTrials() - 1):
×
NEW
1007
                    self.model.setCurrentTrial(self.model.currentTrial() + 1)
×
NEW
1008
            case Qt.Key.Key_Home:
×
NEW
1009
                self.model.setCurrentTrial(0)
×
NEW
1010
            case Qt.Key.Key_End:
×
NEW
1011
                self.model.setCurrentTrial(self.model.nTrials() - 1)
×
NEW
1012
            case _:
×
NEW
1013
                return
×
NEW
1014
        event.accept()
×
1015

1016
    def moveEvent(self, event):
1✔
NEW
1017
        self.settings.setValue('pos', self.pos())
×
NEW
1018
        super().moveEvent(event)
×
1019

1020
    def resizeEvent(self, event):
1✔
NEW
1021
        self.settings.setValue('size', self.size())
×
NEW
1022
        super().resizeEvent(event)
×
1023

1024

1025
def online_plots_cli(*args):
1✔
NEW
1026
    sys.argv.extend([str(arg) for arg in args])
×
1027

NEW
1028
    class CLISettings(BaseSettings, cli_parse_args=True, cli_enforce_required=False, cli_avoid_json=True):
×
1029
        """Display a Session's Online Plot."""
1030

NEW
1031
        session: CliPositionalArg[FilePath | DirectoryPath | UUID4] = Field(description="a session's Task Data File or eID")
×
1032

1033
    # set app information
NEW
1034
    QCoreApplication.setOrganizationName('International Brain Laboratory')
×
NEW
1035
    QCoreApplication.setOrganizationDomain('internationalbrainlab.org')
×
NEW
1036
    QCoreApplication.setApplicationName('IBLRIG Online Plots')
×
NEW
1037
    if os.name == 'nt':
×
NEW
1038
        app_id = f'IBL.iblrig.online_plots.{iblrig_version}'
×
NEW
1039
        ctypes.windll.shell32.SetCurrentProcessExplicitAppUserModelID(app_id)
×
1040

NEW
1041
    app = QApplication([])
×
1042

NEW
1043
    if len(sys.argv) < 2:
×
NEW
1044
        local_subjects_folder = str(get_local_and_remote_paths()['local_subjects_folder'])
×
NEW
1045
        session, _ = QFileDialog.getOpenFileName(
×
1046
            caption='Select Task Data File', filter='Task Data (*.raw.jsonable)', directory=local_subjects_folder
1047
        )
NEW
1048
        if len(session) == 0:
×
NEW
1049
            return
×
1050
    else:
NEW
1051
        session = CLISettings().session
×
NEW
1052
    window = OnlinePlotsView(session)
×
NEW
1053
    window.show()
×
1054

NEW
1055
    sys.exit(app.exec())
×
1056

1057

1058
if __name__ == '__main__':
1✔
NEW
1059
    online_plots_cli()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc