• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 11880859300

17 Nov 2024 05:32PM UTC coverage: 99.529% (-0.002%) from 99.531%
11880859300

push

github

web-flow
Merge pull request #262 from betaboon/refactor-absolute-imports-and-remove-compat

Refactor absolute imports and remove compat

28 of 28 new or added lines in 8 files covered. (100.0%)

4 existing lines in 1 file now uncovered.

845 of 849 relevant lines covered (99.53%)

6.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.11
/mocket/mocket.py
1
import collections
7✔
2
import collections.abc as collections_abc
7✔
3
import contextlib
7✔
4
import errno
7✔
5
import hashlib
7✔
6
import itertools
7✔
7
import json
7✔
8
import os
7✔
9
import select
7✔
10
import socket
7✔
11
import ssl
7✔
12
from datetime import datetime, timedelta
7✔
13
from json.decoder import JSONDecodeError
7✔
14
from typing import Optional, Tuple
7✔
15

16
import urllib3
7✔
17
from urllib3.connection import match_hostname as urllib3_match_hostname
7✔
18
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
7✔
19

20
try:
7✔
21
    from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
7✔
22
except ImportError:
23
    urllib3_wrap_socket = None
24

25

26
from mocket.compat import decode_from_bytes, encode_to_bytes
7✔
27
from mocket.utils import (
7✔
28
    MocketMode,
29
    MocketSocketCore,
30
    get_mocketize,
31
    hexdump,
32
    hexload,
33
)
34

35
xxh32 = None
7✔
36
try:
7✔
37
    from xxhash import xxh32
7✔
38
except ImportError:  # pragma: no cover
39
    with contextlib.suppress(ImportError):
40
        from xxhash_cffi import xxh32
41
hasher = xxh32 or hashlib.md5
7✔
42

43
try:  # pragma: no cover
44
    from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
45

46
    pyopenssl_override = True
47
except ImportError:
48
    pyopenssl_override = False
49

50
true_socket = socket.socket
7✔
51
true_create_connection = socket.create_connection
7✔
52
true_gethostbyname = socket.gethostbyname
7✔
53
true_gethostname = socket.gethostname
7✔
54
true_getaddrinfo = socket.getaddrinfo
7✔
55
true_socketpair = socket.socketpair
7✔
56
true_ssl_wrap_socket = getattr(
7✔
57
    ssl, "wrap_socket", None
58
)  # from Py3.12 it's only under SSLContext
59
true_ssl_socket = ssl.SSLSocket
7✔
60
true_ssl_context = ssl.SSLContext
7✔
61
true_inet_pton = socket.inet_pton
7✔
62
true_urllib3_wrap_socket = urllib3_wrap_socket
7✔
63
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
7✔
64
true_urllib3_match_hostname = urllib3_match_hostname
7✔
65

66

67
class SuperFakeSSLContext:
7✔
68
    """For Python 3.6 and newer."""
69

70
    class FakeSetter(int):
7✔
71
        def __set__(self, *args):
7✔
72
            pass
7✔
73

74
    minimum_version = FakeSetter()
7✔
75
    options = FakeSetter()
7✔
76
    verify_mode = FakeSetter()
7✔
77
    verify_flags = FakeSetter()
7✔
78

79

80
class FakeSSLContext(SuperFakeSSLContext):
7✔
81
    DUMMY_METHODS = (
7✔
82
        "load_default_certs",
83
        "load_verify_locations",
84
        "set_alpn_protocols",
85
        "set_ciphers",
86
        "set_default_verify_paths",
87
    )
88
    sock = None
7✔
89
    post_handshake_auth = None
7✔
90
    _check_hostname = False
7✔
91

92
    @property
7✔
93
    def check_hostname(self):
7✔
94
        return self._check_hostname
5✔
95

96
    @check_hostname.setter
7✔
97
    def check_hostname(self, _):
7✔
98
        self._check_hostname = False
7✔
99

100
    def __init__(self, *args, **kwargs):
7✔
101
        self._set_dummy_methods()
7✔
102

103
    def _set_dummy_methods(self):
7✔
104
        def dummy_method(*args, **kwargs):
7✔
105
            pass
7✔
106

107
        for m in self.DUMMY_METHODS:
7✔
108
            setattr(self, m, dummy_method)
7✔
109

110
    @staticmethod
7✔
111
    def wrap_socket(sock, *args, **kwargs):
7✔
112
        sock.kwargs = kwargs
7✔
113
        sock._secure_socket = True
7✔
114
        return sock
7✔
115

116
    @staticmethod
7✔
117
    def wrap_bio(incoming, outcoming, *args, **kwargs):
7✔
118
        ssl_obj = MocketSocket()
7✔
119
        ssl_obj._host = kwargs["server_hostname"]
7✔
120
        return ssl_obj
7✔
121

122

123
def create_connection(address, timeout=None, source_address=None):
7✔
124
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
7✔
125
    if timeout:
7✔
126
        s.settimeout(timeout)
7✔
127
    s.connect(address)
7✔
128
    return s
7✔
129

130

131
def socketpair(*args, **kwargs):
7✔
132
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
133
    import _socket
7✔
134

135
    return _socket.socketpair(*args, **kwargs)
7✔
136

137

138
def _hash_request(h, req):
7✔
139
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
7✔
140

141

142
class MocketSocket:
7✔
143
    timeout = None
7✔
144
    _fd = None
7✔
145
    family = None
7✔
146
    type = None
7✔
147
    proto = None
7✔
148
    _host = None
7✔
149
    _port = None
7✔
150
    _address = None
7✔
151
    cipher = lambda s: ("ADH", "AES256", "SHA")
7✔
152
    compression = lambda s: ssl.OP_NO_COMPRESSION
7✔
153
    _mode = None
7✔
154
    _bufsize = None
7✔
155
    _secure_socket = False
7✔
156
    _did_handshake = False
7✔
157
    _sent_non_empty_bytes = False
7✔
158
    _io = None
7✔
159

160
    def __init__(
7✔
161
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
162
    ):
163
        self.true_socket = true_socket(family, type, proto)
7✔
164
        self._buflen = 65536
7✔
165
        self._entry = None
7✔
166
        self.family = int(family)
7✔
167
        self.type = int(type)
7✔
168
        self.proto = int(proto)
7✔
169
        self._truesocket_recording_dir = None
7✔
170
        self.kwargs = kwargs
7✔
171

172
    def __str__(self):
173
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
174

175
    def __enter__(self):
7✔
176
        return self
7✔
177

178
    def __exit__(self, exc_type, exc_val, exc_tb):
7✔
179
        self.close()
7✔
180

181
    @property
7✔
182
    def io(self):
7✔
183
        if self._io is None:
7✔
184
            self._io = MocketSocketCore((self._host, self._port))
7✔
185
        return self._io
7✔
186

187
    def fileno(self):
7✔
188
        address = (self._host, self._port)
7✔
189
        r_fd, _ = Mocket.get_pair(address)
7✔
190
        if not r_fd:
7✔
191
            r_fd, w_fd = os.pipe()
7✔
192
            Mocket.set_pair(address, (r_fd, w_fd))
7✔
193
        return r_fd
7✔
194

195
    def gettimeout(self):
7✔
196
        return self.timeout
7✔
197

198
    def setsockopt(self, family, type, proto):
7✔
199
        self.family = family
7✔
200
        self.type = type
7✔
201
        self.proto = proto
7✔
202

203
        if self.true_socket:
7✔
204
            self.true_socket.setsockopt(family, type, proto)
7✔
205

206
    def settimeout(self, timeout):
7✔
207
        self.timeout = timeout
7✔
208

209
    @staticmethod
7✔
210
    def getsockopt(level, optname, buflen=None):
7✔
211
        return socket.SOCK_STREAM
×
212

213
    def do_handshake(self):
7✔
214
        self._did_handshake = True
7✔
215

216
    def getpeername(self):
7✔
217
        return self._address
7✔
218

219
    def setblocking(self, block):
7✔
220
        self.settimeout(None) if block else self.settimeout(0.0)
7✔
221

222
    def getblocking(self):
7✔
223
        return self.gettimeout() is None
7✔
224

225
    def getsockname(self):
7✔
226
        return socket.gethostbyname(self._address[0]), self._address[1]
7✔
227

228
    def getpeercert(self, *args, **kwargs):
7✔
229
        if not (self._host and self._port):
7✔
230
            self._address = self._host, self._port = Mocket._address
7✔
231

232
        now = datetime.now()
7✔
233
        shift = now + timedelta(days=30 * 12)
7✔
234
        return {
7✔
235
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
236
            "subjectAltName": (
237
                ("DNS", f"*.{self._host}"),
238
                ("DNS", self._host),
239
                ("DNS", "*"),
240
            ),
241
            "subject": (
242
                (("organizationName", f"*.{self._host}"),),
243
                (("organizationalUnitName", "Domain Control Validated"),),
244
                (("commonName", f"*.{self._host}"),),
245
            ),
246
        }
247

248
    def unwrap(self):
7✔
249
        return self
4✔
250

251
    def write(self, data):
7✔
252
        return self.send(encode_to_bytes(data))
7✔
253

254
    def connect(self, address):
7✔
255
        self._address = self._host, self._port = address
7✔
256
        Mocket._address = address
7✔
257

258
    def makefile(self, mode="r", bufsize=-1):
7✔
259
        self._mode = mode
7✔
260
        self._bufsize = bufsize
7✔
261
        return self.io
7✔
262

263
    def get_entry(self, data):
7✔
264
        return Mocket.get_entry(self._host, self._port, data)
7✔
265

266
    def sendall(self, data, entry=None, *args, **kwargs):
7✔
267
        if entry is None:
7✔
268
            entry = self.get_entry(data)
7✔
269

270
        if entry:
7✔
271
            consume_response = entry.collect(data)
7✔
272
            response = entry.get_response() if consume_response is not False else None
7✔
273
        else:
274
            response = self.true_sendall(data, *args, **kwargs)
7✔
275

276
        if response is not None:
7✔
277
            self.io.seek(0)
7✔
278
            self.io.write(response)
7✔
279
            self.io.truncate()
7✔
280
            self.io.seek(0)
7✔
281

282
    def read(self, buffersize):
7✔
283
        rv = self.io.read(buffersize)
7✔
284
        if rv:
7✔
285
            self._sent_non_empty_bytes = True
7✔
286
        if self._did_handshake and not self._sent_non_empty_bytes:
7✔
UNCOV
287
            raise ssl.SSLWantReadError("The operation did not complete (read)")
3✔
288
        return rv
7✔
289

290
    def recv_into(self, buffer, buffersize=None, flags=None):
7✔
291
        if hasattr(buffer, "write"):
7✔
292
            return buffer.write(self.read(buffersize))
7✔
293
        # buffer is a memoryview
UNCOV
294
        data = self.read(buffersize)
3✔
UNCOV
295
        if data:
3✔
296
            buffer[: len(data)] = data
×
UNCOV
297
        return len(data)
3✔
298

299
    def recv(self, buffersize, flags=None):
7✔
300
        r_fd, _ = Mocket.get_pair((self._host, self._port))
7✔
301
        if r_fd:
7✔
302
            return os.read(r_fd, buffersize)
7✔
303
        data = self.read(buffersize)
7✔
304
        if data:
7✔
305
            return data
7✔
306
        # used by Redis mock
307
        exc = BlockingIOError()
7✔
308
        exc.errno = errno.EWOULDBLOCK
7✔
309
        exc.args = (0,)
7✔
310
        raise exc
7✔
311

312
    def true_sendall(self, data, *args, **kwargs):
7✔
313
        if not MocketMode().is_allowed((self._host, self._port)):
7✔
314
            MocketMode.raise_not_allowed()
7✔
315

316
        req = decode_from_bytes(data)
7✔
317
        # make request unique again
318
        req_signature = _hash_request(hasher, req)
7✔
319
        # port should be always a string
320
        port = str(self._port)
7✔
321

322
        # prepare responses dictionary
323
        responses = {}
7✔
324

325
        if Mocket.get_truesocket_recording_dir():
7✔
326
            path = os.path.join(
7✔
327
                Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
328
            )
329
            # check if there's already a recorded session dumped to a JSON file
330
            try:
7✔
331
                with open(path) as f:
7✔
332
                    responses = json.load(f)
7✔
333
            # if not, create a new dictionary
334
            except (FileNotFoundError, JSONDecodeError):
7✔
335
                pass
7✔
336

337
        try:
7✔
338
            try:
7✔
339
                response_dict = responses[self._host][port][req_signature]
7✔
340
            except KeyError:
7✔
341
                if hasher is not hashlib.md5:
7✔
342
                    # Fallback for backwards compatibility
343
                    req_signature = _hash_request(hashlib.md5, req)
7✔
344
                    response_dict = responses[self._host][port][req_signature]
7✔
345
                else:
346
                    raise
×
347
        except KeyError:
7✔
348
            # preventing next KeyError exceptions
349
            responses.setdefault(self._host, {})
7✔
350
            responses[self._host].setdefault(port, {})
7✔
351
            responses[self._host][port].setdefault(req_signature, {})
7✔
352
            response_dict = responses[self._host][port][req_signature]
7✔
353

354
        # try to get the response from the dictionary
355
        try:
7✔
356
            encoded_response = hexload(response_dict["response"])
7✔
357
        # if not available, call the real sendall
358
        except KeyError:
7✔
359
            host, port = self._host, self._port
7✔
360
            host = true_gethostbyname(host)
7✔
361

362
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
7✔
363
                self.true_socket = true_urllib3_ssl_wrap_socket(
7✔
364
                    self.true_socket,
365
                    **self.kwargs,
366
                )
367

368
            with contextlib.suppress(OSError, ValueError):
7✔
369
                # already connected
370
                self.true_socket.connect((host, port))
7✔
371
            self.true_socket.sendall(data, *args, **kwargs)
7✔
372
            encoded_response = b""
7✔
373
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
374
            while True:
5✔
375
                more_to_read = select.select([self.true_socket], [], [], 0.1)[0]
7✔
376
                if not more_to_read and encoded_response:
7✔
377
                    break
7✔
378
                new_content = self.true_socket.recv(self._buflen)
7✔
379
                if not new_content:
7✔
380
                    break
7✔
381
                encoded_response += new_content
7✔
382

383
            # dump the resulting dictionary to a JSON file
384
            if Mocket.get_truesocket_recording_dir():
7✔
385
                # update the dictionary with request and response lines
386
                response_dict["request"] = req
7✔
387
                response_dict["response"] = hexdump(encoded_response)
7✔
388

389
                with open(path, mode="w") as f:
7✔
390
                    f.write(
7✔
391
                        decode_from_bytes(
392
                            json.dumps(responses, indent=4, sort_keys=True)
393
                        )
394
                    )
395

396
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
397
        return encoded_response
7✔
398

399
    def send(self, data, *args, **kwargs):  # pragma: no cover
400
        entry = self.get_entry(data)
401
        if not entry or (entry and self._entry != entry):
402
            kwargs["entry"] = entry
403
            self.sendall(data, *args, **kwargs)
404
        else:
405
            req = Mocket.last_request()
406
            if hasattr(req, "add_data"):
407
                req.add_data(data)
408
        self._entry = entry
409
        return len(data)
410

411
    def close(self):
7✔
412
        if self.true_socket and not self.true_socket._closed:
7✔
413
            self.true_socket.close()
7✔
414
        self._fd = None
7✔
415

416
    def __getattr__(self, name):
7✔
417
        """Do nothing catchall function, for methods like shutdown()"""
418

419
        def do_nothing(*args, **kwargs):
7✔
420
            pass
7✔
421

422
        return do_nothing
7✔
423

424

425
class Mocket:
7✔
426
    _socket_pairs = {}
7✔
427
    _address = (None, None)
7✔
428
    _entries = collections.defaultdict(list)
7✔
429
    _requests = []
7✔
430
    _namespace = str(id(_entries))
7✔
431
    _truesocket_recording_dir = None
7✔
432

433
    @classmethod
7✔
434
    def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
7✔
435
        """
436
        Given the id() of the caller, return a pair of file descriptors
437
        as a tuple of two integers: (<read_fd>, <write_fd>)
438
        """
439
        return cls._socket_pairs.get(address, (None, None))
7✔
440

441
    @classmethod
7✔
442
    def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
7✔
443
        """
444
        Store a pair of file descriptors under the key `id_`
445
        as a tuple of two integers: (<read_fd>, <write_fd>)
446
        """
447
        cls._socket_pairs[address] = pair
7✔
448

449
    @classmethod
7✔
450
    def register(cls, *entries):
7✔
451
        for entry in entries:
7✔
452
            cls._entries[entry.location].append(entry)
7✔
453

454
    @classmethod
7✔
455
    def get_entry(cls, host, port, data):
7✔
456
        host = host or Mocket._address[0]
7✔
457
        port = port or Mocket._address[1]
7✔
458
        entries = cls._entries.get((host, port), [])
7✔
459
        for entry in entries:
7✔
460
            if entry.can_handle(data):
7✔
461
                return entry
7✔
462

463
    @classmethod
7✔
464
    def collect(cls, data):
7✔
465
        cls.request_list().append(data)
7✔
466

467
    @classmethod
7✔
468
    def reset(cls):
7✔
469
        for r_fd, w_fd in cls._socket_pairs.values():
7✔
470
            os.close(r_fd)
7✔
471
            os.close(w_fd)
7✔
472
        cls._socket_pairs = {}
7✔
473
        cls._entries = collections.defaultdict(list)
7✔
474
        cls._requests = []
7✔
475

476
    @classmethod
7✔
477
    def last_request(cls):
7✔
478
        if cls.has_requests():
7✔
479
            return cls.request_list()[-1]
7✔
480

481
    @classmethod
7✔
482
    def request_list(cls):
7✔
483
        return cls._requests
7✔
484

485
    @classmethod
7✔
486
    def remove_last_request(cls):
7✔
487
        if cls.has_requests():
7✔
488
            del cls._requests[-1]
7✔
489

490
    @classmethod
7✔
491
    def has_requests(cls):
7✔
492
        return bool(cls.request_list())
7✔
493

494
    @staticmethod
7✔
495
    def enable(namespace=None, truesocket_recording_dir=None):
7✔
496
        Mocket._namespace = namespace
7✔
497
        Mocket._truesocket_recording_dir = truesocket_recording_dir
7✔
498

499
        if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
7✔
500
            # JSON dumps will be saved here
501
            raise AssertionError
×
502

503
        socket.socket = socket.__dict__["socket"] = MocketSocket
7✔
504
        socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
7✔
505
        socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
7✔
506
        socket.create_connection = socket.__dict__["create_connection"] = (
7✔
507
            create_connection
508
        )
509
        socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
7✔
510
        socket.gethostbyname = socket.__dict__["gethostbyname"] = (
7✔
511
            lambda host: "127.0.0.1"
512
        )
513
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
7✔
514
            lambda host, port, family=None, socktype=None, proto=None, flags=None: [
515
                (2, 1, 6, "", (host, port))
516
            ]
517
        )
518
        socket.socketpair = socket.__dict__["socketpair"] = socketpair
7✔
519
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
7✔
520
        ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
7✔
521
        socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes(
7✔
522
            "\x7f\x00\x00\x01", "utf-8"
523
        )
524
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
7✔
525
            FakeSSLContext.wrap_socket
526
        )
527
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
7✔
528
            "ssl_wrap_socket"
529
        ] = FakeSSLContext.wrap_socket
530
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
7✔
531
            FakeSSLContext.wrap_socket
532
        )
533
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
7✔
534
            "ssl_wrap_socket"
535
        ] = FakeSSLContext.wrap_socket
536
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
7✔
537
            "match_hostname"
538
        ] = lambda *args: None
539
        if pyopenssl_override:  # pragma: no cover
540
            # Take out the pyopenssl version - use the default implementation
541
            extract_from_urllib3()
542

543
    @staticmethod
7✔
544
    def disable():
7✔
545
        socket.socket = socket.__dict__["socket"] = true_socket
7✔
546
        socket._socketobject = socket.__dict__["_socketobject"] = true_socket
7✔
547
        socket.SocketType = socket.__dict__["SocketType"] = true_socket
7✔
548
        socket.create_connection = socket.__dict__["create_connection"] = (
7✔
549
            true_create_connection
550
        )
551
        socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
7✔
552
        socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
7✔
553
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
7✔
554
        socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
7✔
555
        if true_ssl_wrap_socket:
7✔
556
            ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
5✔
557
        ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
7✔
558
        socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
7✔
559
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
7✔
560
            true_urllib3_wrap_socket
561
        )
562
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
7✔
563
            "ssl_wrap_socket"
564
        ] = true_urllib3_ssl_wrap_socket
565
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
7✔
566
            true_urllib3_ssl_wrap_socket
567
        )
568
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
7✔
569
            "ssl_wrap_socket"
570
        ] = true_urllib3_ssl_wrap_socket
571
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
7✔
572
            "match_hostname"
573
        ] = true_urllib3_match_hostname
574
        Mocket.reset()
7✔
575
        if pyopenssl_override:  # pragma: no cover
576
            # Put the pyopenssl version back in place
577
            inject_into_urllib3()
578

579
    @classmethod
7✔
580
    def get_namespace(cls):
7✔
581
        return cls._namespace
7✔
582

583
    @classmethod
7✔
584
    def get_truesocket_recording_dir(cls):
7✔
585
        return cls._truesocket_recording_dir
7✔
586

587
    @classmethod
7✔
588
    def assert_fail_if_entries_not_served(cls):
7✔
589
        """Mocket checks that all entries have been served at least once."""
590
        if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
7✔
591
            raise AssertionError("Some Mocket entries have not been served")
7✔
592

593

594
class MocketEntry:
7✔
595
    class Response(bytes):
7✔
596
        @property
7✔
597
        def data(self):
7✔
598
            return self
7✔
599

600
    response_index = 0
7✔
601
    request_cls = bytes
7✔
602
    response_cls = Response
7✔
603
    responses = None
7✔
604
    _served = None
7✔
605

606
    def __init__(self, location, responses):
7✔
607
        self._served = False
7✔
608
        self.location = location
7✔
609

610
        if not isinstance(responses, collections_abc.Iterable):
7✔
611
            responses = [responses]
7✔
612

613
        if not responses:
7✔
614
            self.responses = [self.response_cls(encode_to_bytes(""))]
7✔
615
        else:
616
            self.responses = []
7✔
617
            for r in responses:
7✔
618
                if not isinstance(r, BaseException) and not getattr(r, "data", False):
7✔
619
                    if isinstance(r, str):
7✔
620
                        r = encode_to_bytes(r)
7✔
621
                    r = self.response_cls(r)
7✔
622
                self.responses.append(r)
7✔
623

624
    def __repr__(self):
625
        return f"{self.__class__.__name__}(location={self.location})"
626

627
    @staticmethod
7✔
628
    def can_handle(data):
7✔
629
        return True
7✔
630

631
    def collect(self, data):
7✔
632
        req = self.request_cls(data)
7✔
633
        Mocket.collect(req)
7✔
634

635
    def get_response(self):
7✔
636
        response = self.responses[self.response_index]
7✔
637
        if self.response_index < len(self.responses) - 1:
7✔
638
            self.response_index += 1
7✔
639

640
        self._served = True
7✔
641

642
        if isinstance(response, BaseException):
7✔
643
            raise response
7✔
644

645
        return response.data
7✔
646

647

648
class Mocketizer:
7✔
649
    def __init__(
7✔
650
        self,
651
        instance=None,
652
        namespace=None,
653
        truesocket_recording_dir=None,
654
        strict_mode=False,
655
        strict_mode_allowed=None,
656
    ):
657
        self.instance = instance
7✔
658
        self.truesocket_recording_dir = truesocket_recording_dir
7✔
659
        self.namespace = namespace or str(id(self))
7✔
660
        MocketMode().STRICT = strict_mode
7✔
661
        if strict_mode:
7✔
662
            MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
7✔
663
        elif strict_mode_allowed:
7✔
664
            raise ValueError(
7✔
665
                "Allowed locations are only accepted when STRICT mode is active."
666
            )
667

668
    def enter(self):
7✔
669
        Mocket.enable(
7✔
670
            namespace=self.namespace,
671
            truesocket_recording_dir=self.truesocket_recording_dir,
672
        )
673
        if self.instance:
7✔
674
            self.check_and_call("mocketize_setup")
7✔
675

676
    def __enter__(self):
7✔
677
        self.enter()
7✔
678
        return self
7✔
679

680
    def exit(self):
7✔
681
        if self.instance:
7✔
682
            self.check_and_call("mocketize_teardown")
7✔
683
        Mocket.disable()
7✔
684

685
    def __exit__(self, type, value, tb):
7✔
686
        self.exit()
7✔
687

688
    async def __aenter__(self, *args, **kwargs):
7✔
689
        self.enter()
7✔
690
        return self
7✔
691

692
    async def __aexit__(self, *args, **kwargs):
7✔
693
        self.exit()
7✔
694

695
    def check_and_call(self, method_name):
7✔
696
        method = getattr(self.instance, method_name, None)
7✔
697
        if callable(method):
7✔
698
            method()
7✔
699

700
    @staticmethod
7✔
701
    def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
7✔
702
        instance = args[0] if args else None
7✔
703
        namespace = None
7✔
704
        if truesocket_recording_dir:
7✔
705
            namespace = ".".join(
7✔
706
                (
707
                    instance.__class__.__module__,
708
                    instance.__class__.__name__,
709
                    test.__name__,
710
                )
711
            )
712

713
        return Mocketizer(
7✔
714
            instance,
715
            namespace=namespace,
716
            truesocket_recording_dir=truesocket_recording_dir,
717
            strict_mode=strict_mode,
718
            strict_mode_allowed=strict_mode_allowed,
719
        )
720

721

722
def wrapper(
7✔
723
    test,
724
    truesocket_recording_dir=None,
725
    strict_mode=False,
726
    strict_mode_allowed=None,
727
    *args,
728
    **kwargs,
729
):
730
    with Mocketizer.factory(
7✔
731
        test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
732
    ):
733
        return test(*args, **kwargs)
7✔
734

735

736
mocketize = get_mocketize(wrapper_=wrapper)
7✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc