• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mindflayer / python-mocket / 11930726772

20 Nov 2024 09:42AM UTC coverage: 99.255% (-0.07%) from 99.324%
11930726772

push

github

web-flow
Merge pull request #265 from betaboon/refactor-split-socket-and-ssl-socket

Refactor split socket and ssl socket

197 of 204 new or added lines in 6 files covered. (96.57%)

1 existing line in 1 file now uncovered.

933 of 940 relevant lines covered (99.26%)

6.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.58
/mocket/socket.py
1
from __future__ import annotations
7✔
2

3
import contextlib
7✔
4
import errno
7✔
5
import hashlib
7✔
6
import json
7✔
7
import os
7✔
8
import select
7✔
9
import socket
7✔
10
import ssl
7✔
11
from datetime import datetime, timedelta
7✔
12
from json.decoder import JSONDecodeError
7✔
13
from types import TracebackType
7✔
14
from typing import Any, Type
7✔
15

16
import urllib3.connection
7✔
17
from typing_extensions import Self
7✔
18

19
from mocket.compat import decode_from_bytes, encode_to_bytes
7✔
20
from mocket.entry import MocketEntry
7✔
21
from mocket.io import MocketSocketCore
7✔
22
from mocket.mocket import Mocket
7✔
23
from mocket.mode import MocketMode
7✔
24
from mocket.types import (
7✔
25
    Address,
26
    ReadableBuffer,
27
    WriteableBuffer,
28
    _PeerCertRetDictType,
29
    _RetAddress,
30
)
31
from mocket.utils import hexdump, hexload
7✔
32

33
true_create_connection = socket.create_connection
7✔
34
true_getaddrinfo = socket.getaddrinfo
7✔
35
true_gethostbyname = socket.gethostbyname
7✔
36
true_gethostname = socket.gethostname
7✔
37
true_inet_pton = socket.inet_pton
7✔
38
true_socket = socket.socket
7✔
39
true_socketpair = socket.socketpair
7✔
40
true_urllib3_match_hostname = urllib3.connection.match_hostname
7✔
41

42

43
xxh32 = None
7✔
44
try:
7✔
45
    from xxhash import xxh32
7✔
46
except ImportError:  # pragma: no cover
47
    with contextlib.suppress(ImportError):
48
        from xxhash_cffi import xxh32
49
hasher = xxh32 or hashlib.md5
7✔
50

51

52
def mock_create_connection(address, timeout=None, source_address=None):
7✔
53
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
7✔
54
    if timeout:
7✔
55
        s.settimeout(timeout)
7✔
56
    s.connect(address)
7✔
57
    return s
7✔
58

59

60
def mock_getaddrinfo(
7✔
61
    host: str,
62
    port: int,
63
    family: int = 0,
64
    type: int = 0,
65
    proto: int = 0,
66
    flags: int = 0,
67
) -> list[tuple[int, int, int, str, tuple[str, int]]]:
68
    return [(2, 1, 6, "", (host, port))]
7✔
69

70

71
def mock_gethostbyname(hostname: str) -> str:
7✔
72
    return "127.0.0.1"
7✔
73

74

75
def mock_gethostname() -> str:
7✔
76
    return "localhost"
7✔
77

78

79
def mock_inet_pton(address_family: int, ip_string: str) -> bytes:
7✔
80
    return bytes("\x7f\x00\x00\x01", "utf-8")
7✔
81

82

83
def mock_socketpair(*args, **kwargs):
7✔
84
    """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
85
    import _socket
7✔
86

87
    return _socket.socketpair(*args, **kwargs)
7✔
88

89

90
def mock_urllib3_match_hostname(*args: Any) -> None:
7✔
NEW
91
    return None
×
92

93

94
def _hash_request(h, req):
7✔
95
    return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest()
7✔
96

97

98
class MocketSocket:
7✔
99
    def __init__(
7✔
100
        self,
101
        family: socket.AddressFamily | int = socket.AF_INET,
102
        type: socket.SocketKind | int = socket.SOCK_STREAM,
103
        proto: int = 0,
104
        fileno: int | None = None,
105
        **kwargs: Any,
106
    ) -> None:
107
        self._family = family
7✔
108
        self._type = type
7✔
109
        self._proto = proto
7✔
110

111
        self._kwargs = kwargs
7✔
112
        self._true_socket = true_socket(family, type, proto)
7✔
113

114
        self._buflen = 65536
7✔
115
        self._timeout: float | None = None
7✔
116

117
        self._host = None
7✔
118
        self._port = None
7✔
119
        self._address = None
7✔
120

121
        self._io = None
7✔
122
        self._entry = None
7✔
123

124
    def __str__(self) -> str:
125
        return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
126

127
    def __enter__(self) -> Self:
7✔
128
        return self
7✔
129

130
    def __exit__(
7✔
131
        self,
132
        type_: Type[BaseException] | None,  # noqa: UP006
133
        value: BaseException | None,
134
        traceback: TracebackType | None,
135
    ) -> None:
136
        self.close()
7✔
137

138
    @property
7✔
139
    def family(self) -> int:
7✔
140
        return self._family
7✔
141

142
    @property
7✔
143
    def type(self) -> int:
7✔
144
        return self._type
7✔
145

146
    @property
7✔
147
    def proto(self) -> int:
7✔
148
        return self._proto
7✔
149

150
    @property
7✔
151
    def io(self) -> MocketSocketCore:
7✔
152
        if self._io is None:
7✔
153
            self._io = MocketSocketCore((self._host, self._port))
7✔
154
        return self._io
7✔
155

156
    def fileno(self) -> int:
7✔
157
        address = (self._host, self._port)
7✔
158
        r_fd, _ = Mocket.get_pair(address)
7✔
159
        if not r_fd:
7✔
160
            r_fd, w_fd = os.pipe()
7✔
161
            Mocket.set_pair(address, (r_fd, w_fd))
7✔
162
        return r_fd
7✔
163

164
    def gettimeout(self) -> float | None:
7✔
165
        return self._timeout
7✔
166

167
    # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
168
    def setsockopt(self, family: int, type: int, proto: int) -> None:
7✔
169
        self._family = family
7✔
170
        self._type = type
7✔
171
        self._proto = proto
7✔
172

173
        if self._true_socket:
7✔
174
            self._true_socket.setsockopt(family, type, proto)
7✔
175

176
    def settimeout(self, timeout: float | None) -> None:
7✔
177
        self._timeout = timeout
7✔
178

179
    @staticmethod
7✔
180
    def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
7✔
181
        return socket.SOCK_STREAM
×
182

183
    def getpeername(self) -> _RetAddress:
7✔
184
        return self._address
7✔
185

186
    def setblocking(self, block: bool) -> None:
7✔
187
        self.settimeout(None) if block else self.settimeout(0.0)
7✔
188

189
    def getblocking(self) -> bool:
7✔
190
        return self.gettimeout() is None
7✔
191

192
    def getsockname(self) -> _RetAddress:
7✔
193
        return true_gethostbyname(self._address[0]), self._address[1]
7✔
194

195
    def connect(self, address: Address) -> None:
7✔
196
        self._address = self._host, self._port = address
7✔
197
        Mocket._address = address
7✔
198

199
    def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore:
7✔
200
        return self.io
7✔
201

202
    def get_entry(self, data: bytes) -> MocketEntry | None:
7✔
203
        return Mocket.get_entry(self._host, self._port, data)
7✔
204

205
    def sendall(self, data, entry=None, *args, **kwargs):
7✔
206
        if entry is None:
7✔
207
            entry = self.get_entry(data)
7✔
208

209
        if entry:
7✔
210
            consume_response = entry.collect(data)
7✔
211
            response = entry.get_response() if consume_response is not False else None
7✔
212
        else:
213
            response = self.true_sendall(data, *args, **kwargs)
7✔
214

215
        if response is not None:
7✔
216
            self.io.seek(0)
7✔
217
            self.io.write(response)
7✔
218
            self.io.truncate()
7✔
219
            self.io.seek(0)
7✔
220

221
    def recv_into(
7✔
222
        self,
223
        buffer: WriteableBuffer,
224
        buffersize: int | None = None,
225
        flags: int | None = None,
226
    ) -> int:
227
        if hasattr(buffer, "write"):
7✔
228
            return buffer.write(self.recv(buffersize))
7✔
229

230
        # buffer is a memoryview
NEW
231
        if buffersize is None:
3✔
NEW
232
            buffersize = len(buffer)
3✔
233

NEW
234
        data = self.recv(buffersize)
3✔
235
        if data:
3✔
236
            buffer[: len(data)] = data
3✔
UNCOV
237
        return len(data)
3✔
238

239
    def recv(self, buffersize: int, flags: int | None = None) -> bytes:
7✔
240
        r_fd, _ = Mocket.get_pair((self._host, self._port))
7✔
241
        if r_fd:
7✔
242
            return os.read(r_fd, buffersize)
7✔
243
        data = self.io.read(buffersize)
7✔
244
        if data:
7✔
245
            return data
7✔
246
        # used by Redis mock
247
        exc = BlockingIOError()
7✔
248
        exc.errno = errno.EWOULDBLOCK
7✔
249
        exc.args = (0,)
7✔
250
        raise exc
7✔
251

252
    def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int:
7✔
253
        if not MocketMode().is_allowed((self._host, self._port)):
7✔
254
            MocketMode.raise_not_allowed()
7✔
255

256
        req = decode_from_bytes(data)
7✔
257
        # make request unique again
258
        req_signature = _hash_request(hasher, req)
7✔
259
        # port should be always a string
260
        port = str(self._port)
7✔
261

262
        # prepare responses dictionary
263
        responses = {}
7✔
264

265
        if Mocket.get_truesocket_recording_dir():
7✔
266
            path = os.path.join(
7✔
267
                Mocket.get_truesocket_recording_dir(),
268
                Mocket.get_namespace() + ".json",
269
            )
270
            # check if there's already a recorded session dumped to a JSON file
271
            try:
7✔
272
                with open(path) as f:
7✔
273
                    responses = json.load(f)
7✔
274
            # if not, create a new dictionary
275
            except (FileNotFoundError, JSONDecodeError):
7✔
276
                pass
7✔
277

278
        try:
7✔
279
            try:
7✔
280
                response_dict = responses[self._host][port][req_signature]
7✔
281
            except KeyError:
7✔
282
                if hasher is not hashlib.md5:
7✔
283
                    # Fallback for backwards compatibility
284
                    req_signature = _hash_request(hashlib.md5, req)
7✔
285
                    response_dict = responses[self._host][port][req_signature]
7✔
286
                else:
287
                    raise
×
288
        except KeyError:
7✔
289
            # preventing next KeyError exceptions
290
            responses.setdefault(self._host, {})
7✔
291
            responses[self._host].setdefault(port, {})
7✔
292
            responses[self._host][port].setdefault(req_signature, {})
7✔
293
            response_dict = responses[self._host][port][req_signature]
7✔
294

295
        # try to get the response from the dictionary
296
        try:
7✔
297
            encoded_response = hexload(response_dict["response"])
7✔
298
        # if not available, call the real sendall
299
        except KeyError:
7✔
300
            host, port = self._host, self._port
7✔
301
            host = true_gethostbyname(host)
7✔
302

303
            with contextlib.suppress(OSError, ValueError):
7✔
304
                # already connected
305
                self._true_socket.connect((host, port))
7✔
306
            self._true_socket.sendall(data, *args, **kwargs)
7✔
307
            encoded_response = b""
7✔
308
            # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12
309
            while True:
5✔
310
                more_to_read = select.select([self._true_socket], [], [], 0.1)[0]
7✔
311
                if not more_to_read and encoded_response:
7✔
312
                    break
7✔
313
                new_content = self._true_socket.recv(self._buflen)
7✔
314
                if not new_content:
7✔
315
                    break
7✔
316
                encoded_response += new_content
7✔
317

318
            # dump the resulting dictionary to a JSON file
319
            if Mocket.get_truesocket_recording_dir():
7✔
320
                # update the dictionary with request and response lines
321
                response_dict["request"] = req
7✔
322
                response_dict["response"] = hexdump(encoded_response)
7✔
323

324
                with open(path, mode="w") as f:
7✔
325
                    f.write(
7✔
326
                        decode_from_bytes(
327
                            json.dumps(responses, indent=4, sort_keys=True)
328
                        )
329
                    )
330

331
        # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO
332
        return encoded_response
7✔
333

334
    def send(
335
        self,
336
        data: ReadableBuffer,
337
        *args: Any,
338
        **kwargs: Any,
339
    ) -> int:  # pragma: no cover
340
        entry = self.get_entry(data)
341
        if not entry or (entry and self._entry != entry):
342
            kwargs["entry"] = entry
343
            self.sendall(data, *args, **kwargs)
344
        else:
345
            req = Mocket.last_request()
346
            if hasattr(req, "add_data"):
347
                req.add_data(data)
348
        self._entry = entry
349
        return len(data)
350

351
    def close(self) -> None:
7✔
352
        if self._true_socket and not self._true_socket._closed:
7✔
353
            self._true_socket.close()
7✔
354

355
    def __getattr__(self, name: str) -> Any:
7✔
356
        """Do nothing catchall function, for methods like shutdown()"""
357

358
        def do_nothing(*args: Any, **kwargs: Any) -> Any:
7✔
359
            pass
7✔
360

361
        return do_nothing
7✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc