• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 9315616737

31 May 2024 08:56AM UTC coverage: 98.824% (+0.3%) from 98.537%
9315616737

push

github

mindflayer
Bump version.

1 of 1 new or added line in 1 file covered. (100.0%)

11 existing lines in 1 file now uncovered.

840 of 850 relevant lines covered (98.82%)

4.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.82
/mocket/mocket.py
1
import collections
5✔
2
import collections.abc as collections_abc
5✔
3
import contextlib
5✔
4
import errno
5✔
5
import hashlib
5✔
6
import itertools
5✔
7
import json
5✔
8
import os
5✔
9
import select
5✔
10
import socket
5✔
11
import ssl
5✔
12
from datetime import datetime, timedelta
5✔
13
from json.decoder import JSONDecodeError
5✔
14
from typing import Optional, Tuple
5✔
15

16
import urllib3
5✔
17
from urllib3.connection import match_hostname as urllib3_match_hostname
5✔
18
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
5✔
19

20
try:
5✔
21
    from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
5✔
22
except ImportError:
23
    urllib3_wrap_socket = None
24

25

26
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
5✔
27
from .utils import (
5✔
28
    SSL_PROTOCOL,
29
    MocketMode,
30
    MocketSocketCore,
31
    get_mocketize,
32
    hexdump,
33
    hexload,
34
)
35

36
xxh32 = None
5✔
37
try:
5✔
38
    from xxhash import xxh32
5✔
39
except ImportError:  # pragma: no cover
40
    with contextlib.suppress(ImportError):
41
        from xxhash_cffi import xxh32
42
hasher = xxh32 or hashlib.md5
5✔
43

44
try:  # pragma: no cover
45
    from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
46

47
    pyopenssl_override = True
48
except ImportError:
49
    pyopenssl_override = False
50

51
try:  # pragma: no cover
52
    from aiohttp import TCPConnector
53

54
    aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
55
except (ImportError, AttributeError):
×
UNCOV
56
    aiohttp_make_ssl_context_cache_clear = None
×
57

58

59
true_socket = socket.socket
5✔
60
true_create_connection = socket.create_connection
5✔
61
true_gethostbyname = socket.gethostbyname
5✔
62
true_gethostname = socket.gethostname
5✔
63
true_getaddrinfo = socket.getaddrinfo
5✔
64
true_socketpair = socket.socketpair
5✔
65
true_ssl_wrap_socket = getattr(
5✔
66
    ssl, "wrap_socket", None
67
)  # in Py3.12 it's only under SSLContext
68
true_ssl_socket = ssl.SSLSocket
5✔
69
true_ssl_context = ssl.SSLContext
5✔
70
true_inet_pton = socket.inet_pton
5✔
71
true_urllib3_wrap_socket = urllib3_wrap_socket
5✔
72
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
5✔
73
true_urllib3_match_hostname = urllib3_match_hostname
5✔
74

75

76
class SuperFakeSSLContext:
5✔
77
    """For Python 3.6 and newer."""
78

79
    class FakeSetter(int):
5✔
80
        def __set__(self, *args):
5✔
81
            pass
5✔
82

83
    minimum_version = FakeSetter()
5✔
84
    options = FakeSetter()
5✔
85
    verify_mode = FakeSetter()
5✔
86

87

88
class FakeSSLContext(SuperFakeSSLContext):
5✔
89
    DUMMY_METHODS = (
5✔
90
        "load_default_certs",
91
        "load_verify_locations",
92
        "set_alpn_protocols",
93
        "set_ciphers",
94
        "set_default_verify_paths",
95
    )
96
    sock = None
5✔
97
    post_handshake_auth = None
5✔
98
    _check_hostname = False
5✔
99

100
    @property
5✔
101
    def check_hostname(self):
5✔
102
        return self._check_hostname
4✔
103

104
    @check_hostname.setter
5✔
105
    def check_hostname(self, *args):
5✔
106
        self._check_hostname = False
5✔
107

108
    def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
5✔
109
        self._set_dummy_methods()
5✔
110

111
        if isinstance(sock, MocketSocket):
5✔
112
            self.sock = sock
×
113
            self.sock._host = server_hostname
×
UNCOV
114
            self.sock.true_socket = true_ssl_socket(
×
115
                sock=self.sock.true_socket,
116
                server_hostname=server_hostname,
117
                _context=true_ssl_context(protocol=SSL_PROTOCOL),
118
            )
119
        elif isinstance(sock, int) and true_ssl_context:
5✔
120
            self.context = true_ssl_context(sock)
5✔
121

122
    def _set_dummy_methods(self):
5✔
123
        def dummy_method(*args, **kwargs):
5✔
124
            pass
5✔
125

126
        for m in self.DUMMY_METHODS:
5✔
127
            setattr(self, m, dummy_method)
5✔
128

129
    @staticmethod
5✔
130
    def wrap_socket(sock=sock, *args, **kwargs):
5✔
131
        sock.kwargs = kwargs
5✔
132
        sock._secure_socket = True
5✔
133
        return sock
5✔
134

135
    @staticmethod
5✔
136
    def wrap_bio(incoming, outcoming, *args, **kwargs):
5✔
137
        ssl_obj = MocketSocket()
5✔
138
        ssl_obj._host = kwargs["server_hostname"]
5✔
139
        return ssl_obj
5✔
140

141
    def __getattr__(self, name):
5✔
142
        if self.sock is not None:
5✔
UNCOV
143
            return getattr(self.sock, name)
×
144

145

146
def create_connection(address, timeout=None, source_address=None):
5✔
147
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
5✔
148
    if timeout:
5✔
149
        s.settimeout(timeout)
5✔
150
    s.connect(address)
5✔
151
    return s
5✔
152

153

154
def socketpair(*args, **kwargs):
5✔
155
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
156
    import _socket
5✔
157

158
    return _socket.socketpair(*args, **kwargs)
5✔
159

160

161
def _hash_request(h, req):
5✔
162
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
5✔
163

164

165
class MocketSocket:
5✔
166
    timeout = None
5✔
167
    _fd = None
5✔
168
    family = None
5✔
169
    type = None
5✔
170
    proto = None
5✔
171
    _host = None
5✔
172
    _port = None
5✔
173
    _address = None
5✔
174
    cipher = lambda s: ("ADH", "AES256", "SHA")
5✔
175
    compression = lambda s: ssl.OP_NO_COMPRESSION
5✔
176
    _mode = None
5✔
177
    _bufsize = None
5✔
178
    _secure_socket = False
5✔
179
    _did_handshake = False
5✔
180
    _sent_non_empty_bytes = False
5✔
181
    _io = None
5✔
182

183
    def __init__(
5✔
184
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
185
    ):
186
        self.true_socket = true_socket(family, type, proto)
5✔
187
        self._buflen = 65536
5✔
188
        self._entry = None
5✔
189
        self.family = int(family)
5✔
190
        self.type = int(type)
5✔
191
        self.proto = int(proto)
5✔
192
        self._truesocket_recording_dir = None
5✔
193
        self.kwargs = kwargs
5✔
194

195
    def __str__(self):
196
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
197

198
    def __enter__(self):
5✔
199
        return self
5✔
200

201
    def __exit__(self, exc_type, exc_val, exc_tb):
5✔
202
        self.close()
5✔
203

204
    @property
5✔
205
    def io(self):
5✔
206
        if self._io is None:
5✔
207
            self._io = MocketSocketCore((self._host, self._port))
5✔
208
        return self._io
5✔
209

210
    def fileno(self):
5✔
211
        address = (self._host, self._port)
5✔
212
        r_fd, _ = Mocket.get_pair(address)
5✔
213
        if not r_fd:
5✔
214
            r_fd, w_fd = os.pipe()
5✔
215
            Mocket.set_pair(address, (r_fd, w_fd))
5✔
216
        return r_fd
5✔
217

218
    def gettimeout(self):
5✔
219
        return self.timeout
5✔
220

221
    def setsockopt(self, family, type, proto):
5✔
222
        self.family = family
5✔
223
        self.type = type
5✔
224
        self.proto = proto
5✔
225

226
        if self.true_socket:
5✔
227
            self.true_socket.setsockopt(family, type, proto)
5✔
228

229
    def settimeout(self, timeout):
5✔
230
        self.timeout = timeout
5✔
231

232
    @staticmethod
5✔
233
    def getsockopt(level, optname, buflen=None):
5✔
UNCOV
234
        return socket.SOCK_STREAM
×
235

236
    def do_handshake(self):
5✔
237
        self._did_handshake = True
5✔
238

239
    def getpeername(self):
5✔
240
        return self._address
5✔
241

242
    def setblocking(self, block):
5✔
243
        self.settimeout(None) if block else self.settimeout(0.0)
5✔
244

245
    def getblocking(self):
5✔
246
        return self.gettimeout() is None
5✔
247

248
    def getsockname(self):
5✔
249
        return socket.gethostbyname(self._address[0]), self._address[1]
5✔
250

251
    def getpeercert(self, *args, **kwargs):
5✔
252
        if not (self._host and self._port):
5✔
253
            self._address = self._host, self._port = Mocket._address
5✔
254

255
        now = datetime.now()
5✔
256
        shift = now + timedelta(days=30 * 12)
5✔
257
        return {
5✔
258
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
259
            "subjectAltName": (
260
                ("DNS", f"*.{self._host}"),
261
                ("DNS", self._host),
262
                ("DNS", "*"),
263
            ),
264
            "subject": (
265
                (("organizationName", f"*.{self._host}"),),
266
                (("organizationalUnitName", "Domain Control Validated"),),
267
                (("commonName", f"*.{self._host}"),),
268
            ),
269
        }
270

271
    def unwrap(self):
5✔
272
        return self
3✔
273

274
    def write(self, data):
5✔
275
        return self.send(encode_to_bytes(data))
5✔
276

277
    def connect(self, address):
5✔
278
        self._address = self._host, self._port = address
5✔
279
        Mocket._address = address
5✔
280

281
    def makefile(self, mode="r", bufsize=-1):
5✔
282
        self._mode = mode
5✔
283
        self._bufsize = bufsize
5✔
284
        return self.io
5✔
285

286
    def get_entry(self, data):
5✔
287
        return Mocket.get_entry(self._host, self._port, data)
5✔
288

289
    def sendall(self, data, entry=None, *args, **kwargs):
5✔
290
        if entry is None:
5✔
291
            entry = self.get_entry(data)
5✔
292

293
        if entry:
5✔
294
            consume_response = entry.collect(data)
5✔
295
            response = entry.get_response() if consume_response is not False else None
5✔
296
        else:
297
            response = self.true_sendall(data, *args, **kwargs)
5✔
298

299
        if response is not None:
5✔
300
            self.io.seek(0)
5✔
301
            self.io.write(response)
5✔
302
            self.io.truncate()
5✔
303
            self.io.seek(0)
5✔
304

305
    def read(self, buffersize):
5✔
306
        rv = self.io.read(buffersize)
5✔
307
        if rv:
5✔
308
            self._sent_non_empty_bytes = True
5✔
309
        if self._did_handshake and not self._sent_non_empty_bytes:
5✔
UNCOV
310
            raise ssl.SSLWantReadError("The operation did not complete (read)")
2✔
311
        return rv
5✔
312

313
    def recv_into(self, buffer, buffersize=None, flags=None):
5✔
314
        if hasattr(buffer, "write"):
5✔
315
            return buffer.write(self.read(buffersize))
5✔
316
        # buffer is a memoryview
UNCOV
317
        data = self.read(buffersize)
2✔
UNCOV
318
        if data:
2✔
UNCOV
319
            buffer[: len(data)] = data
×
UNCOV
320
        return len(data)
2✔
321

322
    def recv(self, buffersize, flags=None):
5✔
323
        r_fd, _ = Mocket.get_pair((self._host, self._port))
5✔
324
        if r_fd:
5✔
325
            return os.read(r_fd, buffersize)
5✔
326
        data = self.read(buffersize)
5✔
327
        if data:
5✔
328
            return data
5✔
329
        # used by Redis mock
330
        exc = BlockingIOError()
5✔
331
        exc.errno = errno.EWOULDBLOCK
5✔
332
        exc.args = (0,)
5✔
333
        raise exc
5✔
334

335
    def true_sendall(self, data, *args, **kwargs):
5✔
336
        if not MocketMode().is_allowed((self._host, self._port)):
5✔
337
            MocketMode.raise_not_allowed()
5✔
338

339
        req = decode_from_bytes(data)
5✔
340
        # make request unique again
341
        req_signature = _hash_request(hasher, req)
5✔
342
        # port should be always a string
343
        port = text_type(self._port)
5✔
344

345
        # prepare responses dictionary
346
        responses = {}
5✔
347

348
        if Mocket.get_truesocket_recording_dir():
5✔
349
            path = os.path.join(
5✔
350
                Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
351
            )
352
            # check if there's already a recorded session dumped to a JSON file
353
            try:
5✔
354
                with open(path) as f:
5✔
355
                    responses = json.load(f)
5✔
356
            # if not, create a new dictionary
357
            except (FileNotFoundError, JSONDecodeError):
5✔
358
                pass
5✔
359

360
        try:
5✔
361
            try:
5✔
362
                response_dict = responses[self._host][port][req_signature]
5✔
363
            except KeyError:
5✔
364
                if hasher is not hashlib.md5:
5✔
365
                    # Fallback for backwards compatibility
366
                    req_signature = _hash_request(hashlib.md5, req)
5✔
367
                    response_dict = responses[self._host][port][req_signature]
5✔
368
                else:
UNCOV
369
                    raise
×
370
        except KeyError:
5✔
371
            # preventing next KeyError exceptions
372
            responses.setdefault(self._host, {})
5✔
373
            responses[self._host].setdefault(port, {})
5✔
374
            responses[self._host][port].setdefault(req_signature, {})
5✔
375
            response_dict = responses[self._host][port][req_signature]
5✔
376

377
        # try to get the response from the dictionary
378
        try:
5✔
379
            encoded_response = hexload(response_dict["response"])
5✔
380
        # if not available, call the real sendall
381
        except KeyError:
5✔
382
            host, port = self._host, self._port
5✔
383
            host = true_gethostbyname(host)
5✔
384

385
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
5✔
386
                self.true_socket = true_urllib3_ssl_wrap_socket(
5✔
387
                    self.true_socket,
388
                    **self.kwargs,
389
                )
390

391
            with contextlib.suppress(OSError, ValueError):
5✔
392
                # already connected
393
                self.true_socket.connect((host, port))
5✔
394
            self.true_socket.sendall(data, *args, **kwargs)
5✔
395
            encoded_response = b""
5✔
396
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L13
397
            while True:
3✔
398
                if (
5✔
399
                    not select.select([self.true_socket], [], [], 0.1)[0]
400
                    and encoded_response
401
                ):
402
                    break
5✔
403
                recv = self.true_socket.recv(self._buflen)
5✔
404

405
                if not recv and encoded_response:
5✔
406
                    break
5✔
407
                encoded_response += recv
5✔
408

409
            # dump the resulting dictionary to a JSON file
410
            if Mocket.get_truesocket_recording_dir():
5✔
411
                # update the dictionary with request and response lines
412
                response_dict["request"] = req
5✔
413
                response_dict["response"] = hexdump(encoded_response)
5✔
414

415
                with open(path, mode="w") as f:
5✔
416
                    f.write(
5✔
417
                        decode_from_bytes(
418
                            json.dumps(responses, indent=4, sort_keys=True)
419
                        )
420
                    )
421

422
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
423
        return encoded_response
5✔
424

425
    def send(self, data, *args, **kwargs):  # pragma: no cover
426
        entry = self.get_entry(data)
427
        if not entry or (entry and self._entry != entry):
428
            kwargs["entry"] = entry
429
            self.sendall(data, *args, **kwargs)
430
        else:
431
            req = Mocket.last_request()
432
            if hasattr(req, "add_data"):
433
                req.add_data(data)
434
        self._entry = entry
435
        return len(data)
436

437
    def close(self):
5✔
438
        if self.true_socket and not self.true_socket._closed:
5✔
439
            self.true_socket.close()
5✔
440
        self._fd = None
5✔
441

442
    def __getattr__(self, name):
5✔
443
        """Do nothing catchall function, for methods like shutdown()"""
444

445
        def do_nothing(*args, **kwargs):
5✔
446
            pass
5✔
447

448
        return do_nothing
5✔
449

450

451
class Mocket:
5✔
452
    _socket_pairs = {}
5✔
453
    _address = (None, None)
5✔
454
    _entries = collections.defaultdict(list)
5✔
455
    _requests = []
5✔
456
    _namespace = text_type(id(_entries))
5✔
457
    _truesocket_recording_dir = None
5✔
458

459
    @classmethod
5✔
460
    def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]:
5✔
461
        """
462
        Given the id() of the caller, return a pair of file descriptors
463
        as a tuple of two integers: (<read_fd>, <write_fd>)
464
        """
465
        return cls._socket_pairs.get(address, (None, None))
5✔
466

467
    @classmethod
5✔
468
    def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None:
5✔
469
        """
470
        Store a pair of file descriptors under the key `id_`
471
        as a tuple of two integers: (<read_fd>, <write_fd>)
472
        """
473
        cls._socket_pairs[address] = pair
5✔
474

475
    @classmethod
5✔
476
    def register(cls, *entries):
5✔
477
        for entry in entries:
5✔
478
            cls._entries[entry.location].append(entry)
5✔
479

480
    @classmethod
5✔
481
    def get_entry(cls, host, port, data):
5✔
482
        host = host or Mocket._address[0]
5✔
483
        port = port or Mocket._address[1]
5✔
484
        entries = cls._entries.get((host, port), [])
5✔
485
        for entry in entries:
5✔
486
            if entry.can_handle(data):
5✔
487
                return entry
5✔
488

489
    @classmethod
5✔
490
    def collect(cls, data):
5✔
491
        cls.request_list().append(data)
5✔
492

493
    @classmethod
5✔
494
    def reset(cls):
5✔
495
        for r_fd, w_fd in cls._socket_pairs.values():
5✔
496
            os.close(r_fd)
5✔
497
            os.close(w_fd)
5✔
498
        cls._socket_pairs = {}
5✔
499
        cls._entries = collections.defaultdict(list)
5✔
500
        cls._requests = []
5✔
501

502
    @classmethod
5✔
503
    def last_request(cls):
5✔
504
        if cls.has_requests():
5✔
505
            return cls.request_list()[-1]
5✔
506

507
    @classmethod
5✔
508
    def request_list(cls):
5✔
509
        return cls._requests
5✔
510

511
    @classmethod
5✔
512
    def remove_last_request(cls):
5✔
513
        if cls.has_requests():
5✔
514
            del cls._requests[-1]
5✔
515

516
    @classmethod
5✔
517
    def has_requests(cls):
5✔
518
        return bool(cls.request_list())
5✔
519

520
    @staticmethod
5✔
521
    def enable(namespace=None, truesocket_recording_dir=None):
5✔
522
        Mocket._namespace = namespace
5✔
523
        Mocket._truesocket_recording_dir = truesocket_recording_dir
5✔
524

525
        if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
5✔
526
            # JSON dumps will be saved here
UNCOV
527
            raise AssertionError
×
528

529
        socket.socket = socket.__dict__["socket"] = MocketSocket
5✔
530
        socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
5✔
531
        socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
5✔
532
        socket.create_connection = socket.__dict__["create_connection"] = (
5✔
533
            create_connection
534
        )
535
        socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
5✔
536
        socket.gethostbyname = socket.__dict__["gethostbyname"] = (
5✔
537
            lambda host: "127.0.0.1"
538
        )
539
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
5✔
540
            lambda host, port, family=None, socktype=None, proto=None, flags=None: [
541
                (2, 1, 6, "", (host, port))
542
            ]
543
        )
544
        socket.socketpair = socket.__dict__["socketpair"] = socketpair
5✔
545
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
5✔
546
        ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
5✔
547
        socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type(
5✔
548
            "\x7f\x00\x00\x01", "utf-8"
549
        )
550
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
5✔
551
            FakeSSLContext.wrap_socket
552
        )
553
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
5✔
554
            "ssl_wrap_socket"
555
        ] = FakeSSLContext.wrap_socket
556
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
5✔
557
            FakeSSLContext.wrap_socket
558
        )
559
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
5✔
560
            "ssl_wrap_socket"
561
        ] = FakeSSLContext.wrap_socket
562
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
5✔
563
            "match_hostname"
564
        ] = lambda *args: None
565
        if pyopenssl_override:  # pragma: no cover
566
            # Take out the pyopenssl version - use the default implementation
567
            extract_from_urllib3()
568
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
569
            aiohttp_make_ssl_context_cache_clear()
570

571
    @staticmethod
5✔
572
    def disable():
5✔
573
        socket.socket = socket.__dict__["socket"] = true_socket
5✔
574
        socket._socketobject = socket.__dict__["_socketobject"] = true_socket
5✔
575
        socket.SocketType = socket.__dict__["SocketType"] = true_socket
5✔
576
        socket.create_connection = socket.__dict__["create_connection"] = (
5✔
577
            true_create_connection
578
        )
579
        socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
5✔
580
        socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
5✔
581
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
5✔
582
        socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
5✔
583
        if true_ssl_wrap_socket:
5✔
584
            ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
4✔
585
        ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
5✔
586
        socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
5✔
587
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
5✔
588
            true_urllib3_wrap_socket
589
        )
590
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
5✔
591
            "ssl_wrap_socket"
592
        ] = true_urllib3_ssl_wrap_socket
593
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
5✔
594
            true_urllib3_ssl_wrap_socket
595
        )
596
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
5✔
597
            "ssl_wrap_socket"
598
        ] = true_urllib3_ssl_wrap_socket
599
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
5✔
600
            "match_hostname"
601
        ] = true_urllib3_match_hostname
602
        Mocket.reset()
5✔
603
        if pyopenssl_override:  # pragma: no cover
604
            # Put the pyopenssl version back in place
605
            inject_into_urllib3()
606
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
607
            aiohttp_make_ssl_context_cache_clear()
608

609
    @classmethod
5✔
610
    def get_namespace(cls):
5✔
611
        return cls._namespace
5✔
612

613
    @classmethod
5✔
614
    def get_truesocket_recording_dir(cls):
5✔
615
        return cls._truesocket_recording_dir
5✔
616

617
    @classmethod
5✔
618
    def assert_fail_if_entries_not_served(cls):
5✔
619
        """Mocket checks that all entries have been served at least once."""
620
        if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
5✔
621
            raise AssertionError("Some Mocket entries have not been served")
5✔
622

623

624
class MocketEntry:
5✔
625
    class Response(byte_type):
5✔
626
        @property
5✔
627
        def data(self):
5✔
628
            return self
5✔
629

630
    response_index = 0
5✔
631
    request_cls = byte_type
5✔
632
    response_cls = Response
5✔
633
    responses = None
5✔
634
    _served = None
5✔
635

636
    def __init__(self, location, responses):
5✔
637
        self._served = False
5✔
638
        self.location = location
5✔
639

640
        if not isinstance(responses, collections_abc.Iterable) or isinstance(
5✔
641
            responses, basestring
642
        ):
643
            responses = [responses]
5✔
644

645
        if not responses:
5✔
646
            self.responses = [self.response_cls(encode_to_bytes(""))]
5✔
647
        else:
648
            self.responses = []
5✔
649
            for r in responses:
5✔
650
                if not isinstance(r, BaseException) and not getattr(r, "data", False):
5✔
651
                    if isinstance(r, text_type):
5✔
652
                        r = encode_to_bytes(r)
5✔
653
                    r = self.response_cls(r)
5✔
654
                self.responses.append(r)
5✔
655

656
    def __repr__(self):
657
        return f"{self.__class__.__name__}(location={self.location})"
658

659
    @staticmethod
5✔
660
    def can_handle(data):
5✔
661
        return True
5✔
662

663
    def collect(self, data):
5✔
664
        req = self.request_cls(data)
5✔
665
        Mocket.collect(req)
5✔
666

667
    def get_response(self):
5✔
668
        response = self.responses[self.response_index]
5✔
669
        if self.response_index < len(self.responses) - 1:
5✔
670
            self.response_index += 1
5✔
671

672
        self._served = True
5✔
673

674
        if isinstance(response, BaseException):
5✔
675
            raise response
5✔
676

677
        return response.data
5✔
678

679

680
class Mocketizer:
5✔
681
    def __init__(
5✔
682
        self,
683
        instance=None,
684
        namespace=None,
685
        truesocket_recording_dir=None,
686
        strict_mode=False,
687
        strict_mode_allowed=None,
688
    ):
689
        self.instance = instance
5✔
690
        self.truesocket_recording_dir = truesocket_recording_dir
5✔
691
        self.namespace = namespace or text_type(id(self))
5✔
692
        MocketMode().STRICT = strict_mode
5✔
693
        if strict_mode:
5✔
694
            MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
5✔
695
        elif strict_mode_allowed:
5✔
696
            raise ValueError(
5✔
697
                "Allowed locations are only accepted when STRICT mode is active."
698
            )
699

700
    def enter(self):
5✔
701
        Mocket.enable(
5✔
702
            namespace=self.namespace,
703
            truesocket_recording_dir=self.truesocket_recording_dir,
704
        )
705
        if self.instance:
5✔
706
            self.check_and_call("mocketize_setup")
5✔
707

708
    def __enter__(self):
5✔
709
        self.enter()
5✔
710
        return self
5✔
711

712
    def exit(self):
5✔
713
        if self.instance:
5✔
714
            self.check_and_call("mocketize_teardown")
5✔
715
        Mocket.disable()
5✔
716

717
    def __exit__(self, type, value, tb):
5✔
718
        self.exit()
5✔
719

720
    async def __aenter__(self, *args, **kwargs):
5✔
721
        self.enter()
5✔
722
        return self
5✔
723

724
    async def __aexit__(self, *args, **kwargs):
5✔
725
        self.exit()
5✔
726

727
    def check_and_call(self, method_name):
5✔
728
        method = getattr(self.instance, method_name, None)
5✔
729
        if callable(method):
5✔
730
            method()
5✔
731

732
    @staticmethod
5✔
733
    def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
5✔
734
        instance = args[0] if args else None
5✔
735
        namespace = None
5✔
736
        if truesocket_recording_dir:
5✔
737
            namespace = ".".join(
5✔
738
                (
739
                    instance.__class__.__module__,
740
                    instance.__class__.__name__,
741
                    test.__name__,
742
                )
743
            )
744

745
        return Mocketizer(
5✔
746
            instance,
747
            namespace=namespace,
748
            truesocket_recording_dir=truesocket_recording_dir,
749
            strict_mode=strict_mode,
750
            strict_mode_allowed=strict_mode_allowed,
751
        )
752

753

754
def wrapper(
5✔
755
    test,
756
    truesocket_recording_dir=None,
757
    strict_mode=False,
758
    strict_mode_allowed=None,
759
    *args,
760
    **kwargs,
761
):
762
    with Mocketizer.factory(
5✔
763
        test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
764
    ):
765
        return test(*args, **kwargs)
5✔
766

767

768
mocketize = get_mocketize(wrapper_=wrapper)
5✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc