• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 9059293298

13 May 2024 07:58AM UTC coverage: 97.608% (-1.4%) from 99.046%
9059293298

Pull #226

github

web-flow
Merge 345d04a67 into d154be2ec
Pull Request #226: Big response for async clients

21 of 21 new or added lines in 2 files covered. (100.0%)

8 existing lines in 1 file now uncovered.

816 of 836 relevant lines covered (97.61%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.5
/mocket/mocket.py
1
import collections
1✔
2
import collections.abc as collections_abc
1✔
3
import contextlib
1✔
4
import errno
1✔
5
import hashlib
1✔
6
import itertools
1✔
7
import json
1✔
8
import os
1✔
9
import select
1✔
10
import socket
1✔
11
import ssl
1✔
12
from datetime import datetime, timedelta
1✔
13
from json.decoder import JSONDecodeError
1✔
14

15
import urllib3
1✔
16
from urllib3.connection import match_hostname as urllib3_match_hostname
1✔
17
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
1✔
18

19
try:
1✔
20
    from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
1✔
21
except ImportError:
22
    urllib3_wrap_socket = None
23

24

25
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
1✔
26
from .utils import (
1✔
27
    SSL_PROTOCOL,
28
    MocketMode,
29
    MocketSocketCore,
30
    get_mocketize,
31
    hexdump,
32
    hexload,
33
)
34

35
xxh32 = None
1✔
36
try:
1✔
37
    from xxhash import xxh32
1✔
38
except ImportError:  # pragma: no cover
39
    with contextlib.suppress(ImportError):
40
        from xxhash_cffi import xxh32
41
hasher = xxh32 or hashlib.md5
1✔
42

43
try:  # pragma: no cover
44
    from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
45

46
    pyopenssl_override = True
47
except ImportError:
48
    pyopenssl_override = False
49

50
try:  # pragma: no cover
51
    from aiohttp import TCPConnector
52

53
    aiohttp_make_ssl_context_cache_clear = TCPConnector._make_ssl_context.cache_clear
54
except (ImportError, AttributeError):
1✔
55
    aiohttp_make_ssl_context_cache_clear = None
1✔
56

57

58
true_socket = socket.socket
1✔
59
true_create_connection = socket.create_connection
1✔
60
true_gethostbyname = socket.gethostbyname
1✔
61
true_gethostname = socket.gethostname
1✔
62
true_getaddrinfo = socket.getaddrinfo
1✔
63
true_socketpair = socket.socketpair
1✔
64
true_ssl_wrap_socket = getattr(
1✔
65
    ssl, "wrap_socket", None
66
)  # in Py3.12 it's only under SSLContext
67
true_ssl_socket = ssl.SSLSocket
1✔
68
true_ssl_context = ssl.SSLContext
1✔
69
true_inet_pton = socket.inet_pton
1✔
70
true_urllib3_wrap_socket = urllib3_wrap_socket
1✔
71
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
1✔
72
true_urllib3_match_hostname = urllib3_match_hostname
1✔
73

74

75
class SuperFakeSSLContext:
1✔
76
    """For Python 3.6"""
77

78
    class FakeSetter(int):
1✔
79
        def __set__(self, *args):
1✔
80
            pass
1✔
81

82
    minimum_version = FakeSetter()
1✔
83
    options = FakeSetter()
1✔
84
    verify_mode = FakeSetter(ssl.CERT_NONE)
1✔
85

86

87
class FakeSSLContext(SuperFakeSSLContext):
1✔
88
    DUMMY_METHODS = (
1✔
89
        "load_default_certs",
90
        "load_verify_locations",
91
        "set_alpn_protocols",
92
        "set_ciphers",
93
        "set_default_verify_paths",
94
    )
95
    sock = None
1✔
96
    post_handshake_auth = None
1✔
97
    _check_hostname = False
1✔
98

99
    @property
1✔
100
    def check_hostname(self):
1✔
UNCOV
101
        return self._check_hostname
×
102

103
    @check_hostname.setter
1✔
104
    def check_hostname(self, *args):
1✔
105
        self._check_hostname = False
1✔
106

107
    def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
1✔
108
        self._set_dummy_methods()
1✔
109

110
        if isinstance(sock, MocketSocket):
1✔
111
            self.sock = sock
×
112
            self.sock._host = server_hostname
×
113
            self.sock.true_socket = true_ssl_socket(
×
114
                sock=self.sock.true_socket,
115
                server_hostname=server_hostname,
116
                _context=true_ssl_context(protocol=SSL_PROTOCOL),
117
            )
118
        elif isinstance(sock, int) and true_ssl_context:
1✔
119
            self.context = true_ssl_context(sock)
1✔
120

121
    def _set_dummy_methods(self):
1✔
122
        def dummy_method(*args, **kwargs):
1✔
123
            pass
1✔
124

125
        for m in self.DUMMY_METHODS:
1✔
126
            setattr(self, m, dummy_method)
1✔
127

128
    @staticmethod
1✔
129
    def wrap_socket(sock=sock, *args, **kwargs):
1✔
130
        sock.kwargs = kwargs
1✔
131
        sock._secure_socket = True
1✔
132
        return sock
1✔
133

134
    @staticmethod
1✔
135
    def wrap_bio(incoming, outcoming, *args, **kwargs):
1✔
136
        ssl_obj = MocketSocket()
1✔
137
        ssl_obj._host = kwargs["server_hostname"]
1✔
138
        return ssl_obj
1✔
139

140
    def __getattr__(self, name):
1✔
141
        if self.sock is not None:
1✔
142
            return getattr(self.sock, name)
×
143

144

145
def create_connection(address, timeout=None, source_address=None):
1✔
146
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
1✔
147
    if timeout:
1✔
148
        s.settimeout(timeout)
1✔
149
    s.connect(address)
1✔
150
    return s
1✔
151

152

153
def socketpair(*args, **kwargs):
1✔
154
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
155
    import _socket
1✔
156

157
    return _socket.socketpair(*args, **kwargs)
1✔
158

159

160
def _hash_request(h, req):
1✔
161
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
1✔
162

163

164
class MocketSocket:
1✔
165
    timeout = None
1✔
166
    _fd = None
1✔
167
    family = None
1✔
168
    type = None
1✔
169
    proto = None
1✔
170
    _host = None
1✔
171
    _port = None
1✔
172
    _address = None
1✔
173
    cipher = lambda s: ("ADH", "AES256", "SHA")
1✔
174
    compression = lambda s: ssl.OP_NO_COMPRESSION
1✔
175
    _mode = None
1✔
176
    _bufsize = None
1✔
177
    _secure_socket = False
1✔
178
    _did_handshake = False
1✔
179
    _sent_non_empty_bytes = False
1✔
180
    read_fd = None
1✔
181
    write_fd = None
1✔
182

183
    def __init__(
1✔
184
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
185
    ):
186
        self.true_socket = true_socket(family, type, proto)
1✔
187
        self._buflen = 65536
1✔
188
        self._entry = None
1✔
189
        self.family = int(family)
1✔
190
        self.type = int(type)
1✔
191
        self.proto = int(proto)
1✔
192
        self._truesocket_recording_dir = None
1✔
193
        self.kwargs = kwargs
1✔
194

195
    def __str__(self):
196
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
197

198
    def __enter__(self):
1✔
199
        return self
1✔
200

201
    def __exit__(self, exc_type, exc_val, exc_tb):
1✔
202
        self.close()
1✔
203

204
    @property
1✔
205
    def fd(self):
1✔
206
        if self._fd is None:
1✔
207
            self._fd = MocketSocketCore(w_fd=self.write_fd)
1✔
208
        return self._fd
1✔
209

210
    def gettimeout(self):
1✔
211
        return self.timeout
1✔
212

213
    def setsockopt(self, family, type, proto):
1✔
214
        self.family = family
1✔
215
        self.type = type
1✔
216
        self.proto = proto
1✔
217

218
        if self.true_socket:
1✔
219
            self.true_socket.setsockopt(family, type, proto)
1✔
220

221
    def settimeout(self, timeout):
1✔
222
        self.timeout = timeout
1✔
223

224
    @staticmethod
1✔
225
    def getsockopt(level, optname, buflen=None):
1✔
226
        return socket.SOCK_STREAM
×
227

228
    def do_handshake(self):
1✔
229
        self._did_handshake = True
1✔
230

231
    def getpeername(self):
1✔
232
        return self._address
1✔
233

234
    def setblocking(self, block):
1✔
235
        self.settimeout(None) if block else self.settimeout(0.0)
1✔
236

237
    def getblocking(self):
1✔
238
        return self.gettimeout() is None
1✔
239

240
    def getsockname(self):
1✔
241
        return socket.gethostbyname(self._address[0]), self._address[1]
1✔
242

243
    def getpeercert(self, *args, **kwargs):
1✔
UNCOV
244
        if not (self._host and self._port):
×
UNCOV
245
            self._address = self._host, self._port = Mocket._address
×
246

UNCOV
247
        now = datetime.now()
×
UNCOV
248
        shift = now + timedelta(days=30 * 12)
×
UNCOV
249
        return {
×
250
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
251
            "subjectAltName": (
252
                ("DNS", f"*.{self._host}"),
253
                ("DNS", self._host),
254
                ("DNS", "*"),
255
            ),
256
            "subject": (
257
                (("organizationName", f"*.{self._host}"),),
258
                (("organizationalUnitName", "Domain Control Validated"),),
259
                (("commonName", f"*.{self._host}"),),
260
            ),
261
        }
262

263
    def unwrap(self):
1✔
UNCOV
264
        return self
×
265

266
    def write(self, data):
1✔
267
        return self.send(encode_to_bytes(data))
1✔
268

269
    def fileno(self):
1✔
270
        if self.read_fd:
1✔
271
            return self.read_fd
1✔
272
        self.read_fd, self.write_fd = os.pipe()
1✔
273
        return self.read_fd
1✔
274

275
    def connect(self, address):
1✔
276
        self._address = self._host, self._port = address
1✔
277
        Mocket._address = address
1✔
278

279
    def makefile(self, mode="r", bufsize=-1):
1✔
280
        self._mode = mode
1✔
281
        self._bufsize = bufsize
1✔
282
        return self.fd
1✔
283

284
    def get_entry(self, data):
1✔
285
        return Mocket.get_entry(self._host, self._port, data)
1✔
286

287
    def sendall(self, data, entry=None, *args, **kwargs):
1✔
288
        if entry is None:
1✔
289
            entry = self.get_entry(data)
1✔
290

291
        if entry:
1✔
292
            consume_response = entry.collect(data)
1✔
293
            response = entry.get_response() if consume_response is not False else None
1✔
294
        else:
295
            response = self.true_sendall(data, *args, **kwargs)
1✔
296

297
        if response is not None:
1✔
298
            self.fd.seek(0)
1✔
299
            self.fd.write(response)
1✔
300
            self.fd.truncate()
1✔
301
            self.fd.seek(0)
1✔
302

303
    def read(self, buffersize):
1✔
304
        rv = self.fd.read(buffersize)
1✔
305
        if rv:
1✔
306
            self._sent_non_empty_bytes = True
1✔
307
        if self._did_handshake and not self._sent_non_empty_bytes:
1✔
308
            raise ssl.SSLWantReadError("The operation did not complete (read)")
×
309
        return rv
1✔
310

311
    def recv_into(self, buffer, buffersize=None, flags=None):
1✔
312
        if hasattr(buffer, "write"):
1✔
313
            return buffer.write(self.read(buffersize))
1✔
314
        # buffer is a memoryview
315
        data = self.read(buffersize)
×
316
        if data:
×
317
            buffer[: len(data)] = data
×
318
        return len(data)
×
319

320
    def recv(self, buffersize, flags=None):
1✔
321
        if self.read_fd:
1✔
322
            return os.read(self.read_fd, buffersize)
1✔
323
        data = self.read(buffersize)
1✔
324
        if data:
1✔
325
            return data
1✔
326
        # used by Redis mock
327
        exc = BlockingIOError()
1✔
328
        exc.errno = errno.EWOULDBLOCK
1✔
329
        exc.args = (0,)
1✔
330
        raise exc
1✔
331

332
    def true_sendall(self, data, *args, **kwargs):
1✔
333
        if not MocketMode().is_allowed((self._host, self._port)):
1✔
334
            MocketMode.raise_not_allowed()
1✔
335

336
        req = decode_from_bytes(data)
1✔
337
        # make request unique again
338
        req_signature = _hash_request(hasher, req)
1✔
339
        # port should be always a string
340
        port = text_type(self._port)
1✔
341

342
        # prepare responses dictionary
343
        responses = {}
1✔
344

345
        if Mocket.get_truesocket_recording_dir():
1✔
346
            path = os.path.join(
1✔
347
                Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
348
            )
349
            # check if there's already a recorded session dumped to a JSON file
350
            try:
1✔
351
                with open(path) as f:
1✔
352
                    responses = json.load(f)
1✔
353
            # if not, create a new dictionary
354
            except (FileNotFoundError, JSONDecodeError):
1✔
355
                pass
1✔
356

357
        try:
1✔
358
            try:
1✔
359
                response_dict = responses[self._host][port][req_signature]
1✔
360
            except KeyError:
1✔
361
                if hasher is not hashlib.md5:
1✔
362
                    # Fallback for backwards compatibility
363
                    req_signature = _hash_request(hashlib.md5, req)
1✔
364
                    response_dict = responses[self._host][port][req_signature]
1✔
365
                else:
366
                    raise
×
367
        except KeyError:
1✔
368
            # preventing next KeyError exceptions
369
            responses.setdefault(self._host, {})
1✔
370
            responses[self._host].setdefault(port, {})
1✔
371
            responses[self._host][port].setdefault(req_signature, {})
1✔
372
            response_dict = responses[self._host][port][req_signature]
1✔
373

374
        # try to get the response from the dictionary
375
        try:
1✔
376
            encoded_response = hexload(response_dict["response"])
1✔
377
        # if not available, call the real sendall
378
        except KeyError:
1✔
379
            host, port = self._host, self._port
1✔
380
            host = true_gethostbyname(host)
1✔
381

382
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
1✔
383
                self.true_socket = true_urllib3_ssl_wrap_socket(
1✔
384
                    self.true_socket,
385
                    **self.kwargs,
386
                )
387

388
            with contextlib.suppress(OSError, ValueError):
1✔
389
                # already connected
390
                self.true_socket.connect((host, port))
1✔
391
            self.true_socket.sendall(data, *args, **kwargs)
1✔
392
            encoded_response = b""
1✔
393
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L13
394
            while True:
1✔
395
                if (
1✔
396
                    not select.select([self.true_socket], [], [], 0.1)[0]
397
                    and encoded_response
398
                ):
399
                    break
1✔
400
                recv = self.true_socket.recv(self._buflen)
1✔
401

402
                if not recv and encoded_response:
1✔
403
                    break
1✔
404
                encoded_response += recv
1✔
405

406
            # dump the resulting dictionary to a JSON file
407
            if Mocket.get_truesocket_recording_dir():
1✔
408
                # update the dictionary with request and response lines
409
                response_dict["request"] = req
1✔
410
                response_dict["response"] = hexdump(encoded_response)
1✔
411

412
                with open(path, mode="w") as f:
1✔
413
                    f.write(
1✔
414
                        decode_from_bytes(
415
                            json.dumps(responses, indent=4, sort_keys=True)
416
                        )
417
                    )
418

419
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
420
        return encoded_response
1✔
421

422
    def send(self, data, *args, **kwargs):  # pragma: no cover
423
        entry = self.get_entry(data)
424
        kwargs["entry"] = entry
425
        if not entry or (entry and self._entry != entry):
426
            self.sendall(data, *args, **kwargs)
427
        else:
428
            req = Mocket.last_request()
429
            if hasattr(req, "add_data"):
430
                req.add_data(data)
431
        self._entry = entry
432
        return len(data)
433

434
    def close(self):
1✔
435
        if self.true_socket and not self.true_socket._closed:
1✔
436
            self.true_socket.close()
1✔
437
        self._fd = None
1✔
438

439
    def __getattr__(self, name):
1✔
440
        """Do nothing catchall function, for methods like shutdown()"""
441

442
        def do_nothing(*args, **kwargs):
1✔
443
            pass
1✔
444

445
        return do_nothing
1✔
446

447

448
class Mocket:
1✔
449
    _address = (None, None)
1✔
450
    _entries = collections.defaultdict(list)
1✔
451
    _requests = []
1✔
452
    _namespace = text_type(id(_entries))
1✔
453
    _truesocket_recording_dir = None
1✔
454

455
    @classmethod
1✔
456
    def register(cls, *entries):
1✔
457
        for entry in entries:
1✔
458
            cls._entries[entry.location].append(entry)
1✔
459

460
    @classmethod
1✔
461
    def get_entry(cls, host, port, data):
1✔
462
        host = host or Mocket._address[0]
1✔
463
        port = port or Mocket._address[1]
1✔
464
        entries = cls._entries.get((host, port), [])
1✔
465
        for entry in entries:
1✔
466
            if entry.can_handle(data):
1✔
467
                return entry
1✔
468

469
    @classmethod
1✔
470
    def collect(cls, data):
1✔
471
        cls.request_list().append(data)
1✔
472

473
    @classmethod
1✔
474
    def reset(cls):
1✔
475
        cls._entries = collections.defaultdict(list)
1✔
476
        cls._requests = []
1✔
477

478
    @classmethod
1✔
479
    def last_request(cls):
1✔
480
        if cls.has_requests():
1✔
481
            return cls.request_list()[-1]
1✔
482

483
    @classmethod
1✔
484
    def request_list(cls):
1✔
485
        return cls._requests
1✔
486

487
    @classmethod
1✔
488
    def remove_last_request(cls):
1✔
489
        if cls.has_requests():
1✔
490
            del cls._requests[-1]
1✔
491

492
    @classmethod
1✔
493
    def has_requests(cls):
1✔
494
        return bool(cls.request_list())
1✔
495

496
    @staticmethod
1✔
497
    def enable(namespace=None, truesocket_recording_dir=None):
1✔
498
        Mocket._namespace = namespace
1✔
499
        Mocket._truesocket_recording_dir = truesocket_recording_dir
1✔
500

501
        if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir):
1✔
502
            # JSON dumps will be saved here
503
            raise AssertionError
×
504

505
        socket.socket = socket.__dict__["socket"] = MocketSocket
1✔
506
        socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
1✔
507
        socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
1✔
508
        socket.create_connection = socket.__dict__["create_connection"] = (
1✔
509
            create_connection
510
        )
511
        socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
1✔
512
        socket.gethostbyname = socket.__dict__["gethostbyname"] = (
1✔
513
            lambda host: "127.0.0.1"
514
        )
515
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = (
1✔
516
            lambda host, port, family=None, socktype=None, proto=None, flags=None: [
517
                (2, 1, 6, "", (host, port))
518
            ]
519
        )
520
        socket.socketpair = socket.__dict__["socketpair"] = socketpair
1✔
521
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
1✔
522
        ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
1✔
523
        socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type(
1✔
524
            "\x7f\x00\x00\x01", "utf-8"
525
        )
526
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
1✔
527
            FakeSSLContext.wrap_socket
528
        )
529
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
1✔
530
            "ssl_wrap_socket"
531
        ] = FakeSSLContext.wrap_socket
532
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
1✔
533
            FakeSSLContext.wrap_socket
534
        )
535
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
1✔
536
            "ssl_wrap_socket"
537
        ] = FakeSSLContext.wrap_socket
538
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
1✔
539
            "match_hostname"
540
        ] = lambda *args: None
541
        if pyopenssl_override:  # pragma: no cover
542
            # Take out the pyopenssl version - use the default implementation
543
            extract_from_urllib3()
544
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
545
            aiohttp_make_ssl_context_cache_clear()
546

547
    @staticmethod
1✔
548
    def disable():
1✔
549
        socket.socket = socket.__dict__["socket"] = true_socket
1✔
550
        socket._socketobject = socket.__dict__["_socketobject"] = true_socket
1✔
551
        socket.SocketType = socket.__dict__["SocketType"] = true_socket
1✔
552
        socket.create_connection = socket.__dict__["create_connection"] = (
1✔
553
            true_create_connection
554
        )
555
        socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
1✔
556
        socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
1✔
557
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
1✔
558
        socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
1✔
559
        if true_ssl_wrap_socket:
1✔
UNCOV
560
            ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
×
561
        ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
1✔
562
        socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
1✔
563
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = (
1✔
564
            true_urllib3_wrap_socket
565
        )
566
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
1✔
567
            "ssl_wrap_socket"
568
        ] = true_urllib3_ssl_wrap_socket
569
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = (
1✔
570
            true_urllib3_ssl_wrap_socket
571
        )
572
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
1✔
573
            "ssl_wrap_socket"
574
        ] = true_urllib3_ssl_wrap_socket
575
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
1✔
576
            "match_hostname"
577
        ] = true_urllib3_match_hostname
578
        Mocket.reset()
1✔
579
        if pyopenssl_override:  # pragma: no cover
580
            # Put the pyopenssl version back in place
581
            inject_into_urllib3()
582
        if aiohttp_make_ssl_context_cache_clear:  # pragma: no cover
583
            aiohttp_make_ssl_context_cache_clear()
584

585
    @classmethod
1✔
586
    def get_namespace(cls):
1✔
587
        return cls._namespace
1✔
588

589
    @classmethod
1✔
590
    def get_truesocket_recording_dir(cls):
1✔
591
        return cls._truesocket_recording_dir
1✔
592

593
    @classmethod
1✔
594
    def assert_fail_if_entries_not_served(cls):
1✔
595
        """Mocket checks that all entries have been served at least once."""
596
        if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
1✔
597
            raise AssertionError("Some Mocket entries have not been served")
1✔
598

599

600
class MocketEntry:
1✔
601
    class Response(byte_type):
1✔
602
        @property
1✔
603
        def data(self):
1✔
604
            return self
1✔
605

606
    response_index = 0
1✔
607
    request_cls = byte_type
1✔
608
    response_cls = Response
1✔
609
    responses = None
1✔
610
    _served = None
1✔
611

612
    def __init__(self, location, responses):
1✔
613
        self._served = False
1✔
614
        self.location = location
1✔
615

616
        if not isinstance(responses, collections_abc.Iterable) or isinstance(
1✔
617
            responses, basestring
618
        ):
619
            responses = [responses]
1✔
620

621
        if not responses:
1✔
622
            self.responses = [self.response_cls(encode_to_bytes(""))]
1✔
623
        else:
624
            self.responses = []
1✔
625
            for r in responses:
1✔
626
                if not isinstance(r, BaseException) and not getattr(r, "data", False):
1✔
627
                    if isinstance(r, text_type):
1✔
628
                        r = encode_to_bytes(r)
1✔
629
                    r = self.response_cls(r)
1✔
630
                self.responses.append(r)
1✔
631

632
    def __repr__(self):
633
        return f"{self.__class__.__name__}(location={self.location})"
634

635
    @staticmethod
1✔
636
    def can_handle(data):
1✔
637
        return True
1✔
638

639
    def collect(self, data):
1✔
640
        req = self.request_cls(data)
1✔
641
        Mocket.collect(req)
1✔
642

643
    def get_response(self):
1✔
644
        response = self.responses[self.response_index]
1✔
645
        if self.response_index < len(self.responses) - 1:
1✔
646
            self.response_index += 1
1✔
647

648
        self._served = True
1✔
649

650
        if isinstance(response, BaseException):
1✔
651
            raise response
1✔
652

653
        return response.data
1✔
654

655

656
class Mocketizer:
1✔
657
    def __init__(
1✔
658
        self,
659
        instance=None,
660
        namespace=None,
661
        truesocket_recording_dir=None,
662
        strict_mode=False,
663
        strict_mode_allowed=None,
664
    ):
665
        self.instance = instance
1✔
666
        self.truesocket_recording_dir = truesocket_recording_dir
1✔
667
        self.namespace = namespace or text_type(id(self))
1✔
668
        MocketMode().STRICT = strict_mode
1✔
669
        if strict_mode:
1✔
670
            MocketMode().STRICT_ALLOWED = strict_mode_allowed or []
1✔
671
        elif strict_mode_allowed:
1✔
672
            raise ValueError(
1✔
673
                "Allowed locations are only accepted when STRICT mode is active."
674
            )
675

676
    def enter(self):
1✔
677
        Mocket.enable(
1✔
678
            namespace=self.namespace,
679
            truesocket_recording_dir=self.truesocket_recording_dir,
680
        )
681
        if self.instance:
1✔
682
            self.check_and_call("mocketize_setup")
1✔
683

684
    def __enter__(self):
1✔
685
        self.enter()
1✔
686
        return self
1✔
687

688
    def exit(self):
1✔
689
        if self.instance:
1✔
690
            self.check_and_call("mocketize_teardown")
1✔
691
        Mocket.disable()
1✔
692

693
    def __exit__(self, type, value, tb):
1✔
694
        self.exit()
1✔
695

696
    async def __aenter__(self, *args, **kwargs):
1✔
697
        self.enter()
1✔
698
        return self
1✔
699

700
    async def __aexit__(self, *args, **kwargs):
1✔
701
        self.exit()
1✔
702

703
    def check_and_call(self, method_name):
1✔
704
        method = getattr(self.instance, method_name, None)
1✔
705
        if callable(method):
1✔
706
            method()
1✔
707

708
    @staticmethod
1✔
709
    def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
1✔
710
        instance = args[0] if args else None
1✔
711
        namespace = None
1✔
712
        if truesocket_recording_dir:
1✔
713
            namespace = ".".join(
1✔
714
                (
715
                    instance.__class__.__module__,
716
                    instance.__class__.__name__,
717
                    test.__name__,
718
                )
719
            )
720

721
        return Mocketizer(
1✔
722
            instance,
723
            namespace=namespace,
724
            truesocket_recording_dir=truesocket_recording_dir,
725
            strict_mode=strict_mode,
726
            strict_mode_allowed=strict_mode_allowed,
727
        )
728

729

730
def wrapper(
1✔
731
    test,
732
    truesocket_recording_dir=None,
733
    strict_mode=False,
734
    strict_mode_allowed=None,
735
    *args,
736
    **kwargs,
737
):
738
    with Mocketizer.factory(
1✔
739
        test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
740
    ):
741
        return test(*args, **kwargs)
1✔
742

743

744
mocketize = get_mocketize(wrapper_=wrapper)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc