• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 12006204376

25 Nov 2024 08:37AM UTC coverage: 98.418% (-0.8%) from 99.255%
12006204376

Pull #269

github

web-flow
Merge 86872cd95 into 0da27224a
Pull Request #269: Recording: skip response cache toggle

11 of 11 new or added lines in 4 files covered. (100.0%)

8 existing lines in 3 files now uncovered.

933 of 948 relevant lines covered (98.42%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.73
/mocket/socket.py
1
from __future__ import annotations
1✔
2

3
import contextlib
1✔
4
import errno
1✔
5
import hashlib
1✔
6
import json
1✔
7
import os
1✔
8
import select
1✔
9
import socket
1✔
10
import uuid
1✔
11
from json.decoder import JSONDecodeError
1✔
12
from types import TracebackType
1✔
13
from typing import Any, Type
1✔
14

15
import urllib3.connection
1✔
16
from typing_extensions import Self
1✔
17

18
from mocket.compat import decode_from_bytes, encode_to_bytes
1✔
19
from mocket.entry import MocketEntry
1✔
20
from mocket.io import MocketSocketIO
1✔
21
from mocket.mocket import Mocket
1✔
22
from mocket.mode import MocketMode
1✔
23
from mocket.types import (
1✔
24
    Address,
25
    ReadableBuffer,
26
    WriteableBuffer,
27
    _RetAddress,
28
)
29
from mocket.utils import hexdump, hexload
1✔
30

31
true_create_connection = socket.create_connection
1✔
32
true_getaddrinfo = socket.getaddrinfo
1✔
33
true_gethostbyname = socket.gethostbyname
1✔
34
true_gethostname = socket.gethostname
1✔
35
true_inet_pton = socket.inet_pton
1✔
36
true_socket = socket.socket
1✔
37
true_socketpair = socket.socketpair
1✔
38
true_urllib3_match_hostname = urllib3.connection.match_hostname
1✔
39

40

41
xxh32 = None
1✔
42
try:
1✔
43
    from xxhash import xxh32
1✔
44
except ImportError:  # pragma: no cover
45
    with contextlib.suppress(ImportError):
46
        from xxhash_cffi import xxh32
47
hasher = xxh32 or hashlib.md5
1✔
48

49

50
def mock_create_connection(address, timeout=None, source_address=None):
1✔
51
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
1✔
52
    if timeout:
1✔
53
        s.settimeout(timeout)
1✔
54
    s.connect(address)
1✔
55
    return s
1✔
56

57

58
def mock_getaddrinfo(
1✔
59
    host: str,
60
    port: int,
61
    family: int = 0,
62
    type: int = 0,
63
    proto: int = 0,
64
    flags: int = 0,
65
) -> list[tuple[int, int, int, str, tuple[str, int]]]:
66
    return [(2, 1, 6, "", (host, port))]
1✔
67

68

69
def mock_gethostbyname(hostname: str) -> str:
1✔
70
    return "127.0.0.1"
1✔
71

72

73
def mock_gethostname() -> str:
1✔
74
    return "localhost"
1✔
75

76

77
def mock_inet_pton(address_family: int, ip_string: str) -> bytes:
1✔
78
    return bytes("\x7f\x00\x00\x01", "utf-8")
1✔
79

80

81
def mock_socketpair(*args, **kwargs):
1✔
82
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
83
    import _socket
1✔
84

85
    return _socket.socketpair(*args, **kwargs)
1✔
86

87

88
def mock_urllib3_match_hostname(*args: Any) -> None:
1✔
89
    return None
×
90

91

92
def _hash_request(h, req):
1✔
93
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
1✔
94

95

96
class MocketSocket:
1✔
97
    def __init__(
1✔
98
        self,
99
        family: socket.AddressFamily | int = socket.AF_INET,
100
        type: socket.SocketKind | int = socket.SOCK_STREAM,
101
        proto: int = 0,
102
        fileno: int | None = None,
103
        **kwargs: Any,
104
    ) -> None:
105
        self._family = family
1✔
106
        self._type = type
1✔
107
        self._proto = proto
1✔
108

109
        self._kwargs = kwargs
1✔
110
        self._true_socket = true_socket(family, type, proto)
1✔
111

112
        self._buflen = 65536
1✔
113
        self._timeout: float | None = None
1✔
114

115
        self._host = None
1✔
116
        self._port = None
1✔
117
        self._address = None
1✔
118

119
        self._io = None
1✔
120
        self._entry = None
1✔
121

122
    def __str__(self) -> str:
123
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
124

125
    def __enter__(self) -> Self:
1✔
126
        return self
1✔
127

128
    def __exit__(
1✔
129
        self,
130
        type_: Type[BaseException] | None,  # noqa: UP006
131
        value: BaseException | None,
132
        traceback: TracebackType | None,
133
    ) -> None:
134
        self.close()
1✔
135

136
    @property
1✔
137
    def family(self) -> int:
1✔
138
        return self._family
1✔
139

140
    @property
1✔
141
    def type(self) -> int:
1✔
142
        return self._type
1✔
143

144
    @property
1✔
145
    def proto(self) -> int:
1✔
146
        return self._proto
1✔
147

148
    @property
1✔
149
    def io(self) -> MocketSocketIO:
1✔
150
        if self._io is None:
1✔
151
            self._io = MocketSocketIO((self._host, self._port))
1✔
152
        return self._io
1✔
153

154
    def fileno(self) -> int:
1✔
155
        address = (self._host, self._port)
1✔
156
        r_fd, _ = Mocket.get_pair(address)
1✔
157
        if not r_fd:
1✔
158
            r_fd, w_fd = os.pipe()
1✔
159
            Mocket.set_pair(address, (r_fd, w_fd))
1✔
160
        return r_fd
1✔
161

162
    def gettimeout(self) -> float | None:
1✔
163
        return self._timeout
1✔
164

165
    # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
166
    def setsockopt(self, family: int, type: int, proto: int) -> None:
1✔
167
        self._family = family
1✔
168
        self._type = type
1✔
169
        self._proto = proto
1✔
170

171
        if self._true_socket:
1✔
172
            self._true_socket.setsockopt(family, type, proto)
1✔
173

174
    def settimeout(self, timeout: float | None) -> None:
1✔
175
        self._timeout = timeout
1✔
176

177
    @staticmethod
1✔
178
    def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
1✔
179
        return socket.SOCK_STREAM
×
180

181
    def getpeername(self) -> _RetAddress:
1✔
182
        return self._address
1✔
183

184
    def setblocking(self, block: bool) -> None:
1✔
185
        self.settimeout(None) if block else self.settimeout(0.0)
1✔
186

187
    def getblocking(self) -> bool:
1✔
188
        return self.gettimeout() is None
1✔
189

190
    def getsockname(self) -> _RetAddress:
1✔
191
        return true_gethostbyname(self._address[0]), self._address[1]
1✔
192

193
    def connect(self, address: Address) -> None:
1✔
194
        self._address = self._host, self._port = address
1✔
195
        Mocket._address = address
1✔
196

197
    def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO:
1✔
198
        return self.io
1✔
199

200
    def get_entry(self, data: bytes) -> MocketEntry | None:
1✔
201
        return Mocket.get_entry(self._host, self._port, data)
1✔
202

203
    def sendall(self, data, entry=None, *args, **kwargs):
1✔
204
        if entry is None:
1✔
205
            entry = self.get_entry(data)
1✔
206

207
        if entry:
1✔
208
            consume_response = entry.collect(data)
1✔
209
            response = entry.get_response() if consume_response is not False else None
1✔
210
        else:
211
            response = self.true_sendall(data, *args, **kwargs)
1✔
212

213
        if response is not None:
1✔
214
            self.io.seek(0)
1✔
215
            self.io.write(response)
1✔
216
            self.io.truncate()
1✔
217
            self.io.seek(0)
1✔
218

219
    def recv_into(
1✔
220
        self,
221
        buffer: WriteableBuffer,
222
        buffersize: int | None = None,
223
        flags: int | None = None,
224
    ) -> int:
225
        if hasattr(buffer, "write"):
1✔
226
            return buffer.write(self.recv(buffersize))
1✔
227

228
        # buffer is a memoryview
UNCOV
229
        if buffersize is None:
×
UNCOV
230
            buffersize = len(buffer)
×
231

UNCOV
232
        data = self.recv(buffersize)
×
UNCOV
233
        if data:
×
UNCOV
234
            buffer[: len(data)] = data
×
UNCOV
235
        return len(data)
×
236

237
    def recv(self, buffersize: int, flags: int | None = None) -> bytes:
1✔
238
        r_fd, _ = Mocket.get_pair((self._host, self._port))
1✔
239
        if r_fd:
1✔
240
            return os.read(r_fd, buffersize)
1✔
241
        data = self.io.read(buffersize)
1✔
242
        if data:
1✔
243
            return data
1✔
244
        # used by Redis mock
245
        exc = BlockingIOError()
1✔
246
        exc.errno = errno.EWOULDBLOCK
1✔
247
        exc.args = (0,)
1✔
248
        raise exc
1✔
249

250
    def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
1✔
251
        if not MocketMode().is_allowed((self._host, self._port)):
1✔
252
            MocketMode.raise_not_allowed()
1✔
253

254
        req = decode_from_bytes(data)
1✔
255
        # make request unique again
256
        req_signature_source = (
1✔
257
            str(uuid.uuid4()) if Mocket.get_skip_response_cache() else req
258
        )
259
        req_signature = _hash_request(hasher, req_signature_source)
1✔
260
        # port should be always a string
261
        port = str(self._port)
1✔
262

263
        # prepare responses dictionary
264
        responses = {}
1✔
265

266
        if Mocket.get_truesocket_recording_dir():
1✔
267
            path = os.path.join(
1✔
268
                Mocket.get_truesocket_recording_dir(),
269
                Mocket.get_namespace() + ".json",
270
            )
271
            # check if there's already a recorded session dumped to a JSON file
272
            try:
1✔
273
                with open(path) as f:
1✔
274
                    responses = json.load(f)
1✔
275
            # if not, create a new dictionary
276
            except (FileNotFoundError, JSONDecodeError):
1✔
277
                pass
1✔
278

279
        try:
1✔
280
            try:
1✔
281
                response_dict = responses[self._host][port][req_signature]
1✔
282
            except KeyError:
1✔
283
                if hasher is not hashlib.md5:
1✔
284
                    # Fallback for backwards compatibility
285
                    req_signature = _hash_request(hashlib.md5, req_signature_source)
1✔
286
                    response_dict = responses[self._host][port][req_signature]
1✔
287
                else:
288
                    raise
×
289
        except KeyError:
1✔
290
            # preventing next KeyError exceptions
291
            responses.setdefault(self._host, {})
1✔
292
            responses[self._host].setdefault(port, {})
1✔
293
            responses[self._host][port].setdefault(req_signature, {})
1✔
294
            response_dict = responses[self._host][port][req_signature]
1✔
295

296
        # try to get the response from the dictionary
297
        try:
1✔
298
            encoded_response = hexload(response_dict["response"])
1✔
299
        # if not available, call the real sendall
300
        except KeyError:
1✔
301
            host, port = self._host, self._port
1✔
302
            host = true_gethostbyname(host)
1✔
303

304
            with contextlib.suppress(OSError, ValueError):
1✔
305
                # already connected
306
                self._true_socket.connect((host, port))
1✔
307
            self._true_socket.sendall(data, *args, **kwargs)
1✔
308
            encoded_response = b""
1✔
309
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
310
            while True:
1✔
311
                more_to_read = select.select([self._true_socket], [], [], 0.1)[0]
1✔
312
                if not more_to_read and encoded_response:
1✔
313
                    break
1✔
314
                new_content = self._true_socket.recv(self._buflen)
1✔
315
                if not new_content:
1✔
316
                    break
1✔
317
                encoded_response += new_content
1✔
318

319
            # dump the resulting dictionary to a JSON file
320
            if Mocket.get_truesocket_recording_dir():
1✔
321
                # update the dictionary with request and response lines
322
                response_dict["request"] = req
1✔
323
                response_dict["response"] = hexdump(encoded_response)
1✔
324

325
                with open(path, mode="w") as f:
1✔
326
                    f.write(
1✔
327
                        decode_from_bytes(
328
                            json.dumps(responses, indent=4, sort_keys=True)
329
                        )
330
                    )
331

332
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
333
        return encoded_response
1✔
334

335
    def send(
336
        self,
337
        data: ReadableBuffer,
338
        *args: Any,
339
        **kwargs: Any,
340
    ) -> int:  # pragma: no cover
341
        entry = self.get_entry(data)
342
        if not entry or (entry and self._entry != entry):
343
            kwargs["entry"] = entry
344
            self.sendall(data, *args, **kwargs)
345
        else:
346
            req = Mocket.last_request()
347
            if hasattr(req, "add_data"):
348
                req.add_data(data)
349
        self._entry = entry
350
        return len(data)
351

352
    def close(self) -> None:
1✔
353
        if self._true_socket and not self._true_socket._closed:
1✔
354
            self._true_socket.close()
1✔
355

356
    def __getattr__(self, name: str) -> Any:
1✔
357
        """Do nothing catchall function, for methods like shutdown()"""
358

359
        def do_nothing(*args: Any, **kwargs: Any) -> Any:
1✔
360
            pass
1✔
361

362
        return do_nothing
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc