• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12880

pending completion
#12880

push

travis-ci

web-flow
Merge pull request #1067 from CBroz1/master

Add support for insert CSV

4 of 4 new or added lines in 1 file covered. (100.0%)

3102 of 3424 relevant lines covered (90.6%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

82.86
/datajoint/autopopulate.py
1
"""This module defines class dj.AutoPopulate"""
2
import logging
1✔
3
import datetime
1✔
4
import traceback
1✔
5
import random
1✔
6
import inspect
1✔
7
from tqdm import tqdm
1✔
8
from .hash import key_hash
1✔
9
from .expression import QueryExpression, AndList
1✔
10
from .errors import DataJointError, LostConnectionError
1✔
11
import signal
1✔
12
import multiprocessing as mp
1✔
13
import contextlib
1✔
14

15
# noinspection PyExceptionInherit,PyCallingNonCallable
16

17
logger = logging.getLogger(__name__.split(".")[0])
1✔
18

19

20
# --- helper functions for multiprocessing --
21

22

23
def _initialize_populate(table, jobs, populate_kwargs):
1✔
24
    """
25
    Initialize the process for mulitprocessing.
26
    Saves the unpickled copy of the table to the current process and reconnects.
27
    """
28
    process = mp.current_process()
×
29
    process.table = table
×
30
    process.jobs = jobs
×
31
    process.populate_kwargs = populate_kwargs
×
32
    table.connection.connect()  # reconnect
×
33

34

35
def _call_populate1(key):
1✔
36
    """
37
    Call current process' table._populate1()
38
    :key - a dict specifying job to compute
39
    :return: key, error if error, otherwise None
40
    """
41
    process = mp.current_process()
×
42
    return process.table._populate1(key, process.jobs, **process.populate_kwargs)
×
43

44

45
class AutoPopulate:
1✔
46
    """
47
    AutoPopulate is a mixin class that adds the method populate() to a Table class.
48
    Auto-populated tables must inherit from both Table and AutoPopulate,
49
    must define the property `key_source`, and must define the callback method `make`.
50
    """
51

52
    _key_source = None
1✔
53
    _allow_insert = False
1✔
54

55
    @property
1✔
56
    def key_source(self):
1✔
57
        """
58
        :return: the query expression that yields primary key values to be passed,
59
        sequentially, to the ``make`` method when populate() is called.
60
        The default value is the join of the parent tables references from the primary key.
61
        Subclasses may override they key_source to change the scope or the granularity
62
        of the make calls.
63
        """
64

65
        def _rename_attributes(table, props):
1✔
66
            return (
1✔
67
                table.proj(
68
                    **{
69
                        attr: ref
70
                        for attr, ref in props["attr_map"].items()
71
                        if attr != ref
72
                    }
73
                )
74
                if props["aliased"]
75
                else table.proj()
76
            )
77

78
        if self._key_source is None:
1✔
79
            parents = self.target.parents(
1✔
80
                primary=True, as_objects=True, foreign_key_info=True
81
            )
82
            if not parents:
1✔
83
                raise DataJointError(
×
84
                    "A table must have dependencies "
85
                    "from its primary key for auto-populate to work"
86
                )
87
            self._key_source = _rename_attributes(*parents[0])
1✔
88
            for q in parents[1:]:
1✔
89
                self._key_source *= _rename_attributes(*q)
1✔
90
        return self._key_source
1✔
91

92
    def make(self, key):
1✔
93
        """
94
        Derived classes must implement method `make` that fetches data from tables
95
        above them in the dependency hierarchy, restricting by the given key,
96
        computes secondary attributes, and inserts the new tuples into self.
97
        """
98
        raise NotImplementedError(
×
99
            "Subclasses of AutoPopulate must implement the method `make`"
100
        )
101

102
    @property
1✔
103
    def target(self):
1✔
104
        """
105
        :return: table to be populated.
106
        In the typical case, dj.AutoPopulate is mixed into a dj.Table class by
107
        inheritance and the target is self.
108
        """
109
        return self
1✔
110

111
    def _job_key(self, key):
1✔
112
        """
113
        :param key:  they key returned for the job from the key source
114
        :return: the dict to use to generate the job reservation hash
115
        This method allows subclasses to control the job reservation granularity.
116
        """
117
        return key
1✔
118

119
    def _jobs_to_do(self, restrictions):
1✔
120
        """
121
        :return: the query yeilding the keys to be computed (derived from self.key_source)
122
        """
123
        if self.restriction:
1✔
124
            raise DataJointError(
×
125
                "Cannot call populate on a restricted table. "
126
                "Instead, pass conditions to populate() as arguments."
127
            )
128
        todo = self.key_source
1✔
129

130
        # key_source is a QueryExpression subclass -- trigger instantiation
131
        if inspect.isclass(todo) and issubclass(todo, QueryExpression):
1✔
132
            todo = todo()
1✔
133

134
        if not isinstance(todo, QueryExpression):
1✔
135
            raise DataJointError("Invalid key_source value")
×
136

137
        try:
1✔
138
            # check if target lacks any attributes from the primary key of key_source
139
            raise DataJointError(
1✔
140
                "The populate target lacks attribute %s "
141
                "from the primary key of key_source"
142
                % next(
143
                    name
144
                    for name in todo.heading.primary_key
145
                    if name not in self.target.heading
146
                )
147
            )
148
        except StopIteration:
1✔
149
            pass
1✔
150
        return (todo & AndList(restrictions)).proj()
1✔
151

152
    def populate(
1✔
153
        self,
154
        *restrictions,
155
        suppress_errors=False,
156
        return_exception_objects=False,
157
        reserve_jobs=False,
158
        order="original",
159
        limit=None,
160
        max_calls=None,
161
        display_progress=False,
162
        processes=1,
163
        make_kwargs=None,
164
    ):
165
        """
166
        ``table.populate()`` calls ``table.make(key)`` for every primary key in
167
        ``self.key_source`` for which there is not already a tuple in table.
168

169
        :param restrictions: a list of restrictions each restrict
170
            (table.key_source - target.proj())
171
        :param suppress_errors: if True, do not terminate execution.
172
        :param return_exception_objects: return error objects instead of just error messages
173
        :param reserve_jobs: if True, reserve jobs to populate in asynchronous fashion
174
        :param order: "original"|"reverse"|"random"  - the order of execution
175
        :param limit: if not None, check at most this many keys
176
        :param max_calls: if not None, populate at most this many keys
177
        :param display_progress: if True, report progress_bar
178
        :param processes: number of processes to use. Set to None to use all cores
179
        :param make_kwargs: Keyword arguments which do not affect the result of computation
180
            to be passed down to each ``make()`` call. Computation arguments should be
181
            specified within the pipeline e.g. using a `dj.Lookup` table.
182
        :type make_kwargs: dict, optional
183
        """
184
        if self.connection.in_transaction:
1✔
185
            raise DataJointError("Populate cannot be called during a transaction.")
×
186

187
        valid_order = ["original", "reverse", "random"]
1✔
188
        if order not in valid_order:
1✔
189
            raise DataJointError(
×
190
                "The order argument must be one of %s" % str(valid_order)
191
            )
192
        jobs = (
1✔
193
            self.connection.schemas[self.target.database].jobs if reserve_jobs else None
194
        )
195

196
        # define and set up signal handler for SIGTERM:
197
        if reserve_jobs:
1✔
198

199
            def handler(signum, frame):
1✔
200
                logger.info("Populate terminated by SIGTERM")
×
201
                raise SystemExit("SIGTERM received")
×
202

203
            old_handler = signal.signal(signal.SIGTERM, handler)
1✔
204

205
        keys = (self._jobs_to_do(restrictions) - self.target).fetch("KEY", limit=limit)
1✔
206

207
        # exclude "error" or "ignore" jobs
208
        if reserve_jobs:
1✔
209
            exclude_key_hashes = (
1✔
210
                jobs
211
                & {"table_name": self.target.table_name}
212
                & 'status in ("error", "ignore")'
213
            ).fetch("key_hash")
214
            keys = [key for key in keys if key_hash(key) not in exclude_key_hashes]
1✔
215

216
        if order == "reverse":
1✔
217
            keys.reverse()
×
218
        elif order == "random":
1✔
219
            random.shuffle(keys)
×
220

221
        logger.debug("Found %d keys to populate" % len(keys))
1✔
222

223
        keys = keys[:max_calls]
1✔
224
        nkeys = len(keys)
1✔
225
        if not nkeys:
1✔
226
            return
1✔
227

228
        processes = min(_ for _ in (processes, nkeys, mp.cpu_count()) if _)
1✔
229

230
        error_list = []
1✔
231
        populate_kwargs = dict(
1✔
232
            suppress_errors=suppress_errors,
233
            return_exception_objects=return_exception_objects,
234
            make_kwargs=make_kwargs,
235
        )
236

237
        if processes == 1:
1✔
238
            for key in (
1✔
239
                tqdm(keys, desc=self.__class__.__name__) if display_progress else keys
240
            ):
241
                error = self._populate1(key, jobs, **populate_kwargs)
1✔
242
                if error is not None:
1✔
243
                    error_list.append(error)
1✔
244
        else:
245
            # spawn multiple processes
246
            self.connection.close()  # disconnect parent process from MySQL server
1✔
247
            del self.connection._conn.ctx  # SSLContext is not pickleable
1✔
248
            with mp.Pool(
1✔
249
                processes, _initialize_populate, (self, jobs, populate_kwargs)
250
            ) as pool, (
251
                tqdm(desc="Processes: ", total=nkeys)
252
                if display_progress
253
                else contextlib.nullcontext()
254
            ) as progress_bar:
255
                for error in pool.imap(_call_populate1, keys, chunksize=1):
1✔
256
                    if error is not None:
1✔
257
                        error_list.append(error)
×
258
                    if display_progress:
1✔
259
                        progress_bar.update()
×
260
            self.connection.connect()  # reconnect parent process to MySQL server
1✔
261

262
        # restore original signal handler:
263
        if reserve_jobs:
1✔
264
            signal.signal(signal.SIGTERM, old_handler)
1✔
265

266
        if suppress_errors:
1✔
267
            return error_list
1✔
268

269
    def _populate1(
1✔
270
        self, key, jobs, suppress_errors, return_exception_objects, make_kwargs=None
271
    ):
272
        """
273
        populates table for one source key, calling self.make inside a transaction.
274
        :param jobs: the jobs table or None if not reserve_jobs
275
        :param key: dict specifying job to populate
276
        :param suppress_errors: bool if errors should be suppressed and returned
277
        :param return_exception_objects: if True, errors must be returned as objects
278
        :return: (key, error) when suppress_errors=True, otherwise None
279
        """
280
        make = self._make_tuples if hasattr(self, "_make_tuples") else self.make
1✔
281

282
        if jobs is None or jobs.reserve(self.target.table_name, self._job_key(key)):
1✔
283
            self.connection.start_transaction()
1✔
284
            if key in self.target:  # already populated
1✔
285
                self.connection.cancel_transaction()
×
286
                if jobs is not None:
×
287
                    jobs.complete(self.target.table_name, self._job_key(key))
×
288
            else:
289
                logger.debug(f"Making {key} -> {self.target.full_table_name}")
1✔
290
                self.__class__._allow_insert = True
1✔
291
                try:
1✔
292
                    make(dict(key), **(make_kwargs or {}))
1✔
293
                except (KeyboardInterrupt, SystemExit, Exception) as error:
1✔
294
                    try:
1✔
295
                        self.connection.cancel_transaction()
1✔
296
                    except LostConnectionError:
×
297
                        pass
×
298
                    error_message = "{exception}{msg}".format(
1✔
299
                        exception=error.__class__.__name__,
300
                        msg=": " + str(error) if str(error) else "",
301
                    )
302
                    logger.debug(
1✔
303
                        f"Error making {key} -> {self.target.full_table_name} - {error_message}"
304
                    )
305
                    if jobs is not None:
1✔
306
                        # show error name and error message (if any)
307
                        jobs.error(
1✔
308
                            self.target.table_name,
309
                            self._job_key(key),
310
                            error_message=error_message,
311
                            error_stack=traceback.format_exc(),
312
                        )
313
                    if not suppress_errors or isinstance(error, SystemExit):
1✔
314
                        raise
1✔
315
                    else:
316
                        logger.error(error)
1✔
317
                        return key, error if return_exception_objects else error_message
1✔
318
                else:
319
                    self.connection.commit_transaction()
1✔
320
                    logger.debug(
1✔
321
                        f"Success making {key} -> {self.target.full_table_name}"
322
                    )
323
                    if jobs is not None:
1✔
324
                        jobs.complete(self.target.table_name, self._job_key(key))
1✔
325
                finally:
326
                    self.__class__._allow_insert = False
1✔
327

328
    def progress(self, *restrictions, display=True):
1✔
329
        """
330
        Report the progress of populating the table.
331
        :return: (remaining, total) -- numbers of tuples to be populated
332
        """
333
        todo = self._jobs_to_do(restrictions)
1✔
334
        total = len(todo)
1✔
335
        remaining = len(todo - self.target)
1✔
336
        if display:
1✔
337
            print(
1✔
338
                "%-20s" % self.__class__.__name__,
339
                "Completed %d of %d (%2.1f%%)   %s"
340
                % (
341
                    total - remaining,
342
                    total,
343
                    100 - 100 * remaining / (total + 1e-12),
344
                    datetime.datetime.strftime(
345
                        datetime.datetime.now(), "%Y-%m-%d %H:%M:%S"
346
                    ),
347
                ),
348
                flush=True,
349
            )
350
        return remaining, total
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc