• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 11988860782

23 Nov 2024 04:57PM UTC coverage: 98.876% (-0.4%) from 99.255%
11988860782

Pull #268

github

web-flow
Merge 199e90344 into 0da27224a
Pull Request #268: Improve readability and backwards-compatibility of injection-code

101 of 105 new or added lines in 4 files covered. (96.19%)

7 existing lines in 2 files now uncovered.

968 of 979 relevant lines covered (98.88%)

6.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.57
/mocket/socket.py
1
from __future__ import annotations
7✔
2

3
import contextlib
7✔
4
import errno
7✔
5
import hashlib
7✔
6
import json
7✔
7
import os
7✔
8
import select
7✔
9
import socket
7✔
10
from json.decoder import JSONDecodeError
7✔
11
from types import TracebackType
7✔
12
from typing import Any, Type
7✔
13

14
import urllib3
7✔
15
from typing_extensions import Self
7✔
16

17
from mocket.compat import decode_from_bytes, encode_to_bytes
7✔
18
from mocket.entry import MocketEntry
7✔
19
from mocket.io import MocketSocketIO
7✔
20
from mocket.mocket import Mocket
7✔
21
from mocket.mode import MocketMode
7✔
22
from mocket.types import (
7✔
23
    Address,
24
    ReadableBuffer,
25
    WriteableBuffer,
26
    _RetAddress,
27
)
28
from mocket.utils import hexdump, hexload
7✔
29

30
true_socket_socket = socket.socket
7✔
31
true_socket_socket_type = socket.SocketType
7✔
32
true_socket_create_connection = socket.create_connection
7✔
33
true_socket_gethostname = socket.gethostname
7✔
34
true_socket_gethostbyname = socket.gethostbyname
7✔
35
true_socket_getaddrinfo = socket.getaddrinfo
7✔
36
true_socket_socketpair = socket.socketpair
7✔
37
true_socket_inet_pton = socket.inet_pton
7✔
38
true_urllib3_match_hostname = urllib3.connection.match_hostname
7✔
39

40

41
xxh32 = None
7✔
42
try:
7✔
43
    from xxhash import xxh32
7✔
44
except ImportError:  # pragma: no cover
45
    with contextlib.suppress(ImportError):
46
        from xxhash_cffi import xxh32
47
hasher = xxh32 or hashlib.md5
7✔
48

49

50
def mock_create_connection(address, timeout=None, source_address=None):
7✔
51
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
7✔
52
    if timeout:
7✔
53
        s.settimeout(timeout)
7✔
54
    s.connect(address)
7✔
55
    return s
7✔
56

57

58
def mock_getaddrinfo(
7✔
59
    host: str,
60
    port: int,
61
    family: int = 0,
62
    type: int = 0,
63
    proto: int = 0,
64
    flags: int = 0,
65
) -> list[tuple[int, int, int, str, tuple[str, int]]]:
66
    return [(2, 1, 6, "", (host, port))]
7✔
67

68

69
def mock_gethostbyname(hostname: str) -> str:
7✔
70
    return "127.0.0.1"
7✔
71

72

73
def mock_gethostname() -> str:
7✔
74
    return "localhost"
7✔
75

76

77
def mock_inet_pton(address_family: int, ip_string: str) -> bytes:
7✔
78
    return bytes("\x7f\x00\x00\x01", "utf-8")
7✔
79

80

81
def mock_socketpair(*args, **kwargs):
7✔
82
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
83
    import _socket
7✔
84

85
    return _socket.socketpair(*args, **kwargs)
7✔
86

87

88
def mock_urllib3_match_hostname(*args: Any) -> None:
7✔
89
    return None
×
90

91

92
def _hash_request(h, req):
7✔
93
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
7✔
94

95

96
class MocketSocket:
7✔
97
    def __init__(
7✔
98
        self,
99
        family: socket.AddressFamily | int = socket.AF_INET,
100
        type: socket.SocketKind | int = socket.SOCK_STREAM,
101
        proto: int = 0,
102
        fileno: int | None = None,
103
        **kwargs: Any,
104
    ) -> None:
105
        self._family = family
7✔
106
        self._type = type
7✔
107
        self._proto = proto
7✔
108

109
        self._kwargs = kwargs
7✔
110
        self._true_socket = true_socket_socket(family, type, proto)
7✔
111

112
        self._buflen = 65536
7✔
113
        self._timeout: float | None = None
7✔
114

115
        self._host = None
7✔
116
        self._port = None
7✔
117
        self._address = None
7✔
118

119
        self._io = None
7✔
120
        self._entry = None
7✔
121

122
    def __str__(self) -> str:
123
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
124

125
    def __enter__(self) -> Self:
7✔
126
        return self
7✔
127

128
    def __exit__(
7✔
129
        self,
130
        type_: Type[BaseException] | None,  # noqa: UP006
131
        value: BaseException | None,
132
        traceback: TracebackType | None,
133
    ) -> None:
134
        self.close()
7✔
135

136
    @property
7✔
137
    def family(self) -> int:
7✔
138
        return self._family
7✔
139

140
    @property
7✔
141
    def type(self) -> int:
7✔
142
        return self._type
7✔
143

144
    @property
7✔
145
    def proto(self) -> int:
7✔
146
        return self._proto
7✔
147

148
    @property
7✔
149
    def io(self) -> MocketSocketIO:
7✔
150
        if self._io is None:
7✔
151
            self._io = MocketSocketIO((self._host, self._port))
7✔
152
        return self._io
7✔
153

154
    def fileno(self) -> int:
7✔
155
        address = (self._host, self._port)
7✔
156
        r_fd, _ = Mocket.get_pair(address)
7✔
157
        if not r_fd:
7✔
158
            r_fd, w_fd = os.pipe()
7✔
159
            Mocket.set_pair(address, (r_fd, w_fd))
7✔
160
        return r_fd
7✔
161

162
    def gettimeout(self) -> float | None:
7✔
163
        return self._timeout
7✔
164

165
    # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
166
    def setsockopt(self, family: int, type: int, proto: int) -> None:
7✔
167
        self._family = family
7✔
168
        self._type = type
7✔
169
        self._proto = proto
7✔
170

171
        if self._true_socket:
7✔
172
            self._true_socket.setsockopt(family, type, proto)
7✔
173

174
    def settimeout(self, timeout: float | None) -> None:
7✔
175
        self._timeout = timeout
7✔
176

177
    @staticmethod
7✔
178
    def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
7✔
179
        return socket.SOCK_STREAM
×
180

181
    def getpeername(self) -> _RetAddress:
7✔
182
        return self._address
7✔
183

184
    def setblocking(self, block: bool) -> None:
7✔
185
        self.settimeout(None) if block else self.settimeout(0.0)
7✔
186

187
    def getblocking(self) -> bool:
7✔
188
        return self.gettimeout() is None
7✔
189

190
    def getsockname(self) -> _RetAddress:
7✔
191
        return true_socket_gethostbyname(self._address[0]), self._address[1]
7✔
192

193
    def connect(self, address: Address) -> None:
7✔
194
        self._address = self._host, self._port = address
7✔
195
        Mocket._address = address
7✔
196

197
    def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO:
7✔
198
        return self.io
7✔
199

200
    def get_entry(self, data: bytes) -> MocketEntry | None:
7✔
201
        return Mocket.get_entry(self._host, self._port, data)
7✔
202

203
    def sendall(self, data, entry=None, *args, **kwargs):
7✔
204
        if entry is None:
7✔
205
            entry = self.get_entry(data)
7✔
206

207
        if entry:
7✔
208
            consume_response = entry.collect(data)
7✔
209
            response = entry.get_response() if consume_response is not False else None
7✔
210
        else:
211
            response = self.true_sendall(data, *args, **kwargs)
7✔
212

213
        if response is not None:
7✔
214
            self.io.seek(0)
7✔
215
            self.io.write(response)
7✔
216
            self.io.truncate()
7✔
217
            self.io.seek(0)
7✔
218

219
    def recv_into(
7✔
220
        self,
221
        buffer: WriteableBuffer,
222
        buffersize: int | None = None,
223
        flags: int | None = None,
224
    ) -> int:
225
        if hasattr(buffer, "write"):
7✔
226
            return buffer.write(self.recv(buffersize))
7✔
227

228
        # buffer is a memoryview
UNCOV
229
        if buffersize is None:
3✔
UNCOV
230
            buffersize = len(buffer)
3✔
231

UNCOV
232
        data = self.recv(buffersize)
3✔
UNCOV
233
        if data:
3✔
UNCOV
234
            buffer[: len(data)] = data
3✔
UNCOV
235
        return len(data)
3✔
236

237
    def recv(self, buffersize: int, flags: int | None = None) -> bytes:
7✔
238
        r_fd, _ = Mocket.get_pair((self._host, self._port))
7✔
239
        if r_fd:
7✔
240
            return os.read(r_fd, buffersize)
7✔
241
        data = self.io.read(buffersize)
7✔
242
        if data:
7✔
243
            return data
7✔
244
        # used by Redis mock
245
        exc = BlockingIOError()
7✔
246
        exc.errno = errno.EWOULDBLOCK
7✔
247
        exc.args = (0,)
7✔
248
        raise exc
7✔
249

250
    def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
7✔
251
        if not MocketMode().is_allowed((self._host, self._port)):
7✔
252
            MocketMode.raise_not_allowed()
7✔
253

254
        req = decode_from_bytes(data)
7✔
255
        # make request unique again
256
        req_signature = _hash_request(hasher, req)
7✔
257
        # port should be always a string
258
        port = str(self._port)
7✔
259

260
        # prepare responses dictionary
261
        responses = {}
7✔
262

263
        if Mocket.get_truesocket_recording_dir():
7✔
264
            path = os.path.join(
7✔
265
                Mocket.get_truesocket_recording_dir(),
266
                Mocket.get_namespace() + ".json",
267
            )
268
            # check if there's already a recorded session dumped to a JSON file
269
            try:
7✔
270
                with open(path) as f:
7✔
271
                    responses = json.load(f)
7✔
272
            # if not, create a new dictionary
273
            except (FileNotFoundError, JSONDecodeError):
7✔
274
                pass
7✔
275

276
        try:
7✔
277
            try:
7✔
278
                response_dict = responses[self._host][port][req_signature]
7✔
279
            except KeyError:
7✔
280
                if hasher is not hashlib.md5:
7✔
281
                    # Fallback for backwards compatibility
282
                    req_signature = _hash_request(hashlib.md5, req)
7✔
283
                    response_dict = responses[self._host][port][req_signature]
7✔
284
                else:
285
                    raise
×
286
        except KeyError:
7✔
287
            # preventing next KeyError exceptions
288
            responses.setdefault(self._host, {})
7✔
289
            responses[self._host].setdefault(port, {})
7✔
290
            responses[self._host][port].setdefault(req_signature, {})
7✔
291
            response_dict = responses[self._host][port][req_signature]
7✔
292

293
        # try to get the response from the dictionary
294
        try:
7✔
295
            encoded_response = hexload(response_dict["response"])
7✔
296
        # if not available, call the real sendall
297
        except KeyError:
7✔
298
            host, port = self._host, self._port
7✔
299
            host = true_socket_gethostbyname(host)
7✔
300

301
            with contextlib.suppress(OSError, ValueError):
7✔
302
                # already connected
303
                self._true_socket.connect((host, port))
7✔
304
            self._true_socket.sendall(data, *args, **kwargs)
7✔
305
            encoded_response = b""
7✔
306
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
307
            while True:
5✔
308
                more_to_read = select.select([self._true_socket], [], [], 0.1)[0]
7✔
309
                if not more_to_read and encoded_response:
7✔
310
                    break
7✔
311
                new_content = self._true_socket.recv(self._buflen)
7✔
312
                if not new_content:
7✔
313
                    break
7✔
314
                encoded_response += new_content
7✔
315

316
            # dump the resulting dictionary to a JSON file
317
            if Mocket.get_truesocket_recording_dir():
7✔
318
                # update the dictionary with request and response lines
319
                response_dict["request"] = req
7✔
320
                response_dict["response"] = hexdump(encoded_response)
7✔
321

322
                with open(path, mode="w") as f:
7✔
323
                    f.write(
7✔
324
                        decode_from_bytes(
325
                            json.dumps(responses, indent=4, sort_keys=True)
326
                        )
327
                    )
328

329
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
330
        return encoded_response
7✔
331

332
    def send(
333
        self,
334
        data: ReadableBuffer,
335
        *args: Any,
336
        **kwargs: Any,
337
    ) -> int:  # pragma: no cover
338
        entry = self.get_entry(data)
339
        if not entry or (entry and self._entry != entry):
340
            kwargs["entry"] = entry
341
            self.sendall(data, *args, **kwargs)
342
        else:
343
            req = Mocket.last_request()
344
            if hasattr(req, "add_data"):
345
                req.add_data(data)
346
        self._entry = entry
347
        return len(data)
348

349
    def close(self) -> None:
7✔
350
        if self._true_socket and not self._true_socket._closed:
7✔
351
            self._true_socket.close()
7✔
352

353
    def __getattr__(self, name: str) -> Any:
7✔
354
        """Do nothing catchall function, for methods like shutdown()"""
355

356
        def do_nothing(*args: Any, **kwargs: Any) -> Any:
7✔
357
            pass
7✔
358

359
        return do_nothing
7✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc