• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 12006176199

25 Nov 2024 08:36AM UTC coverage: 98.424% (-0.8%) from 99.255%
12006176199

Pull #271

github

web-flow
Merge 1f703d518 into 0da27224a
Pull Request #271: Recording: support removing headers from recorded data

13 of 13 new or added lines in 5 files covered. (100.0%)

8 existing lines in 3 files now uncovered.

937 of 952 relevant lines covered (98.42%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.77
/mocket/socket.py
1
from __future__ import annotations
1✔
2

3
import contextlib
1✔
4
import errno
1✔
5
import hashlib
1✔
6
import json
1✔
7
import os
1✔
8
import select
1✔
9
import socket
1✔
10
import re
1✔
11
from json.decoder import JSONDecodeError
1✔
12
from types import TracebackType
1✔
13
from typing import Any, Type
1✔
14

15
import urllib3.connection
1✔
16
from typing_extensions import Self
1✔
17

18
from mocket.compat import decode_from_bytes, encode_to_bytes
1✔
19
from mocket.entry import MocketEntry
1✔
20
from mocket.io import MocketSocketIO
1✔
21
from mocket.mocket import Mocket
1✔
22
from mocket.mode import MocketMode
1✔
23
from mocket.types import (
1✔
24
    Address,
25
    ReadableBuffer,
26
    WriteableBuffer,
27
    _RetAddress,
28
)
29
from mocket.utils import hexdump, hexload
1✔
30

31
true_create_connection = socket.create_connection
1✔
32
true_getaddrinfo = socket.getaddrinfo
1✔
33
true_gethostbyname = socket.gethostbyname
1✔
34
true_gethostname = socket.gethostname
1✔
35
true_inet_pton = socket.inet_pton
1✔
36
true_socket = socket.socket
1✔
37
true_socketpair = socket.socketpair
1✔
38
true_urllib3_match_hostname = urllib3.connection.match_hostname
1✔
39

40

41
xxh32 = None
1✔
42
try:
1✔
43
    from xxhash import xxh32
1✔
44
except ImportError:  # pragma: no cover
45
    with contextlib.suppress(ImportError):
46
        from xxhash_cffi import xxh32
47
hasher = xxh32 or hashlib.md5
1✔
48

49

50
def mock_create_connection(address, timeout=None, source_address=None):
1✔
51
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
1✔
52
    if timeout:
1✔
53
        s.settimeout(timeout)
1✔
54
    s.connect(address)
1✔
55
    return s
1✔
56

57

58
def mock_getaddrinfo(
1✔
59
    host: str,
60
    port: int,
61
    family: int = 0,
62
    type: int = 0,
63
    proto: int = 0,
64
    flags: int = 0,
65
) -> list[tuple[int, int, int, str, tuple[str, int]]]:
66
    return [(2, 1, 6, "", (host, port))]
1✔
67

68

69
def mock_gethostbyname(hostname: str) -> str:
1✔
70
    return "127.0.0.1"
1✔
71

72

73
def mock_gethostname() -> str:
1✔
74
    return "localhost"
1✔
75

76

77
def mock_inet_pton(address_family: int, ip_string: str) -> bytes:
1✔
78
    return bytes("\x7f\x00\x00\x01", "utf-8")
1✔
79

80

81
def mock_socketpair(*args, **kwargs):
1✔
82
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
83
    import _socket
1✔
84

85
    return _socket.socketpair(*args, **kwargs)
1✔
86

87

88
def mock_urllib3_match_hostname(*args: Any) -> None:
1✔
89
    return None
×
90

91

92
def _hash_request(h, req):
1✔
93
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
1✔
94

95

96
class MocketSocket:
1✔
97
    def __init__(
1✔
98
        self,
99
        family: socket.AddressFamily | int = socket.AF_INET,
100
        type: socket.SocketKind | int = socket.SOCK_STREAM,
101
        proto: int = 0,
102
        fileno: int | None = None,
103
        **kwargs: Any,
104
    ) -> None:
105
        self._family = family
1✔
106
        self._type = type
1✔
107
        self._proto = proto
1✔
108

109
        self._kwargs = kwargs
1✔
110
        self._true_socket = true_socket(family, type, proto)
1✔
111

112
        self._buflen = 65536
1✔
113
        self._timeout: float | None = None
1✔
114

115
        self._host = None
1✔
116
        self._port = None
1✔
117
        self._address = None
1✔
118

119
        self._io = None
1✔
120
        self._entry = None
1✔
121

122
    def __str__(self) -> str:
123
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
124

125
    def __enter__(self) -> Self:
1✔
126
        return self
1✔
127

128
    def __exit__(
1✔
129
        self,
130
        type_: Type[BaseException] | None,  # noqa: UP006
131
        value: BaseException | None,
132
        traceback: TracebackType | None,
133
    ) -> None:
134
        self.close()
1✔
135

136
    @property
1✔
137
    def family(self) -> int:
1✔
138
        return self._family
1✔
139

140
    @property
1✔
141
    def type(self) -> int:
1✔
142
        return self._type
1✔
143

144
    @property
1✔
145
    def proto(self) -> int:
1✔
146
        return self._proto
1✔
147

148
    @property
1✔
149
    def io(self) -> MocketSocketIO:
1✔
150
        if self._io is None:
1✔
151
            self._io = MocketSocketIO((self._host, self._port))
1✔
152
        return self._io
1✔
153

154
    def fileno(self) -> int:
1✔
155
        address = (self._host, self._port)
1✔
156
        r_fd, _ = Mocket.get_pair(address)
1✔
157
        if not r_fd:
1✔
158
            r_fd, w_fd = os.pipe()
1✔
159
            Mocket.set_pair(address, (r_fd, w_fd))
1✔
160
        return r_fd
1✔
161

162
    def gettimeout(self) -> float | None:
1✔
163
        return self._timeout
1✔
164

165
    # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
166
    def setsockopt(self, family: int, type: int, proto: int) -> None:
1✔
167
        self._family = family
1✔
168
        self._type = type
1✔
169
        self._proto = proto
1✔
170

171
        if self._true_socket:
1✔
172
            self._true_socket.setsockopt(family, type, proto)
1✔
173

174
    def settimeout(self, timeout: float | None) -> None:
1✔
175
        self._timeout = timeout
1✔
176

177
    @staticmethod
1✔
178
    def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
1✔
179
        return socket.SOCK_STREAM
×
180

181
    def getpeername(self) -> _RetAddress:
1✔
182
        return self._address
1✔
183

184
    def setblocking(self, block: bool) -> None:
1✔
185
        self.settimeout(None) if block else self.settimeout(0.0)
1✔
186

187
    def getblocking(self) -> bool:
1✔
188
        return self.gettimeout() is None
1✔
189

190
    def getsockname(self) -> _RetAddress:
1✔
191
        return true_gethostbyname(self._address[0]), self._address[1]
1✔
192

193
    def connect(self, address: Address) -> None:
1✔
194
        self._address = self._host, self._port = address
1✔
195
        Mocket._address = address
1✔
196

197
    def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO:
1✔
198
        return self.io
1✔
199

200
    def get_entry(self, data: bytes) -> MocketEntry | None:
1✔
201
        return Mocket.get_entry(self._host, self._port, data)
1✔
202

203
    def sendall(self, data, entry=None, *args, **kwargs):
1✔
204
        if entry is None:
1✔
205
            entry = self.get_entry(data)
1✔
206

207
        if entry:
1✔
208
            consume_response = entry.collect(data)
1✔
209
            response = entry.get_response() if consume_response is not False else None
1✔
210
        else:
211
            response = self.true_sendall(data, *args, **kwargs)
1✔
212

213
        if response is not None:
1✔
214
            self.io.seek(0)
1✔
215
            self.io.write(response)
1✔
216
            self.io.truncate()
1✔
217
            self.io.seek(0)
1✔
218

219
    def recv_into(
1✔
220
        self,
221
        buffer: WriteableBuffer,
222
        buffersize: int | None = None,
223
        flags: int | None = None,
224
    ) -> int:
225
        if hasattr(buffer, "write"):
1✔
226
            return buffer.write(self.recv(buffersize))
1✔
227

228
        # buffer is a memoryview
UNCOV
229
        if buffersize is None:
×
UNCOV
230
            buffersize = len(buffer)
×
231

UNCOV
232
        data = self.recv(buffersize)
×
UNCOV
233
        if data:
×
UNCOV
234
            buffer[: len(data)] = data
×
UNCOV
235
        return len(data)
×
236

237
    def recv(self, buffersize: int, flags: int | None = None) -> bytes:
1✔
238
        r_fd, _ = Mocket.get_pair((self._host, self._port))
1✔
239
        if r_fd:
1✔
240
            return os.read(r_fd, buffersize)
1✔
241
        data = self.io.read(buffersize)
1✔
242
        if data:
1✔
243
            return data
1✔
244
        # used by Redis mock
245
        exc = BlockingIOError()
1✔
246
        exc.errno = errno.EWOULDBLOCK
1✔
247
        exc.args = (0,)
1✔
248
        raise exc
1✔
249

250
    def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
1✔
251
        if not MocketMode().is_allowed((self._host, self._port)):
1✔
252
            MocketMode.raise_not_allowed()
1✔
253

254
        req = decode_from_bytes(data)
1✔
255
        # make request unique again
256
        req_signature = _hash_request(hasher, req)
1✔
257
        # port should be always a string
258
        port = str(self._port)
1✔
259

260
        # prepare responses dictionary
261
        responses = {}
1✔
262

263
        if Mocket.get_truesocket_recording_dir():
1✔
264
            path = os.path.join(
1✔
265
                Mocket.get_truesocket_recording_dir(),
266
                Mocket.get_namespace() + ".json",
267
            )
268
            # check if there's already a recorded session dumped to a JSON file
269
            try:
1✔
270
                with open(path) as f:
1✔
271
                    responses = json.load(f)
1✔
272
            # if not, create a new dictionary
273
            except (FileNotFoundError, JSONDecodeError):
1✔
274
                pass
1✔
275

276
        try:
1✔
277
            try:
1✔
278
                response_dict = responses[self._host][port][req_signature]
1✔
279
            except KeyError:
1✔
280
                if hasher is not hashlib.md5:
1✔
281
                    # Fallback for backwards compatibility
282
                    req_signature = _hash_request(hashlib.md5, req)
1✔
283
                    response_dict = responses[self._host][port][req_signature]
1✔
284
                else:
285
                    raise
×
286
        except KeyError:
1✔
287
            # preventing next KeyError exceptions
288
            responses.setdefault(self._host, {})
1✔
289
            responses[self._host].setdefault(port, {})
1✔
290
            responses[self._host][port].setdefault(req_signature, {})
1✔
291
            response_dict = responses[self._host][port][req_signature]
1✔
292

293
        # try to get the response from the dictionary
294
        try:
1✔
295
            encoded_response = hexload(response_dict["response"])
1✔
296
        # if not available, call the real sendall
297
        except KeyError:
1✔
298
            host, port = self._host, self._port
1✔
299
            host = true_gethostbyname(host)
1✔
300

301
            with contextlib.suppress(OSError, ValueError):
1✔
302
                # already connected
303
                self._true_socket.connect((host, port))
1✔
304
            self._true_socket.sendall(data, *args, **kwargs)
1✔
305
            encoded_response = b""
1✔
306
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
307
            while True:
1✔
308
                more_to_read = select.select([self._true_socket], [], [], 0.1)[0]
1✔
309
                if not more_to_read and encoded_response:
1✔
310
                    break
1✔
311
                new_content = self._true_socket.recv(self._buflen)
1✔
312
                if not new_content:
1✔
313
                    break
1✔
314
                encoded_response += new_content
1✔
315

316
            # dump the resulting dictionary to a JSON file
317
            if Mocket.get_truesocket_recording_dir():
1✔
318
                # filter out unwanted headers
319
                for header in Mocket.get_recording_ignored_headers():
1✔
320
                    header_pattern = rf"{header}: .*\r\n"
1✔
321
                    req = re.sub(header_pattern, "", req)
1✔
322

323
                # update the dictionary with request and response lines
324
                response_dict["request"] = req
1✔
325
                response_dict["response"] = hexdump(encoded_response)
1✔
326

327
                with open(path, mode="w") as f:
1✔
328
                    f.write(
1✔
329
                        decode_from_bytes(
330
                            json.dumps(responses, indent=4, sort_keys=True)
331
                        )
332
                    )
333

334
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
335
        return encoded_response
1✔
336

337
    def send(
338
        self,
339
        data: ReadableBuffer,
340
        *args: Any,
341
        **kwargs: Any,
342
    ) -> int:  # pragma: no cover
343
        entry = self.get_entry(data)
344
        if not entry or (entry and self._entry != entry):
345
            kwargs["entry"] = entry
346
            self.sendall(data, *args, **kwargs)
347
        else:
348
            req = Mocket.last_request()
349
            if hasattr(req, "add_data"):
350
                req.add_data(data)
351
        self._entry = entry
352
        return len(data)
353

354
    def close(self) -> None:
1✔
355
        if self._true_socket and not self._true_socket._closed:
1✔
356
            self._true_socket.close()
1✔
357

358
    def __getattr__(self, name: str) -> Any:
1✔
359
        """Do nothing catchall function, for methods like shutdown()"""
360

361
        def do_nothing(*args: Any, **kwargs: Any) -> Any:
1✔
362
            pass
1✔
363

364
        return do_nothing
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc