• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 9054312782

12 May 2024 09:21PM UTC coverage: 98.57% (-0.5%) from 99.046%
9054312782

push

github

web-flow
Fix coveralls badge

827 of 839 relevant lines covered (98.57%)

2.94 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.34
/mocket/mocket.py
1
import collections
3✔
2
import collections.abc as collections_abc
3✔
3
import contextlib
3✔
4
import errno
3✔
5
import hashlib
3✔
6
import itertools
3✔
7
import json
3✔
8
import os
3✔
9
import select
3✔
10
import socket
3✔
11
import ssl
3✔
12
from datetime import datetime, timedelta
3✔
13
from json.decoder import JSONDecodeError
3✔
14

15
import urllib3
3✔
16
from urllib3.connection import match_hostname as urllib3_match_hostname
3✔
17
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
3✔
18

19
try:
3✔
20
    from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
3✔
21
except ImportError:
22
    urllib3_wrap_socket = None
23

24

25
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
3✔
26
from .utils import (
3✔
27
    SSL_PROTOCOL,
28
    MocketMode,
29
    MocketSocketCore,
30
    get_mocketize,
31
    hexdump,
32
    hexload,
33
)
34

35
xxh32 = None
3✔
36
try:
3✔
37
    from xxhash import xxh32
3✔
38
except ImportError:  # pragma: no cover
39
    with contextlib.suppress(ImportError):
40
        from xxhash_cffi import xxh32
41
hasher = xxh32 or hashlib.md5
3✔
42

43
try:  # pragma: no cover
44
    from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
45

46
    pyopenssl_override = True
47
except ImportError:
48
    pyopenssl_override = False
49

50
try:  # pragma: no cover
51
    from aiohttp import TCPConnector
52

53
    aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
54
except (ImportError, AttributeError):
1✔
55
    aiohttp_make_ssl_context_cache_clear = None
1✔
56

57

58
true_socket = socket.socket
3✔
59
true_create_connection = socket.create_connection
3✔
60
true_gethostbyname = socket.gethostbyname
3✔
61
true_gethostname = socket.gethostname
3✔
62
true_getaddrinfo = socket.getaddrinfo
3✔
63
true_socketpair = socket.socketpair
3✔
64
true_ssl_wrap_socket = getattr(
3✔
65
    ssl, "wrap_socket", None
66
)  # in Py3.12 it's only under SSLContext
67
true_ssl_socket = ssl.SSLSocket
3✔
68
true_ssl_context = ssl.SSLContext
3✔
69
true_inet_pton = socket.inet_pton
3✔
70
true_urllib3_wrap_socket = urllib3_wrap_socket
3✔
71
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
3✔
72
true_urllib3_match_hostname = urllib3_match_hostname
3✔
73

74

75
class SuperFakeSSLContext:
3✔
76
    """For Python 3.6"""
77

78
    class FakeSetter(int):
3✔
79
        def __set__(self, *args):
3✔
80
            pass
3✔
81

82
    minimum_version = FakeSetter()
3✔
83
    options = FakeSetter()
3✔
84
    verify_mode = FakeSetter(ssl.CERT_NONE)
3✔
85

86

87
class FakeSSLContext(SuperFakeSSLContext):
3✔
88
    DUMMY_METHODS = (
3✔
89
        "load_default_certs",
90
        "load_verify_locations",
91
        "set_alpn_protocols",
92
        "set_ciphers",
93
        "set_default_verify_paths",
94
    )
95
    sock = None
3✔
96
    post_handshake_auth = None
3✔
97
    _check_hostname = False
3✔
98

99
    @property
3✔
100
    def check_hostname(self):
3✔
101
        return self._check_hostname
2✔
102

103
    @check_hostname.setter
3✔
104
    def check_hostname(self, *args):
3✔
105
        self._check_hostname = False
3✔
106

107
    def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
3✔
108
        self._set_dummy_methods()
3✔
109

110
        if isinstance(sock, MocketSocket):
3✔
111
            self.sock = sock
×
112
            self.sock._host = server_hostname
×
113
            self.sock.true_socket = true_ssl_socket(
×
114
                sock=self.sock.true_socket,
115
                server_hostname=server_hostname,
116
                _context=true_ssl_context(protocol=SSL_PROTOCOL),
117
            )
118
        elif isinstance(sock, int) and true_ssl_context:
3✔
119
            self.context = true_ssl_context(sock)
3✔
120

121
    def _set_dummy_methods(self):
3✔
122
        def dummy_method(*args, **kwargs):
3✔
123
            pass
3✔
124

125
        for m in self.DUMMY_METHODS:
3✔
126
            setattr(self, m, dummy_method)
3✔
127

128
    @staticmethod
3✔
129
    def wrap_socket(sock=sock, *args, **kwargs):
3✔
130
        sock.kwargs = kwargs
3✔
131
        sock._secure_socket = True
3✔
132
        return sock
3✔
133

134
    @staticmethod
3✔
135
    def wrap_bio(incoming, outcoming, *args, **kwargs):
3✔
136
        ssl_obj = MocketSocket()
3✔
137
        ssl_obj._host = kwargs["server_hostname"]
3✔
138
        return ssl_obj
3✔
139

140
    def __getattr__(self, name):
3✔
141
        if self.sock is not None:
3✔
142
            return getattr(self.sock, name)
×
143

144

145
def create_connection(address, timeout=None, source_address=None):
3✔
146
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
3✔
147
    if timeout:
3✔
148
        s.settimeout(timeout)
3✔
149
    s.connect(address)
3✔
150
    return s
3✔
151

152

153
def socketpair(*args, **kwargs):
3✔
154
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
155
    import _socket
3✔
156

157
    return _socket.socketpair(*args, **kwargs)
3✔
158

159

160
def _hash_request(h, req):
3✔
161
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
3✔
162

163

164
class MocketSocket:
3✔
165
    timeout = None
3✔
166
    _fd = None
3✔
167
    family = None
3✔
168
    type = None
3✔
169
    proto = None
3✔
170
    _host = None
3✔
171
    _port = None
3✔
172
    _address = None
3✔
173
    cipher = lambda s: ("ADH", "AES256", "SHA")
3✔
174
    compression = lambda s: ssl.OP_NO_COMPRESSION
3✔
175
    _mode = None
3✔
176
    _bufsize = None
3✔
177
    _secure_socket = False
3✔
178

179
    def __init__(
3✔
180
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
181
    ):
182
        self.true_socket = true_socket(family, type, proto)
3✔
183
        self._buflen = 65536
3✔
184
        self._entry = None
3✔
185
        self.family = int(family)
3✔
186
        self.type = int(type)
3✔
187
        self.proto = int(proto)
3✔
188
        self._truesocket_recording_dir = None
3✔
189
        self._did_handshake = False
3✔
190
        self._sent_non_empty_bytes = False
3✔
191
        self.kwargs = kwargs
3✔
192

193
    def __str__(self):
194
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
195

196
    def __enter__(self):
3✔
197
        return self
3✔
198

199
    def __exit__(self, exc_type, exc_val, exc_tb):
3✔
200
        self.close()
3✔
201

202
    @property
3✔
203
    def fd(self):
3✔
204
        if self._fd is None:
3✔
205
            self._fd = MocketSocketCore()
3✔
206
        return self._fd
3✔
207

208
    def gettimeout(self):
3✔
209
        return self.timeout
3✔
210

211
    def setsockopt(self, family, type, proto):
3✔
212
        self.family = family
3✔
213
        self.type = type
3✔
214
        self.proto = proto
3✔
215

216
        if self.true_socket:
3✔
217
            self.true_socket.setsockopt(family, type, proto)
3✔
218

219
    def settimeout(self, timeout):
3✔
220
        self.timeout = timeout
3✔
221

222
    @staticmethod
3✔
223
    def getsockopt(level, optname, buflen=None):
3✔
224
        return socket.SOCK_STREAM
×
225

226
    def do_handshake(self):
3✔
227
        self._did_handshake = True
3✔
228

229
    def getpeername(self):
3✔
230
        return self._address
3✔
231

232
    def setblocking(self, block):
3✔
233
        self.settimeout(None) if block else self.settimeout(0.0)
3✔
234

235
    def getblocking(self):
3✔
236
        return self.gettimeout() is None
3✔
237

238
    def getsockname(self):
3✔
239
        return socket.gethostbyname(self._address[0]), self._address[1]
3✔
240

241
    def getpeercert(self, *args, **kwargs):
3✔
242
        if not (self._host and self._port):
2✔
243
            self._address = self._host, self._port = Mocket._address
2✔
244

245
        now = datetime.now()
2✔
246
        shift = now + timedelta(days=30 * 12)
2✔
247
        return {
2✔
248
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
249
            "subjectAltName": (
250
                ("DNS", f"*.{self._host}"),
251
                ("DNS", self._host),
252
                ("DNS", "*"),
253
            ),
254
            "subject": (
255
                (("organizationName", f"*.{self._host}"),),
256
                (("organizationalUnitName", "Domain Control Validated"),),
257
                (("commonName", f"*.{self._host}"),),
258
            ),
259
        }
260

261
    def unwrap(self):
3✔
262
        return self
2✔
263

264
    def write(self, data):
3✔
265
        return self.send(encode_to_bytes(data))
3✔
266

267
    @staticmethod
3✔
268
    def fileno():
3✔
269
        if Mocket.r_fd is not None:
3✔
270
            return Mocket.r_fd
3✔
271
        Mocket.r_fd, Mocket.w_fd = os.pipe()
3✔
272
        return Mocket.r_fd
3✔
273

274
    def connect(self, address):
3✔
275
        self._address = self._host, self._port = address
3✔
276
        Mocket._address = address
3✔
277

278
    def makefile(self, mode="r", bufsize=-1):
3✔
279
        self._mode = mode
3✔
280
        self._bufsize = bufsize
3✔
281
        return self.fd
3✔
282

283
    def get_entry(self, data):
3✔
284
        return Mocket.get_entry(self._host, self._port, data)
3✔
285

286
    def sendall(self, data, entry=None, *args, **kwargs):
3✔
287
        if entry is None:
3✔
288
            entry = self.get_entry(data)
3✔
289

290
        if entry:
3✔
291
            consume_response = entry.collect(data)
3✔
292
            response = entry.get_response() if consume_response is not False else None
3✔
293
        else:
294
            response = self.true_sendall(data, *args, **kwargs)
3✔
295

296
        if response is not None:
3✔
297
            self.fd.seek(0)
3✔
298
            self.fd.write(response)
3✔
299
            self.fd.truncate()
3✔
300
            self.fd.seek(0)
3✔
301

302
    def read(self, buffersize):
3✔
303
        rv = self.fd.read(buffersize)
3✔
304
        if rv:
3✔
305
            self._sent_non_empty_bytes = True
3✔
306
        if self._did_handshake and not self._sent_non_empty_bytes:
3✔
307
            raise ssl.SSLWantReadError("The operation did not complete (read)")
×
308
        return rv
3✔
309

310
    def recv_into(self, buffer, buffersize=None, flags=None):
3✔
311
        if hasattr(buffer, "write"):
3✔
312
            return buffer.write(self.read(buffersize))
3✔
313
        # buffer is a memoryview
314
        data = self.read(buffersize)
×
315
        if data:
×
316
            buffer[: len(data)] = data
×
317
        return len(data)
×
318

319
    def recv(self, buffersize, flags=None):
3✔
320
        if Mocket.r_fd and Mocket.w_fd:
3✔
321
            return os.read(Mocket.r_fd, buffersize)
3✔
322
        data = self.read(buffersize)
3✔
323
        if data:
3✔
324
            return data
3✔
325
        # used by Redis mock
326
        exc = BlockingIOError()
3✔
327
        exc.errno = errno.EWOULDBLOCK
3✔
328
        exc.args = (0,)
3✔
329
        raise exc
3✔
330

331
    def true_sendall(self, data, *args, **kwargs):
3✔
332
        if not MocketMode().is_allowed((self._host, self._port)):
3✔
333
            MocketMode.raise_not_allowed()
3✔
334

335
        req = decode_from_bytes(data)
3✔
336
        # make request unique again
337
        req_signature = _hash_request(hasher, req)
3✔
338
        # port should be always a string
339
        port = text_type(self._port)
3✔
340

341
        # prepare responses dictionary
342
        responses = {}
3✔
343

344
        if Mocket.get_truesocket_recording_dir():
3✔
345
            path = os.path.join(
3✔
346
                Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
347
            )
348
            # check if there's already a recorded session dumped to a JSON file
349
            try:
3✔
350
                with open(path) as f:
3✔
351
                    responses = json.load(f)
3✔
352
            # if not, create a new dictionary
353
            except (FileNotFoundError, JSONDecodeError):
3✔
354
                pass
3✔
355

356
        try:
3✔
357
            try:
3✔
358
                response_dict = responses[self._host][port][req_signature]
3✔
359
            except KeyError:
3✔
360
                if hasher is not hashlib.md5:
3✔
361
                    # Fallback for backwards compatibility
362
                    req_signature = _hash_request(hashlib.md5, req)
3✔
363
                    response_dict = responses[self._host][port][req_signature]
3✔
364
                else:
365
                    raise
×
366
        except KeyError:
3✔
367
            # preventing next KeyError exceptions
368
            responses.setdefault(self._host, {})
3✔
369
            responses[self._host].setdefault(port, {})
3✔
370
            responses[self._host][port].setdefault(req_signature, {})
3✔
371
            response_dict = responses[self._host][port][req_signature]
3✔
372

373
        # try to get the response from the dictionary
374
        try:
3✔
375
            encoded_response = hexload(response_dict["response"])
3✔
376
        # if not available, call the real sendall
377
        except KeyError:
3✔
378
            host, port = self._host, self._port
3✔
379
            host = true_gethostbyname(host)
3✔
380

381
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
3✔
382
                self.true_socket = true_urllib3_ssl_wrap_socket(
3✔
383
                    self.true_socket,
384
                    **self.kwargs,
385
                )
386

387
            with contextlib.suppress(OSError, ValueError):
3✔
388
                # already connected
389
                self.true_socket.connect((host, port))
3✔
390
            self.true_socket.sendall(data, *args, **kwargs)
3✔
391
            encoded_response = b""
3✔
392
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L13
393
            while True:
1✔
394
                if (
3✔
395
                    not select.select([self.true_socket], [], [], 0.1)[0]
396
                    and encoded_response
397
                ):
398
                    break
3✔
399
                recv = self.true_socket.recv(self._buflen)
3✔
400

401
                if not recv and encoded_response:
3✔
402
                    break
3✔
403
                encoded_response += recv
3✔
404

405
            # dump the resulting dictionary to a JSON file
406
            if Mocket.get_truesocket_recording_dir():
3✔
407
                # update the dictionary with request and response lines
408
                response_dict["request"] = req
3✔
409
                response_dict["response"] = hexdump(encoded_response)
3✔
410

411
                with open(path, mode="w") as f:
3✔
412
                    f.write(
3✔
413
                        decode_from_bytes(
414
                            json.dumps(responses, indent=4, sort_keys=True)
415
                        )
416
                    )
417

418
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
419
        return encoded_response
3✔
420

421
    def send(self, data, *args, **kwargs):  # pragma: no cover
422
        entry = self.get_entry(data)
423
        kwargs["entry"] = entry
424
        if not entry or (entry and self._entry != entry):
425
            self.sendall(data, *args, **kwargs)
426
        else:
427
            req = Mocket.last_request()
428
            if hasattr(req, "add_data"):
429
                req.add_data(data)
430
        self._entry = entry
431
        return len(data)
432

433
    def close(self):
3✔
434
        if self.true_socket and not self.true_socket._closed:
3✔
435
            self.true_socket.close()
3✔
436
        self._fd = None
3✔
437

438
    def __getattr__(self, name):
3✔
439
        """Do nothing catchall function, for methods like close() and shutdown()"""
440

441
        def do_nothing(*args, **kwargs):
3✔
442
            pass
3✔
443

444
        return do_nothing
3✔
445

446

447
class Mocket:
3✔
448
    _address = (None, None)
3✔
449
    _entries = collections.defaultdict(list)
3✔
450
    _requests = []
3✔
451
    _namespace = text_type(id(_entries))
3✔
452
    _truesocket_recording_dir = None
3✔
453
    r_fd = None
3✔
454
    w_fd = None
3✔
455

456
    @classmethod
3✔
457
    def register(cls, *entries):
3✔
458
        for entry in entries:
3✔
459
            cls._entries[entry.location].append(entry)
3✔
460

461
    @classmethod
3✔
462
    def get_entry(cls, host, port, data):
3✔
463
        host = host or Mocket._address[0]
3✔
464
        port = port or Mocket._address[1]
3✔
465
        entries = cls._entries.get((host, port), [])
3✔
466
        for entry in entries:
3✔
467
            if entry.can_handle(data):
3✔
468
                return entry
3✔
469

470
    @classmethod
3✔
471
    def collect(cls, data):
3✔
472
        cls.request_list().append(data)
3✔
473

474
    @classmethod
3✔
475
    def reset(cls):
3✔
476
        if cls.r_fd is not None:
3✔
477
            os.close(cls.r_fd)
3✔
478
            cls.r_fd = None
3✔
479
        if cls.w_fd is not None:
3✔
480
            os.close(cls.w_fd)
3✔
481
            cls.w_fd = None
3✔
482
        cls._entries = collections.defaultdict(list)
3✔
483
        cls._requests = []
3✔
484

485
    @classmethod
3✔
486
    def last_request(cls):
3✔
487
        if cls.has_requests():
3✔
488
            return cls.request_list()[-1]
3✔
489

490
    @classmethod
3✔
491
    def request_list(cls):
3✔
492
        return cls._requests
3✔
493

494
    @classmethod
3✔
495
    def remove_last_request(cls):
3✔
496
        if cls.has_requests():
3✔
497
            del cls._requests[-1]
3✔
498

499
    @classmethod
3✔
500
    def has_requests(cls):
3✔
501
        return bool(cls.request_list())
3✔
502

503
    @staticmethod
3✔
504
    def enable(namespace=None, truesocket_recording_dir=None):
3✔
505
        Mocket._namespace = namespace
3✔
506
        Mocket._truesocket_recording_dir = truesocket_recording_dir
3✔
507

508
        if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
3✔
509
            # JSON dumps will be saved here
510
            raise AssertionError
×
511

512
        socket.socket = socket.__dict__["socket"] = MocketSocket
3✔
513
        socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
3✔
514
        socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
3✔
515
        socket.create_connection = socket.__dict__["create_connection"] = (
3✔
516
            create_connection
517
        )
518
        socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
3✔
519
        socket.gethostbyname = socket.__dict__["gethostbyname"] = (
3✔
520
            lambda host: "127.0.0.1"
521
        )
522
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
3✔
523
            lambda host, port, family=None, socktype=None, proto=None, flags=None: [
524
                (2, 1, 6, "", (host, port))
525
            ]
526
        )
527
        socket.socketpair = socket.__dict__["socketpair"] = socketpair
3✔
528
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
3✔
529
        ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
3✔
530
        socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type(
3✔
531
            "\x7f\x00\x00\x01", "utf-8"
532
        )
533
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
3✔
534
            FakeSSLContext.wrap_socket
535
        )
536
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
3✔
537
            "ssl_wrap_socket"
538
        ] = FakeSSLContext.wrap_socket
539
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
3✔
540
            FakeSSLContext.wrap_socket
541
        )
542
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
3✔
543
            "ssl_wrap_socket"
544
        ] = FakeSSLContext.wrap_socket
545
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
3✔
546
            "match_hostname"
547
        ] = lambda *args: None
548
        if pyopenssl_override:  # pragma: no cover
549
            # Take out the pyopenssl version - use the default implementation
550
            extract_from_urllib3()
551
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
552
            aiohttp_make_ssl_context_cache_clear()
553

554
    @staticmethod
3✔
555
    def disable():
3✔
556
        socket.socket = socket.__dict__["socket"] = true_socket
3✔
557
        socket._socketobject = socket.__dict__["_socketobject"] = true_socket
3✔
558
        socket.SocketType = socket.__dict__["SocketType"] = true_socket
3✔
559
        socket.create_connection = socket.__dict__["create_connection"] = (
3✔
560
            true_create_connection
561
        )
562
        socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
3✔
563
        socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
3✔
564
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
3✔
565
        socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
3✔
566
        if true_ssl_wrap_socket:
3✔
567
            ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
2✔
568
        ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
3✔
569
        socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
3✔
570
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
3✔
571
            true_urllib3_wrap_socket
572
        )
573
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
3✔
574
            "ssl_wrap_socket"
575
        ] = true_urllib3_ssl_wrap_socket
576
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
3✔
577
            true_urllib3_ssl_wrap_socket
578
        )
579
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
3✔
580
            "ssl_wrap_socket"
581
        ] = true_urllib3_ssl_wrap_socket
582
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
3✔
583
            "match_hostname"
584
        ] = true_urllib3_match_hostname
585
        Mocket.reset()
3✔
586
        if pyopenssl_override:  # pragma: no cover
587
            # Put the pyopenssl version back in place
588
            inject_into_urllib3()
589
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
590
            aiohttp_make_ssl_context_cache_clear()
591

592
    @classmethod
3✔
593
    def get_namespace(cls):
3✔
594
        return cls._namespace
3✔
595

596
    @classmethod
3✔
597
    def get_truesocket_recording_dir(cls):
3✔
598
        return cls._truesocket_recording_dir
3✔
599

600
    @classmethod
3✔
601
    def assert_fail_if_entries_not_served(cls):
3✔
602
        """Mocket checks that all entries have been served at least once."""
603
        if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
3✔
604
            raise AssertionError("Some Mocket entries have not been served")
3✔
605

606

607
class MocketEntry:
3✔
608
    class Response(byte_type):
3✔
609
        @property
3✔
610
        def data(self):
3✔
611
            return self
3✔
612

613
    response_index = 0
3✔
614
    request_cls = byte_type
3✔
615
    response_cls = Response
3✔
616
    responses = None
3✔
617
    _served = None
3✔
618

619
    def __init__(self, location, responses):
3✔
620
        self._served = False
3✔
621
        self.location = location
3✔
622

623
        if not isinstance(responses, collections_abc.Iterable) or isinstance(
3✔
624
            responses, basestring
625
        ):
626
            responses = [responses]
3✔
627

628
        if not responses:
3✔
629
            self.responses = [self.response_cls(encode_to_bytes(""))]
3✔
630
        else:
631
            self.responses = []
3✔
632
            for r in responses:
3✔
633
                if not isinstance(r, BaseException) and not getattr(r, "data", False):
3✔
634
                    if isinstance(r, text_type):
3✔
635
                        r = encode_to_bytes(r)
3✔
636
                    r = self.response_cls(r)
3✔
637
                self.responses.append(r)
3✔
638

639
    def __repr__(self):
640
        return f"{self.__class__.__name__}(location={self.location})"
641

642
    @staticmethod
3✔
643
    def can_handle(data):
3✔
644
        return True
3✔
645

646
    def collect(self, data):
3✔
647
        req = self.request_cls(data)
3✔
648
        Mocket.collect(req)
3✔
649

650
    def get_response(self):
3✔
651
        response = self.responses[self.response_index]
3✔
652
        if self.response_index < len(self.responses) - 1:
3✔
653
            self.response_index += 1
3✔
654

655
        self._served = True
3✔
656

657
        if isinstance(response, BaseException):
3✔
658
            raise response
3✔
659

660
        return response.data
3✔
661

662

663
class Mocketizer:
3✔
664
    def __init__(
3✔
665
        self,
666
        instance=None,
667
        namespace=None,
668
        truesocket_recording_dir=None,
669
        strict_mode=False,
670
        strict_mode_allowed=None,
671
    ):
672
        self.instance = instance
3✔
673
        self.truesocket_recording_dir = truesocket_recording_dir
3✔
674
        self.namespace = namespace or text_type(id(self))
3✔
675
        MocketMode().STRICT = strict_mode
3✔
676
        if strict_mode:
3✔
677
            MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
3✔
678
        elif strict_mode_allowed:
3✔
679
            raise ValueError(
3✔
680
                "Allowed locations are only accepted when STRICT mode is active."
681
            )
682

683
    def enter(self):
3✔
684
        Mocket.enable(
3✔
685
            namespace=self.namespace,
686
            truesocket_recording_dir=self.truesocket_recording_dir,
687
        )
688
        if self.instance:
3✔
689
            self.check_and_call("mocketize_setup")
3✔
690

691
    def __enter__(self):
3✔
692
        self.enter()
3✔
693
        return self
3✔
694

695
    def exit(self):
3✔
696
        if self.instance:
3✔
697
            self.check_and_call("mocketize_teardown")
3✔
698
        Mocket.disable()
3✔
699

700
    def __exit__(self, type, value, tb):
3✔
701
        self.exit()
3✔
702

703
    async def __aenter__(self, *args, **kwargs):
3✔
704
        self.enter()
3✔
705
        return self
3✔
706

707
    async def __aexit__(self, *args, **kwargs):
3✔
708
        self.exit()
3✔
709

710
    def check_and_call(self, method_name):
3✔
711
        method = getattr(self.instance, method_name, None)
3✔
712
        if callable(method):
3✔
713
            method()
3✔
714

715
    @staticmethod
3✔
716
    def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
3✔
717
        instance = args[0] if args else None
3✔
718
        namespace = None
3✔
719
        if truesocket_recording_dir:
3✔
720
            namespace = ".".join(
3✔
721
                (
722
                    instance.__class__.__module__,
723
                    instance.__class__.__name__,
724
                    test.__name__,
725
                )
726
            )
727

728
        return Mocketizer(
3✔
729
            instance,
730
            namespace=namespace,
731
            truesocket_recording_dir=truesocket_recording_dir,
732
            strict_mode=strict_mode,
733
            strict_mode_allowed=strict_mode_allowed,
734
        )
735

736

737
def wrapper(
3✔
738
    test,
739
    truesocket_recording_dir=None,
740
    strict_mode=False,
741
    strict_mode_allowed=None,
742
    *args,
743
    **kwargs,
744
):
745
    with Mocketizer.factory(
3✔
746
        test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
747
    ):
748
        return test(*args, **kwargs)
3✔
749

750

751
mocketize = get_mocketize(wrapper_=wrapper)
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc