• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 7562750313

17 Jan 2024 10:50PM UTC coverage: 98.668%. Remained the same
7562750313

push

github

web-flow
Bump version

1 of 1 new or added line in 1 file covered. (100.0%)

4 existing lines in 1 file now uncovered.

815 of 826 relevant lines covered (98.67%)

5.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.23
/mocket/mocket.py
1
import collections
6✔
2
import collections.abc as collections_abc
6✔
3
import errno
6✔
4
import hashlib
6✔
5
import io
6✔
6
import itertools
6✔
7
import json
6✔
8
import os
6✔
9
import select
6✔
10
import socket
6✔
11
import ssl
6✔
12
from datetime import datetime, timedelta
6✔
13
from json.decoder import JSONDecodeError
6✔
14

15
import urllib3
6✔
16
from urllib3.connection import match_hostname as urllib3_match_hostname
6✔
17
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
6✔
18

19
try:
6✔
20
    from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
6✔
21
except ImportError:
22
    urllib3_wrap_socket = None
23

24
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
6✔
25
from .exceptions import StrictMocketException
6✔
26
from .utils import (
6✔
27
    SSL_PROTOCOL,
28
    MocketMode,
29
    MocketSocketCore,
30
    get_mocketize,
31
    hexdump,
32
    hexload,
33
)
34

35
xxh32 = None
6✔
36
try:
6✔
37
    from xxhash import xxh32
6✔
38
except ImportError:  # pragma: no cover
39
    try:
40
        from xxhash_cffi import xxh32
41
    except ImportError:
42
        pass
43
hasher = xxh32 or hashlib.md5
6✔
44

45
try:  # pragma: no cover
46
    from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
47

48
    pyopenssl_override = True
49
except ImportError:
50
    pyopenssl_override = False
51

52
try:  # pragma: no cover
53
    from aiohttp import TCPConnector
54

55
    aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
56
except (ImportError, AttributeError):
1✔
57
    aiohttp_make_ssl_context_cache_clear = None
1✔
58

59

60
true_socket = socket.socket
6✔
61
true_create_connection = socket.create_connection
6✔
62
true_gethostbyname = socket.gethostbyname
6✔
63
true_gethostname = socket.gethostname
6✔
64
true_getaddrinfo = socket.getaddrinfo
6✔
65
true_socketpair = socket.socketpair
6✔
66
true_ssl_wrap_socket = getattr(
6✔
67
    ssl, "wrap_socket", None
68
)  # in Py3.12 it's only under SSLContext
69
true_ssl_socket = ssl.SSLSocket
6✔
70
true_ssl_context = ssl.SSLContext
6✔
71
true_inet_pton = socket.inet_pton
6✔
72
true_urllib3_wrap_socket = urllib3_wrap_socket
6✔
73
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
6✔
74
true_urllib3_match_hostname = urllib3_match_hostname
6✔
75

76

77
class SuperFakeSSLContext:
6✔
78
    """For Python 3.6"""
79

80
    class FakeSetter(int):
6✔
81
        def __set__(self, *args):
6✔
82
            pass
6✔
83

84
    minimum_version = FakeSetter()
6✔
85
    options = FakeSetter()
6✔
86
    verify_mode = FakeSetter(ssl.CERT_NONE)
6✔
87

88

89
class FakeSSLContext(SuperFakeSSLContext):
6✔
90
    DUMMY_METHODS = (
6✔
91
        "load_default_certs",
92
        "load_verify_locations",
93
        "set_alpn_protocols",
94
        "set_ciphers",
95
        "set_default_verify_paths",
96
    )
97
    sock = None
6✔
98
    post_handshake_auth = None
6✔
99
    _check_hostname = False
6✔
100

101
    @property
6✔
102
    def check_hostname(self):
6✔
103
        return self._check_hostname
5✔
104

105
    @check_hostname.setter
6✔
106
    def check_hostname(self, *args):
6✔
107
        self._check_hostname = False
6✔
108

109
    def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
6✔
110
        self._set_dummy_methods()
6✔
111

112
        if isinstance(sock, MocketSocket):
6✔
113
            self.sock = sock
×
114
            self.sock._host = server_hostname
×
115
            self.sock.true_socket = true_ssl_socket(
×
116
                sock=self.sock.true_socket,
117
                server_hostname=server_hostname,
118
                _context=true_ssl_context(protocol=SSL_PROTOCOL),
119
            )
120
        elif isinstance(sock, int) and true_ssl_context:
6✔
121
            self.context = true_ssl_context(sock)
6✔
122

123
    def _set_dummy_methods(self):
6✔
124
        def dummy_method(*args, **kwargs):
6✔
125
            pass
6✔
126

127
        for m in self.DUMMY_METHODS:
6✔
128
            setattr(self, m, dummy_method)
6✔
129

130
    @staticmethod
6✔
131
    def wrap_socket(sock=sock, *args, **kwargs):
6✔
132
        sock.kwargs = kwargs
6✔
133
        sock._secure_socket = True
6✔
134
        return sock
6✔
135

136
    @staticmethod
6✔
137
    def wrap_bio(incoming, outcoming, *args, **kwargs):
6✔
138
        ssl_obj = MocketSocket()
6✔
139
        ssl_obj._host = kwargs["server_hostname"]
6✔
140
        return ssl_obj
6✔
141

142
    def __getattr__(self, name):
6✔
143
        if self.sock is not None:
6✔
144
            return getattr(self.sock, name)
×
145

146

147
def create_connection(address, timeout=None, source_address=None):
6✔
148
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
6✔
149
    if timeout:
6✔
150
        s.settimeout(timeout)
6✔
151
    s.connect(address)
6✔
152
    return s
6✔
153

154

155
def socketpair(*args, **kwargs):
6✔
156
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
157
    import _socket
6✔
158

159
    return _socket.socketpair(*args, **kwargs)
6✔
160

161

162
def _hash_request(h, req):
6✔
163
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
6✔
164

165

166
class MocketSocket:
6✔
167
    timeout = None
6✔
168
    _fd = None
6✔
169
    family = None
6✔
170
    type = None
6✔
171
    proto = None
6✔
172
    _host = None
6✔
173
    _port = None
6✔
174
    _address = None
6✔
175
    cipher = lambda s: ("ADH", "AES256", "SHA")
6✔
176
    compression = lambda s: ssl.OP_NO_COMPRESSION
6✔
177
    _mode = None
6✔
178
    _bufsize = None
6✔
179
    _secure_socket = False
6✔
180

181
    def __init__(
6✔
182
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
183
    ):
184
        self.true_socket = true_socket(family, type, proto)
6✔
185
        self._buflen = 65536
6✔
186
        self._entry = None
6✔
187
        self.family = int(family)
6✔
188
        self.type = int(type)
6✔
189
        self.proto = int(proto)
6✔
190
        self._truesocket_recording_dir = None
6✔
191
        self._did_handshake = False
6✔
192
        self._sent_non_empty_bytes = False
6✔
193
        self.kwargs = kwargs
6✔
194

195
    def __str__(self):
196
        return "({})(family={} type={} protocol={})".format(
197
            self.__class__.__name__, self.family, self.type, self.proto
198
        )
199

200
    def __enter__(self):
6✔
201
        return self
6✔
202

203
    def __exit__(self, exc_type, exc_val, exc_tb):
6✔
204
        self.close()
6✔
205

206
    @property
6✔
207
    def fd(self):
6✔
208
        if self._fd is None:
6✔
209
            self._fd = MocketSocketCore()
6✔
210
        return self._fd
6✔
211

212
    def gettimeout(self):
6✔
213
        return self.timeout
6✔
214

215
    def setsockopt(self, family, type, proto):
6✔
216
        self.family = family
6✔
217
        self.type = type
6✔
218
        self.proto = proto
6✔
219

220
        if self.true_socket:
6✔
221
            self.true_socket.setsockopt(family, type, proto)
6✔
222

223
    def settimeout(self, timeout):
6✔
224
        self.timeout = timeout
6✔
225

226
    @staticmethod
6✔
227
    def getsockopt(level, optname, buflen=None):
6✔
228
        return socket.SOCK_STREAM
×
229

230
    def do_handshake(self):
6✔
231
        self._did_handshake = True
6✔
232

233
    def getpeername(self):
6✔
234
        return self._address
6✔
235

236
    def setblocking(self, block):
6✔
237
        self.settimeout(None) if block else self.settimeout(0.0)
6✔
238

239
    def getsockname(self):
6✔
240
        return socket.gethostbyname(self._address[0]), self._address[1]
6✔
241

242
    def getpeercert(self, *args, **kwargs):
6✔
243
        if not (self._host and self._port):
5✔
244
            self._address = self._host, self._port = Mocket._address
5✔
245

246
        now = datetime.now()
5✔
247
        shift = now + timedelta(days=30 * 12)
5✔
248
        return {
5✔
249
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
250
            "subjectAltName": (
251
                ("DNS", "*.%s" % self._host),
252
                ("DNS", self._host),
253
                ("DNS", "*"),
254
            ),
255
            "subject": (
256
                (("organizationName", "*.%s" % self._host),),
257
                (("organizationalUnitName", "Domain Control Validated"),),
258
                (("commonName", "*.%s" % self._host),),
259
            ),
260
        }
261

262
    def unwrap(self):
6✔
263
        return self
4✔
264

265
    def write(self, data):
6✔
266
        return self.send(encode_to_bytes(data))
6✔
267

268
    @staticmethod
6✔
269
    def fileno():
6✔
270
        if Mocket.r_fd is not None:
6✔
271
            return Mocket.r_fd
6✔
272
        Mocket.r_fd, Mocket.w_fd = os.pipe()
6✔
273
        return Mocket.r_fd
6✔
274

275
    def connect(self, address):
6✔
276
        self._address = self._host, self._port = address
6✔
277
        Mocket._address = address
6✔
278

279
    def makefile(self, mode="r", bufsize=-1):
6✔
280
        self._mode = mode
6✔
281
        self._bufsize = bufsize
6✔
282
        return self.fd
6✔
283

284
    def get_entry(self, data):
6✔
285
        return Mocket.get_entry(self._host, self._port, data)
6✔
286

287
    def sendall(self, data, entry=None, *args, **kwargs):
6✔
288
        if entry is None:
6✔
289
            entry = self.get_entry(data)
6✔
290

291
        if entry:
6✔
292
            consume_response = entry.collect(data)
6✔
293
            if consume_response is not False:
6✔
294
                response = entry.get_response()
6✔
295
            else:
296
                response = None
6✔
297
        else:
298
            response = self.true_sendall(data, *args, **kwargs)
6✔
299

300
        if response is not None:
6✔
301
            self.fd.seek(0)
6✔
302
            self.fd.write(response)
6✔
303
            self.fd.truncate()
6✔
304
            self.fd.seek(0)
6✔
305

306
    def read(self, buffersize):
6✔
307
        rv = self.fd.read(buffersize)
6✔
308
        if rv:
6✔
309
            self._sent_non_empty_bytes = True
6✔
310
        if self._did_handshake and not self._sent_non_empty_bytes:
6✔
UNCOV
311
            raise ssl.SSLWantReadError("The operation did not complete (read)")
1✔
312
        return rv
6✔
313

314
    def recv_into(self, buffer, buffersize=None, flags=None):
6✔
315
        if hasattr(buffer, "write"):
6✔
316
            return buffer.write(self.read(buffersize))
6✔
317
        # buffer is a memoryview
UNCOV
318
        data = self.read(buffersize)
1✔
UNCOV
319
        if data:
1✔
320
            buffer[: len(data)] = data
×
UNCOV
321
        return len(data)
1✔
322

323
    def recv(self, buffersize, flags=None):
6✔
324
        if Mocket.r_fd and Mocket.w_fd:
6✔
325
            return os.read(Mocket.r_fd, buffersize)
6✔
326
        data = self.read(buffersize)
6✔
327
        if data:
6✔
328
            return data
6✔
329
        # used by Redis mock
330
        exc = BlockingIOError()
6✔
331
        exc.errno = errno.EWOULDBLOCK
6✔
332
        exc.args = (0,)
6✔
333
        raise exc
6✔
334

335
    def true_sendall(self, data, *args, **kwargs):
6✔
336
        if MocketMode().STRICT:
6✔
337
            raise StrictMocketException("Mocket tried to use the real `socket` module.")
6✔
338

339
        req = decode_from_bytes(data)
6✔
340
        # make request unique again
341
        req_signature = _hash_request(hasher, req)
6✔
342
        # port should be always a string
343
        port = text_type(self._port)
6✔
344

345
        # prepare responses dictionary
346
        responses = {}
6✔
347

348
        if Mocket.get_truesocket_recording_dir():
6✔
349
            path = os.path.join(
6✔
350
                Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
351
            )
352
            # check if there's already a recorded session dumped to a JSON file
353
            try:
6✔
354
                with io.open(path) as f:
6✔
355
                    responses = json.load(f)
6✔
356
            # if not, create a new dictionary
357
            except (FileNotFoundError, JSONDecodeError):
6✔
358
                pass
6✔
359

360
        try:
6✔
361
            try:
6✔
362
                response_dict = responses[self._host][port][req_signature]
6✔
363
            except KeyError:
6✔
364
                if hasher is not hashlib.md5:
6✔
365
                    # Fallback for backwards compatibility
366
                    req_signature = _hash_request(hashlib.md5, req)
6✔
367
                    response_dict = responses[self._host][port][req_signature]
6✔
368
                else:
369
                    raise
×
370
        except KeyError:
6✔
371
            # preventing next KeyError exceptions
372
            responses.setdefault(self._host, {})
6✔
373
            responses[self._host].setdefault(port, {})
6✔
374
            responses[self._host][port].setdefault(req_signature, {})
6✔
375
            response_dict = responses[self._host][port][req_signature]
6✔
376

377
        # try to get the response from the dictionary
378
        try:
6✔
379
            encoded_response = hexload(response_dict["response"])
6✔
380
        # if not available, call the real sendall
381
        except KeyError:
6✔
382
            host, port = Mocket._address
6✔
383
            host = true_gethostbyname(host)
6✔
384

385
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
6✔
386
                self.true_socket = true_urllib3_ssl_wrap_socket(
6✔
387
                    self.true_socket,
388
                    **self.kwargs,
389
                )
390

391
            try:
6✔
392
                self.true_socket.connect((host, port))
6✔
393
            except (OSError, socket.error, ValueError):
6✔
394
                # already connected
395
                pass
6✔
396
            self.true_socket.sendall(data, *args, **kwargs)
6✔
397
            encoded_response = b""
6✔
398
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L13
399
            while True:
4✔
400
                if (
6✔
401
                    not select.select([self.true_socket], [], [], 0.1)[0]
402
                    and encoded_response
403
                ):
404
                    break
6✔
405
                recv = self.true_socket.recv(self._buflen)
6✔
406

407
                if not recv and encoded_response:
6✔
408
                    break
6✔
409
                encoded_response += recv
6✔
410

411
            # dump the resulting dictionary to a JSON file
412
            if Mocket.get_truesocket_recording_dir():
6✔
413
                # update the dictionary with request and response lines
414
                response_dict["request"] = req
6✔
415
                response_dict["response"] = hexdump(encoded_response)
6✔
416

417
                with io.open(path, mode="w") as f:
6✔
418
                    f.write(
6✔
419
                        decode_from_bytes(
420
                            json.dumps(responses, indent=4, sort_keys=True)
421
                        )
422
                    )
423

424
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
425
        return encoded_response
6✔
426

427
    def send(self, data, *args, **kwargs):  # pragma: no cover
428
        entry = self.get_entry(data)
429
        if not entry or (entry and self._entry != entry):
430
            self.sendall(data, entry=entry, *args, **kwargs)
431
        else:
432
            req = Mocket.last_request()
433
            if hasattr(req, "add_data"):
434
                req.add_data(data)
435
        self._entry = entry
436
        return len(data)
437

438
    def close(self):
6✔
439
        if self.true_socket and not self.true_socket._closed:
6✔
440
            self.true_socket.close()
6✔
441
        self._fd = None
6✔
442

443
    def __getattr__(self, name):
6✔
444
        """Do nothing catchall function, for methods like close() and shutdown()"""
445

446
        def do_nothing(*args, **kwargs):
6✔
447
            pass
6✔
448

449
        return do_nothing
6✔
450

451

452
class Mocket:
6✔
453
    _address = (None, None)
6✔
454
    _entries = collections.defaultdict(list)
6✔
455
    _requests = []
6✔
456
    _namespace = text_type(id(_entries))
6✔
457
    _truesocket_recording_dir = None
6✔
458
    r_fd = None
6✔
459
    w_fd = None
6✔
460

461
    @classmethod
6✔
462
    def register(cls, *entries):
6✔
463
        for entry in entries:
6✔
464
            cls._entries[entry.location].append(entry)
6✔
465

466
    @classmethod
6✔
467
    def get_entry(cls, host, port, data):
6✔
468
        host = host or Mocket._address[0]
6✔
469
        port = port or Mocket._address[1]
6✔
470
        entries = cls._entries.get((host, port), [])
6✔
471
        for entry in entries:
6✔
472
            if entry.can_handle(data):
6✔
473
                return entry
6✔
474

475
    @classmethod
6✔
476
    def collect(cls, data):
6✔
477
        cls.request_list().append(data)
6✔
478

479
    @classmethod
6✔
480
    def reset(cls):
6✔
481
        if cls.r_fd is not None:
6✔
482
            os.close(cls.r_fd)
6✔
483
            cls.r_fd = None
6✔
484
        if cls.w_fd is not None:
6✔
485
            os.close(cls.w_fd)
6✔
486
            cls.w_fd = None
6✔
487
        cls._entries = collections.defaultdict(list)
6✔
488
        cls._requests = []
6✔
489

490
    @classmethod
6✔
491
    def last_request(cls):
6✔
492
        if cls.has_requests():
6✔
493
            return cls.request_list()[-1]
6✔
494

495
    @classmethod
6✔
496
    def request_list(cls):
6✔
497
        return cls._requests
6✔
498

499
    @classmethod
6✔
500
    def remove_last_request(cls):
6✔
501
        if cls.has_requests():
6✔
502
            del cls._requests[-1]
6✔
503

504
    @classmethod
6✔
505
    def has_requests(cls):
6✔
506
        return bool(cls.request_list())
6✔
507

508
    @staticmethod
6✔
509
    def enable(namespace=None, truesocket_recording_dir=None):
6✔
510
        Mocket._namespace = namespace
6✔
511
        Mocket._truesocket_recording_dir = truesocket_recording_dir
6✔
512

513
        if truesocket_recording_dir:
6✔
514
            # JSON dumps will be saved here
515
            if not os.path.isdir(truesocket_recording_dir):
6✔
516
                raise AssertionError
×
517

518
        socket.socket = socket.__dict__["socket"] = MocketSocket
6✔
519
        socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
6✔
520
        socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
6✔
521
        socket.create_connection = socket.__dict__[
6✔
522
            "create_connection"
523
        ] = create_connection
524
        socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
6✔
525
        socket.gethostbyname = socket.__dict__[
6✔
526
            "gethostbyname"
527
        ] = lambda host: "127.0.0.1"
528
        socket.getaddrinfo = socket.__dict__[
6✔
529
            "getaddrinfo"
530
        ] = lambda host, port, family=None, socktype=None, proto=None, flags=None: [
531
            (2, 1, 6, "", (host, port))
532
        ]
533
        socket.socketpair = socket.__dict__["socketpair"] = socketpair
6✔
534
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
6✔
535
        ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
6✔
536
        socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type(
6✔
537
            "\x7f\x00\x00\x01", "utf-8"
538
        )
539
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__[
6✔
540
            "wrap_socket"
541
        ] = FakeSSLContext.wrap_socket
542
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
6✔
543
            "ssl_wrap_socket"
544
        ] = FakeSSLContext.wrap_socket
545
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__[
6✔
546
            "ssl_wrap_socket"
547
        ] = FakeSSLContext.wrap_socket
548
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
6✔
549
            "ssl_wrap_socket"
550
        ] = FakeSSLContext.wrap_socket
551
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
6✔
552
            "match_hostname"
553
        ] = lambda *args: None
554
        if pyopenssl_override:  # pragma: no cover
555
            # Take out the pyopenssl version - use the default implementation
556
            extract_from_urllib3()
557
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
558
            aiohttp_make_ssl_context_cache_clear()
559

560
    @staticmethod
6✔
561
    def disable():
6✔
562
        socket.socket = socket.__dict__["socket"] = true_socket
6✔
563
        socket._socketobject = socket.__dict__["_socketobject"] = true_socket
6✔
564
        socket.SocketType = socket.__dict__["SocketType"] = true_socket
6✔
565
        socket.create_connection = socket.__dict__[
6✔
566
            "create_connection"
567
        ] = true_create_connection
568
        socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
6✔
569
        socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
6✔
570
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
6✔
571
        socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
6✔
572
        if true_ssl_wrap_socket:
6✔
573
            ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
5✔
574
        ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
6✔
575
        socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
6✔
576
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__[
6✔
577
            "wrap_socket"
578
        ] = true_urllib3_wrap_socket
579
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
6✔
580
            "ssl_wrap_socket"
581
        ] = true_urllib3_ssl_wrap_socket
582
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__[
6✔
583
            "ssl_wrap_socket"
584
        ] = true_urllib3_ssl_wrap_socket
585
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
6✔
586
            "ssl_wrap_socket"
587
        ] = true_urllib3_ssl_wrap_socket
588
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
6✔
589
            "match_hostname"
590
        ] = true_urllib3_match_hostname
591
        Mocket.reset()
6✔
592
        if pyopenssl_override:  # pragma: no cover
593
            # Put the pyopenssl version back in place
594
            inject_into_urllib3()
595
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
596
            aiohttp_make_ssl_context_cache_clear()
597

598
    @classmethod
6✔
599
    def get_namespace(cls):
6✔
600
        return cls._namespace
6✔
601

602
    @classmethod
6✔
603
    def get_truesocket_recording_dir(cls):
6✔
604
        return cls._truesocket_recording_dir
6✔
605

606
    @classmethod
6✔
607
    def assert_fail_if_entries_not_served(cls):
6✔
608
        """Mocket checks that all entries have been served at least once."""
609
        if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
6✔
610
            raise AssertionError("Some Mocket entries have not been served")
6✔
611

612

613
class MocketEntry:
6✔
614
    class Response(byte_type):
6✔
615
        @property
6✔
616
        def data(self):
6✔
617
            return self
6✔
618

619
    response_index = 0
6✔
620
    request_cls = byte_type
6✔
621
    response_cls = Response
6✔
622
    responses = None
6✔
623
    _served = None
6✔
624

625
    def __init__(self, location, responses):
6✔
626
        self._served = False
6✔
627
        self.location = location
6✔
628

629
        if not isinstance(responses, collections_abc.Iterable) or isinstance(
6✔
630
            responses, basestring
631
        ):
632
            responses = [responses]
6✔
633

634
        if not responses:
6✔
635
            self.responses = [self.response_cls(encode_to_bytes(""))]
6✔
636
        else:
637
            self.responses = []
6✔
638
            for r in responses:
6✔
639
                if not isinstance(r, BaseException) and not getattr(r, "data", False):
6✔
640
                    if isinstance(r, text_type):
6✔
641
                        r = encode_to_bytes(r)
6✔
642
                    r = self.response_cls(r)
6✔
643
                self.responses.append(r)
6✔
644

645
    @staticmethod
6✔
646
    def can_handle(data):
6✔
647
        return True
6✔
648

649
    def collect(self, data):
6✔
650
        req = self.request_cls(data)
6✔
651
        Mocket.collect(req)
6✔
652

653
    def get_response(self):
6✔
654
        response = self.responses[self.response_index]
6✔
655
        if self.response_index < len(self.responses) - 1:
6✔
656
            self.response_index += 1
6✔
657

658
        self._served = True
6✔
659

660
        if isinstance(response, BaseException):
6✔
661
            raise response
6✔
662

663
        return response.data
6✔
664

665

666
class Mocketizer:
6✔
667
    def __init__(
6✔
668
        self,
669
        instance=None,
670
        namespace=None,
671
        truesocket_recording_dir=None,
672
        strict_mode=False,
673
    ):
674
        self.instance = instance
6✔
675
        self.truesocket_recording_dir = truesocket_recording_dir
6✔
676
        self.namespace = namespace or text_type(id(self))
6✔
677
        MocketMode().STRICT = strict_mode
6✔
678

679
    def enter(self):
6✔
680
        Mocket.enable(
6✔
681
            namespace=self.namespace,
682
            truesocket_recording_dir=self.truesocket_recording_dir,
683
        )
684
        if self.instance:
6✔
685
            self.check_and_call("mocketize_setup")
6✔
686

687
    def __enter__(self):
6✔
688
        self.enter()
6✔
689
        return self
6✔
690

691
    def exit(self):
6✔
692
        if self.instance:
6✔
693
            self.check_and_call("mocketize_teardown")
6✔
694
        Mocket.disable()
6✔
695

696
    def __exit__(self, type, value, tb):
6✔
697
        self.exit()
6✔
698

699
    async def __aenter__(self, *args, **kwargs):
6✔
700
        self.enter()
6✔
701
        return self
6✔
702

703
    async def __aexit__(self, *args, **kwargs):
6✔
704
        self.exit()
6✔
705

706
    def check_and_call(self, method_name):
6✔
707
        method = getattr(self.instance, method_name, None)
6✔
708
        if callable(method):
6✔
709
            method()
6✔
710

711
    @staticmethod
6✔
712
    def factory(test, truesocket_recording_dir, strict_mode, args):
6✔
713
        instance = args[0] if args else None
6✔
714
        namespace = None
6✔
715
        if truesocket_recording_dir:
6✔
716
            namespace = ".".join(
6✔
717
                (
718
                    instance.__class__.__module__,
719
                    instance.__class__.__name__,
720
                    test.__name__,
721
                )
722
            )
723

724
        return Mocketizer(
6✔
725
            instance,
726
            namespace=namespace,
727
            truesocket_recording_dir=truesocket_recording_dir,
728
            strict_mode=strict_mode,
729
        )
730

731

732
def wrapper(test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs):
6✔
733
    with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args):
6✔
734
        return test(*args, **kwargs)
6✔
735

736

737
mocketize = get_mocketize(wrapper_=wrapper)
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc