• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

bagerard / mongoengine / 10550339234

25 Aug 2024 09:45PM UTC coverage: 94.526%. First build
10550339234

push

github

bagerard
Fix pillow deprecation warning related with LANCZOS filter

2 of 3 new or added lines in 1 file covered. (66.67%)

5267 of 5572 relevant lines covered (94.53%)

1.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.94
/mongoengine/fields.py
1
import datetime
2✔
2
import decimal
2✔
3
import inspect
2✔
4
import itertools
2✔
5
import re
2✔
6
import socket
2✔
7
import time
2✔
8
import uuid
2✔
9
from inspect import isclass
2✔
10
from io import BytesIO
2✔
11
from operator import itemgetter
2✔
12

13
import gridfs
2✔
14
import pymongo
2✔
15
from bson import SON, Binary, DBRef, ObjectId
2✔
16
from bson.decimal128 import Decimal128, create_decimal128_context
2✔
17
from bson.int64 import Int64
2✔
18
from pymongo import ReturnDocument
2✔
19

20
try:
2✔
21
    import dateutil
2✔
22
except ImportError:
2✔
23
    dateutil = None
2✔
24
else:
25
    import dateutil.parser
×
26

27
from mongoengine.base import (
2✔
28
    BaseDocument,
29
    BaseField,
30
    ComplexBaseField,
31
    GeoJsonBaseField,
32
    LazyReference,
33
    ObjectIdField,
34
    get_document,
35
)
36
from mongoengine.base.utils import LazyRegexCompiler
2✔
37
from mongoengine.common import _import_class
2✔
38
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
2✔
39
from mongoengine.document import Document, EmbeddedDocument
2✔
40
from mongoengine.errors import (
2✔
41
    DoesNotExist,
42
    InvalidQueryError,
43
    ValidationError,
44
)
45
from mongoengine.queryset import DO_NOTHING
2✔
46
from mongoengine.queryset.base import BaseQuerySet
2✔
47
from mongoengine.queryset.transform import STRING_OPERATORS
2✔
48

49
try:
2✔
50
    from PIL import Image, ImageOps
2✔
51

52
    if hasattr(Image, "Resampling"):
2✔
53
        LANCZOS = Image.Resampling.LANCZOS
2✔
54
    else:
NEW
55
        LANCZOS = Image.LANCZOS
×
56
except ImportError:
×
57
    # pillow is optional so may not be installed
58
    Image = None
×
59
    ImageOps = None
×
60

61

62
__all__ = (
2✔
63
    "StringField",
64
    "URLField",
65
    "EmailField",
66
    "IntField",
67
    "LongField",
68
    "FloatField",
69
    "DecimalField",
70
    "BooleanField",
71
    "DateTimeField",
72
    "DateField",
73
    "ComplexDateTimeField",
74
    "EmbeddedDocumentField",
75
    "ObjectIdField",
76
    "GenericEmbeddedDocumentField",
77
    "DynamicField",
78
    "ListField",
79
    "SortedListField",
80
    "EmbeddedDocumentListField",
81
    "DictField",
82
    "MapField",
83
    "ReferenceField",
84
    "CachedReferenceField",
85
    "LazyReferenceField",
86
    "GenericLazyReferenceField",
87
    "GenericReferenceField",
88
    "BinaryField",
89
    "GridFSError",
90
    "GridFSProxy",
91
    "FileField",
92
    "ImageGridFsProxy",
93
    "ImproperlyConfigured",
94
    "ImageField",
95
    "GeoPointField",
96
    "PointField",
97
    "LineStringField",
98
    "PolygonField",
99
    "SequenceField",
100
    "UUIDField",
101
    "EnumField",
102
    "MultiPointField",
103
    "MultiLineStringField",
104
    "MultiPolygonField",
105
    "GeoJsonBaseField",
106
    "Decimal128Field",
107
)
108

109
RECURSIVE_REFERENCE_CONSTANT = "self"
2✔
110

111

112
class StringField(BaseField):
2✔
113
    """A unicode string field."""
114

115
    def __init__(self, regex=None, max_length=None, min_length=None, **kwargs):
2✔
116
        """
117
        :param regex: (optional) A string pattern that will be applied during validation
118
        :param max_length: (optional) A max length that will be applied during validation
119
        :param min_length: (optional) A min length that will be applied during validation
120
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.BaseField`
121
        """
122
        self.regex = re.compile(regex) if regex else None
2✔
123
        self.max_length = max_length
2✔
124
        self.min_length = min_length
2✔
125
        super().__init__(**kwargs)
2✔
126

127
    def to_python(self, value):
2✔
128
        if isinstance(value, str):
2✔
129
            return value
2✔
130
        try:
2✔
131
            value = value.decode("utf-8")
2✔
132
        except Exception:
2✔
133
            pass
2✔
134
        return value
2✔
135

136
    def validate(self, value):
2✔
137
        if not isinstance(value, str):
2✔
138
            self.error("StringField only accepts string values")
2✔
139

140
        if self.max_length is not None and len(value) > self.max_length:
2✔
141
            self.error("String value is too long")
2✔
142

143
        if self.min_length is not None and len(value) < self.min_length:
2✔
144
            self.error("String value is too short")
2✔
145

146
        if self.regex is not None and self.regex.match(value) is None:
2✔
147
            self.error("String value did not match validation regex")
2✔
148

149
    def lookup_member(self, member_name):
2✔
150
        return None
2✔
151

152
    def prepare_query_value(self, op, value):
2✔
153
        if not isinstance(op, str):
2✔
154
            return value
2✔
155

156
        if op in STRING_OPERATORS:
2✔
157
            case_insensitive = op.startswith("i")
2✔
158
            op = op.lstrip("i")
2✔
159

160
            flags = re.IGNORECASE if case_insensitive else 0
2✔
161

162
            regex = r"%s"
2✔
163
            if op == "startswith":
2✔
164
                regex = r"^%s"
2✔
165
            elif op == "endswith":
2✔
166
                regex = r"%s$"
2✔
167
            elif op == "exact":
2✔
168
                regex = r"^%s$"
2✔
169
            elif op == "wholeword":
2✔
170
                regex = r"\b%s\b"
2✔
171
            elif op == "regex":
2✔
172
                regex = value
2✔
173

174
            if op == "regex":
2✔
175
                value = re.compile(regex, flags)
2✔
176
            else:
177
                # escape unsafe characters which could lead to a re.error
178
                value = re.escape(value)
2✔
179
                value = re.compile(regex % value, flags)
2✔
180
        return super().prepare_query_value(op, value)
2✔
181

182

183
class URLField(StringField):
2✔
184
    """A field that validates input as an URL."""
185

186
    _URL_REGEX = LazyRegexCompiler(
2✔
187
        r"^(?:[a-z0-9\.\-]*)://"  # scheme is validated separately
188
        r"(?:(?:[A-Z0-9](?:[A-Z0-9-_]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|"  # domain...
189
        r"localhost|"  # localhost...
190
        r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|"  # ...or ipv4
191
        r"\[?[A-F0-9]*:[A-F0-9:]+\]?)"  # ...or ipv6
192
        r"(?::\d+)?"  # optional port
193
        r"(?:/?|[/?]\S+)$",
194
        re.IGNORECASE,
195
    )
196
    _URL_SCHEMES = ["http", "https", "ftp", "ftps"]
2✔
197

198
    def __init__(self, url_regex=None, schemes=None, **kwargs):
2✔
199
        """
200
        :param url_regex: (optional) Overwrite the default regex used for validation
201
        :param schemes: (optional) Overwrite the default URL schemes that are allowed
202
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.StringField`
203
        """
204
        self.url_regex = url_regex or self._URL_REGEX
2✔
205
        self.schemes = schemes or self._URL_SCHEMES
2✔
206
        super().__init__(**kwargs)
2✔
207

208
    def validate(self, value):
2✔
209
        # Check first if the scheme is valid
210
        scheme = value.split("://")[0].lower()
2✔
211
        if scheme not in self.schemes:
2✔
212
            self.error(f"Invalid scheme {scheme} in URL: {value}")
2✔
213

214
        # Then check full URL
215
        if not self.url_regex.match(value):
2✔
216
            self.error(f"Invalid URL: {value}")
2✔
217

218

219
class EmailField(StringField):
2✔
220
    """A field that validates input as an email address."""
221

222
    USER_REGEX = LazyRegexCompiler(
2✔
223
        # `dot-atom` defined in RFC 5322 Section 3.2.3.
224
        r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z]+)*\Z"
225
        # `quoted-string` defined in RFC 5322 Section 3.2.4.
226
        r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)',
227
        re.IGNORECASE,
228
    )
229

230
    UTF8_USER_REGEX = LazyRegexCompiler(
2✔
231
        (
232
            # RFC 6531 Section 3.3 extends `atext` (used by dot-atom) to
233
            # include `UTF8-non-ascii`.
234
            r"(^[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+(\.[-!#$%&'*+/=?^_`{}|~0-9A-Z\u0080-\U0010FFFF]+)*\Z"
235
            # `quoted-string`
236
            r'|^"([\001-\010\013\014\016-\037!#-\[\]-\177]|\\[\001-\011\013\014\016-\177])*"\Z)'
237
        ),
238
        re.IGNORECASE | re.UNICODE,
239
    )
240

241
    DOMAIN_REGEX = LazyRegexCompiler(
2✔
242
        r"((?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+)(?:[A-Z0-9-]{2,63}(?<!-))\Z",
243
        re.IGNORECASE,
244
    )
245

246
    error_msg = "Invalid email address: %s"
2✔
247

248
    def __init__(
2✔
249
        self,
250
        domain_whitelist=None,
251
        allow_utf8_user=False,
252
        allow_ip_domain=False,
253
        *args,
254
        **kwargs,
255
    ):
256
        """
257
        :param domain_whitelist: (optional) list of valid domain names applied during validation
258
        :param allow_utf8_user: Allow user part of the email to contain utf8 char
259
        :param allow_ip_domain: Allow domain part of the email to be an IPv4 or IPv6 address
260
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.StringField`
261
        """
262
        self.domain_whitelist = domain_whitelist or []
2✔
263
        self.allow_utf8_user = allow_utf8_user
2✔
264
        self.allow_ip_domain = allow_ip_domain
2✔
265
        super().__init__(*args, **kwargs)
2✔
266

267
    def validate_user_part(self, user_part):
2✔
268
        """Validate the user part of the email address. Return True if
269
        valid and False otherwise.
270
        """
271
        if self.allow_utf8_user:
2✔
272
            return self.UTF8_USER_REGEX.match(user_part)
2✔
273
        return self.USER_REGEX.match(user_part)
2✔
274

275
    def validate_domain_part(self, domain_part):
2✔
276
        """Validate the domain part of the email address. Return True if
277
        valid and False otherwise.
278
        """
279
        # Skip domain validation if it's in the whitelist.
280
        if domain_part in self.domain_whitelist:
2✔
281
            return True
2✔
282

283
        if self.DOMAIN_REGEX.match(domain_part):
2✔
284
            return True
2✔
285

286
        # Validate IPv4/IPv6, e.g. user@[192.168.0.1]
287
        if self.allow_ip_domain and domain_part[0] == "[" and domain_part[-1] == "]":
2✔
288
            for addr_family in (socket.AF_INET, socket.AF_INET6):
2✔
289
                try:
2✔
290
                    socket.inet_pton(addr_family, domain_part[1:-1])
2✔
291
                    return True
2✔
292
                except (OSError, UnicodeEncodeError):
2✔
293
                    pass
2✔
294

295
        return False
2✔
296

297
    def validate(self, value):
2✔
298
        super().validate(value)
2✔
299

300
        if "@" not in value:
2✔
301
            self.error(self.error_msg % value)
2✔
302

303
        user_part, domain_part = value.rsplit("@", 1)
2✔
304

305
        # Validate the user part.
306
        if not self.validate_user_part(user_part):
2✔
307
            self.error(self.error_msg % value)
2✔
308

309
        # Validate the domain and, if invalid, see if it's IDN-encoded.
310
        if not self.validate_domain_part(domain_part):
2✔
311
            try:
2✔
312
                domain_part = domain_part.encode("idna").decode("ascii")
2✔
313
            except UnicodeError:
2✔
314
                self.error(
2✔
315
                    "{} {}".format(
316
                        self.error_msg % value, "(domain failed IDN encoding)"
317
                    )
318
                )
319
            else:
320
                if not self.validate_domain_part(domain_part):
2✔
321
                    self.error(
2✔
322
                        "{} {}".format(
323
                            self.error_msg % value, "(domain validation failed)"
324
                        )
325
                    )
326

327

328
class IntField(BaseField):
2✔
329
    """32-bit integer field."""
330

331
    def __init__(self, min_value=None, max_value=None, **kwargs):
2✔
332
        """
333
        :param min_value: (optional) A min value that will be applied during validation
334
        :param max_value: (optional) A max value that will be applied during validation
335
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.BaseField`
336
        """
337
        self.min_value, self.max_value = min_value, max_value
2✔
338
        super().__init__(**kwargs)
2✔
339

340
    def to_python(self, value):
2✔
341
        try:
2✔
342
            value = int(value)
2✔
343
        except (TypeError, ValueError):
2✔
344
            pass
2✔
345
        return value
2✔
346

347
    def validate(self, value):
2✔
348
        try:
2✔
349
            value = int(value)
2✔
350
        except (TypeError, ValueError):
2✔
351
            self.error("%s could not be converted to int" % value)
2✔
352

353
        if self.min_value is not None and value < self.min_value:
2✔
354
            self.error("Integer value is too small")
2✔
355

356
        if self.max_value is not None and value > self.max_value:
2✔
357
            self.error("Integer value is too large")
2✔
358

359
    def prepare_query_value(self, op, value):
2✔
360
        if value is None:
2✔
361
            return value
2✔
362

363
        return super().prepare_query_value(op, int(value))
2✔
364

365

366
class LongField(IntField):
2✔
367
    """64-bit integer field. (Equivalent to IntField since the support to Python2 was dropped)"""
368

369
    def to_mongo(self, value):
2✔
370
        return Int64(value)
2✔
371

372

373
class FloatField(BaseField):
2✔
374
    """Floating point number field."""
375

376
    def __init__(self, min_value=None, max_value=None, **kwargs):
2✔
377
        """
378
        :param min_value: (optional) A min value that will be applied during validation
379
        :param max_value: (optional) A max value that will be applied during validation
380
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.BaseField`
381
        """
382
        self.min_value, self.max_value = min_value, max_value
2✔
383
        super().__init__(**kwargs)
2✔
384

385
    def to_python(self, value):
2✔
386
        try:
2✔
387
            value = float(value)
2✔
388
        except ValueError:
2✔
389
            pass
2✔
390
        return value
2✔
391

392
    def validate(self, value):
2✔
393
        if isinstance(value, int):
2✔
394
            try:
2✔
395
                value = float(value)
2✔
396
            except OverflowError:
2✔
397
                self.error("The value is too large to be converted to float")
2✔
398

399
        if not isinstance(value, float):
2✔
400
            self.error("FloatField only accepts float and integer values")
2✔
401

402
        if self.min_value is not None and value < self.min_value:
2✔
403
            self.error("Float value is too small")
2✔
404

405
        if self.max_value is not None and value > self.max_value:
2✔
406
            self.error("Float value is too large")
2✔
407

408
    def prepare_query_value(self, op, value):
2✔
409
        if value is None:
2✔
410
            return value
2✔
411

412
        return super().prepare_query_value(op, float(value))
2✔
413

414

415
class DecimalField(BaseField):
2✔
416
    """Disclaimer: This field is kept for historical reason but since it converts the values to float, it
417
    is not suitable for true decimal storage. Consider using :class:`~mongoengine.fields.Decimal128Field`.
418

419
    Fixed-point decimal number field. Stores the value as a float by default unless `force_string` is used.
420
    If using floats, beware of Decimal to float conversion (potential precision loss)
421
    """
422

423
    def __init__(
2✔
424
        self,
425
        min_value=None,
426
        max_value=None,
427
        force_string=False,
428
        precision=2,
429
        rounding=decimal.ROUND_HALF_UP,
430
        **kwargs,
431
    ):
432
        """
433
        :param min_value: (optional) A min value that will be applied during validation
434
        :param max_value: (optional) A max value that will be applied during validation
435
        :param force_string: Store the value as a string (instead of a float).
436
         Be aware that this affects query sorting and operation like lte, gte (as string comparison is applied)
437
         and some query operator won't work (e.g. inc, dec)
438
        :param precision: Number of decimal places to store.
439
        :param rounding: The rounding rule from the python decimal library:
440

441
            - decimal.ROUND_CEILING (towards Infinity)
442
            - decimal.ROUND_DOWN (towards zero)
443
            - decimal.ROUND_FLOOR (towards -Infinity)
444
            - decimal.ROUND_HALF_DOWN (to nearest with ties going towards zero)
445
            - decimal.ROUND_HALF_EVEN (to nearest with ties going to nearest even integer)
446
            - decimal.ROUND_HALF_UP (to nearest with ties going away from zero)
447
            - decimal.ROUND_UP (away from zero)
448
            - decimal.ROUND_05UP (away from zero if last digit after rounding towards zero would have been 0 or 5; otherwise towards zero)
449

450
            Defaults to: ``decimal.ROUND_HALF_UP``
451
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.BaseField`
452
        """
453
        self.min_value = min_value
2✔
454
        self.max_value = max_value
2✔
455
        self.force_string = force_string
2✔
456

457
        if precision < 0 or not isinstance(precision, int):
2✔
458
            self.error("precision must be a positive integer")
2✔
459

460
        self.precision = precision
2✔
461
        self.rounding = rounding
2✔
462

463
        super().__init__(**kwargs)
2✔
464

465
    def to_python(self, value):
2✔
466
        # Convert to string for python 2.6 before casting to Decimal
467
        try:
2✔
468
            value = decimal.Decimal("%s" % value)
2✔
469
        except (TypeError, ValueError, decimal.InvalidOperation):
2✔
470
            return value
2✔
471
        if self.precision > 0:
2✔
472
            return value.quantize(
2✔
473
                decimal.Decimal(".%s" % ("0" * self.precision)), rounding=self.rounding
474
            )
475
        else:
476
            return value.quantize(decimal.Decimal(), rounding=self.rounding)
2✔
477

478
    def to_mongo(self, value):
2✔
479
        if self.force_string:
2✔
480
            return str(self.to_python(value))
2✔
481
        return float(self.to_python(value))
2✔
482

483
    def validate(self, value):
2✔
484
        if not isinstance(value, decimal.Decimal):
2✔
485
            if not isinstance(value, str):
2✔
486
                value = str(value)
2✔
487
            try:
2✔
488
                value = decimal.Decimal(value)
2✔
489
            except (TypeError, ValueError, decimal.InvalidOperation) as exc:
2✔
490
                self.error("Could not convert value to decimal: %s" % exc)
2✔
491

492
        if self.min_value is not None and value < self.min_value:
2✔
493
            self.error("Decimal value is too small")
2✔
494

495
        if self.max_value is not None and value > self.max_value:
2✔
496
            self.error("Decimal value is too large")
2✔
497

498
    def prepare_query_value(self, op, value):
2✔
499
        if value is None:
2✔
500
            return value
2✔
501
        return super().prepare_query_value(op, self.to_mongo(value))
2✔
502

503

504
class BooleanField(BaseField):
2✔
505
    """Boolean field type."""
506

507
    def to_python(self, value):
2✔
508
        try:
2✔
509
            value = bool(value)
2✔
510
        except (ValueError, TypeError):
2✔
511
            pass
2✔
512
        return value
2✔
513

514
    def validate(self, value):
2✔
515
        if not isinstance(value, bool):
2✔
516
            self.error("BooleanField only accepts boolean values")
2✔
517

518

519
class DateTimeField(BaseField):
2✔
520
    """Datetime field.
521

522
    Uses the python-dateutil library if available alternatively use time.strptime
523
    to parse the dates.  Note: python-dateutil's parser is fully featured and when
524
    installed you can utilise it to convert varying types of date formats into valid
525
    python datetime objects.
526

527
    Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow)
528

529
    Note: Microseconds are rounded to the nearest millisecond.
530
      Pre UTC microsecond support is effectively broken.
531
      Use :class:`~mongoengine.fields.ComplexDateTimeField` if you
532
      need accurate microsecond support.
533
    """
534

535
    def validate(self, value):
2✔
536
        new_value = self.to_mongo(value)
2✔
537
        if not isinstance(new_value, (datetime.datetime, datetime.date)):
2✔
538
            self.error('cannot parse date "%s"' % value)
2✔
539

540
    def to_mongo(self, value):
2✔
541
        if value is None:
2✔
542
            return value
2✔
543
        if isinstance(value, datetime.datetime):
2✔
544
            return value
2✔
545
        if isinstance(value, datetime.date):
2✔
546
            return datetime.datetime(value.year, value.month, value.day)
2✔
547
        if callable(value):
2✔
548
            return value()
2✔
549

550
        if isinstance(value, str):
2✔
551
            return self._parse_datetime(value)
2✔
552
        else:
553
            return None
2✔
554

555
    @staticmethod
2✔
556
    def _parse_datetime(value):
2✔
557
        # Attempt to parse a datetime from a string
558
        value = value.strip()
2✔
559
        if not value:
2✔
560
            return None
2✔
561

562
        if dateutil:
2✔
563
            try:
×
564
                return dateutil.parser.parse(value)
×
565
            except (TypeError, ValueError, OverflowError):
×
566
                return None
×
567

568
        # split usecs, because they are not recognized by strptime.
569
        if "." in value:
2✔
570
            try:
2✔
571
                value, usecs = value.split(".")
2✔
572
                usecs = int(usecs)
2✔
573
            except ValueError:
2✔
574
                return None
2✔
575
        else:
576
            usecs = 0
2✔
577
        kwargs = {"microsecond": usecs}
2✔
578
        try:  # Seconds are optional, so try converting seconds first.
2✔
579
            return datetime.datetime(
2✔
580
                *time.strptime(value, "%Y-%m-%d %H:%M:%S")[:6], **kwargs
581
            )
582
        except ValueError:
2✔
583
            try:  # Try without seconds.
2✔
584
                return datetime.datetime(
2✔
585
                    *time.strptime(value, "%Y-%m-%d %H:%M")[:5], **kwargs
586
                )
587
            except ValueError:  # Try without hour/minutes/seconds.
2✔
588
                try:
2✔
589
                    return datetime.datetime(
2✔
590
                        *time.strptime(value, "%Y-%m-%d")[:3], **kwargs
591
                    )
592
                except ValueError:
2✔
593
                    return None
2✔
594

595
    def prepare_query_value(self, op, value):
2✔
596
        return super().prepare_query_value(op, self.to_mongo(value))
2✔
597

598

599
class DateField(DateTimeField):
2✔
600
    def to_mongo(self, value):
2✔
601
        value = super().to_mongo(value)
2✔
602
        # drop hours, minutes, seconds
603
        if isinstance(value, datetime.datetime):
2✔
604
            value = datetime.datetime(value.year, value.month, value.day)
2✔
605
        return value
2✔
606

607
    def to_python(self, value):
2✔
608
        value = super().to_python(value)
2✔
609
        # convert datetime to date
610
        if isinstance(value, datetime.datetime):
2✔
611
            value = datetime.date(value.year, value.month, value.day)
2✔
612
        return value
2✔
613

614

615
class ComplexDateTimeField(StringField):
2✔
616
    """
617
    ComplexDateTimeField handles microseconds exactly instead of rounding
618
    like DateTimeField does.
619

620
    Derives from a StringField so you can do `gte` and `lte` filtering by
621
    using lexicographical comparison when filtering / sorting strings.
622

623
    The stored string has the following format:
624

625
        YYYY,MM,DD,HH,MM,SS,NNNNNN
626

627
    Where NNNNNN is the number of microseconds of the represented `datetime`.
628
    The `,` as the separator can be easily modified by passing the `separator`
629
    keyword when initializing the field.
630

631
    Note: To default the field to the current datetime, use: DateTimeField(default=datetime.utcnow)
632
    """
633

634
    def __init__(self, separator=",", **kwargs):
2✔
635
        """
636
        :param separator: Allows to customize the separator used for storage (default ``,``)
637
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.StringField`
638
        """
639
        self.separator = separator
2✔
640
        self.format = separator.join(["%Y", "%m", "%d", "%H", "%M", "%S", "%f"])
2✔
641
        super().__init__(**kwargs)
2✔
642

643
    def _convert_from_datetime(self, val):
2✔
644
        """
645
        Convert a `datetime` object to a string representation (which will be
646
        stored in MongoDB). This is the reverse function of
647
        `_convert_from_string`.
648

649
        >>> a = datetime(2011, 6, 8, 20, 26, 24, 92284)
650
        >>> ComplexDateTimeField()._convert_from_datetime(a)
651
        '2011,06,08,20,26,24,092284'
652
        """
653
        return val.strftime(self.format)
2✔
654

655
    def _convert_from_string(self, data):
2✔
656
        """
657
        Convert a string representation to a `datetime` object (the object you
658
        will manipulate). This is the reverse function of
659
        `_convert_from_datetime`.
660

661
        >>> a = '2011,06,08,20,26,24,092284'
662
        >>> ComplexDateTimeField()._convert_from_string(a)
663
        datetime.datetime(2011, 6, 8, 20, 26, 24, 92284)
664
        """
665
        values = [int(d) for d in data.split(self.separator)]
2✔
666
        return datetime.datetime(*values)
2✔
667

668
    def __get__(self, instance, owner):
2✔
669
        if instance is None:
2✔
670
            return self
×
671

672
        data = super().__get__(instance, owner)
2✔
673

674
        if isinstance(data, datetime.datetime) or data is None:
2✔
675
            return data
2✔
676
        return self._convert_from_string(data)
2✔
677

678
    def __set__(self, instance, value):
2✔
679
        super().__set__(instance, value)
2✔
680
        value = instance._data[self.name]
2✔
681
        if value is not None:
2✔
682
            if isinstance(value, datetime.datetime):
2✔
683
                instance._data[self.name] = self._convert_from_datetime(value)
2✔
684
            else:
685
                instance._data[self.name] = value
2✔
686

687
    def validate(self, value):
2✔
688
        value = self.to_python(value)
2✔
689
        if not isinstance(value, datetime.datetime):
2✔
690
            self.error("Only datetime objects may used in a ComplexDateTimeField")
2✔
691

692
    def to_python(self, value):
2✔
693
        original_value = value
2✔
694
        try:
2✔
695
            return self._convert_from_string(value)
2✔
696
        except Exception:
2✔
697
            return original_value
2✔
698

699
    def to_mongo(self, value):
2✔
700
        value = self.to_python(value)
2✔
701
        return self._convert_from_datetime(value)
2✔
702

703
    def prepare_query_value(self, op, value):
2✔
704
        if value is None:
2✔
705
            return value
2✔
706
        return super().prepare_query_value(op, self._convert_from_datetime(value))
2✔
707

708

709
class EmbeddedDocumentField(BaseField):
2✔
710
    """An embedded document field - with a declared document_type.
711
    Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`.
712
    """
713

714
    def __init__(self, document_type, **kwargs):
2✔
715
        if not (
2✔
716
            isinstance(document_type, str)
717
            or issubclass(document_type, EmbeddedDocument)
718
        ):
719
            self.error(
2✔
720
                "Invalid embedded document class provided to an "
721
                "EmbeddedDocumentField"
722
            )
723

724
        self.document_type_obj = document_type
2✔
725
        super().__init__(**kwargs)
2✔
726

727
    @property
2✔
728
    def document_type(self):
2✔
729
        if isinstance(self.document_type_obj, str):
2✔
730
            if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
2✔
731
                resolved_document_type = self.owner_document
2✔
732
            else:
733
                resolved_document_type = get_document(self.document_type_obj)
2✔
734

735
            if not issubclass(resolved_document_type, EmbeddedDocument):
2✔
736
                # Due to the late resolution of the document_type
737
                # There is a chance that it won't be an EmbeddedDocument (#1661)
738
                self.error(
2✔
739
                    "Invalid embedded document class provided to an "
740
                    "EmbeddedDocumentField"
741
                )
742
            self.document_type_obj = resolved_document_type
2✔
743

744
        return self.document_type_obj
2✔
745

746
    def to_python(self, value):
2✔
747
        if not isinstance(value, self.document_type):
2✔
748
            return self.document_type._from_son(
2✔
749
                value, _auto_dereference=self._auto_dereference
750
            )
751
        return value
2✔
752

753
    def to_mongo(self, value, use_db_field=True, fields=None):
2✔
754
        if not isinstance(value, self.document_type):
2✔
755
            return value
2✔
756
        return self.document_type.to_mongo(value, use_db_field, fields)
2✔
757

758
    def validate(self, value, clean=True):
2✔
759
        """Make sure that the document instance is an instance of the
760
        EmbeddedDocument subclass provided when the document was defined.
761
        """
762
        # Using isinstance also works for subclasses of self.document
763
        if not isinstance(value, self.document_type):
2✔
764
            self.error(
2✔
765
                "Invalid embedded document instance provided to an "
766
                "EmbeddedDocumentField"
767
            )
768
        value.validate(clean=clean)
2✔
769

770
    def lookup_member(self, member_name):
2✔
771
        doc_and_subclasses = [self.document_type] + self.document_type.__subclasses__()
2✔
772
        for doc_type in doc_and_subclasses:
2✔
773
            field = doc_type._fields.get(member_name)
2✔
774
            if field:
2✔
775
                return field
2✔
776

777
    def prepare_query_value(self, op, value):
2✔
778
        if value is not None and not isinstance(value, self.document_type):
2✔
779
            # Short circuit for special operators, returning them as is
780
            if isinstance(value, dict) and all(k.startswith("$") for k in value.keys()):
2✔
781
                return value
2✔
782
            try:
2✔
783
                value = self.document_type._from_son(value)
2✔
784
            except ValueError:
2✔
785
                raise InvalidQueryError(
2✔
786
                    "Querying the embedded document '%s' failed, due to an invalid query value"
787
                    % (self.document_type._class_name,)
788
                )
789
        super().prepare_query_value(op, value)
2✔
790
        return self.to_mongo(value)
2✔
791

792

793
class GenericEmbeddedDocumentField(BaseField):
2✔
794
    """A generic embedded document field - allows any
795
    :class:`~mongoengine.EmbeddedDocument` to be stored.
796

797
    Only valid values are subclasses of :class:`~mongoengine.EmbeddedDocument`.
798

799
    .. note ::
800
        You can use the choices param to limit the acceptable
801
        EmbeddedDocument types
802
    """
803

804
    def prepare_query_value(self, op, value):
2✔
805
        return super().prepare_query_value(op, self.to_mongo(value))
2✔
806

807
    def to_python(self, value):
2✔
808
        if isinstance(value, dict):
2✔
809
            doc_cls = get_document(value["_cls"])
2✔
810
            value = doc_cls._from_son(value)
2✔
811

812
        return value
2✔
813

814
    def validate(self, value, clean=True):
2✔
815
        if self.choices and isinstance(value, SON):
2✔
816
            for choice in self.choices:
2✔
817
                if value["_cls"] == choice._class_name:
2✔
818
                    return True
2✔
819

820
        if not isinstance(value, EmbeddedDocument):
2✔
821
            self.error(
×
822
                "Invalid embedded document instance provided to an "
823
                "GenericEmbeddedDocumentField"
824
            )
825

826
        value.validate(clean=clean)
2✔
827

828
    def lookup_member(self, member_name):
2✔
829
        document_choices = self.choices or []
2✔
830
        for document_choice in document_choices:
2✔
831
            doc_and_subclasses = [document_choice] + document_choice.__subclasses__()
2✔
832
            for doc_type in doc_and_subclasses:
2✔
833
                field = doc_type._fields.get(member_name)
2✔
834
                if field:
2✔
835
                    return field
2✔
836

837
    def to_mongo(self, document, use_db_field=True, fields=None):
2✔
838
        if document is None:
2✔
839
            return None
×
840
        data = document.to_mongo(use_db_field, fields)
2✔
841
        if "_cls" not in data:
2✔
842
            data["_cls"] = document._class_name
2✔
843
        return data
2✔
844

845

846
class DynamicField(BaseField):
2✔
847
    """A truly dynamic field type capable of handling different and varying
848
    types of data.
849

850
    Used by :class:`~mongoengine.DynamicDocument` to handle dynamic data"""
851

852
    def to_mongo(self, value, use_db_field=True, fields=None):
2✔
853
        """Convert a Python type to a MongoDB compatible type."""
854

855
        if isinstance(value, str):
2✔
856
            return value
2✔
857

858
        if hasattr(value, "to_mongo"):
2✔
859
            cls = value.__class__
2✔
860
            val = value.to_mongo(use_db_field, fields)
2✔
861
            # If we its a document thats not inherited add _cls
862
            if isinstance(value, Document):
2✔
863
                val = {"_ref": value.to_dbref(), "_cls": cls.__name__}
2✔
864
            if isinstance(value, EmbeddedDocument):
2✔
865
                val["_cls"] = cls.__name__
2✔
866
            return val
2✔
867

868
        if not isinstance(value, (dict, list, tuple)):
2✔
869
            return value
2✔
870

871
        is_list = False
2✔
872
        if not hasattr(value, "items"):
2✔
873
            is_list = True
2✔
874
            value = {k: v for k, v in enumerate(value)}
2✔
875

876
        data = {}
2✔
877
        for k, v in value.items():
2✔
878
            data[k] = self.to_mongo(v, use_db_field, fields)
2✔
879

880
        value = data
2✔
881
        if is_list:  # Convert back to a list
2✔
882
            value = [v for k, v in sorted(data.items(), key=itemgetter(0))]
2✔
883
        return value
2✔
884

885
    def to_python(self, value):
2✔
886
        if isinstance(value, dict) and "_cls" in value:
2✔
887
            doc_cls = get_document(value["_cls"])
2✔
888
            if "_ref" in value:
2✔
889
                value = doc_cls._get_db().dereference(value["_ref"])
2✔
890
            return doc_cls._from_son(value)
2✔
891

892
        return super().to_python(value)
2✔
893

894
    def lookup_member(self, member_name):
2✔
895
        return member_name
×
896

897
    def prepare_query_value(self, op, value):
2✔
898
        if isinstance(value, str):
2✔
899
            return StringField().prepare_query_value(op, value)
2✔
900
        return super().prepare_query_value(op, self.to_mongo(value))
2✔
901

902
    def validate(self, value, clean=True):
2✔
903
        if hasattr(value, "validate"):
2✔
904
            value.validate(clean=clean)
2✔
905

906

907
class ListField(ComplexBaseField):
2✔
908
    """A list field that wraps a standard field, allowing multiple instances
909
    of the field to be used as a list in the database.
910

911
    If using with ReferenceFields see: :ref:`many-to-many-with-listfields`
912

913
    .. note::
914
        Required means it cannot be empty - as the default for ListFields is []
915
    """
916

917
    def __init__(self, field=None, *, max_length=None, **kwargs):
2✔
918
        self.max_length = max_length
2✔
919
        kwargs.setdefault("default", list)
2✔
920
        super().__init__(field=field, **kwargs)
2✔
921

922
    def __get__(self, instance, owner):
2✔
923
        if instance is None:
2✔
924
            # Document class being used rather than a document object
925
            return self
2✔
926
        value = instance._data.get(self.name)
2✔
927
        LazyReferenceField = _import_class("LazyReferenceField")
2✔
928
        GenericLazyReferenceField = _import_class("GenericLazyReferenceField")
2✔
929
        if (
2✔
930
            isinstance(self.field, (LazyReferenceField, GenericLazyReferenceField))
931
            and value
932
        ):
933
            instance._data[self.name] = [self.field.build_lazyref(x) for x in value]
2✔
934
        return super().__get__(instance, owner)
2✔
935

936
    def validate(self, value):
2✔
937
        """Make sure that a list of valid fields is being used."""
938
        if not isinstance(value, (list, tuple, BaseQuerySet)):
2✔
939
            self.error("Only lists and tuples may be used in a list field")
2✔
940

941
        # Validate that max_length is not exceeded.
942
        # NOTE It's still possible to bypass this enforcement by using $push.
943
        # However, if the document is reloaded after $push and then re-saved,
944
        # the validation error will be raised.
945
        if self.max_length is not None and len(value) > self.max_length:
2✔
946
            self.error("List is too long")
2✔
947

948
        super().validate(value)
2✔
949

950
    def prepare_query_value(self, op, value):
2✔
951
        # Validate that the `set` operator doesn't contain more items than `max_length`.
952
        if op == "set" and self.max_length is not None and len(value) > self.max_length:
2✔
953
            self.error("List is too long")
2✔
954

955
        if self.field:
2✔
956
            # If the value is iterable and it's not a string nor a
957
            # BaseDocument, call prepare_query_value for each of its items.
958
            is_iter = hasattr(value, "__iter__")
2✔
959
            eligible_iter = is_iter and not isinstance(value, (str, BaseDocument))
2✔
960
            if (
2✔
961
                op in ("set", "unset", "gt", "gte", "lt", "lte", "ne", None)
962
                and eligible_iter
963
            ):
964
                return [self.field.prepare_query_value(op, v) for v in value]
2✔
965

966
            return self.field.prepare_query_value(op, value)
2✔
967

968
        return super().prepare_query_value(op, value)
2✔
969

970

971
class EmbeddedDocumentListField(ListField):
2✔
972
    """A :class:`~mongoengine.ListField` designed specially to hold a list of
973
    embedded documents to provide additional query helpers.
974

975
    .. note::
976
        The only valid list values are subclasses of
977
        :class:`~mongoengine.EmbeddedDocument`.
978
    """
979

980
    def __init__(self, document_type, **kwargs):
2✔
981
        """
982
        :param document_type: The type of
983
         :class:`~mongoengine.EmbeddedDocument` the list will hold.
984
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.ListField`
985
        """
986
        super().__init__(field=EmbeddedDocumentField(document_type), **kwargs)
2✔
987

988

989
class SortedListField(ListField):
2✔
990
    """A ListField that sorts the contents of its list before writing to
991
    the database in order to ensure that a sorted list is always
992
    retrieved.
993

994
    .. warning::
995
        There is a potential race condition when handling lists.  If you set /
996
        save the whole list then other processes trying to save the whole list
997
        as well could overwrite changes.  The safest way to append to a list is
998
        to perform a push operation.
999
    """
1000

1001
    def __init__(self, field, **kwargs):
2✔
1002
        self._ordering = kwargs.pop("ordering", None)
2✔
1003
        self._order_reverse = kwargs.pop("reverse", False)
2✔
1004
        super().__init__(field, **kwargs)
2✔
1005

1006
    def to_mongo(self, value, use_db_field=True, fields=None):
2✔
1007
        value = super().to_mongo(value, use_db_field, fields)
2✔
1008
        if self._ordering is not None:
2✔
1009
            return sorted(
2✔
1010
                value, key=itemgetter(self._ordering), reverse=self._order_reverse
1011
            )
1012
        return sorted(value, reverse=self._order_reverse)
2✔
1013

1014

1015
def key_not_string(d):
2✔
1016
    """Helper function to recursively determine if any key in a
1017
    dictionary is not a string.
1018
    """
1019
    for k, v in d.items():
2✔
1020
        if not isinstance(k, str) or (isinstance(v, dict) and key_not_string(v)):
2✔
1021
            return True
2✔
1022

1023

1024
def key_starts_with_dollar(d):
2✔
1025
    """Helper function to recursively determine if any key in a
1026
    dictionary starts with a dollar
1027
    """
1028
    for k, v in d.items():
2✔
1029
        if (k.startswith("$")) or (isinstance(v, dict) and key_starts_with_dollar(v)):
2✔
1030
            return True
2✔
1031

1032

1033
class DictField(ComplexBaseField):
2✔
1034
    """A dictionary field that wraps a standard Python dictionary. This is
1035
    similar to an embedded document, but the structure is not defined.
1036

1037
    .. note::
1038
        Required means it cannot be empty - as the default for DictFields is {}
1039
    """
1040

1041
    def __init__(self, field=None, *args, **kwargs):
2✔
1042
        kwargs.setdefault("default", dict)
2✔
1043
        super().__init__(*args, field=field, **kwargs)
2✔
1044
        self.set_auto_dereferencing(False)
2✔
1045

1046
    def validate(self, value):
2✔
1047
        """Make sure that a list of valid fields is being used."""
1048
        if not isinstance(value, dict):
2✔
1049
            self.error("Only dictionaries may be used in a DictField")
2✔
1050

1051
        if key_not_string(value):
2✔
1052
            msg = "Invalid dictionary key - documents must have only string keys"
2✔
1053
            self.error(msg)
2✔
1054

1055
        # Following condition applies to MongoDB >= 3.6
1056
        # older Mongo has stricter constraints but
1057
        # it will be rejected upon insertion anyway
1058
        # Having a validation that depends on the MongoDB version
1059
        # is not straightforward as the field isn't aware of the connected Mongo
1060
        if key_starts_with_dollar(value):
2✔
1061
            self.error(
2✔
1062
                'Invalid dictionary key name - keys may not startswith "$" characters'
1063
            )
1064
        super().validate(value)
2✔
1065

1066
    def lookup_member(self, member_name):
2✔
1067
        return DictField(db_field=member_name)
2✔
1068

1069
    def prepare_query_value(self, op, value):
2✔
1070
        match_operators = [*STRING_OPERATORS]
2✔
1071

1072
        if op in match_operators and isinstance(value, str):
2✔
1073
            return StringField().prepare_query_value(op, value)
2✔
1074

1075
        if hasattr(
2✔
1076
            self.field, "field"
1077
        ):  # Used for instance when using DictField(ListField(IntField()))
1078
            if op in ("set", "unset") and isinstance(value, dict):
2✔
1079
                return {
2✔
1080
                    k: self.field.prepare_query_value(op, v) for k, v in value.items()
1081
                }
1082
            return self.field.prepare_query_value(op, value)
2✔
1083

1084
        return super().prepare_query_value(op, value)
2✔
1085

1086

1087
class MapField(DictField):
2✔
1088
    """A field that maps a name to a specified field type. Similar to
1089
    a DictField, except the 'value' of each item must match the specified
1090
    field type.
1091
    """
1092

1093
    def __init__(self, field=None, *args, **kwargs):
2✔
1094
        # XXX ValidationError raised outside the "validate" method.
1095
        if not isinstance(field, BaseField):
2✔
1096
            self.error("Argument to MapField constructor must be a valid field")
2✔
1097
        super().__init__(field=field, *args, **kwargs)
2✔
1098

1099

1100
class ReferenceField(BaseField):
2✔
1101
    """A reference to a document that will be automatically dereferenced on
1102
    access (lazily).
1103

1104
    Note this means you will get a database I/O access everytime you access
1105
    this field. This is necessary because the field returns a :class:`~mongoengine.Document`
1106
    which precise type can depend of the value of the `_cls` field present in the
1107
    document in database.
1108
    In short, using this type of field can lead to poor performances (especially
1109
    if you access this field only to retrieve it `pk` field which is already
1110
    known before dereference). To solve this you should consider using the
1111
    :class:`~mongoengine.fields.LazyReferenceField`.
1112

1113
    Use the `reverse_delete_rule` to handle what should happen if the document
1114
    the field is referencing is deleted.  EmbeddedDocuments, DictFields and
1115
    MapFields does not support reverse_delete_rule and an `InvalidDocumentError`
1116
    will be raised if trying to set on one of these Document / Field types.
1117

1118
    The options are:
1119

1120
      * DO_NOTHING (0)  - don't do anything (default).
1121
      * NULLIFY    (1)  - Updates the reference to null.
1122
      * CASCADE    (2)  - Deletes the documents associated with the reference.
1123
      * DENY       (3)  - Prevent the deletion of the reference object.
1124
      * PULL       (4)  - Pull the reference from a :class:`~mongoengine.fields.ListField` of references
1125

1126
    Alternative syntax for registering delete rules (useful when implementing
1127
    bi-directional delete rules)
1128

1129
    .. code-block:: python
1130

1131
        class Org(Document):
1132
            owner = ReferenceField('User')
1133

1134
        class User(Document):
1135
            org = ReferenceField('Org', reverse_delete_rule=CASCADE)
1136

1137
        User.register_delete_rule(Org, 'owner', DENY)
1138
    """
1139

1140
    def __init__(
2✔
1141
        self, document_type, dbref=False, reverse_delete_rule=DO_NOTHING, **kwargs
1142
    ):
1143
        """Initialises the Reference Field.
1144

1145
        :param document_type: The type of Document that will be referenced
1146
        :param dbref:  Store the reference as :class:`~pymongo.dbref.DBRef`
1147
          or as the :class:`~pymongo.objectid.ObjectId`.
1148
        :param reverse_delete_rule: Determines what to do when the referring
1149
          object is deleted
1150
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.BaseField`
1151

1152
        .. note ::
1153
            A reference to an abstract document type is always stored as a
1154
            :class:`~pymongo.dbref.DBRef`, regardless of the value of `dbref`.
1155
        """
1156
        # XXX ValidationError raised outside of the "validate" method.
1157
        if not (
2✔
1158
            isinstance(document_type, str)
1159
            or (isclass(document_type) and issubclass(document_type, Document))
1160
        ):
1161
            self.error(
2✔
1162
                "Argument to ReferenceField constructor must be a "
1163
                "document class or a string"
1164
            )
1165

1166
        self.dbref = dbref
2✔
1167
        self.document_type_obj = document_type
2✔
1168
        self.reverse_delete_rule = reverse_delete_rule
2✔
1169
        super().__init__(**kwargs)
2✔
1170

1171
    @property
2✔
1172
    def document_type(self):
2✔
1173
        if isinstance(self.document_type_obj, str):
2✔
1174
            if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
2✔
1175
                self.document_type_obj = self.owner_document
2✔
1176
            else:
1177
                self.document_type_obj = get_document(self.document_type_obj)
2✔
1178
        return self.document_type_obj
2✔
1179

1180
    @staticmethod
2✔
1181
    def _lazy_load_ref(ref_cls, dbref):
2✔
1182
        dereferenced_son = ref_cls._get_db().dereference(dbref)
2✔
1183
        if dereferenced_son is None:
2✔
1184
            raise DoesNotExist(f"Trying to dereference unknown document {dbref}")
2✔
1185

1186
        return ref_cls._from_son(dereferenced_son)
2✔
1187

1188
    def __get__(self, instance, owner):
2✔
1189
        """Descriptor to allow lazy dereferencing."""
1190
        if instance is None:
2✔
1191
            # Document class being used rather than a document object
1192
            return self
2✔
1193

1194
        # Get value from document instance if available
1195
        ref_value = instance._data.get(self.name)
2✔
1196
        auto_dereference = instance._fields[self.name]._auto_dereference
2✔
1197
        # Dereference DBRefs
1198
        if auto_dereference and isinstance(ref_value, DBRef):
2✔
1199
            if hasattr(ref_value, "cls"):
2✔
1200
                # Dereference using the class type specified in the reference
1201
                cls = get_document(ref_value.cls)
2✔
1202
            else:
1203
                cls = self.document_type
2✔
1204

1205
            instance._data[self.name] = self._lazy_load_ref(cls, ref_value)
2✔
1206

1207
        return super().__get__(instance, owner)
2✔
1208

1209
    def to_mongo(self, document):
2✔
1210
        if isinstance(document, DBRef):
2✔
1211
            if not self.dbref:
2✔
1212
                return document.id
2✔
1213
            return document
2✔
1214

1215
        if isinstance(document, Document):
2✔
1216
            # We need the id from the saved object to create the DBRef
1217
            id_ = document.pk
2✔
1218

1219
            # XXX ValidationError raised outside of the "validate" method.
1220
            if id_ is None:
2✔
1221
                self.error(
×
1222
                    "You can only reference documents once they have"
1223
                    " been saved to the database"
1224
                )
1225

1226
            # Use the attributes from the document instance, so that they
1227
            # override the attributes of this field's document type
1228
            cls = document
2✔
1229
        else:
1230
            id_ = document
2✔
1231
            cls = self.document_type
2✔
1232

1233
        id_field_name = cls._meta["id_field"]
2✔
1234
        id_field = cls._fields[id_field_name]
2✔
1235

1236
        id_ = id_field.to_mongo(id_)
2✔
1237
        if self.document_type._meta.get("abstract"):
2✔
1238
            collection = cls._get_collection_name()
2✔
1239
            return DBRef(collection, id_, cls=cls._class_name)
2✔
1240
        elif self.dbref:
2✔
1241
            collection = cls._get_collection_name()
2✔
1242
            return DBRef(collection, id_)
2✔
1243

1244
        return id_
2✔
1245

1246
    def to_python(self, value):
2✔
1247
        """Convert a MongoDB-compatible type to a Python type."""
1248
        if not self.dbref and not isinstance(
2✔
1249
            value, (DBRef, Document, EmbeddedDocument)
1250
        ):
1251
            collection = self.document_type._get_collection_name()
2✔
1252
            value = DBRef(collection, self.document_type.id.to_python(value))
2✔
1253
        return value
2✔
1254

1255
    def prepare_query_value(self, op, value):
2✔
1256
        if value is None:
2✔
1257
            return None
2✔
1258
        super().prepare_query_value(op, value)
2✔
1259
        return self.to_mongo(value)
2✔
1260

1261
    def validate(self, value):
2✔
1262
        if not isinstance(value, (self.document_type, LazyReference, DBRef, ObjectId)):
2✔
1263
            self.error(
2✔
1264
                "A ReferenceField only accepts DBRef, LazyReference, ObjectId or documents"
1265
            )
1266

1267
        if isinstance(value, Document) and value.id is None:
2✔
1268
            self.error(
2✔
1269
                "You can only reference documents once they have been "
1270
                "saved to the database"
1271
            )
1272

1273
    def lookup_member(self, member_name):
2✔
1274
        return self.document_type._fields.get(member_name)
×
1275

1276

1277
class CachedReferenceField(BaseField):
2✔
1278
    """A referencefield with cache fields to purpose pseudo-joins"""
1279

1280
    def __init__(self, document_type, fields=None, auto_sync=True, **kwargs):
2✔
1281
        """Initialises the Cached Reference Field.
1282

1283
        :param document_type: The type of Document that will be referenced
1284
        :param fields:  A list of fields to be cached in document
1285
        :param auto_sync: if True documents are auto updated
1286
        :param kwargs: Keyword arguments passed into the parent :class:`~mongoengine.BaseField`
1287
        """
1288
        if fields is None:
2✔
1289
            fields = []
2✔
1290

1291
        # XXX ValidationError raised outside of the "validate" method.
1292
        if not isinstance(document_type, str) and not (
2✔
1293
            inspect.isclass(document_type) and issubclass(document_type, Document)
1294
        ):
1295
            self.error(
2✔
1296
                "Argument to CachedReferenceField constructor must be a"
1297
                " document class or a string"
1298
            )
1299

1300
        self.auto_sync = auto_sync
2✔
1301
        self.document_type_obj = document_type
2✔
1302
        self.fields = fields
2✔
1303
        super().__init__(**kwargs)
2✔
1304

1305
    def start_listener(self):
2✔
1306
        from mongoengine import signals
2✔
1307

1308
        signals.post_save.connect(self.on_document_pre_save, sender=self.document_type)
2✔
1309

1310
    def on_document_pre_save(self, sender, document, created, **kwargs):
2✔
1311
        if created:
2✔
1312
            return None
2✔
1313

1314
        update_kwargs = {
2✔
1315
            f"set__{self.name}__{key}": val
1316
            for key, val in document._delta()[0].items()
1317
            if key in self.fields
1318
        }
1319
        if update_kwargs:
2✔
1320
            filter_kwargs = {}
2✔
1321
            filter_kwargs[self.name] = document
2✔
1322

1323
            self.owner_document.objects(**filter_kwargs).update(**update_kwargs)
2✔
1324

1325
    def to_python(self, value):
2✔
1326
        if isinstance(value, dict):
2✔
1327
            collection = self.document_type._get_collection_name()
2✔
1328
            value = DBRef(collection, self.document_type.id.to_python(value["_id"]))
2✔
1329
            return self.document_type._from_son(
2✔
1330
                self.document_type._get_db().dereference(value)
1331
            )
1332

1333
        return value
2✔
1334

1335
    @property
2✔
1336
    def document_type(self):
2✔
1337
        if isinstance(self.document_type_obj, str):
2✔
1338
            if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
2✔
1339
                self.document_type_obj = self.owner_document
2✔
1340
            else:
1341
                self.document_type_obj = get_document(self.document_type_obj)
×
1342
        return self.document_type_obj
2✔
1343

1344
    @staticmethod
2✔
1345
    def _lazy_load_ref(ref_cls, dbref):
2✔
1346
        dereferenced_son = ref_cls._get_db().dereference(dbref)
×
1347
        if dereferenced_son is None:
×
1348
            raise DoesNotExist(f"Trying to dereference unknown document {dbref}")
×
1349

1350
        return ref_cls._from_son(dereferenced_son)
×
1351

1352
    def __get__(self, instance, owner):
2✔
1353
        if instance is None:
2✔
1354
            # Document class being used rather than a document object
1355
            return self
2✔
1356

1357
        # Get value from document instance if available
1358
        value = instance._data.get(self.name)
2✔
1359
        auto_dereference = instance._fields[self.name]._auto_dereference
2✔
1360

1361
        # Dereference DBRefs
1362
        if auto_dereference and isinstance(value, DBRef):
2✔
1363
            instance._data[self.name] = self._lazy_load_ref(self.document_type, value)
×
1364

1365
        return super().__get__(instance, owner)
2✔
1366

1367
    def to_mongo(self, document, use_db_field=True, fields=None):
2✔
1368
        id_field_name = self.document_type._meta["id_field"]
2✔
1369
        id_field = self.document_type._fields[id_field_name]
2✔
1370

1371
        # XXX ValidationError raised outside of the "validate" method.
1372
        if isinstance(document, Document):
2✔
1373
            # We need the id from the saved object to create the DBRef
1374
            id_ = document.pk
2✔
1375
            if id_ is None:
2✔
1376
                self.error(
×
1377
                    "You can only reference documents once they have"
1378
                    " been saved to the database"
1379
                )
1380
        else:
1381
            self.error("Only accept a document object")
×
1382

1383
        value = SON((("_id", id_field.to_mongo(id_)),))
2✔
1384

1385
        if fields:
2✔
1386
            new_fields = [f for f in self.fields if f in fields]
×
1387
        else:
1388
            new_fields = self.fields
2✔
1389

1390
        value.update(dict(document.to_mongo(use_db_field, fields=new_fields)))
2✔
1391
        return value
2✔
1392

1393
    def prepare_query_value(self, op, value):
2✔
1394
        if value is None:
2✔
1395
            return None
2✔
1396

1397
        # XXX ValidationError raised outside of the "validate" method.
1398
        if isinstance(value, Document):
2✔
1399
            if value.pk is None:
2✔
1400
                self.error(
×
1401
                    "You can only reference documents once they have"
1402
                    " been saved to the database"
1403
                )
1404
            value_dict = {"_id": value.pk}
2✔
1405
            for field in self.fields:
2✔
1406
                value_dict.update({field: value[field]})
2✔
1407

1408
            return value_dict
2✔
1409

1410
        raise NotImplementedError
×
1411

1412
    def validate(self, value):
2✔
1413
        if not isinstance(value, self.document_type):
2✔
1414
            self.error("A CachedReferenceField only accepts documents")
×
1415

1416
        if isinstance(value, Document) and value.id is None:
2✔
1417
            self.error(
×
1418
                "You can only reference documents once they have been "
1419
                "saved to the database"
1420
            )
1421

1422
    def lookup_member(self, member_name):
2✔
1423
        return self.document_type._fields.get(member_name)
2✔
1424

1425
    def sync_all(self):
2✔
1426
        """
1427
        Sync all cached fields on demand.
1428
        Caution: this operation may be slower.
1429
        """
1430
        update_key = "set__%s" % self.name
2✔
1431

1432
        for doc in self.document_type.objects:
2✔
1433
            filter_kwargs = {}
2✔
1434
            filter_kwargs[self.name] = doc
2✔
1435

1436
            update_kwargs = {}
2✔
1437
            update_kwargs[update_key] = doc
2✔
1438

1439
            self.owner_document.objects(**filter_kwargs).update(**update_kwargs)
2✔
1440

1441

1442
class GenericReferenceField(BaseField):
2✔
1443
    """A reference to *any* :class:`~mongoengine.document.Document` subclass
1444
    that will be automatically dereferenced on access (lazily).
1445

1446
    Note this field works the same way as :class:`~mongoengine.document.ReferenceField`,
1447
    doing database I/O access the first time it is accessed (even if it's to access
1448
    it ``pk`` or ``id`` field).
1449
    To solve this you should consider using the
1450
    :class:`~mongoengine.fields.GenericLazyReferenceField`.
1451

1452
    .. note ::
1453
        * Any documents used as a generic reference must be registered in the
1454
          document registry.  Importing the model will automatically register
1455
          it.
1456

1457
        * You can use the choices param to limit the acceptable Document types
1458
    """
1459

1460
    def __init__(self, *args, **kwargs):
2✔
1461
        choices = kwargs.pop("choices", None)
2✔
1462
        super().__init__(*args, **kwargs)
2✔
1463
        self.choices = []
2✔
1464
        # Keep the choices as a list of allowed Document class names
1465
        if choices:
2✔
1466
            for choice in choices:
2✔
1467
                if isinstance(choice, str):
2✔
1468
                    self.choices.append(choice)
2✔
1469
                elif isinstance(choice, type) and issubclass(choice, Document):
2✔
1470
                    self.choices.append(choice._class_name)
2✔
1471
                else:
1472
                    # XXX ValidationError raised outside of the "validate"
1473
                    # method.
1474
                    self.error(
×
1475
                        "Invalid choices provided: must be a list of"
1476
                        "Document subclasses and/or str"
1477
                    )
1478

1479
    def _validate_choices(self, value):
2✔
1480
        if isinstance(value, dict):
2✔
1481
            # If the field has not been dereferenced, it is still a dict
1482
            # of class and DBRef
1483
            value = value.get("_cls")
2✔
1484
        elif isinstance(value, Document):
2✔
1485
            value = value._class_name
2✔
1486
        super()._validate_choices(value)
2✔
1487

1488
    @staticmethod
2✔
1489
    def _lazy_load_ref(ref_cls, dbref):
2✔
1490
        dereferenced_son = ref_cls._get_db().dereference(dbref)
2✔
1491
        if dereferenced_son is None:
2✔
1492
            raise DoesNotExist(f"Trying to dereference unknown document {dbref}")
2✔
1493

1494
        return ref_cls._from_son(dereferenced_son)
2✔
1495

1496
    def __get__(self, instance, owner):
2✔
1497
        if instance is None:
2✔
1498
            return self
×
1499

1500
        value = instance._data.get(self.name)
2✔
1501

1502
        auto_dereference = instance._fields[self.name]._auto_dereference
2✔
1503
        if auto_dereference and isinstance(value, dict):
2✔
1504
            doc_cls = get_document(value["_cls"])
2✔
1505
            instance._data[self.name] = self._lazy_load_ref(doc_cls, value["_ref"])
2✔
1506

1507
        return super().__get__(instance, owner)
2✔
1508

1509
    def validate(self, value):
2✔
1510
        if not isinstance(value, (Document, DBRef, dict, SON)):
2✔
1511
            self.error("GenericReferences can only contain documents")
2✔
1512

1513
        if isinstance(value, (dict, SON)):
2✔
1514
            if "_ref" not in value or "_cls" not in value:
2✔
1515
                self.error("GenericReferences can only contain documents")
×
1516

1517
        # We need the id from the saved object to create the DBRef
1518
        elif isinstance(value, Document) and value.id is None:
2✔
1519
            self.error(
2✔
1520
                "You can only reference documents once they have been"
1521
                " saved to the database"
1522
            )
1523

1524
    def to_mongo(self, document):
2✔
1525
        if document is None:
2✔
1526
            return None
×
1527

1528
        if isinstance(document, (dict, SON, ObjectId, DBRef)):
2✔
1529
            return document
2✔
1530

1531
        id_field_name = document.__class__._meta["id_field"]
2✔
1532
        id_field = document.__class__._fields[id_field_name]
2✔
1533

1534
        if isinstance(document, Document):
2✔
1535
            # We need the id from the saved object to create the DBRef
1536
            id_ = document.id
2✔
1537
            if id_ is None:
2✔
1538
                # XXX ValidationError raised outside of the "validate" method.
1539
                self.error(
×
1540
                    "You can only reference documents once they have"
1541
                    " been saved to the database"
1542
                )
1543
        else:
1544
            id_ = document
×
1545

1546
        id_ = id_field.to_mongo(id_)
2✔
1547
        collection = document._get_collection_name()
2✔
1548
        ref = DBRef(collection, id_)
2✔
1549
        return SON((("_cls", document._class_name), ("_ref", ref)))
2✔
1550

1551
    def prepare_query_value(self, op, value):
2✔
1552
        if value is None:
2✔
1553
            return None
2✔
1554

1555
        return self.to_mongo(value)
2✔
1556

1557

1558
class BinaryField(BaseField):
2✔
1559
    """A binary data field."""
1560

1561
    def __init__(self, max_bytes=None, **kwargs):
2✔
1562
        self.max_bytes = max_bytes
2✔
1563
        super().__init__(**kwargs)
2✔
1564

1565
    def __set__(self, instance, value):
2✔
1566
        """Handle bytearrays in python 3.1"""
1567
        if isinstance(value, bytearray):
2✔
1568
            value = bytes(value)
2✔
1569
        return super().__set__(instance, value)
2✔
1570

1571
    def to_mongo(self, value):
2✔
1572
        return Binary(value)
2✔
1573

1574
    def validate(self, value):
2✔
1575
        if not isinstance(value, (bytes, Binary)):
2✔
1576
            self.error(
2✔
1577
                "BinaryField only accepts instances of "
1578
                "(%s, %s, Binary)" % (bytes.__name__, Binary.__name__)
1579
            )
1580

1581
        if self.max_bytes is not None and len(value) > self.max_bytes:
2✔
1582
            self.error("Binary value is too long")
2✔
1583

1584
    def prepare_query_value(self, op, value):
2✔
1585
        if value is None:
2✔
1586
            return value
×
1587
        return super().prepare_query_value(op, self.to_mongo(value))
2✔
1588

1589

1590
class EnumField(BaseField):
2✔
1591
    """Enumeration Field. Values are stored underneath as is,
1592
    so it will only work with simple types (str, int, etc) that
1593
    are bson encodable
1594

1595
    Example usage:
1596

1597
    .. code-block:: python
1598

1599
        class Status(Enum):
1600
            NEW = 'new'
1601
            ONGOING = 'ongoing'
1602
            DONE = 'done'
1603

1604
        class ModelWithEnum(Document):
1605
            status = EnumField(Status, default=Status.NEW)
1606

1607
        ModelWithEnum(status='done')
1608
        ModelWithEnum(status=Status.DONE)
1609

1610
    Enum fields can be searched using enum or its value:
1611

1612
    .. code-block:: python
1613

1614
        ModelWithEnum.objects(status='new').count()
1615
        ModelWithEnum.objects(status=Status.NEW).count()
1616

1617
    The values can be restricted to a subset of the enum by using the ``choices`` parameter:
1618

1619
    .. code-block:: python
1620

1621
        class ModelWithEnum(Document):
1622
            status = EnumField(Status, choices=[Status.NEW, Status.DONE])
1623
    """
1624

1625
    def __init__(self, enum, **kwargs):
2✔
1626
        self._enum_cls = enum
2✔
1627
        if kwargs.get("choices"):
2✔
1628
            invalid_choices = []
2✔
1629
            for choice in kwargs["choices"]:
2✔
1630
                if not isinstance(choice, enum):
2✔
1631
                    invalid_choices.append(choice)
2✔
1632
            if invalid_choices:
2✔
1633
                raise ValueError("Invalid choices: %r" % invalid_choices)
2✔
1634
        else:
1635
            kwargs["choices"] = list(self._enum_cls)  # Implicit validator
2✔
1636
        super().__init__(**kwargs)
2✔
1637

1638
    def validate(self, value):
2✔
1639
        if isinstance(value, self._enum_cls):
2✔
1640
            return super().validate(value)
2✔
1641
        try:
×
1642
            self._enum_cls(value)
×
1643
        except ValueError:
×
1644
            self.error(f"{value} is not a valid {self._enum_cls}")
×
1645

1646
    def to_python(self, value):
2✔
1647
        value = super().to_python(value)
2✔
1648
        if not isinstance(value, self._enum_cls):
2✔
1649
            try:
2✔
1650
                return self._enum_cls(value)
2✔
1651
            except ValueError:
2✔
1652
                return value
2✔
1653
        return value
2✔
1654

1655
    def __set__(self, instance, value):
2✔
1656
        return super().__set__(instance, self.to_python(value))
2✔
1657

1658
    def to_mongo(self, value):
2✔
1659
        if isinstance(value, self._enum_cls):
2✔
1660
            return value.value
2✔
1661
        return value
2✔
1662

1663
    def prepare_query_value(self, op, value):
2✔
1664
        if value is None:
2✔
1665
            return value
2✔
1666
        return super().prepare_query_value(op, self.to_mongo(value))
2✔
1667

1668

1669
class GridFSError(Exception):
2✔
1670
    pass
2✔
1671

1672

1673
class GridFSProxy:
2✔
1674
    """Proxy object to handle writing and reading of files to and from GridFS"""
1675

1676
    _fs = None
2✔
1677

1678
    def __init__(
2✔
1679
        self,
1680
        grid_id=None,
1681
        key=None,
1682
        instance=None,
1683
        db_alias=DEFAULT_CONNECTION_NAME,
1684
        collection_name="fs",
1685
    ):
1686
        self.grid_id = grid_id  # Store GridFS id for file
2✔
1687
        self.key = key
2✔
1688
        self.instance = instance
2✔
1689
        self.db_alias = db_alias
2✔
1690
        self.collection_name = collection_name
2✔
1691
        self.newfile = None  # Used for partial writes
2✔
1692
        self.gridout = None
2✔
1693

1694
    def __getattr__(self, name):
2✔
1695
        attrs = (
2✔
1696
            "_fs",
1697
            "grid_id",
1698
            "key",
1699
            "instance",
1700
            "db_alias",
1701
            "collection_name",
1702
            "newfile",
1703
            "gridout",
1704
        )
1705
        if name in attrs:
2✔
1706
            return self.__getattribute__(name)
×
1707
        obj = self.get()
2✔
1708
        if hasattr(obj, name):
2✔
1709
            return getattr(obj, name)
2✔
1710
        raise AttributeError
2✔
1711

1712
    def __get__(self, instance, value):
2✔
1713
        return self
×
1714

1715
    def __bool__(self):
2✔
1716
        return bool(self.grid_id)
2✔
1717

1718
    def __getstate__(self):
2✔
1719
        self_dict = self.__dict__
×
1720
        self_dict["_fs"] = None
×
1721
        return self_dict
×
1722

1723
    def __copy__(self):
2✔
1724
        copied = GridFSProxy()
×
1725
        copied.__dict__.update(self.__getstate__())
×
1726
        return copied
×
1727

1728
    def __deepcopy__(self, memo):
2✔
1729
        return self.__copy__()
×
1730

1731
    def __repr__(self):
2✔
1732
        return f"<{self.__class__.__name__}: {self.grid_id}>"
×
1733

1734
    def __str__(self):
2✔
1735
        gridout = self.get()
2✔
1736
        filename = gridout.filename if gridout else "<no file>"
2✔
1737
        return f"<{self.__class__.__name__}: {filename} ({self.grid_id})>"
2✔
1738

1739
    def __eq__(self, other):
2✔
1740
        if isinstance(other, GridFSProxy):
2✔
1741
            return (
2✔
1742
                (self.grid_id == other.grid_id)
1743
                and (self.collection_name == other.collection_name)
1744
                and (self.db_alias == other.db_alias)
1745
            )
1746
        else:
1747
            return False
2✔
1748

1749
    def __ne__(self, other):
2✔
1750
        return not self == other
×
1751

1752
    @property
2✔
1753
    def fs(self):
2✔
1754
        if not self._fs:
2✔
1755
            self._fs = gridfs.GridFS(get_db(self.db_alias), self.collection_name)
2✔
1756
        return self._fs
2✔
1757

1758
    def get(self, grid_id=None):
2✔
1759
        if grid_id:
2✔
1760
            self.grid_id = grid_id
×
1761

1762
        if self.grid_id is None:
2✔
1763
            return None
2✔
1764

1765
        try:
2✔
1766
            if self.gridout is None:
2✔
1767
                self.gridout = self.fs.get(self.grid_id)
2✔
1768
            return self.gridout
2✔
1769
        except Exception:
×
1770
            # File has been deleted
1771
            return None
×
1772

1773
    def new_file(self, **kwargs):
2✔
1774
        self.newfile = self.fs.new_file(**kwargs)
2✔
1775
        self.grid_id = self.newfile._id
2✔
1776
        self._mark_as_changed()
2✔
1777

1778
    def put(self, file_obj, **kwargs):
2✔
1779
        if self.grid_id:
2✔
1780
            raise GridFSError(
×
1781
                "This document already has a file. Either delete "
1782
                "it or call replace to overwrite it"
1783
            )
1784
        self.grid_id = self.fs.put(file_obj, **kwargs)
2✔
1785
        self._mark_as_changed()
2✔
1786

1787
    def write(self, string):
2✔
1788
        if self.grid_id:
2✔
1789
            if not self.newfile:
2✔
1790
                raise GridFSError(
×
1791
                    "This document already has a file. Either "
1792
                    "delete it or call replace to overwrite it"
1793
                )
1794
        else:
1795
            self.new_file()
×
1796
        self.newfile.write(string)
2✔
1797

1798
    def writelines(self, lines):
2✔
1799
        if not self.newfile:
×
1800
            self.new_file()
×
1801
            self.grid_id = self.newfile._id
×
1802
        self.newfile.writelines(lines)
×
1803

1804
    def read(self, size=-1):
2✔
1805
        gridout = self.get()
2✔
1806
        if gridout is None:
2✔
1807
            return None
2✔
1808
        else:
1809
            try:
2✔
1810
                return gridout.read(size)
2✔
1811
            except Exception:
×
1812
                return ""
×
1813

1814
    def delete(self):
2✔
1815
        # Delete file from GridFS, FileField still remains
1816
        self.fs.delete(self.grid_id)
2✔
1817
        self.grid_id = None
2✔
1818
        self.gridout = None
2✔
1819
        self._mark_as_changed()
2✔
1820

1821
    def replace(self, file_obj, **kwargs):
2✔
1822
        self.delete()
2✔
1823
        self.put(file_obj, **kwargs)
2✔
1824

1825
    def close(self):
2✔
1826
        if self.newfile:
2✔
1827
            self.newfile.close()
2✔
1828

1829
    def _mark_as_changed(self):
2✔
1830
        """Inform the instance that `self.key` has been changed"""
1831
        if self.instance:
2✔
1832
            self.instance._mark_as_changed(self.key)
2✔
1833

1834

1835
class FileField(BaseField):
2✔
1836
    """A GridFS storage field."""
1837

1838
    proxy_class = GridFSProxy
2✔
1839

1840
    def __init__(
2✔
1841
        self, db_alias=DEFAULT_CONNECTION_NAME, collection_name="fs", **kwargs
1842
    ):
1843
        super().__init__(**kwargs)
2✔
1844
        self.collection_name = collection_name
2✔
1845
        self.db_alias = db_alias
2✔
1846

1847
    def __get__(self, instance, owner):
2✔
1848
        if instance is None:
2✔
1849
            return self
×
1850

1851
        # Check if a file already exists for this model
1852
        grid_file = instance._data.get(self.name)
2✔
1853
        if not isinstance(grid_file, self.proxy_class):
2✔
1854
            grid_file = self.get_proxy_obj(key=self.name, instance=instance)
2✔
1855
            instance._data[self.name] = grid_file
2✔
1856

1857
        if not grid_file.key:
2✔
1858
            grid_file.key = self.name
2✔
1859
            grid_file.instance = instance
2✔
1860
        return grid_file
2✔
1861

1862
    def __set__(self, instance, value):
2✔
1863
        key = self.name
2✔
1864
        if (
2✔
1865
            hasattr(value, "read") and not isinstance(value, GridFSProxy)
1866
        ) or isinstance(value, (bytes, str)):
1867
            # using "FileField() = file/string" notation
1868
            grid_file = instance._data.get(self.name)
2✔
1869
            # If a file already exists, delete it
1870
            if grid_file:
2✔
1871
                try:
2✔
1872
                    grid_file.delete()
2✔
1873
                except Exception:
×
1874
                    pass
×
1875

1876
            # Create a new proxy object as we don't already have one
1877
            instance._data[key] = self.get_proxy_obj(key=key, instance=instance)
2✔
1878
            instance._data[key].put(value)
2✔
1879
        else:
1880
            instance._data[key] = value
2✔
1881

1882
        instance._mark_as_changed(key)
2✔
1883

1884
    def get_proxy_obj(self, key, instance, db_alias=None, collection_name=None):
2✔
1885
        if db_alias is None:
2✔
1886
            db_alias = self.db_alias
2✔
1887
        if collection_name is None:
2✔
1888
            collection_name = self.collection_name
2✔
1889

1890
        return self.proxy_class(
2✔
1891
            key=key,
1892
            instance=instance,
1893
            db_alias=db_alias,
1894
            collection_name=collection_name,
1895
        )
1896

1897
    def to_mongo(self, value):
2✔
1898
        # Store the GridFS file id in MongoDB
1899
        if isinstance(value, self.proxy_class) and value.grid_id is not None:
2✔
1900
            return value.grid_id
2✔
1901
        return None
2✔
1902

1903
    def to_python(self, value):
2✔
1904
        if value is not None:
2✔
1905
            return self.proxy_class(
2✔
1906
                value, collection_name=self.collection_name, db_alias=self.db_alias
1907
            )
1908

1909
    def validate(self, value):
2✔
1910
        if value.grid_id is not None:
2✔
1911
            if not isinstance(value, self.proxy_class):
2✔
1912
                self.error("FileField only accepts GridFSProxy values")
×
1913
            if not isinstance(value.grid_id, ObjectId):
2✔
1914
                self.error("Invalid GridFSProxy value")
×
1915

1916

1917
class ImageGridFsProxy(GridFSProxy):
2✔
1918
    """Proxy for ImageField"""
1919

1920
    def put(self, file_obj, **kwargs):
2✔
1921
        """
1922
        Insert a image in database
1923
        applying field properties (size, thumbnail_size)
1924
        """
1925
        field = self.instance._fields[self.key]
2✔
1926
        # Handle nested fields
1927
        if hasattr(field, "field") and isinstance(field.field, FileField):
2✔
1928
            field = field.field
×
1929

1930
        try:
2✔
1931
            img = Image.open(file_obj)
2✔
1932
            img_format = img.format
2✔
1933
        except Exception as e:
2✔
1934
            raise ValidationError("Invalid image: %s" % e)
2✔
1935

1936
        # Progressive JPEG
1937
        # TODO: fixme, at least unused, at worst bad implementation
1938
        progressive = img.info.get("progressive") or False
2✔
1939

1940
        if (
2✔
1941
            kwargs.get("progressive")
1942
            and isinstance(kwargs.get("progressive"), bool)
1943
            and img_format == "JPEG"
1944
        ):
1945
            progressive = True
×
1946
        else:
1947
            progressive = False
2✔
1948

1949
        if field.size and (
2✔
1950
            img.size[0] > field.size["width"] or img.size[1] > field.size["height"]
1951
        ):
1952
            size = field.size
2✔
1953

1954
            if size["force"]:
2✔
1955
                img = ImageOps.fit(img, (size["width"], size["height"]), LANCZOS)
2✔
1956
            else:
1957
                img.thumbnail((size["width"], size["height"]), LANCZOS)
×
1958

1959
        thumbnail = None
2✔
1960
        if field.thumbnail_size:
2✔
1961
            size = field.thumbnail_size
2✔
1962

1963
            if size["force"]:
2✔
1964
                thumbnail = ImageOps.fit(img, (size["width"], size["height"]), LANCZOS)
2✔
1965
            else:
1966
                thumbnail = img.copy()
×
1967
                thumbnail.thumbnail((size["width"], size["height"]), LANCZOS)
×
1968

1969
        if thumbnail:
2✔
1970
            thumb_id = self._put_thumbnail(thumbnail, img_format, progressive)
2✔
1971
        else:
1972
            thumb_id = None
2✔
1973

1974
        w, h = img.size
2✔
1975

1976
        io = BytesIO()
2✔
1977
        img.save(io, img_format, progressive=progressive)
2✔
1978
        io.seek(0)
2✔
1979

1980
        return super().put(
2✔
1981
            io, width=w, height=h, format=img_format, thumbnail_id=thumb_id, **kwargs
1982
        )
1983

1984
    def delete(self, *args, **kwargs):
2✔
1985
        # deletes thumbnail
1986
        out = self.get()
2✔
1987
        if out and out.thumbnail_id:
2✔
1988
            self.fs.delete(out.thumbnail_id)
2✔
1989

1990
        return super().delete()
2✔
1991

1992
    def _put_thumbnail(self, thumbnail, format, progressive, **kwargs):
2✔
1993
        w, h = thumbnail.size
2✔
1994

1995
        io = BytesIO()
2✔
1996
        thumbnail.save(io, format, progressive=progressive)
2✔
1997
        io.seek(0)
2✔
1998

1999
        return self.fs.put(io, width=w, height=h, format=format, **kwargs)
2✔
2000

2001
    @property
2✔
2002
    def size(self):
2✔
2003
        """
2004
        return a width, height of image
2005
        """
2006
        out = self.get()
2✔
2007
        if out:
2✔
2008
            return out.width, out.height
2✔
2009

2010
    @property
2✔
2011
    def format(self):
2✔
2012
        """
2013
        return format of image
2014
        ex: PNG, JPEG, GIF, etc
2015
        """
2016
        out = self.get()
2✔
2017
        if out:
2✔
2018
            return out.format
2✔
2019

2020
    @property
2✔
2021
    def thumbnail(self):
2✔
2022
        """
2023
        return a gridfs.grid_file.GridOut
2024
        representing a thumbnail of Image
2025
        """
2026
        out = self.get()
2✔
2027
        if out and out.thumbnail_id:
2✔
2028
            return self.fs.get(out.thumbnail_id)
2✔
2029

2030
    def write(self, *args, **kwargs):
2✔
2031
        raise RuntimeError('Please use "put" method instead')
×
2032

2033
    def writelines(self, *args, **kwargs):
2✔
2034
        raise RuntimeError('Please use "put" method instead')
×
2035

2036

2037
class ImproperlyConfigured(Exception):
2✔
2038
    pass
2✔
2039

2040

2041
class ImageField(FileField):
2✔
2042
    """
2043
    A Image File storage field.
2044

2045
    :param size: max size to store images, provided as (width, height, force)
2046
        if larger, it will be automatically resized (ex: size=(800, 600, True))
2047
    :param thumbnail_size: size to generate a thumbnail, provided as (width, height, force)
2048
    """
2049

2050
    proxy_class = ImageGridFsProxy
2✔
2051

2052
    def __init__(
2✔
2053
        self, size=None, thumbnail_size=None, collection_name="images", **kwargs
2054
    ):
2055
        if not Image:
2✔
2056
            raise ImproperlyConfigured("PIL library was not found")
×
2057

2058
        params_size = ("width", "height", "force")
2✔
2059
        extra_args = {"size": size, "thumbnail_size": thumbnail_size}
2✔
2060
        for att_name, att in extra_args.items():
2✔
2061
            value = None
2✔
2062
            if isinstance(att, (tuple, list)):
2✔
2063
                value = dict(itertools.zip_longest(params_size, att, fillvalue=None))
2✔
2064

2065
            setattr(self, att_name, value)
2✔
2066

2067
        super().__init__(collection_name=collection_name, **kwargs)
2✔
2068

2069

2070
class SequenceField(BaseField):
2✔
2071
    """Provides a sequential counter see:
2072
     https://docs.mongodb.com/manual/reference/method/ObjectId/#ObjectIDs-SequenceNumbers
2073

2074
    .. note::
2075

2076
             Although traditional databases often use increasing sequence
2077
             numbers for primary keys. In MongoDB, the preferred approach is to
2078
             use Object IDs instead.  The concept is that in a very large
2079
             cluster of machines, it is easier to create an object ID than have
2080
             global, uniformly increasing sequence numbers.
2081

2082
    :param collection_name:  Name of the counter collection (default 'mongoengine.counters')
2083
    :param sequence_name: Name of the sequence in the collection (default 'ClassName.counter')
2084
    :param value_decorator: Any callable to use as a counter (default int)
2085

2086
    Use any callable as `value_decorator` to transform calculated counter into
2087
    any value suitable for your needs, e.g. string or hexadecimal
2088
    representation of the default integer counter value.
2089

2090
    .. note::
2091

2092
        In case the counter is defined in the abstract document, it will be
2093
        common to all inherited documents and the default sequence name will
2094
        be the class name of the abstract document.
2095
    """
2096

2097
    _auto_gen = True
2✔
2098
    COLLECTION_NAME = "mongoengine.counters"
2✔
2099
    VALUE_DECORATOR = int
2✔
2100

2101
    def __init__(
2✔
2102
        self,
2103
        collection_name=None,
2104
        db_alias=None,
2105
        sequence_name=None,
2106
        value_decorator=None,
2107
        *args,
2108
        **kwargs,
2109
    ):
2110
        self.collection_name = collection_name or self.COLLECTION_NAME
2✔
2111
        self.db_alias = db_alias or DEFAULT_CONNECTION_NAME
2✔
2112
        self.sequence_name = sequence_name
2✔
2113
        self.value_decorator = (
2✔
2114
            value_decorator if callable(value_decorator) else self.VALUE_DECORATOR
2115
        )
2116
        super().__init__(*args, **kwargs)
2✔
2117

2118
    def generate(self):
2✔
2119
        """
2120
        Generate and Increment the counter
2121
        """
2122
        sequence_name = self.get_sequence_name()
2✔
2123
        sequence_id = f"{sequence_name}.{self.name}"
2✔
2124
        collection = get_db(alias=self.db_alias)[self.collection_name]
2✔
2125

2126
        counter = collection.find_one_and_update(
2✔
2127
            filter={"_id": sequence_id},
2128
            update={"$inc": {"next": 1}},
2129
            return_document=ReturnDocument.AFTER,
2130
            upsert=True,
2131
        )
2132
        return self.value_decorator(counter["next"])
2✔
2133

2134
    def set_next_value(self, value):
2✔
2135
        """Helper method to set the next sequence value"""
2136
        sequence_name = self.get_sequence_name()
2✔
2137
        sequence_id = f"{sequence_name}.{self.name}"
2✔
2138
        collection = get_db(alias=self.db_alias)[self.collection_name]
2✔
2139
        counter = collection.find_one_and_update(
2✔
2140
            filter={"_id": sequence_id},
2141
            update={"$set": {"next": value}},
2142
            return_document=ReturnDocument.AFTER,
2143
            upsert=True,
2144
        )
2145
        return self.value_decorator(counter["next"])
2✔
2146

2147
    def get_next_value(self):
2✔
2148
        """Helper method to get the next value for previewing.
2149

2150
        .. warning:: There is no guarantee this will be the next value
2151
        as it is only fixed on set.
2152
        """
2153
        sequence_name = self.get_sequence_name()
2✔
2154
        sequence_id = f"{sequence_name}.{self.name}"
2✔
2155
        collection = get_db(alias=self.db_alias)[self.collection_name]
2✔
2156
        data = collection.find_one({"_id": sequence_id})
2✔
2157

2158
        if data:
2✔
2159
            return self.value_decorator(data["next"] + 1)
2✔
2160

2161
        return self.value_decorator(1)
2✔
2162

2163
    def get_sequence_name(self):
2✔
2164
        if self.sequence_name:
2✔
2165
            return self.sequence_name
2✔
2166
        owner = self.owner_document
2✔
2167
        if issubclass(owner, Document) and not owner._meta.get("abstract"):
2✔
2168
            return owner._get_collection_name()
2✔
2169
        else:
2170
            return (
2✔
2171
                "".join("_%s" % c if c.isupper() else c for c in owner._class_name)
2172
                .strip("_")
2173
                .lower()
2174
            )
2175

2176
    def __get__(self, instance, owner):
2✔
2177
        value = super().__get__(instance, owner)
2✔
2178
        if value is None and instance._initialised:
2✔
2179
            value = self.generate()
2✔
2180
            instance._data[self.name] = value
2✔
2181
            instance._mark_as_changed(self.name)
2✔
2182

2183
        return value
2✔
2184

2185
    def __set__(self, instance, value):
2✔
2186
        if value is None and instance._initialised:
2✔
2187
            value = self.generate()
2✔
2188

2189
        return super().__set__(instance, value)
2✔
2190

2191
    def prepare_query_value(self, op, value):
2✔
2192
        """
2193
        This method is overridden in order to convert the query value into to required
2194
        type. We need to do this in order to be able to successfully compare query
2195
        values passed as string, the base implementation returns the value as is.
2196
        """
2197
        return self.value_decorator(value)
2✔
2198

2199
    def to_python(self, value):
2✔
2200
        if value is None:
2✔
2201
            value = self.generate()
×
2202
        return value
2✔
2203

2204

2205
class UUIDField(BaseField):
2✔
2206
    """A UUID field."""
2207

2208
    _binary = None
2✔
2209

2210
    def __init__(self, binary=True, **kwargs):
2✔
2211
        """
2212
        Store UUID data in the database
2213

2214
        :param binary: if False store as a string.
2215
        """
2216
        self._binary = binary
2✔
2217
        super().__init__(**kwargs)
2✔
2218

2219
    def to_python(self, value):
2✔
2220
        if not self._binary:
2✔
2221
            original_value = value
2✔
2222
            try:
2✔
2223
                if not isinstance(value, str):
2✔
2224
                    value = str(value)
2✔
2225
                return uuid.UUID(value)
2✔
2226
            except (ValueError, TypeError, AttributeError):
×
2227
                return original_value
×
2228
        return value
2✔
2229

2230
    def to_mongo(self, value):
2✔
2231
        if not self._binary:
2✔
2232
            return str(value)
2✔
2233
        elif isinstance(value, str):
2✔
2234
            return uuid.UUID(value)
×
2235
        return value
2✔
2236

2237
    def prepare_query_value(self, op, value):
2✔
2238
        if value is None:
2✔
2239
            return None
×
2240
        return self.to_mongo(value)
2✔
2241

2242
    def validate(self, value):
2✔
2243
        if not isinstance(value, uuid.UUID):
2✔
2244
            if not isinstance(value, str):
2✔
2245
                value = str(value)
×
2246
            try:
2✔
2247
                uuid.UUID(value)
2✔
2248
            except (ValueError, TypeError, AttributeError) as exc:
2✔
2249
                self.error("Could not convert to UUID: %s" % exc)
2✔
2250

2251

2252
class GeoPointField(BaseField):
2✔
2253
    """A list storing a longitude and latitude coordinate.
2254

2255
    .. note:: this represents a generic point in a 2D plane and a legacy way of
2256
        representing a geo point. It admits 2d indexes but not "2dsphere" indexes
2257
        in MongoDB > 2.4 which are more natural for modeling geospatial points.
2258
        See :ref:`geospatial-indexes`
2259
    """
2260

2261
    _geo_index = pymongo.GEO2D
2✔
2262

2263
    def validate(self, value):
2✔
2264
        """Make sure that a geo-value is of type (x, y)"""
2265
        if not isinstance(value, (list, tuple)):
2✔
2266
            self.error("GeoPointField can only accept tuples or lists of (x, y)")
2✔
2267

2268
        if not len(value) == 2:
2✔
2269
            self.error("Value (%s) must be a two-dimensional point" % repr(value))
2✔
2270
        elif not isinstance(value[0], (float, int)) or not isinstance(
2✔
2271
            value[1], (float, int)
2272
        ):
2273
            self.error("Both values (%s) in point must be float or int" % repr(value))
2✔
2274

2275

2276
class PointField(GeoJsonBaseField):
2✔
2277
    """A GeoJSON field storing a longitude and latitude coordinate.
2278

2279
    The data is represented as:
2280

2281
    .. code-block:: js
2282

2283
        {'type' : 'Point' ,
2284
         'coordinates' : [x, y]}
2285

2286
    You can either pass a dict with the full information or a list
2287
    to set the value.
2288

2289
    Requires mongodb >= 2.4
2290
    """
2291

2292
    _type = "Point"
2✔
2293

2294

2295
class LineStringField(GeoJsonBaseField):
2✔
2296
    """A GeoJSON field storing a line of longitude and latitude coordinates.
2297

2298
    The data is represented as:
2299

2300
    .. code-block:: js
2301

2302
        {'type' : 'LineString' ,
2303
         'coordinates' : [[x1, y1], [x2, y2] ... [xn, yn]]}
2304

2305
    You can either pass a dict with the full information or a list of points.
2306

2307
    Requires mongodb >= 2.4
2308
    """
2309

2310
    _type = "LineString"
2✔
2311

2312

2313
class PolygonField(GeoJsonBaseField):
2✔
2314
    """A GeoJSON field storing a polygon of longitude and latitude coordinates.
2315

2316
    The data is represented as:
2317

2318
    .. code-block:: js
2319

2320
        {'type' : 'Polygon' ,
2321
         'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
2322
                          [[x1, y1], [x1, y1] ... [xn, yn]]}
2323

2324
    You can either pass a dict with the full information or a list
2325
    of LineStrings. The first LineString being the outside and the rest being
2326
    holes.
2327

2328
    Requires mongodb >= 2.4
2329
    """
2330

2331
    _type = "Polygon"
2✔
2332

2333

2334
class MultiPointField(GeoJsonBaseField):
2✔
2335
    """A GeoJSON field storing a list of Points.
2336

2337
    The data is represented as:
2338

2339
    .. code-block:: js
2340

2341
        {'type' : 'MultiPoint' ,
2342
         'coordinates' : [[x1, y1], [x2, y2]]}
2343

2344
    You can either pass a dict with the full information or a list
2345
    to set the value.
2346

2347
    Requires mongodb >= 2.6
2348
    """
2349

2350
    _type = "MultiPoint"
2✔
2351

2352

2353
class MultiLineStringField(GeoJsonBaseField):
2✔
2354
    """A GeoJSON field storing a list of LineStrings.
2355

2356
    The data is represented as:
2357

2358
    .. code-block:: js
2359

2360
        {'type' : 'MultiLineString' ,
2361
         'coordinates' : [[[x1, y1], [x1, y1] ... [xn, yn]],
2362
                          [[x1, y1], [x1, y1] ... [xn, yn]]]}
2363

2364
    You can either pass a dict with the full information or a list of points.
2365

2366
    Requires mongodb >= 2.6
2367
    """
2368

2369
    _type = "MultiLineString"
2✔
2370

2371

2372
class MultiPolygonField(GeoJsonBaseField):
2✔
2373
    """A GeoJSON field storing  list of Polygons.
2374

2375
    The data is represented as:
2376

2377
    .. code-block:: js
2378

2379
        {'type' : 'MultiPolygon' ,
2380
         'coordinates' : [[
2381
               [[x1, y1], [x1, y1] ... [xn, yn]],
2382
               [[x1, y1], [x1, y1] ... [xn, yn]]
2383
           ], [
2384
               [[x1, y1], [x1, y1] ... [xn, yn]],
2385
               [[x1, y1], [x1, y1] ... [xn, yn]]
2386
           ]
2387
        }
2388

2389
    You can either pass a dict with the full information or a list
2390
    of Polygons.
2391

2392
    Requires mongodb >= 2.6
2393
    """
2394

2395
    _type = "MultiPolygon"
2✔
2396

2397

2398
class LazyReferenceField(BaseField):
2✔
2399
    """A really lazy reference to a document.
2400
    Unlike the :class:`~mongoengine.fields.ReferenceField` it will
2401
    **not** be automatically (lazily) dereferenced on access.
2402
    Instead, access will return a :class:`~mongoengine.base.LazyReference` class
2403
    instance, allowing access to `pk` or manual dereference by using
2404
    ``fetch()`` method.
2405
    """
2406

2407
    def __init__(
2✔
2408
        self,
2409
        document_type,
2410
        passthrough=False,
2411
        dbref=False,
2412
        reverse_delete_rule=DO_NOTHING,
2413
        **kwargs,
2414
    ):
2415
        """Initialises the Reference Field.
2416

2417
        :param dbref:  Store the reference as :class:`~pymongo.dbref.DBRef`
2418
          or as the :class:`~pymongo.objectid.ObjectId`.id .
2419
        :param reverse_delete_rule: Determines what to do when the referring
2420
          object is deleted
2421
        :param passthrough: When trying to access unknown fields, the
2422
          :class:`~mongoengine.base.datastructure.LazyReference` instance will
2423
          automatically call `fetch()` and try to retrieve the field on the fetched
2424
          document. Note this only work getting field (not setting or deleting).
2425
        """
2426
        # XXX ValidationError raised outside of the "validate" method.
2427
        if not isinstance(document_type, str) and not issubclass(
2✔
2428
            document_type, Document
2429
        ):
2430
            self.error(
2✔
2431
                "Argument to LazyReferenceField constructor must be a "
2432
                "document class or a string"
2433
            )
2434

2435
        self.dbref = dbref
2✔
2436
        self.passthrough = passthrough
2✔
2437
        self.document_type_obj = document_type
2✔
2438
        self.reverse_delete_rule = reverse_delete_rule
2✔
2439
        super().__init__(**kwargs)
2✔
2440

2441
    @property
2✔
2442
    def document_type(self):
2✔
2443
        if isinstance(self.document_type_obj, str):
2✔
2444
            if self.document_type_obj == RECURSIVE_REFERENCE_CONSTANT:
2✔
2445
                self.document_type_obj = self.owner_document
×
2446
            else:
2447
                self.document_type_obj = get_document(self.document_type_obj)
2✔
2448
        return self.document_type_obj
2✔
2449

2450
    def build_lazyref(self, value):
2✔
2451
        if isinstance(value, LazyReference):
2✔
2452
            if value.passthrough != self.passthrough:
2✔
2453
                value = LazyReference(
×
2454
                    value.document_type, value.pk, passthrough=self.passthrough
2455
                )
2456
        elif value is not None:
2✔
2457
            if isinstance(value, self.document_type):
2✔
2458
                value = LazyReference(
2✔
2459
                    self.document_type, value.pk, passthrough=self.passthrough
2460
                )
2461
            elif isinstance(value, DBRef):
2✔
2462
                value = LazyReference(
2✔
2463
                    self.document_type, value.id, passthrough=self.passthrough
2464
                )
2465
            else:
2466
                # value is the primary key of the referenced document
2467
                value = LazyReference(
2✔
2468
                    self.document_type, value, passthrough=self.passthrough
2469
                )
2470
        return value
2✔
2471

2472
    def __get__(self, instance, owner):
2✔
2473
        """Descriptor to allow lazy dereferencing."""
2474
        if instance is None:
2✔
2475
            # Document class being used rather than a document object
2476
            return self
×
2477

2478
        value = self.build_lazyref(instance._data.get(self.name))
2✔
2479
        if value:
2✔
2480
            instance._data[self.name] = value
2✔
2481

2482
        return super().__get__(instance, owner)
2✔
2483

2484
    def to_mongo(self, value):
2✔
2485
        if isinstance(value, LazyReference):
2✔
2486
            pk = value.pk
2✔
2487
        elif isinstance(value, self.document_type):
2✔
2488
            pk = value.pk
2✔
2489
        elif isinstance(value, DBRef):
2✔
2490
            pk = value.id
2✔
2491
        else:
2492
            # value is the primary key of the referenced document
2493
            pk = value
×
2494
        id_field_name = self.document_type._meta["id_field"]
2✔
2495
        id_field = self.document_type._fields[id_field_name]
2✔
2496
        pk = id_field.to_mongo(pk)
2✔
2497
        if self.dbref:
2✔
2498
            return DBRef(self.document_type._get_collection_name(), pk)
2✔
2499
        else:
2500
            return pk
2✔
2501

2502
    def to_python(self, value):
2✔
2503
        """Convert a MongoDB-compatible type to a Python type."""
2504
        if not isinstance(value, (DBRef, Document, EmbeddedDocument)):
2✔
2505
            collection = self.document_type._get_collection_name()
2✔
2506
            value = DBRef(collection, self.document_type.id.to_python(value))
2✔
2507
            value = self.build_lazyref(value)
2✔
2508
        return value
2✔
2509

2510
    def validate(self, value):
2✔
2511
        if isinstance(value, LazyReference):
2✔
2512
            if value.collection != self.document_type._get_collection_name():
2✔
2513
                self.error("Reference must be on a `%s` document." % self.document_type)
2✔
2514
            pk = value.pk
2✔
2515
        elif isinstance(value, self.document_type):
2✔
2516
            pk = value.pk
2✔
2517
        elif isinstance(value, DBRef):
2✔
2518
            # TODO: check collection ?
2519
            collection = self.document_type._get_collection_name()
2✔
2520
            if value.collection != collection:
2✔
2521
                self.error("DBRef on bad collection (must be on `%s`)" % collection)
2✔
2522
            pk = value.id
2✔
2523
        else:
2524
            # value is the primary key of the referenced document
2525
            id_field_name = self.document_type._meta["id_field"]
2✔
2526
            id_field = getattr(self.document_type, id_field_name)
2✔
2527
            pk = value
2✔
2528
            try:
2✔
2529
                id_field.validate(pk)
2✔
2530
            except ValidationError:
2✔
2531
                self.error(
2✔
2532
                    "value should be `{0}` document, LazyReference or DBRef on `{0}` "
2533
                    "or `{0}`'s primary key (i.e. `{1}`)".format(
2534
                        self.document_type.__name__, type(id_field).__name__
2535
                    )
2536
                )
2537

2538
        if pk is None:
2✔
2539
            self.error(
2✔
2540
                "You can only reference documents once they have been "
2541
                "saved to the database"
2542
            )
2543

2544
    def prepare_query_value(self, op, value):
2✔
2545
        if value is None:
2✔
2546
            return None
×
2547
        super().prepare_query_value(op, value)
2✔
2548
        return self.to_mongo(value)
2✔
2549

2550
    def lookup_member(self, member_name):
2✔
2551
        return self.document_type._fields.get(member_name)
×
2552

2553

2554
class GenericLazyReferenceField(GenericReferenceField):
2✔
2555
    """A reference to *any* :class:`~mongoengine.document.Document` subclass.
2556
    Unlike the :class:`~mongoengine.fields.GenericReferenceField` it will
2557
    **not** be automatically (lazily) dereferenced on access.
2558
    Instead, access will return a :class:`~mongoengine.base.LazyReference` class
2559
    instance, allowing access to `pk` or manual dereference by using
2560
    ``fetch()`` method.
2561

2562
    .. note ::
2563
        * Any documents used as a generic reference must be registered in the
2564
          document registry.  Importing the model will automatically register
2565
          it.
2566

2567
        * You can use the choices param to limit the acceptable Document types
2568
    """
2569

2570
    def __init__(self, *args, **kwargs):
2✔
2571
        self.passthrough = kwargs.pop("passthrough", False)
2✔
2572
        super().__init__(*args, **kwargs)
2✔
2573

2574
    def _validate_choices(self, value):
2✔
2575
        if isinstance(value, LazyReference):
2✔
2576
            value = value.document_type._class_name
2✔
2577
        super()._validate_choices(value)
2✔
2578

2579
    def build_lazyref(self, value):
2✔
2580
        if isinstance(value, LazyReference):
2✔
2581
            if value.passthrough != self.passthrough:
2✔
2582
                value = LazyReference(
×
2583
                    value.document_type, value.pk, passthrough=self.passthrough
2584
                )
2585
        elif value is not None:
2✔
2586
            if isinstance(value, (dict, SON)):
2✔
2587
                value = LazyReference(
2✔
2588
                    get_document(value["_cls"]),
2589
                    value["_ref"].id,
2590
                    passthrough=self.passthrough,
2591
                )
2592
            elif isinstance(value, Document):
2✔
2593
                value = LazyReference(
2✔
2594
                    type(value), value.pk, passthrough=self.passthrough
2595
                )
2596
        return value
2✔
2597

2598
    def __get__(self, instance, owner):
2✔
2599
        if instance is None:
2✔
2600
            return self
×
2601

2602
        value = self.build_lazyref(instance._data.get(self.name))
2✔
2603
        if value:
2✔
2604
            instance._data[self.name] = value
2✔
2605

2606
        return super().__get__(instance, owner)
2✔
2607

2608
    def validate(self, value):
2✔
2609
        if isinstance(value, LazyReference) and value.pk is None:
2✔
2610
            self.error(
×
2611
                "You can only reference documents once they have been"
2612
                " saved to the database"
2613
            )
2614
        return super().validate(value)
2✔
2615

2616
    def to_mongo(self, document):
2✔
2617
        if document is None:
2✔
2618
            return None
×
2619

2620
        if isinstance(document, LazyReference):
2✔
2621
            return SON(
2✔
2622
                (
2623
                    ("_cls", document.document_type._class_name),
2624
                    (
2625
                        "_ref",
2626
                        DBRef(
2627
                            document.document_type._get_collection_name(), document.pk
2628
                        ),
2629
                    ),
2630
                )
2631
            )
2632
        else:
2633
            return super().to_mongo(document)
2✔
2634

2635

2636
class Decimal128Field(BaseField):
2✔
2637
    """
2638
    128-bit decimal-based floating-point field capable of emulating decimal
2639
    rounding with exact precision. This field will expose decimal.Decimal but stores the value as a
2640
    `bson.Decimal128` behind the scene, this field is intended for monetary data, scientific computations, etc.
2641
    """
2642

2643
    DECIMAL_CONTEXT = create_decimal128_context()
2✔
2644

2645
    def __init__(self, min_value=None, max_value=None, **kwargs):
2✔
2646
        self.min_value = min_value
2✔
2647
        self.max_value = max_value
2✔
2648
        super().__init__(**kwargs)
2✔
2649

2650
    def to_mongo(self, value):
2✔
2651
        if value is None:
2✔
2652
            return None
2✔
2653
        if isinstance(value, Decimal128):
2✔
2654
            return value
2✔
2655
        if not isinstance(value, decimal.Decimal):
2✔
2656
            with decimal.localcontext(self.DECIMAL_CONTEXT) as ctx:
2✔
2657
                value = ctx.create_decimal(value)
2✔
2658
        return Decimal128(value)
2✔
2659

2660
    def to_python(self, value):
2✔
2661
        if value is None:
2✔
2662
            return None
×
2663
        return self.to_mongo(value).to_decimal()
2✔
2664

2665
    def validate(self, value):
2✔
2666
        if not isinstance(value, Decimal128):
2✔
2667
            try:
2✔
2668
                value = Decimal128(value)
2✔
2669
            except (TypeError, ValueError, decimal.InvalidOperation) as exc:
2✔
2670
                self.error("Could not convert value to Decimal128: %s" % exc)
2✔
2671

2672
        if self.min_value is not None and value.to_decimal() < self.min_value:
2✔
2673
            self.error("Decimal value is too small")
2✔
2674

2675
        if self.max_value is not None and value.to_decimal() > self.max_value:
2✔
2676
            self.error("Decimal value is too large")
2✔
2677

2678
    def prepare_query_value(self, op, value):
2✔
2679
        return super().prepare_query_value(op, self.to_mongo(value))
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc