• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

int-brain-lab / ibllib / 1761696499260742

05 Oct 2023 09:46AM UTC coverage: 55.27% (-1.4%) from 56.628%
1761696499260742

Pull #655

continuous-integration/UCL

bimac
add @sleepless decorator
Pull Request #655: add @sleepless decorator

21 of 21 new or added lines in 1 file covered. (100.0%)

10330 of 18690 relevant lines covered (55.27%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.22
/ibllib/pipes/base_tasks.py
1
"""Abstract base classes for dynamic pipeline tasks."""
1✔
2
import logging
1✔
3
from pathlib import Path
1✔
4

5
from pkg_resources import parse_version
1✔
6
from one.webclient import no_cache
1✔
7
from iblutil.util import flatten
1✔
8

9
from ibllib.pipes.tasks import Task
1✔
10
import ibllib.io.session_params as sess_params
1✔
11
from ibllib.qc.base import sign_off_dict, SIGN_OFF_CATEGORIES
1✔
12
from ibllib.io.raw_daq_loaders import load_timeline_sync_and_chmap
1✔
13

14
_logger = logging.getLogger(__name__)
1✔
15

16

17
class DynamicTask(Task):
1✔
18

19
    def __init__(self, session_path, **kwargs):
1✔
20
        super().__init__(session_path, **kwargs)
1✔
21
        self.session_params = self.read_params_file()
1✔
22

23
        # TODO Which should be default?
24
        # Sync collection
25
        self.sync_collection = self.get_sync_collection(kwargs.get('sync_collection', None))
1✔
26
        # Sync type
27
        self.sync = self.get_sync(kwargs.get('sync', None))
1✔
28
        # Sync extension
29
        self.sync_ext = self.get_sync_extension(kwargs.get('sync_ext', None))
1✔
30
        # Sync namespace
31
        self.sync_namespace = self.get_sync_namespace(kwargs.get('sync_namespace', None))
1✔
32

33
    def get_sync_collection(self, sync_collection=None):
1✔
34
        return sync_collection if sync_collection else sess_params.get_sync_collection(self.session_params)
1✔
35

36
    def get_sync(self, sync=None):
1✔
37
        return sync if sync else sess_params.get_sync_label(self.session_params)
1✔
38

39
    def get_sync_extension(self, sync_ext=None):
1✔
40
        return sync_ext if sync_ext else sess_params.get_sync_extension(self.session_params)
1✔
41

42
    def get_sync_namespace(self, sync_namespace=None):
1✔
43
        return sync_namespace if sync_namespace else sess_params.get_sync_namespace(self.session_params)
1✔
44

45
    def get_protocol(self, protocol=None, task_collection=None):
1✔
46
        return protocol if protocol else sess_params.get_task_protocol(self.session_params, task_collection)
1✔
47

48
    def get_task_collection(self, collection=None):
1✔
49
        if not collection:
1✔
50
            collection = sess_params.get_task_collection(self.session_params)
1✔
51
        # If inferring the collection from the experiment description, assert only one returned
52
        assert collection is None or isinstance(collection, str) or len(collection) == 1
1✔
53
        return collection
1✔
54

55
    def get_device_collection(self, device, device_collection=None):
1✔
56
        if device_collection:
1✔
57
            return device_collection
1✔
58
        collection_map = sess_params.get_collections(self.session_params['devices'])
1✔
59
        return collection_map.get(device)
1✔
60

61
    def read_params_file(self):
1✔
62
        params = sess_params.read_params(self.session_path)
1✔
63

64
        if params is None:
1✔
65
            return {}
1✔
66

67
        # TODO figure out the best way
68
        # if params is None and self.one:
69
        #     # Try to read params from alyx or try to download params file
70
        #     params = self.one.load_dataset(self.one.path2eid(self.session_path), 'params.yml')
71
        #     params = self.one.alyx.rest()
72

73
        return params
1✔
74

75

76
class BehaviourTask(DynamicTask):
1✔
77

78
    def __init__(self, session_path, **kwargs):
1✔
79
        super().__init__(session_path, **kwargs)
1✔
80

81
        self.collection = self.get_task_collection(kwargs.get('collection', None))
1✔
82
        # Task type (protocol)
83
        self.protocol = self.get_protocol(kwargs.get('protocol', None), task_collection=self.collection)
1✔
84

85
        self.protocol_number = self.get_protocol_number(kwargs.get('protocol_number'), task_protocol=self.protocol)
1✔
86

87
        self.output_collection = 'alf'
1✔
88
        # Do not use kwargs.get('number', None) -- this will return None if number is 0
89
        if self.protocol_number is not None:
1✔
90
            self.output_collection += f'/task_{self.protocol_number:02}'
1✔
91

92
    def get_protocol(self, protocol=None, task_collection=None):
1✔
93
        return protocol if protocol else sess_params.get_task_protocol(self.session_params, task_collection)
1✔
94

95
    def get_task_collection(self, collection=None):
1✔
96
        if not collection:
1✔
97
            collection = sess_params.get_task_collection(self.session_params)
×
98
        # If inferring the collection from the experiment description, assert only one returned
99
        assert collection is None or isinstance(collection, str) or len(collection) == 1
1✔
100
        return collection
1✔
101

102
    def get_protocol_number(self, number=None, task_protocol=None):
1✔
103
        if number is None:  # Do not use "if not number" as that will return True if number is 0
1✔
104
            number = sess_params.get_task_protocol_number(self.session_params, task_protocol)
1✔
105
        # If inferring the number from the experiment description, assert only one returned (or something went wrong)
106
        assert number is None or isinstance(number, int)
1✔
107
        return number
1✔
108

109
    @staticmethod
1✔
110
    def _spacer_support(settings):
1✔
111
        """
112
        Spacer support was introduced in v7.1 for iblrig v7 and v8.0.1 in v8.
113

114
        Parameters
115
        ----------
116
        settings : dict
117
            The task settings dict.
118

119
        Returns
120
        -------
121
        bool
122
            True if task spacers are to be expected.
123
        """
124
        v = parse_version
1✔
125
        version = v(settings.get('IBLRIG_VERSION_TAG'))
1✔
126
        return version not in (v('100.0.0'), v('8.0.0')) and version >= v('7.1.0')
1✔
127

128

129
class VideoTask(DynamicTask):
1✔
130

131
    def __init__(self, session_path, cameras, **kwargs):
1✔
132
        super().__init__(session_path, cameras=cameras, **kwargs)
1✔
133
        self.cameras = cameras
1✔
134
        self.device_collection = self.get_device_collection('cameras', kwargs.get('device_collection', 'raw_video_data'))
1✔
135
        # self.collection = self.get_task_collection(kwargs.get('collection', None))
136

137

138
class AudioTask(DynamicTask):
1✔
139

140
    def __init__(self, session_path, **kwargs):
1✔
141
        super().__init__(session_path, **kwargs)
1✔
142
        self.device_collection = self.get_device_collection('microphone', kwargs.get('device_collection', 'raw_behavior_data'))
1✔
143

144

145
class EphysTask(DynamicTask):
1✔
146

147
    def __init__(self, session_path, **kwargs):
1✔
148
        super().__init__(session_path, **kwargs)
1✔
149

150
        self.pname = self.get_pname(kwargs.get('pname', None))
1✔
151
        self.nshanks, self.pextra = self.get_nshanks(kwargs.get('nshanks', None))
1✔
152
        self.device_collection = self.get_device_collection('neuropixel', kwargs.get('device_collection', 'raw_ephys_data'))
1✔
153

154
    def get_pname(self, pname):
1✔
155
        # pname can be a list or a string
156
        pname = self.kwargs.get('pname', pname)
1✔
157

158
        return pname
1✔
159

160
    def get_nshanks(self, nshanks=None):
1✔
161
        nshanks = self.kwargs.get('nshanks', nshanks)
1✔
162
        if nshanks is not None:
1✔
163
            pextra = [chr(97 + int(shank)) for shank in range(nshanks)]
1✔
164
        else:
165
            pextra = []
1✔
166

167
        return nshanks, pextra
1✔
168

169

170
class WidefieldTask(DynamicTask):
1✔
171
    def __init__(self, session_path, **kwargs):
1✔
172
        super().__init__(session_path, **kwargs)
1✔
173

174
        self.device_collection = self.get_device_collection('widefield', kwargs.get('device_collection', 'raw_widefield_data'))
1✔
175

176

177
class MesoscopeTask(DynamicTask):
1✔
178
    def __init__(self, session_path, **kwargs):
1✔
179
        super().__init__(session_path, **kwargs)
1✔
180

181
        self.device_collection = self.get_device_collection(
1✔
182
            'mesoscope', kwargs.get('device_collection', 'raw_imaging_data_[0-9]*'))
183

184
    def get_signatures(self, **kwargs):
1✔
185
        """
186
        From the template signature of the task, create the exact list of inputs and outputs to expect based on the
187
        available device collection folders
188

189
        Necessary because we don't know in advance how many device collection folders ("imaging bouts") to expect
190
        """
191
        self.session_path = Path(self.session_path)
1✔
192
        # Glob for all device collection (raw imaging data) folders
193
        raw_imaging_folders = [p.name for p in self.session_path.glob(self.device_collection)]
1✔
194
        # For all inputs and outputs that are part of the device collection, expand to one file per folder
195
        # All others keep unchanged
196
        self.input_files = [(sig[0], sig[1].replace(self.device_collection, folder), sig[2])
1✔
197
                            for folder in raw_imaging_folders for sig in self.signature['input_files']]
198
        self.output_files = [(sig[0], sig[1].replace(self.device_collection, folder), sig[2])
1✔
199
                             for folder in raw_imaging_folders for sig in self.signature['output_files']]
200

201
    def load_sync(self):
1✔
202
        """
203
        Load the sync and channel map.
204

205
        This method may be expanded to support other raw DAQ data formats.
206

207
        Returns
208
        -------
209
        one.alf.io.AlfBunch
210
            A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses
211
            and the corresponding channel numbers.
212
        dict
213
            A map of channel names and their corresponding indices.
214
        """
215
        alf_path = self.session_path / self.sync_collection
1✔
216
        if self.get_sync_namespace() == 'timeline':
1✔
217
            # Load the sync and channel map from the raw DAQ data
218
            sync, chmap = load_timeline_sync_and_chmap(alf_path)
1✔
219
        else:
220
            raise NotImplementedError
×
221
        return sync, chmap
1✔
222

223

224
class RegisterRawDataTask(DynamicTask):
1✔
225
    """
1✔
226
    Base register raw task.
227
    To rename files
228
     1. input and output must have the same length
229
     2. output files must have full filename
230
    """
231

232
    priority = 100
1✔
233
    job_size = 'small'
1✔
234

235
    def rename_files(self, symlink_old=False):
1✔
236

237
        # If either no inputs or no outputs are given, we don't do any renaming
238
        if not all(map(len, (self.input_files, self.output_files))):
1✔
239
            return
1✔
240

241
        # Otherwise we need to make sure there is one to one correspondence for renaming files
242
        assert len(self.input_files) == len(self.output_files)
1✔
243

244
        for before, after in zip(self.input_files, self.output_files):
1✔
245
            old_file, old_collection, required = before
1✔
246
            old_path = self.session_path.joinpath(old_collection).glob(old_file)
1✔
247
            old_path = next(old_path, None)
1✔
248
            # if the file doesn't exist and it is not required we are okay to continue
249
            if not old_path:
1✔
250
                if required:
×
251
                    raise FileNotFoundError(str(old_file))
×
252
                else:
253
                    continue
×
254

255
            new_file, new_collection, _ = after
1✔
256
            new_path = self.session_path.joinpath(new_collection, new_file)
1✔
257
            if old_path == new_path:
1✔
258
                continue
×
259
            new_path.parent.mkdir(parents=True, exist_ok=True)
1✔
260
            _logger.debug('%s -> %s', old_path.relative_to(self.session_path), new_path.relative_to(self.session_path))
1✔
261
            old_path.replace(new_path)
1✔
262
            if symlink_old:
1✔
263
                old_path.symlink_to(new_path)
1✔
264

265
    def register_snapshots(self, unlink=False, collection=None):
1✔
266
        """
267
        Register any photos in the snapshots folder to the session. Typically imaging users will
268
        take numerous photos for reference.  Supported extensions: .jpg, .jpeg, .png, .tif, .tiff
269

270
        If a .txt file with the same name exists in the same location, the contents will be added
271
        to the note text.
272

273
        Parameters
274
        ----------
275
        unlink : bool
276
            If true, files are deleted after upload.
277
        collection : str, list, optional
278
            Location of 'snapshots' folder relative to the session path. If None, uses
279
            'device_collection' attribute (if exists) or root session path.
280

281
        Returns
282
        -------
283
        list of dict
284
            The newly registered Alyx notes.
285
        """
286
        collection = getattr(self, 'device_collection', None) if collection is None else collection
1✔
287
        collection = collection or ''  # If not defined, use no collection
1✔
288
        if collection and '*' in collection:
1✔
289
            collection = [p.name for p in self.session_path.glob(collection)]
1✔
290
            # Check whether folders on disk contain '*'; this is to stop an infinite recursion
291
            assert not any('*' in c for c in collection), 'folders containing asterisks not supported'
1✔
292
        # If more that one collection exists, register snapshots in each collection
293
        if collection and not isinstance(collection, str):
1✔
294
            return flatten(filter(None, [self.register_snapshots(unlink, c) for c in collection]))
1✔
295
        snapshots_path = self.session_path.joinpath(*filter(None, (collection, 'snapshots')))
1✔
296
        if not snapshots_path.exists():
1✔
297
            return
1✔
298

299
        eid = self.one.path2eid(self.session_path, query_type='remote')
1✔
300
        if not eid:
1✔
301
            _logger.warning('Failed to upload snapshots: session not found on Alyx')
×
302
            return
×
303
        note = dict(user=self.one.alyx.user, content_type='session', object_id=eid, text='')
1✔
304

305
        notes = []
1✔
306
        exts = ('.jpg', '.jpeg', '.png', '.tif', '.tiff')
1✔
307
        for snapshot in filter(lambda x: x.suffix.lower() in exts, snapshots_path.glob('*.*')):
1✔
308
            _logger.debug('Uploading "%s"...', snapshot.relative_to(self.session_path))
1✔
309
            if snapshot.with_suffix('.txt').exists():
1✔
310
                with open(snapshot.with_suffix('.txt'), 'r') as txt_file:
1✔
311
                    note['text'] = txt_file.read().strip()
1✔
312
            else:
313
                note['text'] = ''
1✔
314
            with open(snapshot, 'rb') as img_file:
1✔
315
                files = {'image': img_file}
1✔
316
                notes.append(self.one.alyx.rest('notes', 'create', data=note, files=files))
1✔
317
            if unlink:
1✔
318
                snapshot.unlink()
×
319
        # If nothing else in the snapshots folder, delete the folder
320
        if unlink and next(snapshots_path.rglob('*'), None) is None:
1✔
321
            snapshots_path.rmdir()
×
322
        _logger.info('%i snapshots uploaded to Alyx', len(notes))
1✔
323
        return notes
1✔
324

325
    def _run(self, **kwargs):
1✔
326
        self.rename_files(**kwargs)
1✔
327
        out_files = []
1✔
328
        n_required = 0
1✔
329
        for file_sig in self.output_files:
1✔
330
            file_name, collection, required = file_sig
1✔
331
            n_required += required
1✔
332
            file_path = self.session_path.joinpath(collection).glob(file_name)
1✔
333
            file_path = next(file_path, None)
1✔
334
            if not file_path and not required:
1✔
335
                continue
1✔
336
            elif not file_path and required:
1✔
337
                _logger.error(f'expected {file_sig} missing')
×
338
            else:
339
                out_files.append(file_path)
1✔
340

341
        if len(out_files) < n_required:
1✔
342
            self.status = -1
×
343

344
        return out_files
1✔
345

346

347
class ExperimentDescriptionRegisterRaw(RegisterRawDataTask):
1✔
348
    """dict of list: custom sign off keys corresponding to specific devices"""
1✔
349
    sign_off_categories = SIGN_OFF_CATEGORIES
1✔
350

351
    @property
1✔
352
    def signature(self):
1✔
353
        signature = {
×
354
            'input_files': [],
355
            'output_files': [('*experiment.description.yaml', '', True)]
356
        }
357
        return signature
×
358

359
    def _run(self, **kwargs):
1✔
360
        # Register experiment description file
361
        out_files = super(ExperimentDescriptionRegisterRaw, self)._run(**kwargs)
×
362
        if not self.one.offline and self.status == 0:
×
363
            with no_cache(self.one.alyx):  # Ensure we don't load the cached JSON response
×
364
                eid = self.one.path2eid(self.session_path, query_type='remote')
×
365
            exp_dec = sess_params.read_params(out_files[0])
×
366
            data = sign_off_dict(exp_dec, sign_off_categories=self.sign_off_categories)
×
367
            self.one.alyx.json_field_update('sessions', eid, data=data)
×
368
        return out_files
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc