• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 11884044999

17 Nov 2024 07:34PM UTC coverage: 99.324% (-0.2%) from 99.545%
11884044999

push

github

web-flow
Refactor: introduce state object (#264)

* refactor: move enable- and disable-functions from mocket.mocket to mocket.inject
* refactor: Mocket - add typing and get rid of cyclic import

104 of 107 new or added lines in 9 files covered. (97.2%)

1 existing line in 1 file now uncovered.

881 of 887 relevant lines covered (99.32%)

6.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.58
/mocket/socket.py
1
import contextlib
7✔
2
import errno
7✔
3
import hashlib
7✔
4
import json
7✔
5
import os
7✔
6
import select
7✔
7
import socket
7✔
8
import ssl
7✔
9
from datetime import datetime, timedelta
7✔
10
from json.decoder import JSONDecodeError
7✔
11

12
from mocket.compat import decode_from_bytes, encode_to_bytes
7✔
13
from mocket.inject import (
7✔
14
    true_gethostbyname,
15
    true_socket,
16
    true_urllib3_ssl_wrap_socket,
17
)
18
from mocket.io import MocketSocketCore
7✔
19
from mocket.mocket import Mocket
7✔
20
from mocket.mode import MocketMode
7✔
21
from mocket.utils import hexdump, hexload
7✔
22

23
xxh32 = None
7✔
24
try:
7✔
25
    from xxhash import xxh32
7✔
26
except ImportError:  # pragma: no cover
27
    with contextlib.suppress(ImportError):
28
        from xxhash_cffi import xxh32
29
hasher = xxh32 or hashlib.md5
7✔
30

31

32
def create_connection(address, timeout=None, source_address=None):
7✔
33
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
7✔
34
    if timeout:
7✔
35
        s.settimeout(timeout)
7✔
36
    s.connect(address)
7✔
37
    return s
7✔
38

39

40
def socketpair(*args, **kwargs):
7✔
41
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
42
    import _socket
7✔
43

44
    return _socket.socketpair(*args, **kwargs)
7✔
45

46

47
def _hash_request(h, req):
7✔
48
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
7✔
49

50

51
class MocketSocket:
7✔
52
    timeout = None
7✔
53
    _fd = None
7✔
54
    family = None
7✔
55
    type = None
7✔
56
    proto = None
7✔
57
    _host = None
7✔
58
    _port = None
7✔
59
    _address = None
7✔
60
    cipher = lambda s: ("ADH", "AES256", "SHA")
7✔
61
    compression = lambda s: ssl.OP_NO_COMPRESSION
7✔
62
    _mode = None
7✔
63
    _bufsize = None
7✔
64
    _secure_socket = False
7✔
65
    _did_handshake = False
7✔
66
    _sent_non_empty_bytes = False
7✔
67
    _io = None
7✔
68

69
    def __init__(
7✔
70
        self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs
71
    ):
72
        self.true_socket = true_socket(family, type, proto)
7✔
73
        self._buflen = 65536
7✔
74
        self._entry = None
7✔
75
        self.family = int(family)
7✔
76
        self.type = int(type)
7✔
77
        self.proto = int(proto)
7✔
78
        self._truesocket_recording_dir = None
7✔
79
        self.kwargs = kwargs
7✔
80

81
    def __str__(self):
82
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
83

84
    def __enter__(self):
7✔
85
        return self
7✔
86

87
    def __exit__(self, exc_type, exc_val, exc_tb):
7✔
88
        self.close()
7✔
89

90
    @property
7✔
91
    def io(self):
7✔
92
        if self._io is None:
7✔
93
            self._io = MocketSocketCore((self._host, self._port))
7✔
94
        return self._io
7✔
95

96
    def fileno(self):
7✔
97
        address = (self._host, self._port)
7✔
98
        r_fd, _ = Mocket.get_pair(address)
7✔
99
        if not r_fd:
7✔
100
            r_fd, w_fd = os.pipe()
7✔
101
            Mocket.set_pair(address, (r_fd, w_fd))
7✔
102
        return r_fd
7✔
103

104
    def gettimeout(self):
7✔
105
        return self.timeout
7✔
106

107
    def setsockopt(self, family, type, proto):
7✔
108
        self.family = family
7✔
109
        self.type = type
7✔
110
        self.proto = proto
7✔
111

112
        if self.true_socket:
7✔
113
            self.true_socket.setsockopt(family, type, proto)
7✔
114

115
    def settimeout(self, timeout):
7✔
116
        self.timeout = timeout
7✔
117

118
    @staticmethod
7✔
119
    def getsockopt(level, optname, buflen=None):
7✔
120
        return socket.SOCK_STREAM
×
121

122
    def do_handshake(self):
7✔
123
        self._did_handshake = True
7✔
124

125
    def getpeername(self):
7✔
126
        return self._address
7✔
127

128
    def setblocking(self, block):
7✔
129
        self.settimeout(None) if block else self.settimeout(0.0)
7✔
130

131
    def getblocking(self):
7✔
132
        return self.gettimeout() is None
7✔
133

134
    def getsockname(self):
7✔
135
        return socket.gethostbyname(self._address[0]), self._address[1]
7✔
136

137
    def getpeercert(self, *args, **kwargs):
7✔
138
        if not (self._host and self._port):
7✔
139
            self._address = self._host, self._port = Mocket._address
7✔
140

141
        now = datetime.now()
7✔
142
        shift = now + timedelta(days=30 * 12)
7✔
143
        return {
7✔
144
            "notAfter": shift.strftime("%b %d %H:%M:%S GMT"),
145
            "subjectAltName": (
146
                ("DNS", f"*.{self._host}"),
147
                ("DNS", self._host),
148
                ("DNS", "*"),
149
            ),
150
            "subject": (
151
                (("organizationName", f"*.{self._host}"),),
152
                (("organizationalUnitName", "Domain Control Validated"),),
153
                (("commonName", f"*.{self._host}"),),
154
            ),
155
        }
156

157
    def unwrap(self):
7✔
UNCOV
158
        return self
4✔
159

160
    def write(self, data):
7✔
161
        return self.send(encode_to_bytes(data))
7✔
162

163
    def connect(self, address):
7✔
164
        self._address = self._host, self._port = address
7✔
165
        Mocket._address = address
7✔
166

167
    def makefile(self, mode="r", bufsize=-1):
7✔
168
        self._mode = mode
7✔
169
        self._bufsize = bufsize
7✔
170
        return self.io
7✔
171

172
    def get_entry(self, data):
7✔
173
        return Mocket.get_entry(self._host, self._port, data)
7✔
174

175
    def sendall(self, data, entry=None, *args, **kwargs):
7✔
176
        if entry is None:
7✔
177
            entry = self.get_entry(data)
7✔
178

179
        if entry:
7✔
180
            consume_response = entry.collect(data)
7✔
181
            response = entry.get_response() if consume_response is not False else None
7✔
182
        else:
183
            response = self.true_sendall(data, *args, **kwargs)
7✔
184

185
        if response is not None:
7✔
186
            self.io.seek(0)
7✔
187
            self.io.write(response)
7✔
188
            self.io.truncate()
7✔
189
            self.io.seek(0)
7✔
190

191
    def read(self, buffersize):
7✔
192
        rv = self.io.read(buffersize)
7✔
193
        if rv:
7✔
194
            self._sent_non_empty_bytes = True
7✔
195
        if self._did_handshake and not self._sent_non_empty_bytes:
7✔
196
            raise ssl.SSLWantReadError("The operation did not complete (read)")
3✔
197
        return rv
7✔
198

199
    def recv_into(self, buffer, buffersize=None, flags=None):
7✔
200
        if hasattr(buffer, "write"):
7✔
201
            return buffer.write(self.read(buffersize))
7✔
202
        # buffer is a memoryview
203
        data = self.read(buffersize)
3✔
204
        if data:
3✔
205
            buffer[: len(data)] = data
×
206
        return len(data)
3✔
207

208
    def recv(self, buffersize, flags=None):
7✔
209
        r_fd, _ = Mocket.get_pair((self._host, self._port))
7✔
210
        if r_fd:
7✔
211
            return os.read(r_fd, buffersize)
7✔
212
        data = self.read(buffersize)
7✔
213
        if data:
7✔
214
            return data
7✔
215
        # used by Redis mock
216
        exc = BlockingIOError()
7✔
217
        exc.errno = errno.EWOULDBLOCK
7✔
218
        exc.args = (0,)
7✔
219
        raise exc
7✔
220

221
    def true_sendall(self, data, *args, **kwargs):
7✔
222
        if not MocketMode().is_allowed((self._host, self._port)):
7✔
223
            MocketMode.raise_not_allowed()
7✔
224

225
        req = decode_from_bytes(data)
7✔
226
        # make request unique again
227
        req_signature = _hash_request(hasher, req)
7✔
228
        # port should be always a string
229
        port = str(self._port)
7✔
230

231
        # prepare responses dictionary
232
        responses = {}
7✔
233

234
        if Mocket.get_truesocket_recording_dir():
7✔
235
            path = os.path.join(
7✔
236
                Mocket.get_truesocket_recording_dir(),
237
                Mocket.get_namespace() + ".json",
238
            )
239
            # check if there's already a recorded session dumped to a JSON file
240
            try:
7✔
241
                with open(path) as f:
7✔
242
                    responses = json.load(f)
7✔
243
            # if not, create a new dictionary
244
            except (FileNotFoundError, JSONDecodeError):
7✔
245
                pass
7✔
246

247
        try:
7✔
248
            try:
7✔
249
                response_dict = responses[self._host][port][req_signature]
7✔
250
            except KeyError:
7✔
251
                if hasher is not hashlib.md5:
7✔
252
                    # Fallback for backwards compatibility
253
                    req_signature = _hash_request(hashlib.md5, req)
7✔
254
                    response_dict = responses[self._host][port][req_signature]
7✔
255
                else:
256
                    raise
×
257
        except KeyError:
7✔
258
            # preventing next KeyError exceptions
259
            responses.setdefault(self._host, {})
7✔
260
            responses[self._host].setdefault(port, {})
7✔
261
            responses[self._host][port].setdefault(req_signature, {})
7✔
262
            response_dict = responses[self._host][port][req_signature]
7✔
263

264
        # try to get the response from the dictionary
265
        try:
7✔
266
            encoded_response = hexload(response_dict["response"])
7✔
267
        # if not available, call the real sendall
268
        except KeyError:
7✔
269
            host, port = self._host, self._port
7✔
270
            host = true_gethostbyname(host)
7✔
271

272
            if isinstance(self.true_socket, true_socket) and self._secure_socket:
7✔
273
                self.true_socket = true_urllib3_ssl_wrap_socket(
7✔
274
                    self.true_socket,
275
                    **self.kwargs,
276
                )
277

278
            with contextlib.suppress(OSError, ValueError):
7✔
279
                # already connected
280
                self.true_socket.connect((host, port))
7✔
281
            self.true_socket.sendall(data, *args, **kwargs)
7✔
282
            encoded_response = b""
7✔
283
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
284
            while True:
5✔
285
                more_to_read = select.select([self.true_socket], [], [], 0.1)[0]
7✔
286
                if not more_to_read and encoded_response:
7✔
287
                    break
7✔
288
                new_content = self.true_socket.recv(self._buflen)
7✔
289
                if not new_content:
7✔
290
                    break
7✔
291
                encoded_response += new_content
7✔
292

293
            # dump the resulting dictionary to a JSON file
294
            if Mocket.get_truesocket_recording_dir():
7✔
295
                # update the dictionary with request and response lines
296
                response_dict["request"] = req
7✔
297
                response_dict["response"] = hexdump(encoded_response)
7✔
298

299
                with open(path, mode="w") as f:
7✔
300
                    f.write(
7✔
301
                        decode_from_bytes(
302
                            json.dumps(responses, indent=4, sort_keys=True)
303
                        )
304
                    )
305

306
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
307
        return encoded_response
7✔
308

309
    def send(self, data, *args, **kwargs):  # pragma: no cover
310
        entry = self.get_entry(data)
311
        if not entry or (entry and self._entry != entry):
312
            kwargs["entry"] = entry
313
            self.sendall(data, *args, **kwargs)
314
        else:
315
            req = Mocket.last_request()
316
            if hasattr(req, "add_data"):
317
                req.add_data(data)
318
        self._entry = entry
319
        return len(data)
320

321
    def close(self):
7✔
322
        if self.true_socket and not self.true_socket._closed:
7✔
323
            self.true_socket.close()
7✔
324
        self._fd = None
7✔
325

326
    def __getattr__(self, name):
7✔
327
        """Do nothing catchall function, for methods like shutdown()"""
328

329
        def do_nothing(*args, **kwargs):
7✔
330
            pass
7✔
331

332
        return do_nothing
7✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc