• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

netzkolchose / django-computedfields / 16733364684

04 Aug 2025 08:21PM UTC coverage: 94.464% (-1.3%) from 95.75%
16733364684

Pull #197

github

web-flow
Merge 033fa8d63 into 1259b59f9
Pull Request #197: performance optimizations

546 of 594 branches covered (91.92%)

Branch coverage included in aggregate %.

72 of 94 new or added lines in 3 files covered. (76.6%)

20 existing lines in 1 file now uncovered.

1331 of 1393 relevant lines covered (95.55%)

11.46 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.38
/computedfields/resolver.py
1
"""
2
Contains the resolver logic for automated computed field updates.
3
"""
4
from .thread_locals import get_not_computed_context, set_not_computed_context
12✔
5
import operator
12✔
6
from functools import reduce
12✔
7
from collections import defaultdict
12✔
8

9
from django.db import transaction
12✔
10
from django.db.models import QuerySet
12✔
11

12
from .settings import settings
12✔
13
from .graph import ComputedModelsGraph, ComputedFieldsException, Graph, ModelGraph, IM2mMap
12✔
14
from .helpers import proxy_to_base_model, slice_iterator, subquery_pk, are_same, frozenset_none
12✔
15
from . import __version__
12✔
16
from .signals import resolver_start, resolver_exit, resolver_update
12✔
17

18
from . import backends
12✔
19

20
# typing imports
21
from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, Sequence, Set,
12✔
22
                    Tuple, Type, Union, cast, overload, FrozenSet)
23
from django.db.models import Field, Model
12✔
24
from .graph import (IComputedField, IDepends, IFkMap, ILocalMroMap, ILookupMap, _ST, _GT, F,
12✔
25
                    IRecorded, IRecordedStrict, IModelUpdate, IModelUpdateCache)
26

27

28
MALFORMED_DEPENDS = """
12✔
29
Your depends keyword argument is malformed.
30

31
The depends keyword should either be None, an empty listing or
32
a listing of rules as depends=[rule1, rule2, .. ruleN].
33

34
A rule is formed as ('relation.path', ['list', 'of', 'fieldnames']) tuple.
35
The relation path either contains 'self' for fieldnames on the same model,
36
or a string as 'a.b.c', where 'a' is a relation on the current model
37
descending over 'b' to 'c' to pull fieldnames from 'c'. The denoted fieldnames
38
must be concrete fields on the rightmost model of the relation path.
39

40
Example:
41
depends=[
42
    ('self', ['name', 'status']),
43
    ('parent.color', ['value'])
44
]
45
This has 2 path rules - one for fields 'name' and 'status' on the same model,
46
and one to a field 'value' on a foreign model, which is accessible from
47
the current model through self -> parent -> color relation.
48
"""
49

50

51
class ResolverException(ComputedFieldsException):
12✔
52
    """
53
    Exception raised during model and field registration or dependency resolving.
54
    """
55

56

57
class Resolver:
12✔
58
    """
59
    Holds the needed data for graph calculations and runtime dependency resolving.
60

61
    Basic workflow:
62

63
        - On django startup a resolver gets instantiated early to track all project-wide
64
          model registrations and computed field decorations (collector phase).
65
        - On `app.ready` the computed fields are associated with their models to build
66
          a resolver-wide map of models with computed fields (``computed_models``).
67
        - After that the resolver maps are created (see `graph.ComputedModelsGraph`).
68
    """
69

70
    def __init__(self):
12✔
71
        # collector phase data
72
        #: Models from `class_prepared` signal hook during collector phase.
73
        self.models: Set[Type[Model]] = set()
12✔
74
        #: Computed fields found during collector phase.
75
        self.computedfields: Set[IComputedField] = set()
12✔
76

77
        # resolving phase data and final maps
78
        self._graph: Optional[ComputedModelsGraph] = None
12✔
79
        self._computed_models: Dict[Type[Model], Dict[str, IComputedField]] = {}
12✔
80
        self._map: ILookupMap = {}
12✔
81
        self._fk_map: IFkMap = {}
12✔
82
        self._local_mro: ILocalMroMap = {}
12✔
83
        self._m2m: IM2mMap = {}
12✔
84
        self._proxymodels: Dict[Type[Model], Type[Model]] = {}
12✔
85
        self._batchsize: int = settings.COMPUTEDFIELDS_BATCHSIZE
12✔
86
        self._update_backend: str = settings.COMPUTEDFIELDS_UPDATE_BACKEND
12✔
87
        self._update = getattr(backends, self._update_backend)
12✔
88

89
        # some internal states
90
        self._sealed: bool = False        # initial boot phase
12✔
91
        self._initialized: bool = False   # initialized (computed_models populated)?
12✔
92
        self._map_loaded: bool = False    # final stage with fully loaded maps
12✔
93

94
        # runtime caches
95
        self._cached_updates: IModelUpdateCache = defaultdict(dict)
12✔
96
        self._cached_mro = defaultdict(dict)
12✔
97
        self._cached_select_related = defaultdict(dict)
12✔
98
        self._cached_prefetch_related = defaultdict(dict)
12✔
99
        self._cached_querysize = defaultdict(lambda: defaultdict(dict))
12✔
100

101
    def add_model(self, sender: Type[Model], **kwargs) -> None:
12✔
102
        """
103
        `class_prepared` signal hook to collect models during ORM registration.
104
        """
105
        if self._sealed:
12✔
106
            raise ResolverException('cannot add models on sealed resolver')
12✔
107
        self.models.add(sender)
12✔
108

109
    def add_field(self, field: IComputedField) -> None:
12✔
110
        """
111
        Collects fields from decoration stage of @computed.
112
        """
113
        if self._sealed:
12✔
114
            raise ResolverException('cannot add computed fields on sealed resolver')
12✔
115
        self.computedfields.add(field)
12✔
116

117
    def seal(self) -> None:
12✔
118
        """
119
        Seal the resolver, so no new models or computed fields can be added anymore.
120

121
        This marks the end of the collector phase and is a basic security measure
122
        to catch runtime model creations with computed fields.
123

124
        (Currently runtime creation of models with computed fields is not supported,
125
        trying to do so will raise an exception. This might change in future versions.)
126
        """
127
        self._sealed = True
12✔
128

129
    @property
12✔
130
    def models_with_computedfields(self) -> Generator[Tuple[Type[Model], Set[IComputedField]], None, None]:
12✔
131
        """
132
        Generator of tracked models with their computed fields.
133

134
        This cannot be accessed during the collector phase.
135
        """
136
        if not self._sealed:
12✔
137
            raise ResolverException('resolver must be sealed before accessing models or fields')
12✔
138

139
        field_ids: List[int] = [f.creation_counter for f in self.computedfields]
12✔
140
        for model in self.models:
12✔
141
            fields = set()
12✔
142
            for field in model._meta.fields:
12✔
143
                # for some reason the in ... check does not work for Django >= 3.2 anymore
144
                # workaround: check for _computed and the field creation_counter
145
                if hasattr(field, '_computed') and field.creation_counter in field_ids:
12✔
146
                    fields.add(field)
12✔
147
            if fields:
12✔
148
                yield (model, cast(Set[IComputedField], fields))
12✔
149

150
    @property
12✔
151
    def computedfields_with_models(self) -> Generator[Tuple[IComputedField, Set[Type[Model]]], None, None]:
12✔
152
        """
153
        Generator of tracked computed fields and their models.
154

155
        This cannot be accessed during the collector phase.
156
        """
157
        if not self._sealed:
12✔
158
            raise ResolverException('resolver must be sealed before accessing models or fields')
12✔
159

160
        for field in self.computedfields:
12✔
161
            models = set()
12✔
162
            for model in self.models:
12✔
163
                for f in model._meta.fields:
12✔
164
                    if hasattr(field, '_computed') and f.creation_counter == field.creation_counter:
12✔
165
                        models.add(model)
12✔
166
            yield (field, models)
12✔
167

168
    @property
12✔
169
    def computed_models(self) -> Dict[Type[Model], Dict[str, IComputedField]]:
12✔
170
        """
171
        Mapping of `ComputedFieldModel` models and their computed fields.
172

173
        The data is the single source of truth for the graph reduction and
174
        map creations. Thus it can be used to decide at runtime whether
175
        the active resolver respects a certain model with computed fields.
176
        
177
        .. NOTE::
178
        
179
            The resolver will only list models here, that actually have
180
            a computed field defined. A model derived from `ComputedFieldsModel`
181
            without a computed field will not be listed.
182
        """
183
        if self._initialized:
12✔
184
            return self._computed_models
12✔
185
        raise ResolverException('resolver is not properly initialized')
12✔
186

187
    def extract_computed_models(self) -> Dict[Type[Model], Dict[str, IComputedField]]:
12✔
188
        """
189
        Creates `computed_models` mapping from models and computed fields
190
        found in collector phase.
191
        """
192
        computed_models: Dict[Type[Model], Dict[str, IComputedField]] = {}
12✔
193
        for model, computedfields in self.models_with_computedfields:
12✔
194
            if not issubclass(model, _ComputedFieldsModelBase):
12✔
195
                raise ResolverException(f'{model} is not a subclass of ComputedFieldsModel')
12✔
196
            computed_models[model] = {}
12✔
197
            for field in computedfields:
12✔
198
                computed_models[model][field.name] = field
12✔
199

200
        return computed_models
12✔
201

202
    def initialize(self, models_only: bool = False) -> None:
12✔
203
        """
204
        Entrypoint for ``app.ready`` to seal the resolver and trigger
205
        the resolver map creation.
206

207
        Upon instantiation the resolver is in the collector phase, where it tracks
208
        model registrations and computed field decorations.
209

210
        After calling ``initialize`` no more models or fields can be registered
211
        to the resolver, and ``computed_models`` and the resolver maps get loaded.
212
        """
213
        # resolver must be sealed before doing any map calculations
214
        self.seal()
12✔
215
        self._computed_models = self.extract_computed_models()
12✔
216
        self._initialized = True
12✔
217
        if not models_only:
12✔
218
            self.load_maps()
12✔
219

220
    def load_maps(self, _force_recreation: bool = False) -> None:
12✔
221
        """
222
        Load all needed resolver maps. The steps are:
223

224
            - create intermodel graph of the dependencies
225
            - remove redundant paths with cycling check
226
            - create modelgraphs for local MRO
227
            - merge graphs to uniongraph with cycling check
228
            - create final resolver maps
229

230
                - `lookup_map`: intermodel dependencies as queryset access strings
231
                - `fk_map`: models with their contributing fk fields
232
                - `local_mro`: MRO of local computed fields per model
233
        """
234
        self._graph = ComputedModelsGraph(self.computed_models)
12✔
235
        if not getattr(settings, 'COMPUTEDFIELDS_ALLOW_RECURSION', False):
12✔
236
            self._graph.get_edgepaths()
12✔
237
            self._graph.get_uniongraph().get_edgepaths()
12✔
238
        self._map, self._fk_map = self._graph.generate_maps()
12✔
239
        self._local_mro = self._graph.generate_local_mro_map()
12✔
240
        self._m2m = self._graph._m2m
12✔
241
        self._patch_proxy_models()
12✔
242
        self._map_loaded = True
12✔
243
        self._clear_runtime_caches()
12✔
244

245
    def _clear_runtime_caches(self):
12✔
246
        """
247
        Clear all runtime caches.
248
        """
249
        self._cached_updates.clear()
12✔
250
        self._cached_mro.clear()
12✔
251
        self._cached_select_related.clear()
12✔
252
        self._cached_prefetch_related.clear()
12✔
253
        self._cached_querysize.clear()
12✔
254

255
    def _patch_proxy_models(self) -> None:
12✔
256
        """
257
        Patch proxy models into the resolver maps.
258
        """
259
        for model in self.models:
12✔
260
            if model._meta.proxy:
12✔
261
                basemodel = proxy_to_base_model(model)
12✔
262
                if basemodel in self._map:
12✔
263
                    self._map[model] = self._map[basemodel]
12✔
264
                if basemodel in self._fk_map:
12✔
265
                    self._fk_map[model] = self._fk_map[basemodel]
12✔
266
                if basemodel in self._local_mro:
12✔
267
                    self._local_mro[model] = self._local_mro[basemodel]
12✔
268
                if basemodel in self._m2m:
12!
UNCOV
269
                    self._m2m[model] = self._m2m[basemodel]
×
270
                self._proxymodels[model] = basemodel or model
12✔
271

272
    def get_local_mro(
12✔
273
        self,
274
        model: Type[Model],
275
        update_fields: Optional[FrozenSet[str]] = None
276
    ) -> List[str]:
277
        """
278
        Return `MRO` for local computed field methods for a given set of `update_fields`.
279
        The returned list of fieldnames must be calculated in order to correctly update
280
        dependent computed field values in one pass.
281

282
        Returns computed fields as self dependent to simplify local field dependency calculation.
283
        """
284
        try:
12✔
285
            return self._cached_mro[model][update_fields]
12✔
286
        except KeyError:
12✔
287
            pass
12✔
288
        entry = self._local_mro.get(model)
12✔
289
        if not entry:
12✔
290
            self._cached_mro[model][update_fields] = []
12✔
291
            return []
12✔
292
        if update_fields is None:
12✔
293
            self._cached_mro[model][update_fields] = entry['base']
12✔
294
            return entry['base']
12✔
295
        base = entry['base']
12✔
296
        fields = entry['fields']
12✔
297
        mro = 0
12✔
298
        for field in update_fields:
12✔
299
            mro |= fields.get(field, 0)
12✔
300
        result = [name for pos, name in enumerate(base) if mro & (1 << pos)]
12✔
301
        self._cached_mro[model][update_fields] = result
12✔
302
        return result
12✔
303

304
    def get_model_updates(
12✔
305
        self,
306
        model: Type[Model],
307
        update_fields: Optional[FrozenSet[str]] = None
308
    ) -> IModelUpdate:
309
        """
310
        For a given model and updated fields this method
311
        returns a dictionary with dependent models (keys) and a tuple
312
        with dependent fields and the queryset accessor string (value).
313
        """
314
        try:
12✔
315
            return self._cached_updates[model][update_fields]
12✔
316
        except KeyError:
12✔
317
            pass
12✔
318
        modeldata = self._map.get(model)
12✔
319
        if not modeldata:
12✔
320
            self._cached_updates[model][update_fields] = {}
12✔
321
            return {}
12✔
322
        if not update_fields:
12✔
323
            updates: Set[str] = set(modeldata.keys())
12✔
324
        else:
325
            updates = set()
12✔
326
            for fieldname in update_fields:
12✔
327
                if fieldname in modeldata:
12✔
328
                    updates.add(fieldname)
12✔
329
        model_updates: IModelUpdate = defaultdict(lambda: (set(), set()))
12✔
330
        for update in updates:
12✔
331
            # aggregate fields and paths to cover
332
            # multiple comp field dependencies
333
            for m, r in modeldata[update].items():
12✔
334
                fields, paths = r
12✔
335
                m_fields, m_paths = model_updates[m]
12✔
336
                m_fields.update(fields)
12✔
337
                m_paths.update(paths)
12✔
338
        self._cached_updates[model][update_fields] = model_updates
12✔
339
        return model_updates
12✔
340

341
    def _querysets_for_update(
12✔
342
        self,
343
        model: Type[Model],
344
        instance: Union[Model, QuerySet],
345
        update_fields: Optional[Iterable[str]] = None,
346
        pk_list: bool = False,
347
    ) -> Dict[Type[Model], List[Any]]:
348
        """
349
        Returns a mapping of all dependent models, dependent fields and a
350
        queryset containing all dependent objects.
351
        """
352
        final: Dict[Type[Model], List[Any]] = {}
12✔
353
        model_updates = self.get_model_updates(model, frozenset_none(update_fields))
12✔
354
        if not model_updates:
12✔
355
            return final
12✔
356

357
        subquery = '__in' if isinstance(instance, QuerySet) else ''
12✔
358
        # fix #100
359
        # mysql does not support 'LIMIT & IN/ALL/ANY/SOME subquery'
360
        # thus we extract pks explicitly instead
361
        real_inst: Union[Model, QuerySet, Set[Any]] = instance
12✔
362
        if isinstance(instance, QuerySet):
12✔
363
            from django.db import connections
12✔
364
            if not instance.query.can_filter() and connections[instance.db].vendor == 'mysql':
12!
UNCOV
365
                real_inst = set(instance.values_list('pk', flat=True).iterator())
×
366

367
        # generate narrowed down querysets for all cf dependencies
368
        for m, data in model_updates.items():
12✔
369
            fields, paths = data
12✔
370
            queryset: Union[QuerySet, Set[Any]] = m._base_manager.none()
12✔
371
            query_pipe_method = self._choose_optimal_query_pipe_method(paths)
12✔
372
            queryset = reduce(
12✔
373
                query_pipe_method,
374
                (m._base_manager.filter(**{path+subquery: real_inst}) for path in paths),
375
                queryset
376
            )
377
            if pk_list:
12✔
378
                # need pks for post_delete since the real queryset will be empty
379
                # after deleting the instance in question
380
                # since we need to interact with the db anyways
381
                # we can already drop empty results here
382
                queryset = set(queryset.values_list('pk', flat=True).iterator())
12✔
383
                if not queryset:
12✔
384
                    continue
12✔
385
            # FIXME: change to tuple or dict for narrower type
386
            final[m] = [queryset, fields]
12✔
387
        return final
12✔
388
    
389
    def _get_model(self, instance: Union[Model, QuerySet]) -> Type[Model]:
12✔
390
        return instance.model if isinstance(instance, QuerySet) else type(instance)
12✔
391

392
    def _choose_optimal_query_pipe_method(self, paths: Set[str]) -> Callable:
12✔
393
        """
394
            Choose optimal pipe method, to combine querystes.
395
            Returns `|` if there are only one element or the difference is only the fields name, on the same path.
396
            Otherwise, return union.
397
        """
398
        if len(paths) == 1:
12✔
399
            return operator.or_
12✔
400
        else:
401
            paths_by_parts = tuple(path.split("__") for path in paths)
12✔
402
            if are_same(*(len(path_in_parts) for path_in_parts in paths_by_parts)):
12✔
403
                max_depth = len(paths_by_parts[0]) - 1
12✔
404
                for depth, paths_parts in enumerate(zip(*paths_by_parts)):
12!
405
                    if are_same(*paths_parts):
12✔
406
                        pass
12✔
407
                    else:
408
                        if depth == max_depth:
12✔
409
                            return operator.or_
12✔
410
                        else:
411
                            break
12✔
412
        return lambda x, y: x.union(y)
12✔
413

414
    def preupdate_dependent(
12✔
415
        self,
416
        instance: Union[QuerySet, Model],
417
        model: Optional[Type[Model]] = None,
418
        update_fields: Optional[Iterable[str]] = None,
419
    ) -> Dict[Type[Model], List[Any]]:
420
        """
421
        Create a mapping of currently associated computed field records,
422
        that might turn dirty by a follow-up bulk action.
423

424
        Feed the mapping back to ``update_dependent`` as `old` argument
425
        after your bulk action to update de-associated computed field records as well.
426
        """
427
        result = self._querysets_for_update(
12✔
428
            model or self._get_model(instance), instance, update_fields, pk_list=True)
429

430
        # exit empty, if we are in not_computed context
431
        if ctx := get_not_computed_context():
12✔
432
            if result and ctx.recover:
12✔
433
                ctx.record_querysets(result)
12✔
434
            return {}
12✔
435
        return result
12✔
436

437
    def update_dependent(
12✔
438
        self,
439
        instance: Union[QuerySet, Model],
440
        model: Optional[Type[Model]] = None,
441
        update_fields: Optional[Iterable[str]] = None,
442
        old: Optional[Dict[Type[Model], List[Any]]] = None,
443
        update_local: bool = True,
444
        querysize: Optional[int] = None,
445
        _is_recursive: bool = False
446
    ) -> None:
447
        """
448
        Updates all dependent computed fields on related models traversing
449
        the dependency tree as shown in the graphs.
450

451
        This is the main entry hook of the resolver to do updates on dependent
452
        computed fields during runtime. While this is done automatically for
453
        model instance actions from signal handlers, you have to call it yourself
454
        after changes done by bulk actions.
455

456
        To do that, simply call this function after the update with the queryset
457
        containing the changed objects:
458

459
            >>> Entry.objects.filter(pub_date__year=2010).update(comments_on=False)
460
            >>> update_dependent(Entry.objects.filter(pub_date__year=2010))
461

462
        This can also be used with ``bulk_create``. Since ``bulk_create``
463
        returns the objects in a python container, you have to create the queryset
464
        yourself, e.g. with pks:
465

466
            >>> objs = Entry.objects.bulk_create([
467
            ...     Entry(headline='This is a test'),
468
            ...     Entry(headline='This is only a test'),
469
            ... ])
470
            >>> pks = set(obj.pk for obj in objs)
471
            >>> update_dependent(Entry.objects.filter(pk__in=pks))
472

473
        .. NOTE::
474

475
            Getting pks from ``bulk_create`` is not supported by all database adapters.
476
            With a local computed field you can "cheat" here by providing a sentinel:
477

478
                >>> MyComputedModel.objects.bulk_create([
479
                ...     MyComputedModel(comp='SENTINEL'), # here or as default field value
480
                ...     MyComputedModel(comp='SENTINEL'),
481
                ... ])
482
                >>> update_dependent(MyComputedModel.objects.filter(comp='SENTINEL'))
483

484
            If the sentinel is beyond reach of the method result, this even ensures to update
485
            only the newly added records.
486

487
        `instance` can also be a single model instance. Since calling ``save`` on a model instance
488
        will trigger this function by the `post_save` signal already it should not be called
489
        for single instances, if they get saved anyway.
490

491
        `update_fields` can be used to indicate, that only certain fields on the queryset changed,
492
        which helps to further narrow down the records to be updated.
493

494
        Special care is needed, if a bulk action contains foreign key changes,
495
        that are part of a computed field dependency chain. To correctly handle that case,
496
        provide the result of ``preupdate_dependent`` as `old` argument like this:
497

498
                >>> # given: some computed fields model depends somehow on Entry.fk_field
499
                >>> old_relations = preupdate_dependent(Entry.objects.filter(pub_date__year=2010))
500
                >>> Entry.objects.filter(pub_date__year=2010).update(fk_field=new_related_obj)
501
                >>> update_dependent(Entry.objects.filter(pub_date__year=2010), old=old_relations)
502

503
        `update_local=False` disables model local computed field updates of the entry node. 
504
        (used as optimization during tree traversal). You should not disable it yourself.
505
        """
506
        _model = model or self._get_model(instance)
12✔
507

508
        # bulk_updater might change fields, ensure we have set/None
509
        _update_fields = None if update_fields is None else set(update_fields)
12✔
510

511
        # exit early if we are in not_computed context
512
        if ctx := get_not_computed_context():
12✔
513
            if ctx.recover:
12✔
514
                ctx.record_update(instance, _model, _update_fields)
12✔
515
            return
12✔
516

517
        # Note: update_local is always off for updates triggered from the resolver
518
        # but True by default to avoid accidentally skipping updates called by user
519
        if update_local and self.has_computedfields(_model):
12✔
520
            # We skip a transaction here in the same sense,
521
            # as local cf updates are not guarded either.
522
            # FIXME: signals are broken here...
523
            if isinstance(instance, QuerySet):
12!
524
                self.bulk_updater(instance, _update_fields, local_only=True, querysize=querysize)
12✔
525
            else:
UNCOV
526
                self.single_updater(_model, instance, _update_fields)
×
527

528
        updates = self._querysets_for_update(_model, instance, _update_fields).values()
12✔
529
        if updates:
12✔
530
            if not _is_recursive:
12✔
531
                resolver_start.send(sender=self)
12✔
532
                with transaction.atomic():
12✔
533
                    pks_updated: Dict[Type[Model], Set[Any]] = {}
12✔
534
                    for queryset, fields in updates:
12✔
535
                        _pks = self.bulk_updater(queryset, fields, return_pks=True, querysize=querysize)
12✔
536
                        if _pks:
12✔
537
                            pks_updated[queryset.model] = _pks
12✔
538
                    if old:
12✔
539
                        for model2, data in old.items():
12✔
540
                            pks, fields = data
12✔
541
                            queryset = model2.objects.filter(pk__in=pks-pks_updated.get(model2, set()))
12✔
542
                            self.bulk_updater(queryset, fields, querysize=querysize)
12✔
543
            else:
544
                for queryset, fields in updates:
12✔
545
                    self.bulk_updater(queryset, fields, return_pks=False, querysize=querysize)
12✔
546
            if not _is_recursive:
12✔
547
                resolver_exit.send(sender=self)
12✔
548

549
    def single_updater(
12✔
550
        self,
551
        model,
552
        instance,
553
        update_fields
554
    ):
555
        # TODO: needs a couple of tests, proper typing and doc
NEW
556
        cf_mro = self.get_local_mro(model, frozenset_none(update_fields))
×
NEW
557
        if update_fields:
×
NEW
558
            update_fields.update(cf_mro)
×
NEW
559
        changed = []
×
NEW
560
        for fieldname in cf_mro:
×
NEW
561
            old_value = getattr(instance, fieldname)
×
NEW
562
            new_value = self._compute(instance, model, fieldname)
×
NEW
563
            if new_value != old_value:
×
NEW
564
                changed.append(fieldname)
×
NEW
565
                setattr(instance, fieldname, new_value)
×
NEW
566
        if changed:
×
NEW
567
            self._update(model.objects.all(), [instance], changed)
×
NEW
568
            resolver_update.send(sender=self, model=model, fields=changed, pks=[instance.pk])
×
569

570
    def bulk_updater(
12✔
571
        self,
572
        queryset: QuerySet,
573
        update_fields: Optional[Set[str]] = None,
574
        return_pks: bool = False,
575
        local_only: bool = False,
576
        querysize: Optional[int] = None
577
    ) -> Optional[Set[Any]]:
578
        """
579
        Update local computed fields and descent in the dependency tree by calling
580
        ``update_dependent`` for dependent models.
581

582
        This method does the local field updates on `queryset`:
583

584
            - eval local `MRO` of computed fields
585
            - expand `update_fields`
586
            - apply optional `select_related` and `prefetch_related` rules to `queryset`
587
            - walk all records and recalculate fields in `update_fields`
588
            - aggregate changeset and save as batched `bulk_update` to the database
589

590
        By default this method triggers the update of dependent models by calling
591
        ``update_dependent`` with `update_fields` (next level of tree traversal).
592
        This can be suppressed by setting `local_only=True`.
593

594
        If `return_pks` is set, the method returns a set of altered pks of `queryset`.
595
        """
596
        model: Type[Model] = queryset.model
12✔
597

598
        # distinct issue workaround
599
        # the workaround is needed for already sliced/distinct querysets coming from outside
600
        # TODO: distinct is a major query perf smell, and is in fact only needed on back relations
601
        #       may need some rework in _querysets_for_update
602
        #       ideally we find a way to avoid it for forward relations
603
        #       also see #101
604
        if queryset.query.can_filter() and not queryset.query.distinct_fields:
12!
605
            if queryset.query.combinator != "union":
12✔
606
                queryset = queryset.distinct()
12✔
607
        else:
UNCOV
608
            queryset = model._base_manager.filter(pk__in=subquery_pk(queryset, queryset.db))
×
609

610
        # correct update_fields by local mro
611
        mro: List[str] = self.get_local_mro(model, frozenset_none(update_fields))
12✔
612
        fields = frozenset(mro)
12✔
613
        if update_fields:
12✔
614
            update_fields.update(fields)
12✔
615

616
        # fix #167: skip prefetch/select if union was used
617
        # fix #193: if select or prefetch is set, extract pks on UNIONed queryset
618
        select = self.get_select_related(model, fields)
12✔
619
        prefetch = self.get_prefetch_related(model, fields)
12✔
620
        if (select or prefetch) and queryset.query.combinator == "union":
12✔
621
            queryset = model._base_manager.filter(pk__in=subquery_pk(queryset, queryset.db))
12✔
622
        if select:
12✔
623
            queryset = queryset.select_related(*select)
12✔
624
        if prefetch:
12✔
625
            queryset = queryset.prefetch_related(*prefetch)
12✔
626

627
        pks = []
12✔
628
        if fields:
12✔
629
            q_size = self.get_querysize(model, fields, querysize)
12✔
630
            changed_objs: List[Model] = []
12✔
631
            for elem in slice_iterator(queryset, q_size):
12✔
632
                # note on the loop: while it is technically not needed to batch things here,
633
                # we still prebatch to not cause memory issues for very big querysets
634
                has_changed = False
12✔
635
                for comp_field in mro:
12✔
636
                    new_value = self._compute(elem, model, comp_field)
12✔
637
                    if new_value != getattr(elem, comp_field):
12✔
638
                        has_changed = True
12✔
639
                        setattr(elem, comp_field, new_value)
12✔
640
                if has_changed:
12✔
641
                    changed_objs.append(elem)
12✔
642
                    pks.append(elem.pk)
12✔
643
                if len(changed_objs) >= self._batchsize:
12✔
644
                    self._update(model._base_manager.all(), changed_objs, fields)
12✔
645
                    changed_objs = []
12✔
646
            if changed_objs:
12✔
647
                self._update(model._base_manager.all(), changed_objs, fields)
12✔
648

649
            if pks:
12✔
650
                resolver_update.send(sender=self, model=model, fields=fields, pks=pks)
12✔
651

652
        # trigger dependent comp field updates from changed records
653
        # other than before we exit the update tree early, if we have no changes at all
654
        # also cuts the update tree for recursive deps (tree-like)
655
        if not local_only and pks:
12✔
656
            self.update_dependent(
12✔
657
                instance=model._base_manager.filter(pk__in=pks),
658
                model=model,
659
                update_fields=fields,
660
                update_local=False,
661
                _is_recursive=True
662
            )
663
        return set(pks) if return_pks else None
12✔
664

665
    def _compute(self, instance: Model, model: Type[Model], fieldname: str) -> Any:
12✔
666
        """
667
        Returns the computed field value for ``fieldname``.
668
        Note that this is just a shorthand method for calling the underlying computed
669
        field method and does not deal with local MRO, thus should only be used,
670
        if the MRO is respected by other means.
671
        For quick inspection of a single computed field value, that gonna be written
672
        to the database, always use ``compute(fieldname)`` instead.
673
        """
674
        field = self._computed_models[model][fieldname]
12✔
675
        if instance._state.adding or not instance.pk:
12✔
676
            if field._computed['default_on_create']:
12✔
677
                return field.get_default()
12✔
678
        return field._computed['func'](instance)
12✔
679

680
    def compute(self, instance: Model, fieldname: str) -> Any:
12✔
681
        """
682
        Returns the computed field value for ``fieldname``. This method allows
683
        to inspect the new calculated value, that would be written to the database
684
        by a following ``save()``.
685

686
        Other than calling ``update_computedfields`` on an model instance this call
687
        is not destructive for old computed field values.
688
        """
689
        # Getting a single computed value prehand is quite complicated,
690
        # as we have to:
691
        # - resolve local MRO backwards (stored MRO data is optimized for forward deps)
692
        # - calc all local cfs, that the requested one depends on
693
        # - stack and rewind interim values, as we dont want to introduce side effects here
694
        #   (in fact the save/bulker logic might try to save db calls based on changes)
695
        if get_not_computed_context():
12✔
696
            return getattr(instance, fieldname)
12✔
697
        mro = self.get_local_mro(type(instance), None)
12✔
698
        if not fieldname in mro:
12✔
699
            return getattr(instance, fieldname)
12✔
700
        entries = self._local_mro[type(instance)]['fields']
12✔
701
        pos = 1 << mro.index(fieldname)
12✔
702
        stack: List[Tuple[str, Any]] = []
12✔
703
        model = type(instance)
12✔
704
        for field in mro:
12!
705
            if field == fieldname:
12✔
706
                ret = self._compute(instance, model, fieldname)
12✔
707
                for field2, old in stack:
12✔
708
                    # reapply old stack values
709
                    setattr(instance, field2, old)
12✔
710
                return ret
12✔
711
            f_mro = entries.get(field, 0)
12✔
712
            if f_mro & pos:
12✔
713
                # append old value to stack for later rewinding
714
                # calc and set new value for field, if the requested one depends on it
715
                stack.append((field, getattr(instance, field)))
12✔
716
                setattr(instance, field, self._compute(instance, model, field))
12✔
717

718
    def get_select_related(
12✔
719
        self,
720
        model: Type[Model],
721
        fields: Optional[FrozenSet[str]] = None
722
    ) -> Set[str]:
723
        """
724
        Get defined select_related rules for `fields` (all if none given).
725
        """
726
        try:
12✔
727
            return self._cached_select_related[model][fields]
12✔
728
        except KeyError:
12✔
729
            pass
12✔
730
        select: Set[str] = set()
12✔
731
        ff = fields
12✔
732
        if ff is None:
12!
NEW
733
            ff = frozenset(self._computed_models[model].keys())
×
734
        for field in ff:
12✔
735
            select.update(self._computed_models[model][field]._computed['select_related'])
12✔
736
        self._cached_select_related[model][fields] = select
12✔
737
        return select
12✔
738

739
    def get_prefetch_related(
12✔
740
        self,
741
        model: Type[Model],
742
        fields: Optional[FrozenSet[str]] = None
743
    ) -> List:
744
        """
745
        Get defined prefetch_related rules for `fields` (all if none given).
746
        """
747
        try:
12✔
748
            return self._cached_prefetch_related[model][fields]
12✔
749
        except KeyError:
12✔
750
            pass
12✔
751
        prefetch: List[Any] = []
12✔
752
        ff = fields
12✔
753
        if ff is None:
12!
NEW
754
            ff = frozenset(self._computed_models[model].keys())
×
755
        for field in ff:
12✔
756
            prefetch.extend(self._computed_models[model][field]._computed['prefetch_related'])
12✔
757
        self._cached_prefetch_related[model][fields] = prefetch
12✔
758
        return prefetch
12✔
759

760
    def get_querysize(
12✔
761
        self,
762
        model: Type[Model],
763
        fields: Optional[FrozenSet[str]] = None,
764
        override: Optional[int] = None
765
    ) -> int:
766
        try:
12✔
767
            return self._cached_querysize[model][fields][override]
12✔
768
        except KeyError:
12✔
769
            pass
12✔
770
        ff = fields
12✔
771
        if ff is None:
12✔
772
            ff = frozenset(self._computed_models[model].keys())
12✔
773
        base = settings.COMPUTEDFIELDS_QUERYSIZE if override is None else override
12✔
774
        result = min(self._computed_models[model][f]._computed['querysize'] or base for f in ff)
12✔
775
        self._cached_querysize[model][fields][override] = result
12✔
776
        return result
12✔
777

778
    def get_contributing_fks(self) -> IFkMap:
12✔
779
        """
780
        Get a mapping of models and their local foreign key fields,
781
        that are part of a computed fields dependency chain.
782

783
        Whenever a bulk action changes one of the fields listed here, you have to create
784
        a listing of the associated  records with ``preupdate_dependent`` before doing
785
        the bulk change. After the bulk change feed the listing back to ``update_dependent``
786
        with the `old` argument.
787

788
        With ``COMPUTEDFIELDS_ADMIN = True`` in `settings.py` this mapping can also be
789
        inspected as admin view. 
790
        """
791
        if not self._map_loaded:  # pragma: no cover
792
            raise ResolverException('resolver has no maps loaded yet')
793
        return self._fk_map
12✔
794

795
    def _sanity_check(self, field: Field, depends: IDepends) -> None:
12✔
796
        """
797
        Basic type check for computed field arguments `field` and `depends`.
798
        This only checks for proper type alignment (most crude source of errors) to give
799
        devs an early startup error for misconfigured computed fields.
800
        More subtle errors like non-existing paths or fields are caught
801
        by the resolver during graph reduction yielding somewhat crytic error messages.
802

803
        There is another class of misconfigured computed fields we currently cannot
804
        find by any safety measures - if `depends` provides valid paths and fields,
805
        but the function operates on different dependencies. Currently it is the devs'
806
        responsibility to perfectly align `depends` entries with dependencies
807
        used by the function to avoid faulty update behavior.
808
        """
809
        if not isinstance(field, Field):
12!
UNCOV
810
                raise ResolverException('field argument is not a Field instance')
×
811
        for rule in depends:
12✔
812
            try:
12✔
813
                path, fieldnames = rule
12✔
UNCOV
814
            except ValueError:
×
UNCOV
815
                raise ResolverException(MALFORMED_DEPENDS)
×
816
            if not isinstance(path, str) or not all(isinstance(f, str) for f in fieldnames):
12!
UNCOV
817
                raise ResolverException(MALFORMED_DEPENDS)
×
818

819
    def computedfield_factory(
12✔
820
        self,
821
        field: 'Field[_ST, _GT]',
822
        compute: Callable[..., _ST],
823
        depends: Optional[IDepends] = None,
824
        select_related: Optional[Sequence[str]] = None,
825
        prefetch_related: Optional[Sequence[Any]] = None,
826
        querysize: Optional[int] = None,
827
        default_on_create: Optional[bool] = False
828
    ) -> 'Field[_ST, _GT]':
829
        """
830
        Factory for computed fields.
831

832
        The method gets exposed as ``ComputedField`` to allow a more declarative
833
        code style with better separation of field declarations and function
834
        implementations. It is also used internally for the ``computed`` decorator.
835
        Similar to the decorator, the ``compute`` function expects a single argument
836
        as model instance of the model it got applied to.
837

838
        Usage example:
839

840
        .. code-block:: python
841

842
            from computedfields.models import ComputedField
843

844
            def calc_mul(inst):
845
                return inst.a * inst.b
846

847
            class MyModel(ComputedFieldsModel):
848
                a = models.IntegerField()
849
                b = models.IntegerField()
850
                sum = ComputedField(
851
                    models.IntegerField(),
852
                    depends=[('self', ['a', 'b'])],
853
                    compute=lambda inst: inst.a + inst.b
854
                )
855
                mul = ComputedField(
856
                    models.IntegerField(),
857
                    depends=[('self', ['a', 'b'])],
858
                    compute=calc_mul
859
                )
860
        """
861
        self._sanity_check(field, depends or [])
12✔
862
        cf = cast('IComputedField[_ST, _GT]', field)
12✔
863
        cf._computed = {
12✔
864
            'func': compute,
865
            'depends': depends or [],
866
            'select_related': select_related or [],
867
            'prefetch_related': prefetch_related or [],
868
            'querysize': querysize,
869
            'default_on_create': default_on_create
870
        }
871
        cf.editable = False
12✔
872
        self.add_field(cf)
12✔
873
        return field
12✔
874

875
    def computed(
12✔
876
        self,
877
        field: 'Field[_ST, _GT]',
878
        depends: Optional[IDepends] = None,
879
        select_related: Optional[Sequence[str]] = None,
880
        prefetch_related: Optional[Sequence[Any]] = None,
881
        querysize: Optional[int] = None,
882
        default_on_create: Optional[bool] = False
883
    ) -> Callable[[Callable[..., _ST]], 'Field[_ST, _GT]']:
884
        """
885
        Decorator to create computed fields.
886

887
        `field` should be a model concrete field instance suitable to hold the result
888
        of the decorated method. The decorator expects a keyword argument `depends`
889
        to indicate dependencies to model fields (local or related).
890
        Listed dependencies will automatically update the computed field.
891

892
        Examples:
893

894
            - create a char field with no further dependencies (not very useful)
895

896
            .. code-block:: python
897

898
                @computed(models.CharField(max_length=32))
899
                def ...
900

901
            - create a char field with a dependency to the field ``name`` on a
902
              foreign key relation ``fk``
903

904
            .. code-block:: python
905

906
                @computed(models.CharField(max_length=32), depends=[('fk', ['name'])])
907
                def ...
908

909
        Dependencies should be listed as ``['relation_name', concrete_fieldnames]``.
910
        The relation can span serveral models, simply name the relation
911
        in python style with a dot (e.g. ``'a.b.c'``). A relation can be any of
912
        foreign key, m2m, o2o and their back relations. The fieldnames must point to
913
        concrete fields on the foreign model.
914

915
        .. NOTE::
916

917
            Dependencies to model local fields should be listed with ``'self'`` as relation name.
918

919
        With `select_related` and `prefetch_related` you can instruct the dependency resolver
920
        to apply certain optimizations on the update queryset.
921

922
        .. NOTE::
923

924
            `select_related` and `prefetch_related` are stacked over computed fields
925
            of the same model during updates, that are marked for update.
926
            If your optimizations contain custom attributes (as with `to_attr` of a
927
            `Prefetch` object), these attributes will only be available on instances
928
            during updates from the resolver, never on newly constructed instances or
929
            model instances pulled by other means, unless you applied the same lookups manually.
930

931
            To keep the computed field methods working under any circumstances,
932
            it is a good idea not to rely on lookups with custom attributes,
933
            or to test explicitly for them in the method with an appropriate plan B.
934

935
        With `default_on_create` set to ``True`` the function calculation will be skipped
936
        for newly created or copy-cloned instances, instead the value will be set from the
937
        inner field's `default` argument.
938

939
        .. CAUTION::
940

941
            With the dependency resolver you can easily create recursive dependencies
942
            by accident. Imagine the following:
943

944
            .. code-block:: python
945

946
                class A(ComputedFieldsModel):
947
                    @computed(models.CharField(max_length=32), depends=[('b_set', ['comp'])])
948
                    def comp(self):
949
                        return ''.join(b.comp for b in self.b_set.all())
950

951
                class B(ComputedFieldsModel):
952
                    a = models.ForeignKey(A)
953

954
                    @computed(models.CharField(max_length=32), depends=[('a', ['comp'])])
955
                    def comp(self):
956
                        return a.comp
957

958
            Neither an object of `A` or `B` can be saved, since the ``comp`` fields depend on
959
            each other. While it is quite easy to spot for this simple case it might get tricky
960
            for more complicated dependencies. Therefore the dependency resolver tries
961
            to detect cyclic dependencies and might raise a ``CycleNodeException`` during
962
            startup.
963

964
            If you experience this in your project try to get in-depth cycle
965
            information, either by using the ``rendergraph`` management command or
966
            by directly accessing the graph objects:
967

968
            - intermodel dependency graph: ``active_resolver._graph``
969
            - model local dependency graphs: ``active_resolver._graph.modelgraphs[your_model]``
970
            - union graph: ``active_resolver._graph.get_uniongraph()``
971

972
            Also see the graph documentation :ref:`here<graph>`.
973
        """
974
        def wrap(func: Callable[..., _ST]) -> 'Field[_ST, _GT]':
12✔
975
            return self.computedfield_factory(
12✔
976
                field,
977
                compute=func,
978
                depends=depends,
979
                select_related=select_related,
980
                prefetch_related=prefetch_related,
981
                querysize=querysize,
982
                default_on_create=default_on_create
983
            )
984
        return wrap
12✔
985

986
    @overload
12✔
987
    def precomputed(self, f: F) -> F:
12✔
UNCOV
988
        ...
×
989
    @overload
12✔
990
    def precomputed(self, skip_after: bool) -> Callable[[F], F]:
12✔
UNCOV
991
        ...
×
992
    def precomputed(self, *dargs, **dkwargs) -> Union[F, Callable[[F], F]]:
12✔
993
        """
994
        Decorator for custom ``save`` methods, that expect local computed fields
995
        to contain already updated values on enter.
996

997
        By default local computed field values are only calculated once by the
998
        ``ComputedFieldModel.save`` method after your own save method.
999

1000
        By placing this decorator on your save method, the values will be updated
1001
        before entering your method as well. Note that this comes for the price of
1002
        doubled local computed field calculations (before and after your save method).
1003
        
1004
        To avoid a second recalculation, the decorator can be called with `skip_after=True`.
1005
        Note that this might lead to desychronized computed field values, if you do late
1006
        field changes in your save method without another resync afterwards.
1007
        """
1008
        skip: bool = False
12✔
1009
        func: Optional[F] = None
12✔
1010
        if dargs:
12✔
1011
            if len(dargs) > 1 or not callable(dargs[0]) or dkwargs:
12!
UNCOV
1012
                raise ResolverException('error in @precomputed declaration')
×
1013
            func = dargs[0]
12✔
1014
        else:
1015
            skip = dkwargs.get('skip_after', False)
12✔
1016
        
1017
        def wrap(func: F) -> F:
12✔
1018
            def _save(instance, *args, **kwargs):
12✔
1019
                new_fields = self.update_computedfields(instance, kwargs.get('update_fields'))
12✔
1020
                if new_fields:
12!
UNCOV
1021
                    kwargs['update_fields'] = new_fields
×
1022
                kwargs['skip_computedfields'] = skip
12✔
1023
                return func(instance, *args, **kwargs)
12✔
1024
            return cast(F, _save)
12✔
1025
        
1026
        return wrap(func) if func else wrap
12✔
1027

1028
    def update_computedfields(
12✔
1029
        self,
1030
        instance: Model,
1031
        update_fields: Optional[Iterable[str]] = None
1032
        ) -> Optional[Iterable[str]]:
1033
        """
1034
        Update values of local computed fields of `instance`.
1035

1036
        Other than calling ``compute`` on an instance, this call overwrites
1037
        computed field values on the instance (destructive).
1038

1039
        Returns ``None`` or an updated set of field names for `update_fields`.
1040
        The returned fields might contained additional computed fields, that also
1041
        changed based on the input fields, thus should extend `update_fields`
1042
        on a save call.
1043
        """
1044
        if get_not_computed_context():
12✔
1045
            return update_fields
12✔
1046
        model = type(instance)
12✔
1047
        if not self.has_computedfields(model):
12✔
1048
            return update_fields
12✔
1049
        cf_mro = self.get_local_mro(model, frozenset_none(update_fields))
12✔
1050
        if update_fields:
12✔
1051
            update_fields = set(update_fields)
12✔
1052
            update_fields.update(set(cf_mro))
12✔
1053
        for fieldname in cf_mro:
12✔
1054
            setattr(instance, fieldname, self._compute(instance, model, fieldname))
12✔
1055
        if update_fields:
12✔
1056
            return update_fields
12✔
1057
        return None
12✔
1058

1059
    def has_computedfields(self, model: Type[Model]) -> bool:
12✔
1060
        """
1061
        Indicate whether `model` has computed fields.
1062
        """
1063
        return model in self._computed_models
12✔
1064

1065
    def get_computedfields(self, model: Type[Model]) -> Iterable[str]:
12✔
1066
        """
1067
        Get all computed fields on `model`.
1068
        """
1069
        return self._computed_models.get(model, {}).keys()
12✔
1070

1071
    def is_computedfield(self, model: Type[Model], fieldname: str) -> bool:
12✔
1072
        """
1073
        Indicate whether `fieldname` on `model` is a computed field.
1074
        """
1075
        return fieldname in self.get_computedfields(model)
12✔
1076

1077
    def get_graphs(self) -> Tuple[Graph, Dict[Type[Model], ModelGraph], Graph]:
12✔
1078
        """
1079
        Return a tuple of all graphs as
1080
        ``(intermodel_graph, {model: modelgraph, ...}, union_graph)``.
1081
        """
UNCOV
1082
        graph = self._graph
×
UNCOV
1083
        if not graph:
×
UNCOV
1084
            graph = ComputedModelsGraph(active_resolver.computed_models)
×
UNCOV
1085
            graph.get_edgepaths()
×
UNCOV
1086
            graph.get_uniongraph()
×
UNCOV
1087
        return (graph, graph.modelgraphs, graph.get_uniongraph())
×
1088

1089

1090
# active_resolver is currently treated as global singleton (used in imports)
1091
#: Currently active resolver.
1092
active_resolver = Resolver()
12✔
1093

1094
# BOOT_RESOLVER: resolver that holds all startup declarations and resolve maps
1095
# gets deactivated after startup, thus it is currently not possible to define
1096
# new computed fields and add their resolve rules at runtime
1097
# TODO: investigate on custom resolvers at runtime to be bootstrapped from BOOT_RESOLVER
1098
#: Resolver used during django bootstrapping.
1099
#: This is currently the same as `active_resolver` (treated as global singleton).
1100
BOOT_RESOLVER = active_resolver
12✔
1101

1102

1103
# placeholder class to test for correct model inheritance
1104
# during initial field resolving
1105
class _ComputedFieldsModelBase:
12✔
1106
    pass
12✔
1107

1108

1109
class NotComputed:
12✔
1110
    """
1111
    Context to disable all computed field calculations and resolver updates temporarily.
1112

1113
    With *recover=True* the context will track all database relevant actions and update
1114
    affected computed fields on exit of the context.
1115
    """
1116
    def __init__(self, recover=False):
12✔
1117
        self.remove_ctx = True
12✔
1118
        self.recover = recover
12✔
1119
        self.qs: IRecordedStrict = defaultdict(lambda: {'pks': set(), 'fields': set()})
12✔
1120
        self.up: IRecorded = defaultdict(lambda: {'pks': set(), 'fields': set()})
12✔
1121

1122
    def __enter__(self):
12✔
1123
        ctx = get_not_computed_context()
12✔
1124
        if ctx:
12✔
1125
            self.remove_ctx = False
12✔
1126
            return ctx
12✔
1127
        set_not_computed_context(self)
12✔
1128
        return self
12✔
1129

1130
    def __exit__(self, exc_type, exc_value, traceback):
12✔
1131
        if self.remove_ctx:
12✔
1132
            set_not_computed_context(None)
12✔
1133
            if self.recover:
12✔
1134
                self._resync()
12✔
1135
        return False
12✔
1136
    
1137
    def record_querysets(
12✔
1138
        self,
1139
        data: Dict[Type[Model], List[Any]]
1140
    ):
1141
        """
1142
        Records the results of a previous _queryset_for_updates call
1143
        (must be called with argument *pk_list=True*).
1144
        """
1145
        if not self.recover:
12!
UNCOV
1146
            return
×
1147
        for model, mdata in data.items():
12✔
1148
            pks, fields = mdata
12✔
1149
            entry = self.qs[model]
12✔
1150
            entry['pks'] |= pks
12✔
1151
            # expand fields (might show a negative perf impact)
1152
            entry['fields'] |= fields
12✔
1153

1154
    def record_update(
12✔
1155
        self,
1156
        instance: Union[QuerySet, Model],
1157
        model: Type[Model],
1158
        fields: Optional[Set[str]] = None
1159
    ):
1160
        """
1161
        Records any update as typically given to update_dependent.
1162
        """
1163
        if not self.recover:
12!
UNCOV
1164
            return
×
1165
        entry = self.up[model]
12✔
1166
        if isinstance(instance, QuerySet):
12✔
1167
            entry['pks'].update(instance.values_list('pk', flat=True))
12✔
1168
        else:
1169
            entry['pks'].add(instance.pk)
12✔
1170
        # expand fields (might show a negative perf impact)
1171
        # special None handling in fields here is needed to preserve
1172
        # "all" rule from update_dependent on local CF model updates
1173
        if fields is None:
12✔
1174
            entry['fields'] = None
12✔
1175
        else:
1176
            if not entry['fields'] is None:
12✔
1177
                entry['fields'] |= fields
12✔
1178

1179
    def _resync(self):
12✔
1180
        """
1181
        This method tries to recover from the desync state by replaying the updates
1182
        of the recorded db actions.
1183

1184
        The resync does a flattening on the first update tree level:
1185
        - determine all follow-up changesets as pk lists (next tree level)
1186
        - merge *local_only* CF models with follow-up changesets (limited flattening)
1187
        - update remaining *local_only* CF models
1188
        - update remaining changesets with full descent
1189

1190
        The method currently favours field- and changeset merges over isolated updates.
1191
        The final updates are done the same way as during normal operation (DFS).
1192
        """
1193
        if not self.qs and not self.up:
12✔
1194
            return
12✔
1195

1196
        # first collect querysets from record_update for later bulk_update
1197
        # this additional pk extraction introduces a timy perf penalty,
1198
        # but pays off by pk merging
1199
        for model, local_data in self.up.items():
12✔
1200

1201
            # for CF models expand the local MRO before getting the querysets
1202
            # FIXME: untangle the side effect update of fields in update_dependent <-- bulk_updater
1203
            fields = local_data['fields']
12✔
1204
            if fields and active_resolver.has_computedfields(model):
12✔
1205
                fields = set(active_resolver.get_local_mro(model, frozenset(fields)))
12✔
1206

1207
            mdata = active_resolver._querysets_for_update(
12✔
1208
                model,
1209
                model._base_manager.filter(pk__in=local_data['pks']),
1210
                update_fields=fields,
1211
                pk_list=True
1212
            )
1213
            for m, mdata in mdata.items():
12✔
1214
                pks, fields = mdata
12✔
1215
                entry = self.qs[m]
12✔
1216
                entry['pks'] |= pks
12✔
1217
                entry['fields'] |= fields
12✔
1218
    
1219
        # move CF model local_only updates to final changesets, if already there
1220
        for model, mdata in self.up.items():
12✔
1221
            # patch for proxy models (resolver works internally with basemodels only)
1222
            basemodel = proxy_to_base_model(model) if model._meta.proxy else model
12✔
1223
            if active_resolver.has_computedfields(model) and basemodel in self.qs:
12✔
1224
                local_entry = self.up[model]
12✔
1225
                final_entry = self.qs[basemodel]
12✔
1226
                if local_entry['fields'] is None:
12✔
1227
                    final_entry['fields'] = set(active_resolver.get_local_mro(model))
12✔
1228
                else:
1229
                    final_entry['fields'] |= final_entry['fields']
12✔
1230
                final_entry['pks'] |= local_entry['pks']
12✔
1231
                local_entry['pks'].clear()
12✔
1232

1233
        # finally update all remaining changesets:
1234
        # 1. local_only update for CF models in up
1235
        # 2. all remaining changesets in qs
1236
        resolver_start.send(sender=active_resolver)
12✔
1237
        with transaction.atomic():
12✔
1238
            for model, local_data in self.up.items():
12✔
1239
                if local_data['pks'] and active_resolver.has_computedfields(model):
12✔
1240
                    # postponed local_only upd for CFs models
1241
                    # IMPORTANT: must happen before final updates
1242
                    active_resolver.bulk_updater(
12✔
1243
                        model._base_manager.filter(pk__in=local_data['pks']),
1244
                        local_data['fields'],
1245
                        local_only=True,
1246
                        querysize=settings.COMPUTEDFIELDS_QUERYSIZE
1247
                    )
1248
            for model, mdata in self.qs.items():
12✔
1249
                if mdata['pks']:
12!
1250
                    active_resolver.bulk_updater(
12✔
1251
                        model._base_manager.filter(pk__in=mdata['pks']),
1252
                        mdata['fields'],
1253
                        querysize=settings.COMPUTEDFIELDS_QUERYSIZE
1254
                    )
1255
        resolver_exit.send(sender=active_resolver)
12✔
1256

1257

1258
#class NotComputed:
1259
#    """
1260
#    Context to disable all computed field calculations and resolver updates temporarily.
1261
#
1262
#    With *recover=True* the context will track all database relevant actions and update
1263
#    affected computed fields on exit of the context.
1264
#    """
1265
#    def __init__(self, recover=False):
1266
#        self.remove_ctx = True
1267
#        self.recover = recover
1268
#        self.recorded_qs = defaultdict(lambda: defaultdict(lambda: set()))
1269
#        self.recorded_up = defaultdict(lambda: defaultdict(lambda: set()))
1270
#
1271
#    def __enter__(self):
1272
#        ctx = get_not_computed_context()
1273
#        if ctx:
1274
#            self.remove_ctx = False
1275
#            return ctx
1276
#        set_not_computed_context(self)
1277
#        return self
1278
#
1279
#    def __exit__(self, exc_type, exc_value, traceback):
1280
#        if self.remove_ctx:
1281
#            set_not_computed_context(None)
1282
#            if self.recover:
1283
#                self._resync()
1284
#        return False
1285
#
1286
#    def record_querysets(
1287
#        self,
1288
#        data: Dict[Type[Model], List[Any]]
1289
#    ):
1290
#        for model, mdata in data.items():
1291
#            pks, fields = mdata
1292
#            self.recorded_qs[model][frozenset(fields)] |= pks
1293
#
1294
#    def record_update(
1295
#        self,
1296
#        instance: Union[QuerySet, Model],
1297
#        model: Type[Model],
1298
#        fields: Optional[Set[str]] = None
1299
#    ):
1300
#        ff = None if fields is None else frozenset(fields)
1301
#        if isinstance(instance, QuerySet):
1302
#            self.recorded_up[model][ff].update(instance.values_list('pk', flat=True))
1303
#        else:
1304
#            self.recorded_up[model][ff].add(instance.pk)
1305
#
1306
#    def _resync(self):
1307
#        if not self.recorded_qs and not self.recorded_up:
1308
#            return
1309
#
1310
#        # working way: move pks to recorded_qs, if model:fields is alread there
1311
#        for model, data in self.recorded_up.items():
1312
#            for fields, pks in data.items():
1313
#                if fields and active_resolver.has_computedfields(model):
1314
#                    fields = set(active_resolver.get_local_mro(model, fields))
1315
#                mdata = active_resolver._querysets_for_update(
1316
#                    model,
1317
#                    model._base_manager.filter(pk__in=pks),
1318
#                    update_fields=fields,
1319
#                    pk_list=True
1320
#                )
1321
#                for qs_model, qs_data in mdata.items():
1322
#                    qs_pks, qs_fields = qs_data
1323
#                    self.recorded_qs[qs_model][frozenset(qs_fields)] |= qs_pks
1324
#
1325
#        resolver_start.send(sender=active_resolver)
1326
#        with transaction.atomic():
1327
#            for model, data in self.recorded_up.items():
1328
#                for fields, pks in data.items():
1329
#                    if active_resolver.has_computedfields(model):
1330
#                        basemodel = proxy_to_base_model(model) if model._meta.proxy else model
1331
#                        ff = frozenset(active_resolver.get_local_mro(model) if fields is None else fields)
1332
#                        if basemodel in self.recorded_qs and ff in self.recorded_qs[basemodel]:
1333
#                            self.recorded_qs[basemodel][ff] |= pks
1334
#                        else:
1335
#                            ff = None if fields is None else set(fields)
1336
#                            active_resolver.bulk_updater(
1337
#                                model._base_manager.filter(pk__in=pks),
1338
#                                ff,
1339
#                                local_only=True,
1340
#                                querysize=settings.COMPUTEDFIELDS_QUERYSIZE,
1341
#                            )
1342
#
1343
#            # attempt with merging into same recorded_qs run
1344
#            # here we would benefit from a topsorted list ;)
1345
#            # FIXME: loop needs a recursion abort
1346
#            recorded_qs = self.recorded_qs
1347
#            while recorded_qs:
1348
#                recorded_up = defaultdict(lambda: defaultdict(lambda: set()))
1349
#                done = defaultdict(lambda: set())
1350
#                for model, data in recorded_qs.items():
1351
#                    for fields, pks in data.items():
1352
#                        pks = active_resolver.bulk_updater(
1353
#                            model._base_manager.filter(pk__in=pks),
1354
#                            None if fields is None else set(fields),
1355
#                            local_only=True,
1356
#                            querysize=settings.COMPUTEDFIELDS_QUERYSIZE,
1357
#                            return_pks=True
1358
#                        )
1359
#                        done[model].add(frozenset(fields))
1360
#                        if pks:
1361
#                            fields = set(active_resolver.get_local_mro(model, fields))
1362
#                            mdata = active_resolver._querysets_for_update(
1363
#                                model,
1364
#                                model._base_manager.filter(pk__in=pks),
1365
#                                update_fields=fields,
1366
#                                pk_list=True
1367
#                            )
1368
#                            for qs_model, qs_data in mdata.items():
1369
#                                qs_pks, qs_fields = qs_data
1370
#                                ff = frozenset(qs_fields)
1371
#                                if (
1372
#                                    qs_model in recorded_qs
1373
#                                    and ff in recorded_qs[qs_model]
1374
#                                    and ff not in done[qs_model]
1375
#                                ):
1376
#                                    recorded_qs[qs_model][ff] |= qs_pks
1377
#                                else:
1378
#                                    recorded_up[qs_model][ff] |= qs_pks
1379
#                                #recorded_up[qs_model][frozenset(qs_fields)] |= qs_pks
1380
#                recorded_qs = recorded_up
1381
#        resolver_exit.send(sender=active_resolver)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc