• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 6602480756

22 Oct 2023 07:38AM UTC coverage: 98.015% (-0.7%) from 98.761%
6602480756

Pull #204

github-actions

Giorgio Salluzzo
Small refactor.
Pull Request #204: Support for Python 3.12

7 of 7 new or added lines in 2 files covered. (100.0%)

790 of 806 relevant lines covered (98.01%)

3.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.98
/mocket/mocket.py
1
import collections
4✔
2
import collections.abc as collections_abc
4✔
3
import errno
4✔
4
import hashlib
4✔
5
import io
4✔
6
import itertools
4✔
7
import json
4✔
8
import os
4✔
9
import select
4✔
10
import socket
4✔
11
import ssl
4✔
12
from datetime import datetime, timedelta
4✔
13
from json.decoder import JSONDecodeError
4✔
14

15
import urllib3
4✔
16
from urllib3.connection import match_hostname as urllib3_match_hostname
4✔
17
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
4✔
18

19
try:
4✔
20
    from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
4✔
21
except ImportError:
22
    urllib3_wrap_socket = None
23

24
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
4✔
25
from .exceptions import StrictMocketException
4✔
26
from .utils import (
4✔
27
    SSL_PROTOCOL,
28
    MocketMode,
29
    MocketSocketCore,
30
    get_mocketize,
31
    hexdump,
32
    hexload,
33
)
34

35
xxh32 = None
4✔
36
try:
4✔
37
    from xxhash import xxh32
4✔
38
except ImportError:  # pragma: no cover
39
    try:
40
        from xxhash_cffi import xxh32
41
    except ImportError:
42
        pass
43
hasher = xxh32 or hashlib.md5
4✔
44

45
try:  # pragma: no cover
46
    from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3
47

48
    pyopenssl_override = True
49
except ImportError:
50
    pyopenssl_override = False
51

52

53
true_socket = socket.socket
4✔
54
true_create_connection = socket.create_connection
4✔
55
true_gethostbyname = socket.gethostbyname
4✔
56
true_gethostname = socket.gethostname
4✔
57
true_getaddrinfo = socket.getaddrinfo
4✔
58
true_socketpair = socket.socketpair
4✔
59
true_ssl_wrap_socket = ssl.wrap_socket
4✔
60
true_ssl_socket = ssl.SSLSocket
4✔
61
true_ssl_context = ssl.SSLContext
4✔
62
true_inet_pton = socket.inet_pton
4✔
63
true_urllib3_wrap_socket = urllib3_wrap_socket
4✔
64
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
4✔
65
true_urllib3_match_hostname = urllib3_match_hostname
4✔
66

67

68
class SuperFakeSSLContext:
4✔
69
    """For Python 3.6"""
70

71
    class FakeSetter(int):
4✔
72
        def __set__(self, *args):
4✔
73
            pass
4✔
74

75
    minimum_version = FakeSetter()
4✔
76
    options = FakeSetter()
4✔
77
    verify_mode = FakeSetter(ssl.CERT_NONE)
4✔
78

79

80
class FakeSSLContext(SuperFakeSSLContext):
4✔
81
    DUMMY_METHODS = (
4✔
82
        "load_default_certs",
83
        "load_verify_locations",
84
        "set_alpn_protocols",
85
        "set_ciphers",
86
    )
87
    sock = None
4✔
88
    post_handshake_auth = None
4✔
89
    _check_hostname = False
4✔
90

91
    @property
4✔
92
    def check_hostname(self):
4✔
93
        return self._check_hostname
4✔
94

95
    @check_hostname.setter
4✔
96
    def check_hostname(self, *args):
4✔
97
        self._check_hostname = False
4✔
98

99
    def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
4✔
100
        self._set_dummy_methods()
4✔
101

102
        if isinstance(sock, MocketSocket):
4✔
103
            self.sock = sock
×
104
            self.sock._host = server_hostname
×
105
            self.sock.true_socket = true_ssl_socket(
×
106
                sock=self.sock.true_socket,
107
                server_hostname=server_hostname,
108
                _context=true_ssl_context(protocol=SSL_PROTOCOL),
109
            )
110
        elif isinstance(sock, int) and true_ssl_context:
4✔
111
            self.context = true_ssl_context(sock)
4✔
112

113
    def _set_dummy_methods(self):
4✔
114
        def dummy_method(*args, **kwargs):
4✔
115
            pass
4✔
116

117
        for m in self.DUMMY_METHODS:
4✔
118
            setattr(self, m, dummy_method)
4✔
119

120
    @staticmethod
4✔
121
    def wrap_socket(sock=sock, *args, **kwargs):
4✔
122
        sock.kwargs = kwargs
4✔
123
        sock._secure_socket = True
4✔
124
        return sock
4✔
125

126
    @staticmethod
4✔
127
    def wrap_bio(incoming, outcoming, *args, **kwargs):
4✔
128
        ssl_obj = MocketSocket()
4✔
129
        ssl_obj._host = kwargs["server_hostname"]
4✔
130
        return ssl_obj
4✔
131

132
    def __getattr__(self, name):
4✔
133
        if self.sock is not None:
4✔
134
            return getattr(self.sock, name)
×
135

136

137
def create_connection(address, timeout=None, source_address=None):
4✔
138
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
4✔
139
    if timeout:
4✔
140
        s.settimeout(timeout)
4✔
141
    s.connect(address)
4✔
142
    return s
4✔
143

144

145
def socketpair(*args, **kwargs):
4✔
146
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
147
    import _socket
4✔
148

149
    return _socket.socketpair(*args, **kwargs)
4✔
150

151

152
def _hash_request(h, req):
4✔
153
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
4✔
154

155

156
class MocketSocket:
4✔
157
    timeout = None
4✔
158
    _fd = None
4✔
159
    family = None
4✔
160
    type = None
4✔
161
    proto = None
4✔
162
    _host = None
4✔
163
    _port = None
4✔
164
    _address = None
4✔
165
    cipher = lambda s: ("ADH", "AES256", "SHA")
4✔
166
    compression = lambda s: ssl.OP_NO_COMPRESSION
4✔
167
    _mode = None
4✔
168
    _bufsize = None
4✔
169
    _secure_socket = False
4✔
170

171
    def __init__(
4✔
172
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
173
    ):
174
        self.true_socket = true_socket(family, type, proto)
4✔
175
        self._buflen = 65536
4✔
176
        self._entry = None
4✔
177
        self.family = int(family)
4✔
178
        self.type = int(type)
4✔
179
        self.proto = int(proto)
4✔
180
        self._truesocket_recording_dir = None
4✔
181
        self.kwargs = kwargs
4✔
182

183
    def __str__(self):
184
        return "({})(family={} type={} protocol={})".format(
185
            self.__class__.__name__, self.family, self.type, self.proto
186
        )
187

188
    def __enter__(self):
4✔
189
        return self
4✔
190

191
    def __exit__(self, exc_type, exc_val, exc_tb):
4✔
192
        self.close()
4✔
193

194
    @property
4✔
195
    def fd(self):
4✔
196
        if self._fd is None:
4✔
197
            self._fd = MocketSocketCore()
4✔
198
        return self._fd
4✔
199

200
    def gettimeout(self):
4✔
201
        return self.timeout
4✔
202

203
    def setsockopt(self, family, type, proto):
4✔
204
        self.family = family
4✔
205
        self.type = type
4✔
206
        self.proto = proto
4✔
207

208
        if self.true_socket:
4✔
209
            self.true_socket.setsockopt(family, type, proto)
4✔
210

211
    def settimeout(self, timeout):
4✔
212
        self.timeout = timeout
4✔
213

214
    @staticmethod
4✔
215
    def getsockopt(level, optname, buflen=None):
4✔
216
        return socket.SOCK_STREAM
×
217

218
    def do_handshake(self):
4✔
219
        pass
4✔
220

221
    def getpeername(self):
4✔
222
        return self._address
4✔
223

224
    def setblocking(self, block):
4✔
225
        self.settimeout(None) if block else self.settimeout(0.0)
4✔
226

227
    def getsockname(self):
4✔
228
        return socket.gethostbyname(self._address[0]), self._address[1]
4✔
229

230
    def getpeercert(self, *args, **kwargs):
4✔
231
        if not (self._host and self._port):
×
232
            self._address = self._host, self._port = Mocket._address
×
233

234
        now = datetime.now()
×
235
        shift = now + timedelta(days=30 * 12)
×
236
        return {
×
237
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
238
            "subjectAltName": (
239
                ("DNS", "*.%s" % self._host),
240
                ("DNS", self._host),
241
                ("DNS", "*"),
242
            ),
243
            "subject": (
244
                (("organizationName", "*.%s" % self._host),),
245
                (("organizationalUnitName", "Domain Control Validated"),),
246
                (("commonName", "*.%s" % self._host),),
247
            ),
248
        }
249

250
    def unwrap(self):
4✔
251
        return self
×
252

253
    def write(self, data):
4✔
254
        return self.send(encode_to_bytes(data))
4✔
255

256
    @staticmethod
4✔
257
    def fileno():
4✔
258
        Mocket.r_fd, Mocket.w_fd = os.pipe()
4✔
259
        return Mocket.r_fd
4✔
260

261
    def connect(self, address):
4✔
262
        self._address = self._host, self._port = address
4✔
263
        Mocket._address = address
4✔
264

265
    def makefile(self, mode="r", bufsize=-1):
4✔
266
        self._mode = mode
4✔
267
        self._bufsize = bufsize
4✔
268
        return self.fd
4✔
269

270
    def get_entry(self, data):
4✔
271
        return Mocket.get_entry(self._host, self._port, data)
4✔
272

273
    def sendall(self, data, entry=None, *args, **kwargs):
4✔
274
        if entry is None:
4✔
275
            entry = self.get_entry(data)
4✔
276

277
        if entry:
4✔
278
            consume_response = entry.collect(data)
4✔
279
            if consume_response is not False:
4✔
280
                response = entry.get_response()
4✔
281
            else:
282
                response = None
4✔
283
        else:
284
            response = self.true_sendall(data, *args, **kwargs)
4✔
285

286
        if response is not None:
4✔
287
            self.fd.seek(0)
4✔
288
            self.fd.write(response)
4✔
289
            self.fd.truncate()
4✔
290
            self.fd.seek(0)
4✔
291

292
    def read(self, buffersize):
4✔
293
        return self.fd.read(buffersize)
4✔
294

295
    def recv_into(self, buffer, buffersize=None, flags=None):
4✔
296
        return buffer.write(self.read(buffersize))
4✔
297

298
    def recv(self, buffersize, flags=None):
4✔
299
        if Mocket.r_fd and Mocket.w_fd:
4✔
300
            return os.read(Mocket.r_fd, buffersize)
4✔
301
        data = self.read(buffersize)
4✔
302
        if data:
4✔
303
            return data
4✔
304
        # used by Redis mock
305
        exc = BlockingIOError()
4✔
306
        exc.errno = errno.EWOULDBLOCK
4✔
307
        exc.args = (0,)
4✔
308
        raise exc
4✔
309

310
    def true_sendall(self, data, *args, **kwargs):
4✔
311
        if MocketMode().STRICT:
4✔
312
            raise StrictMocketException("Mocket tried to use the real `socket` module.")
4✔
313

314
        req = decode_from_bytes(data)
4✔
315
        # make request unique again
316
        req_signature = _hash_request(hasher, req)
4✔
317
        # port should be always a string
318
        port = text_type(self._port)
4✔
319

320
        # prepare responses dictionary
321
        responses = {}
4✔
322

323
        if Mocket.get_truesocket_recording_dir():
4✔
324
            path = os.path.join(
4✔
325
                Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json"
326
            )
327
            # check if there's already a recorded session dumped to a JSON file
328
            try:
4✔
329
                with io.open(path) as f:
4✔
330
                    responses = json.load(f)
4✔
331
            # if not, create a new dictionary
332
            except (FileNotFoundError, JSONDecodeError):
4✔
333
                pass
4✔
334

335
        try:
4✔
336
            try:
4✔
337
                response_dict = responses[self._host][port][req_signature]
4✔
338
            except KeyError:
4✔
339
                if hasher is not hashlib.md5:
4✔
340
                    # Fallback for backwards compatibility
341
                    req_signature = _hash_request(hashlib.md5, req)
4✔
342
                    response_dict = responses[self._host][port][req_signature]
4✔
343
                else:
344
                    raise
×
345
        except KeyError:
4✔
346
            # preventing next KeyError exceptions
347
            responses.setdefault(self._host, {})
4✔
348
            responses[self._host].setdefault(port, {})
4✔
349
            responses[self._host][port].setdefault(req_signature, {})
4✔
350
            response_dict = responses[self._host][port][req_signature]
4✔
351

352
        # try to get the response from the dictionary
353
        try:
4✔
354
            encoded_response = hexload(response_dict["response"])
4✔
355
        # if not available, call the real sendall
356
        except KeyError:
4✔
357
            host, port = Mocket._address
4✔
358
            host = true_gethostbyname(host)
4✔
359

360
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
4✔
361
                self.true_socket = true_urllib3_ssl_wrap_socket(
4✔
362
                    self.true_socket,
363
                    **self.kwargs,
364
                )
365

366
            try:
4✔
367
                self.true_socket.connect((host, port))
4✔
368
            except (OSError, socket.error, ValueError):
4✔
369
                # already connected
370
                pass
4✔
371
            self.true_socket.sendall(data, *args, **kwargs)
4✔
372
            encoded_response = b""
4✔
373
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L13
374
            while True:
2✔
375
                if (
4✔
376
                    not select.select([self.true_socket], [], [], 0.1)[0]
377
                    and encoded_response
378
                ):
379
                    break
4✔
380
                recv = self.true_socket.recv(self._buflen)
4✔
381

382
                if not recv and encoded_response:
4✔
383
                    break
4✔
384
                encoded_response += recv
4✔
385

386
            # dump the resulting dictionary to a JSON file
387
            if Mocket.get_truesocket_recording_dir():
4✔
388
                # update the dictionary with request and response lines
389
                response_dict["request"] = req
4✔
390
                response_dict["response"] = hexdump(encoded_response)
4✔
391

392
                with io.open(path, mode="w") as f:
4✔
393
                    f.write(
4✔
394
                        decode_from_bytes(
395
                            json.dumps(responses, indent=4, sort_keys=True)
396
                        )
397
                    )
398

399
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
400
        return encoded_response
4✔
401

402
    def send(self, data, *args, **kwargs):  # pragma: no cover
403
        entry = self.get_entry(data)
404
        if not entry or (entry and self._entry != entry):
405
            self.sendall(data, entry=entry, *args, **kwargs)
406
        else:
407
            req = Mocket.last_request()
408
            if hasattr(req, "add_data"):
409
                req.add_data(data)
410
        self._entry = entry
411
        return len(data)
412

413
    def close(self):
4✔
414
        if self.true_socket and not self.true_socket._closed:
4✔
415
            self.true_socket.close()
4✔
416
        self._fd = None
4✔
417

418
    def __getattr__(self, name):
4✔
419
        """Do nothing catchall function, for methods like close() and shutdown()"""
420

421
        def do_nothing(*args, **kwargs):
4✔
422
            pass
4✔
423

424
        return do_nothing
4✔
425

426

427
class Mocket:
4✔
428
    _address = (None, None)
4✔
429
    _entries = collections.defaultdict(list)
4✔
430
    _requests = []
4✔
431
    _namespace = text_type(id(_entries))
4✔
432
    _truesocket_recording_dir = None
4✔
433
    r_fd = None
4✔
434
    w_fd = None
4✔
435

436
    @classmethod
4✔
437
    def register(cls, *entries):
4✔
438
        for entry in entries:
4✔
439
            cls._entries[entry.location].append(entry)
4✔
440

441
    @classmethod
4✔
442
    def get_entry(cls, host, port, data):
4✔
443
        host = host or Mocket._address[0]
4✔
444
        port = port or Mocket._address[1]
4✔
445
        entries = cls._entries.get((host, port), [])
4✔
446
        for entry in entries:
4✔
447
            if entry.can_handle(data):
4✔
448
                return entry
4✔
449

450
    @classmethod
4✔
451
    def collect(cls, data):
4✔
452
        cls.request_list().append(data)
4✔
453

454
    @classmethod
4✔
455
    def reset(cls):
4✔
456
        cls.r_fd = None
4✔
457
        cls.w_fd = None
4✔
458
        cls._entries = collections.defaultdict(list)
4✔
459
        cls._requests = []
4✔
460

461
    @classmethod
4✔
462
    def last_request(cls):
4✔
463
        if cls.has_requests():
4✔
464
            return cls.request_list()[-1]
4✔
465

466
    @classmethod
4✔
467
    def request_list(cls):
4✔
468
        return cls._requests
4✔
469

470
    @classmethod
4✔
471
    def remove_last_request(cls):
4✔
472
        if cls.has_requests():
4✔
473
            del cls._requests[-1]
4✔
474

475
    @classmethod
4✔
476
    def has_requests(cls):
4✔
477
        return bool(cls.request_list())
4✔
478

479
    @staticmethod
4✔
480
    def enable(namespace=None, truesocket_recording_dir=None):
4✔
481
        Mocket._namespace = namespace
4✔
482
        Mocket._truesocket_recording_dir = truesocket_recording_dir
4✔
483

484
        if truesocket_recording_dir:
4✔
485
            # JSON dumps will be saved here
486
            if not os.path.isdir(truesocket_recording_dir):
4✔
487
                raise AssertionError
×
488

489
        socket.socket = socket.__dict__["socket"] = MocketSocket
4✔
490
        socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket
4✔
491
        socket.SocketType = socket.__dict__["SocketType"] = MocketSocket
4✔
492
        socket.create_connection = socket.__dict__[
4✔
493
            "create_connection"
494
        ] = create_connection
495
        socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost"
4✔
496
        socket.gethostbyname = socket.__dict__[
4✔
497
            "gethostbyname"
498
        ] = lambda host: "127.0.0.1"
499
        socket.getaddrinfo = socket.__dict__[
4✔
500
            "getaddrinfo"
501
        ] = lambda host, port, family=None, socktype=None, proto=None, flags=None: [
502
            (2, 1, 6, "", (host, port))
503
        ]
504
        socket.socketpair = socket.__dict__["socketpair"] = socketpair
4✔
505
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket
4✔
506
        ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext
4✔
507
        socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type(
4✔
508
            "\x7f\x00\x00\x01", "utf-8"
509
        )
510
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__[
4✔
511
            "wrap_socket"
512
        ] = FakeSSLContext.wrap_socket
513
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
4✔
514
            "ssl_wrap_socket"
515
        ] = FakeSSLContext.wrap_socket
516
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__[
4✔
517
            "ssl_wrap_socket"
518
        ] = FakeSSLContext.wrap_socket
519
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
4✔
520
            "ssl_wrap_socket"
521
        ] = FakeSSLContext.wrap_socket
522
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
4✔
523
            "match_hostname"
524
        ] = lambda *args: None
525
        if pyopenssl_override:  # pragma: no cover
526
            # Take out the pyopenssl version - use the default implementation
527
            extract_from_urllib3()
528

529
    @staticmethod
4✔
530
    def disable():
4✔
531
        socket.socket = socket.__dict__["socket"] = true_socket
4✔
532
        socket._socketobject = socket.__dict__["_socketobject"] = true_socket
4✔
533
        socket.SocketType = socket.__dict__["SocketType"] = true_socket
4✔
534
        socket.create_connection = socket.__dict__[
4✔
535
            "create_connection"
536
        ] = true_create_connection
537
        socket.gethostname = socket.__dict__["gethostname"] = true_gethostname
4✔
538
        socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname
4✔
539
        socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo
4✔
540
        socket.socketpair = socket.__dict__["socketpair"] = true_socketpair
4✔
541
        ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket
4✔
542
        ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context
4✔
543
        socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton
4✔
544
        urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__[
4✔
545
            "wrap_socket"
546
        ] = true_urllib3_wrap_socket
547
        urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[
4✔
548
            "ssl_wrap_socket"
549
        ] = true_urllib3_ssl_wrap_socket
550
        urllib3.util.ssl_wrap_socket = urllib3.util.__dict__[
4✔
551
            "ssl_wrap_socket"
552
        ] = true_urllib3_ssl_wrap_socket
553
        urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
4✔
554
            "ssl_wrap_socket"
555
        ] = true_urllib3_ssl_wrap_socket
556
        urllib3.connection.match_hostname = urllib3.connection.__dict__[
4✔
557
            "match_hostname"
558
        ] = true_urllib3_match_hostname
559
        Mocket.reset()
4✔
560
        if pyopenssl_override:  # pragma: no cover
561
            # Put the pyopenssl version back in place
562
            inject_into_urllib3()
563

564
    @classmethod
4✔
565
    def get_namespace(cls):
4✔
566
        return cls._namespace
4✔
567

568
    @classmethod
4✔
569
    def get_truesocket_recording_dir(cls):
4✔
570
        return cls._truesocket_recording_dir
4✔
571

572
    @classmethod
4✔
573
    def assert_fail_if_entries_not_served(cls):
4✔
574
        """Mocket checks that all entries have been served at least once."""
575
        if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
4✔
576
            raise AssertionError("Some Mocket entries have not been served")
4✔
577

578

579
class MocketEntry:
4✔
580
    class Response(byte_type):
4✔
581
        @property
4✔
582
        def data(self):
4✔
583
            return self
4✔
584

585
    response_index = 0
4✔
586
    request_cls = byte_type
4✔
587
    response_cls = Response
4✔
588
    responses = None
4✔
589
    _served = None
4✔
590

591
    def __init__(self, location, responses):
4✔
592
        self._served = False
4✔
593
        self.location = location
4✔
594

595
        if not isinstance(responses, collections_abc.Iterable) or isinstance(
4✔
596
            responses, basestring
597
        ):
598
            responses = [responses]
4✔
599

600
        if not responses:
4✔
601
            self.responses = [self.response_cls(encode_to_bytes(""))]
4✔
602
        else:
603
            self.responses = []
4✔
604
            for r in responses:
4✔
605
                if not isinstance(r, BaseException) and not getattr(r, "data", False):
4✔
606
                    if isinstance(r, text_type):
4✔
607
                        r = encode_to_bytes(r)
4✔
608
                    r = self.response_cls(r)
4✔
609
                self.responses.append(r)
4✔
610

611
    @staticmethod
4✔
612
    def can_handle(data):
4✔
613
        return True
4✔
614

615
    def collect(self, data):
4✔
616
        req = self.request_cls(data)
4✔
617
        Mocket.collect(req)
4✔
618

619
    def get_response(self):
4✔
620
        response = self.responses[self.response_index]
4✔
621
        if self.response_index < len(self.responses) - 1:
4✔
622
            self.response_index += 1
4✔
623

624
        self._served = True
4✔
625

626
        if isinstance(response, BaseException):
4✔
627
            raise response
4✔
628

629
        return response.data
4✔
630

631

632
class Mocketizer:
4✔
633
    def __init__(
4✔
634
        self,
635
        instance=None,
636
        namespace=None,
637
        truesocket_recording_dir=None,
638
        strict_mode=False,
639
    ):
640
        self.instance = instance
4✔
641
        self.truesocket_recording_dir = truesocket_recording_dir
4✔
642
        self.namespace = namespace or text_type(id(self))
4✔
643
        MocketMode().STRICT = strict_mode
4✔
644

645
    def enter(self):
4✔
646
        Mocket.enable(
4✔
647
            namespace=self.namespace,
648
            truesocket_recording_dir=self.truesocket_recording_dir,
649
        )
650
        if self.instance:
4✔
651
            self.check_and_call("mocketize_setup")
4✔
652

653
    def __enter__(self):
4✔
654
        self.enter()
4✔
655
        return self
4✔
656

657
    def exit(self):
4✔
658
        if self.instance:
4✔
659
            self.check_and_call("mocketize_teardown")
4✔
660
        Mocket.disable()
4✔
661

662
    def __exit__(self, type, value, tb):
4✔
663
        self.exit()
4✔
664

665
    async def __aenter__(self, *args, **kwargs):
4✔
666
        self.enter()
4✔
667
        return self
4✔
668

669
    async def __aexit__(self, *args, **kwargs):
4✔
670
        self.exit()
4✔
671

672
    def check_and_call(self, method_name):
4✔
673
        method = getattr(self.instance, method_name, None)
4✔
674
        if callable(method):
4✔
675
            method()
4✔
676

677
    @staticmethod
4✔
678
    def factory(test, truesocket_recording_dir, strict_mode, args):
4✔
679
        instance = args[0] if args else None
4✔
680
        namespace = None
4✔
681
        if truesocket_recording_dir:
4✔
682
            namespace = ".".join(
4✔
683
                (
684
                    instance.__class__.__module__,
685
                    instance.__class__.__name__,
686
                    test.__name__,
687
                )
688
            )
689

690
        return Mocketizer(
4✔
691
            instance,
692
            namespace=namespace,
693
            truesocket_recording_dir=truesocket_recording_dir,
694
            strict_mode=strict_mode,
695
        )
696

697

698
def wrapper(test, truesocket_recording_dir=None, strict_mode=False, *args, **kwargs):
4✔
699
    with Mocketizer.factory(test, truesocket_recording_dir, strict_mode, args):
4✔
700
        return test(*args, **kwargs)
4✔
701

702

703
mocketize = get_mocketize(wrapper_=wrapper)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc