• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Synss / python-mbedtls / 15400568096

02 Jun 2025 07:03PM UTC coverage: 85.758% (-1.7%) from 87.47%
15400568096

Pull #127

github

pre-commit-ci[bot]
[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
Pull Request #127: [pre-commit.ci] pre-commit autoupdate

2523 of 2942 relevant lines covered (85.76%)

0.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.17
/src/mbedtls/tls.py
1
# SPDX-License-Identifier: MIT
2
# Copyright (c) 2018, Mathias Laurin
3

4
from __future__ import annotations
1✔
5

6
import enum
1✔
7
import errno
1✔
8
import os
1✔
9
import socket as _pysocket
1✔
10
import struct
1✔
11
import sys
1✔
12
from contextlib import suppress
1✔
13
from typing import Any, Final, NoReturn, Optional, Tuple, Union, cast, overload
14

15
from ._tls import HandshakeStep as HandshakeStep
1✔
16
from ._tls import HelloVerifyRequest
1✔
17
from ._tls import MbedTLSBuffer as TLSWrappedBuffer
1✔
18
from ._tls import (
1✔
19
    Purpose,
20
    RaggedEOF,
21
    TLSSession,
22
    TrustStore,
23
    WantReadError,
24
    WantWriteError,
25
    _BaseContext,
26
    _tls_from_version,
27
    _tls_to_version,
28
    ciphers_available,
29
)
30
from ._tlsi import DTLSConfiguration as DTLSConfiguration
1✔
31
from ._tlsi import DTLSVersion as DTLSVersion
1✔
32
from ._tlsi import MaxFragmentLength as MaxFragmentLength
1✔
33
from ._tlsi import NextProtocol as NextProtocol
1✔
34
from ._tlsi import PrivateKey as PrivateKey
1✔
35
from ._tlsi import ServerNameCallback as ServerNameCallback
1✔
36
from ._tlsi import TLSConfiguration as TLSConfiguration
1✔
37
from ._tlsi import TLSVersion as TLSVersion
1✔
38

39
if sys.version_info < (3, 10):
40
    from typing_extensions import TypeAlias
41
else:
42
    from typing import TypeAlias
43

44

45
__all__ = (
1✔
46
    "ClientContext",
47
    "DTLSConfiguration",
48
    "DTLSVersion",
49
    "HelloVerifyRequest",
50
    "MaxFragmentLength",
51
    "NextProtocol",
52
    "PrivateKey",
53
    "Purpose",
54
    "RaggedEOF",
55
    "ServerContext",
56
    "ServerNameCallback",
57
    "TLSConfiguration",
58
    "TLSRecordHeader",
59
    "TLSSession",
60
    "TLSVersion",
61
    "TLSWrappedBuffer",
62
    "TLSWrappedSocket",
63
    "TrustStore",
64
    "WantReadError",
65
    "WantWriteError",
66
    "ciphers_available",
67
)
68

69
# `_Address` stolen from `_socket.pyi`.
70
_Address: TypeAlias = Union[Tuple[Any, ...], str]
71
_WriteableBuffer: TypeAlias = Union[memoryview, bytearray]
1✔
72

73

74
class TLSRecordHeader:
1✔
75
    """Encode/decode TLS record protocol format."""
76

77
    __slots__ = ("record_type", "version", "length")
1✔
78
    fmt = "!BHH"
1✔
79

80
    class RecordType(enum.IntEnum):
1✔
81
        CHANGE_CIPHER_SPEC = 0x14
1✔
82
        ALERT = 0x15
1✔
83
        HANDSHAKE = 0x16
1✔
84
        APPLICATION_DATA = 0x17
1✔
85

86
    def __init__(
1✔
87
        self,
88
        record_type: Union[int, TLSRecordHeader.RecordType],
89
        version: Union[int, Tuple[int, int], TLSVersion],
90
        length: int,
91
    ) -> None:
92
        def parse_version(
1✔
93
            v: Union[int, Tuple[int, int], TLSVersion],
94
        ) -> TLSVersion:
95
            if isinstance(v, TLSVersion):
1✔
96
                return v
1✔
97
            if isinstance(v, int):
×
98
                return TLSVersion(v)
×
99
            return TLSVersion(((v[0] & 0xFF) << 8) + v[1] & 0xFF)
×
100

101
        self.record_type: Final = TLSRecordHeader.RecordType(record_type)
1✔
102
        self.version: Final = parse_version(version)
1✔
103
        self.length: Final = length
1✔
104

105
    def __str__(self) -> str:
1✔
106
        return "{}({}, {}, {})".format(
×
107
            type(self).__name__,
108
            self.record_type,
109
            self.version,
110
            self.length,
111
        )
112

113
    def __repr__(self) -> str:
114
        return "{}({!r}, {!r}, {!r})".format(
115
            type(self).__name__,
116
            self.record_type,
117
            self.version,
118
            self.length,
119
        )
120

121
    def __eq__(self, other: object) -> bool:
1✔
122
        if not isinstance(other, TLSRecordHeader):
1✔
123
            return NotImplemented
×
124
        return (
1✔
125
            self.record_type is other.record_type
126
            and self.version is other.version
127
            and self.length == other.length
128
        )
129

130
    def __hash__(self) -> int:
1✔
131
        return 0x5AFE ^ self.record_type ^ self.version.value ^ self.length
×
132

133
    def __len__(self) -> int:
1✔
134
        return 5
1✔
135

136
    def __bytes__(self) -> bytes:
1✔
137
        maj, min_ = _tls_from_version(self.version)
1✔
138
        version = ((maj & 0xFF) << 8) + (min_ & 0xFF)
1✔
139
        return struct.pack(
1✔
140
            TLSRecordHeader.fmt,
141
            self.record_type,
142
            version,
143
            self.length,
144
        )
145

146
    @classmethod
1✔
147
    def from_bytes(cls, header: bytes) -> TLSRecordHeader:
1✔
148
        record_type, maj_min_version, length = struct.unpack(
1✔
149
            TLSRecordHeader.fmt, header[:5]
150
        )
151
        maj, min_ = (maj_min_version >> 8) & 0xFF, maj_min_version & 0xFF
1✔
152
        return cls(
1✔
153
            TLSRecordHeader.RecordType(record_type),
154
            _tls_to_version((maj, min_)),
155
            length,
156
        )
157

158

159
class ClientContext(_BaseContext):
1✔
160
    # _pep543.ClientContext
161

162
    @property
1✔
163
    def _purpose(self) -> Purpose:
1✔
164
        return Purpose.CLIENT_AUTH
1✔
165

166
    def wrap_socket(
1✔
167
        self, socket: _pysocket.socket, server_hostname: Optional[str]
168
    ) -> TLSWrappedSocket:
169
        """Wrap an existing Python socket object ``socket`` and return a
170
        ``TLSWrappedSocket`` object. ``socket`` must be a ``SOCK_STREAM``
171
        socket: all other socket types are unsupported.
172

173
        Args:
174
            socket: The socket to wrap.
175
            server_hostname: The hostname of the service
176
                which we are connecting to.  Pass ``None`` if hostname
177
                validation is not desired.  This parameter has no
178
                default value because opting-out hostname validation is
179
                dangerous and should not be the default behavior.
180

181
        """
182
        buffer = self.wrap_buffers(server_hostname)
1✔
183
        return TLSWrappedSocket(socket, buffer)
1✔
184

185
    def wrap_buffers(self, server_hostname: Optional[str]) -> TLSWrappedBuffer:
1✔
186
        """Create an in-memory stream for TLS."""
187
        # PEP 543
188
        return TLSWrappedBuffer(self, server_hostname)
1✔
189

190

191
class ServerContext(_BaseContext):
1✔
192
    # _pep543.ServerContext
193

194
    @property
1✔
195
    def _purpose(self) -> Purpose:
1✔
196
        return Purpose.SERVER_AUTH
1✔
197

198
    def wrap_socket(self, socket: _pysocket.socket) -> TLSWrappedSocket:
1✔
199
        """Wrap an existing Python socket object ``socket``."""
200
        buffer = self.wrap_buffers()
1✔
201
        return TLSWrappedSocket(socket, buffer)
1✔
202

203
    def wrap_buffers(self) -> TLSWrappedBuffer:
1✔
204
        # PEP 543
205
        return TLSWrappedBuffer(self)
1✔
206

207

208
class TLSWrappedSocket:
1✔
209
    # pylint: disable=too-many-instance-attributes
210
    # _pep543.TLSWrappedSocket
211
    def __init__(
1✔
212
        self, socket: _pysocket.socket, buffer: TLSWrappedBuffer
213
    ) -> None:
214
        super().__init__()
1✔
215
        self._socket = socket
1✔
216
        self._buffer = buffer
1✔
217
        self._context = buffer.context
1✔
218
        self._closed = False
1✔
219

220
    @property
1✔
221
    def context(self) -> _BaseContext:
1✔
222
        return self._buffer.context
×
223

224
    @property
1✔
225
    def _buffer(self) -> TLSWrappedBuffer:
1✔
226
        return cast(TLSWrappedBuffer, self.__dict__["_buffer"])
1✔
227

228
    @_buffer.setter
1✔
229
    def _buffer(self, __buffer: TLSWrappedBuffer) -> None:
1✔
230
        self.__dict__["_buffer"] = __buffer
1✔
231
        self.setcookieparam = __buffer.setcookieparam
1✔
232
        self.cipher = __buffer.cipher
1✔
233
        self.negotiated_protocol = __buffer.negotiated_protocol
1✔
234
        self.negotiated_tls_version = __buffer.negotiated_tls_version
1✔
235

236
    @property
1✔
237
    def _socket(self) -> _pysocket.socket:
1✔
238
        return cast(_pysocket.socket, self.__dict__["_socket"])
1✔
239

240
    @_socket.setter
1✔
241
    def _socket(self, __socket: _pysocket.socket) -> None:
1✔
242
        self.__dict__["_socket"] = __socket
1✔
243
        # PEP 543 requires the full socket API.
244
        self.family = __socket.family
1✔
245
        self.proto = __socket.proto
1✔
246
        self.type = __socket.type
1✔
247
        self.bind = __socket.bind
1✔
248
        self.connect = __socket.connect
1✔
249
        self.connect_ex = __socket.connect_ex
1✔
250
        self.fileno = __socket.fileno
1✔
251
        self.getpeername = __socket.getpeername
1✔
252
        self.getsockname = __socket.getsockname
1✔
253
        self.getsockopt = __socket.getsockopt
1✔
254
        self.listen = __socket.listen
1✔
255
        self.makefile = __socket.makefile
1✔
256
        self.setblocking = __socket.setblocking
1✔
257
        self.settimeout = __socket.settimeout
1✔
258
        self.gettimeout = __socket.gettimeout
1✔
259
        self.setsockopt = __socket.setsockopt
1✔
260

261
    def __getstate__(self) -> NoReturn:
1✔
262
        raise TypeError(f"cannot pickle {self.__class__.__name__!r} object")
1✔
263

264
    def __enter__(self) -> TLSWrappedSocket:
1✔
265
        return self
1✔
266

267
    def __exit__(self, *exc_info: object) -> None:
1✔
268
        if not self._closed:
1✔
269
            self.close()
1✔
270

271
    def __str__(self) -> str:
1✔
272
        return str(self._socket)
×
273

274
    @property
1✔
275
    def _handshake_state(self) -> HandshakeStep:
1✔
276
        # pylint: disable=protected-access
277
        return self._buffer._handshake_state
×
278

279
    def setmtu(self, mtu: int) -> None:
1✔
280
        """Set Maxiumum Transport Unit (MTU) for DTLS.
281

282
        Set to zero to unset.
283

284
        Raises:
285
            OverflowError: If value cannot be converted to UInt16.
286

287
        """
288
        self._buffer.setmtu(mtu)
×
289

290
    CHUNK_SIZE: Final = 4096
1✔
291

292
    def accept(self) -> Tuple[TLSWrappedSocket, _Address]:
1✔
293
        if self.type == _pysocket.SOCK_STREAM:
×
294
            conn, address = self._socket.accept()
×
295
        else:
296
            _, address = self._socket.recvfrom(
×
297
                TLSWrappedSocket.CHUNK_SIZE, _pysocket.MSG_PEEK
298
            )
299
            # Use this socket to communicate with the client and bind
300
            # another one for the next connection.  This procedure is
301
            # adapted from `mbedtls_net_accept()`.
302
            sockname = self.getsockname()
×
303
            conn = _pysocket.fromfd(self.fileno(), self.family, self.type)
×
304
            conn.connect(address)
×
305
            # Closing the socket on Python 2.7 and 3.4 invalidates
306
            # the accessors.  So we should get the values first.
307
            family, type_, proto = self.family, self.type, self.proto
×
308
            self.close()
×
309
            self._socket = _pysocket.socket(family, type_, proto)
×
310
            self.setsockopt(_pysocket.SOL_SOCKET, _pysocket.SO_REUSEADDR, 1)
×
311
            self.bind(sockname)
×
312
        if isinstance(self.context, ClientContext):
×
313
            # pylint: disable=protected-access
314
            # Probably not very useful but there is not reason to forbid it.
315
            return (
×
316
                self.context.wrap_socket(conn, self._buffer._server_hostname),
317
                address,
318
            )
319
        assert isinstance(self.context, ServerContext)
×
320
        return self.context.wrap_socket(conn), address
×
321

322
    def close(self) -> None:
1✔
323
        self._closed = True
1✔
324
        self._buffer.shutdown()
1✔
325
        self._socket.close()
1✔
326

327
    def recv(self, bufsize: int, flags: int = 0) -> bytes:
1✔
328
        encrypted = self._socket.recv(bufsize, flags)
×
329
        if not encrypted:
×
330
            return b""
×
331
        self._buffer.receive_from_network(encrypted)
×
332
        return self._buffer.read(bufsize)
×
333

334
    def recv_into(
1✔
335
        self,
336
        buffer: _WriteableBuffer,
337
        nbytes: Optional[int] = None,
338
        flags: int = 0,
339
    ) -> int:
340
        if nbytes is None:
×
341
            bufsize = len(buffer)
×
342
        else:
343
            bufsize = min(len(buffer), nbytes)
×
344
        encrypted = self._socket.recv(bufsize, flags)
×
345
        if not encrypted:
×
346
            return 0
×
347
        self._buffer.receive_from_network(encrypted)
×
348
        return self._buffer.readinto(buffer, len(encrypted))
×
349

350
    def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, _Address]:
1✔
351
        encrypted, addr = self._socket.recvfrom(bufsize, flags)
×
352
        if not encrypted:
×
353
            return b"", addr
×
354
        self._buffer.receive_from_network(encrypted)
×
355
        return self._buffer.read(bufsize), addr
×
356

357
    def recvfrom_into(
1✔
358
        self,
359
        buffer: _WriteableBuffer,
360
        nbytes: Optional[int] = None,
361
        flags: int = 0,
362
    ) -> Tuple[int, _Address]:
363
        encrypted, addr = self._socket.recvfrom(
×
364
            nbytes if nbytes is not None else len(buffer), flags
365
        )
366
        if not encrypted:
×
367
            return 0, addr
×
368
        self._buffer.receive_from_network(encrypted)
×
369
        return (
×
370
            self._buffer.readinto(
371
                buffer, nbytes if nbytes is not None else len(buffer)
372
            ),
373
            addr,
374
        )
375

376
    def send(self, message: bytes, flags: int = 0) -> int:
1✔
377
        # Maximum size supported by TLS is 16K (encrypted).
378
        # mbedTLS defines it in MBEDTLS_SSL_MAX_CONTENT_LEN and
379
        # MBEDTLS_SSL_IN_CONTENT_LEN/MBEDTLS_SSL_OUT_CONTENT_LEN.
380
        amt = self._buffer.write(message)
×
381
        encrypted = self._buffer.peek_outgoing(amt)
×
382
        self._socket.send(encrypted, flags)
×
383
        self._buffer.consume_outgoing(amt)
×
384
        return len(message)
×
385

386
    def sendall(self, message: bytes, flags: int = 0) -> None:
1✔
387
        amt = self._buffer.write(message)
×
388
        encrypted = self._buffer.peek_outgoing(amt)
×
389
        self._buffer.consume_outgoing(amt)
×
390
        self._socket.sendall(encrypted, flags)
×
391

392
    def sendto(self, message: bytes, *args: Any) -> int:
1✔
393
        if not 1 <= len(args) <= 2:
×
394
            raise TypeError(
×
395
                "sendto() takes 2 or 3 arguments (%i given)" % (1 + len(args))
396
            )
397

398
        amt = self._buffer.write(message)
×
399
        encrypted = self._buffer.peek_outgoing(amt)
×
400
        self._socket.sendto(encrypted, *args)
×
401
        self._buffer.consume_outgoing(amt)
×
402
        return len(message)
×
403

404
    def shutdown(self, how: int) -> None:
1✔
405
        self._buffer.shutdown()
×
406
        # Alerts are much smaller but it doesn't matter.
407
        close_notify = self._buffer.peek_outgoing(4096)
×
408
        with suppress(OSError):
×
409
            # Do not raise if the socket is already closed.
410
            amt = self._socket.send(close_notify)
×
411
            self._buffer.consume_outgoing(amt)
×
412
        self._socket.shutdown(how)
×
413

414
    # PEP 543 adds the following methods.
415
    @overload
1✔
416
    def do_handshake(self) -> None: ...
417

418
    @overload
1✔
419
    def do_handshake(self, address: _Address) -> None: ...
420

421
    @overload
1✔
422
    def do_handshake(self, flags: int, address: _Address) -> None: ...
423

424
    def do_handshake(self, *args):  # type: ignore[no-untyped-def]
1✔
425
        # pylint: disable=too-many-branches
426
        if args and self.type is not _pysocket.SOCK_DGRAM:
×
427
            raise OSError(errno.ENOTCONN, os.strerror(errno.ENOTCONN))
×
428

429
        if len(args) == 0:
×
430
            flags, address = 0, None
×
431
        elif len(args) == 1:
×
432
            flags, address = 0, args[0]
×
433
        elif len(args) == 2:
×
434
            assert isinstance(args[0], int)
×
435
            flags, address = args
×
436
        else:
437
            raise TypeError("do_handshake() takes 0, 1, or 2 arguments")
×
438

439
        while self._handshake_state is not HandshakeStep.HANDSHAKE_OVER:
×
440
            try:
×
441
                self._buffer.do_handshake()
×
442
            except WantReadError as exc:
×
443
                if address is None:
×
444
                    data = self._socket.recv(
×
445
                        TLSWrappedSocket.CHUNK_SIZE, flags
446
                    )
447
                else:
448
                    data, addr = self._socket.recvfrom(
×
449
                        TLSWrappedSocket.CHUNK_SIZE, flags
450
                    )
451
                    if addr != address:
×
452
                        # The error may not be the clearest but we'd better
453
                        # bail out in any case.
454
                        raise OSError(
×
455
                            errno.ENOTCONN, os.strerror(errno.ENOTCONN)
456
                        ) from exc
457
                self._buffer.receive_from_network(data)
×
458
            except WantWriteError:
×
459
                in_transit = self._buffer.peek_outgoing(
×
460
                    TLSWrappedSocket.CHUNK_SIZE
461
                )
462
                if address is None:
×
463
                    amt = self._socket.send(in_transit, flags)
×
464
                else:
465
                    amt = self._socket.sendto(in_transit, flags, address)
×
466
                self._buffer.consume_outgoing(amt)
×
467

468
    def unwrap(self) -> _pysocket.socket:
1✔
469
        self._buffer.shutdown()
×
470
        return self._socket
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc