• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 12013336921

25 Nov 2024 03:28PM UTC coverage: 99.025%. Remained the same
12013336921

Pull #273

github

web-flow
Merge 5405e2ea4 into a5b5e34b8
Pull Request #273: Target `make safetest` got broken

914 of 923 relevant lines covered (99.02%)

6.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.0
/mocket/socket.py
1
from __future__ import annotations
7✔
2

3
import contextlib
7✔
4
import errno
7✔
5
import hashlib
7✔
6
import json
7✔
7
import os
7✔
8
import select
7✔
9
import socket
7✔
10
from json.decoder import JSONDecodeError
7✔
11
from types import TracebackType
7✔
12
from typing import Any, Type
7✔
13

14
from typing_extensions import Self
7✔
15

16
from mocket.compat import decode_from_bytes, encode_to_bytes
7✔
17
from mocket.entry import MocketEntry
7✔
18
from mocket.io import MocketSocketIO
7✔
19
from mocket.mocket import Mocket
7✔
20
from mocket.mode import MocketMode
7✔
21
from mocket.types import (
7✔
22
    Address,
23
    ReadableBuffer,
24
    WriteableBuffer,
25
    _RetAddress,
26
)
27
from mocket.utils import hexdump, hexload
7✔
28

29
true_gethostbyname = socket.gethostbyname
7✔
30
true_socket = socket.socket
7✔
31

32

33
xxh32 = None
7✔
34
try:
7✔
35
    from xxhash import xxh32
7✔
36
except ImportError:  # pragma: no cover
37
    with contextlib.suppress(ImportError):
38
        from xxhash_cffi import xxh32
39
hasher = xxh32 or hashlib.md5
7✔
40

41

42
def mock_create_connection(address, timeout=None, source_address=None):
7✔
43
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
7✔
44
    if timeout:
7✔
45
        s.settimeout(timeout)
7✔
46
    s.connect(address)
7✔
47
    return s
7✔
48

49

50
def mock_getaddrinfo(
7✔
51
    host: str,
52
    port: int,
53
    family: int = 0,
54
    type: int = 0,
55
    proto: int = 0,
56
    flags: int = 0,
57
) -> list[tuple[int, int, int, str, tuple[str, int]]]:
58
    return [(2, 1, 6, "", (host, port))]
7✔
59

60

61
def mock_gethostbyname(hostname: str) -> str:
7✔
62
    return "127.0.0.1"
7✔
63

64

65
def mock_gethostname() -> str:
7✔
66
    return "localhost"
7✔
67

68

69
def mock_inet_pton(address_family: int, ip_string: str) -> bytes:
7✔
70
    return bytes("\x7f\x00\x00\x01", "utf-8")
7✔
71

72

73
def mock_socketpair(*args, **kwargs):
7✔
74
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
75
    import _socket
7✔
76

77
    return _socket.socketpair(*args, **kwargs)
7✔
78

79

80
def _hash_request(h, req):
7✔
81
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
7✔
82

83

84
class MocketSocket:
7✔
85
    def __init__(
7✔
86
        self,
87
        family: socket.AddressFamily | int = socket.AF_INET,
88
        type: socket.SocketKind | int = socket.SOCK_STREAM,
89
        proto: int = 0,
90
        fileno: int | None = None,
91
        **kwargs: Any,
92
    ) -> None:
93
        self._family = family
7✔
94
        self._type = type
7✔
95
        self._proto = proto
7✔
96

97
        self._kwargs = kwargs
7✔
98
        self._true_socket = true_socket(family, type, proto)
7✔
99

100
        self._buflen = 65536
7✔
101
        self._timeout: float | None = None
7✔
102

103
        self._host = None
7✔
104
        self._port = None
7✔
105
        self._address = None
7✔
106

107
        self._io = None
7✔
108
        self._entry = None
7✔
109

110
    def __str__(self) -> str:
111
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
112

113
    def __enter__(self) -> Self:
7✔
114
        return self
7✔
115

116
    def __exit__(
7✔
117
        self,
118
        type_: Type[BaseException] | None,  # noqa: UP006
119
        value: BaseException | None,
120
        traceback: TracebackType | None,
121
    ) -> None:
122
        self.close()
7✔
123

124
    @property
7✔
125
    def family(self) -> int:
7✔
126
        return self._family
7✔
127

128
    @property
7✔
129
    def type(self) -> int:
7✔
130
        return self._type
7✔
131

132
    @property
7✔
133
    def proto(self) -> int:
7✔
134
        return self._proto
7✔
135

136
    @property
7✔
137
    def io(self) -> MocketSocketIO:
7✔
138
        if self._io is None:
7✔
139
            self._io = MocketSocketIO((self._host, self._port))
7✔
140
        return self._io
7✔
141

142
    def fileno(self) -> int:
7✔
143
        address = (self._host, self._port)
7✔
144
        r_fd, _ = Mocket.get_pair(address)
7✔
145
        if not r_fd:
7✔
146
            r_fd, w_fd = os.pipe()
7✔
147
            Mocket.set_pair(address, (r_fd, w_fd))
7✔
148
        return r_fd
7✔
149

150
    def gettimeout(self) -> float | None:
7✔
151
        return self._timeout
7✔
152

153
    # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
154
    def setsockopt(self, family: int, type: int, proto: int) -> None:
7✔
155
        self._family = family
7✔
156
        self._type = type
7✔
157
        self._proto = proto
7✔
158

159
        if self._true_socket:
7✔
160
            self._true_socket.setsockopt(family, type, proto)
7✔
161

162
    def settimeout(self, timeout: float | None) -> None:
7✔
163
        self._timeout = timeout
7✔
164

165
    @staticmethod
7✔
166
    def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
7✔
167
        return socket.SOCK_STREAM
×
168

169
    def getpeername(self) -> _RetAddress:
7✔
170
        return self._address
7✔
171

172
    def setblocking(self, block: bool) -> None:
7✔
173
        self.settimeout(None) if block else self.settimeout(0.0)
7✔
174

175
    def getblocking(self) -> bool:
7✔
176
        return self.gettimeout() is None
7✔
177

178
    def getsockname(self) -> _RetAddress:
7✔
179
        return true_gethostbyname(self._address[0]), self._address[1]
7✔
180

181
    def connect(self, address: Address) -> None:
7✔
182
        self._address = self._host, self._port = address
7✔
183
        Mocket._address = address
7✔
184

185
    def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO:
7✔
186
        return self.io
7✔
187

188
    def get_entry(self, data: bytes) -> MocketEntry | None:
7✔
189
        return Mocket.get_entry(self._host, self._port, data)
7✔
190

191
    def sendall(self, data, entry=None, *args, **kwargs):
7✔
192
        if entry is None:
7✔
193
            entry = self.get_entry(data)
7✔
194

195
        if entry:
7✔
196
            consume_response = entry.collect(data)
7✔
197
            response = entry.get_response() if consume_response is not False else None
7✔
198
        else:
199
            response = self.true_sendall(data, *args, **kwargs)
7✔
200

201
        if response is not None:
7✔
202
            self.io.seek(0)
7✔
203
            self.io.write(response)
7✔
204
            self.io.truncate()
7✔
205
            self.io.seek(0)
7✔
206

207
    def recv_into(
7✔
208
        self,
209
        buffer: WriteableBuffer,
210
        buffersize: int | None = None,
211
        flags: int | None = None,
212
    ) -> int:
213
        if hasattr(buffer, "write"):
7✔
214
            return buffer.write(self.recv(buffersize))
7✔
215

216
        # buffer is a memoryview
217
        if buffersize is None:
3✔
218
            buffersize = len(buffer)
3✔
219

220
        data = self.recv(buffersize)
3✔
221
        if data:
3✔
222
            buffer[: len(data)] = data
3✔
223
        return len(data)
3✔
224

225
    def recv(self, buffersize: int, flags: int | None = None) -> bytes:
7✔
226
        r_fd, _ = Mocket.get_pair((self._host, self._port))
7✔
227
        if r_fd:
7✔
228
            return os.read(r_fd, buffersize)
7✔
229
        data = self.io.read(buffersize)
7✔
230
        if data:
7✔
231
            return data
7✔
232
        # used by Redis mock
233
        exc = BlockingIOError()
7✔
234
        exc.errno = errno.EWOULDBLOCK
7✔
235
        exc.args = (0,)
7✔
236
        raise exc
7✔
237

238
    def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
7✔
239
        if not MocketMode().is_allowed((self._host, self._port)):
7✔
240
            MocketMode.raise_not_allowed()
7✔
241

242
        req = decode_from_bytes(data)
7✔
243
        # make request unique again
244
        req_signature = _hash_request(hasher, req)
7✔
245
        # port should be always a string
246
        port = str(self._port)
7✔
247

248
        # prepare responses dictionary
249
        responses = {}
7✔
250

251
        if Mocket.get_truesocket_recording_dir():
7✔
252
            path = os.path.join(
7✔
253
                Mocket.get_truesocket_recording_dir(),
254
                Mocket.get_namespace() + ".json",
255
            )
256
            # check if there's already a recorded session dumped to a JSON file
257
            try:
7✔
258
                with open(path) as f:
7✔
259
                    responses = json.load(f)
7✔
260
            # if not, create a new dictionary
261
            except (FileNotFoundError, JSONDecodeError):
7✔
262
                pass
7✔
263

264
        try:
7✔
265
            try:
7✔
266
                response_dict = responses[self._host][port][req_signature]
7✔
267
            except KeyError:
7✔
268
                if hasher is not hashlib.md5:
7✔
269
                    # Fallback for backwards compatibility
270
                    req_signature = _hash_request(hashlib.md5, req)
7✔
271
                    response_dict = responses[self._host][port][req_signature]
7✔
272
                else:
273
                    raise
×
274
        except KeyError:
7✔
275
            # preventing next KeyError exceptions
276
            responses.setdefault(self._host, {})
7✔
277
            responses[self._host].setdefault(port, {})
7✔
278
            responses[self._host][port].setdefault(req_signature, {})
7✔
279
            response_dict = responses[self._host][port][req_signature]
7✔
280

281
        # try to get the response from the dictionary
282
        try:
7✔
283
            encoded_response = hexload(response_dict["response"])
7✔
284
        # if not available, call the real sendall
285
        except KeyError:
7✔
286
            host, port = self._host, self._port
7✔
287
            host = true_gethostbyname(host)
7✔
288

289
            with contextlib.suppress(OSError, ValueError):
7✔
290
                # already connected
291
                self._true_socket.connect((host, port))
7✔
292
            self._true_socket.sendall(data, *args, **kwargs)
7✔
293
            encoded_response = b""
7✔
294
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
295
            while True:
5✔
296
                more_to_read = select.select([self._true_socket], [], [], 0.1)[0]
7✔
297
                if not more_to_read and encoded_response:
7✔
298
                    break
7✔
299
                new_content = self._true_socket.recv(self._buflen)
7✔
300
                if not new_content:
7✔
301
                    break
7✔
302
                encoded_response += new_content
7✔
303

304
            # dump the resulting dictionary to a JSON file
305
            if Mocket.get_truesocket_recording_dir():
7✔
306
                # update the dictionary with request and response lines
307
                response_dict["request"] = req
7✔
308
                response_dict["response"] = hexdump(encoded_response)
7✔
309

310
                with open(path, mode="w") as f:
7✔
311
                    f.write(
7✔
312
                        decode_from_bytes(
313
                            json.dumps(responses, indent=4, sort_keys=True)
314
                        )
315
                    )
316

317
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
318
        return encoded_response
7✔
319

320
    def send(
321
        self,
322
        data: ReadableBuffer,
323
        *args: Any,
324
        **kwargs: Any,
325
    ) -> int:  # pragma: no cover
326
        entry = self.get_entry(data)
327
        if not entry or (entry and self._entry != entry):
328
            kwargs["entry"] = entry
329
            self.sendall(data, *args, **kwargs)
330
        else:
331
            req = Mocket.last_request()
332
            if hasattr(req, "add_data"):
333
                req.add_data(data)
334
        self._entry = entry
335
        return len(data)
336

337
    def close(self) -> None:
7✔
338
        if self._true_socket and not self._true_socket._closed:
7✔
339
            self._true_socket.close()
7✔
340

341
    def __getattr__(self, name: str) -> Any:
7✔
342
        """Do nothing catchall function, for methods like shutdown()"""
343

344
        def do_nothing(*args: Any, **kwargs: Any) -> Any:
7✔
345
            pass
7✔
346

347
        return do_nothing
7✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc