• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 11416392747

19 Oct 2024 09:20AM UTC coverage: 99.06% (+0.003%) from 99.057%
11416392747

push

github

web-flow
Switching to using `puremagic` for identifying MIME types. (#255)

10 of 10 new or added lines in 3 files covered. (100.0%)

1 existing line in 1 file now uncovered.

843 of 851 relevant lines covered (99.06%)

6.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.26
/mocket/mocket.py
1
import collections
7✔
2
import collections.abc as collections_abc
7✔
3
import contextlib
7✔
4
import errno
7✔
5
import hashlib
7✔
6
import itertools
7✔
7
import json
7✔
8
import os
7✔
9
import select
7✔
10
import socket
7✔
11
import ssl
7✔
12
from datetime import datetime, timedelta
7✔
13
from json.decoder import JSONDecodeError
7✔
14
from typing import Optional, Tuple
7✔
15

16
import urllib3
7✔
17
from urllib3.connection import match_hostname as urllib3_match_hostname
7✔
18
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
7✔
19

20
try:
7✔
21
    from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
7✔
22
except ImportError:
23
    urllib3_wrap_socket = None
24

25

26
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
7✔
27
from .utils import (
7✔
28
    SSL_PROTOCOL,
29
    MocketMode,
30
    MocketSocketCore,
31
    get_mocketize,
32
    hexdump,
33
    hexload,
34
)
35

36
xxh32 = None
7✔
37
try:
7✔
38
    from xxhash import xxh32
7✔
39
except ImportError:  # pragma: no cover
40
    with contextlib.suppress(ImportError):
41
        from xxhash_cffi import xxh32
42
hasher = xxh32 or hashlib.md5
7✔
43

44
try:  # pragma: no cover
45
    from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
46

47
    pyopenssl_override = True
48
except ImportError:
49
    pyopenssl_override = False
50

51
try:  # pragma: no cover
52
    from aiohttp import TCPConnector
53

54
    aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
55
except (ImportError, AttributeError):
7✔
56
    aiohttp_make_ssl_context_cache_clear = None
7✔
57

58

59
true_socket = socket.socket
7✔
60
true_create_connection = socket.create_connection
7✔
61
true_gethostbyname = socket.gethostbyname
7✔
62
true_gethostname = socket.gethostname
7✔
63
true_getaddrinfo = socket.getaddrinfo
7✔
64
true_socketpair = socket.socketpair
7✔
65
true_ssl_wrap_socket = getattr(
7✔
66
    ssl, "wrap_socket", None
67
)  # from Py3.12 it's only under SSLContext
68
true_ssl_socket = ssl.SSLSocket
7✔
69
true_ssl_context = ssl.SSLContext
7✔
70
true_inet_pton = socket.inet_pton
7✔
71
true_urllib3_wrap_socket = urllib3_wrap_socket
7✔
72
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
7✔
73
true_urllib3_match_hostname = urllib3_match_hostname
7✔
74

75

76
class SuperFakeSSLContext:
7✔
77
    """For Python 3.6 and newer."""
78

79
    class FakeSetter(int):
7✔
80
        def __set__(self, *args):
7✔
81
            pass
7✔
82

83
    minimum_version = FakeSetter()
7✔
84
    options = FakeSetter()
7✔
85
    verify_mode = FakeSetter()
7✔
86
    verify_flags = FakeSetter()
7✔
87

88

89
class FakeSSLContext(SuperFakeSSLContext):
7✔
90
    DUMMY_METHODS = (
7✔
91
        "load_default_certs",
92
        "load_verify_locations",
93
        "set_alpn_protocols",
94
        "set_ciphers",
95
        "set_default_verify_paths",
96
    )
97
    sock = None
7✔
98
    post_handshake_auth = None
7✔
99
    _check_hostname = False
7✔
100

101
    @property
7✔
102
    def check_hostname(self):
7✔
103
        return self._check_hostname
5✔
104

105
    @check_hostname.setter
7✔
106
    def check_hostname(self, _):
7✔
107
        self._check_hostname = False
7✔
108

109
    def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
7✔
110
        self._set_dummy_methods()
7✔
111

112
        if isinstance(sock, MocketSocket):
7✔
113
            self.sock = sock
×
114
            self.sock._host = server_hostname
×
115
            self.sock.true_socket = true_ssl_socket(
×
116
                sock=self.sock.true_socket,
117
                server_hostname=server_hostname,
118
                _context=true_ssl_context(protocol=SSL_PROTOCOL),
119
            )
120
        elif isinstance(sock, int) and true_ssl_context:
7✔
121
            self.context = true_ssl_context(sock)
7✔
122

123
    def _set_dummy_methods(self):
7✔
124
        def dummy_method(*args, **kwargs):
7✔
125
            pass
7✔
126

127
        for m in self.DUMMY_METHODS:
7✔
128
            setattr(self, m, dummy_method)
7✔
129

130
    @staticmethod
7✔
131
    def wrap_socket(sock=sock, *args, **kwargs):
7✔
132
        sock.kwargs = kwargs
7✔
133
        sock._secure_socket = True
7✔
134
        return sock
7✔
135

136
    @staticmethod
7✔
137
    def wrap_bio(incoming, outcoming, *args, **kwargs):
7✔
138
        ssl_obj = MocketSocket()
7✔
139
        ssl_obj._host = kwargs["server_hostname"]
7✔
140
        return ssl_obj
7✔
141

142
    def __getattr__(self, name):
7✔
143
        if self.sock is not None:
7✔
144
            return getattr(self.sock, name)
×
145

146

147
def create_connection(address, timeout=None, source_address=None):
7✔
148
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
7✔
149
    if timeout:
7✔
150
        s.settimeout(timeout)
7✔
151
    s.connect(address)
7✔
152
    return s
7✔
153

154

155
def socketpair(*args, **kwargs):
7✔
156
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
157
    import _socket
7✔
158

159
    return _socket.socketpair(*args, **kwargs)
7✔
160

161

162
def _hash_request(h, req):
7✔
163
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
7✔
164

165

166
class MocketSocket:
7✔
167
    timeout = None
7✔
168
    _fd = None
7✔
169
    family = None
7✔
170
    type = None
7✔
171
    proto = None
7✔
172
    _host = None
7✔
173
    _port = None
7✔
174
    _address = None
7✔
175
    cipher = lambda s: ("ADH", "AES256", "SHA")
7✔
176
    compression = lambda s: ssl.OP_NO_COMPRESSION
7✔
177
    _mode = None
7✔
178
    _bufsize = None
7✔
179
    _secure_socket = False
7✔
180
    _did_handshake = False
7✔
181
    _sent_non_empty_bytes = False
7✔
182
    _io = None
7✔
183

184
    def __init__(
7✔
185
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
186
    ):
187
        self.true_socket = true_socket(family, type, proto)
7✔
188
        self._buflen = 65536
7✔
189
        self._entry = None
7✔
190
        self.family = int(family)
7✔
191
        self.type = int(type)
7✔
192
        self.proto = int(proto)
7✔
193
        self._truesocket_recording_dir = None
7✔
194
        self.kwargs = kwargs
7✔
195

196
    def __str__(self):
197
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
198

199
    def __enter__(self):
7✔
200
        return self
7✔
201

202
    def __exit__(self, exc_type, exc_val, exc_tb):
7✔
203
        self.close()
7✔
204

205
    @property
7✔
206
    def io(self):
7✔
207
        if self._io is None:
7✔
208
            self._io = MocketSocketCore((self._host, self._port))
7✔
209
        return self._io
7✔
210

211
    def fileno(self):
7✔
212
        address = (self._host, self._port)
7✔
213
        r_fd, _ = Mocket.get_pair(address)
7✔
214
        if not r_fd:
7✔
215
            r_fd, w_fd = os.pipe()
7✔
216
            Mocket.set_pair(address, (r_fd, w_fd))
7✔
217
        return r_fd
7✔
218

219
    def gettimeout(self):
7✔
220
        return self.timeout
7✔
221

222
    def setsockopt(self, family, type, proto):
7✔
223
        self.family = family
7✔
224
        self.type = type
7✔
225
        self.proto = proto
7✔
226

227
        if self.true_socket:
7✔
228
            self.true_socket.setsockopt(family, type, proto)
7✔
229

230
    def settimeout(self, timeout):
7✔
231
        self.timeout = timeout
7✔
232

233
    @staticmethod
7✔
234
    def getsockopt(level, optname, buflen=None):
7✔
235
        return socket.SOCK_STREAM
×
236

237
    def do_handshake(self):
7✔
238
        self._did_handshake = True
7✔
239

240
    def getpeername(self):
7✔
241
        return self._address
7✔
242

243
    def setblocking(self, block):
7✔
244
        self.settimeout(None) if block else self.settimeout(0.0)
7✔
245

246
    def getblocking(self):
7✔
247
        return self.gettimeout() is None
7✔
248

249
    def getsockname(self):
7✔
250
        return socket.gethostbyname(self._address[0]), self._address[1]
7✔
251

252
    def getpeercert(self, *args, **kwargs):
7✔
253
        if not (self._host and self._port):
7✔
254
            self._address = self._host, self._port = Mocket._address
7✔
255

256
        now = datetime.now()
7✔
257
        shift = now + timedelta(days=30 * 12)
7✔
258
        return {
7✔
259
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
260
            "subjectAltName": (
261
                ("DNS", f"*.{self._host}"),
262
                ("DNS", self._host),
263
                ("DNS", "*"),
264
            ),
265
            "subject": (
266
                (("organizationName", f"*.{self._host}"),),
267
                (("organizationalUnitName", "Domain Control Validated"),),
268
                (("commonName", f"*.{self._host}"),),
269
            ),
270
        }
271

272
    def unwrap(self):
7✔
UNCOV
273
        return self
4✔
274

275
    def write(self, data):
7✔
276
        return self.send(encode_to_bytes(data))
7✔
277

278
    def connect(self, address):
7✔
279
        self._address = self._host, self._port = address
7✔
280
        Mocket._address = address
7✔
281

282
    def makefile(self, mode="r", bufsize=-1):
7✔
283
        self._mode = mode
7✔
284
        self._bufsize = bufsize
7✔
285
        return self.io
7✔
286

287
    def get_entry(self, data):
7✔
288
        return Mocket.get_entry(self._host, self._port, data)
7✔
289

290
    def sendall(self, data, entry=None, *args, **kwargs):
7✔
291
        if entry is None:
7✔
292
            entry = self.get_entry(data)
7✔
293

294
        if entry:
7✔
295
            consume_response = entry.collect(data)
7✔
296
            response = entry.get_response() if consume_response is not False else None
7✔
297
        else:
298
            response = self.true_sendall(data, *args, **kwargs)
7✔
299

300
        if response is not None:
7✔
301
            self.io.seek(0)
7✔
302
            self.io.write(response)
7✔
303
            self.io.truncate()
7✔
304
            self.io.seek(0)
7✔
305

306
    def read(self, buffersize):
7✔
307
        rv = self.io.read(buffersize)
7✔
308
        if rv:
7✔
309
            self._sent_non_empty_bytes = True
7✔
310
        if self._did_handshake and not self._sent_non_empty_bytes:
7✔
311
            raise ssl.SSLWantReadError("The operation did not complete (read)")
3✔
312
        return rv
7✔
313

314
    def recv_into(self, buffer, buffersize=None, flags=None):
7✔
315
        if hasattr(buffer, "write"):
7✔
316
            return buffer.write(self.read(buffersize))
7✔
317
        # buffer is a memoryview
318
        data = self.read(buffersize)
3✔
319
        if data:
3✔
320
            buffer[: len(data)] = data
×
321
        return len(data)
3✔
322

323
    def recv(self, buffersize, flags=None):
7✔
324
        r_fd, _ = Mocket.get_pair((self._host, self._port))
7✔
325
        if r_fd:
7✔
326
            return os.read(r_fd, buffersize)
7✔
327
        data = self.read(buffersize)
7✔
328
        if data:
7✔
329
            return data
7✔
330
        # used by Redis mock
331
        exc = BlockingIOError()
7✔
332
        exc.errno = errno.EWOULDBLOCK
7✔
333
        exc.args = (0,)
7✔
334
        raise exc
7✔
335

336
    def true_sendall(self, data, *args, **kwargs):
7✔
337
        if not MocketMode().is_allowed((self._host, self._port)):
7✔
338
            MocketMode.raise_not_allowed()
7✔
339

340
        req = decode_from_bytes(data)
7✔
341
        # make request unique again
342
        req_signature = _hash_request(hasher, req)
7✔
343
        # port should be always a string
344
        port = text_type(self._port)
7✔
345

346
        # prepare responses dictionary
347
        responses = {}
7✔
348

349
        if Mocket.get_truesocket_recording_dir():
7✔
350
            path = os.path.join(
7✔
351
                Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
352
            )
353
            # check if there's already a recorded session dumped to a JSON file
354
            try:
7✔
355
                with open(path) as f:
7✔
356
                    responses = json.load(f)
7✔
357
            # if not, create a new dictionary
358
            except (FileNotFoundError, JSONDecodeError):
7✔
359
                pass
7✔
360

361
        try:
7✔
362
            try:
7✔
363
                response_dict = responses[self._host][port][req_signature]
7✔
364
            except KeyError:
7✔
365
                if hasher is not hashlib.md5:
7✔
366
                    # Fallback for backwards compatibility
367
                    req_signature = _hash_request(hashlib.md5, req)
7✔
368
                    response_dict = responses[self._host][port][req_signature]
7✔
369
                else:
370
                    raise
×
371
        except KeyError:
7✔
372
            # preventing next KeyError exceptions
373
            responses.setdefault(self._host, {})
7✔
374
            responses[self._host].setdefault(port, {})
7✔
375
            responses[self._host][port].setdefault(req_signature, {})
7✔
376
            response_dict = responses[self._host][port][req_signature]
7✔
377

378
        # try to get the response from the dictionary
379
        try:
7✔
380
            encoded_response = hexload(response_dict["response"])
7✔
381
        # if not available, call the real sendall
382
        except KeyError:
7✔
383
            host, port = self._host, self._port
7✔
384
            host = true_gethostbyname(host)
7✔
385

386
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
7✔
387
                self.true_socket = true_urllib3_ssl_wrap_socket(
7✔
388
                    self.true_socket,
389
                    **self.kwargs,
390
                )
391

392
            with contextlib.suppress(OSError, ValueError):
7✔
393
                # already connected
394
                self.true_socket.connect((host, port))
7✔
395
            self.true_socket.sendall(data, *args, **kwargs)
7✔
396
            encoded_response = b""
7✔
397
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L13
398
            while True:
5✔
399
                if (
7✔
400
                    not select.select([self.true_socket], [], [], 0.1)[0]
401
                    and encoded_response
402
                ):
403
                    break
7✔
404
                recv = self.true_socket.recv(self._buflen)
7✔
405

406
                if not recv and encoded_response:
7✔
407
                    break
7✔
408
                encoded_response += recv
7✔
409

410
            # dump the resulting dictionary to a JSON file
411
            if Mocket.get_truesocket_recording_dir():
7✔
412
                # update the dictionary with request and response lines
413
                response_dict["request"] = req
7✔
414
                response_dict["response"] = hexdump(encoded_response)
7✔
415

416
                with open(path, mode="w") as f:
7✔
417
                    f.write(
7✔
418
                        decode_from_bytes(
419
                            json.dumps(responses, indent=4, sort_keys=True)
420
                        )
421
                    )
422

423
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
424
        return encoded_response
7✔
425

426
    def send(self, data, *args, **kwargs):  # pragma: no cover
427
        entry = self.get_entry(data)
428
        if not entry or (entry and self._entry != entry):
429
            kwargs["entry"] = entry
430
            self.sendall(data, *args, **kwargs)
431
        else:
432
            req = Mocket.last_request()
433
            if hasattr(req, "add_data"):
434
                req.add_data(data)
435
        self._entry = entry
436
        return len(data)
437

438
    def close(self):
7✔
439
        if self.true_socket and not self.true_socket._closed:
7✔
440
            self.true_socket.close()
7✔
441
        self._fd = None
7✔
442

443
    def __getattr__(self, name):
7✔
444
        """Do nothing catchall function, for methods like shutdown()"""
445

446
        def do_nothing(*args, **kwargs):
7✔
447
            pass
7✔
448

449
        return do_nothing
7✔
450

451

452
class Mocket:
7✔
453
    _socket_pairs = {}
7✔
454
    _address = (None, None)
7✔
455
    _entries = collections.defaultdict(list)
7✔
456
    _requests = []
7✔
457
    _namespace = text_type(id(_entries))
7✔
458
    _truesocket_recording_dir = None
7✔
459

460
    @classmethod
7✔
461
    def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
7✔
462
        """
463
        Given the id() of the caller, return a pair of file descriptors
464
        as a tuple of two integers: (<read_fd>, <write_fd>)
465
        """
466
        return cls._socket_pairs.get(address, (None, None))
7✔
467

468
    @classmethod
7✔
469
    def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
7✔
470
        """
471
        Store a pair of file descriptors under the key `id_`
472
        as a tuple of two integers: (<read_fd>, <write_fd>)
473
        """
474
        cls._socket_pairs[address] = pair
7✔
475

476
    @classmethod
7✔
477
    def register(cls, *entries):
7✔
478
        for entry in entries:
7✔
479
            cls._entries[entry.location].append(entry)
7✔
480

481
    @classmethod
7✔
482
    def get_entry(cls, host, port, data):
7✔
483
        host = host or Mocket._address[0]
7✔
484
        port = port or Mocket._address[1]
7✔
485
        entries = cls._entries.get((host, port), [])
7✔
486
        for entry in entries:
7✔
487
            if entry.can_handle(data):
7✔
488
                return entry
7✔
489

490
    @classmethod
7✔
491
    def collect(cls, data):
7✔
492
        cls.request_list().append(data)
7✔
493

494
    @classmethod
7✔
495
    def reset(cls):
7✔
496
        for r_fd, w_fd in cls._socket_pairs.values():
7✔
497
            os.close(r_fd)
7✔
498
            os.close(w_fd)
7✔
499
        cls._socket_pairs = {}
7✔
500
        cls._entries = collections.defaultdict(list)
7✔
501
        cls._requests = []
7✔
502

503
    @classmethod
7✔
504
    def last_request(cls):
7✔
505
        if cls.has_requests():
7✔
506
            return cls.request_list()[-1]
7✔
507

508
    @classmethod
7✔
509
    def request_list(cls):
7✔
510
        return cls._requests
7✔
511

512
    @classmethod
7✔
513
    def remove_last_request(cls):
7✔
514
        if cls.has_requests():
7✔
515
            del cls._requests[-1]
7✔
516

517
    @classmethod
7✔
518
    def has_requests(cls):
7✔
519
        return bool(cls.request_list())
7✔
520

521
    @staticmethod
7✔
522
    def enable(namespace=None, truesocket_recording_dir=None):
7✔
523
        Mocket._namespace = namespace
7✔
524
        Mocket._truesocket_recording_dir = truesocket_recording_dir
7✔
525

526
        if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
7✔
527
            # JSON dumps will be saved here
528
            raise AssertionError
×
529

530
        socket.socket = socket.__dict__["socket"] = MocketSocket
7✔
531
        socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
7✔
532
        socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
7✔
533
        socket.create_connection = socket.__dict__["create_connection"] = (
7✔
534
            create_connection
535
        )
536
        socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
7✔
537
        socket.gethostbyname = socket.__dict__["gethostbyname"] = (
7✔
538
            lambda host: "127.0.0.1"
539
        )
540
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
7✔
541
            lambda host, port, family=None, socktype=None, proto=None, flags=None: [
542
                (2, 1, 6, "", (host, port))
543
            ]
544
        )
545
        socket.socketpair = socket.__dict__["socketpair"] = socketpair
7✔
546
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
7✔
547
        ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
7✔
548
        socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type(
7✔
549
            "\x7f\x00\x00\x01", "utf-8"
550
        )
551
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
7✔
552
            FakeSSLContext.wrap_socket
553
        )
554
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
7✔
555
            "ssl_wrap_socket"
556
        ] = FakeSSLContext.wrap_socket
557
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
7✔
558
            FakeSSLContext.wrap_socket
559
        )
560
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
7✔
561
            "ssl_wrap_socket"
562
        ] = FakeSSLContext.wrap_socket
563
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
7✔
564
            "match_hostname"
565
        ] = lambda *args: None
566
        if pyopenssl_override:  # pragma: no cover
567
            # Take out the pyopenssl version - use the default implementation
568
            extract_from_urllib3()
569
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
570
            aiohttp_make_ssl_context_cache_clear()
571

572
    @staticmethod
7✔
573
    def disable():
7✔
574
        socket.socket = socket.__dict__["socket"] = true_socket
7✔
575
        socket._socketobject = socket.__dict__["_socketobject"] = true_socket
7✔
576
        socket.SocketType = socket.__dict__["SocketType"] = true_socket
7✔
577
        socket.create_connection = socket.__dict__["create_connection"] = (
7✔
578
            true_create_connection
579
        )
580
        socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
7✔
581
        socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
7✔
582
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
7✔
583
        socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
7✔
584
        if true_ssl_wrap_socket:
7✔
585
            ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
5✔
586
        ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
7✔
587
        socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
7✔
588
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
7✔
589
            true_urllib3_wrap_socket
590
        )
591
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
7✔
592
            "ssl_wrap_socket"
593
        ] = true_urllib3_ssl_wrap_socket
594
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
7✔
595
            true_urllib3_ssl_wrap_socket
596
        )
597
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
7✔
598
            "ssl_wrap_socket"
599
        ] = true_urllib3_ssl_wrap_socket
600
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
7✔
601
            "match_hostname"
602
        ] = true_urllib3_match_hostname
603
        Mocket.reset()
7✔
604
        if pyopenssl_override:  # pragma: no cover
605
            # Put the pyopenssl version back in place
606
            inject_into_urllib3()
607
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
608
            aiohttp_make_ssl_context_cache_clear()
609

610
    @classmethod
7✔
611
    def get_namespace(cls):
7✔
612
        return cls._namespace
7✔
613

614
    @classmethod
7✔
615
    def get_truesocket_recording_dir(cls):
7✔
616
        return cls._truesocket_recording_dir
7✔
617

618
    @classmethod
7✔
619
    def assert_fail_if_entries_not_served(cls):
7✔
620
        """Mocket checks that all entries have been served at least once."""
621
        if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
7✔
622
            raise AssertionError("Some Mocket entries have not been served")
7✔
623

624

625
class MocketEntry:
7✔
626
    class Response(byte_type):
7✔
627
        @property
7✔
628
        def data(self):
7✔
629
            return self
7✔
630

631
    response_index = 0
7✔
632
    request_cls = byte_type
7✔
633
    response_cls = Response
7✔
634
    responses = None
7✔
635
    _served = None
7✔
636

637
    def __init__(self, location, responses):
7✔
638
        self._served = False
7✔
639
        self.location = location
7✔
640

641
        if not isinstance(responses, collections_abc.Iterable) or isinstance(
7✔
642
            responses, basestring
643
        ):
644
            responses = [responses]
7✔
645

646
        if not responses:
7✔
647
            self.responses = [self.response_cls(encode_to_bytes(""))]
7✔
648
        else:
649
            self.responses = []
7✔
650
            for r in responses:
7✔
651
                if not isinstance(r, BaseException) and not getattr(r, "data", False):
7✔
652
                    if isinstance(r, text_type):
7✔
653
                        r = encode_to_bytes(r)
7✔
654
                    r = self.response_cls(r)
7✔
655
                self.responses.append(r)
7✔
656

657
    def __repr__(self):
658
        return f"{self.__class__.__name__}(location={self.location})"
659

660
    @staticmethod
7✔
661
    def can_handle(data):
7✔
662
        return True
7✔
663

664
    def collect(self, data):
7✔
665
        req = self.request_cls(data)
7✔
666
        Mocket.collect(req)
7✔
667

668
    def get_response(self):
7✔
669
        response = self.responses[self.response_index]
7✔
670
        if self.response_index < len(self.responses) - 1:
7✔
671
            self.response_index += 1
7✔
672

673
        self._served = True
7✔
674

675
        if isinstance(response, BaseException):
7✔
676
            raise response
7✔
677

678
        return response.data
7✔
679

680

681
class Mocketizer:
7✔
682
    def __init__(
7✔
683
        self,
684
        instance=None,
685
        namespace=None,
686
        truesocket_recording_dir=None,
687
        strict_mode=False,
688
        strict_mode_allowed=None,
689
    ):
690
        self.instance = instance
7✔
691
        self.truesocket_recording_dir = truesocket_recording_dir
7✔
692
        self.namespace = namespace or text_type(id(self))
7✔
693
        MocketMode().STRICT = strict_mode
7✔
694
        if strict_mode:
7✔
695
            MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
7✔
696
        elif strict_mode_allowed:
7✔
697
            raise ValueError(
7✔
698
                "Allowed locations are only accepted when STRICT mode is active."
699
            )
700

701
    def enter(self):
7✔
702
        Mocket.enable(
7✔
703
            namespace=self.namespace,
704
            truesocket_recording_dir=self.truesocket_recording_dir,
705
        )
706
        if self.instance:
7✔
707
            self.check_and_call("mocketize_setup")
7✔
708

709
    def __enter__(self):
7✔
710
        self.enter()
7✔
711
        return self
7✔
712

713
    def exit(self):
7✔
714
        if self.instance:
7✔
715
            self.check_and_call("mocketize_teardown")
7✔
716
        Mocket.disable()
7✔
717

718
    def __exit__(self, type, value, tb):
7✔
719
        self.exit()
7✔
720

721
    async def __aenter__(self, *args, **kwargs):
7✔
722
        self.enter()
7✔
723
        return self
7✔
724

725
    async def __aexit__(self, *args, **kwargs):
7✔
726
        self.exit()
7✔
727

728
    def check_and_call(self, method_name):
7✔
729
        method = getattr(self.instance, method_name, None)
7✔
730
        if callable(method):
7✔
731
            method()
7✔
732

733
    @staticmethod
7✔
734
    def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
7✔
735
        instance = args[0] if args else None
7✔
736
        namespace = None
7✔
737
        if truesocket_recording_dir:
7✔
738
            namespace = ".".join(
7✔
739
                (
740
                    instance.__class__.__module__,
741
                    instance.__class__.__name__,
742
                    test.__name__,
743
                )
744
            )
745

746
        return Mocketizer(
7✔
747
            instance,
748
            namespace=namespace,
749
            truesocket_recording_dir=truesocket_recording_dir,
750
            strict_mode=strict_mode,
751
            strict_mode_allowed=strict_mode_allowed,
752
        )
753

754

755
def wrapper(
7✔
756
    test,
757
    truesocket_recording_dir=None,
758
    strict_mode=False,
759
    strict_mode_allowed=None,
760
    *args,
761
    **kwargs,
762
):
763
    with Mocketizer.factory(
7✔
764
        test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
765
    ):
766
        return test(*args, **kwargs)
7✔
767

768

769
mocketize = get_mocketize(wrapper_=wrapper)
7✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc