• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 1761696499260742

05 Oct 2023 09:46AM UTC coverage: 55.27% (-1.4%) from 56.628%
1761696499260742

Pull #655

continuous-integration/UCL

bimac
add @sleepless decorator
Pull Request #655: add @sleepless decorator

21 of 21 new or added lines in 1 file covered. (100.0%)

10330 of 18690 relevant lines covered (55.27%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

32.61
/ibllib/io/extractors/video_motion.py
1
"""
1✔
2
A module for aligning the wheel motion with the rotary encoder.  Currently used by the camera QC
3
in order to check timestamp alignment.
4
"""
5
import matplotlib
1✔
6
import matplotlib.pyplot as plt
1✔
7
from matplotlib.widgets import RectangleSelector
1✔
8
import numpy as np
1✔
9
from scipy import signal
1✔
10
import cv2
1✔
11
from itertools import cycle
1✔
12
import matplotlib.animation as animation
1✔
13
import logging
1✔
14
from pathlib import Path
1✔
15

16
from one.api import ONE
1✔
17
import ibllib.io.video as vidio
1✔
18
from iblutil.util import Bunch
1✔
19
import brainbox.video as video
1✔
20
import brainbox.behavior.wheel as wh
1✔
21
import one.alf.io as alfio
1✔
22
from one.alf.spec import is_session_path, is_uuid_string
1✔
23

24

25
def find_nearest(array, value):
1✔
26
    array = np.asarray(array)
1✔
27
    idx = (np.abs(array - value)).argmin()
1✔
28
    return idx
1✔
29

30

31
class MotionAlignment:
1✔
32
    roi = {
1✔
33
        'left': ((800, 1020), (233, 1096)),
34
        'right': ((426, 510), (104, 545)),
35
        'body': ((402, 481), (31, 103))
36
    }
37

38
    def __init__(self, eid=None, one=None, log=logging.getLogger(__name__), **kwargs):
1✔
39
        self.one = one or ONE()
1✔
40
        self.eid = eid
1✔
41
        self.session_path = kwargs.pop('session_path', None) or self.one.eid2path(eid)
1✔
42
        self.ref = self.one.dict2ref(self.one.path2ref(self.session_path))
1✔
43
        self.log = log
1✔
44
        self.trials = self.wheel = self.camera_times = None
1✔
45
        raw_cam_path = self.session_path.joinpath('raw_video_data')
1✔
46
        camera_path = list(raw_cam_path.glob('_iblrig_*Camera.raw.*'))
1✔
47
        self.video_paths = {vidio.label_from_path(x): x for x in camera_path}
1✔
48
        self.data = Bunch()
1✔
49
        self.alignment = Bunch()
1✔
50

51
    def align_all_trials(self, side='all'):
1✔
52
        """Align all wheel motion for all trials"""
53
        if self.trials is None:
×
54
            self.load_data()
×
55
        if side == 'all':
×
56
            side = self.video_paths.keys()
×
57
        if not isinstance(side, str):
×
58
            # Try to iterate over sides
59
            [self.align_all_trials(s) for s in side]
×
60
        if side not in self.video_paths:
×
61
            raise ValueError(f'{side} camera video file not found')
×
62
        # Align each trial sequentially
63
        for i in np.arange(self.trials['intervals'].shape[0]):
×
64
            self.align_motion(i, display=False)
×
65

66
    @staticmethod
1✔
67
    def set_roi(video_path):
1✔
68
        """Manually set the ROIs for a given set of videos
69
        TODO Improve docstring
70
        TODO A method for setting ROIs by label
71
        """
72
        frame = vidio.get_video_frame(str(video_path), 0)
×
73

74
        def line_select_callback(eclick, erelease):
×
75
            """
76
            Callback for line selection.
77

78
            *eclick* and *erelease* are the press and release events.
79
            """
80
            x1, y1 = eclick.xdata, eclick.ydata
×
81
            x2, y2 = erelease.xdata, erelease.ydata
×
82
            print("(%3.2f, %3.2f) --> (%3.2f, %3.2f)" % (x1, y1, x2, y2))
×
83
            return np.array([[x1, x2], [y1, y2]])
×
84

85
        plt.imshow(frame)
×
86
        roi = RectangleSelector(plt.gca(), line_select_callback,
×
87
                                drawtype='box', useblit=True,
88
                                button=[1, 3],  # don't use middle button
89
                                minspanx=5, minspany=5,
90
                                spancoords='pixels',
91
                                interactive=True)
92
        plt.show()
×
93
        ((x1, x2, *_), (y1, *_, y2)) = roi.corners
×
94
        col = np.arange(round(x1), round(x2), dtype=int)
×
95
        row = np.arange(round(y1), round(y2), dtype=int)
×
96
        return col, row
×
97

98
    def load_data(self, download=False):
1✔
99
        """
100
        Load wheel, trial and camera timestamp data
101
        :return: wheel, trials
102
        """
103
        if download:
×
104
            self.data.wheel = self.one.load_object(self.eid, 'wheel')
×
105
            self.data.trials = self.one.load_object(self.eid, 'trials')
×
106
            cam = self.one.load(self.eid, ['camera.times'], dclass_output=True)
×
107
            self.data.camera_times = {vidio.label_from_path(url): ts
×
108
                                      for ts, url in zip(cam.data, cam.url)}
109
        else:
110
            alf_path = self.session_path / 'alf'
×
111
            self.data.wheel = alfio.load_object(alf_path, 'wheel', short_keys=True)
×
112
            self.data.trials = alfio.load_object(alf_path, 'trials')
×
113
            self.data.camera_times = {vidio.label_from_path(x): alfio.load_file_content(x)
×
114
                                      for x in alf_path.glob('*Camera.times*')}
115
        assert all(x is not None for x in self.data.values())
×
116

117
    def _set_eid_or_path(self, session_path_or_eid):
1✔
118
        """Parse a given eID or session path
119
        If a session UUID is given, resolves and stores the local path and vice versa
120
        :param session_path_or_eid: A session eid or path
121
        :return:
122
        """
123
        self.eid = None
×
124
        if is_uuid_string(str(session_path_or_eid)):
×
125
            self.eid = session_path_or_eid
×
126
            # Try to set session_path if data is found locally
127
            self.session_path = self.one.eid2path(self.eid)
×
128
        elif is_session_path(session_path_or_eid):
×
129
            self.session_path = Path(session_path_or_eid)
×
130
            if self.one is not None:
×
131
                self.eid = self.one.path2eid(self.session_path)
×
132
                if not self.eid:
×
133
                    self.log.warning('Failed to determine eID from session path')
×
134
        else:
135
            self.log.error('Cannot run alignment: an experiment uuid or session path is required')
×
136
            raise ValueError("'session' must be a valid session path or uuid")
×
137

138
    def align_motion(self, period=(-np.inf, np.inf), side='left', sd_thresh=10, display=False):
1✔
139
        """
140
        Align video to the wheel using cross-correlation of the video motion signal and the rotary
141
        encoder.
142

143
        Parameters
144
        ----------
145
        period : (float, float)
146
            The time period over which to do the alignment.
147
        side : {'left', 'right'}
148
            With which camera to perform the alignment.
149
        sd_thresh : float
150
            For plotting where the motion energy goes above this standard deviation threshold.
151
        display : bool
152
            When true, displays the aligned wheel motion energy along with the rotary encoder
153
            signal.
154

155
        Returns
156
        -------
157
        int
158
            Frame offset, i.e. by how many frames the video was shifted to match the rotary encoder
159
            signal.  Negative values mean the video was shifted backwards with respect to the wheel
160
            timestamps.
161
        float
162
            The peak cross-correlation.
163
        numpy.ndarray
164
            The motion energy used in the cross-correlation, i.e. the frame difference for the
165
            period given.
166
        """
167
        # Get data samples within period
168
        wheel = self.data['wheel']
1✔
169
        self.alignment.label = side
1✔
170
        self.alignment.to_mask = lambda ts: np.logical_and(ts >= period[0], ts <= period[1])
1✔
171
        camera_times = self.data['camera_times'][side]
1✔
172
        cam_mask = self.alignment.to_mask(camera_times)
1✔
173
        frame_numbers, = np.where(cam_mask)
1✔
174

175
        if frame_numbers.size == 0:
1✔
176
            raise ValueError('No frames during given period')
×
177

178
        # Motion Energy
179
        camera_path = self.video_paths[side]
1✔
180
        roi = (*[slice(*r) for r in self.roi[side]], 0)
1✔
181
        try:
1✔
182
            # TODO Add function arg to make grayscale
183
            self.alignment.frames = \
1✔
184
                vidio.get_video_frames_preload(camera_path, frame_numbers, mask=roi)
185
            assert self.alignment.frames.size != 0
1✔
186
        except AssertionError:
×
187
            self.log.error('Failed to open video')
×
188
            return None, None, None
×
189
        self.alignment.df, stDev = video.motion_energy(self.alignment.frames, 2)
1✔
190
        self.alignment.period = period  # For plotting
1✔
191

192
        # Calculate rotary encoder velocity trace
193
        x = camera_times[cam_mask]
1✔
194
        Fs = 1000
1✔
195
        pos, t = wh.interpolate_position(wheel.timestamps, wheel.position, freq=Fs)
1✔
196
        v, _ = wh.velocity_filtered(pos, Fs)
1✔
197
        interp_mask = self.alignment.to_mask(t)
1✔
198
        # Convert to normalized speed
199
        xs = np.unique([find_nearest(t[interp_mask], ts) for ts in x])
1✔
200
        vs = np.abs(v[interp_mask][xs])
1✔
201
        vs = (vs - np.min(vs)) / (np.max(vs) - np.min(vs))
1✔
202

203
        # FIXME This can be used as a goodness of fit measure
204
        USE_CV2 = False
1✔
205
        if USE_CV2:
1✔
206
            # convert from numpy format to openCV format
207
            dfCV = np.float32(self.alignment.df.reshape((-1, 1)))
×
208
            reCV = np.float32(vs.reshape((-1, 1)))
×
209

210
            # perform cross correlation
211
            resultCv = cv2.matchTemplate(dfCV, reCV, cv2.TM_CCORR_NORMED)
×
212

213
            # convert result back to numpy array
214
            xcorr = np.asarray(resultCv)
×
215
        else:
216
            xcorr = signal.correlate(self.alignment.df, vs)
1✔
217

218
        # Cross correlate wheel speed trace with the motion energy
219
        CORRECTION = 2
1✔
220
        self.alignment.c = max(xcorr)
1✔
221
        self.alignment.xcorr = np.argmax(xcorr)
1✔
222
        self.alignment.dt_i = self.alignment.xcorr - xs.size + CORRECTION
1✔
223
        self.log.info(f'{side} camera, adjusted by {self.alignment.dt_i} frames')
1✔
224

225
        if display:
1✔
226
            # Plot the motion energy
227
            fig, ax = plt.subplots(2, 1, sharex='all')
×
228
            y = np.pad(self.alignment.df, 1, 'edge')
×
229
            ax[0].plot(x, y, '-x', label='wheel motion energy')
×
230
            thresh = stDev > sd_thresh
×
231
            ax[0].vlines(x[np.array(np.pad(thresh, 1, 'constant', constant_values=False))], 0, 1,
×
232
                         linewidth=0.5, linestyle=':', label=f'>{sd_thresh} s.d. diff')
233
            ax[1].plot(t[interp_mask], np.abs(v[interp_mask]))
×
234

235
            # Plot other stuff
236
            dt = np.diff(camera_times[[0, np.abs(self.alignment.dt_i)]])
×
237
            fps = 1 / np.diff(camera_times).mean()
×
238
            ax[0].plot(t[interp_mask][xs] - dt, vs, 'r-x', label='velocity (shifted)')
×
239
            ax[0].set_title('normalized motion energy, %s camera, %.0f fps' % (side, fps))
×
240
            ax[0].set_ylabel('rate of change (a.u.)')
×
241
            ax[0].legend()
×
242
            ax[1].set_ylabel('wheel speed (rad / s)')
×
243
            ax[1].set_xlabel('Time (s)')
×
244

245
            title = f'{self.ref}, from {period[0]:.1f}s - {period[1]:.1f}s'
×
246
            fig.suptitle(title, fontsize=16)
×
247
            fig.set_size_inches(19.2, 9.89)
×
248

249
        return self.alignment.dt_i, self.alignment.c, self.alignment.df
1✔
250

251
    def plot_alignment(self, energy=True, save=False):
1✔
252
        if not self.alignment:
×
253
            self.log.error('No alignment data, run `align_motion` first')
×
254
            return
×
255
        # Change backend based on save flag
256
        backend = matplotlib.get_backend().lower()
×
257
        if (save and backend != 'agg') or (not save and backend == 'agg'):
×
258
            new_backend = 'Agg' if save else 'Qt5Agg'
×
259
            self.log.warning('Switching backend from %s to %s', backend, new_backend)
×
260
            matplotlib.use(new_backend)
×
261
        from matplotlib import pyplot as plt
×
262

263
        # Main animated plots
264
        fig, axes = plt.subplots(nrows=2)
×
265
        title = f'{self.ref}'  # ', from {period[0]:.1f}s - {period[1]:.1f}s'
×
266
        fig.suptitle(title, fontsize=16)
×
267

268
        wheel = self.data['wheel']
×
269
        wheel_mask = self.alignment['to_mask'](wheel.timestamps)
×
270
        ts = self.data['camera_times'][self.alignment['label']]
×
271
        frame_numbers, = np.where(self.alignment['to_mask'](ts))
×
272
        if energy:
×
273
            self.alignment['frames'] = video.frame_diffs(self.alignment['frames'], 2)
×
274
            frame_numbers = frame_numbers[1:-1]
×
275
        data = {'frame_ids': frame_numbers}
×
276

277
        def init_plot():
×
278
            """
279
            Plot the wheel data for the current trial
280
            :return: None
281
            """
282
            data['im'] = axes[0].imshow(self.alignment['frames'][0])
×
283
            axes[0].axis('off')
×
284
            axes[0].set_title(f'adjusted by {self.alignment["dt_i"]} frames')
×
285

286
            # Plot the wheel position
287
            ax = axes[1]
×
288
            ax.clear()
×
289
            ax.plot(wheel.timestamps[wheel_mask], wheel.position[wheel_mask], '-x')
×
290

291
            ts_0 = frame_numbers[0]
×
292
            data['idx_0'] = ts_0 - self.alignment['dt_i']
×
293
            ts_0 = ts[ts_0 + self.alignment['dt_i']]
×
294
            data['ln'] = ax.axvline(x=ts_0, color='k')
×
295
            ax.set_xlim([ts_0 - (3 / 2), ts_0 + (3 / 2)])
×
296
            data['frame_num'] = 0
×
297
            mkr = find_nearest(wheel.timestamps[wheel_mask], ts_0)
×
298

299
            data['marker'], = ax.plot(
×
300
                wheel.timestamps[wheel_mask][mkr],
301
                wheel.position[wheel_mask][mkr], 'r-x')
302
            ax.set_ylabel('Wheel position (rad))')
×
303
            ax.set_xlabel('Time (s))')
×
304
            return
×
305

306
        def animate(i):
×
307
            """
308
            Callback for figure animation.  Sets image data for current frame and moves pointer
309
            along axis
310
            :param i: unused; the current time step of the calling method
311
            :return: None
312
            """
313
            if i < 0:
×
314
                data['frame_num'] -= 1
×
315
                if data['frame_num'] < 0:
×
316
                    data['frame_num'] = len(self.alignment['frames']) - 1
×
317
            else:
318
                data['frame_num'] += 1
×
319
                if data['frame_num'] >= len(self.alignment['frames']):
×
320
                    data['frame_num'] = 0
×
321
            i = data['frame_num']  # NB: This is index for current trial's frame list
×
322

323
            frame = self.alignment['frames'][i]
×
324
            t_x = ts[data['idx_0'] + i]
×
325
            data['ln'].set_xdata([t_x, t_x])
×
326
            axes[1].set_xlim([t_x - (3 / 2), t_x + (3 / 2)])
×
327
            data['im'].set_data(frame)
×
328

329
            mkr = find_nearest(wheel.timestamps[wheel_mask], t_x)
×
330
            data['marker'].set_data(
×
331
                wheel.timestamps[wheel_mask][mkr],
332
                wheel.position[wheel_mask][mkr]
333
            )
334

335
            return data['im'], data['ln'], data['marker']
×
336

337
        anim = animation.FuncAnimation(fig, animate, init_func=init_plot,
×
338
                                       frames=(range(len(self.alignment.df))
339
                                               if save
340
                                               else cycle(range(60))),
341
                                       interval=20, blit=False,
342
                                       repeat=not save, cache_frame_data=False)
343
        anim.running = False
×
344

345
        def process_key(event):
×
346
            """
347
            Callback for key presses.
348
            :param event: a figure key_press_event
349
            :return: None
350
            """
351
            if event.key.isspace():
×
352
                if anim.running:
×
353
                    anim.event_source.stop()
×
354
                else:
355
                    anim.event_source.start()
×
356
                anim.running = ~anim.running
×
357
            elif event.key == 'right':
×
358
                if anim.running:
×
359
                    anim.event_source.stop()
×
360
                    anim.running = False
×
361
                animate(1)
×
362
                fig.canvas.draw()
×
363
            elif event.key == 'left':
×
364
                if anim.running:
×
365
                    anim.event_source.stop()
×
366
                    anim.running = False
×
367
                animate(-1)
×
368
                fig.canvas.draw()
×
369

370
        fig.canvas.mpl_connect('key_press_event', process_key)
×
371

372
        # init_plot()
373
        # while True:
374
        #     animate(0)
375
        if save:
×
376
            filename = '%s_%c.mp4' % (self.ref, self.alignment['label'][0])
×
377
            if isinstance(save, (str, Path)):
×
378
                filename = Path(save).joinpath(filename)
×
379
            self.log.info(f'Saving to {filename}')
×
380
            # Set up formatting for the movie files
381
            Writer = animation.writers['ffmpeg']
×
382
            writer = Writer(fps=24, metadata=dict(artist='Miles Wells'), bitrate=1800)
×
383
            anim.save(str(filename), writer=writer)
×
384
        else:
385
            plt.show()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc